feat(pusht): add dual-head uv transformer
This commit is contained in:
@@ -27,6 +27,7 @@ class PMFTransformerForDiffusion(ModuleAttrMixin):
|
|||||||
obs_as_cond: bool = False,
|
obs_as_cond: bool = False,
|
||||||
n_cond_layers: int = 0,
|
n_cond_layers: int = 0,
|
||||||
n_time_tokens: int = 4,
|
n_time_tokens: int = 4,
|
||||||
|
n_head_layers: int = 4,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -34,11 +35,18 @@ class PMFTransformerForDiffusion(ModuleAttrMixin):
|
|||||||
n_obs_steps = horizon
|
n_obs_steps = horizon
|
||||||
if n_time_tokens < 1:
|
if n_time_tokens < 1:
|
||||||
raise ValueError("n_time_tokens must be >= 1")
|
raise ValueError("n_time_tokens must be >= 1")
|
||||||
|
if n_head_layers < 0:
|
||||||
|
raise ValueError("n_head_layers must be >= 0")
|
||||||
|
if n_head_layers >= n_layer:
|
||||||
|
raise ValueError(
|
||||||
|
"n_head_layers must be smaller than n_layer so shared trunk depth stays positive"
|
||||||
|
)
|
||||||
|
|
||||||
obs_as_cond = cond_dim > 0
|
obs_as_cond = cond_dim > 0
|
||||||
T = horizon
|
T = horizon
|
||||||
n_global_cond_tokens = 2 * n_time_tokens
|
n_global_cond_tokens = 2 * n_time_tokens
|
||||||
T_cond = n_global_cond_tokens + (n_obs_steps if obs_as_cond else 0)
|
T_cond = n_global_cond_tokens + (n_obs_steps if obs_as_cond else 0)
|
||||||
|
n_shared_layers = n_layer - n_head_layers
|
||||||
|
|
||||||
self.input_emb = nn.Linear(input_dim, n_emb)
|
self.input_emb = nn.Linear(input_dim, n_emb)
|
||||||
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
|
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
|
||||||
@@ -81,9 +89,17 @@ class PMFTransformerForDiffusion(ModuleAttrMixin):
|
|||||||
batch_first=True,
|
batch_first=True,
|
||||||
norm_first=True,
|
norm_first=True,
|
||||||
)
|
)
|
||||||
self.decoder = nn.TransformerDecoder(
|
self.shared_decoder = nn.TransformerDecoder(
|
||||||
decoder_layer=decoder_layer,
|
decoder_layer=decoder_layer,
|
||||||
num_layers=n_layer,
|
num_layers=n_shared_layers,
|
||||||
|
)
|
||||||
|
self.u_decoder = nn.TransformerDecoder(
|
||||||
|
decoder_layer=decoder_layer,
|
||||||
|
num_layers=n_head_layers,
|
||||||
|
)
|
||||||
|
self.v_decoder = nn.TransformerDecoder(
|
||||||
|
decoder_layer=decoder_layer,
|
||||||
|
num_layers=n_head_layers,
|
||||||
)
|
)
|
||||||
|
|
||||||
if causal_attn:
|
if causal_attn:
|
||||||
@@ -109,7 +125,8 @@ class PMFTransformerForDiffusion(ModuleAttrMixin):
|
|||||||
self.mask = None
|
self.mask = None
|
||||||
self.memory_mask = None
|
self.memory_mask = None
|
||||||
|
|
||||||
self.ln_f = nn.LayerNorm(n_emb)
|
self.ln_u = nn.LayerNorm(n_emb)
|
||||||
|
self.ln_v = nn.LayerNorm(n_emb)
|
||||||
self.head_u = nn.Linear(n_emb, output_dim)
|
self.head_u = nn.Linear(n_emb, output_dim)
|
||||||
self.head_v = nn.Linear(n_emb, output_dim)
|
self.head_v = nn.Linear(n_emb, output_dim)
|
||||||
|
|
||||||
@@ -120,11 +137,20 @@ class PMFTransformerForDiffusion(ModuleAttrMixin):
|
|||||||
self.obs_as_cond = obs_as_cond
|
self.obs_as_cond = obs_as_cond
|
||||||
self.n_global_cond_tokens = n_global_cond_tokens
|
self.n_global_cond_tokens = n_global_cond_tokens
|
||||||
self.n_time_tokens = n_time_tokens
|
self.n_time_tokens = n_time_tokens
|
||||||
|
self.n_layer = n_layer
|
||||||
|
self.n_head_layers = n_head_layers
|
||||||
|
self.n_shared_layers = n_shared_layers
|
||||||
|
|
||||||
self.apply(self._init_weights)
|
self.apply(self._init_weights)
|
||||||
logger.info(
|
logger.info(
|
||||||
"number of parameters: %e", sum(p.numel() for p in self.parameters())
|
"number of parameters: %e", sum(p.numel() for p in self.parameters())
|
||||||
)
|
)
|
||||||
|
logger.info(
|
||||||
|
"PMFTransformerForDiffusion layers: shared=%d u_head=%d v_head=%d",
|
||||||
|
self.n_shared_layers,
|
||||||
|
self.n_head_layers,
|
||||||
|
self.n_head_layers,
|
||||||
|
)
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
ignore_types = (
|
ignore_types = (
|
||||||
@@ -255,11 +281,22 @@ class PMFTransformerForDiffusion(ModuleAttrMixin):
|
|||||||
|
|
||||||
token_pos = self.pos_emb[:, : input_emb.shape[1], :]
|
token_pos = self.pos_emb[:, : input_emb.shape[1], :]
|
||||||
x = self.drop(input_emb + token_pos)
|
x = self.drop(input_emb + token_pos)
|
||||||
x = self.decoder(
|
shared_x = self.shared_decoder(
|
||||||
tgt=x,
|
tgt=x,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
tgt_mask=self.mask,
|
tgt_mask=self.mask,
|
||||||
memory_mask=self.memory_mask,
|
memory_mask=self.memory_mask,
|
||||||
)
|
)
|
||||||
x = self.ln_f(x)
|
u_x = self.u_decoder(
|
||||||
return self.head_u(x), self.head_v(x)
|
tgt=shared_x,
|
||||||
|
memory=memory,
|
||||||
|
tgt_mask=self.mask,
|
||||||
|
memory_mask=self.memory_mask,
|
||||||
|
)
|
||||||
|
v_x = self.v_decoder(
|
||||||
|
tgt=shared_x,
|
||||||
|
memory=memory,
|
||||||
|
tgt_mask=self.mask,
|
||||||
|
memory_mask=self.memory_mask,
|
||||||
|
)
|
||||||
|
return self.head_u(self.ln_u(u_x)), self.head_v(self.ln_v(v_x))
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy):
|
|||||||
obs_as_cond=True,
|
obs_as_cond=True,
|
||||||
pred_action_steps_only=False,
|
pred_action_steps_only=False,
|
||||||
n_time_tokens=4,
|
n_time_tokens=4,
|
||||||
|
n_head_layers=4,
|
||||||
min_time=0.05,
|
min_time=0.05,
|
||||||
du_dt_epsilon=1.0e-3,
|
du_dt_epsilon=1.0e-3,
|
||||||
pmf_u_loss_weight=1.0,
|
pmf_u_loss_weight=1.0,
|
||||||
@@ -155,6 +156,7 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy):
|
|||||||
obs_as_cond=obs_as_cond,
|
obs_as_cond=obs_as_cond,
|
||||||
n_cond_layers=n_cond_layers,
|
n_cond_layers=n_cond_layers,
|
||||||
n_time_tokens=n_time_tokens,
|
n_time_tokens=n_time_tokens,
|
||||||
|
n_head_layers=n_head_layers,
|
||||||
)
|
)
|
||||||
self.noise_scheduler = noise_scheduler
|
self.noise_scheduler = noise_scheduler
|
||||||
self.mask_generator = LowdimMaskGenerator(
|
self.mask_generator = LowdimMaskGenerator(
|
||||||
|
|||||||
@@ -63,6 +63,7 @@ policy:
|
|||||||
n_emb: 256
|
n_emb: 256
|
||||||
n_head: 4
|
n_head: 4
|
||||||
n_layer: 12
|
n_layer: 12
|
||||||
|
n_head_layers: 4
|
||||||
n_obs_steps: 2
|
n_obs_steps: 2
|
||||||
n_time_tokens: 4
|
n_time_tokens: 4
|
||||||
noise_scale: 1.0
|
noise_scale: 1.0
|
||||||
|
|||||||
Reference in New Issue
Block a user