Compare commits
7 Commits
no_dt_pred
...
mamba_mean
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5897a0afd1 | ||
|
|
9b2968997c | ||
|
|
01fc1e4eab | ||
|
|
913740266b | ||
|
|
444f5fc109 | ||
|
|
c15115edc4 | ||
|
|
cac3236f9d |
27
AGENTS.md
Normal file
27
AGENTS.md
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
# Repository Guidelines
|
||||||
|
|
||||||
|
## Project Structure & Module Organization
|
||||||
|
This repository is a flat Python training project, not a packaged library. `main.py` is the CLI entry point. `as_mamba.py` holds `TrainConfig`, dataset loading, the training loop, validation image generation, and SwanLab logging. `mamba2_minimal.py` contains the standalone Mamba-2 building blocks. `train_as_mamba.sh` is the multi-GPU launcher. Runtime artifacts land in `outputs/` and `swanlog/`; treat both as generated data, not source.
|
||||||
|
|
||||||
|
## Build, Test, and Development Commands
|
||||||
|
Use `uv` with the committed lockfile.
|
||||||
|
|
||||||
|
- `uv sync` installs the Python 3.12 environment from `pyproject.toml` and `uv.lock`.
|
||||||
|
- `uv run python main.py --help` lists all training flags and is the fastest CLI sanity check.
|
||||||
|
- `uv run python main.py --device cpu --epochs 1 --steps-per-epoch 1 --batch-size 8 --num-workers 0 --val-every 0` runs a minimal local smoke test.
|
||||||
|
- `bash train_as_mamba.sh` launches the default distributed training job with `torchrun`; adjust `--nproc_per_node` and shell variables before using it on a new machine.
|
||||||
|
- `uv run python -m compileall main.py as_mamba.py mamba2_minimal.py` performs a lightweight syntax check when no test suite is available.
|
||||||
|
|
||||||
|
## Coding Style & Naming Conventions
|
||||||
|
Follow the existing Python style: 4-space indentation, type hints on public functions, `dataclass`-based config, and small helper functions over deeply nested logic. Use `snake_case` for functions, variables, and CLI flags; use `PascalCase` for classes such as `ASMamba` and `TrainConfig`. Keep new modules top-level unless the project is restructured. There is no configured formatter or linter yet, so match surrounding code closely and keep imports grouped and readable.
|
||||||
|
|
||||||
|
## Testing Guidelines
|
||||||
|
There is no dedicated `tests/` directory or pytest setup yet. For changes, require at minimum:
|
||||||
|
|
||||||
|
- `uv run python -m compileall main.py as_mamba.py mamba2_minimal.py`
|
||||||
|
- one short training smoke run with reduced epochs and workers
|
||||||
|
|
||||||
|
If you add reusable logic, prefer extracting pure functions so a future pytest suite can cover them easily. Name any new test files `test_<module>.py`.
|
||||||
|
|
||||||
|
## Commit & Pull Request Guidelines
|
||||||
|
Recent history favors short imperative commit subjects, usually Conventional Commit style: `feat: ...`, `fix: ...`, and scoped variants like `feat(mamba): ...`. Keep commits focused on one training or infrastructure change. Pull requests should describe the config or behavior change, list the verification commands you ran, and include sample images or metric notes when outputs in `outputs/` materially change. Do not commit `.venv/`, `__pycache__/`, or large generated artifacts.
|
||||||
753
as_mamba.py
753
as_mamba.py
@@ -1,9 +1,14 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
import os
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Iterator, Optional
|
||||||
|
|
||||||
|
import lpips
|
||||||
|
|
||||||
|
os.environ.setdefault("MPLCONFIGDIR", "/tmp/mamba_diffusion_mplconfig")
|
||||||
import matplotlib
|
import matplotlib
|
||||||
|
|
||||||
matplotlib.use("Agg")
|
matplotlib.use("Agg")
|
||||||
@@ -11,36 +16,41 @@ matplotlib.use("Agg")
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from datasets import load_dataset
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from mpl_toolkits.mplot3d import Axes3D # noqa: F401
|
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
from torch.func import jvp
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
from mamba2_minimal import InferenceCache, Mamba2, Mamba2Config, RMSNorm
|
from mamba2_minimal import InferenceCache, Mamba2, Mamba2Config, RMSNorm
|
||||||
|
|
||||||
|
|
||||||
|
FIXED_VAL_SAMPLING_STEPS = 5
|
||||||
|
FIXED_VAL_TIME_GRID = (1.0, 0.8, 0.6, 0.4, 0.2, 0.0)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainConfig:
|
class TrainConfig:
|
||||||
seed: int = 42
|
seed: int = 42
|
||||||
device: str = "cuda"
|
device: str = "cuda"
|
||||||
|
epochs: int = 50
|
||||||
|
steps_per_epoch: int = 200
|
||||||
batch_size: int = 128
|
batch_size: int = 128
|
||||||
steps_per_epoch: int = 50
|
seq_len: int = 5
|
||||||
epochs: int = 60
|
lr: float = 2e-4
|
||||||
warmup_epochs: int = 15
|
|
||||||
seq_len: int = 20
|
|
||||||
lr: float = 1e-3
|
|
||||||
weight_decay: float = 1e-2
|
weight_decay: float = 1e-2
|
||||||
dt_min: float = 1e-3
|
max_grad_norm: float = 1.0
|
||||||
dt_max: float = 0.06
|
|
||||||
lambda_flow: float = 1.0
|
lambda_flow: float = 1.0
|
||||||
lambda_pos: float = 1.0
|
lambda_perceptual: float = 2.0
|
||||||
lambda_nfe: float = 0.05
|
num_classes: int = 10
|
||||||
radius_min: float = 0.6
|
image_size: int = 28
|
||||||
radius_max: float = 1.4
|
channels: int = 1
|
||||||
center_min: float = -6.0
|
num_workers: int = 8
|
||||||
center_max: float = 6.0
|
dataset_name: str = "ylecun/mnist"
|
||||||
center_distance_min: float = 6.0
|
dataset_split: str = "train"
|
||||||
d_model: int = 128
|
d_model: int = 0
|
||||||
n_layer: int = 4
|
n_layer: int = 6
|
||||||
d_state: int = 64
|
d_state: int = 64
|
||||||
d_conv: int = 4
|
d_conv: int = 4
|
||||||
expand: int = 2
|
expand: int = 2
|
||||||
@@ -48,12 +58,29 @@ class TrainConfig:
|
|||||||
chunk_size: int = 1
|
chunk_size: int = 1
|
||||||
use_residual: bool = False
|
use_residual: bool = False
|
||||||
output_dir: str = "outputs"
|
output_dir: str = "outputs"
|
||||||
project: str = "as-mamba"
|
project: str = "as-mamba-mnist"
|
||||||
run_name: str = "sphere-to-sphere"
|
run_name: str = "mnist-meanflow"
|
||||||
val_every: int = 200
|
val_every: int = 200
|
||||||
val_samples: int = 256
|
val_samples_per_class: int = 8
|
||||||
val_plot_samples: int = 16
|
val_grid_rows: int = 4
|
||||||
val_max_steps: int = 100
|
val_sampling_steps: int = FIXED_VAL_SAMPLING_STEPS
|
||||||
|
time_grid_size: int = 256
|
||||||
|
use_ddp: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class AdaLNZero(nn.Module):
|
||||||
|
def __init__(self, d_model: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.norm = RMSNorm(d_model)
|
||||||
|
self.mod = nn.Linear(d_model, 2 * d_model)
|
||||||
|
nn.init.zeros_(self.mod.weight)
|
||||||
|
nn.init.zeros_(self.mod.bias)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, cond: Tensor) -> Tensor:
|
||||||
|
x = self.norm(x)
|
||||||
|
params = self.mod(cond).unsqueeze(1)
|
||||||
|
scale, shift = params.chunk(2, dim=-1)
|
||||||
|
return x * (1 + scale) + shift
|
||||||
|
|
||||||
|
|
||||||
class Mamba2Backbone(nn.Module):
|
class Mamba2Backbone(nn.Module):
|
||||||
@@ -66,7 +93,7 @@ class Mamba2Backbone(nn.Module):
|
|||||||
nn.ModuleDict(
|
nn.ModuleDict(
|
||||||
dict(
|
dict(
|
||||||
mixer=Mamba2(args),
|
mixer=Mamba2(args),
|
||||||
norm=RMSNorm(args.d_model),
|
adaln=AdaLNZero(args.d_model),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
for _ in range(args.n_layer)
|
for _ in range(args.n_layer)
|
||||||
@@ -75,25 +102,56 @@ class Mamba2Backbone(nn.Module):
|
|||||||
self.norm_f = RMSNorm(args.d_model)
|
self.norm_f = RMSNorm(args.d_model)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: Tensor, h: Optional[list[InferenceCache]] = None
|
self, x: Tensor, cond: Tensor, h: Optional[list[InferenceCache]] = None
|
||||||
) -> tuple[Tensor, list[InferenceCache]]:
|
) -> tuple[Tensor, list[InferenceCache]]:
|
||||||
if h is None:
|
if h is None:
|
||||||
h = [None for _ in range(self.args.n_layer)]
|
h = [None for _ in range(self.args.n_layer)]
|
||||||
|
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
y, h[i] = layer["mixer"](layer["norm"](x), h[i])
|
x_mod = layer["adaln"](x, cond)
|
||||||
|
y, h[i] = layer["mixer"](x_mod, h[i])
|
||||||
x = x + y if self.use_residual else y
|
x = x + y if self.use_residual else y
|
||||||
|
|
||||||
x = self.norm_f(x)
|
x = self.norm_f(x)
|
||||||
return x, h
|
return x, h
|
||||||
|
|
||||||
|
|
||||||
|
def sinusoidal_embedding(t: Tensor, dim: int) -> Tensor:
|
||||||
|
if dim <= 0:
|
||||||
|
raise ValueError("sinusoidal embedding dim must be > 0")
|
||||||
|
half = dim // 2
|
||||||
|
if half == 0:
|
||||||
|
return torch.zeros(*t.shape, dim, device=t.device, dtype=t.dtype)
|
||||||
|
denom = max(half - 1, 1)
|
||||||
|
freqs = torch.exp(
|
||||||
|
-math.log(10000.0)
|
||||||
|
* torch.arange(half, device=t.device, dtype=torch.float32)
|
||||||
|
/ denom
|
||||||
|
)
|
||||||
|
args = t.to(torch.float32).unsqueeze(-1) * freqs
|
||||||
|
emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
|
||||||
|
if dim % 2 == 1:
|
||||||
|
pad = torch.zeros(*emb.shape[:-1], 1, device=emb.device, dtype=emb.dtype)
|
||||||
|
emb = torch.cat([emb, pad], dim=-1)
|
||||||
|
return emb.to(t.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def safe_time_divisor(t: Tensor) -> Tensor:
|
||||||
|
eps = torch.finfo(t.dtype).eps
|
||||||
|
return torch.where(t > 0, t, torch.full_like(t, eps))
|
||||||
|
|
||||||
|
|
||||||
class ASMamba(nn.Module):
|
class ASMamba(nn.Module):
|
||||||
def __init__(self, cfg: TrainConfig) -> None:
|
def __init__(self, cfg: TrainConfig) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.dt_min = float(cfg.dt_min)
|
input_dim = cfg.channels * cfg.image_size * cfg.image_size
|
||||||
self.dt_max = float(cfg.dt_max)
|
if cfg.d_model == 0:
|
||||||
|
cfg.d_model = input_dim
|
||||||
|
if cfg.d_model != input_dim:
|
||||||
|
raise ValueError(
|
||||||
|
f"d_model must equal flattened image dim ({input_dim}) when input_proj is disabled."
|
||||||
|
)
|
||||||
|
|
||||||
args = Mamba2Config(
|
args = Mamba2Config(
|
||||||
d_model=cfg.d_model,
|
d_model=cfg.d_model,
|
||||||
@@ -105,29 +163,43 @@ class ASMamba(nn.Module):
|
|||||||
chunk_size=cfg.chunk_size,
|
chunk_size=cfg.chunk_size,
|
||||||
)
|
)
|
||||||
self.backbone = Mamba2Backbone(args, use_residual=cfg.use_residual)
|
self.backbone = Mamba2Backbone(args, use_residual=cfg.use_residual)
|
||||||
self.input_proj = nn.Linear(3, cfg.d_model)
|
self.cond_emb = nn.Embedding(cfg.num_classes, cfg.d_model)
|
||||||
self.delta_head = nn.Linear(cfg.d_model, 3)
|
self.clean_head = nn.Linear(cfg.d_model, input_dim)
|
||||||
self.dt_head = nn.Sequential(
|
|
||||||
nn.Linear(cfg.d_model, cfg.d_model),
|
|
||||||
nn.SiLU(),
|
|
||||||
nn.Linear(cfg.d_model, 1),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: Tensor, h: Optional[list[InferenceCache]] = None
|
self,
|
||||||
) -> tuple[Tensor, Tensor, list[InferenceCache]]:
|
z_t: Tensor,
|
||||||
x_proj = self.input_proj(x)
|
r: Tensor,
|
||||||
feats, h = self.backbone(x_proj, h)
|
t: Tensor,
|
||||||
delta = self.delta_head(feats)
|
cond: Tensor,
|
||||||
dt_raw = self.dt_head(feats).squeeze(-1)
|
h: Optional[list[InferenceCache]] = None,
|
||||||
dt = torch.clamp(F.softplus(dt_raw), min=self.dt_min, max=self.dt_max)
|
) -> tuple[Tensor, list[InferenceCache]]:
|
||||||
return delta, dt, h
|
if r.dim() == 1:
|
||||||
|
r = r.unsqueeze(1)
|
||||||
|
elif r.dim() == 3 and r.shape[-1] == 1:
|
||||||
|
r = r.squeeze(-1)
|
||||||
|
if t.dim() == 1:
|
||||||
|
t = t.unsqueeze(1)
|
||||||
|
elif t.dim() == 3 and t.shape[-1] == 1:
|
||||||
|
t = t.squeeze(-1)
|
||||||
|
|
||||||
|
r = r.to(dtype=z_t.dtype)
|
||||||
|
t = t.to(dtype=z_t.dtype)
|
||||||
|
z_t = z_t + sinusoidal_embedding(r, z_t.shape[-1]) + sinusoidal_embedding(
|
||||||
|
t, z_t.shape[-1]
|
||||||
|
)
|
||||||
|
cond_vec = self.cond_emb(cond)
|
||||||
|
feats, h = self.backbone(z_t, cond_vec, h)
|
||||||
|
x_pred = torch.tanh(self.clean_head(feats))
|
||||||
|
return x_pred, h
|
||||||
|
|
||||||
def step(
|
def step(
|
||||||
self, x: Tensor, h: list[InferenceCache]
|
self, z_t: Tensor, r: Tensor, t: Tensor, cond: Tensor, h: list[InferenceCache]
|
||||||
) -> tuple[Tensor, Tensor, list[InferenceCache]]:
|
) -> tuple[Tensor, list[InferenceCache]]:
|
||||||
delta, dt, h = self.forward(x.unsqueeze(1), h)
|
x_pred, h = self.forward(
|
||||||
return delta[:, 0, :], dt[:, 0], h
|
z_t.unsqueeze(1), r.unsqueeze(1), t.unsqueeze(1), cond, h
|
||||||
|
)
|
||||||
|
return x_pred[:, 0, :], h
|
||||||
|
|
||||||
def init_cache(self, batch_size: int, device: torch.device) -> list[InferenceCache]:
|
def init_cache(self, batch_size: int, device: torch.device) -> list[InferenceCache]:
|
||||||
return [
|
return [
|
||||||
@@ -137,10 +209,12 @@ class ASMamba(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class SwanLogger:
|
class SwanLogger:
|
||||||
def __init__(self, cfg: TrainConfig) -> None:
|
def __init__(self, cfg: TrainConfig, enabled: bool = True) -> None:
|
||||||
self.enabled = False
|
self.enabled = enabled
|
||||||
self._swan = None
|
self._swan = None
|
||||||
self._run = None
|
self._run = None
|
||||||
|
if not self.enabled:
|
||||||
|
return
|
||||||
try:
|
try:
|
||||||
import swanlab # type: ignore
|
import swanlab # type: ignore
|
||||||
|
|
||||||
@@ -169,7 +243,11 @@ class SwanLogger:
|
|||||||
target.log(payload)
|
target.log(payload)
|
||||||
|
|
||||||
def log_image(
|
def log_image(
|
||||||
self, key: str, image_path: Path, caption: str | None = None, step: int | None = None
|
self,
|
||||||
|
key: str,
|
||||||
|
image_path: Path,
|
||||||
|
caption: str | None = None,
|
||||||
|
step: int | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if not self.enabled:
|
if not self.enabled:
|
||||||
return
|
return
|
||||||
@@ -188,324 +266,391 @@ class SwanLogger:
|
|||||||
finish()
|
finish()
|
||||||
|
|
||||||
|
|
||||||
|
class LPIPSPerceptualLoss(nn.Module):
|
||||||
|
def __init__(self, cfg: TrainConfig) -> None:
|
||||||
|
super().__init__()
|
||||||
|
torch_home = Path(cfg.output_dir) / ".torch"
|
||||||
|
torch_home.mkdir(parents=True, exist_ok=True)
|
||||||
|
os.environ["TORCH_HOME"] = str(torch_home)
|
||||||
|
self.channels = cfg.channels
|
||||||
|
self.loss_fn = lpips.LPIPS(net="vgg", verbose=False)
|
||||||
|
self.loss_fn.eval()
|
||||||
|
for param in self.loss_fn.parameters():
|
||||||
|
param.requires_grad_(False)
|
||||||
|
|
||||||
|
def _prepare_images(self, images: Tensor) -> Tensor:
|
||||||
|
if images.shape[1] == 1:
|
||||||
|
return images.repeat(1, 3, 1, 1)
|
||||||
|
if images.shape[1] != 3:
|
||||||
|
raise ValueError(
|
||||||
|
"LPIPS perceptual loss expects 1-channel or 3-channel images. "
|
||||||
|
f"Got {images.shape[1]} channels."
|
||||||
|
)
|
||||||
|
return images
|
||||||
|
|
||||||
|
def forward(self, pred: Tensor, target: Tensor) -> Tensor:
|
||||||
|
pred_rgb = self._prepare_images(pred)
|
||||||
|
target_rgb = self._prepare_images(target)
|
||||||
|
return self.loss_fn(pred_rgb, target_rgb).mean()
|
||||||
|
|
||||||
|
|
||||||
def set_seed(seed: int) -> None:
|
def set_seed(seed: int) -> None:
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.manual_seed_all(seed)
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
|
||||||
def sample_points_in_sphere(
|
def setup_distributed(cfg: TrainConfig) -> tuple[bool, int, int, torch.device]:
|
||||||
center: Tensor, radius: float, batch_size: int, device: torch.device
|
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
||||||
) -> Tensor:
|
rank = int(os.environ.get("RANK", "0"))
|
||||||
direction = torch.randn(batch_size, 3, device=device)
|
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
||||||
direction = direction / (direction.norm(dim=-1, keepdim=True) + 1e-8)
|
use_ddp = cfg.use_ddp and world_size > 1
|
||||||
u = torch.rand(batch_size, 1, device=device)
|
if use_ddp:
|
||||||
r = radius * torch.pow(u, 1.0 / 3.0)
|
torch.distributed.init_process_group(backend="nccl", init_method="env://")
|
||||||
return center + direction * r
|
torch.cuda.set_device(local_rank)
|
||||||
|
device = torch.device("cuda", local_rank)
|
||||||
|
else:
|
||||||
|
device = torch.device(cfg.device if torch.cuda.is_available() else "cpu")
|
||||||
|
return use_ddp, rank, world_size, device
|
||||||
|
|
||||||
|
|
||||||
def sample_sphere_params(cfg: TrainConfig, device: torch.device) -> tuple[Tensor, Tensor]:
|
def unwrap_model(model: nn.Module) -> nn.Module:
|
||||||
center_a = torch.empty(3, device=device).uniform_(cfg.center_min, cfg.center_max)
|
return model.module if hasattr(model, "module") else model
|
||||||
center_b = torch.empty(3, device=device).uniform_(cfg.center_min, cfg.center_max)
|
|
||||||
for _ in range(128):
|
|
||||||
if torch.norm(center_a - center_b) >= cfg.center_distance_min:
|
def validate_config(cfg: TrainConfig) -> None:
|
||||||
break
|
if cfg.seq_len != 5:
|
||||||
center_b = torch.empty(3, device=device).uniform_(cfg.center_min, cfg.center_max)
|
raise ValueError(
|
||||||
if torch.norm(center_a - center_b) < 1e-3:
|
f"seq_len must be 5 for the required 5-block training setup (got {cfg.seq_len})."
|
||||||
center_b = center_b + torch.tensor([cfg.center_distance_min, 0.0, 0.0], device=device)
|
)
|
||||||
radius_a = float(torch.empty(1).uniform_(cfg.radius_min, cfg.radius_max).item())
|
if cfg.time_grid_size < 2:
|
||||||
radius_b = float(torch.empty(1).uniform_(cfg.radius_min, cfg.radius_max).item())
|
raise ValueError("time_grid_size must be >= 2.")
|
||||||
return (center_a, torch.tensor(radius_a, device=device)), (
|
if cfg.lambda_perceptual < 0:
|
||||||
center_b,
|
raise ValueError("lambda_perceptual must be >= 0.")
|
||||||
torch.tensor(radius_b, device=device),
|
if cfg.max_grad_norm <= 0:
|
||||||
|
raise ValueError("max_grad_norm must be > 0.")
|
||||||
|
if cfg.val_sampling_steps != FIXED_VAL_SAMPLING_STEPS:
|
||||||
|
raise ValueError(
|
||||||
|
f"val_sampling_steps is fixed to {FIXED_VAL_SAMPLING_STEPS} for validation sampling."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def sample_block_times(
|
||||||
|
cfg: TrainConfig, batch_size: int, device: torch.device, dtype: torch.dtype
|
||||||
|
) -> tuple[Tensor, Tensor]:
|
||||||
|
num_internal = cfg.seq_len - 1
|
||||||
|
normal = torch.randn(batch_size, num_internal, device=device, dtype=dtype)
|
||||||
|
logit_normal = torch.sigmoid(normal * math.sqrt(0.8))
|
||||||
|
uniform = torch.rand(batch_size, num_internal, device=device, dtype=dtype)
|
||||||
|
use_uniform = torch.rand(batch_size, num_internal, device=device) < 0.1
|
||||||
|
cuts = torch.where(use_uniform, uniform, logit_normal)
|
||||||
|
cuts, _ = torch.sort(cuts, dim=-1)
|
||||||
|
boundaries = torch.cat(
|
||||||
|
[
|
||||||
|
torch.zeros(batch_size, 1, device=device, dtype=dtype),
|
||||||
|
cuts,
|
||||||
|
torch.ones(batch_size, 1, device=device, dtype=dtype),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
)
|
)
|
||||||
|
return boundaries[:, :-1], boundaries[:, 1:]
|
||||||
|
|
||||||
|
|
||||||
def sample_batch(
|
def build_dataloader(
|
||||||
cfg: TrainConfig,
|
cfg: TrainConfig, distributed: bool = False
|
||||||
sphere_a: tuple[Tensor, Tensor],
|
) -> tuple[DataLoader, Optional[DistributedSampler]]:
|
||||||
sphere_b: tuple[Tensor, Tensor],
|
ds = load_dataset(cfg.dataset_name, split=cfg.dataset_split)
|
||||||
device: torch.device,
|
|
||||||
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
|
def transform(example):
|
||||||
center_a, radius_a = sphere_a
|
image = example.get("img", example.get("image"))
|
||||||
center_b, radius_b = sphere_b
|
label = example.get("label", example.get("labels"))
|
||||||
x0 = sample_points_in_sphere(center_a, float(radius_a.item()), cfg.batch_size, device)
|
if isinstance(image, list):
|
||||||
x1 = sample_points_in_sphere(center_b, float(radius_b.item()), cfg.batch_size, device)
|
arr = np.stack([np.array(im, dtype=np.float32) for im in image], axis=0)
|
||||||
v_gt = x1 - x0
|
arr = arr / 127.5 - 1.0
|
||||||
dt_fixed = 1.0 / cfg.seq_len
|
if arr.ndim == 3:
|
||||||
t_seq = torch.arange(cfg.seq_len, device=device) * dt_fixed
|
tensor = torch.from_numpy(arr).unsqueeze(1)
|
||||||
x_seq = x0[:, None, :] + t_seq[None, :, None] * v_gt[:, None, :]
|
else:
|
||||||
return x0, x1, x_seq, t_seq
|
tensor = torch.from_numpy(arr).permute(0, 3, 1, 2)
|
||||||
|
labels = torch.tensor(label, dtype=torch.long)
|
||||||
|
return {"pixel_values": tensor, "labels": labels}
|
||||||
|
arr = np.array(image, dtype=np.float32) / 127.5 - 1.0
|
||||||
|
if arr.ndim == 2:
|
||||||
|
tensor = torch.from_numpy(arr).unsqueeze(0)
|
||||||
|
else:
|
||||||
|
tensor = torch.from_numpy(arr).permute(2, 0, 1)
|
||||||
|
return {"pixel_values": tensor, "labels": torch.tensor(label, dtype=torch.long)}
|
||||||
|
|
||||||
|
ds = ds.with_transform(transform)
|
||||||
|
sampler = DistributedSampler(ds, shuffle=True) if distributed else None
|
||||||
|
loader = DataLoader(
|
||||||
|
ds,
|
||||||
|
batch_size=cfg.batch_size,
|
||||||
|
shuffle=(sampler is None),
|
||||||
|
sampler=sampler,
|
||||||
|
num_workers=cfg.num_workers,
|
||||||
|
drop_last=True,
|
||||||
|
pin_memory=torch.cuda.is_available(),
|
||||||
|
)
|
||||||
|
return loader, sampler
|
||||||
|
|
||||||
|
|
||||||
|
def infinite_loader(loader: DataLoader) -> Iterator[dict]:
|
||||||
|
while True:
|
||||||
|
for batch in loader:
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
|
||||||
|
def build_noisy_sequence(
|
||||||
|
x0: Tensor,
|
||||||
|
eps: Tensor,
|
||||||
|
t_seq: Tensor,
|
||||||
|
) -> tuple[Tensor, Tensor]:
|
||||||
|
z_t = (1.0 - t_seq.unsqueeze(-1)) * x0[:, None, :] + t_seq.unsqueeze(-1) * eps[:, None, :]
|
||||||
|
v_gt = eps - x0
|
||||||
|
return z_t, v_gt
|
||||||
|
|
||||||
|
|
||||||
def compute_losses(
|
def compute_losses(
|
||||||
delta: Tensor,
|
model: nn.Module,
|
||||||
dt: Tensor,
|
perceptual_loss_fn: LPIPSPerceptualLoss,
|
||||||
x_seq: Tensor,
|
|
||||||
x0: Tensor,
|
x0: Tensor,
|
||||||
|
z_t: Tensor,
|
||||||
v_gt: Tensor,
|
v_gt: Tensor,
|
||||||
|
r_seq: Tensor,
|
||||||
t_seq: Tensor,
|
t_seq: Tensor,
|
||||||
|
cond: Tensor,
|
||||||
cfg: TrainConfig,
|
cfg: TrainConfig,
|
||||||
) -> tuple[Tensor, Tensor, Tensor]:
|
) -> tuple[dict[str, Tensor], Tensor]:
|
||||||
target_disp = v_gt[:, None, :] * dt.unsqueeze(-1)
|
seq_len = z_t.shape[1]
|
||||||
flow_loss = F.mse_loss(delta, target_disp)
|
safe_t = safe_time_divisor(t_seq).unsqueeze(-1)
|
||||||
|
|
||||||
t_next = t_seq[None, :, None] + dt.unsqueeze(-1)
|
x_pred, _ = model(z_t, r_seq, t_seq, cond)
|
||||||
t_next = torch.clamp(t_next, 0.0, 1.0)
|
u = (z_t - x_pred) / safe_t
|
||||||
x_target = x0[:, None, :] + t_next * v_gt[:, None, :]
|
|
||||||
x_next_pred = x_seq + delta
|
|
||||||
pos_loss = F.mse_loss(x_next_pred, x_target)
|
|
||||||
|
|
||||||
nfe_loss = (-torch.log(dt)).mean()
|
x_pred_inst, _ = model(z_t, t_seq, t_seq, cond)
|
||||||
return flow_loss, pos_loss, nfe_loss
|
v_inst = ((z_t - x_pred_inst) / safe_t).detach()
|
||||||
|
|
||||||
|
def u_fn(z_in: Tensor, r_in: Tensor, t_in: Tensor) -> Tensor:
|
||||||
|
x_pred_local, _ = model(z_in, r_in, t_in, cond)
|
||||||
|
return (z_in - x_pred_local) / safe_time_divisor(t_in).unsqueeze(-1)
|
||||||
|
|
||||||
|
_, dudt = jvp(
|
||||||
|
u_fn,
|
||||||
|
(z_t, r_seq, t_seq),
|
||||||
|
(v_inst, torch.zeros_like(r_seq), torch.ones_like(t_seq)),
|
||||||
|
)
|
||||||
|
corrected_velocity = u + (t_seq - r_seq).unsqueeze(-1) * dudt.detach()
|
||||||
|
target_velocity = v_gt[:, None, :].expand(-1, seq_len, -1)
|
||||||
|
|
||||||
|
pred_images = x_pred.reshape(
|
||||||
|
x0.shape[0] * seq_len, cfg.channels, cfg.image_size, cfg.image_size
|
||||||
|
)
|
||||||
|
target_images = (
|
||||||
|
x0.reshape(x0.shape[0], cfg.channels, cfg.image_size, cfg.image_size)
|
||||||
|
.unsqueeze(1)
|
||||||
|
.expand(-1, seq_len, -1, -1, -1)
|
||||||
|
.reshape(x0.shape[0] * seq_len, cfg.channels, cfg.image_size, cfg.image_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
losses = {
|
||||||
|
"flow": F.mse_loss(corrected_velocity, target_velocity),
|
||||||
|
"perceptual": perceptual_loss_fn(pred_images, target_images),
|
||||||
|
}
|
||||||
|
losses["total"] = cfg.lambda_flow * losses["flow"] + cfg.lambda_perceptual * losses[
|
||||||
|
"perceptual"
|
||||||
|
]
|
||||||
|
return losses, x_pred
|
||||||
|
|
||||||
|
|
||||||
def validate(
|
def make_grid(images: Tensor, nrow: int) -> np.ndarray:
|
||||||
|
images = images.detach().cpu().numpy()
|
||||||
|
b, c, h, w = images.shape
|
||||||
|
nrow = max(1, min(nrow, b))
|
||||||
|
ncol = math.ceil(b / nrow)
|
||||||
|
grid = np.zeros((c, ncol * h, nrow * w), dtype=np.float32)
|
||||||
|
for idx in range(b):
|
||||||
|
r = idx // nrow
|
||||||
|
cidx = idx % nrow
|
||||||
|
grid[:, r * h : (r + 1) * h, cidx * w : (cidx + 1) * w] = images[idx]
|
||||||
|
return np.transpose(grid, (1, 2, 0))
|
||||||
|
|
||||||
|
|
||||||
|
def save_image_grid(
|
||||||
|
images: Tensor, save_path: Path, nrow: int, title: str | None = None
|
||||||
|
) -> None:
|
||||||
|
images = images.clamp(-1.0, 1.0)
|
||||||
|
images = (images + 1.0) / 2.0
|
||||||
|
grid = make_grid(images, nrow=nrow)
|
||||||
|
if grid.ndim == 3 and grid.shape[2] == 1:
|
||||||
|
grid = np.repeat(grid, 3, axis=2)
|
||||||
|
plt.imsave(save_path, grid)
|
||||||
|
if title is not None:
|
||||||
|
fig, ax = plt.subplots(figsize=(4, 3))
|
||||||
|
ax.imshow(grid)
|
||||||
|
ax.set_title(title)
|
||||||
|
ax.axis("off")
|
||||||
|
fig.tight_layout()
|
||||||
|
fig.savefig(save_path, dpi=160)
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
def sample_class_images(
|
||||||
|
model: ASMamba,
|
||||||
|
cfg: TrainConfig,
|
||||||
|
device: torch.device,
|
||||||
|
cond: Tensor,
|
||||||
|
) -> Tensor:
|
||||||
|
model.eval()
|
||||||
|
input_dim = cfg.channels * cfg.image_size * cfg.image_size
|
||||||
|
z_t = torch.randn(cond.shape[0], input_dim, device=device)
|
||||||
|
time_grid = torch.tensor(FIXED_VAL_TIME_GRID, device=device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for step_idx in range(FIXED_VAL_SAMPLING_STEPS):
|
||||||
|
t_cur = torch.full(
|
||||||
|
(cond.shape[0],),
|
||||||
|
float(time_grid[step_idx].item()),
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
t_next = time_grid[step_idx + 1]
|
||||||
|
x_pred, _ = model(
|
||||||
|
z_t.unsqueeze(1),
|
||||||
|
t_cur.unsqueeze(1),
|
||||||
|
t_cur.unsqueeze(1),
|
||||||
|
cond,
|
||||||
|
)
|
||||||
|
x_pred = x_pred[:, 0, :]
|
||||||
|
u_inst = (z_t - x_pred) / safe_time_divisor(t_cur).unsqueeze(-1)
|
||||||
|
z_t = z_t + (t_next - t_cur).unsqueeze(-1) * u_inst
|
||||||
|
|
||||||
|
return z_t.view(cond.shape[0], cfg.channels, cfg.image_size, cfg.image_size)
|
||||||
|
|
||||||
|
|
||||||
|
def log_class_samples(
|
||||||
model: ASMamba,
|
model: ASMamba,
|
||||||
cfg: TrainConfig,
|
cfg: TrainConfig,
|
||||||
sphere_a: tuple[Tensor, Tensor],
|
|
||||||
sphere_b: tuple[Tensor, Tensor],
|
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
logger: SwanLogger,
|
logger: SwanLogger,
|
||||||
step: int,
|
step: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if cfg.val_samples_per_class <= 0:
|
||||||
|
return
|
||||||
|
training_mode = model.training
|
||||||
model.eval()
|
model.eval()
|
||||||
center_b, radius_b = sphere_b
|
for cls in range(cfg.num_classes):
|
||||||
|
cond = torch.full(
|
||||||
with torch.no_grad():
|
(cfg.val_samples_per_class,), cls, device=device, dtype=torch.long
|
||||||
x0 = sample_points_in_sphere(
|
|
||||||
sphere_a[0], float(sphere_a[1].item()), cfg.val_samples, device
|
|
||||||
)
|
)
|
||||||
traj = rollout_trajectory(model, x0, max_steps=cfg.val_max_steps)
|
x_final = sample_class_images(model, cfg, device, cond)
|
||||||
|
save_path = Path(cfg.output_dir) / f"val_class_{cls}_step_{step:06d}.png"
|
||||||
x_final = traj[:, -1, :]
|
save_image_grid(x_final, save_path, nrow=cfg.val_grid_rows)
|
||||||
center_b_cpu = center_b.detach().cpu()
|
|
||||||
radius_b_cpu = radius_b.detach().cpu()
|
|
||||||
dist = torch.linalg.norm(x_final - center_b_cpu, dim=-1)
|
|
||||||
inside = dist <= radius_b_cpu
|
|
||||||
|
|
||||||
logger.log(
|
|
||||||
{
|
|
||||||
"val/inside_ratio": float(inside.float().mean().item()),
|
|
||||||
"val/inside_count": float(inside.float().sum().item()),
|
|
||||||
"val/final_dist_mean": float(dist.mean().item()),
|
|
||||||
"val/final_dist_min": float(dist.min().item()),
|
|
||||||
"val/final_dist_max": float(dist.max().item()),
|
|
||||||
},
|
|
||||||
step=step,
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.val_plot_samples > 0:
|
|
||||||
count = min(cfg.val_plot_samples, traj.shape[0])
|
|
||||||
if count == 0:
|
|
||||||
model.train()
|
|
||||||
return
|
|
||||||
indices = torch.linspace(0, traj.shape[0] - 1, steps=count).long()
|
|
||||||
traj_plot = traj[indices]
|
|
||||||
save_path = Path(cfg.output_dir) / f"val_traj_step_{step:06d}.png"
|
|
||||||
plot_trajectories(
|
|
||||||
traj_plot,
|
|
||||||
sphere_a,
|
|
||||||
sphere_b,
|
|
||||||
save_path,
|
|
||||||
title=f"Validation Trajectories (step {step})",
|
|
||||||
)
|
|
||||||
ratio = float(inside.float().mean().item())
|
|
||||||
logger.log_image(
|
logger.log_image(
|
||||||
"val/trajectory",
|
f"val/class_{cls}",
|
||||||
save_path,
|
save_path,
|
||||||
caption=f"step {step} | inside_ratio={ratio:.3f}",
|
caption=f"class {cls} step {step}",
|
||||||
step=step,
|
step=step,
|
||||||
)
|
)
|
||||||
|
if training_mode:
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
|
|
||||||
def train(cfg: TrainConfig) -> tuple[ASMamba, tuple[Tensor, Tensor], tuple[Tensor, Tensor]]:
|
def build_perceptual_loss(
|
||||||
device = torch.device(cfg.device if torch.cuda.is_available() else "cpu")
|
cfg: TrainConfig, device: torch.device, rank: int, use_ddp: bool
|
||||||
set_seed(cfg.seed)
|
) -> LPIPSPerceptualLoss:
|
||||||
|
if use_ddp and rank != 0:
|
||||||
|
torch.distributed.barrier()
|
||||||
|
perceptual_loss_fn = LPIPSPerceptualLoss(cfg).to(device)
|
||||||
|
if use_ddp and rank == 0:
|
||||||
|
torch.distributed.barrier()
|
||||||
|
return perceptual_loss_fn
|
||||||
|
|
||||||
|
|
||||||
|
def train(cfg: TrainConfig) -> ASMamba:
|
||||||
|
validate_config(cfg)
|
||||||
|
use_ddp, rank, world_size, device = setup_distributed(cfg)
|
||||||
|
del world_size
|
||||||
|
set_seed(cfg.seed + rank)
|
||||||
output_dir = Path(cfg.output_dir)
|
output_dir = Path(cfg.output_dir)
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
if rank == 0:
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
model = ASMamba(cfg).to(device)
|
model = ASMamba(cfg).to(device)
|
||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
|
if use_ddp:
|
||||||
logger = SwanLogger(cfg)
|
model = nn.parallel.DistributedDataParallel(model, device_ids=[device.index])
|
||||||
|
optimizer = torch.optim.AdamW(
|
||||||
sphere_a, sphere_b = sample_sphere_params(cfg, device)
|
model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay
|
||||||
center_a, radius_a = sphere_a
|
|
||||||
center_b, radius_b = sphere_b
|
|
||||||
center_dist = torch.norm(center_a - center_b).item()
|
|
||||||
logger.log(
|
|
||||||
{
|
|
||||||
"sphere_a/radius": float(radius_a.item()),
|
|
||||||
"sphere_b/radius": float(radius_b.item()),
|
|
||||||
"sphere_a/center_x": float(center_a[0].item()),
|
|
||||||
"sphere_a/center_y": float(center_a[1].item()),
|
|
||||||
"sphere_a/center_z": float(center_a[2].item()),
|
|
||||||
"sphere_b/center_x": float(center_b[0].item()),
|
|
||||||
"sphere_b/center_y": float(center_b[1].item()),
|
|
||||||
"sphere_b/center_z": float(center_b[2].item()),
|
|
||||||
"sphere/center_dist": float(center_dist),
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
perceptual_loss_fn = build_perceptual_loss(cfg, device, rank, use_ddp)
|
||||||
|
logger = SwanLogger(cfg, enabled=(rank == 0))
|
||||||
|
|
||||||
|
loader, sampler = build_dataloader(cfg, distributed=use_ddp)
|
||||||
|
loader_iter = infinite_loader(loader)
|
||||||
|
|
||||||
global_step = 0
|
global_step = 0
|
||||||
for epoch in range(cfg.epochs):
|
for epoch_idx in range(cfg.epochs):
|
||||||
warmup = epoch < cfg.warmup_epochs
|
if sampler is not None:
|
||||||
|
sampler.set_epoch(epoch_idx)
|
||||||
model.train()
|
model.train()
|
||||||
for p in model.dt_head.parameters():
|
|
||||||
p.requires_grad = not warmup
|
|
||||||
|
|
||||||
for _ in range(cfg.steps_per_epoch):
|
for _ in range(cfg.steps_per_epoch):
|
||||||
x0, x1, x_seq, t_seq = sample_batch(cfg, sphere_a, sphere_b, device)
|
batch = next(loader_iter)
|
||||||
v_gt = x1 - x0
|
x0 = batch["pixel_values"].to(device)
|
||||||
|
cond = batch["labels"].to(device)
|
||||||
|
b = x0.shape[0]
|
||||||
|
x0 = x0.view(b, -1)
|
||||||
|
eps = torch.randn_like(x0)
|
||||||
|
|
||||||
delta, dt, _ = model(x_seq)
|
r_seq, t_seq = sample_block_times(cfg, b, device, x0.dtype)
|
||||||
if warmup:
|
z_t, v_gt = build_noisy_sequence(x0, eps, t_seq)
|
||||||
dt = torch.full_like(dt, 1.0 / cfg.seq_len)
|
|
||||||
|
|
||||||
flow_loss, pos_loss, nfe_loss = compute_losses(
|
losses, _ = compute_losses(
|
||||||
delta=delta,
|
model=model,
|
||||||
dt=dt,
|
perceptual_loss_fn=perceptual_loss_fn,
|
||||||
x_seq=x_seq,
|
|
||||||
x0=x0,
|
x0=x0,
|
||||||
|
z_t=z_t,
|
||||||
v_gt=v_gt,
|
v_gt=v_gt,
|
||||||
|
r_seq=r_seq,
|
||||||
t_seq=t_seq,
|
t_seq=t_seq,
|
||||||
|
cond=cond,
|
||||||
cfg=cfg,
|
cfg=cfg,
|
||||||
)
|
)
|
||||||
|
|
||||||
loss = cfg.lambda_flow * flow_loss + cfg.lambda_pos * pos_loss
|
|
||||||
if not warmup:
|
|
||||||
loss = loss + cfg.lambda_nfe * nfe_loss
|
|
||||||
|
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
loss.backward()
|
losses["total"].backward()
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
|
model.parameters(), max_norm=cfg.max_grad_norm
|
||||||
|
)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
if global_step % 10 == 0:
|
if global_step % 10 == 0 and rank == 0:
|
||||||
logger.log(
|
logger.log(
|
||||||
{
|
{
|
||||||
"loss/total": float(loss.item()),
|
"loss/total": float(losses["total"].item()),
|
||||||
"loss/flow": float(flow_loss.item()),
|
"loss/flow": float(losses["flow"].item()),
|
||||||
"loss/pos": float(pos_loss.item()),
|
"loss/perceptual": float(losses["perceptual"].item()),
|
||||||
"loss/nfe": float(nfe_loss.item()),
|
"grad/total_norm": float(grad_norm.item()),
|
||||||
"dt/mean": float(dt.mean().item()),
|
"time/r_mean": float(r_seq.mean().item()),
|
||||||
"dt/min": float(dt.min().item()),
|
"time/t_mean": float(t_seq.mean().item()),
|
||||||
"dt/max": float(dt.max().item()),
|
"time/zero_block_frac": float((t_seq == r_seq).float().mean().item()),
|
||||||
"stage": 0 if warmup else 1,
|
|
||||||
},
|
},
|
||||||
step=global_step,
|
step=global_step,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.val_every > 0 and global_step > 0 and global_step % cfg.val_every == 0:
|
if (
|
||||||
validate(model, cfg, sphere_a, sphere_b, device, logger, global_step)
|
cfg.val_every > 0
|
||||||
|
and global_step > 0
|
||||||
|
and global_step % cfg.val_every == 0
|
||||||
|
and rank == 0
|
||||||
|
):
|
||||||
|
log_class_samples(unwrap_model(model), cfg, device, logger, global_step)
|
||||||
|
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
logger.finish()
|
logger.finish()
|
||||||
return model, sphere_a, sphere_b
|
if use_ddp:
|
||||||
|
torch.distributed.destroy_process_group()
|
||||||
|
return unwrap_model(model)
|
||||||
def rollout_trajectory(
|
|
||||||
model: ASMamba,
|
|
||||||
x0: Tensor,
|
|
||||||
max_steps: int = 100,
|
|
||||||
) -> Tensor:
|
|
||||||
device = x0.device
|
|
||||||
model.eval()
|
|
||||||
h = model.init_cache(batch_size=x0.shape[0], device=device)
|
|
||||||
x = x0
|
|
||||||
total_time = torch.zeros(x0.shape[0], device=device)
|
|
||||||
traj = [x0.detach().cpu()]
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for _ in range(max_steps):
|
|
||||||
delta, dt, h = model.step(x, h)
|
|
||||||
dt = torch.clamp(dt, min=model.dt_min, max=model.dt_max)
|
|
||||||
remaining = 1.0 - total_time
|
|
||||||
overshoot = dt > remaining
|
|
||||||
if overshoot.any():
|
|
||||||
scale = (remaining / dt).unsqueeze(-1)
|
|
||||||
delta = torch.where(overshoot.unsqueeze(-1), delta * scale, delta)
|
|
||||||
dt = torch.where(overshoot, remaining, dt)
|
|
||||||
|
|
||||||
x = x + delta
|
|
||||||
total_time = total_time + dt
|
|
||||||
traj.append(x.detach().cpu())
|
|
||||||
|
|
||||||
if torch.all(total_time >= 1.0 - 1e-6):
|
|
||||||
break
|
|
||||||
|
|
||||||
return torch.stack(traj, dim=1)
|
|
||||||
|
|
||||||
|
|
||||||
def sphere_wireframe(
|
|
||||||
center: Tensor, radius: float, u_steps: int = 24, v_steps: int = 12
|
|
||||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
||||||
center_np = center.detach().cpu().numpy()
|
|
||||||
u = np.linspace(0, 2 * np.pi, u_steps)
|
|
||||||
v = np.linspace(0, np.pi, v_steps)
|
|
||||||
x = center_np[0] + radius * np.outer(np.cos(u), np.sin(v))
|
|
||||||
y = center_np[1] + radius * np.outer(np.sin(u), np.sin(v))
|
|
||||||
z = center_np[2] + radius * np.outer(np.ones_like(u), np.cos(v))
|
|
||||||
return x, y, z
|
|
||||||
|
|
||||||
|
|
||||||
def plot_trajectories(
|
|
||||||
traj: Tensor,
|
|
||||||
sphere_a: tuple[Tensor, Tensor],
|
|
||||||
sphere_b: tuple[Tensor, Tensor],
|
|
||||||
save_path: Path,
|
|
||||||
title: str = "AS-Mamba Trajectories",
|
|
||||||
) -> None:
|
|
||||||
traj = traj.detach().cpu()
|
|
||||||
if traj.dim() == 2:
|
|
||||||
traj = traj.unsqueeze(0)
|
|
||||||
traj_np = traj.numpy()
|
|
||||||
|
|
||||||
fig = plt.figure(figsize=(7, 6))
|
|
||||||
ax = fig.add_subplot(111, projection="3d")
|
|
||||||
|
|
||||||
for i in range(traj_np.shape[0]):
|
|
||||||
ax.plot(
|
|
||||||
traj_np[i, :, 0],
|
|
||||||
traj_np[i, :, 1],
|
|
||||||
traj_np[i, :, 2],
|
|
||||||
color="green",
|
|
||||||
alpha=0.6,
|
|
||||||
)
|
|
||||||
|
|
||||||
starts = traj_np[:, 0, :]
|
|
||||||
ends = traj_np[:, -1, :]
|
|
||||||
ax.scatter(starts[:, 0], starts[:, 1], starts[:, 2], color="blue", s=20, label="Start")
|
|
||||||
ax.scatter(ends[:, 0], ends[:, 1], ends[:, 2], color="red", s=20, label="End")
|
|
||||||
|
|
||||||
center_a, radius_a = sphere_a
|
|
||||||
center_b, radius_b = sphere_b
|
|
||||||
x_a, y_a, z_a = sphere_wireframe(center_a, float(radius_a.item()))
|
|
||||||
x_b, y_b, z_b = sphere_wireframe(center_b, float(radius_b.item()))
|
|
||||||
ax.plot_wireframe(x_a, y_a, z_a, color="blue", alpha=0.15, linewidth=0.5)
|
|
||||||
ax.plot_wireframe(x_b, y_b, z_b, color="red", alpha=0.15, linewidth=0.5)
|
|
||||||
|
|
||||||
ax.set_title(title)
|
|
||||||
ax.set_xlabel("X")
|
|
||||||
ax.set_ylabel("Y")
|
|
||||||
ax.set_zlabel("Z")
|
|
||||||
ax.legend(loc="best")
|
|
||||||
fig.tight_layout()
|
|
||||||
fig.savefig(save_path, dpi=160)
|
|
||||||
plt.close(fig)
|
|
||||||
|
|
||||||
|
|
||||||
def run_training_and_plot(cfg: TrainConfig) -> Path:
|
def run_training_and_plot(cfg: TrainConfig) -> Path:
|
||||||
model, sphere_a, sphere_b = train(cfg)
|
train(cfg)
|
||||||
device = next(model.parameters()).device
|
return Path(cfg.output_dir)
|
||||||
|
|
||||||
plot_samples = max(1, cfg.val_plot_samples)
|
|
||||||
x0 = sample_points_in_sphere(
|
|
||||||
sphere_a[0], float(sphere_a[1].item()), plot_samples, device
|
|
||||||
)
|
|
||||||
traj = rollout_trajectory(model, x0, max_steps=cfg.val_max_steps)
|
|
||||||
output_dir = Path(cfg.output_dir)
|
|
||||||
save_path = output_dir / "as_mamba_trajectory.png"
|
|
||||||
plot_trajectories(traj, sphere_a, sphere_b, save_path)
|
|
||||||
return save_path
|
|
||||||
|
|||||||
39
main.py
39
main.py
@@ -4,25 +4,40 @@ from as_mamba import TrainConfig, run_training_and_plot
|
|||||||
|
|
||||||
|
|
||||||
def build_parser() -> argparse.ArgumentParser:
|
def build_parser() -> argparse.ArgumentParser:
|
||||||
parser = argparse.ArgumentParser(description="Train AS-Mamba on sphere-to-sphere flow.")
|
parser = argparse.ArgumentParser(description="Train AS-Mamba on MNIST MeanFlow x-prediction.")
|
||||||
parser.add_argument("--epochs", type=int, default=None)
|
parser.add_argument("--epochs", type=int, default=None)
|
||||||
parser.add_argument("--warmup-epochs", type=int, default=None)
|
|
||||||
parser.add_argument("--batch-size", type=int, default=None)
|
|
||||||
parser.add_argument("--steps-per-epoch", type=int, default=None)
|
parser.add_argument("--steps-per-epoch", type=int, default=None)
|
||||||
|
parser.add_argument("--batch-size", type=int, default=None)
|
||||||
parser.add_argument("--seq-len", type=int, default=None)
|
parser.add_argument("--seq-len", type=int, default=None)
|
||||||
parser.add_argument("--lr", type=float, default=None)
|
parser.add_argument("--lr", type=float, default=None)
|
||||||
|
parser.add_argument("--weight-decay", type=float, default=None)
|
||||||
|
parser.add_argument("--max-grad-norm", type=float, default=None)
|
||||||
parser.add_argument("--device", type=str, default=None)
|
parser.add_argument("--device", type=str, default=None)
|
||||||
parser.add_argument("--output-dir", type=str, default=None)
|
parser.add_argument("--output-dir", type=str, default=None)
|
||||||
parser.add_argument("--project", type=str, default=None)
|
parser.add_argument("--project", type=str, default=None)
|
||||||
parser.add_argument("--run-name", type=str, default=None)
|
parser.add_argument("--run-name", type=str, default=None)
|
||||||
|
parser.add_argument("--lambda-flow", type=float, default=None)
|
||||||
|
parser.add_argument("--lambda-perceptual", type=float, default=None)
|
||||||
|
parser.add_argument("--num-classes", type=int, default=None)
|
||||||
|
parser.add_argument("--image-size", type=int, default=None)
|
||||||
|
parser.add_argument("--channels", type=int, default=None)
|
||||||
|
parser.add_argument("--num-workers", type=int, default=None)
|
||||||
|
parser.add_argument("--dataset-name", type=str, default=None)
|
||||||
|
parser.add_argument("--dataset-split", type=str, default=None)
|
||||||
|
parser.add_argument("--d-model", type=int, default=None)
|
||||||
|
parser.add_argument("--n-layer", type=int, default=None)
|
||||||
|
parser.add_argument("--d-state", type=int, default=None)
|
||||||
|
parser.add_argument("--d-conv", type=int, default=None)
|
||||||
|
parser.add_argument("--expand", type=int, default=None)
|
||||||
|
parser.add_argument("--headdim", type=int, default=None)
|
||||||
|
parser.add_argument("--chunk-size", type=int, default=None)
|
||||||
|
parser.add_argument("--use-residual", action=argparse.BooleanOptionalAction, default=None)
|
||||||
parser.add_argument("--val-every", type=int, default=None)
|
parser.add_argument("--val-every", type=int, default=None)
|
||||||
parser.add_argument("--val-samples", type=int, default=None)
|
parser.add_argument("--val-samples-per-class", type=int, default=None)
|
||||||
parser.add_argument("--val-plot-samples", type=int, default=None)
|
parser.add_argument("--val-grid-rows", type=int, default=None)
|
||||||
parser.add_argument("--val-max-steps", type=int, default=None)
|
parser.add_argument("--val-sampling-steps", type=int, default=None)
|
||||||
parser.add_argument("--center-min", type=float, default=None)
|
parser.add_argument("--time-grid-size", type=int, default=None)
|
||||||
parser.add_argument("--center-max", type=float, default=None)
|
parser.add_argument("--use-ddp", action=argparse.BooleanOptionalAction, default=None)
|
||||||
parser.add_argument("--center-distance-min", type=float, default=None)
|
|
||||||
parser.add_argument("--use-residual", action="store_true")
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@@ -35,8 +50,8 @@ def main() -> None:
|
|||||||
if value is not None:
|
if value is not None:
|
||||||
setattr(cfg, key.replace("-", "_"), value)
|
setattr(cfg, key.replace("-", "_"), value)
|
||||||
|
|
||||||
plot_path = run_training_and_plot(cfg)
|
out_path = run_training_and_plot(cfg)
|
||||||
print(f"Saved trajectory plot to {plot_path}")
|
print(f"Saved outputs to {out_path}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -5,9 +5,12 @@ description = "Add your description here"
|
|||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.12"
|
requires-python = ">=3.12"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"datasets>=2.19.0",
|
||||||
"einops>=0.7.0",
|
"einops>=0.7.0",
|
||||||
|
"lpips>=0.1.4",
|
||||||
"matplotlib>=3.8.0",
|
"matplotlib>=3.8.0",
|
||||||
"numpy>=1.26.0",
|
"numpy>=1.26.0",
|
||||||
"swanlab>=0.5.0",
|
"swanlab>=0.5.0",
|
||||||
"torch>=2.2.0",
|
"torch>=2.2.0",
|
||||||
|
"torchvision>=0.24.1",
|
||||||
]
|
]
|
||||||
|
|||||||
74
train_as_mamba.sh
Executable file
74
train_as_mamba.sh
Executable file
@@ -0,0 +1,74 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
DEVICE="cuda"
|
||||||
|
EPOCHS=2000
|
||||||
|
STEPS_PER_EPOCH=200
|
||||||
|
BATCH_SIZE=512
|
||||||
|
SEQ_LEN=5
|
||||||
|
LR=1e-3
|
||||||
|
WEIGHT_DECAY=1e-2
|
||||||
|
LAMBDA_FLOW=1.0
|
||||||
|
LAMBDA_PERCEPTUAL=0.4
|
||||||
|
NUM_CLASSES=10
|
||||||
|
IMAGE_SIZE=28
|
||||||
|
CHANNELS=1
|
||||||
|
NUM_WORKERS=32
|
||||||
|
DATASET_NAME="ylecun/mnist"
|
||||||
|
DATASET_SPLIT="train"
|
||||||
|
D_MODEL=784
|
||||||
|
N_LAYER=8
|
||||||
|
D_STATE=32
|
||||||
|
D_CONV=4
|
||||||
|
EXPAND=2
|
||||||
|
HEADDIM=32
|
||||||
|
CHUNK_SIZE=1
|
||||||
|
USE_RESIDUAL=true
|
||||||
|
USE_DDP=true
|
||||||
|
VAL_EVERY=1000
|
||||||
|
VAL_SAMPLES_PER_CLASS=8
|
||||||
|
VAL_GRID_ROWS=4
|
||||||
|
VAL_SAMPLING_STEPS=5
|
||||||
|
TIME_GRID_SIZE=256
|
||||||
|
PROJECT="as-mamba-mnist"
|
||||||
|
RUN_NAME="mnist-meanflow-xpred"
|
||||||
|
OUTPUT_DIR="outputs"
|
||||||
|
|
||||||
|
USE_RESIDUAL_FLAG="--use-residual"
|
||||||
|
if [ "${USE_RESIDUAL}" = "false" ]; then USE_RESIDUAL_FLAG="--no-use-residual"; fi
|
||||||
|
USE_DDP_FLAG="--use-ddp"
|
||||||
|
if [ "${USE_DDP}" = "false" ]; then USE_DDP_FLAG="--no-use-ddp"; fi
|
||||||
|
|
||||||
|
uv run torchrun --nproc_per_node=2 main.py \
|
||||||
|
--device "${DEVICE}" \
|
||||||
|
--epochs "${EPOCHS}" \
|
||||||
|
--steps-per-epoch "${STEPS_PER_EPOCH}" \
|
||||||
|
--batch-size "${BATCH_SIZE}" \
|
||||||
|
--seq-len "${SEQ_LEN}" \
|
||||||
|
--lr "${LR}" \
|
||||||
|
--weight-decay "${WEIGHT_DECAY}" \
|
||||||
|
--lambda-flow "${LAMBDA_FLOW}" \
|
||||||
|
--lambda-perceptual "${LAMBDA_PERCEPTUAL}" \
|
||||||
|
--num-classes "${NUM_CLASSES}" \
|
||||||
|
--image-size "${IMAGE_SIZE}" \
|
||||||
|
--channels "${CHANNELS}" \
|
||||||
|
--num-workers "${NUM_WORKERS}" \
|
||||||
|
--dataset-name "${DATASET_NAME}" \
|
||||||
|
--dataset-split "${DATASET_SPLIT}" \
|
||||||
|
--d-model "${D_MODEL}" \
|
||||||
|
--n-layer "${N_LAYER}" \
|
||||||
|
--d-state "${D_STATE}" \
|
||||||
|
--d-conv "${D_CONV}" \
|
||||||
|
--expand "${EXPAND}" \
|
||||||
|
--headdim "${HEADDIM}" \
|
||||||
|
--chunk-size "${CHUNK_SIZE}" \
|
||||||
|
${USE_RESIDUAL_FLAG} \
|
||||||
|
${USE_DDP_FLAG} \
|
||||||
|
--val-every "${VAL_EVERY}" \
|
||||||
|
--val-samples-per-class "${VAL_SAMPLES_PER_CLASS}" \
|
||||||
|
--val-grid-rows "${VAL_GRID_ROWS}" \
|
||||||
|
--val-sampling-steps "${VAL_SAMPLING_STEPS}" \
|
||||||
|
--time-grid-size "${TIME_GRID_SIZE}" \
|
||||||
|
--project "${PROJECT}" \
|
||||||
|
--run-name "${RUN_NAME}" \
|
||||||
|
--output-dir "${OUTPUT_DIR}"
|
||||||
Reference in New Issue
Block a user