feat(dataset): 定义VLAChunkedDataset类,构建数据可视化工具

This commit is contained in:
gouhanke
2026-02-03 15:24:09 +08:00
parent 57acfd645f
commit d3863ea1dd
3 changed files with 287 additions and 71 deletions

View File

@@ -2,87 +2,80 @@ import h5py
import torch
import numpy as np
from torch.utils.data import Dataset
from typing import Dict, List, Any
class VLAHDF5Dataset(Dataset):
def __init__(self,
dataset_path: str,
pred_horizon: int = 16,
obs_horizon: int = 2,
transform=None):
self.dataset_path = dataset_path
class VLAChunkedDataset(Dataset):
def __init__(
self,
data_path: str,
pred_horizon: int = 16,
obs_horizon: int = 2,
obs_keys: List[str] = ["top", "angle"]
):
self.data_path = data_path
self.pred_horizon = pred_horizon
self.obs_horizon = obs_horizon
self.transform = transform
self.obs_keys = obs_keys
self.file_handle = None
# 1. 在初始化时,我们只读取数据的“元数据”(形状、长度),不加载内容
# 这一步很快,不会占用内存
with h5py.File(self.dataset_path, 'r') as root:
self.demo_keys = list(root['data'].keys())
with h5py.File(self.data_path, 'r') as f:
self.total_len = f["action"].shape[0]
# 尝试从属性或特定路径读取语言指令
# 假设你的格式中语言存在根目录属性里,或者你手动指定
self.lang_instruction = f.attrs.get("language", "执行任务")
if isinstance(self.lang_instruction, bytes):
self.lang_instruction = self.lang_instruction.decode("utf-8")
# 构建索引表:(demo_key, start_time)
self.indices = []
for key in self.demo_keys:
demo = root['data'][key]
L = demo['actions'].shape[0]
# 遍历该轨迹的所有时刻
for t in range(L):
self.indices.append((key, t))
def _get_handle(self):
if self.file_handle is None:
self.file_handle = h5py.File(self.data_path, 'r', swmr=True)
return self.file_handle
def __len__(self):
return len(self.indices)
return self.total_len
def __getitem__(self, idx):
key, t_start = self.indices[idx]
def __getitem__(self, idx: int) -> Dict[str, Any]:
f = self._get_handle()
t_start = idx
# 2. 【关键】在 __getitem__ 内部打开文件
# 这确保了每个 DataLoader worker 都有自己独立的文件句柄
with h5py.File(self.dataset_path, 'r') as root:
demo = root['data'][key]
# 获取数据总长度
L = demo['actions'].shape[0]
# --- 读取动作 (Actions) ---
t_end = min(t_start + self.pred_horizon, L)
# HDF5 支持直接切片读取,非常快
actions = demo['actions'][t_start : t_end]
# 处理 Padding (如果动作不够长)
if len(actions) < self.pred_horizon:
# 转为 Tensor 处理 Padding
actions = torch.from_numpy(actions)
pad_len = self.pred_horizon - len(actions)
last_action = actions[-1].unsqueeze(0)
actions = torch.cat([actions, last_action.repeat(pad_len, 1)])
action_mask = torch.cat([torch.ones(len(actions)-pad_len), torch.zeros(pad_len)])
else:
actions = torch.from_numpy(actions)
action_mask = torch.ones(self.pred_horizon)
# --- 读取图像 (Images) ---
# 处理历史观测 padding (如果 t_start < obs_horizon)
images_list = []
# --- 1. 动作与掩码 (Action & Mask) ---
t_end = min(t_start + self.pred_horizon, self.total_len)
actual_len = t_end - t_start
actions_np = f["action"][t_start:t_end]
# 创建掩码1 表示真实数据0 表示 Padding
# 这是为了在计算 Loss 时屏蔽掉末端重复的动作
action_mask = torch.ones(self.pred_horizon, dtype=torch.float32)
if actual_len < self.pred_horizon:
pad_len = self.pred_horizon - actual_len
# 填充最后一个有效动作
pad_block = np.tile(actions_np[-1], (pad_len, 1))
actions_np = np.concatenate([actions_np, pad_block], axis=0)
# 将填充部分的掩码置为 0
action_mask[actual_len:] = 0.0
# --- 2. 观察值 (Observations) ---
obs_dict = {}
for key in self.obs_keys:
imgs = []
for i in range(self.obs_horizon):
t_read = max(0, t_start - self.obs_horizon + 1 + i)
# 读取单帧
img = demo['obs']['agentview_rgb'][t_read]
images_list.append(img)
t_query = max(0, t_start - (self.obs_horizon - 1) + i)
imgs.append(f[f"observations/images/{key}"][t_query])
# Stack 并转为 Tensor: [T, H, W, C] -> [T, C, H, W]
images = np.stack(images_list)
images = torch.from_numpy(images).permute(0, 3, 1, 2).float() / 255.0
# --- 读取语言指令 ---
# 假设语言存储在 demo 的属性中 (Robomimic 风格)
lang_text = demo.attrs.get("model_file", "") # 或自定义字段
img_stack = np.stack(imgs).astype(np.float32) / 255.0
img_stack = img_stack.transpose(0, 3, 1, 2)
obs_dict[key] = torch.from_numpy(img_stack)
# 3. 应用图像增强
if self.transform:
images = self.transform(images)
# --- 3. 状态值 (Low-dim State) ---
# 对应你文件里的 qpos
qpos = f["observations/qpos"][t_start].astype(np.float32)
return {
"images": images,
"text": lang_text, # 后续在 collate_fn 中处理 tokenize
"actions": actions,
"action_mask": action_mask
"obs": obs_dict, # 视觉输入
"qpos": torch.from_numpy(qpos), # 本体感受 (关节角)
"actions": torch.from_numpy(actions_np).float(),
"action_mask": action_mask, # Loss 掩码
"language": self.lang_instruction # 文本指令
}

