跑通配置和训练脚本

This commit is contained in:
gouhanke
2026-02-03 16:51:04 +08:00
parent bd8bbb0cfc
commit 3b58760469
5 changed files with 227 additions and 90 deletions

View File

@@ -1,6 +1,8 @@
import h5py
import torch
import numpy as np
import os
import glob
from torch.utils.data import Dataset
from typing import Dict, List, Any
@@ -10,72 +12,108 @@ class VLAChunkedDataset(Dataset):
data_path: str,
pred_horizon: int = 16,
obs_horizon: int = 2,
obs_keys: List[str] = ["top", "angle"]
obs_keys: List[str] = ["top"] # 默认只用 top
):
self.data_path = data_path
self.pred_horizon = pred_horizon
self.obs_horizon = obs_horizon
self.obs_keys = obs_keys
self.file_handle = None
with h5py.File(self.data_path, 'r') as f:
self.total_len = f["action"].shape[0]
# 尝试从属性或特定路径读取语言指令
# 假设你的格式中语言存在根目录属性里,或者你手动指定
self.lang_instruction = f.attrs.get("language", "执行任务")
if isinstance(self.lang_instruction, bytes):
self.lang_instruction = self.lang_instruction.decode("utf-8")
# --- 1. 扫描文件 ---
if os.path.isdir(data_path):
# 如果是文件夹,读取所有 episode_*.hdf5
self.file_paths = sorted(glob.glob(os.path.join(data_path, "*.hdf5")))
else:
# 如果是单文件
self.file_paths = [data_path]
def _get_handle(self):
if self.file_handle is None:
self.file_handle = h5py.File(self.data_path, 'r', swmr=True)
return self.file_handle
if len(self.file_paths) == 0:
raise ValueError(f"No .hdf5 files found in {data_path}")
print(f"Found {len(self.file_paths)} episodes. Indexing...")
# --- 2. 建立全局索引 (Episode, Time) ---
# 我们需要知道 global_index=1000 对应的是哪个文件的第几帧
self.index_map = [] # [(file_idx, start_time), ...]
for i, path in enumerate(self.file_paths):
with h5py.File(path, 'r') as f:
# 假设所有文件的 action 长度就是 episode 长度
total_len = f["action"].shape[0]
# 有效的起始点:从 0 到 total_len - 1
# 即使到了最后几帧,因为有 padding所以也是有效的 sample
for t in range(total_len):
self.index_map.append((i, t))
print(f"✅ Indexed {len(self.index_map)} total samples.")
def __len__(self):
return self.total_len
return len(self.index_map)
def __getitem__(self, idx: int) -> Dict[str, Any]:
f = self._get_handle()
t_start = idx
# --- 1. 定位文件 ---
file_idx, t_start = self.index_map[idx]
file_path = self.file_paths[file_idx]
# --- 1. 动作与掩码 (Action & Mask) ---
t_end = min(t_start + self.pred_horizon, self.total_len)
actual_len = t_end - t_start
actions_np = f["action"][t_start:t_end]
# 创建掩码1 表示真实数据0 表示 Padding
# 这是为了在计算 Loss 时屏蔽掉末端重复的动作
action_mask = torch.ones(self.pred_horizon, dtype=torch.float32)
if actual_len < self.pred_horizon:
pad_len = self.pred_horizon - actual_len
# 填充最后一个有效动作
pad_block = np.tile(actions_np[-1], (pad_len, 1))
actions_np = np.concatenate([actions_np, pad_block], axis=0)
# 将填充部分的掩码置为 0
action_mask[actual_len:] = 0.0
# --- 2. 观察值 (Observations) ---
obs_dict = {}
for key in self.obs_keys:
imgs = []
for i in range(self.obs_horizon):
t_query = max(0, t_start - (self.obs_horizon - 1) + i)
imgs.append(f[f"observations/images/{key}"][t_query])
# 每次读取打开文件 (Lazy Loading),读取完自动关闭
# 这种方式对多进程 DataLoader 最安全
with h5py.File(file_path, 'r') as f:
total_len = f["action"].shape[0]
img_stack = np.stack(imgs).astype(np.float32) / 255.0
img_stack = img_stack.transpose(0, 3, 1, 2)
obs_dict[key] = torch.from_numpy(img_stack)
# --- 2. 动作 (Action) ---
t_end = min(t_start + self.pred_horizon, total_len)
# 读取动作片段
actions_np = f["action"][t_start:t_end] # (L, 16)
# Padding 处理
actual_len = actions_np.shape[0]
action_mask = torch.ones(self.pred_horizon, dtype=torch.float32)
if actual_len < self.pred_horizon:
pad_len = self.pred_horizon - actual_len
# 重复最后一帧动作进行填充
pad_block = np.tile(actions_np[-1], (pad_len, 1))
actions_np = np.concatenate([actions_np, pad_block], axis=0)
# 标记 Padding 部分为 0
action_mask[actual_len:] = 0.0
# --- 3. 图像 (Images) ---
obs_dict = {}
for key in self.obs_keys:
imgs = []
# 处理观测历史 (Obs Horizon)
# 如果 t_start=0, obs_horizon=2, 我们需要读取 t=0 和 t=0 (重复第一帧)
for i in range(self.obs_horizon):
# 倒序读取:当前帧,前一帧...
# 注意:这里逻辑是 [t_start - (obs_horizon-1) + i]
# 比如 horizon=2, t=10. i=0 -> t=9; i=1 -> t=10.
query_t = t_start - (self.obs_horizon - 1) + i
query_t = max(0, query_t) # 边界保护
imgs.append(f[f"observations/images/{key}"][query_t])
# Stack -> (Obs_Horizon, H, W, C)
img_stack = np.stack(imgs)
# Normalize & Permute -> (Obs_Horizon, C, H, W)
img_stack = img_stack.astype(np.float32) / 255.0
img_stack = np.transpose(img_stack, (0, 3, 1, 2))
obs_dict[key] = torch.from_numpy(img_stack)
# --- 3. 状态值 (Low-dim State) ---
# 对应你文件里的 qpos
qpos = f["observations/qpos"][t_start].astype(np.float32)
# --- 4. QPos ---
qpos = f["observations/qpos"][t_start].astype(np.float32)
# --- 5. Language ---
# 暂时写死或从 attrs 读取
lang = f.attrs.get("language", "task instruction placeholder")
if isinstance(lang, bytes):
lang = lang.decode("utf-8")
return {
"obs": obs_dict, # 视觉输入
"qpos": torch.from_numpy(qpos), # 本体感受 (关节角)
"obs": obs_dict,
"qpos": torch.from_numpy(qpos),
"actions": torch.from_numpy(actions_np).float(),
"action_mask": action_mask, # Loss 掩码
"language": self.lang_instruction # 文本指令
"action_mask": action_mask,
"language": lang
}