Compare commits
1 Commits
DiT-imageP
...
dualhead_u
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9169e4d7e0 |
@@ -27,6 +27,7 @@ class PMFTransformerForDiffusion(ModuleAttrMixin):
|
||||
obs_as_cond: bool = False,
|
||||
n_cond_layers: int = 0,
|
||||
n_time_tokens: int = 4,
|
||||
n_head_layers: int = 4,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -34,11 +35,18 @@ class PMFTransformerForDiffusion(ModuleAttrMixin):
|
||||
n_obs_steps = horizon
|
||||
if n_time_tokens < 1:
|
||||
raise ValueError("n_time_tokens must be >= 1")
|
||||
if n_head_layers < 0:
|
||||
raise ValueError("n_head_layers must be >= 0")
|
||||
if n_head_layers >= n_layer:
|
||||
raise ValueError(
|
||||
"n_head_layers must be smaller than n_layer so shared trunk depth stays positive"
|
||||
)
|
||||
|
||||
obs_as_cond = cond_dim > 0
|
||||
T = horizon
|
||||
n_global_cond_tokens = 2 * n_time_tokens
|
||||
T_cond = n_global_cond_tokens + (n_obs_steps if obs_as_cond else 0)
|
||||
n_shared_layers = n_layer - n_head_layers
|
||||
|
||||
self.input_emb = nn.Linear(input_dim, n_emb)
|
||||
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
|
||||
@@ -81,9 +89,17 @@ class PMFTransformerForDiffusion(ModuleAttrMixin):
|
||||
batch_first=True,
|
||||
norm_first=True,
|
||||
)
|
||||
self.decoder = nn.TransformerDecoder(
|
||||
self.shared_decoder = nn.TransformerDecoder(
|
||||
decoder_layer=decoder_layer,
|
||||
num_layers=n_layer,
|
||||
num_layers=n_shared_layers,
|
||||
)
|
||||
self.u_decoder = nn.TransformerDecoder(
|
||||
decoder_layer=decoder_layer,
|
||||
num_layers=n_head_layers,
|
||||
)
|
||||
self.v_decoder = nn.TransformerDecoder(
|
||||
decoder_layer=decoder_layer,
|
||||
num_layers=n_head_layers,
|
||||
)
|
||||
|
||||
if causal_attn:
|
||||
@@ -109,7 +125,8 @@ class PMFTransformerForDiffusion(ModuleAttrMixin):
|
||||
self.mask = None
|
||||
self.memory_mask = None
|
||||
|
||||
self.ln_f = nn.LayerNorm(n_emb)
|
||||
self.ln_u = nn.LayerNorm(n_emb)
|
||||
self.ln_v = nn.LayerNorm(n_emb)
|
||||
self.head_u = nn.Linear(n_emb, output_dim)
|
||||
self.head_v = nn.Linear(n_emb, output_dim)
|
||||
|
||||
@@ -120,11 +137,20 @@ class PMFTransformerForDiffusion(ModuleAttrMixin):
|
||||
self.obs_as_cond = obs_as_cond
|
||||
self.n_global_cond_tokens = n_global_cond_tokens
|
||||
self.n_time_tokens = n_time_tokens
|
||||
self.n_layer = n_layer
|
||||
self.n_head_layers = n_head_layers
|
||||
self.n_shared_layers = n_shared_layers
|
||||
|
||||
self.apply(self._init_weights)
|
||||
logger.info(
|
||||
"number of parameters: %e", sum(p.numel() for p in self.parameters())
|
||||
)
|
||||
logger.info(
|
||||
"PMFTransformerForDiffusion layers: shared=%d u_head=%d v_head=%d",
|
||||
self.n_shared_layers,
|
||||
self.n_head_layers,
|
||||
self.n_head_layers,
|
||||
)
|
||||
|
||||
def _init_weights(self, module):
|
||||
ignore_types = (
|
||||
@@ -255,11 +281,22 @@ class PMFTransformerForDiffusion(ModuleAttrMixin):
|
||||
|
||||
token_pos = self.pos_emb[:, : input_emb.shape[1], :]
|
||||
x = self.drop(input_emb + token_pos)
|
||||
x = self.decoder(
|
||||
shared_x = self.shared_decoder(
|
||||
tgt=x,
|
||||
memory=memory,
|
||||
tgt_mask=self.mask,
|
||||
memory_mask=self.memory_mask,
|
||||
)
|
||||
x = self.ln_f(x)
|
||||
return self.head_u(x), self.head_v(x)
|
||||
u_x = self.u_decoder(
|
||||
tgt=shared_x,
|
||||
memory=memory,
|
||||
tgt_mask=self.mask,
|
||||
memory_mask=self.memory_mask,
|
||||
)
|
||||
v_x = self.v_decoder(
|
||||
tgt=shared_x,
|
||||
memory=memory,
|
||||
tgt_mask=self.mask,
|
||||
memory_mask=self.memory_mask,
|
||||
)
|
||||
return self.head_u(self.ln_u(u_x)), self.head_v(self.ln_v(v_x))
|
||||
|
||||
@@ -42,6 +42,7 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy):
|
||||
obs_as_cond=True,
|
||||
pred_action_steps_only=False,
|
||||
n_time_tokens=4,
|
||||
n_head_layers=4,
|
||||
min_time=0.05,
|
||||
du_dt_epsilon=1.0e-3,
|
||||
pmf_u_loss_weight=1.0,
|
||||
@@ -155,6 +156,7 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy):
|
||||
obs_as_cond=obs_as_cond,
|
||||
n_cond_layers=n_cond_layers,
|
||||
n_time_tokens=n_time_tokens,
|
||||
n_head_layers=n_head_layers,
|
||||
)
|
||||
self.noise_scheduler = noise_scheduler
|
||||
self.mask_generator = LowdimMaskGenerator(
|
||||
|
||||
@@ -63,6 +63,7 @@ policy:
|
||||
n_emb: 256
|
||||
n_head: 4
|
||||
n_layer: 12
|
||||
n_head_layers: 4
|
||||
n_obs_steps: 2
|
||||
n_time_tokens: 4
|
||||
noise_scale: 1.0
|
||||
|
||||
Reference in New Issue
Block a user