Merge branch 'dev' of gitlab.com:leeeezd0016-group/gouhanke-vla into dev
This commit is contained in:
@@ -94,9 +94,51 @@ class VLAAgent(nn.Module):
|
|||||||
B = actions.shape[0]
|
B = actions.shape[0]
|
||||||
|
|
||||||
# 归一化 states (qpos) 和 actions
|
# 归一化 states (qpos) 和 actions
|
||||||
|
# ======== 归一化测试代码 (调试用) ========
|
||||||
|
if not hasattr(self, '_norm_test_done'):
|
||||||
|
self._norm_test_done = True
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("归一化测试 - 第一个batch:")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# 检查action归一化
|
||||||
|
action_orig = batch['action'].clone()
|
||||||
|
print(f"Action 原始范围: [{action_orig.min():.4f}, {action_orig.max():.4f}]")
|
||||||
|
print(f"Action 各维度范围 (前5维):")
|
||||||
|
for i in range(min(5, action_orig.shape[-1])):
|
||||||
|
print(f" 维度{i}: [{action_orig[..., i].min():.4f}, {action_orig[..., i].max():.4f}]")
|
||||||
|
|
||||||
|
# 检查qpos归一化
|
||||||
|
state_orig = states.clone()
|
||||||
|
print(f"Qpos 原始范围: [{state_orig.min():.4f}, {state_orig.max():.4f}]")
|
||||||
|
|
||||||
states = self.normalization.normalize_qpos(states)
|
states = self.normalization.normalize_qpos(states)
|
||||||
actions = self.normalization.normalize_action(actions)
|
actions = self.normalization.normalize_action(actions)
|
||||||
|
|
||||||
|
if hasattr(self, '_norm_test_done'):
|
||||||
|
print(f"Action 归一化后范围: [{actions.min():.4f}, {actions.max():.4f}]")
|
||||||
|
print(f"Qpos 归一化后范围: [{states.min():.4f}, {states.max():.4f}]")
|
||||||
|
|
||||||
|
# 检查是否在预期范围内
|
||||||
|
if self.normalization.normalization_type == 'min_max':
|
||||||
|
action_in_range = (actions >= -1.1) & (actions <= 1.1)
|
||||||
|
state_in_range = (states >= -1.1) & (states <= 1.1)
|
||||||
|
print(f"Action 在[-1,1]范围内: {action_in_range.all().item()}")
|
||||||
|
print(f"Qpos 在[-1,1]范围内: {state_in_range.all().item()}")
|
||||||
|
|
||||||
|
if not action_in_range.all():
|
||||||
|
print(f"⚠️ Action超出范围的维度:")
|
||||||
|
for i in range(actions.shape[-1]):
|
||||||
|
if not action_in_range[..., i].all():
|
||||||
|
print(f" 维度{i}: min={actions[..., i].min():.4f}, max={actions[..., i].max():.4f}")
|
||||||
|
if not state_in_range.all():
|
||||||
|
print(f"⚠️ Qpos超出范围的维度:")
|
||||||
|
for i in range(states.shape[-1]):
|
||||||
|
if not state_in_range[..., i].all():
|
||||||
|
print(f" 维度{i}: min={states[..., i].min():.4f}, max={states[..., i].max():.4f}")
|
||||||
|
print("=" * 60 + "\n")
|
||||||
|
# ======== 归一化测试代码结束 ========
|
||||||
|
|
||||||
state_features = self.state_encoder(states)
|
state_features = self.state_encoder(states)
|
||||||
|
|
||||||
# 1. 提取视觉特征
|
# 1. 提取视觉特征
|
||||||
@@ -316,9 +358,24 @@ class VLAAgent(nn.Module):
|
|||||||
"""
|
"""
|
||||||
B = proprioception.shape[0]
|
B = proprioception.shape[0]
|
||||||
|
|
||||||
|
# ======== 推理归一化测试代码 (调试用) ========
|
||||||
|
if not hasattr(self, '_infer_norm_test_done'):
|
||||||
|
self._infer_norm_test_done = True
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("推理归一化测试 - 第一个推理batch:")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"Qpos输入范围: [{proprioception.min():.4f}, {proprioception.max():.4f}]")
|
||||||
# 归一化 proprioception (qpos)
|
# 归一化 proprioception (qpos)
|
||||||
|
proprioception_orig = proprioception.clone()
|
||||||
proprioception = self.normalization.normalize_qpos(proprioception)
|
proprioception = self.normalization.normalize_qpos(proprioception)
|
||||||
|
|
||||||
|
if hasattr(self, '_infer_norm_test_done'):
|
||||||
|
print(f"Qpos归一化后范围: [{proprioception.min():.4f}, {proprioception.max():.4f}]")
|
||||||
|
if self.normalization.normalization_type == 'min_max':
|
||||||
|
in_range = (proprioception >= -1.1) & (proprioception <= 1.1)
|
||||||
|
print(f"Qpos在[-1,1]范围内: {in_range.all().item()}")
|
||||||
|
# ======== 推理归一化测试代码结束 ========
|
||||||
|
|
||||||
# 1. 提取当前观测特征(只提取一次)
|
# 1. 提取当前观测特征(只提取一次)
|
||||||
visual_features = self.vision_encoder(images)
|
visual_features = self.vision_encoder(images)
|
||||||
state_features = self.state_encoder(proprioception)
|
state_features = self.state_encoder(proprioception)
|
||||||
@@ -356,8 +413,17 @@ class VLAAgent(nn.Module):
|
|||||||
).prev_sample
|
).prev_sample
|
||||||
|
|
||||||
# 4. 反归一化动作序列
|
# 4. 反归一化动作序列
|
||||||
|
# ======== 反归一化测试代码 (调试用) ========
|
||||||
|
if hasattr(self, '_infer_norm_test_done'):
|
||||||
|
print(f"去噪后action范围 (归一化空间): [{current_actions.min():.4f}, {current_actions.max():.4f}]")
|
||||||
|
|
||||||
denormalized_actions = self.normalization.denormalize_action(current_actions)
|
denormalized_actions = self.normalization.denormalize_action(current_actions)
|
||||||
|
|
||||||
|
if hasattr(self, '_infer_norm_test_done'):
|
||||||
|
print(f"反归一化后action范围: [{denormalized_actions.min():.4f}, {denormalized_actions.max():.4f}]")
|
||||||
|
print("=" * 60 + "\n")
|
||||||
|
# ======== 反归一化测试代码结束 ========
|
||||||
|
|
||||||
return denormalized_actions
|
return denormalized_actions
|
||||||
|
|
||||||
def get_normalization_stats(self):
|
def get_normalization_stats(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user