import pathlib import sys import pytest import torch import torch.nn as nn ROOT_DIR = pathlib.Path(__file__).resolve().parents[1] if str(ROOT_DIR) not in sys.path: sys.path.append(str(ROOT_DIR)) import diffusion_policy.policy.imf_transformer_hybrid_image_policy as policy_module # noqa: E402 from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin # noqa: E402 from diffusion_policy.policy.imf_transformer_hybrid_image_policy import ( # noqa: E402 IMFTransformerHybridImagePolicy, ) class ConstantModel(nn.Module): def __init__(self, value): super().__init__() self.value = value def forward(self, sample, r, t, cond=None): return torch.full_like(sample, self.value) class AffineModel(nn.Module): def __init__(self): super().__init__() self.weight = nn.Parameter(torch.tensor(2.0)) def forward(self, sample, r, t, cond=None): return sample * self.weight + (r + t).view(-1, 1, 1) class SumMixModel(nn.Module): def __init__(self): super().__init__() self.weight = nn.Parameter(torch.tensor(2.0)) def forward(self, sample, r, t, cond=None): mixed = sample.sum(dim=-1, keepdim=True).expand_as(sample) return mixed * self.weight + t.view(-1, 1, 1) class TrackingContext: def __init__(self): self.active = False self.enter_count = 0 def __enter__(self): self.active = True self.enter_count += 1 return self def __exit__(self, exc_type, exc, tb): self.active = False return False def make_policy(model): policy = IMFTransformerHybridImagePolicy.__new__(IMFTransformerHybridImagePolicy) ModuleAttrMixin.__init__(policy) policy.model = model return policy def fake_parent_init( self, shape_meta, noise_scheduler, horizon, n_action_steps, n_obs_steps, num_inference_steps=None, crop_shape=(76, 76), obs_encoder_group_norm=False, eval_fixed_crop=False, n_layer=8, n_cond_layers=0, n_head=1, n_emb=256, p_drop_emb=0.0, p_drop_attn=0.3, causal_attn=True, time_as_cond=True, obs_as_cond=True, pred_action_steps_only=False, **kwargs, ): ModuleAttrMixin.__init__(self) self.action_dim = shape_meta['action']['shape'][0] self.obs_feature_dim = 4 self.obs_as_cond = obs_as_cond self.pred_action_steps_only = pred_action_steps_only self.n_action_steps = n_action_steps self.n_obs_steps = n_obs_steps self.horizon = horizon @pytest.fixture def shape_meta(): return { 'action': {'shape': [2]}, 'obs': {}, } def test_sample_one_step_uses_imf_update_formula(): policy = make_policy(ConstantModel(0.25)) z_1 = torch.tensor([ [[1.0, -1.0], [0.5, 0.0]], [[2.0, 3.0], [-2.0, 4.0]], ]) r = torch.zeros(z_1.shape[0]) t = torch.ones(z_1.shape[0]) x_hat = policy._sample_one_step(z_1, r=r, t=t, cond=None) expected = z_1 - (t - r).view(-1, 1, 1) * 0.25 assert torch.allclose(x_hat, expected) def test_compound_velocity_uses_detached_du_dt_term(): policy = make_policy(ConstantModel(0.0)) u = torch.tensor([[[1.0], [2.0]]], requires_grad=True) du_dt = torch.tensor([[[3.0], [4.0]]], requires_grad=True) r = torch.tensor([0.2]) t = torch.tensor([0.8]) compound = policy._compound_velocity(u, du_dt, r, t) expected = u + (t - r).view(-1, 1, 1) * du_dt.detach() assert torch.allclose(compound, expected) compound.sum().backward() assert u.grad is not None assert du_dt.grad is None def test_compute_u_and_du_dt_uses_math_sdpa_context_for_torch_func_jvp(monkeypatch): tracker = TrackingContext() def fake_jvp(fn, primals, tangents): assert tracker.active is True return fn(*primals), torch.zeros_like(primals[0]) monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', fake_jvp) policy = make_policy(ConstantModel(0.5)) policy._jvp_math_sdp_context = lambda tensor: tracker z_t = torch.randn(2, 3, 4) r = torch.rand(2, requires_grad=True) t = torch.rand(2, requires_grad=True) v = torch.randn_like(z_t, requires_grad=True) policy._compute_u_and_du_dt(z_t, r, t, cond=None, v=v) assert tracker.enter_count == 1 def test_compute_u_and_du_dt_uses_math_sdpa_context_for_autograd_fallback(monkeypatch): tracker = TrackingContext() def fake_autograd_jvp(fn, primals, tangents, create_graph=False, strict=False): assert tracker.active is True return fn(*primals), torch.zeros_like(primals[0]) monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', None) monkeypatch.setattr(policy_module.torch.autograd.functional, 'jvp', fake_autograd_jvp) policy = make_policy(ConstantModel(0.5)) policy._jvp_math_sdp_context = lambda tensor: tracker z_t = torch.randn(2, 3, 4) r = torch.rand(2, requires_grad=True) t = torch.rand(2, requires_grad=True) v = torch.randn_like(z_t, requires_grad=True) policy._compute_u_and_du_dt(z_t, r, t, cond=None, v=v) assert tracker.enter_count == 1 def test_compute_u_and_du_dt_uses_detached_v_zero_r_unit_t_and_reapplies_conditioning(monkeypatch): captured = {} def fake_jvp(fn, primals, tangents): captured['tangents'] = tangents captured['primal_output'] = fn(*primals) return captured['primal_output'], torch.zeros_like(primals[0]) monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', fake_jvp) policy = make_policy(SumMixModel()) z_t = torch.tensor([[[1.0, 2.0, 3.0]]]) r = torch.rand(1, requires_grad=True) t = torch.rand(1, requires_grad=True) v = torch.tensor([[[10.0, 20.0, 30.0]]], requires_grad=True) condition_mask = torch.tensor([[[False, True, False]]]) condition_data = torch.tensor([[[0.0, 7.0, 0.0]]]) policy._compute_u_and_du_dt( z_t, r, t, cond=None, v=v, condition_data=condition_data, condition_mask=condition_mask, ) tangent_v, tangent_r, tangent_t = captured['tangents'] assert torch.equal(tangent_v, v.detach()) assert tangent_v.requires_grad is False assert torch.equal(tangent_r, torch.zeros_like(r)) assert torch.equal(tangent_t, torch.ones_like(t)) conditioned = z_t.clone() conditioned[condition_mask] = condition_data[condition_mask] expected_primal = policy.model(conditioned, r, t, cond=None) assert torch.allclose(captured['primal_output'], expected_primal) def test_compute_u_and_du_dt_fallback_blocks_conditioned_tangent_leakage_and_keeps_primal_gradients(monkeypatch): monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', None) policy = make_policy(SumMixModel()) z_t = torch.tensor([[[1.0, 2.0, 3.0]]], requires_grad=True) r = torch.rand(1, requires_grad=True) t = torch.rand(1, requires_grad=True) v = torch.tensor([[[1.0, 10.0, 100.0]]], requires_grad=True) condition_mask = torch.tensor([[[False, True, False]]]) condition_data = torch.tensor([[[0.0, 7.0, 0.0]]]) u, du_dt = policy._compute_u_and_du_dt( z_t, r, t, cond=None, v=v, condition_data=condition_data, condition_mask=condition_mask, ) conditioned = z_t.detach().clone() conditioned[condition_mask] = condition_data[condition_mask] expected_u = policy.model(conditioned, r, t, cond=None) expected_du_dt_scalar = policy.model.weight.detach() * torch.tensor(101.0) + 1.0 expected_du_dt = torch.full_like(z_t, expected_du_dt_scalar) assert u.shape == z_t.shape assert du_dt.shape == z_t.shape assert torch.allclose(u, expected_u) assert torch.allclose(du_dt, expected_du_dt) u.sum().backward() assert policy.model.weight.grad is not None assert torch.count_nonzero(policy.model.weight.grad) > 0 def test_init_uses_action_step_horizon_when_pred_action_steps_only(monkeypatch, shape_meta): monkeypatch.setattr( policy_module.DiffusionTransformerHybridImagePolicy, '__init__', fake_parent_init, ) policy = IMFTransformerHybridImagePolicy( shape_meta=shape_meta, noise_scheduler=None, horizon=10, n_action_steps=4, n_obs_steps=2, num_inference_steps=1, n_layer=1, n_head=1, n_emb=16, p_drop_emb=0.0, p_drop_attn=0.0, causal_attn=True, obs_as_cond=True, pred_action_steps_only=True, ) assert policy.model.horizon == 4 assert policy.num_inference_steps == 1 def test_init_rejects_non_one_step_inference(monkeypatch, shape_meta): monkeypatch.setattr( policy_module.DiffusionTransformerHybridImagePolicy, '__init__', fake_parent_init, ) with pytest.raises(ValueError, match='num_inference_steps'): IMFTransformerHybridImagePolicy( shape_meta=shape_meta, noise_scheduler=None, horizon=10, n_action_steps=4, n_obs_steps=2, num_inference_steps=2, n_layer=1, n_head=1, n_emb=16, p_drop_emb=0.0, p_drop_attn=0.0, causal_attn=True, obs_as_cond=True, pred_action_steps_only=False, )