feat: align pmf transformer training and config defaults
This commit is contained in:
@@ -47,6 +47,12 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy):
|
|||||||
pmf_u_loss_weight=1.0,
|
pmf_u_loss_weight=1.0,
|
||||||
pmf_v_loss_weight=1.0,
|
pmf_v_loss_weight=1.0,
|
||||||
noise_scale=1.0,
|
noise_scale=1.0,
|
||||||
|
adatloss_eps=0.01,
|
||||||
|
p_mean=-0.4,
|
||||||
|
p_std=1.0,
|
||||||
|
tr_uniform=True,
|
||||||
|
tr_uniform_prob=0.1,
|
||||||
|
data_proportion=0.5,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -171,6 +177,12 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy):
|
|||||||
self.pmf_u_loss_weight = pmf_u_loss_weight
|
self.pmf_u_loss_weight = pmf_u_loss_weight
|
||||||
self.pmf_v_loss_weight = pmf_v_loss_weight
|
self.pmf_v_loss_weight = pmf_v_loss_weight
|
||||||
self.noise_scale = noise_scale
|
self.noise_scale = noise_scale
|
||||||
|
self.adatloss_eps = adatloss_eps
|
||||||
|
self.p_mean = p_mean
|
||||||
|
self.p_std = p_std
|
||||||
|
self.tr_uniform = tr_uniform
|
||||||
|
self.tr_uniform_prob = tr_uniform_prob
|
||||||
|
self.data_proportion = data_proportion
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
|
||||||
if num_inference_steps is None:
|
if num_inference_steps is None:
|
||||||
@@ -185,6 +197,33 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy):
|
|||||||
def _time_view(self, value: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
|
def _time_view(self, value: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
|
||||||
return value.reshape(value.shape[0], *([1] * (ref.ndim - 1)))
|
return value.reshape(value.shape[0], *([1] * (ref.ndim - 1)))
|
||||||
|
|
||||||
|
def _adatloss(self, loss: torch.Tensor) -> torch.Tensor:
|
||||||
|
denom = loss.detach() + self.adatloss_eps
|
||||||
|
return loss / denom
|
||||||
|
|
||||||
|
def _sample_logit_normal(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||||
|
normal = torch.randn(batch_size, device=device, dtype=dtype)
|
||||||
|
return torch.sigmoid(normal * self.p_std + self.p_mean)
|
||||||
|
|
||||||
|
def _sample_tr(self, batch_size: int, device: torch.device, dtype: torch.dtype):
|
||||||
|
t = self._sample_logit_normal(batch_size, device, dtype)
|
||||||
|
r = self._sample_logit_normal(batch_size, device, dtype)
|
||||||
|
|
||||||
|
if self.tr_uniform:
|
||||||
|
uniform_mask = torch.rand(batch_size, device=device) < self.tr_uniform_prob
|
||||||
|
uniform_t = torch.rand(batch_size, device=device, dtype=dtype)
|
||||||
|
uniform_r = torch.rand(batch_size, device=device, dtype=dtype)
|
||||||
|
t = torch.where(uniform_mask, uniform_t, t)
|
||||||
|
r = torch.where(uniform_mask, uniform_r, r)
|
||||||
|
|
||||||
|
data_size = int(batch_size * self.data_proportion)
|
||||||
|
fm_mask = torch.arange(batch_size, device=device) < data_size
|
||||||
|
r = torch.where(fm_mask, t, r)
|
||||||
|
|
||||||
|
t_final = torch.maximum(t, r)
|
||||||
|
r_final = torch.minimum(t, r)
|
||||||
|
return t_final, r_final
|
||||||
|
|
||||||
def _trajectory_inputs(
|
def _trajectory_inputs(
|
||||||
self,
|
self,
|
||||||
nobs: Dict[str, torch.Tensor],
|
nobs: Dict[str, torch.Tensor],
|
||||||
@@ -217,10 +256,9 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy):
|
|||||||
condition_data: torch.Tensor,
|
condition_data: torch.Tensor,
|
||||||
condition_mask: torch.Tensor,
|
condition_mask: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if condition_mask.any():
|
if not condition_mask.any():
|
||||||
sample = sample.clone()
|
|
||||||
sample[condition_mask] = condition_data[condition_mask]
|
|
||||||
return sample
|
return sample
|
||||||
|
return torch.where(condition_mask, condition_data, sample)
|
||||||
|
|
||||||
def _compute_u_v(
|
def _compute_u_v(
|
||||||
self,
|
self,
|
||||||
@@ -230,7 +268,7 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy):
|
|||||||
cond: torch.Tensor,
|
cond: torch.Tensor,
|
||||||
):
|
):
|
||||||
x_hat_u, x_hat_v = self.model(sample, t, r, cond)
|
x_hat_u, x_hat_v = self.model(sample, t, r, cond)
|
||||||
denom = self._time_view(t.clamp_min(self.min_time), sample)
|
denom = self._time_view(t, sample)
|
||||||
u = (sample - x_hat_u) / denom
|
u = (sample - x_hat_u) / denom
|
||||||
v = (sample - x_hat_v) / denom
|
v = (sample - x_hat_v) / denom
|
||||||
return u, v
|
return u, v
|
||||||
@@ -245,16 +283,31 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy):
|
|||||||
condition_mask: torch.Tensor,
|
condition_mask: torch.Tensor,
|
||||||
tangent_v: torch.Tensor,
|
tangent_v: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
step = torch.full_like(t, self.du_dt_epsilon)
|
tangent_sample = tangent_v.detach()
|
||||||
sample_base = self._apply_conditioning(sample, condition_data, condition_mask)
|
tangent_r = torch.zeros_like(r)
|
||||||
u_base, _ = self._compute_u_v(sample_base, t, r, cond)
|
tangent_t = torch.ones_like(t)
|
||||||
|
|
||||||
t_next = (t + step).clamp(max=1.0)
|
def u_fn(sample_input, r_input, t_input):
|
||||||
sample_next = sample + self._time_view(step, sample) * tangent_v.detach()
|
conditioned_sample = self._apply_conditioning(
|
||||||
sample_next = self._apply_conditioning(sample_next, condition_data, condition_mask)
|
sample_input, condition_data, condition_mask
|
||||||
u_next, _ = self._compute_u_v(sample_next, t_next, r, cond)
|
)
|
||||||
|
u_value, _ = self._compute_u_v(conditioned_sample, t_input, r_input, cond)
|
||||||
|
return u_value
|
||||||
|
|
||||||
return (u_next - u_base) / self._time_view(step, sample)
|
primals = (sample, r, t)
|
||||||
|
tangents = (tangent_sample, tangent_r, tangent_t)
|
||||||
|
try:
|
||||||
|
_, du_dt = torch.func.jvp(u_fn, primals, tangents)
|
||||||
|
except (AttributeError, NotImplementedError, RuntimeError):
|
||||||
|
_, du_dt = torch.autograd.functional.jvp(
|
||||||
|
u_fn,
|
||||||
|
primals,
|
||||||
|
tangents,
|
||||||
|
create_graph=False,
|
||||||
|
strict=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return du_dt
|
||||||
|
|
||||||
# ========= inference ============
|
# ========= inference ============
|
||||||
def conditional_sample(
|
def conditional_sample(
|
||||||
@@ -368,14 +421,14 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy):
|
|||||||
noise = torch.randn_like(trajectory) * self.noise_scale
|
noise = torch.randn_like(trajectory) * self.noise_scale
|
||||||
batch_size = trajectory.shape[0]
|
batch_size = trajectory.shape[0]
|
||||||
|
|
||||||
t = torch.rand(batch_size, device=trajectory.device, dtype=trajectory.dtype)
|
t, r = self._sample_tr(
|
||||||
r = torch.rand(batch_size, device=trajectory.device, dtype=trajectory.dtype) * t
|
batch_size, device=trajectory.device, dtype=trajectory.dtype
|
||||||
|
)
|
||||||
z_t = (1 - self._time_view(t, trajectory)) * trajectory + self._time_view(t, trajectory) * noise
|
z_t = (1 - self._time_view(t, trajectory)) * trajectory + self._time_view(t, trajectory) * noise
|
||||||
z_t = self._apply_conditioning(z_t, trajectory, condition_mask)
|
z_t = self._apply_conditioning(z_t, trajectory, condition_mask)
|
||||||
|
|
||||||
loss_mask = ~condition_mask
|
loss_mask = ~condition_mask
|
||||||
denom = self._time_view(t.clamp_min(self.min_time), trajectory)
|
target_v = noise - trajectory
|
||||||
target_v = (z_t - trajectory) / denom
|
|
||||||
|
|
||||||
u, v = self._compute_u_v(z_t, t, r, cond)
|
u, v = self._compute_u_v(z_t, t, r, cond)
|
||||||
du_dt = self._compute_du_dt(
|
du_dt = self._compute_du_dt(
|
||||||
@@ -395,4 +448,6 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy):
|
|||||||
loss_v = loss_v * loss_mask.type(loss_v.dtype)
|
loss_v = loss_v * loss_mask.type(loss_v.dtype)
|
||||||
loss_u = reduce(loss_u, "b ... -> b (...)", "mean").mean()
|
loss_u = reduce(loss_u, "b ... -> b (...)", "mean").mean()
|
||||||
loss_v = reduce(loss_v, "b ... -> b (...)", "mean").mean()
|
loss_v = reduce(loss_v, "b ... -> b (...)", "mean").mean()
|
||||||
|
loss_u = self._adatloss(loss_u)
|
||||||
|
loss_v = self._adatloss(loss_v)
|
||||||
return self.pmf_u_loss_weight * loss_u + self.pmf_v_loss_weight * loss_v
|
return self.pmf_u_loss_weight * loss_u + self.pmf_v_loss_weight * loss_v
|
||||||
|
|||||||
@@ -58,15 +58,17 @@ policy:
|
|||||||
- 84
|
- 84
|
||||||
eval_fixed_crop: true
|
eval_fixed_crop: true
|
||||||
horizon: 16
|
horizon: 16
|
||||||
min_time: 0.05
|
|
||||||
n_action_steps: 8
|
n_action_steps: 8
|
||||||
n_cond_layers: 0
|
n_cond_layers: 0
|
||||||
n_emb: 256
|
n_emb: 256
|
||||||
n_head: 4
|
n_head: 4
|
||||||
n_layer: 8
|
n_layer: 12
|
||||||
n_obs_steps: 2
|
n_obs_steps: 2
|
||||||
n_time_tokens: 4
|
n_time_tokens: 4
|
||||||
noise_scale: 1.0
|
noise_scale: 1.0
|
||||||
|
adatloss_eps: 0.01
|
||||||
|
p_mean: -0.4
|
||||||
|
p_std: 1.0
|
||||||
noise_scheduler:
|
noise_scheduler:
|
||||||
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
||||||
beta_end: 0.02
|
beta_end: 0.02
|
||||||
@@ -76,13 +78,16 @@ policy:
|
|||||||
num_train_timesteps: 100
|
num_train_timesteps: 100
|
||||||
prediction_type: sample
|
prediction_type: sample
|
||||||
variance_type: fixed_small
|
variance_type: fixed_small
|
||||||
num_inference_steps: 32
|
num_inference_steps: 1
|
||||||
obs_as_cond: true
|
obs_as_cond: true
|
||||||
obs_encoder_group_norm: true
|
obs_encoder_group_norm: true
|
||||||
p_drop_attn: 0.0
|
p_drop_attn: 0.0
|
||||||
p_drop_emb: 0.0
|
p_drop_emb: 0.0
|
||||||
pmf_u_loss_weight: 1.0
|
pmf_u_loss_weight: 1.0
|
||||||
pmf_v_loss_weight: 1.0
|
pmf_v_loss_weight: 1.0
|
||||||
|
tr_uniform: true
|
||||||
|
tr_uniform_prob: 0.1
|
||||||
|
data_proportion: 0.5
|
||||||
shape_meta:
|
shape_meta:
|
||||||
action:
|
action:
|
||||||
shape:
|
shape:
|
||||||
|
|||||||
Reference in New Issue
Block a user