Files
roboimi/roboimi/vla/models/heads/transformer1d.py
2026-02-28 19:07:27 +08:00

397 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Transformer-based Diffusion Policy Head
使用Transformer架构Encoder-Decoder替代UNet进行噪声预测。
支持通过Cross-Attention注入全局条件观测特征
"""
import math
import torch
import torch.nn as nn
from typing import Optional
class SinusoidalPosEmb(nn.Module):
"""正弦位置编码(用于时间步嵌入)"""
def __init__(self, dim: int):
super().__init__()
self.dim = dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class Transformer1D(nn.Module):
"""
Transformer-based 1D Diffusion Model
使用Encoder-Decoder架构
- Encoder: 处理条件(观测 + 时间步)
- Decoder: 通过Cross-Attention预测噪声
Args:
input_dim: 输入动作维度
output_dim: 输出动作维度
horizon: 预测horizon长度
n_obs_steps: 观测步数
cond_dim: 条件维度
n_layer: Transformer层数
n_head: 注意力头数
n_emb: 嵌入维度
p_drop_emb: Embedding dropout
p_drop_attn: Attention dropout
causal_attn: 是否使用因果注意力(自回归)
n_cond_layers: Encoder层数0表示使用MLP
"""
def __init__(
self,
input_dim: int,
output_dim: int,
horizon: int,
n_obs_steps: int = None,
cond_dim: int = 0,
n_layer: int = 8,
n_head: int = 8,
n_emb: int = 256,
p_drop_emb: float = 0.1,
p_drop_attn: float = 0.1,
causal_attn: bool = False,
obs_as_cond: bool = False,
n_cond_layers: int = 0
):
super().__init__()
# 计算序列长度
if n_obs_steps is None:
n_obs_steps = horizon
T = horizon
T_cond = 1 # 时间步token数量
# 确定是否使用观测作为条件
obs_as_cond = cond_dim > 0
if obs_as_cond:
T_cond += n_obs_steps
# 保存配置
self.T = T
self.T_cond = T_cond
self.horizon = horizon
self.obs_as_cond = obs_as_cond
self.input_dim = input_dim
self.output_dim = output_dim
# ==================== 输入嵌入 ====================
self.input_emb = nn.Linear(input_dim, n_emb)
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
self.drop = nn.Dropout(p_drop_emb)
# ==================== 条件编码 ====================
# 时间步嵌入
self.time_emb = SinusoidalPosEmb(n_emb)
# 观测条件嵌入(可选)
self.cond_obs_emb = None
if obs_as_cond:
self.cond_obs_emb = nn.Linear(cond_dim, n_emb)
# 条件位置编码
self.cond_pos_emb = None
if T_cond > 0:
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
# ==================== Encoder ====================
self.encoder = None
self.encoder_only = False
if T_cond > 0:
if n_cond_layers > 0:
# 使用Transformer Encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4 * n_emb,
dropout=p_drop_attn,
activation='gelu',
batch_first=True,
norm_first=True # Pre-LN更稳定
)
self.encoder = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=n_cond_layers
)
else:
# 使用简单的MLP
self.encoder = nn.Sequential(
nn.Linear(n_emb, 4 * n_emb),
nn.Mish(),
nn.Linear(4 * n_emb, n_emb)
)
else:
# Encoder-only模式BERT风格
self.encoder_only = True
encoder_layer = nn.TransformerEncoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4 * n_emb,
dropout=p_drop_attn,
activation='gelu',
batch_first=True,
norm_first=True
)
self.encoder = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=n_layer
)
# ==================== Attention Mask ====================
self.mask = None
self.memory_mask = None
if causal_attn:
# 因果mask确保只关注左侧
sz = T
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
self.register_buffer("mask", mask)
if obs_as_cond:
# 交叉注意力mask
S = T_cond
t, s = torch.meshgrid(
torch.arange(T),
torch.arange(S),
indexing='ij'
)
mask = t >= (s - 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
self.register_buffer('memory_mask', mask)
# ==================== Decoder ====================
if not self.encoder_only:
decoder_layer = nn.TransformerDecoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4 * n_emb,
dropout=p_drop_attn,
activation='gelu',
batch_first=True,
norm_first=True
)
self.decoder = nn.TransformerDecoder(
decoder_layer=decoder_layer,
num_layers=n_layer
)
# ==================== 输出头 ====================
self.ln_f = nn.LayerNorm(n_emb)
self.head = nn.Linear(n_emb, output_dim)
# ==================== 初始化 ====================
self.apply(self._init_weights)
# 打印参数量
total_params = sum(p.numel() for p in self.parameters())
print(f"Transformer1D parameters: {total_params:,}")
def _init_weights(self, module):
"""初始化权重"""
if isinstance(module, (nn.Linear, nn.Embedding)):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.MultiheadAttention):
# MultiheadAttention的权重初始化
for name in ['in_proj_weight', 'q_proj_weight', 'k_proj_weight', 'v_proj_weight']:
weight = getattr(module, name, None)
if weight is not None:
torch.nn.init.normal_(weight, mean=0.0, std=0.02)
for name in ['in_proj_bias', 'bias_k', 'bias_v']:
bias = getattr(module, name, None)
if bias is not None:
torch.nn.init.zeros_(bias)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight)
elif isinstance(module, Transformer1D):
# 位置编码初始化
torch.nn.init.normal_(self.pos_emb, mean=0.0, std=0.02)
if self.cond_pos_emb is not None:
torch.nn.init.normal_(self.cond_pos_emb, mean=0.0, std=0.02)
def forward(
self,
sample: torch.Tensor,
timestep: torch.Tensor,
cond: Optional[torch.Tensor] = None,
**kwargs
):
"""
前向传播
Args:
sample: (B, T, input_dim) 输入序列(加噪动作)
timestep: (B,) 时间步
cond: (B, T', cond_dim) 条件序列(观测特征)
Returns:
(B, T, output_dim) 预测的噪声
"""
# ==================== 处理时间步 ====================
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# 扩展到batch维度
timesteps = timesteps.expand(sample.shape[0])
time_emb = self.time_emb(timesteps).unsqueeze(1) # (B, 1, n_emb)
# ==================== 处理输入 ====================
input_emb = self.input_emb(sample) # (B, T, n_emb)
# ==================== Encoder-Decoder模式 ====================
if not self.encoder_only:
# --- Encoder: 处理条件 ---
cond_embeddings = time_emb
if self.obs_as_cond and cond is not None:
# 添加观测条件
cond_obs_emb = self.cond_obs_emb(cond) # (B, T_cond-1, n_emb)
cond_embeddings = torch.cat([cond_embeddings, cond_obs_emb], dim=1)
# 添加位置编码
tc = cond_embeddings.shape[1]
pos_emb = self.cond_pos_emb[:, :tc, :]
x = self.drop(cond_embeddings + pos_emb)
# 通过encoder
memory = self.encoder(x) # (B, T_cond, n_emb)
# --- Decoder: 预测噪声 ---
# 添加位置编码到输入
token_embeddings = input_emb
t = token_embeddings.shape[1]
pos_emb = self.pos_emb[:, :t, :]
x = self.drop(token_embeddings + pos_emb)
# Cross-Attention: Query来自输入Key/Value来自memory
x = self.decoder(
tgt=x,
memory=memory,
tgt_mask=self.mask,
memory_mask=self.memory_mask
)
# ==================== Encoder-Only模式 ====================
else:
# BERT风格时间步作为特殊token
token_embeddings = torch.cat([time_emb, input_emb], dim=1)
t = token_embeddings.shape[1]
pos_emb = self.pos_emb[:, :t, :]
x = self.drop(token_embeddings + pos_emb)
x = self.encoder(src=x, mask=self.mask)
x = x[:, 1:, :] # 移除时间步token
# ==================== 输出头 ====================
x = self.ln_f(x)
x = self.head(x) # (B, T, output_dim)
return x
# ============================================================================
# 便捷函数创建Transformer1D模型
# ============================================================================
def create_transformer1d(
input_dim: int,
output_dim: int,
horizon: int,
n_obs_steps: int,
cond_dim: int,
n_layer: int = 8,
n_head: int = 8,
n_emb: int = 256,
**kwargs
) -> Transformer1D:
"""
创建Transformer1D模型的便捷函数
Args:
input_dim: 输入动作维度
output_dim: 输出动作维度
horizon: 预测horizon
n_obs_steps: 观测步数
cond_dim: 条件维度
n_layer: Transformer层数
n_head: 注意力头数
n_emb: 嵌入维度
**kwargs: 其他参数
Returns:
Transformer1D模型
"""
model = Transformer1D(
input_dim=input_dim,
output_dim=output_dim,
horizon=horizon,
n_obs_steps=n_obs_steps,
cond_dim=cond_dim,
n_layer=n_layer,
n_head=n_head,
n_emb=n_emb,
**kwargs
)
return model
if __name__ == "__main__":
print("=" * 80)
print("Testing Transformer1D")
print("=" * 80)
# 配置
B = 4
T = 16
action_dim = 16
obs_horizon = 2
cond_dim = 416 # vision + state特征维度
# 创建模型
model = Transformer1D(
input_dim=action_dim,
output_dim=action_dim,
horizon=T,
n_obs_steps=obs_horizon,
cond_dim=cond_dim,
n_layer=4,
n_head=8,
n_emb=256,
causal_attn=False
)
# 测试前向传播
sample = torch.randn(B, T, action_dim)
timestep = torch.randint(0, 100, (B,))
cond = torch.randn(B, obs_horizon, cond_dim)
output = model(sample, timestep, cond)
print(f"\n输入:")
print(f" sample: {sample.shape}")
print(f" timestep: {timestep.shape}")
print(f" cond: {cond.shape}")
print(f"\n输出:")
print(f" output: {output.shape}")
print(f"\n✅ 测试通过!")