#!/usr/bin/env python3 """ 检查特定帧的图像 - 用于验证数据记录问题 功能: 1. 提取每个 episode 的第 0、1、2 帧图像 2. 对比不同 episode 的相同帧号 3. 保存图像供人工检查 """ import os import h5py import glob import cv2 import numpy as np def check_specific_frames(frame_indices=[0, 1, 2], camera='top', num_episodes=10): """ 检查特定帧的图像和 qpos Args: frame_indices: 要检查的帧索引列表 camera: 相机名称 num_episodes: 要检查的 episode 数量 """ dataset_dir = "roboimi/demos/dataset/sim_transfer" episode_files = sorted(glob.glob(os.path.join(dataset_dir, "episode_*.hdf5"))) # 按数字排序 episode_files = sorted(episode_files, key=lambda x: int(x.split('_')[-1].replace('.hdf5', ''))) # 创建输出目录 output_dir = f'/tmp/dataset_frames' os.makedirs(output_dir, exist_ok=True) print(f"检查前 {min(num_episodes, len(episode_files))} 个 episode 的特定帧") print(f"帧索引: {frame_indices}") print(f"相机: {camera}") print("="*80) # 收集所有数据 for ep_idx in range(min(num_episodes, len(episode_files))): ep_file = episode_files[ep_idx] ep_name = os.path.basename(ep_file).replace('.hdf5', '') try: with h5py.File(ep_file, 'r') as f: # 读取 qpos qpos = f['/observations/qpos'][:] # 读取图像 img_path = f'/observations/images/{camera}' if img_path not in f: print(f"Episode {ep_name}: 相机 {camera} 不存在") continue images = f[img_path][:] print(f"\nEpisode {ep_name}:") print(f" 总帧数: {len(images)}") # 保存指定帧 for frame_idx in frame_indices: if frame_idx >= len(images): print(f" 帧 {frame_idx}: 超出范围") continue # 保存图像 img = images[frame_idx] filename = f"{output_dir}/ep{ep_idx:02d}_frame{frame_idx:03d}.png" cv2.imwrite(filename, img) # 打印 qpos q = qpos[frame_idx] print(f" 帧 {frame_idx}: qpos[0:3]=[{q[0]:6.2f}, {q[1]:6.2f}, {q[2]:6.2f}], qpos[3]={q[3]:6.2f} → {filename}") except Exception as e: print(f"Episode {ep_name}: 错误 - {e}") print("\n" + "="*80) print(f"✅ 所有图像已保存到: {output_dir}") print(f"\n查看方法:") print(f" eog {output_dir}/*.png") print(f" ") print(f" # 或对比特定帧:") print(f" eog {output_dir}/*_frame000.png # 所有 episode 的第 0 帧") print(f" eog {output_dir}/*_frame001.png # 所有 episode 的第 1 帧") print(f" eog {output_dir}/*_frame002.png # 所有 episode 的第 2 帧") def compare_frame_across_episodes(frame_idx=0, camera='top', num_episodes=10): """ 并排对比所有 episode 的某一帧 生成一个大的对比图,包含所有 episode 的指定帧 """ dataset_dir = "roboimi/demos/dataset/sim_transfer" episode_files = sorted(glob.glob(os.path.join(dataset_dir, "episode_*.hdf5"))) episode_files = sorted(episode_files, key=lambda x: int(x.split('_')[-1].replace('.hdf5', ''))) num_compare = min(num_episodes, len(episode_files)) cols = 5 # 每行 5 个 rows = (num_compare + cols - 1) // cols # 创建输出目录 output_dir = f'/tmp/dataset_frames' os.makedirs(output_dir, exist_ok=True) print(f"生成对比图: 所有 Episode 的第 {frame_idx} 帧") print("="*80) # 收集图像 images_compare = [] qpos_list = [] for ep_idx in range(num_compare): ep_file = episode_files[ep_idx] ep_name = os.path.basename(ep_file).replace('.hdf5', '') try: with h5py.File(ep_file, 'r') as f: qpos = f['/observations/qpos'][:] img_path = f'/observations/images/{camera}' if img_path in f and frame_idx < f[img_path].shape[0]: img = f[img_path][frame_idx] images_compare.append(img) qpos_list.append(qpos[frame_idx]) print(f"Episode {ep_name}: qpos[0:3]=[{qpos[frame_idx][0]:.2f}, {qpos[frame_idx][1]:.2f}, {qpos[frame_idx][2]:.2f}]") except Exception as e: print(f"Episode {ep_name}: 错误 - {e}") if not images_compare: print("❌ 没有收集到图像") return # 获取图像尺寸 h, w = images_compare[0].shape[:2] # 创建对比图 compare_img = np.zeros((rows * h + 50, cols * w, 3), dtype=np.uint8) for i, (img, qpos) in enumerate(zip(images_compare, qpos_list)): row = i // cols col = i % cols y_start = row * h + 30 y_end = y_start + h x_start = col * w x_end = x_start + w # 调整大小(如果需要) if img.shape[:2] != (h, w): img = cv2.resize(img, (w, h)) compare_img[y_start:y_end, x_start:x_end] = img # 添加信息 ep_name = f"Ep {i}" cv2.putText(compare_img, ep_name, (x_start + 10, row * h + 20), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2) cv2.putText(compare_img, f"qpos[3]={qpos[3]:.2f}", (x_start + 10, y_end - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) # 保存对比图 output_path = f"{output_dir}/compare_frame{frame_idx:03d}.png" cv2.imwrite(output_path, compare_img) print(f"\n✅ 对比图已保存: {output_path}") print(f" 查看方法: eog {output_path}") if __name__ == "__main__": import sys print("="*80) print("特定帧检查工具") print("="*80) if len(sys.argv) > 1: frame_idx = int(sys.argv[1]) compare_frame_across_episodes(frame_idx=frame_idx, camera='top', num_episodes=10) else: # 默认检查第 0、1、2 帧 check_specific_frames(frame_indices=[0, 1, 2], camera='top', num_episodes=10) print("\n" + "="*80) print("生成对比图...") print("="*80) # 生成第 0 帧的对比图 compare_frame_across_episodes(frame_idx=0, camera='top', num_episodes=10) compare_frame_across_episodes(frame_idx=1, camera='top', num_episodes=10) compare_frame_across_episodes(frame_idx=2, camera='top', num_episodes=10) print("\n" + "="*80) print("其他用法:") print(" python check_specific_frames.py 0 # 只检查第 0 帧") print(" python check_specific_frames.py 1 # 只检查第 1 帧") print("="*80)