feat: 缓存worker内的句柄

This commit is contained in:
gouhanke
2026-03-04 10:49:41 +08:00
parent 8bcad5844e
commit 7d39933a5b
2 changed files with 60 additions and 19 deletions

View File

@@ -139,6 +139,7 @@ def main(cfg: DictConfig):
shuffle=True,
num_workers=cfg.train.num_workers,
pin_memory=(cfg.train.device != "cpu"),
persistent_workers=(cfg.train.num_workers > 0),
drop_last=True # 丢弃不完整批次以稳定训练
)
@@ -150,6 +151,7 @@ def main(cfg: DictConfig):
shuffle=False,
num_workers=cfg.train.num_workers,
pin_memory=(cfg.train.device != "cpu"),
persistent_workers=(cfg.train.num_workers > 0),
drop_last=False
)

View File

@@ -3,6 +3,7 @@ import h5py
from torch.utils.data import Dataset
from typing import List, Dict, Union
from pathlib import Path
from collections import OrderedDict
class SimpleRobotDataset(Dataset):
@@ -21,6 +22,7 @@ class SimpleRobotDataset(Dataset):
obs_horizon: int = 2,
pred_horizon: int = 8,
camera_names: List[str] = None,
max_open_files: int = 64,
):
"""
Args:
@@ -28,6 +30,7 @@ class SimpleRobotDataset(Dataset):
obs_horizon: 观察过去多少帧
pred_horizon: 预测未来多少帧动作
camera_names: 相机名称列表,如 ["r_vis", "top", "front"]
max_open_files: 每个 worker 最多缓存的 HDF5 文件句柄数
HDF5 文件格式:
- action: [T, action_dim]
@@ -37,6 +40,8 @@ class SimpleRobotDataset(Dataset):
self.obs_horizon = obs_horizon
self.pred_horizon = pred_horizon
self.camera_names = camera_names or []
self.max_open_files = max(1, int(max_open_files))
self._file_cache: "OrderedDict[str, h5py.File]" = OrderedDict()
self.dataset_dir = Path(dataset_dir)
if not self.dataset_dir.exists():
@@ -69,29 +74,60 @@ class SimpleRobotDataset(Dataset):
def __len__(self):
return len(self.frame_meta)
def _close_all_files(self) -> None:
"""关闭当前 worker 内缓存的所有 HDF5 文件句柄。"""
for f in self._file_cache.values():
try:
f.close()
except Exception:
pass
self._file_cache.clear()
def _get_h5_file(self, hdf5_path: Union[str, Path]) -> h5py.File:
"""
获取 HDF5 文件句柄worker 内 LRU 缓存)。
注意:缓存的是文件句柄,不是帧数据本身。
"""
key = str(hdf5_path)
if key in self._file_cache:
self._file_cache.move_to_end(key)
return self._file_cache[key]
# 超过上限时淘汰最久未使用的句柄
if len(self._file_cache) >= self.max_open_files:
_, old_file = self._file_cache.popitem(last=False)
try:
old_file.close()
except Exception:
pass
f = h5py.File(key, 'r')
self._file_cache[key] = f
return f
def _load_frame(self, idx: int) -> Dict:
"""从 HDF5 文件懒加载单帧数据"""
meta = self.frame_meta[idx]
with h5py.File(meta["hdf5_path"], 'r') as f:
frame = {
"episode_index": meta["ep_idx"],
"frame_index": meta["frame_idx"],
"task": f.get('task', [b"unknown"])[0].decode() if 'task' in f else "unknown",
"observation.state": torch.from_numpy(f['observations/qpos'][meta["frame_idx"]]).float(),
"action": torch.from_numpy(f['action'][meta["frame_idx"]]).float(),
}
f = self._get_h5_file(meta["hdf5_path"])
frame = {
"episode_index": meta["ep_idx"],
"frame_index": meta["frame_idx"],
"task": f.get('task', [b"unknown"])[0].decode() if 'task' in f else "unknown",
"observation.state": torch.from_numpy(f['observations/qpos'][meta["frame_idx"]]).float(),
"action": torch.from_numpy(f['action'][meta["frame_idx"]]).float(),
}
# 加载图像数据: observations/images/{cam_name} -> observation.{cam_name}
for cam_name in self.camera_names:
h5_path = f'observations/images/{cam_name}'
if h5_path in f:
img = f[h5_path][meta["frame_idx"]]
# Resize图像到224x224减少内存和I/O负担
import cv2
img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_LINEAR)
# 转换为float并归一化到 [0, 1]
img = torch.from_numpy(img).float() / 255.0
frame[f"observation.{cam_name}"] = img.permute(2, 0, 1) # HWC -> CHW
# 加载图像数据: observations/images/{cam_name} -> observation.{cam_name}
for cam_name in self.camera_names:
h5_path = f'observations/images/{cam_name}'
if h5_path in f:
img = f[h5_path][meta["frame_idx"]]
# Resize图像到224x224减少内存和I/O负担
import cv2
img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_LINEAR)
# 转换为float并归一化到 [0, 1]
img = torch.from_numpy(img).float() / 255.0
frame[f"observation.{cam_name}"] = img.permute(2, 0, 1) # HWC -> CHW
return frame
@@ -201,3 +237,6 @@ class SimpleRobotDataset(Dataset):
"dtype": str(sample[key].dtype),
}
return info
def __del__(self):
self._close_all_files()