feat: stabilize meanflow training and time sampling

This commit is contained in:
Logic
2026-03-11 22:54:48 +08:00
parent 9b2968997c
commit 5897a0afd1
2 changed files with 17 additions and 11 deletions

View File

@@ -40,8 +40,9 @@ class TrainConfig:
seq_len: int = 5
lr: float = 2e-4
weight_decay: float = 1e-2
max_grad_norm: float = 1.0
lambda_flow: float = 1.0
lambda_perceptual: float = 0.4
lambda_perceptual: float = 2.0
num_classes: int = 10
image_size: int = 28
channels: int = 1
@@ -189,7 +190,7 @@ class ASMamba(nn.Module):
)
cond_vec = self.cond_emb(cond)
feats, h = self.backbone(z_t, cond_vec, h)
x_pred = self.clean_head(feats)
x_pred = torch.tanh(self.clean_head(feats))
return x_pred, h
def step(
@@ -326,6 +327,8 @@ def validate_config(cfg: TrainConfig) -> None:
raise ValueError("time_grid_size must be >= 2.")
if cfg.lambda_perceptual < 0:
raise ValueError("lambda_perceptual must be >= 0.")
if cfg.max_grad_norm <= 0:
raise ValueError("max_grad_norm must be > 0.")
if cfg.val_sampling_steps != FIXED_VAL_SAMPLING_STEPS:
raise ValueError(
f"val_sampling_steps is fixed to {FIXED_VAL_SAMPLING_STEPS} for validation sampling."
@@ -335,19 +338,17 @@ def validate_config(cfg: TrainConfig) -> None:
def sample_block_times(
cfg: TrainConfig, batch_size: int, device: torch.device, dtype: torch.dtype
) -> tuple[Tensor, Tensor]:
# Sampling sorted discrete cut points allows repeated boundaries, so zero-length
# interior blocks occur with non-zero probability while keeping t > 0.
cuts = torch.randint(
1,
cfg.time_grid_size,
(batch_size, cfg.seq_len - 1),
device=device,
)
num_internal = cfg.seq_len - 1
normal = torch.randn(batch_size, num_internal, device=device, dtype=dtype)
logit_normal = torch.sigmoid(normal * math.sqrt(0.8))
uniform = torch.rand(batch_size, num_internal, device=device, dtype=dtype)
use_uniform = torch.rand(batch_size, num_internal, device=device) < 0.1
cuts = torch.where(use_uniform, uniform, logit_normal)
cuts, _ = torch.sort(cuts, dim=-1)
boundaries = torch.cat(
[
torch.zeros(batch_size, 1, device=device, dtype=dtype),
cuts.to(dtype=dtype) / float(cfg.time_grid_size),
cuts,
torch.ones(batch_size, 1, device=device, dtype=dtype),
],
dim=-1,
@@ -615,6 +616,9 @@ def train(cfg: TrainConfig) -> ASMamba:
optimizer.zero_grad(set_to_none=True)
losses["total"].backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), max_norm=cfg.max_grad_norm
)
optimizer.step()
if global_step % 10 == 0 and rank == 0:
@@ -623,6 +627,7 @@ def train(cfg: TrainConfig) -> ASMamba:
"loss/total": float(losses["total"].item()),
"loss/flow": float(losses["flow"].item()),
"loss/perceptual": float(losses["perceptual"].item()),
"grad/total_norm": float(grad_norm.item()),
"time/r_mean": float(r_seq.mean().item()),
"time/t_mean": float(t_seq.mean().item()),
"time/zero_block_frac": float((t_seq == r_seq).float().mean().item()),