diff --git a/check_all_episodes.py b/check_all_episodes.py new file mode 100644 index 0000000..2734216 --- /dev/null +++ b/check_all_episodes.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python3 +""" +检查所有 episode 的重复帧情况 + +找出哪些 episode 有问题,需要删除或重新收集 +""" +import os +import h5py +import glob +import numpy as np + + +def check_all_episodes(): + """检查所有 episode 的质量""" + + dataset_dir = "roboimi/demos/dataset/sim_transfer" + episode_files = sorted(glob.glob(os.path.join(dataset_dir, "episode_*.hdf5"))) + episode_files = sorted(episode_files, key=lambda x: int(x.split('_')[-1].replace('.hdf5', ''))) + + print("="*80) + print("所有 Episode 质量检查") + print("="*80) + + good_episodes = [] + bad_episodes = [] + + for ep_idx, ep_file in enumerate(episode_files): + ep_name = os.path.basename(ep_file).replace('.hdf5', '') + + try: + with h5py.File(ep_file, 'r') as f: + img_path = '/observations/images/top' + if img_path not in f: + continue + + images = f[img_path][:] + + # 检查前 50 帧的重复情况 + check_frames = min(50, len(images)) + duplicate_count = 0 + + for i in range(check_frames - 1): + img1 = images[i] + img2 = images[i + 1] + diff = np.mean(np.abs(img1.astype(float) - img2.astype(float))) + + if diff < 1.0: # 重复 + duplicate_count += 1 + + duplicate_rate = duplicate_count / check_frames * 100 + + # 判断质量 + if duplicate_rate > 10: # 超过10%重复 + bad_episodes.append((ep_idx, ep_name, duplicate_rate, duplicate_count)) + status = "❌" + else: + good_episodes.append((ep_idx, ep_name, duplicate_rate, duplicate_count)) + status = "✅" + + print(f"{status} Episode {ep_idx:2d}: {duplicate_rate:5.1f}% 重复 ({duplicate_count:2d}/{check_frames}) - {ep_name}") + + except Exception as e: + print(f"❌ Episode {ep_idx}: 错误 - {e}") + + # 总结 + print("\n" + "="*80) + print("总结") + print("="*80) + print(f"总共检查: {len(episode_files)} 个 episodes") + print(f"正常的: {len(good_episodes)} 个 ✅") + print(f"有问题的: {len(bad_episodes)} 个 ❌") + + if bad_episodes: + print(f"\n有问题的 episodes:") + for ep_idx, ep_name, rate, count in bad_episodes: + print(f" - episode_{ep_idx}.hdf5: {rate:.1f}% 重复") + + print(f"\n删除命令:") + ep_names = [name for _, name, _, _ in bad_episodes] + print(f" rm " + " ".join([f"{dataset_dir}/{name}.hdf5" for name in ep_names])) + + print(f"\n建议:") + if len(bad_episodes) > 0: + print(f" 1. 删除有问题的 {len(bad_episodes)} 个 episodes") + print(f" 2. 重新收集数据,或使用剩余的 {len(good_episodes)} 个正常 episodes") + else: + print(f" ✅ 所有 episodes 都正常,可以直接使用!") + + +if __name__ == "__main__": + check_all_episodes() diff --git a/check_specific_frames.py b/check_specific_frames.py new file mode 100644 index 0000000..ce93d35 --- /dev/null +++ b/check_specific_frames.py @@ -0,0 +1,202 @@ +#!/usr/bin/env python3 +""" +检查特定帧的图像 - 用于验证数据记录问题 + +功能: +1. 提取每个 episode 的第 0、1、2 帧图像 +2. 对比不同 episode 的相同帧号 +3. 保存图像供人工检查 +""" +import os +import h5py +import glob +import cv2 +import numpy as np + + +def check_specific_frames(frame_indices=[0, 1, 2], camera='top', num_episodes=10): + """ + 检查特定帧的图像和 qpos + + Args: + frame_indices: 要检查的帧索引列表 + camera: 相机名称 + num_episodes: 要检查的 episode 数量 + """ + + dataset_dir = "roboimi/demos/dataset/sim_transfer" + episode_files = sorted(glob.glob(os.path.join(dataset_dir, "episode_*.hdf5"))) + # 按数字排序 + episode_files = sorted(episode_files, key=lambda x: int(x.split('_')[-1].replace('.hdf5', ''))) + + # 创建输出目录 + output_dir = f'/tmp/dataset_frames' + os.makedirs(output_dir, exist_ok=True) + + print(f"检查前 {min(num_episodes, len(episode_files))} 个 episode 的特定帧") + print(f"帧索引: {frame_indices}") + print(f"相机: {camera}") + print("="*80) + + # 收集所有数据 + for ep_idx in range(min(num_episodes, len(episode_files))): + ep_file = episode_files[ep_idx] + ep_name = os.path.basename(ep_file).replace('.hdf5', '') + + try: + with h5py.File(ep_file, 'r') as f: + # 读取 qpos + qpos = f['/observations/qpos'][:] + + # 读取图像 + img_path = f'/observations/images/{camera}' + if img_path not in f: + print(f"Episode {ep_name}: 相机 {camera} 不存在") + continue + + images = f[img_path][:] + + print(f"\nEpisode {ep_name}:") + print(f" 总帧数: {len(images)}") + + # 保存指定帧 + for frame_idx in frame_indices: + if frame_idx >= len(images): + print(f" 帧 {frame_idx}: 超出范围") + continue + + # 保存图像 + img = images[frame_idx] + filename = f"{output_dir}/ep{ep_idx:02d}_frame{frame_idx:03d}.png" + cv2.imwrite(filename, img) + + # 打印 qpos + q = qpos[frame_idx] + print(f" 帧 {frame_idx}: qpos[0:3]=[{q[0]:6.2f}, {q[1]:6.2f}, {q[2]:6.2f}], qpos[3]={q[3]:6.2f} → {filename}") + + except Exception as e: + print(f"Episode {ep_name}: 错误 - {e}") + + print("\n" + "="*80) + print(f"✅ 所有图像已保存到: {output_dir}") + print(f"\n查看方法:") + print(f" eog {output_dir}/*.png") + print(f" ") + print(f" # 或对比特定帧:") + print(f" eog {output_dir}/*_frame000.png # 所有 episode 的第 0 帧") + print(f" eog {output_dir}/*_frame001.png # 所有 episode 的第 1 帧") + print(f" eog {output_dir}/*_frame002.png # 所有 episode 的第 2 帧") + + +def compare_frame_across_episodes(frame_idx=0, camera='top', num_episodes=10): + """ + 并排对比所有 episode 的某一帧 + + 生成一个大的对比图,包含所有 episode 的指定帧 + """ + + dataset_dir = "roboimi/demos/dataset/sim_transfer" + episode_files = sorted(glob.glob(os.path.join(dataset_dir, "episode_*.hdf5"))) + episode_files = sorted(episode_files, key=lambda x: int(x.split('_')[-1].replace('.hdf5', ''))) + + num_compare = min(num_episodes, len(episode_files)) + cols = 5 # 每行 5 个 + rows = (num_compare + cols - 1) // cols + + # 创建输出目录 + output_dir = f'/tmp/dataset_frames' + os.makedirs(output_dir, exist_ok=True) + + print(f"生成对比图: 所有 Episode 的第 {frame_idx} 帧") + print("="*80) + + # 收集图像 + images_compare = [] + qpos_list = [] + + for ep_idx in range(num_compare): + ep_file = episode_files[ep_idx] + ep_name = os.path.basename(ep_file).replace('.hdf5', '') + + try: + with h5py.File(ep_file, 'r') as f: + qpos = f['/observations/qpos'][:] + img_path = f'/observations/images/{camera}' + + if img_path in f and frame_idx < f[img_path].shape[0]: + img = f[img_path][frame_idx] + images_compare.append(img) + qpos_list.append(qpos[frame_idx]) + print(f"Episode {ep_name}: qpos[0:3]=[{qpos[frame_idx][0]:.2f}, {qpos[frame_idx][1]:.2f}, {qpos[frame_idx][2]:.2f}]") + + except Exception as e: + print(f"Episode {ep_name}: 错误 - {e}") + + if not images_compare: + print("❌ 没有收集到图像") + return + + # 获取图像尺寸 + h, w = images_compare[0].shape[:2] + + # 创建对比图 + compare_img = np.zeros((rows * h + 50, cols * w, 3), dtype=np.uint8) + + for i, (img, qpos) in enumerate(zip(images_compare, qpos_list)): + row = i // cols + col = i % cols + + y_start = row * h + 30 + y_end = y_start + h + x_start = col * w + x_end = x_start + w + + # 调整大小(如果需要) + if img.shape[:2] != (h, w): + img = cv2.resize(img, (w, h)) + + compare_img[y_start:y_end, x_start:x_end] = img + + # 添加信息 + ep_name = f"Ep {i}" + cv2.putText(compare_img, ep_name, (x_start + 10, row * h + 20), + cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2) + cv2.putText(compare_img, f"qpos[3]={qpos[3]:.2f}", (x_start + 10, y_end - 10), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) + + # 保存对比图 + output_path = f"{output_dir}/compare_frame{frame_idx:03d}.png" + cv2.imwrite(output_path, compare_img) + + print(f"\n✅ 对比图已保存: {output_path}") + print(f" 查看方法: eog {output_path}") + + +if __name__ == "__main__": + import sys + + print("="*80) + print("特定帧检查工具") + print("="*80) + + if len(sys.argv) > 1: + frame_idx = int(sys.argv[1]) + compare_frame_across_episodes(frame_idx=frame_idx, camera='top', num_episodes=10) + else: + # 默认检查第 0、1、2 帧 + check_specific_frames(frame_indices=[0, 1, 2], camera='top', num_episodes=10) + + print("\n" + "="*80) + print("生成对比图...") + print("="*80) + + # 生成第 0 帧的对比图 + compare_frame_across_episodes(frame_idx=0, camera='top', num_episodes=10) + compare_frame_across_episodes(frame_idx=1, camera='top', num_episodes=10) + compare_frame_across_episodes(frame_idx=2, camera='top', num_episodes=10) + + print("\n" + "="*80) + print("其他用法:") + print(" python check_specific_frames.py 0 # 只检查第 0 帧") + print(" python check_specific_frames.py 1 # 只检查第 1 帧") + print("="*80) diff --git a/generate_dataset_videos.py b/generate_dataset_videos.py new file mode 100644 index 0000000..0adae9f --- /dev/null +++ b/generate_dataset_videos.py @@ -0,0 +1,324 @@ +#!/usr/bin/env python3 +""" +将 HDF5 数据集转换为视频,用于可视化检查 + +功能: +1. 将单个 episode 转换为视频 +2. 对比多个 episode 的视频 +3. 放慢播放速度便于观察 +""" +import os +import h5py +import glob +import cv2 +import numpy as np + + +def episode_to_video(episode_file, output_path, camera='top', fps=30, slow_factor=1): + """ + 将单个 episode 转换为视频 + + Args: + episode_file: HDF5 文件路径 + output_path: 输出视频路径 + camera: 要使用的相机名称 + fps: 帧率 + slow_factor: 慢放倍数(1=正常,2=半速) + """ + try: + with h5py.File(episode_file, 'r') as f: + # 读取图像序列 + img_path = f'/observations/images/{camera}' + + if img_path not in f: + print(f" ❌ 相机 {camera} 不存在") + return False + + images = f[img_path][:] # shape: (T, H, W, C) + qpos = f['/observations/qpos'][:] + actions = f['/action'][:] + + total_frames = len(images) + height, width = images.shape[1], images.shape[2] + + # 创建视频写入器 + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + actual_fps = fps // slow_factor + out = cv2.VideoWriter(output_path, fourcc, actual_fps, (width, height)) + + # 逐帧写入 + for i in range(total_frames): + frame = images[i].astype(np.uint8) + + # 在图像上添加信息 + info_text = [ + f"Episode: {os.path.basename(episode_file).replace('.hdf5', '')}", + f"Frame: {i}/{total_frames}", + f"qpos[0:3]: [{qpos[i, 0]:.2f}, {qpos[i, 1]:.2f}, {qpos[i, 2]:.2f}]", + ] + + for j, text in enumerate(info_text): + cv2.putText(frame, text, (10, 30 + j*30), + cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) + + out.write(frame) + + out.release() + print(f" ✅ 保存: {output_path}") + print(f" 帧数: {total_frames}, 尺寸: {width}x{height}, FPS: {actual_fps}") + return True + + except Exception as e: + print(f" ❌ 错误: {e}") + return False + + +def generate_all_videos(camera='top', num_episodes=5, slow_factor=1): + """生成前 N 个 episode 的视频""" + + dataset_dir = "roboimi/demos/dataset/sim_transfer" + episode_files = sorted(glob.glob(os.path.join(dataset_dir, "episode_*.hdf5"))) + + if len(episode_files) == 0: + print(f"❌ 没有找到数据文件: {dataset_dir}") + return + + # 创建输出目录 + output_dir = '/tmp/dataset_videos' + os.makedirs(output_dir, exist_ok=True) + + print(f"找到 {len(episode_files)} 个 episode 文件") + print(f"将生成前 {min(num_episodes, len(episode_files))} 个 episode 的视频\n") + + # 生成视频 + for i in range(min(num_episodes, len(episode_files))): + ep_file = episode_files[i] + ep_name = os.path.basename(ep_file).replace('.hdf5', '') + output_path = f"{output_dir}/{ep_name}_{camera}.mp4" + + print(f"[{i+1}/{min(num_episodes, len(episode_files))}] {ep_name}") + episode_to_video(ep_file, output_path, camera=camera, slow_factor=slow_factor) + print() + + print(f"✅ 所有视频已保存到: {output_dir}") + print(f"\n播放方法:") + print(f" # 播放单个视频") + print(f" vlc {output_dir}/*.mp4") + print(f" ") + print(f" # 或用文件管理器") + print(f" nautilus {output_dir}") + + +def generate_multi_camera_video(episode_idx=0, slow_factor=1): + """生成包含多个相机的视频(分屏显示)""" + + dataset_dir = "roboimi/demos/dataset/sim_transfer" + episode_files = sorted(glob.glob(os.path.join(dataset_dir, "episode_*.hdf5"))) + + if episode_idx >= len(episode_files): + print(f"❌ Episode {episode_idx} 不存在") + return + + ep_file = episode_files[episode_idx] + + try: + with h5py.File(ep_file, 'r') as f: + # 获取所有相机 + cameras = [] + for key in f.keys(): + if 'images' in key: + for cam_name in f[key].keys(): + if cam_name not in cameras: + cameras.append(cam_name) + + print(f"Episode {episode_idx} 的相机: {cameras}") + + # 读取所有相机的图像 + all_images = {} + for cam in cameras: + img_path = f'/observations/images/{cam}' + if img_path in f: + all_images[cam] = f[img_path][:] + + if not all_images: + print("❌ 没有找到图像数据") + return + + # 获取第一个相机的尺寸 + first_cam = list(all_images.keys())[0] + total_frames = len(all_images[first_cam]) + height, width = all_images[first_cam].shape[1], all_images[first_cam].shape[2] + + # 创建多相机布局 + num_cams = len(all_images) + cols = min(2, num_cams) + rows = (num_cams + cols - 1) // cols + + canvas_width = width * cols + canvas_height = height * rows + + # 创建视频写入器 + output_path = f'/tmp/dataset_videos/episode_{episode_idx}_all_cameras.mp4' + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + out = cv2.VideoWriter(output_path, fourcc, 30 // slow_factor, (canvas_width, canvas_height)) + + # 逐帧合成 + for i in range(total_frames): + canvas = np.zeros((canvas_height, canvas_width, 3), dtype=np.uint8) + + for cam_idx, cam_name in enumerate(all_images.keys()): + img = all_images[cam_name][i] + + # 计算在画布上的位置 + row = cam_idx // cols + col = cam_idx % cols + y_start = row * height + y_end = y_start + height + x_start = col * width + x_end = x_start + width + + # 调整大小(如果需要) + if img.shape[:2] != (height, width): + img = cv2.resize(img, (width, height)) + + # 放到画布上 + canvas[y_start:y_end, x_start:x_end] = img + + # 添加相机名称 + cv2.putText(canvas, cam_name, (x_start + 10, y_start + 30), + cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2) + + # 添加帧信息 + cv2.putText(canvas, f"Frame: {i}/{total_frames}", (10, canvas_height - 10), + cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) + + out.write(canvas) + + out.release() + print(f"✅ 保存多相机视频: {output_path}") + + except Exception as e: + print(f"❌ 错误: {e}") + + +def compare_episodes(camera='top', slow_factor=2): + """并排对比多个 episode 的视频""" + + dataset_dir = "roboimi/demos/dataset/sim_transfer" + episode_files = sorted(glob.glob(os.path.join(dataset_dir, "episode_*.hdf5"))) + + # 选择要对比的 episode + episodes_to_compare = [0, 1, 2, 3, 4] # 对比前 5 个 + + print(f"对比 Episodes: {episodes_to_compare}") + + # 读取所有 episode 的数据 + all_data = [] + for ep_idx in episodes_to_compare: + if ep_idx >= len(episode_files): + continue + + try: + with h5py.File(episode_files[ep_idx], 'r') as f: + img_path = f'/observations/images/{camera}' + if img_path in f: + all_data.append({ + 'idx': ep_idx, + 'images': f[img_path][:], + 'qpos': f['/observations/qpos'][:] + }) + except: + pass + + if len(all_data) == 0: + print("❌ 没有数据") + return + + # 获取参数 + first_data = all_data[0] + height, width = first_data['images'].shape[1], first_data['images'].shape[2] + total_frames = min([d['images'].shape[0] for d in all_data]) + + # 创建并排布局 + num_compare = len(all_data) + canvas_width = width * num_compare + canvas_height = height + + # 创建视频 + output_path = f'/tmp/dataset_videos/compare_{camera}.mp4' + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + out = cv2.VideoWriter(output_path, fourcc, 30 // slow_factor, (canvas_width, canvas_height)) + + print(f"生成对比视频,共 {total_frames} 帧...") + + # 逐帧对比 + for i in range(total_frames): + canvas = np.zeros((canvas_height, canvas_width, 3), dtype=np.uint8) + + for j, data in enumerate(all_data): + img = data['images'][i] + qpos = data['qpos'][i] + + # 调整大小(如果需要) + if img.shape[:2] != (height, width): + img = cv2.resize(img, (width, height)) + + # 放到画布上 + x_start = j * width + x_end = x_start + width + canvas[:, x_start:x_end] = img + + # 添加信息 + ep_name = f"Ep {data['idx']}" + cv2.putText(canvas, ep_name, (x_start + 10, 30), + cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 255), 2) + cv2.putText(canvas, f"qpos[0:3]: [{qpos[0]:.2f}, {qpos[1]:.2f}, {qpos[2]:.2f}]", + (x_start + 10, height - 10), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) + + # 添加帧号 + cv2.putText(canvas, f"Frame: {i}/{total_frames}", (10, canvas_height - 30), + cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) + + out.write(canvas) + + if i % 100 == 0: + print(f" 进度: {i}/{total_frames}") + + out.release() + print(f"✅ 保存对比视频: {output_path}") + + +if __name__ == "__main__": + import sys + + print("="*60) + print("数据集视频生成工具") + print("="*60) + + if len(sys.argv) > 1: + command = sys.argv[1] + + if command == 'compare': + # 对比多个 episode + camera = sys.argv[2] if len(sys.argv) > 2 else 'top' + compare_episodes(camera=camera, slow_factor=2) + + elif command == 'multi': + # 多相机视频 + ep_idx = int(sys.argv[2]) if len(sys.argv) > 2 else 0 + generate_multi_camera_video(episode_idx=ep_idx, slow_factor=1) + + else: + print("未知命令") + else: + # 默认:生成前 5 个 episode 的视频 + print("\n生成前 5 个 episode 的视频(top 相机,慢放 2x)...") + print("="*60 + "\n") + generate_all_videos(camera='top', num_episodes=5, slow_factor=2) + + print("\n" + "="*60) + print("其他用法:") + print(" python generate_dataset_videos.py compare top # 对比多个 episode") + print(" python generate_dataset_videos.py multi 0 # 多相机视频") + print("="*60)