1 Commits

Author SHA1 Message Date
Logic
31925bbf39 feat: add pusht dit no-causal config 2026-03-27 17:06:16 +08:00
15 changed files with 118 additions and 907 deletions

View File

@@ -1,247 +0,0 @@
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6) -> None:
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: Tensor) -> Tensor:
rms = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
return (x.float() * rms).to(x.dtype) * self.weight
class RMSNormNoWeight(nn.Module):
def __init__(self, eps: float = 1e-6) -> None:
super().__init__()
self.eps = eps
def forward(self, x: Tensor) -> Tensor:
rms = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
return (x.float() * rms).to(x.dtype)
def precompute_rope_freqs(
dim: int,
max_seq_len: int,
theta: float = 10000.0,
device: Optional[torch.device] = None,
) -> Tensor:
if dim % 2 != 0:
raise ValueError(f'RoPE requires an even head dimension, got {dim}.')
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim))
positions = torch.arange(max_seq_len, device=device).float()
angles = torch.outer(positions, freqs)
return torch.polar(torch.ones_like(angles), angles)
def apply_rope(x: Tensor, freqs: Tensor) -> Tensor:
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
freqs = freqs.unsqueeze(0).unsqueeze(2)
x_rotated = x_complex * freqs
return torch.view_as_real(x_rotated).reshape_as(x).to(x.dtype)
class GroupedQuerySelfAttention(nn.Module):
def __init__(
self,
d_model: int,
n_heads: int,
n_kv_heads: int,
dropout: float = 0.0,
) -> None:
super().__init__()
if d_model % n_heads != 0:
raise ValueError(f'd_model={d_model} must be divisible by n_heads={n_heads}.')
if n_heads % n_kv_heads != 0:
raise ValueError(f'n_heads={n_heads} must be divisible by n_kv_heads={n_kv_heads}.')
self.d_model = d_model
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.n_kv_groups = n_heads // n_kv_heads
self.d_head = d_model // n_heads
self.attn_dropout = nn.Dropout(dropout)
self.out_dropout = nn.Dropout(dropout)
self.w_q = nn.Linear(d_model, n_heads * self.d_head, bias=False)
self.w_k = nn.Linear(d_model, n_kv_heads * self.d_head, bias=False)
self.w_v = nn.Linear(d_model, n_kv_heads * self.d_head, bias=False)
self.w_o = nn.Linear(n_heads * self.d_head, d_model, bias=False)
def forward(
self,
x: Tensor,
rope_freqs: Tensor,
mask: Optional[Tensor] = None,
) -> Tensor:
batch_size, seq_len, _ = x.shape
q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_head)
k = self.w_k(x).view(batch_size, seq_len, self.n_kv_heads, self.d_head)
v = self.w_v(x).view(batch_size, seq_len, self.n_kv_heads, self.d_head)
q = apply_rope(q, rope_freqs)
k = apply_rope(k, rope_freqs)
if self.n_kv_heads != self.n_heads:
k = k.unsqueeze(3).expand(
batch_size, seq_len, self.n_kv_heads, self.n_kv_groups, self.d_head
)
k = k.reshape(batch_size, seq_len, self.n_heads, self.d_head)
v = v.unsqueeze(3).expand(
batch_size, seq_len, self.n_kv_heads, self.n_kv_groups, self.d_head
)
v = v.reshape(batch_size, seq_len, self.n_heads, self.d_head)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
scale = 1.0 / math.sqrt(self.d_head)
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale
if mask is not None:
attn_weights = attn_weights + mask
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = self.attn_dropout(attn_weights)
out = torch.matmul(attn_weights, v)
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
return self.out_dropout(self.w_o(out))
class SwiGLUFFN(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.0, mult: float = 2.667) -> None:
super().__init__()
raw = int(mult * d_model)
d_ff = ((raw + 7) // 8) * 8
self.w_gate = nn.Linear(d_model, d_ff, bias=False)
self.w_up = nn.Linear(d_model, d_ff, bias=False)
self.w_down = nn.Linear(d_ff, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x: Tensor) -> Tensor:
return self.dropout(self.w_down(F.silu(self.w_gate(x)) * self.w_up(x)))
class AttnResOperator(nn.Module):
def __init__(self, d_model: int, eps: float = 1e-6) -> None:
super().__init__()
self.pseudo_query = nn.Parameter(torch.zeros(d_model))
self.key_norm = RMSNormNoWeight(eps=eps)
def forward(self, sources: Tensor) -> Tensor:
keys = self.key_norm(sources)
logits = torch.einsum('d,nbtd->nbt', self.pseudo_query, keys)
weights = F.softmax(logits, dim=0)
return torch.einsum('nbt,nbtd->btd', weights, sources)
class AttnResSubLayer(nn.Module):
def __init__(
self,
d_model: int,
n_heads: int,
n_kv_heads: int,
dropout: float,
ffn_mult: float,
eps: float,
is_attention: bool,
) -> None:
super().__init__()
self.norm = RMSNorm(d_model, eps=eps)
self.attn_res = AttnResOperator(d_model, eps=eps)
self.is_attention = is_attention
if is_attention:
self.fn = GroupedQuerySelfAttention(
d_model=d_model,
n_heads=n_heads,
n_kv_heads=n_kv_heads,
dropout=dropout,
)
else:
self.fn = SwiGLUFFN(d_model=d_model, dropout=dropout, mult=ffn_mult)
def forward(self, sources: Tensor, rope_freqs: Tensor, mask: Optional[Tensor] = None) -> Tensor:
h = self.attn_res(sources)
normed = self.norm(h)
if self.is_attention:
return self.fn(normed, rope_freqs, mask)
return self.fn(normed)
class AttnResTransformerBackbone(nn.Module):
def __init__(
self,
d_model: int,
n_blocks: int,
n_heads: int,
n_kv_heads: int,
max_seq_len: int,
dropout: float = 0.0,
ffn_mult: float = 2.667,
eps: float = 1e-6,
rope_theta: float = 10000.0,
causal_attn: bool = False,
) -> None:
super().__init__()
self.causal_attn = causal_attn
self.layers = nn.ModuleList()
for _ in range(n_blocks):
self.layers.append(
AttnResSubLayer(
d_model=d_model,
n_heads=n_heads,
n_kv_heads=n_kv_heads,
dropout=dropout,
ffn_mult=ffn_mult,
eps=eps,
is_attention=True,
)
)
self.layers.append(
AttnResSubLayer(
d_model=d_model,
n_heads=n_heads,
n_kv_heads=n_kv_heads,
dropout=dropout,
ffn_mult=ffn_mult,
eps=eps,
is_attention=False,
)
)
rope_freqs = precompute_rope_freqs(
dim=d_model // n_heads,
max_seq_len=max_seq_len,
theta=rope_theta,
)
self.register_buffer('rope_freqs', rope_freqs, persistent=False)
@staticmethod
def _build_causal_mask(seq_len: int, device: torch.device) -> Tensor:
mask = torch.full((seq_len, seq_len), float('-inf'), device=device)
mask = torch.triu(mask, diagonal=1)
return mask.unsqueeze(0).unsqueeze(0)
def forward(self, x: Tensor) -> Tensor:
seq_len = x.shape[1]
rope_freqs = self.rope_freqs[:seq_len]
mask = None
if self.causal_attn:
mask = self._build_causal_mask(seq_len, x.device)
layer_outputs = [x]
for layer in self.layers:
sources = torch.stack(layer_outputs, dim=0)
output = layer(sources, rope_freqs, mask)
layer_outputs.append(output)
return torch.stack(layer_outputs, dim=0).sum(dim=0)

View File

@@ -5,15 +5,6 @@ import torch
import torch.nn as nn import torch.nn as nn
from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin
from diffusion_policy.model.diffusion.attnres_transformer_components import (
AttnResOperator,
AttnResSubLayer,
AttnResTransformerBackbone,
GroupedQuerySelfAttention,
RMSNorm,
RMSNormNoWeight,
SwiGLUFFN,
)
from diffusion_policy.model.diffusion.positional_embedding import SinusoidalPosEmb from diffusion_policy.model.diffusion.positional_embedding import SinusoidalPosEmb
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -36,20 +27,14 @@ class IMFTransformerForDiffusion(ModuleAttrMixin):
time_as_cond: bool = True, time_as_cond: bool = True,
obs_as_cond: bool = False, obs_as_cond: bool = False,
n_cond_layers: int = 0, n_cond_layers: int = 0,
backbone_type: str = 'vanilla',
n_kv_head: int = 1,
attn_res_ffn_mult: float = 2.667,
attn_res_eps: float = 1e-6,
attn_res_rope_theta: float = 10000.0,
) -> None: ) -> None:
super().__init__() super().__init__()
assert n_head == 1, 'IMFTransformerForDiffusion currently supports single-head attention only.' assert n_head == 1, 'IMFTransformerForDiffusion currently supports single-head attention only.'
if n_obs_steps is None: if n_obs_steps is None:
n_obs_steps = horizon n_obs_steps = horizon
self.backbone_type = backbone_type
T = horizon T = horizon
T_cond = 2 T_cond = 2
if not time_as_cond: if not time_as_cond:
@@ -61,77 +46,21 @@ class IMFTransformerForDiffusion(ModuleAttrMixin):
T_cond += n_obs_steps T_cond += n_obs_steps
self.input_emb = nn.Linear(input_dim, n_emb) self.input_emb = nn.Linear(input_dim, n_emb)
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
self.drop = nn.Dropout(p_drop_emb) self.drop = nn.Dropout(p_drop_emb)
self.time_emb = SinusoidalPosEmb(n_emb) self.time_emb = SinusoidalPosEmb(n_emb)
self.cond_obs_emb = nn.Linear(cond_dim, n_emb) if obs_as_cond else None self.cond_obs_emb = None
self.time_token_proj = None if obs_as_cond:
self.cond_obs_emb = nn.Linear(cond_dim, n_emb)
self.cond_pos_emb = None self.cond_pos_emb = None
self.pos_emb = None
self.encoder = None self.encoder = None
self.decoder = None self.decoder = None
self.attnres_backbone = None
encoder_only = False encoder_only = False
if T_cond > 0:
if backbone_type == 'attnres_full': self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
if not time_as_cond: if n_cond_layers > 0:
raise ValueError('attnres_full backbone requires time_as_cond=True.')
if n_cond_layers != 0:
raise ValueError('attnres_full backbone does not support n_cond_layers > 0.')
self.time_token_proj = nn.Linear(n_emb, n_emb)
self.attnres_backbone = AttnResTransformerBackbone(
d_model=n_emb,
n_blocks=n_layer,
n_heads=n_head,
n_kv_heads=n_kv_head,
max_seq_len=T + T_cond,
dropout=p_drop_attn,
ffn_mult=attn_res_ffn_mult,
eps=attn_res_eps,
rope_theta=attn_res_rope_theta,
causal_attn=causal_attn,
)
self.ln_f = RMSNorm(n_emb, eps=attn_res_eps)
else:
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
if T_cond > 0:
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
if n_cond_layers > 0:
encoder_layer = nn.TransformerEncoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4 * n_emb,
dropout=p_drop_attn,
activation='gelu',
batch_first=True,
norm_first=True,
)
self.encoder = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=n_cond_layers,
)
else:
self.encoder = nn.Sequential(
nn.Linear(n_emb, 4 * n_emb),
nn.Mish(),
nn.Linear(4 * n_emb, n_emb),
)
decoder_layer = nn.TransformerDecoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4 * n_emb,
dropout=p_drop_attn,
activation='gelu',
batch_first=True,
norm_first=True,
)
self.decoder = nn.TransformerDecoder(
decoder_layer=decoder_layer,
num_layers=n_layer,
)
else:
encoder_only = True
encoder_layer = nn.TransformerEncoderLayer( encoder_layer = nn.TransformerEncoderLayer(
d_model=n_emb, d_model=n_emb,
nhead=n_head, nhead=n_head,
@@ -143,12 +72,45 @@ class IMFTransformerForDiffusion(ModuleAttrMixin):
) )
self.encoder = nn.TransformerEncoder( self.encoder = nn.TransformerEncoder(
encoder_layer=encoder_layer, encoder_layer=encoder_layer,
num_layers=n_layer, num_layers=n_cond_layers,
)
else:
self.encoder = nn.Sequential(
nn.Linear(n_emb, 4 * n_emb),
nn.Mish(),
nn.Linear(4 * n_emb, n_emb),
) )
self.ln_f = nn.LayerNorm(n_emb) decoder_layer = nn.TransformerDecoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4 * n_emb,
dropout=p_drop_attn,
activation='gelu',
batch_first=True,
norm_first=True,
)
self.decoder = nn.TransformerDecoder(
decoder_layer=decoder_layer,
num_layers=n_layer,
)
else:
encoder_only = True
encoder_layer = nn.TransformerEncoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4 * n_emb,
dropout=p_drop_attn,
activation='gelu',
batch_first=True,
norm_first=True,
)
self.encoder = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=n_layer,
)
if causal_attn and backbone_type != 'attnres_full': if causal_attn:
sz = T sz = T
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
@@ -170,6 +132,7 @@ class IMFTransformerForDiffusion(ModuleAttrMixin):
self.mask = None self.mask = None
self.memory_mask = None self.memory_mask = None
self.ln_f = nn.LayerNorm(n_emb)
self.head = nn.Linear(n_emb, output_dim) self.head = nn.Linear(n_emb, output_dim)
self.T = T self.T = T
@@ -196,11 +159,6 @@ class IMFTransformerForDiffusion(ModuleAttrMixin):
nn.ModuleList, nn.ModuleList,
nn.Mish, nn.Mish,
nn.Sequential, nn.Sequential,
AttnResTransformerBackbone,
AttnResSubLayer,
GroupedQuerySelfAttention,
SwiGLUFFN,
RMSNormNoWeight,
) )
if isinstance(module, (nn.Linear, nn.Embedding)): if isinstance(module, (nn.Linear, nn.Embedding)):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
@@ -218,16 +176,12 @@ class IMFTransformerForDiffusion(ModuleAttrMixin):
bias = getattr(module, name) bias = getattr(module, name)
if bias is not None: if bias is not None:
torch.nn.init.zeros_(bias) torch.nn.init.zeros_(bias)
elif isinstance(module, (nn.LayerNorm, RMSNorm)): elif isinstance(module, nn.LayerNorm):
if getattr(module, 'bias', None) is not None: torch.nn.init.zeros_(module.bias)
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight) torch.nn.init.ones_(module.weight)
elif isinstance(module, AttnResOperator):
torch.nn.init.zeros_(module.pseudo_query)
elif isinstance(module, IMFTransformerForDiffusion): elif isinstance(module, IMFTransformerForDiffusion):
if module.pos_emb is not None: torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02) if module.cond_obs_emb is not None:
if module.cond_pos_emb is not None:
torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02) torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02)
elif isinstance(module, ignore_types): elif isinstance(module, ignore_types):
pass pass
@@ -238,24 +192,21 @@ class IMFTransformerForDiffusion(ModuleAttrMixin):
decay = set() decay = set()
no_decay = set() no_decay = set()
whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention) whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention)
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding, RMSNorm) blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in self.named_modules(): for mn, m in self.named_modules():
for pn, _ in m.named_parameters(recurse=False): for pn, _ in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn fpn = '%s.%s' % (mn, pn) if mn else pn
if pn.endswith('bias'): if pn.endswith('bias'):
no_decay.add(fpn) no_decay.add(fpn)
elif pn.startswith('bias'): elif pn.startswith('bias'):
no_decay.add(fpn) no_decay.add(fpn)
elif pn == 'pseudo_query':
no_decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
decay.add(fpn) decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
no_decay.add(fpn) no_decay.add(fpn)
if self.pos_emb is not None: no_decay.add('pos_emb')
no_decay.add('pos_emb')
no_decay.add('_dummy_variable') no_decay.add('_dummy_variable')
if self.cond_pos_emb is not None: if self.cond_pos_emb is not None:
no_decay.add('cond_pos_emb') no_decay.add('cond_pos_emb')
@@ -299,38 +250,18 @@ class IMFTransformerForDiffusion(ModuleAttrMixin):
value = value.to(device=sample.device, dtype=sample.dtype) value = value.to(device=sample.device, dtype=sample.dtype)
return value.expand(sample.shape[0]) return value.expand(sample.shape[0])
def _forward_attnres_full( def forward(
self, self,
sample: torch.Tensor, sample: torch.Tensor,
r: torch.Tensor, r: Union[torch.Tensor, float, int],
t: torch.Tensor, t: Union[torch.Tensor, float, int],
cond: Optional[torch.Tensor] = None, cond: Optional[torch.Tensor] = None,
) -> torch.Tensor: ):
sample_tokens = self.input_emb(sample) r = self._prepare_time_input(r, sample)
token_parts = [ t = self._prepare_time_input(t, sample)
self.time_token_proj(self.time_emb(r)).unsqueeze(1),
self.time_token_proj(self.time_emb(t)).unsqueeze(1),
]
if self.obs_as_cond:
if cond is None:
raise ValueError('cond is required when obs_as_cond=True for attnres_full backbone.')
token_parts.append(self.cond_obs_emb(cond))
token_parts.append(sample_tokens)
x = torch.cat(token_parts, dim=1)
x = self.drop(x)
x = self.attnres_backbone(x)
x = x[:, -sample_tokens.shape[1] :, :]
return x
def _forward_vanilla(
self,
sample: torch.Tensor,
r: torch.Tensor,
t: torch.Tensor,
cond: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r_emb = self.time_emb(r).unsqueeze(1) r_emb = self.time_emb(r).unsqueeze(1)
t_emb = self.time_emb(t).unsqueeze(1) t_emb = self.time_emb(t).unsqueeze(1)
input_emb = self.input_emb(sample) input_emb = self.input_emb(sample)
if self.encoder_only: if self.encoder_only:
@@ -361,22 +292,6 @@ class IMFTransformerForDiffusion(ModuleAttrMixin):
tgt_mask=self.mask, tgt_mask=self.mask,
memory_mask=self.memory_mask, memory_mask=self.memory_mask,
) )
return x
def forward(
self,
sample: torch.Tensor,
r: Union[torch.Tensor, float, int],
t: Union[torch.Tensor, float, int],
cond: Optional[torch.Tensor] = None,
):
r = self._prepare_time_input(r, sample)
t = self._prepare_time_input(t, sample)
if self.backbone_type == 'attnres_full':
x = self._forward_attnres_full(sample, r, t, cond=cond)
else:
x = self._forward_vanilla(sample, r, t, cond=cond)
x = self.ln_f(x) x = self.ln_f(x)
x = self.head(x) x = self.head(x)

