Implement Mamba MeanFlow x-prediction training
This commit is contained in:
335
as_mamba.py
335
as_mamba.py
@@ -1,11 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from dataclasses import asdict, dataclass
|
||||
import os
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Optional
|
||||
|
||||
import lpips
|
||||
|
||||
os.environ.setdefault("MPLCONFIGDIR", "/tmp/mamba_diffusion_mplconfig")
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use("Agg")
|
||||
@@ -16,12 +19,17 @@ import torch.nn.functional as F
|
||||
from datasets import load_dataset
|
||||
from matplotlib import pyplot as plt
|
||||
from torch import Tensor, nn
|
||||
from torch.func import jvp
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from mamba2_minimal import InferenceCache, Mamba2, Mamba2Config, RMSNorm
|
||||
|
||||
|
||||
FIXED_VAL_SAMPLING_STEPS = 5
|
||||
FIXED_VAL_TIME_GRID = (1.0, 0.8, 0.6, 0.4, 0.2, 0.0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainConfig:
|
||||
seed: int = 42
|
||||
@@ -29,13 +37,11 @@ class TrainConfig:
|
||||
epochs: int = 50
|
||||
steps_per_epoch: int = 200
|
||||
batch_size: int = 128
|
||||
seq_len: int = 20
|
||||
seq_len: int = 5
|
||||
lr: float = 2e-4
|
||||
weight_decay: float = 1e-2
|
||||
dt_min: float = 1e-3
|
||||
dt_max: float = 0.06
|
||||
dt_alpha: float = 8.0
|
||||
lambda_flow: float = 1.0
|
||||
lambda_perceptual: float = 0.4
|
||||
num_classes: int = 10
|
||||
image_size: int = 28
|
||||
channels: int = 1
|
||||
@@ -52,10 +58,12 @@ class TrainConfig:
|
||||
use_residual: bool = False
|
||||
output_dir: str = "outputs"
|
||||
project: str = "as-mamba-mnist"
|
||||
run_name: str = "mnist-flow"
|
||||
run_name: str = "mnist-meanflow"
|
||||
val_every: int = 200
|
||||
val_samples_per_class: int = 8
|
||||
val_grid_rows: int = 4
|
||||
val_sampling_steps: int = FIXED_VAL_SAMPLING_STEPS
|
||||
time_grid_size: int = 256
|
||||
use_ddp: bool = False
|
||||
|
||||
|
||||
@@ -127,6 +135,11 @@ def sinusoidal_embedding(t: Tensor, dim: int) -> Tensor:
|
||||
return emb.to(t.dtype)
|
||||
|
||||
|
||||
def safe_time_divisor(t: Tensor) -> Tensor:
|
||||
eps = torch.finfo(t.dtype).eps
|
||||
return torch.where(t > 0, t, torch.full_like(t, eps))
|
||||
|
||||
|
||||
class ASMamba(nn.Module):
|
||||
def __init__(self, cfg: TrainConfig) -> None:
|
||||
super().__init__()
|
||||
@@ -150,32 +163,42 @@ class ASMamba(nn.Module):
|
||||
)
|
||||
self.backbone = Mamba2Backbone(args, use_residual=cfg.use_residual)
|
||||
self.cond_emb = nn.Embedding(cfg.num_classes, cfg.d_model)
|
||||
self.delta_head = nn.Linear(cfg.d_model, input_dim)
|
||||
self.clean_head = nn.Linear(cfg.d_model, input_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
dt: Tensor,
|
||||
z_t: Tensor,
|
||||
r: Tensor,
|
||||
t: Tensor,
|
||||
cond: Tensor,
|
||||
h: Optional[list[InferenceCache]] = None,
|
||||
) -> tuple[Tensor, list[InferenceCache]]:
|
||||
if dt.dim() == 1:
|
||||
dt = dt.unsqueeze(1)
|
||||
elif dt.dim() == 3 and dt.shape[-1] == 1:
|
||||
dt = dt.squeeze(-1)
|
||||
dt = dt.to(dtype=x.dtype)
|
||||
dt_emb = sinusoidal_embedding(dt, x.shape[-1])
|
||||
x = x + dt_emb
|
||||
if r.dim() == 1:
|
||||
r = r.unsqueeze(1)
|
||||
elif r.dim() == 3 and r.shape[-1] == 1:
|
||||
r = r.squeeze(-1)
|
||||
if t.dim() == 1:
|
||||
t = t.unsqueeze(1)
|
||||
elif t.dim() == 3 and t.shape[-1] == 1:
|
||||
t = t.squeeze(-1)
|
||||
|
||||
r = r.to(dtype=z_t.dtype)
|
||||
t = t.to(dtype=z_t.dtype)
|
||||
z_t = z_t + sinusoidal_embedding(r, z_t.shape[-1]) + sinusoidal_embedding(
|
||||
t, z_t.shape[-1]
|
||||
)
|
||||
cond_vec = self.cond_emb(cond)
|
||||
feats, h = self.backbone(x, cond_vec, h)
|
||||
delta = self.delta_head(feats)
|
||||
return delta, h
|
||||
feats, h = self.backbone(z_t, cond_vec, h)
|
||||
x_pred = self.clean_head(feats)
|
||||
return x_pred, h
|
||||
|
||||
def step(
|
||||
self, x: Tensor, dt: Tensor, cond: Tensor, h: list[InferenceCache]
|
||||
self, z_t: Tensor, r: Tensor, t: Tensor, cond: Tensor, h: list[InferenceCache]
|
||||
) -> tuple[Tensor, list[InferenceCache]]:
|
||||
delta, h = self.forward(x.unsqueeze(1), dt.unsqueeze(1), cond, h)
|
||||
return delta[:, 0, :], h
|
||||
x_pred, h = self.forward(
|
||||
z_t.unsqueeze(1), r.unsqueeze(1), t.unsqueeze(1), cond, h
|
||||
)
|
||||
return x_pred[:, 0, :], h
|
||||
|
||||
def init_cache(self, batch_size: int, device: torch.device) -> list[InferenceCache]:
|
||||
return [
|
||||
@@ -242,6 +265,34 @@ class SwanLogger:
|
||||
finish()
|
||||
|
||||
|
||||
class LPIPSPerceptualLoss(nn.Module):
|
||||
def __init__(self, cfg: TrainConfig) -> None:
|
||||
super().__init__()
|
||||
torch_home = Path(cfg.output_dir) / ".torch"
|
||||
torch_home.mkdir(parents=True, exist_ok=True)
|
||||
os.environ["TORCH_HOME"] = str(torch_home)
|
||||
self.channels = cfg.channels
|
||||
self.loss_fn = lpips.LPIPS(net="vgg", verbose=False)
|
||||
self.loss_fn.eval()
|
||||
for param in self.loss_fn.parameters():
|
||||
param.requires_grad_(False)
|
||||
|
||||
def _prepare_images(self, images: Tensor) -> Tensor:
|
||||
if images.shape[1] == 1:
|
||||
return images.repeat(1, 3, 1, 1)
|
||||
if images.shape[1] != 3:
|
||||
raise ValueError(
|
||||
"LPIPS perceptual loss expects 1-channel or 3-channel images. "
|
||||
f"Got {images.shape[1]} channels."
|
||||
)
|
||||
return images
|
||||
|
||||
def forward(self, pred: Tensor, target: Tensor) -> Tensor:
|
||||
pred_rgb = self._prepare_images(pred)
|
||||
target_rgb = self._prepare_images(target)
|
||||
return self.loss_fn(pred_rgb, target_rgb).mean()
|
||||
|
||||
|
||||
def set_seed(seed: int) -> None:
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
@@ -266,40 +317,42 @@ def unwrap_model(model: nn.Module) -> nn.Module:
|
||||
return model.module if hasattr(model, "module") else model
|
||||
|
||||
|
||||
def validate_time_config(cfg: TrainConfig) -> None:
|
||||
if cfg.seq_len <= 0:
|
||||
raise ValueError("seq_len must be > 0")
|
||||
base = 1.0 / cfg.seq_len
|
||||
if cfg.dt_max <= base:
|
||||
def validate_config(cfg: TrainConfig) -> None:
|
||||
if cfg.seq_len != 5:
|
||||
raise ValueError(
|
||||
"dt_max must be > 1/seq_len to allow non-uniform dt_seq. "
|
||||
f"Got dt_max={cfg.dt_max}, seq_len={cfg.seq_len}, 1/seq_len={base}."
|
||||
f"seq_len must be 5 for the required 5-block training setup (got {cfg.seq_len})."
|
||||
)
|
||||
if cfg.dt_min >= cfg.dt_max:
|
||||
if cfg.time_grid_size < 2:
|
||||
raise ValueError("time_grid_size must be >= 2.")
|
||||
if cfg.lambda_perceptual < 0:
|
||||
raise ValueError("lambda_perceptual must be >= 0.")
|
||||
if cfg.val_sampling_steps != FIXED_VAL_SAMPLING_STEPS:
|
||||
raise ValueError(
|
||||
f"dt_min must be < dt_max (got dt_min={cfg.dt_min}, dt_max={cfg.dt_max})."
|
||||
f"val_sampling_steps is fixed to {FIXED_VAL_SAMPLING_STEPS} for validation sampling."
|
||||
)
|
||||
|
||||
|
||||
def sample_time_sequence(
|
||||
cfg: TrainConfig, batch_size: int, device: torch.device
|
||||
) -> Tensor:
|
||||
alpha = float(cfg.dt_alpha)
|
||||
if alpha <= 0:
|
||||
raise ValueError("dt_alpha must be > 0")
|
||||
dist = torch.distributions.Gamma(alpha, 1.0)
|
||||
raw = dist.sample((batch_size, cfg.seq_len)).to(device)
|
||||
dt_seq = raw / raw.sum(dim=-1, keepdim=True)
|
||||
base = 1.0 / cfg.seq_len
|
||||
max_dt = float(cfg.dt_max)
|
||||
if max_dt <= base:
|
||||
return torch.full_like(dt_seq, base)
|
||||
max_current = dt_seq.max(dim=-1, keepdim=True).values
|
||||
if (max_current > max_dt).any():
|
||||
gamma = (max_dt - base) / (max_current - base)
|
||||
gamma = gamma.clamp(0.0, 1.0)
|
||||
dt_seq = gamma * dt_seq + (1.0 - gamma) * base
|
||||
return dt_seq
|
||||
def sample_block_times(
|
||||
cfg: TrainConfig, batch_size: int, device: torch.device, dtype: torch.dtype
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
# Sampling sorted discrete cut points allows repeated boundaries, so zero-length
|
||||
# interior blocks occur with non-zero probability while keeping t > 0.
|
||||
cuts = torch.randint(
|
||||
1,
|
||||
cfg.time_grid_size,
|
||||
(batch_size, cfg.seq_len - 1),
|
||||
device=device,
|
||||
)
|
||||
cuts, _ = torch.sort(cuts, dim=-1)
|
||||
boundaries = torch.cat(
|
||||
[
|
||||
torch.zeros(batch_size, 1, device=device, dtype=dtype),
|
||||
cuts.to(dtype=dtype) / float(cfg.time_grid_size),
|
||||
torch.ones(batch_size, 1, device=device, dtype=dtype),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
return boundaries[:, :-1], boundaries[:, 1:]
|
||||
|
||||
|
||||
def build_dataloader(
|
||||
@@ -346,15 +399,66 @@ def infinite_loader(loader: DataLoader) -> Iterator[dict]:
|
||||
yield batch
|
||||
|
||||
|
||||
def build_noisy_sequence(
|
||||
x0: Tensor,
|
||||
eps: Tensor,
|
||||
t_seq: Tensor,
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
z_t = (1.0 - t_seq.unsqueeze(-1)) * x0[:, None, :] + t_seq.unsqueeze(-1) * eps[:, None, :]
|
||||
v_gt = eps - x0
|
||||
return z_t, v_gt
|
||||
|
||||
|
||||
def compute_losses(
|
||||
delta: Tensor,
|
||||
model: nn.Module,
|
||||
perceptual_loss_fn: LPIPSPerceptualLoss,
|
||||
x0: Tensor,
|
||||
z_t: Tensor,
|
||||
v_gt: Tensor,
|
||||
dt_seq: Tensor,
|
||||
) -> dict[str, Tensor]:
|
||||
losses: dict[str, Tensor] = {}
|
||||
target_disp = v_gt[:, None, :] * dt_seq.unsqueeze(-1)
|
||||
losses["flow"] = F.mse_loss(delta, target_disp)
|
||||
return losses
|
||||
r_seq: Tensor,
|
||||
t_seq: Tensor,
|
||||
cond: Tensor,
|
||||
cfg: TrainConfig,
|
||||
) -> tuple[dict[str, Tensor], Tensor]:
|
||||
seq_len = z_t.shape[1]
|
||||
safe_t = safe_time_divisor(t_seq).unsqueeze(-1)
|
||||
|
||||
x_pred, _ = model(z_t, r_seq, t_seq, cond)
|
||||
u = (z_t - x_pred) / safe_t
|
||||
|
||||
x_pred_inst, _ = model(z_t, t_seq, t_seq, cond)
|
||||
v_inst = ((z_t - x_pred_inst) / safe_t).detach()
|
||||
|
||||
def u_fn(z_in: Tensor, r_in: Tensor, t_in: Tensor) -> Tensor:
|
||||
x_pred_local, _ = model(z_in, r_in, t_in, cond)
|
||||
return (z_in - x_pred_local) / safe_time_divisor(t_in).unsqueeze(-1)
|
||||
|
||||
_, dudt = jvp(
|
||||
u_fn,
|
||||
(z_t, r_seq, t_seq),
|
||||
(v_inst, torch.zeros_like(r_seq), torch.ones_like(t_seq)),
|
||||
)
|
||||
corrected_velocity = u + (t_seq - r_seq).unsqueeze(-1) * dudt.detach()
|
||||
target_velocity = v_gt[:, None, :].expand(-1, seq_len, -1)
|
||||
|
||||
pred_images = x_pred.reshape(
|
||||
x0.shape[0] * seq_len, cfg.channels, cfg.image_size, cfg.image_size
|
||||
)
|
||||
target_images = (
|
||||
x0.reshape(x0.shape[0], cfg.channels, cfg.image_size, cfg.image_size)
|
||||
.unsqueeze(1)
|
||||
.expand(-1, seq_len, -1, -1, -1)
|
||||
.reshape(x0.shape[0] * seq_len, cfg.channels, cfg.image_size, cfg.image_size)
|
||||
)
|
||||
|
||||
losses = {
|
||||
"flow": F.mse_loss(corrected_velocity, target_velocity),
|
||||
"perceptual": perceptual_loss_fn(pred_images, target_images),
|
||||
}
|
||||
losses["total"] = cfg.lambda_flow * losses["flow"] + cfg.lambda_perceptual * losses[
|
||||
"perceptual"
|
||||
]
|
||||
return losses, x_pred
|
||||
|
||||
|
||||
def make_grid(images: Tensor, nrow: int) -> np.ndarray:
|
||||
@@ -389,30 +493,36 @@ def save_image_grid(
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def rollout_trajectory(
|
||||
def sample_class_images(
|
||||
model: ASMamba,
|
||||
x0: Tensor,
|
||||
cfg: TrainConfig,
|
||||
device: torch.device,
|
||||
cond: Tensor,
|
||||
dt_seq: Tensor,
|
||||
) -> Tensor:
|
||||
device = x0.device
|
||||
model.eval()
|
||||
h = model.init_cache(batch_size=x0.shape[0], device=device)
|
||||
x = x0
|
||||
if dt_seq.dim() == 1:
|
||||
dt_seq = dt_seq.unsqueeze(0).expand(x0.shape[0], -1)
|
||||
elif dt_seq.shape[0] == 1 and x0.shape[0] > 1:
|
||||
dt_seq = dt_seq.expand(x0.shape[0], -1)
|
||||
traj = [x0.detach().cpu()]
|
||||
input_dim = cfg.channels * cfg.image_size * cfg.image_size
|
||||
z_t = torch.randn(cond.shape[0], input_dim, device=device)
|
||||
time_grid = torch.tensor(FIXED_VAL_TIME_GRID, device=device)
|
||||
|
||||
with torch.no_grad():
|
||||
for step_idx in range(dt_seq.shape[1]):
|
||||
dt = dt_seq[:, step_idx]
|
||||
delta, h = model.step(x, dt, cond, h)
|
||||
x = x + delta
|
||||
traj.append(x.detach().cpu())
|
||||
for step_idx in range(FIXED_VAL_SAMPLING_STEPS):
|
||||
t_cur = torch.full(
|
||||
(cond.shape[0],),
|
||||
float(time_grid[step_idx].item()),
|
||||
device=device,
|
||||
)
|
||||
t_next = time_grid[step_idx + 1]
|
||||
x_pred, _ = model(
|
||||
z_t.unsqueeze(1),
|
||||
t_cur.unsqueeze(1),
|
||||
t_cur.unsqueeze(1),
|
||||
cond,
|
||||
)
|
||||
x_pred = x_pred[:, 0, :]
|
||||
u_inst = (z_t - x_pred) / safe_time_divisor(t_cur).unsqueeze(-1)
|
||||
z_t = z_t + (t_next - t_cur).unsqueeze(-1) * u_inst
|
||||
|
||||
return torch.stack(traj, dim=1)
|
||||
return z_t.view(cond.shape[0], cfg.channels, cfg.image_size, cfg.image_size)
|
||||
|
||||
|
||||
def log_class_samples(
|
||||
@@ -424,24 +534,13 @@ def log_class_samples(
|
||||
) -> None:
|
||||
if cfg.val_samples_per_class <= 0:
|
||||
return
|
||||
training_mode = model.training
|
||||
model.eval()
|
||||
max_steps = cfg.seq_len
|
||||
input_dim = cfg.channels * cfg.image_size * cfg.image_size
|
||||
dt_seq = torch.full(
|
||||
(cfg.val_samples_per_class, max_steps),
|
||||
1.0 / max_steps,
|
||||
device=device,
|
||||
)
|
||||
|
||||
for cls in range(cfg.num_classes):
|
||||
cond = torch.full(
|
||||
(cfg.val_samples_per_class,), cls, device=device, dtype=torch.long
|
||||
)
|
||||
x0 = torch.randn(cfg.val_samples_per_class, input_dim, device=device)
|
||||
traj = rollout_trajectory(model, x0, cond, dt_seq=dt_seq)
|
||||
x_final = traj[:, -1, :].view(
|
||||
cfg.val_samples_per_class, cfg.channels, cfg.image_size, cfg.image_size
|
||||
)
|
||||
x_final = sample_class_images(model, cfg, device, cond)
|
||||
save_path = Path(cfg.output_dir) / f"val_class_{cls}_step_{step:06d}.png"
|
||||
save_image_grid(x_final, save_path, nrow=cfg.val_grid_rows)
|
||||
logger.log_image(
|
||||
@@ -450,12 +549,25 @@ def log_class_samples(
|
||||
caption=f"class {cls} step {step}",
|
||||
step=step,
|
||||
)
|
||||
model.train()
|
||||
if training_mode:
|
||||
model.train()
|
||||
|
||||
|
||||
def build_perceptual_loss(
|
||||
cfg: TrainConfig, device: torch.device, rank: int, use_ddp: bool
|
||||
) -> LPIPSPerceptualLoss:
|
||||
if use_ddp and rank != 0:
|
||||
torch.distributed.barrier()
|
||||
perceptual_loss_fn = LPIPSPerceptualLoss(cfg).to(device)
|
||||
if use_ddp and rank == 0:
|
||||
torch.distributed.barrier()
|
||||
return perceptual_loss_fn
|
||||
|
||||
|
||||
def train(cfg: TrainConfig) -> ASMamba:
|
||||
validate_time_config(cfg)
|
||||
validate_config(cfg)
|
||||
use_ddp, rank, world_size, device = setup_distributed(cfg)
|
||||
del world_size
|
||||
set_seed(cfg.seed + rank)
|
||||
output_dir = Path(cfg.output_dir)
|
||||
if rank == 0:
|
||||
@@ -467,48 +579,53 @@ def train(cfg: TrainConfig) -> ASMamba:
|
||||
optimizer = torch.optim.AdamW(
|
||||
model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay
|
||||
)
|
||||
perceptual_loss_fn = build_perceptual_loss(cfg, device, rank, use_ddp)
|
||||
logger = SwanLogger(cfg, enabled=(rank == 0))
|
||||
|
||||
loader, sampler = build_dataloader(cfg, distributed=use_ddp)
|
||||
loader_iter = infinite_loader(loader)
|
||||
|
||||
global_step = 0
|
||||
for _ in range(cfg.epochs):
|
||||
for epoch_idx in range(cfg.epochs):
|
||||
if sampler is not None:
|
||||
sampler.set_epoch(global_step)
|
||||
sampler.set_epoch(epoch_idx)
|
||||
model.train()
|
||||
for _ in range(cfg.steps_per_epoch):
|
||||
batch = next(loader_iter)
|
||||
x1 = batch["pixel_values"].to(device)
|
||||
x0 = batch["pixel_values"].to(device)
|
||||
cond = batch["labels"].to(device)
|
||||
b = x1.shape[0]
|
||||
x1 = x1.view(b, -1)
|
||||
x0 = torch.randn_like(x1)
|
||||
v_gt = x1 - x0
|
||||
dt_seq = sample_time_sequence(cfg, b, device)
|
||||
t_seq = torch.cumsum(dt_seq, dim=-1)
|
||||
t_seq = torch.cat([torch.zeros(b, 1, device=device), t_seq[:, :-1]], dim=-1)
|
||||
x_seq = x0[:, None, :] + t_seq[:, :, None] * v_gt[:, None, :]
|
||||
b = x0.shape[0]
|
||||
x0 = x0.view(b, -1)
|
||||
eps = torch.randn_like(x0)
|
||||
|
||||
delta, _ = model(x_seq, dt_seq, cond)
|
||||
r_seq, t_seq = sample_block_times(cfg, b, device, x0.dtype)
|
||||
z_t, v_gt = build_noisy_sequence(x0, eps, t_seq)
|
||||
|
||||
losses = compute_losses(
|
||||
delta=delta,
|
||||
losses, _ = compute_losses(
|
||||
model=model,
|
||||
perceptual_loss_fn=perceptual_loss_fn,
|
||||
x0=x0,
|
||||
z_t=z_t,
|
||||
v_gt=v_gt,
|
||||
dt_seq=dt_seq,
|
||||
r_seq=r_seq,
|
||||
t_seq=t_seq,
|
||||
cond=cond,
|
||||
cfg=cfg,
|
||||
)
|
||||
|
||||
loss = cfg.lambda_flow * losses["flow"]
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
loss.backward()
|
||||
losses["total"].backward()
|
||||
optimizer.step()
|
||||
|
||||
if global_step % 10 == 0:
|
||||
if global_step % 10 == 0 and rank == 0:
|
||||
logger.log(
|
||||
{
|
||||
"loss/total": float(loss.item()),
|
||||
"loss/total": float(losses["total"].item()),
|
||||
"loss/flow": float(losses["flow"].item()),
|
||||
"loss/perceptual": float(losses["perceptual"].item()),
|
||||
"time/r_mean": float(r_seq.mean().item()),
|
||||
"time/t_mean": float(t_seq.mean().item()),
|
||||
"time/zero_block_frac": float((t_seq == r_seq).float().mean().item()),
|
||||
},
|
||||
step=global_step,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user