feat: 添加transformer头

This commit is contained in:
gouhanke
2026-02-28 19:07:27 +08:00
parent abb4f501e3
commit cdb887c9bf
7 changed files with 708 additions and 21 deletions

View File

@@ -27,6 +27,7 @@ class VLAAgent(nn.Module):
dataset_stats=None, # 数据集统计信息,用于归一化 dataset_stats=None, # 数据集统计信息,用于归一化
normalization_type='min_max', # 归一化类型: 'gaussian' 或 'min_max' normalization_type='min_max', # 归一化类型: 'gaussian' 或 'min_max'
num_action_steps=8, # 每次推理实际执行多少步动作 num_action_steps=8, # 每次推理实际执行多少步动作
head_type='unet', # Policy head类型: 'unet' 或 'transformer'
): ):
super().__init__() super().__init__()
# 保存参数 # 保存参数
@@ -37,6 +38,7 @@ class VLAAgent(nn.Module):
self.num_cams = num_cams self.num_cams = num_cams
self.num_action_steps = num_action_steps self.num_action_steps = num_action_steps
self.inference_steps = inference_steps self.inference_steps = inference_steps
self.head_type = head_type # 'unet' 或 'transformer'
# 归一化模块 - 统一训练和推理的归一化逻辑 # 归一化模块 - 统一训练和推理的归一化逻辑
@@ -47,10 +49,15 @@ class VLAAgent(nn.Module):
self.vision_encoder = vision_backbone self.vision_encoder = vision_backbone
single_cam_feat_dim = self.vision_encoder.output_dim single_cam_feat_dim = self.vision_encoder.output_dim
# global_cond_dim: 展平后的总维度用于UNet
total_vision_dim = single_cam_feat_dim * num_cams * obs_horizon total_vision_dim = single_cam_feat_dim * num_cams * obs_horizon
total_prop_dim = obs_dim * obs_horizon total_prop_dim = obs_dim * obs_horizon
self.global_cond_dim = total_vision_dim + total_prop_dim self.global_cond_dim = total_vision_dim + total_prop_dim
# per_step_cond_dim: 每步的条件维度用于Transformer
# 注意这里不乘以obs_horizon因为Transformer的输入是序列形式
self.per_step_cond_dim = single_cam_feat_dim * num_cams + obs_dim
self.noise_scheduler = DDPMScheduler( self.noise_scheduler = DDPMScheduler(
num_train_timesteps=diffusion_steps, num_train_timesteps=diffusion_steps,
beta_schedule='squaredcos_cap_v2', # 机器人任务常用的 schedule beta_schedule='squaredcos_cap_v2', # 机器人任务常用的 schedule
@@ -66,11 +73,27 @@ class VLAAgent(nn.Module):
prediction_type='epsilon' prediction_type='epsilon'
) )
self.noise_pred_net = head( # 根据head类型初始化不同的参数
input_dim=action_dim, if head_type == 'transformer':
# input_dim = action_dim + obs_dim, # 备选:包含观测维度 # 如果head已经是nn.Module实例直接使用否则需要初始化
global_cond_dim=self.global_cond_dim if isinstance(head, nn.Module):
) # 已经是实例化的模块测试时直接传入<E4BCA0><E585A5>
self.noise_pred_net = head
else:
# Hydra部分初始化的对象调用时传入参数
self.noise_pred_net = head(
input_dim=action_dim,
output_dim=action_dim,
horizon=pred_horizon,
n_obs_steps=obs_horizon,
cond_dim=self.per_step_cond_dim # 每步的条件维度
)
else: # 'unet' (default)
# UNet接口: input_dim, global_cond_dim
self.noise_pred_net = head(
input_dim=action_dim,
global_cond_dim=self.global_cond_dim
)
self.state_encoder = state_encoder self.state_encoder = state_encoder
self.action_encoder = action_encoder self.action_encoder = action_encoder
@@ -124,13 +147,22 @@ class VLAAgent(nn.Module):
global_cond = torch.cat([visual_features, state_features], dim=-1) global_cond = torch.cat([visual_features, state_features], dim=-1)
global_cond = global_cond.flatten(start_dim=1) global_cond = global_cond.flatten(start_dim=1)
# 5. 网络预测噪声根据head类型选择接口
# 5. 网络预测噪声 if self.head_type == 'transformer':
pred_noise = self.noise_pred_net( # Transformer需要序列格式的条件: (B, obs_horizon, cond_dim_per_step)
sample=noisy_actions, # 将展平的global_cond reshape回序列格式
timestep=timesteps, cond = global_cond.reshape(B, self.obs_horizon, self.per_step_cond_dim)
global_cond=global_cond pred_noise = self.noise_pred_net(
) sample=noisy_actions,
timestep=timesteps,
cond=cond
)
else: # 'unet'
pred_noise = self.noise_pred_net(
sample=noisy_actions,
timestep=timesteps,
global_cond=global_cond
)
# 6. 计算 Loss (MSE),支持 padding mask # 6. 计算 Loss (MSE),支持 padding mask
loss = nn.functional.mse_loss(pred_noise, noise, reduction='none') loss = nn.functional.mse_loss(pred_noise, noise, reduction='none')
@@ -343,12 +375,21 @@ class VLAAgent(nn.Module):
global_cond = torch.cat([visual_features, state_features], dim=-1) global_cond = torch.cat([visual_features, state_features], dim=-1)
global_cond = global_cond.flatten(start_dim=1) global_cond = global_cond.flatten(start_dim=1)
# 预测噪声 # 预测噪声根据head类型选择接口
noise_pred = self.noise_pred_net( if self.head_type == 'transformer':
sample=model_input, # Transformer需要序列格式的条件
timestep=t, cond = global_cond.reshape(B, self.obs_horizon, self.per_step_cond_dim)
global_cond=global_cond noise_pred = self.noise_pred_net(
) sample=model_input,
timestep=t,
cond=cond
)
else: # 'unet'
noise_pred = self.noise_pred_net(
sample=model_input,
timestep=t,
global_cond=global_cond
)
# 移除噪声,更新 current_actions # 移除噪声,更新 current_actions
current_actions = self.infer_scheduler.step( current_actions = self.infer_scheduler.step(

View File

@@ -0,0 +1,54 @@
# @package agent
defaults:
- /backbone@vision_backbone: resnet_diffusion
- /modules@state_encoder: identity_state_encoder
- /modules@action_encoder: identity_action_encoder
- /head: transformer1d
- _self_
_target_: roboimi.vla.agent.VLAAgent
# ====================
# 模型维度配置
# ====================
action_dim: 16 # 动作维度(机器人关节数)
obs_dim: 16 # 本体感知维度(关节位置)
# ====================
# 归一化配置
# ====================
normalization_type: "min_max" # "min_max" or "gaussian"
# ====================
# 时间步配置
# ====================
pred_horizon: 16 # 预测未来多少步动作
obs_horizon: 2 # 使用多少步历史观测
num_action_steps: 8 # 每次推理实际执行多少步动作(应 <= pred_horizon - obs_horizon + 1
# ====================
# 相机配置
# ====================
num_cams: 3 # 摄像头数量 (r_vis, top, front)
# ====================
# 扩散过程配置
# ====================
diffusion_steps: 100 # 扩散训练步数DDPM
inference_steps: 10 # 推理时的去噪步数DDIM<4D><EFBC8C>定为 10
# ====================
# Head 类型标识用于VLAAgent选择调用方式
# ====================
head_type: "transformer" # "unet" 或 "transformer"
# Head 参数覆盖
head:
input_dim: ${agent.action_dim}
output_dim: ${agent.action_dim}
horizon: ${agent.pred_horizon}
n_obs_steps: ${agent.obs_horizon}
# Transformer的cond_dim是每步的维度
# ResNet18 + SpatialSoftmax(32 keypoints) = 64维/相机
# 计算方式:单相机特征(64) * 相机数(3) + obs_dim(16) = 208
cond_dim: 208

View File

@@ -1,5 +1,5 @@
defaults: defaults:
- agent: resnet_diffusion - agent: resnet_transformer
- data: simpe_robot_dataset - data: simpe_robot_dataset
- eval: eval - eval: eval
- _self_ - _self_

View File

@@ -0,0 +1,29 @@
# Transformer-based Diffusion Policy Head
_target_: roboimi.vla.models.heads.transformer1d.Transformer1D
_partial_: true
# ====================
# Transformer 架构配置
# ====================
n_layer: 8 # Transformer层数
n_head: 8 # 注意力头数
n_emb: 256 # 嵌入维度
p_drop_emb: 0.1 # Embedding dropout
p_drop_attn: 0.1 # Attention dropout
# ====================
# 条件配置
# ====================
causal_attn: false # 是否使用因果注意力(自回归生成)
obs_as_cond: true # 观测作为条件由cond_dim > 0决定
n_cond_layers: 0 # 条件编码器层数0表示使用MLP>0使用TransformerEncoder
# ====================
# 注意事项
# ====================
# 以下参数将在agent配置中通过interpolation提供
# - input_dim: ${agent.action_dim}
# - output_dim: ${agent.action_dim}
# - horizon: ${agent.pred_horizon}
# - n_obs_steps: ${agent.obs_horizon}
# - cond_dim: 通过agent中的global_cond_dim计算

View File

@@ -1,4 +1,5 @@
# # Action Head models # Action Head models
from .conditional_unet1d import ConditionalUnet1D from .conditional_unet1d import ConditionalUnet1D
from .transformer1d import Transformer1D
__all__ = ["ConditionalUnet1D"] __all__ = ["ConditionalUnet1D", "Transformer1D"]

View File

@@ -0,0 +1,396 @@
"""
Transformer-based Diffusion Policy Head
使用Transformer架构Encoder-Decoder替代UNet进行噪声预测。
支持通过Cross-Attention注入全局条件观测特征
"""
import math
import torch
import torch.nn as nn
from typing import Optional
class SinusoidalPosEmb(nn.Module):
"""正弦位置编码(用于时间步嵌入)"""
def __init__(self, dim: int):
super().__init__()
self.dim = dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class Transformer1D(nn.Module):
"""
Transformer-based 1D Diffusion Model
使用Encoder-Decoder架构
- Encoder: 处理条件(观测 + 时间步)
- Decoder: 通过Cross-Attention预测噪声
Args:
input_dim: 输入动作维度
output_dim: 输出动作维度
horizon: 预测horizon长度
n_obs_steps: 观测步数
cond_dim: 条件维度
n_layer: Transformer层数
n_head: 注意力头数
n_emb: 嵌入维度
p_drop_emb: Embedding dropout
p_drop_attn: Attention dropout
causal_attn: 是否使用因果注意力(自回归)
n_cond_layers: Encoder层数0表示使用MLP
"""
def __init__(
self,
input_dim: int,
output_dim: int,
horizon: int,
n_obs_steps: int = None,
cond_dim: int = 0,
n_layer: int = 8,
n_head: int = 8,
n_emb: int = 256,
p_drop_emb: float = 0.1,
p_drop_attn: float = 0.1,
causal_attn: bool = False,
obs_as_cond: bool = False,
n_cond_layers: int = 0
):
super().__init__()
# 计算序列长度
if n_obs_steps is None:
n_obs_steps = horizon
T = horizon
T_cond = 1 # 时间步token数量
# 确定是否使用观测作为条件
obs_as_cond = cond_dim > 0
if obs_as_cond:
T_cond += n_obs_steps
# 保存配置
self.T = T
self.T_cond = T_cond
self.horizon = horizon
self.obs_as_cond = obs_as_cond
self.input_dim = input_dim
self.output_dim = output_dim
# ==================== 输入嵌入 ====================
self.input_emb = nn.Linear(input_dim, n_emb)
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
self.drop = nn.Dropout(p_drop_emb)
# ==================== 条件编码 ====================
# 时间步嵌入
self.time_emb = SinusoidalPosEmb(n_emb)
# 观测条件嵌入(可选)
self.cond_obs_emb = None
if obs_as_cond:
self.cond_obs_emb = nn.Linear(cond_dim, n_emb)
# 条件位置编码
self.cond_pos_emb = None
if T_cond > 0:
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
# ==================== Encoder ====================
self.encoder = None
self.encoder_only = False
if T_cond > 0:
if n_cond_layers > 0:
# 使用Transformer Encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4 * n_emb,
dropout=p_drop_attn,
activation='gelu',
batch_first=True,
norm_first=True # Pre-LN更稳定
)
self.encoder = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=n_cond_layers
)
else:
# 使用简单的MLP
self.encoder = nn.Sequential(
nn.Linear(n_emb, 4 * n_emb),
nn.Mish(),
nn.Linear(4 * n_emb, n_emb)
)
else:
# Encoder-only模式BERT风格
self.encoder_only = True
encoder_layer = nn.TransformerEncoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4 * n_emb,
dropout=p_drop_attn,
activation='gelu',
batch_first=True,
norm_first=True
)
self.encoder = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=n_layer
)
# ==================== Attention Mask ====================
self.mask = None
self.memory_mask = None
if causal_attn:
# 因果mask确保只关注左侧
sz = T
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
self.register_buffer("mask", mask)
if obs_as_cond:
# 交叉注意力mask
S = T_cond
t, s = torch.meshgrid(
torch.arange(T),
torch.arange(S),
indexing='ij'
)
mask = t >= (s - 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
self.register_buffer('memory_mask', mask)
# ==================== Decoder ====================
if not self.encoder_only:
decoder_layer = nn.TransformerDecoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4 * n_emb,
dropout=p_drop_attn,
activation='gelu',
batch_first=True,
norm_first=True
)
self.decoder = nn.TransformerDecoder(
decoder_layer=decoder_layer,
num_layers=n_layer
)
# ==================== 输出头 ====================
self.ln_f = nn.LayerNorm(n_emb)
self.head = nn.Linear(n_emb, output_dim)
# ==================== 初始化 ====================
self.apply(self._init_weights)
# 打印参数量
total_params = sum(p.numel() for p in self.parameters())
print(f"Transformer1D parameters: {total_params:,}")
def _init_weights(self, module):
"""初始化权重"""
if isinstance(module, (nn.Linear, nn.Embedding)):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.MultiheadAttention):
# MultiheadAttention的权重初始化
for name in ['in_proj_weight', 'q_proj_weight', 'k_proj_weight', 'v_proj_weight']:
weight = getattr(module, name, None)
if weight is not None:
torch.nn.init.normal_(weight, mean=0.0, std=0.02)
for name in ['in_proj_bias', 'bias_k', 'bias_v']:
bias = getattr(module, name, None)
if bias is not None:
torch.nn.init.zeros_(bias)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight)
elif isinstance(module, Transformer1D):
# 位置编码初始化
torch.nn.init.normal_(self.pos_emb, mean=0.0, std=0.02)
if self.cond_pos_emb is not None:
torch.nn.init.normal_(self.cond_pos_emb, mean=0.0, std=0.02)
def forward(
self,
sample: torch.Tensor,
timestep: torch.Tensor,
cond: Optional[torch.Tensor] = None,
**kwargs
):
"""
前向传播
Args:
sample: (B, T, input_dim) 输入序列(加噪动作)
timestep: (B,) 时间步
cond: (B, T', cond_dim) 条件序列(观测特征)
Returns:
(B, T, output_dim) 预测的噪声
"""
# ==================== 处理时间步 ====================
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# 扩展到batch维度
timesteps = timesteps.expand(sample.shape[0])
time_emb = self.time_emb(timesteps).unsqueeze(1) # (B, 1, n_emb)
# ==================== 处理输入 ====================
input_emb = self.input_emb(sample) # (B, T, n_emb)
# ==================== Encoder-Decoder模式 ====================
if not self.encoder_only:
# --- Encoder: 处理条件 ---
cond_embeddings = time_emb
if self.obs_as_cond and cond is not None:
# 添加观测条件
cond_obs_emb = self.cond_obs_emb(cond) # (B, T_cond-1, n_emb)
cond_embeddings = torch.cat([cond_embeddings, cond_obs_emb], dim=1)
# 添加位置编码
tc = cond_embeddings.shape[1]
pos_emb = self.cond_pos_emb[:, :tc, :]
x = self.drop(cond_embeddings + pos_emb)
# 通过encoder
memory = self.encoder(x) # (B, T_cond, n_emb)
# --- Decoder: 预测噪声 ---
# 添加位置编码到输入
token_embeddings = input_emb
t = token_embeddings.shape[1]
pos_emb = self.pos_emb[:, :t, :]
x = self.drop(token_embeddings + pos_emb)
# Cross-Attention: Query来自输入Key/Value来自memory
x = self.decoder(
tgt=x,
memory=memory,
tgt_mask=self.mask,
memory_mask=self.memory_mask
)
# ==================== Encoder-Only模式 ====================
else:
# BERT风格时间步作为特殊token
token_embeddings = torch.cat([time_emb, input_emb], dim=1)
t = token_embeddings.shape[1]
pos_emb = self.pos_emb[:, :t, :]
x = self.drop(token_embeddings + pos_emb)
x = self.encoder(src=x, mask=self.mask)
x = x[:, 1:, :] # 移除时间步token
# ==================== 输出头 ====================
x = self.ln_f(x)
x = self.head(x) # (B, T, output_dim)
return x
# ============================================================================
# 便捷函数创建Transformer1D模型
# ============================================================================
def create_transformer1d(
input_dim: int,
output_dim: int,
horizon: int,
n_obs_steps: int,
cond_dim: int,
n_layer: int = 8,
n_head: int = 8,
n_emb: int = 256,
**kwargs
) -> Transformer1D:
"""
创建Transformer1D模型的便捷函数
Args:
input_dim: 输入动作维度
output_dim: 输出动作维度
horizon: 预测horizon
n_obs_steps: 观测步数
cond_dim: 条件维度
n_layer: Transformer层数
n_head: 注意力头数
n_emb: 嵌入维度
**kwargs: 其他参数
Returns:
Transformer1D模型
"""
model = Transformer1D(
input_dim=input_dim,
output_dim=output_dim,
horizon=horizon,
n_obs_steps=n_obs_steps,
cond_dim=cond_dim,
n_layer=n_layer,
n_head=n_head,
n_emb=n_emb,
**kwargs
)
return model
if __name__ == "__main__":
print("=" * 80)
print("Testing Transformer1D")
print("=" * 80)
# 配置
B = 4
T = 16
action_dim = 16
obs_horizon = 2
cond_dim = 416 # vision + state特征维度
# 创建模型
model = Transformer1D(
input_dim=action_dim,
output_dim=action_dim,
horizon=T,
n_obs_steps=obs_horizon,
cond_dim=cond_dim,
n_layer=4,
n_head=8,
n_emb=256,
causal_attn=False
)
# 测试前向传播
sample = torch.randn(B, T, action_dim)
timestep = torch.randint(0, 100, (B,))
cond = torch.randn(B, obs_horizon, cond_dim)
output = model(sample, timestep, cond)
print(f"\n输入:")
print(f" sample: {sample.shape}")
print(f" timestep: {timestep.shape}")
print(f" cond: {cond.shape}")
print(f"\n输出:")
print(f" output: {output.shape}")
print(f"\n✅ 测试通过!")

