feat(vla): align transformer training stack and rollout validation

This commit is contained in:
Logic
2026-03-31 15:39:20 +08:00
parent 424c265823
commit d84bc6876e
25 changed files with 4043 additions and 706 deletions

View File

@@ -178,12 +178,18 @@ class ResNetDiffusionBackbone(VLABackbone):
spatial_softmax_num_keypoints: int = 32,
use_separate_rgb_encoder_per_camera: bool = False, # 新增:是否为每个摄像头使用独立编码器
num_cameras: int = 1, # 新增:摄像头数量(仅在独立编码器模式下使用)
camera_names: Optional[Tuple[str, ...]] = None, # 显式相机顺序
freeze_backbone: bool = True, # 新增是否冻结ResNet backbone推荐True
):
super().__init__()
self.use_separate_rgb_encoder_per_camera = use_separate_rgb_encoder_per_camera
self.num_cameras = num_cameras
self.camera_names = tuple(camera_names) if camera_names is not None else None
if self.camera_names is not None and len(self.camera_names) != self.num_cameras:
raise ValueError(
f"camera_names 长度({len(self.camera_names)})与 num_cameras({self.num_cameras})不一致"
)
if use_separate_rgb_encoder_per_camera:
# 独立编码器模式:为每个摄像头创建独立的编码器
@@ -217,6 +223,22 @@ class ResNetDiffusionBackbone(VLABackbone):
)
self.feature_dim = self.rgb_encoder.feature_dim
def _ordered_camera_names(self, images) -> Tuple[str, ...]:
if self.camera_names is None:
camera_names = tuple(sorted(images.keys()))
if len(camera_names) != self.num_cameras:
raise ValueError(
f"图像输入相机数量({len(camera_names)})与 num_cameras({self.num_cameras})不一致"
)
return camera_names
missing = [cam_name for cam_name in self.camera_names if cam_name not in images]
if missing:
raise ValueError(
f"图像输入缺少必需相机。missing={missing}, expected={list(self.camera_names)}"
)
return self.camera_names
def forward(self, images):
"""
Args:
@@ -228,7 +250,7 @@ class ResNetDiffusionBackbone(VLABackbone):
"""
any_tensor = next(iter(images.values()))
B, T = any_tensor.shape[:2]
cam_names = sorted(images.keys())
cam_names = self._ordered_camera_names(images)
if self.use_separate_rgb_encoder_per_camera:
# 独立编码器模式:每个摄像头使用对应的编码器
@@ -236,7 +258,7 @@ class ResNetDiffusionBackbone(VLABackbone):
for cam_idx, cam_name in enumerate(cam_names):
img = images[cam_name]
encoder = self.rgb_encoder[cam_idx]
features = encoder.forward_single_image(img.view(B * T, *img.shape[2:]))
features = encoder.forward_single_image(img.reshape(B * T, *img.shape[2:]))
features_all.append(features)
return torch.cat(features_all, dim=1).view(B, T, -1)
else:
@@ -244,7 +266,7 @@ class ResNetDiffusionBackbone(VLABackbone):
features_all = []
for cam_name in cam_names:
img = images[cam_name]
features = self.rgb_encoder.forward_single_image(img.view(B * T, *img.shape[2:]))
features = self.rgb_encoder.forward_single_image(img.reshape(B * T, *img.shape[2:]))
features_all.append(features)
return torch.cat(features_all, dim=1).view(B, T, -1)
@@ -369,4 +391,4 @@ if __name__ == "__main__":
print("\n" + "=" * 60)
print("🎉 All tests completed successfully!")
print("=" * 60)
print("=" * 60)

View File

