feat: stabilize meanflow training and time sampling
This commit is contained in:
27
as_mamba.py
27
as_mamba.py
@@ -40,8 +40,9 @@ class TrainConfig:
|
||||
seq_len: int = 5
|
||||
lr: float = 2e-4
|
||||
weight_decay: float = 1e-2
|
||||
max_grad_norm: float = 1.0
|
||||
lambda_flow: float = 1.0
|
||||
lambda_perceptual: float = 0.4
|
||||
lambda_perceptual: float = 2.0
|
||||
num_classes: int = 10
|
||||
image_size: int = 28
|
||||
channels: int = 1
|
||||
@@ -189,7 +190,7 @@ class ASMamba(nn.Module):
|
||||
)
|
||||
cond_vec = self.cond_emb(cond)
|
||||
feats, h = self.backbone(z_t, cond_vec, h)
|
||||
x_pred = self.clean_head(feats)
|
||||
x_pred = torch.tanh(self.clean_head(feats))
|
||||
return x_pred, h
|
||||
|
||||
def step(
|
||||
@@ -326,6 +327,8 @@ def validate_config(cfg: TrainConfig) -> None:
|
||||
raise ValueError("time_grid_size must be >= 2.")
|
||||
if cfg.lambda_perceptual < 0:
|
||||
raise ValueError("lambda_perceptual must be >= 0.")
|
||||
if cfg.max_grad_norm <= 0:
|
||||
raise ValueError("max_grad_norm must be > 0.")
|
||||
if cfg.val_sampling_steps != FIXED_VAL_SAMPLING_STEPS:
|
||||
raise ValueError(
|
||||
f"val_sampling_steps is fixed to {FIXED_VAL_SAMPLING_STEPS} for validation sampling."
|
||||
@@ -335,19 +338,17 @@ def validate_config(cfg: TrainConfig) -> None:
|
||||
def sample_block_times(
|
||||
cfg: TrainConfig, batch_size: int, device: torch.device, dtype: torch.dtype
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
# Sampling sorted discrete cut points allows repeated boundaries, so zero-length
|
||||
# interior blocks occur with non-zero probability while keeping t > 0.
|
||||
cuts = torch.randint(
|
||||
1,
|
||||
cfg.time_grid_size,
|
||||
(batch_size, cfg.seq_len - 1),
|
||||
device=device,
|
||||
)
|
||||
num_internal = cfg.seq_len - 1
|
||||
normal = torch.randn(batch_size, num_internal, device=device, dtype=dtype)
|
||||
logit_normal = torch.sigmoid(normal * math.sqrt(0.8))
|
||||
uniform = torch.rand(batch_size, num_internal, device=device, dtype=dtype)
|
||||
use_uniform = torch.rand(batch_size, num_internal, device=device) < 0.1
|
||||
cuts = torch.where(use_uniform, uniform, logit_normal)
|
||||
cuts, _ = torch.sort(cuts, dim=-1)
|
||||
boundaries = torch.cat(
|
||||
[
|
||||
torch.zeros(batch_size, 1, device=device, dtype=dtype),
|
||||
cuts.to(dtype=dtype) / float(cfg.time_grid_size),
|
||||
cuts,
|
||||
torch.ones(batch_size, 1, device=device, dtype=dtype),
|
||||
],
|
||||
dim=-1,
|
||||
@@ -615,6 +616,9 @@ def train(cfg: TrainConfig) -> ASMamba:
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
losses["total"].backward()
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
model.parameters(), max_norm=cfg.max_grad_norm
|
||||
)
|
||||
optimizer.step()
|
||||
|
||||
if global_step % 10 == 0 and rank == 0:
|
||||
@@ -623,6 +627,7 @@ def train(cfg: TrainConfig) -> ASMamba:
|
||||
"loss/total": float(losses["total"].item()),
|
||||
"loss/flow": float(losses["flow"].item()),
|
||||
"loss/perceptual": float(losses["perceptual"].item()),
|
||||
"grad/total_norm": float(grad_norm.item()),
|
||||
"time/r_mean": float(r_seq.mean().item()),
|
||||
"time/t_mean": float(t_seq.mean().item()),
|
||||
"time/zero_block_frac": float((t_seq == r_seq).float().mean().item()),
|
||||
|
||||
1
main.py
1
main.py
@@ -11,6 +11,7 @@ def build_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument("--seq-len", type=int, default=None)
|
||||
parser.add_argument("--lr", type=float, default=None)
|
||||
parser.add_argument("--weight-decay", type=float, default=None)
|
||||
parser.add_argument("--max-grad-norm", type=float, default=None)
|
||||
parser.add_argument("--device", type=str, default=None)
|
||||
parser.add_argument("--output-dir", type=str, default=None)
|
||||
parser.add_argument("--project", type=str, default=None)
|
||||
|
||||
Reference in New Issue
Block a user