feat: align pmf transformer training and config defaults
This commit is contained in:
@@ -47,6 +47,12 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy):
|
||||
pmf_u_loss_weight=1.0,
|
||||
pmf_v_loss_weight=1.0,
|
||||
noise_scale=1.0,
|
||||
adatloss_eps=0.01,
|
||||
p_mean=-0.4,
|
||||
p_std=1.0,
|
||||
tr_uniform=True,
|
||||
tr_uniform_prob=0.1,
|
||||
data_proportion=0.5,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -171,6 +177,12 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy):
|
||||
self.pmf_u_loss_weight = pmf_u_loss_weight
|
||||
self.pmf_v_loss_weight = pmf_v_loss_weight
|
||||
self.noise_scale = noise_scale
|
||||
self.adatloss_eps = adatloss_eps
|
||||
self.p_mean = p_mean
|
||||
self.p_std = p_std
|
||||
self.tr_uniform = tr_uniform
|
||||
self.tr_uniform_prob = tr_uniform_prob
|
||||
self.data_proportion = data_proportion
|
||||
self.kwargs = kwargs
|
||||
|
||||
if num_inference_steps is None:
|
||||
@@ -185,6 +197,33 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy):
|
||||
def _time_view(self, value: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
|
||||
return value.reshape(value.shape[0], *([1] * (ref.ndim - 1)))
|
||||
|
||||
def _adatloss(self, loss: torch.Tensor) -> torch.Tensor:
|
||||
denom = loss.detach() + self.adatloss_eps
|
||||
return loss / denom
|
||||
|
||||
def _sample_logit_normal(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||
normal = torch.randn(batch_size, device=device, dtype=dtype)
|
||||
return torch.sigmoid(normal * self.p_std + self.p_mean)
|
||||
|
||||
def _sample_tr(self, batch_size: int, device: torch.device, dtype: torch.dtype):
|
||||
t = self._sample_logit_normal(batch_size, device, dtype)
|
||||
r = self._sample_logit_normal(batch_size, device, dtype)
|
||||
|
||||
if self.tr_uniform:
|
||||
uniform_mask = torch.rand(batch_size, device=device) < self.tr_uniform_prob
|
||||
uniform_t = torch.rand(batch_size, device=device, dtype=dtype)
|
||||
uniform_r = torch.rand(batch_size, device=device, dtype=dtype)
|
||||
t = torch.where(uniform_mask, uniform_t, t)
|
||||
r = torch.where(uniform_mask, uniform_r, r)
|
||||
|
||||
data_size = int(batch_size * self.data_proportion)
|
||||
fm_mask = torch.arange(batch_size, device=device) < data_size
|
||||
r = torch.where(fm_mask, t, r)
|
||||
|
||||
t_final = torch.maximum(t, r)
|
||||
r_final = torch.minimum(t, r)
|
||||
return t_final, r_final
|
||||
|
||||
def _trajectory_inputs(
|
||||
self,
|
||||
nobs: Dict[str, torch.Tensor],
|
||||
@@ -217,10 +256,9 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy):
|
||||
condition_data: torch.Tensor,
|
||||
condition_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
if condition_mask.any():
|
||||
sample = sample.clone()
|
||||
sample[condition_mask] = condition_data[condition_mask]
|
||||
return sample
|
||||
if not condition_mask.any():
|
||||
return sample
|
||||
return torch.where(condition_mask, condition_data, sample)
|
||||
|
||||
def _compute_u_v(
|
||||
self,
|
||||
@@ -230,7 +268,7 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy):
|
||||
cond: torch.Tensor,
|
||||
):
|
||||
x_hat_u, x_hat_v = self.model(sample, t, r, cond)
|
||||
denom = self._time_view(t.clamp_min(self.min_time), sample)
|
||||
denom = self._time_view(t, sample)
|
||||
u = (sample - x_hat_u) / denom
|
||||
v = (sample - x_hat_v) / denom
|
||||
return u, v
|
||||
@@ -245,16 +283,31 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy):
|
||||
condition_mask: torch.Tensor,
|
||||
tangent_v: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
step = torch.full_like(t, self.du_dt_epsilon)
|
||||
sample_base = self._apply_conditioning(sample, condition_data, condition_mask)
|
||||
u_base, _ = self._compute_u_v(sample_base, t, r, cond)
|
||||
tangent_sample = tangent_v.detach()
|
||||
tangent_r = torch.zeros_like(r)
|
||||
tangent_t = torch.ones_like(t)
|
||||
|
||||
t_next = (t + step).clamp(max=1.0)
|
||||
sample_next = sample + self._time_view(step, sample) * tangent_v.detach()
|
||||
sample_next = self._apply_conditioning(sample_next, condition_data, condition_mask)
|
||||
u_next, _ = self._compute_u_v(sample_next, t_next, r, cond)
|
||||
def u_fn(sample_input, r_input, t_input):
|
||||
conditioned_sample = self._apply_conditioning(
|
||||
sample_input, condition_data, condition_mask
|
||||
)
|
||||
u_value, _ = self._compute_u_v(conditioned_sample, t_input, r_input, cond)
|
||||
return u_value
|
||||
|
||||
return (u_next - u_base) / self._time_view(step, sample)
|
||||
primals = (sample, r, t)
|
||||
tangents = (tangent_sample, tangent_r, tangent_t)
|
||||
try:
|
||||
_, du_dt = torch.func.jvp(u_fn, primals, tangents)
|
||||
except (AttributeError, NotImplementedError, RuntimeError):
|
||||
_, du_dt = torch.autograd.functional.jvp(
|
||||
u_fn,
|
||||
primals,
|
||||
tangents,
|
||||
create_graph=False,
|
||||
strict=False,
|
||||
)
|
||||
|
||||
return du_dt
|
||||
|
||||
# ========= inference ============
|
||||
def conditional_sample(
|
||||
@@ -368,14 +421,14 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy):
|
||||
noise = torch.randn_like(trajectory) * self.noise_scale
|
||||
batch_size = trajectory.shape[0]
|
||||
|
||||
t = torch.rand(batch_size, device=trajectory.device, dtype=trajectory.dtype)
|
||||
r = torch.rand(batch_size, device=trajectory.device, dtype=trajectory.dtype) * t
|
||||
t, r = self._sample_tr(
|
||||
batch_size, device=trajectory.device, dtype=trajectory.dtype
|
||||
)
|
||||
z_t = (1 - self._time_view(t, trajectory)) * trajectory + self._time_view(t, trajectory) * noise
|
||||
z_t = self._apply_conditioning(z_t, trajectory, condition_mask)
|
||||
|
||||
loss_mask = ~condition_mask
|
||||
denom = self._time_view(t.clamp_min(self.min_time), trajectory)
|
||||
target_v = (z_t - trajectory) / denom
|
||||
target_v = noise - trajectory
|
||||
|
||||
u, v = self._compute_u_v(z_t, t, r, cond)
|
||||
du_dt = self._compute_du_dt(
|
||||
@@ -395,4 +448,6 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy):
|
||||
loss_v = loss_v * loss_mask.type(loss_v.dtype)
|
||||
loss_u = reduce(loss_u, "b ... -> b (...)", "mean").mean()
|
||||
loss_v = reduce(loss_v, "b ... -> b (...)", "mean").mean()
|
||||
loss_u = self._adatloss(loss_u)
|
||||
loss_v = self._adatloss(loss_v)
|
||||
return self.pmf_u_loss_weight * loss_u + self.pmf_v_loss_weight * loss_v
|
||||
|
||||
@@ -58,15 +58,17 @@ policy:
|
||||
- 84
|
||||
eval_fixed_crop: true
|
||||
horizon: 16
|
||||
min_time: 0.05
|
||||
n_action_steps: 8
|
||||
n_cond_layers: 0
|
||||
n_emb: 256
|
||||
n_head: 4
|
||||
n_layer: 8
|
||||
n_layer: 12
|
||||
n_obs_steps: 2
|
||||
n_time_tokens: 4
|
||||
noise_scale: 1.0
|
||||
adatloss_eps: 0.01
|
||||
p_mean: -0.4
|
||||
p_std: 1.0
|
||||
noise_scheduler:
|
||||
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
||||
beta_end: 0.02
|
||||
@@ -76,13 +78,16 @@ policy:
|
||||
num_train_timesteps: 100
|
||||
prediction_type: sample
|
||||
variance_type: fixed_small
|
||||
num_inference_steps: 32
|
||||
num_inference_steps: 1
|
||||
obs_as_cond: true
|
||||
obs_encoder_group_norm: true
|
||||
p_drop_attn: 0.0
|
||||
p_drop_emb: 0.0
|
||||
pmf_u_loss_weight: 1.0
|
||||
pmf_v_loss_weight: 1.0
|
||||
tr_uniform: true
|
||||
tr_uniform_prob: 0.1
|
||||
data_proportion: 0.5
|
||||
shape_meta:
|
||||
action:
|
||||
shape:
|
||||
|
||||
Reference in New Issue
Block a user