Files
roboimi/roboimi/vla/scripts/calculate_stats.py
2026-02-12 12:23:34 +08:00

87 lines
2.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import h5py
import numpy as np
import os
import glob
import pickle
def get_data_stats(dataset_dir):
"""
计算 action 和 qpos 的 Min, Max, Mean, Std
输出扁平化结构(与 NormalizationModule 期望一致):
{
'action_mean': [...],
'action_std': [...],
'action_min': [...],
'action_max': [...],
'qpos_mean': [...],
'qpos_std': [...],
'qpos_min': [...],
'qpos_max': [...],
}
"""
files = sorted(glob.glob(os.path.join(dataset_dir, 'episode_*.hdf5')))
print(f"Found {len(files)} episodes in {dataset_dir}")
all_actions = []
all_qpos = []
print("Reading data...")
for file_path in files:
with h5py.File(file_path, 'r') as f:
action = f['action'][:]
qpos = f['observations']['qpos'][:]
all_actions.append(action)
all_qpos.append(qpos)
# 拼接所有数据
all_actions = np.concatenate(all_actions, axis=0)
all_qpos = np.concatenate(all_qpos, axis=0)
print(f"Total steps: {all_actions.shape[0]}")
# --- 核心计算部分 ---
# 计算统计量
action_mean = np.mean(all_actions, axis=0)
action_std = np.std(all_actions, axis=0)
action_min = np.min(all_actions, axis=0)
action_max = np.max(all_actions, axis=0)
qpos_mean = np.mean(all_qpos, axis=0)
qpos_std = np.std(all_qpos, axis=0)
qpos_min = np.min(all_qpos, axis=0)
qpos_max = np.max(all_qpos, axis=0)
# 修正标准差(防止除以 0
eps = 1e-8
action_std_corrected = np.where(action_std < eps, eps, action_std)
qpos_std_corrected = np.where(qpos_std < eps, eps, qpos_std)
# 转换为扁平化结构(与 NormalizationModule 期望一致)
stats_flat = {
'action_mean': action_mean,
'action_std': action_std_corrected,
'action_min': action_min,
'action_max': action_max,
'qpos_mean': qpos_mean,
'qpos_std': qpos_std_corrected,
'qpos_min': qpos_min,
'qpos_max': qpos_max
}
return stats_flat
if __name__ == "__main__":
DATASET_DIR = 'roboimi/demos/dataset/sim_transfer'
OUTPUT_PATH = DATASET_DIR + "/dataset_stats.pkl"
stats_flat = get_data_stats(DATASET_DIR)
# 打印检查
print("\n--- Stats Computed ---")
print(f"Action Mean shape: {stats_flat['action_mean'].shape}")
print(f"Action Std shape: {stats_flat['action_std'].shape}")
# 保存
with open(OUTPUT_PATH, 'wb') as f:
pickle.dump(stats_flat, f)
print(f"\nStats saved to {OUTPUT_PATH}")