View File

@@ -1 +1,135 @@
# 检查 Dataset 读取是否正确
import os
import cv2
import torch
import numpy as np
import argparse
from torch.utils.data import DataLoader
from roboimi.vla.data.dataset import VLAChunkedDataset
# 颜色常量 (BGR)
COLOR_TEXT = (255, 255, 255)
COLOR_VALID = (0, 255, 0) # 有效动作显示为绿色
COLOR_PAD = (0, 0, 255) # Padding 动作显示为红色
def render_text_block(canvas_width, text_lines):
"""创建一个显示文本信息的图像块"""
h_per_line = 30
h = len(text_lines) * h_per_line + 20
block = np.zeros((h, canvas_width, 3), dtype=np.uint8)
for i, line in enumerate(text_lines):
cv2.putText(block, line, (10, 30 + i * h_per_line),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, COLOR_TEXT, 1)
return block
def visualize_dataset(data_path: str, output_dir: str):
os.makedirs(output_dir, exist_ok=True)
# 1. 实例化 Dataset (使用你最新的定义)
dataset = VLAChunkedDataset(
data_path=data_path,
pred_horizon=16, # 预测未来 16 步
obs_horizon=2, # 观察过去 2 帧
obs_keys=["top", "angle"] # 你的两个视角
)
# 使用 DataLoader 模拟训练时的读取行为
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
print(f"[VISUALIZE] 开始生成样本检查图至: {output_dir}")
print(f" - 数据总长: {len(dataset)}")
# 我们抽取开头几个,和末尾几个(检查 Mask 逻辑)
indices_to_check = list(range(0, 5)) + list(range(len(dataset)-5, len(dataset)))
for i, batch in enumerate(dataloader):
# 为了演示,只处理我们感兴趣的索引,或者随机抽取
# 这里为了简单,我们遍历前 10 个和最后 5 个
is_start = i < 5
is_end = i > (len(dataset) - 6)
if not (is_start or is_end):
continue
# --- 数据解包 ---
# Batch size = 1, 取 index 0
obs = batch['obs'] # Dict
qpos = batch['qpos'][0].numpy() # [State_Dim]
actions = batch['actions'][0].numpy() # [Pred_Horizon, Action_Dim]
mask = batch['action_mask'][0].numpy() # [Pred_Horizon]
lang = batch['language'][0] # String
# --- 1. 图像渲染 (obs) ---
# 逻辑:将不同视角的历史帧横向拼接,不同视角纵向拼接
view_blocks = []
for key in dataset.obs_keys:
# tensor: [1, T, C, H, W] -> [T, C, H, W]
imgs_tensor = obs[key][0]
T, C, H, W = imgs_tensor.shape
frame_list = []
for t in range(T):
# [C, H, W] -> [H, W, C] -> numpy
img_np = imgs_tensor[t].permute(1, 2, 0).numpy()
img_np = (img_np * 255).astype(np.uint8)
img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
# 标记时间步 (t-1, t-0)
label = f"{key} (t - {T-1-t})"
cv2.putText(img_bgr, label, (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2)
frame_list.append(img_bgr)
# 横向拼接历史帧
view_blocks.append(np.hstack(frame_list))
# 纵向拼接不同视角
visual_block = np.vstack(view_blocks)
H_vis, W_vis, _ = visual_block.shape
# --- 2. 文本信息渲染 (Language & QPos) ---
info_lines = [
f"Sample Index: {i} {'(TRAJECTORY END)' if is_end else ''}",
f"Language: {lang}",
f"Current QPos (First 6): {np.round(qpos[:6], 3)}"
]
info_block = render_text_block(W_vis, info_lines)
# --- 3. 动作块渲染 (Action Chunk & Mask) ---
# 我们创建一个专门的区域来显示 16 个动作的数值和有效性
action_lines = ["Future Action Chunk (Pred Horizon=16):"]
for t_act in range(len(actions)):
# 检查 Mask
is_valid = mask[t_act] > 0.5
status = "[VALID]" if is_valid else "[PAD] "
vals = np.round(actions[t_act][:6], 3) # 只显示前6维
line = f" t+{t_act:02d} {status} {vals}"
action_lines.append(line)
# 动态改变颜色有点复杂,这里用简单的文本块,但在上面画色条
action_block = render_text_block(W_vis, action_lines)
# 给 Action Block 加颜色标记
# 简单处理:如果是 PAD在文字左侧画红条VALID 画绿条
line_h = 30
start_y = 50 # 文本起始偏移
for t_act in range(len(actions)):
is_valid = mask[t_act] > 0.5
color = COLOR_VALID if is_valid else COLOR_PAD
# 画一个小矩形指示器
cv2.rectangle(action_block, (0, start_y + t_act*line_h - 20), (5, start_y + t_act*line_h - 5), color, -1)
# --- 4. 最终合成 ---
final_img = np.vstack([info_block, visual_block, action_block])
save_path = os.path.join(output_dir, f"check_{i:04d}.png")
cv2.imwrite(save_path, final_img)
print(f"\n[SUCCESS] 可视化完成。请重点检查 {output_dir} 中的最后几张图 (Mask 是否变红)。")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data", type=str, default="roboimi/demos/dataset/sim_transfer/episode_0.hdf5", help="数据路径")
parser.add_argument("--out", type=str, default="vla_debug_vis", help="输出目录")
args = parser.parse_args()
visualize_dataset(args.data, args.out)

View File

@@ -0,0 +1,89 @@
import h5py
import cv2
import numpy as np
import argparse
import os
from tqdm import tqdm
def visualize_episode(hdf5_path: str, output_path: str, fps: int = 30):
"""
将单个 episode_x.hdf5 转换为带有遥测数据叠加的可视化视频。
"""
if not os.path.exists(hdf5_path):
print(f"错误: 找不到文件 {hdf5_path}")
return
# 如果 output_path 是目录,则自动生成文件名
if os.path.isdir(output_path) or not output_path.endswith('.mp4'):
os.makedirs(output_path, exist_ok=True)
base_name = os.path.splitext(os.path.basename(hdf5_path))[0]
output_path = os.path.join(output_path, f"{base_name}.mp4")
else:
# 确保输出目录存在
output_dir = os.path.dirname(output_path)
if output_dir:
os.makedirs(output_dir, exist_ok=True)
with h5py.File(hdf5_path, 'r') as f:
# 获取基础数据
images_grp = f['observations/images']
qpos = f['observations/qpos'][:]
actions = f['action'][:]
# 获取视角列表
views = list(images_grp.keys()) # ['angle', 'r_vis', 'top']
num_steps = images_grp[views[0]].shape[0]
# 视频参数设置
# 我们将三个视角横向拼接: (H, W*3, 3)
h, w, _ = images_grp[views[0]][0].shape
out_w = w * len(views)
out_h = h + 150 # 底部留出 150 像素显示数据文字
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video_writer = cv2.VideoWriter(output_path, fourcc, fps, (out_w, out_h))
print(f"正在处理 {num_steps} 帧数据...")
for t in tqdm(range(num_steps)):
# 1. 拼接视角图像
frame_views = []
for view_name in views:
img = images_grp[view_name][t]
# HDF5 通常存为 RGBOpenCV 需要 BGR
img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
# 在图像左上角标记视角名称
cv2.putText(img_bgr, view_name, (20, 40),
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
frame_views.append(img_bgr)
combined_img = np.hstack(frame_views)
# 2. 创建底部信息栏
info_bar = np.zeros((150, out_w, 3), dtype=np.uint8)
# 3. 渲染数据文字 (qpos 和 action)
# 我们展示前 7 维作为代表(通常是臂的 6 自由度 + 夹持器)
qpos_str = "qpos (0-6): " + " ".join([f"{x:.2f}" for x in qpos[t][:7]])
act_str = "action(0-6): " + " ".join([f"{x:.2f}" for x in actions[t][:7]])
cv2.putText(info_bar, qpos_str, (20, 50),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
cv2.putText(info_bar, act_str, (20, 100),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2)
cv2.putText(info_bar, f"Step: {t}/{num_steps}", (out_w - 200, 75),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (150, 150, 150), 2)
# 4. 合并图像与信息栏
final_frame = np.vstack([combined_img, info_bar])
video_writer.write(final_frame)
video_writer.release()
print(f"\n[SUCCESS] 可视化视频已保存至: {output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="可视化单个 Episode HDF5 文件")
parser.add_argument("--input", type=str, required=True, help="输入 hdf5 路径")
parser.add_argument("--output", type=str, default="debug_episode.mp4", help="输出视频路径")
args = parser.parse_args()
visualize_episode(args.input, args.output)