From 4cd5085b335720a90f4882ec3faf62f2023a075c Mon Sep 17 00:00:00 2001 From: Logic Date: Thu, 26 Mar 2026 20:41:37 +0800 Subject: [PATCH] feat: add pusht image imf transformer --- .../imf_transformer_for_diffusion.py | 298 +++++++++++++++++ .../imf_transformer_hybrid_image_policy.py | 273 +++++++++++++++ image_pusht_diffusion_policy_dit_imf.yaml | 30 ++ tests/test_imf_transformer_for_diffusion.py | 46 +++ ...est_imf_transformer_hybrid_image_policy.py | 313 ++++++++++++++++++ 5 files changed, 960 insertions(+) create mode 100644 diffusion_policy/model/diffusion/imf_transformer_for_diffusion.py create mode 100644 diffusion_policy/policy/imf_transformer_hybrid_image_policy.py create mode 100644 image_pusht_diffusion_policy_dit_imf.yaml create mode 100644 tests/test_imf_transformer_for_diffusion.py create mode 100644 tests/test_imf_transformer_hybrid_image_policy.py diff --git a/diffusion_policy/model/diffusion/imf_transformer_for_diffusion.py b/diffusion_policy/model/diffusion/imf_transformer_for_diffusion.py new file mode 100644 index 0000000..0967915 --- /dev/null +++ b/diffusion_policy/model/diffusion/imf_transformer_for_diffusion.py @@ -0,0 +1,298 @@ +from typing import Optional, Tuple, Union +import logging + +import torch +import torch.nn as nn + +from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin +from diffusion_policy.model.diffusion.positional_embedding import SinusoidalPosEmb + +logger = logging.getLogger(__name__) + + +class IMFTransformerForDiffusion(ModuleAttrMixin): + def __init__( + self, + input_dim: int, + output_dim: int, + horizon: int, + n_obs_steps: int = None, + cond_dim: int = 0, + n_layer: int = 12, + n_head: int = 1, + n_emb: int = 768, + p_drop_emb: float = 0.1, + p_drop_attn: float = 0.1, + causal_attn: bool = False, + time_as_cond: bool = True, + obs_as_cond: bool = False, + n_cond_layers: int = 0, + ) -> None: + super().__init__() + + assert n_head == 1, 'IMFTransformerForDiffusion currently supports single-head attention only.' + + if n_obs_steps is None: + n_obs_steps = horizon + + T = horizon + T_cond = 2 + if not time_as_cond: + T += 2 + T_cond -= 2 + obs_as_cond = cond_dim > 0 + if obs_as_cond: + assert time_as_cond + T_cond += n_obs_steps + + self.input_emb = nn.Linear(input_dim, n_emb) + self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb)) + self.drop = nn.Dropout(p_drop_emb) + + self.time_emb = SinusoidalPosEmb(n_emb) + self.cond_obs_emb = None + if obs_as_cond: + self.cond_obs_emb = nn.Linear(cond_dim, n_emb) + + self.cond_pos_emb = None + self.encoder = None + self.decoder = None + encoder_only = False + if T_cond > 0: + self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb)) + if n_cond_layers > 0: + encoder_layer = nn.TransformerEncoderLayer( + d_model=n_emb, + nhead=n_head, + dim_feedforward=4 * n_emb, + dropout=p_drop_attn, + activation='gelu', + batch_first=True, + norm_first=True, + ) + self.encoder = nn.TransformerEncoder( + encoder_layer=encoder_layer, + num_layers=n_cond_layers, + ) + else: + self.encoder = nn.Sequential( + nn.Linear(n_emb, 4 * n_emb), + nn.Mish(), + nn.Linear(4 * n_emb, n_emb), + ) + + decoder_layer = nn.TransformerDecoderLayer( + d_model=n_emb, + nhead=n_head, + dim_feedforward=4 * n_emb, + dropout=p_drop_attn, + activation='gelu', + batch_first=True, + norm_first=True, + ) + self.decoder = nn.TransformerDecoder( + decoder_layer=decoder_layer, + num_layers=n_layer, + ) + else: + encoder_only = True + encoder_layer = nn.TransformerEncoderLayer( + d_model=n_emb, + nhead=n_head, + dim_feedforward=4 * n_emb, + dropout=p_drop_attn, + activation='gelu', + batch_first=True, + norm_first=True, + ) + self.encoder = nn.TransformerEncoder( + encoder_layer=encoder_layer, + num_layers=n_layer, + ) + + if causal_attn: + sz = T + mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) + mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) + self.register_buffer('mask', mask) + + if time_as_cond and obs_as_cond: + S = T_cond + t_idx, s_idx = torch.meshgrid( + torch.arange(T), + torch.arange(S), + indexing='ij', + ) + mask = t_idx >= (s_idx - 2) + mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) + self.register_buffer('memory_mask', mask) + else: + self.memory_mask = None + else: + self.mask = None + self.memory_mask = None + + self.ln_f = nn.LayerNorm(n_emb) + self.head = nn.Linear(n_emb, output_dim) + + self.T = T + self.T_cond = T_cond + self.horizon = horizon + self.time_as_cond = time_as_cond + self.obs_as_cond = obs_as_cond + self.encoder_only = encoder_only + + self.apply(self._init_weights) + logger.info( + 'number of parameters: %e', + sum(p.numel() for p in self.parameters()), + ) + + def _init_weights(self, module): + ignore_types = ( + nn.Dropout, + SinusoidalPosEmb, + nn.TransformerEncoderLayer, + nn.TransformerDecoderLayer, + nn.TransformerEncoder, + nn.TransformerDecoder, + nn.ModuleList, + nn.Mish, + nn.Sequential, + ) + if isinstance(module, (nn.Linear, nn.Embedding)): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.MultiheadAttention): + weight_names = ['in_proj_weight', 'q_proj_weight', 'k_proj_weight', 'v_proj_weight'] + for name in weight_names: + weight = getattr(module, name) + if weight is not None: + torch.nn.init.normal_(weight, mean=0.0, std=0.02) + + bias_names = ['in_proj_bias', 'bias_k', 'bias_v'] + for name in bias_names: + bias = getattr(module, name) + if bias is not None: + torch.nn.init.zeros_(bias) + elif isinstance(module, nn.LayerNorm): + torch.nn.init.zeros_(module.bias) + torch.nn.init.ones_(module.weight) + elif isinstance(module, IMFTransformerForDiffusion): + torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02) + if module.cond_obs_emb is not None: + torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02) + elif isinstance(module, ignore_types): + pass + else: + raise RuntimeError(f'Unaccounted module {module}') + + def get_optim_groups(self, weight_decay: float = 1e-3): + decay = set() + no_decay = set() + whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention) + blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) + for mn, m in self.named_modules(): + for pn, _ in m.named_parameters(): + fpn = '%s.%s' % (mn, pn) if mn else pn + + if pn.endswith('bias'): + no_decay.add(fpn) + elif pn.startswith('bias'): + no_decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): + decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): + no_decay.add(fpn) + + no_decay.add('pos_emb') + no_decay.add('_dummy_variable') + if self.cond_pos_emb is not None: + no_decay.add('cond_pos_emb') + + param_dict = {pn: p for pn, p in self.named_parameters()} + inter_params = decay & no_decay + union_params = decay | no_decay + assert len(inter_params) == 0, f'parameters {inter_params} made it into both decay/no_decay sets!' + assert len(param_dict.keys() - union_params) == 0, ( + f'parameters {param_dict.keys() - union_params} were not separated into either decay/no_decay sets!' + ) + + optim_groups = [ + { + 'params': [param_dict[pn] for pn in sorted(list(decay))], + 'weight_decay': weight_decay, + }, + { + 'params': [param_dict[pn] for pn in sorted(list(no_decay))], + 'weight_decay': 0.0, + }, + ] + return optim_groups + + def configure_optimizers( + self, + learning_rate: float = 1e-4, + weight_decay: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.95), + ): + optim_groups = self.get_optim_groups(weight_decay=weight_decay) + optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas) + return optimizer + + def _prepare_time_input(self, value: Union[torch.Tensor, float, int], sample: torch.Tensor) -> torch.Tensor: + if not torch.is_tensor(value): + value = torch.tensor([value], dtype=sample.dtype, device=sample.device) + elif value.ndim == 0: + value = value[None].to(device=sample.device, dtype=sample.dtype) + else: + value = value.to(device=sample.device, dtype=sample.dtype) + return value.expand(sample.shape[0]) + + def forward( + self, + sample: torch.Tensor, + r: Union[torch.Tensor, float, int], + t: Union[torch.Tensor, float, int], + cond: Optional[torch.Tensor] = None, + ): + r = self._prepare_time_input(r, sample) + t = self._prepare_time_input(t, sample) + r_emb = self.time_emb(r).unsqueeze(1) + t_emb = self.time_emb(t).unsqueeze(1) + + input_emb = self.input_emb(sample) + + if self.encoder_only: + token_embeddings = torch.cat([r_emb, t_emb, input_emb], dim=1) + token_count = token_embeddings.shape[1] + position_embeddings = self.pos_emb[:, :token_count, :] + x = self.drop(token_embeddings + position_embeddings) + x = self.encoder(src=x, mask=self.mask) + x = x[:, 2:, :] + else: + cond_embeddings = torch.cat([r_emb, t_emb], dim=1) + if self.obs_as_cond: + cond_obs_emb = self.cond_obs_emb(cond) + cond_embeddings = torch.cat([cond_embeddings, cond_obs_emb], dim=1) + token_count = cond_embeddings.shape[1] + position_embeddings = self.cond_pos_emb[:, :token_count, :] + x = self.drop(cond_embeddings + position_embeddings) + x = self.encoder(x) + memory = x + + token_embeddings = input_emb + token_count = token_embeddings.shape[1] + position_embeddings = self.pos_emb[:, :token_count, :] + x = self.drop(token_embeddings + position_embeddings) + x = self.decoder( + tgt=x, + memory=memory, + tgt_mask=self.mask, + memory_mask=self.memory_mask, + ) + + x = self.ln_f(x) + x = self.head(x) + return x diff --git a/diffusion_policy/policy/imf_transformer_hybrid_image_policy.py b/diffusion_policy/policy/imf_transformer_hybrid_image_policy.py new file mode 100644 index 0000000..4ffc77e --- /dev/null +++ b/diffusion_policy/policy/imf_transformer_hybrid_image_policy.py @@ -0,0 +1,273 @@ +from contextlib import nullcontext +from typing import Dict, Optional + +import torch +import torch.nn.functional as F +from einops import reduce + +from diffusion_policy.common.pytorch_util import dict_apply +from diffusion_policy.model.diffusion.imf_transformer_for_diffusion import IMFTransformerForDiffusion +from diffusion_policy.policy.diffusion_transformer_hybrid_image_policy import ( + DiffusionTransformerHybridImagePolicy, +) + +try: + from torch.func import jvp as TORCH_FUNC_JVP +except ImportError: # pragma: no cover - depends on torch version + TORCH_FUNC_JVP = None + + +class IMFTransformerHybridImagePolicy(DiffusionTransformerHybridImagePolicy): + def __init__( + self, + shape_meta: dict, + noise_scheduler, + horizon, + n_action_steps, + n_obs_steps, + num_inference_steps=None, + crop_shape=(76, 76), + obs_encoder_group_norm=False, + eval_fixed_crop=False, + n_layer=8, + n_cond_layers=0, + n_head=1, + n_emb=256, + p_drop_emb=0.0, + p_drop_attn=0.3, + causal_attn=True, + time_as_cond=True, + obs_as_cond=True, + pred_action_steps_only=False, + **kwargs, + ): + if num_inference_steps is None: + num_inference_steps = 1 + elif num_inference_steps != 1: + raise ValueError( + 'IMFTransformerHybridImagePolicy only supports one-step inference; ' + f'num_inference_steps must be 1, got {num_inference_steps}.' + ) + + super().__init__( + shape_meta=shape_meta, + noise_scheduler=noise_scheduler, + horizon=horizon, + n_action_steps=n_action_steps, + n_obs_steps=n_obs_steps, + num_inference_steps=num_inference_steps, + crop_shape=crop_shape, + obs_encoder_group_norm=obs_encoder_group_norm, + eval_fixed_crop=eval_fixed_crop, + n_layer=n_layer, + n_cond_layers=n_cond_layers, + n_head=n_head, + n_emb=n_emb, + p_drop_emb=p_drop_emb, + p_drop_attn=p_drop_attn, + causal_attn=causal_attn, + time_as_cond=time_as_cond, + obs_as_cond=obs_as_cond, + pred_action_steps_only=pred_action_steps_only, + **kwargs, + ) + + input_dim = self.action_dim if self.obs_as_cond else (self.obs_feature_dim + self.action_dim) + cond_dim = self.obs_feature_dim if self.obs_as_cond else 0 + model_horizon = self.n_action_steps if self.pred_action_steps_only else horizon + self.model = IMFTransformerForDiffusion( + input_dim=input_dim, + output_dim=input_dim, + horizon=model_horizon, + n_obs_steps=n_obs_steps, + cond_dim=cond_dim, + n_layer=n_layer, + n_head=n_head, + n_emb=n_emb, + p_drop_emb=p_drop_emb, + p_drop_attn=p_drop_attn, + causal_attn=causal_attn, + time_as_cond=time_as_cond, + obs_as_cond=obs_as_cond, + n_cond_layers=n_cond_layers, + ) + self.num_inference_steps = 1 + + def fn(self, z: torch.Tensor, r: torch.Tensor, t: torch.Tensor, cond=None) -> torch.Tensor: + return self.model(z, r, t, cond=cond) + + @staticmethod + def _broadcast_batch_time(value: torch.Tensor, reference: torch.Tensor) -> torch.Tensor: + while value.ndim < reference.ndim: + value = value.unsqueeze(-1) + return value + + @staticmethod + def _apply_conditioning( + trajectory: torch.Tensor, + condition_data: Optional[torch.Tensor] = None, + condition_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if condition_data is None or condition_mask is None: + return trajectory + conditioned = trajectory.clone() + conditioned[condition_mask] = condition_data[condition_mask] + return conditioned + + @staticmethod + def _jvp_math_sdp_context(z_t: torch.Tensor): + if z_t.is_cuda: + return torch.backends.cuda.sdp_kernel( + enable_flash=False, + enable_math=True, + enable_mem_efficient=False, + enable_cudnn=False, + ) + return nullcontext() + + @staticmethod + def _jvp_tangents(v: torch.Tensor, r: torch.Tensor, t: torch.Tensor): + return v.detach(), torch.zeros_like(r), torch.ones_like(t) + + def _compute_u_and_du_dt( + self, + z_t: torch.Tensor, + r: torch.Tensor, + t: torch.Tensor, + cond, + v: torch.Tensor, + condition_data: Optional[torch.Tensor] = None, + condition_mask: Optional[torch.Tensor] = None, + ): + tangents = self._jvp_tangents(v, r, t) + + def g(z, r_value, t_value): + conditioned_z = self._apply_conditioning(z, condition_data, condition_mask) + return self.fn(conditioned_z, r_value, t_value, cond=cond) + + with self._jvp_math_sdp_context(z_t): + if TORCH_FUNC_JVP is not None: + try: + return TORCH_FUNC_JVP(g, (z_t, r, t), tangents) + except (RuntimeError, TypeError, NotImplementedError): + pass + + u = g(z_t, r, t) + _, du_dt = torch.autograd.functional.jvp( + g, + (z_t, r, t), + tangents, + create_graph=False, + strict=False, + ) + return u, du_dt + + def _compound_velocity( + self, + u: torch.Tensor, + du_dt: torch.Tensor, + r: torch.Tensor, + t: torch.Tensor, + ) -> torch.Tensor: + delta = self._broadcast_batch_time(t - r, u) + return u + delta * du_dt.detach() + + def _sample_one_step( + self, + z_t: torch.Tensor, + r: torch.Tensor = None, + t: torch.Tensor = None, + cond=None, + ) -> torch.Tensor: + batch_size = z_t.shape[0] + if t is None: + t = torch.ones(batch_size, device=z_t.device, dtype=z_t.dtype) + if r is None: + r = torch.zeros(batch_size, device=z_t.device, dtype=z_t.dtype) + u = self.fn(z_t, r, t, cond=cond) + delta = self._broadcast_batch_time(t - r, z_t) + return z_t - delta * u + + def conditional_sample( + self, + condition_data, + condition_mask, + cond=None, + generator=None, + **kwargs, + ): + trajectory = torch.randn( + size=condition_data.shape, + dtype=condition_data.dtype, + device=condition_data.device, + generator=generator, + ) + trajectory = self._apply_conditioning(trajectory, condition_data, condition_mask) + trajectory = self._sample_one_step(trajectory, cond=cond) + trajectory = self._apply_conditioning(trajectory, condition_data, condition_mask) + return trajectory + + def compute_loss(self, batch): + assert 'valid_mask' not in batch + nobs = self.normalizer.normalize(batch['obs']) + nactions = self.normalizer['action'].normalize(batch['action']) + batch_size = nactions.shape[0] + horizon = nactions.shape[1] + To = self.n_obs_steps + + cond = None + trajectory = nactions + if self.obs_as_cond: + this_nobs = dict_apply( + nobs, + lambda x: x[:, :To, ...].reshape(-1, *x.shape[2:]), + ) + nobs_features = self.obs_encoder(this_nobs) + cond = nobs_features.reshape(batch_size, To, -1) + if self.pred_action_steps_only: + start = To - 1 + end = start + self.n_action_steps + trajectory = nactions[:, start:end] + else: + this_nobs = dict_apply(nobs, lambda x: x.reshape(-1, *x.shape[2:])) + nobs_features = self.obs_encoder(this_nobs) + nobs_features = nobs_features.reshape(batch_size, horizon, -1) + trajectory = torch.cat([nactions, nobs_features], dim=-1).detach() + + if self.pred_action_steps_only: + condition_mask = torch.zeros_like(trajectory, dtype=torch.bool) + else: + condition_mask = self.mask_generator(trajectory.shape) + + loss_mask = torch.zeros_like(trajectory, dtype=torch.bool) + loss_mask[..., : self.action_dim] = True + loss_mask = loss_mask & ~condition_mask + + x = trajectory + e = torch.randn_like(x) + t = torch.rand(batch_size, device=x.device, dtype=x.dtype) + r = torch.rand(batch_size, device=x.device, dtype=x.dtype) + t, r = torch.maximum(t, r), torch.minimum(t, r) + + t_broadcast = self._broadcast_batch_time(t, x) + z_t = (1 - t_broadcast) * x + t_broadcast * e + z_t = self._apply_conditioning(z_t, x, condition_mask) + + v = self.fn(z_t, t, t, cond=cond) + u, du_dt = self._compute_u_and_du_dt( + z_t, + r, + t, + cond=cond, + v=v, + condition_data=x, + condition_mask=condition_mask, + ) + V = self._compound_velocity(u, du_dt, r, t) + target = e - x + + loss = F.mse_loss(V, target, reduction='none') + loss = loss * loss_mask.type(loss.dtype) + loss = reduce(loss, 'b ... -> b (...)', 'mean') + loss = loss.mean() + return loss diff --git a/image_pusht_diffusion_policy_dit_imf.yaml b/image_pusht_diffusion_policy_dit_imf.yaml new file mode 100644 index 0000000..804a80d --- /dev/null +++ b/image_pusht_diffusion_policy_dit_imf.yaml @@ -0,0 +1,30 @@ +defaults: + - diffusion_policy/config/train_diffusion_transformer_hybrid_workspace@_here_ + - override /diffusion_policy/config/task@task: pusht_image + - _self_ + +exp_name: pusht_image_dit_imf + +policy: + _target_: diffusion_policy.policy.imf_transformer_hybrid_image_policy.IMFTransformerHybridImagePolicy + num_inference_steps: 1 + n_head: 1 + +logging: + backend: swanlab + mode: online + tags: ["${name}", "${task_name}", "${exp_name}", "swanlab"] + id: ${now:%Y%m%d%H%M%S}_${name}_${task_name} + group: ${exp_name} + +dataloader: + num_workers: 0 + +val_dataloader: + num_workers: 0 + +task: + env_runner: + n_envs: 1 + n_test_vis: 0 + n_train_vis: 0 diff --git a/tests/test_imf_transformer_for_diffusion.py b/tests/test_imf_transformer_for_diffusion.py new file mode 100644 index 0000000..c1ba345 --- /dev/null +++ b/tests/test_imf_transformer_for_diffusion.py @@ -0,0 +1,46 @@ +import inspect +import pathlib +import sys + +import torch + +ROOT_DIR = pathlib.Path(__file__).resolve().parents[1] +if str(ROOT_DIR) not in sys.path: + sys.path.append(str(ROOT_DIR)) + +from diffusion_policy.model.diffusion.imf_transformer_for_diffusion import ( # noqa: E402 + IMFTransformerForDiffusion, +) + + +def test_imf_transformer_forward_signature_and_shape_single_head(): + signature = inspect.signature(IMFTransformerForDiffusion.forward) + assert list(signature.parameters)[:5] == ['self', 'sample', 'r', 't', 'cond'] + assert signature.parameters['cond'].default is None + + model = IMFTransformerForDiffusion( + input_dim=3, + output_dim=3, + horizon=5, + n_obs_steps=2, + cond_dim=4, + n_layer=1, + n_head=1, + n_emb=16, + p_drop_emb=0.0, + p_drop_attn=0.0, + causal_attn=True, + time_as_cond=True, + obs_as_cond=True, + n_cond_layers=0, + ) + model.configure_optimizers() + + sample = torch.randn(2, 5, 3) + r = torch.rand(2) + t = torch.rand(2) + cond = torch.randn(2, 2, 4) + + pred_u = model(sample, r, t, cond=cond) + + assert pred_u.shape == sample.shape diff --git a/tests/test_imf_transformer_hybrid_image_policy.py b/tests/test_imf_transformer_hybrid_image_policy.py new file mode 100644 index 0000000..6f3ce50 --- /dev/null +++ b/tests/test_imf_transformer_hybrid_image_policy.py @@ -0,0 +1,313 @@ +import pathlib +import sys + +import pytest +import torch +import torch.nn as nn + +ROOT_DIR = pathlib.Path(__file__).resolve().parents[1] +if str(ROOT_DIR) not in sys.path: + sys.path.append(str(ROOT_DIR)) + +import diffusion_policy.policy.imf_transformer_hybrid_image_policy as policy_module # noqa: E402 +from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin # noqa: E402 +from diffusion_policy.policy.imf_transformer_hybrid_image_policy import ( # noqa: E402 + IMFTransformerHybridImagePolicy, +) + + +class ConstantModel(nn.Module): + def __init__(self, value): + super().__init__() + self.value = value + + def forward(self, sample, r, t, cond=None): + return torch.full_like(sample, self.value) + + +class AffineModel(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.tensor(2.0)) + + def forward(self, sample, r, t, cond=None): + return sample * self.weight + (r + t).view(-1, 1, 1) + + +class SumMixModel(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.tensor(2.0)) + + def forward(self, sample, r, t, cond=None): + mixed = sample.sum(dim=-1, keepdim=True).expand_as(sample) + return mixed * self.weight + t.view(-1, 1, 1) + + +class TrackingContext: + def __init__(self): + self.active = False + self.enter_count = 0 + + def __enter__(self): + self.active = True + self.enter_count += 1 + return self + + def __exit__(self, exc_type, exc, tb): + self.active = False + return False + + +def make_policy(model): + policy = IMFTransformerHybridImagePolicy.__new__(IMFTransformerHybridImagePolicy) + ModuleAttrMixin.__init__(policy) + policy.model = model + return policy + + +def fake_parent_init( + self, + shape_meta, + noise_scheduler, + horizon, + n_action_steps, + n_obs_steps, + num_inference_steps=None, + crop_shape=(76, 76), + obs_encoder_group_norm=False, + eval_fixed_crop=False, + n_layer=8, + n_cond_layers=0, + n_head=1, + n_emb=256, + p_drop_emb=0.0, + p_drop_attn=0.3, + causal_attn=True, + time_as_cond=True, + obs_as_cond=True, + pred_action_steps_only=False, + **kwargs, +): + ModuleAttrMixin.__init__(self) + self.action_dim = shape_meta['action']['shape'][0] + self.obs_feature_dim = 4 + self.obs_as_cond = obs_as_cond + self.pred_action_steps_only = pred_action_steps_only + self.n_action_steps = n_action_steps + self.n_obs_steps = n_obs_steps + self.horizon = horizon + + +@pytest.fixture +def shape_meta(): + return { + 'action': {'shape': [2]}, + 'obs': {}, + } + + +def test_sample_one_step_uses_imf_update_formula(): + policy = make_policy(ConstantModel(0.25)) + z_1 = torch.tensor([ + [[1.0, -1.0], [0.5, 0.0]], + [[2.0, 3.0], [-2.0, 4.0]], + ]) + r = torch.zeros(z_1.shape[0]) + t = torch.ones(z_1.shape[0]) + + x_hat = policy._sample_one_step(z_1, r=r, t=t, cond=None) + + expected = z_1 - (t - r).view(-1, 1, 1) * 0.25 + assert torch.allclose(x_hat, expected) + + +def test_compound_velocity_uses_detached_du_dt_term(): + policy = make_policy(ConstantModel(0.0)) + u = torch.tensor([[[1.0], [2.0]]], requires_grad=True) + du_dt = torch.tensor([[[3.0], [4.0]]], requires_grad=True) + r = torch.tensor([0.2]) + t = torch.tensor([0.8]) + + compound = policy._compound_velocity(u, du_dt, r, t) + expected = u + (t - r).view(-1, 1, 1) * du_dt.detach() + + assert torch.allclose(compound, expected) + + compound.sum().backward() + assert u.grad is not None + assert du_dt.grad is None + + +def test_compute_u_and_du_dt_uses_math_sdpa_context_for_torch_func_jvp(monkeypatch): + tracker = TrackingContext() + + def fake_jvp(fn, primals, tangents): + assert tracker.active is True + return fn(*primals), torch.zeros_like(primals[0]) + + monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', fake_jvp) + + policy = make_policy(ConstantModel(0.5)) + policy._jvp_math_sdp_context = lambda tensor: tracker + z_t = torch.randn(2, 3, 4) + r = torch.rand(2, requires_grad=True) + t = torch.rand(2, requires_grad=True) + v = torch.randn_like(z_t, requires_grad=True) + + policy._compute_u_and_du_dt(z_t, r, t, cond=None, v=v) + + assert tracker.enter_count == 1 + + +def test_compute_u_and_du_dt_uses_math_sdpa_context_for_autograd_fallback(monkeypatch): + tracker = TrackingContext() + + def fake_autograd_jvp(fn, primals, tangents, create_graph=False, strict=False): + assert tracker.active is True + return fn(*primals), torch.zeros_like(primals[0]) + + monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', None) + monkeypatch.setattr(policy_module.torch.autograd.functional, 'jvp', fake_autograd_jvp) + + policy = make_policy(ConstantModel(0.5)) + policy._jvp_math_sdp_context = lambda tensor: tracker + z_t = torch.randn(2, 3, 4) + r = torch.rand(2, requires_grad=True) + t = torch.rand(2, requires_grad=True) + v = torch.randn_like(z_t, requires_grad=True) + + policy._compute_u_and_du_dt(z_t, r, t, cond=None, v=v) + + assert tracker.enter_count == 1 + + +def test_compute_u_and_du_dt_uses_detached_v_zero_r_unit_t_and_reapplies_conditioning(monkeypatch): + captured = {} + + def fake_jvp(fn, primals, tangents): + captured['tangents'] = tangents + captured['primal_output'] = fn(*primals) + return captured['primal_output'], torch.zeros_like(primals[0]) + + monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', fake_jvp) + + policy = make_policy(SumMixModel()) + z_t = torch.tensor([[[1.0, 2.0, 3.0]]]) + r = torch.rand(1, requires_grad=True) + t = torch.rand(1, requires_grad=True) + v = torch.tensor([[[10.0, 20.0, 30.0]]], requires_grad=True) + condition_mask = torch.tensor([[[False, True, False]]]) + condition_data = torch.tensor([[[0.0, 7.0, 0.0]]]) + + policy._compute_u_and_du_dt( + z_t, + r, + t, + cond=None, + v=v, + condition_data=condition_data, + condition_mask=condition_mask, + ) + + tangent_v, tangent_r, tangent_t = captured['tangents'] + assert torch.equal(tangent_v, v.detach()) + assert tangent_v.requires_grad is False + assert torch.equal(tangent_r, torch.zeros_like(r)) + assert torch.equal(tangent_t, torch.ones_like(t)) + + conditioned = z_t.clone() + conditioned[condition_mask] = condition_data[condition_mask] + expected_primal = policy.model(conditioned, r, t, cond=None) + assert torch.allclose(captured['primal_output'], expected_primal) + + +def test_compute_u_and_du_dt_fallback_blocks_conditioned_tangent_leakage_and_keeps_primal_gradients(monkeypatch): + monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', None) + + policy = make_policy(SumMixModel()) + z_t = torch.tensor([[[1.0, 2.0, 3.0]]], requires_grad=True) + r = torch.rand(1, requires_grad=True) + t = torch.rand(1, requires_grad=True) + v = torch.tensor([[[1.0, 10.0, 100.0]]], requires_grad=True) + condition_mask = torch.tensor([[[False, True, False]]]) + condition_data = torch.tensor([[[0.0, 7.0, 0.0]]]) + + u, du_dt = policy._compute_u_and_du_dt( + z_t, + r, + t, + cond=None, + v=v, + condition_data=condition_data, + condition_mask=condition_mask, + ) + + conditioned = z_t.detach().clone() + conditioned[condition_mask] = condition_data[condition_mask] + expected_u = policy.model(conditioned, r, t, cond=None) + expected_du_dt_scalar = policy.model.weight.detach() * torch.tensor(101.0) + 1.0 + expected_du_dt = torch.full_like(z_t, expected_du_dt_scalar) + + assert u.shape == z_t.shape + assert du_dt.shape == z_t.shape + assert torch.allclose(u, expected_u) + assert torch.allclose(du_dt, expected_du_dt) + + u.sum().backward() + assert policy.model.weight.grad is not None + assert torch.count_nonzero(policy.model.weight.grad) > 0 + + +def test_init_uses_action_step_horizon_when_pred_action_steps_only(monkeypatch, shape_meta): + monkeypatch.setattr( + policy_module.DiffusionTransformerHybridImagePolicy, + '__init__', + fake_parent_init, + ) + + policy = IMFTransformerHybridImagePolicy( + shape_meta=shape_meta, + noise_scheduler=None, + horizon=10, + n_action_steps=4, + n_obs_steps=2, + num_inference_steps=1, + n_layer=1, + n_head=1, + n_emb=16, + p_drop_emb=0.0, + p_drop_attn=0.0, + causal_attn=True, + obs_as_cond=True, + pred_action_steps_only=True, + ) + + assert policy.model.horizon == 4 + assert policy.num_inference_steps == 1 + + +def test_init_rejects_non_one_step_inference(monkeypatch, shape_meta): + monkeypatch.setattr( + policy_module.DiffusionTransformerHybridImagePolicy, + '__init__', + fake_parent_init, + ) + + with pytest.raises(ValueError, match='num_inference_steps'): + IMFTransformerHybridImagePolicy( + shape_meta=shape_meta, + noise_scheduler=None, + horizon=10, + n_action_steps=4, + n_obs_steps=2, + num_inference_steps=2, + n_layer=1, + n_head=1, + n_emb=16, + p_drop_emb=0.0, + p_drop_attn=0.0, + causal_attn=True, + obs_as_cond=True, + pred_action_steps_only=False, + )