feat(vla): align transformer training stack and rollout validation
This commit is contained in:
@@ -178,12 +178,18 @@ class ResNetDiffusionBackbone(VLABackbone):
|
||||
spatial_softmax_num_keypoints: int = 32,
|
||||
use_separate_rgb_encoder_per_camera: bool = False, # 新增:是否为每个摄像头使用独立编码器
|
||||
num_cameras: int = 1, # 新增:摄像头数量(仅在独立编码器模式下使用)
|
||||
camera_names: Optional[Tuple[str, ...]] = None, # 显式相机顺序
|
||||
freeze_backbone: bool = True, # 新增:是否冻结ResNet backbone(推荐True)
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.use_separate_rgb_encoder_per_camera = use_separate_rgb_encoder_per_camera
|
||||
self.num_cameras = num_cameras
|
||||
self.camera_names = tuple(camera_names) if camera_names is not None else None
|
||||
if self.camera_names is not None and len(self.camera_names) != self.num_cameras:
|
||||
raise ValueError(
|
||||
f"camera_names 长度({len(self.camera_names)})与 num_cameras({self.num_cameras})不一致"
|
||||
)
|
||||
|
||||
if use_separate_rgb_encoder_per_camera:
|
||||
# 独立编码器模式:为每个摄像头创建独立的编码器
|
||||
@@ -217,6 +223,22 @@ class ResNetDiffusionBackbone(VLABackbone):
|
||||
)
|
||||
self.feature_dim = self.rgb_encoder.feature_dim
|
||||
|
||||
def _ordered_camera_names(self, images) -> Tuple[str, ...]:
|
||||
if self.camera_names is None:
|
||||
camera_names = tuple(sorted(images.keys()))
|
||||
if len(camera_names) != self.num_cameras:
|
||||
raise ValueError(
|
||||
f"图像输入相机数量({len(camera_names)})与 num_cameras({self.num_cameras})不一致"
|
||||
)
|
||||
return camera_names
|
||||
|
||||
missing = [cam_name for cam_name in self.camera_names if cam_name not in images]
|
||||
if missing:
|
||||
raise ValueError(
|
||||
f"图像输入缺少必需相机。missing={missing}, expected={list(self.camera_names)}"
|
||||
)
|
||||
return self.camera_names
|
||||
|
||||
def forward(self, images):
|
||||
"""
|
||||
Args:
|
||||
@@ -228,7 +250,7 @@ class ResNetDiffusionBackbone(VLABackbone):
|
||||
"""
|
||||
any_tensor = next(iter(images.values()))
|
||||
B, T = any_tensor.shape[:2]
|
||||
cam_names = sorted(images.keys())
|
||||
cam_names = self._ordered_camera_names(images)
|
||||
|
||||
if self.use_separate_rgb_encoder_per_camera:
|
||||
# 独立编码器模式:每个摄像头使用对应的编码器
|
||||
@@ -236,7 +258,7 @@ class ResNetDiffusionBackbone(VLABackbone):
|
||||
for cam_idx, cam_name in enumerate(cam_names):
|
||||
img = images[cam_name]
|
||||
encoder = self.rgb_encoder[cam_idx]
|
||||
features = encoder.forward_single_image(img.view(B * T, *img.shape[2:]))
|
||||
features = encoder.forward_single_image(img.reshape(B * T, *img.shape[2:]))
|
||||
features_all.append(features)
|
||||
return torch.cat(features_all, dim=1).view(B, T, -1)
|
||||
else:
|
||||
@@ -244,7 +266,7 @@ class ResNetDiffusionBackbone(VLABackbone):
|
||||
features_all = []
|
||||
for cam_name in cam_names:
|
||||
img = images[cam_name]
|
||||
features = self.rgb_encoder.forward_single_image(img.view(B * T, *img.shape[2:]))
|
||||
features = self.rgb_encoder.forward_single_image(img.reshape(B * T, *img.shape[2:]))
|
||||
features_all.append(features)
|
||||
return torch.cat(features_all, dim=1).view(B, T, -1)
|
||||
|
||||
@@ -369,4 +391,4 @@ if __name__ == "__main__":
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("🎉 All tests completed successfully!")
|
||||
print("=" * 60)
|
||||
print("=" * 60)
|
||||
|
||||
@@ -1,19 +1,35 @@
|
||||
"""
|
||||
Transformer-based Diffusion Policy Head
|
||||
"""Transformer-based diffusion head aligned with diffusion_policy's TransformerForDiffusion."""
|
||||
|
||||
使用Transformer架构(Encoder-Decoder)替代UNet进行噪声预测。
|
||||
支持通过Cross-Attention注入全局条件(观测特征)。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModuleAttrMixin(nn.Module):
|
||||
"""Minimal local copy of diffusion_policy's ModuleAttrMixin for state-dict parity."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._dummy_variable = nn.Parameter()
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(iter(self.parameters())).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(iter(self.parameters())).dtype
|
||||
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
"""正弦位置编码(用于时间步嵌入)"""
|
||||
def __init__(self, dim: int):
|
||||
def __init__(self, dim: int) -> None:
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
@@ -27,35 +43,13 @@ class SinusoidalPosEmb(nn.Module):
|
||||
return emb
|
||||
|
||||
|
||||
class Transformer1D(nn.Module):
|
||||
"""
|
||||
Transformer-based 1D Diffusion Model
|
||||
|
||||
使用Encoder-Decoder架构:
|
||||
- Encoder: 处理条件(观测 + 时间步)
|
||||
- Decoder: 通过Cross-Attention预测噪声
|
||||
|
||||
Args:
|
||||
input_dim: 输入动作维度
|
||||
output_dim: 输出动作维度
|
||||
horizon: 预测horizon长度
|
||||
n_obs_steps: 观测步数
|
||||
cond_dim: 条件维度
|
||||
n_layer: Transformer层数
|
||||
n_head: 注意力头数
|
||||
n_emb: 嵌入维度
|
||||
p_drop_emb: Embedding dropout
|
||||
p_drop_attn: Attention dropout
|
||||
causal_attn: 是否使用因果注意力(自回归)
|
||||
n_cond_layers: Encoder层数(0表示使用MLP)
|
||||
"""
|
||||
|
||||
class Transformer1D(ModuleAttrMixin):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
horizon: int,
|
||||
n_obs_steps: int = None,
|
||||
n_obs_steps: Optional[int] = None,
|
||||
cond_dim: int = 0,
|
||||
n_layer: int = 8,
|
||||
n_head: int = 8,
|
||||
@@ -63,57 +57,42 @@ class Transformer1D(nn.Module):
|
||||
p_drop_emb: float = 0.1,
|
||||
p_drop_attn: float = 0.1,
|
||||
causal_attn: bool = False,
|
||||
time_as_cond: bool = True,
|
||||
obs_as_cond: bool = False,
|
||||
n_cond_layers: int = 0
|
||||
):
|
||||
n_cond_layers: int = 0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# 计算序列长度
|
||||
if n_obs_steps is None:
|
||||
n_obs_steps = horizon
|
||||
|
||||
T = horizon
|
||||
T_cond = 1 # 时间步token数量
|
||||
|
||||
# 确定是否使用观测作为条件
|
||||
T_cond = 1
|
||||
if not time_as_cond:
|
||||
T += 1
|
||||
T_cond -= 1
|
||||
obs_as_cond = cond_dim > 0
|
||||
if obs_as_cond:
|
||||
assert time_as_cond
|
||||
T_cond += n_obs_steps
|
||||
|
||||
# 保存配置
|
||||
self.T = T
|
||||
self.T_cond = T_cond
|
||||
self.horizon = horizon
|
||||
self.obs_as_cond = obs_as_cond
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
|
||||
# ==================== 输入嵌入 ====================
|
||||
self.input_emb = nn.Linear(input_dim, n_emb)
|
||||
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
|
||||
self.drop = nn.Dropout(p_drop_emb)
|
||||
|
||||
# ==================== 条件编码 ====================
|
||||
# 时间步嵌入
|
||||
self.time_emb = SinusoidalPosEmb(n_emb)
|
||||
|
||||
# 观测条件嵌入(可选)
|
||||
self.cond_obs_emb = None
|
||||
if obs_as_cond:
|
||||
self.cond_obs_emb = nn.Linear(cond_dim, n_emb)
|
||||
|
||||
# 条件位置编码
|
||||
self.cond_pos_emb = None
|
||||
self.encoder = None
|
||||
self.decoder = None
|
||||
encoder_only = False
|
||||
|
||||
if T_cond > 0:
|
||||
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
|
||||
|
||||
# ==================== Encoder ====================
|
||||
self.encoder = None
|
||||
self.encoder_only = False
|
||||
|
||||
if T_cond > 0:
|
||||
if n_cond_layers > 0:
|
||||
# 使用Transformer Encoder
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=n_emb,
|
||||
nhead=n_head,
|
||||
@@ -121,61 +100,19 @@ class Transformer1D(nn.Module):
|
||||
dropout=p_drop_attn,
|
||||
activation='gelu',
|
||||
batch_first=True,
|
||||
norm_first=True # Pre-LN更稳定
|
||||
norm_first=True,
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(
|
||||
encoder_layer=encoder_layer,
|
||||
num_layers=n_cond_layers
|
||||
num_layers=n_cond_layers,
|
||||
)
|
||||
else:
|
||||
# 使用简单的MLP
|
||||
self.encoder = nn.Sequential(
|
||||
nn.Linear(n_emb, 4 * n_emb),
|
||||
nn.Mish(),
|
||||
nn.Linear(4 * n_emb, n_emb)
|
||||
nn.Linear(4 * n_emb, n_emb),
|
||||
)
|
||||
else:
|
||||
# Encoder-only模式(BERT风格)
|
||||
self.encoder_only = True
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=n_emb,
|
||||
nhead=n_head,
|
||||
dim_feedforward=4 * n_emb,
|
||||
dropout=p_drop_attn,
|
||||
activation='gelu',
|
||||
batch_first=True,
|
||||
norm_first=True
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(
|
||||
encoder_layer=encoder_layer,
|
||||
num_layers=n_layer
|
||||
)
|
||||
|
||||
# ==================== Attention Mask ====================
|
||||
self.mask = None
|
||||
self.memory_mask = None
|
||||
|
||||
if causal_attn:
|
||||
# 因果mask:确保只关注左侧
|
||||
sz = T
|
||||
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
||||
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||
self.register_buffer("mask", mask)
|
||||
|
||||
if obs_as_cond:
|
||||
# 交叉注意力mask
|
||||
S = T_cond
|
||||
t, s = torch.meshgrid(
|
||||
torch.arange(T),
|
||||
torch.arange(S),
|
||||
indexing='ij'
|
||||
)
|
||||
mask = t >= (s - 1)
|
||||
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||
self.register_buffer('memory_mask', mask)
|
||||
|
||||
# ==================== Decoder ====================
|
||||
if not self.encoder_only:
|
||||
decoder_layer = nn.TransformerDecoderLayer(
|
||||
d_model=n_emb,
|
||||
nhead=n_head,
|
||||
@@ -183,136 +120,199 @@ class Transformer1D(nn.Module):
|
||||
dropout=p_drop_attn,
|
||||
activation='gelu',
|
||||
batch_first=True,
|
||||
norm_first=True
|
||||
norm_first=True,
|
||||
)
|
||||
self.decoder = nn.TransformerDecoder(
|
||||
decoder_layer=decoder_layer,
|
||||
num_layers=n_layer
|
||||
num_layers=n_layer,
|
||||
)
|
||||
else:
|
||||
encoder_only = True
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=n_emb,
|
||||
nhead=n_head,
|
||||
dim_feedforward=4 * n_emb,
|
||||
dropout=p_drop_attn,
|
||||
activation='gelu',
|
||||
batch_first=True,
|
||||
norm_first=True,
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(
|
||||
encoder_layer=encoder_layer,
|
||||
num_layers=n_layer,
|
||||
)
|
||||
|
||||
# ==================== 输出头 ====================
|
||||
if causal_attn:
|
||||
sz = T
|
||||
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
||||
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||
self.register_buffer('mask', mask)
|
||||
|
||||
if time_as_cond and obs_as_cond:
|
||||
S = T_cond
|
||||
t, s = torch.meshgrid(torch.arange(T), torch.arange(S), indexing='ij')
|
||||
mask = t >= (s - 1)
|
||||
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||
self.register_buffer('memory_mask', mask)
|
||||
else:
|
||||
self.memory_mask = None
|
||||
else:
|
||||
self.mask = None
|
||||
self.memory_mask = None
|
||||
|
||||
self.ln_f = nn.LayerNorm(n_emb)
|
||||
self.head = nn.Linear(n_emb, output_dim)
|
||||
|
||||
# ==================== 初始化 ====================
|
||||
self.apply(self._init_weights)
|
||||
self.T = T
|
||||
self.T_cond = T_cond
|
||||
self.horizon = horizon
|
||||
self.time_as_cond = time_as_cond
|
||||
self.obs_as_cond = obs_as_cond
|
||||
self.encoder_only = encoder_only
|
||||
|
||||
# 打印参数量
|
||||
total_params = sum(p.numel() for p in self.parameters())
|
||||
print(f"Transformer1D parameters: {total_params:,}")
|
||||
self.apply(self._init_weights)
|
||||
logger.info('number of parameters: %e', sum(p.numel() for p in self.parameters()))
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""初始化权重"""
|
||||
ignore_types = (
|
||||
nn.Dropout,
|
||||
SinusoidalPosEmb,
|
||||
nn.TransformerEncoderLayer,
|
||||
nn.TransformerDecoderLayer,
|
||||
nn.TransformerEncoder,
|
||||
nn.TransformerDecoder,
|
||||
nn.ModuleList,
|
||||
nn.Mish,
|
||||
nn.Sequential,
|
||||
)
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.MultiheadAttention):
|
||||
# MultiheadAttention的权重初始化
|
||||
for name in ['in_proj_weight', 'q_proj_weight', 'k_proj_weight', 'v_proj_weight']:
|
||||
weight = getattr(module, name, None)
|
||||
for name in ('in_proj_weight', 'q_proj_weight', 'k_proj_weight', 'v_proj_weight'):
|
||||
weight = getattr(module, name)
|
||||
if weight is not None:
|
||||
torch.nn.init.normal_(weight, mean=0.0, std=0.02)
|
||||
|
||||
for name in ['in_proj_bias', 'bias_k', 'bias_v']:
|
||||
bias = getattr(module, name, None)
|
||||
for name in ('in_proj_bias', 'bias_k', 'bias_v'):
|
||||
bias = getattr(module, name)
|
||||
if bias is not None:
|
||||
torch.nn.init.zeros_(bias)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
torch.nn.init.ones_(module.weight)
|
||||
elif isinstance(module, Transformer1D):
|
||||
# 位置编码初始化
|
||||
torch.nn.init.normal_(self.pos_emb, mean=0.0, std=0.02)
|
||||
if self.cond_pos_emb is not None:
|
||||
torch.nn.init.normal_(self.cond_pos_emb, mean=0.0, std=0.02)
|
||||
torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
|
||||
if module.cond_obs_emb is not None:
|
||||
torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02)
|
||||
elif isinstance(module, ignore_types):
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError(f'Unaccounted module {module}')
|
||||
|
||||
def get_optim_groups(self, weight_decay: float = 1e-3):
|
||||
decay = set()
|
||||
no_decay = set()
|
||||
whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention)
|
||||
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
||||
|
||||
for module_name, module in self.named_modules():
|
||||
for param_name, _ in module.named_parameters():
|
||||
full_param_name = f'{module_name}.{param_name}' if module_name else param_name
|
||||
|
||||
if param_name.endswith('bias'):
|
||||
no_decay.add(full_param_name)
|
||||
elif param_name.startswith('bias'):
|
||||
no_decay.add(full_param_name)
|
||||
elif param_name.endswith('weight') and isinstance(module, whitelist_weight_modules):
|
||||
decay.add(full_param_name)
|
||||
elif param_name.endswith('weight') and isinstance(module, blacklist_weight_modules):
|
||||
no_decay.add(full_param_name)
|
||||
|
||||
no_decay.add('pos_emb')
|
||||
no_decay.add('_dummy_variable')
|
||||
if self.cond_pos_emb is not None:
|
||||
no_decay.add('cond_pos_emb')
|
||||
|
||||
param_dict = {name: param for name, param in self.named_parameters()}
|
||||
inter_params = decay & no_decay
|
||||
union_params = decay | no_decay
|
||||
assert len(inter_params) == 0, f'parameters {inter_params} made it into both decay/no_decay sets!'
|
||||
assert len(param_dict.keys() - union_params) == 0, (
|
||||
f'parameters {param_dict.keys() - union_params} were not separated into either decay/no_decay sets!'
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
'params': [param_dict[name] for name in sorted(decay)],
|
||||
'weight_decay': weight_decay,
|
||||
},
|
||||
{
|
||||
'params': [param_dict[name] for name in sorted(no_decay)],
|
||||
'weight_decay': 0.0,
|
||||
},
|
||||
]
|
||||
|
||||
def configure_optimizers(
|
||||
self,
|
||||
learning_rate: float = 1e-4,
|
||||
weight_decay: float = 1e-3,
|
||||
betas: Tuple[float, float] = (0.9, 0.95),
|
||||
):
|
||||
optim_groups = self.get_optim_groups(weight_decay=weight_decay)
|
||||
return torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
cond: Optional[torch.Tensor] = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
前向传播
|
||||
|
||||
Args:
|
||||
sample: (B, T, input_dim) 输入序列(加噪动作)
|
||||
timestep: (B,) 时间步
|
||||
cond: (B, T', cond_dim) 条件序列(观测特征)
|
||||
|
||||
Returns:
|
||||
(B, T, output_dim) 预测的噪声
|
||||
"""
|
||||
# ==================== 处理时间步 ====================
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
# 扩展到batch维度
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
time_emb = self.time_emb(timesteps).unsqueeze(1) # (B, 1, n_emb)
|
||||
time_emb = self.time_emb(timesteps).unsqueeze(1)
|
||||
|
||||
# ==================== 处理输入 ====================
|
||||
input_emb = self.input_emb(sample) # (B, T, n_emb)
|
||||
input_emb = self.input_emb(sample)
|
||||
|
||||
# ==================== Encoder-Decoder模式 ====================
|
||||
if not self.encoder_only:
|
||||
# --- Encoder: 处理条件 ---
|
||||
if self.encoder_only:
|
||||
token_embeddings = torch.cat([time_emb, input_emb], dim=1)
|
||||
t = token_embeddings.shape[1]
|
||||
position_embeddings = self.pos_emb[:, :t, :]
|
||||
x = self.drop(token_embeddings + position_embeddings)
|
||||
x = self.encoder(src=x, mask=self.mask)
|
||||
x = x[:, 1:, :]
|
||||
else:
|
||||
cond_embeddings = time_emb
|
||||
|
||||
if self.obs_as_cond and cond is not None:
|
||||
# 添加观测条件
|
||||
cond_obs_emb = self.cond_obs_emb(cond) # (B, T_cond-1, n_emb)
|
||||
if self.obs_as_cond:
|
||||
cond_obs_emb = self.cond_obs_emb(cond)
|
||||
cond_embeddings = torch.cat([cond_embeddings, cond_obs_emb], dim=1)
|
||||
|
||||
# 添加位置编码
|
||||
tc = cond_embeddings.shape[1]
|
||||
pos_emb = self.cond_pos_emb[:, :tc, :]
|
||||
x = self.drop(cond_embeddings + pos_emb)
|
||||
position_embeddings = self.cond_pos_emb[:, :tc, :]
|
||||
x = self.drop(cond_embeddings + position_embeddings)
|
||||
memory = self.encoder(x)
|
||||
|
||||
# 通过encoder
|
||||
memory = self.encoder(x) # (B, T_cond, n_emb)
|
||||
|
||||
# --- Decoder: 预测噪声 ---
|
||||
# 添加位置编码到输入
|
||||
token_embeddings = input_emb
|
||||
t = token_embeddings.shape[1]
|
||||
pos_emb = self.pos_emb[:, :t, :]
|
||||
x = self.drop(token_embeddings + pos_emb)
|
||||
|
||||
# Cross-Attention: Query来自输入,Key/Value来自memory
|
||||
position_embeddings = self.pos_emb[:, :t, :]
|
||||
x = self.drop(token_embeddings + position_embeddings)
|
||||
x = self.decoder(
|
||||
tgt=x,
|
||||
memory=memory,
|
||||
tgt_mask=self.mask,
|
||||
memory_mask=self.memory_mask
|
||||
memory_mask=self.memory_mask,
|
||||
)
|
||||
|
||||
# ==================== Encoder-Only模式 ====================
|
||||
else:
|
||||
# BERT风格:时间步作为特殊token
|
||||
token_embeddings = torch.cat([time_emb, input_emb], dim=1)
|
||||
t = token_embeddings.shape[1]
|
||||
pos_emb = self.pos_emb[:, :t, :]
|
||||
x = self.drop(token_embeddings + pos_emb)
|
||||
|
||||
x = self.encoder(src=x, mask=self.mask)
|
||||
x = x[:, 1:, :] # 移除时间步token
|
||||
|
||||
# ==================== 输出头 ====================
|
||||
x = self.ln_f(x)
|
||||
x = self.head(x) # (B, T, output_dim)
|
||||
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 便捷函数:创建Transformer1D模型
|
||||
# ============================================================================
|
||||
def create_transformer1d(
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
@@ -322,26 +322,9 @@ def create_transformer1d(
|
||||
n_layer: int = 8,
|
||||
n_head: int = 8,
|
||||
n_emb: int = 256,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Transformer1D:
|
||||
"""
|
||||
创建Transformer1D模型的便捷函数
|
||||
|
||||
Args:
|
||||
input_dim: 输入动作维度
|
||||
output_dim: 输出动作维度
|
||||
horizon: 预测horizon
|
||||
n_obs_steps: 观测步数
|
||||
cond_dim: 条件维度
|
||||
n_layer: Transformer层数
|
||||
n_head: 注意力头数
|
||||
n_emb: 嵌入维度
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
Transformer1D模型
|
||||
"""
|
||||
model = Transformer1D(
|
||||
return Transformer1D(
|
||||
input_dim=input_dim,
|
||||
output_dim=output_dim,
|
||||
horizon=horizon,
|
||||
@@ -350,47 +333,5 @@ def create_transformer1d(
|
||||
n_layer=n_layer,
|
||||
n_head=n_head,
|
||||
n_emb=n_emb,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 80)
|
||||
print("Testing Transformer1D")
|
||||
print("=" * 80)
|
||||
|
||||
# 配置
|
||||
B = 4
|
||||
T = 16
|
||||
action_dim = 16
|
||||
obs_horizon = 2
|
||||
cond_dim = 416 # vision + state特征维度
|
||||
|
||||
# 创建模型
|
||||
model = Transformer1D(
|
||||
input_dim=action_dim,
|
||||
output_dim=action_dim,
|
||||
horizon=T,
|
||||
n_obs_steps=obs_horizon,
|
||||
cond_dim=cond_dim,
|
||||
n_layer=4,
|
||||
n_head=8,
|
||||
n_emb=256,
|
||||
causal_attn=False
|
||||
)
|
||||
|
||||
# 测试前向传播
|
||||
sample = torch.randn(B, T, action_dim)
|
||||
timestep = torch.randint(0, 100, (B,))
|
||||
cond = torch.randn(B, obs_horizon, cond_dim)
|
||||
|
||||
output = model(sample, timestep, cond)
|
||||
|
||||
print(f"\n输入:")
|
||||
print(f" sample: {sample.shape}")
|
||||
print(f" timestep: {timestep.shape}")
|
||||
print(f" cond: {cond.shape}")
|
||||
print(f"\n输出:")
|
||||
print(f" output: {output.shape}")
|
||||
print(f"\n✅ 测试通过!")
|
||||
|
||||
Reference in New Issue
Block a user