debug: 归一化

This commit is contained in:
gouhanke
2026-02-12 12:23:34 +08:00
parent 83cd55e67b
commit ab971b3f96
4 changed files with 92 additions and 74 deletions

View File

@@ -14,20 +14,18 @@ from typing import Optional, Dict, Literal
class NormalizationModule(nn.Module):
"""
统一的归一化模块
用于在 Agent 内部对 qpos 和 action 进行归一化/反归一化
"""
def __init__(
self,
stats: Optional[Dict] = None,
normalization_type: Literal['gaussian', 'min_max'] = 'min_max'
normalization_type: Literal['gaussian', 'min_max'] = None,
):
"""
Args:
stats: 数据集统计信息字典,格式:
{
'normalization_type': 'gaussian' | 'min_max',
'qpos_mean': [...],
'qpos_std': [...],
'qpos_min': [...], # 仅 min_max 需要
@@ -45,26 +43,17 @@ class NormalizationModule(nn.Module):
self.enabled = stats is not None
if self.enabled:
# 从 stats 中读取归一化类型(如果提供)
self.normalization_type = stats.get('normalization_type', normalization_type)
if self.normalization_type == 'gaussian':
self.register_buffer('qpos_mean', torch.tensor(stats['qpos_mean'], dtype=torch.float32))
self.register_buffer('qpos_std', torch.tensor(stats['qpos_std'], dtype=torch.float32))
self.register_buffer('action_mean', torch.tensor(stats['action_mean'], dtype=torch.float32))
self.register_buffer('action_std', torch.tensor(stats['action_std'], dtype=torch.float32))
# 注册为 buffer (不会被优化,但会随模型保存)
self.register_buffer('qpos_mean', torch.tensor(stats['qpos_mean'], dtype=torch.float32))
self.register_buffer('qpos_std', torch.tensor(stats['qpos_std'], dtype=torch.float32))
self.register_buffer('action_mean', torch.tensor(stats['action_mean'], dtype=torch.float32))
self.register_buffer('action_std', torch.tensor(stats['action_std'], dtype=torch.float32))
# MinMax 归一化需要 min/max
if self.normalization_type == 'min_max':
qpos_min = stats.get('qpos_min', [0.0] * len(stats['qpos_mean']))
qpos_max = stats.get('qpos_max', [1.0] * len(stats['qpos_mean']))
action_min = stats.get('action_min', [0.0] * len(stats['action_mean']))
action_max = stats.get('action_max', [1.0] * len(stats['action_mean']))
self.register_buffer('qpos_min', torch.tensor(qpos_min, dtype=torch.float32))
self.register_buffer('qpos_max', torch.tensor(qpos_max, dtype=torch.float32))
self.register_buffer('action_min', torch.tensor(action_min, dtype=torch.float32))
self.register_buffer('action_max', torch.tensor(action_max, dtype=torch.float32))
elif self.normalization_type == 'min_max':
self.register_buffer('qpos_min', torch.tensor(stats['qpos_min'], dtype=torch.float32))
self.register_buffer('qpos_max', torch.tensor(stats['qpos_max'], dtype=torch.float32))
self.register_buffer('action_min', torch.tensor(stats['action_min'], dtype=torch.float32))
self.register_buffer('action_max', torch.tensor(stats['action_max'], dtype=torch.float32))
def normalize_qpos(self, qpos: torch.Tensor) -> torch.Tensor:
"""归一化 qpos"""
@@ -73,8 +62,10 @@ class NormalizationModule(nn.Module):
if self.normalization_type == 'gaussian':
return (qpos - self.qpos_mean) / self.qpos_std
else: # min_max
elif self.normalization_type == 'min_max':
return 2 * (qpos - self.qpos_min) / (self.qpos_max - self.qpos_min) - 1
else:
raise ValueError(f"Unknown normalization type: {self.normalization_type}")
def denormalize_qpos(self, qpos: torch.Tensor) -> torch.Tensor:
"""反归一化 qpos"""
@@ -83,8 +74,10 @@ class NormalizationModule(nn.Module):
if self.normalization_type == 'gaussian':
return qpos * self.qpos_std + self.qpos_mean
else: # min_max
elif self.normalization_type == 'min_max':
return (qpos + 1) / 2 * (self.qpos_max - self.qpos_min) + self.qpos_min
else:
raise ValueError(f"Unknown normalization type: {self.normalization_type}")
def normalize_action(self, action: torch.Tensor) -> torch.Tensor:
"""归一化 action"""
@@ -93,8 +86,10 @@ class NormalizationModule(nn.Module):
if self.normalization_type == 'gaussian':
return (action - self.action_mean) / self.action_std
else: # min_max
elif self.normalization_type == 'min_max':
return 2 * (action - self.action_min) / (self.action_max - self.action_min) - 1
else:
raise ValueError(f"Unknown normalization type: {self.normalization_type}")
def denormalize_action(self, action: torch.Tensor) -> torch.Tensor:
"""反归一化 action"""
@@ -103,8 +98,10 @@ class NormalizationModule(nn.Module):
if self.normalization_type == 'gaussian':
return action * self.action_std + self.action_mean
else: # min_max
elif self.normalization_type == 'min_max':
return (action + 1) / 2 * (self.action_max - self.action_min) + self.action_min
else:
raise ValueError(f"Unknown normalization type: {self.normalization_type}")
def get_stats(self) -> Optional[Dict]:
"""导出统计信息(用于保存到 checkpoint"""
@@ -113,13 +110,14 @@ class NormalizationModule(nn.Module):
stats = {
'normalization_type': self.normalization_type,
'qpos_mean': self.qpos_mean.cpu().tolist(),
'qpos_std': self.qpos_std.cpu().tolist(),
'action_mean': self.action_mean.cpu().tolist(),
'action_std': self.action_std.cpu().tolist(),
}
if self.normalization_type == 'min_max':
if self.normalization_type == 'gaussian':
stats['qpos_mean'] = self.qpos_mean.cpu().tolist()
stats['qpos_std'] = self.qpos_std.cpu().tolist()
stats['action_mean'] = self.action_mean.cpu().tolist()
stats['action_std'] = self.action_std.cpu().tolist()
elif self.normalization_type == 'min_max':
stats['qpos_min'] = self.qpos_min.cpu().tolist()
stats['qpos_max'] = self.qpos_max.cpu().tolist()
stats['action_min'] = self.action_min.cpu().tolist()