View File

@@ -39,11 +39,6 @@ class IMFTransformerHybridImagePolicy(DiffusionTransformerHybridImagePolicy):
time_as_cond=True, time_as_cond=True,
obs_as_cond=True, obs_as_cond=True,
pred_action_steps_only=False, pred_action_steps_only=False,
backbone_type='vanilla',
n_kv_head=1,
attn_res_ffn_mult=2.667,
attn_res_eps=1e-6,
attn_res_rope_theta=10000.0,
**kwargs, **kwargs,
): ):
if num_inference_steps is None: if num_inference_steps is None:
@@ -95,11 +90,6 @@ class IMFTransformerHybridImagePolicy(DiffusionTransformerHybridImagePolicy):
time_as_cond=time_as_cond, time_as_cond=time_as_cond,
obs_as_cond=obs_as_cond, obs_as_cond=obs_as_cond,
n_cond_layers=n_cond_layers, n_cond_layers=n_cond_layers,
backbone_type=backbone_type,
n_kv_head=n_kv_head,
attn_res_ffn_mult=attn_res_ffn_mult,
attn_res_eps=attn_res_eps,
attn_res_rope_theta=attn_res_rope_theta,
) )
self.num_inference_steps = 1 self.num_inference_steps = 1

View File

@@ -0,0 +1,53 @@
# PushT DiT No-Causal Compare Implementation Plan
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
**Goal:** Add a PushT image DiT no-causal config, rerun the two prior DiT baselines for 350 epochs, and compare max `test_mean_score` plus batch-1 inference latency.
**Architecture:** Keep the existing causal DiT baselines unchanged and add a separate no-causal config that only flips `policy.causal_attn=false` while preserving the SwanLab naming safeguards. Launch the default DiT (`256x8`) locally and the `256x18` DiT on 5880 GPU0, then parse `logs.json.txt` and benchmark both checkpoints on the same hardware.
**Tech Stack:** Hydra, Diffusion Policy transformer image workspace, SwanLab, uv Python env, local 5090 + trusted remote 5880.
---
### Task 1: Add no-causal DiT config and config regression test
**Files:**
- Create: `image_pusht_diffusion_policy_dit_nocausal.yaml`
- Modify: `tests/test_pusht_swanlab_config.py`
- [ ] Write a failing test asserting the new no-causal DiT config uses SwanLab-safe naming and `policy.causal_attn == False`.
- [ ] Run the targeted pytest command and verify it fails because the config does not exist yet.
- [ ] Add the minimal new config by composing from the existing PushT DiT config and overriding only `policy.causal_attn=false`.
- [ ] Re-run the targeted pytest command and verify it passes.
### Task 2: Smoke-verify the new config
**Files:**
- Read: `image_pusht_diffusion_policy_dit_nocausal.yaml`
- [ ] Run `train.py --help` against the new config.
- [ ] Verify Hydra resolves the config without errors.
### Task 3: Launch the two 350-epoch no-causal reruns
**Files:**
- Write runtime scripts/logs under `data/run_logs/`
- Write outputs under `data/outputs/`
- [ ] Launch local run: `dit_nocausal_img_pusht_default_seed42_local` with 350 epochs.
- [ ] Launch remote run: `dit_nocausal_img_pusht_emb256_layer18_seed42_5880gpu0` with 350 epochs and `policy.n_layer=18`.
- [ ] Use explicit SwanLab overrides: unique `logging.name`, `logging.resume=false`, `logging.id=null`, shared group `dit_pusht_nocausal_compare`.
- [ ] Record pid files and launcher scripts.
### Task 4: Monitor and summarize
**Files:**
- Read: per-run `logs.json.txt`
- Read: checkpoints directories
- [ ] Monitor until both runs reach epoch 349 completion.
- [ ] Extract `max(test_mean_score)` and final logged `test_mean_score`.
- [ ] Identify the best checkpoint for each run.
- [ ] Benchmark batch-1 `policy.predict_action(obs)` latency on the same hardware.
- [ ] Report the final comparison table and short conclusion.

