from __future__ import annotations import math import os from dataclasses import asdict, dataclass from pathlib import Path from typing import Iterator, Optional import lpips os.environ.setdefault("MPLCONFIGDIR", "/tmp/mamba_diffusion_mplconfig") import matplotlib matplotlib.use("Agg") import numpy as np import torch import torch.nn.functional as F from datasets import load_dataset from matplotlib import pyplot as plt from torch import Tensor, nn from torch.func import jvp from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from mamba2_minimal import InferenceCache, Mamba2, Mamba2Config, RMSNorm FIXED_VAL_SAMPLING_STEPS = 5 FIXED_VAL_TIME_GRID = (1.0, 0.8, 0.6, 0.4, 0.2, 0.0) @dataclass class TrainConfig: seed: int = 42 device: str = "cuda" epochs: int = 50 steps_per_epoch: int = 200 batch_size: int = 128 seq_len: int = 5 lr: float = 2e-4 weight_decay: float = 1e-2 lambda_flow: float = 1.0 lambda_perceptual: float = 0.4 num_classes: int = 10 image_size: int = 28 channels: int = 1 num_workers: int = 8 dataset_name: str = "ylecun/mnist" dataset_split: str = "train" d_model: int = 0 n_layer: int = 6 d_state: int = 64 d_conv: int = 4 expand: int = 2 headdim: int = 32 chunk_size: int = 1 use_residual: bool = False output_dir: str = "outputs" project: str = "as-mamba-mnist" run_name: str = "mnist-meanflow" val_every: int = 200 val_samples_per_class: int = 8 val_grid_rows: int = 4 val_sampling_steps: int = FIXED_VAL_SAMPLING_STEPS time_grid_size: int = 256 use_ddp: bool = False class AdaLNZero(nn.Module): def __init__(self, d_model: int) -> None: super().__init__() self.norm = RMSNorm(d_model) self.mod = nn.Linear(d_model, 2 * d_model) nn.init.zeros_(self.mod.weight) nn.init.zeros_(self.mod.bias) def forward(self, x: Tensor, cond: Tensor) -> Tensor: x = self.norm(x) params = self.mod(cond).unsqueeze(1) scale, shift = params.chunk(2, dim=-1) return x * (1 + scale) + shift class Mamba2Backbone(nn.Module): def __init__(self, args: Mamba2Config, use_residual: bool = True) -> None: super().__init__() self.args = args self.use_residual = use_residual self.layers = nn.ModuleList( [ nn.ModuleDict( dict( mixer=Mamba2(args), adaln=AdaLNZero(args.d_model), ) ) for _ in range(args.n_layer) ] ) self.norm_f = RMSNorm(args.d_model) def forward( self, x: Tensor, cond: Tensor, h: Optional[list[InferenceCache]] = None ) -> tuple[Tensor, list[InferenceCache]]: if h is None: h = [None for _ in range(self.args.n_layer)] for i, layer in enumerate(self.layers): x_mod = layer["adaln"](x, cond) y, h[i] = layer["mixer"](x_mod, h[i]) x = x + y if self.use_residual else y x = self.norm_f(x) return x, h def sinusoidal_embedding(t: Tensor, dim: int) -> Tensor: if dim <= 0: raise ValueError("sinusoidal embedding dim must be > 0") half = dim // 2 if half == 0: return torch.zeros(*t.shape, dim, device=t.device, dtype=t.dtype) denom = max(half - 1, 1) freqs = torch.exp( -math.log(10000.0) * torch.arange(half, device=t.device, dtype=torch.float32) / denom ) args = t.to(torch.float32).unsqueeze(-1) * freqs emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1) if dim % 2 == 1: pad = torch.zeros(*emb.shape[:-1], 1, device=emb.device, dtype=emb.dtype) emb = torch.cat([emb, pad], dim=-1) return emb.to(t.dtype) def safe_time_divisor(t: Tensor) -> Tensor: eps = torch.finfo(t.dtype).eps return torch.where(t > 0, t, torch.full_like(t, eps)) class ASMamba(nn.Module): def __init__(self, cfg: TrainConfig) -> None: super().__init__() self.cfg = cfg input_dim = cfg.channels * cfg.image_size * cfg.image_size if cfg.d_model == 0: cfg.d_model = input_dim if cfg.d_model != input_dim: raise ValueError( f"d_model must equal flattened image dim ({input_dim}) when input_proj is disabled." ) args = Mamba2Config( d_model=cfg.d_model, n_layer=cfg.n_layer, d_state=cfg.d_state, d_conv=cfg.d_conv, expand=cfg.expand, headdim=cfg.headdim, chunk_size=cfg.chunk_size, ) self.backbone = Mamba2Backbone(args, use_residual=cfg.use_residual) self.cond_emb = nn.Embedding(cfg.num_classes, cfg.d_model) self.clean_head = nn.Linear(cfg.d_model, input_dim) def forward( self, z_t: Tensor, r: Tensor, t: Tensor, cond: Tensor, h: Optional[list[InferenceCache]] = None, ) -> tuple[Tensor, list[InferenceCache]]: if r.dim() == 1: r = r.unsqueeze(1) elif r.dim() == 3 and r.shape[-1] == 1: r = r.squeeze(-1) if t.dim() == 1: t = t.unsqueeze(1) elif t.dim() == 3 and t.shape[-1] == 1: t = t.squeeze(-1) r = r.to(dtype=z_t.dtype) t = t.to(dtype=z_t.dtype) z_t = z_t + sinusoidal_embedding(r, z_t.shape[-1]) + sinusoidal_embedding( t, z_t.shape[-1] ) cond_vec = self.cond_emb(cond) feats, h = self.backbone(z_t, cond_vec, h) x_pred = self.clean_head(feats) return x_pred, h def step( self, z_t: Tensor, r: Tensor, t: Tensor, cond: Tensor, h: list[InferenceCache] ) -> tuple[Tensor, list[InferenceCache]]: x_pred, h = self.forward( z_t.unsqueeze(1), r.unsqueeze(1), t.unsqueeze(1), cond, h ) return x_pred[:, 0, :], h def init_cache(self, batch_size: int, device: torch.device) -> list[InferenceCache]: return [ InferenceCache.alloc(batch_size, self.backbone.args, device=device) for _ in range(self.backbone.args.n_layer) ] class SwanLogger: def __init__(self, cfg: TrainConfig, enabled: bool = True) -> None: self.enabled = enabled self._swan = None self._run = None if not self.enabled: return try: import swanlab # type: ignore self._swan = swanlab self._run = self._swan.init( project=cfg.project, experiment_name=cfg.run_name, config=asdict(cfg), ) self.enabled = True except Exception: self.enabled = False def log(self, metrics: dict, step: int | None = None) -> None: if not self.enabled: return target = self._run if self._run is not None else self._swan if step is None: target.log(metrics) return try: target.log(metrics, step=step) except TypeError: payload = dict(metrics) payload["step"] = step target.log(payload) def log_image( self, key: str, image_path: Path, caption: str | None = None, step: int | None = None, ) -> None: if not self.enabled: return image = self._swan.Image(str(image_path), caption=caption) self.log({key: image}, step=step) def finish(self) -> None: if not self.enabled: return finish = None if self._run is not None: finish = getattr(self._run, "finish", None) if finish is None: finish = getattr(self._swan, "finish", None) if callable(finish): finish() class LPIPSPerceptualLoss(nn.Module): def __init__(self, cfg: TrainConfig) -> None: super().__init__() torch_home = Path(cfg.output_dir) / ".torch" torch_home.mkdir(parents=True, exist_ok=True) os.environ["TORCH_HOME"] = str(torch_home) self.channels = cfg.channels self.loss_fn = lpips.LPIPS(net="vgg", verbose=False) self.loss_fn.eval() for param in self.loss_fn.parameters(): param.requires_grad_(False) def _prepare_images(self, images: Tensor) -> Tensor: if images.shape[1] == 1: return images.repeat(1, 3, 1, 1) if images.shape[1] != 3: raise ValueError( "LPIPS perceptual loss expects 1-channel or 3-channel images. " f"Got {images.shape[1]} channels." ) return images def forward(self, pred: Tensor, target: Tensor) -> Tensor: pred_rgb = self._prepare_images(pred) target_rgb = self._prepare_images(target) return self.loss_fn(pred_rgb, target_rgb).mean() def set_seed(seed: int) -> None: torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def setup_distributed(cfg: TrainConfig) -> tuple[bool, int, int, torch.device]: world_size = int(os.environ.get("WORLD_SIZE", "1")) rank = int(os.environ.get("RANK", "0")) local_rank = int(os.environ.get("LOCAL_RANK", "0")) use_ddp = cfg.use_ddp and world_size > 1 if use_ddp: torch.distributed.init_process_group(backend="nccl", init_method="env://") torch.cuda.set_device(local_rank) device = torch.device("cuda", local_rank) else: device = torch.device(cfg.device if torch.cuda.is_available() else "cpu") return use_ddp, rank, world_size, device def unwrap_model(model: nn.Module) -> nn.Module: return model.module if hasattr(model, "module") else model def validate_config(cfg: TrainConfig) -> None: if cfg.seq_len != 5: raise ValueError( f"seq_len must be 5 for the required 5-block training setup (got {cfg.seq_len})." ) if cfg.time_grid_size < 2: raise ValueError("time_grid_size must be >= 2.") if cfg.lambda_perceptual < 0: raise ValueError("lambda_perceptual must be >= 0.") if cfg.val_sampling_steps != FIXED_VAL_SAMPLING_STEPS: raise ValueError( f"val_sampling_steps is fixed to {FIXED_VAL_SAMPLING_STEPS} for validation sampling." ) def sample_block_times( cfg: TrainConfig, batch_size: int, device: torch.device, dtype: torch.dtype ) -> tuple[Tensor, Tensor]: # Sampling sorted discrete cut points allows repeated boundaries, so zero-length # interior blocks occur with non-zero probability while keeping t > 0. cuts = torch.randint( 1, cfg.time_grid_size, (batch_size, cfg.seq_len - 1), device=device, ) cuts, _ = torch.sort(cuts, dim=-1) boundaries = torch.cat( [ torch.zeros(batch_size, 1, device=device, dtype=dtype), cuts.to(dtype=dtype) / float(cfg.time_grid_size), torch.ones(batch_size, 1, device=device, dtype=dtype), ], dim=-1, ) return boundaries[:, :-1], boundaries[:, 1:] def build_dataloader( cfg: TrainConfig, distributed: bool = False ) -> tuple[DataLoader, Optional[DistributedSampler]]: ds = load_dataset(cfg.dataset_name, split=cfg.dataset_split) def transform(example): image = example.get("img", example.get("image")) label = example.get("label", example.get("labels")) if isinstance(image, list): arr = np.stack([np.array(im, dtype=np.float32) for im in image], axis=0) arr = arr / 127.5 - 1.0 if arr.ndim == 3: tensor = torch.from_numpy(arr).unsqueeze(1) else: tensor = torch.from_numpy(arr).permute(0, 3, 1, 2) labels = torch.tensor(label, dtype=torch.long) return {"pixel_values": tensor, "labels": labels} arr = np.array(image, dtype=np.float32) / 127.5 - 1.0 if arr.ndim == 2: tensor = torch.from_numpy(arr).unsqueeze(0) else: tensor = torch.from_numpy(arr).permute(2, 0, 1) return {"pixel_values": tensor, "labels": torch.tensor(label, dtype=torch.long)} ds = ds.with_transform(transform) sampler = DistributedSampler(ds, shuffle=True) if distributed else None loader = DataLoader( ds, batch_size=cfg.batch_size, shuffle=(sampler is None), sampler=sampler, num_workers=cfg.num_workers, drop_last=True, pin_memory=torch.cuda.is_available(), ) return loader, sampler def infinite_loader(loader: DataLoader) -> Iterator[dict]: while True: for batch in loader: yield batch def build_noisy_sequence( x0: Tensor, eps: Tensor, t_seq: Tensor, ) -> tuple[Tensor, Tensor]: z_t = (1.0 - t_seq.unsqueeze(-1)) * x0[:, None, :] + t_seq.unsqueeze(-1) * eps[:, None, :] v_gt = eps - x0 return z_t, v_gt def compute_losses( model: nn.Module, perceptual_loss_fn: LPIPSPerceptualLoss, x0: Tensor, z_t: Tensor, v_gt: Tensor, r_seq: Tensor, t_seq: Tensor, cond: Tensor, cfg: TrainConfig, ) -> tuple[dict[str, Tensor], Tensor]: seq_len = z_t.shape[1] safe_t = safe_time_divisor(t_seq).unsqueeze(-1) x_pred, _ = model(z_t, r_seq, t_seq, cond) u = (z_t - x_pred) / safe_t x_pred_inst, _ = model(z_t, t_seq, t_seq, cond) v_inst = ((z_t - x_pred_inst) / safe_t).detach() def u_fn(z_in: Tensor, r_in: Tensor, t_in: Tensor) -> Tensor: x_pred_local, _ = model(z_in, r_in, t_in, cond) return (z_in - x_pred_local) / safe_time_divisor(t_in).unsqueeze(-1) _, dudt = jvp( u_fn, (z_t, r_seq, t_seq), (v_inst, torch.zeros_like(r_seq), torch.ones_like(t_seq)), ) corrected_velocity = u + (t_seq - r_seq).unsqueeze(-1) * dudt.detach() target_velocity = v_gt[:, None, :].expand(-1, seq_len, -1) pred_images = x_pred.reshape( x0.shape[0] * seq_len, cfg.channels, cfg.image_size, cfg.image_size ) target_images = ( x0.reshape(x0.shape[0], cfg.channels, cfg.image_size, cfg.image_size) .unsqueeze(1) .expand(-1, seq_len, -1, -1, -1) .reshape(x0.shape[0] * seq_len, cfg.channels, cfg.image_size, cfg.image_size) ) losses = { "flow": F.mse_loss(corrected_velocity, target_velocity), "perceptual": perceptual_loss_fn(pred_images, target_images), } losses["total"] = cfg.lambda_flow * losses["flow"] + cfg.lambda_perceptual * losses[ "perceptual" ] return losses, x_pred def make_grid(images: Tensor, nrow: int) -> np.ndarray: images = images.detach().cpu().numpy() b, c, h, w = images.shape nrow = max(1, min(nrow, b)) ncol = math.ceil(b / nrow) grid = np.zeros((c, ncol * h, nrow * w), dtype=np.float32) for idx in range(b): r = idx // nrow cidx = idx % nrow grid[:, r * h : (r + 1) * h, cidx * w : (cidx + 1) * w] = images[idx] return np.transpose(grid, (1, 2, 0)) def save_image_grid( images: Tensor, save_path: Path, nrow: int, title: str | None = None ) -> None: images = images.clamp(-1.0, 1.0) images = (images + 1.0) / 2.0 grid = make_grid(images, nrow=nrow) if grid.ndim == 3 and grid.shape[2] == 1: grid = np.repeat(grid, 3, axis=2) plt.imsave(save_path, grid) if title is not None: fig, ax = plt.subplots(figsize=(4, 3)) ax.imshow(grid) ax.set_title(title) ax.axis("off") fig.tight_layout() fig.savefig(save_path, dpi=160) plt.close(fig) def sample_class_images( model: ASMamba, cfg: TrainConfig, device: torch.device, cond: Tensor, ) -> Tensor: model.eval() input_dim = cfg.channels * cfg.image_size * cfg.image_size z_t = torch.randn(cond.shape[0], input_dim, device=device) time_grid = torch.tensor(FIXED_VAL_TIME_GRID, device=device) with torch.no_grad(): for step_idx in range(FIXED_VAL_SAMPLING_STEPS): t_cur = torch.full( (cond.shape[0],), float(time_grid[step_idx].item()), device=device, ) t_next = time_grid[step_idx + 1] x_pred, _ = model( z_t.unsqueeze(1), t_cur.unsqueeze(1), t_cur.unsqueeze(1), cond, ) x_pred = x_pred[:, 0, :] u_inst = (z_t - x_pred) / safe_time_divisor(t_cur).unsqueeze(-1) z_t = z_t + (t_next - t_cur).unsqueeze(-1) * u_inst return z_t.view(cond.shape[0], cfg.channels, cfg.image_size, cfg.image_size) def log_class_samples( model: ASMamba, cfg: TrainConfig, device: torch.device, logger: SwanLogger, step: int, ) -> None: if cfg.val_samples_per_class <= 0: return training_mode = model.training model.eval() for cls in range(cfg.num_classes): cond = torch.full( (cfg.val_samples_per_class,), cls, device=device, dtype=torch.long ) x_final = sample_class_images(model, cfg, device, cond) save_path = Path(cfg.output_dir) / f"val_class_{cls}_step_{step:06d}.png" save_image_grid(x_final, save_path, nrow=cfg.val_grid_rows) logger.log_image( f"val/class_{cls}", save_path, caption=f"class {cls} step {step}", step=step, ) if training_mode: model.train() def build_perceptual_loss( cfg: TrainConfig, device: torch.device, rank: int, use_ddp: bool ) -> LPIPSPerceptualLoss: if use_ddp and rank != 0: torch.distributed.barrier() perceptual_loss_fn = LPIPSPerceptualLoss(cfg).to(device) if use_ddp and rank == 0: torch.distributed.barrier() return perceptual_loss_fn def train(cfg: TrainConfig) -> ASMamba: validate_config(cfg) use_ddp, rank, world_size, device = setup_distributed(cfg) del world_size set_seed(cfg.seed + rank) output_dir = Path(cfg.output_dir) if rank == 0: output_dir.mkdir(parents=True, exist_ok=True) model = ASMamba(cfg).to(device) if use_ddp: model = nn.parallel.DistributedDataParallel(model, device_ids=[device.index]) optimizer = torch.optim.AdamW( model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay ) perceptual_loss_fn = build_perceptual_loss(cfg, device, rank, use_ddp) logger = SwanLogger(cfg, enabled=(rank == 0)) loader, sampler = build_dataloader(cfg, distributed=use_ddp) loader_iter = infinite_loader(loader) global_step = 0 for epoch_idx in range(cfg.epochs): if sampler is not None: sampler.set_epoch(epoch_idx) model.train() for _ in range(cfg.steps_per_epoch): batch = next(loader_iter) x0 = batch["pixel_values"].to(device) cond = batch["labels"].to(device) b = x0.shape[0] x0 = x0.view(b, -1) eps = torch.randn_like(x0) r_seq, t_seq = sample_block_times(cfg, b, device, x0.dtype) z_t, v_gt = build_noisy_sequence(x0, eps, t_seq) losses, _ = compute_losses( model=model, perceptual_loss_fn=perceptual_loss_fn, x0=x0, z_t=z_t, v_gt=v_gt, r_seq=r_seq, t_seq=t_seq, cond=cond, cfg=cfg, ) optimizer.zero_grad(set_to_none=True) losses["total"].backward() optimizer.step() if global_step % 10 == 0 and rank == 0: logger.log( { "loss/total": float(losses["total"].item()), "loss/flow": float(losses["flow"].item()), "loss/perceptual": float(losses["perceptual"].item()), "time/r_mean": float(r_seq.mean().item()), "time/t_mean": float(t_seq.mean().item()), "time/zero_block_frac": float((t_seq == r_seq).float().mean().item()), }, step=global_step, ) if ( cfg.val_every > 0 and global_step > 0 and global_step % cfg.val_every == 0 and rank == 0 ): log_class_samples(unwrap_model(model), cfg, device, logger, global_step) global_step += 1 logger.finish() if use_ddp: torch.distributed.destroy_process_group() return unwrap_model(model) def run_training_and_plot(cfg: TrainConfig) -> Path: train(cfg) return Path(cfg.output_dir)