@@ -1,19 +1,35 @@
"""
Transformer-based Diffusion Policy Head
"""Transformer-based diffusion head aligned with diffusion_policy's TransformerForDiffusion."""
使用Transformer架构Encoder-Decoder替代UNet进行噪声预测。
支持通过Cross-Attention注入全局条件观测特征
"""
from __future__ import annotations
import logging
import math
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from typing import Optional
logger = logging.getLogger(__name__)
class ModuleAttrMixin(nn.Module):
"""Minimal local copy of diffusion_policy's ModuleAttrMixin for state-dict parity."""
def __init__(self) -> None:
super().__init__()
self._dummy_variable = nn.Parameter()
@property
def device(self):
return next(iter(self.parameters())).device
@property
def dtype(self):
return next(iter(self.parameters())).dtype
class SinusoidalPosEmb(nn.Module):
"""正弦位置编码(用于时间步嵌入)"""
def __init__(self, dim: int):
def __init__(self, dim: int) -> None:
super().__init__()
self.dim = dim
@@ -27,35 +43,13 @@ class SinusoidalPosEmb(nn.Module):
return emb
class Transformer1D(nn.Module):
"""
Transformer-based 1D Diffusion Model
使用Encoder-Decoder架构
- Encoder: 处理条件(观测 + 时间步)
- Decoder: 通过Cross-Attention预测噪声
Args:
input_dim: 输入动作维度
output_dim: 输出动作维度
horizon: 预测horizon长度
n_obs_steps: 观测步数
cond_dim: 条件维度
n_layer: Transformer层数
n_head: 注意力头数
n_emb: 嵌入维度
p_drop_emb: Embedding dropout
p_drop_attn: Attention dropout
causal_attn: 是否使用因果注意力(自回归)
n_cond_layers: Encoder层数0表示使用MLP
"""
class Transformer1D(ModuleAttrMixin):
def __init__(
self,
input_dim: int,
output_dim: int,
horizon: int,
n_obs_steps: int = None,
n_obs_steps: Optional[int] = None,
cond_dim: int = 0,
n_layer: int = 8,
n_head: int = 8,
@@ -63,57 +57,42 @@ class Transformer1D(nn.Module):
p_drop_emb: float = 0.1,
p_drop_attn: float = 0.1,
causal_attn: bool = False,
time_as_cond: bool = True,
obs_as_cond: bool = False,
n_cond_layers: int = 0
):
n_cond_layers: int = 0,
) -> None:
super().__init__()
# 计算序列长度
if n_obs_steps is None:
n_obs_steps = horizon
T = horizon
T_cond = 1 # 时间步token数量
# 确定是否使用观测作为条件
T_cond = 1
if not time_as_cond:
T += 1
T_cond -= 1
obs_as_cond = cond_dim > 0
if obs_as_cond:
assert time_as_cond
T_cond += n_obs_steps
# 保存配置
self.T = T
self.T_cond = T_cond
self.horizon = horizon
self.obs_as_cond = obs_as_cond
self.input_dim = input_dim
self.output_dim = output_dim
# ==================== 输入嵌入 ====================
self.input_emb = nn.Linear(input_dim, n_emb)
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
self.drop = nn.Dropout(p_drop_emb)
# ==================== 条件编码 ====================
# 时间步嵌入
self.time_emb = SinusoidalPosEmb(n_emb)
# 观测条件嵌入(可选)
self.cond_obs_emb = None
if obs_as_cond:
self.cond_obs_emb = nn.Linear(cond_dim, n_emb)
# 条件位置编码
self.cond_pos_emb = None
self.encoder = None
self.decoder = None
encoder_only = False
if T_cond > 0:
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
# ==================== Encoder ====================
self.encoder = None
self.encoder_only = False
if T_cond > 0:
if n_cond_layers > 0:
# 使用Transformer Encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=n_emb,
nhead=n_head,
@@ -121,61 +100,19 @@ class Transformer1D(nn.Module):
dropout=p_drop_attn,
activation='gelu',
batch_first=True,
norm_first=True # Pre-LN更稳定
norm_first=True,
)
self.encoder = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=n_cond_layers
num_layers=n_cond_layers,
)
else:
# 使用简单的MLP
self.encoder = nn.Sequential(
nn.Linear(n_emb, 4 * n_emb),
nn.Mish(),
nn.Linear(4 * n_emb, n_emb)
nn.Linear(4 * n_emb, n_emb),
)
else:
# Encoder-only模式BERT风格
self.encoder_only = True
encoder_layer = nn.TransformerEncoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4 * n_emb,
dropout=p_drop_attn,
activation='gelu',
batch_first=True,
norm_first=True
)
self.encoder = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=n_layer
)
# ==================== Attention Mask ====================
self.mask = None
self.memory_mask = None
if causal_attn:
# 因果mask确保只关注左侧
sz = T
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
self.register_buffer("mask", mask)
if obs_as_cond:
# 交叉注意力mask
S = T_cond
t, s = torch.meshgrid(
torch.arange(T),
torch.arange(S),
indexing='ij'
)
mask = t >= (s - 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
self.register_buffer('memory_mask', mask)
# ==================== Decoder ====================
if not self.encoder_only:
decoder_layer = nn.TransformerDecoderLayer(
d_model=n_emb,
nhead=n_head,
@@ -183,136 +120,199 @@ class Transformer1D(nn.Module):
dropout=p_drop_attn,
activation='gelu',
batch_first=True,
norm_first=True
norm_first=True,
)
self.decoder = nn.TransformerDecoder(
decoder_layer=decoder_layer,
num_layers=n_layer
num_layers=n_layer,
)
else:
encoder_only = True
encoder_layer = nn.TransformerEncoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4 * n_emb,
dropout=p_drop_attn,
activation='gelu',
batch_first=True,
norm_first=True,
)
self.encoder = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=n_layer,
)
# ==================== 输出头 ====================
if causal_attn:
sz = T
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
self.register_buffer('mask', mask)
if time_as_cond and obs_as_cond:
S = T_cond
t, s = torch.meshgrid(torch.arange(T), torch.arange(S), indexing='ij')
mask = t >= (s - 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
self.register_buffer('memory_mask', mask)
else:
self.memory_mask = None
else:
self.mask = None
self.memory_mask = None
self.ln_f = nn.LayerNorm(n_emb)
self.head = nn.Linear(n_emb, output_dim)
# ==================== 初始化 ====================
self.apply(self._init_weights)
self.T = T
self.T_cond = T_cond
self.horizon = horizon
self.time_as_cond = time_as_cond
self.obs_as_cond = obs_as_cond
self.encoder_only = encoder_only
# 打印参数量
total_params = sum(p.numel() for p in self.parameters())
print(f"Transformer1D parameters: {total_params:,}")
self.apply(self._init_weights)
logger.info('number of parameters: %e', sum(p.numel() for p in self.parameters()))
def _init_weights(self, module):
"""初始化权重"""
ignore_types = (
nn.Dropout,
SinusoidalPosEmb,
nn.TransformerEncoderLayer,
nn.TransformerDecoderLayer,
nn.TransformerEncoder,
nn.TransformerDecoder,
nn.ModuleList,
nn.Mish,
nn.Sequential,
)
if isinstance(module, (nn.Linear, nn.Embedding)):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.MultiheadAttention):
# MultiheadAttention的权重初始化
for name in ['in_proj_weight', 'q_proj_weight', 'k_proj_weight', 'v_proj_weight']:
weight = getattr(module, name, None)
for name in ('in_proj_weight', 'q_proj_weight', 'k_proj_weight', 'v_proj_weight'):
weight = getattr(module, name)
if weight is not None:
torch.nn.init.normal_(weight, mean=0.0, std=0.02)
for name in ['in_proj_bias', 'bias_k', 'bias_v']:
bias = getattr(module, name, None)
for name in ('in_proj_bias', 'bias_k', 'bias_v'):
bias = getattr(module, name)
if bias is not None:
torch.nn.init.zeros_(bias)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight)
elif isinstance(module, Transformer1D):
# 位置编码初始化
torch.nn.init.normal_(self.pos_emb, mean=0.0, std=0.02)
if self.cond_pos_emb is not None:
torch.nn.init.normal_(self.cond_pos_emb, mean=0.0, std=0.02)
torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
if module.cond_obs_emb is not None:
torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02)
elif isinstance(module, ignore_types):
pass
else:
raise RuntimeError(f'Unaccounted module {module}')
def get_optim_groups(self, weight_decay: float = 1e-3):
decay = set()
no_decay = set()
whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention)
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for module_name, module in self.named_modules():
for param_name, _ in module.named_parameters():
full_param_name = f'{module_name}.{param_name}' if module_name else param_name
if param_name.endswith('bias'):
no_decay.add(full_param_name)
elif param_name.startswith('bias'):
no_decay.add(full_param_name)
elif param_name.endswith('weight') and isinstance(module, whitelist_weight_modules):
decay.add(full_param_name)
elif param_name.endswith('weight') and isinstance(module, blacklist_weight_modules):
no_decay.add(full_param_name)
no_decay.add('pos_emb')
no_decay.add('_dummy_variable')
if self.cond_pos_emb is not None:
no_decay.add('cond_pos_emb')
param_dict = {name: param for name, param in self.named_parameters()}
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, f'parameters {inter_params} made it into both decay/no_decay sets!'
assert len(param_dict.keys() - union_params) == 0, (
f'parameters {param_dict.keys() - union_params} were not separated into either decay/no_decay sets!'
)
return [
{
'params': [param_dict[name] for name in sorted(decay)],
'weight_decay': weight_decay,
},
{
'params': [param_dict[name] for name in sorted(no_decay)],
'weight_decay': 0.0,
},
]
def configure_optimizers(
self,
learning_rate: float = 1e-4,
weight_decay: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.95),
):
optim_groups = self.get_optim_groups(weight_decay=weight_decay)
return torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
def forward(
self,
sample: torch.Tensor,
timestep: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
cond: Optional[torch.Tensor] = None,
**kwargs
**kwargs,
):
"""
前向传播
Args:
sample: (B, T, input_dim) 输入序列(加噪动作)
timestep: (B,) 时间步
cond: (B, T', cond_dim) 条件序列(观测特征)
Returns:
(B, T, output_dim) 预测的噪声
"""
# ==================== 处理时间步 ====================
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# 扩展到batch维度
timesteps = timesteps.expand(sample.shape[0])
time_emb = self.time_emb(timesteps).unsqueeze(1) # (B, 1, n_emb)
time_emb = self.time_emb(timesteps).unsqueeze(1)
# ==================== 处理输入 ====================
input_emb = self.input_emb(sample) # (B, T, n_emb)
input_emb = self.input_emb(sample)
# ==================== Encoder-Decoder模式 ====================
if not self.encoder_only:
# --- Encoder: 处理条件 ---
if self.encoder_only:
token_embeddings = torch.cat([time_emb, input_emb], dim=1)
t = token_embeddings.shape[1]
position_embeddings = self.pos_emb[:, :t, :]
x = self.drop(token_embeddings + position_embeddings)
x = self.encoder(src=x, mask=self.mask)
x = x[:, 1:, :]
else:
cond_embeddings = time_emb
if self.obs_as_cond and cond is not None:
# 添加观测条件
cond_obs_emb = self.cond_obs_emb(cond) # (B, T_cond-1, n_emb)
if self.obs_as_cond:
cond_obs_emb = self.cond_obs_emb(cond)
cond_embeddings = torch.cat([cond_embeddings, cond_obs_emb], dim=1)
# 添加位置编码
tc = cond_embeddings.shape[1]
pos_emb = self.cond_pos_emb[:, :tc, :]
x = self.drop(cond_embeddings + pos_emb)
position_embeddings = self.cond_pos_emb[:, :tc, :]
x = self.drop(cond_embeddings + position_embeddings)
memory = self.encoder(x)
# 通过encoder
memory = self.encoder(x) # (B, T_cond, n_emb)
# --- Decoder: 预测噪声 ---
# 添加位置编码到输入
token_embeddings = input_emb
t = token_embeddings.shape[1]
pos_emb = self.pos_emb[:, :t, :]
x = self.drop(token_embeddings + pos_emb)
# Cross-Attention: Query来自输入Key/Value来自memory
position_embeddings = self.pos_emb[:, :t, :]
x = self.drop(token_embeddings + position_embeddings)
x = self.decoder(
tgt=x,
memory=memory,
tgt_mask=self.mask,
memory_mask=self.memory_mask
memory_mask=self.memory_mask,
)
# ==================== Encoder-Only模式 ====================
else:
# BERT风格时间步作为特殊token
token_embeddings = torch.cat([time_emb, input_emb], dim=1)
t = token_embeddings.shape[1]
pos_emb = self.pos_emb[:, :t, :]
x = self.drop(token_embeddings + pos_emb)
x = self.encoder(src=x, mask=self.mask)
x = x[:, 1:, :] # 移除时间步token
# ==================== 输出头 ====================
x = self.ln_f(x)
x = self.head(x) # (B, T, output_dim)
x = self.head(x)
return x
# ============================================================================
# 便捷函数创建Transformer1D模型
# ============================================================================
def create_transformer1d(
input_dim: int,
output_dim: int,
@@ -322,26 +322,9 @@ def create_transformer1d(
n_layer: int = 8,
n_head: int = 8,
n_emb: int = 256,
**kwargs
**kwargs,
) -> Transformer1D:
"""
创建Transformer1D模型的便捷函数
Args:
input_dim: 输入动作维度
output_dim: 输出动作维度
horizon: 预测horizon
n_obs_steps: 观测步数
cond_dim: 条件维度
n_layer: Transformer层数
n_head: 注意力头数
n_emb: 嵌入维度
**kwargs: 其他参数
Returns:
Transformer1D模型
"""
model = Transformer1D(
return Transformer1D(
input_dim=input_dim,
output_dim=output_dim,
horizon=horizon,
@@ -350,47 +333,5 @@ def create_transformer1d(
n_layer=n_layer,
n_head=n_head,
n_emb=n_emb,
**kwargs
**kwargs,
)
return model
if __name__ == "__main__":
print("=" * 80)
print("Testing Transformer1D")
print("=" * 80)
# 配置
B = 4
T = 16
action_dim = 16
obs_horizon = 2
cond_dim = 416 # vision + state特征维度
# 创建模型
model = Transformer1D(
input_dim=action_dim,
output_dim=action_dim,
horizon=T,
n_obs_steps=obs_horizon,
cond_dim=cond_dim,
n_layer=4,
n_head=8,
n_emb=256,
causal_attn=False
)
# 测试前向传播
sample = torch.randn(B, T, action_dim)
timestep = torch.randint(0, 100, (B,))
cond = torch.randn(B, obs_horizon, cond_dim)
output = model(sample, timestep, cond)
print(f"\n输入:")
print(f" sample: {sample.shape}")
print(f" timestep: {timestep.shape}")
print(f" cond: {cond.shape}")
print(f"\n输出:")
print(f" output: {output.shape}")
print(f"\n✅ 测试通过!")