View File

@@ -1,60 +0,0 @@
# PushT iMF Full-Attention Implementation Plan
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
**Goal:** Add a separate full-attention PushT image iMF config, commit/push it on a new branch, and launch the 9-run 350-epoch architecture sweep across 3 GPUs.
**Architecture:** Keep the existing causal iMF path untouched and add a standalone full-attention config that only flips `policy.causal_attn=false` while retaining one-step iMF inference and SwanLab-safe naming. Reuse the previous 9-run architecture matrix and balanced 3-queue scheduling across local 5090 plus 5880 GPU0/GPU1.
**Tech Stack:** Hydra, Diffusion Policy iMF image workspace, SwanLab, uv env, local shell + trusted remote 5880 over SSH.
---
### Task 1: Add full-attention iMF config with TDD
**Files:**
- Create: `image_pusht_diffusion_policy_dit_imf_fullattn.yaml`
- Modify: `tests/test_pusht_swanlab_config.py`
- [ ] Write a failing config regression test asserting the new config uses SwanLab-safe naming and `policy.causal_attn == False`.
- [ ] Run the targeted pytest command and verify it fails because the config does not exist yet.
- [ ] Add the minimal full-attention config by composing from the existing PushT image iMF config and overriding only `exp_name` and `policy.causal_attn=false`.
- [ ] Re-run the targeted pytest and verify it passes.
### Task 2: Verify the new config
**Files:**
- Read: `image_pusht_diffusion_policy_dit_imf_fullattn.yaml`
- [ ] Run `train.py --help` for the new config.
- [ ] Run a real `training.debug=true` smoke test locally to confirm the training path is valid.
### Task 3: Commit and push the new branch
**Files:**
- Commit only the new config/test/plan files needed for the full-attention experiment chain.
- [ ] Run verification commands again before commit.
- [ ] Commit with a focused message.
- [ ] Push `feat/pusht-imf-fullattn` to origin.
### Task 4: Launch the 9-run sweep
**Files:**
- Write queue scripts and logs under `data/run_logs/` locally and on 5880.
- Write outputs under `data/outputs/` locally and on 5880.
- [ ] Use the same matrix as the prior iMF sweep: `n_emb ∈ {128,256,384}`, `n_layer ∈ {6,12,18}`, `seed=42`.
- [ ] Set `training.num_epochs=350` for all 9 runs.
- [ ] Encode `fullattn` in every `exp_name`, `logging.name`, and run directory to avoid collisions.
- [ ] Balance the 9 runs across local 5090, 5880 GPU0, and 5880 GPU1 as three serial queues.
- [ ] Sync the new config to the remote smoke repo before launching remote queues.
### Task 5: Monitor and auto-summarize
**Files:**
- Read local and remote pid files, logs, outputs, checkpoints.
- [ ] Start an xhigh monitoring agent that polls all three queues.
- [ ] On completion, parse all 9 `logs.json.txt` files and rank by max `test_mean_score`.
- [ ] Report embedding/layer trends and the best configuration.

