314 lines
9.1 KiB
Python
314 lines
9.1 KiB
Python
import pathlib
|
|
import sys
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
|
|
if str(ROOT_DIR) not in sys.path:
|
|
sys.path.append(str(ROOT_DIR))
|
|
|
|
import diffusion_policy.policy.imf_transformer_hybrid_image_policy as policy_module # noqa: E402
|
|
from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin # noqa: E402
|
|
from diffusion_policy.policy.imf_transformer_hybrid_image_policy import ( # noqa: E402
|
|
IMFTransformerHybridImagePolicy,
|
|
)
|
|
|
|
|
|
class ConstantModel(nn.Module):
|
|
def __init__(self, value):
|
|
super().__init__()
|
|
self.value = value
|
|
|
|
def forward(self, sample, r, t, cond=None):
|
|
return torch.full_like(sample, self.value)
|
|
|
|
|
|
class AffineModel(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.tensor(2.0))
|
|
|
|
def forward(self, sample, r, t, cond=None):
|
|
return sample * self.weight + (r + t).view(-1, 1, 1)
|
|
|
|
|
|
class SumMixModel(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.tensor(2.0))
|
|
|
|
def forward(self, sample, r, t, cond=None):
|
|
mixed = sample.sum(dim=-1, keepdim=True).expand_as(sample)
|
|
return mixed * self.weight + t.view(-1, 1, 1)
|
|
|
|
|
|
class TrackingContext:
|
|
def __init__(self):
|
|
self.active = False
|
|
self.enter_count = 0
|
|
|
|
def __enter__(self):
|
|
self.active = True
|
|
self.enter_count += 1
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc, tb):
|
|
self.active = False
|
|
return False
|
|
|
|
|
|
def make_policy(model):
|
|
policy = IMFTransformerHybridImagePolicy.__new__(IMFTransformerHybridImagePolicy)
|
|
ModuleAttrMixin.__init__(policy)
|
|
policy.model = model
|
|
return policy
|
|
|
|
|
|
def fake_parent_init(
|
|
self,
|
|
shape_meta,
|
|
noise_scheduler,
|
|
horizon,
|
|
n_action_steps,
|
|
n_obs_steps,
|
|
num_inference_steps=None,
|
|
crop_shape=(76, 76),
|
|
obs_encoder_group_norm=False,
|
|
eval_fixed_crop=False,
|
|
n_layer=8,
|
|
n_cond_layers=0,
|
|
n_head=1,
|
|
n_emb=256,
|
|
p_drop_emb=0.0,
|
|
p_drop_attn=0.3,
|
|
causal_attn=True,
|
|
time_as_cond=True,
|
|
obs_as_cond=True,
|
|
pred_action_steps_only=False,
|
|
**kwargs,
|
|
):
|
|
ModuleAttrMixin.__init__(self)
|
|
self.action_dim = shape_meta['action']['shape'][0]
|
|
self.obs_feature_dim = 4
|
|
self.obs_as_cond = obs_as_cond
|
|
self.pred_action_steps_only = pred_action_steps_only
|
|
self.n_action_steps = n_action_steps
|
|
self.n_obs_steps = n_obs_steps
|
|
self.horizon = horizon
|
|
|
|
|
|
@pytest.fixture
|
|
def shape_meta():
|
|
return {
|
|
'action': {'shape': [2]},
|
|
'obs': {},
|
|
}
|
|
|
|
|
|
def test_sample_one_step_uses_imf_update_formula():
|
|
policy = make_policy(ConstantModel(0.25))
|
|
z_1 = torch.tensor([
|
|
[[1.0, -1.0], [0.5, 0.0]],
|
|
[[2.0, 3.0], [-2.0, 4.0]],
|
|
])
|
|
r = torch.zeros(z_1.shape[0])
|
|
t = torch.ones(z_1.shape[0])
|
|
|
|
x_hat = policy._sample_one_step(z_1, r=r, t=t, cond=None)
|
|
|
|
expected = z_1 - (t - r).view(-1, 1, 1) * 0.25
|
|
assert torch.allclose(x_hat, expected)
|
|
|
|
|
|
def test_compound_velocity_uses_detached_du_dt_term():
|
|
policy = make_policy(ConstantModel(0.0))
|
|
u = torch.tensor([[[1.0], [2.0]]], requires_grad=True)
|
|
du_dt = torch.tensor([[[3.0], [4.0]]], requires_grad=True)
|
|
r = torch.tensor([0.2])
|
|
t = torch.tensor([0.8])
|
|
|
|
compound = policy._compound_velocity(u, du_dt, r, t)
|
|
expected = u + (t - r).view(-1, 1, 1) * du_dt.detach()
|
|
|
|
assert torch.allclose(compound, expected)
|
|
|
|
compound.sum().backward()
|
|
assert u.grad is not None
|
|
assert du_dt.grad is None
|
|
|
|
|
|
def test_compute_u_and_du_dt_uses_math_sdpa_context_for_torch_func_jvp(monkeypatch):
|
|
tracker = TrackingContext()
|
|
|
|
def fake_jvp(fn, primals, tangents):
|
|
assert tracker.active is True
|
|
return fn(*primals), torch.zeros_like(primals[0])
|
|
|
|
monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', fake_jvp)
|
|
|
|
policy = make_policy(ConstantModel(0.5))
|
|
policy._jvp_math_sdp_context = lambda tensor: tracker
|
|
z_t = torch.randn(2, 3, 4)
|
|
r = torch.rand(2, requires_grad=True)
|
|
t = torch.rand(2, requires_grad=True)
|
|
v = torch.randn_like(z_t, requires_grad=True)
|
|
|
|
policy._compute_u_and_du_dt(z_t, r, t, cond=None, v=v)
|
|
|
|
assert tracker.enter_count == 1
|
|
|
|
|
|
def test_compute_u_and_du_dt_uses_math_sdpa_context_for_autograd_fallback(monkeypatch):
|
|
tracker = TrackingContext()
|
|
|
|
def fake_autograd_jvp(fn, primals, tangents, create_graph=False, strict=False):
|
|
assert tracker.active is True
|
|
return fn(*primals), torch.zeros_like(primals[0])
|
|
|
|
monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', None)
|
|
monkeypatch.setattr(policy_module.torch.autograd.functional, 'jvp', fake_autograd_jvp)
|
|
|
|
policy = make_policy(ConstantModel(0.5))
|
|
policy._jvp_math_sdp_context = lambda tensor: tracker
|
|
z_t = torch.randn(2, 3, 4)
|
|
r = torch.rand(2, requires_grad=True)
|
|
t = torch.rand(2, requires_grad=True)
|
|
v = torch.randn_like(z_t, requires_grad=True)
|
|
|
|
policy._compute_u_and_du_dt(z_t, r, t, cond=None, v=v)
|
|
|
|
assert tracker.enter_count == 1
|
|
|
|
|
|
def test_compute_u_and_du_dt_uses_detached_v_zero_r_unit_t_and_reapplies_conditioning(monkeypatch):
|
|
captured = {}
|
|
|
|
def fake_jvp(fn, primals, tangents):
|
|
captured['tangents'] = tangents
|
|
captured['primal_output'] = fn(*primals)
|
|
return captured['primal_output'], torch.zeros_like(primals[0])
|
|
|
|
monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', fake_jvp)
|
|
|
|
policy = make_policy(SumMixModel())
|
|
z_t = torch.tensor([[[1.0, 2.0, 3.0]]])
|
|
r = torch.rand(1, requires_grad=True)
|
|
t = torch.rand(1, requires_grad=True)
|
|
v = torch.tensor([[[10.0, 20.0, 30.0]]], requires_grad=True)
|
|
condition_mask = torch.tensor([[[False, True, False]]])
|
|
condition_data = torch.tensor([[[0.0, 7.0, 0.0]]])
|
|
|
|
policy._compute_u_and_du_dt(
|
|
z_t,
|
|
r,
|
|
t,
|
|
cond=None,
|
|
v=v,
|
|
condition_data=condition_data,
|
|
condition_mask=condition_mask,
|
|
)
|
|
|
|
tangent_v, tangent_r, tangent_t = captured['tangents']
|
|
assert torch.equal(tangent_v, v.detach())
|
|
assert tangent_v.requires_grad is False
|
|
assert torch.equal(tangent_r, torch.zeros_like(r))
|
|
assert torch.equal(tangent_t, torch.ones_like(t))
|
|
|
|
conditioned = z_t.clone()
|
|
conditioned[condition_mask] = condition_data[condition_mask]
|
|
expected_primal = policy.model(conditioned, r, t, cond=None)
|
|
assert torch.allclose(captured['primal_output'], expected_primal)
|
|
|
|
|
|
def test_compute_u_and_du_dt_fallback_blocks_conditioned_tangent_leakage_and_keeps_primal_gradients(monkeypatch):
|
|
monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', None)
|
|
|
|
policy = make_policy(SumMixModel())
|
|
z_t = torch.tensor([[[1.0, 2.0, 3.0]]], requires_grad=True)
|
|
r = torch.rand(1, requires_grad=True)
|
|
t = torch.rand(1, requires_grad=True)
|
|
v = torch.tensor([[[1.0, 10.0, 100.0]]], requires_grad=True)
|
|
condition_mask = torch.tensor([[[False, True, False]]])
|
|
condition_data = torch.tensor([[[0.0, 7.0, 0.0]]])
|
|
|
|
u, du_dt = policy._compute_u_and_du_dt(
|
|
z_t,
|
|
r,
|
|
t,
|
|
cond=None,
|
|
v=v,
|
|
condition_data=condition_data,
|
|
condition_mask=condition_mask,
|
|
)
|
|
|
|
conditioned = z_t.detach().clone()
|
|
conditioned[condition_mask] = condition_data[condition_mask]
|
|
expected_u = policy.model(conditioned, r, t, cond=None)
|
|
expected_du_dt_scalar = policy.model.weight.detach() * torch.tensor(101.0) + 1.0
|
|
expected_du_dt = torch.full_like(z_t, expected_du_dt_scalar)
|
|
|
|
assert u.shape == z_t.shape
|
|
assert du_dt.shape == z_t.shape
|
|
assert torch.allclose(u, expected_u)
|
|
assert torch.allclose(du_dt, expected_du_dt)
|
|
|
|
u.sum().backward()
|
|
assert policy.model.weight.grad is not None
|
|
assert torch.count_nonzero(policy.model.weight.grad) > 0
|
|
|
|
|
|
def test_init_uses_action_step_horizon_when_pred_action_steps_only(monkeypatch, shape_meta):
|
|
monkeypatch.setattr(
|
|
policy_module.DiffusionTransformerHybridImagePolicy,
|
|
'__init__',
|
|
fake_parent_init,
|
|
)
|
|
|
|
policy = IMFTransformerHybridImagePolicy(
|
|
shape_meta=shape_meta,
|
|
noise_scheduler=None,
|
|
horizon=10,
|
|
n_action_steps=4,
|
|
n_obs_steps=2,
|
|
num_inference_steps=1,
|
|
n_layer=1,
|
|
n_head=1,
|
|
n_emb=16,
|
|
p_drop_emb=0.0,
|
|
p_drop_attn=0.0,
|
|
causal_attn=True,
|
|
obs_as_cond=True,
|
|
pred_action_steps_only=True,
|
|
)
|
|
|
|
assert policy.model.horizon == 4
|
|
assert policy.num_inference_steps == 1
|
|
|
|
|
|
def test_init_rejects_non_one_step_inference(monkeypatch, shape_meta):
|
|
monkeypatch.setattr(
|
|
policy_module.DiffusionTransformerHybridImagePolicy,
|
|
'__init__',
|
|
fake_parent_init,
|
|
)
|
|
|
|
with pytest.raises(ValueError, match='num_inference_steps'):
|
|
IMFTransformerHybridImagePolicy(
|
|
shape_meta=shape_meta,
|
|
noise_scheduler=None,
|
|
horizon=10,
|
|
n_action_steps=4,
|
|
n_obs_steps=2,
|
|
num_inference_steps=2,
|
|
n_layer=1,
|
|
n_head=1,
|
|
n_emb=16,
|
|
p_drop_emb=0.0,
|
|
p_drop_attn=0.0,
|
|
causal_attn=True,
|
|
obs_as_cond=True,
|
|
pred_action_steps_only=False,
|
|
)
|