feat: 添加transformer头
This commit is contained in:
166
test_transformer_head.py
Normal file
166
test_transformer_head.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
测试Transformer1D Head
|
||||
|
||||
验证:
|
||||
1. 模型初始化
|
||||
2. 前向传播
|
||||
3. 与VLAAgent集成
|
||||
"""
|
||||
|
||||
import torch
|
||||
import sys
|
||||
sys.path.append('.')
|
||||
|
||||
def test_transformer_standalone():
|
||||
"""测试独立的Transformer1D模型"""
|
||||
print("=" * 80)
|
||||
print("测试1: Transformer1D 独立模型")
|
||||
print("=" * 80)
|
||||
|
||||
from roboimi.vla.models.heads.transformer1d import Transformer1D
|
||||
|
||||
# 配置
|
||||
B = 4
|
||||
T = 16
|
||||
action_dim = 16
|
||||
obs_horizon = 2
|
||||
# 注意:Transformer的cond_dim是指每步条件的维度,不是总维度
|
||||
# cond: (B, obs_horizon, cond_dim_per_step)
|
||||
cond_dim_per_step = 208 # 64*3 + 16 = 192 + 16 = 208
|
||||
|
||||
# 创建模型
|
||||
model = Transformer1D(
|
||||
input_dim=action_dim,
|
||||
output_dim=action_dim,
|
||||
horizon=T,
|
||||
n_obs_steps=obs_horizon,
|
||||
cond_dim=cond_dim_per_step, # 每步的维度
|
||||
n_layer=4,
|
||||
n_head=8,
|
||||
n_emb=256,
|
||||
causal_attn=False
|
||||
)
|
||||
|
||||
# 测试前向传播
|
||||
sample = torch.randn(B, T, action_dim)
|
||||
timestep = torch.randint(0, 100, (B,))
|
||||
cond = torch.randn(B, obs_horizon, cond_dim_per_step)
|
||||
|
||||
output = model(sample, timestep, cond)
|
||||
|
||||
print(f"\n✅ 输入:")
|
||||
print(f" sample: {sample.shape}")
|
||||
print(f" timestep: {timestep.shape}")
|
||||
print(f" cond: {cond.shape}")
|
||||
print(f"\n✅ 输出:")
|
||||
print(f" output: {output.shape}")
|
||||
|
||||
assert output.shape == (B, T, action_dim), f"输出形状错误: {output.shape}"
|
||||
print(f"\n✅ 测试通过!")
|
||||
|
||||
|
||||
def test_transformer_with_agent():
|
||||
"""测试Transformer与VLAAgent集成"""
|
||||
print("\n" + "=" * 80)
|
||||
print("测试2: Transformer + VLAAgent 集成")
|
||||
print("=" * 80)
|
||||
|
||||
from roboimi.vla.agent import VLAAgent
|
||||
from roboimi.vla.models.backbones.resnet_diffusion import ResNetDiffusionBackbone
|
||||
from roboimi.vla.modules.encoders import IdentityStateEncoder, IdentityActionEncoder
|
||||
from roboimi.vla.models.heads.transformer1d import Transformer1D
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
# 创建简单的配置
|
||||
vision_backbone = ResNetDiffusionBackbone(
|
||||
vision_backbone="resnet18",
|
||||
pretrained_backbone_weights=None,
|
||||
input_shape=(3, 84, 84),
|
||||
use_group_norm=True,
|
||||
spatial_softmax_num_keypoints=32,
|
||||
freeze_backbone=False,
|
||||
use_separate_rgb_encoder_per_camera=False,
|
||||
num_cameras=1
|
||||
)
|
||||
|
||||
state_encoder = IdentityStateEncoder()
|
||||
action_encoder = IdentityActionEncoder()
|
||||
|
||||
# 创建Transformer head
|
||||
action_dim = 16
|
||||
obs_dim = 16
|
||||
pred_horizon = 16
|
||||
obs_horizon = 2
|
||||
num_cams = 1
|
||||
|
||||
# 计算条件维度
|
||||
single_cam_feat_dim = vision_backbone.output_dim # 64
|
||||
# 每步的条件维度(不乘以obs_horizon)
|
||||
per_step_cond_dim = single_cam_feat_dim * num_cams + obs_dim # 64 * 1 + 16 = 80
|
||||
|
||||
transformer_head = Transformer1D(
|
||||
input_dim=action_dim,
|
||||
output_dim=action_dim,
|
||||
horizon=pred_horizon,
|
||||
n_obs_steps=obs_horizon,
|
||||
cond_dim=per_step_cond_dim, # 每步的维度,不是总维度!
|
||||
n_layer=4,
|
||||
n_head=8,
|
||||
n_emb=128,
|
||||
causal_attn=False
|
||||
)
|
||||
|
||||
# 创建Agent
|
||||
agent = VLAAgent(
|
||||
vision_backbone=vision_backbone,
|
||||
state_encoder=state_encoder,
|
||||
action_encoder=action_encoder,
|
||||
head=transformer_head,
|
||||
action_dim=action_dim,
|
||||
obs_dim=obs_dim,
|
||||
pred_horizon=pred_horizon,
|
||||
obs_horizon=obs_horizon,
|
||||
diffusion_steps=100,
|
||||
inference_steps=10,
|
||||
num_cams=num_cams,
|
||||
dataset_stats=None,
|
||||
normalization_type='min_max',
|
||||
num_action_steps=8,
|
||||
head_type='transformer'
|
||||
)
|
||||
|
||||
print(f"\n✅ VLAAgent with Transformer创建成功")
|
||||
print(f" head_type: {agent.head_type}")
|
||||
print(f" 参数量: {sum(p.numel() for p in agent.parameters()):,}")
|
||||
|
||||
# 测试前向传播
|
||||
B = 2
|
||||
batch = {
|
||||
'images': {'cam0': torch.randn(B, obs_horizon, 3, 84, 84)},
|
||||
'qpos': torch.randn(B, obs_horizon, obs_dim),
|
||||
'action': torch.randn(B, pred_horizon, action_dim)
|
||||
}
|
||||
|
||||
loss = agent.compute_loss(batch)
|
||||
print(f"\n✅ 训练loss: {loss.item():.4f}")
|
||||
|
||||
# 测试推理
|
||||
agent.eval()
|
||||
with torch.no_grad():
|
||||
actions = agent.predict_action(batch['images'], batch['qpos'])
|
||||
print(f"✅ 推理输出shape: {actions.shape}")
|
||||
|
||||
print(f"\n✅ 集成测试通过!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
test_transformer_standalone()
|
||||
test_transformer_with_agent()
|
||||
print("\n" + "=" * 80)
|
||||
print("🎉 所有测试通过!")
|
||||
print("=" * 80)
|
||||
except Exception as e:
|
||||
print(f"\n❌ 测试失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
Reference in New Issue
Block a user