167 lines
4.6 KiB
Python
167 lines
4.6 KiB
Python
"""
|
||
测试Transformer1D Head
|
||
|
||
验证:
|
||
1. 模型初始化
|
||
2. 前向传播
|
||
3. 与VLAAgent集成
|
||
"""
|
||
|
||
import torch
|
||
import sys
|
||
sys.path.append('.')
|
||
|
||
def test_transformer_standalone():
|
||
"""测试独立的Transformer1D模型"""
|
||
print("=" * 80)
|
||
print("测试1: Transformer1D 独立模型")
|
||
print("=" * 80)
|
||
|
||
from roboimi.vla.models.heads.transformer1d import Transformer1D
|
||
|
||
# 配置
|
||
B = 4
|
||
T = 16
|
||
action_dim = 16
|
||
obs_horizon = 2
|
||
# 注意:Transformer的cond_dim是指每步条件的维度,不是总维度
|
||
# cond: (B, obs_horizon, cond_dim_per_step)
|
||
cond_dim_per_step = 208 # 64*3 + 16 = 192 + 16 = 208
|
||
|
||
# 创建模型
|
||
model = Transformer1D(
|
||
input_dim=action_dim,
|
||
output_dim=action_dim,
|
||
horizon=T,
|
||
n_obs_steps=obs_horizon,
|
||
cond_dim=cond_dim_per_step, # 每步的维度
|
||
n_layer=4,
|
||
n_head=8,
|
||
n_emb=256,
|
||
causal_attn=False
|
||
)
|
||
|
||
# 测试前向传播
|
||
sample = torch.randn(B, T, action_dim)
|
||
timestep = torch.randint(0, 100, (B,))
|
||
cond = torch.randn(B, obs_horizon, cond_dim_per_step)
|
||
|
||
output = model(sample, timestep, cond)
|
||
|
||
print(f"\n✅ 输入:")
|
||
print(f" sample: {sample.shape}")
|
||
print(f" timestep: {timestep.shape}")
|
||
print(f" cond: {cond.shape}")
|
||
print(f"\n✅ 输出:")
|
||
print(f" output: {output.shape}")
|
||
|
||
assert output.shape == (B, T, action_dim), f"输出形状错误: {output.shape}"
|
||
print(f"\n✅ 测试通过!")
|
||
|
||
|
||
def test_transformer_with_agent():
|
||
"""测试Transformer与VLAAgent集成"""
|
||
print("\n" + "=" * 80)
|
||
print("测试2: Transformer + VLAAgent 集成")
|
||
print("=" * 80)
|
||
|
||
from roboimi.vla.agent import VLAAgent
|
||
from roboimi.vla.models.backbones.resnet_diffusion import ResNetDiffusionBackbone
|
||
from roboimi.vla.modules.encoders import IdentityStateEncoder, IdentityActionEncoder
|
||
from roboimi.vla.models.heads.transformer1d import Transformer1D
|
||
from omegaconf import OmegaConf
|
||
|
||
# 创建简单的配置
|
||
vision_backbone = ResNetDiffusionBackbone(
|
||
vision_backbone="resnet18",
|
||
pretrained_backbone_weights=None,
|
||
input_shape=(3, 84, 84),
|
||
use_group_norm=True,
|
||
spatial_softmax_num_keypoints=32,
|
||
freeze_backbone=False,
|
||
use_separate_rgb_encoder_per_camera=False,
|
||
num_cameras=1
|
||
)
|
||
|
||
state_encoder = IdentityStateEncoder()
|
||
action_encoder = IdentityActionEncoder()
|
||
|
||
# 创建Transformer head
|
||
action_dim = 16
|
||
obs_dim = 16
|
||
pred_horizon = 16
|
||
obs_horizon = 2
|
||
num_cams = 1
|
||
|
||
# 计算条件维度
|
||
single_cam_feat_dim = vision_backbone.output_dim # 64
|
||
# 每步的条件维度(不乘以obs_horizon)
|
||
per_step_cond_dim = single_cam_feat_dim * num_cams + obs_dim # 64 * 1 + 16 = 80
|
||
|
||
transformer_head = Transformer1D(
|
||
input_dim=action_dim,
|
||
output_dim=action_dim,
|
||
horizon=pred_horizon,
|
||
n_obs_steps=obs_horizon,
|
||
cond_dim=per_step_cond_dim, # 每步的维度,不是总维度!
|
||
n_layer=4,
|
||
n_head=8,
|
||
n_emb=128,
|
||
causal_attn=False
|
||
)
|
||
|
||
# 创建Agent
|
||
agent = VLAAgent(
|
||
vision_backbone=vision_backbone,
|
||
state_encoder=state_encoder,
|
||
action_encoder=action_encoder,
|
||
head=transformer_head,
|
||
action_dim=action_dim,
|
||
obs_dim=obs_dim,
|
||
pred_horizon=pred_horizon,
|
||
obs_horizon=obs_horizon,
|
||
diffusion_steps=100,
|
||
inference_steps=10,
|
||
num_cams=num_cams,
|
||
dataset_stats=None,
|
||
normalization_type='min_max',
|
||
num_action_steps=8,
|
||
head_type='transformer'
|
||
)
|
||
|
||
print(f"\n✅ VLAAgent with Transformer创建成功")
|
||
print(f" head_type: {agent.head_type}")
|
||
print(f" 参数量: {sum(p.numel() for p in agent.parameters()):,}")
|
||
|
||
# 测试前向传播
|
||
B = 2
|
||
batch = {
|
||
'images': {'cam0': torch.randn(B, obs_horizon, 3, 84, 84)},
|
||
'qpos': torch.randn(B, obs_horizon, obs_dim),
|
||
'action': torch.randn(B, pred_horizon, action_dim)
|
||
}
|
||
|
||
loss = agent.compute_loss(batch)
|
||
print(f"\n✅ 训练loss: {loss.item():.4f}")
|
||
|
||
# 测试推理
|
||
agent.eval()
|
||
with torch.no_grad():
|
||
actions = agent.predict_action(batch['images'], batch['qpos'])
|
||
print(f"✅ 推理输出shape: {actions.shape}")
|
||
|
||
print(f"\n✅ 集成测试通过!")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
try:
|
||
test_transformer_standalone()
|
||
test_transformer_with_agent()
|
||
print("\n" + "=" * 80)
|
||
print("🎉 所有测试通过!")
|
||
print("=" * 80)
|
||
except Exception as e:
|
||
print(f"\n❌ 测试失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|