166
test_transformer_head.py Normal file
View File

@@ -0,0 +1,166 @@
"""
测试Transformer1D Head
验证:
1. 模型初始化
2. 前向传播
3. 与VLAAgent集成
"""
import torch
import sys
sys.path.append('.')
def test_transformer_standalone():
"""测试独立的Transformer1D模型"""
print("=" * 80)
print("测试1: Transformer1D 独立模型")
print("=" * 80)
from roboimi.vla.models.heads.transformer1d import Transformer1D
# 配置
B = 4
T = 16
action_dim = 16
obs_horizon = 2
# 注意Transformer的cond_dim是指每步条件的维度不是总维度
# cond: (B, obs_horizon, cond_dim_per_step)
cond_dim_per_step = 208 # 64*3 + 16 = 192 + 16 = 208
# 创建模型
model = Transformer1D(
input_dim=action_dim,
output_dim=action_dim,
horizon=T,
n_obs_steps=obs_horizon,
cond_dim=cond_dim_per_step, # 每步的维度
n_layer=4,
n_head=8,
n_emb=256,
causal_attn=False
)
# 测试前向传播
sample = torch.randn(B, T, action_dim)
timestep = torch.randint(0, 100, (B,))
cond = torch.randn(B, obs_horizon, cond_dim_per_step)
output = model(sample, timestep, cond)
print(f"\n✅ 输入:")
print(f" sample: {sample.shape}")
print(f" timestep: {timestep.shape}")
print(f" cond: {cond.shape}")
print(f"\n✅ 输出:")
print(f" output: {output.shape}")
assert output.shape == (B, T, action_dim), f"输出形状错误: {output.shape}"
print(f"\n✅ 测试通过!")
def test_transformer_with_agent():
"""测试Transformer与VLAAgent集成"""
print("\n" + "=" * 80)
print("测试2: Transformer + VLAAgent 集成")
print("=" * 80)
from roboimi.vla.agent import VLAAgent
from roboimi.vla.models.backbones.resnet_diffusion import ResNetDiffusionBackbone
from roboimi.vla.modules.encoders import IdentityStateEncoder, IdentityActionEncoder
from roboimi.vla.models.heads.transformer1d import Transformer1D
from omegaconf import OmegaConf
# 创建简单的配置
vision_backbone = ResNetDiffusionBackbone(
vision_backbone="resnet18",
pretrained_backbone_weights=None,
input_shape=(3, 84, 84),
use_group_norm=True,
spatial_softmax_num_keypoints=32,
freeze_backbone=False,
use_separate_rgb_encoder_per_camera=False,
num_cameras=1
)
state_encoder = IdentityStateEncoder()
action_encoder = IdentityActionEncoder()
# 创建Transformer head
action_dim = 16
obs_dim = 16
pred_horizon = 16
obs_horizon = 2
num_cams = 1
# 计算条件维度
single_cam_feat_dim = vision_backbone.output_dim # 64
# 每步的条件维度不乘以obs_horizon
per_step_cond_dim = single_cam_feat_dim * num_cams + obs_dim # 64 * 1 + 16 = 80
transformer_head = Transformer1D(
input_dim=action_dim,
output_dim=action_dim,
horizon=pred_horizon,
n_obs_steps=obs_horizon,
cond_dim=per_step_cond_dim, # 每步的维度,不是总维度!
n_layer=4,
n_head=8,
n_emb=128,
causal_attn=False
)
# 创建Agent
agent = VLAAgent(
vision_backbone=vision_backbone,
state_encoder=state_encoder,
action_encoder=action_encoder,
head=transformer_head,
action_dim=action_dim,
obs_dim=obs_dim,
pred_horizon=pred_horizon,
obs_horizon=obs_horizon,
diffusion_steps=100,
inference_steps=10,
num_cams=num_cams,
dataset_stats=None,
normalization_type='min_max',
num_action_steps=8,
head_type='transformer'
)
print(f"\n✅ VLAAgent with Transformer创建成功")
print(f" head_type: {agent.head_type}")
print(f" 参数量: {sum(p.numel() for p in agent.parameters()):,}")
# 测试前向传播
B = 2
batch = {
'images': {'cam0': torch.randn(B, obs_horizon, 3, 84, 84)},
'qpos': torch.randn(B, obs_horizon, obs_dim),
'action': torch.randn(B, pred_horizon, action_dim)
}
loss = agent.compute_loss(batch)
print(f"\n✅ 训练loss: {loss.item():.4f}")
# 测试推理
agent.eval()
with torch.no_grad():
actions = agent.predict_action(batch['images'], batch['qpos'])
print(f"✅ 推理输出shape: {actions.shape}")
print(f"\n✅ 集成测试通过!")
if __name__ == "__main__":
try:
test_transformer_standalone()
test_transformer_with_agent()
print("\n" + "=" * 80)
print("🎉 所有测试通过!")
print("=" * 80)
except Exception as e:
print(f"\n❌ 测试失败: {e}")
import traceback
traceback.print_exc()