View File

@@ -1,57 +0,0 @@
# PushT Image iMF AttnRes Implementation Plan
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
**Goal:** Add an AttnRes-backed full-attention iMF backbone for the PushT image experiment path, verify it with tests/smoke runs, then launch the 9-run 350-epoch architecture sweep across the local 5090 and remote 5880 GPUs.
**Architecture:** Extend `IMFTransformerForDiffusion` with a selectable `attnres_full` backbone that keeps the current iMF training/inference API unchanged while replacing the transformer internals with RMSNorm + RoPE self-attention + SwiGLU + Full AttnRes depth-wise residual routing. Add one standalone Hydra config for the PushT image sweep and reuse queue-style launch scripts with unique SwanLab names.
**Tech Stack:** Python 3.9 via uv, PyTorch 2.8 CUDA, Hydra, SwanLab online logging, local shell + SSH to trusted 5880 host.
---
### Task 1: Add regression tests for the new AttnRes path
**Files:**
- Modify: `tests/test_imf_transformer_for_diffusion.py`
- Modify: `tests/test_pusht_swanlab_config.py`
- [ ] Add a failing model test that instantiates `IMFTransformerForDiffusion(backbone_type='attnres_full', causal_attn=False, ...)`, runs a forward pass with conditional observations, and asserts output shape plus optimizer construction.
- [ ] Run the targeted pytest selection and confirm the new test fails for the expected missing-backbone reason.
- [ ] Add a failing config regression test for `image_pusht_diffusion_policy_dit_imf_attnres_full.yaml` asserting SwanLab naming fields and `policy.causal_attn == False`.
- [ ] Re-run the targeted pytest selection and confirm the config test fails before implementation.
### Task 2: Implement the AttnRes-backed iMF backbone
**Files:**
- Create: `diffusion_policy/model/diffusion/attnres_transformer_components.py`
- Modify: `diffusion_policy/model/diffusion/imf_transformer_for_diffusion.py`
- [ ] Add focused reusable modules for `RMSNorm`, RoPE helpers, grouped-query self-attention, SwiGLU FFN, and the Full AttnRes operator.
- [ ] Extend `IMFTransformerForDiffusion` with a `backbone_type` switch that preserves the existing vanilla path and adds an `attnres_full` path using concatenated `[r, t, obs, sample]` tokens.
- [ ] Ensure the AttnRes path slices condition tokens away before the output head so the returned tensor still matches the sample/action horizon.
- [ ] Update optimizer parameter grouping to treat RMSNorm weights like LayerNorm weights (no decay) and include any new positional/conditioning parameters.
- [ ] Run the targeted tests and get them green.
### Task 3: Add the new PushT config and smoke-test path
**Files:**
- Create: `image_pusht_diffusion_policy_dit_imf_attnres_full.yaml`
- Modify: `tests/test_pusht_swanlab_config.py`
- [ ] Add a standalone PushT image config for the AttnRes iMF variant with SwanLab online logging, `policy.backbone_type=attnres_full`, and `policy.causal_attn=false`.
- [ ] Verify `uv run python train.py --config-dir=. --config-name=image_pusht_diffusion_policy_dit_imf_attnres_full.yaml --help` succeeds.
- [ ] Run a real smoke training command with `training.debug=true`, `training.device=cuda:0`, safety overrides (`dataloader.num_workers=0`, `task.env_runner.n_envs=1`, no vis), and confirm it reaches the training loop and writes a run directory.
### Task 4: Prepare launch scripts and start the 9-run sweep
**Files:**
- Create or modify: `data/run_logs/imf_attnres_local_queue.sh`
- Create or modify locally before copy: `data/run_logs/imf_attnres_remote_gpu0_queue.sh`
- Create or modify locally before copy: `data/run_logs/imf_attnres_remote_gpu1_queue.sh`
- [ ] Write queue command templates for the 9 runs using config `image_pusht_diffusion_policy_dit_imf_attnres_full.yaml`, `training.num_epochs=350`, unique `exp_name/logging.name`, and shared `logging.group=imf_pusht_attnres_arch_sweep`.
- [ ] Sync the necessary config/model files plus remote queue scripts to `droid@100.73.14.65:~/project/diffusion_policy-smoke`.
- [ ] Start the local queue under `nohup`, record PID, and verify the first run log is advancing.
- [ ] Start the two remote queues under `nohup`, record PIDs, and verify both first-run logs are advancing.
- [ ] Confirm all three GPUs have officially entered training for the new sweep.

