feat(vla): align transformer training stack and rollout validation

This commit is contained in:
Logic
2026-03-31 15:39:20 +08:00
parent 424c265823
commit d84bc6876e
25 changed files with 4043 additions and 706 deletions

View File

@@ -105,7 +105,7 @@ class SimpleRobotDataset(Dataset):
self._file_cache[key] = f
return f
def _load_frame(self, idx: int) -> Dict:
def _load_frame(self, idx: int, *, load_images: bool = True) -> Dict:
"""从 HDF5 文件懒加载单帧数据"""
meta = self.frame_meta[idx]
f = self._get_h5_file(meta["hdf5_path"])
@@ -118,21 +118,22 @@ class SimpleRobotDataset(Dataset):
}
# 加载图像数据: observations/images/{cam_name} -> observation.{cam_name}
for cam_name in self.camera_names:
h5_path = f'observations/images/{cam_name}'
if h5_path in f:
img = f[h5_path][meta["frame_idx"]]
# Resize图像到224x224减少内存和I/O负担
import cv2
img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_LINEAR)
# 转换为float并归一化到 [0, 1]
img = torch.from_numpy(img).float() / 255.0
frame[f"observation.{cam_name}"] = img.permute(2, 0, 1) # HWC -> CHW
if load_images:
for cam_name in self.camera_names:
h5_path = f'observations/images/{cam_name}'
if h5_path in f:
img = f[h5_path][meta["frame_idx"]]
# Resize图像到224x224减少内存和I/O负担
import cv2
img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_LINEAR)
# 转换为float并归一化到 [0, 1]
img = torch.from_numpy(img).float() / 255.0
frame[f"observation.{cam_name}"] = img.permute(2, 0, 1) # HWC -> CHW
return frame
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
frame = self._load_frame(idx)
frame = self._load_frame(idx, load_images=False)
ep_idx = frame["episode_index"]
# 获取当前 episode 的帧索引范围
@@ -186,10 +187,10 @@ class SimpleRobotDataset(Dataset):
target_idx = idx + delta
if target_idx <= ep_end:
actions.append(self._load_frame(target_idx)["action"])
actions.append(self._load_frame(target_idx, load_images=False)["action"])
action_is_pad.append(False)
else:
actions.append(self._load_frame(ep_end)["action"])
actions.append(self._load_frame(ep_end, load_images=False)["action"])
action_is_pad.append(True)
# ============================================