- check_all_episodes.py:检查各个episode是否有重复帧。 - check_specific_frames.py:检查前几帧是否位于正确初始位置。 - generate_dataset_videos.py:根据hdf5生成视频
92 lines
3.0 KiB
Python
92 lines
3.0 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
检查所有 episode 的重复帧情况
|
|
|
|
找出哪些 episode 有问题,需要删除或重新收集
|
|
"""
|
|
import os
|
|
import h5py
|
|
import glob
|
|
import numpy as np
|
|
|
|
|
|
def check_all_episodes():
|
|
"""检查所有 episode 的质量"""
|
|
|
|
dataset_dir = "roboimi/demos/dataset/sim_transfer"
|
|
episode_files = sorted(glob.glob(os.path.join(dataset_dir, "episode_*.hdf5")))
|
|
episode_files = sorted(episode_files, key=lambda x: int(x.split('_')[-1].replace('.hdf5', '')))
|
|
|
|
print("="*80)
|
|
print("所有 Episode 质量检查")
|
|
print("="*80)
|
|
|
|
good_episodes = []
|
|
bad_episodes = []
|
|
|
|
for ep_idx, ep_file in enumerate(episode_files):
|
|
ep_name = os.path.basename(ep_file).replace('.hdf5', '')
|
|
|
|
try:
|
|
with h5py.File(ep_file, 'r') as f:
|
|
img_path = '/observations/images/top'
|
|
if img_path not in f:
|
|
continue
|
|
|
|
images = f[img_path][:]
|
|
|
|
# 检查前 50 帧的重复情况
|
|
check_frames = min(50, len(images))
|
|
duplicate_count = 0
|
|
|
|
for i in range(check_frames - 1):
|
|
img1 = images[i]
|
|
img2 = images[i + 1]
|
|
diff = np.mean(np.abs(img1.astype(float) - img2.astype(float)))
|
|
|
|
if diff < 1.0: # 重复
|
|
duplicate_count += 1
|
|
|
|
duplicate_rate = duplicate_count / check_frames * 100
|
|
|
|
# 判断质量
|
|
if duplicate_rate > 10: # 超过10%重复
|
|
bad_episodes.append((ep_idx, ep_name, duplicate_rate, duplicate_count))
|
|
status = "❌"
|
|
else:
|
|
good_episodes.append((ep_idx, ep_name, duplicate_rate, duplicate_count))
|
|
status = "✅"
|
|
|
|
print(f"{status} Episode {ep_idx:2d}: {duplicate_rate:5.1f}% 重复 ({duplicate_count:2d}/{check_frames}) - {ep_name}")
|
|
|
|
except Exception as e:
|
|
print(f"❌ Episode {ep_idx}: 错误 - {e}")
|
|
|
|
# 总结
|
|
print("\n" + "="*80)
|
|
print("总结")
|
|
print("="*80)
|
|
print(f"总共检查: {len(episode_files)} 个 episodes")
|
|
print(f"正常的: {len(good_episodes)} 个 ✅")
|
|
print(f"有问题的: {len(bad_episodes)} 个 ❌")
|
|
|
|
if bad_episodes:
|
|
print(f"\n有问题的 episodes:")
|
|
for ep_idx, ep_name, rate, count in bad_episodes:
|
|
print(f" - episode_{ep_idx}.hdf5: {rate:.1f}% 重复")
|
|
|
|
print(f"\n删除命令:")
|
|
ep_names = [name for _, name, _, _ in bad_episodes]
|
|
print(f" rm " + " ".join([f"{dataset_dir}/{name}.hdf5" for name in ep_names]))
|
|
|
|
print(f"\n建议:")
|
|
if len(bad_episodes) > 0:
|
|
print(f" 1. 删除有问题的 {len(bad_episodes)} 个 episodes")
|
|
print(f" 2. 重新收集数据,或使用剩余的 {len(good_episodes)} 个正常 episodes")
|
|
else:
|
|
print(f" ✅ 所有 episodes 都正常,可以直接使用!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
check_all_episodes()
|