From 1446f97459f9472716cd32a3e8c4b7bebcafe397 Mon Sep 17 00:00:00 2001 From: gameloader Date: Wed, 21 Jan 2026 13:07:36 +0800 Subject: [PATCH] refactor(as_mamba): Remove dt prediction and use fixed dt Removes the `dt_head` network and associated configuration parameters (dt_min, dt_max, lambda_nfe, warmup_epochs). Replaces predicted time steps with a fixed value derived from sequence length. Eliminates the warmup phase and NFE loss calculation. --- as_mamba.py | 80 ++++++++++++++--------------------------------------- main.py | 1 - 2 files changed, 21 insertions(+), 60 deletions(-) diff --git a/as_mamba.py b/as_mamba.py index 8ea998c..f01d829 100644 --- a/as_mamba.py +++ b/as_mamba.py @@ -25,15 +25,11 @@ class TrainConfig: batch_size: int = 128 steps_per_epoch: int = 50 epochs: int = 60 - warmup_epochs: int = 15 seq_len: int = 20 lr: float = 1e-3 weight_decay: float = 1e-2 - dt_min: float = 1e-3 - dt_max: float = 0.06 lambda_flow: float = 1.0 lambda_pos: float = 1.0 - lambda_nfe: float = 0.05 radius_min: float = 0.6 radius_max: float = 1.4 center_min: float = -6.0 @@ -53,7 +49,7 @@ class TrainConfig: val_every: int = 200 val_samples: int = 256 val_plot_samples: int = 16 - val_max_steps: int = 100 + val_max_steps: int = 0 class Mamba2Backbone(nn.Module): @@ -92,8 +88,6 @@ class ASMamba(nn.Module): def __init__(self, cfg: TrainConfig) -> None: super().__init__() self.cfg = cfg - self.dt_min = float(cfg.dt_min) - self.dt_max = float(cfg.dt_max) args = Mamba2Config( d_model=cfg.d_model, @@ -107,27 +101,20 @@ class ASMamba(nn.Module): self.backbone = Mamba2Backbone(args, use_residual=cfg.use_residual) self.input_proj = nn.Linear(3, cfg.d_model) self.delta_head = nn.Linear(cfg.d_model, 3) - self.dt_head = nn.Sequential( - nn.Linear(cfg.d_model, cfg.d_model), - nn.SiLU(), - nn.Linear(cfg.d_model, 1), - ) def forward( self, x: Tensor, h: Optional[list[InferenceCache]] = None - ) -> tuple[Tensor, Tensor, list[InferenceCache]]: + ) -> tuple[Tensor, list[InferenceCache]]: x_proj = self.input_proj(x) feats, h = self.backbone(x_proj, h) delta = self.delta_head(feats) - dt_raw = self.dt_head(feats).squeeze(-1) - dt = torch.clamp(F.softplus(dt_raw), min=self.dt_min, max=self.dt_max) - return delta, dt, h + return delta, h def step( self, x: Tensor, h: list[InferenceCache] - ) -> tuple[Tensor, Tensor, list[InferenceCache]]: - delta, dt, h = self.forward(x.unsqueeze(1), h) - return delta[:, 0, :], dt[:, 0], h + ) -> tuple[Tensor, list[InferenceCache]]: + delta, h = self.forward(x.unsqueeze(1), h) + return delta[:, 0, :], h def init_cache(self, batch_size: int, device: torch.device) -> list[InferenceCache]: return [ @@ -240,24 +227,23 @@ def sample_batch( def compute_losses( delta: Tensor, - dt: Tensor, x_seq: Tensor, x0: Tensor, v_gt: Tensor, t_seq: Tensor, cfg: TrainConfig, -) -> tuple[Tensor, Tensor, Tensor]: - target_disp = v_gt[:, None, :] * dt.unsqueeze(-1) +) -> tuple[Tensor, Tensor]: + dt_fixed = 1.0 / cfg.seq_len + target_disp = v_gt[:, None, :] * dt_fixed flow_loss = F.mse_loss(delta, target_disp) - t_next = t_seq[None, :, None] + dt.unsqueeze(-1) + t_next = t_seq[None, :, None] + dt_fixed t_next = torch.clamp(t_next, 0.0, 1.0) x_target = x0[:, None, :] + t_next * v_gt[:, None, :] x_next_pred = x_seq + delta pos_loss = F.mse_loss(x_next_pred, x_target) - nfe_loss = (-torch.log(dt)).mean() - return flow_loss, pos_loss, nfe_loss + return flow_loss, pos_loss def validate( @@ -271,12 +257,13 @@ def validate( ) -> None: model.eval() center_b, radius_b = sphere_b + steps = cfg.seq_len if cfg.val_max_steps <= 0 else cfg.val_max_steps with torch.no_grad(): x0 = sample_points_in_sphere( sphere_a[0], float(sphere_a[1].item()), cfg.val_samples, device ) - traj = rollout_trajectory(model, x0, max_steps=cfg.val_max_steps) + traj = rollout_trajectory(model, x0, steps=steps) x_final = traj[:, -1, :] center_b_cpu = center_b.detach().cpu() @@ -291,6 +278,7 @@ def validate( "val/final_dist_mean": float(dist.mean().item()), "val/final_dist_min": float(dist.min().item()), "val/final_dist_max": float(dist.max().item()), + "val/steps": float(steps), }, step=step, ) @@ -351,22 +339,15 @@ def train(cfg: TrainConfig) -> tuple[ASMamba, tuple[Tensor, Tensor], tuple[Tenso global_step = 0 for epoch in range(cfg.epochs): - warmup = epoch < cfg.warmup_epochs model.train() - for p in model.dt_head.parameters(): - p.requires_grad = not warmup for _ in range(cfg.steps_per_epoch): x0, x1, x_seq, t_seq = sample_batch(cfg, sphere_a, sphere_b, device) v_gt = x1 - x0 - delta, dt, _ = model(x_seq) - if warmup: - dt = torch.full_like(dt, 1.0 / cfg.seq_len) - - flow_loss, pos_loss, nfe_loss = compute_losses( + delta, _ = model(x_seq) + flow_loss, pos_loss = compute_losses( delta=delta, - dt=dt, x_seq=x_seq, x0=x0, v_gt=v_gt, @@ -375,8 +356,6 @@ def train(cfg: TrainConfig) -> tuple[ASMamba, tuple[Tensor, Tensor], tuple[Tenso ) loss = cfg.lambda_flow * flow_loss + cfg.lambda_pos * pos_loss - if not warmup: - loss = loss + cfg.lambda_nfe * nfe_loss optimizer.zero_grad(set_to_none=True) loss.backward() @@ -388,11 +367,6 @@ def train(cfg: TrainConfig) -> tuple[ASMamba, tuple[Tensor, Tensor], tuple[Tenso "loss/total": float(loss.item()), "loss/flow": float(flow_loss.item()), "loss/pos": float(pos_loss.item()), - "loss/nfe": float(nfe_loss.item()), - "dt/mean": float(dt.mean().item()), - "dt/min": float(dt.min().item()), - "dt/max": float(dt.max().item()), - "stage": 0 if warmup else 1, }, step=global_step, ) @@ -408,33 +382,20 @@ def train(cfg: TrainConfig) -> tuple[ASMamba, tuple[Tensor, Tensor], tuple[Tenso def rollout_trajectory( model: ASMamba, x0: Tensor, - max_steps: int = 100, + steps: int, ) -> Tensor: device = x0.device model.eval() h = model.init_cache(batch_size=x0.shape[0], device=device) x = x0 - total_time = torch.zeros(x0.shape[0], device=device) traj = [x0.detach().cpu()] with torch.no_grad(): - for _ in range(max_steps): - delta, dt, h = model.step(x, h) - dt = torch.clamp(dt, min=model.dt_min, max=model.dt_max) - remaining = 1.0 - total_time - overshoot = dt > remaining - if overshoot.any(): - scale = (remaining / dt).unsqueeze(-1) - delta = torch.where(overshoot.unsqueeze(-1), delta * scale, delta) - dt = torch.where(overshoot, remaining, dt) - + for _ in range(steps): + delta, h = model.step(x, h) x = x + delta - total_time = total_time + dt traj.append(x.detach().cpu()) - if torch.all(total_time >= 1.0 - 1e-6): - break - return torch.stack(traj, dim=1) @@ -504,7 +465,8 @@ def run_training_and_plot(cfg: TrainConfig) -> Path: x0 = sample_points_in_sphere( sphere_a[0], float(sphere_a[1].item()), plot_samples, device ) - traj = rollout_trajectory(model, x0, max_steps=cfg.val_max_steps) + steps = cfg.seq_len if cfg.val_max_steps <= 0 else cfg.val_max_steps + traj = rollout_trajectory(model, x0, steps=steps) output_dir = Path(cfg.output_dir) save_path = output_dir / "as_mamba_trajectory.png" plot_trajectories(traj, sphere_a, sphere_b, save_path) diff --git a/main.py b/main.py index 1e49981..f93a783 100644 --- a/main.py +++ b/main.py @@ -6,7 +6,6 @@ from as_mamba import TrainConfig, run_training_and_plot def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Train AS-Mamba on sphere-to-sphere flow.") parser.add_argument("--epochs", type=int, default=None) - parser.add_argument("--warmup-epochs", type=int, default=None) parser.add_argument("--batch-size", type=int, default=None) parser.add_argument("--steps-per-epoch", type=int, default=None) parser.add_argument("--seq-len", type=int, default=None)