feat: 添加transformer头
This commit is contained in:
@@ -27,6 +27,7 @@ class VLAAgent(nn.Module):
|
|||||||
dataset_stats=None, # 数据集统计信息,用于归一化
|
dataset_stats=None, # 数据集统计信息,用于归一化
|
||||||
normalization_type='min_max', # 归一化类型: 'gaussian' 或 'min_max'
|
normalization_type='min_max', # 归一化类型: 'gaussian' 或 'min_max'
|
||||||
num_action_steps=8, # 每次推理实际执行多少步动作
|
num_action_steps=8, # 每次推理实际执行多少步动作
|
||||||
|
head_type='unet', # Policy head类型: 'unet' 或 'transformer'
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# 保存参数
|
# 保存参数
|
||||||
@@ -37,6 +38,7 @@ class VLAAgent(nn.Module):
|
|||||||
self.num_cams = num_cams
|
self.num_cams = num_cams
|
||||||
self.num_action_steps = num_action_steps
|
self.num_action_steps = num_action_steps
|
||||||
self.inference_steps = inference_steps
|
self.inference_steps = inference_steps
|
||||||
|
self.head_type = head_type # 'unet' 或 'transformer'
|
||||||
|
|
||||||
|
|
||||||
# 归一化模块 - 统一训练和推理的归一化逻辑
|
# 归一化模块 - 统一训练和推理的归一化逻辑
|
||||||
@@ -47,10 +49,15 @@ class VLAAgent(nn.Module):
|
|||||||
|
|
||||||
self.vision_encoder = vision_backbone
|
self.vision_encoder = vision_backbone
|
||||||
single_cam_feat_dim = self.vision_encoder.output_dim
|
single_cam_feat_dim = self.vision_encoder.output_dim
|
||||||
|
# global_cond_dim: 展平后的总维度(用于UNet)
|
||||||
total_vision_dim = single_cam_feat_dim * num_cams * obs_horizon
|
total_vision_dim = single_cam_feat_dim * num_cams * obs_horizon
|
||||||
total_prop_dim = obs_dim * obs_horizon
|
total_prop_dim = obs_dim * obs_horizon
|
||||||
self.global_cond_dim = total_vision_dim + total_prop_dim
|
self.global_cond_dim = total_vision_dim + total_prop_dim
|
||||||
|
|
||||||
|
# per_step_cond_dim: 每步的条件维度(用于Transformer)
|
||||||
|
# 注意:这里不乘以obs_horizon,因为Transformer的输入是序列形式
|
||||||
|
self.per_step_cond_dim = single_cam_feat_dim * num_cams + obs_dim
|
||||||
|
|
||||||
self.noise_scheduler = DDPMScheduler(
|
self.noise_scheduler = DDPMScheduler(
|
||||||
num_train_timesteps=diffusion_steps,
|
num_train_timesteps=diffusion_steps,
|
||||||
beta_schedule='squaredcos_cap_v2', # 机器人任务常用的 schedule
|
beta_schedule='squaredcos_cap_v2', # 机器人任务常用的 schedule
|
||||||
@@ -66,11 +73,27 @@ class VLAAgent(nn.Module):
|
|||||||
prediction_type='epsilon'
|
prediction_type='epsilon'
|
||||||
)
|
)
|
||||||
|
|
||||||
self.noise_pred_net = head(
|
# 根据head类型初始化不同的参数
|
||||||
input_dim=action_dim,
|
if head_type == 'transformer':
|
||||||
# input_dim = action_dim + obs_dim, # 备选:包含观测维度
|
# 如果head已经是nn.Module实例,直接使用;否则需要初始化
|
||||||
global_cond_dim=self.global_cond_dim
|
if isinstance(head, nn.Module):
|
||||||
)
|
# 已经是实例化的模块(测试时直接传入<E4BCA0><E585A5>
|
||||||
|
self.noise_pred_net = head
|
||||||
|
else:
|
||||||
|
# Hydra部分初始化的对象,调用时传入参数
|
||||||
|
self.noise_pred_net = head(
|
||||||
|
input_dim=action_dim,
|
||||||
|
output_dim=action_dim,
|
||||||
|
horizon=pred_horizon,
|
||||||
|
n_obs_steps=obs_horizon,
|
||||||
|
cond_dim=self.per_step_cond_dim # 每步的条件维度
|
||||||
|
)
|
||||||
|
else: # 'unet' (default)
|
||||||
|
# UNet接口: input_dim, global_cond_dim
|
||||||
|
self.noise_pred_net = head(
|
||||||
|
input_dim=action_dim,
|
||||||
|
global_cond_dim=self.global_cond_dim
|
||||||
|
)
|
||||||
|
|
||||||
self.state_encoder = state_encoder
|
self.state_encoder = state_encoder
|
||||||
self.action_encoder = action_encoder
|
self.action_encoder = action_encoder
|
||||||
@@ -124,13 +147,22 @@ class VLAAgent(nn.Module):
|
|||||||
global_cond = torch.cat([visual_features, state_features], dim=-1)
|
global_cond = torch.cat([visual_features, state_features], dim=-1)
|
||||||
global_cond = global_cond.flatten(start_dim=1)
|
global_cond = global_cond.flatten(start_dim=1)
|
||||||
|
|
||||||
|
# 5. 网络预测噪声(根据head类型选择接口)
|
||||||
# 5. 网络预测噪声
|
if self.head_type == 'transformer':
|
||||||
pred_noise = self.noise_pred_net(
|
# Transformer需要序列格式的条件: (B, obs_horizon, cond_dim_per_step)
|
||||||
sample=noisy_actions,
|
# 将展平的global_cond reshape回序列格式
|
||||||
timestep=timesteps,
|
cond = global_cond.reshape(B, self.obs_horizon, self.per_step_cond_dim)
|
||||||
global_cond=global_cond
|
pred_noise = self.noise_pred_net(
|
||||||
)
|
sample=noisy_actions,
|
||||||
|
timestep=timesteps,
|
||||||
|
cond=cond
|
||||||
|
)
|
||||||
|
else: # 'unet'
|
||||||
|
pred_noise = self.noise_pred_net(
|
||||||
|
sample=noisy_actions,
|
||||||
|
timestep=timesteps,
|
||||||
|
global_cond=global_cond
|
||||||
|
)
|
||||||
|
|
||||||
# 6. 计算 Loss (MSE),支持 padding mask
|
# 6. 计算 Loss (MSE),支持 padding mask
|
||||||
loss = nn.functional.mse_loss(pred_noise, noise, reduction='none')
|
loss = nn.functional.mse_loss(pred_noise, noise, reduction='none')
|
||||||
@@ -343,12 +375,21 @@ class VLAAgent(nn.Module):
|
|||||||
global_cond = torch.cat([visual_features, state_features], dim=-1)
|
global_cond = torch.cat([visual_features, state_features], dim=-1)
|
||||||
global_cond = global_cond.flatten(start_dim=1)
|
global_cond = global_cond.flatten(start_dim=1)
|
||||||
|
|
||||||
# 预测噪声
|
# 预测噪声(根据head类型选择接口)
|
||||||
noise_pred = self.noise_pred_net(
|
if self.head_type == 'transformer':
|
||||||
sample=model_input,
|
# Transformer需要序列格式的条件
|
||||||
timestep=t,
|
cond = global_cond.reshape(B, self.obs_horizon, self.per_step_cond_dim)
|
||||||
global_cond=global_cond
|
noise_pred = self.noise_pred_net(
|
||||||
)
|
sample=model_input,
|
||||||
|
timestep=t,
|
||||||
|
cond=cond
|
||||||
|
)
|
||||||
|
else: # 'unet'
|
||||||
|
noise_pred = self.noise_pred_net(
|
||||||
|
sample=model_input,
|
||||||
|
timestep=t,
|
||||||
|
global_cond=global_cond
|
||||||
|
)
|
||||||
|
|
||||||
# 移除噪声,更新 current_actions
|
# 移除噪声,更新 current_actions
|
||||||
current_actions = self.infer_scheduler.step(
|
current_actions = self.infer_scheduler.step(
|
||||||
|
|||||||
54
roboimi/vla/conf/agent/resnet_transformer.yaml
Normal file
54
roboimi/vla/conf/agent/resnet_transformer.yaml
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
# @package agent
|
||||||
|
defaults:
|
||||||
|
- /backbone@vision_backbone: resnet_diffusion
|
||||||
|
- /modules@state_encoder: identity_state_encoder
|
||||||
|
- /modules@action_encoder: identity_action_encoder
|
||||||
|
- /head: transformer1d
|
||||||
|
- _self_
|
||||||
|
|
||||||
|
_target_: roboimi.vla.agent.VLAAgent
|
||||||
|
|
||||||
|
# ====================
|
||||||
|
# 模型维度配置
|
||||||
|
# ====================
|
||||||
|
action_dim: 16 # 动作维度(机器人关节数)
|
||||||
|
obs_dim: 16 # 本体感知维度(关节位置)
|
||||||
|
|
||||||
|
# ====================
|
||||||
|
# 归一化配置
|
||||||
|
# ====================
|
||||||
|
normalization_type: "min_max" # "min_max" or "gaussian"
|
||||||
|
|
||||||
|
# ====================
|
||||||
|
# 时间步配置
|
||||||
|
# ====================
|
||||||
|
pred_horizon: 16 # 预测未来多少步动作
|
||||||
|
obs_horizon: 2 # 使用多少步历史观测
|
||||||
|
num_action_steps: 8 # 每次推理实际执行多少步动作(应 <= pred_horizon - obs_horizon + 1)
|
||||||
|
|
||||||
|
# ====================
|
||||||
|
# 相机配置
|
||||||
|
# ====================
|
||||||
|
num_cams: 3 # 摄像头数量 (r_vis, top, front)
|
||||||
|
|
||||||
|
# ====================
|
||||||
|
# 扩散过程配置
|
||||||
|
# ====================
|
||||||
|
diffusion_steps: 100 # 扩散训练步数(DDPM)
|
||||||
|
inference_steps: 10 # 推理时的去噪步数(DDIM,<4D><EFBC8C>定为 10)
|
||||||
|
|
||||||
|
# ====================
|
||||||
|
# Head 类型标识(用于VLAAgent选择调用方式)
|
||||||
|
# ====================
|
||||||
|
head_type: "transformer" # "unet" 或 "transformer"
|
||||||
|
|
||||||
|
# Head 参数覆盖
|
||||||
|
head:
|
||||||
|
input_dim: ${agent.action_dim}
|
||||||
|
output_dim: ${agent.action_dim}
|
||||||
|
horizon: ${agent.pred_horizon}
|
||||||
|
n_obs_steps: ${agent.obs_horizon}
|
||||||
|
# Transformer的cond_dim是每步的维度
|
||||||
|
# ResNet18 + SpatialSoftmax(32 keypoints) = 64维/相机
|
||||||
|
# 计算方式:单相机特征(64) * 相机数(3) + obs_dim(16) = 208
|
||||||
|
cond_dim: 208
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
defaults:
|
defaults:
|
||||||
- agent: resnet_diffusion
|
- agent: resnet_transformer
|
||||||
- data: simpe_robot_dataset
|
- data: simpe_robot_dataset
|
||||||
- eval: eval
|
- eval: eval
|
||||||
- _self_
|
- _self_
|
||||||
|
|||||||
29
roboimi/vla/conf/head/transformer1d.yaml
Normal file
29
roboimi/vla/conf/head/transformer1d.yaml
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
# Transformer-based Diffusion Policy Head
|
||||||
|
_target_: roboimi.vla.models.heads.transformer1d.Transformer1D
|
||||||
|
_partial_: true
|
||||||
|
|
||||||
|
# ====================
|
||||||
|
# Transformer 架构配置
|
||||||
|
# ====================
|
||||||
|
n_layer: 8 # Transformer层数
|
||||||
|
n_head: 8 # 注意力头数
|
||||||
|
n_emb: 256 # 嵌入维度
|
||||||
|
p_drop_emb: 0.1 # Embedding dropout
|
||||||
|
p_drop_attn: 0.1 # Attention dropout
|
||||||
|
|
||||||
|
# ====================
|
||||||
|
# 条件配置
|
||||||
|
# ====================
|
||||||
|
causal_attn: false # 是否使用因果注意力(自回归生成)
|
||||||
|
obs_as_cond: true # 观测作为条件(由cond_dim > 0决定)
|
||||||
|
n_cond_layers: 0 # 条件编码器层数(0表示使用MLP,>0使用TransformerEncoder)
|
||||||
|
|
||||||
|
# ====================
|
||||||
|
# 注意事项
|
||||||
|
# ====================
|
||||||
|
# 以下参数将在agent配置中通过interpolation提供:
|
||||||
|
# - input_dim: ${agent.action_dim}
|
||||||
|
# - output_dim: ${agent.action_dim}
|
||||||
|
# - horizon: ${agent.pred_horizon}
|
||||||
|
# - n_obs_steps: ${agent.obs_horizon}
|
||||||
|
# - cond_dim: 通过agent中的global_cond_dim计算
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
# # Action Head models
|
# Action Head models
|
||||||
from .conditional_unet1d import ConditionalUnet1D
|
from .conditional_unet1d import ConditionalUnet1D
|
||||||
|
from .transformer1d import Transformer1D
|
||||||
|
|
||||||
__all__ = ["ConditionalUnet1D"]
|
__all__ = ["ConditionalUnet1D", "Transformer1D"]
|
||||||
|
|||||||
396
roboimi/vla/models/heads/transformer1d.py
Normal file
396
roboimi/vla/models/heads/transformer1d.py
Normal file
@@ -0,0 +1,396 @@
|
|||||||
|
"""
|
||||||
|
Transformer-based Diffusion Policy Head
|
||||||
|
|
||||||
|
使用Transformer架构(Encoder-Decoder)替代UNet进行噪声预测。
|
||||||
|
支持通过Cross-Attention注入全局条件(观测特征)。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class SinusoidalPosEmb(nn.Module):
|
||||||
|
"""正弦位置编码(用于时间步嵌入)"""
|
||||||
|
def __init__(self, dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
device = x.device
|
||||||
|
half_dim = self.dim // 2
|
||||||
|
emb = math.log(10000) / (half_dim - 1)
|
||||||
|
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||||||
|
emb = x[:, None] * emb[None, :]
|
||||||
|
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer1D(nn.Module):
|
||||||
|
"""
|
||||||
|
Transformer-based 1D Diffusion Model
|
||||||
|
|
||||||
|
使用Encoder-Decoder架构:
|
||||||
|
- Encoder: 处理条件(观测 + 时间步)
|
||||||
|
- Decoder: 通过Cross-Attention预测噪声
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dim: 输入动作维度
|
||||||
|
output_dim: 输出动作维度
|
||||||
|
horizon: 预测horizon长度
|
||||||
|
n_obs_steps: 观测步数
|
||||||
|
cond_dim: 条件维度
|
||||||
|
n_layer: Transformer层数
|
||||||
|
n_head: 注意力头数
|
||||||
|
n_emb: 嵌入维度
|
||||||
|
p_drop_emb: Embedding dropout
|
||||||
|
p_drop_attn: Attention dropout
|
||||||
|
causal_attn: 是否使用因果注意力(自回归)
|
||||||
|
n_cond_layers: Encoder层数(0表示使用MLP)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim: int,
|
||||||
|
output_dim: int,
|
||||||
|
horizon: int,
|
||||||
|
n_obs_steps: int = None,
|
||||||
|
cond_dim: int = 0,
|
||||||
|
n_layer: int = 8,
|
||||||
|
n_head: int = 8,
|
||||||
|
n_emb: int = 256,
|
||||||
|
p_drop_emb: float = 0.1,
|
||||||
|
p_drop_attn: float = 0.1,
|
||||||
|
causal_attn: bool = False,
|
||||||
|
obs_as_cond: bool = False,
|
||||||
|
n_cond_layers: int = 0
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# 计算序列长度
|
||||||
|
if n_obs_steps is None:
|
||||||
|
n_obs_steps = horizon
|
||||||
|
|
||||||
|
T = horizon
|
||||||
|
T_cond = 1 # 时间步token数量
|
||||||
|
|
||||||
|
# 确定是否使用观测作为条件
|
||||||
|
obs_as_cond = cond_dim > 0
|
||||||
|
if obs_as_cond:
|
||||||
|
T_cond += n_obs_steps
|
||||||
|
|
||||||
|
# 保存配置
|
||||||
|
self.T = T
|
||||||
|
self.T_cond = T_cond
|
||||||
|
self.horizon = horizon
|
||||||
|
self.obs_as_cond = obs_as_cond
|
||||||
|
self.input_dim = input_dim
|
||||||
|
self.output_dim = output_dim
|
||||||
|
|
||||||
|
# ==================== 输入嵌入 ====================
|
||||||
|
self.input_emb = nn.Linear(input_dim, n_emb)
|
||||||
|
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
|
||||||
|
self.drop = nn.Dropout(p_drop_emb)
|
||||||
|
|
||||||
|
# ==================== 条件编码 ====================
|
||||||
|
# 时间步嵌入
|
||||||
|
self.time_emb = SinusoidalPosEmb(n_emb)
|
||||||
|
|
||||||
|
# 观测条件嵌入(可选)
|
||||||
|
self.cond_obs_emb = None
|
||||||
|
if obs_as_cond:
|
||||||
|
self.cond_obs_emb = nn.Linear(cond_dim, n_emb)
|
||||||
|
|
||||||
|
# 条件位置编码
|
||||||
|
self.cond_pos_emb = None
|
||||||
|
if T_cond > 0:
|
||||||
|
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
|
||||||
|
|
||||||
|
# ==================== Encoder ====================
|
||||||
|
self.encoder = None
|
||||||
|
self.encoder_only = False
|
||||||
|
|
||||||
|
if T_cond > 0:
|
||||||
|
if n_cond_layers > 0:
|
||||||
|
# 使用Transformer Encoder
|
||||||
|
encoder_layer = nn.TransformerEncoderLayer(
|
||||||
|
d_model=n_emb,
|
||||||
|
nhead=n_head,
|
||||||
|
dim_feedforward=4 * n_emb,
|
||||||
|
dropout=p_drop_attn,
|
||||||
|
activation='gelu',
|
||||||
|
batch_first=True,
|
||||||
|
norm_first=True # Pre-LN更稳定
|
||||||
|
)
|
||||||
|
self.encoder = nn.TransformerEncoder(
|
||||||
|
encoder_layer=encoder_layer,
|
||||||
|
num_layers=n_cond_layers
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 使用简单的MLP
|
||||||
|
self.encoder = nn.Sequential(
|
||||||
|
nn.Linear(n_emb, 4 * n_emb),
|
||||||
|
nn.Mish(),
|
||||||
|
nn.Linear(4 * n_emb, n_emb)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Encoder-only模式(BERT风格)
|
||||||
|
self.encoder_only = True
|
||||||
|
encoder_layer = nn.TransformerEncoderLayer(
|
||||||
|
d_model=n_emb,
|
||||||
|
nhead=n_head,
|
||||||
|
dim_feedforward=4 * n_emb,
|
||||||
|
dropout=p_drop_attn,
|
||||||
|
activation='gelu',
|
||||||
|
batch_first=True,
|
||||||
|
norm_first=True
|
||||||
|
)
|
||||||
|
self.encoder = nn.TransformerEncoder(
|
||||||
|
encoder_layer=encoder_layer,
|
||||||
|
num_layers=n_layer
|
||||||
|
)
|
||||||
|
|
||||||
|
# ==================== Attention Mask ====================
|
||||||
|
self.mask = None
|
||||||
|
self.memory_mask = None
|
||||||
|
|
||||||
|
if causal_attn:
|
||||||
|
# 因果mask:确保只关注左侧
|
||||||
|
sz = T
|
||||||
|
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
||||||
|
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||||
|
self.register_buffer("mask", mask)
|
||||||
|
|
||||||
|
if obs_as_cond:
|
||||||
|
# 交叉注意力mask
|
||||||
|
S = T_cond
|
||||||
|
t, s = torch.meshgrid(
|
||||||
|
torch.arange(T),
|
||||||
|
torch.arange(S),
|
||||||
|
indexing='ij'
|
||||||
|
)
|
||||||
|
mask = t >= (s - 1)
|
||||||
|
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||||
|
self.register_buffer('memory_mask', mask)
|
||||||
|
|
||||||
|
# ==================== Decoder ====================
|
||||||
|
if not self.encoder_only:
|
||||||
|
decoder_layer = nn.TransformerDecoderLayer(
|
||||||
|
d_model=n_emb,
|
||||||
|
nhead=n_head,
|
||||||
|
dim_feedforward=4 * n_emb,
|
||||||
|
dropout=p_drop_attn,
|
||||||
|
activation='gelu',
|
||||||
|
batch_first=True,
|
||||||
|
norm_first=True
|
||||||
|
)
|
||||||
|
self.decoder = nn.TransformerDecoder(
|
||||||
|
decoder_layer=decoder_layer,
|
||||||
|
num_layers=n_layer
|
||||||
|
)
|
||||||
|
|
||||||
|
# ==================== 输出头 ====================
|
||||||
|
self.ln_f = nn.LayerNorm(n_emb)
|
||||||
|
self.head = nn.Linear(n_emb, output_dim)
|
||||||
|
|
||||||
|
# ==================== 初始化 ====================
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
# 打印参数量
|
||||||
|
total_params = sum(p.numel() for p in self.parameters())
|
||||||
|
print(f"Transformer1D parameters: {total_params:,}")
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
"""初始化权重"""
|
||||||
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||||
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||||
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||||
|
torch.nn.init.zeros_(module.bias)
|
||||||
|
elif isinstance(module, nn.MultiheadAttention):
|
||||||
|
# MultiheadAttention的权重初始化
|
||||||
|
for name in ['in_proj_weight', 'q_proj_weight', 'k_proj_weight', 'v_proj_weight']:
|
||||||
|
weight = getattr(module, name, None)
|
||||||
|
if weight is not None:
|
||||||
|
torch.nn.init.normal_(weight, mean=0.0, std=0.02)
|
||||||
|
|
||||||
|
for name in ['in_proj_bias', 'bias_k', 'bias_v']:
|
||||||
|
bias = getattr(module, name, None)
|
||||||
|
if bias is not None:
|
||||||
|
torch.nn.init.zeros_(bias)
|
||||||
|
elif isinstance(module, nn.LayerNorm):
|
||||||
|
torch.nn.init.zeros_(module.bias)
|
||||||
|
torch.nn.init.ones_(module.weight)
|
||||||
|
elif isinstance(module, Transformer1D):
|
||||||
|
# 位置编码初始化
|
||||||
|
torch.nn.init.normal_(self.pos_emb, mean=0.0, std=0.02)
|
||||||
|
if self.cond_pos_emb is not None:
|
||||||
|
torch.nn.init.normal_(self.cond_pos_emb, mean=0.0, std=0.02)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
sample: torch.Tensor,
|
||||||
|
timestep: torch.Tensor,
|
||||||
|
cond: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
前向传播
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample: (B, T, input_dim) 输入序列(加噪动作)
|
||||||
|
timestep: (B,) 时间步
|
||||||
|
cond: (B, T', cond_dim) 条件序列(观测特征)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(B, T, output_dim) 预测的噪声
|
||||||
|
"""
|
||||||
|
# ==================== 处理时间步 ====================
|
||||||
|
timesteps = timestep
|
||||||
|
if not torch.is_tensor(timesteps):
|
||||||
|
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||||
|
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||||
|
timesteps = timesteps[None].to(sample.device)
|
||||||
|
|
||||||
|
# 扩展到batch维度
|
||||||
|
timesteps = timesteps.expand(sample.shape[0])
|
||||||
|
time_emb = self.time_emb(timesteps).unsqueeze(1) # (B, 1, n_emb)
|
||||||
|
|
||||||
|
# ==================== 处理输入 ====================
|
||||||
|
input_emb = self.input_emb(sample) # (B, T, n_emb)
|
||||||
|
|
||||||
|
# ==================== Encoder-Decoder模式 ====================
|
||||||
|
if not self.encoder_only:
|
||||||
|
# --- Encoder: 处理条件 ---
|
||||||
|
cond_embeddings = time_emb
|
||||||
|
|
||||||
|
if self.obs_as_cond and cond is not None:
|
||||||
|
# 添加观测条件
|
||||||
|
cond_obs_emb = self.cond_obs_emb(cond) # (B, T_cond-1, n_emb)
|
||||||
|
cond_embeddings = torch.cat([cond_embeddings, cond_obs_emb], dim=1)
|
||||||
|
|
||||||
|
# 添加位置编码
|
||||||
|
tc = cond_embeddings.shape[1]
|
||||||
|
pos_emb = self.cond_pos_emb[:, :tc, :]
|
||||||
|
x = self.drop(cond_embeddings + pos_emb)
|
||||||
|
|
||||||
|
# 通过encoder
|
||||||
|
memory = self.encoder(x) # (B, T_cond, n_emb)
|
||||||
|
|
||||||
|
# --- Decoder: 预测噪声 ---
|
||||||
|
# 添加位置编码到输入
|
||||||
|
token_embeddings = input_emb
|
||||||
|
t = token_embeddings.shape[1]
|
||||||
|
pos_emb = self.pos_emb[:, :t, :]
|
||||||
|
x = self.drop(token_embeddings + pos_emb)
|
||||||
|
|
||||||
|
# Cross-Attention: Query来自输入,Key/Value来自memory
|
||||||
|
x = self.decoder(
|
||||||
|
tgt=x,
|
||||||
|
memory=memory,
|
||||||
|
tgt_mask=self.mask,
|
||||||
|
memory_mask=self.memory_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
# ==================== Encoder-Only模式 ====================
|
||||||
|
else:
|
||||||
|
# BERT风格:时间步作为特殊token
|
||||||
|
token_embeddings = torch.cat([time_emb, input_emb], dim=1)
|
||||||
|
t = token_embeddings.shape[1]
|
||||||
|
pos_emb = self.pos_emb[:, :t, :]
|
||||||
|
x = self.drop(token_embeddings + pos_emb)
|
||||||
|
|
||||||
|
x = self.encoder(src=x, mask=self.mask)
|
||||||
|
x = x[:, 1:, :] # 移除时间步token
|
||||||
|
|
||||||
|
# ==================== 输出头 ====================
|
||||||
|
x = self.ln_f(x)
|
||||||
|
x = self.head(x) # (B, T, output_dim)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# 便捷函数:创建Transformer1D模型
|
||||||
|
# ============================================================================
|
||||||
|
def create_transformer1d(
|
||||||
|
input_dim: int,
|
||||||
|
output_dim: int,
|
||||||
|
horizon: int,
|
||||||
|
n_obs_steps: int,
|
||||||
|
cond_dim: int,
|
||||||
|
n_layer: int = 8,
|
||||||
|
n_head: int = 8,
|
||||||
|
n_emb: int = 256,
|
||||||
|
**kwargs
|
||||||
|
) -> Transformer1D:
|
||||||
|
"""
|
||||||
|
创建Transformer1D模型的便捷函数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dim: 输入动作维度
|
||||||
|
output_dim: 输出动作维度
|
||||||
|
horizon: 预测horizon
|
||||||
|
n_obs_steps: 观测步数
|
||||||
|
cond_dim: 条件维度
|
||||||
|
n_layer: Transformer层数
|
||||||
|
n_head: 注意力头数
|
||||||
|
n_emb: 嵌入维度
|
||||||
|
**kwargs: 其他参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Transformer1D模型
|
||||||
|
"""
|
||||||
|
model = Transformer1D(
|
||||||
|
input_dim=input_dim,
|
||||||
|
output_dim=output_dim,
|
||||||
|
horizon=horizon,
|
||||||
|
n_obs_steps=n_obs_steps,
|
||||||
|
cond_dim=cond_dim,
|
||||||
|
n_layer=n_layer,
|
||||||
|
n_head=n_head,
|
||||||
|
n_emb=n_emb,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("=" * 80)
|
||||||
|
print("Testing Transformer1D")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# 配置
|
||||||
|
B = 4
|
||||||
|
T = 16
|
||||||
|
action_dim = 16
|
||||||
|
obs_horizon = 2
|
||||||
|
cond_dim = 416 # vision + state特征维度
|
||||||
|
|
||||||
|
# 创建模型
|
||||||
|
model = Transformer1D(
|
||||||
|
input_dim=action_dim,
|
||||||
|
output_dim=action_dim,
|
||||||
|
horizon=T,
|
||||||
|
n_obs_steps=obs_horizon,
|
||||||
|
cond_dim=cond_dim,
|
||||||
|
n_layer=4,
|
||||||
|
n_head=8,
|
||||||
|
n_emb=256,
|
||||||
|
causal_attn=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# 测试前向传播
|
||||||
|
sample = torch.randn(B, T, action_dim)
|
||||||
|
timestep = torch.randint(0, 100, (B,))
|
||||||
|
cond = torch.randn(B, obs_horizon, cond_dim)
|
||||||
|
|
||||||
|
output = model(sample, timestep, cond)
|
||||||
|
|
||||||
|
print(f"\n输入:")
|
||||||
|
print(f" sample: {sample.shape}")
|
||||||
|
print(f" timestep: {timestep.shape}")
|
||||||
|
print(f" cond: {cond.shape}")
|
||||||
|
print(f"\n输出:")
|
||||||
|
print(f" output: {output.shape}")
|
||||||
|
print(f"\n✅ 测试通过!")
|
||||||
166
test_transformer_head.py
Normal file
166
test_transformer_head.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
"""
|
||||||
|
测试Transformer1D Head
|
||||||
|
|
||||||
|
验证:
|
||||||
|
1. 模型初始化
|
||||||
|
2. 前向传播
|
||||||
|
3. 与VLAAgent集成
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import sys
|
||||||
|
sys.path.append('.')
|
||||||
|
|
||||||
|
def test_transformer_standalone():
|
||||||
|
"""测试独立的Transformer1D模型"""
|
||||||
|
print("=" * 80)
|
||||||
|
print("测试1: Transformer1D 独立模型")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
from roboimi.vla.models.heads.transformer1d import Transformer1D
|
||||||
|
|
||||||
|
# 配置
|
||||||
|
B = 4
|
||||||
|
T = 16
|
||||||
|
action_dim = 16
|
||||||
|
obs_horizon = 2
|
||||||
|
# 注意:Transformer的cond_dim是指每步条件的维度,不是总维度
|
||||||
|
# cond: (B, obs_horizon, cond_dim_per_step)
|
||||||
|
cond_dim_per_step = 208 # 64*3 + 16 = 192 + 16 = 208
|
||||||
|
|
||||||
|
# 创建模型
|
||||||
|
model = Transformer1D(
|
||||||
|
input_dim=action_dim,
|
||||||
|
output_dim=action_dim,
|
||||||
|
horizon=T,
|
||||||
|
n_obs_steps=obs_horizon,
|
||||||
|
cond_dim=cond_dim_per_step, # 每步的维度
|
||||||
|
n_layer=4,
|
||||||
|
n_head=8,
|
||||||
|
n_emb=256,
|
||||||
|
causal_attn=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# 测试前向传播
|
||||||
|
sample = torch.randn(B, T, action_dim)
|
||||||
|
timestep = torch.randint(0, 100, (B,))
|
||||||
|
cond = torch.randn(B, obs_horizon, cond_dim_per_step)
|
||||||
|
|
||||||
|
output = model(sample, timestep, cond)
|
||||||
|
|
||||||
|
print(f"\n✅ 输入:")
|
||||||
|
print(f" sample: {sample.shape}")
|
||||||
|
print(f" timestep: {timestep.shape}")
|
||||||
|
print(f" cond: {cond.shape}")
|
||||||
|
print(f"\n✅ 输出:")
|
||||||
|
print(f" output: {output.shape}")
|
||||||
|
|
||||||
|
assert output.shape == (B, T, action_dim), f"输出形状错误: {output.shape}"
|
||||||
|
print(f"\n✅ 测试通过!")
|
||||||
|
|
||||||
|
|
||||||
|
def test_transformer_with_agent():
|
||||||
|
"""测试Transformer与VLAAgent集成"""
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("测试2: Transformer + VLAAgent 集成")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
from roboimi.vla.agent import VLAAgent
|
||||||
|
from roboimi.vla.models.backbones.resnet_diffusion import ResNetDiffusionBackbone
|
||||||
|
from roboimi.vla.modules.encoders import IdentityStateEncoder, IdentityActionEncoder
|
||||||
|
from roboimi.vla.models.heads.transformer1d import Transformer1D
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
# 创建简单的配置
|
||||||
|
vision_backbone = ResNetDiffusionBackbone(
|
||||||
|
vision_backbone="resnet18",
|
||||||
|
pretrained_backbone_weights=None,
|
||||||
|
input_shape=(3, 84, 84),
|
||||||
|
use_group_norm=True,
|
||||||
|
spatial_softmax_num_keypoints=32,
|
||||||
|
freeze_backbone=False,
|
||||||
|
use_separate_rgb_encoder_per_camera=False,
|
||||||
|
num_cameras=1
|
||||||
|
)
|
||||||
|
|
||||||
|
state_encoder = IdentityStateEncoder()
|
||||||
|
action_encoder = IdentityActionEncoder()
|
||||||
|
|
||||||
|
# 创建Transformer head
|
||||||
|
action_dim = 16
|
||||||
|
obs_dim = 16
|
||||||
|
pred_horizon = 16
|
||||||
|
obs_horizon = 2
|
||||||
|
num_cams = 1
|
||||||
|
|
||||||
|
# 计算条件维度
|
||||||
|
single_cam_feat_dim = vision_backbone.output_dim # 64
|
||||||
|
# 每步的条件维度(不乘以obs_horizon)
|
||||||
|
per_step_cond_dim = single_cam_feat_dim * num_cams + obs_dim # 64 * 1 + 16 = 80
|
||||||
|
|
||||||
|
transformer_head = Transformer1D(
|
||||||
|
input_dim=action_dim,
|
||||||
|
output_dim=action_dim,
|
||||||
|
horizon=pred_horizon,
|
||||||
|
n_obs_steps=obs_horizon,
|
||||||
|
cond_dim=per_step_cond_dim, # 每步的维度,不是总维度!
|
||||||
|
n_layer=4,
|
||||||
|
n_head=8,
|
||||||
|
n_emb=128,
|
||||||
|
causal_attn=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建Agent
|
||||||
|
agent = VLAAgent(
|
||||||
|
vision_backbone=vision_backbone,
|
||||||
|
state_encoder=state_encoder,
|
||||||
|
action_encoder=action_encoder,
|
||||||
|
head=transformer_head,
|
||||||
|
action_dim=action_dim,
|
||||||
|
obs_dim=obs_dim,
|
||||||
|
pred_horizon=pred_horizon,
|
||||||
|
obs_horizon=obs_horizon,
|
||||||
|
diffusion_steps=100,
|
||||||
|
inference_steps=10,
|
||||||
|
num_cams=num_cams,
|
||||||
|
dataset_stats=None,
|
||||||
|
normalization_type='min_max',
|
||||||
|
num_action_steps=8,
|
||||||
|
head_type='transformer'
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\n✅ VLAAgent with Transformer创建成功")
|
||||||
|
print(f" head_type: {agent.head_type}")
|
||||||
|
print(f" 参数量: {sum(p.numel() for p in agent.parameters()):,}")
|
||||||
|
|
||||||
|
# 测试前向传播
|
||||||
|
B = 2
|
||||||
|
batch = {
|
||||||
|
'images': {'cam0': torch.randn(B, obs_horizon, 3, 84, 84)},
|
||||||
|
'qpos': torch.randn(B, obs_horizon, obs_dim),
|
||||||
|
'action': torch.randn(B, pred_horizon, action_dim)
|
||||||
|
}
|
||||||
|
|
||||||
|
loss = agent.compute_loss(batch)
|
||||||
|
print(f"\n✅ 训练loss: {loss.item():.4f}")
|
||||||
|
|
||||||
|
# 测试推理
|
||||||
|
agent.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
actions = agent.predict_action(batch['images'], batch['qpos'])
|
||||||
|
print(f"✅ 推理输出shape: {actions.shape}")
|
||||||
|
|
||||||
|
print(f"\n✅ 集成测试通过!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
test_transformer_standalone()
|
||||||
|
test_transformer_with_agent()
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("🎉 所有测试通过!")
|
||||||
|
print("=" * 80)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ 测试失败: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
Reference in New Issue
Block a user