Files
roboimi/roboimi/vla/models/normalization.py
2026-02-12 12:23:34 +08:00

127 lines
5.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
归一化模块 - 统一训练和推理的归一化逻辑
支持两种归一化方式:
1. Gaussian (z-score): (x - mean) / std
2. MinMax: 2 * (x - min) / (max - min) - 1 -> [-1, 1]
"""
import torch
import torch.nn as nn
from typing import Optional, Dict, Literal
class NormalizationModule(nn.Module):
"""
统一的归一化模块
用于在 Agent 内部对 qpos 和 action 进行归一化/反归一化
"""
def __init__(
self,
stats: Optional[Dict] = None,
normalization_type: Literal['gaussian', 'min_max'] = None,
):
"""
Args:
stats: 数据集统计信息字典,格式:
{
'qpos_mean': [...],
'qpos_std': [...],
'qpos_min': [...], # 仅 min_max 需要
'qpos_max': [...], # 仅 min_max 需要
'action_mean': [...],
'action_std': [...],
'action_min': [...], # 仅 min_max 需要
'action_max': [...], # 仅 min_max 需要
}
normalization_type: 归一化类型 ('gaussian''min_max')
"""
super().__init__()
self.normalization_type = normalization_type
self.enabled = stats is not None
if self.enabled:
if self.normalization_type == 'gaussian':
self.register_buffer('qpos_mean', torch.tensor(stats['qpos_mean'], dtype=torch.float32))
self.register_buffer('qpos_std', torch.tensor(stats['qpos_std'], dtype=torch.float32))
self.register_buffer('action_mean', torch.tensor(stats['action_mean'], dtype=torch.float32))
self.register_buffer('action_std', torch.tensor(stats['action_std'], dtype=torch.float32))
elif self.normalization_type == 'min_max':
self.register_buffer('qpos_min', torch.tensor(stats['qpos_min'], dtype=torch.float32))
self.register_buffer('qpos_max', torch.tensor(stats['qpos_max'], dtype=torch.float32))
self.register_buffer('action_min', torch.tensor(stats['action_min'], dtype=torch.float32))
self.register_buffer('action_max', torch.tensor(stats['action_max'], dtype=torch.float32))
def normalize_qpos(self, qpos: torch.Tensor) -> torch.Tensor:
"""归一化 qpos"""
if not self.enabled:
return qpos
if self.normalization_type == 'gaussian':
return (qpos - self.qpos_mean) / self.qpos_std
elif self.normalization_type == 'min_max':
return 2 * (qpos - self.qpos_min) / (self.qpos_max - self.qpos_min) - 1
else:
raise ValueError(f"Unknown normalization type: {self.normalization_type}")
def denormalize_qpos(self, qpos: torch.Tensor) -> torch.Tensor:
"""反归一化 qpos"""
if not self.enabled:
return qpos
if self.normalization_type == 'gaussian':
return qpos * self.qpos_std + self.qpos_mean
elif self.normalization_type == 'min_max':
return (qpos + 1) / 2 * (self.qpos_max - self.qpos_min) + self.qpos_min
else:
raise ValueError(f"Unknown normalization type: {self.normalization_type}")
def normalize_action(self, action: torch.Tensor) -> torch.Tensor:
"""归一化 action"""
if not self.enabled:
return action
if self.normalization_type == 'gaussian':
return (action - self.action_mean) / self.action_std
elif self.normalization_type == 'min_max':
return 2 * (action - self.action_min) / (self.action_max - self.action_min) - 1
else:
raise ValueError(f"Unknown normalization type: {self.normalization_type}")
def denormalize_action(self, action: torch.Tensor) -> torch.Tensor:
"""反归一化 action"""
if not self.enabled:
return action
if self.normalization_type == 'gaussian':
return action * self.action_std + self.action_mean
elif self.normalization_type == 'min_max':
return (action + 1) / 2 * (self.action_max - self.action_min) + self.action_min
else:
raise ValueError(f"Unknown normalization type: {self.normalization_type}")
def get_stats(self) -> Optional[Dict]:
"""导出统计信息(用于保存到 checkpoint"""
if not self.enabled:
return None
stats = {
'normalization_type': self.normalization_type,
}
if self.normalization_type == 'gaussian':
stats['qpos_mean'] = self.qpos_mean.cpu().tolist()
stats['qpos_std'] = self.qpos_std.cpu().tolist()
stats['action_mean'] = self.action_mean.cpu().tolist()
stats['action_std'] = self.action_std.cpu().tolist()
elif self.normalization_type == 'min_max':
stats['qpos_min'] = self.qpos_min.cpu().tolist()
stats['qpos_max'] = self.qpos_max.cpu().tolist()
stats['action_min'] = self.action_min.cpu().tolist()
stats['action_max'] = self.action_max.cpu().tolist()
return stats