feat: 添加transformer头

This commit is contained in:
gouhanke
2026-02-28 19:07:27 +08:00
parent abb4f501e3
commit cdb887c9bf
7 changed files with 708 additions and 21 deletions

166
test_transformer_head.py Normal file
View File

@@ -0,0 +1,166 @@
"""
测试Transformer1D Head
验证:
1. 模型初始化
2. 前向传播
3. 与VLAAgent集成
"""
import torch
import sys
sys.path.append('.')
def test_transformer_standalone():
"""测试独立的Transformer1D模型"""
print("=" * 80)
print("测试1: Transformer1D 独立模型")
print("=" * 80)
from roboimi.vla.models.heads.transformer1d import Transformer1D
# 配置
B = 4
T = 16
action_dim = 16
obs_horizon = 2
# 注意Transformer的cond_dim是指每步条件的维度不是总维度
# cond: (B, obs_horizon, cond_dim_per_step)
cond_dim_per_step = 208 # 64*3 + 16 = 192 + 16 = 208
# 创建模型
model = Transformer1D(
input_dim=action_dim,
output_dim=action_dim,
horizon=T,
n_obs_steps=obs_horizon,
cond_dim=cond_dim_per_step, # 每步的维度
n_layer=4,
n_head=8,
n_emb=256,
causal_attn=False
)
# 测试前向传播
sample = torch.randn(B, T, action_dim)
timestep = torch.randint(0, 100, (B,))
cond = torch.randn(B, obs_horizon, cond_dim_per_step)
output = model(sample, timestep, cond)
print(f"\n✅ 输入:")
print(f" sample: {sample.shape}")
print(f" timestep: {timestep.shape}")
print(f" cond: {cond.shape}")
print(f"\n✅ 输出:")
print(f" output: {output.shape}")
assert output.shape == (B, T, action_dim), f"输出形状错误: {output.shape}"
print(f"\n✅ 测试通过!")
def test_transformer_with_agent():
"""测试Transformer与VLAAgent集成"""
print("\n" + "=" * 80)
print("测试2: Transformer + VLAAgent 集成")
print("=" * 80)
from roboimi.vla.agent import VLAAgent
from roboimi.vla.models.backbones.resnet_diffusion import ResNetDiffusionBackbone
from roboimi.vla.modules.encoders import IdentityStateEncoder, IdentityActionEncoder
from roboimi.vla.models.heads.transformer1d import Transformer1D
from omegaconf import OmegaConf
# 创建简单的配置
vision_backbone = ResNetDiffusionBackbone(
vision_backbone="resnet18",
pretrained_backbone_weights=None,
input_shape=(3, 84, 84),
use_group_norm=True,
spatial_softmax_num_keypoints=32,
freeze_backbone=False,
use_separate_rgb_encoder_per_camera=False,
num_cameras=1
)
state_encoder = IdentityStateEncoder()
action_encoder = IdentityActionEncoder()
# 创建Transformer head
action_dim = 16
obs_dim = 16
pred_horizon = 16
obs_horizon = 2
num_cams = 1
# 计算条件维度
single_cam_feat_dim = vision_backbone.output_dim # 64
# 每步的条件维度不乘以obs_horizon
per_step_cond_dim = single_cam_feat_dim * num_cams + obs_dim # 64 * 1 + 16 = 80
transformer_head = Transformer1D(
input_dim=action_dim,
output_dim=action_dim,
horizon=pred_horizon,
n_obs_steps=obs_horizon,
cond_dim=per_step_cond_dim, # 每步的维度,不是总维度!
n_layer=4,
n_head=8,
n_emb=128,
causal_attn=False
)
# 创建Agent
agent = VLAAgent(
vision_backbone=vision_backbone,
state_encoder=state_encoder,
action_encoder=action_encoder,
head=transformer_head,
action_dim=action_dim,
obs_dim=obs_dim,
pred_horizon=pred_horizon,
obs_horizon=obs_horizon,
diffusion_steps=100,
inference_steps=10,
num_cams=num_cams,
dataset_stats=None,
normalization_type='min_max',
num_action_steps=8,
head_type='transformer'
)
print(f"\n✅ VLAAgent with Transformer创建成功")
print(f" head_type: {agent.head_type}")
print(f" 参数量: {sum(p.numel() for p in agent.parameters()):,}")
# 测试前向传播
B = 2
batch = {
'images': {'cam0': torch.randn(B, obs_horizon, 3, 84, 84)},
'qpos': torch.randn(B, obs_horizon, obs_dim),
'action': torch.randn(B, pred_horizon, action_dim)
}
loss = agent.compute_loss(batch)
print(f"\n✅ 训练loss: {loss.item():.4f}")
# 测试推理
agent.eval()
with torch.no_grad():
actions = agent.predict_action(batch['images'], batch['qpos'])
print(f"✅ 推理输出shape: {actions.shape}")
print(f"\n✅ 集成测试通过!")
if __name__ == "__main__":
try:
test_transformer_standalone()
test_transformer_with_agent()
print("\n" + "=" * 80)
print("🎉 所有测试通过!")
print("=" * 80)
except Exception as e:
print(f"\n❌ 测试失败: {e}")
import traceback
traceback.print_exc()