View File

@@ -1,107 +0,0 @@
# PushT Image iMF Full-Attention Sweep Design
## Goal
在一个独立新分支上,为 PushT 图像 iMF 路线新增 **full-attention** 变体(关闭因果注意力),并按与之前相同的架构扫描网格运行 **9 组实验**,每组训练 **350 epochs**。所有实验完成后,提取每组 **`max(test_mean_score)`** 并输出完整排名和趋势总结。
## Scope
本次工作仅覆盖:
1. 在不影响现有因果版 iMF 路线的前提下,新增 full-attention 实验链路;
2.`n_emb ∈ {128, 256, 384}``n_layer ∈ {6, 12, 18}` 的 9 组组合做 350-epoch 扫描;
3. 在本机 5090 与 5880 双卡上做三路并行调度;
4. 在全部实验完成后自动汇总结果并直接向用户汇报。
不在本次范围内:
- 不替换或删除现有因果版 iMF 配置;
- 不改动已有 DiT baseline 实现;
- 不做多 seed 扩展;
- 不额外增加视频记录。
## Design Choice
采用“**新增独立配置 + 新分支**”的方式,而不是覆盖现有 iMF 默认配置。
原因:
- 现有因果版 iMF 已完成实验与结果记录,保持不动更利于对照;
- full-attention 作为新的实验链路,使用独立配置更易复现;
- 运行时只需要通过配置切换 `policy.causal_attn=false`,不需要重新设计 iMF 算法本身。
## Configuration Design
新增一个独立配置文件,例如:
- `image_pusht_diffusion_policy_dit_imf_fullattn.yaml`
其职责:
- 继承当前 PushT image iMF 配置链路;
- 保持 iMF 单步推理、SwanLab 标量记录、无视频记录;
- 显式设置:
- `policy.causal_attn=false`
- `policy.n_head=1`
- 保持其余 iMF 训练语义不变。
SwanLab 命名延续当前修复后的策略:
- `logging.name=${exp_name}`
- `logging.resume=false`
- `logging.id=null`
- `logging.group=${exp_name}` 或统一 sweep group override
## Code Change Strategy
优先最小改动:
- 若当前 `IMFTransformerForDiffusion` 已支持 `causal_attn=False` 分支,则不改核心算法,仅通过新配置关闭因果 mask
- 如需补充回归验证,则新增针对 full-attention 配置/掩码行为的最小测试;
- 不改变已有因果版实验配置和已有测试语义。
## Experiment Matrix
实验网格固定为:
- `n_emb=128, n_layer=6`
- `n_emb=128, n_layer=12`
- `n_emb=128, n_layer=18`
- `n_emb=256, n_layer=6`
- `n_emb=256, n_layer=12`
- `n_emb=256, n_layer=18`
- `n_emb=384, n_layer=6`
- `n_emb=384, n_layer=12`
- `n_emb=384, n_layer=18`
统一设置:
- `training.num_epochs=350`
- `training.resume=false`
- `seed=42`
- PushT image 数据路径不变
- 指标以 **`logs.json.txt``test_mean_score` 的最大值** 为准
## Scheduling Design
使用三路串行队列并行执行 9 个实验:
- 本机 50901 个顺序队列
- 5880 GPU01 个顺序队列
- 5880 GPU11 个顺序队列
分配原则:
- 延续按 `n_emb × n_layer` 近似平衡工作量;
- 每张卡同一时刻只跑 1 个实验;
- 队列脚本负责“前一个结束后自动启动下一个”。
## Monitoring Design
继续采用“**训练队列脚本 + 监控 agent**”双层机制:
1. **实际调度**由本地/远端队列脚本负责;
2. **监控**由一个 xhigh 子 agent 轮询:
- 读取 pid 状态
- 检查 master log
- 检查每个 run 的 `logs.json.txt`
- 判断是否卡死/失败/全部完成
3. 一旦全部完成,监控 agent 直接返回:
- 9 组实验的最终 epoch
- 每组 `max(test_mean_score)`
- 排名表
- embedding / layer 趋势总结
本次要求下agent 在收到全部完成信号后应直接向主会话回报结果,不等待用户再次提醒。
## Success Criteria
满足以下条件即视为完成:
1. full-attention iMF 配置在新分支上可运行;
2. 9 组 350-epoch 实验全部完成;
3. 不记录仿真视频,只记录标量;
4. SwanLab 运行命名不冲突;
5. 输出 9 组实验 `max(test_mean_score)` 的完整汇总与结论;
6. 全部实验结束后主会话可直接给用户最终总结。

