Files
roboimi/roboimi/vla/scripts/visualize_data.py

135 lines
5.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import cv2
import torch
import numpy as np
import argparse
from torch.utils.data import DataLoader
from roboimi.vla.data.dataset import VLAChunkedDataset
# 颜色常量 (BGR)
COLOR_TEXT = (255, 255, 255)
COLOR_VALID = (0, 255, 0) # 有效动作显示为绿色
COLOR_PAD = (0, 0, 255) # Padding 动作显示为红色
def render_text_block(canvas_width, text_lines):
"""创建一个显示文本信息的图像块"""
h_per_line = 30
h = len(text_lines) * h_per_line + 20
block = np.zeros((h, canvas_width, 3), dtype=np.uint8)
for i, line in enumerate(text_lines):
cv2.putText(block, line, (10, 30 + i * h_per_line),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, COLOR_TEXT, 1)
return block
def visualize_dataset(data_path: str, output_dir: str):
os.makedirs(output_dir, exist_ok=True)
# 1. 实例化 Dataset (使用你最新的定义)
dataset = VLAChunkedDataset(
data_path=data_path,
pred_horizon=16, # 预测未来 16 步
obs_horizon=2, # 观察过去 2 帧
obs_keys=["top", "angle"] # 你的两个视角
)
# 使用 DataLoader 模拟训练时的读取行为
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
print(f"[VISUALIZE] 开始生成样本检查图至: {output_dir}")
print(f" - 数据总长: {len(dataset)}")
# 我们抽取开头几个,和末尾几个(检查 Mask 逻辑)
indices_to_check = list(range(0, 5)) + list(range(len(dataset)-5, len(dataset)))
for i, batch in enumerate(dataloader):
# 为了演示,只处理我们感兴趣的索引,或者随机抽取
# 这里为了简单,我们遍历前 10 个和最后 5 个
is_start = i < 5
is_end = i > (len(dataset) - 6)
if not (is_start or is_end):
continue
# --- 数据解包 ---
# Batch size = 1, 取 index 0
obs = batch['obs'] # Dict
qpos = batch['qpos'][0].numpy() # [State_Dim]
actions = batch['actions'][0].numpy() # [Pred_Horizon, Action_Dim]
mask = batch['action_mask'][0].numpy() # [Pred_Horizon]
lang = batch['language'][0] # String
# --- 1. 图像渲染 (obs) ---
# 逻辑:将不同视角的历史帧横向拼接,不同视角纵向拼接
view_blocks = []
for key in dataset.obs_keys:
# tensor: [1, T, C, H, W] -> [T, C, H, W]
imgs_tensor = obs[key][0]
T, C, H, W = imgs_tensor.shape
frame_list = []
for t in range(T):
# [C, H, W] -> [H, W, C] -> numpy
img_np = imgs_tensor[t].permute(1, 2, 0).numpy()
img_np = (img_np * 255).astype(np.uint8)
img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
# 标记时间步 (t-1, t-0)
label = f"{key} (t - {T-1-t})"
cv2.putText(img_bgr, label, (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2)
frame_list.append(img_bgr)
# 横向拼接历史帧
view_blocks.append(np.hstack(frame_list))
# 纵向拼接不同视角
visual_block = np.vstack(view_blocks)
H_vis, W_vis, _ = visual_block.shape
# --- 2. 文本信息渲染 (Language & QPos) ---
info_lines = [
f"Sample Index: {i} {'(TRAJECTORY END)' if is_end else ''}",
f"Language: {lang}",
f"Current QPos (First 6): {np.round(qpos[:6], 3)}"
]
info_block = render_text_block(W_vis, info_lines)
# --- 3. 动作块渲染 (Action Chunk & Mask) ---
# 我们创建一个专门的区域来显示 16 个动作的数值和有效性
action_lines = ["Future Action Chunk (Pred Horizon=16):"]
for t_act in range(len(actions)):
# 检查 Mask
is_valid = mask[t_act] > 0.5
status = "[VALID]" if is_valid else "[PAD] "
vals = np.round(actions[t_act][:6], 3) # 只显示前6维
line = f" t+{t_act:02d} {status} {vals}"
action_lines.append(line)
# 动态改变颜色有点复杂,这里用简单的文本块,但在上面画色条
action_block = render_text_block(W_vis, action_lines)
# 给 Action Block 加颜色标记
# 简单处理:如果是 PAD在文字左侧画红条VALID 画绿条
line_h = 30
start_y = 50 # 文本起始偏移
for t_act in range(len(actions)):
is_valid = mask[t_act] > 0.5
color = COLOR_VALID if is_valid else COLOR_PAD
# 画一个小矩形指示器
cv2.rectangle(action_block, (0, start_y + t_act*line_h - 20), (5, start_y + t_act*line_h - 5), color, -1)
# --- 4. 最终合成 ---
final_img = np.vstack([info_block, visual_block, action_block])
save_path = os.path.join(output_dir, f"check_{i:04d}.png")
cv2.imwrite(save_path, final_img)
print(f"\n[SUCCESS] 可视化完成。请重点检查 {output_dir} 中的最后几张图 (Mask 是否变红)。")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data", type=str, default="roboimi/demos/dataset/sim_transfer/episode_0.hdf5", help="数据路径")
parser.add_argument("--out", type=str, default="vla_debug_vis", help="输出目录")
args = parser.parse_args()
visualize_dataset(args.data, args.out)