Files
roboimi/test_transformer_head.py
2026-02-28 19:07:27 +08:00

167 lines
4.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
测试Transformer1D Head
验证:
1. 模型初始化
2. 前向传播
3. 与VLAAgent集成
"""
import torch
import sys
sys.path.append('.')
def test_transformer_standalone():
"""测试独立的Transformer1D模型"""
print("=" * 80)
print("测试1: Transformer1D 独立模型")
print("=" * 80)
from roboimi.vla.models.heads.transformer1d import Transformer1D
# 配置
B = 4
T = 16
action_dim = 16
obs_horizon = 2
# 注意Transformer的cond_dim是指每步条件的维度不是总维度
# cond: (B, obs_horizon, cond_dim_per_step)
cond_dim_per_step = 208 # 64*3 + 16 = 192 + 16 = 208
# 创建模型
model = Transformer1D(
input_dim=action_dim,
output_dim=action_dim,
horizon=T,
n_obs_steps=obs_horizon,
cond_dim=cond_dim_per_step, # 每步的维度
n_layer=4,
n_head=8,
n_emb=256,
causal_attn=False
)
# 测试前向传播
sample = torch.randn(B, T, action_dim)
timestep = torch.randint(0, 100, (B,))
cond = torch.randn(B, obs_horizon, cond_dim_per_step)
output = model(sample, timestep, cond)
print(f"\n✅ 输入:")
print(f" sample: {sample.shape}")
print(f" timestep: {timestep.shape}")
print(f" cond: {cond.shape}")
print(f"\n✅ 输出:")
print(f" output: {output.shape}")
assert output.shape == (B, T, action_dim), f"输出形状错误: {output.shape}"
print(f"\n✅ 测试通过!")
def test_transformer_with_agent():
"""测试Transformer与VLAAgent集成"""
print("\n" + "=" * 80)
print("测试2: Transformer + VLAAgent 集成")
print("=" * 80)
from roboimi.vla.agent import VLAAgent
from roboimi.vla.models.backbones.resnet_diffusion import ResNetDiffusionBackbone
from roboimi.vla.modules.encoders import IdentityStateEncoder, IdentityActionEncoder
from roboimi.vla.models.heads.transformer1d import Transformer1D
from omegaconf import OmegaConf
# 创建简单的配置
vision_backbone = ResNetDiffusionBackbone(
vision_backbone="resnet18",
pretrained_backbone_weights=None,
input_shape=(3, 84, 84),
use_group_norm=True,
spatial_softmax_num_keypoints=32,
freeze_backbone=False,
use_separate_rgb_encoder_per_camera=False,
num_cameras=1
)
state_encoder = IdentityStateEncoder()
action_encoder = IdentityActionEncoder()
# 创建Transformer head
action_dim = 16
obs_dim = 16
pred_horizon = 16
obs_horizon = 2
num_cams = 1
# 计算条件维度
single_cam_feat_dim = vision_backbone.output_dim # 64
# 每步的条件维度不乘以obs_horizon
per_step_cond_dim = single_cam_feat_dim * num_cams + obs_dim # 64 * 1 + 16 = 80
transformer_head = Transformer1D(
input_dim=action_dim,
output_dim=action_dim,
horizon=pred_horizon,
n_obs_steps=obs_horizon,
cond_dim=per_step_cond_dim, # 每步的维度,不是总维度!
n_layer=4,
n_head=8,
n_emb=128,
causal_attn=False
)
# 创建Agent
agent = VLAAgent(
vision_backbone=vision_backbone,
state_encoder=state_encoder,
action_encoder=action_encoder,
head=transformer_head,
action_dim=action_dim,
obs_dim=obs_dim,
pred_horizon=pred_horizon,
obs_horizon=obs_horizon,
diffusion_steps=100,
inference_steps=10,
num_cams=num_cams,
dataset_stats=None,
normalization_type='min_max',
num_action_steps=8,
head_type='transformer'
)
print(f"\n✅ VLAAgent with Transformer创建成功")
print(f" head_type: {agent.head_type}")
print(f" 参数量: {sum(p.numel() for p in agent.parameters()):,}")
# 测试前向传播
B = 2
batch = {
'images': {'cam0': torch.randn(B, obs_horizon, 3, 84, 84)},
'qpos': torch.randn(B, obs_horizon, obs_dim),
'action': torch.randn(B, pred_horizon, action_dim)
}
loss = agent.compute_loss(batch)
print(f"\n✅ 训练loss: {loss.item():.4f}")
# 测试推理
agent.eval()
with torch.no_grad():
actions = agent.predict_action(batch['images'], batch['qpos'])
print(f"✅ 推理输出shape: {actions.shape}")
print(f"\n✅ 集成测试通过!")
if __name__ == "__main__":
try:
test_transformer_standalone()
test_transformer_with_agent()
print("\n" + "=" * 80)
print("🎉 所有测试通过!")
print("=" * 80)
except Exception as e:
print(f"\n❌ 测试失败: {e}")
import traceback
traceback.print_exc()