chore: 添加测试文件

- check_all_episodes.py:检查各个episode是否有重复帧。
- check_specific_frames.py:检查前几帧是否位于正确初始位置。
- generate_dataset_videos.py:根据hdf5生成视频
This commit is contained in:
gouhanke
2026-02-26 13:59:47 +08:00
parent 3deeffb9fe
commit 40c40695dd
3 changed files with 617 additions and 0 deletions

202
check_specific_frames.py Normal file
View File

@@ -0,0 +1,202 @@
#!/usr/bin/env python3
"""
检查特定帧的图像 - 用于验证数据记录问题
功能:
1. 提取每个 episode 的第 0、1、2 帧图像
2. 对比不同 episode 的相同帧号
3. 保存图像供人工检查
"""
import os
import h5py
import glob
import cv2
import numpy as np
def check_specific_frames(frame_indices=[0, 1, 2], camera='top', num_episodes=10):
"""
检查特定帧的图像和 qpos
Args:
frame_indices: 要检查的帧索引列表
camera: 相机名称
num_episodes: 要检查的 episode 数量
"""
dataset_dir = "roboimi/demos/dataset/sim_transfer"
episode_files = sorted(glob.glob(os.path.join(dataset_dir, "episode_*.hdf5")))
# 按数字排序
episode_files = sorted(episode_files, key=lambda x: int(x.split('_')[-1].replace('.hdf5', '')))
# 创建输出目录
output_dir = f'/tmp/dataset_frames'
os.makedirs(output_dir, exist_ok=True)
print(f"检查前 {min(num_episodes, len(episode_files))} 个 episode 的特定帧")
print(f"帧索引: {frame_indices}")
print(f"相机: {camera}")
print("="*80)
# 收集所有数据
for ep_idx in range(min(num_episodes, len(episode_files))):
ep_file = episode_files[ep_idx]
ep_name = os.path.basename(ep_file).replace('.hdf5', '')
try:
with h5py.File(ep_file, 'r') as f:
# 读取 qpos
qpos = f['/observations/qpos'][:]
# 读取图像
img_path = f'/observations/images/{camera}'
if img_path not in f:
print(f"Episode {ep_name}: 相机 {camera} 不存在")
continue
images = f[img_path][:]
print(f"\nEpisode {ep_name}:")
print(f" 总帧数: {len(images)}")
# 保存指定帧
for frame_idx in frame_indices:
if frame_idx >= len(images):
print(f"{frame_idx}: 超出范围")
continue
# 保存图像
img = images[frame_idx]
filename = f"{output_dir}/ep{ep_idx:02d}_frame{frame_idx:03d}.png"
cv2.imwrite(filename, img)
# 打印 qpos
q = qpos[frame_idx]
print(f"{frame_idx}: qpos[0:3]=[{q[0]:6.2f}, {q[1]:6.2f}, {q[2]:6.2f}], qpos[3]={q[3]:6.2f}{filename}")
except Exception as e:
print(f"Episode {ep_name}: 错误 - {e}")
print("\n" + "="*80)
print(f"✅ 所有图像已保存到: {output_dir}")
print(f"\n查看方法:")
print(f" eog {output_dir}/*.png")
print(f" ")
print(f" # 或对比特定帧:")
print(f" eog {output_dir}/*_frame000.png # 所有 episode 的第 0 帧")
print(f" eog {output_dir}/*_frame001.png # 所有 episode 的第 1 帧")
print(f" eog {output_dir}/*_frame002.png # 所有 episode 的第 2 帧")
def compare_frame_across_episodes(frame_idx=0, camera='top', num_episodes=10):
"""
并排对比所有 episode 的某一帧
生成一个大的对比图,包含所有 episode 的指定帧
"""
dataset_dir = "roboimi/demos/dataset/sim_transfer"
episode_files = sorted(glob.glob(os.path.join(dataset_dir, "episode_*.hdf5")))
episode_files = sorted(episode_files, key=lambda x: int(x.split('_')[-1].replace('.hdf5', '')))
num_compare = min(num_episodes, len(episode_files))
cols = 5 # 每行 5 个
rows = (num_compare + cols - 1) // cols
# 创建输出目录
output_dir = f'/tmp/dataset_frames'
os.makedirs(output_dir, exist_ok=True)
print(f"生成对比图: 所有 Episode 的第 {frame_idx}")
print("="*80)
# 收集图像
images_compare = []
qpos_list = []
for ep_idx in range(num_compare):
ep_file = episode_files[ep_idx]
ep_name = os.path.basename(ep_file).replace('.hdf5', '')
try:
with h5py.File(ep_file, 'r') as f:
qpos = f['/observations/qpos'][:]
img_path = f'/observations/images/{camera}'
if img_path in f and frame_idx < f[img_path].shape[0]:
img = f[img_path][frame_idx]
images_compare.append(img)
qpos_list.append(qpos[frame_idx])
print(f"Episode {ep_name}: qpos[0:3]=[{qpos[frame_idx][0]:.2f}, {qpos[frame_idx][1]:.2f}, {qpos[frame_idx][2]:.2f}]")
except Exception as e:
print(f"Episode {ep_name}: 错误 - {e}")
if not images_compare:
print("❌ 没有收集到图像")
return
# 获取图像尺寸
h, w = images_compare[0].shape[:2]
# 创建对比图
compare_img = np.zeros((rows * h + 50, cols * w, 3), dtype=np.uint8)
for i, (img, qpos) in enumerate(zip(images_compare, qpos_list)):
row = i // cols
col = i % cols
y_start = row * h + 30
y_end = y_start + h
x_start = col * w
x_end = x_start + w
# 调整大小(如果需要)
if img.shape[:2] != (h, w):
img = cv2.resize(img, (w, h))
compare_img[y_start:y_end, x_start:x_end] = img
# 添加信息
ep_name = f"Ep {i}"
cv2.putText(compare_img, ep_name, (x_start + 10, row * h + 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
cv2.putText(compare_img, f"qpos[3]={qpos[3]:.2f}", (x_start + 10, y_end - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
# 保存对比图
output_path = f"{output_dir}/compare_frame{frame_idx:03d}.png"
cv2.imwrite(output_path, compare_img)
print(f"\n✅ 对比图已保存: {output_path}")
print(f" 查看方法: eog {output_path}")
if __name__ == "__main__":
import sys
print("="*80)
print("特定帧检查工具")
print("="*80)
if len(sys.argv) > 1:
frame_idx = int(sys.argv[1])
compare_frame_across_episodes(frame_idx=frame_idx, camera='top', num_episodes=10)
else:
# 默认检查第 0、1、2 帧
check_specific_frames(frame_indices=[0, 1, 2], camera='top', num_episodes=10)
print("\n" + "="*80)
print("生成对比图...")
print("="*80)
# 生成第 0 帧的对比图
compare_frame_across_episodes(frame_idx=0, camera='top', num_episodes=10)
compare_frame_across_episodes(frame_idx=1, camera='top', num_episodes=10)
compare_frame_across_episodes(frame_idx=2, camera='top', num_episodes=10)
print("\n" + "="*80)
print("其他用法:")
print(" python check_specific_frames.py 0 # 只检查第 0 帧")
print(" python check_specific_frames.py 1 # 只检查第 1 帧")
print("="*80)