feat(dataset): 添加统计数据计算脚本

This commit is contained in:
gouhanke
2026-02-04 21:53:48 +08:00
parent 03f10b0c22
commit 92660562fb

View File

@@ -0,0 +1,72 @@
import h5py
import numpy as np
import os
import glob
import pickle
def get_data_stats(dataset_dir):
"""
计算 action 和 qpos 的 Min, Max, Mean, Std
"""
files = sorted(glob.glob(os.path.join(dataset_dir, 'episode_*.hdf5')))
print(f"Found {len(files)} episodes in {dataset_dir}")
all_actions = []
all_qpos = []
print("Reading data...")
for file_path in files:
with h5py.File(file_path, 'r') as f:
action = f['action'][:]
qpos = f['observations']['qpos'][:]
all_actions.append(action)
all_qpos.append(qpos)
# 拼接所有数据
all_actions = np.concatenate(all_actions, axis=0)
all_qpos = np.concatenate(all_qpos, axis=0)
print(f"Total steps: {all_actions.shape[0]}")
# --- 核心计算部分 ---
stats = {
'action': {
'min': np.min(all_actions, axis=0),
'max': np.max(all_actions, axis=0),
'mean': np.mean(all_actions, axis=0), # 均值
'std': np.std(all_actions, axis=0) # 标准差
},
'qpos': {
'min': np.min(all_qpos, axis=0),
'max': np.max(all_qpos, axis=0),
'mean': np.mean(all_qpos, axis=0), # 均值
'std': np.std(all_qpos, axis=0) # 标准差
}
}
# --- 修正标准差 (防止除以 0) ---
# 如果某个关节从未移动例如备用按钮std 会是 0导致除零错误。
# 策略:将 std 为 0 的地方替换为 1.0 (不缩放) 或一个小的 epsilon
for key in stats:
# 找到 std 极小的维度
std = stats[key]['std']
std = np.where(std < 1e-8, 1.0, std) # 如果 std 太小,设为 1.0 避免除零
stats[key]['std'] = std
return stats
if __name__ == "__main__":
DATASET_DIR = 'roboimi/demos/dataset/sim_transfer'
OUTPUT_PATH = DATASET_DIR + "/data_stats.pkl"
stats = get_data_stats(DATASET_DIR)
# 打印检查
print("\n--- Stats Computed ---")
print(f"Action Mean shape: {stats['action']['mean'].shape}")
print(f"Action Std shape: {stats['action']['std'].shape}")
# 保存
with open(OUTPUT_PATH, 'wb') as f:
pickle.dump(stats, f)
print(f"\nStats saved to {OUTPUT_PATH}")