View File

@@ -1,108 +0,0 @@
# PushT Image iMF AttnRes Design
## Goal
在现有 PushT 图像 iMF full-attention 路线之上,引入 `attn_res` 仓库中的 **Full AttnRes** 残差聚合形式,并同步使用与其匹配的 **RMSNorm + 自注意力 + SwiGLU FFN** 模块,保持 iMF 训练目标与一步推理语义不变,仅作用于本次实验链路。实现完成并验证后,启动与此前相同的 9 组 `n_emb × n_layer` 扫描350 epochs, seed=42, SwanLab online, 无视频记录)。
## Scope
本次工作仅覆盖:
1.`IMFTransformerForDiffusion` 增加一个 AttnRes-backed backbone 变体;
2. 保持 `forward(sample, r, t, cond=None)`、iMF loss、一步推理策略接口不变
3. 新增独立 PushT 图像配置用于该变体;
4. 复用本地 5090 + 远端 5880 双卡三路并行调度 9 组实验。
不在范围内:
- 不替换已有 vanilla iMF/full-attn 配置;
- 不修改 DiT baseline
- 不增加视频日志;
- 不扩大到多 seed。
## Recommended Approach
采用“**在当前 iMF 模型内增加可选 AttnRes backbone**”的方式,而不是新建独立 policy 链路。
理由:
- policy / workspace / loss / sampling 路径已经被验证,保留这些路径可最大程度缩小变动面;
- 仅在模型内部切换 backbone可以让新实验与既有 iMF 结果保持可比;
- 配置上只需显式打开 `backbone_type=attnres_full``causal_attn=false` 等开关,复现实验更直接。
## Architecture
### 1. Backbone split
`IMFTransformerForDiffusion` 保留现有 vanilla encoder/decoder 实现为默认路径,并新增 `attnres_full` 路径:
- **vanilla**:保持当前实现不变;
- **attnres_full**:使用单栈式全注意力 Transformer输入 token 序列为
`[r token, t token, obs cond tokens..., action/sample tokens...]`
模型只对末尾的 action/sample token 位置输出 `u` 预测,前置条件 token 仅参与上下文建模。
### 2. AttnRes stack
新 backbone 使用以下模块:
- `RMSNorm`
- `Rotary Position Embedding`(用于自注意力 q/k
- `GroupedQueryAttention`(本实验默认 `n_kv_head=1`,与单头配置兼容)
- `SwiGLU` FFN
- `AttnResOperator`(每个子层一个 pseudo-query执行 full depth-wise residual aggregation
每个 transformer block 由两个子层组成:
1. self-attention 子层
2. FFN 子层
每个子层的输入不再是简单 `x + f(x)`,而是从 embedding 与全部历史子层输出中通过 Full AttnRes 聚合得到 `h_l`,再执行 `RMSNorm(h_l) -> sublayer_fn(...)`
### 3. Conditioning and token flow
- `sample` 先经 `input_emb` 映射为 action tokens
- `r``t` 各自经 `SinusoidalPosEmb + linear` 映射为两个条件 token
- 图像观测编码后的 `cond` 通过 `cond_obs_emb` 映射为 obs tokens
- 拼接后的完整 token 序列进入 AttnRes stack
- 输出时切掉前置条件 token仅保留 action/sample token 段,随后经 `RMSNorm + head` 得到最终 `u`
### 4. Attention mode
本次实验链路固定为 **non-causal full attention**
- `causal_attn=false`
- 不构造 causal mask
- 所有 token 可彼此双向可见
这与用户指定的“训练过程仍然使用全注意力(不加因果注意)”一致。
## Config and Logging
新增独立配置文件,例如:
- `image_pusht_diffusion_policy_dit_imf_attnres_full.yaml`
该配置需要:
- 指向现有 `IMFTransformerHybridImagePolicy`
- 显式开启 AttnRes backbone 相关参数
- 设置 `policy.causal_attn=false`
- 保持 `logging.backend=swanlab``logging.mode=online`
- 运行时通过覆盖保证:
- `logging.name=<unique_run_name>`
- `logging.group=imf_pusht_attnres_arch_sweep`
- `exp_name=<unique_run_name>`
- 保持 `task.env_runner.n_test_vis=0``n_train_vis=0`,仅记录标量
## Experiment Matrix
固定 9 组:
- `n_emb ∈ {128, 256, 384}`
- `n_layer ∈ {6, 12, 18}`
- `seed=42`
- `training.num_epochs=350`
## Scheduling
沿用之前验证过的三队列分配:
- 本机 5090`384x18`, `256x6`, `128x6`
- 5880 GPU0`384x12`, `256x12`, `128x12`
- 5880 GPU1`384x6`, `256x18`, `128x18`
每个 run name 编码 backbone 与结构,例如:
`imf_attnres_emb256_layer12_seed42_5880gpu0`
## Verification
实现阶段至少验证:
1. 新配置的 SwanLab 命名与 `causal_attn=false` 正确;
2. 新 backbone 的 forward shape 与 `configure_optimizers()` 可用;
3. 旧 vanilla 路径测试不回归;
4. `training.debug=true` smoke run 可以完整通过。
## Success Criteria
1. 新 AttnRes iMF 变体在本分支可训练、可一步推理;
2. 不影响已有 vanilla iMF/full-attn 链路;
3. 9 组实验成功在三张卡上正式启动;
4. SwanLab run 名称唯一,无冲突;
5. 不记录视频,仅记录标量。

View File

@@ -1,35 +0,0 @@
defaults:
- diffusion_policy/config/train_diffusion_transformer_hybrid_workspace@_here_
- override /diffusion_policy/config/task@task: pusht_image
- _self_
exp_name: pusht_image_dit_imf_attnres_full
policy:
_target_: diffusion_policy.policy.imf_transformer_hybrid_image_policy.IMFTransformerHybridImagePolicy
num_inference_steps: 1
n_head: 1
n_kv_head: 1
causal_attn: false
backbone_type: attnres_full
logging:
backend: swanlab
mode: online
name: ${exp_name}
resume: false
tags: ["${name}", "${task_name}", "${exp_name}", "swanlab"]
id: null
group: ${exp_name}
dataloader:
num_workers: 0
val_dataloader:
num_workers: 0
task:
env_runner:
n_envs: 1
n_test_vis: 0
n_train_vis: 0

View File

@@ -3,12 +3,10 @@ defaults:
- override /diffusion_policy/config/task@task: pusht_image - override /diffusion_policy/config/task@task: pusht_image
- _self_ - _self_
exp_name: pusht_image_dit_imf_fullattn exp_name: pusht_image_dit_nocausal
policy: policy:
_target_: diffusion_policy.policy.imf_transformer_hybrid_image_policy.IMFTransformerHybridImagePolicy _target_: diffusion_policy.policy.diffusion_transformer_hybrid_image_policy.DiffusionTransformerHybridImagePolicy
num_inference_steps: 1
n_head: 1
causal_attn: false causal_attn: false
logging: logging:

View File

@@ -1,29 +0,0 @@
#!/usr/bin/env bash
set -euo pipefail
cd /home/droid/project/diffusion_policy/.worktrees/feat-pusht-imf-attnres
export PYTHONUNBUFFERED=1
export SWANLAB_API_KEY='PSZrBMLx1XAjDjvmhUcNz'
export LD_LIBRARY_PATH="$(printf '%s:' .venv/lib/python3.9/site-packages/nvidia/*/lib | sed 's/:$//')"
run_exp() {
local name="$1" emb="$2" layer="$3"
echo "[$(date '+%F %T')] START $name emb=$emb layer=$layer"
.venv/bin/python train.py \
--config-dir=. \
--config-name=image_pusht_diffusion_policy_dit_imf_attnres_full.yaml \
training.device=cuda:0 \
training.num_epochs=350 \
training.resume=false \
exp_name="$name" \
logging.group=imf_pusht_attnres_arch_sweep \
logging.name="$name" \
logging.resume=false \
logging.id=null \
hydra.run.dir="data/outputs/$name" \
policy.n_emb="$emb" \
policy.n_layer="$layer" \
> "data/run_logs/${name}.log" 2>&1
echo "[$(date '+%F %T')] END $name"
}
run_exp imf_attnres_emb384_layer18_seed42_local 384 18
run_exp imf_attnres_emb256_layer6_seed42_local 256 6
run_exp imf_attnres_emb128_layer6_seed42_local 128 6

View File

@@ -1,29 +0,0 @@
#!/usr/bin/env bash
set -euo pipefail
cd /home/droid/project/diffusion_policy-smoke
export PYTHONUNBUFFERED=1
export SWANLAB_API_KEY='PSZrBMLx1XAjDjvmhUcNz'
export LD_LIBRARY_PATH="$(printf '%s:' .venv/lib/python3.9/site-packages/nvidia/*/lib | sed 's/:$//')"
run_exp() {
local name="$1" emb="$2" layer="$3"
echo "[$(date '+%F %T')] START $name emb=$emb layer=$layer"
.venv/bin/python train.py \
--config-dir=. \
--config-name=image_pusht_diffusion_policy_dit_imf_attnres_full.yaml \
training.device=cuda:0 \
training.num_epochs=350 \
training.resume=false \
exp_name="$name" \
logging.group=imf_pusht_attnres_arch_sweep \
logging.name="$name" \
logging.resume=false \
logging.id=null \
hydra.run.dir="data/outputs/$name" \
policy.n_emb="$emb" \
policy.n_layer="$layer" \
> "data/run_logs/${name}.log" 2>&1
echo "[$(date '+%F %T')] END $name"
}
run_exp imf_attnres_emb384_layer12_seed42_5880gpu0 384 12
run_exp imf_attnres_emb256_layer12_seed42_5880gpu0 256 12
run_exp imf_attnres_emb128_layer12_seed42_5880gpu0 128 12

View File

@@ -1,29 +0,0 @@
#!/usr/bin/env bash
set -euo pipefail
cd /home/droid/project/diffusion_policy-smoke
export PYTHONUNBUFFERED=1
export SWANLAB_API_KEY='PSZrBMLx1XAjDjvmhUcNz'
export LD_LIBRARY_PATH="$(printf '%s:' .venv/lib/python3.9/site-packages/nvidia/*/lib | sed 's/:$//')"
run_exp() {
local name="$1" emb="$2" layer="$3"
echo "[$(date '+%F %T')] START $name emb=$emb layer=$layer"
.venv/bin/python train.py \
--config-dir=. \
--config-name=image_pusht_diffusion_policy_dit_imf_attnres_full.yaml \
training.device=cuda:1 \
training.num_epochs=350 \
training.resume=false \
exp_name="$name" \
logging.group=imf_pusht_attnres_arch_sweep \
logging.name="$name" \
logging.resume=false \
logging.id=null \
hydra.run.dir="data/outputs/$name" \
policy.n_emb="$emb" \
policy.n_layer="$layer" \
> "data/run_logs/${name}.log" 2>&1
echo "[$(date '+%F %T')] END $name"
}
run_exp imf_attnres_emb384_layer6_seed42_5880gpu1 384 6
run_exp imf_attnres_emb256_layer18_seed42_5880gpu1 256 18
run_exp imf_attnres_emb128_layer18_seed42_5880gpu1 128 18

View File

@@ -44,34 +44,3 @@ def test_imf_transformer_forward_signature_and_shape_single_head():
pred_u = model(sample, r, t, cond=cond) pred_u = model(sample, r, t, cond=cond)
assert pred_u.shape == sample.shape assert pred_u.shape == sample.shape
def test_imf_transformer_attnres_full_backbone_forward_shape_and_optimizer():
model = IMFTransformerForDiffusion(
input_dim=3,
output_dim=3,
horizon=5,
n_obs_steps=2,
cond_dim=4,
n_layer=2,
n_head=1,
n_emb=16,
p_drop_emb=0.0,
p_drop_attn=0.0,
causal_attn=False,
time_as_cond=True,
obs_as_cond=True,
n_cond_layers=0,
backbone_type='attnres_full',
)
optimizer = model.configure_optimizers()
sample = torch.randn(2, 5, 3)
r = torch.rand(2)
t = torch.rand(2)
cond = torch.randn(2, 2, 4)
pred_u = model(sample, r, t, cond=cond)
assert pred_u.shape == sample.shape
assert optimizer is not None

View File

@@ -32,8 +32,8 @@ def test_image_pusht_dit_imf_swanlab_config_uses_exp_name_and_no_resume_collisio
assert cfg.logging.group == cfg.exp_name assert cfg.logging.group == cfg.exp_name
def test_image_pusht_dit_imf_fullattn_config_uses_exp_name_and_disables_causal_attention(): def test_image_pusht_dit_nocausal_config_uses_exp_name_and_disables_causal_attention():
cfg = _load_cfg('image_pusht_diffusion_policy_dit_imf_fullattn.yaml') cfg = _load_cfg('image_pusht_diffusion_policy_dit_nocausal.yaml')
assert cfg.logging.backend == 'swanlab' assert cfg.logging.backend == 'swanlab'
assert cfg.logging.mode == 'online' assert cfg.logging.mode == 'online'
@@ -42,16 +42,3 @@ def test_image_pusht_dit_imf_fullattn_config_uses_exp_name_and_disables_causal_a
assert cfg.logging.id is None assert cfg.logging.id is None
assert cfg.logging.group == cfg.exp_name assert cfg.logging.group == cfg.exp_name
assert cfg.policy.causal_attn is False assert cfg.policy.causal_attn is False
def test_image_pusht_dit_imf_attnres_full_config_uses_exp_name_and_disables_causal_attention():
cfg = _load_cfg('image_pusht_diffusion_policy_dit_imf_attnres_full.yaml')
assert cfg.logging.backend == 'swanlab'
assert cfg.logging.mode == 'online'
assert cfg.logging.name == cfg.exp_name
assert cfg.logging.resume is False
assert cfg.logging.id is None
assert cfg.logging.group == cfg.exp_name
assert cfg.policy.causal_attn is False
assert cfg.policy.backbone_type == 'attnres_full'