From 42dc29a2cb369b353b7fe00402a6f60e00abbc97 Mon Sep 17 00:00:00 2001 From: gameloader Date: Mon, 16 Mar 2026 15:37:32 +0800 Subject: [PATCH] feat: align pmf transformer training and config defaults --- .../pmf_transformer_hybrid_image_policy.py | 89 +++++++++++++++---- image_pusht_diffusion_policy_dit_pmf.yaml | 11 ++- 2 files changed, 80 insertions(+), 20 deletions(-) diff --git a/diffusion_policy/policy/pmf_transformer_hybrid_image_policy.py b/diffusion_policy/policy/pmf_transformer_hybrid_image_policy.py index 20e7a1a..66afc6c 100644 --- a/diffusion_policy/policy/pmf_transformer_hybrid_image_policy.py +++ b/diffusion_policy/policy/pmf_transformer_hybrid_image_policy.py @@ -47,6 +47,12 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy): pmf_u_loss_weight=1.0, pmf_v_loss_weight=1.0, noise_scale=1.0, + adatloss_eps=0.01, + p_mean=-0.4, + p_std=1.0, + tr_uniform=True, + tr_uniform_prob=0.1, + data_proportion=0.5, **kwargs, ): super().__init__() @@ -171,6 +177,12 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy): self.pmf_u_loss_weight = pmf_u_loss_weight self.pmf_v_loss_weight = pmf_v_loss_weight self.noise_scale = noise_scale + self.adatloss_eps = adatloss_eps + self.p_mean = p_mean + self.p_std = p_std + self.tr_uniform = tr_uniform + self.tr_uniform_prob = tr_uniform_prob + self.data_proportion = data_proportion self.kwargs = kwargs if num_inference_steps is None: @@ -185,6 +197,33 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy): def _time_view(self, value: torch.Tensor, ref: torch.Tensor) -> torch.Tensor: return value.reshape(value.shape[0], *([1] * (ref.ndim - 1))) + def _adatloss(self, loss: torch.Tensor) -> torch.Tensor: + denom = loss.detach() + self.adatloss_eps + return loss / denom + + def _sample_logit_normal(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: + normal = torch.randn(batch_size, device=device, dtype=dtype) + return torch.sigmoid(normal * self.p_std + self.p_mean) + + def _sample_tr(self, batch_size: int, device: torch.device, dtype: torch.dtype): + t = self._sample_logit_normal(batch_size, device, dtype) + r = self._sample_logit_normal(batch_size, device, dtype) + + if self.tr_uniform: + uniform_mask = torch.rand(batch_size, device=device) < self.tr_uniform_prob + uniform_t = torch.rand(batch_size, device=device, dtype=dtype) + uniform_r = torch.rand(batch_size, device=device, dtype=dtype) + t = torch.where(uniform_mask, uniform_t, t) + r = torch.where(uniform_mask, uniform_r, r) + + data_size = int(batch_size * self.data_proportion) + fm_mask = torch.arange(batch_size, device=device) < data_size + r = torch.where(fm_mask, t, r) + + t_final = torch.maximum(t, r) + r_final = torch.minimum(t, r) + return t_final, r_final + def _trajectory_inputs( self, nobs: Dict[str, torch.Tensor], @@ -217,10 +256,9 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy): condition_data: torch.Tensor, condition_mask: torch.Tensor, ) -> torch.Tensor: - if condition_mask.any(): - sample = sample.clone() - sample[condition_mask] = condition_data[condition_mask] - return sample + if not condition_mask.any(): + return sample + return torch.where(condition_mask, condition_data, sample) def _compute_u_v( self, @@ -230,7 +268,7 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy): cond: torch.Tensor, ): x_hat_u, x_hat_v = self.model(sample, t, r, cond) - denom = self._time_view(t.clamp_min(self.min_time), sample) + denom = self._time_view(t, sample) u = (sample - x_hat_u) / denom v = (sample - x_hat_v) / denom return u, v @@ -245,16 +283,31 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy): condition_mask: torch.Tensor, tangent_v: torch.Tensor, ) -> torch.Tensor: - step = torch.full_like(t, self.du_dt_epsilon) - sample_base = self._apply_conditioning(sample, condition_data, condition_mask) - u_base, _ = self._compute_u_v(sample_base, t, r, cond) + tangent_sample = tangent_v.detach() + tangent_r = torch.zeros_like(r) + tangent_t = torch.ones_like(t) - t_next = (t + step).clamp(max=1.0) - sample_next = sample + self._time_view(step, sample) * tangent_v.detach() - sample_next = self._apply_conditioning(sample_next, condition_data, condition_mask) - u_next, _ = self._compute_u_v(sample_next, t_next, r, cond) + def u_fn(sample_input, r_input, t_input): + conditioned_sample = self._apply_conditioning( + sample_input, condition_data, condition_mask + ) + u_value, _ = self._compute_u_v(conditioned_sample, t_input, r_input, cond) + return u_value - return (u_next - u_base) / self._time_view(step, sample) + primals = (sample, r, t) + tangents = (tangent_sample, tangent_r, tangent_t) + try: + _, du_dt = torch.func.jvp(u_fn, primals, tangents) + except (AttributeError, NotImplementedError, RuntimeError): + _, du_dt = torch.autograd.functional.jvp( + u_fn, + primals, + tangents, + create_graph=False, + strict=False, + ) + + return du_dt # ========= inference ============ def conditional_sample( @@ -368,14 +421,14 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy): noise = torch.randn_like(trajectory) * self.noise_scale batch_size = trajectory.shape[0] - t = torch.rand(batch_size, device=trajectory.device, dtype=trajectory.dtype) - r = torch.rand(batch_size, device=trajectory.device, dtype=trajectory.dtype) * t + t, r = self._sample_tr( + batch_size, device=trajectory.device, dtype=trajectory.dtype + ) z_t = (1 - self._time_view(t, trajectory)) * trajectory + self._time_view(t, trajectory) * noise z_t = self._apply_conditioning(z_t, trajectory, condition_mask) loss_mask = ~condition_mask - denom = self._time_view(t.clamp_min(self.min_time), trajectory) - target_v = (z_t - trajectory) / denom + target_v = noise - trajectory u, v = self._compute_u_v(z_t, t, r, cond) du_dt = self._compute_du_dt( @@ -395,4 +448,6 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy): loss_v = loss_v * loss_mask.type(loss_v.dtype) loss_u = reduce(loss_u, "b ... -> b (...)", "mean").mean() loss_v = reduce(loss_v, "b ... -> b (...)", "mean").mean() + loss_u = self._adatloss(loss_u) + loss_v = self._adatloss(loss_v) return self.pmf_u_loss_weight * loss_u + self.pmf_v_loss_weight * loss_v diff --git a/image_pusht_diffusion_policy_dit_pmf.yaml b/image_pusht_diffusion_policy_dit_pmf.yaml index 3123932..4b1afb1 100644 --- a/image_pusht_diffusion_policy_dit_pmf.yaml +++ b/image_pusht_diffusion_policy_dit_pmf.yaml @@ -58,15 +58,17 @@ policy: - 84 eval_fixed_crop: true horizon: 16 - min_time: 0.05 n_action_steps: 8 n_cond_layers: 0 n_emb: 256 n_head: 4 - n_layer: 8 + n_layer: 12 n_obs_steps: 2 n_time_tokens: 4 noise_scale: 1.0 + adatloss_eps: 0.01 + p_mean: -0.4 + p_std: 1.0 noise_scheduler: _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler beta_end: 0.02 @@ -76,13 +78,16 @@ policy: num_train_timesteps: 100 prediction_type: sample variance_type: fixed_small - num_inference_steps: 32 + num_inference_steps: 1 obs_as_cond: true obs_encoder_group_norm: true p_drop_attn: 0.0 p_drop_emb: 0.0 pmf_u_loss_weight: 1.0 pmf_v_loss_weight: 1.0 + tr_uniform: true + tr_uniform_prob: 0.1 + data_proportion: 0.5 shape_meta: action: shape: