2 Commits

Author SHA1 Message Date
Logic
211abbb87f chore: add attnres sweep queue scripts 2026-03-29 11:18:35 +08:00
Logic
185ed6596c feat: add pusht imf attnres backbone 2026-03-29 11:15:59 +08:00
11 changed files with 734 additions and 61 deletions

View File

@@ -0,0 +1,247 @@
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6) -> None:
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: Tensor) -> Tensor:
rms = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
return (x.float() * rms).to(x.dtype) * self.weight
class RMSNormNoWeight(nn.Module):
def __init__(self, eps: float = 1e-6) -> None:
super().__init__()
self.eps = eps
def forward(self, x: Tensor) -> Tensor:
rms = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
return (x.float() * rms).to(x.dtype)
def precompute_rope_freqs(
dim: int,
max_seq_len: int,
theta: float = 10000.0,
device: Optional[torch.device] = None,
) -> Tensor:
if dim % 2 != 0:
raise ValueError(f'RoPE requires an even head dimension, got {dim}.')
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim))
positions = torch.arange(max_seq_len, device=device).float()
angles = torch.outer(positions, freqs)
return torch.polar(torch.ones_like(angles), angles)
def apply_rope(x: Tensor, freqs: Tensor) -> Tensor:
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
freqs = freqs.unsqueeze(0).unsqueeze(2)
x_rotated = x_complex * freqs
return torch.view_as_real(x_rotated).reshape_as(x).to(x.dtype)
class GroupedQuerySelfAttention(nn.Module):
def __init__(
self,
d_model: int,
n_heads: int,
n_kv_heads: int,
dropout: float = 0.0,
) -> None:
super().__init__()
if d_model % n_heads != 0:
raise ValueError(f'd_model={d_model} must be divisible by n_heads={n_heads}.')
if n_heads % n_kv_heads != 0:
raise ValueError(f'n_heads={n_heads} must be divisible by n_kv_heads={n_kv_heads}.')
self.d_model = d_model
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.n_kv_groups = n_heads // n_kv_heads
self.d_head = d_model // n_heads
self.attn_dropout = nn.Dropout(dropout)
self.out_dropout = nn.Dropout(dropout)
self.w_q = nn.Linear(d_model, n_heads * self.d_head, bias=False)
self.w_k = nn.Linear(d_model, n_kv_heads * self.d_head, bias=False)
self.w_v = nn.Linear(d_model, n_kv_heads * self.d_head, bias=False)
self.w_o = nn.Linear(n_heads * self.d_head, d_model, bias=False)
def forward(
self,
x: Tensor,
rope_freqs: Tensor,
mask: Optional[Tensor] = None,
) -> Tensor:
batch_size, seq_len, _ = x.shape
q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_head)
k = self.w_k(x).view(batch_size, seq_len, self.n_kv_heads, self.d_head)
v = self.w_v(x).view(batch_size, seq_len, self.n_kv_heads, self.d_head)
q = apply_rope(q, rope_freqs)
k = apply_rope(k, rope_freqs)
if self.n_kv_heads != self.n_heads:
k = k.unsqueeze(3).expand(
batch_size, seq_len, self.n_kv_heads, self.n_kv_groups, self.d_head
)
k = k.reshape(batch_size, seq_len, self.n_heads, self.d_head)
v = v.unsqueeze(3).expand(
batch_size, seq_len, self.n_kv_heads, self.n_kv_groups, self.d_head
)
v = v.reshape(batch_size, seq_len, self.n_heads, self.d_head)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
scale = 1.0 / math.sqrt(self.d_head)
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale
if mask is not None:
attn_weights = attn_weights + mask
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = self.attn_dropout(attn_weights)
out = torch.matmul(attn_weights, v)
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
return self.out_dropout(self.w_o(out))
class SwiGLUFFN(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.0, mult: float = 2.667) -> None:
super().__init__()
raw = int(mult * d_model)
d_ff = ((raw + 7) // 8) * 8
self.w_gate = nn.Linear(d_model, d_ff, bias=False)
self.w_up = nn.Linear(d_model, d_ff, bias=False)
self.w_down = nn.Linear(d_ff, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x: Tensor) -> Tensor:
return self.dropout(self.w_down(F.silu(self.w_gate(x)) * self.w_up(x)))
class AttnResOperator(nn.Module):
def __init__(self, d_model: int, eps: float = 1e-6) -> None:
super().__init__()
self.pseudo_query = nn.Parameter(torch.zeros(d_model))
self.key_norm = RMSNormNoWeight(eps=eps)
def forward(self, sources: Tensor) -> Tensor:
keys = self.key_norm(sources)
logits = torch.einsum('d,nbtd->nbt', self.pseudo_query, keys)
weights = F.softmax(logits, dim=0)
return torch.einsum('nbt,nbtd->btd', weights, sources)
class AttnResSubLayer(nn.Module):
def __init__(
self,
d_model: int,
n_heads: int,
n_kv_heads: int,
dropout: float,
ffn_mult: float,
eps: float,
is_attention: bool,
) -> None:
super().__init__()
self.norm = RMSNorm(d_model, eps=eps)
self.attn_res = AttnResOperator(d_model, eps=eps)
self.is_attention = is_attention
if is_attention:
self.fn = GroupedQuerySelfAttention(
d_model=d_model,
n_heads=n_heads,
n_kv_heads=n_kv_heads,
dropout=dropout,
)
else:
self.fn = SwiGLUFFN(d_model=d_model, dropout=dropout, mult=ffn_mult)
def forward(self, sources: Tensor, rope_freqs: Tensor, mask: Optional[Tensor] = None) -> Tensor:
h = self.attn_res(sources)
normed = self.norm(h)
if self.is_attention:
return self.fn(normed, rope_freqs, mask)
return self.fn(normed)
class AttnResTransformerBackbone(nn.Module):
def __init__(
self,
d_model: int,
n_blocks: int,
n_heads: int,
n_kv_heads: int,
max_seq_len: int,
dropout: float = 0.0,
ffn_mult: float = 2.667,
eps: float = 1e-6,
rope_theta: float = 10000.0,
causal_attn: bool = False,
) -> None:
super().__init__()
self.causal_attn = causal_attn
self.layers = nn.ModuleList()
for _ in range(n_blocks):
self.layers.append(
AttnResSubLayer(
d_model=d_model,
n_heads=n_heads,
n_kv_heads=n_kv_heads,
dropout=dropout,
ffn_mult=ffn_mult,
eps=eps,
is_attention=True,
)
)
self.layers.append(
AttnResSubLayer(
d_model=d_model,
n_heads=n_heads,
n_kv_heads=n_kv_heads,
dropout=dropout,
ffn_mult=ffn_mult,
eps=eps,
is_attention=False,
)
)
rope_freqs = precompute_rope_freqs(
dim=d_model // n_heads,
max_seq_len=max_seq_len,
theta=rope_theta,
)
self.register_buffer('rope_freqs', rope_freqs, persistent=False)
@staticmethod
def _build_causal_mask(seq_len: int, device: torch.device) -> Tensor:
mask = torch.full((seq_len, seq_len), float('-inf'), device=device)
mask = torch.triu(mask, diagonal=1)
return mask.unsqueeze(0).unsqueeze(0)
def forward(self, x: Tensor) -> Tensor:
seq_len = x.shape[1]
rope_freqs = self.rope_freqs[:seq_len]
mask = None
if self.causal_attn:
mask = self._build_causal_mask(seq_len, x.device)
layer_outputs = [x]
for layer in self.layers:
sources = torch.stack(layer_outputs, dim=0)
output = layer(sources, rope_freqs, mask)
layer_outputs.append(output)
return torch.stack(layer_outputs, dim=0).sum(dim=0)

View File

@@ -5,6 +5,15 @@ import torch
import torch.nn as nn
from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin
from diffusion_policy.model.diffusion.attnres_transformer_components import (
AttnResOperator,
AttnResSubLayer,
AttnResTransformerBackbone,
GroupedQuerySelfAttention,
RMSNorm,
RMSNormNoWeight,
SwiGLUFFN,
)
from diffusion_policy.model.diffusion.positional_embedding import SinusoidalPosEmb
logger = logging.getLogger(__name__)
@@ -27,14 +36,20 @@ class IMFTransformerForDiffusion(ModuleAttrMixin):
time_as_cond: bool = True,
obs_as_cond: bool = False,
n_cond_layers: int = 0,
backbone_type: str = 'vanilla',
n_kv_head: int = 1,
attn_res_ffn_mult: float = 2.667,
attn_res_eps: float = 1e-6,
attn_res_rope_theta: float = 10000.0,
) -> None:
super().__init__()
assert n_head == 1, 'IMFTransformerForDiffusion currently supports single-head attention only.'
if n_obs_steps is None:
n_obs_steps = horizon
self.backbone_type = backbone_type
T = horizon
T_cond = 2
if not time_as_cond:
@@ -46,21 +61,77 @@ class IMFTransformerForDiffusion(ModuleAttrMixin):
T_cond += n_obs_steps
self.input_emb = nn.Linear(input_dim, n_emb)
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
self.drop = nn.Dropout(p_drop_emb)
self.time_emb = SinusoidalPosEmb(n_emb)
self.cond_obs_emb = None
if obs_as_cond:
self.cond_obs_emb = nn.Linear(cond_dim, n_emb)
self.cond_obs_emb = nn.Linear(cond_dim, n_emb) if obs_as_cond else None
self.time_token_proj = None
self.cond_pos_emb = None
self.pos_emb = None
self.encoder = None
self.decoder = None
self.attnres_backbone = None
encoder_only = False
if T_cond > 0:
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
if n_cond_layers > 0:
if backbone_type == 'attnres_full':
if not time_as_cond:
raise ValueError('attnres_full backbone requires time_as_cond=True.')
if n_cond_layers != 0:
raise ValueError('attnres_full backbone does not support n_cond_layers > 0.')
self.time_token_proj = nn.Linear(n_emb, n_emb)
self.attnres_backbone = AttnResTransformerBackbone(
d_model=n_emb,
n_blocks=n_layer,
n_heads=n_head,
n_kv_heads=n_kv_head,
max_seq_len=T + T_cond,
dropout=p_drop_attn,
ffn_mult=attn_res_ffn_mult,
eps=attn_res_eps,
rope_theta=attn_res_rope_theta,
causal_attn=causal_attn,
)
self.ln_f = RMSNorm(n_emb, eps=attn_res_eps)
else:
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
if T_cond > 0:
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
if n_cond_layers > 0:
encoder_layer = nn.TransformerEncoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4 * n_emb,
dropout=p_drop_attn,
activation='gelu',
batch_first=True,
norm_first=True,
)
self.encoder = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=n_cond_layers,
)
else:
self.encoder = nn.Sequential(
nn.Linear(n_emb, 4 * n_emb),
nn.Mish(),
nn.Linear(4 * n_emb, n_emb),
)
decoder_layer = nn.TransformerDecoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4 * n_emb,
dropout=p_drop_attn,
activation='gelu',
batch_first=True,
norm_first=True,
)
self.decoder = nn.TransformerDecoder(
decoder_layer=decoder_layer,
num_layers=n_layer,
)
else:
encoder_only = True
encoder_layer = nn.TransformerEncoderLayer(
d_model=n_emb,
nhead=n_head,
@@ -72,45 +143,12 @@ class IMFTransformerForDiffusion(ModuleAttrMixin):
)
self.encoder = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=n_cond_layers,
)
else:
self.encoder = nn.Sequential(
nn.Linear(n_emb, 4 * n_emb),
nn.Mish(),
nn.Linear(4 * n_emb, n_emb),
num_layers=n_layer,
)
decoder_layer = nn.TransformerDecoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4 * n_emb,
dropout=p_drop_attn,
activation='gelu',
batch_first=True,
norm_first=True,
)
self.decoder = nn.TransformerDecoder(
decoder_layer=decoder_layer,
num_layers=n_layer,
)
else:
encoder_only = True
encoder_layer = nn.TransformerEncoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4 * n_emb,
dropout=p_drop_attn,
activation='gelu',
batch_first=True,
norm_first=True,
)
self.encoder = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=n_layer,
)
self.ln_f = nn.LayerNorm(n_emb)
if causal_attn:
if causal_attn and backbone_type != 'attnres_full':
sz = T
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
@@ -132,7 +170,6 @@ class IMFTransformerForDiffusion(ModuleAttrMixin):
self.mask = None
self.memory_mask = None
self.ln_f = nn.LayerNorm(n_emb)
self.head = nn.Linear(n_emb, output_dim)
self.T = T
@@ -159,6 +196,11 @@ class IMFTransformerForDiffusion(ModuleAttrMixin):
nn.ModuleList,
nn.Mish,
nn.Sequential,
AttnResTransformerBackbone,
AttnResSubLayer,
GroupedQuerySelfAttention,
SwiGLUFFN,
RMSNormNoWeight,
)
if isinstance(module, (nn.Linear, nn.Embedding)):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
@@ -176,12 +218,16 @@ class IMFTransformerForDiffusion(ModuleAttrMixin):
bias = getattr(module, name)
if bias is not None:
torch.nn.init.zeros_(bias)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.zeros_(module.bias)
elif isinstance(module, (nn.LayerNorm, RMSNorm)):
if getattr(module, 'bias', None) is not None:
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight)
elif isinstance(module, AttnResOperator):
torch.nn.init.zeros_(module.pseudo_query)
elif isinstance(module, IMFTransformerForDiffusion):
torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
if module.cond_obs_emb is not None:
if module.pos_emb is not None:
torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
if module.cond_pos_emb is not None:
torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02)
elif isinstance(module, ignore_types):
pass
@@ -192,21 +238,24 @@ class IMFTransformerForDiffusion(ModuleAttrMixin):
decay = set()
no_decay = set()
whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention)
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding, RMSNorm)
for mn, m in self.named_modules():
for pn, _ in m.named_parameters():
for pn, _ in m.named_parameters(recurse=False):
fpn = '%s.%s' % (mn, pn) if mn else pn
if pn.endswith('bias'):
no_decay.add(fpn)
elif pn.startswith('bias'):
no_decay.add(fpn)
elif pn == 'pseudo_query':
no_decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
no_decay.add(fpn)
no_decay.add('pos_emb')
if self.pos_emb is not None:
no_decay.add('pos_emb')
no_decay.add('_dummy_variable')
if self.cond_pos_emb is not None:
no_decay.add('cond_pos_emb')
@@ -250,18 +299,38 @@ class IMFTransformerForDiffusion(ModuleAttrMixin):
value = value.to(device=sample.device, dtype=sample.dtype)
return value.expand(sample.shape[0])
def forward(
def _forward_attnres_full(
self,
sample: torch.Tensor,
r: Union[torch.Tensor, float, int],
t: Union[torch.Tensor, float, int],
r: torch.Tensor,
t: torch.Tensor,
cond: Optional[torch.Tensor] = None,
):
r = self._prepare_time_input(r, sample)
t = self._prepare_time_input(t, sample)
) -> torch.Tensor:
sample_tokens = self.input_emb(sample)
token_parts = [
self.time_token_proj(self.time_emb(r)).unsqueeze(1),
self.time_token_proj(self.time_emb(t)).unsqueeze(1),
]
if self.obs_as_cond:
if cond is None:
raise ValueError('cond is required when obs_as_cond=True for attnres_full backbone.')
token_parts.append(self.cond_obs_emb(cond))
token_parts.append(sample_tokens)
x = torch.cat(token_parts, dim=1)
x = self.drop(x)
x = self.attnres_backbone(x)
x = x[:, -sample_tokens.shape[1] :, :]
return x
def _forward_vanilla(
self,
sample: torch.Tensor,
r: torch.Tensor,
t: torch.Tensor,
cond: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r_emb = self.time_emb(r).unsqueeze(1)
t_emb = self.time_emb(t).unsqueeze(1)
input_emb = self.input_emb(sample)
if self.encoder_only:
@@ -292,6 +361,22 @@ class IMFTransformerForDiffusion(ModuleAttrMixin):
tgt_mask=self.mask,
memory_mask=self.memory_mask,
)
return x
def forward(
self,
sample: torch.Tensor,
r: Union[torch.Tensor, float, int],
t: Union[torch.Tensor, float, int],
cond: Optional[torch.Tensor] = None,
):
r = self._prepare_time_input(r, sample)
t = self._prepare_time_input(t, sample)
if self.backbone_type == 'attnres_full':
x = self._forward_attnres_full(sample, r, t, cond=cond)
else:
x = self._forward_vanilla(sample, r, t, cond=cond)
x = self.ln_f(x)
x = self.head(x)

