feat: stabilize meanflow training and time sampling

This commit is contained in:
Logic
2026-03-11 22:54:48 +08:00
parent 9b2968997c
commit 5897a0afd1
2 changed files with 17 additions and 11 deletions

View File

@@ -40,8 +40,9 @@ class TrainConfig:
seq_len: int = 5 seq_len: int = 5
lr: float = 2e-4 lr: float = 2e-4
weight_decay: float = 1e-2 weight_decay: float = 1e-2
max_grad_norm: float = 1.0
lambda_flow: float = 1.0 lambda_flow: float = 1.0
lambda_perceptual: float = 0.4 lambda_perceptual: float = 2.0
num_classes: int = 10 num_classes: int = 10
image_size: int = 28 image_size: int = 28
channels: int = 1 channels: int = 1
@@ -189,7 +190,7 @@ class ASMamba(nn.Module):
) )
cond_vec = self.cond_emb(cond) cond_vec = self.cond_emb(cond)
feats, h = self.backbone(z_t, cond_vec, h) feats, h = self.backbone(z_t, cond_vec, h)
x_pred = self.clean_head(feats) x_pred = torch.tanh(self.clean_head(feats))
return x_pred, h return x_pred, h
def step( def step(
@@ -326,6 +327,8 @@ def validate_config(cfg: TrainConfig) -> None:
raise ValueError("time_grid_size must be >= 2.") raise ValueError("time_grid_size must be >= 2.")
if cfg.lambda_perceptual < 0: if cfg.lambda_perceptual < 0:
raise ValueError("lambda_perceptual must be >= 0.") raise ValueError("lambda_perceptual must be >= 0.")
if cfg.max_grad_norm <= 0:
raise ValueError("max_grad_norm must be > 0.")
if cfg.val_sampling_steps != FIXED_VAL_SAMPLING_STEPS: if cfg.val_sampling_steps != FIXED_VAL_SAMPLING_STEPS:
raise ValueError( raise ValueError(
f"val_sampling_steps is fixed to {FIXED_VAL_SAMPLING_STEPS} for validation sampling." f"val_sampling_steps is fixed to {FIXED_VAL_SAMPLING_STEPS} for validation sampling."
@@ -335,19 +338,17 @@ def validate_config(cfg: TrainConfig) -> None:
def sample_block_times( def sample_block_times(
cfg: TrainConfig, batch_size: int, device: torch.device, dtype: torch.dtype cfg: TrainConfig, batch_size: int, device: torch.device, dtype: torch.dtype
) -> tuple[Tensor, Tensor]: ) -> tuple[Tensor, Tensor]:
# Sampling sorted discrete cut points allows repeated boundaries, so zero-length num_internal = cfg.seq_len - 1
# interior blocks occur with non-zero probability while keeping t > 0. normal = torch.randn(batch_size, num_internal, device=device, dtype=dtype)
cuts = torch.randint( logit_normal = torch.sigmoid(normal * math.sqrt(0.8))
1, uniform = torch.rand(batch_size, num_internal, device=device, dtype=dtype)
cfg.time_grid_size, use_uniform = torch.rand(batch_size, num_internal, device=device) < 0.1
(batch_size, cfg.seq_len - 1), cuts = torch.where(use_uniform, uniform, logit_normal)
device=device,
)
cuts, _ = torch.sort(cuts, dim=-1) cuts, _ = torch.sort(cuts, dim=-1)
boundaries = torch.cat( boundaries = torch.cat(
[ [
torch.zeros(batch_size, 1, device=device, dtype=dtype), torch.zeros(batch_size, 1, device=device, dtype=dtype),
cuts.to(dtype=dtype) / float(cfg.time_grid_size), cuts,
torch.ones(batch_size, 1, device=device, dtype=dtype), torch.ones(batch_size, 1, device=device, dtype=dtype),
], ],
dim=-1, dim=-1,
@@ -615,6 +616,9 @@ def train(cfg: TrainConfig) -> ASMamba:
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)
losses["total"].backward() losses["total"].backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), max_norm=cfg.max_grad_norm
)
optimizer.step() optimizer.step()
if global_step % 10 == 0 and rank == 0: if global_step % 10 == 0 and rank == 0:
@@ -623,6 +627,7 @@ def train(cfg: TrainConfig) -> ASMamba:
"loss/total": float(losses["total"].item()), "loss/total": float(losses["total"].item()),
"loss/flow": float(losses["flow"].item()), "loss/flow": float(losses["flow"].item()),
"loss/perceptual": float(losses["perceptual"].item()), "loss/perceptual": float(losses["perceptual"].item()),
"grad/total_norm": float(grad_norm.item()),
"time/r_mean": float(r_seq.mean().item()), "time/r_mean": float(r_seq.mean().item()),
"time/t_mean": float(t_seq.mean().item()), "time/t_mean": float(t_seq.mean().item()),
"time/zero_block_frac": float((t_seq == r_seq).float().mean().item()), "time/zero_block_frac": float((t_seq == r_seq).float().mean().item()),

View File

@@ -11,6 +11,7 @@ def build_parser() -> argparse.ArgumentParser:
parser.add_argument("--seq-len", type=int, default=None) parser.add_argument("--seq-len", type=int, default=None)
parser.add_argument("--lr", type=float, default=None) parser.add_argument("--lr", type=float, default=None)
parser.add_argument("--weight-decay", type=float, default=None) parser.add_argument("--weight-decay", type=float, default=None)
parser.add_argument("--max-grad-norm", type=float, default=None)
parser.add_argument("--device", type=str, default=None) parser.add_argument("--device", type=str, default=None)
parser.add_argument("--output-dir", type=str, default=None) parser.add_argument("--output-dir", type=str, default=None)
parser.add_argument("--project", type=str, default=None) parser.add_argument("--project", type=str, default=None)