feat: add pusht imf attnres backbone
This commit is contained in:
@@ -44,3 +44,34 @@ def test_imf_transformer_forward_signature_and_shape_single_head():
|
||||
pred_u = model(sample, r, t, cond=cond)
|
||||
|
||||
assert pred_u.shape == sample.shape
|
||||
|
||||
|
||||
def test_imf_transformer_attnres_full_backbone_forward_shape_and_optimizer():
|
||||
model = IMFTransformerForDiffusion(
|
||||
input_dim=3,
|
||||
output_dim=3,
|
||||
horizon=5,
|
||||
n_obs_steps=2,
|
||||
cond_dim=4,
|
||||
n_layer=2,
|
||||
n_head=1,
|
||||
n_emb=16,
|
||||
p_drop_emb=0.0,
|
||||
p_drop_attn=0.0,
|
||||
causal_attn=False,
|
||||
time_as_cond=True,
|
||||
obs_as_cond=True,
|
||||
n_cond_layers=0,
|
||||
backbone_type='attnres_full',
|
||||
)
|
||||
optimizer = model.configure_optimizers()
|
||||
|
||||
sample = torch.randn(2, 5, 3)
|
||||
r = torch.rand(2)
|
||||
t = torch.rand(2)
|
||||
cond = torch.randn(2, 2, 4)
|
||||
|
||||
pred_u = model(sample, r, t, cond=cond)
|
||||
|
||||
assert pred_u.shape == sample.shape
|
||||
assert optimizer is not None
|
||||
|
||||
@@ -42,3 +42,16 @@ def test_image_pusht_dit_imf_fullattn_config_uses_exp_name_and_disables_causal_a
|
||||
assert cfg.logging.id is None
|
||||
assert cfg.logging.group == cfg.exp_name
|
||||
assert cfg.policy.causal_attn is False
|
||||
|
||||
|
||||
def test_image_pusht_dit_imf_attnres_full_config_uses_exp_name_and_disables_causal_attention():
|
||||
cfg = _load_cfg('image_pusht_diffusion_policy_dit_imf_attnres_full.yaml')
|
||||
|
||||
assert cfg.logging.backend == 'swanlab'
|
||||
assert cfg.logging.mode == 'online'
|
||||
assert cfg.logging.name == cfg.exp_name
|
||||
assert cfg.logging.resume is False
|
||||
assert cfg.logging.id is None
|
||||
assert cfg.logging.group == cfg.exp_name
|
||||
assert cfg.policy.causal_attn is False
|
||||
assert cfg.policy.backbone_type == 'attnres_full'
|
||||
|
||||
Reference in New Issue
Block a user