Files
roboimi/check_all_episodes.py
gouhanke 40c40695dd chore: 添加测试文件
- check_all_episodes.py:检查各个episode是否有重复帧。
- check_specific_frames.py:检查前几帧是否位于正确初始位置。
- generate_dataset_videos.py:根据hdf5生成视频
2026-02-26 13:59:47 +08:00

92 lines
3.0 KiB
Python

#!/usr/bin/env python3
"""
检查所有 episode 的重复帧情况
找出哪些 episode 有问题,需要删除或重新收集
"""
import os
import h5py
import glob
import numpy as np
def check_all_episodes():
"""检查所有 episode 的质量"""
dataset_dir = "roboimi/demos/dataset/sim_transfer"
episode_files = sorted(glob.glob(os.path.join(dataset_dir, "episode_*.hdf5")))
episode_files = sorted(episode_files, key=lambda x: int(x.split('_')[-1].replace('.hdf5', '')))
print("="*80)
print("所有 Episode 质量检查")
print("="*80)
good_episodes = []
bad_episodes = []
for ep_idx, ep_file in enumerate(episode_files):
ep_name = os.path.basename(ep_file).replace('.hdf5', '')
try:
with h5py.File(ep_file, 'r') as f:
img_path = '/observations/images/top'
if img_path not in f:
continue
images = f[img_path][:]
# 检查前 50 帧的重复情况
check_frames = min(50, len(images))
duplicate_count = 0
for i in range(check_frames - 1):
img1 = images[i]
img2 = images[i + 1]
diff = np.mean(np.abs(img1.astype(float) - img2.astype(float)))
if diff < 1.0: # 重复
duplicate_count += 1
duplicate_rate = duplicate_count / check_frames * 100
# 判断质量
if duplicate_rate > 10: # 超过10%重复
bad_episodes.append((ep_idx, ep_name, duplicate_rate, duplicate_count))
status = ""
else:
good_episodes.append((ep_idx, ep_name, duplicate_rate, duplicate_count))
status = ""
print(f"{status} Episode {ep_idx:2d}: {duplicate_rate:5.1f}% 重复 ({duplicate_count:2d}/{check_frames}) - {ep_name}")
except Exception as e:
print(f"❌ Episode {ep_idx}: 错误 - {e}")
# 总结
print("\n" + "="*80)
print("总结")
print("="*80)
print(f"总共检查: {len(episode_files)} 个 episodes")
print(f"正常的: {len(good_episodes)} 个 ✅")
print(f"有问题的: {len(bad_episodes)} 个 ❌")
if bad_episodes:
print(f"\n有问题的 episodes:")
for ep_idx, ep_name, rate, count in bad_episodes:
print(f" - episode_{ep_idx}.hdf5: {rate:.1f}% 重复")
print(f"\n删除命令:")
ep_names = [name for _, name, _, _ in bad_episodes]
print(f" rm " + " ".join([f"{dataset_dir}/{name}.hdf5" for name in ep_names]))
print(f"\n建议:")
if len(bad_episodes) > 0:
print(f" 1. 删除有问题的 {len(bad_episodes)} 个 episodes")
print(f" 2. 重新收集数据,或使用剩余的 {len(good_episodes)} 个正常 episodes")
else:
print(f" ✅ 所有 episodes 都正常,可以直接使用!")
if __name__ == "__main__":
check_all_episodes()