feat(mamba): add Mamba2 implementation

Add initial project structure including core Mamba2 logic,
entry point, and uv-based dependency management.
This commit is contained in:
gameloader
2026-01-21 12:54:49 +08:00
commit c58a73ae26
8 changed files with 2151 additions and 0 deletions

511
as_mamba.py Normal file
View File

@@ -0,0 +1,511 @@
from __future__ import annotations
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Optional
import matplotlib
matplotlib.use("Agg")
import numpy as np
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D # noqa: F401
from torch import Tensor, nn
from mamba2_minimal import InferenceCache, Mamba2, Mamba2Config, RMSNorm
@dataclass
class TrainConfig:
seed: int = 42
device: str = "cuda"
batch_size: int = 128
steps_per_epoch: int = 50
epochs: int = 60
warmup_epochs: int = 15
seq_len: int = 20
lr: float = 1e-3
weight_decay: float = 1e-2
dt_min: float = 1e-3
dt_max: float = 0.06
lambda_flow: float = 1.0
lambda_pos: float = 1.0
lambda_nfe: float = 0.05
radius_min: float = 0.6
radius_max: float = 1.4
center_min: float = -6.0
center_max: float = 6.0
center_distance_min: float = 6.0
d_model: int = 128
n_layer: int = 4
d_state: int = 64
d_conv: int = 4
expand: int = 2
headdim: int = 32
chunk_size: int = 1
use_residual: bool = False
output_dir: str = "outputs"
project: str = "as-mamba"
run_name: str = "sphere-to-sphere"
val_every: int = 200
val_samples: int = 256
val_plot_samples: int = 16
val_max_steps: int = 100
class Mamba2Backbone(nn.Module):
def __init__(self, args: Mamba2Config, use_residual: bool = True) -> None:
super().__init__()
self.args = args
self.use_residual = use_residual
self.layers = nn.ModuleList(
[
nn.ModuleDict(
dict(
mixer=Mamba2(args),
norm=RMSNorm(args.d_model),
)
)
for _ in range(args.n_layer)
]
)
self.norm_f = RMSNorm(args.d_model)
def forward(
self, x: Tensor, h: Optional[list[InferenceCache]] = None
) -> tuple[Tensor, list[InferenceCache]]:
if h is None:
h = [None for _ in range(self.args.n_layer)]
for i, layer in enumerate(self.layers):
y, h[i] = layer["mixer"](layer["norm"](x), h[i])
x = x + y if self.use_residual else y
x = self.norm_f(x)
return x, h
class ASMamba(nn.Module):
def __init__(self, cfg: TrainConfig) -> None:
super().__init__()
self.cfg = cfg
self.dt_min = float(cfg.dt_min)
self.dt_max = float(cfg.dt_max)
args = Mamba2Config(
d_model=cfg.d_model,
n_layer=cfg.n_layer,
d_state=cfg.d_state,
d_conv=cfg.d_conv,
expand=cfg.expand,
headdim=cfg.headdim,
chunk_size=cfg.chunk_size,
)
self.backbone = Mamba2Backbone(args, use_residual=cfg.use_residual)
self.input_proj = nn.Linear(3, cfg.d_model)
self.delta_head = nn.Linear(cfg.d_model, 3)
self.dt_head = nn.Sequential(
nn.Linear(cfg.d_model, cfg.d_model),
nn.SiLU(),
nn.Linear(cfg.d_model, 1),
)
def forward(
self, x: Tensor, h: Optional[list[InferenceCache]] = None
) -> tuple[Tensor, Tensor, list[InferenceCache]]:
x_proj = self.input_proj(x)
feats, h = self.backbone(x_proj, h)
delta = self.delta_head(feats)
dt_raw = self.dt_head(feats).squeeze(-1)
dt = torch.clamp(F.softplus(dt_raw), min=self.dt_min, max=self.dt_max)
return delta, dt, h
def step(
self, x: Tensor, h: list[InferenceCache]
) -> tuple[Tensor, Tensor, list[InferenceCache]]:
delta, dt, h = self.forward(x.unsqueeze(1), h)
return delta[:, 0, :], dt[:, 0], h
def init_cache(self, batch_size: int, device: torch.device) -> list[InferenceCache]:
return [
InferenceCache.alloc(batch_size, self.backbone.args, device=device)
for _ in range(self.backbone.args.n_layer)
]
class SwanLogger:
def __init__(self, cfg: TrainConfig) -> None:
self.enabled = False
self._swan = None
self._run = None
try:
import swanlab # type: ignore
self._swan = swanlab
self._run = self._swan.init(
project=cfg.project,
experiment_name=cfg.run_name,
config=asdict(cfg),
)
self.enabled = True
except Exception:
self.enabled = False
def log(self, metrics: dict, step: int | None = None) -> None:
if not self.enabled:
return
target = self._run if self._run is not None else self._swan
if step is None:
target.log(metrics)
return
try:
target.log(metrics, step=step)
except TypeError:
payload = dict(metrics)
payload["step"] = step
target.log(payload)
def log_image(
self, key: str, image_path: Path, caption: str | None = None, step: int | None = None
) -> None:
if not self.enabled:
return
image = self._swan.Image(str(image_path), caption=caption)
self.log({key: image}, step=step)
def finish(self) -> None:
if not self.enabled:
return
finish = None
if self._run is not None:
finish = getattr(self._run, "finish", None)
if finish is None:
finish = getattr(self._swan, "finish", None)
if callable(finish):
finish()
def set_seed(seed: int) -> None:
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def sample_points_in_sphere(
center: Tensor, radius: float, batch_size: int, device: torch.device
) -> Tensor:
direction = torch.randn(batch_size, 3, device=device)
direction = direction / (direction.norm(dim=-1, keepdim=True) + 1e-8)
u = torch.rand(batch_size, 1, device=device)
r = radius * torch.pow(u, 1.0 / 3.0)
return center + direction * r
def sample_sphere_params(cfg: TrainConfig, device: torch.device) -> tuple[Tensor, Tensor]:
center_a = torch.empty(3, device=device).uniform_(cfg.center_min, cfg.center_max)
center_b = torch.empty(3, device=device).uniform_(cfg.center_min, cfg.center_max)
for _ in range(128):
if torch.norm(center_a - center_b) >= cfg.center_distance_min:
break
center_b = torch.empty(3, device=device).uniform_(cfg.center_min, cfg.center_max)
if torch.norm(center_a - center_b) < 1e-3:
center_b = center_b + torch.tensor([cfg.center_distance_min, 0.0, 0.0], device=device)
radius_a = float(torch.empty(1).uniform_(cfg.radius_min, cfg.radius_max).item())
radius_b = float(torch.empty(1).uniform_(cfg.radius_min, cfg.radius_max).item())
return (center_a, torch.tensor(radius_a, device=device)), (
center_b,
torch.tensor(radius_b, device=device),
)
def sample_batch(
cfg: TrainConfig,
sphere_a: tuple[Tensor, Tensor],
sphere_b: tuple[Tensor, Tensor],
device: torch.device,
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
center_a, radius_a = sphere_a
center_b, radius_b = sphere_b
x0 = sample_points_in_sphere(center_a, float(radius_a.item()), cfg.batch_size, device)
x1 = sample_points_in_sphere(center_b, float(radius_b.item()), cfg.batch_size, device)
v_gt = x1 - x0
dt_fixed = 1.0 / cfg.seq_len
t_seq = torch.arange(cfg.seq_len, device=device) * dt_fixed
x_seq = x0[:, None, :] + t_seq[None, :, None] * v_gt[:, None, :]
return x0, x1, x_seq, t_seq
def compute_losses(
delta: Tensor,
dt: Tensor,
x_seq: Tensor,
x0: Tensor,
v_gt: Tensor,
t_seq: Tensor,
cfg: TrainConfig,
) -> tuple[Tensor, Tensor, Tensor]:
target_disp = v_gt[:, None, :] * dt.unsqueeze(-1)
flow_loss = F.mse_loss(delta, target_disp)
t_next = t_seq[None, :, None] + dt.unsqueeze(-1)
t_next = torch.clamp(t_next, 0.0, 1.0)
x_target = x0[:, None, :] + t_next * v_gt[:, None, :]
x_next_pred = x_seq + delta
pos_loss = F.mse_loss(x_next_pred, x_target)
nfe_loss = (-torch.log(dt)).mean()
return flow_loss, pos_loss, nfe_loss
def validate(
model: ASMamba,
cfg: TrainConfig,
sphere_a: tuple[Tensor, Tensor],
sphere_b: tuple[Tensor, Tensor],
device: torch.device,
logger: SwanLogger,
step: int,
) -> None:
model.eval()
center_b, radius_b = sphere_b
with torch.no_grad():
x0 = sample_points_in_sphere(
sphere_a[0], float(sphere_a[1].item()), cfg.val_samples, device
)
traj = rollout_trajectory(model, x0, max_steps=cfg.val_max_steps)
x_final = traj[:, -1, :]
center_b_cpu = center_b.detach().cpu()
radius_b_cpu = radius_b.detach().cpu()
dist = torch.linalg.norm(x_final - center_b_cpu, dim=-1)
inside = dist <= radius_b_cpu
logger.log(
{
"val/inside_ratio": float(inside.float().mean().item()),
"val/inside_count": float(inside.float().sum().item()),
"val/final_dist_mean": float(dist.mean().item()),
"val/final_dist_min": float(dist.min().item()),
"val/final_dist_max": float(dist.max().item()),
},
step=step,
)
if cfg.val_plot_samples > 0:
count = min(cfg.val_plot_samples, traj.shape[0])
if count == 0:
model.train()
return
indices = torch.linspace(0, traj.shape[0] - 1, steps=count).long()
traj_plot = traj[indices]
save_path = Path(cfg.output_dir) / f"val_traj_step_{step:06d}.png"
plot_trajectories(
traj_plot,
sphere_a,
sphere_b,
save_path,
title=f"Validation Trajectories (step {step})",
)
ratio = float(inside.float().mean().item())
logger.log_image(
"val/trajectory",
save_path,
caption=f"step {step} | inside_ratio={ratio:.3f}",
step=step,
)
model.train()
def train(cfg: TrainConfig) -> tuple[ASMamba, tuple[Tensor, Tensor], tuple[Tensor, Tensor]]:
device = torch.device(cfg.device if torch.cuda.is_available() else "cpu")
set_seed(cfg.seed)
output_dir = Path(cfg.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
model = ASMamba(cfg).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
logger = SwanLogger(cfg)
sphere_a, sphere_b = sample_sphere_params(cfg, device)
center_a, radius_a = sphere_a
center_b, radius_b = sphere_b
center_dist = torch.norm(center_a - center_b).item()
logger.log(
{
"sphere_a/radius": float(radius_a.item()),
"sphere_b/radius": float(radius_b.item()),
"sphere_a/center_x": float(center_a[0].item()),
"sphere_a/center_y": float(center_a[1].item()),
"sphere_a/center_z": float(center_a[2].item()),
"sphere_b/center_x": float(center_b[0].item()),
"sphere_b/center_y": float(center_b[1].item()),
"sphere_b/center_z": float(center_b[2].item()),
"sphere/center_dist": float(center_dist),
}
)
global_step = 0
for epoch in range(cfg.epochs):
warmup = epoch < cfg.warmup_epochs
model.train()
for p in model.dt_head.parameters():
p.requires_grad = not warmup
for _ in range(cfg.steps_per_epoch):
x0, x1, x_seq, t_seq = sample_batch(cfg, sphere_a, sphere_b, device)
v_gt = x1 - x0
delta, dt, _ = model(x_seq)
if warmup:
dt = torch.full_like(dt, 1.0 / cfg.seq_len)
flow_loss, pos_loss, nfe_loss = compute_losses(
delta=delta,
dt=dt,
x_seq=x_seq,
x0=x0,
v_gt=v_gt,
t_seq=t_seq,
cfg=cfg,
)
loss = cfg.lambda_flow * flow_loss + cfg.lambda_pos * pos_loss
if not warmup:
loss = loss + cfg.lambda_nfe * nfe_loss
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
if global_step % 10 == 0:
logger.log(
{
"loss/total": float(loss.item()),
"loss/flow": float(flow_loss.item()),
"loss/pos": float(pos_loss.item()),
"loss/nfe": float(nfe_loss.item()),
"dt/mean": float(dt.mean().item()),
"dt/min": float(dt.min().item()),
"dt/max": float(dt.max().item()),
"stage": 0 if warmup else 1,
},
step=global_step,
)
if cfg.val_every > 0 and global_step > 0 and global_step % cfg.val_every == 0:
validate(model, cfg, sphere_a, sphere_b, device, logger, global_step)
global_step += 1
logger.finish()
return model, sphere_a, sphere_b
def rollout_trajectory(
model: ASMamba,
x0: Tensor,
max_steps: int = 100,
) -> Tensor:
device = x0.device
model.eval()
h = model.init_cache(batch_size=x0.shape[0], device=device)
x = x0
total_time = torch.zeros(x0.shape[0], device=device)
traj = [x0.detach().cpu()]
with torch.no_grad():
for _ in range(max_steps):
delta, dt, h = model.step(x, h)
dt = torch.clamp(dt, min=model.dt_min, max=model.dt_max)
remaining = 1.0 - total_time
overshoot = dt > remaining
if overshoot.any():
scale = (remaining / dt).unsqueeze(-1)
delta = torch.where(overshoot.unsqueeze(-1), delta * scale, delta)
dt = torch.where(overshoot, remaining, dt)
x = x + delta
total_time = total_time + dt
traj.append(x.detach().cpu())
if torch.all(total_time >= 1.0 - 1e-6):
break
return torch.stack(traj, dim=1)
def sphere_wireframe(
center: Tensor, radius: float, u_steps: int = 24, v_steps: int = 12
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
center_np = center.detach().cpu().numpy()
u = np.linspace(0, 2 * np.pi, u_steps)
v = np.linspace(0, np.pi, v_steps)
x = center_np[0] + radius * np.outer(np.cos(u), np.sin(v))
y = center_np[1] + radius * np.outer(np.sin(u), np.sin(v))
z = center_np[2] + radius * np.outer(np.ones_like(u), np.cos(v))
return x, y, z
def plot_trajectories(
traj: Tensor,
sphere_a: tuple[Tensor, Tensor],
sphere_b: tuple[Tensor, Tensor],
save_path: Path,
title: str = "AS-Mamba Trajectories",
) -> None:
traj = traj.detach().cpu()
if traj.dim() == 2:
traj = traj.unsqueeze(0)
traj_np = traj.numpy()
fig = plt.figure(figsize=(7, 6))
ax = fig.add_subplot(111, projection="3d")
for i in range(traj_np.shape[0]):
ax.plot(
traj_np[i, :, 0],
traj_np[i, :, 1],
traj_np[i, :, 2],
color="green",
alpha=0.6,
)
starts = traj_np[:, 0, :]
ends = traj_np[:, -1, :]
ax.scatter(starts[:, 0], starts[:, 1], starts[:, 2], color="blue", s=20, label="Start")
ax.scatter(ends[:, 0], ends[:, 1], ends[:, 2], color="red", s=20, label="End")
center_a, radius_a = sphere_a
center_b, radius_b = sphere_b
x_a, y_a, z_a = sphere_wireframe(center_a, float(radius_a.item()))
x_b, y_b, z_b = sphere_wireframe(center_b, float(radius_b.item()))
ax.plot_wireframe(x_a, y_a, z_a, color="blue", alpha=0.15, linewidth=0.5)
ax.plot_wireframe(x_b, y_b, z_b, color="red", alpha=0.15, linewidth=0.5)
ax.set_title(title)
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Z")
ax.legend(loc="best")
fig.tight_layout()
fig.savefig(save_path, dpi=160)
plt.close(fig)
def run_training_and_plot(cfg: TrainConfig) -> Path:
model, sphere_a, sphere_b = train(cfg)
device = next(model.parameters()).device
plot_samples = max(1, cfg.val_plot_samples)
x0 = sample_points_in_sphere(
sphere_a[0], float(sphere_a[1].item()), plot_samples, device
)
traj = rollout_trajectory(model, x0, max_steps=cfg.val_max_steps)
output_dir = Path(cfg.output_dir)
save_path = output_dir / "as_mamba_trajectory.png"
plot_trajectories(traj, sphere_a, sphere_b, save_path)
return save_path