feat: 缓存worker内的句柄
This commit is contained in:
@@ -139,6 +139,7 @@ def main(cfg: DictConfig):
|
|||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=cfg.train.num_workers,
|
num_workers=cfg.train.num_workers,
|
||||||
pin_memory=(cfg.train.device != "cpu"),
|
pin_memory=(cfg.train.device != "cpu"),
|
||||||
|
persistent_workers=(cfg.train.num_workers > 0),
|
||||||
drop_last=True # 丢弃不完整批次以稳定训练
|
drop_last=True # 丢弃不完整批次以稳定训练
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -150,6 +151,7 @@ def main(cfg: DictConfig):
|
|||||||
shuffle=False,
|
shuffle=False,
|
||||||
num_workers=cfg.train.num_workers,
|
num_workers=cfg.train.num_workers,
|
||||||
pin_memory=(cfg.train.device != "cpu"),
|
pin_memory=(cfg.train.device != "cpu"),
|
||||||
|
persistent_workers=(cfg.train.num_workers > 0),
|
||||||
drop_last=False
|
drop_last=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import h5py
|
|||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from typing import List, Dict, Union
|
from typing import List, Dict, Union
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
|
||||||
class SimpleRobotDataset(Dataset):
|
class SimpleRobotDataset(Dataset):
|
||||||
@@ -21,6 +22,7 @@ class SimpleRobotDataset(Dataset):
|
|||||||
obs_horizon: int = 2,
|
obs_horizon: int = 2,
|
||||||
pred_horizon: int = 8,
|
pred_horizon: int = 8,
|
||||||
camera_names: List[str] = None,
|
camera_names: List[str] = None,
|
||||||
|
max_open_files: int = 64,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -28,6 +30,7 @@ class SimpleRobotDataset(Dataset):
|
|||||||
obs_horizon: 观察过去多少帧
|
obs_horizon: 观察过去多少帧
|
||||||
pred_horizon: 预测未来多少帧动作
|
pred_horizon: 预测未来多少帧动作
|
||||||
camera_names: 相机名称列表,如 ["r_vis", "top", "front"]
|
camera_names: 相机名称列表,如 ["r_vis", "top", "front"]
|
||||||
|
max_open_files: 每个 worker 最多缓存的 HDF5 文件句柄数
|
||||||
|
|
||||||
HDF5 文件格式:
|
HDF5 文件格式:
|
||||||
- action: [T, action_dim]
|
- action: [T, action_dim]
|
||||||
@@ -37,6 +40,8 @@ class SimpleRobotDataset(Dataset):
|
|||||||
self.obs_horizon = obs_horizon
|
self.obs_horizon = obs_horizon
|
||||||
self.pred_horizon = pred_horizon
|
self.pred_horizon = pred_horizon
|
||||||
self.camera_names = camera_names or []
|
self.camera_names = camera_names or []
|
||||||
|
self.max_open_files = max(1, int(max_open_files))
|
||||||
|
self._file_cache: "OrderedDict[str, h5py.File]" = OrderedDict()
|
||||||
|
|
||||||
self.dataset_dir = Path(dataset_dir)
|
self.dataset_dir = Path(dataset_dir)
|
||||||
if not self.dataset_dir.exists():
|
if not self.dataset_dir.exists():
|
||||||
@@ -69,10 +74,41 @@ class SimpleRobotDataset(Dataset):
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.frame_meta)
|
return len(self.frame_meta)
|
||||||
|
|
||||||
|
def _close_all_files(self) -> None:
|
||||||
|
"""关闭当前 worker 内缓存的所有 HDF5 文件句柄。"""
|
||||||
|
for f in self._file_cache.values():
|
||||||
|
try:
|
||||||
|
f.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._file_cache.clear()
|
||||||
|
|
||||||
|
def _get_h5_file(self, hdf5_path: Union[str, Path]) -> h5py.File:
|
||||||
|
"""
|
||||||
|
获取 HDF5 文件句柄(worker 内 LRU 缓存)。
|
||||||
|
注意:缓存的是文件句柄,不是帧数据本身。
|
||||||
|
"""
|
||||||
|
key = str(hdf5_path)
|
||||||
|
if key in self._file_cache:
|
||||||
|
self._file_cache.move_to_end(key)
|
||||||
|
return self._file_cache[key]
|
||||||
|
|
||||||
|
# 超过上限时淘汰最久未使用的句柄
|
||||||
|
if len(self._file_cache) >= self.max_open_files:
|
||||||
|
_, old_file = self._file_cache.popitem(last=False)
|
||||||
|
try:
|
||||||
|
old_file.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
f = h5py.File(key, 'r')
|
||||||
|
self._file_cache[key] = f
|
||||||
|
return f
|
||||||
|
|
||||||
def _load_frame(self, idx: int) -> Dict:
|
def _load_frame(self, idx: int) -> Dict:
|
||||||
"""从 HDF5 文件懒加载单帧数据"""
|
"""从 HDF5 文件懒加载单帧数据"""
|
||||||
meta = self.frame_meta[idx]
|
meta = self.frame_meta[idx]
|
||||||
with h5py.File(meta["hdf5_path"], 'r') as f:
|
f = self._get_h5_file(meta["hdf5_path"])
|
||||||
frame = {
|
frame = {
|
||||||
"episode_index": meta["ep_idx"],
|
"episode_index": meta["ep_idx"],
|
||||||
"frame_index": meta["frame_idx"],
|
"frame_index": meta["frame_idx"],
|
||||||
@@ -201,3 +237,6 @@ class SimpleRobotDataset(Dataset):
|
|||||||
"dtype": str(sample[key].dtype),
|
"dtype": str(sample[key].dtype),
|
||||||
}
|
}
|
||||||
return info
|
return info
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
self._close_all_files()
|
||||||
|
|||||||
Reference in New Issue
Block a user