diff --git a/roboimi/vla/scripts/calculate_stats.py b/roboimi/vla/scripts/calculate_stats.py new file mode 100644 index 0000000..8fd5e9d --- /dev/null +++ b/roboimi/vla/scripts/calculate_stats.py @@ -0,0 +1,72 @@ +import h5py +import numpy as np +import os +import glob +import pickle + +def get_data_stats(dataset_dir): + """ + 计算 action 和 qpos 的 Min, Max, Mean, Std + """ + files = sorted(glob.glob(os.path.join(dataset_dir, 'episode_*.hdf5'))) + print(f"Found {len(files)} episodes in {dataset_dir}") + + all_actions = [] + all_qpos = [] + + print("Reading data...") + for file_path in files: + with h5py.File(file_path, 'r') as f: + action = f['action'][:] + qpos = f['observations']['qpos'][:] + all_actions.append(action) + all_qpos.append(qpos) + + # 拼接所有数据 + all_actions = np.concatenate(all_actions, axis=0) + all_qpos = np.concatenate(all_qpos, axis=0) + + print(f"Total steps: {all_actions.shape[0]}") + + # --- 核心计算部分 --- + stats = { + 'action': { + 'min': np.min(all_actions, axis=0), + 'max': np.max(all_actions, axis=0), + 'mean': np.mean(all_actions, axis=0), # 均值 + 'std': np.std(all_actions, axis=0) # 标准差 + }, + 'qpos': { + 'min': np.min(all_qpos, axis=0), + 'max': np.max(all_qpos, axis=0), + 'mean': np.mean(all_qpos, axis=0), # 均值 + 'std': np.std(all_qpos, axis=0) # 标准差 + } + } + + # --- 修正标准差 (防止除以 0) --- + # 如果某个关节从未移动(例如备用按钮),std 会是 0,导致除零错误。 + # 策略:将 std 为 0 的地方替换为 1.0 (不缩放) 或一个小的 epsilon + for key in stats: + # 找到 std 极小的维度 + std = stats[key]['std'] + std = np.where(std < 1e-8, 1.0, std) # 如果 std 太小,设为 1.0 避免除零 + stats[key]['std'] = std + + return stats + +if __name__ == "__main__": + DATASET_DIR = 'roboimi/demos/dataset/sim_transfer' + OUTPUT_PATH = DATASET_DIR + "/data_stats.pkl" + + stats = get_data_stats(DATASET_DIR) + + # 打印检查 + print("\n--- Stats Computed ---") + print(f"Action Mean shape: {stats['action']['mean'].shape}") + print(f"Action Std shape: {stats['action']['std'].shape}") + + # 保存 + with open(OUTPUT_PATH, 'wb') as f: + pickle.dump(stats, f) + print(f"\nStats saved to {OUTPUT_PATH}") \ No newline at end of file