87 lines
2.5 KiB
Python
87 lines
2.5 KiB
Python
import h5py
|
||
import numpy as np
|
||
import os
|
||
import glob
|
||
import pickle
|
||
|
||
def get_data_stats(dataset_dir):
|
||
"""
|
||
计算 action 和 qpos 的 Min, Max, Mean, Std
|
||
|
||
输出扁平化结构(与 NormalizationModule 期望一致):
|
||
{
|
||
'action_mean': [...],
|
||
'action_std': [...],
|
||
'action_min': [...],
|
||
'action_max': [...],
|
||
'qpos_mean': [...],
|
||
'qpos_std': [...],
|
||
'qpos_min': [...],
|
||
'qpos_max': [...],
|
||
}
|
||
"""
|
||
files = sorted(glob.glob(os.path.join(dataset_dir, 'episode_*.hdf5')))
|
||
print(f"Found {len(files)} episodes in {dataset_dir}")
|
||
|
||
all_actions = []
|
||
all_qpos = []
|
||
|
||
print("Reading data...")
|
||
for file_path in files:
|
||
with h5py.File(file_path, 'r') as f:
|
||
action = f['action'][:]
|
||
qpos = f['observations']['qpos'][:]
|
||
all_actions.append(action)
|
||
all_qpos.append(qpos)
|
||
|
||
# 拼接所有数据
|
||
all_actions = np.concatenate(all_actions, axis=0)
|
||
all_qpos = np.concatenate(all_qpos, axis=0)
|
||
|
||
print(f"Total steps: {all_actions.shape[0]}")
|
||
|
||
# --- 核心计算部分 ---
|
||
# 计算统计量
|
||
action_mean = np.mean(all_actions, axis=0)
|
||
action_std = np.std(all_actions, axis=0)
|
||
action_min = np.min(all_actions, axis=0)
|
||
action_max = np.max(all_actions, axis=0)
|
||
|
||
qpos_mean = np.mean(all_qpos, axis=0)
|
||
qpos_std = np.std(all_qpos, axis=0)
|
||
qpos_min = np.min(all_qpos, axis=0)
|
||
qpos_max = np.max(all_qpos, axis=0)
|
||
|
||
# 修正标准差(防止除以 0)
|
||
eps = 1e-8
|
||
action_std_corrected = np.where(action_std < eps, eps, action_std)
|
||
qpos_std_corrected = np.where(qpos_std < eps, eps, qpos_std)
|
||
|
||
# 转换为扁平化结构(与 NormalizationModule 期望一致)
|
||
stats_flat = {
|
||
'action_mean': action_mean,
|
||
'action_std': action_std_corrected,
|
||
'action_min': action_min,
|
||
'action_max': action_max,
|
||
'qpos_mean': qpos_mean,
|
||
'qpos_std': qpos_std_corrected,
|
||
'qpos_min': qpos_min,
|
||
'qpos_max': qpos_max
|
||
}
|
||
return stats_flat
|
||
|
||
if __name__ == "__main__":
|
||
DATASET_DIR = 'roboimi/demos/dataset/sim_transfer'
|
||
OUTPUT_PATH = DATASET_DIR + "/dataset_stats.pkl"
|
||
|
||
stats_flat = get_data_stats(DATASET_DIR)
|
||
|
||
# 打印检查
|
||
print("\n--- Stats Computed ---")
|
||
print(f"Action Mean shape: {stats_flat['action_mean'].shape}")
|
||
print(f"Action Std shape: {stats_flat['action_std'].shape}")
|
||
|
||
# 保存
|
||
with open(OUTPUT_PATH, 'wb') as f:
|
||
pickle.dump(stats_flat, f)
|
||
print(f"\nStats saved to {OUTPUT_PATH}") |