feat: add pusht image imf transformer

This commit is contained in:
Logic
2026-03-26 20:41:37 +08:00
parent 5e7ae6cfa5
commit 4cd5085b33
5 changed files with 960 additions and 0 deletions

View File

@@ -0,0 +1,46 @@
import inspect
import pathlib
import sys
import torch
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
if str(ROOT_DIR) not in sys.path:
sys.path.append(str(ROOT_DIR))
from diffusion_policy.model.diffusion.imf_transformer_for_diffusion import ( # noqa: E402
IMFTransformerForDiffusion,
)
def test_imf_transformer_forward_signature_and_shape_single_head():
signature = inspect.signature(IMFTransformerForDiffusion.forward)
assert list(signature.parameters)[:5] == ['self', 'sample', 'r', 't', 'cond']
assert signature.parameters['cond'].default is None
model = IMFTransformerForDiffusion(
input_dim=3,
output_dim=3,
horizon=5,
n_obs_steps=2,
cond_dim=4,
n_layer=1,
n_head=1,
n_emb=16,
p_drop_emb=0.0,
p_drop_attn=0.0,
causal_attn=True,
time_as_cond=True,
obs_as_cond=True,
n_cond_layers=0,
)
model.configure_optimizers()
sample = torch.randn(2, 5, 3)
r = torch.rand(2)
t = torch.rand(2)
cond = torch.randn(2, 2, 4)
pred_u = model(sample, r, t, cond=cond)
assert pred_u.shape == sample.shape

View File

