feat: add pusht imf attnres backbone

This commit is contained in:
Logic
2026-03-29 11:15:59 +08:00
parent 78ab18e8f3
commit 185ed6596c
8 changed files with 647 additions and 61 deletions

View File

@@ -44,3 +44,34 @@ def test_imf_transformer_forward_signature_and_shape_single_head():
pred_u = model(sample, r, t, cond=cond)
assert pred_u.shape == sample.shape
def test_imf_transformer_attnres_full_backbone_forward_shape_and_optimizer():
model = IMFTransformerForDiffusion(
input_dim=3,
output_dim=3,
horizon=5,
n_obs_steps=2,
cond_dim=4,
n_layer=2,
n_head=1,
n_emb=16,
p_drop_emb=0.0,
p_drop_attn=0.0,
causal_attn=False,
time_as_cond=True,
obs_as_cond=True,
n_cond_layers=0,
backbone_type='attnres_full',
)
optimizer = model.configure_optimizers()
sample = torch.randn(2, 5, 3)
r = torch.rand(2)
t = torch.rand(2)
cond = torch.randn(2, 2, 4)
pred_u = model(sample, r, t, cond=cond)
assert pred_u.shape == sample.shape
assert optimizer is not None

View File

@@ -42,3 +42,16 @@ def test_image_pusht_dit_imf_fullattn_config_uses_exp_name_and_disables_causal_a
assert cfg.logging.id is None
assert cfg.logging.group == cfg.exp_name
assert cfg.policy.causal_attn is False
def test_image_pusht_dit_imf_attnres_full_config_uses_exp_name_and_disables_causal_attention():
cfg = _load_cfg('image_pusht_diffusion_policy_dit_imf_attnres_full.yaml')
assert cfg.logging.backend == 'swanlab'
assert cfg.logging.mode == 'online'
assert cfg.logging.name == cfg.exp_name
assert cfg.logging.resume is False
assert cfg.logging.id is None
assert cfg.logging.group == cfg.exp_name
assert cfg.policy.causal_attn is False
assert cfg.policy.backbone_type == 'attnres_full'