import inspect import pathlib import sys import torch ROOT_DIR = pathlib.Path(__file__).resolve().parents[1] if str(ROOT_DIR) not in sys.path: sys.path.append(str(ROOT_DIR)) from diffusion_policy.model.diffusion.imf_transformer_for_diffusion import ( # noqa: E402 IMFTransformerForDiffusion, ) def test_imf_transformer_forward_signature_and_shape_single_head(): signature = inspect.signature(IMFTransformerForDiffusion.forward) assert list(signature.parameters)[:5] == ['self', 'sample', 'r', 't', 'cond'] assert signature.parameters['cond'].default is None model = IMFTransformerForDiffusion( input_dim=3, output_dim=3, horizon=5, n_obs_steps=2, cond_dim=4, n_layer=1, n_head=1, n_emb=16, p_drop_emb=0.0, p_drop_attn=0.0, causal_attn=True, time_as_cond=True, obs_as_cond=True, n_cond_layers=0, ) model.configure_optimizers() sample = torch.randn(2, 5, 3) r = torch.rand(2) t = torch.rand(2) cond = torch.randn(2, 2, 4) pred_u = model(sample, r, t, cond=cond) assert pred_u.shape == sample.shape