feat: align pmf transformer training and config defaults

This commit is contained in:
gameloader
2026-03-16 15:37:32 +08:00
parent 79f31940c4
commit 42dc29a2cb
2 changed files with 80 additions and 20 deletions

View File

@@ -47,6 +47,12 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy):
pmf_u_loss_weight=1.0,
pmf_v_loss_weight=1.0,
noise_scale=1.0,
adatloss_eps=0.01,
p_mean=-0.4,
p_std=1.0,
tr_uniform=True,
tr_uniform_prob=0.1,
data_proportion=0.5,
**kwargs,
):
super().__init__()
@@ -171,6 +177,12 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy):
self.pmf_u_loss_weight = pmf_u_loss_weight
self.pmf_v_loss_weight = pmf_v_loss_weight
self.noise_scale = noise_scale
self.adatloss_eps = adatloss_eps
self.p_mean = p_mean
self.p_std = p_std
self.tr_uniform = tr_uniform
self.tr_uniform_prob = tr_uniform_prob
self.data_proportion = data_proportion
self.kwargs = kwargs
if num_inference_steps is None:
@@ -185,6 +197,33 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy):
def _time_view(self, value: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
return value.reshape(value.shape[0], *([1] * (ref.ndim - 1)))
def _adatloss(self, loss: torch.Tensor) -> torch.Tensor:
denom = loss.detach() + self.adatloss_eps
return loss / denom
def _sample_logit_normal(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
normal = torch.randn(batch_size, device=device, dtype=dtype)
return torch.sigmoid(normal * self.p_std + self.p_mean)
def _sample_tr(self, batch_size: int, device: torch.device, dtype: torch.dtype):
t = self._sample_logit_normal(batch_size, device, dtype)
r = self._sample_logit_normal(batch_size, device, dtype)
if self.tr_uniform:
uniform_mask = torch.rand(batch_size, device=device) < self.tr_uniform_prob
uniform_t = torch.rand(batch_size, device=device, dtype=dtype)
uniform_r = torch.rand(batch_size, device=device, dtype=dtype)
t = torch.where(uniform_mask, uniform_t, t)
r = torch.where(uniform_mask, uniform_r, r)
data_size = int(batch_size * self.data_proportion)
fm_mask = torch.arange(batch_size, device=device) < data_size
r = torch.where(fm_mask, t, r)
t_final = torch.maximum(t, r)
r_final = torch.minimum(t, r)
return t_final, r_final
def _trajectory_inputs(
self,
nobs: Dict[str, torch.Tensor],
@@ -217,10 +256,9 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy):
condition_data: torch.Tensor,
condition_mask: torch.Tensor,
) -> torch.Tensor:
if condition_mask.any():
sample = sample.clone()
sample[condition_mask] = condition_data[condition_mask]
return sample
if not condition_mask.any():
return sample
return torch.where(condition_mask, condition_data, sample)
def _compute_u_v(
self,
@@ -230,7 +268,7 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy):
cond: torch.Tensor,
):
x_hat_u, x_hat_v = self.model(sample, t, r, cond)
denom = self._time_view(t.clamp_min(self.min_time), sample)
denom = self._time_view(t, sample)
u = (sample - x_hat_u) / denom
v = (sample - x_hat_v) / denom
return u, v
@@ -245,16 +283,31 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy):
condition_mask: torch.Tensor,
tangent_v: torch.Tensor,
) -> torch.Tensor:
step = torch.full_like(t, self.du_dt_epsilon)
sample_base = self._apply_conditioning(sample, condition_data, condition_mask)
u_base, _ = self._compute_u_v(sample_base, t, r, cond)
tangent_sample = tangent_v.detach()
tangent_r = torch.zeros_like(r)
tangent_t = torch.ones_like(t)
t_next = (t + step).clamp(max=1.0)
sample_next = sample + self._time_view(step, sample) * tangent_v.detach()
sample_next = self._apply_conditioning(sample_next, condition_data, condition_mask)
u_next, _ = self._compute_u_v(sample_next, t_next, r, cond)
def u_fn(sample_input, r_input, t_input):
conditioned_sample = self._apply_conditioning(
sample_input, condition_data, condition_mask
)
u_value, _ = self._compute_u_v(conditioned_sample, t_input, r_input, cond)
return u_value
return (u_next - u_base) / self._time_view(step, sample)
primals = (sample, r, t)
tangents = (tangent_sample, tangent_r, tangent_t)
try:
_, du_dt = torch.func.jvp(u_fn, primals, tangents)
except (AttributeError, NotImplementedError, RuntimeError):
_, du_dt = torch.autograd.functional.jvp(
u_fn,
primals,
tangents,
create_graph=False,
strict=False,
)
return du_dt
# ========= inference ============
def conditional_sample(
@@ -368,14 +421,14 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy):
noise = torch.randn_like(trajectory) * self.noise_scale
batch_size = trajectory.shape[0]
t = torch.rand(batch_size, device=trajectory.device, dtype=trajectory.dtype)
r = torch.rand(batch_size, device=trajectory.device, dtype=trajectory.dtype) * t
t, r = self._sample_tr(
batch_size, device=trajectory.device, dtype=trajectory.dtype
)
z_t = (1 - self._time_view(t, trajectory)) * trajectory + self._time_view(t, trajectory) * noise
z_t = self._apply_conditioning(z_t, trajectory, condition_mask)
loss_mask = ~condition_mask
denom = self._time_view(t.clamp_min(self.min_time), trajectory)
target_v = (z_t - trajectory) / denom
target_v = noise - trajectory
u, v = self._compute_u_v(z_t, t, r, cond)
du_dt = self._compute_du_dt(
@@ -395,4 +448,6 @@ class PMFTransformerHybridImagePolicy(BaseImagePolicy):
loss_v = loss_v * loss_mask.type(loss_v.dtype)
loss_u = reduce(loss_u, "b ... -> b (...)", "mean").mean()
loss_v = reduce(loss_v, "b ... -> b (...)", "mean").mean()
loss_u = self._adatloss(loss_u)
loss_v = self._adatloss(loss_v)
return self.pmf_u_loss_weight * loss_u + self.pmf_v_loss_weight * loss_v

View File

@@ -58,15 +58,17 @@ policy:
- 84
eval_fixed_crop: true
horizon: 16
min_time: 0.05
n_action_steps: 8
n_cond_layers: 0
n_emb: 256
n_head: 4
n_layer: 8
n_layer: 12
n_obs_steps: 2
n_time_tokens: 4
noise_scale: 1.0
adatloss_eps: 0.01
p_mean: -0.4
p_std: 1.0
noise_scheduler:
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
beta_end: 0.02
@@ -76,13 +78,16 @@ policy:
num_train_timesteps: 100
prediction_type: sample
variance_type: fixed_small
num_inference_steps: 32
num_inference_steps: 1
obs_as_cond: true
obs_encoder_group_norm: true
p_drop_attn: 0.0
p_drop_emb: 0.0
pmf_u_loss_weight: 1.0
pmf_v_loss_weight: 1.0
tr_uniform: true
tr_uniform_prob: 0.1
data_proportion: 0.5
shape_meta:
action:
shape: