397 lines
13 KiB
Python
397 lines
13 KiB
Python
"""
|
||
Transformer-based Diffusion Policy Head
|
||
|
||
使用Transformer架构(Encoder-Decoder)替代UNet进行噪声预测。
|
||
支持通过Cross-Attention注入全局条件(观测特征)。
|
||
"""
|
||
|
||
import math
|
||
import torch
|
||
import torch.nn as nn
|
||
from typing import Optional
|
||
|
||
|
||
class SinusoidalPosEmb(nn.Module):
|
||
"""正弦位置编码(用于时间步嵌入)"""
|
||
def __init__(self, dim: int):
|
||
super().__init__()
|
||
self.dim = dim
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
device = x.device
|
||
half_dim = self.dim // 2
|
||
emb = math.log(10000) / (half_dim - 1)
|
||
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||
emb = x[:, None] * emb[None, :]
|
||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||
return emb
|
||
|
||
|
||
class Transformer1D(nn.Module):
|
||
"""
|
||
Transformer-based 1D Diffusion Model
|
||
|
||
使用Encoder-Decoder架构:
|
||
- Encoder: 处理条件(观测 + 时间步)
|
||
- Decoder: 通过Cross-Attention预测噪声
|
||
|
||
Args:
|
||
input_dim: 输入动作维度
|
||
output_dim: 输出动作维度
|
||
horizon: 预测horizon长度
|
||
n_obs_steps: 观测步数
|
||
cond_dim: 条件维度
|
||
n_layer: Transformer层数
|
||
n_head: 注意力头数
|
||
n_emb: 嵌入维度
|
||
p_drop_emb: Embedding dropout
|
||
p_drop_attn: Attention dropout
|
||
causal_attn: 是否使用因果注意力(自回归)
|
||
n_cond_layers: Encoder层数(0表示使用MLP)
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
input_dim: int,
|
||
output_dim: int,
|
||
horizon: int,
|
||
n_obs_steps: int = None,
|
||
cond_dim: int = 0,
|
||
n_layer: int = 8,
|
||
n_head: int = 8,
|
||
n_emb: int = 256,
|
||
p_drop_emb: float = 0.1,
|
||
p_drop_attn: float = 0.1,
|
||
causal_attn: bool = False,
|
||
obs_as_cond: bool = False,
|
||
n_cond_layers: int = 0
|
||
):
|
||
super().__init__()
|
||
|
||
# 计算序列长度
|
||
if n_obs_steps is None:
|
||
n_obs_steps = horizon
|
||
|
||
T = horizon
|
||
T_cond = 1 # 时间步token数量
|
||
|
||
# 确定是否使用观测作为条件
|
||
obs_as_cond = cond_dim > 0
|
||
if obs_as_cond:
|
||
T_cond += n_obs_steps
|
||
|
||
# 保存配置
|
||
self.T = T
|
||
self.T_cond = T_cond
|
||
self.horizon = horizon
|
||
self.obs_as_cond = obs_as_cond
|
||
self.input_dim = input_dim
|
||
self.output_dim = output_dim
|
||
|
||
# ==================== 输入嵌入 ====================
|
||
self.input_emb = nn.Linear(input_dim, n_emb)
|
||
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
|
||
self.drop = nn.Dropout(p_drop_emb)
|
||
|
||
# ==================== 条件编码 ====================
|
||
# 时间步嵌入
|
||
self.time_emb = SinusoidalPosEmb(n_emb)
|
||
|
||
# 观测条件嵌入(可选)
|
||
self.cond_obs_emb = None
|
||
if obs_as_cond:
|
||
self.cond_obs_emb = nn.Linear(cond_dim, n_emb)
|
||
|
||
# 条件位置编码
|
||
self.cond_pos_emb = None
|
||
if T_cond > 0:
|
||
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
|
||
|
||
# ==================== Encoder ====================
|
||
self.encoder = None
|
||
self.encoder_only = False
|
||
|
||
if T_cond > 0:
|
||
if n_cond_layers > 0:
|
||
# 使用Transformer Encoder
|
||
encoder_layer = nn.TransformerEncoderLayer(
|
||
d_model=n_emb,
|
||
nhead=n_head,
|
||
dim_feedforward=4 * n_emb,
|
||
dropout=p_drop_attn,
|
||
activation='gelu',
|
||
batch_first=True,
|
||
norm_first=True # Pre-LN更稳定
|
||
)
|
||
self.encoder = nn.TransformerEncoder(
|
||
encoder_layer=encoder_layer,
|
||
num_layers=n_cond_layers
|
||
)
|
||
else:
|
||
# 使用简单的MLP
|
||
self.encoder = nn.Sequential(
|
||
nn.Linear(n_emb, 4 * n_emb),
|
||
nn.Mish(),
|
||
nn.Linear(4 * n_emb, n_emb)
|
||
)
|
||
else:
|
||
# Encoder-only模式(BERT风格)
|
||
self.encoder_only = True
|
||
encoder_layer = nn.TransformerEncoderLayer(
|
||
d_model=n_emb,
|
||
nhead=n_head,
|
||
dim_feedforward=4 * n_emb,
|
||
dropout=p_drop_attn,
|
||
activation='gelu',
|
||
batch_first=True,
|
||
norm_first=True
|
||
)
|
||
self.encoder = nn.TransformerEncoder(
|
||
encoder_layer=encoder_layer,
|
||
num_layers=n_layer
|
||
)
|
||
|
||
# ==================== Attention Mask ====================
|
||
self.mask = None
|
||
self.memory_mask = None
|
||
|
||
if causal_attn:
|
||
# 因果mask:确保只关注左侧
|
||
sz = T
|
||
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
||
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||
self.register_buffer("mask", mask)
|
||
|
||
if obs_as_cond:
|
||
# 交叉注意力mask
|
||
S = T_cond
|
||
t, s = torch.meshgrid(
|
||
torch.arange(T),
|
||
torch.arange(S),
|
||
indexing='ij'
|
||
)
|
||
mask = t >= (s - 1)
|
||
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||
self.register_buffer('memory_mask', mask)
|
||
|
||
# ==================== Decoder ====================
|
||
if not self.encoder_only:
|
||
decoder_layer = nn.TransformerDecoderLayer(
|
||
d_model=n_emb,
|
||
nhead=n_head,
|
||
dim_feedforward=4 * n_emb,
|
||
dropout=p_drop_attn,
|
||
activation='gelu',
|
||
batch_first=True,
|
||
norm_first=True
|
||
)
|
||
self.decoder = nn.TransformerDecoder(
|
||
decoder_layer=decoder_layer,
|
||
num_layers=n_layer
|
||
)
|
||
|
||
# ==================== 输出头 ====================
|
||
self.ln_f = nn.LayerNorm(n_emb)
|
||
self.head = nn.Linear(n_emb, output_dim)
|
||
|
||
# ==================== 初始化 ====================
|
||
self.apply(self._init_weights)
|
||
|
||
# 打印参数量
|
||
total_params = sum(p.numel() for p in self.parameters())
|
||
print(f"Transformer1D parameters: {total_params:,}")
|
||
|
||
def _init_weights(self, module):
|
||
"""初始化权重"""
|
||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||
torch.nn.init.zeros_(module.bias)
|
||
elif isinstance(module, nn.MultiheadAttention):
|
||
# MultiheadAttention的权重初始化
|
||
for name in ['in_proj_weight', 'q_proj_weight', 'k_proj_weight', 'v_proj_weight']:
|
||
weight = getattr(module, name, None)
|
||
if weight is not None:
|
||
torch.nn.init.normal_(weight, mean=0.0, std=0.02)
|
||
|
||
for name in ['in_proj_bias', 'bias_k', 'bias_v']:
|
||
bias = getattr(module, name, None)
|
||
if bias is not None:
|
||
torch.nn.init.zeros_(bias)
|
||
elif isinstance(module, nn.LayerNorm):
|
||
torch.nn.init.zeros_(module.bias)
|
||
torch.nn.init.ones_(module.weight)
|
||
elif isinstance(module, Transformer1D):
|
||
# 位置编码初始化
|
||
torch.nn.init.normal_(self.pos_emb, mean=0.0, std=0.02)
|
||
if self.cond_pos_emb is not None:
|
||
torch.nn.init.normal_(self.cond_pos_emb, mean=0.0, std=0.02)
|
||
|
||
def forward(
|
||
self,
|
||
sample: torch.Tensor,
|
||
timestep: torch.Tensor,
|
||
cond: Optional[torch.Tensor] = None,
|
||
**kwargs
|
||
):
|
||
"""
|
||
前向传播
|
||
|
||
Args:
|
||
sample: (B, T, input_dim) 输入序列(加噪动作)
|
||
timestep: (B,) 时间步
|
||
cond: (B, T', cond_dim) 条件序列(观测特征)
|
||
|
||
Returns:
|
||
(B, T, output_dim) 预测的噪声
|
||
"""
|
||
# ==================== 处理时间步 ====================
|
||
timesteps = timestep
|
||
if not torch.is_tensor(timesteps):
|
||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||
timesteps = timesteps[None].to(sample.device)
|
||
|
||
# 扩展到batch维度
|
||
timesteps = timesteps.expand(sample.shape[0])
|
||
time_emb = self.time_emb(timesteps).unsqueeze(1) # (B, 1, n_emb)
|
||
|
||
# ==================== 处理输入 ====================
|
||
input_emb = self.input_emb(sample) # (B, T, n_emb)
|
||
|
||
# ==================== Encoder-Decoder模式 ====================
|
||
if not self.encoder_only:
|
||
# --- Encoder: 处理条件 ---
|
||
cond_embeddings = time_emb
|
||
|
||
if self.obs_as_cond and cond is not None:
|
||
# 添加观测条件
|
||
cond_obs_emb = self.cond_obs_emb(cond) # (B, T_cond-1, n_emb)
|
||
cond_embeddings = torch.cat([cond_embeddings, cond_obs_emb], dim=1)
|
||
|
||
# 添加位置编码
|
||
tc = cond_embeddings.shape[1]
|
||
pos_emb = self.cond_pos_emb[:, :tc, :]
|
||
x = self.drop(cond_embeddings + pos_emb)
|
||
|
||
# 通过encoder
|
||
memory = self.encoder(x) # (B, T_cond, n_emb)
|
||
|
||
# --- Decoder: 预测噪声 ---
|
||
# 添加位置编码到输入
|
||
token_embeddings = input_emb
|
||
t = token_embeddings.shape[1]
|
||
pos_emb = self.pos_emb[:, :t, :]
|
||
x = self.drop(token_embeddings + pos_emb)
|
||
|
||
# Cross-Attention: Query来自输入,Key/Value来自memory
|
||
x = self.decoder(
|
||
tgt=x,
|
||
memory=memory,
|
||
tgt_mask=self.mask,
|
||
memory_mask=self.memory_mask
|
||
)
|
||
|
||
# ==================== Encoder-Only模式 ====================
|
||
else:
|
||
# BERT风格:时间步作为特殊token
|
||
token_embeddings = torch.cat([time_emb, input_emb], dim=1)
|
||
t = token_embeddings.shape[1]
|
||
pos_emb = self.pos_emb[:, :t, :]
|
||
x = self.drop(token_embeddings + pos_emb)
|
||
|
||
x = self.encoder(src=x, mask=self.mask)
|
||
x = x[:, 1:, :] # 移除时间步token
|
||
|
||
# ==================== 输出头 ====================
|
||
x = self.ln_f(x)
|
||
x = self.head(x) # (B, T, output_dim)
|
||
|
||
return x
|
||
|
||
|
||
# ============================================================================
|
||
# 便捷函数:创建Transformer1D模型
|
||
# ============================================================================
|
||
def create_transformer1d(
|
||
input_dim: int,
|
||
output_dim: int,
|
||
horizon: int,
|
||
n_obs_steps: int,
|
||
cond_dim: int,
|
||
n_layer: int = 8,
|
||
n_head: int = 8,
|
||
n_emb: int = 256,
|
||
**kwargs
|
||
) -> Transformer1D:
|
||
"""
|
||
创建Transformer1D模型的便捷函数
|
||
|
||
Args:
|
||
input_dim: 输入动作维度
|
||
output_dim: 输出动作维度
|
||
horizon: 预测horizon
|
||
n_obs_steps: 观测步数
|
||
cond_dim: 条件维度
|
||
n_layer: Transformer层数
|
||
n_head: 注意力头数
|
||
n_emb: 嵌入维度
|
||
**kwargs: 其他参数
|
||
|
||
Returns:
|
||
Transformer1D模型
|
||
"""
|
||
model = Transformer1D(
|
||
input_dim=input_dim,
|
||
output_dim=output_dim,
|
||
horizon=horizon,
|
||
n_obs_steps=n_obs_steps,
|
||
cond_dim=cond_dim,
|
||
n_layer=n_layer,
|
||
n_head=n_head,
|
||
n_emb=n_emb,
|
||
**kwargs
|
||
)
|
||
return model
|
||
|
||
|
||
if __name__ == "__main__":
|
||
print("=" * 80)
|
||
print("Testing Transformer1D")
|
||
print("=" * 80)
|
||
|
||
# 配置
|
||
B = 4
|
||
T = 16
|
||
action_dim = 16
|
||
obs_horizon = 2
|
||
cond_dim = 416 # vision + state特征维度
|
||
|
||
# 创建模型
|
||
model = Transformer1D(
|
||
input_dim=action_dim,
|
||
output_dim=action_dim,
|
||
horizon=T,
|
||
n_obs_steps=obs_horizon,
|
||
cond_dim=cond_dim,
|
||
n_layer=4,
|
||
n_head=8,
|
||
n_emb=256,
|
||
causal_attn=False
|
||
)
|
||
|
||
# 测试前向传播
|
||
sample = torch.randn(B, T, action_dim)
|
||
timestep = torch.randint(0, 100, (B,))
|
||
cond = torch.randn(B, obs_horizon, cond_dim)
|
||
|
||
output = model(sample, timestep, cond)
|
||
|
||
print(f"\n输入:")
|
||
print(f" sample: {sample.shape}")
|
||
print(f" timestep: {timestep.shape}")
|
||
print(f" cond: {cond.shape}")
|
||
print(f"\n输出:")
|
||
print(f" output: {output.shape}")
|
||
print(f"\n✅ 测试通过!")
|