from __future__ import annotations from dataclasses import asdict, dataclass from pathlib import Path from typing import Optional import matplotlib matplotlib.use("Agg") import numpy as np import torch import torch.nn.functional as F from matplotlib import pyplot as plt from mpl_toolkits.mplot3d import Axes3D # noqa: F401 from torch import Tensor, nn from mamba2_minimal import InferenceCache, Mamba2, Mamba2Config, RMSNorm @dataclass class TrainConfig: seed: int = 42 device: str = "cuda" batch_size: int = 128 steps_per_epoch: int = 50 epochs: int = 60 seq_len: int = 20 lr: float = 1e-3 weight_decay: float = 1e-2 dt_min: float = 1e-3 dt_max: float = 0.06 dt_alpha: float = 8.0 lambda_flow: float = 1.0 lambda_pos: float = 1.0 lambda_dt: float = 0.05 use_flow_loss: bool = True use_pos_loss: bool = False use_dt_loss: bool = True radius_min: float = 0.6 radius_max: float = 1.4 center_min: float = -6.0 center_max: float = 6.0 center_distance_min: float = 6.0 d_model: int = 128 n_layer: int = 4 d_state: int = 64 d_conv: int = 4 expand: int = 2 headdim: int = 32 chunk_size: int = 1 use_residual: bool = False output_dir: str = "outputs" project: str = "as-mamba" run_name: str = "sphere-to-sphere" val_every: int = 200 val_samples: int = 256 val_plot_samples: int = 16 val_max_steps: int = 0 class Mamba2Backbone(nn.Module): def __init__(self, args: Mamba2Config, use_residual: bool = True) -> None: super().__init__() self.args = args self.use_residual = use_residual self.layers = nn.ModuleList( [ nn.ModuleDict( dict( mixer=Mamba2(args), norm=RMSNorm(args.d_model), ) ) for _ in range(args.n_layer) ] ) self.norm_f = RMSNorm(args.d_model) def forward( self, x: Tensor, h: Optional[list[InferenceCache]] = None ) -> tuple[Tensor, list[InferenceCache]]: if h is None: h = [None for _ in range(self.args.n_layer)] for i, layer in enumerate(self.layers): y, h[i] = layer["mixer"](layer["norm"](x), h[i]) x = x + y if self.use_residual else y x = self.norm_f(x) return x, h class ASMamba(nn.Module): def __init__(self, cfg: TrainConfig) -> None: super().__init__() self.cfg = cfg self.dt_min = float(cfg.dt_min) self.dt_max = float(cfg.dt_max) args = Mamba2Config( d_model=cfg.d_model, n_layer=cfg.n_layer, d_state=cfg.d_state, d_conv=cfg.d_conv, expand=cfg.expand, headdim=cfg.headdim, chunk_size=cfg.chunk_size, ) self.backbone = Mamba2Backbone(args, use_residual=cfg.use_residual) self.input_proj = nn.Linear(3, cfg.d_model) self.delta_head = nn.Linear(cfg.d_model, 3) self.dt_head = nn.Sequential( nn.Linear(cfg.d_model, cfg.d_model), nn.SiLU(), nn.Linear(cfg.d_model, 1), ) def forward( self, x: Tensor, h: Optional[list[InferenceCache]] = None ) -> tuple[Tensor, Tensor, list[InferenceCache]]: x_proj = self.input_proj(x) feats, h = self.backbone(x_proj, h) delta = self.delta_head(feats) dt_raw = self.dt_head(feats).squeeze(-1) dt = torch.clamp(F.softplus(dt_raw), min=self.dt_min, max=self.dt_max) return delta, dt, h def step( self, x: Tensor, h: list[InferenceCache] ) -> tuple[Tensor, Tensor, list[InferenceCache]]: delta, dt, h = self.forward(x.unsqueeze(1), h) return delta[:, 0, :], dt[:, 0], h def init_cache(self, batch_size: int, device: torch.device) -> list[InferenceCache]: return [ InferenceCache.alloc(batch_size, self.backbone.args, device=device) for _ in range(self.backbone.args.n_layer) ] class SwanLogger: def __init__(self, cfg: TrainConfig) -> None: self.enabled = False self._swan = None self._run = None try: import swanlab # type: ignore self._swan = swanlab self._run = self._swan.init( project=cfg.project, experiment_name=cfg.run_name, config=asdict(cfg), ) self.enabled = True except Exception: self.enabled = False def log(self, metrics: dict, step: int | None = None) -> None: if not self.enabled: return target = self._run if self._run is not None else self._swan if step is None: target.log(metrics) return try: target.log(metrics, step=step) except TypeError: payload = dict(metrics) payload["step"] = step target.log(payload) def log_image( self, key: str, image_path: Path, caption: str | None = None, step: int | None = None ) -> None: if not self.enabled: return image = self._swan.Image(str(image_path), caption=caption) self.log({key: image}, step=step) def finish(self) -> None: if not self.enabled: return finish = None if self._run is not None: finish = getattr(self._run, "finish", None) if finish is None: finish = getattr(self._swan, "finish", None) if callable(finish): finish() def set_seed(seed: int) -> None: torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def sample_points_in_sphere( center: Tensor, radius: float, batch_size: int, device: torch.device ) -> Tensor: direction = torch.randn(batch_size, 3, device=device) direction = direction / (direction.norm(dim=-1, keepdim=True) + 1e-8) u = torch.rand(batch_size, 1, device=device) r = radius * torch.pow(u, 1.0 / 3.0) return center + direction * r def sample_sphere_params(cfg: TrainConfig, device: torch.device) -> tuple[Tensor, Tensor]: center_a = torch.empty(3, device=device).uniform_(cfg.center_min, cfg.center_max) center_b = torch.empty(3, device=device).uniform_(cfg.center_min, cfg.center_max) for _ in range(128): if torch.norm(center_a - center_b) >= cfg.center_distance_min: break center_b = torch.empty(3, device=device).uniform_(cfg.center_min, cfg.center_max) if torch.norm(center_a - center_b) < 1e-3: center_b = center_b + torch.tensor([cfg.center_distance_min, 0.0, 0.0], device=device) radius_a = float(torch.empty(1).uniform_(cfg.radius_min, cfg.radius_max).item()) radius_b = float(torch.empty(1).uniform_(cfg.radius_min, cfg.radius_max).item()) return (center_a, torch.tensor(radius_a, device=device)), ( center_b, torch.tensor(radius_b, device=device), ) def sample_time_sequence(cfg: TrainConfig, batch_size: int, device: torch.device) -> Tensor: alpha = float(cfg.dt_alpha) if alpha <= 0: raise ValueError("dt_alpha must be > 0") dist = torch.distributions.Gamma(alpha, 1.0) raw = dist.sample((batch_size, cfg.seq_len)).to(device) dt_seq = raw / raw.sum(dim=-1, keepdim=True) base = 1.0 / cfg.seq_len max_dt = float(cfg.dt_max) if max_dt <= base: return torch.full_like(dt_seq, base) max_current = dt_seq.max(dim=-1, keepdim=True).values if (max_current > max_dt).any(): gamma = (max_dt - base) / (max_current - base) gamma = gamma.clamp(0.0, 1.0) dt_seq = gamma * dt_seq + (1.0 - gamma) * base return dt_seq def sample_batch( cfg: TrainConfig, sphere_a: tuple[Tensor, Tensor], sphere_b: tuple[Tensor, Tensor], device: torch.device, ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: center_a, radius_a = sphere_a center_b, radius_b = sphere_b x0 = sample_points_in_sphere(center_a, float(radius_a.item()), cfg.batch_size, device) x1 = sample_points_in_sphere(center_b, float(radius_b.item()), cfg.batch_size, device) v_gt = x1 - x0 dt_seq = sample_time_sequence(cfg, cfg.batch_size, device) t_seq = torch.cumsum(dt_seq, dim=-1) t_seq = torch.cat([torch.zeros(cfg.batch_size, 1, device=device), t_seq[:, :-1]], dim=-1) x_seq = x0[:, None, :] + t_seq[:, :, None] * v_gt[:, None, :] return x0, x1, x_seq, t_seq, dt_seq def compute_losses( delta: Tensor, dt: Tensor, x_seq: Tensor, x0: Tensor, v_gt: Tensor, t_seq: Tensor, dt_seq: Tensor, cfg: TrainConfig, ) -> dict[str, Tensor]: losses: dict[str, Tensor] = {} if cfg.use_flow_loss: target_disp = v_gt[:, None, :] * dt.unsqueeze(-1) losses["flow"] = F.mse_loss(delta, target_disp) if cfg.use_pos_loss: t_next = t_seq + dt t_next = torch.clamp(t_next, 0.0, 1.0) x_target = x0[:, None, :] + t_next.unsqueeze(-1) * v_gt[:, None, :] x_next_pred = x_seq + delta losses["pos"] = F.mse_loss(x_next_pred, x_target) if cfg.use_dt_loss: losses["dt"] = F.mse_loss(dt, dt_seq) return losses def validate( model: ASMamba, cfg: TrainConfig, sphere_a: tuple[Tensor, Tensor], sphere_b: tuple[Tensor, Tensor], device: torch.device, logger: SwanLogger, step: int, ) -> None: model.eval() center_b, radius_b = sphere_b max_steps = cfg.seq_len if cfg.val_max_steps <= 0 else cfg.val_max_steps with torch.no_grad(): x0 = sample_points_in_sphere( sphere_a[0], float(sphere_a[1].item()), cfg.val_samples, device ) traj = rollout_trajectory(model, x0, max_steps=max_steps) x_final = traj[:, -1, :] center_b_cpu = center_b.detach().cpu() radius_b_cpu = radius_b.detach().cpu() dist = torch.linalg.norm(x_final - center_b_cpu, dim=-1) inside = dist <= radius_b_cpu logger.log( { "val/inside_ratio": float(inside.float().mean().item()), "val/inside_count": float(inside.float().sum().item()), "val/final_dist_mean": float(dist.mean().item()), "val/final_dist_min": float(dist.min().item()), "val/final_dist_max": float(dist.max().item()), "val/max_steps": float(max_steps), }, step=step, ) if cfg.val_plot_samples > 0: count = min(cfg.val_plot_samples, traj.shape[0]) if count == 0: model.train() return indices = torch.linspace(0, traj.shape[0] - 1, steps=count).long() traj_plot = traj[indices] save_path = Path(cfg.output_dir) / f"val_traj_step_{step:06d}.png" plot_trajectories( traj_plot, sphere_a, sphere_b, save_path, title=f"Validation Trajectories (step {step})", ) ratio = float(inside.float().mean().item()) logger.log_image( "val/trajectory", save_path, caption=f"step {step} | inside_ratio={ratio:.3f}", step=step, ) model.train() def train(cfg: TrainConfig) -> tuple[ASMamba, tuple[Tensor, Tensor], tuple[Tensor, Tensor]]: device = torch.device(cfg.device if torch.cuda.is_available() else "cpu") set_seed(cfg.seed) output_dir = Path(cfg.output_dir) output_dir.mkdir(parents=True, exist_ok=True) model = ASMamba(cfg).to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) logger = SwanLogger(cfg) sphere_a, sphere_b = sample_sphere_params(cfg, device) center_a, radius_a = sphere_a center_b, radius_b = sphere_b center_dist = torch.norm(center_a - center_b).item() logger.log( { "sphere_a/radius": float(radius_a.item()), "sphere_b/radius": float(radius_b.item()), "sphere_a/center_x": float(center_a[0].item()), "sphere_a/center_y": float(center_a[1].item()), "sphere_a/center_z": float(center_a[2].item()), "sphere_b/center_x": float(center_b[0].item()), "sphere_b/center_y": float(center_b[1].item()), "sphere_b/center_z": float(center_b[2].item()), "sphere/center_dist": float(center_dist), } ) global_step = 0 for epoch in range(cfg.epochs): model.train() for _ in range(cfg.steps_per_epoch): x0, x1, x_seq, t_seq, dt_seq = sample_batch(cfg, sphere_a, sphere_b, device) v_gt = x1 - x0 delta, dt, _ = model(x_seq) losses = compute_losses( delta=delta, dt=dt, x_seq=x_seq, x0=x0, v_gt=v_gt, t_seq=t_seq, dt_seq=dt_seq, cfg=cfg, ) loss = torch.tensor(0.0, device=device) if cfg.use_flow_loss and "flow" in losses: loss = loss + cfg.lambda_flow * losses["flow"] if cfg.use_pos_loss and "pos" in losses: loss = loss + cfg.lambda_pos * losses["pos"] if cfg.use_dt_loss and "dt" in losses: loss = loss + cfg.lambda_dt * losses["dt"] if loss.item() == 0.0: raise RuntimeError("No loss enabled: enable at least one of flow/pos/dt.") optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step() if global_step % 10 == 0: dt_min = float(dt.min().item()) dt_max = float(dt.max().item()) dt_mean = float(dt.mean().item()) dt_gt_min = float(dt_seq.min().item()) dt_gt_max = float(dt_seq.max().item()) dt_gt_mean = float(dt_seq.mean().item()) eps = 1e-6 clamp_min_ratio = float((dt <= cfg.dt_min + eps).float().mean().item()) clamp_max_ratio = float((dt >= cfg.dt_max - eps).float().mean().item()) clamp_any_ratio = float( ((dt <= cfg.dt_min + eps) | (dt >= cfg.dt_max - eps)).float().mean().item() ) logger.log( { "loss/total": float(loss.item()), "loss/flow": float(losses.get("flow", torch.tensor(0.0)).item()), "loss/pos": float(losses.get("pos", torch.tensor(0.0)).item()), "loss/dt": float(losses.get("dt", torch.tensor(0.0)).item()), "dt/pred_mean": dt_mean, "dt/pred_min": dt_min, "dt/pred_max": dt_max, "dt/gt_mean": dt_gt_mean, "dt/gt_min": dt_gt_min, "dt/gt_max": dt_gt_max, "dt/clamp_min_ratio": clamp_min_ratio, "dt/clamp_max_ratio": clamp_max_ratio, "dt/clamp_any_ratio": clamp_any_ratio, }, step=global_step, ) if cfg.val_every > 0 and global_step > 0 and global_step % cfg.val_every == 0: validate(model, cfg, sphere_a, sphere_b, device, logger, global_step) dt_hist_path = Path(cfg.output_dir) / f"dt_hist_step_{global_step:06d}.png" plot_dt_hist( dt, dt_seq, dt_hist_path, title=f"dt Distribution (step {global_step})", ) logger.log_image( "train/dt_hist", dt_hist_path, caption=f"step {global_step}", step=global_step, ) global_step += 1 logger.finish() return model, sphere_a, sphere_b def rollout_trajectory( model: ASMamba, x0: Tensor, max_steps: int, ) -> Tensor: device = x0.device model.eval() h = model.init_cache(batch_size=x0.shape[0], device=device) x = x0 total_time = torch.zeros(x0.shape[0], device=device) traj = [x0.detach().cpu()] with torch.no_grad(): for _ in range(max_steps): delta, dt, h = model.step(x, h) dt = torch.clamp(dt, min=model.dt_min, max=model.dt_max) remaining = 1.0 - total_time overshoot = dt > remaining if overshoot.any(): scale = (remaining / dt).unsqueeze(-1) delta = torch.where(overshoot.unsqueeze(-1), delta * scale, delta) dt = torch.where(overshoot, remaining, dt) x = x + delta total_time = total_time + dt traj.append(x.detach().cpu()) if torch.all(total_time >= 1.0 - 1e-6): break return torch.stack(traj, dim=1) def sphere_wireframe( center: Tensor, radius: float, u_steps: int = 24, v_steps: int = 12 ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: center_np = center.detach().cpu().numpy() u = np.linspace(0, 2 * np.pi, u_steps) v = np.linspace(0, np.pi, v_steps) x = center_np[0] + radius * np.outer(np.cos(u), np.sin(v)) y = center_np[1] + radius * np.outer(np.sin(u), np.sin(v)) z = center_np[2] + radius * np.outer(np.ones_like(u), np.cos(v)) return x, y, z def plot_trajectories( traj: Tensor, sphere_a: tuple[Tensor, Tensor], sphere_b: tuple[Tensor, Tensor], save_path: Path, title: str = "AS-Mamba Trajectories", ) -> None: traj = traj.detach().cpu() if traj.dim() == 2: traj = traj.unsqueeze(0) traj_np = traj.numpy() fig = plt.figure(figsize=(7, 6)) ax = fig.add_subplot(111, projection="3d") for i in range(traj_np.shape[0]): ax.plot( traj_np[i, :, 0], traj_np[i, :, 1], traj_np[i, :, 2], color="green", alpha=0.6, ) starts = traj_np[:, 0, :] ends = traj_np[:, -1, :] ax.scatter(starts[:, 0], starts[:, 1], starts[:, 2], color="blue", s=20, label="Start") ax.scatter(ends[:, 0], ends[:, 1], ends[:, 2], color="red", s=20, label="End") center_a, radius_a = sphere_a center_b, radius_b = sphere_b x_a, y_a, z_a = sphere_wireframe(center_a, float(radius_a.item())) x_b, y_b, z_b = sphere_wireframe(center_b, float(radius_b.item())) ax.plot_wireframe(x_a, y_a, z_a, color="blue", alpha=0.15, linewidth=0.5) ax.plot_wireframe(x_b, y_b, z_b, color="red", alpha=0.15, linewidth=0.5) ax.set_title(title) ax.set_xlabel("X") ax.set_ylabel("Y") ax.set_zlabel("Z") ax.legend(loc="best") fig.tight_layout() fig.savefig(save_path, dpi=160) plt.close(fig) def plot_dt_hist( dt_pred: Tensor, dt_gt: Tensor, save_path: Path, title: str = "dt Distribution", ) -> None: dt_pred_np = dt_pred.detach().cpu().numpy().reshape(-1) dt_gt_np = dt_gt.detach().cpu().numpy().reshape(-1) fig, ax = plt.subplots(figsize=(6, 4)) ax.hist(dt_gt_np, bins=30, alpha=0.6, label="dt_gt", color="steelblue") ax.hist(dt_pred_np, bins=30, alpha=0.6, label="dt_pred", color="orange") ax.set_title(title) ax.set_xlabel("dt") ax.set_ylabel("count") ax.legend(loc="best") fig.tight_layout() fig.savefig(save_path, dpi=160) plt.close(fig) def run_training_and_plot(cfg: TrainConfig) -> Path: model, sphere_a, sphere_b = train(cfg) device = next(model.parameters()).device plot_samples = max(1, cfg.val_plot_samples) x0 = sample_points_in_sphere( sphere_a[0], float(sphere_a[1].item()), plot_samples, device ) max_steps = cfg.seq_len if cfg.val_max_steps <= 0 else cfg.val_max_steps traj = rollout_trajectory(model, x0, max_steps=max_steps) output_dir = Path(cfg.output_dir) save_path = output_dir / "as_mamba_trajectory.png" plot_trajectories(traj, sphere_a, sphere_b, save_path) return save_path