From 5897a0afd1257f0b7ca453af766cdf53fdd185e1 Mon Sep 17 00:00:00 2001 From: Logic Date: Wed, 11 Mar 2026 22:54:48 +0800 Subject: [PATCH] feat: stabilize meanflow training and time sampling --- as_mamba.py | 27 ++++++++++++++++----------- main.py | 1 + 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/as_mamba.py b/as_mamba.py index b0d5d27..7bbd614 100644 --- a/as_mamba.py +++ b/as_mamba.py @@ -40,8 +40,9 @@ class TrainConfig: seq_len: int = 5 lr: float = 2e-4 weight_decay: float = 1e-2 + max_grad_norm: float = 1.0 lambda_flow: float = 1.0 - lambda_perceptual: float = 0.4 + lambda_perceptual: float = 2.0 num_classes: int = 10 image_size: int = 28 channels: int = 1 @@ -189,7 +190,7 @@ class ASMamba(nn.Module): ) cond_vec = self.cond_emb(cond) feats, h = self.backbone(z_t, cond_vec, h) - x_pred = self.clean_head(feats) + x_pred = torch.tanh(self.clean_head(feats)) return x_pred, h def step( @@ -326,6 +327,8 @@ def validate_config(cfg: TrainConfig) -> None: raise ValueError("time_grid_size must be >= 2.") if cfg.lambda_perceptual < 0: raise ValueError("lambda_perceptual must be >= 0.") + if cfg.max_grad_norm <= 0: + raise ValueError("max_grad_norm must be > 0.") if cfg.val_sampling_steps != FIXED_VAL_SAMPLING_STEPS: raise ValueError( f"val_sampling_steps is fixed to {FIXED_VAL_SAMPLING_STEPS} for validation sampling." @@ -335,19 +338,17 @@ def validate_config(cfg: TrainConfig) -> None: def sample_block_times( cfg: TrainConfig, batch_size: int, device: torch.device, dtype: torch.dtype ) -> tuple[Tensor, Tensor]: - # Sampling sorted discrete cut points allows repeated boundaries, so zero-length - # interior blocks occur with non-zero probability while keeping t > 0. - cuts = torch.randint( - 1, - cfg.time_grid_size, - (batch_size, cfg.seq_len - 1), - device=device, - ) + num_internal = cfg.seq_len - 1 + normal = torch.randn(batch_size, num_internal, device=device, dtype=dtype) + logit_normal = torch.sigmoid(normal * math.sqrt(0.8)) + uniform = torch.rand(batch_size, num_internal, device=device, dtype=dtype) + use_uniform = torch.rand(batch_size, num_internal, device=device) < 0.1 + cuts = torch.where(use_uniform, uniform, logit_normal) cuts, _ = torch.sort(cuts, dim=-1) boundaries = torch.cat( [ torch.zeros(batch_size, 1, device=device, dtype=dtype), - cuts.to(dtype=dtype) / float(cfg.time_grid_size), + cuts, torch.ones(batch_size, 1, device=device, dtype=dtype), ], dim=-1, @@ -615,6 +616,9 @@ def train(cfg: TrainConfig) -> ASMamba: optimizer.zero_grad(set_to_none=True) losses["total"].backward() + grad_norm = torch.nn.utils.clip_grad_norm_( + model.parameters(), max_norm=cfg.max_grad_norm + ) optimizer.step() if global_step % 10 == 0 and rank == 0: @@ -623,6 +627,7 @@ def train(cfg: TrainConfig) -> ASMamba: "loss/total": float(losses["total"].item()), "loss/flow": float(losses["flow"].item()), "loss/perceptual": float(losses["perceptual"].item()), + "grad/total_norm": float(grad_norm.item()), "time/r_mean": float(r_seq.mean().item()), "time/t_mean": float(t_seq.mean().item()), "time/zero_block_frac": float((t_seq == r_seq).float().mean().item()), diff --git a/main.py b/main.py index 8447b7e..7e40c4b 100644 --- a/main.py +++ b/main.py @@ -11,6 +11,7 @@ def build_parser() -> argparse.ArgumentParser: parser.add_argument("--seq-len", type=int, default=None) parser.add_argument("--lr", type=float, default=None) parser.add_argument("--weight-decay", type=float, default=None) + parser.add_argument("--max-grad-norm", type=float, default=None) parser.add_argument("--device", type=str, default=None) parser.add_argument("--output-dir", type=str, default=None) parser.add_argument("--project", type=str, default=None)