@@ -0,0 +1,313 @@
import pathlib
import sys
import pytest
import torch
import torch.nn as nn
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
if str(ROOT_DIR) not in sys.path:
sys.path.append(str(ROOT_DIR))
import diffusion_policy.policy.imf_transformer_hybrid_image_policy as policy_module # noqa: E402
from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin # noqa: E402
from diffusion_policy.policy.imf_transformer_hybrid_image_policy import ( # noqa: E402
IMFTransformerHybridImagePolicy,
)
class ConstantModel(nn.Module):
def __init__(self, value):
super().__init__()
self.value = value
def forward(self, sample, r, t, cond=None):
return torch.full_like(sample, self.value)
class AffineModel(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.tensor(2.0))
def forward(self, sample, r, t, cond=None):
return sample * self.weight + (r + t).view(-1, 1, 1)
class SumMixModel(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.tensor(2.0))
def forward(self, sample, r, t, cond=None):
mixed = sample.sum(dim=-1, keepdim=True).expand_as(sample)
return mixed * self.weight + t.view(-1, 1, 1)
class TrackingContext:
def __init__(self):
self.active = False
self.enter_count = 0
def __enter__(self):
self.active = True
self.enter_count += 1
return self
def __exit__(self, exc_type, exc, tb):
self.active = False
return False
def make_policy(model):
policy = IMFTransformerHybridImagePolicy.__new__(IMFTransformerHybridImagePolicy)
ModuleAttrMixin.__init__(policy)
policy.model = model
return policy
def fake_parent_init(
self,
shape_meta,
noise_scheduler,
horizon,
n_action_steps,
n_obs_steps,
num_inference_steps=None,
crop_shape=(76, 76),
obs_encoder_group_norm=False,
eval_fixed_crop=False,
n_layer=8,
n_cond_layers=0,
n_head=1,
n_emb=256,
p_drop_emb=0.0,
p_drop_attn=0.3,
causal_attn=True,
time_as_cond=True,
obs_as_cond=True,
pred_action_steps_only=False,
**kwargs,
):
ModuleAttrMixin.__init__(self)
self.action_dim = shape_meta['action']['shape'][0]
self.obs_feature_dim = 4
self.obs_as_cond = obs_as_cond
self.pred_action_steps_only = pred_action_steps_only
self.n_action_steps = n_action_steps
self.n_obs_steps = n_obs_steps
self.horizon = horizon
@pytest.fixture
def shape_meta():
return {
'action': {'shape': [2]},
'obs': {},
}
def test_sample_one_step_uses_imf_update_formula():
policy = make_policy(ConstantModel(0.25))
z_1 = torch.tensor([
[[1.0, -1.0], [0.5, 0.0]],
[[2.0, 3.0], [-2.0, 4.0]],
])
r = torch.zeros(z_1.shape[0])
t = torch.ones(z_1.shape[0])
x_hat = policy._sample_one_step(z_1, r=r, t=t, cond=None)
expected = z_1 - (t - r).view(-1, 1, 1) * 0.25
assert torch.allclose(x_hat, expected)
def test_compound_velocity_uses_detached_du_dt_term():
policy = make_policy(ConstantModel(0.0))
u = torch.tensor([[[1.0], [2.0]]], requires_grad=True)
du_dt = torch.tensor([[[3.0], [4.0]]], requires_grad=True)
r = torch.tensor([0.2])
t = torch.tensor([0.8])
compound = policy._compound_velocity(u, du_dt, r, t)
expected = u + (t - r).view(-1, 1, 1) * du_dt.detach()
assert torch.allclose(compound, expected)
compound.sum().backward()
assert u.grad is not None
assert du_dt.grad is None
def test_compute_u_and_du_dt_uses_math_sdpa_context_for_torch_func_jvp(monkeypatch):
tracker = TrackingContext()
def fake_jvp(fn, primals, tangents):
assert tracker.active is True
return fn(*primals), torch.zeros_like(primals[0])
monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', fake_jvp)
policy = make_policy(ConstantModel(0.5))
policy._jvp_math_sdp_context = lambda tensor: tracker
z_t = torch.randn(2, 3, 4)
r = torch.rand(2, requires_grad=True)
t = torch.rand(2, requires_grad=True)
v = torch.randn_like(z_t, requires_grad=True)
policy._compute_u_and_du_dt(z_t, r, t, cond=None, v=v)
assert tracker.enter_count == 1
def test_compute_u_and_du_dt_uses_math_sdpa_context_for_autograd_fallback(monkeypatch):
tracker = TrackingContext()
def fake_autograd_jvp(fn, primals, tangents, create_graph=False, strict=False):
assert tracker.active is True
return fn(*primals), torch.zeros_like(primals[0])
monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', None)
monkeypatch.setattr(policy_module.torch.autograd.functional, 'jvp', fake_autograd_jvp)
policy = make_policy(ConstantModel(0.5))
policy._jvp_math_sdp_context = lambda tensor: tracker
z_t = torch.randn(2, 3, 4)
r = torch.rand(2, requires_grad=True)
t = torch.rand(2, requires_grad=True)
v = torch.randn_like(z_t, requires_grad=True)
policy._compute_u_and_du_dt(z_t, r, t, cond=None, v=v)
assert tracker.enter_count == 1
def test_compute_u_and_du_dt_uses_detached_v_zero_r_unit_t_and_reapplies_conditioning(monkeypatch):
captured = {}
def fake_jvp(fn, primals, tangents):
captured['tangents'] = tangents
captured['primal_output'] = fn(*primals)
return captured['primal_output'], torch.zeros_like(primals[0])
monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', fake_jvp)
policy = make_policy(SumMixModel())
z_t = torch.tensor([[[1.0, 2.0, 3.0]]])
r = torch.rand(1, requires_grad=True)
t = torch.rand(1, requires_grad=True)
v = torch.tensor([[[10.0, 20.0, 30.0]]], requires_grad=True)
condition_mask = torch.tensor([[[False, True, False]]])
condition_data = torch.tensor([[[0.0, 7.0, 0.0]]])
policy._compute_u_and_du_dt(
z_t,
r,
t,
cond=None,
v=v,
condition_data=condition_data,
condition_mask=condition_mask,
)
tangent_v, tangent_r, tangent_t = captured['tangents']
assert torch.equal(tangent_v, v.detach())
assert tangent_v.requires_grad is False
assert torch.equal(tangent_r, torch.zeros_like(r))
assert torch.equal(tangent_t, torch.ones_like(t))
conditioned = z_t.clone()
conditioned[condition_mask] = condition_data[condition_mask]
expected_primal = policy.model(conditioned, r, t, cond=None)
assert torch.allclose(captured['primal_output'], expected_primal)
def test_compute_u_and_du_dt_fallback_blocks_conditioned_tangent_leakage_and_keeps_primal_gradients(monkeypatch):
monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', None)
policy = make_policy(SumMixModel())
z_t = torch.tensor([[[1.0, 2.0, 3.0]]], requires_grad=True)
r = torch.rand(1, requires_grad=True)
t = torch.rand(1, requires_grad=True)
v = torch.tensor([[[1.0, 10.0, 100.0]]], requires_grad=True)
condition_mask = torch.tensor([[[False, True, False]]])
condition_data = torch.tensor([[[0.0, 7.0, 0.0]]])
u, du_dt = policy._compute_u_and_du_dt(
z_t,
r,
t,
cond=None,
v=v,
condition_data=condition_data,
condition_mask=condition_mask,
)
conditioned = z_t.detach().clone()
conditioned[condition_mask] = condition_data[condition_mask]
expected_u = policy.model(conditioned, r, t, cond=None)
expected_du_dt_scalar = policy.model.weight.detach() * torch.tensor(101.0) + 1.0
expected_du_dt = torch.full_like(z_t, expected_du_dt_scalar)
assert u.shape == z_t.shape
assert du_dt.shape == z_t.shape
assert torch.allclose(u, expected_u)
assert torch.allclose(du_dt, expected_du_dt)
u.sum().backward()
assert policy.model.weight.grad is not None
assert torch.count_nonzero(policy.model.weight.grad) > 0
def test_init_uses_action_step_horizon_when_pred_action_steps_only(monkeypatch, shape_meta):
monkeypatch.setattr(
policy_module.DiffusionTransformerHybridImagePolicy,
'__init__',
fake_parent_init,
)
policy = IMFTransformerHybridImagePolicy(
shape_meta=shape_meta,
noise_scheduler=None,
horizon=10,
n_action_steps=4,
n_obs_steps=2,
num_inference_steps=1,
n_layer=1,
n_head=1,
n_emb=16,
p_drop_emb=0.0,
p_drop_attn=0.0,
causal_attn=True,
obs_as_cond=True,
pred_action_steps_only=True,
)
assert policy.model.horizon == 4
assert policy.num_inference_steps == 1
def test_init_rejects_non_one_step_inference(monkeypatch, shape_meta):
monkeypatch.setattr(
policy_module.DiffusionTransformerHybridImagePolicy,
'__init__',
fake_parent_init,
)
with pytest.raises(ValueError, match='num_inference_steps'):
IMFTransformerHybridImagePolicy(
shape_meta=shape_meta,
noise_scheduler=None,
horizon=10,
n_action_steps=4,
n_obs_steps=2,
num_inference_steps=2,
n_layer=1,
n_head=1,
n_emb=16,
p_drop_emb=0.0,
p_drop_attn=0.0,
causal_attn=True,
obs_as_cond=True,
pred_action_steps_only=False,
)