127 lines
5.1 KiB
Python
127 lines
5.1 KiB
Python
"""
|
||
归一化模块 - 统一训练和推理的归一化逻辑
|
||
|
||
支持两种归一化方式:
|
||
1. Gaussian (z-score): (x - mean) / std
|
||
2. MinMax: 2 * (x - min) / (max - min) - 1 -> [-1, 1]
|
||
"""
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
from typing import Optional, Dict, Literal
|
||
|
||
|
||
class NormalizationModule(nn.Module):
|
||
"""
|
||
统一的归一化模块
|
||
用于在 Agent 内部对 qpos 和 action 进行归一化/反归一化
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
stats: Optional[Dict] = None,
|
||
normalization_type: Literal['gaussian', 'min_max'] = None,
|
||
):
|
||
"""
|
||
Args:
|
||
stats: 数据集统计信息字典,格式:
|
||
{
|
||
'qpos_mean': [...],
|
||
'qpos_std': [...],
|
||
'qpos_min': [...], # 仅 min_max 需要
|
||
'qpos_max': [...], # 仅 min_max 需要
|
||
'action_mean': [...],
|
||
'action_std': [...],
|
||
'action_min': [...], # 仅 min_max 需要
|
||
'action_max': [...], # 仅 min_max 需要
|
||
}
|
||
normalization_type: 归一化类型 ('gaussian' 或 'min_max')
|
||
"""
|
||
super().__init__()
|
||
|
||
self.normalization_type = normalization_type
|
||
self.enabled = stats is not None
|
||
|
||
if self.enabled:
|
||
if self.normalization_type == 'gaussian':
|
||
self.register_buffer('qpos_mean', torch.tensor(stats['qpos_mean'], dtype=torch.float32))
|
||
self.register_buffer('qpos_std', torch.tensor(stats['qpos_std'], dtype=torch.float32))
|
||
self.register_buffer('action_mean', torch.tensor(stats['action_mean'], dtype=torch.float32))
|
||
self.register_buffer('action_std', torch.tensor(stats['action_std'], dtype=torch.float32))
|
||
|
||
elif self.normalization_type == 'min_max':
|
||
self.register_buffer('qpos_min', torch.tensor(stats['qpos_min'], dtype=torch.float32))
|
||
self.register_buffer('qpos_max', torch.tensor(stats['qpos_max'], dtype=torch.float32))
|
||
self.register_buffer('action_min', torch.tensor(stats['action_min'], dtype=torch.float32))
|
||
self.register_buffer('action_max', torch.tensor(stats['action_max'], dtype=torch.float32))
|
||
|
||
def normalize_qpos(self, qpos: torch.Tensor) -> torch.Tensor:
|
||
"""归一化 qpos"""
|
||
if not self.enabled:
|
||
return qpos
|
||
|
||
if self.normalization_type == 'gaussian':
|
||
return (qpos - self.qpos_mean) / self.qpos_std
|
||
elif self.normalization_type == 'min_max':
|
||
return 2 * (qpos - self.qpos_min) / (self.qpos_max - self.qpos_min) - 1
|
||
else:
|
||
raise ValueError(f"Unknown normalization type: {self.normalization_type}")
|
||
|
||
def denormalize_qpos(self, qpos: torch.Tensor) -> torch.Tensor:
|
||
"""反归一化 qpos"""
|
||
if not self.enabled:
|
||
return qpos
|
||
|
||
if self.normalization_type == 'gaussian':
|
||
return qpos * self.qpos_std + self.qpos_mean
|
||
elif self.normalization_type == 'min_max':
|
||
return (qpos + 1) / 2 * (self.qpos_max - self.qpos_min) + self.qpos_min
|
||
else:
|
||
raise ValueError(f"Unknown normalization type: {self.normalization_type}")
|
||
|
||
def normalize_action(self, action: torch.Tensor) -> torch.Tensor:
|
||
"""归一化 action"""
|
||
if not self.enabled:
|
||
return action
|
||
|
||
if self.normalization_type == 'gaussian':
|
||
return (action - self.action_mean) / self.action_std
|
||
elif self.normalization_type == 'min_max':
|
||
return 2 * (action - self.action_min) / (self.action_max - self.action_min) - 1
|
||
else:
|
||
raise ValueError(f"Unknown normalization type: {self.normalization_type}")
|
||
|
||
def denormalize_action(self, action: torch.Tensor) -> torch.Tensor:
|
||
"""反归一化 action"""
|
||
if not self.enabled:
|
||
return action
|
||
|
||
if self.normalization_type == 'gaussian':
|
||
return action * self.action_std + self.action_mean
|
||
elif self.normalization_type == 'min_max':
|
||
return (action + 1) / 2 * (self.action_max - self.action_min) + self.action_min
|
||
else:
|
||
raise ValueError(f"Unknown normalization type: {self.normalization_type}")
|
||
|
||
def get_stats(self) -> Optional[Dict]:
|
||
"""导出统计信息(用于保存到 checkpoint)"""
|
||
if not self.enabled:
|
||
return None
|
||
|
||
stats = {
|
||
'normalization_type': self.normalization_type,
|
||
}
|
||
|
||
if self.normalization_type == 'gaussian':
|
||
stats['qpos_mean'] = self.qpos_mean.cpu().tolist()
|
||
stats['qpos_std'] = self.qpos_std.cpu().tolist()
|
||
stats['action_mean'] = self.action_mean.cpu().tolist()
|
||
stats['action_std'] = self.action_std.cpu().tolist()
|
||
elif self.normalization_type == 'min_max':
|
||
stats['qpos_min'] = self.qpos_min.cpu().tolist()
|
||
stats['qpos_max'] = self.qpos_max.cpu().tolist()
|
||
stats['action_min'] = self.action_min.cpu().tolist()
|
||
stats['action_max'] = self.action_max.cpu().tolist()
|
||
|
||
return stats
|