1 Commits

Author SHA1 Message Date
Logic
9169e4d7e0 feat(pusht): add dual-head uv transformer 2026-03-17 17:05:02 +08:00
3 changed files with 46 additions and 6 deletions

View File

@@ -27,6 +27,7 @@ class PMFTransformerForDiffusion(ModuleAttrMixin):
obs_as_cond: bool = False, obs_as_cond: bool = False,
n_cond_layers: int = 0, n_cond_layers: int = 0,
n_time_tokens: int = 4, n_time_tokens: int = 4,
n_head_layers: int = 4,
) -> None: ) -> None:
super().__init__() super().__init__()
@@ -34,11 +35,18 @@ class PMFTransformerForDiffusion(ModuleAttrMixin):
n_obs_steps = horizon n_obs_steps = horizon
if n_time_tokens < 1: if n_time_tokens < 1:
raise ValueError("n_time_tokens must be >= 1") raise ValueError("n_time_tokens must be >= 1")
if n_head_layers < 0:
raise ValueError("n_head_layers must be >= 0")
if n_head_layers >= n_layer:
raise ValueError(
"n_head_layers must be smaller than n_layer so shared trunk depth stays positive"
)
obs_as_cond = cond_dim > 0 obs_as_cond = cond_dim > 0
T = horizon T = horizon
n_global_cond_tokens = 2 * n_time_tokens n_global_cond_tokens = 2 * n_time_tokens
T_cond = n_global_cond_tokens + (n_obs_steps if obs_as_cond else 0) T_cond = n_global_cond_tokens + (n_obs_steps if obs_as_cond else 0)
n_shared_layers = n_layer - n_head_layers
self.input_emb = nn.Linear(input_dim, n_emb) self.input_emb = nn.Linear(input_dim, n_emb)
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb)) self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
@@ -81,9 +89,17 @@ class PMFTransformerForDiffusion(ModuleAttrMixin):
batch_first=True, batch_first=True,
norm_first=True, norm_first=True,
) )
self.decoder = nn.TransformerDecoder( self.shared_decoder = nn.TransformerDecoder(
decoder_layer=decoder_layer, decoder_layer=decoder_layer,
num_layers=n_layer, num_layers=n_shared_layers,
)
self.u_decoder = nn.TransformerDecoder(
decoder_layer=decoder_layer,
num_layers=n_head_layers,
)
self.v_decoder = nn.TransformerDecoder(
decoder_layer=decoder_layer,
num_layers=n_head_layers,
) )
if causal_attn: if causal_attn:
@@ -109,7 +125,8 @@ class PMFTransformerForDiffusion(ModuleAttrMixin):
self.mask = None self.mask = None
self.memory_mask = None self.memory_mask = None
self.ln_f = nn.LayerNorm(n_emb) self.ln_u = nn.LayerNorm(n_emb)
self.ln_v = nn.LayerNorm(n_emb)
self.head_u = nn.Linear(n_emb, output_dim) self.head_u = nn.Linear(n_emb, output_dim)
self.head_v = nn.Linear(n_emb, output_dim) self.head_v = nn.Linear(n_emb, output_dim)
@@ -120,11 +137,20 @@ class PMFTransformerForDiffusion(ModuleAttrMixin):
self.obs_as_cond = obs_as_cond self.obs_as_cond = obs_as_cond
self.n_global_cond_tokens = n_global_cond_tokens self.n_global_cond_tokens = n_global_cond_tokens
self.n_time_tokens = n_time_tokens self.n_time_tokens = n_time_tokens
self.n_layer = n_layer
self.n_head_layers = n_head_layers
self.n_shared_layers = n_shared_layers
self.apply(self._init_weights) self.apply(self._init_weights)
logger.info( logger.info(
"number of parameters: %e", sum(p.numel() for p in self.parameters()) "number of parameters: %e", sum(p.numel() for p in self.parameters())
) )
logger.info(
"PMFTransformerForDiffusion layers: shared=%d u_head=%d v_head=%d",
self.n_shared_layers,
self.n_head_layers,
self.n_head_layers,
)
def _init_weights(self, module): def _init_weights(self, module):
ignore_types = ( ignore_types = (
@@ -255,11 +281,22 @@ class PMFTransformerForDiffusion(ModuleAttrMixin):
token_pos = self.pos_emb[:, : input_emb.shape[1], :] token_pos = self.pos_emb[:, : input_emb.shape[1], :]
x = self.drop(input_emb + token_pos) x = self.drop(input_emb + token_pos)
x = self.decoder( shared_x = self.shared_decoder(
tgt=x, tgt=x,
memory=memory, memory=memory,
tgt_mask=self.mask, tgt_mask=self.mask,
memory_mask=self.memory_mask, memory_mask=self.memory_mask,
) )
x = self.ln_f(x) u_x = self.u_decoder(
return self.head_u(x), self.head_v(x) tgt=shared_x,
memory=memory,
tgt_mask=self.mask,
memory_mask=self.memory_mask,
)
v_x = self.v_decoder(
tgt=shared_x,
memory=memory,
tgt_mask=self.mask,
memory_mask=self.memory_mask,
)
return self.head_u(self.ln_u(u_x)), self.head_v(self.ln_v(v_x))

View File

@@ -42,6 +42,7 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy):
obs_as_cond=True, obs_as_cond=True,
pred_action_steps_only=False, pred_action_steps_only=False,
n_time_tokens=4, n_time_tokens=4,
n_head_layers=4,
min_time=0.05, min_time=0.05,
du_dt_epsilon=1.0e-3, du_dt_epsilon=1.0e-3,
pmf_u_loss_weight=1.0, pmf_u_loss_weight=1.0,
@@ -155,6 +156,7 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy):
obs_as_cond=obs_as_cond, obs_as_cond=obs_as_cond,
n_cond_layers=n_cond_layers, n_cond_layers=n_cond_layers,
n_time_tokens=n_time_tokens, n_time_tokens=n_time_tokens,
n_head_layers=n_head_layers,
) )
self.noise_scheduler = noise_scheduler self.noise_scheduler = noise_scheduler
self.mask_generator = LowdimMaskGenerator( self.mask_generator = LowdimMaskGenerator(

View File

@@ -63,6 +63,7 @@ policy:
n_emb: 256 n_emb: 256
n_head: 4 n_head: 4
n_layer: 12 n_layer: 12
n_head_layers: 4
n_obs_steps: 2 n_obs_steps: 2
n_time_tokens: 4 n_time_tokens: 4
noise_scale: 1.0 noise_scale: 1.0