from __future__ import annotations from dataclasses import asdict, dataclass from pathlib import Path from typing import Optional import matplotlib matplotlib.use("Agg") import numpy as np import torch import torch.nn.functional as F from matplotlib import pyplot as plt from mpl_toolkits.mplot3d import Axes3D # noqa: F401 from torch import Tensor, nn from mamba2_minimal import InferenceCache, Mamba2, Mamba2Config, RMSNorm @dataclass class TrainConfig: seed: int = 42 device: str = "cuda" batch_size: int = 128 steps_per_epoch: int = 50 epochs: int = 60 warmup_epochs: int = 15 seq_len: int = 20 lr: float = 1e-3 weight_decay: float = 1e-2 dt_min: float = 1e-3 dt_max: float = 0.06 lambda_flow: float = 1.0 lambda_pos: float = 1.0 lambda_nfe: float = 0.05 radius_min: float = 0.6 radius_max: float = 1.4 center_min: float = -6.0 center_max: float = 6.0 center_distance_min: float = 6.0 d_model: int = 128 n_layer: int = 4 d_state: int = 64 d_conv: int = 4 expand: int = 2 headdim: int = 32 chunk_size: int = 1 use_residual: bool = False output_dir: str = "outputs" project: str = "as-mamba" run_name: str = "sphere-to-sphere" val_every: int = 200 val_samples: int = 256 val_plot_samples: int = 16 val_max_steps: int = 100 class Mamba2Backbone(nn.Module): def __init__(self, args: Mamba2Config, use_residual: bool = True) -> None: super().__init__() self.args = args self.use_residual = use_residual self.layers = nn.ModuleList( [ nn.ModuleDict( dict( mixer=Mamba2(args), norm=RMSNorm(args.d_model), ) ) for _ in range(args.n_layer) ] ) self.norm_f = RMSNorm(args.d_model) def forward( self, x: Tensor, h: Optional[list[InferenceCache]] = None ) -> tuple[Tensor, list[InferenceCache]]: if h is None: h = [None for _ in range(self.args.n_layer)] for i, layer in enumerate(self.layers): y, h[i] = layer["mixer"](layer["norm"](x), h[i]) x = x + y if self.use_residual else y x = self.norm_f(x) return x, h class ASMamba(nn.Module): def __init__(self, cfg: TrainConfig) -> None: super().__init__() self.cfg = cfg self.dt_min = float(cfg.dt_min) self.dt_max = float(cfg.dt_max) args = Mamba2Config( d_model=cfg.d_model, n_layer=cfg.n_layer, d_state=cfg.d_state, d_conv=cfg.d_conv, expand=cfg.expand, headdim=cfg.headdim, chunk_size=cfg.chunk_size, ) self.backbone = Mamba2Backbone(args, use_residual=cfg.use_residual) self.input_proj = nn.Linear(3, cfg.d_model) self.delta_head = nn.Linear(cfg.d_model, 3) self.dt_head = nn.Sequential( nn.Linear(cfg.d_model, cfg.d_model), nn.SiLU(), nn.Linear(cfg.d_model, 1), ) def forward( self, x: Tensor, h: Optional[list[InferenceCache]] = None ) -> tuple[Tensor, Tensor, list[InferenceCache]]: x_proj = self.input_proj(x) feats, h = self.backbone(x_proj, h) delta = self.delta_head(feats) dt_raw = self.dt_head(feats).squeeze(-1) dt = torch.clamp(F.softplus(dt_raw), min=self.dt_min, max=self.dt_max) return delta, dt, h def step( self, x: Tensor, h: list[InferenceCache] ) -> tuple[Tensor, Tensor, list[InferenceCache]]: delta, dt, h = self.forward(x.unsqueeze(1), h) return delta[:, 0, :], dt[:, 0], h def init_cache(self, batch_size: int, device: torch.device) -> list[InferenceCache]: return [ InferenceCache.alloc(batch_size, self.backbone.args, device=device) for _ in range(self.backbone.args.n_layer) ] class SwanLogger: def __init__(self, cfg: TrainConfig) -> None: self.enabled = False self._swan = None self._run = None try: import swanlab # type: ignore self._swan = swanlab self._run = self._swan.init( project=cfg.project, experiment_name=cfg.run_name, config=asdict(cfg), ) self.enabled = True except Exception: self.enabled = False def log(self, metrics: dict, step: int | None = None) -> None: if not self.enabled: return target = self._run if self._run is not None else self._swan if step is None: target.log(metrics) return try: target.log(metrics, step=step) except TypeError: payload = dict(metrics) payload["step"] = step target.log(payload) def log_image( self, key: str, image_path: Path, caption: str | None = None, step: int | None = None ) -> None: if not self.enabled: return image = self._swan.Image(str(image_path), caption=caption) self.log({key: image}, step=step) def finish(self) -> None: if not self.enabled: return finish = None if self._run is not None: finish = getattr(self._run, "finish", None) if finish is None: finish = getattr(self._swan, "finish", None) if callable(finish): finish() def set_seed(seed: int) -> None: torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def sample_points_in_sphere( center: Tensor, radius: float, batch_size: int, device: torch.device ) -> Tensor: direction = torch.randn(batch_size, 3, device=device) direction = direction / (direction.norm(dim=-1, keepdim=True) + 1e-8) u = torch.rand(batch_size, 1, device=device) r = radius * torch.pow(u, 1.0 / 3.0) return center + direction * r def sample_sphere_params(cfg: TrainConfig, device: torch.device) -> tuple[Tensor, Tensor]: center_a = torch.empty(3, device=device).uniform_(cfg.center_min, cfg.center_max) center_b = torch.empty(3, device=device).uniform_(cfg.center_min, cfg.center_max) for _ in range(128): if torch.norm(center_a - center_b) >= cfg.center_distance_min: break center_b = torch.empty(3, device=device).uniform_(cfg.center_min, cfg.center_max) if torch.norm(center_a - center_b) < 1e-3: center_b = center_b + torch.tensor([cfg.center_distance_min, 0.0, 0.0], device=device) radius_a = float(torch.empty(1).uniform_(cfg.radius_min, cfg.radius_max).item()) radius_b = float(torch.empty(1).uniform_(cfg.radius_min, cfg.radius_max).item()) return (center_a, torch.tensor(radius_a, device=device)), ( center_b, torch.tensor(radius_b, device=device), ) def sample_batch( cfg: TrainConfig, sphere_a: tuple[Tensor, Tensor], sphere_b: tuple[Tensor, Tensor], device: torch.device, ) -> tuple[Tensor, Tensor, Tensor, Tensor]: center_a, radius_a = sphere_a center_b, radius_b = sphere_b x0 = sample_points_in_sphere(center_a, float(radius_a.item()), cfg.batch_size, device) x1 = sample_points_in_sphere(center_b, float(radius_b.item()), cfg.batch_size, device) v_gt = x1 - x0 dt_fixed = 1.0 / cfg.seq_len t_seq = torch.arange(cfg.seq_len, device=device) * dt_fixed x_seq = x0[:, None, :] + t_seq[None, :, None] * v_gt[:, None, :] return x0, x1, x_seq, t_seq def compute_losses( delta: Tensor, dt: Tensor, x_seq: Tensor, x0: Tensor, v_gt: Tensor, t_seq: Tensor, cfg: TrainConfig, ) -> tuple[Tensor, Tensor, Tensor]: target_disp = v_gt[:, None, :] * dt.unsqueeze(-1) flow_loss = F.mse_loss(delta, target_disp) t_next = t_seq[None, :, None] + dt.unsqueeze(-1) t_next = torch.clamp(t_next, 0.0, 1.0) x_target = x0[:, None, :] + t_next * v_gt[:, None, :] x_next_pred = x_seq + delta pos_loss = F.mse_loss(x_next_pred, x_target) nfe_loss = (-torch.log(dt)).mean() return flow_loss, pos_loss, nfe_loss def validate( model: ASMamba, cfg: TrainConfig, sphere_a: tuple[Tensor, Tensor], sphere_b: tuple[Tensor, Tensor], device: torch.device, logger: SwanLogger, step: int, ) -> None: model.eval() center_b, radius_b = sphere_b with torch.no_grad(): x0 = sample_points_in_sphere( sphere_a[0], float(sphere_a[1].item()), cfg.val_samples, device ) traj = rollout_trajectory(model, x0, max_steps=cfg.val_max_steps) x_final = traj[:, -1, :] center_b_cpu = center_b.detach().cpu() radius_b_cpu = radius_b.detach().cpu() dist = torch.linalg.norm(x_final - center_b_cpu, dim=-1) inside = dist <= radius_b_cpu logger.log( { "val/inside_ratio": float(inside.float().mean().item()), "val/inside_count": float(inside.float().sum().item()), "val/final_dist_mean": float(dist.mean().item()), "val/final_dist_min": float(dist.min().item()), "val/final_dist_max": float(dist.max().item()), }, step=step, ) if cfg.val_plot_samples > 0: count = min(cfg.val_plot_samples, traj.shape[0]) if count == 0: model.train() return indices = torch.linspace(0, traj.shape[0] - 1, steps=count).long() traj_plot = traj[indices] save_path = Path(cfg.output_dir) / f"val_traj_step_{step:06d}.png" plot_trajectories( traj_plot, sphere_a, sphere_b, save_path, title=f"Validation Trajectories (step {step})", ) ratio = float(inside.float().mean().item()) logger.log_image( "val/trajectory", save_path, caption=f"step {step} | inside_ratio={ratio:.3f}", step=step, ) model.train() def train(cfg: TrainConfig) -> tuple[ASMamba, tuple[Tensor, Tensor], tuple[Tensor, Tensor]]: device = torch.device(cfg.device if torch.cuda.is_available() else "cpu") set_seed(cfg.seed) output_dir = Path(cfg.output_dir) output_dir.mkdir(parents=True, exist_ok=True) model = ASMamba(cfg).to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) logger = SwanLogger(cfg) sphere_a, sphere_b = sample_sphere_params(cfg, device) center_a, radius_a = sphere_a center_b, radius_b = sphere_b center_dist = torch.norm(center_a - center_b).item() logger.log( { "sphere_a/radius": float(radius_a.item()), "sphere_b/radius": float(radius_b.item()), "sphere_a/center_x": float(center_a[0].item()), "sphere_a/center_y": float(center_a[1].item()), "sphere_a/center_z": float(center_a[2].item()), "sphere_b/center_x": float(center_b[0].item()), "sphere_b/center_y": float(center_b[1].item()), "sphere_b/center_z": float(center_b[2].item()), "sphere/center_dist": float(center_dist), } ) global_step = 0 for epoch in range(cfg.epochs): warmup = epoch < cfg.warmup_epochs model.train() for p in model.dt_head.parameters(): p.requires_grad = not warmup for _ in range(cfg.steps_per_epoch): x0, x1, x_seq, t_seq = sample_batch(cfg, sphere_a, sphere_b, device) v_gt = x1 - x0 delta, dt, _ = model(x_seq) if warmup: dt = torch.full_like(dt, 1.0 / cfg.seq_len) flow_loss, pos_loss, nfe_loss = compute_losses( delta=delta, dt=dt, x_seq=x_seq, x0=x0, v_gt=v_gt, t_seq=t_seq, cfg=cfg, ) loss = cfg.lambda_flow * flow_loss + cfg.lambda_pos * pos_loss if not warmup: loss = loss + cfg.lambda_nfe * nfe_loss optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step() if global_step % 10 == 0: logger.log( { "loss/total": float(loss.item()), "loss/flow": float(flow_loss.item()), "loss/pos": float(pos_loss.item()), "loss/nfe": float(nfe_loss.item()), "dt/mean": float(dt.mean().item()), "dt/min": float(dt.min().item()), "dt/max": float(dt.max().item()), "stage": 0 if warmup else 1, }, step=global_step, ) if cfg.val_every > 0 and global_step > 0 and global_step % cfg.val_every == 0: validate(model, cfg, sphere_a, sphere_b, device, logger, global_step) global_step += 1 logger.finish() return model, sphere_a, sphere_b def rollout_trajectory( model: ASMamba, x0: Tensor, max_steps: int = 100, ) -> Tensor: device = x0.device model.eval() h = model.init_cache(batch_size=x0.shape[0], device=device) x = x0 total_time = torch.zeros(x0.shape[0], device=device) traj = [x0.detach().cpu()] with torch.no_grad(): for _ in range(max_steps): delta, dt, h = model.step(x, h) dt = torch.clamp(dt, min=model.dt_min, max=model.dt_max) remaining = 1.0 - total_time overshoot = dt > remaining if overshoot.any(): scale = (remaining / dt).unsqueeze(-1) delta = torch.where(overshoot.unsqueeze(-1), delta * scale, delta) dt = torch.where(overshoot, remaining, dt) x = x + delta total_time = total_time + dt traj.append(x.detach().cpu()) if torch.all(total_time >= 1.0 - 1e-6): break return torch.stack(traj, dim=1) def sphere_wireframe( center: Tensor, radius: float, u_steps: int = 24, v_steps: int = 12 ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: center_np = center.detach().cpu().numpy() u = np.linspace(0, 2 * np.pi, u_steps) v = np.linspace(0, np.pi, v_steps) x = center_np[0] + radius * np.outer(np.cos(u), np.sin(v)) y = center_np[1] + radius * np.outer(np.sin(u), np.sin(v)) z = center_np[2] + radius * np.outer(np.ones_like(u), np.cos(v)) return x, y, z def plot_trajectories( traj: Tensor, sphere_a: tuple[Tensor, Tensor], sphere_b: tuple[Tensor, Tensor], save_path: Path, title: str = "AS-Mamba Trajectories", ) -> None: traj = traj.detach().cpu() if traj.dim() == 2: traj = traj.unsqueeze(0) traj_np = traj.numpy() fig = plt.figure(figsize=(7, 6)) ax = fig.add_subplot(111, projection="3d") for i in range(traj_np.shape[0]): ax.plot( traj_np[i, :, 0], traj_np[i, :, 1], traj_np[i, :, 2], color="green", alpha=0.6, ) starts = traj_np[:, 0, :] ends = traj_np[:, -1, :] ax.scatter(starts[:, 0], starts[:, 1], starts[:, 2], color="blue", s=20, label="Start") ax.scatter(ends[:, 0], ends[:, 1], ends[:, 2], color="red", s=20, label="End") center_a, radius_a = sphere_a center_b, radius_b = sphere_b x_a, y_a, z_a = sphere_wireframe(center_a, float(radius_a.item())) x_b, y_b, z_b = sphere_wireframe(center_b, float(radius_b.item())) ax.plot_wireframe(x_a, y_a, z_a, color="blue", alpha=0.15, linewidth=0.5) ax.plot_wireframe(x_b, y_b, z_b, color="red", alpha=0.15, linewidth=0.5) ax.set_title(title) ax.set_xlabel("X") ax.set_ylabel("Y") ax.set_zlabel("Z") ax.legend(loc="best") fig.tight_layout() fig.savefig(save_path, dpi=160) plt.close(fig) def run_training_and_plot(cfg: TrainConfig) -> Path: model, sphere_a, sphere_b = train(cfg) device = next(model.parameters()).device plot_samples = max(1, cfg.val_plot_samples) x0 = sample_points_in_sphere( sphere_a[0], float(sphere_a[1].item()), plot_samples, device ) traj = rollout_trajectory(model, x0, max_steps=cfg.val_max_steps) output_dir = Path(cfg.output_dir) save_path = output_dir / "as_mamba_trajectory.png" plot_trajectories(traj, sphere_a, sphere_b, save_path) return save_path