""" 归一化模块 - 统一训练和推理的归一化逻辑 支持两种归一化方式: 1. Gaussian (z-score): (x - mean) / std 2. MinMax: 2 * (x - min) / (max - min) - 1 -> [-1, 1] """ import torch import torch.nn as nn from typing import Optional, Dict, Literal class NormalizationModule(nn.Module): """ 统一的归一化模块 用于在 Agent 内部对 qpos 和 action 进行归一化/反归一化 """ def __init__( self, stats: Optional[Dict] = None, normalization_type: Literal['gaussian', 'min_max'] = None, ): """ Args: stats: 数据集统计信息字典,格式: { 'qpos_mean': [...], 'qpos_std': [...], 'qpos_min': [...], # 仅 min_max 需要 'qpos_max': [...], # 仅 min_max 需要 'action_mean': [...], 'action_std': [...], 'action_min': [...], # 仅 min_max 需要 'action_max': [...], # 仅 min_max 需要 } normalization_type: 归一化类型 ('gaussian' 或 'min_max') """ super().__init__() self.normalization_type = normalization_type self.enabled = stats is not None if self.enabled: if self.normalization_type == 'gaussian': self.register_buffer('qpos_mean', torch.tensor(stats['qpos_mean'], dtype=torch.float32)) self.register_buffer('qpos_std', torch.tensor(stats['qpos_std'], dtype=torch.float32)) self.register_buffer('action_mean', torch.tensor(stats['action_mean'], dtype=torch.float32)) self.register_buffer('action_std', torch.tensor(stats['action_std'], dtype=torch.float32)) elif self.normalization_type == 'min_max': self.register_buffer('qpos_min', torch.tensor(stats['qpos_min'], dtype=torch.float32)) self.register_buffer('qpos_max', torch.tensor(stats['qpos_max'], dtype=torch.float32)) self.register_buffer('action_min', torch.tensor(stats['action_min'], dtype=torch.float32)) self.register_buffer('action_max', torch.tensor(stats['action_max'], dtype=torch.float32)) def normalize_qpos(self, qpos: torch.Tensor) -> torch.Tensor: """归一化 qpos""" if not self.enabled: return qpos if self.normalization_type == 'gaussian': return (qpos - self.qpos_mean) / self.qpos_std elif self.normalization_type == 'min_max': return 2 * (qpos - self.qpos_min) / (self.qpos_max - self.qpos_min) - 1 else: raise ValueError(f"Unknown normalization type: {self.normalization_type}") def denormalize_qpos(self, qpos: torch.Tensor) -> torch.Tensor: """反归一化 qpos""" if not self.enabled: return qpos if self.normalization_type == 'gaussian': return qpos * self.qpos_std + self.qpos_mean elif self.normalization_type == 'min_max': return (qpos + 1) / 2 * (self.qpos_max - self.qpos_min) + self.qpos_min else: raise ValueError(f"Unknown normalization type: {self.normalization_type}") def normalize_action(self, action: torch.Tensor) -> torch.Tensor: """归一化 action""" if not self.enabled: return action if self.normalization_type == 'gaussian': return (action - self.action_mean) / self.action_std elif self.normalization_type == 'min_max': return 2 * (action - self.action_min) / (self.action_max - self.action_min) - 1 else: raise ValueError(f"Unknown normalization type: {self.normalization_type}") def denormalize_action(self, action: torch.Tensor) -> torch.Tensor: """反归一化 action""" if not self.enabled: return action if self.normalization_type == 'gaussian': return action * self.action_std + self.action_mean elif self.normalization_type == 'min_max': return (action + 1) / 2 * (self.action_max - self.action_min) + self.action_min else: raise ValueError(f"Unknown normalization type: {self.normalization_type}") def get_stats(self) -> Optional[Dict]: """导出统计信息(用于保存到 checkpoint)""" if not self.enabled: return None stats = { 'normalization_type': self.normalization_type, } if self.normalization_type == 'gaussian': stats['qpos_mean'] = self.qpos_mean.cpu().tolist() stats['qpos_std'] = self.qpos_std.cpu().tolist() stats['action_mean'] = self.action_mean.cpu().tolist() stats['action_std'] = self.action_std.cpu().tolist() elif self.normalization_type == 'min_max': stats['qpos_min'] = self.qpos_min.cpu().tolist() stats['qpos_max'] = self.qpos_max.cpu().tolist() stats['action_min'] = self.action_min.cpu().tolist() stats['action_max'] = self.action_max.cpu().tolist() return stats