From 9169e4d7e038638bfba161a8232a00ede03b7321 Mon Sep 17 00:00:00 2001 From: Logic Date: Tue, 17 Mar 2026 17:05:02 +0800 Subject: [PATCH] feat(pusht): add dual-head uv transformer --- .../pmf_transformer_for_diffusion.py | 49 ++++++++++++++++--- .../pmf_transformer_hybrid_image_policy.py | 2 + image_pusht_diffusion_policy_dit_pmf.yaml | 1 + 3 files changed, 46 insertions(+), 6 deletions(-) diff --git a/diffusion_policy/model/diffusion/pmf_transformer_for_diffusion.py b/diffusion_policy/model/diffusion/pmf_transformer_for_diffusion.py index 96d16d2..1a65963 100644 --- a/diffusion_policy/model/diffusion/pmf_transformer_for_diffusion.py +++ b/diffusion_policy/model/diffusion/pmf_transformer_for_diffusion.py @@ -27,6 +27,7 @@ class PMFTransformerForDiffusion(ModuleAttrMixin): obs_as_cond: bool = False, n_cond_layers: int = 0, n_time_tokens: int = 4, + n_head_layers: int = 4, ) -> None: super().__init__() @@ -34,11 +35,18 @@ class PMFTransformerForDiffusion(ModuleAttrMixin): n_obs_steps = horizon if n_time_tokens < 1: raise ValueError("n_time_tokens must be >= 1") + if n_head_layers < 0: + raise ValueError("n_head_layers must be >= 0") + if n_head_layers >= n_layer: + raise ValueError( + "n_head_layers must be smaller than n_layer so shared trunk depth stays positive" + ) obs_as_cond = cond_dim > 0 T = horizon n_global_cond_tokens = 2 * n_time_tokens T_cond = n_global_cond_tokens + (n_obs_steps if obs_as_cond else 0) + n_shared_layers = n_layer - n_head_layers self.input_emb = nn.Linear(input_dim, n_emb) self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb)) @@ -81,9 +89,17 @@ class PMFTransformerForDiffusion(ModuleAttrMixin): batch_first=True, norm_first=True, ) - self.decoder = nn.TransformerDecoder( + self.shared_decoder = nn.TransformerDecoder( decoder_layer=decoder_layer, - num_layers=n_layer, + num_layers=n_shared_layers, + ) + self.u_decoder = nn.TransformerDecoder( + decoder_layer=decoder_layer, + num_layers=n_head_layers, + ) + self.v_decoder = nn.TransformerDecoder( + decoder_layer=decoder_layer, + num_layers=n_head_layers, ) if causal_attn: @@ -109,7 +125,8 @@ class PMFTransformerForDiffusion(ModuleAttrMixin): self.mask = None self.memory_mask = None - self.ln_f = nn.LayerNorm(n_emb) + self.ln_u = nn.LayerNorm(n_emb) + self.ln_v = nn.LayerNorm(n_emb) self.head_u = nn.Linear(n_emb, output_dim) self.head_v = nn.Linear(n_emb, output_dim) @@ -120,11 +137,20 @@ class PMFTransformerForDiffusion(ModuleAttrMixin): self.obs_as_cond = obs_as_cond self.n_global_cond_tokens = n_global_cond_tokens self.n_time_tokens = n_time_tokens + self.n_layer = n_layer + self.n_head_layers = n_head_layers + self.n_shared_layers = n_shared_layers self.apply(self._init_weights) logger.info( "number of parameters: %e", sum(p.numel() for p in self.parameters()) ) + logger.info( + "PMFTransformerForDiffusion layers: shared=%d u_head=%d v_head=%d", + self.n_shared_layers, + self.n_head_layers, + self.n_head_layers, + ) def _init_weights(self, module): ignore_types = ( @@ -255,11 +281,22 @@ class PMFTransformerForDiffusion(ModuleAttrMixin): token_pos = self.pos_emb[:, : input_emb.shape[1], :] x = self.drop(input_emb + token_pos) - x = self.decoder( + shared_x = self.shared_decoder( tgt=x, memory=memory, tgt_mask=self.mask, memory_mask=self.memory_mask, ) - x = self.ln_f(x) - return self.head_u(x), self.head_v(x) + u_x = self.u_decoder( + tgt=shared_x, + memory=memory, + tgt_mask=self.mask, + memory_mask=self.memory_mask, + ) + v_x = self.v_decoder( + tgt=shared_x, + memory=memory, + tgt_mask=self.mask, + memory_mask=self.memory_mask, + ) + return self.head_u(self.ln_u(u_x)), self.head_v(self.ln_v(v_x)) diff --git a/diffusion_policy/policy/pmf_transformer_hybrid_image_policy.py b/diffusion_policy/policy/pmf_transformer_hybrid_image_policy.py index 66afc6c..3ce50e0 100644 --- a/diffusion_policy/policy/pmf_transformer_hybrid_image_policy.py +++ b/diffusion_policy/policy/pmf_transformer_hybrid_image_policy.py @@ -42,6 +42,7 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy): obs_as_cond=True, pred_action_steps_only=False, n_time_tokens=4, + n_head_layers=4, min_time=0.05, du_dt_epsilon=1.0e-3, pmf_u_loss_weight=1.0, @@ -155,6 +156,7 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy): obs_as_cond=obs_as_cond, n_cond_layers=n_cond_layers, n_time_tokens=n_time_tokens, + n_head_layers=n_head_layers, ) self.noise_scheduler = noise_scheduler self.mask_generator = LowdimMaskGenerator( diff --git a/image_pusht_diffusion_policy_dit_pmf.yaml b/image_pusht_diffusion_policy_dit_pmf.yaml index 4b1afb1..5ce495b 100644 --- a/image_pusht_diffusion_policy_dit_pmf.yaml +++ b/image_pusht_diffusion_policy_dit_pmf.yaml @@ -63,6 +63,7 @@ policy: n_emb: 256 n_head: 4 n_layer: 12 + n_head_layers: 4 n_obs_steps: 2 n_time_tokens: 4 noise_scale: 1.0