diff --git a/roboimi/demos/vla_scripts/train_vla.py b/roboimi/demos/vla_scripts/train_vla.py index 473c01f..c4656ca 100644 --- a/roboimi/demos/vla_scripts/train_vla.py +++ b/roboimi/demos/vla_scripts/train_vla.py @@ -139,6 +139,7 @@ def main(cfg: DictConfig): shuffle=True, num_workers=cfg.train.num_workers, pin_memory=(cfg.train.device != "cpu"), + persistent_workers=(cfg.train.num_workers > 0), drop_last=True # 丢弃不完整批次以稳定训练 ) @@ -150,6 +151,7 @@ def main(cfg: DictConfig): shuffle=False, num_workers=cfg.train.num_workers, pin_memory=(cfg.train.device != "cpu"), + persistent_workers=(cfg.train.num_workers > 0), drop_last=False ) diff --git a/roboimi/vla/data/simpe_robot_dataset.py b/roboimi/vla/data/simpe_robot_dataset.py index 7b2fef3..83c995f 100644 --- a/roboimi/vla/data/simpe_robot_dataset.py +++ b/roboimi/vla/data/simpe_robot_dataset.py @@ -3,6 +3,7 @@ import h5py from torch.utils.data import Dataset from typing import List, Dict, Union from pathlib import Path +from collections import OrderedDict class SimpleRobotDataset(Dataset): @@ -21,6 +22,7 @@ class SimpleRobotDataset(Dataset): obs_horizon: int = 2, pred_horizon: int = 8, camera_names: List[str] = None, + max_open_files: int = 64, ): """ Args: @@ -28,6 +30,7 @@ class SimpleRobotDataset(Dataset): obs_horizon: 观察过去多少帧 pred_horizon: 预测未来多少帧动作 camera_names: 相机名称列表,如 ["r_vis", "top", "front"] + max_open_files: 每个 worker 最多缓存的 HDF5 文件句柄数 HDF5 文件格式: - action: [T, action_dim] @@ -37,6 +40,8 @@ class SimpleRobotDataset(Dataset): self.obs_horizon = obs_horizon self.pred_horizon = pred_horizon self.camera_names = camera_names or [] + self.max_open_files = max(1, int(max_open_files)) + self._file_cache: "OrderedDict[str, h5py.File]" = OrderedDict() self.dataset_dir = Path(dataset_dir) if not self.dataset_dir.exists(): @@ -69,29 +74,60 @@ class SimpleRobotDataset(Dataset): def __len__(self): return len(self.frame_meta) + def _close_all_files(self) -> None: + """关闭当前 worker 内缓存的所有 HDF5 文件句柄。""" + for f in self._file_cache.values(): + try: + f.close() + except Exception: + pass + self._file_cache.clear() + + def _get_h5_file(self, hdf5_path: Union[str, Path]) -> h5py.File: + """ + 获取 HDF5 文件句柄(worker 内 LRU 缓存)。 + 注意:缓存的是文件句柄,不是帧数据本身。 + """ + key = str(hdf5_path) + if key in self._file_cache: + self._file_cache.move_to_end(key) + return self._file_cache[key] + + # 超过上限时淘汰最久未使用的句柄 + if len(self._file_cache) >= self.max_open_files: + _, old_file = self._file_cache.popitem(last=False) + try: + old_file.close() + except Exception: + pass + + f = h5py.File(key, 'r') + self._file_cache[key] = f + return f + def _load_frame(self, idx: int) -> Dict: """从 HDF5 文件懒加载单帧数据""" meta = self.frame_meta[idx] - with h5py.File(meta["hdf5_path"], 'r') as f: - frame = { - "episode_index": meta["ep_idx"], - "frame_index": meta["frame_idx"], - "task": f.get('task', [b"unknown"])[0].decode() if 'task' in f else "unknown", - "observation.state": torch.from_numpy(f['observations/qpos'][meta["frame_idx"]]).float(), - "action": torch.from_numpy(f['action'][meta["frame_idx"]]).float(), - } + f = self._get_h5_file(meta["hdf5_path"]) + frame = { + "episode_index": meta["ep_idx"], + "frame_index": meta["frame_idx"], + "task": f.get('task', [b"unknown"])[0].decode() if 'task' in f else "unknown", + "observation.state": torch.from_numpy(f['observations/qpos'][meta["frame_idx"]]).float(), + "action": torch.from_numpy(f['action'][meta["frame_idx"]]).float(), + } - # 加载图像数据: observations/images/{cam_name} -> observation.{cam_name} - for cam_name in self.camera_names: - h5_path = f'observations/images/{cam_name}' - if h5_path in f: - img = f[h5_path][meta["frame_idx"]] - # Resize图像到224x224(减少内存和I/O负担) - import cv2 - img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_LINEAR) - # 转换为float并归一化到 [0, 1] - img = torch.from_numpy(img).float() / 255.0 - frame[f"observation.{cam_name}"] = img.permute(2, 0, 1) # HWC -> CHW + # 加载图像数据: observations/images/{cam_name} -> observation.{cam_name} + for cam_name in self.camera_names: + h5_path = f'observations/images/{cam_name}' + if h5_path in f: + img = f[h5_path][meta["frame_idx"]] + # Resize图像到224x224(减少内存和I/O负担) + import cv2 + img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_LINEAR) + # 转换为float并归一化到 [0, 1] + img = torch.from_numpy(img).float() / 255.0 + frame[f"observation.{cam_name}"] = img.permute(2, 0, 1) # HWC -> CHW return frame @@ -201,3 +237,6 @@ class SimpleRobotDataset(Dataset): "dtype": str(sample[key].dtype), } return info + + def __del__(self): + self._close_all_files()