diff --git a/roboimi/vla/data/dataset.py b/roboimi/vla/data/dataset.py index 43bdd53..a3eceb5 100644 --- a/roboimi/vla/data/dataset.py +++ b/roboimi/vla/data/dataset.py @@ -2,87 +2,80 @@ import h5py import torch import numpy as np from torch.utils.data import Dataset +from typing import Dict, List, Any -class VLAHDF5Dataset(Dataset): - def __init__(self, - dataset_path: str, - pred_horizon: int = 16, - obs_horizon: int = 2, - transform=None): - self.dataset_path = dataset_path +class VLAChunkedDataset(Dataset): + def __init__( + self, + data_path: str, + pred_horizon: int = 16, + obs_horizon: int = 2, + obs_keys: List[str] = ["top", "angle"] + ): + self.data_path = data_path self.pred_horizon = pred_horizon self.obs_horizon = obs_horizon - self.transform = transform + self.obs_keys = obs_keys + self.file_handle = None - # 1. 在初始化时,我们只读取数据的“元数据”(形状、长度),不加载内容 - # 这一步很快,不会占用内存 - with h5py.File(self.dataset_path, 'r') as root: - self.demo_keys = list(root['data'].keys()) + with h5py.File(self.data_path, 'r') as f: + self.total_len = f["action"].shape[0] + # 尝试从属性或特定路径读取语言指令 + # 假设你的格式中语言存在根目录属性里,或者你手动指定 + self.lang_instruction = f.attrs.get("language", "执行任务") + if isinstance(self.lang_instruction, bytes): + self.lang_instruction = self.lang_instruction.decode("utf-8") - # 构建索引表:(demo_key, start_time) - self.indices = [] - for key in self.demo_keys: - demo = root['data'][key] - L = demo['actions'].shape[0] - # 遍历该轨迹的所有时刻 - for t in range(L): - self.indices.append((key, t)) - + def _get_handle(self): + if self.file_handle is None: + self.file_handle = h5py.File(self.data_path, 'r', swmr=True) + return self.file_handle + def __len__(self): - return len(self.indices) + return self.total_len - def __getitem__(self, idx): - key, t_start = self.indices[idx] + def __getitem__(self, idx: int) -> Dict[str, Any]: + f = self._get_handle() + t_start = idx - # 2. 【关键】在 __getitem__ 内部打开文件 - # 这确保了每个 DataLoader worker 都有自己独立的文件句柄 - with h5py.File(self.dataset_path, 'r') as root: - demo = root['data'][key] - - # 获取数据总长度 - L = demo['actions'].shape[0] - - # --- 读取动作 (Actions) --- - t_end = min(t_start + self.pred_horizon, L) - # HDF5 支持直接切片读取,非常快 - actions = demo['actions'][t_start : t_end] - - # 处理 Padding (如果动作不够长) - if len(actions) < self.pred_horizon: - # 转为 Tensor 处理 Padding - actions = torch.from_numpy(actions) - pad_len = self.pred_horizon - len(actions) - last_action = actions[-1].unsqueeze(0) - actions = torch.cat([actions, last_action.repeat(pad_len, 1)]) - action_mask = torch.cat([torch.ones(len(actions)-pad_len), torch.zeros(pad_len)]) - else: - actions = torch.from_numpy(actions) - action_mask = torch.ones(self.pred_horizon) - - # --- 读取图像 (Images) --- - # 处理历史观测 padding (如果 t_start < obs_horizon) - images_list = [] + # --- 1. 动作与掩码 (Action & Mask) --- + t_end = min(t_start + self.pred_horizon, self.total_len) + actual_len = t_end - t_start + + actions_np = f["action"][t_start:t_end] + + # 创建掩码:1 表示真实数据,0 表示 Padding + # 这是为了在计算 Loss 时屏蔽掉末端重复的动作 + action_mask = torch.ones(self.pred_horizon, dtype=torch.float32) + + if actual_len < self.pred_horizon: + pad_len = self.pred_horizon - actual_len + # 填充最后一个有效动作 + pad_block = np.tile(actions_np[-1], (pad_len, 1)) + actions_np = np.concatenate([actions_np, pad_block], axis=0) + # 将填充部分的掩码置为 0 + action_mask[actual_len:] = 0.0 + + # --- 2. 观察值 (Observations) --- + obs_dict = {} + for key in self.obs_keys: + imgs = [] for i in range(self.obs_horizon): - t_read = max(0, t_start - self.obs_horizon + 1 + i) - # 读取单帧 - img = demo['obs']['agentview_rgb'][t_read] - images_list.append(img) + t_query = max(0, t_start - (self.obs_horizon - 1) + i) + imgs.append(f[f"observations/images/{key}"][t_query]) - # Stack 并转为 Tensor: [T, H, W, C] -> [T, C, H, W] - images = np.stack(images_list) - images = torch.from_numpy(images).permute(0, 3, 1, 2).float() / 255.0 - - # --- 读取语言指令 --- - # 假设语言存储在 demo 的属性中 (Robomimic 风格) - lang_text = demo.attrs.get("model_file", "") # 或自定义字段 + img_stack = np.stack(imgs).astype(np.float32) / 255.0 + img_stack = img_stack.transpose(0, 3, 1, 2) + obs_dict[key] = torch.from_numpy(img_stack) - # 3. 应用图像增强 - if self.transform: - images = self.transform(images) + # --- 3. 状态值 (Low-dim State) --- + # 对应你文件里的 qpos + qpos = f["observations/qpos"][t_start].astype(np.float32) return { - "images": images, - "text": lang_text, # 后续在 collate_fn 中处理 tokenize - "actions": actions, - "action_mask": action_mask + "obs": obs_dict, # 视觉输入 + "qpos": torch.from_numpy(qpos), # 本体感受 (关节角) + "actions": torch.from_numpy(actions_np).float(), + "action_mask": action_mask, # Loss 掩码 + "language": self.lang_instruction # 文本指令 } \ No newline at end of file diff --git a/roboimi/vla/scripts/visualize_data.py b/roboimi/vla/scripts/visualize_data.py index 1a439cf..10ad1dd 100644 --- a/roboimi/vla/scripts/visualize_data.py +++ b/roboimi/vla/scripts/visualize_data.py @@ -1 +1,135 @@ -# 检查 Dataset 读取是否正确 +import os +import cv2 +import torch +import numpy as np +import argparse +from torch.utils.data import DataLoader +from roboimi.vla.data.dataset import VLAChunkedDataset + +# 颜色常量 (BGR) +COLOR_TEXT = (255, 255, 255) +COLOR_VALID = (0, 255, 0) # 有效动作显示为绿色 +COLOR_PAD = (0, 0, 255) # Padding 动作显示为红色 + +def render_text_block(canvas_width, text_lines): + """创建一个显示文本信息的图像块""" + h_per_line = 30 + h = len(text_lines) * h_per_line + 20 + block = np.zeros((h, canvas_width, 3), dtype=np.uint8) + for i, line in enumerate(text_lines): + cv2.putText(block, line, (10, 30 + i * h_per_line), + cv2.FONT_HERSHEY_SIMPLEX, 0.6, COLOR_TEXT, 1) + return block + +def visualize_dataset(data_path: str, output_dir: str): + os.makedirs(output_dir, exist_ok=True) + + # 1. 实例化 Dataset (使用你最新的定义) + dataset = VLAChunkedDataset( + data_path=data_path, + pred_horizon=16, # 预测未来 16 步 + obs_horizon=2, # 观察过去 2 帧 + obs_keys=["top", "angle"] # 你的两个视角 + ) + + # 使用 DataLoader 模拟训练时的读取行为 + dataloader = DataLoader(dataset, batch_size=1, shuffle=False) + + print(f"[VISUALIZE] 开始生成样本检查图至: {output_dir}") + print(f" - 数据总长: {len(dataset)}") + + # 我们抽取开头几个,和末尾几个(检查 Mask 逻辑) + indices_to_check = list(range(0, 5)) + list(range(len(dataset)-5, len(dataset))) + + for i, batch in enumerate(dataloader): + # 为了演示,只处理我们感兴趣的索引,或者随机抽取 + # 这里为了简单,我们遍历前 10 个和最后 5 个 + is_start = i < 5 + is_end = i > (len(dataset) - 6) + + if not (is_start or is_end): + continue + + # --- 数据解包 --- + # Batch size = 1, 取 index 0 + obs = batch['obs'] # Dict + qpos = batch['qpos'][0].numpy() # [State_Dim] + actions = batch['actions'][0].numpy() # [Pred_Horizon, Action_Dim] + mask = batch['action_mask'][0].numpy() # [Pred_Horizon] + lang = batch['language'][0] # String + + # --- 1. 图像渲染 (obs) --- + # 逻辑:将不同视角的历史帧横向拼接,不同视角纵向拼接 + view_blocks = [] + for key in dataset.obs_keys: + # tensor: [1, T, C, H, W] -> [T, C, H, W] + imgs_tensor = obs[key][0] + T, C, H, W = imgs_tensor.shape + + frame_list = [] + for t in range(T): + # [C, H, W] -> [H, W, C] -> numpy + img_np = imgs_tensor[t].permute(1, 2, 0).numpy() + img_np = (img_np * 255).astype(np.uint8) + img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + + # 标记时间步 (t-1, t-0) + label = f"{key} (t - {T-1-t})" + cv2.putText(img_bgr, label, (10, 30), + cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2) + frame_list.append(img_bgr) + + # 横向拼接历史帧 + view_blocks.append(np.hstack(frame_list)) + + # 纵向拼接不同视角 + visual_block = np.vstack(view_blocks) + H_vis, W_vis, _ = visual_block.shape + + # --- 2. 文本信息渲染 (Language & QPos) --- + info_lines = [ + f"Sample Index: {i} {'(TRAJECTORY END)' if is_end else ''}", + f"Language: {lang}", + f"Current QPos (First 6): {np.round(qpos[:6], 3)}" + ] + info_block = render_text_block(W_vis, info_lines) + + # --- 3. 动作块渲染 (Action Chunk & Mask) --- + # 我们创建一个专门的区域来显示 16 个动作的数值和有效性 + action_lines = ["Future Action Chunk (Pred Horizon=16):"] + for t_act in range(len(actions)): + # 检查 Mask + is_valid = mask[t_act] > 0.5 + status = "[VALID]" if is_valid else "[PAD] " + vals = np.round(actions[t_act][:6], 3) # 只显示前6维 + line = f" t+{t_act:02d} {status} {vals}" + action_lines.append(line) + + # 动态改变颜色有点复杂,这里用简单的文本块,但在上面画色条 + action_block = render_text_block(W_vis, action_lines) + + # 给 Action Block 加颜色标记 + # 简单处理:如果是 PAD,在文字左侧画红条,VALID 画绿条 + line_h = 30 + start_y = 50 # 文本起始偏移 + for t_act in range(len(actions)): + is_valid = mask[t_act] > 0.5 + color = COLOR_VALID if is_valid else COLOR_PAD + # 画一个小矩形指示器 + cv2.rectangle(action_block, (0, start_y + t_act*line_h - 20), (5, start_y + t_act*line_h - 5), color, -1) + + # --- 4. 最终合成 --- + final_img = np.vstack([info_block, visual_block, action_block]) + + save_path = os.path.join(output_dir, f"check_{i:04d}.png") + cv2.imwrite(save_path, final_img) + + print(f"\n[SUCCESS] 可视化完成。请重点检查 {output_dir} 中的最后几张图 (Mask 是否变红)。") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data", type=str, default="roboimi/demos/dataset/sim_transfer/episode_0.hdf5", help="数据路径") + parser.add_argument("--out", type=str, default="vla_debug_vis", help="输出目录") + args = parser.parse_args() + + visualize_dataset(args.data, args.out) \ No newline at end of file diff --git a/roboimi/vla/scripts/visualize_episode.py b/roboimi/vla/scripts/visualize_episode.py new file mode 100644 index 0000000..605be3d --- /dev/null +++ b/roboimi/vla/scripts/visualize_episode.py @@ -0,0 +1,89 @@ +import h5py +import cv2 +import numpy as np +import argparse +import os +from tqdm import tqdm + +def visualize_episode(hdf5_path: str, output_path: str, fps: int = 30): + """ + 将单个 episode_x.hdf5 转换为带有遥测数据叠加的可视化视频。 + """ + if not os.path.exists(hdf5_path): + print(f"错误: 找不到文件 {hdf5_path}") + return + + # 如果 output_path 是目录,则自动生成文件名 + if os.path.isdir(output_path) or not output_path.endswith('.mp4'): + os.makedirs(output_path, exist_ok=True) + base_name = os.path.splitext(os.path.basename(hdf5_path))[0] + output_path = os.path.join(output_path, f"{base_name}.mp4") + else: + # 确保输出目录存在 + output_dir = os.path.dirname(output_path) + if output_dir: + os.makedirs(output_dir, exist_ok=True) + + with h5py.File(hdf5_path, 'r') as f: + # 获取基础数据 + images_grp = f['observations/images'] + qpos = f['observations/qpos'][:] + actions = f['action'][:] + + # 获取视角列表 + views = list(images_grp.keys()) # ['angle', 'r_vis', 'top'] + num_steps = images_grp[views[0]].shape[0] + + # 视频参数设置 + # 我们将三个视角横向拼接: (H, W*3, 3) + h, w, _ = images_grp[views[0]][0].shape + out_w = w * len(views) + out_h = h + 150 # 底部留出 150 像素显示数据文字 + + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + video_writer = cv2.VideoWriter(output_path, fourcc, fps, (out_w, out_h)) + + print(f"正在处理 {num_steps} 帧数据...") + for t in tqdm(range(num_steps)): + # 1. 拼接视角图像 + frame_views = [] + for view_name in views: + img = images_grp[view_name][t] + # HDF5 通常存为 RGB,OpenCV 需要 BGR + img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + # 在图像左上角标记视角名称 + cv2.putText(img_bgr, view_name, (20, 40), + cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) + frame_views.append(img_bgr) + + combined_img = np.hstack(frame_views) + + # 2. 创建底部信息栏 + info_bar = np.zeros((150, out_w, 3), dtype=np.uint8) + + # 3. 渲染数据文字 (qpos 和 action) + # 我们展示前 7 维作为代表(通常是臂的 6 自由度 + 夹持器) + qpos_str = "qpos (0-6): " + " ".join([f"{x:.2f}" for x in qpos[t][:7]]) + act_str = "action(0-6): " + " ".join([f"{x:.2f}" for x in actions[t][:7]]) + + cv2.putText(info_bar, qpos_str, (20, 50), + cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) + cv2.putText(info_bar, act_str, (20, 100), + cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2) + cv2.putText(info_bar, f"Step: {t}/{num_steps}", (out_w - 200, 75), + cv2.FONT_HERSHEY_SIMPLEX, 0.7, (150, 150, 150), 2) + + # 4. 合并图像与信息栏 + final_frame = np.vstack([combined_img, info_bar]) + video_writer.write(final_frame) + + video_writer.release() + print(f"\n[SUCCESS] 可视化视频已保存至: {output_path}") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="可视化单个 Episode HDF5 文件") + parser.add_argument("--input", type=str, required=True, help="输入 hdf5 路径") + parser.add_argument("--output", type=str, default="debug_episode.mp4", help="输出视频路径") + args = parser.parse_args() + + visualize_episode(args.input, args.output) \ No newline at end of file