diff --git a/roboimi/vla/agent.py b/roboimi/vla/agent.py index 34fa47c..477f65a 100644 --- a/roboimi/vla/agent.py +++ b/roboimi/vla/agent.py @@ -27,6 +27,7 @@ class VLAAgent(nn.Module): dataset_stats=None, # 数据集统计信息,用于归一化 normalization_type='min_max', # 归一化类型: 'gaussian' 或 'min_max' num_action_steps=8, # 每次推理实际执行多少步动作 + head_type='unet', # Policy head类型: 'unet' 或 'transformer' ): super().__init__() # 保存参数 @@ -37,6 +38,7 @@ class VLAAgent(nn.Module): self.num_cams = num_cams self.num_action_steps = num_action_steps self.inference_steps = inference_steps + self.head_type = head_type # 'unet' 或 'transformer' # 归一化模块 - 统一训练和推理的归一化逻辑 @@ -47,10 +49,15 @@ class VLAAgent(nn.Module): self.vision_encoder = vision_backbone single_cam_feat_dim = self.vision_encoder.output_dim + # global_cond_dim: 展平后的总维度(用于UNet) total_vision_dim = single_cam_feat_dim * num_cams * obs_horizon total_prop_dim = obs_dim * obs_horizon self.global_cond_dim = total_vision_dim + total_prop_dim + # per_step_cond_dim: 每步的条件维度(用于Transformer) + # 注意:这里不乘以obs_horizon,因为Transformer的输入是序列形式 + self.per_step_cond_dim = single_cam_feat_dim * num_cams + obs_dim + self.noise_scheduler = DDPMScheduler( num_train_timesteps=diffusion_steps, beta_schedule='squaredcos_cap_v2', # 机器人任务常用的 schedule @@ -66,11 +73,27 @@ class VLAAgent(nn.Module): prediction_type='epsilon' ) - self.noise_pred_net = head( - input_dim=action_dim, - # input_dim = action_dim + obs_dim, # 备选:包含观测维度 - global_cond_dim=self.global_cond_dim - ) + # 根据head类型初始化不同的参数 + if head_type == 'transformer': + # 如果head已经是nn.Module实例,直接使用;否则需要初始化 + if isinstance(head, nn.Module): + # 已经是实例化的模块(测试时直接传入�� + self.noise_pred_net = head + else: + # Hydra部分初始化的对象,调用时传入参数 + self.noise_pred_net = head( + input_dim=action_dim, + output_dim=action_dim, + horizon=pred_horizon, + n_obs_steps=obs_horizon, + cond_dim=self.per_step_cond_dim # 每步的条件维度 + ) + else: # 'unet' (default) + # UNet接口: input_dim, global_cond_dim + self.noise_pred_net = head( + input_dim=action_dim, + global_cond_dim=self.global_cond_dim + ) self.state_encoder = state_encoder self.action_encoder = action_encoder @@ -124,13 +147,22 @@ class VLAAgent(nn.Module): global_cond = torch.cat([visual_features, state_features], dim=-1) global_cond = global_cond.flatten(start_dim=1) - - # 5. 网络预测噪声 - pred_noise = self.noise_pred_net( - sample=noisy_actions, - timestep=timesteps, - global_cond=global_cond - ) + # 5. 网络预测噪声(根据head类型选择接口) + if self.head_type == 'transformer': + # Transformer需要序列格式的条件: (B, obs_horizon, cond_dim_per_step) + # 将展平的global_cond reshape回序列格式 + cond = global_cond.reshape(B, self.obs_horizon, self.per_step_cond_dim) + pred_noise = self.noise_pred_net( + sample=noisy_actions, + timestep=timesteps, + cond=cond + ) + else: # 'unet' + pred_noise = self.noise_pred_net( + sample=noisy_actions, + timestep=timesteps, + global_cond=global_cond + ) # 6. 计算 Loss (MSE),支持 padding mask loss = nn.functional.mse_loss(pred_noise, noise, reduction='none') @@ -343,12 +375,21 @@ class VLAAgent(nn.Module): global_cond = torch.cat([visual_features, state_features], dim=-1) global_cond = global_cond.flatten(start_dim=1) - # 预测噪声 - noise_pred = self.noise_pred_net( - sample=model_input, - timestep=t, - global_cond=global_cond - ) + # 预测噪声(根据head类型选择接口) + if self.head_type == 'transformer': + # Transformer需要序列格式的条件 + cond = global_cond.reshape(B, self.obs_horizon, self.per_step_cond_dim) + noise_pred = self.noise_pred_net( + sample=model_input, + timestep=t, + cond=cond + ) + else: # 'unet' + noise_pred = self.noise_pred_net( + sample=model_input, + timestep=t, + global_cond=global_cond + ) # 移除噪声,更新 current_actions current_actions = self.infer_scheduler.step( diff --git a/roboimi/vla/conf/agent/resnet_transformer.yaml b/roboimi/vla/conf/agent/resnet_transformer.yaml new file mode 100644 index 0000000..fd306a1 --- /dev/null +++ b/roboimi/vla/conf/agent/resnet_transformer.yaml @@ -0,0 +1,54 @@ +# @package agent +defaults: + - /backbone@vision_backbone: resnet_diffusion + - /modules@state_encoder: identity_state_encoder + - /modules@action_encoder: identity_action_encoder + - /head: transformer1d + - _self_ + +_target_: roboimi.vla.agent.VLAAgent + +# ==================== +# 模型维度配置 +# ==================== +action_dim: 16 # 动作维度(机器人关节数) +obs_dim: 16 # 本体感知维度(关节位置) + +# ==================== +# 归一化配置 +# ==================== +normalization_type: "min_max" # "min_max" or "gaussian" + +# ==================== +# 时间步配置 +# ==================== +pred_horizon: 16 # 预测未来多少步动作 +obs_horizon: 2 # 使用多少步历史观测 +num_action_steps: 8 # 每次推理实际执行多少步动作(应 <= pred_horizon - obs_horizon + 1) + +# ==================== +# 相机配置 +# ==================== +num_cams: 3 # 摄像头数量 (r_vis, top, front) + +# ==================== +# 扩散过程配置 +# ==================== +diffusion_steps: 100 # 扩散训练步数(DDPM) +inference_steps: 10 # 推理时的去噪步数(DDIM,��定为 10) + +# ==================== +# Head 类型标识(用于VLAAgent选择调用方式) +# ==================== +head_type: "transformer" # "unet" 或 "transformer" + +# Head 参数覆盖 +head: + input_dim: ${agent.action_dim} + output_dim: ${agent.action_dim} + horizon: ${agent.pred_horizon} + n_obs_steps: ${agent.obs_horizon} + # Transformer的cond_dim是每步的维度 + # ResNet18 + SpatialSoftmax(32 keypoints) = 64维/相机 + # 计算方式:单相机特征(64) * 相机数(3) + obs_dim(16) = 208 + cond_dim: 208 diff --git a/roboimi/vla/conf/config.yaml b/roboimi/vla/conf/config.yaml index 8d14c93..ee4d75e 100644 --- a/roboimi/vla/conf/config.yaml +++ b/roboimi/vla/conf/config.yaml @@ -1,5 +1,5 @@ defaults: - - agent: resnet_diffusion + - agent: resnet_transformer - data: simpe_robot_dataset - eval: eval - _self_ diff --git a/roboimi/vla/conf/head/transformer1d.yaml b/roboimi/vla/conf/head/transformer1d.yaml new file mode 100644 index 0000000..5fad467 --- /dev/null +++ b/roboimi/vla/conf/head/transformer1d.yaml @@ -0,0 +1,29 @@ +# Transformer-based Diffusion Policy Head +_target_: roboimi.vla.models.heads.transformer1d.Transformer1D +_partial_: true + +# ==================== +# Transformer 架构配置 +# ==================== +n_layer: 8 # Transformer层数 +n_head: 8 # 注意力头数 +n_emb: 256 # 嵌入维度 +p_drop_emb: 0.1 # Embedding dropout +p_drop_attn: 0.1 # Attention dropout + +# ==================== +# 条件配置 +# ==================== +causal_attn: false # 是否使用因果注意力(自回归生成) +obs_as_cond: true # 观测作为条件(由cond_dim > 0决定) +n_cond_layers: 0 # 条件编码器层数(0表示使用MLP,>0使用TransformerEncoder) + +# ==================== +# 注意事项 +# ==================== +# 以下参数将在agent配置中通过interpolation提供: +# - input_dim: ${agent.action_dim} +# - output_dim: ${agent.action_dim} +# - horizon: ${agent.pred_horizon} +# - n_obs_steps: ${agent.obs_horizon} +# - cond_dim: 通过agent中的global_cond_dim计算 diff --git a/roboimi/vla/models/heads/__init__.py b/roboimi/vla/models/heads/__init__.py index 601a467..9e4ba5c 100644 --- a/roboimi/vla/models/heads/__init__.py +++ b/roboimi/vla/models/heads/__init__.py @@ -1,4 +1,5 @@ -# # Action Head models +# Action Head models from .conditional_unet1d import ConditionalUnet1D +from .transformer1d import Transformer1D -__all__ = ["ConditionalUnet1D"] +__all__ = ["ConditionalUnet1D", "Transformer1D"] diff --git a/roboimi/vla/models/heads/transformer1d.py b/roboimi/vla/models/heads/transformer1d.py new file mode 100644 index 0000000..8d517d8 --- /dev/null +++ b/roboimi/vla/models/heads/transformer1d.py @@ -0,0 +1,396 @@ +""" +Transformer-based Diffusion Policy Head + +使用Transformer架构(Encoder-Decoder)替代UNet进行噪声预测。 +支持通过Cross-Attention注入全局条件(观测特征)。 +""" + +import math +import torch +import torch.nn as nn +from typing import Optional + + +class SinusoidalPosEmb(nn.Module): + """正弦位置编码(用于时间步嵌入)""" + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class Transformer1D(nn.Module): + """ + Transformer-based 1D Diffusion Model + + 使用Encoder-Decoder架构: + - Encoder: 处理条件(观测 + 时间步) + - Decoder: 通过Cross-Attention预测噪声 + + Args: + input_dim: 输入动作维度 + output_dim: 输出动作维度 + horizon: 预测horizon长度 + n_obs_steps: 观测步数 + cond_dim: 条件维度 + n_layer: Transformer层数 + n_head: 注意力头数 + n_emb: 嵌入维度 + p_drop_emb: Embedding dropout + p_drop_attn: Attention dropout + causal_attn: 是否使用因果注意力(自回归) + n_cond_layers: Encoder层数(0表示使用MLP) + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + horizon: int, + n_obs_steps: int = None, + cond_dim: int = 0, + n_layer: int = 8, + n_head: int = 8, + n_emb: int = 256, + p_drop_emb: float = 0.1, + p_drop_attn: float = 0.1, + causal_attn: bool = False, + obs_as_cond: bool = False, + n_cond_layers: int = 0 + ): + super().__init__() + + # 计算序列长度 + if n_obs_steps is None: + n_obs_steps = horizon + + T = horizon + T_cond = 1 # 时间步token数量 + + # 确定是否使用观测作为条件 + obs_as_cond = cond_dim > 0 + if obs_as_cond: + T_cond += n_obs_steps + + # 保存配置 + self.T = T + self.T_cond = T_cond + self.horizon = horizon + self.obs_as_cond = obs_as_cond + self.input_dim = input_dim + self.output_dim = output_dim + + # ==================== 输入嵌入 ==================== + self.input_emb = nn.Linear(input_dim, n_emb) + self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb)) + self.drop = nn.Dropout(p_drop_emb) + + # ==================== 条件编码 ==================== + # 时间步嵌入 + self.time_emb = SinusoidalPosEmb(n_emb) + + # 观测条件嵌入(可选) + self.cond_obs_emb = None + if obs_as_cond: + self.cond_obs_emb = nn.Linear(cond_dim, n_emb) + + # 条件位置编码 + self.cond_pos_emb = None + if T_cond > 0: + self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb)) + + # ==================== Encoder ==================== + self.encoder = None + self.encoder_only = False + + if T_cond > 0: + if n_cond_layers > 0: + # 使用Transformer Encoder + encoder_layer = nn.TransformerEncoderLayer( + d_model=n_emb, + nhead=n_head, + dim_feedforward=4 * n_emb, + dropout=p_drop_attn, + activation='gelu', + batch_first=True, + norm_first=True # Pre-LN更稳定 + ) + self.encoder = nn.TransformerEncoder( + encoder_layer=encoder_layer, + num_layers=n_cond_layers + ) + else: + # 使用简单的MLP + self.encoder = nn.Sequential( + nn.Linear(n_emb, 4 * n_emb), + nn.Mish(), + nn.Linear(4 * n_emb, n_emb) + ) + else: + # Encoder-only模式(BERT风格) + self.encoder_only = True + encoder_layer = nn.TransformerEncoderLayer( + d_model=n_emb, + nhead=n_head, + dim_feedforward=4 * n_emb, + dropout=p_drop_attn, + activation='gelu', + batch_first=True, + norm_first=True + ) + self.encoder = nn.TransformerEncoder( + encoder_layer=encoder_layer, + num_layers=n_layer + ) + + # ==================== Attention Mask ==================== + self.mask = None + self.memory_mask = None + + if causal_attn: + # 因果mask:确保只关注左侧 + sz = T + mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) + mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) + self.register_buffer("mask", mask) + + if obs_as_cond: + # 交叉注意力mask + S = T_cond + t, s = torch.meshgrid( + torch.arange(T), + torch.arange(S), + indexing='ij' + ) + mask = t >= (s - 1) + mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) + self.register_buffer('memory_mask', mask) + + # ==================== Decoder ==================== + if not self.encoder_only: + decoder_layer = nn.TransformerDecoderLayer( + d_model=n_emb, + nhead=n_head, + dim_feedforward=4 * n_emb, + dropout=p_drop_attn, + activation='gelu', + batch_first=True, + norm_first=True + ) + self.decoder = nn.TransformerDecoder( + decoder_layer=decoder_layer, + num_layers=n_layer + ) + + # ==================== 输出头 ==================== + self.ln_f = nn.LayerNorm(n_emb) + self.head = nn.Linear(n_emb, output_dim) + + # ==================== 初始化 ==================== + self.apply(self._init_weights) + + # 打印参数量 + total_params = sum(p.numel() for p in self.parameters()) + print(f"Transformer1D parameters: {total_params:,}") + + def _init_weights(self, module): + """初始化权重""" + if isinstance(module, (nn.Linear, nn.Embedding)): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.MultiheadAttention): + # MultiheadAttention的权重初始化 + for name in ['in_proj_weight', 'q_proj_weight', 'k_proj_weight', 'v_proj_weight']: + weight = getattr(module, name, None) + if weight is not None: + torch.nn.init.normal_(weight, mean=0.0, std=0.02) + + for name in ['in_proj_bias', 'bias_k', 'bias_v']: + bias = getattr(module, name, None) + if bias is not None: + torch.nn.init.zeros_(bias) + elif isinstance(module, nn.LayerNorm): + torch.nn.init.zeros_(module.bias) + torch.nn.init.ones_(module.weight) + elif isinstance(module, Transformer1D): + # 位置编码初始化 + torch.nn.init.normal_(self.pos_emb, mean=0.0, std=0.02) + if self.cond_pos_emb is not None: + torch.nn.init.normal_(self.cond_pos_emb, mean=0.0, std=0.02) + + def forward( + self, + sample: torch.Tensor, + timestep: torch.Tensor, + cond: Optional[torch.Tensor] = None, + **kwargs + ): + """ + 前向传播 + + Args: + sample: (B, T, input_dim) 输入序列(加噪动作) + timestep: (B,) 时间步 + cond: (B, T', cond_dim) 条件序列(观测特征) + + Returns: + (B, T, output_dim) 预测的噪声 + """ + # ==================== 处理时间步 ==================== + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # 扩展到batch维度 + timesteps = timesteps.expand(sample.shape[0]) + time_emb = self.time_emb(timesteps).unsqueeze(1) # (B, 1, n_emb) + + # ==================== 处理输入 ==================== + input_emb = self.input_emb(sample) # (B, T, n_emb) + + # ==================== Encoder-Decoder模式 ==================== + if not self.encoder_only: + # --- Encoder: 处理条件 --- + cond_embeddings = time_emb + + if self.obs_as_cond and cond is not None: + # 添加观测条件 + cond_obs_emb = self.cond_obs_emb(cond) # (B, T_cond-1, n_emb) + cond_embeddings = torch.cat([cond_embeddings, cond_obs_emb], dim=1) + + # 添加位置编码 + tc = cond_embeddings.shape[1] + pos_emb = self.cond_pos_emb[:, :tc, :] + x = self.drop(cond_embeddings + pos_emb) + + # 通过encoder + memory = self.encoder(x) # (B, T_cond, n_emb) + + # --- Decoder: 预测噪声 --- + # 添加位置编码到输入 + token_embeddings = input_emb + t = token_embeddings.shape[1] + pos_emb = self.pos_emb[:, :t, :] + x = self.drop(token_embeddings + pos_emb) + + # Cross-Attention: Query来自输入,Key/Value来自memory + x = self.decoder( + tgt=x, + memory=memory, + tgt_mask=self.mask, + memory_mask=self.memory_mask + ) + + # ==================== Encoder-Only模式 ==================== + else: + # BERT风格:时间步作为特殊token + token_embeddings = torch.cat([time_emb, input_emb], dim=1) + t = token_embeddings.shape[1] + pos_emb = self.pos_emb[:, :t, :] + x = self.drop(token_embeddings + pos_emb) + + x = self.encoder(src=x, mask=self.mask) + x = x[:, 1:, :] # 移除时间步token + + # ==================== 输出头 ==================== + x = self.ln_f(x) + x = self.head(x) # (B, T, output_dim) + + return x + + +# ============================================================================ +# 便捷函数:创建Transformer1D模型 +# ============================================================================ +def create_transformer1d( + input_dim: int, + output_dim: int, + horizon: int, + n_obs_steps: int, + cond_dim: int, + n_layer: int = 8, + n_head: int = 8, + n_emb: int = 256, + **kwargs +) -> Transformer1D: + """ + 创建Transformer1D模型的便捷函数 + + Args: + input_dim: 输入动作维度 + output_dim: 输出动作维度 + horizon: 预测horizon + n_obs_steps: 观测步数 + cond_dim: 条件维度 + n_layer: Transformer层数 + n_head: 注意力头数 + n_emb: 嵌入维度 + **kwargs: 其他参数 + + Returns: + Transformer1D模型 + """ + model = Transformer1D( + input_dim=input_dim, + output_dim=output_dim, + horizon=horizon, + n_obs_steps=n_obs_steps, + cond_dim=cond_dim, + n_layer=n_layer, + n_head=n_head, + n_emb=n_emb, + **kwargs + ) + return model + + +if __name__ == "__main__": + print("=" * 80) + print("Testing Transformer1D") + print("=" * 80) + + # 配置 + B = 4 + T = 16 + action_dim = 16 + obs_horizon = 2 + cond_dim = 416 # vision + state特征维度 + + # 创建模型 + model = Transformer1D( + input_dim=action_dim, + output_dim=action_dim, + horizon=T, + n_obs_steps=obs_horizon, + cond_dim=cond_dim, + n_layer=4, + n_head=8, + n_emb=256, + causal_attn=False + ) + + # 测试前向传播 + sample = torch.randn(B, T, action_dim) + timestep = torch.randint(0, 100, (B,)) + cond = torch.randn(B, obs_horizon, cond_dim) + + output = model(sample, timestep, cond) + + print(f"\n输入:") + print(f" sample: {sample.shape}") + print(f" timestep: {timestep.shape}") + print(f" cond: {cond.shape}") + print(f"\n输出:") + print(f" output: {output.shape}") + print(f"\n✅ 测试通过!") diff --git a/test_transformer_head.py b/test_transformer_head.py new file mode 100644 index 0000000..a95df49 --- /dev/null +++ b/test_transformer_head.py @@ -0,0 +1,166 @@ +""" +测试Transformer1D Head + +验证: +1. 模型初始化 +2. 前向传播 +3. 与VLAAgent集成 +""" + +import torch +import sys +sys.path.append('.') + +def test_transformer_standalone(): + """测试独立的Transformer1D模型""" + print("=" * 80) + print("测试1: Transformer1D 独立模型") + print("=" * 80) + + from roboimi.vla.models.heads.transformer1d import Transformer1D + + # 配置 + B = 4 + T = 16 + action_dim = 16 + obs_horizon = 2 + # 注意:Transformer的cond_dim是指每步条件的维度,不是总维度 + # cond: (B, obs_horizon, cond_dim_per_step) + cond_dim_per_step = 208 # 64*3 + 16 = 192 + 16 = 208 + + # 创建模型 + model = Transformer1D( + input_dim=action_dim, + output_dim=action_dim, + horizon=T, + n_obs_steps=obs_horizon, + cond_dim=cond_dim_per_step, # 每步的维度 + n_layer=4, + n_head=8, + n_emb=256, + causal_attn=False + ) + + # 测试前向传播 + sample = torch.randn(B, T, action_dim) + timestep = torch.randint(0, 100, (B,)) + cond = torch.randn(B, obs_horizon, cond_dim_per_step) + + output = model(sample, timestep, cond) + + print(f"\n✅ 输入:") + print(f" sample: {sample.shape}") + print(f" timestep: {timestep.shape}") + print(f" cond: {cond.shape}") + print(f"\n✅ 输出:") + print(f" output: {output.shape}") + + assert output.shape == (B, T, action_dim), f"输出形状错误: {output.shape}" + print(f"\n✅ 测试通过!") + + +def test_transformer_with_agent(): + """测试Transformer与VLAAgent集成""" + print("\n" + "=" * 80) + print("测试2: Transformer + VLAAgent 集成") + print("=" * 80) + + from roboimi.vla.agent import VLAAgent + from roboimi.vla.models.backbones.resnet_diffusion import ResNetDiffusionBackbone + from roboimi.vla.modules.encoders import IdentityStateEncoder, IdentityActionEncoder + from roboimi.vla.models.heads.transformer1d import Transformer1D + from omegaconf import OmegaConf + + # 创建简单的配置 + vision_backbone = ResNetDiffusionBackbone( + vision_backbone="resnet18", + pretrained_backbone_weights=None, + input_shape=(3, 84, 84), + use_group_norm=True, + spatial_softmax_num_keypoints=32, + freeze_backbone=False, + use_separate_rgb_encoder_per_camera=False, + num_cameras=1 + ) + + state_encoder = IdentityStateEncoder() + action_encoder = IdentityActionEncoder() + + # 创建Transformer head + action_dim = 16 + obs_dim = 16 + pred_horizon = 16 + obs_horizon = 2 + num_cams = 1 + + # 计算条件维度 + single_cam_feat_dim = vision_backbone.output_dim # 64 + # 每步的条件维度(不乘以obs_horizon) + per_step_cond_dim = single_cam_feat_dim * num_cams + obs_dim # 64 * 1 + 16 = 80 + + transformer_head = Transformer1D( + input_dim=action_dim, + output_dim=action_dim, + horizon=pred_horizon, + n_obs_steps=obs_horizon, + cond_dim=per_step_cond_dim, # 每步的维度,不是总维度! + n_layer=4, + n_head=8, + n_emb=128, + causal_attn=False + ) + + # 创建Agent + agent = VLAAgent( + vision_backbone=vision_backbone, + state_encoder=state_encoder, + action_encoder=action_encoder, + head=transformer_head, + action_dim=action_dim, + obs_dim=obs_dim, + pred_horizon=pred_horizon, + obs_horizon=obs_horizon, + diffusion_steps=100, + inference_steps=10, + num_cams=num_cams, + dataset_stats=None, + normalization_type='min_max', + num_action_steps=8, + head_type='transformer' + ) + + print(f"\n✅ VLAAgent with Transformer创建成功") + print(f" head_type: {agent.head_type}") + print(f" 参数量: {sum(p.numel() for p in agent.parameters()):,}") + + # 测试前向传播 + B = 2 + batch = { + 'images': {'cam0': torch.randn(B, obs_horizon, 3, 84, 84)}, + 'qpos': torch.randn(B, obs_horizon, obs_dim), + 'action': torch.randn(B, pred_horizon, action_dim) + } + + loss = agent.compute_loss(batch) + print(f"\n✅ 训练loss: {loss.item():.4f}") + + # 测试推理 + agent.eval() + with torch.no_grad(): + actions = agent.predict_action(batch['images'], batch['qpos']) + print(f"✅ 推理输出shape: {actions.shape}") + + print(f"\n✅ 集成测试通过!") + + +if __name__ == "__main__": + try: + test_transformer_standalone() + test_transformer_with_agent() + print("\n" + "=" * 80) + print("🎉 所有测试通过!") + print("=" * 80) + except Exception as e: + print(f"\n❌ 测试失败: {e}") + import traceback + traceback.print_exc()