From cac3236f9d4dd7db55a23d0eeed924955312b473 Mon Sep 17 00:00:00 2001 From: gameloader Date: Wed, 21 Jan 2026 15:14:04 +0800 Subject: [PATCH] Add configurable dt sampling and loss toggles --- as_mamba.py | 164 +++++++++++++++++++++++++++++++++++----------- main.py | 22 ++++++- train_as_mamba.sh | 62 ++++++++++++++++++ 3 files changed, 207 insertions(+), 41 deletions(-) create mode 100755 train_as_mamba.sh diff --git a/as_mamba.py b/as_mamba.py index 8ea998c..f3c5ba1 100644 --- a/as_mamba.py +++ b/as_mamba.py @@ -25,15 +25,18 @@ class TrainConfig: batch_size: int = 128 steps_per_epoch: int = 50 epochs: int = 60 - warmup_epochs: int = 15 seq_len: int = 20 lr: float = 1e-3 weight_decay: float = 1e-2 dt_min: float = 1e-3 dt_max: float = 0.06 + dt_alpha: float = 8.0 lambda_flow: float = 1.0 lambda_pos: float = 1.0 - lambda_nfe: float = 0.05 + lambda_dt: float = 0.05 + use_flow_loss: bool = True + use_pos_loss: bool = False + use_dt_loss: bool = True radius_min: float = 0.6 radius_max: float = 1.4 center_min: float = -6.0 @@ -53,7 +56,7 @@ class TrainConfig: val_every: int = 200 val_samples: int = 256 val_plot_samples: int = 16 - val_max_steps: int = 100 + val_max_steps: int = 0 class Mamba2Backbone(nn.Module): @@ -221,21 +224,41 @@ def sample_sphere_params(cfg: TrainConfig, device: torch.device) -> tuple[Tensor ) +def sample_time_sequence(cfg: TrainConfig, batch_size: int, device: torch.device) -> Tensor: + alpha = float(cfg.dt_alpha) + if alpha <= 0: + raise ValueError("dt_alpha must be > 0") + dist = torch.distributions.Gamma(alpha, 1.0) + raw = dist.sample((batch_size, cfg.seq_len)).to(device) + dt_seq = raw / raw.sum(dim=-1, keepdim=True) + base = 1.0 / cfg.seq_len + max_dt = float(cfg.dt_max) + if max_dt <= base: + return torch.full_like(dt_seq, base) + max_current = dt_seq.max(dim=-1, keepdim=True).values + if (max_current > max_dt).any(): + gamma = (max_dt - base) / (max_current - base) + gamma = gamma.clamp(0.0, 1.0) + dt_seq = gamma * dt_seq + (1.0 - gamma) * base + return dt_seq + + def sample_batch( cfg: TrainConfig, sphere_a: tuple[Tensor, Tensor], sphere_b: tuple[Tensor, Tensor], device: torch.device, -) -> tuple[Tensor, Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: center_a, radius_a = sphere_a center_b, radius_b = sphere_b x0 = sample_points_in_sphere(center_a, float(radius_a.item()), cfg.batch_size, device) x1 = sample_points_in_sphere(center_b, float(radius_b.item()), cfg.batch_size, device) v_gt = x1 - x0 - dt_fixed = 1.0 / cfg.seq_len - t_seq = torch.arange(cfg.seq_len, device=device) * dt_fixed - x_seq = x0[:, None, :] + t_seq[None, :, None] * v_gt[:, None, :] - return x0, x1, x_seq, t_seq + dt_seq = sample_time_sequence(cfg, cfg.batch_size, device) + t_seq = torch.cumsum(dt_seq, dim=-1) + t_seq = torch.cat([torch.zeros(cfg.batch_size, 1, device=device), t_seq[:, :-1]], dim=-1) + x_seq = x0[:, None, :] + t_seq[:, :, None] * v_gt[:, None, :] + return x0, x1, x_seq, t_seq, dt_seq def compute_losses( @@ -245,19 +268,26 @@ def compute_losses( x0: Tensor, v_gt: Tensor, t_seq: Tensor, + dt_seq: Tensor, cfg: TrainConfig, -) -> tuple[Tensor, Tensor, Tensor]: - target_disp = v_gt[:, None, :] * dt.unsqueeze(-1) - flow_loss = F.mse_loss(delta, target_disp) +) -> dict[str, Tensor]: + losses: dict[str, Tensor] = {} - t_next = t_seq[None, :, None] + dt.unsqueeze(-1) - t_next = torch.clamp(t_next, 0.0, 1.0) - x_target = x0[:, None, :] + t_next * v_gt[:, None, :] - x_next_pred = x_seq + delta - pos_loss = F.mse_loss(x_next_pred, x_target) + if cfg.use_flow_loss: + target_disp = v_gt[:, None, :] * dt.unsqueeze(-1) + losses["flow"] = F.mse_loss(delta, target_disp) - nfe_loss = (-torch.log(dt)).mean() - return flow_loss, pos_loss, nfe_loss + if cfg.use_pos_loss: + t_next = t_seq + dt + t_next = torch.clamp(t_next, 0.0, 1.0) + x_target = x0[:, None, :] + t_next.unsqueeze(-1) * v_gt[:, None, :] + x_next_pred = x_seq + delta + losses["pos"] = F.mse_loss(x_next_pred, x_target) + + if cfg.use_dt_loss: + losses["dt"] = F.mse_loss(dt, dt_seq) + + return losses def validate( @@ -271,12 +301,13 @@ def validate( ) -> None: model.eval() center_b, radius_b = sphere_b + max_steps = cfg.seq_len if cfg.val_max_steps <= 0 else cfg.val_max_steps with torch.no_grad(): x0 = sample_points_in_sphere( sphere_a[0], float(sphere_a[1].item()), cfg.val_samples, device ) - traj = rollout_trajectory(model, x0, max_steps=cfg.val_max_steps) + traj = rollout_trajectory(model, x0, max_steps=max_steps) x_final = traj[:, -1, :] center_b_cpu = center_b.detach().cpu() @@ -291,6 +322,7 @@ def validate( "val/final_dist_mean": float(dist.mean().item()), "val/final_dist_min": float(dist.min().item()), "val/final_dist_max": float(dist.max().item()), + "val/max_steps": float(max_steps), }, step=step, ) @@ -351,54 +383,86 @@ def train(cfg: TrainConfig) -> tuple[ASMamba, tuple[Tensor, Tensor], tuple[Tenso global_step = 0 for epoch in range(cfg.epochs): - warmup = epoch < cfg.warmup_epochs model.train() - for p in model.dt_head.parameters(): - p.requires_grad = not warmup for _ in range(cfg.steps_per_epoch): - x0, x1, x_seq, t_seq = sample_batch(cfg, sphere_a, sphere_b, device) + x0, x1, x_seq, t_seq, dt_seq = sample_batch(cfg, sphere_a, sphere_b, device) v_gt = x1 - x0 delta, dt, _ = model(x_seq) - if warmup: - dt = torch.full_like(dt, 1.0 / cfg.seq_len) - flow_loss, pos_loss, nfe_loss = compute_losses( + losses = compute_losses( delta=delta, dt=dt, x_seq=x_seq, x0=x0, v_gt=v_gt, t_seq=t_seq, + dt_seq=dt_seq, cfg=cfg, ) - loss = cfg.lambda_flow * flow_loss + cfg.lambda_pos * pos_loss - if not warmup: - loss = loss + cfg.lambda_nfe * nfe_loss + loss = torch.tensor(0.0, device=device) + if cfg.use_flow_loss and "flow" in losses: + loss = loss + cfg.lambda_flow * losses["flow"] + if cfg.use_pos_loss and "pos" in losses: + loss = loss + cfg.lambda_pos * losses["pos"] + if cfg.use_dt_loss and "dt" in losses: + loss = loss + cfg.lambda_dt * losses["dt"] + if loss.item() == 0.0: + raise RuntimeError("No loss enabled: enable at least one of flow/pos/dt.") optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step() if global_step % 10 == 0: + dt_min = float(dt.min().item()) + dt_max = float(dt.max().item()) + dt_mean = float(dt.mean().item()) + dt_gt_min = float(dt_seq.min().item()) + dt_gt_max = float(dt_seq.max().item()) + dt_gt_mean = float(dt_seq.mean().item()) + eps = 1e-6 + clamp_min_ratio = float((dt <= cfg.dt_min + eps).float().mean().item()) + clamp_max_ratio = float((dt >= cfg.dt_max - eps).float().mean().item()) + clamp_any_ratio = float( + ((dt <= cfg.dt_min + eps) | (dt >= cfg.dt_max - eps)).float().mean().item() + ) logger.log( { "loss/total": float(loss.item()), - "loss/flow": float(flow_loss.item()), - "loss/pos": float(pos_loss.item()), - "loss/nfe": float(nfe_loss.item()), - "dt/mean": float(dt.mean().item()), - "dt/min": float(dt.min().item()), - "dt/max": float(dt.max().item()), - "stage": 0 if warmup else 1, + "loss/flow": float(losses.get("flow", torch.tensor(0.0)).item()), + "loss/pos": float(losses.get("pos", torch.tensor(0.0)).item()), + "loss/dt": float(losses.get("dt", torch.tensor(0.0)).item()), + "dt/pred_mean": dt_mean, + "dt/pred_min": dt_min, + "dt/pred_max": dt_max, + "dt/gt_mean": dt_gt_mean, + "dt/gt_min": dt_gt_min, + "dt/gt_max": dt_gt_max, + "dt/clamp_min_ratio": clamp_min_ratio, + "dt/clamp_max_ratio": clamp_max_ratio, + "dt/clamp_any_ratio": clamp_any_ratio, }, step=global_step, ) if cfg.val_every > 0 and global_step > 0 and global_step % cfg.val_every == 0: validate(model, cfg, sphere_a, sphere_b, device, logger, global_step) + dt_hist_path = Path(cfg.output_dir) / f"dt_hist_step_{global_step:06d}.png" + plot_dt_hist( + dt, + dt_seq, + dt_hist_path, + title=f"dt Distribution (step {global_step})", + ) + logger.log_image( + "train/dt_hist", + dt_hist_path, + caption=f"step {global_step}", + step=global_step, + ) global_step += 1 logger.finish() @@ -408,7 +472,7 @@ def train(cfg: TrainConfig) -> tuple[ASMamba, tuple[Tensor, Tensor], tuple[Tenso def rollout_trajectory( model: ASMamba, x0: Tensor, - max_steps: int = 100, + max_steps: int, ) -> Tensor: device = x0.device model.eval() @@ -427,11 +491,9 @@ def rollout_trajectory( scale = (remaining / dt).unsqueeze(-1) delta = torch.where(overshoot.unsqueeze(-1), delta * scale, delta) dt = torch.where(overshoot, remaining, dt) - x = x + delta total_time = total_time + dt traj.append(x.detach().cpu()) - if torch.all(total_time >= 1.0 - 1e-6): break @@ -496,6 +558,27 @@ def plot_trajectories( plt.close(fig) +def plot_dt_hist( + dt_pred: Tensor, + dt_gt: Tensor, + save_path: Path, + title: str = "dt Distribution", +) -> None: + dt_pred_np = dt_pred.detach().cpu().numpy().reshape(-1) + dt_gt_np = dt_gt.detach().cpu().numpy().reshape(-1) + + fig, ax = plt.subplots(figsize=(6, 4)) + ax.hist(dt_gt_np, bins=30, alpha=0.6, label="dt_gt", color="steelblue") + ax.hist(dt_pred_np, bins=30, alpha=0.6, label="dt_pred", color="orange") + ax.set_title(title) + ax.set_xlabel("dt") + ax.set_ylabel("count") + ax.legend(loc="best") + fig.tight_layout() + fig.savefig(save_path, dpi=160) + plt.close(fig) + + def run_training_and_plot(cfg: TrainConfig) -> Path: model, sphere_a, sphere_b = train(cfg) device = next(model.parameters()).device @@ -504,7 +587,8 @@ def run_training_and_plot(cfg: TrainConfig) -> Path: x0 = sample_points_in_sphere( sphere_a[0], float(sphere_a[1].item()), plot_samples, device ) - traj = rollout_trajectory(model, x0, max_steps=cfg.val_max_steps) + max_steps = cfg.seq_len if cfg.val_max_steps <= 0 else cfg.val_max_steps + traj = rollout_trajectory(model, x0, max_steps=max_steps) output_dir = Path(cfg.output_dir) save_path = output_dir / "as_mamba_trajectory.png" plot_trajectories(traj, sphere_a, sphere_b, save_path) diff --git a/main.py b/main.py index 1e49981..c8344a3 100644 --- a/main.py +++ b/main.py @@ -6,7 +6,6 @@ from as_mamba import TrainConfig, run_training_and_plot def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Train AS-Mamba on sphere-to-sphere flow.") parser.add_argument("--epochs", type=int, default=None) - parser.add_argument("--warmup-epochs", type=int, default=None) parser.add_argument("--batch-size", type=int, default=None) parser.add_argument("--steps-per-epoch", type=int, default=None) parser.add_argument("--seq-len", type=int, default=None) @@ -15,6 +14,27 @@ def build_parser() -> argparse.ArgumentParser: parser.add_argument("--output-dir", type=str, default=None) parser.add_argument("--project", type=str, default=None) parser.add_argument("--run-name", type=str, default=None) + parser.add_argument("--dt-alpha", type=float, default=None) + parser.add_argument("--dt-min", type=float, default=None) + parser.add_argument("--dt-max", type=float, default=None) + parser.add_argument("--lambda-flow", type=float, default=None) + parser.add_argument("--lambda-pos", type=float, default=None) + parser.add_argument("--lambda-dt", type=float, default=None) + parser.add_argument( + "--use-flow-loss", + action=argparse.BooleanOptionalAction, + default=None, + ) + parser.add_argument( + "--use-pos-loss", + action=argparse.BooleanOptionalAction, + default=None, + ) + parser.add_argument( + "--use-dt-loss", + action=argparse.BooleanOptionalAction, + default=None, + ) parser.add_argument("--val-every", type=int, default=None) parser.add_argument("--val-samples", type=int, default=None) parser.add_argument("--val-plot-samples", type=int, default=None) diff --git a/train_as_mamba.sh b/train_as_mamba.sh new file mode 100755 index 0000000..fe3688f --- /dev/null +++ b/train_as_mamba.sh @@ -0,0 +1,62 @@ +#!/usr/bin/env bash +set -euo pipefail + +DEVICE="cuda" +EPOCHS=80 +STEPS_PER_EPOCH=100 +BATCH_SIZE=256 +SEQ_LEN=20 +LR=1e-3 +DT_MIN=1e-3 +DT_MAX=0.10 +DT_ALPHA=6.0 +LAMBDA_FLOW=1.0 +LAMBDA_POS=0.0 +LAMBDA_DT=0.5 +USE_FLOW_LOSS=true +USE_POS_LOSS=false +USE_DT_LOSS=true +VAL_EVERY=200 +VAL_SAMPLES=512 +VAL_PLOT_SAMPLES=16 +VAL_MAX_STEPS=100 +CENTER_MIN=-8 +CENTER_MAX=8 +CENTER_DISTANCE_MIN=8 +PROJECT="as-mamba" +RUN_NAME="sphere-to-sphere-dt" +OUTPUT_DIR="outputs" + +USE_FLOW_FLAG="--use-flow-loss" +if [ "${USE_FLOW_LOSS}" = "false" ]; then USE_FLOW_FLAG="--no-use-flow-loss"; fi +USE_POS_FLAG="--use-pos-loss" +if [ "${USE_POS_LOSS}" = "false" ]; then USE_POS_FLAG="--no-use-pos-loss"; fi +USE_DT_FLAG="--use-dt-loss" +if [ "${USE_DT_LOSS}" = "false" ]; then USE_DT_FLAG="--no-use-dt-loss"; fi + +uv run python main.py \ + --device "${DEVICE}" \ + --epochs "${EPOCHS}" \ + --steps-per-epoch "${STEPS_PER_EPOCH}" \ + --batch-size "${BATCH_SIZE}" \ + --seq-len "${SEQ_LEN}" \ + --lr "${LR}" \ + --dt-min "${DT_MIN}" \ + --dt-max "${DT_MAX}" \ + --dt-alpha "${DT_ALPHA}" \ + --lambda-flow "${LAMBDA_FLOW}" \ + --lambda-pos "${LAMBDA_POS}" \ + --lambda-dt "${LAMBDA_DT}" \ + ${USE_FLOW_FLAG} \ + ${USE_POS_FLAG} \ + ${USE_DT_FLAG} \ + --val-every "${VAL_EVERY}" \ + --val-samples "${VAL_SAMPLES}" \ + --val-plot-samples "${VAL_PLOT_SAMPLES}" \ + --val-max-steps "${VAL_MAX_STEPS}" \ + --center-min "${CENTER_MIN}" \ + --center-max "${CENTER_MAX}" \ + --center-distance-min "${CENTER_DISTANCE_MIN}" \ + --project "${PROJECT}" \ + --run-name "${RUN_NAME}" \ + --output-dir "${OUTPUT_DIR}"