feat(train): 跑通训练脚本
This commit is contained in:
@@ -90,52 +90,48 @@ class RobotDiffusionDataset(Dataset):
|
||||
# 如果 obs_horizon=2, current_ts=0 -> indices=[0, 0] (Padding)
|
||||
# 如果 obs_horizon=2, current_ts=5 -> indices=[4, 5]
|
||||
|
||||
indices = []
|
||||
for i in range(self.obs_horizon):
|
||||
# t - (To - 1) + i
|
||||
query_ts = start_ts - (self.obs_horizon - 1) + i
|
||||
# 边界处理 (Padding first frame)
|
||||
query_ts = max(query_ts, 0)
|
||||
indices.append(query_ts)
|
||||
|
||||
# 读取 qpos (proprioception)
|
||||
qpos_data = root['observations/qpos']
|
||||
qpos = qpos_data[indices] # smart indexing
|
||||
if self.stats:
|
||||
qpos = self._normalize_data(qpos, self.stats['qpos'])
|
||||
start_idx_raw = start_ts - (self.obs_horizon - 1)
|
||||
start_idx = max(start_idx_raw, 0)
|
||||
end_idx = start_ts + 1
|
||||
pad_len = max(0, -start_idx_raw)
|
||||
|
||||
# 读取 Images
|
||||
# 你有三个视角: angle, r_vis, top
|
||||
# 建议将它们分开返回,或者在 Dataset 里 Concat
|
||||
# Qpos
|
||||
qpos_data = root['observations/qpos']
|
||||
qpos_val = qpos_data[start_idx:end_idx]
|
||||
|
||||
if pad_len > 0:
|
||||
first_frame = qpos_val[0]
|
||||
padding = np.repeat(first_frame[np.newaxis, :], pad_len, axis=0)
|
||||
qpos_val = np.concatenate([padding, qpos_val], axis=0)
|
||||
|
||||
if self.stats:
|
||||
qpos_val = self._normalize_data(qpos_val, self.stats['qpos'])
|
||||
|
||||
# Images
|
||||
image_dict = {}
|
||||
for cam_name in self.camera_names:
|
||||
# HDF5 dataset
|
||||
img_dset = root['observations']['images'][cam_name]
|
||||
imgs_np = img_dset[start_idx:end_idx] # (T, H, W, C)
|
||||
|
||||
imgs = []
|
||||
for t in indices:
|
||||
img = img_dset[t] # (480, 640, 3) uint8
|
||||
img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0 # (C, H, W)
|
||||
imgs.append(img)
|
||||
if pad_len > 0:
|
||||
first_frame = imgs_np[0]
|
||||
padding = np.repeat(first_frame[np.newaxis, ...], pad_len, axis=0)
|
||||
imgs_np = np.concatenate([padding, imgs_np], axis=0)
|
||||
|
||||
# Stack time dimension: (obs_horizon, 3, H, W)
|
||||
image_dict[cam_name] = torch.stack(imgs)
|
||||
# 转换为 Tensor: (T, H, W, C) -> (T, C, H, W)
|
||||
imgs_tensor = torch.from_numpy(imgs_np).float() / 255.0
|
||||
imgs_tensor = torch.einsum('thwc->tchw', imgs_tensor)
|
||||
image_dict[cam_name] = imgs_tensor
|
||||
|
||||
# -----------------------------
|
||||
# 4. 组装 Batch
|
||||
# -----------------------------
|
||||
# ==============================
|
||||
# 3. 组装 Batch
|
||||
# ==============================
|
||||
data_batch = {
|
||||
'action': torch.from_numpy(actions).float(), # (Tp, 16)
|
||||
'qpos': torch.from_numpy(qpos).float(), # (To, 16)
|
||||
'action': torch.from_numpy(actions).float(),
|
||||
'qpos': torch.from_numpy(qpos_val).float(),
|
||||
}
|
||||
# 将图像放入 batch
|
||||
for cam_name, img_tensor in image_dict.items():
|
||||
data_batch[f'image_{cam_name}'] = img_tensor # (To, 3, H, W)
|
||||
|
||||
# TODO: 添加 Language Instruction
|
||||
# 如果所有 episode 共享任务,这里可以是固定 embedding
|
||||
# 如果每个 episode 任务不同,你需要一个额外的 meta json 来映射 file_path -> text
|
||||
# data_batch['lang_text'] = "pick up the red cube"
|
||||
data_batch[f'image_{cam_name}'] = img_tensor
|
||||
|
||||
return data_batch
|
||||
|
||||
|
||||
Reference in New Issue
Block a user