Files
mamba_diffusion/as_mamba.py

614 lines
20 KiB
Python

from __future__ import annotations
import math
from dataclasses import asdict, dataclass
import os
from pathlib import Path
from typing import Iterator, Optional
import matplotlib
matplotlib.use("Agg")
import numpy as np
import torch
import torch.nn.functional as F
from datasets import load_dataset
from matplotlib import pyplot as plt
from torch import Tensor, nn
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from mamba2_minimal import InferenceCache, Mamba2, Mamba2Config, RMSNorm
@dataclass
class TrainConfig:
seed: int = 42
device: str = "cuda"
epochs: int = 50
steps_per_epoch: int = 200
batch_size: int = 128
seq_len: int = 20
lr: float = 2e-4
weight_decay: float = 1e-2
dt_min: float = 1e-3
dt_max: float = 0.06
dt_alpha: float = 8.0
lambda_flow: float = 1.0
lambda_pos: float = 1.0
lambda_dt: float = 1.0
use_flow_loss: bool = True
use_pos_loss: bool = False
use_dt_loss: bool = True
num_classes: int = 10
image_size: int = 28
channels: int = 1
num_workers: int = 8
dataset_name: str = "ylecun/mnist"
dataset_split: str = "train"
d_model: int = 0
n_layer: int = 6
d_state: int = 64
d_conv: int = 4
expand: int = 2
headdim: int = 32
chunk_size: int = 1
use_residual: bool = False
output_dir: str = "outputs"
project: str = "as-mamba-mnist"
run_name: str = "mnist-flow"
val_every: int = 200
val_samples_per_class: int = 8
val_grid_rows: int = 4
val_max_steps: int = 0
use_ddp: bool = False
class AdaLNZero(nn.Module):
def __init__(self, d_model: int) -> None:
super().__init__()
self.norm = RMSNorm(d_model)
self.mod = nn.Linear(d_model, 2 * d_model)
nn.init.zeros_(self.mod.weight)
nn.init.zeros_(self.mod.bias)
def forward(self, x: Tensor, cond: Tensor) -> Tensor:
x = self.norm(x)
params = self.mod(cond).unsqueeze(1)
scale, shift = params.chunk(2, dim=-1)
return x * (1 + scale) + shift
class Mamba2Backbone(nn.Module):
def __init__(self, args: Mamba2Config, use_residual: bool = True) -> None:
super().__init__()
self.args = args
self.use_residual = use_residual
self.layers = nn.ModuleList(
[
nn.ModuleDict(
dict(
mixer=Mamba2(args),
adaln=AdaLNZero(args.d_model),
)
)
for _ in range(args.n_layer)
]
)
self.norm_f = RMSNorm(args.d_model)
def forward(
self, x: Tensor, cond: Tensor, h: Optional[list[InferenceCache]] = None
) -> tuple[Tensor, list[InferenceCache]]:
if h is None:
h = [None for _ in range(self.args.n_layer)]
for i, layer in enumerate(self.layers):
x_mod = layer["adaln"](x, cond)
y, h[i] = layer["mixer"](x_mod, h[i])
x = x + y if self.use_residual else y
x = self.norm_f(x)
return x, h
class ASMamba(nn.Module):
def __init__(self, cfg: TrainConfig) -> None:
super().__init__()
self.cfg = cfg
self.dt_min = float(cfg.dt_min)
self.dt_max = float(cfg.dt_max)
input_dim = cfg.channels * cfg.image_size * cfg.image_size
if cfg.d_model == 0:
cfg.d_model = input_dim
if cfg.d_model != input_dim:
raise ValueError(
f"d_model must equal flattened image dim ({input_dim}) when input_proj is disabled."
)
args = Mamba2Config(
d_model=cfg.d_model,
n_layer=cfg.n_layer,
d_state=cfg.d_state,
d_conv=cfg.d_conv,
expand=cfg.expand,
headdim=cfg.headdim,
chunk_size=cfg.chunk_size,
)
self.backbone = Mamba2Backbone(args, use_residual=cfg.use_residual)
self.cond_emb = nn.Embedding(cfg.num_classes, cfg.d_model)
self.delta_head = nn.Linear(cfg.d_model, input_dim)
self.dt_head = nn.Sequential(
nn.Linear(cfg.d_model, cfg.d_model),
nn.SiLU(),
nn.Linear(cfg.d_model, 1),
)
def forward(
self, x: Tensor, cond: Tensor, h: Optional[list[InferenceCache]] = None
) -> tuple[Tensor, Tensor, list[InferenceCache]]:
cond_vec = self.cond_emb(cond)
feats, h = self.backbone(x, cond_vec, h)
delta = self.delta_head(feats)
dt_raw = self.dt_head(feats).squeeze(-1)
dt = torch.clamp(F.softplus(dt_raw), min=self.dt_min, max=self.dt_max)
return delta, dt, h
def step(
self, x: Tensor, cond: Tensor, h: list[InferenceCache]
) -> tuple[Tensor, Tensor, list[InferenceCache]]:
delta, dt, h = self.forward(x.unsqueeze(1), cond, h)
return delta[:, 0, :], dt[:, 0], h
def init_cache(self, batch_size: int, device: torch.device) -> list[InferenceCache]:
return [
InferenceCache.alloc(batch_size, self.backbone.args, device=device)
for _ in range(self.backbone.args.n_layer)
]
class SwanLogger:
def __init__(self, cfg: TrainConfig, enabled: bool = True) -> None:
self.enabled = enabled
self._swan = None
self._run = None
if not self.enabled:
return
try:
import swanlab # type: ignore
self._swan = swanlab
self._run = self._swan.init(
project=cfg.project,
experiment_name=cfg.run_name,
config=asdict(cfg),
)
self.enabled = True
except Exception:
self.enabled = False
def log(self, metrics: dict, step: int | None = None) -> None:
if not self.enabled:
return
target = self._run if self._run is not None else self._swan
if step is None:
target.log(metrics)
return
try:
target.log(metrics, step=step)
except TypeError:
payload = dict(metrics)
payload["step"] = step
target.log(payload)
def log_image(
self,
key: str,
image_path: Path,
caption: str | None = None,
step: int | None = None,
) -> None:
if not self.enabled:
return
image = self._swan.Image(str(image_path), caption=caption)
self.log({key: image}, step=step)
def finish(self) -> None:
if not self.enabled:
return
finish = None
if self._run is not None:
finish = getattr(self._run, "finish", None)
if finish is None:
finish = getattr(self._swan, "finish", None)
if callable(finish):
finish()
def set_seed(seed: int) -> None:
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def setup_distributed(cfg: TrainConfig) -> tuple[bool, int, int, torch.device]:
world_size = int(os.environ.get("WORLD_SIZE", "1"))
rank = int(os.environ.get("RANK", "0"))
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
use_ddp = cfg.use_ddp and world_size > 1
if use_ddp:
torch.distributed.init_process_group(backend="nccl", init_method="env://")
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
else:
device = torch.device(cfg.device if torch.cuda.is_available() else "cpu")
return use_ddp, rank, world_size, device
def unwrap_model(model: nn.Module) -> nn.Module:
return model.module if hasattr(model, "module") else model
def validate_time_config(cfg: TrainConfig) -> None:
if cfg.seq_len <= 0:
raise ValueError("seq_len must be > 0")
base = 1.0 / cfg.seq_len
if cfg.dt_max <= base:
raise ValueError(
"dt_max must be > 1/seq_len to allow non-uniform dt_seq. "
f"Got dt_max={cfg.dt_max}, seq_len={cfg.seq_len}, 1/seq_len={base}."
)
if cfg.dt_min >= cfg.dt_max:
raise ValueError(
f"dt_min must be < dt_max (got dt_min={cfg.dt_min}, dt_max={cfg.dt_max})."
)
def sample_time_sequence(
cfg: TrainConfig, batch_size: int, device: torch.device
) -> Tensor:
alpha = float(cfg.dt_alpha)
if alpha <= 0:
raise ValueError("dt_alpha must be > 0")
dist = torch.distributions.Gamma(alpha, 1.0)
raw = dist.sample((batch_size, cfg.seq_len)).to(device)
dt_seq = raw / raw.sum(dim=-1, keepdim=True)
base = 1.0 / cfg.seq_len
max_dt = float(cfg.dt_max)
if max_dt <= base:
return torch.full_like(dt_seq, base)
max_current = dt_seq.max(dim=-1, keepdim=True).values
if (max_current > max_dt).any():
gamma = (max_dt - base) / (max_current - base)
gamma = gamma.clamp(0.0, 1.0)
dt_seq = gamma * dt_seq + (1.0 - gamma) * base
return dt_seq
def build_dataloader(
cfg: TrainConfig, distributed: bool = False
) -> tuple[DataLoader, Optional[DistributedSampler]]:
ds = load_dataset(cfg.dataset_name, split=cfg.dataset_split)
def transform(example):
image = example.get("img", example.get("image"))
label = example.get("label", example.get("labels"))
if isinstance(image, list):
arr = np.stack([np.array(im, dtype=np.float32) for im in image], axis=0)
arr = arr / 127.5 - 1.0
if arr.ndim == 3:
tensor = torch.from_numpy(arr).unsqueeze(1)
else:
tensor = torch.from_numpy(arr).permute(0, 3, 1, 2)
labels = torch.tensor(label, dtype=torch.long)
return {"pixel_values": tensor, "labels": labels}
arr = np.array(image, dtype=np.float32) / 127.5 - 1.0
if arr.ndim == 2:
tensor = torch.from_numpy(arr).unsqueeze(0)
else:
tensor = torch.from_numpy(arr).permute(2, 0, 1)
return {"pixel_values": tensor, "labels": torch.tensor(label, dtype=torch.long)}
ds = ds.with_transform(transform)
sampler = DistributedSampler(ds, shuffle=True) if distributed else None
loader = DataLoader(
ds,
batch_size=cfg.batch_size,
shuffle=(sampler is None),
sampler=sampler,
num_workers=cfg.num_workers,
drop_last=True,
pin_memory=torch.cuda.is_available(),
)
return loader, sampler
def infinite_loader(loader: DataLoader) -> Iterator[dict]:
while True:
for batch in loader:
yield batch
def compute_losses(
delta: Tensor,
dt: Tensor,
x_seq: Tensor,
x0: Tensor,
v_gt: Tensor,
t_seq: Tensor,
dt_seq: Tensor,
cfg: TrainConfig,
) -> dict[str, Tensor]:
losses: dict[str, Tensor] = {}
if cfg.use_flow_loss:
target_disp = v_gt[:, None, :] * dt.unsqueeze(-1)
losses["flow"] = F.mse_loss(delta, target_disp)
if cfg.use_pos_loss:
t_next = t_seq + dt
t_next = torch.clamp(t_next, 0.0, 1.0)
x_target = x0[:, None, :] + t_next.unsqueeze(-1) * v_gt[:, None, :]
x_next_pred = x_seq + delta
losses["pos"] = F.mse_loss(x_next_pred, x_target)
if cfg.use_dt_loss:
losses["dt"] = F.mse_loss(dt, dt_seq)
return losses
def plot_dt_hist(
dt_pred: Tensor, dt_gt: Tensor, save_path: Path, title: str = "dt Distribution"
) -> None:
dt_pred_np = dt_pred.detach().cpu().numpy().reshape(-1)
dt_gt_np = dt_gt.detach().cpu().numpy().reshape(-1)
fig, ax = plt.subplots(figsize=(6, 4))
ax.hist(dt_gt_np, bins=30, alpha=0.6, label="dt_gt", color="steelblue")
ax.hist(dt_pred_np, bins=30, alpha=0.6, label="dt_pred", color="orange")
ax.set_title(title)
ax.set_xlabel("dt")
ax.set_ylabel("count")
ax.legend(loc="best")
fig.tight_layout()
fig.savefig(save_path, dpi=160)
plt.close(fig)
def make_grid(images: Tensor, nrow: int) -> np.ndarray:
images = images.detach().cpu().numpy()
b, c, h, w = images.shape
nrow = max(1, min(nrow, b))
ncol = math.ceil(b / nrow)
grid = np.zeros((c, ncol * h, nrow * w), dtype=np.float32)
for idx in range(b):
r = idx // nrow
cidx = idx % nrow
grid[:, r * h : (r + 1) * h, cidx * w : (cidx + 1) * w] = images[idx]
return np.transpose(grid, (1, 2, 0))
def save_image_grid(
images: Tensor, save_path: Path, nrow: int, title: str | None = None
) -> None:
images = images.clamp(-1.0, 1.0)
images = (images + 1.0) / 2.0
grid = make_grid(images, nrow=nrow)
if grid.ndim == 3 and grid.shape[2] == 1:
grid = np.repeat(grid, 3, axis=2)
plt.imsave(save_path, grid)
if title is not None:
fig, ax = plt.subplots(figsize=(4, 3))
ax.imshow(grid)
ax.set_title(title)
ax.axis("off")
fig.tight_layout()
fig.savefig(save_path, dpi=160)
plt.close(fig)
def rollout_trajectory(
model: ASMamba,
x0: Tensor,
cond: Tensor,
max_steps: int,
) -> Tensor:
device = x0.device
model.eval()
h = model.init_cache(batch_size=x0.shape[0], device=device)
x = x0
total_time = torch.zeros(x0.shape[0], device=device)
traj = [x0.detach().cpu()]
with torch.no_grad():
for _ in range(max_steps):
delta, dt, h = model.step(x, cond, h)
dt = torch.clamp(dt, min=model.dt_min, max=model.dt_max)
remaining = 1.0 - total_time
overshoot = dt > remaining
if overshoot.any():
scale = (remaining / dt).unsqueeze(-1)
delta = torch.where(overshoot.unsqueeze(-1), delta * scale, delta)
dt = torch.where(overshoot, remaining, dt)
x = x + delta
total_time = total_time + dt
traj.append(x.detach().cpu())
if torch.all(total_time >= 1.0 - 1e-6):
break
return torch.stack(traj, dim=1)
def log_class_samples(
model: ASMamba,
cfg: TrainConfig,
device: torch.device,
logger: SwanLogger,
step: int,
) -> None:
if cfg.val_samples_per_class <= 0:
return
model.eval()
max_steps = cfg.seq_len if cfg.val_max_steps <= 0 else cfg.val_max_steps
input_dim = cfg.channels * cfg.image_size * cfg.image_size
for cls in range(cfg.num_classes):
cond = torch.full(
(cfg.val_samples_per_class,), cls, device=device, dtype=torch.long
)
x0 = torch.randn(cfg.val_samples_per_class, input_dim, device=device)
traj = rollout_trajectory(model, x0, cond, max_steps=max_steps)
x_final = traj[:, -1, :].view(
cfg.val_samples_per_class, cfg.channels, cfg.image_size, cfg.image_size
)
save_path = Path(cfg.output_dir) / f"val_class_{cls}_step_{step:06d}.png"
save_image_grid(x_final, save_path, nrow=cfg.val_grid_rows)
logger.log_image(
f"val/class_{cls}",
save_path,
caption=f"class {cls} step {step}",
step=step,
)
model.train()
def train(cfg: TrainConfig) -> ASMamba:
validate_time_config(cfg)
use_ddp, rank, world_size, device = setup_distributed(cfg)
set_seed(cfg.seed + rank)
output_dir = Path(cfg.output_dir)
if rank == 0:
output_dir.mkdir(parents=True, exist_ok=True)
model = ASMamba(cfg).to(device)
if use_ddp:
model = nn.parallel.DistributedDataParallel(model, device_ids=[device.index])
optimizer = torch.optim.AdamW(
model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay
)
logger = SwanLogger(cfg, enabled=(rank == 0))
loader, sampler = build_dataloader(cfg, distributed=use_ddp)
loader_iter = infinite_loader(loader)
global_step = 0
for _ in range(cfg.epochs):
if sampler is not None:
sampler.set_epoch(global_step)
model.train()
for _ in range(cfg.steps_per_epoch):
batch = next(loader_iter)
x1 = batch["pixel_values"].to(device)
cond = batch["labels"].to(device)
b = x1.shape[0]
x1 = x1.view(b, -1)
x0 = torch.randn_like(x1)
v_gt = x1 - x0
dt_seq = sample_time_sequence(cfg, b, device)
t_seq = torch.cumsum(dt_seq, dim=-1)
t_seq = torch.cat([torch.zeros(b, 1, device=device), t_seq[:, :-1]], dim=-1)
x_seq = x0[:, None, :] + t_seq[:, :, None] * v_gt[:, None, :]
delta, dt, _ = model(x_seq, cond)
losses = compute_losses(
delta=delta,
dt=dt,
x_seq=x_seq,
x0=x0,
v_gt=v_gt,
t_seq=t_seq,
dt_seq=dt_seq,
cfg=cfg,
)
loss = torch.tensor(0.0, device=device)
if cfg.use_flow_loss and "flow" in losses:
loss = loss + cfg.lambda_flow * losses["flow"]
if cfg.use_pos_loss and "pos" in losses:
loss = loss + cfg.lambda_pos * losses["pos"]
if cfg.use_dt_loss and "dt" in losses:
loss = loss + cfg.lambda_dt * losses["dt"]
if loss.item() == 0.0:
raise RuntimeError(
"No loss enabled: enable at least one of flow/pos/dt."
)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
if global_step % 10 == 0:
dt_min = float(dt.min().item())
dt_max = float(dt.max().item())
dt_mean = float(dt.mean().item())
dt_gt_min = float(dt_seq.min().item())
dt_gt_max = float(dt_seq.max().item())
dt_gt_mean = float(dt_seq.mean().item())
eps = 1e-6
clamp_min_ratio = float((dt <= cfg.dt_min + eps).float().mean().item())
clamp_max_ratio = float((dt >= cfg.dt_max - eps).float().mean().item())
clamp_any_ratio = float(
((dt <= cfg.dt_min + eps) | (dt >= cfg.dt_max - eps))
.float()
.mean()
.item()
)
logger.log(
{
"loss/total": float(loss.item()),
"loss/flow": float(
losses.get("flow", torch.tensor(0.0)).item()
),
"loss/pos": float(losses.get("pos", torch.tensor(0.0)).item()),
"loss/dt": float(losses.get("dt", torch.tensor(0.0)).item()),
"dt/pred_mean": dt_mean,
"dt/pred_min": dt_min,
"dt/pred_max": dt_max,
"dt/gt_mean": dt_gt_mean,
"dt/gt_min": dt_gt_min,
"dt/gt_max": dt_gt_max,
"dt/clamp_min_ratio": clamp_min_ratio,
"dt/clamp_max_ratio": clamp_max_ratio,
"dt/clamp_any_ratio": clamp_any_ratio,
},
step=global_step,
)
if (
cfg.val_every > 0
and global_step > 0
and global_step % cfg.val_every == 0
and rank == 0
):
log_class_samples(unwrap_model(model), cfg, device, logger, global_step)
dt_hist_path = (
Path(cfg.output_dir) / f"dt_hist_step_{global_step:06d}.png"
)
plot_dt_hist(
dt,
dt_seq,
dt_hist_path,
title=f"dt Distribution (step {global_step})",
)
logger.log_image(
"train/dt_hist",
dt_hist_path,
caption=f"step {global_step}",
step=global_step,
)
global_step += 1
logger.finish()
if use_ddp:
torch.distributed.destroy_process_group()
return unwrap_model(model)
def run_training_and_plot(cfg: TrainConfig) -> Path:
train(cfg)
return Path(cfg.output_dir)