Compare commits
2 Commits
78ab18e8f3
...
feat/pusht
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
211abbb87f | ||
|
|
185ed6596c |
@@ -0,0 +1,247 @@
|
|||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
def __init__(self, dim: int, eps: float = 1e-6) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.weight = nn.Parameter(torch.ones(dim))
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
rms = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
|
||||||
|
return (x.float() * rms).to(x.dtype) * self.weight
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNormNoWeight(nn.Module):
|
||||||
|
def __init__(self, eps: float = 1e-6) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
rms = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
|
||||||
|
return (x.float() * rms).to(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def precompute_rope_freqs(
|
||||||
|
dim: int,
|
||||||
|
max_seq_len: int,
|
||||||
|
theta: float = 10000.0,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
) -> Tensor:
|
||||||
|
if dim % 2 != 0:
|
||||||
|
raise ValueError(f'RoPE requires an even head dimension, got {dim}.')
|
||||||
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim))
|
||||||
|
positions = torch.arange(max_seq_len, device=device).float()
|
||||||
|
angles = torch.outer(positions, freqs)
|
||||||
|
return torch.polar(torch.ones_like(angles), angles)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rope(x: Tensor, freqs: Tensor) -> Tensor:
|
||||||
|
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
||||||
|
freqs = freqs.unsqueeze(0).unsqueeze(2)
|
||||||
|
x_rotated = x_complex * freqs
|
||||||
|
return torch.view_as_real(x_rotated).reshape_as(x).to(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class GroupedQuerySelfAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int,
|
||||||
|
n_heads: int,
|
||||||
|
n_kv_heads: int,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
if d_model % n_heads != 0:
|
||||||
|
raise ValueError(f'd_model={d_model} must be divisible by n_heads={n_heads}.')
|
||||||
|
if n_heads % n_kv_heads != 0:
|
||||||
|
raise ValueError(f'n_heads={n_heads} must be divisible by n_kv_heads={n_kv_heads}.')
|
||||||
|
|
||||||
|
self.d_model = d_model
|
||||||
|
self.n_heads = n_heads
|
||||||
|
self.n_kv_heads = n_kv_heads
|
||||||
|
self.n_kv_groups = n_heads // n_kv_heads
|
||||||
|
self.d_head = d_model // n_heads
|
||||||
|
self.attn_dropout = nn.Dropout(dropout)
|
||||||
|
self.out_dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
self.w_q = nn.Linear(d_model, n_heads * self.d_head, bias=False)
|
||||||
|
self.w_k = nn.Linear(d_model, n_kv_heads * self.d_head, bias=False)
|
||||||
|
self.w_v = nn.Linear(d_model, n_kv_heads * self.d_head, bias=False)
|
||||||
|
self.w_o = nn.Linear(n_heads * self.d_head, d_model, bias=False)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
rope_freqs: Tensor,
|
||||||
|
mask: Optional[Tensor] = None,
|
||||||
|
) -> Tensor:
|
||||||
|
batch_size, seq_len, _ = x.shape
|
||||||
|
|
||||||
|
q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_head)
|
||||||
|
k = self.w_k(x).view(batch_size, seq_len, self.n_kv_heads, self.d_head)
|
||||||
|
v = self.w_v(x).view(batch_size, seq_len, self.n_kv_heads, self.d_head)
|
||||||
|
|
||||||
|
q = apply_rope(q, rope_freqs)
|
||||||
|
k = apply_rope(k, rope_freqs)
|
||||||
|
|
||||||
|
if self.n_kv_heads != self.n_heads:
|
||||||
|
k = k.unsqueeze(3).expand(
|
||||||
|
batch_size, seq_len, self.n_kv_heads, self.n_kv_groups, self.d_head
|
||||||
|
)
|
||||||
|
k = k.reshape(batch_size, seq_len, self.n_heads, self.d_head)
|
||||||
|
v = v.unsqueeze(3).expand(
|
||||||
|
batch_size, seq_len, self.n_kv_heads, self.n_kv_groups, self.d_head
|
||||||
|
)
|
||||||
|
v = v.reshape(batch_size, seq_len, self.n_heads, self.d_head)
|
||||||
|
|
||||||
|
q = q.transpose(1, 2)
|
||||||
|
k = k.transpose(1, 2)
|
||||||
|
v = v.transpose(1, 2)
|
||||||
|
|
||||||
|
scale = 1.0 / math.sqrt(self.d_head)
|
||||||
|
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale
|
||||||
|
if mask is not None:
|
||||||
|
attn_weights = attn_weights + mask
|
||||||
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
|
attn_weights = self.attn_dropout(attn_weights)
|
||||||
|
|
||||||
|
out = torch.matmul(attn_weights, v)
|
||||||
|
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
|
||||||
|
return self.out_dropout(self.w_o(out))
|
||||||
|
|
||||||
|
|
||||||
|
class SwiGLUFFN(nn.Module):
|
||||||
|
def __init__(self, d_model: int, dropout: float = 0.0, mult: float = 2.667) -> None:
|
||||||
|
super().__init__()
|
||||||
|
raw = int(mult * d_model)
|
||||||
|
d_ff = ((raw + 7) // 8) * 8
|
||||||
|
self.w_gate = nn.Linear(d_model, d_ff, bias=False)
|
||||||
|
self.w_up = nn.Linear(d_model, d_ff, bias=False)
|
||||||
|
self.w_down = nn.Linear(d_ff, d_model, bias=False)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return self.dropout(self.w_down(F.silu(self.w_gate(x)) * self.w_up(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class AttnResOperator(nn.Module):
|
||||||
|
def __init__(self, d_model: int, eps: float = 1e-6) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.pseudo_query = nn.Parameter(torch.zeros(d_model))
|
||||||
|
self.key_norm = RMSNormNoWeight(eps=eps)
|
||||||
|
|
||||||
|
def forward(self, sources: Tensor) -> Tensor:
|
||||||
|
keys = self.key_norm(sources)
|
||||||
|
logits = torch.einsum('d,nbtd->nbt', self.pseudo_query, keys)
|
||||||
|
weights = F.softmax(logits, dim=0)
|
||||||
|
return torch.einsum('nbt,nbtd->btd', weights, sources)
|
||||||
|
|
||||||
|
|
||||||
|
class AttnResSubLayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int,
|
||||||
|
n_heads: int,
|
||||||
|
n_kv_heads: int,
|
||||||
|
dropout: float,
|
||||||
|
ffn_mult: float,
|
||||||
|
eps: float,
|
||||||
|
is_attention: bool,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.norm = RMSNorm(d_model, eps=eps)
|
||||||
|
self.attn_res = AttnResOperator(d_model, eps=eps)
|
||||||
|
self.is_attention = is_attention
|
||||||
|
if is_attention:
|
||||||
|
self.fn = GroupedQuerySelfAttention(
|
||||||
|
d_model=d_model,
|
||||||
|
n_heads=n_heads,
|
||||||
|
n_kv_heads=n_kv_heads,
|
||||||
|
dropout=dropout,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.fn = SwiGLUFFN(d_model=d_model, dropout=dropout, mult=ffn_mult)
|
||||||
|
|
||||||
|
def forward(self, sources: Tensor, rope_freqs: Tensor, mask: Optional[Tensor] = None) -> Tensor:
|
||||||
|
h = self.attn_res(sources)
|
||||||
|
normed = self.norm(h)
|
||||||
|
if self.is_attention:
|
||||||
|
return self.fn(normed, rope_freqs, mask)
|
||||||
|
return self.fn(normed)
|
||||||
|
|
||||||
|
|
||||||
|
class AttnResTransformerBackbone(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int,
|
||||||
|
n_blocks: int,
|
||||||
|
n_heads: int,
|
||||||
|
n_kv_heads: int,
|
||||||
|
max_seq_len: int,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
ffn_mult: float = 2.667,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
rope_theta: float = 10000.0,
|
||||||
|
causal_attn: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.causal_attn = causal_attn
|
||||||
|
self.layers = nn.ModuleList()
|
||||||
|
for _ in range(n_blocks):
|
||||||
|
self.layers.append(
|
||||||
|
AttnResSubLayer(
|
||||||
|
d_model=d_model,
|
||||||
|
n_heads=n_heads,
|
||||||
|
n_kv_heads=n_kv_heads,
|
||||||
|
dropout=dropout,
|
||||||
|
ffn_mult=ffn_mult,
|
||||||
|
eps=eps,
|
||||||
|
is_attention=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.layers.append(
|
||||||
|
AttnResSubLayer(
|
||||||
|
d_model=d_model,
|
||||||
|
n_heads=n_heads,
|
||||||
|
n_kv_heads=n_kv_heads,
|
||||||
|
dropout=dropout,
|
||||||
|
ffn_mult=ffn_mult,
|
||||||
|
eps=eps,
|
||||||
|
is_attention=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
rope_freqs = precompute_rope_freqs(
|
||||||
|
dim=d_model // n_heads,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
theta=rope_theta,
|
||||||
|
)
|
||||||
|
self.register_buffer('rope_freqs', rope_freqs, persistent=False)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_causal_mask(seq_len: int, device: torch.device) -> Tensor:
|
||||||
|
mask = torch.full((seq_len, seq_len), float('-inf'), device=device)
|
||||||
|
mask = torch.triu(mask, diagonal=1)
|
||||||
|
return mask.unsqueeze(0).unsqueeze(0)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
seq_len = x.shape[1]
|
||||||
|
rope_freqs = self.rope_freqs[:seq_len]
|
||||||
|
mask = None
|
||||||
|
if self.causal_attn:
|
||||||
|
mask = self._build_causal_mask(seq_len, x.device)
|
||||||
|
|
||||||
|
layer_outputs = [x]
|
||||||
|
for layer in self.layers:
|
||||||
|
sources = torch.stack(layer_outputs, dim=0)
|
||||||
|
output = layer(sources, rope_freqs, mask)
|
||||||
|
layer_outputs.append(output)
|
||||||
|
|
||||||
|
return torch.stack(layer_outputs, dim=0).sum(dim=0)
|
||||||
@@ -5,6 +5,15 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin
|
from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin
|
||||||
|
from diffusion_policy.model.diffusion.attnres_transformer_components import (
|
||||||
|
AttnResOperator,
|
||||||
|
AttnResSubLayer,
|
||||||
|
AttnResTransformerBackbone,
|
||||||
|
GroupedQuerySelfAttention,
|
||||||
|
RMSNorm,
|
||||||
|
RMSNormNoWeight,
|
||||||
|
SwiGLUFFN,
|
||||||
|
)
|
||||||
from diffusion_policy.model.diffusion.positional_embedding import SinusoidalPosEmb
|
from diffusion_policy.model.diffusion.positional_embedding import SinusoidalPosEmb
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -27,14 +36,20 @@ class IMFTransformerForDiffusion(ModuleAttrMixin):
|
|||||||
time_as_cond: bool = True,
|
time_as_cond: bool = True,
|
||||||
obs_as_cond: bool = False,
|
obs_as_cond: bool = False,
|
||||||
n_cond_layers: int = 0,
|
n_cond_layers: int = 0,
|
||||||
|
backbone_type: str = 'vanilla',
|
||||||
|
n_kv_head: int = 1,
|
||||||
|
attn_res_ffn_mult: float = 2.667,
|
||||||
|
attn_res_eps: float = 1e-6,
|
||||||
|
attn_res_rope_theta: float = 10000.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
assert n_head == 1, 'IMFTransformerForDiffusion currently supports single-head attention only.'
|
assert n_head == 1, 'IMFTransformerForDiffusion currently supports single-head attention only.'
|
||||||
|
|
||||||
if n_obs_steps is None:
|
if n_obs_steps is None:
|
||||||
n_obs_steps = horizon
|
n_obs_steps = horizon
|
||||||
|
|
||||||
|
self.backbone_type = backbone_type
|
||||||
|
|
||||||
T = horizon
|
T = horizon
|
||||||
T_cond = 2
|
T_cond = 2
|
||||||
if not time_as_cond:
|
if not time_as_cond:
|
||||||
@@ -46,21 +61,77 @@ class IMFTransformerForDiffusion(ModuleAttrMixin):
|
|||||||
T_cond += n_obs_steps
|
T_cond += n_obs_steps
|
||||||
|
|
||||||
self.input_emb = nn.Linear(input_dim, n_emb)
|
self.input_emb = nn.Linear(input_dim, n_emb)
|
||||||
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
|
|
||||||
self.drop = nn.Dropout(p_drop_emb)
|
self.drop = nn.Dropout(p_drop_emb)
|
||||||
|
|
||||||
self.time_emb = SinusoidalPosEmb(n_emb)
|
self.time_emb = SinusoidalPosEmb(n_emb)
|
||||||
self.cond_obs_emb = None
|
self.cond_obs_emb = nn.Linear(cond_dim, n_emb) if obs_as_cond else None
|
||||||
if obs_as_cond:
|
self.time_token_proj = None
|
||||||
self.cond_obs_emb = nn.Linear(cond_dim, n_emb)
|
|
||||||
|
|
||||||
self.cond_pos_emb = None
|
self.cond_pos_emb = None
|
||||||
|
self.pos_emb = None
|
||||||
self.encoder = None
|
self.encoder = None
|
||||||
self.decoder = None
|
self.decoder = None
|
||||||
|
self.attnres_backbone = None
|
||||||
encoder_only = False
|
encoder_only = False
|
||||||
if T_cond > 0:
|
|
||||||
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
|
if backbone_type == 'attnres_full':
|
||||||
if n_cond_layers > 0:
|
if not time_as_cond:
|
||||||
|
raise ValueError('attnres_full backbone requires time_as_cond=True.')
|
||||||
|
if n_cond_layers != 0:
|
||||||
|
raise ValueError('attnres_full backbone does not support n_cond_layers > 0.')
|
||||||
|
|
||||||
|
self.time_token_proj = nn.Linear(n_emb, n_emb)
|
||||||
|
self.attnres_backbone = AttnResTransformerBackbone(
|
||||||
|
d_model=n_emb,
|
||||||
|
n_blocks=n_layer,
|
||||||
|
n_heads=n_head,
|
||||||
|
n_kv_heads=n_kv_head,
|
||||||
|
max_seq_len=T + T_cond,
|
||||||
|
dropout=p_drop_attn,
|
||||||
|
ffn_mult=attn_res_ffn_mult,
|
||||||
|
eps=attn_res_eps,
|
||||||
|
rope_theta=attn_res_rope_theta,
|
||||||
|
causal_attn=causal_attn,
|
||||||
|
)
|
||||||
|
self.ln_f = RMSNorm(n_emb, eps=attn_res_eps)
|
||||||
|
else:
|
||||||
|
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
|
||||||
|
if T_cond > 0:
|
||||||
|
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
|
||||||
|
if n_cond_layers > 0:
|
||||||
|
encoder_layer = nn.TransformerEncoderLayer(
|
||||||
|
d_model=n_emb,
|
||||||
|
nhead=n_head,
|
||||||
|
dim_feedforward=4 * n_emb,
|
||||||
|
dropout=p_drop_attn,
|
||||||
|
activation='gelu',
|
||||||
|
batch_first=True,
|
||||||
|
norm_first=True,
|
||||||
|
)
|
||||||
|
self.encoder = nn.TransformerEncoder(
|
||||||
|
encoder_layer=encoder_layer,
|
||||||
|
num_layers=n_cond_layers,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.encoder = nn.Sequential(
|
||||||
|
nn.Linear(n_emb, 4 * n_emb),
|
||||||
|
nn.Mish(),
|
||||||
|
nn.Linear(4 * n_emb, n_emb),
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_layer = nn.TransformerDecoderLayer(
|
||||||
|
d_model=n_emb,
|
||||||
|
nhead=n_head,
|
||||||
|
dim_feedforward=4 * n_emb,
|
||||||
|
dropout=p_drop_attn,
|
||||||
|
activation='gelu',
|
||||||
|
batch_first=True,
|
||||||
|
norm_first=True,
|
||||||
|
)
|
||||||
|
self.decoder = nn.TransformerDecoder(
|
||||||
|
decoder_layer=decoder_layer,
|
||||||
|
num_layers=n_layer,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
encoder_only = True
|
||||||
encoder_layer = nn.TransformerEncoderLayer(
|
encoder_layer = nn.TransformerEncoderLayer(
|
||||||
d_model=n_emb,
|
d_model=n_emb,
|
||||||
nhead=n_head,
|
nhead=n_head,
|
||||||
@@ -72,45 +143,12 @@ class IMFTransformerForDiffusion(ModuleAttrMixin):
|
|||||||
)
|
)
|
||||||
self.encoder = nn.TransformerEncoder(
|
self.encoder = nn.TransformerEncoder(
|
||||||
encoder_layer=encoder_layer,
|
encoder_layer=encoder_layer,
|
||||||
num_layers=n_cond_layers,
|
num_layers=n_layer,
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.encoder = nn.Sequential(
|
|
||||||
nn.Linear(n_emb, 4 * n_emb),
|
|
||||||
nn.Mish(),
|
|
||||||
nn.Linear(4 * n_emb, n_emb),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
decoder_layer = nn.TransformerDecoderLayer(
|
self.ln_f = nn.LayerNorm(n_emb)
|
||||||
d_model=n_emb,
|
|
||||||
nhead=n_head,
|
|
||||||
dim_feedforward=4 * n_emb,
|
|
||||||
dropout=p_drop_attn,
|
|
||||||
activation='gelu',
|
|
||||||
batch_first=True,
|
|
||||||
norm_first=True,
|
|
||||||
)
|
|
||||||
self.decoder = nn.TransformerDecoder(
|
|
||||||
decoder_layer=decoder_layer,
|
|
||||||
num_layers=n_layer,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
encoder_only = True
|
|
||||||
encoder_layer = nn.TransformerEncoderLayer(
|
|
||||||
d_model=n_emb,
|
|
||||||
nhead=n_head,
|
|
||||||
dim_feedforward=4 * n_emb,
|
|
||||||
dropout=p_drop_attn,
|
|
||||||
activation='gelu',
|
|
||||||
batch_first=True,
|
|
||||||
norm_first=True,
|
|
||||||
)
|
|
||||||
self.encoder = nn.TransformerEncoder(
|
|
||||||
encoder_layer=encoder_layer,
|
|
||||||
num_layers=n_layer,
|
|
||||||
)
|
|
||||||
|
|
||||||
if causal_attn:
|
if causal_attn and backbone_type != 'attnres_full':
|
||||||
sz = T
|
sz = T
|
||||||
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
||||||
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||||
@@ -132,7 +170,6 @@ class IMFTransformerForDiffusion(ModuleAttrMixin):
|
|||||||
self.mask = None
|
self.mask = None
|
||||||
self.memory_mask = None
|
self.memory_mask = None
|
||||||
|
|
||||||
self.ln_f = nn.LayerNorm(n_emb)
|
|
||||||
self.head = nn.Linear(n_emb, output_dim)
|
self.head = nn.Linear(n_emb, output_dim)
|
||||||
|
|
||||||
self.T = T
|
self.T = T
|
||||||
@@ -159,6 +196,11 @@ class IMFTransformerForDiffusion(ModuleAttrMixin):
|
|||||||
nn.ModuleList,
|
nn.ModuleList,
|
||||||
nn.Mish,
|
nn.Mish,
|
||||||
nn.Sequential,
|
nn.Sequential,
|
||||||
|
AttnResTransformerBackbone,
|
||||||
|
AttnResSubLayer,
|
||||||
|
GroupedQuerySelfAttention,
|
||||||
|
SwiGLUFFN,
|
||||||
|
RMSNormNoWeight,
|
||||||
)
|
)
|
||||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||||
@@ -176,12 +218,16 @@ class IMFTransformerForDiffusion(ModuleAttrMixin):
|
|||||||
bias = getattr(module, name)
|
bias = getattr(module, name)
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
torch.nn.init.zeros_(bias)
|
torch.nn.init.zeros_(bias)
|
||||||
elif isinstance(module, nn.LayerNorm):
|
elif isinstance(module, (nn.LayerNorm, RMSNorm)):
|
||||||
torch.nn.init.zeros_(module.bias)
|
if getattr(module, 'bias', None) is not None:
|
||||||
|
torch.nn.init.zeros_(module.bias)
|
||||||
torch.nn.init.ones_(module.weight)
|
torch.nn.init.ones_(module.weight)
|
||||||
|
elif isinstance(module, AttnResOperator):
|
||||||
|
torch.nn.init.zeros_(module.pseudo_query)
|
||||||
elif isinstance(module, IMFTransformerForDiffusion):
|
elif isinstance(module, IMFTransformerForDiffusion):
|
||||||
torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
|
if module.pos_emb is not None:
|
||||||
if module.cond_obs_emb is not None:
|
torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
|
||||||
|
if module.cond_pos_emb is not None:
|
||||||
torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02)
|
torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02)
|
||||||
elif isinstance(module, ignore_types):
|
elif isinstance(module, ignore_types):
|
||||||
pass
|
pass
|
||||||
@@ -192,21 +238,24 @@ class IMFTransformerForDiffusion(ModuleAttrMixin):
|
|||||||
decay = set()
|
decay = set()
|
||||||
no_decay = set()
|
no_decay = set()
|
||||||
whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention)
|
whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention)
|
||||||
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding, RMSNorm)
|
||||||
for mn, m in self.named_modules():
|
for mn, m in self.named_modules():
|
||||||
for pn, _ in m.named_parameters():
|
for pn, _ in m.named_parameters(recurse=False):
|
||||||
fpn = '%s.%s' % (mn, pn) if mn else pn
|
fpn = '%s.%s' % (mn, pn) if mn else pn
|
||||||
|
|
||||||
if pn.endswith('bias'):
|
if pn.endswith('bias'):
|
||||||
no_decay.add(fpn)
|
no_decay.add(fpn)
|
||||||
elif pn.startswith('bias'):
|
elif pn.startswith('bias'):
|
||||||
no_decay.add(fpn)
|
no_decay.add(fpn)
|
||||||
|
elif pn == 'pseudo_query':
|
||||||
|
no_decay.add(fpn)
|
||||||
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
|
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
|
||||||
decay.add(fpn)
|
decay.add(fpn)
|
||||||
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
|
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
|
||||||
no_decay.add(fpn)
|
no_decay.add(fpn)
|
||||||
|
|
||||||
no_decay.add('pos_emb')
|
if self.pos_emb is not None:
|
||||||
|
no_decay.add('pos_emb')
|
||||||
no_decay.add('_dummy_variable')
|
no_decay.add('_dummy_variable')
|
||||||
if self.cond_pos_emb is not None:
|
if self.cond_pos_emb is not None:
|
||||||
no_decay.add('cond_pos_emb')
|
no_decay.add('cond_pos_emb')
|
||||||
@@ -250,18 +299,38 @@ class IMFTransformerForDiffusion(ModuleAttrMixin):
|
|||||||
value = value.to(device=sample.device, dtype=sample.dtype)
|
value = value.to(device=sample.device, dtype=sample.dtype)
|
||||||
return value.expand(sample.shape[0])
|
return value.expand(sample.shape[0])
|
||||||
|
|
||||||
def forward(
|
def _forward_attnres_full(
|
||||||
self,
|
self,
|
||||||
sample: torch.Tensor,
|
sample: torch.Tensor,
|
||||||
r: Union[torch.Tensor, float, int],
|
r: torch.Tensor,
|
||||||
t: Union[torch.Tensor, float, int],
|
t: torch.Tensor,
|
||||||
cond: Optional[torch.Tensor] = None,
|
cond: Optional[torch.Tensor] = None,
|
||||||
):
|
) -> torch.Tensor:
|
||||||
r = self._prepare_time_input(r, sample)
|
sample_tokens = self.input_emb(sample)
|
||||||
t = self._prepare_time_input(t, sample)
|
token_parts = [
|
||||||
|
self.time_token_proj(self.time_emb(r)).unsqueeze(1),
|
||||||
|
self.time_token_proj(self.time_emb(t)).unsqueeze(1),
|
||||||
|
]
|
||||||
|
if self.obs_as_cond:
|
||||||
|
if cond is None:
|
||||||
|
raise ValueError('cond is required when obs_as_cond=True for attnres_full backbone.')
|
||||||
|
token_parts.append(self.cond_obs_emb(cond))
|
||||||
|
token_parts.append(sample_tokens)
|
||||||
|
x = torch.cat(token_parts, dim=1)
|
||||||
|
x = self.drop(x)
|
||||||
|
x = self.attnres_backbone(x)
|
||||||
|
x = x[:, -sample_tokens.shape[1] :, :]
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _forward_vanilla(
|
||||||
|
self,
|
||||||
|
sample: torch.Tensor,
|
||||||
|
r: torch.Tensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
cond: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
r_emb = self.time_emb(r).unsqueeze(1)
|
r_emb = self.time_emb(r).unsqueeze(1)
|
||||||
t_emb = self.time_emb(t).unsqueeze(1)
|
t_emb = self.time_emb(t).unsqueeze(1)
|
||||||
|
|
||||||
input_emb = self.input_emb(sample)
|
input_emb = self.input_emb(sample)
|
||||||
|
|
||||||
if self.encoder_only:
|
if self.encoder_only:
|
||||||
@@ -292,6 +361,22 @@ class IMFTransformerForDiffusion(ModuleAttrMixin):
|
|||||||
tgt_mask=self.mask,
|
tgt_mask=self.mask,
|
||||||
memory_mask=self.memory_mask,
|
memory_mask=self.memory_mask,
|
||||||
)
|
)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
sample: torch.Tensor,
|
||||||
|
r: Union[torch.Tensor, float, int],
|
||||||
|
t: Union[torch.Tensor, float, int],
|
||||||
|
cond: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
r = self._prepare_time_input(r, sample)
|
||||||
|
t = self._prepare_time_input(t, sample)
|
||||||
|
|
||||||
|
if self.backbone_type == 'attnres_full':
|
||||||
|
x = self._forward_attnres_full(sample, r, t, cond=cond)
|
||||||
|
else:
|
||||||
|
x = self._forward_vanilla(sample, r, t, cond=cond)
|
||||||
|
|
||||||
x = self.ln_f(x)
|
x = self.ln_f(x)
|
||||||
x = self.head(x)
|
x = self.head(x)
|
||||||
|
|||||||
@@ -39,6 +39,11 @@ class IMFTransformerHybridImagePolicy(DiffusionTransformerHybridImagePolicy):
|
|||||||
time_as_cond=True,
|
time_as_cond=True,
|
||||||
obs_as_cond=True,
|
obs_as_cond=True,
|
||||||
pred_action_steps_only=False,
|
pred_action_steps_only=False,
|
||||||
|
backbone_type='vanilla',
|
||||||
|
n_kv_head=1,
|
||||||
|
attn_res_ffn_mult=2.667,
|
||||||
|
attn_res_eps=1e-6,
|
||||||
|
attn_res_rope_theta=10000.0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if num_inference_steps is None:
|
if num_inference_steps is None:
|
||||||
@@ -90,6 +95,11 @@ class IMFTransformerHybridImagePolicy(DiffusionTransformerHybridImagePolicy):
|
|||||||
time_as_cond=time_as_cond,
|
time_as_cond=time_as_cond,
|
||||||
obs_as_cond=obs_as_cond,
|
obs_as_cond=obs_as_cond,
|
||||||
n_cond_layers=n_cond_layers,
|
n_cond_layers=n_cond_layers,
|
||||||
|
backbone_type=backbone_type,
|
||||||
|
n_kv_head=n_kv_head,
|
||||||
|
attn_res_ffn_mult=attn_res_ffn_mult,
|
||||||
|
attn_res_eps=attn_res_eps,
|
||||||
|
attn_res_rope_theta=attn_res_rope_theta,
|
||||||
)
|
)
|
||||||
self.num_inference_steps = 1
|
self.num_inference_steps = 1
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,57 @@
|
|||||||
|
# PushT Image iMF AttnRes Implementation Plan
|
||||||
|
|
||||||
|
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||||
|
|
||||||
|
**Goal:** Add an AttnRes-backed full-attention iMF backbone for the PushT image experiment path, verify it with tests/smoke runs, then launch the 9-run 350-epoch architecture sweep across the local 5090 and remote 5880 GPUs.
|
||||||
|
|
||||||
|
**Architecture:** Extend `IMFTransformerForDiffusion` with a selectable `attnres_full` backbone that keeps the current iMF training/inference API unchanged while replacing the transformer internals with RMSNorm + RoPE self-attention + SwiGLU + Full AttnRes depth-wise residual routing. Add one standalone Hydra config for the PushT image sweep and reuse queue-style launch scripts with unique SwanLab names.
|
||||||
|
|
||||||
|
**Tech Stack:** Python 3.9 via uv, PyTorch 2.8 CUDA, Hydra, SwanLab online logging, local shell + SSH to trusted 5880 host.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 1: Add regression tests for the new AttnRes path
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `tests/test_imf_transformer_for_diffusion.py`
|
||||||
|
- Modify: `tests/test_pusht_swanlab_config.py`
|
||||||
|
|
||||||
|
- [ ] Add a failing model test that instantiates `IMFTransformerForDiffusion(backbone_type='attnres_full', causal_attn=False, ...)`, runs a forward pass with conditional observations, and asserts output shape plus optimizer construction.
|
||||||
|
- [ ] Run the targeted pytest selection and confirm the new test fails for the expected missing-backbone reason.
|
||||||
|
- [ ] Add a failing config regression test for `image_pusht_diffusion_policy_dit_imf_attnres_full.yaml` asserting SwanLab naming fields and `policy.causal_attn == False`.
|
||||||
|
- [ ] Re-run the targeted pytest selection and confirm the config test fails before implementation.
|
||||||
|
|
||||||
|
### Task 2: Implement the AttnRes-backed iMF backbone
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `diffusion_policy/model/diffusion/attnres_transformer_components.py`
|
||||||
|
- Modify: `diffusion_policy/model/diffusion/imf_transformer_for_diffusion.py`
|
||||||
|
|
||||||
|
- [ ] Add focused reusable modules for `RMSNorm`, RoPE helpers, grouped-query self-attention, SwiGLU FFN, and the Full AttnRes operator.
|
||||||
|
- [ ] Extend `IMFTransformerForDiffusion` with a `backbone_type` switch that preserves the existing vanilla path and adds an `attnres_full` path using concatenated `[r, t, obs, sample]` tokens.
|
||||||
|
- [ ] Ensure the AttnRes path slices condition tokens away before the output head so the returned tensor still matches the sample/action horizon.
|
||||||
|
- [ ] Update optimizer parameter grouping to treat RMSNorm weights like LayerNorm weights (no decay) and include any new positional/conditioning parameters.
|
||||||
|
- [ ] Run the targeted tests and get them green.
|
||||||
|
|
||||||
|
### Task 3: Add the new PushT config and smoke-test path
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `image_pusht_diffusion_policy_dit_imf_attnres_full.yaml`
|
||||||
|
- Modify: `tests/test_pusht_swanlab_config.py`
|
||||||
|
|
||||||
|
- [ ] Add a standalone PushT image config for the AttnRes iMF variant with SwanLab online logging, `policy.backbone_type=attnres_full`, and `policy.causal_attn=false`.
|
||||||
|
- [ ] Verify `uv run python train.py --config-dir=. --config-name=image_pusht_diffusion_policy_dit_imf_attnres_full.yaml --help` succeeds.
|
||||||
|
- [ ] Run a real smoke training command with `training.debug=true`, `training.device=cuda:0`, safety overrides (`dataloader.num_workers=0`, `task.env_runner.n_envs=1`, no vis), and confirm it reaches the training loop and writes a run directory.
|
||||||
|
|
||||||
|
### Task 4: Prepare launch scripts and start the 9-run sweep
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create or modify: `data/run_logs/imf_attnres_local_queue.sh`
|
||||||
|
- Create or modify locally before copy: `data/run_logs/imf_attnres_remote_gpu0_queue.sh`
|
||||||
|
- Create or modify locally before copy: `data/run_logs/imf_attnres_remote_gpu1_queue.sh`
|
||||||
|
|
||||||
|
- [ ] Write queue command templates for the 9 runs using config `image_pusht_diffusion_policy_dit_imf_attnres_full.yaml`, `training.num_epochs=350`, unique `exp_name/logging.name`, and shared `logging.group=imf_pusht_attnres_arch_sweep`.
|
||||||
|
- [ ] Sync the necessary config/model files plus remote queue scripts to `droid@100.73.14.65:~/project/diffusion_policy-smoke`.
|
||||||
|
- [ ] Start the local queue under `nohup`, record PID, and verify the first run log is advancing.
|
||||||
|
- [ ] Start the two remote queues under `nohup`, record PIDs, and verify both first-run logs are advancing.
|
||||||
|
- [ ] Confirm all three GPUs have officially entered training for the new sweep.
|
||||||
108
docs/superpowers/specs/2026-03-29-pusht-imf-attnres-design.md
Normal file
108
docs/superpowers/specs/2026-03-29-pusht-imf-attnres-design.md
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
# PushT Image iMF AttnRes Design
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
在现有 PushT 图像 iMF full-attention 路线之上,引入 `attn_res` 仓库中的 **Full AttnRes** 残差聚合形式,并同步使用与其匹配的 **RMSNorm + 自注意力 + SwiGLU FFN** 模块,保持 iMF 训练目标与一步推理语义不变,仅作用于本次实验链路。实现完成并验证后,启动与此前相同的 9 组 `n_emb × n_layer` 扫描(350 epochs, seed=42, SwanLab online, 无视频记录)。
|
||||||
|
|
||||||
|
## Scope
|
||||||
|
本次工作仅覆盖:
|
||||||
|
1. 为 `IMFTransformerForDiffusion` 增加一个 AttnRes-backed backbone 变体;
|
||||||
|
2. 保持 `forward(sample, r, t, cond=None)`、iMF loss、一步推理策略接口不变;
|
||||||
|
3. 新增独立 PushT 图像配置用于该变体;
|
||||||
|
4. 复用本地 5090 + 远端 5880 双卡三路并行调度 9 组实验。
|
||||||
|
|
||||||
|
不在范围内:
|
||||||
|
- 不替换已有 vanilla iMF/full-attn 配置;
|
||||||
|
- 不修改 DiT baseline;
|
||||||
|
- 不增加视频日志;
|
||||||
|
- 不扩大到多 seed。
|
||||||
|
|
||||||
|
## Recommended Approach
|
||||||
|
采用“**在当前 iMF 模型内增加可选 AttnRes backbone**”的方式,而不是新建独立 policy 链路。
|
||||||
|
|
||||||
|
理由:
|
||||||
|
- policy / workspace / loss / sampling 路径已经被验证,保留这些路径可最大程度缩小变动面;
|
||||||
|
- 仅在模型内部切换 backbone,可以让新实验与既有 iMF 结果保持可比;
|
||||||
|
- 配置上只需显式打开 `backbone_type=attnres_full`、`causal_attn=false` 等开关,复现实验更直接。
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
### 1. Backbone split
|
||||||
|
`IMFTransformerForDiffusion` 保留现有 vanilla encoder/decoder 实现为默认路径,并新增 `attnres_full` 路径:
|
||||||
|
- **vanilla**:保持当前实现不变;
|
||||||
|
- **attnres_full**:使用单栈式全注意力 Transformer,输入 token 序列为
|
||||||
|
`[r token, t token, obs cond tokens..., action/sample tokens...]`。
|
||||||
|
|
||||||
|
模型只对末尾的 action/sample token 位置输出 `u` 预测,前置条件 token 仅参与上下文建模。
|
||||||
|
|
||||||
|
### 2. AttnRes stack
|
||||||
|
新 backbone 使用以下模块:
|
||||||
|
- `RMSNorm`
|
||||||
|
- `Rotary Position Embedding`(用于自注意力 q/k)
|
||||||
|
- `GroupedQueryAttention`(本实验默认 `n_kv_head=1`,与单头配置兼容)
|
||||||
|
- `SwiGLU` FFN
|
||||||
|
- `AttnResOperator`(每个子层一个 pseudo-query,执行 full depth-wise residual aggregation)
|
||||||
|
|
||||||
|
每个 transformer block 由两个子层组成:
|
||||||
|
1. self-attention 子层
|
||||||
|
2. FFN 子层
|
||||||
|
|
||||||
|
每个子层的输入不再是简单 `x + f(x)`,而是从 embedding 与全部历史子层输出中通过 Full AttnRes 聚合得到 `h_l`,再执行 `RMSNorm(h_l) -> sublayer_fn(...)`。
|
||||||
|
|
||||||
|
### 3. Conditioning and token flow
|
||||||
|
- `sample` 先经 `input_emb` 映射为 action tokens;
|
||||||
|
- `r` 和 `t` 各自经 `SinusoidalPosEmb + linear` 映射为两个条件 token;
|
||||||
|
- 图像观测编码后的 `cond` 通过 `cond_obs_emb` 映射为 obs tokens;
|
||||||
|
- 拼接后的完整 token 序列进入 AttnRes stack;
|
||||||
|
- 输出时切掉前置条件 token,仅保留 action/sample token 段,随后经 `RMSNorm + head` 得到最终 `u`。
|
||||||
|
|
||||||
|
### 4. Attention mode
|
||||||
|
本次实验链路固定为 **non-causal full attention**:
|
||||||
|
- `causal_attn=false`
|
||||||
|
- 不构造 causal mask
|
||||||
|
- 所有 token 可彼此双向可见
|
||||||
|
|
||||||
|
这与用户指定的“训练过程仍然使用全注意力(不加因果注意)”一致。
|
||||||
|
|
||||||
|
## Config and Logging
|
||||||
|
新增独立配置文件,例如:
|
||||||
|
- `image_pusht_diffusion_policy_dit_imf_attnres_full.yaml`
|
||||||
|
|
||||||
|
该配置需要:
|
||||||
|
- 指向现有 `IMFTransformerHybridImagePolicy`
|
||||||
|
- 显式开启 AttnRes backbone 相关参数
|
||||||
|
- 设置 `policy.causal_attn=false`
|
||||||
|
- 保持 `logging.backend=swanlab`、`logging.mode=online`
|
||||||
|
- 运行时通过覆盖保证:
|
||||||
|
- `logging.name=<unique_run_name>`
|
||||||
|
- `logging.group=imf_pusht_attnres_arch_sweep`
|
||||||
|
- `exp_name=<unique_run_name>`
|
||||||
|
- 保持 `task.env_runner.n_test_vis=0` 与 `n_train_vis=0`,仅记录标量
|
||||||
|
|
||||||
|
## Experiment Matrix
|
||||||
|
固定 9 组:
|
||||||
|
- `n_emb ∈ {128, 256, 384}`
|
||||||
|
- `n_layer ∈ {6, 12, 18}`
|
||||||
|
- `seed=42`
|
||||||
|
- `training.num_epochs=350`
|
||||||
|
|
||||||
|
## Scheduling
|
||||||
|
沿用之前验证过的三队列分配:
|
||||||
|
- 本机 5090:`384x18`, `256x6`, `128x6`
|
||||||
|
- 5880 GPU0:`384x12`, `256x12`, `128x12`
|
||||||
|
- 5880 GPU1:`384x6`, `256x18`, `128x18`
|
||||||
|
|
||||||
|
每个 run name 编码 backbone 与结构,例如:
|
||||||
|
`imf_attnres_emb256_layer12_seed42_5880gpu0`
|
||||||
|
|
||||||
|
## Verification
|
||||||
|
实现阶段至少验证:
|
||||||
|
1. 新配置的 SwanLab 命名与 `causal_attn=false` 正确;
|
||||||
|
2. 新 backbone 的 forward shape 与 `configure_optimizers()` 可用;
|
||||||
|
3. 旧 vanilla 路径测试不回归;
|
||||||
|
4. `training.debug=true` smoke run 可以完整通过。
|
||||||
|
|
||||||
|
## Success Criteria
|
||||||
|
1. 新 AttnRes iMF 变体在本分支可训练、可一步推理;
|
||||||
|
2. 不影响已有 vanilla iMF/full-attn 链路;
|
||||||
|
3. 9 组实验成功在三张卡上正式启动;
|
||||||
|
4. SwanLab run 名称唯一,无冲突;
|
||||||
|
5. 不记录视频,仅记录标量。
|
||||||
35
image_pusht_diffusion_policy_dit_imf_attnres_full.yaml
Normal file
35
image_pusht_diffusion_policy_dit_imf_attnres_full.yaml
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
defaults:
|
||||||
|
- diffusion_policy/config/train_diffusion_transformer_hybrid_workspace@_here_
|
||||||
|
- override /diffusion_policy/config/task@task: pusht_image
|
||||||
|
- _self_
|
||||||
|
|
||||||
|
exp_name: pusht_image_dit_imf_attnres_full
|
||||||
|
|
||||||
|
policy:
|
||||||
|
_target_: diffusion_policy.policy.imf_transformer_hybrid_image_policy.IMFTransformerHybridImagePolicy
|
||||||
|
num_inference_steps: 1
|
||||||
|
n_head: 1
|
||||||
|
n_kv_head: 1
|
||||||
|
causal_attn: false
|
||||||
|
backbone_type: attnres_full
|
||||||
|
|
||||||
|
logging:
|
||||||
|
backend: swanlab
|
||||||
|
mode: online
|
||||||
|
name: ${exp_name}
|
||||||
|
resume: false
|
||||||
|
tags: ["${name}", "${task_name}", "${exp_name}", "swanlab"]
|
||||||
|
id: null
|
||||||
|
group: ${exp_name}
|
||||||
|
|
||||||
|
dataloader:
|
||||||
|
num_workers: 0
|
||||||
|
|
||||||
|
val_dataloader:
|
||||||
|
num_workers: 0
|
||||||
|
|
||||||
|
task:
|
||||||
|
env_runner:
|
||||||
|
n_envs: 1
|
||||||
|
n_test_vis: 0
|
||||||
|
n_train_vis: 0
|
||||||
29
scripts/pusht/imf_attnres_local_queue.sh
Executable file
29
scripts/pusht/imf_attnres_local_queue.sh
Executable file
@@ -0,0 +1,29 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
set -euo pipefail
|
||||||
|
cd /home/droid/project/diffusion_policy/.worktrees/feat-pusht-imf-attnres
|
||||||
|
export PYTHONUNBUFFERED=1
|
||||||
|
export SWANLAB_API_KEY='PSZrBMLx1XAjDjvmhUcNz'
|
||||||
|
export LD_LIBRARY_PATH="$(printf '%s:' .venv/lib/python3.9/site-packages/nvidia/*/lib | sed 's/:$//')"
|
||||||
|
run_exp() {
|
||||||
|
local name="$1" emb="$2" layer="$3"
|
||||||
|
echo "[$(date '+%F %T')] START $name emb=$emb layer=$layer"
|
||||||
|
.venv/bin/python train.py \
|
||||||
|
--config-dir=. \
|
||||||
|
--config-name=image_pusht_diffusion_policy_dit_imf_attnres_full.yaml \
|
||||||
|
training.device=cuda:0 \
|
||||||
|
training.num_epochs=350 \
|
||||||
|
training.resume=false \
|
||||||
|
exp_name="$name" \
|
||||||
|
logging.group=imf_pusht_attnres_arch_sweep \
|
||||||
|
logging.name="$name" \
|
||||||
|
logging.resume=false \
|
||||||
|
logging.id=null \
|
||||||
|
hydra.run.dir="data/outputs/$name" \
|
||||||
|
policy.n_emb="$emb" \
|
||||||
|
policy.n_layer="$layer" \
|
||||||
|
> "data/run_logs/${name}.log" 2>&1
|
||||||
|
echo "[$(date '+%F %T')] END $name"
|
||||||
|
}
|
||||||
|
run_exp imf_attnres_emb384_layer18_seed42_local 384 18
|
||||||
|
run_exp imf_attnres_emb256_layer6_seed42_local 256 6
|
||||||
|
run_exp imf_attnres_emb128_layer6_seed42_local 128 6
|
||||||
29
scripts/pusht/imf_attnres_remote_gpu0_queue.sh
Executable file
29
scripts/pusht/imf_attnres_remote_gpu0_queue.sh
Executable file
@@ -0,0 +1,29 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
set -euo pipefail
|
||||||
|
cd /home/droid/project/diffusion_policy-smoke
|
||||||
|
export PYTHONUNBUFFERED=1
|
||||||
|
export SWANLAB_API_KEY='PSZrBMLx1XAjDjvmhUcNz'
|
||||||
|
export LD_LIBRARY_PATH="$(printf '%s:' .venv/lib/python3.9/site-packages/nvidia/*/lib | sed 's/:$//')"
|
||||||
|
run_exp() {
|
||||||
|
local name="$1" emb="$2" layer="$3"
|
||||||
|
echo "[$(date '+%F %T')] START $name emb=$emb layer=$layer"
|
||||||
|
.venv/bin/python train.py \
|
||||||
|
--config-dir=. \
|
||||||
|
--config-name=image_pusht_diffusion_policy_dit_imf_attnres_full.yaml \
|
||||||
|
training.device=cuda:0 \
|
||||||
|
training.num_epochs=350 \
|
||||||
|
training.resume=false \
|
||||||
|
exp_name="$name" \
|
||||||
|
logging.group=imf_pusht_attnres_arch_sweep \
|
||||||
|
logging.name="$name" \
|
||||||
|
logging.resume=false \
|
||||||
|
logging.id=null \
|
||||||
|
hydra.run.dir="data/outputs/$name" \
|
||||||
|
policy.n_emb="$emb" \
|
||||||
|
policy.n_layer="$layer" \
|
||||||
|
> "data/run_logs/${name}.log" 2>&1
|
||||||
|
echo "[$(date '+%F %T')] END $name"
|
||||||
|
}
|
||||||
|
run_exp imf_attnres_emb384_layer12_seed42_5880gpu0 384 12
|
||||||
|
run_exp imf_attnres_emb256_layer12_seed42_5880gpu0 256 12
|
||||||
|
run_exp imf_attnres_emb128_layer12_seed42_5880gpu0 128 12
|
||||||
29
scripts/pusht/imf_attnres_remote_gpu1_queue.sh
Executable file
29
scripts/pusht/imf_attnres_remote_gpu1_queue.sh
Executable file
@@ -0,0 +1,29 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
set -euo pipefail
|
||||||
|
cd /home/droid/project/diffusion_policy-smoke
|
||||||
|
export PYTHONUNBUFFERED=1
|
||||||
|
export SWANLAB_API_KEY='PSZrBMLx1XAjDjvmhUcNz'
|
||||||
|
export LD_LIBRARY_PATH="$(printf '%s:' .venv/lib/python3.9/site-packages/nvidia/*/lib | sed 's/:$//')"
|
||||||
|
run_exp() {
|
||||||
|
local name="$1" emb="$2" layer="$3"
|
||||||
|
echo "[$(date '+%F %T')] START $name emb=$emb layer=$layer"
|
||||||
|
.venv/bin/python train.py \
|
||||||
|
--config-dir=. \
|
||||||
|
--config-name=image_pusht_diffusion_policy_dit_imf_attnres_full.yaml \
|
||||||
|
training.device=cuda:1 \
|
||||||
|
training.num_epochs=350 \
|
||||||
|
training.resume=false \
|
||||||
|
exp_name="$name" \
|
||||||
|
logging.group=imf_pusht_attnres_arch_sweep \
|
||||||
|
logging.name="$name" \
|
||||||
|
logging.resume=false \
|
||||||
|
logging.id=null \
|
||||||
|
hydra.run.dir="data/outputs/$name" \
|
||||||
|
policy.n_emb="$emb" \
|
||||||
|
policy.n_layer="$layer" \
|
||||||
|
> "data/run_logs/${name}.log" 2>&1
|
||||||
|
echo "[$(date '+%F %T')] END $name"
|
||||||
|
}
|
||||||
|
run_exp imf_attnres_emb384_layer6_seed42_5880gpu1 384 6
|
||||||
|
run_exp imf_attnres_emb256_layer18_seed42_5880gpu1 256 18
|
||||||
|
run_exp imf_attnres_emb128_layer18_seed42_5880gpu1 128 18
|
||||||
@@ -44,3 +44,34 @@ def test_imf_transformer_forward_signature_and_shape_single_head():
|
|||||||
pred_u = model(sample, r, t, cond=cond)
|
pred_u = model(sample, r, t, cond=cond)
|
||||||
|
|
||||||
assert pred_u.shape == sample.shape
|
assert pred_u.shape == sample.shape
|
||||||
|
|
||||||
|
|
||||||
|
def test_imf_transformer_attnres_full_backbone_forward_shape_and_optimizer():
|
||||||
|
model = IMFTransformerForDiffusion(
|
||||||
|
input_dim=3,
|
||||||
|
output_dim=3,
|
||||||
|
horizon=5,
|
||||||
|
n_obs_steps=2,
|
||||||
|
cond_dim=4,
|
||||||
|
n_layer=2,
|
||||||
|
n_head=1,
|
||||||
|
n_emb=16,
|
||||||
|
p_drop_emb=0.0,
|
||||||
|
p_drop_attn=0.0,
|
||||||
|
causal_attn=False,
|
||||||
|
time_as_cond=True,
|
||||||
|
obs_as_cond=True,
|
||||||
|
n_cond_layers=0,
|
||||||
|
backbone_type='attnres_full',
|
||||||
|
)
|
||||||
|
optimizer = model.configure_optimizers()
|
||||||
|
|
||||||
|
sample = torch.randn(2, 5, 3)
|
||||||
|
r = torch.rand(2)
|
||||||
|
t = torch.rand(2)
|
||||||
|
cond = torch.randn(2, 2, 4)
|
||||||
|
|
||||||
|
pred_u = model(sample, r, t, cond=cond)
|
||||||
|
|
||||||
|
assert pred_u.shape == sample.shape
|
||||||
|
assert optimizer is not None
|
||||||
|
|||||||
@@ -42,3 +42,16 @@ def test_image_pusht_dit_imf_fullattn_config_uses_exp_name_and_disables_causal_a
|
|||||||
assert cfg.logging.id is None
|
assert cfg.logging.id is None
|
||||||
assert cfg.logging.group == cfg.exp_name
|
assert cfg.logging.group == cfg.exp_name
|
||||||
assert cfg.policy.causal_attn is False
|
assert cfg.policy.causal_attn is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_image_pusht_dit_imf_attnres_full_config_uses_exp_name_and_disables_causal_attention():
|
||||||
|
cfg = _load_cfg('image_pusht_diffusion_policy_dit_imf_attnres_full.yaml')
|
||||||
|
|
||||||
|
assert cfg.logging.backend == 'swanlab'
|
||||||
|
assert cfg.logging.mode == 'online'
|
||||||
|
assert cfg.logging.name == cfg.exp_name
|
||||||
|
assert cfg.logging.resume is False
|
||||||
|
assert cfg.logging.id is None
|
||||||
|
assert cfg.logging.group == cfg.exp_name
|
||||||
|
assert cfg.policy.causal_attn is False
|
||||||
|
assert cfg.policy.backbone_type == 'attnres_full'
|
||||||
|
|||||||
Reference in New Issue
Block a user