chore: 添加测试文件
- check_all_episodes.py:检查各个episode是否有重复帧。 - check_specific_frames.py:检查前几帧是否位于正确初始位置。 - generate_dataset_videos.py:根据hdf5生成视频
This commit is contained in:
91
check_all_episodes.py
Normal file
91
check_all_episodes.py
Normal file
@@ -0,0 +1,91 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
检查所有 episode 的重复帧情况
|
||||
|
||||
找出哪些 episode 有问题,需要删除或重新收集
|
||||
"""
|
||||
import os
|
||||
import h5py
|
||||
import glob
|
||||
import numpy as np
|
||||
|
||||
|
||||
def check_all_episodes():
|
||||
"""检查所有 episode 的质量"""
|
||||
|
||||
dataset_dir = "roboimi/demos/dataset/sim_transfer"
|
||||
episode_files = sorted(glob.glob(os.path.join(dataset_dir, "episode_*.hdf5")))
|
||||
episode_files = sorted(episode_files, key=lambda x: int(x.split('_')[-1].replace('.hdf5', '')))
|
||||
|
||||
print("="*80)
|
||||
print("所有 Episode 质量检查")
|
||||
print("="*80)
|
||||
|
||||
good_episodes = []
|
||||
bad_episodes = []
|
||||
|
||||
for ep_idx, ep_file in enumerate(episode_files):
|
||||
ep_name = os.path.basename(ep_file).replace('.hdf5', '')
|
||||
|
||||
try:
|
||||
with h5py.File(ep_file, 'r') as f:
|
||||
img_path = '/observations/images/top'
|
||||
if img_path not in f:
|
||||
continue
|
||||
|
||||
images = f[img_path][:]
|
||||
|
||||
# 检查前 50 帧的重复情况
|
||||
check_frames = min(50, len(images))
|
||||
duplicate_count = 0
|
||||
|
||||
for i in range(check_frames - 1):
|
||||
img1 = images[i]
|
||||
img2 = images[i + 1]
|
||||
diff = np.mean(np.abs(img1.astype(float) - img2.astype(float)))
|
||||
|
||||
if diff < 1.0: # 重复
|
||||
duplicate_count += 1
|
||||
|
||||
duplicate_rate = duplicate_count / check_frames * 100
|
||||
|
||||
# 判断质量
|
||||
if duplicate_rate > 10: # 超过10%重复
|
||||
bad_episodes.append((ep_idx, ep_name, duplicate_rate, duplicate_count))
|
||||
status = "❌"
|
||||
else:
|
||||
good_episodes.append((ep_idx, ep_name, duplicate_rate, duplicate_count))
|
||||
status = "✅"
|
||||
|
||||
print(f"{status} Episode {ep_idx:2d}: {duplicate_rate:5.1f}% 重复 ({duplicate_count:2d}/{check_frames}) - {ep_name}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Episode {ep_idx}: 错误 - {e}")
|
||||
|
||||
# 总结
|
||||
print("\n" + "="*80)
|
||||
print("总结")
|
||||
print("="*80)
|
||||
print(f"总共检查: {len(episode_files)} 个 episodes")
|
||||
print(f"正常的: {len(good_episodes)} 个 ✅")
|
||||
print(f"有问题的: {len(bad_episodes)} 个 ❌")
|
||||
|
||||
if bad_episodes:
|
||||
print(f"\n有问题的 episodes:")
|
||||
for ep_idx, ep_name, rate, count in bad_episodes:
|
||||
print(f" - episode_{ep_idx}.hdf5: {rate:.1f}% 重复")
|
||||
|
||||
print(f"\n删除命令:")
|
||||
ep_names = [name for _, name, _, _ in bad_episodes]
|
||||
print(f" rm " + " ".join([f"{dataset_dir}/{name}.hdf5" for name in ep_names]))
|
||||
|
||||
print(f"\n建议:")
|
||||
if len(bad_episodes) > 0:
|
||||
print(f" 1. 删除有问题的 {len(bad_episodes)} 个 episodes")
|
||||
print(f" 2. 重新收集数据,或使用剩余的 {len(good_episodes)} 个正常 episodes")
|
||||
else:
|
||||
print(f" ✅ 所有 episodes 都正常,可以直接使用!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
check_all_episodes()
|
||||
202
check_specific_frames.py
Normal file
202
check_specific_frames.py
Normal file
@@ -0,0 +1,202 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
检查特定帧的图像 - 用于验证数据记录问题
|
||||
|
||||
功能:
|
||||
1. 提取每个 episode 的第 0、1、2 帧图像
|
||||
2. 对比不同 episode 的相同帧号
|
||||
3. 保存图像供人工检查
|
||||
"""
|
||||
import os
|
||||
import h5py
|
||||
import glob
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def check_specific_frames(frame_indices=[0, 1, 2], camera='top', num_episodes=10):
|
||||
"""
|
||||
检查特定帧的图像和 qpos
|
||||
|
||||
Args:
|
||||
frame_indices: 要检查的帧索引列表
|
||||
camera: 相机名称
|
||||
num_episodes: 要检查的 episode 数量
|
||||
"""
|
||||
|
||||
dataset_dir = "roboimi/demos/dataset/sim_transfer"
|
||||
episode_files = sorted(glob.glob(os.path.join(dataset_dir, "episode_*.hdf5")))
|
||||
# 按数字排序
|
||||
episode_files = sorted(episode_files, key=lambda x: int(x.split('_')[-1].replace('.hdf5', '')))
|
||||
|
||||
# 创建输出目录
|
||||
output_dir = f'/tmp/dataset_frames'
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
print(f"检查前 {min(num_episodes, len(episode_files))} 个 episode 的特定帧")
|
||||
print(f"帧索引: {frame_indices}")
|
||||
print(f"相机: {camera}")
|
||||
print("="*80)
|
||||
|
||||
# 收集所有数据
|
||||
for ep_idx in range(min(num_episodes, len(episode_files))):
|
||||
ep_file = episode_files[ep_idx]
|
||||
ep_name = os.path.basename(ep_file).replace('.hdf5', '')
|
||||
|
||||
try:
|
||||
with h5py.File(ep_file, 'r') as f:
|
||||
# 读取 qpos
|
||||
qpos = f['/observations/qpos'][:]
|
||||
|
||||
# 读取图像
|
||||
img_path = f'/observations/images/{camera}'
|
||||
if img_path not in f:
|
||||
print(f"Episode {ep_name}: 相机 {camera} 不存在")
|
||||
continue
|
||||
|
||||
images = f[img_path][:]
|
||||
|
||||
print(f"\nEpisode {ep_name}:")
|
||||
print(f" 总帧数: {len(images)}")
|
||||
|
||||
# 保存指定帧
|
||||
for frame_idx in frame_indices:
|
||||
if frame_idx >= len(images):
|
||||
print(f" 帧 {frame_idx}: 超出范围")
|
||||
continue
|
||||
|
||||
# 保存图像
|
||||
img = images[frame_idx]
|
||||
filename = f"{output_dir}/ep{ep_idx:02d}_frame{frame_idx:03d}.png"
|
||||
cv2.imwrite(filename, img)
|
||||
|
||||
# 打印 qpos
|
||||
q = qpos[frame_idx]
|
||||
print(f" 帧 {frame_idx}: qpos[0:3]=[{q[0]:6.2f}, {q[1]:6.2f}, {q[2]:6.2f}], qpos[3]={q[3]:6.2f} → {filename}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Episode {ep_name}: 错误 - {e}")
|
||||
|
||||
print("\n" + "="*80)
|
||||
print(f"✅ 所有图像已保存到: {output_dir}")
|
||||
print(f"\n查看方法:")
|
||||
print(f" eog {output_dir}/*.png")
|
||||
print(f" ")
|
||||
print(f" # 或对比特定帧:")
|
||||
print(f" eog {output_dir}/*_frame000.png # 所有 episode 的第 0 帧")
|
||||
print(f" eog {output_dir}/*_frame001.png # 所有 episode 的第 1 帧")
|
||||
print(f" eog {output_dir}/*_frame002.png # 所有 episode 的第 2 帧")
|
||||
|
||||
|
||||
def compare_frame_across_episodes(frame_idx=0, camera='top', num_episodes=10):
|
||||
"""
|
||||
并排对比所有 episode 的某一帧
|
||||
|
||||
生成一个大的对比图,包含所有 episode 的指定帧
|
||||
"""
|
||||
|
||||
dataset_dir = "roboimi/demos/dataset/sim_transfer"
|
||||
episode_files = sorted(glob.glob(os.path.join(dataset_dir, "episode_*.hdf5")))
|
||||
episode_files = sorted(episode_files, key=lambda x: int(x.split('_')[-1].replace('.hdf5', '')))
|
||||
|
||||
num_compare = min(num_episodes, len(episode_files))
|
||||
cols = 5 # 每行 5 个
|
||||
rows = (num_compare + cols - 1) // cols
|
||||
|
||||
# 创建输出目录
|
||||
output_dir = f'/tmp/dataset_frames'
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
print(f"生成对比图: 所有 Episode 的第 {frame_idx} 帧")
|
||||
print("="*80)
|
||||
|
||||
# 收集图像
|
||||
images_compare = []
|
||||
qpos_list = []
|
||||
|
||||
for ep_idx in range(num_compare):
|
||||
ep_file = episode_files[ep_idx]
|
||||
ep_name = os.path.basename(ep_file).replace('.hdf5', '')
|
||||
|
||||
try:
|
||||
with h5py.File(ep_file, 'r') as f:
|
||||
qpos = f['/observations/qpos'][:]
|
||||
img_path = f'/observations/images/{camera}'
|
||||
|
||||
if img_path in f and frame_idx < f[img_path].shape[0]:
|
||||
img = f[img_path][frame_idx]
|
||||
images_compare.append(img)
|
||||
qpos_list.append(qpos[frame_idx])
|
||||
print(f"Episode {ep_name}: qpos[0:3]=[{qpos[frame_idx][0]:.2f}, {qpos[frame_idx][1]:.2f}, {qpos[frame_idx][2]:.2f}]")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Episode {ep_name}: 错误 - {e}")
|
||||
|
||||
if not images_compare:
|
||||
print("❌ 没有收集到图像")
|
||||
return
|
||||
|
||||
# 获取图像尺寸
|
||||
h, w = images_compare[0].shape[:2]
|
||||
|
||||
# 创建对比图
|
||||
compare_img = np.zeros((rows * h + 50, cols * w, 3), dtype=np.uint8)
|
||||
|
||||
for i, (img, qpos) in enumerate(zip(images_compare, qpos_list)):
|
||||
row = i // cols
|
||||
col = i % cols
|
||||
|
||||
y_start = row * h + 30
|
||||
y_end = y_start + h
|
||||
x_start = col * w
|
||||
x_end = x_start + w
|
||||
|
||||
# 调整大小(如果需要)
|
||||
if img.shape[:2] != (h, w):
|
||||
img = cv2.resize(img, (w, h))
|
||||
|
||||
compare_img[y_start:y_end, x_start:x_end] = img
|
||||
|
||||
# 添加信息
|
||||
ep_name = f"Ep {i}"
|
||||
cv2.putText(compare_img, ep_name, (x_start + 10, row * h + 20),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
|
||||
cv2.putText(compare_img, f"qpos[3]={qpos[3]:.2f}", (x_start + 10, y_end - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
|
||||
|
||||
# 保存对比图
|
||||
output_path = f"{output_dir}/compare_frame{frame_idx:03d}.png"
|
||||
cv2.imwrite(output_path, compare_img)
|
||||
|
||||
print(f"\n✅ 对比图已保存: {output_path}")
|
||||
print(f" 查看方法: eog {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
print("="*80)
|
||||
print("特定帧检查工具")
|
||||
print("="*80)
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
frame_idx = int(sys.argv[1])
|
||||
compare_frame_across_episodes(frame_idx=frame_idx, camera='top', num_episodes=10)
|
||||
else:
|
||||
# 默认检查第 0、1、2 帧
|
||||
check_specific_frames(frame_indices=[0, 1, 2], camera='top', num_episodes=10)
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("生成对比图...")
|
||||
print("="*80)
|
||||
|
||||
# 生成第 0 帧的对比图
|
||||
compare_frame_across_episodes(frame_idx=0, camera='top', num_episodes=10)
|
||||
compare_frame_across_episodes(frame_idx=1, camera='top', num_episodes=10)
|
||||
compare_frame_across_episodes(frame_idx=2, camera='top', num_episodes=10)
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("其他用法:")
|
||||
print(" python check_specific_frames.py 0 # 只检查第 0 帧")
|
||||
print(" python check_specific_frames.py 1 # 只检查第 1 帧")
|
||||
print("="*80)
|
||||
324
generate_dataset_videos.py
Normal file
324
generate_dataset_videos.py
Normal file
@@ -0,0 +1,324 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
将 HDF5 数据集转换为视频,用于可视化检查
|
||||
|
||||
功能:
|
||||
1. 将单个 episode 转换为视频
|
||||
2. 对比多个 episode 的视频
|
||||
3. 放慢播放速度便于观察
|
||||
"""
|
||||
import os
|
||||
import h5py
|
||||
import glob
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def episode_to_video(episode_file, output_path, camera='top', fps=30, slow_factor=1):
|
||||
"""
|
||||
将单个 episode 转换为视频
|
||||
|
||||
Args:
|
||||
episode_file: HDF5 文件路径
|
||||
output_path: 输出视频路径
|
||||
camera: 要使用的相机名称
|
||||
fps: 帧率
|
||||
slow_factor: 慢放倍数(1=正常,2=半速)
|
||||
"""
|
||||
try:
|
||||
with h5py.File(episode_file, 'r') as f:
|
||||
# 读取图像序列
|
||||
img_path = f'/observations/images/{camera}'
|
||||
|
||||
if img_path not in f:
|
||||
print(f" ❌ 相机 {camera} 不存在")
|
||||
return False
|
||||
|
||||
images = f[img_path][:] # shape: (T, H, W, C)
|
||||
qpos = f['/observations/qpos'][:]
|
||||
actions = f['/action'][:]
|
||||
|
||||
total_frames = len(images)
|
||||
height, width = images.shape[1], images.shape[2]
|
||||
|
||||
# 创建视频写入器
|
||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||
actual_fps = fps // slow_factor
|
||||
out = cv2.VideoWriter(output_path, fourcc, actual_fps, (width, height))
|
||||
|
||||
# 逐帧写入
|
||||
for i in range(total_frames):
|
||||
frame = images[i].astype(np.uint8)
|
||||
|
||||
# 在图像上添加信息
|
||||
info_text = [
|
||||
f"Episode: {os.path.basename(episode_file).replace('.hdf5', '')}",
|
||||
f"Frame: {i}/{total_frames}",
|
||||
f"qpos[0:3]: [{qpos[i, 0]:.2f}, {qpos[i, 1]:.2f}, {qpos[i, 2]:.2f}]",
|
||||
]
|
||||
|
||||
for j, text in enumerate(info_text):
|
||||
cv2.putText(frame, text, (10, 30 + j*30),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
|
||||
|
||||
out.write(frame)
|
||||
|
||||
out.release()
|
||||
print(f" ✅ 保存: {output_path}")
|
||||
print(f" 帧数: {total_frames}, 尺寸: {width}x{height}, FPS: {actual_fps}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ 错误: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def generate_all_videos(camera='top', num_episodes=5, slow_factor=1):
|
||||
"""生成前 N 个 episode 的视频"""
|
||||
|
||||
dataset_dir = "roboimi/demos/dataset/sim_transfer"
|
||||
episode_files = sorted(glob.glob(os.path.join(dataset_dir, "episode_*.hdf5")))
|
||||
|
||||
if len(episode_files) == 0:
|
||||
print(f"❌ 没有找到数据文件: {dataset_dir}")
|
||||
return
|
||||
|
||||
# 创建输出目录
|
||||
output_dir = '/tmp/dataset_videos'
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
print(f"找到 {len(episode_files)} 个 episode 文件")
|
||||
print(f"将生成前 {min(num_episodes, len(episode_files))} 个 episode 的视频\n")
|
||||
|
||||
# 生成视频
|
||||
for i in range(min(num_episodes, len(episode_files))):
|
||||
ep_file = episode_files[i]
|
||||
ep_name = os.path.basename(ep_file).replace('.hdf5', '')
|
||||
output_path = f"{output_dir}/{ep_name}_{camera}.mp4"
|
||||
|
||||
print(f"[{i+1}/{min(num_episodes, len(episode_files))}] {ep_name}")
|
||||
episode_to_video(ep_file, output_path, camera=camera, slow_factor=slow_factor)
|
||||
print()
|
||||
|
||||
print(f"✅ 所有视频已保存到: {output_dir}")
|
||||
print(f"\n播放方法:")
|
||||
print(f" # 播放单个视频")
|
||||
print(f" vlc {output_dir}/*.mp4")
|
||||
print(f" ")
|
||||
print(f" # 或用文件管理器")
|
||||
print(f" nautilus {output_dir}")
|
||||
|
||||
|
||||
def generate_multi_camera_video(episode_idx=0, slow_factor=1):
|
||||
"""生成包含多个相机的视频(分屏显示)"""
|
||||
|
||||
dataset_dir = "roboimi/demos/dataset/sim_transfer"
|
||||
episode_files = sorted(glob.glob(os.path.join(dataset_dir, "episode_*.hdf5")))
|
||||
|
||||
if episode_idx >= len(episode_files):
|
||||
print(f"❌ Episode {episode_idx} 不存在")
|
||||
return
|
||||
|
||||
ep_file = episode_files[episode_idx]
|
||||
|
||||
try:
|
||||
with h5py.File(ep_file, 'r') as f:
|
||||
# 获取所有相机
|
||||
cameras = []
|
||||
for key in f.keys():
|
||||
if 'images' in key:
|
||||
for cam_name in f[key].keys():
|
||||
if cam_name not in cameras:
|
||||
cameras.append(cam_name)
|
||||
|
||||
print(f"Episode {episode_idx} 的相机: {cameras}")
|
||||
|
||||
# 读取所有相机的图像
|
||||
all_images = {}
|
||||
for cam in cameras:
|
||||
img_path = f'/observations/images/{cam}'
|
||||
if img_path in f:
|
||||
all_images[cam] = f[img_path][:]
|
||||
|
||||
if not all_images:
|
||||
print("❌ 没有找到图像数据")
|
||||
return
|
||||
|
||||
# 获取第一个相机的尺寸
|
||||
first_cam = list(all_images.keys())[0]
|
||||
total_frames = len(all_images[first_cam])
|
||||
height, width = all_images[first_cam].shape[1], all_images[first_cam].shape[2]
|
||||
|
||||
# 创建多相机布局
|
||||
num_cams = len(all_images)
|
||||
cols = min(2, num_cams)
|
||||
rows = (num_cams + cols - 1) // cols
|
||||
|
||||
canvas_width = width * cols
|
||||
canvas_height = height * rows
|
||||
|
||||
# 创建视频写入器
|
||||
output_path = f'/tmp/dataset_videos/episode_{episode_idx}_all_cameras.mp4'
|
||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||
out = cv2.VideoWriter(output_path, fourcc, 30 // slow_factor, (canvas_width, canvas_height))
|
||||
|
||||
# 逐帧合成
|
||||
for i in range(total_frames):
|
||||
canvas = np.zeros((canvas_height, canvas_width, 3), dtype=np.uint8)
|
||||
|
||||
for cam_idx, cam_name in enumerate(all_images.keys()):
|
||||
img = all_images[cam_name][i]
|
||||
|
||||
# 计算在画布上的位置
|
||||
row = cam_idx // cols
|
||||
col = cam_idx % cols
|
||||
y_start = row * height
|
||||
y_end = y_start + height
|
||||
x_start = col * width
|
||||
x_end = x_start + width
|
||||
|
||||
# 调整大小(如果需要)
|
||||
if img.shape[:2] != (height, width):
|
||||
img = cv2.resize(img, (width, height))
|
||||
|
||||
# 放到画布上
|
||||
canvas[y_start:y_end, x_start:x_end] = img
|
||||
|
||||
# 添加相机名称
|
||||
cv2.putText(canvas, cam_name, (x_start + 10, y_start + 30),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
|
||||
|
||||
# 添加帧信息
|
||||
cv2.putText(canvas, f"Frame: {i}/{total_frames}", (10, canvas_height - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
|
||||
|
||||
out.write(canvas)
|
||||
|
||||
out.release()
|
||||
print(f"✅ 保存多相机视频: {output_path}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 错误: {e}")
|
||||
|
||||
|
||||
def compare_episodes(camera='top', slow_factor=2):
|
||||
"""并排对比多个 episode 的视频"""
|
||||
|
||||
dataset_dir = "roboimi/demos/dataset/sim_transfer"
|
||||
episode_files = sorted(glob.glob(os.path.join(dataset_dir, "episode_*.hdf5")))
|
||||
|
||||
# 选择要对比的 episode
|
||||
episodes_to_compare = [0, 1, 2, 3, 4] # 对比前 5 个
|
||||
|
||||
print(f"对比 Episodes: {episodes_to_compare}")
|
||||
|
||||
# 读取所有 episode 的数据
|
||||
all_data = []
|
||||
for ep_idx in episodes_to_compare:
|
||||
if ep_idx >= len(episode_files):
|
||||
continue
|
||||
|
||||
try:
|
||||
with h5py.File(episode_files[ep_idx], 'r') as f:
|
||||
img_path = f'/observations/images/{camera}'
|
||||
if img_path in f:
|
||||
all_data.append({
|
||||
'idx': ep_idx,
|
||||
'images': f[img_path][:],
|
||||
'qpos': f['/observations/qpos'][:]
|
||||
})
|
||||
except:
|
||||
pass
|
||||
|
||||
if len(all_data) == 0:
|
||||
print("❌ 没有数据")
|
||||
return
|
||||
|
||||
# 获取参数
|
||||
first_data = all_data[0]
|
||||
height, width = first_data['images'].shape[1], first_data['images'].shape[2]
|
||||
total_frames = min([d['images'].shape[0] for d in all_data])
|
||||
|
||||
# 创建并排布局
|
||||
num_compare = len(all_data)
|
||||
canvas_width = width * num_compare
|
||||
canvas_height = height
|
||||
|
||||
# 创建视频
|
||||
output_path = f'/tmp/dataset_videos/compare_{camera}.mp4'
|
||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||
out = cv2.VideoWriter(output_path, fourcc, 30 // slow_factor, (canvas_width, canvas_height))
|
||||
|
||||
print(f"生成对比视频,共 {total_frames} 帧...")
|
||||
|
||||
# 逐帧对比
|
||||
for i in range(total_frames):
|
||||
canvas = np.zeros((canvas_height, canvas_width, 3), dtype=np.uint8)
|
||||
|
||||
for j, data in enumerate(all_data):
|
||||
img = data['images'][i]
|
||||
qpos = data['qpos'][i]
|
||||
|
||||
# 调整大小(如果需要)
|
||||
if img.shape[:2] != (height, width):
|
||||
img = cv2.resize(img, (width, height))
|
||||
|
||||
# 放到画布上
|
||||
x_start = j * width
|
||||
x_end = x_start + width
|
||||
canvas[:, x_start:x_end] = img
|
||||
|
||||
# 添加信息
|
||||
ep_name = f"Ep {data['idx']}"
|
||||
cv2.putText(canvas, ep_name, (x_start + 10, 30),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 255), 2)
|
||||
cv2.putText(canvas, f"qpos[0:3]: [{qpos[0]:.2f}, {qpos[1]:.2f}, {qpos[2]:.2f}]",
|
||||
(x_start + 10, height - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
|
||||
|
||||
# 添加帧号
|
||||
cv2.putText(canvas, f"Frame: {i}/{total_frames}", (10, canvas_height - 30),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
|
||||
|
||||
out.write(canvas)
|
||||
|
||||
if i % 100 == 0:
|
||||
print(f" 进度: {i}/{total_frames}")
|
||||
|
||||
out.release()
|
||||
print(f"✅ 保存对比视频: {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
print("="*60)
|
||||
print("数据集视频生成工具")
|
||||
print("="*60)
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
command = sys.argv[1]
|
||||
|
||||
if command == 'compare':
|
||||
# 对比多个 episode
|
||||
camera = sys.argv[2] if len(sys.argv) > 2 else 'top'
|
||||
compare_episodes(camera=camera, slow_factor=2)
|
||||
|
||||
elif command == 'multi':
|
||||
# 多相机视频
|
||||
ep_idx = int(sys.argv[2]) if len(sys.argv) > 2 else 0
|
||||
generate_multi_camera_video(episode_idx=ep_idx, slow_factor=1)
|
||||
|
||||
else:
|
||||
print("未知命令")
|
||||
else:
|
||||
# 默认:生成前 5 个 episode 的视频
|
||||
print("\n生成前 5 个 episode 的视频(top 相机,慢放 2x)...")
|
||||
print("="*60 + "\n")
|
||||
generate_all_videos(camera='top', num_episodes=5, slow_factor=2)
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("其他用法:")
|
||||
print(" python generate_dataset_videos.py compare top # 对比多个 episode")
|
||||
print(" python generate_dataset_videos.py multi 0 # 多相机视频")
|
||||
print("="*60)
|
||||
Reference in New Issue
Block a user