78 lines
1.9 KiB
Python
78 lines
1.9 KiB
Python
import inspect
|
|
import pathlib
|
|
import sys
|
|
|
|
import torch
|
|
|
|
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
|
|
if str(ROOT_DIR) not in sys.path:
|
|
sys.path.append(str(ROOT_DIR))
|
|
|
|
from diffusion_policy.model.diffusion.imf_transformer_for_diffusion import ( # noqa: E402
|
|
IMFTransformerForDiffusion,
|
|
)
|
|
|
|
|
|
def test_imf_transformer_forward_signature_and_shape_single_head():
|
|
signature = inspect.signature(IMFTransformerForDiffusion.forward)
|
|
assert list(signature.parameters)[:5] == ['self', 'sample', 'r', 't', 'cond']
|
|
assert signature.parameters['cond'].default is None
|
|
|
|
model = IMFTransformerForDiffusion(
|
|
input_dim=3,
|
|
output_dim=3,
|
|
horizon=5,
|
|
n_obs_steps=2,
|
|
cond_dim=4,
|
|
n_layer=1,
|
|
n_head=1,
|
|
n_emb=16,
|
|
p_drop_emb=0.0,
|
|
p_drop_attn=0.0,
|
|
causal_attn=True,
|
|
time_as_cond=True,
|
|
obs_as_cond=True,
|
|
n_cond_layers=0,
|
|
)
|
|
model.configure_optimizers()
|
|
|
|
sample = torch.randn(2, 5, 3)
|
|
r = torch.rand(2)
|
|
t = torch.rand(2)
|
|
cond = torch.randn(2, 2, 4)
|
|
|
|
pred_u = model(sample, r, t, cond=cond)
|
|
|
|
assert pred_u.shape == sample.shape
|
|
|
|
|
|
def test_imf_transformer_attnres_full_backbone_forward_shape_and_optimizer():
|
|
model = IMFTransformerForDiffusion(
|
|
input_dim=3,
|
|
output_dim=3,
|
|
horizon=5,
|
|
n_obs_steps=2,
|
|
cond_dim=4,
|
|
n_layer=2,
|
|
n_head=1,
|
|
n_emb=16,
|
|
p_drop_emb=0.0,
|
|
p_drop_attn=0.0,
|
|
causal_attn=False,
|
|
time_as_cond=True,
|
|
obs_as_cond=True,
|
|
n_cond_layers=0,
|
|
backbone_type='attnres_full',
|
|
)
|
|
optimizer = model.configure_optimizers()
|
|
|
|
sample = torch.randn(2, 5, 3)
|
|
r = torch.rand(2)
|
|
t = torch.rand(2)
|
|
cond = torch.randn(2, 2, 4)
|
|
|
|
pred_u = model(sample, r, t, cond=cond)
|
|
|
|
assert pred_u.shape == sample.shape
|
|
assert optimizer is not None
|