chore: 验证归一化是否有效

This commit is contained in:
gouhanke
2026-02-12 13:01:13 +08:00
parent 37a47ac2dd
commit 116ba13fb9

View File

@@ -94,9 +94,51 @@ class VLAAgent(nn.Module):
B = actions.shape[0] B = actions.shape[0]
# 归一化 states (qpos) 和 actions # 归一化 states (qpos) 和 actions
# ======== 归一化测试代码 (调试用) ========
if not hasattr(self, '_norm_test_done'):
self._norm_test_done = True
print("\n" + "=" * 60)
print("归一化测试 - 第一个batch:")
print("=" * 60)
# 检查action归一化
action_orig = batch['action'].clone()
print(f"Action 原始范围: [{action_orig.min():.4f}, {action_orig.max():.4f}]")
print(f"Action 各维度范围 (前5维):")
for i in range(min(5, action_orig.shape[-1])):
print(f" 维度{i}: [{action_orig[..., i].min():.4f}, {action_orig[..., i].max():.4f}]")
# 检查qpos归一化
state_orig = states.clone()
print(f"Qpos 原始范围: [{state_orig.min():.4f}, {state_orig.max():.4f}]")
states = self.normalization.normalize_qpos(states) states = self.normalization.normalize_qpos(states)
actions = self.normalization.normalize_action(actions) actions = self.normalization.normalize_action(actions)
if hasattr(self, '_norm_test_done'):
print(f"Action 归一化后范围: [{actions.min():.4f}, {actions.max():.4f}]")
print(f"Qpos 归一化后范围: [{states.min():.4f}, {states.max():.4f}]")
# 检查是否在预期范围内
if self.normalization.normalization_type == 'min_max':
action_in_range = (actions >= -1.1) & (actions <= 1.1)
state_in_range = (states >= -1.1) & (states <= 1.1)
print(f"Action 在[-1,1]范围内: {action_in_range.all().item()}")
print(f"Qpos 在[-1,1]范围内: {state_in_range.all().item()}")
if not action_in_range.all():
print(f"⚠️ Action超出范围的维度:")
for i in range(actions.shape[-1]):
if not action_in_range[..., i].all():
print(f" 维度{i}: min={actions[..., i].min():.4f}, max={actions[..., i].max():.4f}")
if not state_in_range.all():
print(f"⚠️ Qpos超出范围的维度:")
for i in range(states.shape[-1]):
if not state_in_range[..., i].all():
print(f" 维度{i}: min={states[..., i].min():.4f}, max={states[..., i].max():.4f}")
print("=" * 60 + "\n")
# ======== 归一化测试代码结束 ========
state_features = self.state_encoder(states) state_features = self.state_encoder(states)
# 1. 提取视觉特征 # 1. 提取视觉特征
@@ -316,9 +358,24 @@ class VLAAgent(nn.Module):
""" """
B = proprioception.shape[0] B = proprioception.shape[0]
# ======== 推理归一化测试代码 (调试用) ========
if not hasattr(self, '_infer_norm_test_done'):
self._infer_norm_test_done = True
print("\n" + "=" * 60)
print("推理归一化测试 - 第一个推理batch:")
print("=" * 60)
print(f"Qpos输入范围: [{proprioception.min():.4f}, {proprioception.max():.4f}]")
# 归一化 proprioception (qpos) # 归一化 proprioception (qpos)
proprioception_orig = proprioception.clone()
proprioception = self.normalization.normalize_qpos(proprioception) proprioception = self.normalization.normalize_qpos(proprioception)
if hasattr(self, '_infer_norm_test_done'):
print(f"Qpos归一化后范围: [{proprioception.min():.4f}, {proprioception.max():.4f}]")
if self.normalization.normalization_type == 'min_max':
in_range = (proprioception >= -1.1) & (proprioception <= 1.1)
print(f"Qpos在[-1,1]范围内: {in_range.all().item()}")
# ======== 推理归一化测试代码结束 ========
# 1. 提取当前观测特征(只提取一次) # 1. 提取当前观测特征(只提取一次)
visual_features = self.vision_encoder(images) visual_features = self.vision_encoder(images)
state_features = self.state_encoder(proprioception) state_features = self.state_encoder(proprioception)
@@ -356,8 +413,17 @@ class VLAAgent(nn.Module):
).prev_sample ).prev_sample
# 4. 反归一化动作序列 # 4. 反归一化动作序列
# ======== 反归一化测试代码 (调试用) ========
if hasattr(self, '_infer_norm_test_done'):
print(f"去噪后action范围 (归一化空间): [{current_actions.min():.4f}, {current_actions.max():.4f}]")
denormalized_actions = self.normalization.denormalize_action(current_actions) denormalized_actions = self.normalization.denormalize_action(current_actions)
if hasattr(self, '_infer_norm_test_done'):
print(f"反归一化后action范围: [{denormalized_actions.min():.4f}, {denormalized_actions.max():.4f}]")
print("=" * 60 + "\n")
# ======== 反归一化测试代码结束 ========
return denormalized_actions return denormalized_actions
def get_normalization_stats(self): def get_normalization_stats(self):