View File

@@ -39,6 +39,11 @@ class IMFTransformerHybridImagePolicy(DiffusionTransformerHybridImagePolicy):
time_as_cond=True,
obs_as_cond=True,
pred_action_steps_only=False,
backbone_type='vanilla',
n_kv_head=1,
attn_res_ffn_mult=2.667,
attn_res_eps=1e-6,
attn_res_rope_theta=10000.0,
**kwargs,
):
if num_inference_steps is None:
@@ -90,6 +95,11 @@ class IMFTransformerHybridImagePolicy(DiffusionTransformerHybridImagePolicy):
time_as_cond=time_as_cond,
obs_as_cond=obs_as_cond,
n_cond_layers=n_cond_layers,
backbone_type=backbone_type,
n_kv_head=n_kv_head,
attn_res_ffn_mult=attn_res_ffn_mult,
attn_res_eps=attn_res_eps,
attn_res_rope_theta=attn_res_rope_theta,
)
self.num_inference_steps = 1

View File

@@ -0,0 +1,57 @@
# PushT Image iMF AttnRes Implementation Plan
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
**Goal:** Add an AttnRes-backed full-attention iMF backbone for the PushT image experiment path, verify it with tests/smoke runs, then launch the 9-run 350-epoch architecture sweep across the local 5090 and remote 5880 GPUs.
**Architecture:** Extend `IMFTransformerForDiffusion` with a selectable `attnres_full` backbone that keeps the current iMF training/inference API unchanged while replacing the transformer internals with RMSNorm + RoPE self-attention + SwiGLU + Full AttnRes depth-wise residual routing. Add one standalone Hydra config for the PushT image sweep and reuse queue-style launch scripts with unique SwanLab names.
**Tech Stack:** Python 3.9 via uv, PyTorch 2.8 CUDA, Hydra, SwanLab online logging, local shell + SSH to trusted 5880 host.
---
### Task 1: Add regression tests for the new AttnRes path
**Files:**
- Modify: `tests/test_imf_transformer_for_diffusion.py`
- Modify: `tests/test_pusht_swanlab_config.py`
- [ ] Add a failing model test that instantiates `IMFTransformerForDiffusion(backbone_type='attnres_full', causal_attn=False, ...)`, runs a forward pass with conditional observations, and asserts output shape plus optimizer construction.
- [ ] Run the targeted pytest selection and confirm the new test fails for the expected missing-backbone reason.
- [ ] Add a failing config regression test for `image_pusht_diffusion_policy_dit_imf_attnres_full.yaml` asserting SwanLab naming fields and `policy.causal_attn == False`.
- [ ] Re-run the targeted pytest selection and confirm the config test fails before implementation.
### Task 2: Implement the AttnRes-backed iMF backbone
**Files:**
- Create: `diffusion_policy/model/diffusion/attnres_transformer_components.py`
- Modify: `diffusion_policy/model/diffusion/imf_transformer_for_diffusion.py`
- [ ] Add focused reusable modules for `RMSNorm`, RoPE helpers, grouped-query self-attention, SwiGLU FFN, and the Full AttnRes operator.
- [ ] Extend `IMFTransformerForDiffusion` with a `backbone_type` switch that preserves the existing vanilla path and adds an `attnres_full` path using concatenated `[r, t, obs, sample]` tokens.
- [ ] Ensure the AttnRes path slices condition tokens away before the output head so the returned tensor still matches the sample/action horizon.
- [ ] Update optimizer parameter grouping to treat RMSNorm weights like LayerNorm weights (no decay) and include any new positional/conditioning parameters.
- [ ] Run the targeted tests and get them green.
### Task 3: Add the new PushT config and smoke-test path
**Files:**
- Create: `image_pusht_diffusion_policy_dit_imf_attnres_full.yaml`
- Modify: `tests/test_pusht_swanlab_config.py`
- [ ] Add a standalone PushT image config for the AttnRes iMF variant with SwanLab online logging, `policy.backbone_type=attnres_full`, and `policy.causal_attn=false`.
- [ ] Verify `uv run python train.py --config-dir=. --config-name=image_pusht_diffusion_policy_dit_imf_attnres_full.yaml --help` succeeds.
- [ ] Run a real smoke training command with `training.debug=true`, `training.device=cuda:0`, safety overrides (`dataloader.num_workers=0`, `task.env_runner.n_envs=1`, no vis), and confirm it reaches the training loop and writes a run directory.
### Task 4: Prepare launch scripts and start the 9-run sweep
**Files:**
- Create or modify: `data/run_logs/imf_attnres_local_queue.sh`
- Create or modify locally before copy: `data/run_logs/imf_attnres_remote_gpu0_queue.sh`
- Create or modify locally before copy: `data/run_logs/imf_attnres_remote_gpu1_queue.sh`
- [ ] Write queue command templates for the 9 runs using config `image_pusht_diffusion_policy_dit_imf_attnres_full.yaml`, `training.num_epochs=350`, unique `exp_name/logging.name`, and shared `logging.group=imf_pusht_attnres_arch_sweep`.
- [ ] Sync the necessary config/model files plus remote queue scripts to `droid@100.73.14.65:~/project/diffusion_policy-smoke`.
- [ ] Start the local queue under `nohup`, record PID, and verify the first run log is advancing.
- [ ] Start the two remote queues under `nohup`, record PIDs, and verify both first-run logs are advancing.
- [ ] Confirm all three GPUs have officially entered training for the new sweep.

