feat(vla): align transformer training stack and rollout validation
This commit is contained in:
@@ -105,7 +105,7 @@ class SimpleRobotDataset(Dataset):
|
||||
self._file_cache[key] = f
|
||||
return f
|
||||
|
||||
def _load_frame(self, idx: int) -> Dict:
|
||||
def _load_frame(self, idx: int, *, load_images: bool = True) -> Dict:
|
||||
"""从 HDF5 文件懒加载单帧数据"""
|
||||
meta = self.frame_meta[idx]
|
||||
f = self._get_h5_file(meta["hdf5_path"])
|
||||
@@ -118,21 +118,22 @@ class SimpleRobotDataset(Dataset):
|
||||
}
|
||||
|
||||
# 加载图像数据: observations/images/{cam_name} -> observation.{cam_name}
|
||||
for cam_name in self.camera_names:
|
||||
h5_path = f'observations/images/{cam_name}'
|
||||
if h5_path in f:
|
||||
img = f[h5_path][meta["frame_idx"]]
|
||||
# Resize图像到224x224(减少内存和I/O负担)
|
||||
import cv2
|
||||
img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_LINEAR)
|
||||
# 转换为float并归一化到 [0, 1]
|
||||
img = torch.from_numpy(img).float() / 255.0
|
||||
frame[f"observation.{cam_name}"] = img.permute(2, 0, 1) # HWC -> CHW
|
||||
if load_images:
|
||||
for cam_name in self.camera_names:
|
||||
h5_path = f'observations/images/{cam_name}'
|
||||
if h5_path in f:
|
||||
img = f[h5_path][meta["frame_idx"]]
|
||||
# Resize图像到224x224(减少内存和I/O负担)
|
||||
import cv2
|
||||
img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_LINEAR)
|
||||
# 转换为float并归一化到 [0, 1]
|
||||
img = torch.from_numpy(img).float() / 255.0
|
||||
frame[f"observation.{cam_name}"] = img.permute(2, 0, 1) # HWC -> CHW
|
||||
|
||||
return frame
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
||||
frame = self._load_frame(idx)
|
||||
frame = self._load_frame(idx, load_images=False)
|
||||
ep_idx = frame["episode_index"]
|
||||
|
||||
# 获取当前 episode 的帧索引范围
|
||||
@@ -186,10 +187,10 @@ class SimpleRobotDataset(Dataset):
|
||||
target_idx = idx + delta
|
||||
|
||||
if target_idx <= ep_end:
|
||||
actions.append(self._load_frame(target_idx)["action"])
|
||||
actions.append(self._load_frame(target_idx, load_images=False)["action"])
|
||||
action_is_pad.append(False)
|
||||
else:
|
||||
actions.append(self._load_frame(ep_end)["action"])
|
||||
actions.append(self._load_frame(ep_end, load_images=False)["action"])
|
||||
action_is_pad.append(True)
|
||||
|
||||
# ============================================
|
||||
|
||||
Reference in New Issue
Block a user