diff --git a/roboimi/demos/vla_scripts/train_vla.py b/roboimi/demos/vla_scripts/train_vla.py index f5fbcb1..358cb5e 100644 --- a/roboimi/demos/vla_scripts/train_vla.py +++ b/roboimi/demos/vla_scripts/train_vla.py @@ -164,24 +164,24 @@ def main(cfg: DictConfig): dataset_stats = None try: dataset_dir = cfg.data.get('dataset_dir', 'roboimi/demos/dataset/sim_transfer') - stats_path = Path(dataset_dir) / 'data_stats.pkl' + stats_path = Path(dataset_dir) / 'dataset_stats.pkl' if stats_path.exists(): with open(stats_path, 'rb') as f: stats = pickle.load(f) + # 扁平化stats字典(嵌套结构→扁平结构)以匹配NormalizationModule的期望格式 dataset_stats = { - 'normalization_type': cfg.data.get('normalization_type', 'gaussian'), - 'action_mean': stats['action']['mean'].tolist(), - 'action_std': stats['action']['std'].tolist(), - 'action_min': stats['action']['min'].tolist(), - 'action_max': stats['action']['max'].tolist(), - 'qpos_mean': stats['qpos']['mean'].tolist(), - 'qpos_std': stats['qpos']['std'].tolist(), - 'qpos_min': stats['qpos']['min'].tolist(), - 'qpos_max': stats['qpos']['max'].tolist(), + 'action_mean': stats['action_mean'].tolist(), + 'action_std': stats['action_std'].tolist(), + 'action_min': stats['action_min'].tolist(), + 'action_max': stats['action_max'].tolist(), + 'qpos_mean': stats['qpos_mean'].tolist(), + 'qpos_std': stats['qpos_std'].tolist(), + 'qpos_min': stats['qpos_min'].tolist(), + 'qpos_max': stats['qpos_max'].tolist(), } - log.info(f"✅ 数据集统计信息加载完成 (归一化: {dataset_stats['normalization_type']})") + log.info(f"✅ 数据集统计信息加载完成 (归一化: {cfg.agent.normalization_type})") else: log.warning(f"⚠️ 统计文件未找到: {stats_path}") log.warning("⚠️ 推理时动作将无法反归一化!") diff --git a/roboimi/vla/conf/agent/resnet_diffusion.yaml b/roboimi/vla/conf/agent/resnet_diffusion.yaml index e079f52..3574f96 100644 --- a/roboimi/vla/conf/agent/resnet_diffusion.yaml +++ b/roboimi/vla/conf/agent/resnet_diffusion.yaml @@ -15,6 +15,11 @@ _target_: roboimi.vla.agent.VLAAgent action_dim: 16 # 动作维度(机器人关节数) obs_dim: 16 # 本体感知维度(关节位置) +# ==================== +# +# ==================== +normalization_type: "min_max" # "min_max" or "gaussian" + # ==================== # 时间步配置 # ==================== diff --git a/roboimi/vla/models/normalization.py b/roboimi/vla/models/normalization.py index 8cfbce7..cb5adef 100644 --- a/roboimi/vla/models/normalization.py +++ b/roboimi/vla/models/normalization.py @@ -14,20 +14,18 @@ from typing import Optional, Dict, Literal class NormalizationModule(nn.Module): """ 统一的归一化模块 - 用于在 Agent 内部对 qpos 和 action 进行归一化/反归一化 """ def __init__( self, stats: Optional[Dict] = None, - normalization_type: Literal['gaussian', 'min_max'] = 'min_max' + normalization_type: Literal['gaussian', 'min_max'] = None, ): """ Args: stats: 数据集统计信息字典,格式: { - 'normalization_type': 'gaussian' | 'min_max', 'qpos_mean': [...], 'qpos_std': [...], 'qpos_min': [...], # 仅 min_max 需要 @@ -45,26 +43,17 @@ class NormalizationModule(nn.Module): self.enabled = stats is not None if self.enabled: - # 从 stats 中读取归一化类型(如果提供) - self.normalization_type = stats.get('normalization_type', normalization_type) + if self.normalization_type == 'gaussian': + self.register_buffer('qpos_mean', torch.tensor(stats['qpos_mean'], dtype=torch.float32)) + self.register_buffer('qpos_std', torch.tensor(stats['qpos_std'], dtype=torch.float32)) + self.register_buffer('action_mean', torch.tensor(stats['action_mean'], dtype=torch.float32)) + self.register_buffer('action_std', torch.tensor(stats['action_std'], dtype=torch.float32)) - # 注册为 buffer (不会被优化,但会随模型保存) - self.register_buffer('qpos_mean', torch.tensor(stats['qpos_mean'], dtype=torch.float32)) - self.register_buffer('qpos_std', torch.tensor(stats['qpos_std'], dtype=torch.float32)) - self.register_buffer('action_mean', torch.tensor(stats['action_mean'], dtype=torch.float32)) - self.register_buffer('action_std', torch.tensor(stats['action_std'], dtype=torch.float32)) - - # MinMax 归一化需要 min/max - if self.normalization_type == 'min_max': - qpos_min = stats.get('qpos_min', [0.0] * len(stats['qpos_mean'])) - qpos_max = stats.get('qpos_max', [1.0] * len(stats['qpos_mean'])) - action_min = stats.get('action_min', [0.0] * len(stats['action_mean'])) - action_max = stats.get('action_max', [1.0] * len(stats['action_mean'])) - - self.register_buffer('qpos_min', torch.tensor(qpos_min, dtype=torch.float32)) - self.register_buffer('qpos_max', torch.tensor(qpos_max, dtype=torch.float32)) - self.register_buffer('action_min', torch.tensor(action_min, dtype=torch.float32)) - self.register_buffer('action_max', torch.tensor(action_max, dtype=torch.float32)) + elif self.normalization_type == 'min_max': + self.register_buffer('qpos_min', torch.tensor(stats['qpos_min'], dtype=torch.float32)) + self.register_buffer('qpos_max', torch.tensor(stats['qpos_max'], dtype=torch.float32)) + self.register_buffer('action_min', torch.tensor(stats['action_min'], dtype=torch.float32)) + self.register_buffer('action_max', torch.tensor(stats['action_max'], dtype=torch.float32)) def normalize_qpos(self, qpos: torch.Tensor) -> torch.Tensor: """归一化 qpos""" @@ -73,8 +62,10 @@ class NormalizationModule(nn.Module): if self.normalization_type == 'gaussian': return (qpos - self.qpos_mean) / self.qpos_std - else: # min_max + elif self.normalization_type == 'min_max': return 2 * (qpos - self.qpos_min) / (self.qpos_max - self.qpos_min) - 1 + else: + raise ValueError(f"Unknown normalization type: {self.normalization_type}") def denormalize_qpos(self, qpos: torch.Tensor) -> torch.Tensor: """反归一化 qpos""" @@ -83,8 +74,10 @@ class NormalizationModule(nn.Module): if self.normalization_type == 'gaussian': return qpos * self.qpos_std + self.qpos_mean - else: # min_max + elif self.normalization_type == 'min_max': return (qpos + 1) / 2 * (self.qpos_max - self.qpos_min) + self.qpos_min + else: + raise ValueError(f"Unknown normalization type: {self.normalization_type}") def normalize_action(self, action: torch.Tensor) -> torch.Tensor: """归一化 action""" @@ -93,8 +86,10 @@ class NormalizationModule(nn.Module): if self.normalization_type == 'gaussian': return (action - self.action_mean) / self.action_std - else: # min_max + elif self.normalization_type == 'min_max': return 2 * (action - self.action_min) / (self.action_max - self.action_min) - 1 + else: + raise ValueError(f"Unknown normalization type: {self.normalization_type}") def denormalize_action(self, action: torch.Tensor) -> torch.Tensor: """反归一化 action""" @@ -103,8 +98,10 @@ class NormalizationModule(nn.Module): if self.normalization_type == 'gaussian': return action * self.action_std + self.action_mean - else: # min_max + elif self.normalization_type == 'min_max': return (action + 1) / 2 * (self.action_max - self.action_min) + self.action_min + else: + raise ValueError(f"Unknown normalization type: {self.normalization_type}") def get_stats(self) -> Optional[Dict]: """导出统计信息(用于保存到 checkpoint)""" @@ -113,13 +110,14 @@ class NormalizationModule(nn.Module): stats = { 'normalization_type': self.normalization_type, - 'qpos_mean': self.qpos_mean.cpu().tolist(), - 'qpos_std': self.qpos_std.cpu().tolist(), - 'action_mean': self.action_mean.cpu().tolist(), - 'action_std': self.action_std.cpu().tolist(), } - if self.normalization_type == 'min_max': + if self.normalization_type == 'gaussian': + stats['qpos_mean'] = self.qpos_mean.cpu().tolist() + stats['qpos_std'] = self.qpos_std.cpu().tolist() + stats['action_mean'] = self.action_mean.cpu().tolist() + stats['action_std'] = self.action_std.cpu().tolist() + elif self.normalization_type == 'min_max': stats['qpos_min'] = self.qpos_min.cpu().tolist() stats['qpos_max'] = self.qpos_max.cpu().tolist() stats['action_min'] = self.action_min.cpu().tolist() diff --git a/roboimi/vla/scripts/calculate_stats.py b/roboimi/vla/scripts/calculate_stats.py index 8fd5e9d..5fece0e 100644 --- a/roboimi/vla/scripts/calculate_stats.py +++ b/roboimi/vla/scripts/calculate_stats.py @@ -7,6 +7,18 @@ import pickle def get_data_stats(dataset_dir): """ 计算 action 和 qpos 的 Min, Max, Mean, Std + + 输出扁平化结构(与 NormalizationModule 期望一致): + { + 'action_mean': [...], + 'action_std': [...], + 'action_min': [...], + 'action_max': [...], + 'qpos_mean': [...], + 'qpos_std': [...], + 'qpos_min': [...], + 'qpos_max': [...], + } """ files = sorted(glob.glob(os.path.join(dataset_dir, 'episode_*.hdf5'))) print(f"Found {len(files)} episodes in {dataset_dir}") @@ -17,8 +29,8 @@ def get_data_stats(dataset_dir): print("Reading data...") for file_path in files: with h5py.File(file_path, 'r') as f: - action = f['action'][:] - qpos = f['observations']['qpos'][:] + action = f['action'][:] + qpos = f['observations']['qpos'][:] all_actions.append(action) all_qpos.append(qpos) @@ -29,44 +41,47 @@ def get_data_stats(dataset_dir): print(f"Total steps: {all_actions.shape[0]}") # --- 核心计算部分 --- - stats = { - 'action': { - 'min': np.min(all_actions, axis=0), - 'max': np.max(all_actions, axis=0), - 'mean': np.mean(all_actions, axis=0), # 均值 - 'std': np.std(all_actions, axis=0) # 标准差 - }, - 'qpos': { - 'min': np.min(all_qpos, axis=0), - 'max': np.max(all_qpos, axis=0), - 'mean': np.mean(all_qpos, axis=0), # 均值 - 'std': np.std(all_qpos, axis=0) # 标准差 - } + # 计算统计量 + action_mean = np.mean(all_actions, axis=0) + action_std = np.std(all_actions, axis=0) + action_min = np.min(all_actions, axis=0) + action_max = np.max(all_actions, axis=0) + + qpos_mean = np.mean(all_qpos, axis=0) + qpos_std = np.std(all_qpos, axis=0) + qpos_min = np.min(all_qpos, axis=0) + qpos_max = np.max(all_qpos, axis=0) + + # 修正标准差(防止除以 0) + eps = 1e-8 + action_std_corrected = np.where(action_std < eps, eps, action_std) + qpos_std_corrected = np.where(qpos_std < eps, eps, qpos_std) + + # 转换为扁平化结构(与 NormalizationModule 期望一致) + stats_flat = { + 'action_mean': action_mean, + 'action_std': action_std_corrected, + 'action_min': action_min, + 'action_max': action_max, + 'qpos_mean': qpos_mean, + 'qpos_std': qpos_std_corrected, + 'qpos_min': qpos_min, + 'qpos_max': qpos_max } - - # --- 修正标准差 (防止除以 0) --- - # 如果某个关节从未移动(例如备用按钮),std 会是 0,导致除零错误。 - # 策略:将 std 为 0 的地方替换为 1.0 (不缩放) 或一个小的 epsilon - for key in stats: - # 找到 std 极小的维度 - std = stats[key]['std'] - std = np.where(std < 1e-8, 1.0, std) # 如果 std 太小,设为 1.0 避免除零 - stats[key]['std'] = std - - return stats + return stats_flat if __name__ == "__main__": DATASET_DIR = 'roboimi/demos/dataset/sim_transfer' - OUTPUT_PATH = DATASET_DIR + "/data_stats.pkl" + OUTPUT_PATH = DATASET_DIR + "/dataset_stats.pkl" - stats = get_data_stats(DATASET_DIR) + stats_flat = get_data_stats(DATASET_DIR) # 打印检查 print("\n--- Stats Computed ---") - print(f"Action Mean shape: {stats['action']['mean'].shape}") - print(f"Action Std shape: {stats['action']['std'].shape}") - + print(f"Action Mean shape: {stats_flat['action_mean'].shape}") + print(f"Action Std shape: {stats_flat['action_std'].shape}") + # 保存 with open(OUTPUT_PATH, 'wb') as f: - pickle.dump(stats, f) + pickle.dump(stats_flat, f) print(f"\nStats saved to {OUTPUT_PATH}") \ No newline at end of file