From 185ed6596cb9b4435ba812c0feb99a0a0c9d086b Mon Sep 17 00:00:00 2001 From: Logic Date: Sun, 29 Mar 2026 11:15:59 +0800 Subject: [PATCH] feat: add pusht imf attnres backbone --- .../attnres_transformer_components.py | 247 ++++++++++++++++++ .../imf_transformer_for_diffusion.py | 207 ++++++++++----- .../imf_transformer_hybrid_image_policy.py | 10 + ...-03-29-pusht-imf-attnres-implementation.md | 57 ++++ .../2026-03-29-pusht-imf-attnres-design.md | 108 ++++++++ ...diffusion_policy_dit_imf_attnres_full.yaml | 35 +++ tests/test_imf_transformer_for_diffusion.py | 31 +++ tests/test_pusht_swanlab_config.py | 13 + 8 files changed, 647 insertions(+), 61 deletions(-) create mode 100644 diffusion_policy/model/diffusion/attnres_transformer_components.py create mode 100644 docs/superpowers/plans/2026-03-29-pusht-imf-attnres-implementation.md create mode 100644 docs/superpowers/specs/2026-03-29-pusht-imf-attnres-design.md create mode 100644 image_pusht_diffusion_policy_dit_imf_attnres_full.yaml diff --git a/diffusion_policy/model/diffusion/attnres_transformer_components.py b/diffusion_policy/model/diffusion/attnres_transformer_components.py new file mode 100644 index 0000000..4796e1b --- /dev/null +++ b/diffusion_policy/model/diffusion/attnres_transformer_components.py @@ -0,0 +1,247 @@ +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6) -> None: + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + rms = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps) + return (x.float() * rms).to(x.dtype) * self.weight + + +class RMSNormNoWeight(nn.Module): + def __init__(self, eps: float = 1e-6) -> None: + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + rms = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps) + return (x.float() * rms).to(x.dtype) + + +def precompute_rope_freqs( + dim: int, + max_seq_len: int, + theta: float = 10000.0, + device: Optional[torch.device] = None, +) -> Tensor: + if dim % 2 != 0: + raise ValueError(f'RoPE requires an even head dimension, got {dim}.') + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim)) + positions = torch.arange(max_seq_len, device=device).float() + angles = torch.outer(positions, freqs) + return torch.polar(torch.ones_like(angles), angles) + + +def apply_rope(x: Tensor, freqs: Tensor) -> Tensor: + x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs = freqs.unsqueeze(0).unsqueeze(2) + x_rotated = x_complex * freqs + return torch.view_as_real(x_rotated).reshape_as(x).to(x.dtype) + + +class GroupedQuerySelfAttention(nn.Module): + def __init__( + self, + d_model: int, + n_heads: int, + n_kv_heads: int, + dropout: float = 0.0, + ) -> None: + super().__init__() + if d_model % n_heads != 0: + raise ValueError(f'd_model={d_model} must be divisible by n_heads={n_heads}.') + if n_heads % n_kv_heads != 0: + raise ValueError(f'n_heads={n_heads} must be divisible by n_kv_heads={n_kv_heads}.') + + self.d_model = d_model + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.n_kv_groups = n_heads // n_kv_heads + self.d_head = d_model // n_heads + self.attn_dropout = nn.Dropout(dropout) + self.out_dropout = nn.Dropout(dropout) + + self.w_q = nn.Linear(d_model, n_heads * self.d_head, bias=False) + self.w_k = nn.Linear(d_model, n_kv_heads * self.d_head, bias=False) + self.w_v = nn.Linear(d_model, n_kv_heads * self.d_head, bias=False) + self.w_o = nn.Linear(n_heads * self.d_head, d_model, bias=False) + + def forward( + self, + x: Tensor, + rope_freqs: Tensor, + mask: Optional[Tensor] = None, + ) -> Tensor: + batch_size, seq_len, _ = x.shape + + q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_head) + k = self.w_k(x).view(batch_size, seq_len, self.n_kv_heads, self.d_head) + v = self.w_v(x).view(batch_size, seq_len, self.n_kv_heads, self.d_head) + + q = apply_rope(q, rope_freqs) + k = apply_rope(k, rope_freqs) + + if self.n_kv_heads != self.n_heads: + k = k.unsqueeze(3).expand( + batch_size, seq_len, self.n_kv_heads, self.n_kv_groups, self.d_head + ) + k = k.reshape(batch_size, seq_len, self.n_heads, self.d_head) + v = v.unsqueeze(3).expand( + batch_size, seq_len, self.n_kv_heads, self.n_kv_groups, self.d_head + ) + v = v.reshape(batch_size, seq_len, self.n_heads, self.d_head) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + scale = 1.0 / math.sqrt(self.d_head) + attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale + if mask is not None: + attn_weights = attn_weights + mask + attn_weights = F.softmax(attn_weights, dim=-1) + attn_weights = self.attn_dropout(attn_weights) + + out = torch.matmul(attn_weights, v) + out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model) + return self.out_dropout(self.w_o(out)) + + +class SwiGLUFFN(nn.Module): + def __init__(self, d_model: int, dropout: float = 0.0, mult: float = 2.667) -> None: + super().__init__() + raw = int(mult * d_model) + d_ff = ((raw + 7) // 8) * 8 + self.w_gate = nn.Linear(d_model, d_ff, bias=False) + self.w_up = nn.Linear(d_model, d_ff, bias=False) + self.w_down = nn.Linear(d_ff, d_model, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, x: Tensor) -> Tensor: + return self.dropout(self.w_down(F.silu(self.w_gate(x)) * self.w_up(x))) + + +class AttnResOperator(nn.Module): + def __init__(self, d_model: int, eps: float = 1e-6) -> None: + super().__init__() + self.pseudo_query = nn.Parameter(torch.zeros(d_model)) + self.key_norm = RMSNormNoWeight(eps=eps) + + def forward(self, sources: Tensor) -> Tensor: + keys = self.key_norm(sources) + logits = torch.einsum('d,nbtd->nbt', self.pseudo_query, keys) + weights = F.softmax(logits, dim=0) + return torch.einsum('nbt,nbtd->btd', weights, sources) + + +class AttnResSubLayer(nn.Module): + def __init__( + self, + d_model: int, + n_heads: int, + n_kv_heads: int, + dropout: float, + ffn_mult: float, + eps: float, + is_attention: bool, + ) -> None: + super().__init__() + self.norm = RMSNorm(d_model, eps=eps) + self.attn_res = AttnResOperator(d_model, eps=eps) + self.is_attention = is_attention + if is_attention: + self.fn = GroupedQuerySelfAttention( + d_model=d_model, + n_heads=n_heads, + n_kv_heads=n_kv_heads, + dropout=dropout, + ) + else: + self.fn = SwiGLUFFN(d_model=d_model, dropout=dropout, mult=ffn_mult) + + def forward(self, sources: Tensor, rope_freqs: Tensor, mask: Optional[Tensor] = None) -> Tensor: + h = self.attn_res(sources) + normed = self.norm(h) + if self.is_attention: + return self.fn(normed, rope_freqs, mask) + return self.fn(normed) + + +class AttnResTransformerBackbone(nn.Module): + def __init__( + self, + d_model: int, + n_blocks: int, + n_heads: int, + n_kv_heads: int, + max_seq_len: int, + dropout: float = 0.0, + ffn_mult: float = 2.667, + eps: float = 1e-6, + rope_theta: float = 10000.0, + causal_attn: bool = False, + ) -> None: + super().__init__() + self.causal_attn = causal_attn + self.layers = nn.ModuleList() + for _ in range(n_blocks): + self.layers.append( + AttnResSubLayer( + d_model=d_model, + n_heads=n_heads, + n_kv_heads=n_kv_heads, + dropout=dropout, + ffn_mult=ffn_mult, + eps=eps, + is_attention=True, + ) + ) + self.layers.append( + AttnResSubLayer( + d_model=d_model, + n_heads=n_heads, + n_kv_heads=n_kv_heads, + dropout=dropout, + ffn_mult=ffn_mult, + eps=eps, + is_attention=False, + ) + ) + + rope_freqs = precompute_rope_freqs( + dim=d_model // n_heads, + max_seq_len=max_seq_len, + theta=rope_theta, + ) + self.register_buffer('rope_freqs', rope_freqs, persistent=False) + + @staticmethod + def _build_causal_mask(seq_len: int, device: torch.device) -> Tensor: + mask = torch.full((seq_len, seq_len), float('-inf'), device=device) + mask = torch.triu(mask, diagonal=1) + return mask.unsqueeze(0).unsqueeze(0) + + def forward(self, x: Tensor) -> Tensor: + seq_len = x.shape[1] + rope_freqs = self.rope_freqs[:seq_len] + mask = None + if self.causal_attn: + mask = self._build_causal_mask(seq_len, x.device) + + layer_outputs = [x] + for layer in self.layers: + sources = torch.stack(layer_outputs, dim=0) + output = layer(sources, rope_freqs, mask) + layer_outputs.append(output) + + return torch.stack(layer_outputs, dim=0).sum(dim=0) diff --git a/diffusion_policy/model/diffusion/imf_transformer_for_diffusion.py b/diffusion_policy/model/diffusion/imf_transformer_for_diffusion.py index 0967915..7c47a2f 100644 --- a/diffusion_policy/model/diffusion/imf_transformer_for_diffusion.py +++ b/diffusion_policy/model/diffusion/imf_transformer_for_diffusion.py @@ -5,6 +5,15 @@ import torch import torch.nn as nn from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin +from diffusion_policy.model.diffusion.attnres_transformer_components import ( + AttnResOperator, + AttnResSubLayer, + AttnResTransformerBackbone, + GroupedQuerySelfAttention, + RMSNorm, + RMSNormNoWeight, + SwiGLUFFN, +) from diffusion_policy.model.diffusion.positional_embedding import SinusoidalPosEmb logger = logging.getLogger(__name__) @@ -27,14 +36,20 @@ class IMFTransformerForDiffusion(ModuleAttrMixin): time_as_cond: bool = True, obs_as_cond: bool = False, n_cond_layers: int = 0, + backbone_type: str = 'vanilla', + n_kv_head: int = 1, + attn_res_ffn_mult: float = 2.667, + attn_res_eps: float = 1e-6, + attn_res_rope_theta: float = 10000.0, ) -> None: super().__init__() assert n_head == 1, 'IMFTransformerForDiffusion currently supports single-head attention only.' - if n_obs_steps is None: n_obs_steps = horizon + self.backbone_type = backbone_type + T = horizon T_cond = 2 if not time_as_cond: @@ -46,21 +61,77 @@ class IMFTransformerForDiffusion(ModuleAttrMixin): T_cond += n_obs_steps self.input_emb = nn.Linear(input_dim, n_emb) - self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb)) self.drop = nn.Dropout(p_drop_emb) - self.time_emb = SinusoidalPosEmb(n_emb) - self.cond_obs_emb = None - if obs_as_cond: - self.cond_obs_emb = nn.Linear(cond_dim, n_emb) - + self.cond_obs_emb = nn.Linear(cond_dim, n_emb) if obs_as_cond else None + self.time_token_proj = None self.cond_pos_emb = None + self.pos_emb = None self.encoder = None self.decoder = None + self.attnres_backbone = None encoder_only = False - if T_cond > 0: - self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb)) - if n_cond_layers > 0: + + if backbone_type == 'attnres_full': + if not time_as_cond: + raise ValueError('attnres_full backbone requires time_as_cond=True.') + if n_cond_layers != 0: + raise ValueError('attnres_full backbone does not support n_cond_layers > 0.') + + self.time_token_proj = nn.Linear(n_emb, n_emb) + self.attnres_backbone = AttnResTransformerBackbone( + d_model=n_emb, + n_blocks=n_layer, + n_heads=n_head, + n_kv_heads=n_kv_head, + max_seq_len=T + T_cond, + dropout=p_drop_attn, + ffn_mult=attn_res_ffn_mult, + eps=attn_res_eps, + rope_theta=attn_res_rope_theta, + causal_attn=causal_attn, + ) + self.ln_f = RMSNorm(n_emb, eps=attn_res_eps) + else: + self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb)) + if T_cond > 0: + self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb)) + if n_cond_layers > 0: + encoder_layer = nn.TransformerEncoderLayer( + d_model=n_emb, + nhead=n_head, + dim_feedforward=4 * n_emb, + dropout=p_drop_attn, + activation='gelu', + batch_first=True, + norm_first=True, + ) + self.encoder = nn.TransformerEncoder( + encoder_layer=encoder_layer, + num_layers=n_cond_layers, + ) + else: + self.encoder = nn.Sequential( + nn.Linear(n_emb, 4 * n_emb), + nn.Mish(), + nn.Linear(4 * n_emb, n_emb), + ) + + decoder_layer = nn.TransformerDecoderLayer( + d_model=n_emb, + nhead=n_head, + dim_feedforward=4 * n_emb, + dropout=p_drop_attn, + activation='gelu', + batch_first=True, + norm_first=True, + ) + self.decoder = nn.TransformerDecoder( + decoder_layer=decoder_layer, + num_layers=n_layer, + ) + else: + encoder_only = True encoder_layer = nn.TransformerEncoderLayer( d_model=n_emb, nhead=n_head, @@ -72,45 +143,12 @@ class IMFTransformerForDiffusion(ModuleAttrMixin): ) self.encoder = nn.TransformerEncoder( encoder_layer=encoder_layer, - num_layers=n_cond_layers, - ) - else: - self.encoder = nn.Sequential( - nn.Linear(n_emb, 4 * n_emb), - nn.Mish(), - nn.Linear(4 * n_emb, n_emb), + num_layers=n_layer, ) - decoder_layer = nn.TransformerDecoderLayer( - d_model=n_emb, - nhead=n_head, - dim_feedforward=4 * n_emb, - dropout=p_drop_attn, - activation='gelu', - batch_first=True, - norm_first=True, - ) - self.decoder = nn.TransformerDecoder( - decoder_layer=decoder_layer, - num_layers=n_layer, - ) - else: - encoder_only = True - encoder_layer = nn.TransformerEncoderLayer( - d_model=n_emb, - nhead=n_head, - dim_feedforward=4 * n_emb, - dropout=p_drop_attn, - activation='gelu', - batch_first=True, - norm_first=True, - ) - self.encoder = nn.TransformerEncoder( - encoder_layer=encoder_layer, - num_layers=n_layer, - ) + self.ln_f = nn.LayerNorm(n_emb) - if causal_attn: + if causal_attn and backbone_type != 'attnres_full': sz = T mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) @@ -132,7 +170,6 @@ class IMFTransformerForDiffusion(ModuleAttrMixin): self.mask = None self.memory_mask = None - self.ln_f = nn.LayerNorm(n_emb) self.head = nn.Linear(n_emb, output_dim) self.T = T @@ -159,6 +196,11 @@ class IMFTransformerForDiffusion(ModuleAttrMixin): nn.ModuleList, nn.Mish, nn.Sequential, + AttnResTransformerBackbone, + AttnResSubLayer, + GroupedQuerySelfAttention, + SwiGLUFFN, + RMSNormNoWeight, ) if isinstance(module, (nn.Linear, nn.Embedding)): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) @@ -176,12 +218,16 @@ class IMFTransformerForDiffusion(ModuleAttrMixin): bias = getattr(module, name) if bias is not None: torch.nn.init.zeros_(bias) - elif isinstance(module, nn.LayerNorm): - torch.nn.init.zeros_(module.bias) + elif isinstance(module, (nn.LayerNorm, RMSNorm)): + if getattr(module, 'bias', None) is not None: + torch.nn.init.zeros_(module.bias) torch.nn.init.ones_(module.weight) + elif isinstance(module, AttnResOperator): + torch.nn.init.zeros_(module.pseudo_query) elif isinstance(module, IMFTransformerForDiffusion): - torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02) - if module.cond_obs_emb is not None: + if module.pos_emb is not None: + torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02) + if module.cond_pos_emb is not None: torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02) elif isinstance(module, ignore_types): pass @@ -192,21 +238,24 @@ class IMFTransformerForDiffusion(ModuleAttrMixin): decay = set() no_decay = set() whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention) - blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) + blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding, RMSNorm) for mn, m in self.named_modules(): - for pn, _ in m.named_parameters(): + for pn, _ in m.named_parameters(recurse=False): fpn = '%s.%s' % (mn, pn) if mn else pn if pn.endswith('bias'): no_decay.add(fpn) elif pn.startswith('bias'): no_decay.add(fpn) + elif pn == 'pseudo_query': + no_decay.add(fpn) elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): decay.add(fpn) elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): no_decay.add(fpn) - no_decay.add('pos_emb') + if self.pos_emb is not None: + no_decay.add('pos_emb') no_decay.add('_dummy_variable') if self.cond_pos_emb is not None: no_decay.add('cond_pos_emb') @@ -250,18 +299,38 @@ class IMFTransformerForDiffusion(ModuleAttrMixin): value = value.to(device=sample.device, dtype=sample.dtype) return value.expand(sample.shape[0]) - def forward( + def _forward_attnres_full( self, sample: torch.Tensor, - r: Union[torch.Tensor, float, int], - t: Union[torch.Tensor, float, int], + r: torch.Tensor, + t: torch.Tensor, cond: Optional[torch.Tensor] = None, - ): - r = self._prepare_time_input(r, sample) - t = self._prepare_time_input(t, sample) + ) -> torch.Tensor: + sample_tokens = self.input_emb(sample) + token_parts = [ + self.time_token_proj(self.time_emb(r)).unsqueeze(1), + self.time_token_proj(self.time_emb(t)).unsqueeze(1), + ] + if self.obs_as_cond: + if cond is None: + raise ValueError('cond is required when obs_as_cond=True for attnres_full backbone.') + token_parts.append(self.cond_obs_emb(cond)) + token_parts.append(sample_tokens) + x = torch.cat(token_parts, dim=1) + x = self.drop(x) + x = self.attnres_backbone(x) + x = x[:, -sample_tokens.shape[1] :, :] + return x + + def _forward_vanilla( + self, + sample: torch.Tensor, + r: torch.Tensor, + t: torch.Tensor, + cond: Optional[torch.Tensor] = None, + ) -> torch.Tensor: r_emb = self.time_emb(r).unsqueeze(1) t_emb = self.time_emb(t).unsqueeze(1) - input_emb = self.input_emb(sample) if self.encoder_only: @@ -292,6 +361,22 @@ class IMFTransformerForDiffusion(ModuleAttrMixin): tgt_mask=self.mask, memory_mask=self.memory_mask, ) + return x + + def forward( + self, + sample: torch.Tensor, + r: Union[torch.Tensor, float, int], + t: Union[torch.Tensor, float, int], + cond: Optional[torch.Tensor] = None, + ): + r = self._prepare_time_input(r, sample) + t = self._prepare_time_input(t, sample) + + if self.backbone_type == 'attnres_full': + x = self._forward_attnres_full(sample, r, t, cond=cond) + else: + x = self._forward_vanilla(sample, r, t, cond=cond) x = self.ln_f(x) x = self.head(x) diff --git a/diffusion_policy/policy/imf_transformer_hybrid_image_policy.py b/diffusion_policy/policy/imf_transformer_hybrid_image_policy.py index 4ffc77e..91d26f1 100644 --- a/diffusion_policy/policy/imf_transformer_hybrid_image_policy.py +++ b/diffusion_policy/policy/imf_transformer_hybrid_image_policy.py @@ -39,6 +39,11 @@ class IMFTransformerHybridImagePolicy(DiffusionTransformerHybridImagePolicy): time_as_cond=True, obs_as_cond=True, pred_action_steps_only=False, + backbone_type='vanilla', + n_kv_head=1, + attn_res_ffn_mult=2.667, + attn_res_eps=1e-6, + attn_res_rope_theta=10000.0, **kwargs, ): if num_inference_steps is None: @@ -90,6 +95,11 @@ class IMFTransformerHybridImagePolicy(DiffusionTransformerHybridImagePolicy): time_as_cond=time_as_cond, obs_as_cond=obs_as_cond, n_cond_layers=n_cond_layers, + backbone_type=backbone_type, + n_kv_head=n_kv_head, + attn_res_ffn_mult=attn_res_ffn_mult, + attn_res_eps=attn_res_eps, + attn_res_rope_theta=attn_res_rope_theta, ) self.num_inference_steps = 1 diff --git a/docs/superpowers/plans/2026-03-29-pusht-imf-attnres-implementation.md b/docs/superpowers/plans/2026-03-29-pusht-imf-attnres-implementation.md new file mode 100644 index 0000000..d782cb0 --- /dev/null +++ b/docs/superpowers/plans/2026-03-29-pusht-imf-attnres-implementation.md @@ -0,0 +1,57 @@ +# PushT Image iMF AttnRes Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Add an AttnRes-backed full-attention iMF backbone for the PushT image experiment path, verify it with tests/smoke runs, then launch the 9-run 350-epoch architecture sweep across the local 5090 and remote 5880 GPUs. + +**Architecture:** Extend `IMFTransformerForDiffusion` with a selectable `attnres_full` backbone that keeps the current iMF training/inference API unchanged while replacing the transformer internals with RMSNorm + RoPE self-attention + SwiGLU + Full AttnRes depth-wise residual routing. Add one standalone Hydra config for the PushT image sweep and reuse queue-style launch scripts with unique SwanLab names. + +**Tech Stack:** Python 3.9 via uv, PyTorch 2.8 CUDA, Hydra, SwanLab online logging, local shell + SSH to trusted 5880 host. + +--- + +### Task 1: Add regression tests for the new AttnRes path + +**Files:** +- Modify: `tests/test_imf_transformer_for_diffusion.py` +- Modify: `tests/test_pusht_swanlab_config.py` + +- [ ] Add a failing model test that instantiates `IMFTransformerForDiffusion(backbone_type='attnres_full', causal_attn=False, ...)`, runs a forward pass with conditional observations, and asserts output shape plus optimizer construction. +- [ ] Run the targeted pytest selection and confirm the new test fails for the expected missing-backbone reason. +- [ ] Add a failing config regression test for `image_pusht_diffusion_policy_dit_imf_attnres_full.yaml` asserting SwanLab naming fields and `policy.causal_attn == False`. +- [ ] Re-run the targeted pytest selection and confirm the config test fails before implementation. + +### Task 2: Implement the AttnRes-backed iMF backbone + +**Files:** +- Create: `diffusion_policy/model/diffusion/attnres_transformer_components.py` +- Modify: `diffusion_policy/model/diffusion/imf_transformer_for_diffusion.py` + +- [ ] Add focused reusable modules for `RMSNorm`, RoPE helpers, grouped-query self-attention, SwiGLU FFN, and the Full AttnRes operator. +- [ ] Extend `IMFTransformerForDiffusion` with a `backbone_type` switch that preserves the existing vanilla path and adds an `attnres_full` path using concatenated `[r, t, obs, sample]` tokens. +- [ ] Ensure the AttnRes path slices condition tokens away before the output head so the returned tensor still matches the sample/action horizon. +- [ ] Update optimizer parameter grouping to treat RMSNorm weights like LayerNorm weights (no decay) and include any new positional/conditioning parameters. +- [ ] Run the targeted tests and get them green. + +### Task 3: Add the new PushT config and smoke-test path + +**Files:** +- Create: `image_pusht_diffusion_policy_dit_imf_attnres_full.yaml` +- Modify: `tests/test_pusht_swanlab_config.py` + +- [ ] Add a standalone PushT image config for the AttnRes iMF variant with SwanLab online logging, `policy.backbone_type=attnres_full`, and `policy.causal_attn=false`. +- [ ] Verify `uv run python train.py --config-dir=. --config-name=image_pusht_diffusion_policy_dit_imf_attnres_full.yaml --help` succeeds. +- [ ] Run a real smoke training command with `training.debug=true`, `training.device=cuda:0`, safety overrides (`dataloader.num_workers=0`, `task.env_runner.n_envs=1`, no vis), and confirm it reaches the training loop and writes a run directory. + +### Task 4: Prepare launch scripts and start the 9-run sweep + +**Files:** +- Create or modify: `data/run_logs/imf_attnres_local_queue.sh` +- Create or modify locally before copy: `data/run_logs/imf_attnres_remote_gpu0_queue.sh` +- Create or modify locally before copy: `data/run_logs/imf_attnres_remote_gpu1_queue.sh` + +- [ ] Write queue command templates for the 9 runs using config `image_pusht_diffusion_policy_dit_imf_attnres_full.yaml`, `training.num_epochs=350`, unique `exp_name/logging.name`, and shared `logging.group=imf_pusht_attnres_arch_sweep`. +- [ ] Sync the necessary config/model files plus remote queue scripts to `droid@100.73.14.65:~/project/diffusion_policy-smoke`. +- [ ] Start the local queue under `nohup`, record PID, and verify the first run log is advancing. +- [ ] Start the two remote queues under `nohup`, record PIDs, and verify both first-run logs are advancing. +- [ ] Confirm all three GPUs have officially entered training for the new sweep. diff --git a/docs/superpowers/specs/2026-03-29-pusht-imf-attnres-design.md b/docs/superpowers/specs/2026-03-29-pusht-imf-attnres-design.md new file mode 100644 index 0000000..cdd8c87 --- /dev/null +++ b/docs/superpowers/specs/2026-03-29-pusht-imf-attnres-design.md @@ -0,0 +1,108 @@ +# PushT Image iMF AttnRes Design + +## Goal +在现有 PushT 图像 iMF full-attention 路线之上,引入 `attn_res` 仓库中的 **Full AttnRes** 残差聚合形式,并同步使用与其匹配的 **RMSNorm + 自注意力 + SwiGLU FFN** 模块,保持 iMF 训练目标与一步推理语义不变,仅作用于本次实验链路。实现完成并验证后,启动与此前相同的 9 组 `n_emb × n_layer` 扫描(350 epochs, seed=42, SwanLab online, 无视频记录)。 + +## Scope +本次工作仅覆盖: +1. 为 `IMFTransformerForDiffusion` 增加一个 AttnRes-backed backbone 变体; +2. 保持 `forward(sample, r, t, cond=None)`、iMF loss、一步推理策略接口不变; +3. 新增独立 PushT 图像配置用于该变体; +4. 复用本地 5090 + 远端 5880 双卡三路并行调度 9 组实验。 + +不在范围内: +- 不替换已有 vanilla iMF/full-attn 配置; +- 不修改 DiT baseline; +- 不增加视频日志; +- 不扩大到多 seed。 + +## Recommended Approach +采用“**在当前 iMF 模型内增加可选 AttnRes backbone**”的方式,而不是新建独立 policy 链路。 + +理由: +- policy / workspace / loss / sampling 路径已经被验证,保留这些路径可最大程度缩小变动面; +- 仅在模型内部切换 backbone,可以让新实验与既有 iMF 结果保持可比; +- 配置上只需显式打开 `backbone_type=attnres_full`、`causal_attn=false` 等开关,复现实验更直接。 + +## Architecture +### 1. Backbone split +`IMFTransformerForDiffusion` 保留现有 vanilla encoder/decoder 实现为默认路径,并新增 `attnres_full` 路径: +- **vanilla**:保持当前实现不变; +- **attnres_full**:使用单栈式全注意力 Transformer,输入 token 序列为 + `[r token, t token, obs cond tokens..., action/sample tokens...]`。 + +模型只对末尾的 action/sample token 位置输出 `u` 预测,前置条件 token 仅参与上下文建模。 + +### 2. AttnRes stack +新 backbone 使用以下模块: +- `RMSNorm` +- `Rotary Position Embedding`(用于自注意力 q/k) +- `GroupedQueryAttention`(本实验默认 `n_kv_head=1`,与单头配置兼容) +- `SwiGLU` FFN +- `AttnResOperator`(每个子层一个 pseudo-query,执行 full depth-wise residual aggregation) + +每个 transformer block 由两个子层组成: +1. self-attention 子层 +2. FFN 子层 + +每个子层的输入不再是简单 `x + f(x)`,而是从 embedding 与全部历史子层输出中通过 Full AttnRes 聚合得到 `h_l`,再执行 `RMSNorm(h_l) -> sublayer_fn(...)`。 + +### 3. Conditioning and token flow +- `sample` 先经 `input_emb` 映射为 action tokens; +- `r` 和 `t` 各自经 `SinusoidalPosEmb + linear` 映射为两个条件 token; +- 图像观测编码后的 `cond` 通过 `cond_obs_emb` 映射为 obs tokens; +- 拼接后的完整 token 序列进入 AttnRes stack; +- 输出时切掉前置条件 token,仅保留 action/sample token 段,随后经 `RMSNorm + head` 得到最终 `u`。 + +### 4. Attention mode +本次实验链路固定为 **non-causal full attention**: +- `causal_attn=false` +- 不构造 causal mask +- 所有 token 可彼此双向可见 + +这与用户指定的“训练过程仍然使用全注意力(不加因果注意)”一致。 + +## Config and Logging +新增独立配置文件,例如: +- `image_pusht_diffusion_policy_dit_imf_attnres_full.yaml` + +该配置需要: +- 指向现有 `IMFTransformerHybridImagePolicy` +- 显式开启 AttnRes backbone 相关参数 +- 设置 `policy.causal_attn=false` +- 保持 `logging.backend=swanlab`、`logging.mode=online` +- 运行时通过覆盖保证: + - `logging.name=` + - `logging.group=imf_pusht_attnres_arch_sweep` + - `exp_name=` +- 保持 `task.env_runner.n_test_vis=0` 与 `n_train_vis=0`,仅记录标量 + +## Experiment Matrix +固定 9 组: +- `n_emb ∈ {128, 256, 384}` +- `n_layer ∈ {6, 12, 18}` +- `seed=42` +- `training.num_epochs=350` + +## Scheduling +沿用之前验证过的三队列分配: +- 本机 5090:`384x18`, `256x6`, `128x6` +- 5880 GPU0:`384x12`, `256x12`, `128x12` +- 5880 GPU1:`384x6`, `256x18`, `128x18` + +每个 run name 编码 backbone 与结构,例如: +`imf_attnres_emb256_layer12_seed42_5880gpu0` + +## Verification +实现阶段至少验证: +1. 新配置的 SwanLab 命名与 `causal_attn=false` 正确; +2. 新 backbone 的 forward shape 与 `configure_optimizers()` 可用; +3. 旧 vanilla 路径测试不回归; +4. `training.debug=true` smoke run 可以完整通过。 + +## Success Criteria +1. 新 AttnRes iMF 变体在本分支可训练、可一步推理; +2. 不影响已有 vanilla iMF/full-attn 链路; +3. 9 组实验成功在三张卡上正式启动; +4. SwanLab run 名称唯一,无冲突; +5. 不记录视频,仅记录标量。 diff --git a/image_pusht_diffusion_policy_dit_imf_attnres_full.yaml b/image_pusht_diffusion_policy_dit_imf_attnres_full.yaml new file mode 100644 index 0000000..303fed5 --- /dev/null +++ b/image_pusht_diffusion_policy_dit_imf_attnres_full.yaml @@ -0,0 +1,35 @@ +defaults: + - diffusion_policy/config/train_diffusion_transformer_hybrid_workspace@_here_ + - override /diffusion_policy/config/task@task: pusht_image + - _self_ + +exp_name: pusht_image_dit_imf_attnres_full + +policy: + _target_: diffusion_policy.policy.imf_transformer_hybrid_image_policy.IMFTransformerHybridImagePolicy + num_inference_steps: 1 + n_head: 1 + n_kv_head: 1 + causal_attn: false + backbone_type: attnres_full + +logging: + backend: swanlab + mode: online + name: ${exp_name} + resume: false + tags: ["${name}", "${task_name}", "${exp_name}", "swanlab"] + id: null + group: ${exp_name} + +dataloader: + num_workers: 0 + +val_dataloader: + num_workers: 0 + +task: + env_runner: + n_envs: 1 + n_test_vis: 0 + n_train_vis: 0 diff --git a/tests/test_imf_transformer_for_diffusion.py b/tests/test_imf_transformer_for_diffusion.py index c1ba345..d08be99 100644 --- a/tests/test_imf_transformer_for_diffusion.py +++ b/tests/test_imf_transformer_for_diffusion.py @@ -44,3 +44,34 @@ def test_imf_transformer_forward_signature_and_shape_single_head(): pred_u = model(sample, r, t, cond=cond) assert pred_u.shape == sample.shape + + +def test_imf_transformer_attnres_full_backbone_forward_shape_and_optimizer(): + model = IMFTransformerForDiffusion( + input_dim=3, + output_dim=3, + horizon=5, + n_obs_steps=2, + cond_dim=4, + n_layer=2, + n_head=1, + n_emb=16, + p_drop_emb=0.0, + p_drop_attn=0.0, + causal_attn=False, + time_as_cond=True, + obs_as_cond=True, + n_cond_layers=0, + backbone_type='attnres_full', + ) + optimizer = model.configure_optimizers() + + sample = torch.randn(2, 5, 3) + r = torch.rand(2) + t = torch.rand(2) + cond = torch.randn(2, 2, 4) + + pred_u = model(sample, r, t, cond=cond) + + assert pred_u.shape == sample.shape + assert optimizer is not None diff --git a/tests/test_pusht_swanlab_config.py b/tests/test_pusht_swanlab_config.py index d008345..9e071ef 100644 --- a/tests/test_pusht_swanlab_config.py +++ b/tests/test_pusht_swanlab_config.py @@ -42,3 +42,16 @@ def test_image_pusht_dit_imf_fullattn_config_uses_exp_name_and_disables_causal_a assert cfg.logging.id is None assert cfg.logging.group == cfg.exp_name assert cfg.policy.causal_attn is False + + +def test_image_pusht_dit_imf_attnres_full_config_uses_exp_name_and_disables_causal_attention(): + cfg = _load_cfg('image_pusht_diffusion_policy_dit_imf_attnres_full.yaml') + + assert cfg.logging.backend == 'swanlab' + assert cfg.logging.mode == 'online' + assert cfg.logging.name == cfg.exp_name + assert cfg.logging.resume is False + assert cfg.logging.id is None + assert cfg.logging.group == cfg.exp_name + assert cfg.policy.causal_attn is False + assert cfg.policy.backbone_type == 'attnres_full'