Files
mamba_diffusion/as_mamba.py
gameloader 01fc1e4eab refactor: simplify delta-only flow training
Remove learned dt prediction and auxiliary losses.

Add repository contributor guidelines.
2026-03-10 18:23:17 +08:00

535 lines
16 KiB
Python

from __future__ import annotations
import math
from dataclasses import asdict, dataclass
import os
from pathlib import Path
from typing import Iterator, Optional
import matplotlib
matplotlib.use("Agg")
import numpy as np
import torch
import torch.nn.functional as F
from datasets import load_dataset
from matplotlib import pyplot as plt
from torch import Tensor, nn
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from mamba2_minimal import InferenceCache, Mamba2, Mamba2Config, RMSNorm
@dataclass
class TrainConfig:
seed: int = 42
device: str = "cuda"
epochs: int = 50
steps_per_epoch: int = 200
batch_size: int = 128
seq_len: int = 20
lr: float = 2e-4
weight_decay: float = 1e-2
dt_min: float = 1e-3
dt_max: float = 0.06
dt_alpha: float = 8.0
lambda_flow: float = 1.0
num_classes: int = 10
image_size: int = 28
channels: int = 1
num_workers: int = 8
dataset_name: str = "ylecun/mnist"
dataset_split: str = "train"
d_model: int = 0
n_layer: int = 6
d_state: int = 64
d_conv: int = 4
expand: int = 2
headdim: int = 32
chunk_size: int = 1
use_residual: bool = False
output_dir: str = "outputs"
project: str = "as-mamba-mnist"
run_name: str = "mnist-flow"
val_every: int = 200
val_samples_per_class: int = 8
val_grid_rows: int = 4
use_ddp: bool = False
class AdaLNZero(nn.Module):
def __init__(self, d_model: int) -> None:
super().__init__()
self.norm = RMSNorm(d_model)
self.mod = nn.Linear(d_model, 2 * d_model)
nn.init.zeros_(self.mod.weight)
nn.init.zeros_(self.mod.bias)
def forward(self, x: Tensor, cond: Tensor) -> Tensor:
x = self.norm(x)
params = self.mod(cond).unsqueeze(1)
scale, shift = params.chunk(2, dim=-1)
return x * (1 + scale) + shift
class Mamba2Backbone(nn.Module):
def __init__(self, args: Mamba2Config, use_residual: bool = True) -> None:
super().__init__()
self.args = args
self.use_residual = use_residual
self.layers = nn.ModuleList(
[
nn.ModuleDict(
dict(
mixer=Mamba2(args),
adaln=AdaLNZero(args.d_model),
)
)
for _ in range(args.n_layer)
]
)
self.norm_f = RMSNorm(args.d_model)
def forward(
self, x: Tensor, cond: Tensor, h: Optional[list[InferenceCache]] = None
) -> tuple[Tensor, list[InferenceCache]]:
if h is None:
h = [None for _ in range(self.args.n_layer)]
for i, layer in enumerate(self.layers):
x_mod = layer["adaln"](x, cond)
y, h[i] = layer["mixer"](x_mod, h[i])
x = x + y if self.use_residual else y
x = self.norm_f(x)
return x, h
def sinusoidal_embedding(t: Tensor, dim: int) -> Tensor:
if dim <= 0:
raise ValueError("sinusoidal embedding dim must be > 0")
half = dim // 2
if half == 0:
return torch.zeros(*t.shape, dim, device=t.device, dtype=t.dtype)
denom = max(half - 1, 1)
freqs = torch.exp(
-math.log(10000.0)
* torch.arange(half, device=t.device, dtype=torch.float32)
/ denom
)
args = t.to(torch.float32).unsqueeze(-1) * freqs
emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
if dim % 2 == 1:
pad = torch.zeros(*emb.shape[:-1], 1, device=emb.device, dtype=emb.dtype)
emb = torch.cat([emb, pad], dim=-1)
return emb.to(t.dtype)
class ASMamba(nn.Module):
def __init__(self, cfg: TrainConfig) -> None:
super().__init__()
self.cfg = cfg
input_dim = cfg.channels * cfg.image_size * cfg.image_size
if cfg.d_model == 0:
cfg.d_model = input_dim
if cfg.d_model != input_dim:
raise ValueError(
f"d_model must equal flattened image dim ({input_dim}) when input_proj is disabled."
)
args = Mamba2Config(
d_model=cfg.d_model,
n_layer=cfg.n_layer,
d_state=cfg.d_state,
d_conv=cfg.d_conv,
expand=cfg.expand,
headdim=cfg.headdim,
chunk_size=cfg.chunk_size,
)
self.backbone = Mamba2Backbone(args, use_residual=cfg.use_residual)
self.cond_emb = nn.Embedding(cfg.num_classes, cfg.d_model)
self.delta_head = nn.Linear(cfg.d_model, input_dim)
def forward(
self,
x: Tensor,
dt: Tensor,
cond: Tensor,
h: Optional[list[InferenceCache]] = None,
) -> tuple[Tensor, list[InferenceCache]]:
if dt.dim() == 1:
dt = dt.unsqueeze(1)
elif dt.dim() == 3 and dt.shape[-1] == 1:
dt = dt.squeeze(-1)
dt = dt.to(dtype=x.dtype)
dt_emb = sinusoidal_embedding(dt, x.shape[-1])
x = x + dt_emb
cond_vec = self.cond_emb(cond)
feats, h = self.backbone(x, cond_vec, h)
delta = self.delta_head(feats)
return delta, h
def step(
self, x: Tensor, dt: Tensor, cond: Tensor, h: list[InferenceCache]
) -> tuple[Tensor, list[InferenceCache]]:
delta, h = self.forward(x.unsqueeze(1), dt.unsqueeze(1), cond, h)
return delta[:, 0, :], h
def init_cache(self, batch_size: int, device: torch.device) -> list[InferenceCache]:
return [
InferenceCache.alloc(batch_size, self.backbone.args, device=device)
for _ in range(self.backbone.args.n_layer)
]
class SwanLogger:
def __init__(self, cfg: TrainConfig, enabled: bool = True) -> None:
self.enabled = enabled
self._swan = None
self._run = None
if not self.enabled:
return
try:
import swanlab # type: ignore
self._swan = swanlab
self._run = self._swan.init(
project=cfg.project,
experiment_name=cfg.run_name,
config=asdict(cfg),
)
self.enabled = True
except Exception:
self.enabled = False
def log(self, metrics: dict, step: int | None = None) -> None:
if not self.enabled:
return
target = self._run if self._run is not None else self._swan
if step is None:
target.log(metrics)
return
try:
target.log(metrics, step=step)
except TypeError:
payload = dict(metrics)
payload["step"] = step
target.log(payload)
def log_image(
self,
key: str,
image_path: Path,
caption: str | None = None,
step: int | None = None,
) -> None:
if not self.enabled:
return
image = self._swan.Image(str(image_path), caption=caption)
self.log({key: image}, step=step)
def finish(self) -> None:
if not self.enabled:
return
finish = None
if self._run is not None:
finish = getattr(self._run, "finish", None)
if finish is None:
finish = getattr(self._swan, "finish", None)
if callable(finish):
finish()
def set_seed(seed: int) -> None:
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def setup_distributed(cfg: TrainConfig) -> tuple[bool, int, int, torch.device]:
world_size = int(os.environ.get("WORLD_SIZE", "1"))
rank = int(os.environ.get("RANK", "0"))
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
use_ddp = cfg.use_ddp and world_size > 1
if use_ddp:
torch.distributed.init_process_group(backend="nccl", init_method="env://")
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
else:
device = torch.device(cfg.device if torch.cuda.is_available() else "cpu")
return use_ddp, rank, world_size, device
def unwrap_model(model: nn.Module) -> nn.Module:
return model.module if hasattr(model, "module") else model
def validate_time_config(cfg: TrainConfig) -> None:
if cfg.seq_len <= 0:
raise ValueError("seq_len must be > 0")
base = 1.0 / cfg.seq_len
if cfg.dt_max <= base:
raise ValueError(
"dt_max must be > 1/seq_len to allow non-uniform dt_seq. "
f"Got dt_max={cfg.dt_max}, seq_len={cfg.seq_len}, 1/seq_len={base}."
)
if cfg.dt_min >= cfg.dt_max:
raise ValueError(
f"dt_min must be < dt_max (got dt_min={cfg.dt_min}, dt_max={cfg.dt_max})."
)
def sample_time_sequence(
cfg: TrainConfig, batch_size: int, device: torch.device
) -> Tensor:
alpha = float(cfg.dt_alpha)
if alpha <= 0:
raise ValueError("dt_alpha must be > 0")
dist = torch.distributions.Gamma(alpha, 1.0)
raw = dist.sample((batch_size, cfg.seq_len)).to(device)
dt_seq = raw / raw.sum(dim=-1, keepdim=True)
base = 1.0 / cfg.seq_len
max_dt = float(cfg.dt_max)
if max_dt <= base:
return torch.full_like(dt_seq, base)
max_current = dt_seq.max(dim=-1, keepdim=True).values
if (max_current > max_dt).any():
gamma = (max_dt - base) / (max_current - base)
gamma = gamma.clamp(0.0, 1.0)
dt_seq = gamma * dt_seq + (1.0 - gamma) * base
return dt_seq
def build_dataloader(
cfg: TrainConfig, distributed: bool = False
) -> tuple[DataLoader, Optional[DistributedSampler]]:
ds = load_dataset(cfg.dataset_name, split=cfg.dataset_split)
def transform(example):
image = example.get("img", example.get("image"))
label = example.get("label", example.get("labels"))
if isinstance(image, list):
arr = np.stack([np.array(im, dtype=np.float32) for im in image], axis=0)
arr = arr / 127.5 - 1.0
if arr.ndim == 3:
tensor = torch.from_numpy(arr).unsqueeze(1)
else:
tensor = torch.from_numpy(arr).permute(0, 3, 1, 2)
labels = torch.tensor(label, dtype=torch.long)
return {"pixel_values": tensor, "labels": labels}
arr = np.array(image, dtype=np.float32) / 127.5 - 1.0
if arr.ndim == 2:
tensor = torch.from_numpy(arr).unsqueeze(0)
else:
tensor = torch.from_numpy(arr).permute(2, 0, 1)
return {"pixel_values": tensor, "labels": torch.tensor(label, dtype=torch.long)}
ds = ds.with_transform(transform)
sampler = DistributedSampler(ds, shuffle=True) if distributed else None
loader = DataLoader(
ds,
batch_size=cfg.batch_size,
shuffle=(sampler is None),
sampler=sampler,
num_workers=cfg.num_workers,
drop_last=True,
pin_memory=torch.cuda.is_available(),
)
return loader, sampler
def infinite_loader(loader: DataLoader) -> Iterator[dict]:
while True:
for batch in loader:
yield batch
def compute_losses(
delta: Tensor,
v_gt: Tensor,
dt_seq: Tensor,
) -> dict[str, Tensor]:
losses: dict[str, Tensor] = {}
target_disp = v_gt[:, None, :] * dt_seq.unsqueeze(-1)
losses["flow"] = F.mse_loss(delta, target_disp)
return losses
def make_grid(images: Tensor, nrow: int) -> np.ndarray:
images = images.detach().cpu().numpy()
b, c, h, w = images.shape
nrow = max(1, min(nrow, b))
ncol = math.ceil(b / nrow)
grid = np.zeros((c, ncol * h, nrow * w), dtype=np.float32)
for idx in range(b):
r = idx // nrow
cidx = idx % nrow
grid[:, r * h : (r + 1) * h, cidx * w : (cidx + 1) * w] = images[idx]
return np.transpose(grid, (1, 2, 0))
def save_image_grid(
images: Tensor, save_path: Path, nrow: int, title: str | None = None
) -> None:
images = images.clamp(-1.0, 1.0)
images = (images + 1.0) / 2.0
grid = make_grid(images, nrow=nrow)
if grid.ndim == 3 and grid.shape[2] == 1:
grid = np.repeat(grid, 3, axis=2)
plt.imsave(save_path, grid)
if title is not None:
fig, ax = plt.subplots(figsize=(4, 3))
ax.imshow(grid)
ax.set_title(title)
ax.axis("off")
fig.tight_layout()
fig.savefig(save_path, dpi=160)
plt.close(fig)
def rollout_trajectory(
model: ASMamba,
x0: Tensor,
cond: Tensor,
dt_seq: Tensor,
) -> Tensor:
device = x0.device
model.eval()
h = model.init_cache(batch_size=x0.shape[0], device=device)
x = x0
if dt_seq.dim() == 1:
dt_seq = dt_seq.unsqueeze(0).expand(x0.shape[0], -1)
elif dt_seq.shape[0] == 1 and x0.shape[0] > 1:
dt_seq = dt_seq.expand(x0.shape[0], -1)
traj = [x0.detach().cpu()]
with torch.no_grad():
for step_idx in range(dt_seq.shape[1]):
dt = dt_seq[:, step_idx]
delta, h = model.step(x, dt, cond, h)
x = x + delta
traj.append(x.detach().cpu())
return torch.stack(traj, dim=1)
def log_class_samples(
model: ASMamba,
cfg: TrainConfig,
device: torch.device,
logger: SwanLogger,
step: int,
) -> None:
if cfg.val_samples_per_class <= 0:
return
model.eval()
max_steps = cfg.seq_len
input_dim = cfg.channels * cfg.image_size * cfg.image_size
dt_seq = torch.full(
(cfg.val_samples_per_class, max_steps),
1.0 / max_steps,
device=device,
)
for cls in range(cfg.num_classes):
cond = torch.full(
(cfg.val_samples_per_class,), cls, device=device, dtype=torch.long
)
x0 = torch.randn(cfg.val_samples_per_class, input_dim, device=device)
traj = rollout_trajectory(model, x0, cond, dt_seq=dt_seq)
x_final = traj[:, -1, :].view(
cfg.val_samples_per_class, cfg.channels, cfg.image_size, cfg.image_size
)
save_path = Path(cfg.output_dir) / f"val_class_{cls}_step_{step:06d}.png"
save_image_grid(x_final, save_path, nrow=cfg.val_grid_rows)
logger.log_image(
f"val/class_{cls}",
save_path,
caption=f"class {cls} step {step}",
step=step,
)
model.train()
def train(cfg: TrainConfig) -> ASMamba:
validate_time_config(cfg)
use_ddp, rank, world_size, device = setup_distributed(cfg)
set_seed(cfg.seed + rank)
output_dir = Path(cfg.output_dir)
if rank == 0:
output_dir.mkdir(parents=True, exist_ok=True)
model = ASMamba(cfg).to(device)
if use_ddp:
model = nn.parallel.DistributedDataParallel(model, device_ids=[device.index])
optimizer = torch.optim.AdamW(
model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay
)
logger = SwanLogger(cfg, enabled=(rank == 0))
loader, sampler = build_dataloader(cfg, distributed=use_ddp)
loader_iter = infinite_loader(loader)
global_step = 0
for _ in range(cfg.epochs):
if sampler is not None:
sampler.set_epoch(global_step)
model.train()
for _ in range(cfg.steps_per_epoch):
batch = next(loader_iter)
x1 = batch["pixel_values"].to(device)
cond = batch["labels"].to(device)
b = x1.shape[0]
x1 = x1.view(b, -1)
x0 = torch.randn_like(x1)
v_gt = x1 - x0
dt_seq = sample_time_sequence(cfg, b, device)
t_seq = torch.cumsum(dt_seq, dim=-1)
t_seq = torch.cat([torch.zeros(b, 1, device=device), t_seq[:, :-1]], dim=-1)
x_seq = x0[:, None, :] + t_seq[:, :, None] * v_gt[:, None, :]
delta, _ = model(x_seq, dt_seq, cond)
losses = compute_losses(
delta=delta,
v_gt=v_gt,
dt_seq=dt_seq,
)
loss = cfg.lambda_flow * losses["flow"]
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
if global_step % 10 == 0:
logger.log(
{
"loss/total": float(loss.item()),
"loss/flow": float(losses["flow"].item()),
},
step=global_step,
)
if (
cfg.val_every > 0
and global_step > 0
and global_step % cfg.val_every == 0
and rank == 0
):
log_class_samples(unwrap_model(model), cfg, device, logger, global_step)
global_step += 1
logger.finish()
if use_ddp:
torch.distributed.destroy_process_group()
return unwrap_model(model)
def run_training_and_plot(cfg: TrainConfig) -> Path:
train(cfg)
return Path(cfg.output_dir)