135 lines
5.5 KiB
Python
135 lines
5.5 KiB
Python
import os
|
||
import cv2
|
||
import torch
|
||
import numpy as np
|
||
import argparse
|
||
from torch.utils.data import DataLoader
|
||
from roboimi.vla.data.dataset import VLAChunkedDataset
|
||
|
||
# 颜色常量 (BGR)
|
||
COLOR_TEXT = (255, 255, 255)
|
||
COLOR_VALID = (0, 255, 0) # 有效动作显示为绿色
|
||
COLOR_PAD = (0, 0, 255) # Padding 动作显示为红色
|
||
|
||
def render_text_block(canvas_width, text_lines):
|
||
"""创建一个显示文本信息的图像块"""
|
||
h_per_line = 30
|
||
h = len(text_lines) * h_per_line + 20
|
||
block = np.zeros((h, canvas_width, 3), dtype=np.uint8)
|
||
for i, line in enumerate(text_lines):
|
||
cv2.putText(block, line, (10, 30 + i * h_per_line),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.6, COLOR_TEXT, 1)
|
||
return block
|
||
|
||
def visualize_dataset(data_path: str, output_dir: str):
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
# 1. 实例化 Dataset (使用你最新的定义)
|
||
dataset = VLAChunkedDataset(
|
||
data_path=data_path,
|
||
pred_horizon=16, # 预测未来 16 步
|
||
obs_horizon=2, # 观察过去 2 帧
|
||
obs_keys=["top", "angle"] # 你的两个视角
|
||
)
|
||
|
||
# 使用 DataLoader 模拟训练时的读取行为
|
||
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
|
||
|
||
print(f"[VISUALIZE] 开始生成样本检查图至: {output_dir}")
|
||
print(f" - 数据总长: {len(dataset)}")
|
||
|
||
# 我们抽取开头几个,和末尾几个(检查 Mask 逻辑)
|
||
indices_to_check = list(range(0, 5)) + list(range(len(dataset)-5, len(dataset)))
|
||
|
||
for i, batch in enumerate(dataloader):
|
||
# 为了演示,只处理我们感兴趣的索引,或者随机抽取
|
||
# 这里为了简单,我们遍历前 10 个和最后 5 个
|
||
is_start = i < 5
|
||
is_end = i > (len(dataset) - 6)
|
||
|
||
if not (is_start or is_end):
|
||
continue
|
||
|
||
# --- 数据解包 ---
|
||
# Batch size = 1, 取 index 0
|
||
obs = batch['obs'] # Dict
|
||
qpos = batch['qpos'][0].numpy() # [State_Dim]
|
||
actions = batch['actions'][0].numpy() # [Pred_Horizon, Action_Dim]
|
||
mask = batch['action_mask'][0].numpy() # [Pred_Horizon]
|
||
lang = batch['language'][0] # String
|
||
|
||
# --- 1. 图像渲染 (obs) ---
|
||
# 逻辑:将不同视角的历史帧横向拼接,不同视角纵向拼接
|
||
view_blocks = []
|
||
for key in dataset.obs_keys:
|
||
# tensor: [1, T, C, H, W] -> [T, C, H, W]
|
||
imgs_tensor = obs[key][0]
|
||
T, C, H, W = imgs_tensor.shape
|
||
|
||
frame_list = []
|
||
for t in range(T):
|
||
# [C, H, W] -> [H, W, C] -> numpy
|
||
img_np = imgs_tensor[t].permute(1, 2, 0).numpy()
|
||
img_np = (img_np * 255).astype(np.uint8)
|
||
img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
||
|
||
# 标记时间步 (t-1, t-0)
|
||
label = f"{key} (t - {T-1-t})"
|
||
cv2.putText(img_bgr, label, (10, 30),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2)
|
||
frame_list.append(img_bgr)
|
||
|
||
# 横向拼接历史帧
|
||
view_blocks.append(np.hstack(frame_list))
|
||
|
||
# 纵向拼接不同视角
|
||
visual_block = np.vstack(view_blocks)
|
||
H_vis, W_vis, _ = visual_block.shape
|
||
|
||
# --- 2. 文本信息渲染 (Language & QPos) ---
|
||
info_lines = [
|
||
f"Sample Index: {i} {'(TRAJECTORY END)' if is_end else ''}",
|
||
f"Language: {lang}",
|
||
f"Current QPos (First 6): {np.round(qpos[:6], 3)}"
|
||
]
|
||
info_block = render_text_block(W_vis, info_lines)
|
||
|
||
# --- 3. 动作块渲染 (Action Chunk & Mask) ---
|
||
# 我们创建一个专门的区域来显示 16 个动作的数值和有效性
|
||
action_lines = ["Future Action Chunk (Pred Horizon=16):"]
|
||
for t_act in range(len(actions)):
|
||
# 检查 Mask
|
||
is_valid = mask[t_act] > 0.5
|
||
status = "[VALID]" if is_valid else "[PAD] "
|
||
vals = np.round(actions[t_act][:6], 3) # 只显示前6维
|
||
line = f" t+{t_act:02d} {status} {vals}"
|
||
action_lines.append(line)
|
||
|
||
# 动态改变颜色有点复杂,这里用简单的文本块,但在上面画色条
|
||
action_block = render_text_block(W_vis, action_lines)
|
||
|
||
# 给 Action Block 加颜色标记
|
||
# 简单处理:如果是 PAD,在文字左侧画红条,VALID 画绿条
|
||
line_h = 30
|
||
start_y = 50 # 文本起始偏移
|
||
for t_act in range(len(actions)):
|
||
is_valid = mask[t_act] > 0.5
|
||
color = COLOR_VALID if is_valid else COLOR_PAD
|
||
# 画一个小矩形指示器
|
||
cv2.rectangle(action_block, (0, start_y + t_act*line_h - 20), (5, start_y + t_act*line_h - 5), color, -1)
|
||
|
||
# --- 4. 最终合成 ---
|
||
final_img = np.vstack([info_block, visual_block, action_block])
|
||
|
||
save_path = os.path.join(output_dir, f"check_{i:04d}.png")
|
||
cv2.imwrite(save_path, final_img)
|
||
|
||
print(f"\n[SUCCESS] 可视化完成。请重点检查 {output_dir} 中的最后几张图 (Mask 是否变红)。")
|
||
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument("--data", type=str, default="roboimi/demos/dataset/sim_transfer/episode_0.hdf5", help="数据路径")
|
||
parser.add_argument("--out", type=str, default="vla_debug_vis", help="输出目录")
|
||
args = parser.parse_args()
|
||
|
||
visualize_dataset(args.data, args.out) |