feat(train): 跑通训练脚本

This commit is contained in:
gouhanke
2026-02-05 14:08:43 +08:00
parent dd2749cb12
commit b0a944f7aa
17 changed files with 1002 additions and 464 deletions

View File

@@ -90,52 +90,48 @@ class RobotDiffusionDataset(Dataset):
# 如果 obs_horizon=2, current_ts=0 -> indices=[0, 0] (Padding)
# 如果 obs_horizon=2, current_ts=5 -> indices=[4, 5]
indices = []
for i in range(self.obs_horizon):
# t - (To - 1) + i
query_ts = start_ts - (self.obs_horizon - 1) + i
# 边界处理 (Padding first frame)
query_ts = max(query_ts, 0)
indices.append(query_ts)
# 读取 qpos (proprioception)
qpos_data = root['observations/qpos']
qpos = qpos_data[indices] # smart indexing
if self.stats:
qpos = self._normalize_data(qpos, self.stats['qpos'])
start_idx_raw = start_ts - (self.obs_horizon - 1)
start_idx = max(start_idx_raw, 0)
end_idx = start_ts + 1
pad_len = max(0, -start_idx_raw)
# 读取 Images
# 你有三个视角: angle, r_vis, top
# 建议将它们分开返回,或者在 Dataset 里 Concat
# Qpos
qpos_data = root['observations/qpos']
qpos_val = qpos_data[start_idx:end_idx]
if pad_len > 0:
first_frame = qpos_val[0]
padding = np.repeat(first_frame[np.newaxis, :], pad_len, axis=0)
qpos_val = np.concatenate([padding, qpos_val], axis=0)
if self.stats:
qpos_val = self._normalize_data(qpos_val, self.stats['qpos'])
# Images
image_dict = {}
for cam_name in self.camera_names:
# HDF5 dataset
img_dset = root['observations']['images'][cam_name]
imgs_np = img_dset[start_idx:end_idx] # (T, H, W, C)
imgs = []
for t in indices:
img = img_dset[t] # (480, 640, 3) uint8
img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0 # (C, H, W)
imgs.append(img)
if pad_len > 0:
first_frame = imgs_np[0]
padding = np.repeat(first_frame[np.newaxis, ...], pad_len, axis=0)
imgs_np = np.concatenate([padding, imgs_np], axis=0)
# Stack time dimension: (obs_horizon, 3, H, W)
image_dict[cam_name] = torch.stack(imgs)
# 转换为 Tensor: (T, H, W, C) -> (T, C, H, W)
imgs_tensor = torch.from_numpy(imgs_np).float() / 255.0
imgs_tensor = torch.einsum('thwc->tchw', imgs_tensor)
image_dict[cam_name] = imgs_tensor
# -----------------------------
# 4. 组装 Batch
# -----------------------------
# ==============================
# 3. 组装 Batch
# ==============================
data_batch = {
'action': torch.from_numpy(actions).float(), # (Tp, 16)
'qpos': torch.from_numpy(qpos).float(), # (To, 16)
'action': torch.from_numpy(actions).float(),
'qpos': torch.from_numpy(qpos_val).float(),
}
# 将图像放入 batch
for cam_name, img_tensor in image_dict.items():
data_batch[f'image_{cam_name}'] = img_tensor # (To, 3, H, W)
# TODO: 添加 Language Instruction
# 如果所有 episode 共享任务,这里可以是固定 embedding
# 如果每个 episode 任务不同,你需要一个额外的 meta json 来映射 file_path -> text
# data_batch['lang_text'] = "pick up the red cube"
data_batch[f'image_{cam_name}'] = img_tensor
return data_batch