7 Commits

Author SHA1 Message Date
Logic
5897a0afd1 feat: stabilize meanflow training and time sampling 2026-03-11 22:54:48 +08:00
Logic
9b2968997c Implement Mamba MeanFlow x-prediction training 2026-03-11 16:33:40 +08:00
gameloader
01fc1e4eab refactor: simplify delta-only flow training
Remove learned dt prediction and auxiliary losses.

Add repository contributor guidelines.
2026-03-10 18:23:17 +08:00
gameloader
913740266b fix: remove dt clamping and use raw softplus for step size 2026-01-22 14:41:02 +08:00
gameloader
444f5fc109 feat: migrate switch to conditional flow matching from sphere trajectory 2026-01-22 14:37:50 +08:00
gameloader
c15115edc4 feat: add conditional AdaLNZero and two-target spheres sampling 2026-01-21 15:41:40 +08:00
gameloader
cac3236f9d Add configurable dt sampling and loss toggles 2026-01-21 15:14:04 +08:00
6 changed files with 1562 additions and 320 deletions

27
AGENTS.md Normal file
View File

@@ -0,0 +1,27 @@
# Repository Guidelines
## Project Structure & Module Organization
This repository is a flat Python training project, not a packaged library. `main.py` is the CLI entry point. `as_mamba.py` holds `TrainConfig`, dataset loading, the training loop, validation image generation, and SwanLab logging. `mamba2_minimal.py` contains the standalone Mamba-2 building blocks. `train_as_mamba.sh` is the multi-GPU launcher. Runtime artifacts land in `outputs/` and `swanlog/`; treat both as generated data, not source.
## Build, Test, and Development Commands
Use `uv` with the committed lockfile.
- `uv sync` installs the Python 3.12 environment from `pyproject.toml` and `uv.lock`.
- `uv run python main.py --help` lists all training flags and is the fastest CLI sanity check.
- `uv run python main.py --device cpu --epochs 1 --steps-per-epoch 1 --batch-size 8 --num-workers 0 --val-every 0` runs a minimal local smoke test.
- `bash train_as_mamba.sh` launches the default distributed training job with `torchrun`; adjust `--nproc_per_node` and shell variables before using it on a new machine.
- `uv run python -m compileall main.py as_mamba.py mamba2_minimal.py` performs a lightweight syntax check when no test suite is available.
## Coding Style & Naming Conventions
Follow the existing Python style: 4-space indentation, type hints on public functions, `dataclass`-based config, and small helper functions over deeply nested logic. Use `snake_case` for functions, variables, and CLI flags; use `PascalCase` for classes such as `ASMamba` and `TrainConfig`. Keep new modules top-level unless the project is restructured. There is no configured formatter or linter yet, so match surrounding code closely and keep imports grouped and readable.
## Testing Guidelines
There is no dedicated `tests/` directory or pytest setup yet. For changes, require at minimum:
- `uv run python -m compileall main.py as_mamba.py mamba2_minimal.py`
- one short training smoke run with reduced epochs and workers
If you add reusable logic, prefer extracting pure functions so a future pytest suite can cover them easily. Name any new test files `test_<module>.py`.
## Commit & Pull Request Guidelines
Recent history favors short imperative commit subjects, usually Conventional Commit style: `feat: ...`, `fix: ...`, and scoped variants like `feat(mamba): ...`. Keep commits focused on one training or infrastructure change. Pull requests should describe the config or behavior change, list the verification commands you ran, and include sample images or metric notes when outputs in `outputs/` materially change. Do not commit `.venv/`, `__pycache__/`, or large generated artifacts.

View File

