Compare commits
4 Commits
feat/pusht
...
feat/pusht
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
211abbb87f | ||
|
|
185ed6596c | ||
|
|
78ab18e8f3 | ||
|
|
484d008997 |
@@ -0,0 +1,247 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6) -> None:
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
rms = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
return (x.float() * rms).to(x.dtype) * self.weight
|
||||
|
||||
|
||||
class RMSNormNoWeight(nn.Module):
|
||||
def __init__(self, eps: float = 1e-6) -> None:
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
rms = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
return (x.float() * rms).to(x.dtype)
|
||||
|
||||
|
||||
def precompute_rope_freqs(
|
||||
dim: int,
|
||||
max_seq_len: int,
|
||||
theta: float = 10000.0,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> Tensor:
|
||||
if dim % 2 != 0:
|
||||
raise ValueError(f'RoPE requires an even head dimension, got {dim}.')
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim))
|
||||
positions = torch.arange(max_seq_len, device=device).float()
|
||||
angles = torch.outer(positions, freqs)
|
||||
return torch.polar(torch.ones_like(angles), angles)
|
||||
|
||||
|
||||
def apply_rope(x: Tensor, freqs: Tensor) -> Tensor:
|
||||
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
||||
freqs = freqs.unsqueeze(0).unsqueeze(2)
|
||||
x_rotated = x_complex * freqs
|
||||
return torch.view_as_real(x_rotated).reshape_as(x).to(x.dtype)
|
||||
|
||||
|
||||
class GroupedQuerySelfAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
dropout: float = 0.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if d_model % n_heads != 0:
|
||||
raise ValueError(f'd_model={d_model} must be divisible by n_heads={n_heads}.')
|
||||
if n_heads % n_kv_heads != 0:
|
||||
raise ValueError(f'n_heads={n_heads} must be divisible by n_kv_heads={n_kv_heads}.')
|
||||
|
||||
self.d_model = d_model
|
||||
self.n_heads = n_heads
|
||||
self.n_kv_heads = n_kv_heads
|
||||
self.n_kv_groups = n_heads // n_kv_heads
|
||||
self.d_head = d_model // n_heads
|
||||
self.attn_dropout = nn.Dropout(dropout)
|
||||
self.out_dropout = nn.Dropout(dropout)
|
||||
|
||||
self.w_q = nn.Linear(d_model, n_heads * self.d_head, bias=False)
|
||||
self.w_k = nn.Linear(d_model, n_kv_heads * self.d_head, bias=False)
|
||||
self.w_v = nn.Linear(d_model, n_kv_heads * self.d_head, bias=False)
|
||||
self.w_o = nn.Linear(n_heads * self.d_head, d_model, bias=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
rope_freqs: Tensor,
|
||||
mask: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
batch_size, seq_len, _ = x.shape
|
||||
|
||||
q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_head)
|
||||
k = self.w_k(x).view(batch_size, seq_len, self.n_kv_heads, self.d_head)
|
||||
v = self.w_v(x).view(batch_size, seq_len, self.n_kv_heads, self.d_head)
|
||||
|
||||
q = apply_rope(q, rope_freqs)
|
||||
k = apply_rope(k, rope_freqs)
|
||||
|
||||
if self.n_kv_heads != self.n_heads:
|
||||
k = k.unsqueeze(3).expand(
|
||||
batch_size, seq_len, self.n_kv_heads, self.n_kv_groups, self.d_head
|
||||
)
|
||||
k = k.reshape(batch_size, seq_len, self.n_heads, self.d_head)
|
||||
v = v.unsqueeze(3).expand(
|
||||
batch_size, seq_len, self.n_kv_heads, self.n_kv_groups, self.d_head
|
||||
)
|
||||
v = v.reshape(batch_size, seq_len, self.n_heads, self.d_head)
|
||||
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
scale = 1.0 / math.sqrt(self.d_head)
|
||||
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale
|
||||
if mask is not None:
|
||||
attn_weights = attn_weights + mask
|
||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||
attn_weights = self.attn_dropout(attn_weights)
|
||||
|
||||
out = torch.matmul(attn_weights, v)
|
||||
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
|
||||
return self.out_dropout(self.w_o(out))
|
||||
|
||||
|
||||
class SwiGLUFFN(nn.Module):
|
||||
def __init__(self, d_model: int, dropout: float = 0.0, mult: float = 2.667) -> None:
|
||||
super().__init__()
|
||||
raw = int(mult * d_model)
|
||||
d_ff = ((raw + 7) // 8) * 8
|
||||
self.w_gate = nn.Linear(d_model, d_ff, bias=False)
|
||||
self.w_up = nn.Linear(d_model, d_ff, bias=False)
|
||||
self.w_down = nn.Linear(d_ff, d_model, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.dropout(self.w_down(F.silu(self.w_gate(x)) * self.w_up(x)))
|
||||
|
||||
|
||||
class AttnResOperator(nn.Module):
|
||||
def __init__(self, d_model: int, eps: float = 1e-6) -> None:
|
||||
super().__init__()
|
||||
self.pseudo_query = nn.Parameter(torch.zeros(d_model))
|
||||
self.key_norm = RMSNormNoWeight(eps=eps)
|
||||
|
||||
def forward(self, sources: Tensor) -> Tensor:
|
||||
keys = self.key_norm(sources)
|
||||
logits = torch.einsum('d,nbtd->nbt', self.pseudo_query, keys)
|
||||
weights = F.softmax(logits, dim=0)
|
||||
return torch.einsum('nbt,nbtd->btd', weights, sources)
|
||||
|
||||
|
||||
class AttnResSubLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
dropout: float,
|
||||
ffn_mult: float,
|
||||
eps: float,
|
||||
is_attention: bool,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(d_model, eps=eps)
|
||||
self.attn_res = AttnResOperator(d_model, eps=eps)
|
||||
self.is_attention = is_attention
|
||||
if is_attention:
|
||||
self.fn = GroupedQuerySelfAttention(
|
||||
d_model=d_model,
|
||||
n_heads=n_heads,
|
||||
n_kv_heads=n_kv_heads,
|
||||
dropout=dropout,
|
||||
)
|
||||
else:
|
||||
self.fn = SwiGLUFFN(d_model=d_model, dropout=dropout, mult=ffn_mult)
|
||||
|
||||
def forward(self, sources: Tensor, rope_freqs: Tensor, mask: Optional[Tensor] = None) -> Tensor:
|
||||
h = self.attn_res(sources)
|
||||
normed = self.norm(h)
|
||||
if self.is_attention:
|
||||
return self.fn(normed, rope_freqs, mask)
|
||||
return self.fn(normed)
|
||||
|
||||
|
||||
class AttnResTransformerBackbone(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
n_blocks: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
max_seq_len: int,
|
||||
dropout: float = 0.0,
|
||||
ffn_mult: float = 2.667,
|
||||
eps: float = 1e-6,
|
||||
rope_theta: float = 10000.0,
|
||||
causal_attn: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.causal_attn = causal_attn
|
||||
self.layers = nn.ModuleList()
|
||||
for _ in range(n_blocks):
|
||||
self.layers.append(
|
||||
AttnResSubLayer(
|
||||
d_model=d_model,
|
||||
n_heads=n_heads,
|
||||
n_kv_heads=n_kv_heads,
|
||||
dropout=dropout,
|
||||
ffn_mult=ffn_mult,
|
||||
eps=eps,
|
||||
is_attention=True,
|
||||
)
|
||||
)
|
||||
self.layers.append(
|
||||
AttnResSubLayer(
|
||||
d_model=d_model,
|
||||
n_heads=n_heads,
|
||||
n_kv_heads=n_kv_heads,
|
||||
dropout=dropout,
|
||||
ffn_mult=ffn_mult,
|
||||
eps=eps,
|
||||
is_attention=False,
|
||||
)
|
||||
)
|
||||
|
||||
rope_freqs = precompute_rope_freqs(
|
||||
dim=d_model // n_heads,
|
||||
max_seq_len=max_seq_len,
|
||||
theta=rope_theta,
|
||||
)
|
||||
self.register_buffer('rope_freqs', rope_freqs, persistent=False)
|
||||
|
||||
@staticmethod
|
||||
def _build_causal_mask(seq_len: int, device: torch.device) -> Tensor:
|
||||
mask = torch.full((seq_len, seq_len), float('-inf'), device=device)
|
||||
mask = torch.triu(mask, diagonal=1)
|
||||
return mask.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
seq_len = x.shape[1]
|
||||
rope_freqs = self.rope_freqs[:seq_len]
|
||||
mask = None
|
||||
if self.causal_attn:
|
||||
mask = self._build_causal_mask(seq_len, x.device)
|
||||
|
||||
layer_outputs = [x]
|
||||
for layer in self.layers:
|
||||
sources = torch.stack(layer_outputs, dim=0)
|
||||
output = layer(sources, rope_freqs, mask)
|
||||
layer_outputs.append(output)
|
||||
|
||||
return torch.stack(layer_outputs, dim=0).sum(dim=0)
|
||||
@@ -5,6 +5,15 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin
|
||||
from diffusion_policy.model.diffusion.attnres_transformer_components import (
|
||||
AttnResOperator,
|
||||
AttnResSubLayer,
|
||||
AttnResTransformerBackbone,
|
||||
GroupedQuerySelfAttention,
|
||||
RMSNorm,
|
||||
RMSNormNoWeight,
|
||||
SwiGLUFFN,
|
||||
)
|
||||
from diffusion_policy.model.diffusion.positional_embedding import SinusoidalPosEmb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -27,14 +36,20 @@ class IMFTransformerForDiffusion(ModuleAttrMixin):
|
||||
time_as_cond: bool = True,
|
||||
obs_as_cond: bool = False,
|
||||
n_cond_layers: int = 0,
|
||||
backbone_type: str = 'vanilla',
|
||||
n_kv_head: int = 1,
|
||||
attn_res_ffn_mult: float = 2.667,
|
||||
attn_res_eps: float = 1e-6,
|
||||
attn_res_rope_theta: float = 10000.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
assert n_head == 1, 'IMFTransformerForDiffusion currently supports single-head attention only.'
|
||||
|
||||
if n_obs_steps is None:
|
||||
n_obs_steps = horizon
|
||||
|
||||
self.backbone_type = backbone_type
|
||||
|
||||
T = horizon
|
||||
T_cond = 2
|
||||
if not time_as_cond:
|
||||
@@ -46,21 +61,77 @@ class IMFTransformerForDiffusion(ModuleAttrMixin):
|
||||
T_cond += n_obs_steps
|
||||
|
||||
self.input_emb = nn.Linear(input_dim, n_emb)
|
||||
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
|
||||
self.drop = nn.Dropout(p_drop_emb)
|
||||
|
||||
self.time_emb = SinusoidalPosEmb(n_emb)
|
||||
self.cond_obs_emb = None
|
||||
if obs_as_cond:
|
||||
self.cond_obs_emb = nn.Linear(cond_dim, n_emb)
|
||||
|
||||
self.cond_obs_emb = nn.Linear(cond_dim, n_emb) if obs_as_cond else None
|
||||
self.time_token_proj = None
|
||||
self.cond_pos_emb = None
|
||||
self.pos_emb = None
|
||||
self.encoder = None
|
||||
self.decoder = None
|
||||
self.attnres_backbone = None
|
||||
encoder_only = False
|
||||
if T_cond > 0:
|
||||
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
|
||||
if n_cond_layers > 0:
|
||||
|
||||
if backbone_type == 'attnres_full':
|
||||
if not time_as_cond:
|
||||
raise ValueError('attnres_full backbone requires time_as_cond=True.')
|
||||
if n_cond_layers != 0:
|
||||
raise ValueError('attnres_full backbone does not support n_cond_layers > 0.')
|
||||
|
||||
self.time_token_proj = nn.Linear(n_emb, n_emb)
|
||||
self.attnres_backbone = AttnResTransformerBackbone(
|
||||
d_model=n_emb,
|
||||
n_blocks=n_layer,
|
||||
n_heads=n_head,
|
||||
n_kv_heads=n_kv_head,
|
||||
max_seq_len=T + T_cond,
|
||||
dropout=p_drop_attn,
|
||||
ffn_mult=attn_res_ffn_mult,
|
||||
eps=attn_res_eps,
|
||||
rope_theta=attn_res_rope_theta,
|
||||
causal_attn=causal_attn,
|
||||
)
|
||||
self.ln_f = RMSNorm(n_emb, eps=attn_res_eps)
|
||||
else:
|
||||
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
|
||||
if T_cond > 0:
|
||||
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
|
||||
if n_cond_layers > 0:
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=n_emb,
|
||||
nhead=n_head,
|
||||
dim_feedforward=4 * n_emb,
|
||||
dropout=p_drop_attn,
|
||||
activation='gelu',
|
||||
batch_first=True,
|
||||
norm_first=True,
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(
|
||||
encoder_layer=encoder_layer,
|
||||
num_layers=n_cond_layers,
|
||||
)
|
||||
else:
|
||||
self.encoder = nn.Sequential(
|
||||
nn.Linear(n_emb, 4 * n_emb),
|
||||
nn.Mish(),
|
||||
nn.Linear(4 * n_emb, n_emb),
|
||||
)
|
||||
|
||||
decoder_layer = nn.TransformerDecoderLayer(
|
||||
d_model=n_emb,
|
||||
nhead=n_head,
|
||||
dim_feedforward=4 * n_emb,
|
||||
dropout=p_drop_attn,
|
||||
activation='gelu',
|
||||
batch_first=True,
|
||||
norm_first=True,
|
||||
)
|
||||
self.decoder = nn.TransformerDecoder(
|
||||
decoder_layer=decoder_layer,
|
||||
num_layers=n_layer,
|
||||
)
|
||||
else:
|
||||
encoder_only = True
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=n_emb,
|
||||
nhead=n_head,
|
||||
@@ -72,45 +143,12 @@ class IMFTransformerForDiffusion(ModuleAttrMixin):
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(
|
||||
encoder_layer=encoder_layer,
|
||||
num_layers=n_cond_layers,
|
||||
)
|
||||
else:
|
||||
self.encoder = nn.Sequential(
|
||||
nn.Linear(n_emb, 4 * n_emb),
|
||||
nn.Mish(),
|
||||
nn.Linear(4 * n_emb, n_emb),
|
||||
num_layers=n_layer,
|
||||
)
|
||||
|
||||
decoder_layer = nn.TransformerDecoderLayer(
|
||||
d_model=n_emb,
|
||||
nhead=n_head,
|
||||
dim_feedforward=4 * n_emb,
|
||||
dropout=p_drop_attn,
|
||||
activation='gelu',
|
||||
batch_first=True,
|
||||
norm_first=True,
|
||||
)
|
||||
self.decoder = nn.TransformerDecoder(
|
||||
decoder_layer=decoder_layer,
|
||||
num_layers=n_layer,
|
||||
)
|
||||
else:
|
||||
encoder_only = True
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=n_emb,
|
||||
nhead=n_head,
|
||||
dim_feedforward=4 * n_emb,
|
||||
dropout=p_drop_attn,
|
||||
activation='gelu',
|
||||
batch_first=True,
|
||||
norm_first=True,
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(
|
||||
encoder_layer=encoder_layer,
|
||||
num_layers=n_layer,
|
||||
)
|
||||
self.ln_f = nn.LayerNorm(n_emb)
|
||||
|
||||
if causal_attn:
|
||||
if causal_attn and backbone_type != 'attnres_full':
|
||||
sz = T
|
||||
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
||||
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||
@@ -132,7 +170,6 @@ class IMFTransformerForDiffusion(ModuleAttrMixin):
|
||||
self.mask = None
|
||||
self.memory_mask = None
|
||||
|
||||
self.ln_f = nn.LayerNorm(n_emb)
|
||||
self.head = nn.Linear(n_emb, output_dim)
|
||||
|
||||
self.T = T
|
||||
@@ -159,6 +196,11 @@ class IMFTransformerForDiffusion(ModuleAttrMixin):
|
||||
nn.ModuleList,
|
||||
nn.Mish,
|
||||
nn.Sequential,
|
||||
AttnResTransformerBackbone,
|
||||
AttnResSubLayer,
|
||||
GroupedQuerySelfAttention,
|
||||
SwiGLUFFN,
|
||||
RMSNormNoWeight,
|
||||
)
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
@@ -176,12 +218,16 @@ class IMFTransformerForDiffusion(ModuleAttrMixin):
|
||||
bias = getattr(module, name)
|
||||
if bias is not None:
|
||||
torch.nn.init.zeros_(bias)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, (nn.LayerNorm, RMSNorm)):
|
||||
if getattr(module, 'bias', None) is not None:
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
torch.nn.init.ones_(module.weight)
|
||||
elif isinstance(module, AttnResOperator):
|
||||
torch.nn.init.zeros_(module.pseudo_query)
|
||||
elif isinstance(module, IMFTransformerForDiffusion):
|
||||
torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
|
||||
if module.cond_obs_emb is not None:
|
||||
if module.pos_emb is not None:
|
||||
torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
|
||||
if module.cond_pos_emb is not None:
|
||||
torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02)
|
||||
elif isinstance(module, ignore_types):
|
||||
pass
|
||||
@@ -192,21 +238,24 @@ class IMFTransformerForDiffusion(ModuleAttrMixin):
|
||||
decay = set()
|
||||
no_decay = set()
|
||||
whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention)
|
||||
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
||||
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding, RMSNorm)
|
||||
for mn, m in self.named_modules():
|
||||
for pn, _ in m.named_parameters():
|
||||
for pn, _ in m.named_parameters(recurse=False):
|
||||
fpn = '%s.%s' % (mn, pn) if mn else pn
|
||||
|
||||
if pn.endswith('bias'):
|
||||
no_decay.add(fpn)
|
||||
elif pn.startswith('bias'):
|
||||
no_decay.add(fpn)
|
||||
elif pn == 'pseudo_query':
|
||||
no_decay.add(fpn)
|
||||
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
|
||||
decay.add(fpn)
|
||||
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
|
||||
no_decay.add(fpn)
|
||||
|
||||
no_decay.add('pos_emb')
|
||||
if self.pos_emb is not None:
|
||||
no_decay.add('pos_emb')
|
||||
no_decay.add('_dummy_variable')
|
||||
if self.cond_pos_emb is not None:
|
||||
no_decay.add('cond_pos_emb')
|
||||
@@ -250,18 +299,38 @@ class IMFTransformerForDiffusion(ModuleAttrMixin):
|
||||
value = value.to(device=sample.device, dtype=sample.dtype)
|
||||
return value.expand(sample.shape[0])
|
||||
|
||||
def forward(
|
||||
def _forward_attnres_full(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
r: Union[torch.Tensor, float, int],
|
||||
t: Union[torch.Tensor, float, int],
|
||||
r: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
cond: Optional[torch.Tensor] = None,
|
||||
):
|
||||
r = self._prepare_time_input(r, sample)
|
||||
t = self._prepare_time_input(t, sample)
|
||||
) -> torch.Tensor:
|
||||
sample_tokens = self.input_emb(sample)
|
||||
token_parts = [
|
||||
self.time_token_proj(self.time_emb(r)).unsqueeze(1),
|
||||
self.time_token_proj(self.time_emb(t)).unsqueeze(1),
|
||||
]
|
||||
if self.obs_as_cond:
|
||||
if cond is None:
|
||||
raise ValueError('cond is required when obs_as_cond=True for attnres_full backbone.')
|
||||
token_parts.append(self.cond_obs_emb(cond))
|
||||
token_parts.append(sample_tokens)
|
||||
x = torch.cat(token_parts, dim=1)
|
||||
x = self.drop(x)
|
||||
x = self.attnres_backbone(x)
|
||||
x = x[:, -sample_tokens.shape[1] :, :]
|
||||
return x
|
||||
|
||||
def _forward_vanilla(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
r: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
cond: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
r_emb = self.time_emb(r).unsqueeze(1)
|
||||
t_emb = self.time_emb(t).unsqueeze(1)
|
||||
|
||||
input_emb = self.input_emb(sample)
|
||||
|
||||
if self.encoder_only:
|
||||
@@ -292,6 +361,22 @@ class IMFTransformerForDiffusion(ModuleAttrMixin):
|
||||
tgt_mask=self.mask,
|
||||
memory_mask=self.memory_mask,
|
||||
)
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
r: Union[torch.Tensor, float, int],
|
||||
t: Union[torch.Tensor, float, int],
|
||||
cond: Optional[torch.Tensor] = None,
|
||||
):
|
||||
r = self._prepare_time_input(r, sample)
|
||||
t = self._prepare_time_input(t, sample)
|
||||
|
||||
if self.backbone_type == 'attnres_full':
|
||||
x = self._forward_attnres_full(sample, r, t, cond=cond)
|
||||
else:
|
||||
x = self._forward_vanilla(sample, r, t, cond=cond)
|
||||
|
||||
x = self.ln_f(x)
|
||||
x = self.head(x)
|
||||
|
||||
@@ -39,6 +39,11 @@ class IMFTransformerHybridImagePolicy(DiffusionTransformerHybridImagePolicy):
|
||||
time_as_cond=True,
|
||||
obs_as_cond=True,
|
||||
pred_action_steps_only=False,
|
||||
backbone_type='vanilla',
|
||||
n_kv_head=1,
|
||||
attn_res_ffn_mult=2.667,
|
||||
attn_res_eps=1e-6,
|
||||
attn_res_rope_theta=10000.0,
|
||||
**kwargs,
|
||||
):
|
||||
if num_inference_steps is None:
|
||||
@@ -90,6 +95,11 @@ class IMFTransformerHybridImagePolicy(DiffusionTransformerHybridImagePolicy):
|
||||
time_as_cond=time_as_cond,
|
||||
obs_as_cond=obs_as_cond,
|
||||
n_cond_layers=n_cond_layers,
|
||||
backbone_type=backbone_type,
|
||||
n_kv_head=n_kv_head,
|
||||
attn_res_ffn_mult=attn_res_ffn_mult,
|
||||
attn_res_eps=attn_res_eps,
|
||||
attn_res_rope_theta=attn_res_rope_theta,
|
||||
)
|
||||
self.num_inference_steps = 1
|
||||
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
# PushT DiT No-Causal Compare Implementation Plan
|
||||
|
||||
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||
|
||||
**Goal:** Add a PushT image DiT no-causal config, rerun the two prior DiT baselines for 350 epochs, and compare max `test_mean_score` plus batch-1 inference latency.
|
||||
|
||||
**Architecture:** Keep the existing causal DiT baselines unchanged and add a separate no-causal config that only flips `policy.causal_attn=false` while preserving the SwanLab naming safeguards. Launch the default DiT (`256x8`) locally and the `256x18` DiT on 5880 GPU0, then parse `logs.json.txt` and benchmark both checkpoints on the same hardware.
|
||||
|
||||
**Tech Stack:** Hydra, Diffusion Policy transformer image workspace, SwanLab, uv Python env, local 5090 + trusted remote 5880.
|
||||
|
||||
---
|
||||
|
||||
### Task 1: Add no-causal DiT config and config regression test
|
||||
|
||||
**Files:**
|
||||
- Create: `image_pusht_diffusion_policy_dit_nocausal.yaml`
|
||||
- Modify: `tests/test_pusht_swanlab_config.py`
|
||||
|
||||
- [ ] Write a failing test asserting the new no-causal DiT config uses SwanLab-safe naming and `policy.causal_attn == False`.
|
||||
- [ ] Run the targeted pytest command and verify it fails because the config does not exist yet.
|
||||
- [ ] Add the minimal new config by composing from the existing PushT DiT config and overriding only `policy.causal_attn=false`.
|
||||
- [ ] Re-run the targeted pytest command and verify it passes.
|
||||
|
||||
### Task 2: Smoke-verify the new config
|
||||
|
||||
**Files:**
|
||||
- Read: `image_pusht_diffusion_policy_dit_nocausal.yaml`
|
||||
|
||||
- [ ] Run `train.py --help` against the new config.
|
||||
- [ ] Verify Hydra resolves the config without errors.
|
||||
|
||||
### Task 3: Launch the two 350-epoch no-causal reruns
|
||||
|
||||
**Files:**
|
||||
- Write runtime scripts/logs under `data/run_logs/`
|
||||
- Write outputs under `data/outputs/`
|
||||
|
||||
- [ ] Launch local run: `dit_nocausal_img_pusht_default_seed42_local` with 350 epochs.
|
||||
- [ ] Launch remote run: `dit_nocausal_img_pusht_emb256_layer18_seed42_5880gpu0` with 350 epochs and `policy.n_layer=18`.
|
||||
- [ ] Use explicit SwanLab overrides: unique `logging.name`, `logging.resume=false`, `logging.id=null`, shared group `dit_pusht_nocausal_compare`.
|
||||
- [ ] Record pid files and launcher scripts.
|
||||
|
||||
### Task 4: Monitor and summarize
|
||||
|
||||
**Files:**
|
||||
- Read: per-run `logs.json.txt`
|
||||
- Read: checkpoints directories
|
||||
|
||||
- [ ] Monitor until both runs reach epoch 349 completion.
|
||||
- [ ] Extract `max(test_mean_score)` and final logged `test_mean_score`.
|
||||
- [ ] Identify the best checkpoint for each run.
|
||||
- [ ] Benchmark batch-1 `policy.predict_action(obs)` latency on the same hardware.
|
||||
- [ ] Report the final comparison table and short conclusion.
|
||||
@@ -0,0 +1,60 @@
|
||||
# PushT iMF Full-Attention Implementation Plan
|
||||
|
||||
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||
|
||||
**Goal:** Add a separate full-attention PushT image iMF config, commit/push it on a new branch, and launch the 9-run 350-epoch architecture sweep across 3 GPUs.
|
||||
|
||||
**Architecture:** Keep the existing causal iMF path untouched and add a standalone full-attention config that only flips `policy.causal_attn=false` while retaining one-step iMF inference and SwanLab-safe naming. Reuse the previous 9-run architecture matrix and balanced 3-queue scheduling across local 5090 plus 5880 GPU0/GPU1.
|
||||
|
||||
**Tech Stack:** Hydra, Diffusion Policy iMF image workspace, SwanLab, uv env, local shell + trusted remote 5880 over SSH.
|
||||
|
||||
---
|
||||
|
||||
### Task 1: Add full-attention iMF config with TDD
|
||||
|
||||
**Files:**
|
||||
- Create: `image_pusht_diffusion_policy_dit_imf_fullattn.yaml`
|
||||
- Modify: `tests/test_pusht_swanlab_config.py`
|
||||
|
||||
- [ ] Write a failing config regression test asserting the new config uses SwanLab-safe naming and `policy.causal_attn == False`.
|
||||
- [ ] Run the targeted pytest command and verify it fails because the config does not exist yet.
|
||||
- [ ] Add the minimal full-attention config by composing from the existing PushT image iMF config and overriding only `exp_name` and `policy.causal_attn=false`.
|
||||
- [ ] Re-run the targeted pytest and verify it passes.
|
||||
|
||||
### Task 2: Verify the new config
|
||||
|
||||
**Files:**
|
||||
- Read: `image_pusht_diffusion_policy_dit_imf_fullattn.yaml`
|
||||
|
||||
- [ ] Run `train.py --help` for the new config.
|
||||
- [ ] Run a real `training.debug=true` smoke test locally to confirm the training path is valid.
|
||||
|
||||
### Task 3: Commit and push the new branch
|
||||
|
||||
**Files:**
|
||||
- Commit only the new config/test/plan files needed for the full-attention experiment chain.
|
||||
|
||||
- [ ] Run verification commands again before commit.
|
||||
- [ ] Commit with a focused message.
|
||||
- [ ] Push `feat/pusht-imf-fullattn` to origin.
|
||||
|
||||
### Task 4: Launch the 9-run sweep
|
||||
|
||||
**Files:**
|
||||
- Write queue scripts and logs under `data/run_logs/` locally and on 5880.
|
||||
- Write outputs under `data/outputs/` locally and on 5880.
|
||||
|
||||
- [ ] Use the same matrix as the prior iMF sweep: `n_emb ∈ {128,256,384}`, `n_layer ∈ {6,12,18}`, `seed=42`.
|
||||
- [ ] Set `training.num_epochs=350` for all 9 runs.
|
||||
- [ ] Encode `fullattn` in every `exp_name`, `logging.name`, and run directory to avoid collisions.
|
||||
- [ ] Balance the 9 runs across local 5090, 5880 GPU0, and 5880 GPU1 as three serial queues.
|
||||
- [ ] Sync the new config to the remote smoke repo before launching remote queues.
|
||||
|
||||
### Task 5: Monitor and auto-summarize
|
||||
|
||||
**Files:**
|
||||
- Read local and remote pid files, logs, outputs, checkpoints.
|
||||
|
||||
- [ ] Start an xhigh monitoring agent that polls all three queues.
|
||||
- [ ] On completion, parse all 9 `logs.json.txt` files and rank by max `test_mean_score`.
|
||||
- [ ] Report embedding/layer trends and the best configuration.
|
||||
@@ -0,0 +1,57 @@
|
||||
# PushT Image iMF AttnRes Implementation Plan
|
||||
|
||||
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||
|
||||
**Goal:** Add an AttnRes-backed full-attention iMF backbone for the PushT image experiment path, verify it with tests/smoke runs, then launch the 9-run 350-epoch architecture sweep across the local 5090 and remote 5880 GPUs.
|
||||
|
||||
**Architecture:** Extend `IMFTransformerForDiffusion` with a selectable `attnres_full` backbone that keeps the current iMF training/inference API unchanged while replacing the transformer internals with RMSNorm + RoPE self-attention + SwiGLU + Full AttnRes depth-wise residual routing. Add one standalone Hydra config for the PushT image sweep and reuse queue-style launch scripts with unique SwanLab names.
|
||||
|
||||
**Tech Stack:** Python 3.9 via uv, PyTorch 2.8 CUDA, Hydra, SwanLab online logging, local shell + SSH to trusted 5880 host.
|
||||
|
||||
---
|
||||
|
||||
### Task 1: Add regression tests for the new AttnRes path
|
||||
|
||||
**Files:**
|
||||
- Modify: `tests/test_imf_transformer_for_diffusion.py`
|
||||
- Modify: `tests/test_pusht_swanlab_config.py`
|
||||
|
||||
- [ ] Add a failing model test that instantiates `IMFTransformerForDiffusion(backbone_type='attnres_full', causal_attn=False, ...)`, runs a forward pass with conditional observations, and asserts output shape plus optimizer construction.
|
||||
- [ ] Run the targeted pytest selection and confirm the new test fails for the expected missing-backbone reason.
|
||||
- [ ] Add a failing config regression test for `image_pusht_diffusion_policy_dit_imf_attnres_full.yaml` asserting SwanLab naming fields and `policy.causal_attn == False`.
|
||||
- [ ] Re-run the targeted pytest selection and confirm the config test fails before implementation.
|
||||
|
||||
### Task 2: Implement the AttnRes-backed iMF backbone
|
||||
|
||||
**Files:**
|
||||
- Create: `diffusion_policy/model/diffusion/attnres_transformer_components.py`
|
||||
- Modify: `diffusion_policy/model/diffusion/imf_transformer_for_diffusion.py`
|
||||
|
||||
- [ ] Add focused reusable modules for `RMSNorm`, RoPE helpers, grouped-query self-attention, SwiGLU FFN, and the Full AttnRes operator.
|
||||
- [ ] Extend `IMFTransformerForDiffusion` with a `backbone_type` switch that preserves the existing vanilla path and adds an `attnres_full` path using concatenated `[r, t, obs, sample]` tokens.
|
||||
- [ ] Ensure the AttnRes path slices condition tokens away before the output head so the returned tensor still matches the sample/action horizon.
|
||||
- [ ] Update optimizer parameter grouping to treat RMSNorm weights like LayerNorm weights (no decay) and include any new positional/conditioning parameters.
|
||||
- [ ] Run the targeted tests and get them green.
|
||||
|
||||
### Task 3: Add the new PushT config and smoke-test path
|
||||
|
||||
**Files:**
|
||||
- Create: `image_pusht_diffusion_policy_dit_imf_attnres_full.yaml`
|
||||
- Modify: `tests/test_pusht_swanlab_config.py`
|
||||
|
||||
- [ ] Add a standalone PushT image config for the AttnRes iMF variant with SwanLab online logging, `policy.backbone_type=attnres_full`, and `policy.causal_attn=false`.
|
||||
- [ ] Verify `uv run python train.py --config-dir=. --config-name=image_pusht_diffusion_policy_dit_imf_attnres_full.yaml --help` succeeds.
|
||||
- [ ] Run a real smoke training command with `training.debug=true`, `training.device=cuda:0`, safety overrides (`dataloader.num_workers=0`, `task.env_runner.n_envs=1`, no vis), and confirm it reaches the training loop and writes a run directory.
|
||||
|
||||
### Task 4: Prepare launch scripts and start the 9-run sweep
|
||||
|
||||
**Files:**
|
||||
- Create or modify: `data/run_logs/imf_attnres_local_queue.sh`
|
||||
- Create or modify locally before copy: `data/run_logs/imf_attnres_remote_gpu0_queue.sh`
|
||||
- Create or modify locally before copy: `data/run_logs/imf_attnres_remote_gpu1_queue.sh`
|
||||
|
||||
- [ ] Write queue command templates for the 9 runs using config `image_pusht_diffusion_policy_dit_imf_attnres_full.yaml`, `training.num_epochs=350`, unique `exp_name/logging.name`, and shared `logging.group=imf_pusht_attnres_arch_sweep`.
|
||||
- [ ] Sync the necessary config/model files plus remote queue scripts to `droid@100.73.14.65:~/project/diffusion_policy-smoke`.
|
||||
- [ ] Start the local queue under `nohup`, record PID, and verify the first run log is advancing.
|
||||
- [ ] Start the two remote queues under `nohup`, record PIDs, and verify both first-run logs are advancing.
|
||||
- [ ] Confirm all three GPUs have officially entered training for the new sweep.
|
||||
107
docs/superpowers/specs/2026-03-27-pusht-imf-fullattn-design.md
Normal file
107
docs/superpowers/specs/2026-03-27-pusht-imf-fullattn-design.md
Normal file
@@ -0,0 +1,107 @@
|
||||
# PushT Image iMF Full-Attention Sweep Design
|
||||
|
||||
## Goal
|
||||
在一个独立新分支上,为 PushT 图像 iMF 路线新增 **full-attention** 变体(关闭因果注意力),并按与之前相同的架构扫描网格运行 **9 组实验**,每组训练 **350 epochs**。所有实验完成后,提取每组 **`max(test_mean_score)`** 并输出完整排名和趋势总结。
|
||||
|
||||
## Scope
|
||||
本次工作仅覆盖:
|
||||
1. 在不影响现有因果版 iMF 路线的前提下,新增 full-attention 实验链路;
|
||||
2. 对 `n_emb ∈ {128, 256, 384}` 与 `n_layer ∈ {6, 12, 18}` 的 9 组组合做 350-epoch 扫描;
|
||||
3. 在本机 5090 与 5880 双卡上做三路并行调度;
|
||||
4. 在全部实验完成后自动汇总结果并直接向用户汇报。
|
||||
|
||||
不在本次范围内:
|
||||
- 不替换或删除现有因果版 iMF 配置;
|
||||
- 不改动已有 DiT baseline 实现;
|
||||
- 不做多 seed 扩展;
|
||||
- 不额外增加视频记录。
|
||||
|
||||
## Design Choice
|
||||
采用“**新增独立配置 + 新分支**”的方式,而不是覆盖现有 iMF 默认配置。
|
||||
|
||||
原因:
|
||||
- 现有因果版 iMF 已完成实验与结果记录,保持不动更利于对照;
|
||||
- full-attention 作为新的实验链路,使用独立配置更易复现;
|
||||
- 运行时只需要通过配置切换 `policy.causal_attn=false`,不需要重新设计 iMF 算法本身。
|
||||
|
||||
## Configuration Design
|
||||
新增一个独立配置文件,例如:
|
||||
- `image_pusht_diffusion_policy_dit_imf_fullattn.yaml`
|
||||
|
||||
其职责:
|
||||
- 继承当前 PushT image iMF 配置链路;
|
||||
- 保持 iMF 单步推理、SwanLab 标量记录、无视频记录;
|
||||
- 显式设置:
|
||||
- `policy.causal_attn=false`
|
||||
- `policy.n_head=1`
|
||||
- 保持其余 iMF 训练语义不变。
|
||||
|
||||
SwanLab 命名延续当前修复后的策略:
|
||||
- `logging.name=${exp_name}`
|
||||
- `logging.resume=false`
|
||||
- `logging.id=null`
|
||||
- `logging.group=${exp_name}` 或统一 sweep group override
|
||||
|
||||
## Code Change Strategy
|
||||
优先最小改动:
|
||||
- 若当前 `IMFTransformerForDiffusion` 已支持 `causal_attn=False` 分支,则不改核心算法,仅通过新配置关闭因果 mask;
|
||||
- 如需补充回归验证,则新增针对 full-attention 配置/掩码行为的最小测试;
|
||||
- 不改变已有因果版实验配置和已有测试语义。
|
||||
|
||||
## Experiment Matrix
|
||||
实验网格固定为:
|
||||
|
||||
- `n_emb=128, n_layer=6`
|
||||
- `n_emb=128, n_layer=12`
|
||||
- `n_emb=128, n_layer=18`
|
||||
- `n_emb=256, n_layer=6`
|
||||
- `n_emb=256, n_layer=12`
|
||||
- `n_emb=256, n_layer=18`
|
||||
- `n_emb=384, n_layer=6`
|
||||
- `n_emb=384, n_layer=12`
|
||||
- `n_emb=384, n_layer=18`
|
||||
|
||||
统一设置:
|
||||
- `training.num_epochs=350`
|
||||
- `training.resume=false`
|
||||
- `seed=42`
|
||||
- PushT image 数据路径不变
|
||||
- 指标以 **`logs.json.txt` 中 `test_mean_score` 的最大值** 为准
|
||||
|
||||
## Scheduling Design
|
||||
使用三路串行队列并行执行 9 个实验:
|
||||
|
||||
- 本机 5090:1 个顺序队列
|
||||
- 5880 GPU0:1 个顺序队列
|
||||
- 5880 GPU1:1 个顺序队列
|
||||
|
||||
分配原则:
|
||||
- 延续按 `n_emb × n_layer` 近似平衡工作量;
|
||||
- 每张卡同一时刻只跑 1 个实验;
|
||||
- 队列脚本负责“前一个结束后自动启动下一个”。
|
||||
|
||||
## Monitoring Design
|
||||
继续采用“**训练队列脚本 + 监控 agent**”双层机制:
|
||||
|
||||
1. **实际调度**由本地/远端队列脚本负责;
|
||||
2. **监控**由一个 xhigh 子 agent 轮询:
|
||||
- 读取 pid 状态
|
||||
- 检查 master log
|
||||
- 检查每个 run 的 `logs.json.txt`
|
||||
- 判断是否卡死/失败/全部完成
|
||||
3. 一旦全部完成,监控 agent 直接返回:
|
||||
- 9 组实验的最终 epoch
|
||||
- 每组 `max(test_mean_score)`
|
||||
- 排名表
|
||||
- embedding / layer 趋势总结
|
||||
|
||||
本次要求下,agent 在收到全部完成信号后应直接向主会话回报结果,不等待用户再次提醒。
|
||||
|
||||
## Success Criteria
|
||||
满足以下条件即视为完成:
|
||||
1. full-attention iMF 配置在新分支上可运行;
|
||||
2. 9 组 350-epoch 实验全部完成;
|
||||
3. 不记录仿真视频,只记录标量;
|
||||
4. SwanLab 运行命名不冲突;
|
||||
5. 输出 9 组实验 `max(test_mean_score)` 的完整汇总与结论;
|
||||
6. 全部实验结束后主会话可直接给用户最终总结。
|
||||
108
docs/superpowers/specs/2026-03-29-pusht-imf-attnres-design.md
Normal file
108
docs/superpowers/specs/2026-03-29-pusht-imf-attnres-design.md
Normal file
@@ -0,0 +1,108 @@
|
||||
# PushT Image iMF AttnRes Design
|
||||
|
||||
## Goal
|
||||
在现有 PushT 图像 iMF full-attention 路线之上,引入 `attn_res` 仓库中的 **Full AttnRes** 残差聚合形式,并同步使用与其匹配的 **RMSNorm + 自注意力 + SwiGLU FFN** 模块,保持 iMF 训练目标与一步推理语义不变,仅作用于本次实验链路。实现完成并验证后,启动与此前相同的 9 组 `n_emb × n_layer` 扫描(350 epochs, seed=42, SwanLab online, 无视频记录)。
|
||||
|
||||
## Scope
|
||||
本次工作仅覆盖:
|
||||
1. 为 `IMFTransformerForDiffusion` 增加一个 AttnRes-backed backbone 变体;
|
||||
2. 保持 `forward(sample, r, t, cond=None)`、iMF loss、一步推理策略接口不变;
|
||||
3. 新增独立 PushT 图像配置用于该变体;
|
||||
4. 复用本地 5090 + 远端 5880 双卡三路并行调度 9 组实验。
|
||||
|
||||
不在范围内:
|
||||
- 不替换已有 vanilla iMF/full-attn 配置;
|
||||
- 不修改 DiT baseline;
|
||||
- 不增加视频日志;
|
||||
- 不扩大到多 seed。
|
||||
|
||||
## Recommended Approach
|
||||
采用“**在当前 iMF 模型内增加可选 AttnRes backbone**”的方式,而不是新建独立 policy 链路。
|
||||
|
||||
理由:
|
||||
- policy / workspace / loss / sampling 路径已经被验证,保留这些路径可最大程度缩小变动面;
|
||||
- 仅在模型内部切换 backbone,可以让新实验与既有 iMF 结果保持可比;
|
||||
- 配置上只需显式打开 `backbone_type=attnres_full`、`causal_attn=false` 等开关,复现实验更直接。
|
||||
|
||||
## Architecture
|
||||
### 1. Backbone split
|
||||
`IMFTransformerForDiffusion` 保留现有 vanilla encoder/decoder 实现为默认路径,并新增 `attnres_full` 路径:
|
||||
- **vanilla**:保持当前实现不变;
|
||||
- **attnres_full**:使用单栈式全注意力 Transformer,输入 token 序列为
|
||||
`[r token, t token, obs cond tokens..., action/sample tokens...]`。
|
||||
|
||||
模型只对末尾的 action/sample token 位置输出 `u` 预测,前置条件 token 仅参与上下文建模。
|
||||
|
||||
### 2. AttnRes stack
|
||||
新 backbone 使用以下模块:
|
||||
- `RMSNorm`
|
||||
- `Rotary Position Embedding`(用于自注意力 q/k)
|
||||
- `GroupedQueryAttention`(本实验默认 `n_kv_head=1`,与单头配置兼容)
|
||||
- `SwiGLU` FFN
|
||||
- `AttnResOperator`(每个子层一个 pseudo-query,执行 full depth-wise residual aggregation)
|
||||
|
||||
每个 transformer block 由两个子层组成:
|
||||
1. self-attention 子层
|
||||
2. FFN 子层
|
||||
|
||||
每个子层的输入不再是简单 `x + f(x)`,而是从 embedding 与全部历史子层输出中通过 Full AttnRes 聚合得到 `h_l`,再执行 `RMSNorm(h_l) -> sublayer_fn(...)`。
|
||||
|
||||
### 3. Conditioning and token flow
|
||||
- `sample` 先经 `input_emb` 映射为 action tokens;
|
||||
- `r` 和 `t` 各自经 `SinusoidalPosEmb + linear` 映射为两个条件 token;
|
||||
- 图像观测编码后的 `cond` 通过 `cond_obs_emb` 映射为 obs tokens;
|
||||
- 拼接后的完整 token 序列进入 AttnRes stack;
|
||||
- 输出时切掉前置条件 token,仅保留 action/sample token 段,随后经 `RMSNorm + head` 得到最终 `u`。
|
||||
|
||||
### 4. Attention mode
|
||||
本次实验链路固定为 **non-causal full attention**:
|
||||
- `causal_attn=false`
|
||||
- 不构造 causal mask
|
||||
- 所有 token 可彼此双向可见
|
||||
|
||||
这与用户指定的“训练过程仍然使用全注意力(不加因果注意)”一致。
|
||||
|
||||
## Config and Logging
|
||||
新增独立配置文件,例如:
|
||||
- `image_pusht_diffusion_policy_dit_imf_attnres_full.yaml`
|
||||
|
||||
该配置需要:
|
||||
- 指向现有 `IMFTransformerHybridImagePolicy`
|
||||
- 显式开启 AttnRes backbone 相关参数
|
||||
- 设置 `policy.causal_attn=false`
|
||||
- 保持 `logging.backend=swanlab`、`logging.mode=online`
|
||||
- 运行时通过覆盖保证:
|
||||
- `logging.name=<unique_run_name>`
|
||||
- `logging.group=imf_pusht_attnres_arch_sweep`
|
||||
- `exp_name=<unique_run_name>`
|
||||
- 保持 `task.env_runner.n_test_vis=0` 与 `n_train_vis=0`,仅记录标量
|
||||
|
||||
## Experiment Matrix
|
||||
固定 9 组:
|
||||
- `n_emb ∈ {128, 256, 384}`
|
||||
- `n_layer ∈ {6, 12, 18}`
|
||||
- `seed=42`
|
||||
- `training.num_epochs=350`
|
||||
|
||||
## Scheduling
|
||||
沿用之前验证过的三队列分配:
|
||||
- 本机 5090:`384x18`, `256x6`, `128x6`
|
||||
- 5880 GPU0:`384x12`, `256x12`, `128x12`
|
||||
- 5880 GPU1:`384x6`, `256x18`, `128x18`
|
||||
|
||||
每个 run name 编码 backbone 与结构,例如:
|
||||
`imf_attnres_emb256_layer12_seed42_5880gpu0`
|
||||
|
||||
## Verification
|
||||
实现阶段至少验证:
|
||||
1. 新配置的 SwanLab 命名与 `causal_attn=false` 正确;
|
||||
2. 新 backbone 的 forward shape 与 `configure_optimizers()` 可用;
|
||||
3. 旧 vanilla 路径测试不回归;
|
||||
4. `training.debug=true` smoke run 可以完整通过。
|
||||
|
||||
## Success Criteria
|
||||
1. 新 AttnRes iMF 变体在本分支可训练、可一步推理;
|
||||
2. 不影响已有 vanilla iMF/full-attn 链路;
|
||||
3. 9 组实验成功在三张卡上正式启动;
|
||||
4. SwanLab run 名称唯一,无冲突;
|
||||
5. 不记录视频,仅记录标量。
|
||||
35
image_pusht_diffusion_policy_dit_imf_attnres_full.yaml
Normal file
35
image_pusht_diffusion_policy_dit_imf_attnres_full.yaml
Normal file
@@ -0,0 +1,35 @@
|
||||
defaults:
|
||||
- diffusion_policy/config/train_diffusion_transformer_hybrid_workspace@_here_
|
||||
- override /diffusion_policy/config/task@task: pusht_image
|
||||
- _self_
|
||||
|
||||
exp_name: pusht_image_dit_imf_attnres_full
|
||||
|
||||
policy:
|
||||
_target_: diffusion_policy.policy.imf_transformer_hybrid_image_policy.IMFTransformerHybridImagePolicy
|
||||
num_inference_steps: 1
|
||||
n_head: 1
|
||||
n_kv_head: 1
|
||||
causal_attn: false
|
||||
backbone_type: attnres_full
|
||||
|
||||
logging:
|
||||
backend: swanlab
|
||||
mode: online
|
||||
name: ${exp_name}
|
||||
resume: false
|
||||
tags: ["${name}", "${task_name}", "${exp_name}", "swanlab"]
|
||||
id: null
|
||||
group: ${exp_name}
|
||||
|
||||
dataloader:
|
||||
num_workers: 0
|
||||
|
||||
val_dataloader:
|
||||
num_workers: 0
|
||||
|
||||
task:
|
||||
env_runner:
|
||||
n_envs: 1
|
||||
n_test_vis: 0
|
||||
n_train_vis: 0
|
||||
@@ -3,10 +3,12 @@ defaults:
|
||||
- override /diffusion_policy/config/task@task: pusht_image
|
||||
- _self_
|
||||
|
||||
exp_name: pusht_image_dit_nocausal
|
||||
exp_name: pusht_image_dit_imf_fullattn
|
||||
|
||||
policy:
|
||||
_target_: diffusion_policy.policy.diffusion_transformer_hybrid_image_policy.DiffusionTransformerHybridImagePolicy
|
||||
_target_: diffusion_policy.policy.imf_transformer_hybrid_image_policy.IMFTransformerHybridImagePolicy
|
||||
num_inference_steps: 1
|
||||
n_head: 1
|
||||
causal_attn: false
|
||||
|
||||
logging:
|
||||
29
scripts/pusht/imf_attnres_local_queue.sh
Executable file
29
scripts/pusht/imf_attnres_local_queue.sh
Executable file
@@ -0,0 +1,29 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
cd /home/droid/project/diffusion_policy/.worktrees/feat-pusht-imf-attnres
|
||||
export PYTHONUNBUFFERED=1
|
||||
export SWANLAB_API_KEY='PSZrBMLx1XAjDjvmhUcNz'
|
||||
export LD_LIBRARY_PATH="$(printf '%s:' .venv/lib/python3.9/site-packages/nvidia/*/lib | sed 's/:$//')"
|
||||
run_exp() {
|
||||
local name="$1" emb="$2" layer="$3"
|
||||
echo "[$(date '+%F %T')] START $name emb=$emb layer=$layer"
|
||||
.venv/bin/python train.py \
|
||||
--config-dir=. \
|
||||
--config-name=image_pusht_diffusion_policy_dit_imf_attnres_full.yaml \
|
||||
training.device=cuda:0 \
|
||||
training.num_epochs=350 \
|
||||
training.resume=false \
|
||||
exp_name="$name" \
|
||||
logging.group=imf_pusht_attnres_arch_sweep \
|
||||
logging.name="$name" \
|
||||
logging.resume=false \
|
||||
logging.id=null \
|
||||
hydra.run.dir="data/outputs/$name" \
|
||||
policy.n_emb="$emb" \
|
||||
policy.n_layer="$layer" \
|
||||
> "data/run_logs/${name}.log" 2>&1
|
||||
echo "[$(date '+%F %T')] END $name"
|
||||
}
|
||||
run_exp imf_attnres_emb384_layer18_seed42_local 384 18
|
||||
run_exp imf_attnres_emb256_layer6_seed42_local 256 6
|
||||
run_exp imf_attnres_emb128_layer6_seed42_local 128 6
|
||||
29
scripts/pusht/imf_attnres_remote_gpu0_queue.sh
Executable file
29
scripts/pusht/imf_attnres_remote_gpu0_queue.sh
Executable file
@@ -0,0 +1,29 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
cd /home/droid/project/diffusion_policy-smoke
|
||||
export PYTHONUNBUFFERED=1
|
||||
export SWANLAB_API_KEY='PSZrBMLx1XAjDjvmhUcNz'
|
||||
export LD_LIBRARY_PATH="$(printf '%s:' .venv/lib/python3.9/site-packages/nvidia/*/lib | sed 's/:$//')"
|
||||
run_exp() {
|
||||
local name="$1" emb="$2" layer="$3"
|
||||
echo "[$(date '+%F %T')] START $name emb=$emb layer=$layer"
|
||||
.venv/bin/python train.py \
|
||||
--config-dir=. \
|
||||
--config-name=image_pusht_diffusion_policy_dit_imf_attnres_full.yaml \
|
||||
training.device=cuda:0 \
|
||||
training.num_epochs=350 \
|
||||
training.resume=false \
|
||||
exp_name="$name" \
|
||||
logging.group=imf_pusht_attnres_arch_sweep \
|
||||
logging.name="$name" \
|
||||
logging.resume=false \
|
||||
logging.id=null \
|
||||
hydra.run.dir="data/outputs/$name" \
|
||||
policy.n_emb="$emb" \
|
||||
policy.n_layer="$layer" \
|
||||
> "data/run_logs/${name}.log" 2>&1
|
||||
echo "[$(date '+%F %T')] END $name"
|
||||
}
|
||||
run_exp imf_attnres_emb384_layer12_seed42_5880gpu0 384 12
|
||||
run_exp imf_attnres_emb256_layer12_seed42_5880gpu0 256 12
|
||||
run_exp imf_attnres_emb128_layer12_seed42_5880gpu0 128 12
|
||||
29
scripts/pusht/imf_attnres_remote_gpu1_queue.sh
Executable file
29
scripts/pusht/imf_attnres_remote_gpu1_queue.sh
Executable file
@@ -0,0 +1,29 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
cd /home/droid/project/diffusion_policy-smoke
|
||||
export PYTHONUNBUFFERED=1
|
||||
export SWANLAB_API_KEY='PSZrBMLx1XAjDjvmhUcNz'
|
||||
export LD_LIBRARY_PATH="$(printf '%s:' .venv/lib/python3.9/site-packages/nvidia/*/lib | sed 's/:$//')"
|
||||
run_exp() {
|
||||
local name="$1" emb="$2" layer="$3"
|
||||
echo "[$(date '+%F %T')] START $name emb=$emb layer=$layer"
|
||||
.venv/bin/python train.py \
|
||||
--config-dir=. \
|
||||
--config-name=image_pusht_diffusion_policy_dit_imf_attnres_full.yaml \
|
||||
training.device=cuda:1 \
|
||||
training.num_epochs=350 \
|
||||
training.resume=false \
|
||||
exp_name="$name" \
|
||||
logging.group=imf_pusht_attnres_arch_sweep \
|
||||
logging.name="$name" \
|
||||
logging.resume=false \
|
||||
logging.id=null \
|
||||
hydra.run.dir="data/outputs/$name" \
|
||||
policy.n_emb="$emb" \
|
||||
policy.n_layer="$layer" \
|
||||
> "data/run_logs/${name}.log" 2>&1
|
||||
echo "[$(date '+%F %T')] END $name"
|
||||
}
|
||||
run_exp imf_attnres_emb384_layer6_seed42_5880gpu1 384 6
|
||||
run_exp imf_attnres_emb256_layer18_seed42_5880gpu1 256 18
|
||||
run_exp imf_attnres_emb128_layer18_seed42_5880gpu1 128 18
|
||||
@@ -44,3 +44,34 @@ def test_imf_transformer_forward_signature_and_shape_single_head():
|
||||
pred_u = model(sample, r, t, cond=cond)
|
||||
|
||||
assert pred_u.shape == sample.shape
|
||||
|
||||
|
||||
def test_imf_transformer_attnres_full_backbone_forward_shape_and_optimizer():
|
||||
model = IMFTransformerForDiffusion(
|
||||
input_dim=3,
|
||||
output_dim=3,
|
||||
horizon=5,
|
||||
n_obs_steps=2,
|
||||
cond_dim=4,
|
||||
n_layer=2,
|
||||
n_head=1,
|
||||
n_emb=16,
|
||||
p_drop_emb=0.0,
|
||||
p_drop_attn=0.0,
|
||||
causal_attn=False,
|
||||
time_as_cond=True,
|
||||
obs_as_cond=True,
|
||||
n_cond_layers=0,
|
||||
backbone_type='attnres_full',
|
||||
)
|
||||
optimizer = model.configure_optimizers()
|
||||
|
||||
sample = torch.randn(2, 5, 3)
|
||||
r = torch.rand(2)
|
||||
t = torch.rand(2)
|
||||
cond = torch.randn(2, 2, 4)
|
||||
|
||||
pred_u = model(sample, r, t, cond=cond)
|
||||
|
||||
assert pred_u.shape == sample.shape
|
||||
assert optimizer is not None
|
||||
|
||||
@@ -32,8 +32,8 @@ def test_image_pusht_dit_imf_swanlab_config_uses_exp_name_and_no_resume_collisio
|
||||
assert cfg.logging.group == cfg.exp_name
|
||||
|
||||
|
||||
def test_image_pusht_dit_nocausal_config_uses_exp_name_and_disables_causal_attention():
|
||||
cfg = _load_cfg('image_pusht_diffusion_policy_dit_nocausal.yaml')
|
||||
def test_image_pusht_dit_imf_fullattn_config_uses_exp_name_and_disables_causal_attention():
|
||||
cfg = _load_cfg('image_pusht_diffusion_policy_dit_imf_fullattn.yaml')
|
||||
|
||||
assert cfg.logging.backend == 'swanlab'
|
||||
assert cfg.logging.mode == 'online'
|
||||
@@ -42,3 +42,16 @@ def test_image_pusht_dit_nocausal_config_uses_exp_name_and_disables_causal_atten
|
||||
assert cfg.logging.id is None
|
||||
assert cfg.logging.group == cfg.exp_name
|
||||
assert cfg.policy.causal_attn is False
|
||||
|
||||
|
||||
def test_image_pusht_dit_imf_attnres_full_config_uses_exp_name_and_disables_causal_attention():
|
||||
cfg = _load_cfg('image_pusht_diffusion_policy_dit_imf_attnres_full.yaml')
|
||||
|
||||
assert cfg.logging.backend == 'swanlab'
|
||||
assert cfg.logging.mode == 'online'
|
||||
assert cfg.logging.name == cfg.exp_name
|
||||
assert cfg.logging.resume is False
|
||||
assert cfg.logging.id is None
|
||||
assert cfg.logging.group == cfg.exp_name
|
||||
assert cfg.policy.causal_attn is False
|
||||
assert cfg.policy.backbone_type == 'attnres_full'
|
||||
|
||||
Reference in New Issue
Block a user