添加pad_loss

This commit is contained in:
gouhanke
2026-02-11 20:33:26 +08:00
parent eeb07cad15
commit 83cd55e67b
5 changed files with 27 additions and 8 deletions

View File

@@ -86,8 +86,8 @@ class SimpleRobotDataset(Dataset):
h5_path = f'observations/images/{cam_name}'
if h5_path in f:
img = f[h5_path][meta["frame_idx"]]
img = torch.from_numpy(img)
# 保持 uint8 格式以节省传输带宽,归一化移至 GPU (在 train_vla.py 中处理)
# 转换为float并归一化到 [0, 1]
img = torch.from_numpy(img).float() / 255.0
frame[f"observation.{cam_name}"] = img.permute(2, 0, 1) # HWC -> CHW
return frame