@@ -1,9 +1,14 @@
from __future__ import annotations from __future__ import annotations
import math
import os
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Iterator, Optional
import lpips
os.environ.setdefault("MPLCONFIGDIR", "/tmp/mamba_diffusion_mplconfig")
import matplotlib import matplotlib
matplotlib.use("Agg") matplotlib.use("Agg")
@@ -11,36 +16,41 @@ matplotlib.use("Agg")
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from datasets import load_dataset
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D # noqa: F401
from torch import Tensor, nn from torch import Tensor, nn
from torch.func import jvp
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from mamba2_minimal import InferenceCache, Mamba2, Mamba2Config, RMSNorm from mamba2_minimal import InferenceCache, Mamba2, Mamba2Config, RMSNorm
FIXED_VAL_SAMPLING_STEPS = 5
FIXED_VAL_TIME_GRID = (1.0, 0.8, 0.6, 0.4, 0.2, 0.0)
@dataclass @dataclass
class TrainConfig: class TrainConfig:
seed: int = 42 seed: int = 42
device: str = "cuda" device: str = "cuda"
epochs: int = 50
steps_per_epoch: int = 200
batch_size: int = 128 batch_size: int = 128
steps_per_epoch: int = 50 seq_len: int = 5
epochs: int = 60 lr: float = 2e-4
warmup_epochs: int = 15
seq_len: int = 20
lr: float = 1e-3
weight_decay: float = 1e-2 weight_decay: float = 1e-2
dt_min: float = 1e-3 max_grad_norm: float = 1.0
dt_max: float = 0.06
lambda_flow: float = 1.0 lambda_flow: float = 1.0
lambda_pos: float = 1.0 lambda_perceptual: float = 2.0
lambda_nfe: float = 0.05 num_classes: int = 10
radius_min: float = 0.6 image_size: int = 28
radius_max: float = 1.4 channels: int = 1
center_min: float = -6.0 num_workers: int = 8
center_max: float = 6.0 dataset_name: str = "ylecun/mnist"
center_distance_min: float = 6.0 dataset_split: str = "train"
d_model: int = 128 d_model: int = 0
n_layer: int = 4 n_layer: int = 6
d_state: int = 64 d_state: int = 64
d_conv: int = 4 d_conv: int = 4
expand: int = 2 expand: int = 2
@@ -48,12 +58,29 @@ class TrainConfig:
chunk_size: int = 1 chunk_size: int = 1
use_residual: bool = False use_residual: bool = False
output_dir: str = "outputs" output_dir: str = "outputs"
project: str = "as-mamba" project: str = "as-mamba-mnist"
run_name: str = "sphere-to-sphere" run_name: str = "mnist-meanflow"
val_every: int = 200 val_every: int = 200
val_samples: int = 256 val_samples_per_class: int = 8
val_plot_samples: int = 16 val_grid_rows: int = 4
val_max_steps: int = 100 val_sampling_steps: int = FIXED_VAL_SAMPLING_STEPS
time_grid_size: int = 256
use_ddp: bool = False
class AdaLNZero(nn.Module):
def __init__(self, d_model: int) -> None:
super().__init__()
self.norm = RMSNorm(d_model)
self.mod = nn.Linear(d_model, 2 * d_model)
nn.init.zeros_(self.mod.weight)
nn.init.zeros_(self.mod.bias)
def forward(self, x: Tensor, cond: Tensor) -> Tensor:
x = self.norm(x)
params = self.mod(cond).unsqueeze(1)
scale, shift = params.chunk(2, dim=-1)
return x * (1 + scale) + shift
class Mamba2Backbone(nn.Module): class Mamba2Backbone(nn.Module):
@@ -66,7 +93,7 @@ class Mamba2Backbone(nn.Module):
nn.ModuleDict( nn.ModuleDict(
dict( dict(
mixer=Mamba2(args), mixer=Mamba2(args),
norm=RMSNorm(args.d_model), adaln=AdaLNZero(args.d_model),
) )
) )
for _ in range(args.n_layer) for _ in range(args.n_layer)
@@ -75,25 +102,56 @@ class Mamba2Backbone(nn.Module):
self.norm_f = RMSNorm(args.d_model) self.norm_f = RMSNorm(args.d_model)
def forward( def forward(
self, x: Tensor, h: Optional[list[InferenceCache]] = None self, x: Tensor, cond: Tensor, h: Optional[list[InferenceCache]] = None
) -> tuple[Tensor, list[InferenceCache]]: ) -> tuple[Tensor, list[InferenceCache]]:
if h is None: if h is None:
h = [None for _ in range(self.args.n_layer)] h = [None for _ in range(self.args.n_layer)]
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
y, h[i] = layer["mixer"](layer["norm"](x), h[i]) x_mod = layer["adaln"](x, cond)
y, h[i] = layer["mixer"](x_mod, h[i])
x = x + y if self.use_residual else y x = x + y if self.use_residual else y
x = self.norm_f(x) x = self.norm_f(x)
return x, h return x, h
def sinusoidal_embedding(t: Tensor, dim: int) -> Tensor:
if dim <= 0:
raise ValueError("sinusoidal embedding dim must be > 0")
half = dim // 2
if half == 0:
return torch.zeros(*t.shape, dim, device=t.device, dtype=t.dtype)
denom = max(half - 1, 1)
freqs = torch.exp(
-math.log(10000.0)
* torch.arange(half, device=t.device, dtype=torch.float32)
/ denom
)
args = t.to(torch.float32).unsqueeze(-1) * freqs
emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
if dim % 2 == 1:
pad = torch.zeros(*emb.shape[:-1], 1, device=emb.device, dtype=emb.dtype)
emb = torch.cat([emb, pad], dim=-1)
return emb.to(t.dtype)
def safe_time_divisor(t: Tensor) -> Tensor:
eps = torch.finfo(t.dtype).eps
return torch.where(t > 0, t, torch.full_like(t, eps))
class ASMamba(nn.Module): class ASMamba(nn.Module):
def __init__(self, cfg: TrainConfig) -> None: def __init__(self, cfg: TrainConfig) -> None:
super().__init__() super().__init__()
self.cfg = cfg self.cfg = cfg
self.dt_min = float(cfg.dt_min) input_dim = cfg.channels * cfg.image_size * cfg.image_size
self.dt_max = float(cfg.dt_max) if cfg.d_model == 0:
cfg.d_model = input_dim
if cfg.d_model != input_dim:
raise ValueError(
f"d_model must equal flattened image dim ({input_dim}) when input_proj is disabled."
)
args = Mamba2Config( args = Mamba2Config(
d_model=cfg.d_model, d_model=cfg.d_model,
@@ -105,29 +163,43 @@ class ASMamba(nn.Module):
chunk_size=cfg.chunk_size, chunk_size=cfg.chunk_size,
) )
self.backbone = Mamba2Backbone(args, use_residual=cfg.use_residual) self.backbone = Mamba2Backbone(args, use_residual=cfg.use_residual)
self.input_proj = nn.Linear(3, cfg.d_model) self.cond_emb = nn.Embedding(cfg.num_classes, cfg.d_model)
self.delta_head = nn.Linear(cfg.d_model, 3) self.clean_head = nn.Linear(cfg.d_model, input_dim)
self.dt_head = nn.Sequential(
nn.Linear(cfg.d_model, cfg.d_model),
nn.SiLU(),
nn.Linear(cfg.d_model, 1),
)
def forward( def forward(
self, x: Tensor, h: Optional[list[InferenceCache]] = None self,
) -> tuple[Tensor, Tensor, list[InferenceCache]]: z_t: Tensor,
x_proj = self.input_proj(x) r: Tensor,
feats, h = self.backbone(x_proj, h) t: Tensor,
delta = self.delta_head(feats) cond: Tensor,
dt_raw = self.dt_head(feats).squeeze(-1) h: Optional[list[InferenceCache]] = None,
dt = torch.clamp(F.softplus(dt_raw), min=self.dt_min, max=self.dt_max) ) -> tuple[Tensor, list[InferenceCache]]:
return delta, dt, h if r.dim() == 1:
r = r.unsqueeze(1)
elif r.dim() == 3 and r.shape[-1] == 1:
r = r.squeeze(-1)
if t.dim() == 1:
t = t.unsqueeze(1)
elif t.dim() == 3 and t.shape[-1] == 1:
t = t.squeeze(-1)
r = r.to(dtype=z_t.dtype)
t = t.to(dtype=z_t.dtype)
z_t = z_t + sinusoidal_embedding(r, z_t.shape[-1]) + sinusoidal_embedding(
t, z_t.shape[-1]
)
cond_vec = self.cond_emb(cond)
feats, h = self.backbone(z_t, cond_vec, h)
x_pred = torch.tanh(self.clean_head(feats))
return x_pred, h
def step( def step(
self, x: Tensor, h: list[InferenceCache] self, z_t: Tensor, r: Tensor, t: Tensor, cond: Tensor, h: list[InferenceCache]
) -> tuple[Tensor, Tensor, list[InferenceCache]]: ) -> tuple[Tensor, list[InferenceCache]]:
delta, dt, h = self.forward(x.unsqueeze(1), h) x_pred, h = self.forward(
return delta[:, 0, :], dt[:, 0], h z_t.unsqueeze(1), r.unsqueeze(1), t.unsqueeze(1), cond, h
)
return x_pred[:, 0, :], h
def init_cache(self, batch_size: int, device: torch.device) -> list[InferenceCache]: def init_cache(self, batch_size: int, device: torch.device) -> list[InferenceCache]:
return [ return [
@@ -137,10 +209,12 @@ class ASMamba(nn.Module):
class SwanLogger: class SwanLogger:
def __init__(self, cfg: TrainConfig) -> None: def __init__(self, cfg: TrainConfig, enabled: bool = True) -> None:
self.enabled = False self.enabled = enabled
self._swan = None self._swan = None
self._run = None self._run = None
if not self.enabled:
return
try: try:
import swanlab # type: ignore import swanlab # type: ignore
@@ -169,7 +243,11 @@ class SwanLogger:
target.log(payload) target.log(payload)
def log_image( def log_image(
self, key: str, image_path: Path, caption: str | None = None, step: int | None = None self,
key: str,
image_path: Path,
caption: str | None = None,
step: int | None = None,
) -> None: ) -> None:
if not self.enabled: if not self.enabled:
return return
@@ -188,324 +266,391 @@ class SwanLogger:
finish() finish()
class LPIPSPerceptualLoss(nn.Module):
def __init__(self, cfg: TrainConfig) -> None:
super().__init__()
torch_home = Path(cfg.output_dir) / ".torch"
torch_home.mkdir(parents=True, exist_ok=True)
os.environ["TORCH_HOME"] = str(torch_home)
self.channels = cfg.channels
self.loss_fn = lpips.LPIPS(net="vgg", verbose=False)
self.loss_fn.eval()
for param in self.loss_fn.parameters():
param.requires_grad_(False)
def _prepare_images(self, images: Tensor) -> Tensor:
if images.shape[1] == 1:
return images.repeat(1, 3, 1, 1)
if images.shape[1] != 3:
raise ValueError(
"LPIPS perceptual loss expects 1-channel or 3-channel images. "
f"Got {images.shape[1]} channels."
)
return images
def forward(self, pred: Tensor, target: Tensor) -> Tensor:
pred_rgb = self._prepare_images(pred)
target_rgb = self._prepare_images(target)
return self.loss_fn(pred_rgb, target_rgb).mean()
def set_seed(seed: int) -> None: def set_seed(seed: int) -> None:
torch.manual_seed(seed) torch.manual_seed(seed)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
def sample_points_in_sphere( def setup_distributed(cfg: TrainConfig) -> tuple[bool, int, int, torch.device]:
center: Tensor, radius: float, batch_size: int, device: torch.device world_size = int(os.environ.get("WORLD_SIZE", "1"))
) -> Tensor: rank = int(os.environ.get("RANK", "0"))
direction = torch.randn(batch_size, 3, device=device) local_rank = int(os.environ.get("LOCAL_RANK", "0"))
direction = direction / (direction.norm(dim=-1, keepdim=True) + 1e-8) use_ddp = cfg.use_ddp and world_size > 1
u = torch.rand(batch_size, 1, device=device) if use_ddp:
r = radius * torch.pow(u, 1.0 / 3.0) torch.distributed.init_process_group(backend="nccl", init_method="env://")
return center + direction * r torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
else:
device = torch.device(cfg.device if torch.cuda.is_available() else "cpu")
return use_ddp, rank, world_size, device
def sample_sphere_params(cfg: TrainConfig, device: torch.device) -> tuple[Tensor, Tensor]: def unwrap_model(model: nn.Module) -> nn.Module:
center_a = torch.empty(3, device=device).uniform_(cfg.center_min, cfg.center_max) return model.module if hasattr(model, "module") else model
center_b = torch.empty(3, device=device).uniform_(cfg.center_min, cfg.center_max)
for _ in range(128):
if torch.norm(center_a - center_b) >= cfg.center_distance_min: def validate_config(cfg: TrainConfig) -> None:
break if cfg.seq_len != 5:
center_b = torch.empty(3, device=device).uniform_(cfg.center_min, cfg.center_max) raise ValueError(
if torch.norm(center_a - center_b) < 1e-3: f"seq_len must be 5 for the required 5-block training setup (got {cfg.seq_len})."
center_b = center_b + torch.tensor([cfg.center_distance_min, 0.0, 0.0], device=device) )
radius_a = float(torch.empty(1).uniform_(cfg.radius_min, cfg.radius_max).item()) if cfg.time_grid_size < 2:
radius_b = float(torch.empty(1).uniform_(cfg.radius_min, cfg.radius_max).item()) raise ValueError("time_grid_size must be >= 2.")
return (center_a, torch.tensor(radius_a, device=device)), ( if cfg.lambda_perceptual < 0:
center_b, raise ValueError("lambda_perceptual must be >= 0.")
torch.tensor(radius_b, device=device), if cfg.max_grad_norm <= 0:
raise ValueError("max_grad_norm must be > 0.")
if cfg.val_sampling_steps != FIXED_VAL_SAMPLING_STEPS:
raise ValueError(
f"val_sampling_steps is fixed to {FIXED_VAL_SAMPLING_STEPS} for validation sampling."
) )
def sample_batch( def sample_block_times(
cfg: TrainConfig, cfg: TrainConfig, batch_size: int, device: torch.device, dtype: torch.dtype
sphere_a: tuple[Tensor, Tensor], ) -> tuple[Tensor, Tensor]:
sphere_b: tuple[Tensor, Tensor], num_internal = cfg.seq_len - 1
device: torch.device, normal = torch.randn(batch_size, num_internal, device=device, dtype=dtype)
) -> tuple[Tensor, Tensor, Tensor, Tensor]: logit_normal = torch.sigmoid(normal * math.sqrt(0.8))
center_a, radius_a = sphere_a uniform = torch.rand(batch_size, num_internal, device=device, dtype=dtype)
center_b, radius_b = sphere_b use_uniform = torch.rand(batch_size, num_internal, device=device) < 0.1
x0 = sample_points_in_sphere(center_a, float(radius_a.item()), cfg.batch_size, device) cuts = torch.where(use_uniform, uniform, logit_normal)
x1 = sample_points_in_sphere(center_b, float(radius_b.item()), cfg.batch_size, device) cuts, _ = torch.sort(cuts, dim=-1)
v_gt = x1 - x0 boundaries = torch.cat(
dt_fixed = 1.0 / cfg.seq_len [
t_seq = torch.arange(cfg.seq_len, device=device) * dt_fixed torch.zeros(batch_size, 1, device=device, dtype=dtype),
x_seq = x0[:, None, :] + t_seq[None, :, None] * v_gt[:, None, :] cuts,
return x0, x1, x_seq, t_seq torch.ones(batch_size, 1, device=device, dtype=dtype),
],
dim=-1,
)
return boundaries[:, :-1], boundaries[:, 1:]
def build_dataloader(
cfg: TrainConfig, distributed: bool = False
) -> tuple[DataLoader, Optional[DistributedSampler]]:
ds = load_dataset(cfg.dataset_name, split=cfg.dataset_split)
def transform(example):
image = example.get("img", example.get("image"))
label = example.get("label", example.get("labels"))
if isinstance(image, list):
arr = np.stack([np.array(im, dtype=np.float32) for im in image], axis=0)
arr = arr / 127.5 - 1.0
if arr.ndim == 3:
tensor = torch.from_numpy(arr).unsqueeze(1)
else:
tensor = torch.from_numpy(arr).permute(0, 3, 1, 2)
labels = torch.tensor(label, dtype=torch.long)
return {"pixel_values": tensor, "labels": labels}
arr = np.array(image, dtype=np.float32) / 127.5 - 1.0
if arr.ndim == 2:
tensor = torch.from_numpy(arr).unsqueeze(0)
else:
tensor = torch.from_numpy(arr).permute(2, 0, 1)
return {"pixel_values": tensor, "labels": torch.tensor(label, dtype=torch.long)}
ds = ds.with_transform(transform)
sampler = DistributedSampler(ds, shuffle=True) if distributed else None
loader = DataLoader(
ds,
batch_size=cfg.batch_size,
shuffle=(sampler is None),
sampler=sampler,
num_workers=cfg.num_workers,
drop_last=True,
pin_memory=torch.cuda.is_available(),
)
return loader, sampler
def infinite_loader(loader: DataLoader) -> Iterator[dict]:
while True:
for batch in loader:
yield batch
def build_noisy_sequence(
x0: Tensor,
eps: Tensor,
t_seq: Tensor,
) -> tuple[Tensor, Tensor]:
z_t = (1.0 - t_seq.unsqueeze(-1)) * x0[:, None, :] + t_seq.unsqueeze(-1) * eps[:, None, :]
v_gt = eps - x0
return z_t, v_gt
def compute_losses( def compute_losses(
delta: Tensor, model: nn.Module,
dt: Tensor, perceptual_loss_fn: LPIPSPerceptualLoss,
x_seq: Tensor,
x0: Tensor, x0: Tensor,
z_t: Tensor,
v_gt: Tensor, v_gt: Tensor,
r_seq: Tensor,
t_seq: Tensor, t_seq: Tensor,
cond: Tensor,
cfg: TrainConfig, cfg: TrainConfig,
) -> tuple[Tensor, Tensor, Tensor]: ) -> tuple[dict[str, Tensor], Tensor]:
target_disp = v_gt[:, None, :] * dt.unsqueeze(-1) seq_len = z_t.shape[1]
flow_loss = F.mse_loss(delta, target_disp) safe_t = safe_time_divisor(t_seq).unsqueeze(-1)
t_next = t_seq[None, :, None] + dt.unsqueeze(-1) x_pred, _ = model(z_t, r_seq, t_seq, cond)
t_next = torch.clamp(t_next, 0.0, 1.0) u = (z_t - x_pred) / safe_t
x_target = x0[:, None, :] + t_next * v_gt[:, None, :]
x_next_pred = x_seq + delta
pos_loss = F.mse_loss(x_next_pred, x_target)
nfe_loss = (-torch.log(dt)).mean() x_pred_inst, _ = model(z_t, t_seq, t_seq, cond)
return flow_loss, pos_loss, nfe_loss v_inst = ((z_t - x_pred_inst) / safe_t).detach()
def u_fn(z_in: Tensor, r_in: Tensor, t_in: Tensor) -> Tensor:
x_pred_local, _ = model(z_in, r_in, t_in, cond)
return (z_in - x_pred_local) / safe_time_divisor(t_in).unsqueeze(-1)
def validate( _, dudt = jvp(
model: ASMamba, u_fn,
cfg: TrainConfig, (z_t, r_seq, t_seq),
sphere_a: tuple[Tensor, Tensor], (v_inst, torch.zeros_like(r_seq), torch.ones_like(t_seq)),
sphere_b: tuple[Tensor, Tensor],
device: torch.device,
logger: SwanLogger,
step: int,
) -> None:
model.eval()
center_b, radius_b = sphere_b
with torch.no_grad():
x0 = sample_points_in_sphere(
sphere_a[0], float(sphere_a[1].item()), cfg.val_samples, device
) )
traj = rollout_trajectory(model, x0, max_steps=cfg.val_max_steps) corrected_velocity = u + (t_seq - r_seq).unsqueeze(-1) * dudt.detach()
target_velocity = v_gt[:, None, :].expand(-1, seq_len, -1)
x_final = traj[:, -1, :] pred_images = x_pred.reshape(
center_b_cpu = center_b.detach().cpu() x0.shape[0] * seq_len, cfg.channels, cfg.image_size, cfg.image_size
radius_b_cpu = radius_b.detach().cpu() )
dist = torch.linalg.norm(x_final - center_b_cpu, dim=-1) target_images = (
inside = dist <= radius_b_cpu x0.reshape(x0.shape[0], cfg.channels, cfg.image_size, cfg.image_size)
.unsqueeze(1)
logger.log( .expand(-1, seq_len, -1, -1, -1)
{ .reshape(x0.shape[0] * seq_len, cfg.channels, cfg.image_size, cfg.image_size)
"val/inside_ratio": float(inside.float().mean().item()),
"val/inside_count": float(inside.float().sum().item()),
"val/final_dist_mean": float(dist.mean().item()),
"val/final_dist_min": float(dist.min().item()),
"val/final_dist_max": float(dist.max().item()),
},
step=step,
) )
if cfg.val_plot_samples > 0: losses = {
count = min(cfg.val_plot_samples, traj.shape[0]) "flow": F.mse_loss(corrected_velocity, target_velocity),
if count == 0: "perceptual": perceptual_loss_fn(pred_images, target_images),
model.train()
return
indices = torch.linspace(0, traj.shape[0] - 1, steps=count).long()
traj_plot = traj[indices]
save_path = Path(cfg.output_dir) / f"val_traj_step_{step:06d}.png"
plot_trajectories(
traj_plot,
sphere_a,
sphere_b,
save_path,
title=f"Validation Trajectories (step {step})",
)
ratio = float(inside.float().mean().item())
logger.log_image(
"val/trajectory",
save_path,
caption=f"step {step} | inside_ratio={ratio:.3f}",
step=step,
)
model.train()
def train(cfg: TrainConfig) -> tuple[ASMamba, tuple[Tensor, Tensor], tuple[Tensor, Tensor]]:
device = torch.device(cfg.device if torch.cuda.is_available() else "cpu")
set_seed(cfg.seed)
output_dir = Path(cfg.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
model = ASMamba(cfg).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
logger = SwanLogger(cfg)
sphere_a, sphere_b = sample_sphere_params(cfg, device)
center_a, radius_a = sphere_a
center_b, radius_b = sphere_b
center_dist = torch.norm(center_a - center_b).item()
logger.log(
{
"sphere_a/radius": float(radius_a.item()),
"sphere_b/radius": float(radius_b.item()),
"sphere_a/center_x": float(center_a[0].item()),
"sphere_a/center_y": float(center_a[1].item()),
"sphere_a/center_z": float(center_a[2].item()),
"sphere_b/center_x": float(center_b[0].item()),
"sphere_b/center_y": float(center_b[1].item()),
"sphere_b/center_z": float(center_b[2].item()),
"sphere/center_dist": float(center_dist),
} }
) losses["total"] = cfg.lambda_flow * losses["flow"] + cfg.lambda_perceptual * losses[
"perceptual"
global_step = 0 ]
for epoch in range(cfg.epochs): return losses, x_pred
warmup = epoch < cfg.warmup_epochs
model.train()
for p in model.dt_head.parameters():
p.requires_grad = not warmup
for _ in range(cfg.steps_per_epoch):
x0, x1, x_seq, t_seq = sample_batch(cfg, sphere_a, sphere_b, device)
v_gt = x1 - x0
delta, dt, _ = model(x_seq)
if warmup:
dt = torch.full_like(dt, 1.0 / cfg.seq_len)
flow_loss, pos_loss, nfe_loss = compute_losses(
delta=delta,
dt=dt,
x_seq=x_seq,
x0=x0,
v_gt=v_gt,
t_seq=t_seq,
cfg=cfg,
)
loss = cfg.lambda_flow * flow_loss + cfg.lambda_pos * pos_loss
if not warmup:
loss = loss + cfg.lambda_nfe * nfe_loss
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
if global_step % 10 == 0:
logger.log(
{
"loss/total": float(loss.item()),
"loss/flow": float(flow_loss.item()),
"loss/pos": float(pos_loss.item()),
"loss/nfe": float(nfe_loss.item()),
"dt/mean": float(dt.mean().item()),
"dt/min": float(dt.min().item()),
"dt/max": float(dt.max().item()),
"stage": 0 if warmup else 1,
},
step=global_step,
)
if cfg.val_every > 0 and global_step > 0 and global_step % cfg.val_every == 0:
validate(model, cfg, sphere_a, sphere_b, device, logger, global_step)
global_step += 1
logger.finish()
return model, sphere_a, sphere_b
def rollout_trajectory( def make_grid(images: Tensor, nrow: int) -> np.ndarray:
model: ASMamba, images = images.detach().cpu().numpy()
x0: Tensor, b, c, h, w = images.shape
max_steps: int = 100, nrow = max(1, min(nrow, b))
) -> Tensor: ncol = math.ceil(b / nrow)
device = x0.device grid = np.zeros((c, ncol * h, nrow * w), dtype=np.float32)
model.eval() for idx in range(b):
h = model.init_cache(batch_size=x0.shape[0], device=device) r = idx // nrow
x = x0 cidx = idx % nrow
total_time = torch.zeros(x0.shape[0], device=device) grid[:, r * h : (r + 1) * h, cidx * w : (cidx + 1) * w] = images[idx]
traj = [x0.detach().cpu()] return np.transpose(grid, (1, 2, 0))
with torch.no_grad():
for _ in range(max_steps):
delta, dt, h = model.step(x, h)
dt = torch.clamp(dt, min=model.dt_min, max=model.dt_max)
remaining = 1.0 - total_time
overshoot = dt > remaining
if overshoot.any():
scale = (remaining / dt).unsqueeze(-1)
delta = torch.where(overshoot.unsqueeze(-1), delta * scale, delta)
dt = torch.where(overshoot, remaining, dt)
x = x + delta
total_time = total_time + dt
traj.append(x.detach().cpu())
if torch.all(total_time >= 1.0 - 1e-6):
break
return torch.stack(traj, dim=1)
def sphere_wireframe( def save_image_grid(
center: Tensor, radius: float, u_steps: int = 24, v_steps: int = 12 images: Tensor, save_path: Path, nrow: int, title: str | None = None
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
center_np = center.detach().cpu().numpy()
u = np.linspace(0, 2 * np.pi, u_steps)
v = np.linspace(0, np.pi, v_steps)
x = center_np[0] + radius * np.outer(np.cos(u), np.sin(v))
y = center_np[1] + radius * np.outer(np.sin(u), np.sin(v))
z = center_np[2] + radius * np.outer(np.ones_like(u), np.cos(v))
return x, y, z
def plot_trajectories(
traj: Tensor,
sphere_a: tuple[Tensor, Tensor],
sphere_b: tuple[Tensor, Tensor],
save_path: Path,
title: str = "AS-Mamba Trajectories",
) -> None: ) -> None:
traj = traj.detach().cpu() images = images.clamp(-1.0, 1.0)
if traj.dim() == 2: images = (images + 1.0) / 2.0
traj = traj.unsqueeze(0) grid = make_grid(images, nrow=nrow)
traj_np = traj.numpy() if grid.ndim == 3 and grid.shape[2] == 1:
grid = np.repeat(grid, 3, axis=2)
fig = plt.figure(figsize=(7, 6)) plt.imsave(save_path, grid)
ax = fig.add_subplot(111, projection="3d") if title is not None:
fig, ax = plt.subplots(figsize=(4, 3))
for i in range(traj_np.shape[0]): ax.imshow(grid)
ax.plot(
traj_np[i, :, 0],
traj_np[i, :, 1],
traj_np[i, :, 2],
color="green",
alpha=0.6,
)
starts = traj_np[:, 0, :]
ends = traj_np[:, -1, :]
ax.scatter(starts[:, 0], starts[:, 1], starts[:, 2], color="blue", s=20, label="Start")
ax.scatter(ends[:, 0], ends[:, 1], ends[:, 2], color="red", s=20, label="End")
center_a, radius_a = sphere_a
center_b, radius_b = sphere_b
x_a, y_a, z_a = sphere_wireframe(center_a, float(radius_a.item()))
x_b, y_b, z_b = sphere_wireframe(center_b, float(radius_b.item()))
ax.plot_wireframe(x_a, y_a, z_a, color="blue", alpha=0.15, linewidth=0.5)
ax.plot_wireframe(x_b, y_b, z_b, color="red", alpha=0.15, linewidth=0.5)
ax.set_title(title) ax.set_title(title)
ax.set_xlabel("X") ax.axis("off")
ax.set_ylabel("Y")
ax.set_zlabel("Z")
ax.legend(loc="best")
fig.tight_layout() fig.tight_layout()
fig.savefig(save_path, dpi=160) fig.savefig(save_path, dpi=160)
plt.close(fig) plt.close(fig)
def run_training_and_plot(cfg: TrainConfig) -> Path: def sample_class_images(
model, sphere_a, sphere_b = train(cfg) model: ASMamba,
device = next(model.parameters()).device cfg: TrainConfig,
device: torch.device,
cond: Tensor,
) -> Tensor:
model.eval()
input_dim = cfg.channels * cfg.image_size * cfg.image_size
z_t = torch.randn(cond.shape[0], input_dim, device=device)
time_grid = torch.tensor(FIXED_VAL_TIME_GRID, device=device)
plot_samples = max(1, cfg.val_plot_samples) with torch.no_grad():
x0 = sample_points_in_sphere( for step_idx in range(FIXED_VAL_SAMPLING_STEPS):
sphere_a[0], float(sphere_a[1].item()), plot_samples, device t_cur = torch.full(
(cond.shape[0],),
float(time_grid[step_idx].item()),
device=device,
) )
traj = rollout_trajectory(model, x0, max_steps=cfg.val_max_steps) t_next = time_grid[step_idx + 1]
x_pred, _ = model(
z_t.unsqueeze(1),
t_cur.unsqueeze(1),
t_cur.unsqueeze(1),
cond,
)
x_pred = x_pred[:, 0, :]
u_inst = (z_t - x_pred) / safe_time_divisor(t_cur).unsqueeze(-1)
z_t = z_t + (t_next - t_cur).unsqueeze(-1) * u_inst
return z_t.view(cond.shape[0], cfg.channels, cfg.image_size, cfg.image_size)
def log_class_samples(
model: ASMamba,
cfg: TrainConfig,
device: torch.device,
logger: SwanLogger,
step: int,
) -> None:
if cfg.val_samples_per_class <= 0:
return
training_mode = model.training
model.eval()
for cls in range(cfg.num_classes):
cond = torch.full(
(cfg.val_samples_per_class,), cls, device=device, dtype=torch.long
)
x_final = sample_class_images(model, cfg, device, cond)
save_path = Path(cfg.output_dir) / f"val_class_{cls}_step_{step:06d}.png"
save_image_grid(x_final, save_path, nrow=cfg.val_grid_rows)
logger.log_image(
f"val/class_{cls}",
save_path,
caption=f"class {cls} step {step}",
step=step,
)
if training_mode:
model.train()
def build_perceptual_loss(
cfg: TrainConfig, device: torch.device, rank: int, use_ddp: bool
) -> LPIPSPerceptualLoss:
if use_ddp and rank != 0:
torch.distributed.barrier()
perceptual_loss_fn = LPIPSPerceptualLoss(cfg).to(device)
if use_ddp and rank == 0:
torch.distributed.barrier()
return perceptual_loss_fn
def train(cfg: TrainConfig) -> ASMamba:
validate_config(cfg)
use_ddp, rank, world_size, device = setup_distributed(cfg)
del world_size
set_seed(cfg.seed + rank)
output_dir = Path(cfg.output_dir) output_dir = Path(cfg.output_dir)
save_path = output_dir / "as_mamba_trajectory.png" if rank == 0:
plot_trajectories(traj, sphere_a, sphere_b, save_path) output_dir.mkdir(parents=True, exist_ok=True)
return save_path
model = ASMamba(cfg).to(device)
if use_ddp:
model = nn.parallel.DistributedDataParallel(model, device_ids=[device.index])
optimizer = torch.optim.AdamW(
model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay
)
perceptual_loss_fn = build_perceptual_loss(cfg, device, rank, use_ddp)
logger = SwanLogger(cfg, enabled=(rank == 0))
loader, sampler = build_dataloader(cfg, distributed=use_ddp)
loader_iter = infinite_loader(loader)
global_step = 0
for epoch_idx in range(cfg.epochs):
if sampler is not None:
sampler.set_epoch(epoch_idx)
model.train()
for _ in range(cfg.steps_per_epoch):
batch = next(loader_iter)
x0 = batch["pixel_values"].to(device)
cond = batch["labels"].to(device)
b = x0.shape[0]
x0 = x0.view(b, -1)
eps = torch.randn_like(x0)
r_seq, t_seq = sample_block_times(cfg, b, device, x0.dtype)
z_t, v_gt = build_noisy_sequence(x0, eps, t_seq)
losses, _ = compute_losses(
model=model,
perceptual_loss_fn=perceptual_loss_fn,
x0=x0,
z_t=z_t,
v_gt=v_gt,
r_seq=r_seq,
t_seq=t_seq,
cond=cond,
cfg=cfg,
)
optimizer.zero_grad(set_to_none=True)
losses["total"].backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), max_norm=cfg.max_grad_norm
)
optimizer.step()
if global_step % 10 == 0 and rank == 0:
logger.log(
{
"loss/total": float(losses["total"].item()),
"loss/flow": float(losses["flow"].item()),
"loss/perceptual": float(losses["perceptual"].item()),
"grad/total_norm": float(grad_norm.item()),
"time/r_mean": float(r_seq.mean().item()),
"time/t_mean": float(t_seq.mean().item()),
"time/zero_block_frac": float((t_seq == r_seq).float().mean().item()),
},
step=global_step,
)
if (
cfg.val_every > 0
and global_step > 0
and global_step % cfg.val_every == 0
and rank == 0
):
log_class_samples(unwrap_model(model), cfg, device, logger, global_step)
global_step += 1
logger.finish()
if use_ddp:
torch.distributed.destroy_process_group()
return unwrap_model(model)
def run_training_and_plot(cfg: TrainConfig) -> Path:
train(cfg)
return Path(cfg.output_dir)

39
main.py
View File

@@ -4,25 +4,40 @@ from as_mamba import TrainConfig, run_training_and_plot
def build_parser() -> argparse.ArgumentParser: def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Train AS-Mamba on sphere-to-sphere flow.") parser = argparse.ArgumentParser(description="Train AS-Mamba on MNIST MeanFlow x-prediction.")
parser.add_argument("--epochs", type=int, default=None) parser.add_argument("--epochs", type=int, default=None)
parser.add_argument("--warmup-epochs", type=int, default=None)
parser.add_argument("--batch-size", type=int, default=None)
parser.add_argument("--steps-per-epoch", type=int, default=None) parser.add_argument("--steps-per-epoch", type=int, default=None)
parser.add_argument("--batch-size", type=int, default=None)
parser.add_argument("--seq-len", type=int, default=None) parser.add_argument("--seq-len", type=int, default=None)
parser.add_argument("--lr", type=float, default=None) parser.add_argument("--lr", type=float, default=None)
parser.add_argument("--weight-decay", type=float, default=None)
parser.add_argument("--max-grad-norm", type=float, default=None)
parser.add_argument("--device", type=str, default=None) parser.add_argument("--device", type=str, default=None)
parser.add_argument("--output-dir", type=str, default=None) parser.add_argument("--output-dir", type=str, default=None)
parser.add_argument("--project", type=str, default=None) parser.add_argument("--project", type=str, default=None)
parser.add_argument("--run-name", type=str, default=None) parser.add_argument("--run-name", type=str, default=None)
parser.add_argument("--lambda-flow", type=float, default=None)
parser.add_argument("--lambda-perceptual", type=float, default=None)
parser.add_argument("--num-classes", type=int, default=None)
parser.add_argument("--image-size", type=int, default=None)
parser.add_argument("--channels", type=int, default=None)
parser.add_argument("--num-workers", type=int, default=None)
parser.add_argument("--dataset-name", type=str, default=None)
parser.add_argument("--dataset-split", type=str, default=None)
parser.add_argument("--d-model", type=int, default=None)
parser.add_argument("--n-layer", type=int, default=None)
parser.add_argument("--d-state", type=int, default=None)
parser.add_argument("--d-conv", type=int, default=None)
parser.add_argument("--expand", type=int, default=None)
parser.add_argument("--headdim", type=int, default=None)
parser.add_argument("--chunk-size", type=int, default=None)
parser.add_argument("--use-residual", action=argparse.BooleanOptionalAction, default=None)
parser.add_argument("--val-every", type=int, default=None) parser.add_argument("--val-every", type=int, default=None)
parser.add_argument("--val-samples", type=int, default=None) parser.add_argument("--val-samples-per-class", type=int, default=None)
parser.add_argument("--val-plot-samples", type=int, default=None) parser.add_argument("--val-grid-rows", type=int, default=None)
parser.add_argument("--val-max-steps", type=int, default=None) parser.add_argument("--val-sampling-steps", type=int, default=None)
parser.add_argument("--center-min", type=float, default=None) parser.add_argument("--time-grid-size", type=int, default=None)
parser.add_argument("--center-max", type=float, default=None) parser.add_argument("--use-ddp", action=argparse.BooleanOptionalAction, default=None)
parser.add_argument("--center-distance-min", type=float, default=None)
parser.add_argument("--use-residual", action="store_true")
return parser return parser
@@ -35,8 +50,8 @@ def main() -> None:
if value is not None: if value is not None:
setattr(cfg, key.replace("-", "_"), value) setattr(cfg, key.replace("-", "_"), value)
plot_path = run_training_and_plot(cfg) out_path = run_training_and_plot(cfg)
print(f"Saved trajectory plot to {plot_path}") print(f"Saved outputs to {out_path}")
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -5,9 +5,12 @@ description = "Add your description here"
readme = "README.md" readme = "README.md"
requires-python = ">=3.12" requires-python = ">=3.12"
dependencies = [ dependencies = [
"datasets>=2.19.0",
"einops>=0.7.0", "einops>=0.7.0",
"lpips>=0.1.4",
"matplotlib>=3.8.0", "matplotlib>=3.8.0",
"numpy>=1.26.0", "numpy>=1.26.0",
"swanlab>=0.5.0", "swanlab>=0.5.0",
"torch>=2.2.0", "torch>=2.2.0",
"torchvision>=0.24.1",
] ]

74
train_as_mamba.sh Executable file
View File

@@ -0,0 +1,74 @@
#!/usr/bin/env bash
set -euo pipefail
DEVICE="cuda"
EPOCHS=2000
STEPS_PER_EPOCH=200
BATCH_SIZE=512
SEQ_LEN=5
LR=1e-3
WEIGHT_DECAY=1e-2
LAMBDA_FLOW=1.0
LAMBDA_PERCEPTUAL=0.4
NUM_CLASSES=10
IMAGE_SIZE=28
CHANNELS=1
NUM_WORKERS=32
DATASET_NAME="ylecun/mnist"
DATASET_SPLIT="train"
D_MODEL=784
N_LAYER=8
D_STATE=32
D_CONV=4
EXPAND=2
HEADDIM=32
CHUNK_SIZE=1
USE_RESIDUAL=true
USE_DDP=true
VAL_EVERY=1000
VAL_SAMPLES_PER_CLASS=8
VAL_GRID_ROWS=4
VAL_SAMPLING_STEPS=5
TIME_GRID_SIZE=256
PROJECT="as-mamba-mnist"
RUN_NAME="mnist-meanflow-xpred"
OUTPUT_DIR="outputs"
USE_RESIDUAL_FLAG="--use-residual"
if [ "${USE_RESIDUAL}" = "false" ]; then USE_RESIDUAL_FLAG="--no-use-residual"; fi
USE_DDP_FLAG="--use-ddp"
if [ "${USE_DDP}" = "false" ]; then USE_DDP_FLAG="--no-use-ddp"; fi
uv run torchrun --nproc_per_node=2 main.py \
--device "${DEVICE}" \
--epochs "${EPOCHS}" \
--steps-per-epoch "${STEPS_PER_EPOCH}" \
--batch-size "${BATCH_SIZE}" \
--seq-len "${SEQ_LEN}" \
--lr "${LR}" \
--weight-decay "${WEIGHT_DECAY}" \
--lambda-flow "${LAMBDA_FLOW}" \
--lambda-perceptual "${LAMBDA_PERCEPTUAL}" \
--num-classes "${NUM_CLASSES}" \
--image-size "${IMAGE_SIZE}" \
--channels "${CHANNELS}" \
--num-workers "${NUM_WORKERS}" \
--dataset-name "${DATASET_NAME}" \
--dataset-split "${DATASET_SPLIT}" \
--d-model "${D_MODEL}" \
--n-layer "${N_LAYER}" \
--d-state "${D_STATE}" \
--d-conv "${D_CONV}" \
--expand "${EXPAND}" \
--headdim "${HEADDIM}" \
--chunk-size "${CHUNK_SIZE}" \
${USE_RESIDUAL_FLAG} \
${USE_DDP_FLAG} \
--val-every "${VAL_EVERY}" \
--val-samples-per-class "${VAL_SAMPLES_PER_CLASS}" \
--val-grid-rows "${VAL_GRID_ROWS}" \
--val-sampling-steps "${VAL_SAMPLING_STEPS}" \
--time-grid-size "${TIME_GRID_SIZE}" \
--project "${PROJECT}" \
--run-name "${RUN_NAME}" \
--output-dir "${OUTPUT_DIR}"

986
uv.lock generated

File diff suppressed because it is too large Load Diff