feat: 缓存worker内的句柄

This commit is contained in:
gouhanke
2026-03-04 10:49:41 +08:00
parent 8bcad5844e
commit 7d39933a5b
2 changed files with 60 additions and 19 deletions

View File

@@ -139,6 +139,7 @@ def main(cfg: DictConfig):
shuffle=True, shuffle=True,
num_workers=cfg.train.num_workers, num_workers=cfg.train.num_workers,
pin_memory=(cfg.train.device != "cpu"), pin_memory=(cfg.train.device != "cpu"),
persistent_workers=(cfg.train.num_workers > 0),
drop_last=True # 丢弃不完整批次以稳定训练 drop_last=True # 丢弃不完整批次以稳定训练
) )
@@ -150,6 +151,7 @@ def main(cfg: DictConfig):
shuffle=False, shuffle=False,
num_workers=cfg.train.num_workers, num_workers=cfg.train.num_workers,
pin_memory=(cfg.train.device != "cpu"), pin_memory=(cfg.train.device != "cpu"),
persistent_workers=(cfg.train.num_workers > 0),
drop_last=False drop_last=False
) )

View File

@@ -3,6 +3,7 @@ import h5py
from torch.utils.data import Dataset from torch.utils.data import Dataset
from typing import List, Dict, Union from typing import List, Dict, Union
from pathlib import Path from pathlib import Path
from collections import OrderedDict
class SimpleRobotDataset(Dataset): class SimpleRobotDataset(Dataset):
@@ -21,6 +22,7 @@ class SimpleRobotDataset(Dataset):
obs_horizon: int = 2, obs_horizon: int = 2,
pred_horizon: int = 8, pred_horizon: int = 8,
camera_names: List[str] = None, camera_names: List[str] = None,
max_open_files: int = 64,
): ):
""" """
Args: Args:
@@ -28,6 +30,7 @@ class SimpleRobotDataset(Dataset):
obs_horizon: 观察过去多少帧 obs_horizon: 观察过去多少帧
pred_horizon: 预测未来多少帧动作 pred_horizon: 预测未来多少帧动作
camera_names: 相机名称列表,如 ["r_vis", "top", "front"] camera_names: 相机名称列表,如 ["r_vis", "top", "front"]
max_open_files: 每个 worker 最多缓存的 HDF5 文件句柄数
HDF5 文件格式: HDF5 文件格式:
- action: [T, action_dim] - action: [T, action_dim]
@@ -37,6 +40,8 @@ class SimpleRobotDataset(Dataset):
self.obs_horizon = obs_horizon self.obs_horizon = obs_horizon
self.pred_horizon = pred_horizon self.pred_horizon = pred_horizon
self.camera_names = camera_names or [] self.camera_names = camera_names or []
self.max_open_files = max(1, int(max_open_files))
self._file_cache: "OrderedDict[str, h5py.File]" = OrderedDict()
self.dataset_dir = Path(dataset_dir) self.dataset_dir = Path(dataset_dir)
if not self.dataset_dir.exists(): if not self.dataset_dir.exists():
@@ -69,10 +74,41 @@ class SimpleRobotDataset(Dataset):
def __len__(self): def __len__(self):
return len(self.frame_meta) return len(self.frame_meta)
def _close_all_files(self) -> None:
"""关闭当前 worker 内缓存的所有 HDF5 文件句柄。"""
for f in self._file_cache.values():
try:
f.close()
except Exception:
pass
self._file_cache.clear()
def _get_h5_file(self, hdf5_path: Union[str, Path]) -> h5py.File:
"""
获取 HDF5 文件句柄worker 内 LRU 缓存)。
注意:缓存的是文件句柄,不是帧数据本身。
"""
key = str(hdf5_path)
if key in self._file_cache:
self._file_cache.move_to_end(key)
return self._file_cache[key]
# 超过上限时淘汰最久未使用的句柄
if len(self._file_cache) >= self.max_open_files:
_, old_file = self._file_cache.popitem(last=False)
try:
old_file.close()
except Exception:
pass
f = h5py.File(key, 'r')
self._file_cache[key] = f
return f
def _load_frame(self, idx: int) -> Dict: def _load_frame(self, idx: int) -> Dict:
"""从 HDF5 文件懒加载单帧数据""" """从 HDF5 文件懒加载单帧数据"""
meta = self.frame_meta[idx] meta = self.frame_meta[idx]
with h5py.File(meta["hdf5_path"], 'r') as f: f = self._get_h5_file(meta["hdf5_path"])
frame = { frame = {
"episode_index": meta["ep_idx"], "episode_index": meta["ep_idx"],
"frame_index": meta["frame_idx"], "frame_index": meta["frame_idx"],
@@ -201,3 +237,6 @@ class SimpleRobotDataset(Dataset):
"dtype": str(sample[key].dtype), "dtype": str(sample[key].dtype),
} }
return info return info
def __del__(self):
self._close_all_files()