from __future__ import annotations import math from dataclasses import asdict, dataclass import os from pathlib import Path from typing import Iterator, Optional import matplotlib matplotlib.use("Agg") import numpy as np import torch import torch.nn.functional as F from datasets import load_dataset from matplotlib import pyplot as plt from torch import Tensor, nn from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from mamba2_minimal import InferenceCache, Mamba2, Mamba2Config, RMSNorm @dataclass class TrainConfig: seed: int = 42 device: str = "cuda" epochs: int = 50 steps_per_epoch: int = 200 batch_size: int = 128 seq_len: int = 20 lr: float = 2e-4 weight_decay: float = 1e-2 dt_min: float = 1e-3 dt_max: float = 0.06 dt_alpha: float = 8.0 lambda_flow: float = 1.0 lambda_pos: float = 1.0 lambda_dt: float = 1.0 use_flow_loss: bool = True use_pos_loss: bool = False use_dt_loss: bool = True num_classes: int = 10 image_size: int = 28 channels: int = 1 num_workers: int = 8 dataset_name: str = "ylecun/mnist" dataset_split: str = "train" d_model: int = 0 n_layer: int = 6 d_state: int = 64 d_conv: int = 4 expand: int = 2 headdim: int = 32 chunk_size: int = 1 use_residual: bool = False output_dir: str = "outputs" project: str = "as-mamba-mnist" run_name: str = "mnist-flow" val_every: int = 200 val_samples_per_class: int = 8 val_grid_rows: int = 4 val_max_steps: int = 0 use_ddp: bool = False class AdaLNZero(nn.Module): def __init__(self, d_model: int) -> None: super().__init__() self.norm = RMSNorm(d_model) self.mod = nn.Linear(d_model, 2 * d_model) nn.init.zeros_(self.mod.weight) nn.init.zeros_(self.mod.bias) def forward(self, x: Tensor, cond: Tensor) -> Tensor: x = self.norm(x) params = self.mod(cond).unsqueeze(1) scale, shift = params.chunk(2, dim=-1) return x * (1 + scale) + shift class Mamba2Backbone(nn.Module): def __init__(self, args: Mamba2Config, use_residual: bool = True) -> None: super().__init__() self.args = args self.use_residual = use_residual self.layers = nn.ModuleList( [ nn.ModuleDict( dict( mixer=Mamba2(args), adaln=AdaLNZero(args.d_model), ) ) for _ in range(args.n_layer) ] ) self.norm_f = RMSNorm(args.d_model) def forward( self, x: Tensor, cond: Tensor, h: Optional[list[InferenceCache]] = None ) -> tuple[Tensor, list[InferenceCache]]: if h is None: h = [None for _ in range(self.args.n_layer)] for i, layer in enumerate(self.layers): x_mod = layer["adaln"](x, cond) y, h[i] = layer["mixer"](x_mod, h[i]) x = x + y if self.use_residual else y x = self.norm_f(x) return x, h class ASMamba(nn.Module): def __init__(self, cfg: TrainConfig) -> None: super().__init__() self.cfg = cfg self.dt_min = float(cfg.dt_min) self.dt_max = float(cfg.dt_max) input_dim = cfg.channels * cfg.image_size * cfg.image_size if cfg.d_model == 0: cfg.d_model = input_dim if cfg.d_model != input_dim: raise ValueError( f"d_model must equal flattened image dim ({input_dim}) when input_proj is disabled." ) args = Mamba2Config( d_model=cfg.d_model, n_layer=cfg.n_layer, d_state=cfg.d_state, d_conv=cfg.d_conv, expand=cfg.expand, headdim=cfg.headdim, chunk_size=cfg.chunk_size, ) self.backbone = Mamba2Backbone(args, use_residual=cfg.use_residual) self.cond_emb = nn.Embedding(cfg.num_classes, cfg.d_model) self.delta_head = nn.Linear(cfg.d_model, input_dim) self.dt_head = nn.Sequential( nn.Linear(cfg.d_model, cfg.d_model), nn.SiLU(), nn.Linear(cfg.d_model, 1), ) def forward( self, x: Tensor, cond: Tensor, h: Optional[list[InferenceCache]] = None ) -> tuple[Tensor, Tensor, list[InferenceCache]]: cond_vec = self.cond_emb(cond) feats, h = self.backbone(x, cond_vec, h) delta = self.delta_head(feats) dt_raw = self.dt_head(feats).squeeze(-1) dt = F.softplus(dt_raw) return delta, dt, h def step( self, x: Tensor, cond: Tensor, h: list[InferenceCache] ) -> tuple[Tensor, Tensor, list[InferenceCache]]: delta, dt, h = self.forward(x.unsqueeze(1), cond, h) return delta[:, 0, :], dt[:, 0], h def init_cache(self, batch_size: int, device: torch.device) -> list[InferenceCache]: return [ InferenceCache.alloc(batch_size, self.backbone.args, device=device) for _ in range(self.backbone.args.n_layer) ] class SwanLogger: def __init__(self, cfg: TrainConfig, enabled: bool = True) -> None: self.enabled = enabled self._swan = None self._run = None if not self.enabled: return try: import swanlab # type: ignore self._swan = swanlab self._run = self._swan.init( project=cfg.project, experiment_name=cfg.run_name, config=asdict(cfg), ) self.enabled = True except Exception: self.enabled = False def log(self, metrics: dict, step: int | None = None) -> None: if not self.enabled: return target = self._run if self._run is not None else self._swan if step is None: target.log(metrics) return try: target.log(metrics, step=step) except TypeError: payload = dict(metrics) payload["step"] = step target.log(payload) def log_image( self, key: str, image_path: Path, caption: str | None = None, step: int | None = None, ) -> None: if not self.enabled: return image = self._swan.Image(str(image_path), caption=caption) self.log({key: image}, step=step) def finish(self) -> None: if not self.enabled: return finish = None if self._run is not None: finish = getattr(self._run, "finish", None) if finish is None: finish = getattr(self._swan, "finish", None) if callable(finish): finish() def set_seed(seed: int) -> None: torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def setup_distributed(cfg: TrainConfig) -> tuple[bool, int, int, torch.device]: world_size = int(os.environ.get("WORLD_SIZE", "1")) rank = int(os.environ.get("RANK", "0")) local_rank = int(os.environ.get("LOCAL_RANK", "0")) use_ddp = cfg.use_ddp and world_size > 1 if use_ddp: torch.distributed.init_process_group(backend="nccl", init_method="env://") torch.cuda.set_device(local_rank) device = torch.device("cuda", local_rank) else: device = torch.device(cfg.device if torch.cuda.is_available() else "cpu") return use_ddp, rank, world_size, device def unwrap_model(model: nn.Module) -> nn.Module: return model.module if hasattr(model, "module") else model def validate_time_config(cfg: TrainConfig) -> None: if cfg.seq_len <= 0: raise ValueError("seq_len must be > 0") base = 1.0 / cfg.seq_len if cfg.dt_max <= base: raise ValueError( "dt_max must be > 1/seq_len to allow non-uniform dt_seq. " f"Got dt_max={cfg.dt_max}, seq_len={cfg.seq_len}, 1/seq_len={base}." ) if cfg.dt_min >= cfg.dt_max: raise ValueError( f"dt_min must be < dt_max (got dt_min={cfg.dt_min}, dt_max={cfg.dt_max})." ) def sample_time_sequence( cfg: TrainConfig, batch_size: int, device: torch.device ) -> Tensor: alpha = float(cfg.dt_alpha) if alpha <= 0: raise ValueError("dt_alpha must be > 0") dist = torch.distributions.Gamma(alpha, 1.0) raw = dist.sample((batch_size, cfg.seq_len)).to(device) dt_seq = raw / raw.sum(dim=-1, keepdim=True) base = 1.0 / cfg.seq_len max_dt = float(cfg.dt_max) if max_dt <= base: return torch.full_like(dt_seq, base) max_current = dt_seq.max(dim=-1, keepdim=True).values if (max_current > max_dt).any(): gamma = (max_dt - base) / (max_current - base) gamma = gamma.clamp(0.0, 1.0) dt_seq = gamma * dt_seq + (1.0 - gamma) * base return dt_seq def build_dataloader( cfg: TrainConfig, distributed: bool = False ) -> tuple[DataLoader, Optional[DistributedSampler]]: ds = load_dataset(cfg.dataset_name, split=cfg.dataset_split) def transform(example): image = example.get("img", example.get("image")) label = example.get("label", example.get("labels")) if isinstance(image, list): arr = np.stack([np.array(im, dtype=np.float32) for im in image], axis=0) arr = arr / 127.5 - 1.0 if arr.ndim == 3: tensor = torch.from_numpy(arr).unsqueeze(1) else: tensor = torch.from_numpy(arr).permute(0, 3, 1, 2) labels = torch.tensor(label, dtype=torch.long) return {"pixel_values": tensor, "labels": labels} arr = np.array(image, dtype=np.float32) / 127.5 - 1.0 if arr.ndim == 2: tensor = torch.from_numpy(arr).unsqueeze(0) else: tensor = torch.from_numpy(arr).permute(2, 0, 1) return {"pixel_values": tensor, "labels": torch.tensor(label, dtype=torch.long)} ds = ds.with_transform(transform) sampler = DistributedSampler(ds, shuffle=True) if distributed else None loader = DataLoader( ds, batch_size=cfg.batch_size, shuffle=(sampler is None), sampler=sampler, num_workers=cfg.num_workers, drop_last=True, pin_memory=torch.cuda.is_available(), ) return loader, sampler def infinite_loader(loader: DataLoader) -> Iterator[dict]: while True: for batch in loader: yield batch def compute_losses( delta: Tensor, dt: Tensor, x_seq: Tensor, x0: Tensor, v_gt: Tensor, t_seq: Tensor, dt_seq: Tensor, cfg: TrainConfig, ) -> dict[str, Tensor]: losses: dict[str, Tensor] = {} if cfg.use_flow_loss: target_disp = v_gt[:, None, :] * dt.unsqueeze(-1) losses["flow"] = F.mse_loss(delta, target_disp) if cfg.use_pos_loss: t_next = t_seq + dt t_next = torch.clamp(t_next, 0.0, 1.0) x_target = x0[:, None, :] + t_next.unsqueeze(-1) * v_gt[:, None, :] x_next_pred = x_seq + delta losses["pos"] = F.mse_loss(x_next_pred, x_target) if cfg.use_dt_loss: losses["dt"] = F.mse_loss(dt, dt_seq) return losses def plot_dt_hist( dt_pred: Tensor, dt_gt: Tensor, save_path: Path, title: str = "dt Distribution" ) -> None: dt_pred_np = dt_pred.detach().cpu().numpy().reshape(-1) dt_gt_np = dt_gt.detach().cpu().numpy().reshape(-1) fig, ax = plt.subplots(figsize=(6, 4)) ax.hist(dt_gt_np, bins=30, alpha=0.6, label="dt_gt", color="steelblue") ax.hist(dt_pred_np, bins=30, alpha=0.6, label="dt_pred", color="orange") ax.set_title(title) ax.set_xlabel("dt") ax.set_ylabel("count") ax.legend(loc="best") fig.tight_layout() fig.savefig(save_path, dpi=160) plt.close(fig) def make_grid(images: Tensor, nrow: int) -> np.ndarray: images = images.detach().cpu().numpy() b, c, h, w = images.shape nrow = max(1, min(nrow, b)) ncol = math.ceil(b / nrow) grid = np.zeros((c, ncol * h, nrow * w), dtype=np.float32) for idx in range(b): r = idx // nrow cidx = idx % nrow grid[:, r * h : (r + 1) * h, cidx * w : (cidx + 1) * w] = images[idx] return np.transpose(grid, (1, 2, 0)) def save_image_grid( images: Tensor, save_path: Path, nrow: int, title: str | None = None ) -> None: images = images.clamp(-1.0, 1.0) images = (images + 1.0) / 2.0 grid = make_grid(images, nrow=nrow) if grid.ndim == 3 and grid.shape[2] == 1: grid = np.repeat(grid, 3, axis=2) plt.imsave(save_path, grid) if title is not None: fig, ax = plt.subplots(figsize=(4, 3)) ax.imshow(grid) ax.set_title(title) ax.axis("off") fig.tight_layout() fig.savefig(save_path, dpi=160) plt.close(fig) def rollout_trajectory( model: ASMamba, x0: Tensor, cond: Tensor, max_steps: int, ) -> Tensor: device = x0.device model.eval() h = model.init_cache(batch_size=x0.shape[0], device=device) x = x0 total_time = torch.zeros(x0.shape[0], device=device) traj = [x0.detach().cpu()] with torch.no_grad(): for _ in range(max_steps): delta, dt, h = model.step(x, cond, h) dt = torch.clamp(dt, min=model.dt_min, max=model.dt_max) remaining = 1.0 - total_time overshoot = dt > remaining if overshoot.any(): scale = (remaining / dt).unsqueeze(-1) delta = torch.where(overshoot.unsqueeze(-1), delta * scale, delta) dt = torch.where(overshoot, remaining, dt) x = x + delta total_time = total_time + dt traj.append(x.detach().cpu()) if torch.all(total_time >= 1.0 - 1e-6): break return torch.stack(traj, dim=1) def log_class_samples( model: ASMamba, cfg: TrainConfig, device: torch.device, logger: SwanLogger, step: int, ) -> None: if cfg.val_samples_per_class <= 0: return model.eval() max_steps = cfg.seq_len if cfg.val_max_steps <= 0 else cfg.val_max_steps input_dim = cfg.channels * cfg.image_size * cfg.image_size for cls in range(cfg.num_classes): cond = torch.full( (cfg.val_samples_per_class,), cls, device=device, dtype=torch.long ) x0 = torch.randn(cfg.val_samples_per_class, input_dim, device=device) traj = rollout_trajectory(model, x0, cond, max_steps=max_steps) x_final = traj[:, -1, :].view( cfg.val_samples_per_class, cfg.channels, cfg.image_size, cfg.image_size ) save_path = Path(cfg.output_dir) / f"val_class_{cls}_step_{step:06d}.png" save_image_grid(x_final, save_path, nrow=cfg.val_grid_rows) logger.log_image( f"val/class_{cls}", save_path, caption=f"class {cls} step {step}", step=step, ) model.train() def train(cfg: TrainConfig) -> ASMamba: validate_time_config(cfg) use_ddp, rank, world_size, device = setup_distributed(cfg) set_seed(cfg.seed + rank) output_dir = Path(cfg.output_dir) if rank == 0: output_dir.mkdir(parents=True, exist_ok=True) model = ASMamba(cfg).to(device) if use_ddp: model = nn.parallel.DistributedDataParallel(model, device_ids=[device.index]) optimizer = torch.optim.AdamW( model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay ) logger = SwanLogger(cfg, enabled=(rank == 0)) loader, sampler = build_dataloader(cfg, distributed=use_ddp) loader_iter = infinite_loader(loader) global_step = 0 for _ in range(cfg.epochs): if sampler is not None: sampler.set_epoch(global_step) model.train() for _ in range(cfg.steps_per_epoch): batch = next(loader_iter) x1 = batch["pixel_values"].to(device) cond = batch["labels"].to(device) b = x1.shape[0] x1 = x1.view(b, -1) x0 = torch.randn_like(x1) v_gt = x1 - x0 dt_seq = sample_time_sequence(cfg, b, device) t_seq = torch.cumsum(dt_seq, dim=-1) t_seq = torch.cat([torch.zeros(b, 1, device=device), t_seq[:, :-1]], dim=-1) x_seq = x0[:, None, :] + t_seq[:, :, None] * v_gt[:, None, :] delta, dt, _ = model(x_seq, cond) losses = compute_losses( delta=delta, dt=dt, x_seq=x_seq, x0=x0, v_gt=v_gt, t_seq=t_seq, dt_seq=dt_seq, cfg=cfg, ) loss = torch.tensor(0.0, device=device) if cfg.use_flow_loss and "flow" in losses: loss = loss + cfg.lambda_flow * losses["flow"] if cfg.use_pos_loss and "pos" in losses: loss = loss + cfg.lambda_pos * losses["pos"] if cfg.use_dt_loss and "dt" in losses: loss = loss + cfg.lambda_dt * losses["dt"] if loss.item() == 0.0: raise RuntimeError( "No loss enabled: enable at least one of flow/pos/dt." ) optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step() if global_step % 10 == 0: dt_min = float(dt.min().item()) dt_max = float(dt.max().item()) dt_mean = float(dt.mean().item()) dt_gt_min = float(dt_seq.min().item()) dt_gt_max = float(dt_seq.max().item()) dt_gt_mean = float(dt_seq.mean().item()) eps = 1e-6 clamp_min_ratio = float((dt <= cfg.dt_min + eps).float().mean().item()) clamp_max_ratio = float((dt >= cfg.dt_max - eps).float().mean().item()) clamp_any_ratio = float( ((dt <= cfg.dt_min + eps) | (dt >= cfg.dt_max - eps)) .float() .mean() .item() ) logger.log( { "loss/total": float(loss.item()), "loss/flow": float( losses.get("flow", torch.tensor(0.0)).item() ), "loss/pos": float(losses.get("pos", torch.tensor(0.0)).item()), "loss/dt": float(losses.get("dt", torch.tensor(0.0)).item()), "dt/pred_mean": dt_mean, "dt/pred_min": dt_min, "dt/pred_max": dt_max, "dt/gt_mean": dt_gt_mean, "dt/gt_min": dt_gt_min, "dt/gt_max": dt_gt_max, "dt/clamp_min_ratio": clamp_min_ratio, "dt/clamp_max_ratio": clamp_max_ratio, "dt/clamp_any_ratio": clamp_any_ratio, }, step=global_step, ) if ( cfg.val_every > 0 and global_step > 0 and global_step % cfg.val_every == 0 and rank == 0 ): log_class_samples(unwrap_model(model), cfg, device, logger, global_step) dt_hist_path = ( Path(cfg.output_dir) / f"dt_hist_step_{global_step:06d}.png" ) plot_dt_hist( dt, dt_seq, dt_hist_path, title=f"dt Distribution (step {global_step})", ) logger.log_image( "train/dt_hist", dt_hist_path, caption=f"step {global_step}", step=global_step, ) global_step += 1 logger.finish() if use_ddp: torch.distributed.destroy_process_group() return unwrap_model(model) def run_training_and_plot(cfg: TrainConfig) -> Path: train(cfg) return Path(cfg.output_dir)