debug: 归一化
This commit is contained in:
@@ -14,20 +14,18 @@ from typing import Optional, Dict, Literal
|
||||
class NormalizationModule(nn.Module):
|
||||
"""
|
||||
统一的归一化模块
|
||||
|
||||
用于在 Agent 内部对 qpos 和 action 进行归一化/反归一化
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stats: Optional[Dict] = None,
|
||||
normalization_type: Literal['gaussian', 'min_max'] = 'min_max'
|
||||
normalization_type: Literal['gaussian', 'min_max'] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
stats: 数据集统计信息字典,格式:
|
||||
{
|
||||
'normalization_type': 'gaussian' | 'min_max',
|
||||
'qpos_mean': [...],
|
||||
'qpos_std': [...],
|
||||
'qpos_min': [...], # 仅 min_max 需要
|
||||
@@ -45,26 +43,17 @@ class NormalizationModule(nn.Module):
|
||||
self.enabled = stats is not None
|
||||
|
||||
if self.enabled:
|
||||
# 从 stats 中读取归一化类型(如果提供)
|
||||
self.normalization_type = stats.get('normalization_type', normalization_type)
|
||||
if self.normalization_type == 'gaussian':
|
||||
self.register_buffer('qpos_mean', torch.tensor(stats['qpos_mean'], dtype=torch.float32))
|
||||
self.register_buffer('qpos_std', torch.tensor(stats['qpos_std'], dtype=torch.float32))
|
||||
self.register_buffer('action_mean', torch.tensor(stats['action_mean'], dtype=torch.float32))
|
||||
self.register_buffer('action_std', torch.tensor(stats['action_std'], dtype=torch.float32))
|
||||
|
||||
# 注册为 buffer (不会被优化,但会随模型保存)
|
||||
self.register_buffer('qpos_mean', torch.tensor(stats['qpos_mean'], dtype=torch.float32))
|
||||
self.register_buffer('qpos_std', torch.tensor(stats['qpos_std'], dtype=torch.float32))
|
||||
self.register_buffer('action_mean', torch.tensor(stats['action_mean'], dtype=torch.float32))
|
||||
self.register_buffer('action_std', torch.tensor(stats['action_std'], dtype=torch.float32))
|
||||
|
||||
# MinMax 归一化需要 min/max
|
||||
if self.normalization_type == 'min_max':
|
||||
qpos_min = stats.get('qpos_min', [0.0] * len(stats['qpos_mean']))
|
||||
qpos_max = stats.get('qpos_max', [1.0] * len(stats['qpos_mean']))
|
||||
action_min = stats.get('action_min', [0.0] * len(stats['action_mean']))
|
||||
action_max = stats.get('action_max', [1.0] * len(stats['action_mean']))
|
||||
|
||||
self.register_buffer('qpos_min', torch.tensor(qpos_min, dtype=torch.float32))
|
||||
self.register_buffer('qpos_max', torch.tensor(qpos_max, dtype=torch.float32))
|
||||
self.register_buffer('action_min', torch.tensor(action_min, dtype=torch.float32))
|
||||
self.register_buffer('action_max', torch.tensor(action_max, dtype=torch.float32))
|
||||
elif self.normalization_type == 'min_max':
|
||||
self.register_buffer('qpos_min', torch.tensor(stats['qpos_min'], dtype=torch.float32))
|
||||
self.register_buffer('qpos_max', torch.tensor(stats['qpos_max'], dtype=torch.float32))
|
||||
self.register_buffer('action_min', torch.tensor(stats['action_min'], dtype=torch.float32))
|
||||
self.register_buffer('action_max', torch.tensor(stats['action_max'], dtype=torch.float32))
|
||||
|
||||
def normalize_qpos(self, qpos: torch.Tensor) -> torch.Tensor:
|
||||
"""归一化 qpos"""
|
||||
@@ -73,8 +62,10 @@ class NormalizationModule(nn.Module):
|
||||
|
||||
if self.normalization_type == 'gaussian':
|
||||
return (qpos - self.qpos_mean) / self.qpos_std
|
||||
else: # min_max
|
||||
elif self.normalization_type == 'min_max':
|
||||
return 2 * (qpos - self.qpos_min) / (self.qpos_max - self.qpos_min) - 1
|
||||
else:
|
||||
raise ValueError(f"Unknown normalization type: {self.normalization_type}")
|
||||
|
||||
def denormalize_qpos(self, qpos: torch.Tensor) -> torch.Tensor:
|
||||
"""反归一化 qpos"""
|
||||
@@ -83,8 +74,10 @@ class NormalizationModule(nn.Module):
|
||||
|
||||
if self.normalization_type == 'gaussian':
|
||||
return qpos * self.qpos_std + self.qpos_mean
|
||||
else: # min_max
|
||||
elif self.normalization_type == 'min_max':
|
||||
return (qpos + 1) / 2 * (self.qpos_max - self.qpos_min) + self.qpos_min
|
||||
else:
|
||||
raise ValueError(f"Unknown normalization type: {self.normalization_type}")
|
||||
|
||||
def normalize_action(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""归一化 action"""
|
||||
@@ -93,8 +86,10 @@ class NormalizationModule(nn.Module):
|
||||
|
||||
if self.normalization_type == 'gaussian':
|
||||
return (action - self.action_mean) / self.action_std
|
||||
else: # min_max
|
||||
elif self.normalization_type == 'min_max':
|
||||
return 2 * (action - self.action_min) / (self.action_max - self.action_min) - 1
|
||||
else:
|
||||
raise ValueError(f"Unknown normalization type: {self.normalization_type}")
|
||||
|
||||
def denormalize_action(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""反归一化 action"""
|
||||
@@ -103,8 +98,10 @@ class NormalizationModule(nn.Module):
|
||||
|
||||
if self.normalization_type == 'gaussian':
|
||||
return action * self.action_std + self.action_mean
|
||||
else: # min_max
|
||||
elif self.normalization_type == 'min_max':
|
||||
return (action + 1) / 2 * (self.action_max - self.action_min) + self.action_min
|
||||
else:
|
||||
raise ValueError(f"Unknown normalization type: {self.normalization_type}")
|
||||
|
||||
def get_stats(self) -> Optional[Dict]:
|
||||
"""导出统计信息(用于保存到 checkpoint)"""
|
||||
@@ -113,13 +110,14 @@ class NormalizationModule(nn.Module):
|
||||
|
||||
stats = {
|
||||
'normalization_type': self.normalization_type,
|
||||
'qpos_mean': self.qpos_mean.cpu().tolist(),
|
||||
'qpos_std': self.qpos_std.cpu().tolist(),
|
||||
'action_mean': self.action_mean.cpu().tolist(),
|
||||
'action_std': self.action_std.cpu().tolist(),
|
||||
}
|
||||
|
||||
if self.normalization_type == 'min_max':
|
||||
if self.normalization_type == 'gaussian':
|
||||
stats['qpos_mean'] = self.qpos_mean.cpu().tolist()
|
||||
stats['qpos_std'] = self.qpos_std.cpu().tolist()
|
||||
stats['action_mean'] = self.action_mean.cpu().tolist()
|
||||
stats['action_std'] = self.action_std.cpu().tolist()
|
||||
elif self.normalization_type == 'min_max':
|
||||
stats['qpos_min'] = self.qpos_min.cpu().tolist()
|
||||
stats['qpos_max'] = self.qpos_max.cpu().tolist()
|
||||
stats['action_min'] = self.action_min.cpu().tolist()
|
||||
|
||||
Reference in New Issue
Block a user