feat: stabilize meanflow training and time sampling
This commit is contained in:
27
as_mamba.py
27
as_mamba.py
@@ -40,8 +40,9 @@ class TrainConfig:
|
|||||||
seq_len: int = 5
|
seq_len: int = 5
|
||||||
lr: float = 2e-4
|
lr: float = 2e-4
|
||||||
weight_decay: float = 1e-2
|
weight_decay: float = 1e-2
|
||||||
|
max_grad_norm: float = 1.0
|
||||||
lambda_flow: float = 1.0
|
lambda_flow: float = 1.0
|
||||||
lambda_perceptual: float = 0.4
|
lambda_perceptual: float = 2.0
|
||||||
num_classes: int = 10
|
num_classes: int = 10
|
||||||
image_size: int = 28
|
image_size: int = 28
|
||||||
channels: int = 1
|
channels: int = 1
|
||||||
@@ -189,7 +190,7 @@ class ASMamba(nn.Module):
|
|||||||
)
|
)
|
||||||
cond_vec = self.cond_emb(cond)
|
cond_vec = self.cond_emb(cond)
|
||||||
feats, h = self.backbone(z_t, cond_vec, h)
|
feats, h = self.backbone(z_t, cond_vec, h)
|
||||||
x_pred = self.clean_head(feats)
|
x_pred = torch.tanh(self.clean_head(feats))
|
||||||
return x_pred, h
|
return x_pred, h
|
||||||
|
|
||||||
def step(
|
def step(
|
||||||
@@ -326,6 +327,8 @@ def validate_config(cfg: TrainConfig) -> None:
|
|||||||
raise ValueError("time_grid_size must be >= 2.")
|
raise ValueError("time_grid_size must be >= 2.")
|
||||||
if cfg.lambda_perceptual < 0:
|
if cfg.lambda_perceptual < 0:
|
||||||
raise ValueError("lambda_perceptual must be >= 0.")
|
raise ValueError("lambda_perceptual must be >= 0.")
|
||||||
|
if cfg.max_grad_norm <= 0:
|
||||||
|
raise ValueError("max_grad_norm must be > 0.")
|
||||||
if cfg.val_sampling_steps != FIXED_VAL_SAMPLING_STEPS:
|
if cfg.val_sampling_steps != FIXED_VAL_SAMPLING_STEPS:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"val_sampling_steps is fixed to {FIXED_VAL_SAMPLING_STEPS} for validation sampling."
|
f"val_sampling_steps is fixed to {FIXED_VAL_SAMPLING_STEPS} for validation sampling."
|
||||||
@@ -335,19 +338,17 @@ def validate_config(cfg: TrainConfig) -> None:
|
|||||||
def sample_block_times(
|
def sample_block_times(
|
||||||
cfg: TrainConfig, batch_size: int, device: torch.device, dtype: torch.dtype
|
cfg: TrainConfig, batch_size: int, device: torch.device, dtype: torch.dtype
|
||||||
) -> tuple[Tensor, Tensor]:
|
) -> tuple[Tensor, Tensor]:
|
||||||
# Sampling sorted discrete cut points allows repeated boundaries, so zero-length
|
num_internal = cfg.seq_len - 1
|
||||||
# interior blocks occur with non-zero probability while keeping t > 0.
|
normal = torch.randn(batch_size, num_internal, device=device, dtype=dtype)
|
||||||
cuts = torch.randint(
|
logit_normal = torch.sigmoid(normal * math.sqrt(0.8))
|
||||||
1,
|
uniform = torch.rand(batch_size, num_internal, device=device, dtype=dtype)
|
||||||
cfg.time_grid_size,
|
use_uniform = torch.rand(batch_size, num_internal, device=device) < 0.1
|
||||||
(batch_size, cfg.seq_len - 1),
|
cuts = torch.where(use_uniform, uniform, logit_normal)
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
cuts, _ = torch.sort(cuts, dim=-1)
|
cuts, _ = torch.sort(cuts, dim=-1)
|
||||||
boundaries = torch.cat(
|
boundaries = torch.cat(
|
||||||
[
|
[
|
||||||
torch.zeros(batch_size, 1, device=device, dtype=dtype),
|
torch.zeros(batch_size, 1, device=device, dtype=dtype),
|
||||||
cuts.to(dtype=dtype) / float(cfg.time_grid_size),
|
cuts,
|
||||||
torch.ones(batch_size, 1, device=device, dtype=dtype),
|
torch.ones(batch_size, 1, device=device, dtype=dtype),
|
||||||
],
|
],
|
||||||
dim=-1,
|
dim=-1,
|
||||||
@@ -615,6 +616,9 @@ def train(cfg: TrainConfig) -> ASMamba:
|
|||||||
|
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
losses["total"].backward()
|
losses["total"].backward()
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
|
model.parameters(), max_norm=cfg.max_grad_norm
|
||||||
|
)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
if global_step % 10 == 0 and rank == 0:
|
if global_step % 10 == 0 and rank == 0:
|
||||||
@@ -623,6 +627,7 @@ def train(cfg: TrainConfig) -> ASMamba:
|
|||||||
"loss/total": float(losses["total"].item()),
|
"loss/total": float(losses["total"].item()),
|
||||||
"loss/flow": float(losses["flow"].item()),
|
"loss/flow": float(losses["flow"].item()),
|
||||||
"loss/perceptual": float(losses["perceptual"].item()),
|
"loss/perceptual": float(losses["perceptual"].item()),
|
||||||
|
"grad/total_norm": float(grad_norm.item()),
|
||||||
"time/r_mean": float(r_seq.mean().item()),
|
"time/r_mean": float(r_seq.mean().item()),
|
||||||
"time/t_mean": float(t_seq.mean().item()),
|
"time/t_mean": float(t_seq.mean().item()),
|
||||||
"time/zero_block_frac": float((t_seq == r_seq).float().mean().item()),
|
"time/zero_block_frac": float((t_seq == r_seq).float().mean().item()),
|
||||||
|
|||||||
1
main.py
1
main.py
@@ -11,6 +11,7 @@ def build_parser() -> argparse.ArgumentParser:
|
|||||||
parser.add_argument("--seq-len", type=int, default=None)
|
parser.add_argument("--seq-len", type=int, default=None)
|
||||||
parser.add_argument("--lr", type=float, default=None)
|
parser.add_argument("--lr", type=float, default=None)
|
||||||
parser.add_argument("--weight-decay", type=float, default=None)
|
parser.add_argument("--weight-decay", type=float, default=None)
|
||||||
|
parser.add_argument("--max-grad-norm", type=float, default=None)
|
||||||
parser.add_argument("--device", type=str, default=None)
|
parser.add_argument("--device", type=str, default=None)
|
||||||
parser.add_argument("--output-dir", type=str, default=None)
|
parser.add_argument("--output-dir", type=str, default=None)
|
||||||
parser.add_argument("--project", type=str, default=None)
|
parser.add_argument("--project", type=str, default=None)
|
||||||
|
|||||||
Reference in New Issue
Block a user