feat: 缓存worker内的句柄
This commit is contained in:
@@ -139,6 +139,7 @@ def main(cfg: DictConfig):
|
||||
shuffle=True,
|
||||
num_workers=cfg.train.num_workers,
|
||||
pin_memory=(cfg.train.device != "cpu"),
|
||||
persistent_workers=(cfg.train.num_workers > 0),
|
||||
drop_last=True # 丢弃不完整批次以稳定训练
|
||||
)
|
||||
|
||||
@@ -150,6 +151,7 @@ def main(cfg: DictConfig):
|
||||
shuffle=False,
|
||||
num_workers=cfg.train.num_workers,
|
||||
pin_memory=(cfg.train.device != "cpu"),
|
||||
persistent_workers=(cfg.train.num_workers > 0),
|
||||
drop_last=False
|
||||
)
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import h5py
|
||||
from torch.utils.data import Dataset
|
||||
from typing import List, Dict, Union
|
||||
from pathlib import Path
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class SimpleRobotDataset(Dataset):
|
||||
@@ -21,6 +22,7 @@ class SimpleRobotDataset(Dataset):
|
||||
obs_horizon: int = 2,
|
||||
pred_horizon: int = 8,
|
||||
camera_names: List[str] = None,
|
||||
max_open_files: int = 64,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -28,6 +30,7 @@ class SimpleRobotDataset(Dataset):
|
||||
obs_horizon: 观察过去多少帧
|
||||
pred_horizon: 预测未来多少帧动作
|
||||
camera_names: 相机名称列表,如 ["r_vis", "top", "front"]
|
||||
max_open_files: 每个 worker 最多缓存的 HDF5 文件句柄数
|
||||
|
||||
HDF5 文件格式:
|
||||
- action: [T, action_dim]
|
||||
@@ -37,6 +40,8 @@ class SimpleRobotDataset(Dataset):
|
||||
self.obs_horizon = obs_horizon
|
||||
self.pred_horizon = pred_horizon
|
||||
self.camera_names = camera_names or []
|
||||
self.max_open_files = max(1, int(max_open_files))
|
||||
self._file_cache: "OrderedDict[str, h5py.File]" = OrderedDict()
|
||||
|
||||
self.dataset_dir = Path(dataset_dir)
|
||||
if not self.dataset_dir.exists():
|
||||
@@ -69,10 +74,41 @@ class SimpleRobotDataset(Dataset):
|
||||
def __len__(self):
|
||||
return len(self.frame_meta)
|
||||
|
||||
def _close_all_files(self) -> None:
|
||||
"""关闭当前 worker 内缓存的所有 HDF5 文件句柄。"""
|
||||
for f in self._file_cache.values():
|
||||
try:
|
||||
f.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._file_cache.clear()
|
||||
|
||||
def _get_h5_file(self, hdf5_path: Union[str, Path]) -> h5py.File:
|
||||
"""
|
||||
获取 HDF5 文件句柄(worker 内 LRU 缓存)。
|
||||
注意:缓存的是文件句柄,不是帧数据本身。
|
||||
"""
|
||||
key = str(hdf5_path)
|
||||
if key in self._file_cache:
|
||||
self._file_cache.move_to_end(key)
|
||||
return self._file_cache[key]
|
||||
|
||||
# 超过上限时淘汰最久未使用的句柄
|
||||
if len(self._file_cache) >= self.max_open_files:
|
||||
_, old_file = self._file_cache.popitem(last=False)
|
||||
try:
|
||||
old_file.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
f = h5py.File(key, 'r')
|
||||
self._file_cache[key] = f
|
||||
return f
|
||||
|
||||
def _load_frame(self, idx: int) -> Dict:
|
||||
"""从 HDF5 文件懒加载单帧数据"""
|
||||
meta = self.frame_meta[idx]
|
||||
with h5py.File(meta["hdf5_path"], 'r') as f:
|
||||
f = self._get_h5_file(meta["hdf5_path"])
|
||||
frame = {
|
||||
"episode_index": meta["ep_idx"],
|
||||
"frame_index": meta["frame_idx"],
|
||||
@@ -201,3 +237,6 @@ class SimpleRobotDataset(Dataset):
|
||||
"dtype": str(sample[key].dtype),
|
||||
}
|
||||
return info
|
||||
|
||||
def __del__(self):
|
||||
self._close_all_files()
|
||||
|
||||
Reference in New Issue
Block a user