From 92660562fb0234441b80086329e9823d402f37b3 Mon Sep 17 00:00:00 2001 From: gouhanke <12219217+gouhanke@user.noreply.gitee.com> Date: Wed, 4 Feb 2026 21:53:48 +0800 Subject: [PATCH] =?UTF-8?q?feat(dataset):=20=E6=B7=BB=E5=8A=A0=E7=BB=9F?= =?UTF-8?q?=E8=AE=A1=E6=95=B0=E6=8D=AE=E8=AE=A1=E7=AE=97=E8=84=9A=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- roboimi/vla/scripts/calculate_stats.py | 72 ++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 roboimi/vla/scripts/calculate_stats.py diff --git a/roboimi/vla/scripts/calculate_stats.py b/roboimi/vla/scripts/calculate_stats.py new file mode 100644 index 0000000..8fd5e9d --- /dev/null +++ b/roboimi/vla/scripts/calculate_stats.py @@ -0,0 +1,72 @@ +import h5py +import numpy as np +import os +import glob +import pickle + +def get_data_stats(dataset_dir): + """ + 计算 action 和 qpos 的 Min, Max, Mean, Std + """ + files = sorted(glob.glob(os.path.join(dataset_dir, 'episode_*.hdf5'))) + print(f"Found {len(files)} episodes in {dataset_dir}") + + all_actions = [] + all_qpos = [] + + print("Reading data...") + for file_path in files: + with h5py.File(file_path, 'r') as f: + action = f['action'][:] + qpos = f['observations']['qpos'][:] + all_actions.append(action) + all_qpos.append(qpos) + + # 拼接所有数据 + all_actions = np.concatenate(all_actions, axis=0) + all_qpos = np.concatenate(all_qpos, axis=0) + + print(f"Total steps: {all_actions.shape[0]}") + + # --- 核心计算部分 --- + stats = { + 'action': { + 'min': np.min(all_actions, axis=0), + 'max': np.max(all_actions, axis=0), + 'mean': np.mean(all_actions, axis=0), # 均值 + 'std': np.std(all_actions, axis=0) # 标准差 + }, + 'qpos': { + 'min': np.min(all_qpos, axis=0), + 'max': np.max(all_qpos, axis=0), + 'mean': np.mean(all_qpos, axis=0), # 均值 + 'std': np.std(all_qpos, axis=0) # 标准差 + } + } + + # --- 修正标准差 (防止除以 0) --- + # 如果某个关节从未移动(例如备用按钮),std 会是 0,导致除零错误。 + # 策略:将 std 为 0 的地方替换为 1.0 (不缩放) 或一个小的 epsilon + for key in stats: + # 找到 std 极小的维度 + std = stats[key]['std'] + std = np.where(std < 1e-8, 1.0, std) # 如果 std 太小,设为 1.0 避免除零 + stats[key]['std'] = std + + return stats + +if __name__ == "__main__": + DATASET_DIR = 'roboimi/demos/dataset/sim_transfer' + OUTPUT_PATH = DATASET_DIR + "/data_stats.pkl" + + stats = get_data_stats(DATASET_DIR) + + # 打印检查 + print("\n--- Stats Computed ---") + print(f"Action Mean shape: {stats['action']['mean'].shape}") + print(f"Action Std shape: {stats['action']['std'].shape}") + + # 保存 + with open(OUTPUT_PATH, 'wb') as f: + pickle.dump(stats, f) + print(f"\nStats saved to {OUTPUT_PATH}") \ No newline at end of file