feat(dataset): 添加统计数据计算脚本
This commit is contained in:
72
roboimi/vla/scripts/calculate_stats.py
Normal file
72
roboimi/vla/scripts/calculate_stats.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
import h5py
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
def get_data_stats(dataset_dir):
|
||||||
|
"""
|
||||||
|
计算 action 和 qpos 的 Min, Max, Mean, Std
|
||||||
|
"""
|
||||||
|
files = sorted(glob.glob(os.path.join(dataset_dir, 'episode_*.hdf5')))
|
||||||
|
print(f"Found {len(files)} episodes in {dataset_dir}")
|
||||||
|
|
||||||
|
all_actions = []
|
||||||
|
all_qpos = []
|
||||||
|
|
||||||
|
print("Reading data...")
|
||||||
|
for file_path in files:
|
||||||
|
with h5py.File(file_path, 'r') as f:
|
||||||
|
action = f['action'][:]
|
||||||
|
qpos = f['observations']['qpos'][:]
|
||||||
|
all_actions.append(action)
|
||||||
|
all_qpos.append(qpos)
|
||||||
|
|
||||||
|
# 拼接所有数据
|
||||||
|
all_actions = np.concatenate(all_actions, axis=0)
|
||||||
|
all_qpos = np.concatenate(all_qpos, axis=0)
|
||||||
|
|
||||||
|
print(f"Total steps: {all_actions.shape[0]}")
|
||||||
|
|
||||||
|
# --- 核心计算部分 ---
|
||||||
|
stats = {
|
||||||
|
'action': {
|
||||||
|
'min': np.min(all_actions, axis=0),
|
||||||
|
'max': np.max(all_actions, axis=0),
|
||||||
|
'mean': np.mean(all_actions, axis=0), # 均值
|
||||||
|
'std': np.std(all_actions, axis=0) # 标准差
|
||||||
|
},
|
||||||
|
'qpos': {
|
||||||
|
'min': np.min(all_qpos, axis=0),
|
||||||
|
'max': np.max(all_qpos, axis=0),
|
||||||
|
'mean': np.mean(all_qpos, axis=0), # 均值
|
||||||
|
'std': np.std(all_qpos, axis=0) # 标准差
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# --- 修正标准差 (防止除以 0) ---
|
||||||
|
# 如果某个关节从未移动(例如备用按钮),std 会是 0,导致除零错误。
|
||||||
|
# 策略:将 std 为 0 的地方替换为 1.0 (不缩放) 或一个小的 epsilon
|
||||||
|
for key in stats:
|
||||||
|
# 找到 std 极小的维度
|
||||||
|
std = stats[key]['std']
|
||||||
|
std = np.where(std < 1e-8, 1.0, std) # 如果 std 太小,设为 1.0 避免除零
|
||||||
|
stats[key]['std'] = std
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
DATASET_DIR = 'roboimi/demos/dataset/sim_transfer'
|
||||||
|
OUTPUT_PATH = DATASET_DIR + "/data_stats.pkl"
|
||||||
|
|
||||||
|
stats = get_data_stats(DATASET_DIR)
|
||||||
|
|
||||||
|
# 打印检查
|
||||||
|
print("\n--- Stats Computed ---")
|
||||||
|
print(f"Action Mean shape: {stats['action']['mean'].shape}")
|
||||||
|
print(f"Action Std shape: {stats['action']['std'].shape}")
|
||||||
|
|
||||||
|
# 保存
|
||||||
|
with open(OUTPUT_PATH, 'wb') as f:
|
||||||
|
pickle.dump(stats, f)
|
||||||
|
print(f"\nStats saved to {OUTPUT_PATH}")
|
||||||
Reference in New Issue
Block a user