380 lines
14 KiB
Python
380 lines
14 KiB
Python
"""Local IMF-AttnRes transformer head aligned with diffusion_policy@185ed659."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from .attnres_transformer_components import (
|
|
AttnResOperator,
|
|
AttnResSubLayer,
|
|
AttnResTransformerBackbone,
|
|
GroupedQuerySelfAttention,
|
|
RMSNorm,
|
|
RMSNormNoWeight,
|
|
SwiGLUFFN,
|
|
)
|
|
from .transformer1d import ModuleAttrMixin, SinusoidalPosEmb
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class IMFTransformer1D(ModuleAttrMixin):
|
|
def __init__(
|
|
self,
|
|
input_dim: int,
|
|
output_dim: int,
|
|
horizon: int,
|
|
n_obs_steps: Optional[int] = None,
|
|
cond_dim: int = 0,
|
|
n_layer: int = 12,
|
|
n_head: int = 1,
|
|
n_emb: int = 768,
|
|
p_drop_emb: float = 0.1,
|
|
p_drop_attn: float = 0.1,
|
|
causal_attn: bool = False,
|
|
time_as_cond: bool = True,
|
|
obs_as_cond: bool = False,
|
|
n_cond_layers: int = 0,
|
|
backbone_type: str = 'attnres_full',
|
|
n_kv_head: int = 1,
|
|
attn_res_ffn_mult: float = 2.667,
|
|
attn_res_eps: float = 1e-6,
|
|
attn_res_rope_theta: float = 10000.0,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
if n_head != 1:
|
|
raise AssertionError('IMFTransformer1D currently supports single-head attention only.')
|
|
if n_obs_steps is None:
|
|
n_obs_steps = horizon
|
|
|
|
self.backbone_type = backbone_type
|
|
|
|
T = horizon
|
|
T_cond = 2
|
|
if not time_as_cond:
|
|
T += 2
|
|
T_cond -= 2
|
|
obs_as_cond = cond_dim > 0
|
|
if obs_as_cond:
|
|
assert time_as_cond
|
|
T_cond += n_obs_steps
|
|
|
|
self.input_emb = nn.Linear(input_dim, n_emb)
|
|
self.drop = nn.Dropout(p_drop_emb)
|
|
self.time_emb = SinusoidalPosEmb(n_emb)
|
|
self.cond_obs_emb = nn.Linear(cond_dim, n_emb) if obs_as_cond else None
|
|
self.time_token_proj = None
|
|
self.cond_pos_emb = None
|
|
self.pos_emb = None
|
|
self.encoder = None
|
|
self.decoder = None
|
|
self.attnres_backbone = None
|
|
encoder_only = False
|
|
|
|
if backbone_type == 'attnres_full':
|
|
if not time_as_cond:
|
|
raise ValueError('attnres_full backbone requires time_as_cond=True.')
|
|
if n_cond_layers != 0:
|
|
raise ValueError('attnres_full backbone does not support n_cond_layers > 0.')
|
|
|
|
self.time_token_proj = nn.Linear(n_emb, n_emb)
|
|
self.attnres_backbone = AttnResTransformerBackbone(
|
|
d_model=n_emb,
|
|
n_blocks=n_layer,
|
|
n_heads=n_head,
|
|
n_kv_heads=n_kv_head,
|
|
max_seq_len=T + T_cond,
|
|
dropout=p_drop_attn,
|
|
ffn_mult=attn_res_ffn_mult,
|
|
eps=attn_res_eps,
|
|
rope_theta=attn_res_rope_theta,
|
|
causal_attn=causal_attn,
|
|
)
|
|
self.ln_f = RMSNorm(n_emb, eps=attn_res_eps)
|
|
else:
|
|
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
|
|
if T_cond > 0:
|
|
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
|
|
if n_cond_layers > 0:
|
|
encoder_layer = nn.TransformerEncoderLayer(
|
|
d_model=n_emb,
|
|
nhead=n_head,
|
|
dim_feedforward=4 * n_emb,
|
|
dropout=p_drop_attn,
|
|
activation='gelu',
|
|
batch_first=True,
|
|
norm_first=True,
|
|
)
|
|
self.encoder = nn.TransformerEncoder(
|
|
encoder_layer=encoder_layer,
|
|
num_layers=n_cond_layers,
|
|
)
|
|
else:
|
|
self.encoder = nn.Sequential(
|
|
nn.Linear(n_emb, 4 * n_emb),
|
|
nn.Mish(),
|
|
nn.Linear(4 * n_emb, n_emb),
|
|
)
|
|
|
|
decoder_layer = nn.TransformerDecoderLayer(
|
|
d_model=n_emb,
|
|
nhead=n_head,
|
|
dim_feedforward=4 * n_emb,
|
|
dropout=p_drop_attn,
|
|
activation='gelu',
|
|
batch_first=True,
|
|
norm_first=True,
|
|
)
|
|
self.decoder = nn.TransformerDecoder(
|
|
decoder_layer=decoder_layer,
|
|
num_layers=n_layer,
|
|
)
|
|
else:
|
|
encoder_only = True
|
|
encoder_layer = nn.TransformerEncoderLayer(
|
|
d_model=n_emb,
|
|
nhead=n_head,
|
|
dim_feedforward=4 * n_emb,
|
|
dropout=p_drop_attn,
|
|
activation='gelu',
|
|
batch_first=True,
|
|
norm_first=True,
|
|
)
|
|
self.encoder = nn.TransformerEncoder(
|
|
encoder_layer=encoder_layer,
|
|
num_layers=n_layer,
|
|
)
|
|
|
|
self.ln_f = nn.LayerNorm(n_emb)
|
|
|
|
if causal_attn and backbone_type != 'attnres_full':
|
|
sz = T
|
|
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
|
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
|
self.register_buffer('mask', mask)
|
|
|
|
if time_as_cond and obs_as_cond:
|
|
S = T_cond
|
|
t_idx, s_idx = torch.meshgrid(
|
|
torch.arange(T),
|
|
torch.arange(S),
|
|
indexing='ij',
|
|
)
|
|
mask = t_idx >= (s_idx - 2)
|
|
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
|
self.register_buffer('memory_mask', mask)
|
|
else:
|
|
self.memory_mask = None
|
|
else:
|
|
self.mask = None
|
|
self.memory_mask = None
|
|
|
|
self.head = nn.Linear(n_emb, output_dim)
|
|
|
|
self.T = T
|
|
self.T_cond = T_cond
|
|
self.horizon = horizon
|
|
self.time_as_cond = time_as_cond
|
|
self.obs_as_cond = obs_as_cond
|
|
self.encoder_only = encoder_only
|
|
|
|
self.apply(self._init_weights)
|
|
logger.info('number of parameters: %e', sum(p.numel() for p in self.parameters()))
|
|
|
|
def _init_weights(self, module):
|
|
ignore_types = (
|
|
nn.Dropout,
|
|
SinusoidalPosEmb,
|
|
nn.TransformerEncoderLayer,
|
|
nn.TransformerDecoderLayer,
|
|
nn.TransformerEncoder,
|
|
nn.TransformerDecoder,
|
|
nn.ModuleList,
|
|
nn.Mish,
|
|
nn.Sequential,
|
|
AttnResTransformerBackbone,
|
|
AttnResSubLayer,
|
|
GroupedQuerySelfAttention,
|
|
SwiGLUFFN,
|
|
RMSNormNoWeight,
|
|
)
|
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
|
torch.nn.init.zeros_(module.bias)
|
|
elif isinstance(module, nn.MultiheadAttention):
|
|
for name in ('in_proj_weight', 'q_proj_weight', 'k_proj_weight', 'v_proj_weight'):
|
|
weight = getattr(module, name)
|
|
if weight is not None:
|
|
torch.nn.init.normal_(weight, mean=0.0, std=0.02)
|
|
|
|
for name in ('in_proj_bias', 'bias_k', 'bias_v'):
|
|
bias = getattr(module, name)
|
|
if bias is not None:
|
|
torch.nn.init.zeros_(bias)
|
|
elif isinstance(module, (nn.LayerNorm, RMSNorm)):
|
|
if getattr(module, 'bias', None) is not None:
|
|
torch.nn.init.zeros_(module.bias)
|
|
torch.nn.init.ones_(module.weight)
|
|
elif isinstance(module, AttnResOperator):
|
|
torch.nn.init.zeros_(module.pseudo_query)
|
|
elif isinstance(module, IMFTransformer1D):
|
|
if module.pos_emb is not None:
|
|
torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
|
|
if module.cond_pos_emb is not None:
|
|
torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02)
|
|
elif isinstance(module, ignore_types):
|
|
pass
|
|
else:
|
|
raise RuntimeError(f'Unaccounted module {module}')
|
|
|
|
def get_optim_groups(self, weight_decay: float = 1e-3):
|
|
decay = set()
|
|
no_decay = set()
|
|
whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention)
|
|
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding, RMSNorm)
|
|
for mn, m in self.named_modules():
|
|
for pn, _ in m.named_parameters(recurse=False):
|
|
fpn = f'{mn}.{pn}' if mn else pn
|
|
|
|
if pn.endswith('bias'):
|
|
no_decay.add(fpn)
|
|
elif pn.startswith('bias'):
|
|
no_decay.add(fpn)
|
|
elif pn == 'pseudo_query':
|
|
no_decay.add(fpn)
|
|
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
|
|
decay.add(fpn)
|
|
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
|
|
no_decay.add(fpn)
|
|
|
|
if self.pos_emb is not None:
|
|
no_decay.add('pos_emb')
|
|
no_decay.add('_dummy_variable')
|
|
if self.cond_pos_emb is not None:
|
|
no_decay.add('cond_pos_emb')
|
|
|
|
param_dict = {pn: p for pn, p in self.named_parameters()}
|
|
inter_params = decay & no_decay
|
|
union_params = decay | no_decay
|
|
assert len(inter_params) == 0, f'parameters {inter_params} made it into both decay/no_decay sets!'
|
|
assert len(param_dict.keys() - union_params) == 0, (
|
|
f'parameters {param_dict.keys() - union_params} were not separated into either decay/no_decay sets!'
|
|
)
|
|
|
|
return [
|
|
{
|
|
'params': [param_dict[pn] for pn in sorted(list(decay))],
|
|
'weight_decay': weight_decay,
|
|
},
|
|
{
|
|
'params': [param_dict[pn] for pn in sorted(list(no_decay))],
|
|
'weight_decay': 0.0,
|
|
},
|
|
]
|
|
|
|
def configure_optimizers(
|
|
self,
|
|
learning_rate: float = 1e-4,
|
|
weight_decay: float = 1e-3,
|
|
betas: Tuple[float, float] = (0.9, 0.95),
|
|
):
|
|
optim_groups = self.get_optim_groups(weight_decay=weight_decay)
|
|
return torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
|
|
|
|
def _prepare_time_input(self, value: Union[torch.Tensor, float, int], sample: torch.Tensor) -> torch.Tensor:
|
|
if not torch.is_tensor(value):
|
|
value = torch.tensor([value], dtype=sample.dtype, device=sample.device)
|
|
elif value.ndim == 0:
|
|
value = value[None].to(device=sample.device, dtype=sample.dtype)
|
|
else:
|
|
value = value.to(device=sample.device, dtype=sample.dtype)
|
|
return value.expand(sample.shape[0])
|
|
|
|
def _forward_attnres_full(
|
|
self,
|
|
sample: torch.Tensor,
|
|
r: torch.Tensor,
|
|
t: torch.Tensor,
|
|
cond: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
sample_tokens = self.input_emb(sample)
|
|
token_parts = [
|
|
self.time_token_proj(self.time_emb(r)).unsqueeze(1),
|
|
self.time_token_proj(self.time_emb(t)).unsqueeze(1),
|
|
]
|
|
if self.obs_as_cond:
|
|
if cond is None:
|
|
raise ValueError('cond is required when obs_as_cond=True for attnres_full backbone.')
|
|
token_parts.append(self.cond_obs_emb(cond))
|
|
token_parts.append(sample_tokens)
|
|
x = torch.cat(token_parts, dim=1)
|
|
x = self.drop(x)
|
|
x = self.attnres_backbone(x)
|
|
x = x[:, -sample_tokens.shape[1]:, :]
|
|
return x
|
|
|
|
def _forward_vanilla(
|
|
self,
|
|
sample: torch.Tensor,
|
|
r: torch.Tensor,
|
|
t: torch.Tensor,
|
|
cond: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
r_emb = self.time_emb(r).unsqueeze(1)
|
|
t_emb = self.time_emb(t).unsqueeze(1)
|
|
input_emb = self.input_emb(sample)
|
|
|
|
if self.encoder_only:
|
|
token_embeddings = torch.cat([r_emb, t_emb, input_emb], dim=1)
|
|
token_count = token_embeddings.shape[1]
|
|
position_embeddings = self.pos_emb[:, :token_count, :]
|
|
x = self.drop(token_embeddings + position_embeddings)
|
|
x = self.encoder(src=x, mask=self.mask)
|
|
x = x[:, 2:, :]
|
|
else:
|
|
cond_embeddings = torch.cat([r_emb, t_emb], dim=1)
|
|
if self.obs_as_cond:
|
|
cond_embeddings = torch.cat([cond_embeddings, self.cond_obs_emb(cond)], dim=1)
|
|
token_count = cond_embeddings.shape[1]
|
|
position_embeddings = self.cond_pos_emb[:, :token_count, :]
|
|
x = self.drop(cond_embeddings + position_embeddings)
|
|
x = self.encoder(x)
|
|
memory = x
|
|
|
|
token_embeddings = input_emb
|
|
token_count = token_embeddings.shape[1]
|
|
position_embeddings = self.pos_emb[:, :token_count, :]
|
|
x = self.drop(token_embeddings + position_embeddings)
|
|
x = self.decoder(
|
|
tgt=x,
|
|
memory=memory,
|
|
tgt_mask=self.mask,
|
|
memory_mask=self.memory_mask,
|
|
)
|
|
return x
|
|
|
|
def forward(
|
|
self,
|
|
sample: torch.Tensor,
|
|
r: Union[torch.Tensor, float, int],
|
|
t: Union[torch.Tensor, float, int],
|
|
cond: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
r = self._prepare_time_input(r, sample)
|
|
t = self._prepare_time_input(t, sample)
|
|
|
|
if self.backbone_type == 'attnres_full':
|
|
x = self._forward_attnres_full(sample, r, t, cond=cond)
|
|
else:
|
|
x = self._forward_vanilla(sample, r, t, cond=cond)
|
|
|
|
x = self.ln_f(x)
|
|
x = self.head(x)
|
|
return x
|