View File

@@ -0,0 +1,108 @@
# PushT Image iMF AttnRes Design
## Goal
在现有 PushT 图像 iMF full-attention 路线之上,引入 `attn_res` 仓库中的 **Full AttnRes** 残差聚合形式,并同步使用与其匹配的 **RMSNorm + 自注意力 + SwiGLU FFN** 模块,保持 iMF 训练目标与一步推理语义不变,仅作用于本次实验链路。实现完成并验证后,启动与此前相同的 9 组 `n_emb × n_layer` 扫描350 epochs, seed=42, SwanLab online, 无视频记录)。
## Scope
本次工作仅覆盖:
1.`IMFTransformerForDiffusion` 增加一个 AttnRes-backed backbone 变体;
2. 保持 `forward(sample, r, t, cond=None)`、iMF loss、一步推理策略接口不变
3. 新增独立 PushT 图像配置用于该变体;
4. 复用本地 5090 + 远端 5880 双卡三路并行调度 9 组实验。
不在范围内:
- 不替换已有 vanilla iMF/full-attn 配置;
- 不修改 DiT baseline
- 不增加视频日志;
- 不扩大到多 seed。
## Recommended Approach
采用“**在当前 iMF 模型内增加可选 AttnRes backbone**”的方式,而不是新建独立 policy 链路。
理由:
- policy / workspace / loss / sampling 路径已经被验证,保留这些路径可最大程度缩小变动面;
- 仅在模型内部切换 backbone可以让新实验与既有 iMF 结果保持可比;
- 配置上只需显式打开 `backbone_type=attnres_full``causal_attn=false` 等开关,复现实验更直接。
## Architecture
### 1. Backbone split
`IMFTransformerForDiffusion` 保留现有 vanilla encoder/decoder 实现为默认路径,并新增 `attnres_full` 路径:
- **vanilla**:保持当前实现不变;
- **attnres_full**:使用单栈式全注意力 Transformer输入 token 序列为
`[r token, t token, obs cond tokens..., action/sample tokens...]`
模型只对末尾的 action/sample token 位置输出 `u` 预测,前置条件 token 仅参与上下文建模。
### 2. AttnRes stack
新 backbone 使用以下模块:
- `RMSNorm`
- `Rotary Position Embedding`(用于自注意力 q/k
- `GroupedQueryAttention`(本实验默认 `n_kv_head=1`,与单头配置兼容)
- `SwiGLU` FFN
- `AttnResOperator`(每个子层一个 pseudo-query执行 full depth-wise residual aggregation
每个 transformer block 由两个子层组成:
1. self-attention 子层
2. FFN 子层
每个子层的输入不再是简单 `x + f(x)`,而是从 embedding 与全部历史子层输出中通过 Full AttnRes 聚合得到 `h_l`,再执行 `RMSNorm(h_l) -> sublayer_fn(...)`
### 3. Conditioning and token flow
- `sample` 先经 `input_emb` 映射为 action tokens
- `r``t` 各自经 `SinusoidalPosEmb + linear` 映射为两个条件 token
- 图像观测编码后的 `cond` 通过 `cond_obs_emb` 映射为 obs tokens
- 拼接后的完整 token 序列进入 AttnRes stack
- 输出时切掉前置条件 token仅保留 action/sample token 段,随后经 `RMSNorm + head` 得到最终 `u`
### 4. Attention mode
本次实验链路固定为 **non-causal full attention**
- `causal_attn=false`
- 不构造 causal mask
- 所有 token 可彼此双向可见
这与用户指定的“训练过程仍然使用全注意力(不加因果注意)”一致。
## Config and Logging
新增独立配置文件,例如:
- `image_pusht_diffusion_policy_dit_imf_attnres_full.yaml`
该配置需要:
- 指向现有 `IMFTransformerHybridImagePolicy`
- 显式开启 AttnRes backbone 相关参数
- 设置 `policy.causal_attn=false`
- 保持 `logging.backend=swanlab``logging.mode=online`
- 运行时通过覆盖保证:
- `logging.name=<unique_run_name>`
- `logging.group=imf_pusht_attnres_arch_sweep`
- `exp_name=<unique_run_name>`
- 保持 `task.env_runner.n_test_vis=0``n_train_vis=0`,仅记录标量
## Experiment Matrix
固定 9 组:
- `n_emb ∈ {128, 256, 384}`
- `n_layer ∈ {6, 12, 18}`
- `seed=42`
- `training.num_epochs=350`
## Scheduling
沿用之前验证过的三队列分配:
- 本机 5090`384x18`, `256x6`, `128x6`
- 5880 GPU0`384x12`, `256x12`, `128x12`
- 5880 GPU1`384x6`, `256x18`, `128x18`
每个 run name 编码 backbone 与结构,例如:
`imf_attnres_emb256_layer12_seed42_5880gpu0`
## Verification
实现阶段至少验证:
1. 新配置的 SwanLab 命名与 `causal_attn=false` 正确;
2. 新 backbone 的 forward shape 与 `configure_optimizers()` 可用;
3. 旧 vanilla 路径测试不回归;
4. `training.debug=true` smoke run 可以完整通过。
## Success Criteria
1. 新 AttnRes iMF 变体在本分支可训练、可一步推理;
2. 不影响已有 vanilla iMF/full-attn 链路;
3. 9 组实验成功在三张卡上正式启动;
4. SwanLab run 名称唯一,无冲突;
5. 不记录视频,仅记录标量。

View File

@@ -0,0 +1,35 @@
defaults:
- diffusion_policy/config/train_diffusion_transformer_hybrid_workspace@_here_
- override /diffusion_policy/config/task@task: pusht_image
- _self_
exp_name: pusht_image_dit_imf_attnres_full
policy:
_target_: diffusion_policy.policy.imf_transformer_hybrid_image_policy.IMFTransformerHybridImagePolicy
num_inference_steps: 1
n_head: 1
n_kv_head: 1
causal_attn: false
backbone_type: attnres_full
logging:
backend: swanlab
mode: online
name: ${exp_name}
resume: false
tags: ["${name}", "${task_name}", "${exp_name}", "swanlab"]
id: null
group: ${exp_name}
dataloader:
num_workers: 0
val_dataloader:
num_workers: 0
task:
env_runner:
n_envs: 1
n_test_vis: 0
n_train_vis: 0

View File

@@ -0,0 +1,29 @@
#!/usr/bin/env bash
set -euo pipefail
cd /home/droid/project/diffusion_policy/.worktrees/feat-pusht-imf-attnres
export PYTHONUNBUFFERED=1
export SWANLAB_API_KEY='PSZrBMLx1XAjDjvmhUcNz'
export LD_LIBRARY_PATH="$(printf '%s:' .venv/lib/python3.9/site-packages/nvidia/*/lib | sed 's/:$//')"
run_exp() {
local name="$1" emb="$2" layer="$3"
echo "[$(date '+%F %T')] START $name emb=$emb layer=$layer"
.venv/bin/python train.py \
--config-dir=. \
--config-name=image_pusht_diffusion_policy_dit_imf_attnres_full.yaml \
training.device=cuda:0 \
training.num_epochs=350 \
training.resume=false \
exp_name="$name" \
logging.group=imf_pusht_attnres_arch_sweep \
logging.name="$name" \
logging.resume=false \
logging.id=null \
hydra.run.dir="data/outputs/$name" \
policy.n_emb="$emb" \
policy.n_layer="$layer" \
> "data/run_logs/${name}.log" 2>&1
echo "[$(date '+%F %T')] END $name"
}
run_exp imf_attnres_emb384_layer18_seed42_local 384 18
run_exp imf_attnres_emb256_layer6_seed42_local 256 6
run_exp imf_attnres_emb128_layer6_seed42_local 128 6

View File

@@ -0,0 +1,29 @@
#!/usr/bin/env bash
set -euo pipefail
cd /home/droid/project/diffusion_policy-smoke
export PYTHONUNBUFFERED=1
export SWANLAB_API_KEY='PSZrBMLx1XAjDjvmhUcNz'
export LD_LIBRARY_PATH="$(printf '%s:' .venv/lib/python3.9/site-packages/nvidia/*/lib | sed 's/:$//')"
run_exp() {
local name="$1" emb="$2" layer="$3"
echo "[$(date '+%F %T')] START $name emb=$emb layer=$layer"
.venv/bin/python train.py \
--config-dir=. \
--config-name=image_pusht_diffusion_policy_dit_imf_attnres_full.yaml \
training.device=cuda:0 \
training.num_epochs=350 \
training.resume=false \
exp_name="$name" \
logging.group=imf_pusht_attnres_arch_sweep \
logging.name="$name" \
logging.resume=false \
logging.id=null \
hydra.run.dir="data/outputs/$name" \
policy.n_emb="$emb" \
policy.n_layer="$layer" \
> "data/run_logs/${name}.log" 2>&1
echo "[$(date '+%F %T')] END $name"
}
run_exp imf_attnres_emb384_layer12_seed42_5880gpu0 384 12
run_exp imf_attnres_emb256_layer12_seed42_5880gpu0 256 12
run_exp imf_attnres_emb128_layer12_seed42_5880gpu0 128 12

View File

@@ -0,0 +1,29 @@
#!/usr/bin/env bash
set -euo pipefail
cd /home/droid/project/diffusion_policy-smoke
export PYTHONUNBUFFERED=1
export SWANLAB_API_KEY='PSZrBMLx1XAjDjvmhUcNz'
export LD_LIBRARY_PATH="$(printf '%s:' .venv/lib/python3.9/site-packages/nvidia/*/lib | sed 's/:$//')"
run_exp() {
local name="$1" emb="$2" layer="$3"
echo "[$(date '+%F %T')] START $name emb=$emb layer=$layer"
.venv/bin/python train.py \
--config-dir=. \
--config-name=image_pusht_diffusion_policy_dit_imf_attnres_full.yaml \
training.device=cuda:1 \
training.num_epochs=350 \
training.resume=false \
exp_name="$name" \
logging.group=imf_pusht_attnres_arch_sweep \
logging.name="$name" \
logging.resume=false \
logging.id=null \
hydra.run.dir="data/outputs/$name" \
policy.n_emb="$emb" \
policy.n_layer="$layer" \
> "data/run_logs/${name}.log" 2>&1
echo "[$(date '+%F %T')] END $name"
}
run_exp imf_attnres_emb384_layer6_seed42_5880gpu1 384 6
run_exp imf_attnres_emb256_layer18_seed42_5880gpu1 256 18
run_exp imf_attnres_emb128_layer18_seed42_5880gpu1 128 18

View File

@@ -44,3 +44,34 @@ def test_imf_transformer_forward_signature_and_shape_single_head():
pred_u = model(sample, r, t, cond=cond)
assert pred_u.shape == sample.shape
def test_imf_transformer_attnres_full_backbone_forward_shape_and_optimizer():
model = IMFTransformerForDiffusion(
input_dim=3,
output_dim=3,
horizon=5,
n_obs_steps=2,
cond_dim=4,
n_layer=2,
n_head=1,
n_emb=16,
p_drop_emb=0.0,
p_drop_attn=0.0,
causal_attn=False,
time_as_cond=True,
obs_as_cond=True,
n_cond_layers=0,
backbone_type='attnres_full',
)
optimizer = model.configure_optimizers()
sample = torch.randn(2, 5, 3)
r = torch.rand(2)
t = torch.rand(2)
cond = torch.randn(2, 2, 4)
pred_u = model(sample, r, t, cond=cond)
assert pred_u.shape == sample.shape
assert optimizer is not None

View File

@@ -42,3 +42,16 @@ def test_image_pusht_dit_imf_fullattn_config_uses_exp_name_and_disables_causal_a
assert cfg.logging.id is None
assert cfg.logging.group == cfg.exp_name
assert cfg.policy.causal_attn is False
def test_image_pusht_dit_imf_attnres_full_config_uses_exp_name_and_disables_causal_attention():
cfg = _load_cfg('image_pusht_diffusion_policy_dit_imf_attnres_full.yaml')
assert cfg.logging.backend == 'swanlab'
assert cfg.logging.mode == 'online'
assert cfg.logging.name == cfg.exp_name
assert cfg.logging.resume is False
assert cfg.logging.id is None
assert cfg.logging.group == cfg.exp_name
assert cfg.policy.causal_attn is False
assert cfg.policy.backbone_type == 'attnres_full'