feat: migrate switch to conditional flow matching from sphere trajectory
This commit is contained in:
717
as_mamba.py
717
as_mamba.py
@@ -1,8 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from dataclasses import asdict, dataclass
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Iterator, Optional
|
||||
|
||||
import matplotlib
|
||||
|
||||
@@ -11,9 +13,11 @@ matplotlib.use("Agg")
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from datasets import load_dataset
|
||||
from matplotlib import pyplot as plt
|
||||
from mpl_toolkits.mplot3d import Axes3D # noqa: F401
|
||||
from torch import Tensor, nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from mamba2_minimal import InferenceCache, Mamba2, Mamba2Config, RMSNorm
|
||||
|
||||
@@ -22,28 +26,29 @@ from mamba2_minimal import InferenceCache, Mamba2, Mamba2Config, RMSNorm
|
||||
class TrainConfig:
|
||||
seed: int = 42
|
||||
device: str = "cuda"
|
||||
epochs: int = 50
|
||||
steps_per_epoch: int = 200
|
||||
batch_size: int = 128
|
||||
steps_per_epoch: int = 50
|
||||
epochs: int = 60
|
||||
seq_len: int = 20
|
||||
lr: float = 1e-3
|
||||
lr: float = 2e-4
|
||||
weight_decay: float = 1e-2
|
||||
dt_min: float = 1e-3
|
||||
dt_max: float = 0.06
|
||||
dt_alpha: float = 8.0
|
||||
lambda_flow: float = 1.0
|
||||
lambda_pos: float = 1.0
|
||||
lambda_dt: float = 0.05
|
||||
lambda_dt: float = 1.0
|
||||
use_flow_loss: bool = True
|
||||
use_pos_loss: bool = False
|
||||
use_dt_loss: bool = True
|
||||
radius_min: float = 0.6
|
||||
radius_max: float = 1.4
|
||||
center_min: float = -6.0
|
||||
center_max: float = 6.0
|
||||
center_distance_min: float = 6.0
|
||||
d_model: int = 128
|
||||
n_layer: int = 4
|
||||
num_classes: int = 10
|
||||
image_size: int = 28
|
||||
channels: int = 1
|
||||
num_workers: int = 8
|
||||
dataset_name: str = "ylecun/mnist"
|
||||
dataset_split: str = "train"
|
||||
d_model: int = 0
|
||||
n_layer: int = 6
|
||||
d_state: int = 64
|
||||
d_conv: int = 4
|
||||
expand: int = 2
|
||||
@@ -51,12 +56,13 @@ class TrainConfig:
|
||||
chunk_size: int = 1
|
||||
use_residual: bool = False
|
||||
output_dir: str = "outputs"
|
||||
project: str = "as-mamba"
|
||||
run_name: str = "sphere-to-sphere"
|
||||
project: str = "as-mamba-mnist"
|
||||
run_name: str = "mnist-flow"
|
||||
val_every: int = 200
|
||||
val_samples: int = 256
|
||||
val_plot_samples: int = 16
|
||||
val_samples_per_class: int = 8
|
||||
val_grid_rows: int = 4
|
||||
val_max_steps: int = 0
|
||||
use_ddp: bool = False
|
||||
|
||||
|
||||
class AdaLNZero(nn.Module):
|
||||
@@ -93,15 +99,10 @@ class Mamba2Backbone(nn.Module):
|
||||
self.norm_f = RMSNorm(args.d_model)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
cond: Optional[Tensor] = None,
|
||||
h: Optional[list[InferenceCache]] = None,
|
||||
self, x: Tensor, cond: Tensor, h: Optional[list[InferenceCache]] = None
|
||||
) -> tuple[Tensor, list[InferenceCache]]:
|
||||
if h is None:
|
||||
h = [None for _ in range(self.args.n_layer)]
|
||||
if cond is None:
|
||||
cond = torch.zeros(x.shape[0], x.shape[-1], device=x.device, dtype=x.dtype)
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
x_mod = layer["adaln"](x, cond)
|
||||
@@ -118,6 +119,13 @@ class ASMamba(nn.Module):
|
||||
self.cfg = cfg
|
||||
self.dt_min = float(cfg.dt_min)
|
||||
self.dt_max = float(cfg.dt_max)
|
||||
input_dim = cfg.channels * cfg.image_size * cfg.image_size
|
||||
if cfg.d_model == 0:
|
||||
cfg.d_model = input_dim
|
||||
if cfg.d_model != input_dim:
|
||||
raise ValueError(
|
||||
f"d_model must equal flattened image dim ({input_dim}) when input_proj is disabled."
|
||||
)
|
||||
|
||||
args = Mamba2Config(
|
||||
d_model=cfg.d_model,
|
||||
@@ -129,9 +137,8 @@ class ASMamba(nn.Module):
|
||||
chunk_size=cfg.chunk_size,
|
||||
)
|
||||
self.backbone = Mamba2Backbone(args, use_residual=cfg.use_residual)
|
||||
self.input_proj = nn.Linear(3, cfg.d_model)
|
||||
self.cond_emb = nn.Embedding(2, cfg.d_model)
|
||||
self.delta_head = nn.Linear(cfg.d_model, 3)
|
||||
self.cond_emb = nn.Embedding(cfg.num_classes, cfg.d_model)
|
||||
self.delta_head = nn.Linear(cfg.d_model, input_dim)
|
||||
self.dt_head = nn.Sequential(
|
||||
nn.Linear(cfg.d_model, cfg.d_model),
|
||||
nn.SiLU(),
|
||||
@@ -141,9 +148,8 @@ class ASMamba(nn.Module):
|
||||
def forward(
|
||||
self, x: Tensor, cond: Tensor, h: Optional[list[InferenceCache]] = None
|
||||
) -> tuple[Tensor, Tensor, list[InferenceCache]]:
|
||||
x_proj = self.input_proj(x)
|
||||
cond_vec = self.cond_emb(cond)
|
||||
feats, h = self.backbone(x_proj, cond_vec, h)
|
||||
feats, h = self.backbone(x, cond_vec, h)
|
||||
delta = self.delta_head(feats)
|
||||
dt_raw = self.dt_head(feats).squeeze(-1)
|
||||
dt = torch.clamp(F.softplus(dt_raw), min=self.dt_min, max=self.dt_max)
|
||||
@@ -163,10 +169,12 @@ class ASMamba(nn.Module):
|
||||
|
||||
|
||||
class SwanLogger:
|
||||
def __init__(self, cfg: TrainConfig) -> None:
|
||||
self.enabled = False
|
||||
def __init__(self, cfg: TrainConfig, enabled: bool = True) -> None:
|
||||
self.enabled = enabled
|
||||
self._swan = None
|
||||
self._run = None
|
||||
if not self.enabled:
|
||||
return
|
||||
try:
|
||||
import swanlab # type: ignore
|
||||
|
||||
@@ -195,7 +203,11 @@ class SwanLogger:
|
||||
target.log(payload)
|
||||
|
||||
def log_image(
|
||||
self, key: str, image_path: Path, caption: str | None = None, step: int | None = None
|
||||
self,
|
||||
key: str,
|
||||
image_path: Path,
|
||||
caption: str | None = None,
|
||||
step: int | None = None,
|
||||
) -> None:
|
||||
if not self.enabled:
|
||||
return
|
||||
@@ -220,55 +232,42 @@ def set_seed(seed: int) -> None:
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def sample_points_in_sphere(
|
||||
center: Tensor, radius: float, batch_size: int, device: torch.device
|
||||
) -> Tensor:
|
||||
direction = torch.randn(batch_size, 3, device=device)
|
||||
direction = direction / (direction.norm(dim=-1, keepdim=True) + 1e-8)
|
||||
u = torch.rand(batch_size, 1, device=device)
|
||||
r = radius * torch.pow(u, 1.0 / 3.0)
|
||||
return center + direction * r
|
||||
def setup_distributed(cfg: TrainConfig) -> tuple[bool, int, int, torch.device]:
|
||||
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
||||
rank = int(os.environ.get("RANK", "0"))
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
||||
use_ddp = cfg.use_ddp and world_size > 1
|
||||
if use_ddp:
|
||||
torch.distributed.init_process_group(backend="nccl", init_method="env://")
|
||||
torch.cuda.set_device(local_rank)
|
||||
device = torch.device("cuda", local_rank)
|
||||
else:
|
||||
device = torch.device(cfg.device if torch.cuda.is_available() else "cpu")
|
||||
return use_ddp, rank, world_size, device
|
||||
|
||||
|
||||
def sample_center(cfg: TrainConfig, device: torch.device) -> Tensor:
|
||||
return torch.empty(3, device=device).uniform_(cfg.center_min, cfg.center_max)
|
||||
def unwrap_model(model: nn.Module) -> nn.Module:
|
||||
return model.module if hasattr(model, "module") else model
|
||||
|
||||
|
||||
def sample_center_far(
|
||||
cfg: TrainConfig, device: torch.device, refs: list[Tensor]
|
||||
) -> Tensor:
|
||||
center = sample_center(cfg, device)
|
||||
for _ in range(256):
|
||||
if all(torch.norm(center - ref) >= cfg.center_distance_min for ref in refs):
|
||||
return center
|
||||
center = sample_center(cfg, device)
|
||||
return center
|
||||
|
||||
|
||||
def sample_spheres_params(
|
||||
cfg: TrainConfig, device: torch.device
|
||||
) -> tuple[tuple[Tensor, Tensor], tuple[Tensor, Tensor], tuple[Tensor, Tensor]]:
|
||||
center_a = sample_center(cfg, device)
|
||||
center_b0 = sample_center_far(cfg, device, [center_a])
|
||||
center_b1 = sample_center_far(cfg, device, [center_a, center_b0])
|
||||
if torch.norm(center_a - center_b0) < 1e-3:
|
||||
center_b0 = center_b0 + torch.tensor(
|
||||
[cfg.center_distance_min, 0.0, 0.0], device=device
|
||||
def validate_time_config(cfg: TrainConfig) -> None:
|
||||
if cfg.seq_len <= 0:
|
||||
raise ValueError("seq_len must be > 0")
|
||||
base = 1.0 / cfg.seq_len
|
||||
if cfg.dt_max <= base:
|
||||
raise ValueError(
|
||||
"dt_max must be > 1/seq_len to allow non-uniform dt_seq. "
|
||||
f"Got dt_max={cfg.dt_max}, seq_len={cfg.seq_len}, 1/seq_len={base}."
|
||||
)
|
||||
if torch.norm(center_a - center_b1) < 1e-3:
|
||||
center_b1 = center_b1 + torch.tensor(
|
||||
[-cfg.center_distance_min, 0.0, 0.0], device=device
|
||||
if cfg.dt_min >= cfg.dt_max:
|
||||
raise ValueError(
|
||||
f"dt_min must be < dt_max (got dt_min={cfg.dt_min}, dt_max={cfg.dt_max})."
|
||||
)
|
||||
radius_a = float(torch.empty(1).uniform_(cfg.radius_min, cfg.radius_max).item())
|
||||
radius_b0 = float(torch.empty(1).uniform_(cfg.radius_min, cfg.radius_max).item())
|
||||
radius_b1 = float(torch.empty(1).uniform_(cfg.radius_min, cfg.radius_max).item())
|
||||
sphere_a = (center_a, torch.tensor(radius_a, device=device))
|
||||
sphere_b0 = (center_b0, torch.tensor(radius_b0, device=device))
|
||||
sphere_b1 = (center_b1, torch.tensor(radius_b1, device=device))
|
||||
return sphere_a, sphere_b0, sphere_b1
|
||||
|
||||
|
||||
def sample_time_sequence(cfg: TrainConfig, batch_size: int, device: torch.device) -> Tensor:
|
||||
def sample_time_sequence(
|
||||
cfg: TrainConfig, batch_size: int, device: torch.device
|
||||
) -> Tensor:
|
||||
alpha = float(cfg.dt_alpha)
|
||||
if alpha <= 0:
|
||||
raise ValueError("dt_alpha must be > 0")
|
||||
@@ -287,29 +286,48 @@ def sample_time_sequence(cfg: TrainConfig, batch_size: int, device: torch.device
|
||||
return dt_seq
|
||||
|
||||
|
||||
def sample_batch(
|
||||
cfg: TrainConfig,
|
||||
sphere_a: tuple[Tensor, Tensor],
|
||||
sphere_b0: tuple[Tensor, Tensor],
|
||||
sphere_b1: tuple[Tensor, Tensor],
|
||||
device: torch.device,
|
||||
) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
|
||||
center_a, radius_a = sphere_a
|
||||
x0 = sample_points_in_sphere(center_a, float(radius_a.item()), cfg.batch_size, device)
|
||||
cond = torch.randint(0, 2, (cfg.batch_size,), device=device, dtype=torch.long)
|
||||
x1_0 = sample_points_in_sphere(
|
||||
sphere_b0[0], float(sphere_b0[1].item()), cfg.batch_size, device
|
||||
def build_dataloader(
|
||||
cfg: TrainConfig, distributed: bool = False
|
||||
) -> tuple[DataLoader, Optional[DistributedSampler]]:
|
||||
ds = load_dataset(cfg.dataset_name, split=cfg.dataset_split)
|
||||
|
||||
def transform(example):
|
||||
image = example.get("img", example.get("image"))
|
||||
label = example.get("label", example.get("labels"))
|
||||
if isinstance(image, list):
|
||||
arr = np.stack([np.array(im, dtype=np.float32) for im in image], axis=0)
|
||||
arr = arr / 127.5 - 1.0
|
||||
if arr.ndim == 3:
|
||||
tensor = torch.from_numpy(arr).unsqueeze(1)
|
||||
else:
|
||||
tensor = torch.from_numpy(arr).permute(0, 3, 1, 2)
|
||||
labels = torch.tensor(label, dtype=torch.long)
|
||||
return {"pixel_values": tensor, "labels": labels}
|
||||
arr = np.array(image, dtype=np.float32) / 127.5 - 1.0
|
||||
if arr.ndim == 2:
|
||||
tensor = torch.from_numpy(arr).unsqueeze(0)
|
||||
else:
|
||||
tensor = torch.from_numpy(arr).permute(2, 0, 1)
|
||||
return {"pixel_values": tensor, "labels": torch.tensor(label, dtype=torch.long)}
|
||||
|
||||
ds = ds.with_transform(transform)
|
||||
sampler = DistributedSampler(ds, shuffle=True) if distributed else None
|
||||
loader = DataLoader(
|
||||
ds,
|
||||
batch_size=cfg.batch_size,
|
||||
shuffle=(sampler is None),
|
||||
sampler=sampler,
|
||||
num_workers=cfg.num_workers,
|
||||
drop_last=True,
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
)
|
||||
x1_1 = sample_points_in_sphere(
|
||||
sphere_b1[0], float(sphere_b1[1].item()), cfg.batch_size, device
|
||||
)
|
||||
x1 = torch.where(cond[:, None] == 0, x1_0, x1_1)
|
||||
v_gt = x1 - x0
|
||||
dt_seq = sample_time_sequence(cfg, cfg.batch_size, device)
|
||||
t_seq = torch.cumsum(dt_seq, dim=-1)
|
||||
t_seq = torch.cat([torch.zeros(cfg.batch_size, 1, device=device), t_seq[:, :-1]], dim=-1)
|
||||
x_seq = x0[:, None, :] + t_seq[:, :, None] * v_gt[:, None, :]
|
||||
return x0, x1, x_seq, t_seq, dt_seq, cond
|
||||
return loader, sampler
|
||||
|
||||
|
||||
def infinite_loader(loader: DataLoader) -> Iterator[dict]:
|
||||
while True:
|
||||
for batch in loader:
|
||||
yield batch
|
||||
|
||||
|
||||
def compute_losses(
|
||||
@@ -341,228 +359,54 @@ def compute_losses(
|
||||
return losses
|
||||
|
||||
|
||||
def validate(
|
||||
model: ASMamba,
|
||||
cfg: TrainConfig,
|
||||
sphere_a: tuple[Tensor, Tensor],
|
||||
sphere_b0: tuple[Tensor, Tensor],
|
||||
sphere_b1: tuple[Tensor, Tensor],
|
||||
device: torch.device,
|
||||
logger: SwanLogger,
|
||||
step: int,
|
||||
def plot_dt_hist(
|
||||
dt_pred: Tensor, dt_gt: Tensor, save_path: Path, title: str = "dt Distribution"
|
||||
) -> None:
|
||||
model.eval()
|
||||
center_b0, radius_b0 = sphere_b0
|
||||
center_b1, radius_b1 = sphere_b1
|
||||
max_steps = cfg.seq_len if cfg.val_max_steps <= 0 else cfg.val_max_steps
|
||||
dt_pred_np = dt_pred.detach().cpu().numpy().reshape(-1)
|
||||
dt_gt_np = dt_gt.detach().cpu().numpy().reshape(-1)
|
||||
|
||||
with torch.no_grad():
|
||||
x0 = sample_points_in_sphere(
|
||||
sphere_a[0], float(sphere_a[1].item()), cfg.val_samples, device
|
||||
)
|
||||
cond = torch.randint(0, 2, (cfg.val_samples,), device=device, dtype=torch.long)
|
||||
traj = rollout_trajectory(model, x0, cond, max_steps=max_steps)
|
||||
|
||||
x_final = traj[:, -1, :]
|
||||
center_b0_cpu = center_b0.detach().cpu()
|
||||
center_b1_cpu = center_b1.detach().cpu()
|
||||
radius_b0_cpu = radius_b0.detach().cpu()
|
||||
radius_b1_cpu = radius_b1.detach().cpu()
|
||||
cond_cpu = cond.detach().cpu()
|
||||
target_center = torch.where(
|
||||
cond_cpu[:, None] == 0, center_b0_cpu.unsqueeze(0), center_b1_cpu.unsqueeze(0)
|
||||
)
|
||||
target_radius = torch.where(cond_cpu == 0, radius_b0_cpu, radius_b1_cpu)
|
||||
dist = torch.linalg.norm(x_final - target_center, dim=-1)
|
||||
inside = dist <= target_radius
|
||||
mask0 = cond_cpu == 0
|
||||
mask1 = cond_cpu == 1
|
||||
inside0 = inside[mask0]
|
||||
inside1 = inside[mask1]
|
||||
ratio0 = float(inside0.float().mean().item()) if inside0.numel() > 0 else 0.0
|
||||
ratio1 = float(inside1.float().mean().item()) if inside1.numel() > 0 else 0.0
|
||||
|
||||
logger.log(
|
||||
{
|
||||
"val/inside_ratio": float(inside.float().mean().item()),
|
||||
"val/inside_ratio_c0": ratio0,
|
||||
"val/inside_ratio_c1": ratio1,
|
||||
"val/cond0_count": float(mask0.float().sum().item()),
|
||||
"val/cond1_count": float(mask1.float().sum().item()),
|
||||
"val/inside_count": float(inside.float().sum().item()),
|
||||
"val/final_dist_mean": float(dist.mean().item()),
|
||||
"val/final_dist_min": float(dist.min().item()),
|
||||
"val/final_dist_max": float(dist.max().item()),
|
||||
"val/max_steps": float(max_steps),
|
||||
},
|
||||
step=step,
|
||||
)
|
||||
|
||||
if cfg.val_plot_samples > 0:
|
||||
count = min(cfg.val_plot_samples, traj.shape[0])
|
||||
if count == 0:
|
||||
model.train()
|
||||
return
|
||||
indices = torch.linspace(0, traj.shape[0] - 1, steps=count).long()
|
||||
traj_plot = traj[indices]
|
||||
cond_plot = cond_cpu[indices]
|
||||
save_path = Path(cfg.output_dir) / f"val_traj_step_{step:06d}.png"
|
||||
plot_trajectories_cond(
|
||||
traj_plot,
|
||||
cond_plot,
|
||||
sphere_a,
|
||||
sphere_b0,
|
||||
sphere_b1,
|
||||
save_path,
|
||||
title=f"Validation Trajectories (step {step})",
|
||||
)
|
||||
ratio = float(inside.float().mean().item())
|
||||
logger.log_image(
|
||||
"val/trajectory",
|
||||
save_path,
|
||||
caption=f"step {step} | inside_ratio={ratio:.3f}",
|
||||
step=step,
|
||||
)
|
||||
|
||||
model.train()
|
||||
fig, ax = plt.subplots(figsize=(6, 4))
|
||||
ax.hist(dt_gt_np, bins=30, alpha=0.6, label="dt_gt", color="steelblue")
|
||||
ax.hist(dt_pred_np, bins=30, alpha=0.6, label="dt_pred", color="orange")
|
||||
ax.set_title(title)
|
||||
ax.set_xlabel("dt")
|
||||
ax.set_ylabel("count")
|
||||
ax.legend(loc="best")
|
||||
fig.tight_layout()
|
||||
fig.savefig(save_path, dpi=160)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def train(
|
||||
cfg: TrainConfig,
|
||||
) -> tuple[ASMamba, tuple[Tensor, Tensor], tuple[Tensor, Tensor], tuple[Tensor, Tensor]]:
|
||||
device = torch.device(cfg.device if torch.cuda.is_available() else "cpu")
|
||||
set_seed(cfg.seed)
|
||||
output_dir = Path(cfg.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
def make_grid(images: Tensor, nrow: int) -> np.ndarray:
|
||||
images = images.detach().cpu().numpy()
|
||||
b, c, h, w = images.shape
|
||||
nrow = max(1, min(nrow, b))
|
||||
ncol = math.ceil(b / nrow)
|
||||
grid = np.zeros((c, ncol * h, nrow * w), dtype=np.float32)
|
||||
for idx in range(b):
|
||||
r = idx // nrow
|
||||
cidx = idx % nrow
|
||||
grid[:, r * h : (r + 1) * h, cidx * w : (cidx + 1) * w] = images[idx]
|
||||
return np.transpose(grid, (1, 2, 0))
|
||||
|
||||
model = ASMamba(cfg).to(device)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
|
||||
logger = SwanLogger(cfg)
|
||||
|
||||
sphere_a, sphere_b0, sphere_b1 = sample_spheres_params(cfg, device)
|
||||
center_a, radius_a = sphere_a
|
||||
center_b0, radius_b0 = sphere_b0
|
||||
center_b1, radius_b1 = sphere_b1
|
||||
dist_a_b0 = torch.norm(center_a - center_b0).item()
|
||||
dist_a_b1 = torch.norm(center_a - center_b1).item()
|
||||
dist_b0_b1 = torch.norm(center_b0 - center_b1).item()
|
||||
logger.log(
|
||||
{
|
||||
"sphere_a/radius": float(radius_a.item()),
|
||||
"sphere_a/center_x": float(center_a[0].item()),
|
||||
"sphere_a/center_y": float(center_a[1].item()),
|
||||
"sphere_a/center_z": float(center_a[2].item()),
|
||||
"sphere_b0/radius": float(radius_b0.item()),
|
||||
"sphere_b0/center_x": float(center_b0[0].item()),
|
||||
"sphere_b0/center_y": float(center_b0[1].item()),
|
||||
"sphere_b0/center_z": float(center_b0[2].item()),
|
||||
"sphere_b1/radius": float(radius_b1.item()),
|
||||
"sphere_b1/center_x": float(center_b1[0].item()),
|
||||
"sphere_b1/center_y": float(center_b1[1].item()),
|
||||
"sphere_b1/center_z": float(center_b1[2].item()),
|
||||
"sphere/dist_a_b0": float(dist_a_b0),
|
||||
"sphere/dist_a_b1": float(dist_a_b1),
|
||||
"sphere/dist_b0_b1": float(dist_b0_b1),
|
||||
}
|
||||
)
|
||||
|
||||
global_step = 0
|
||||
for epoch in range(cfg.epochs):
|
||||
model.train()
|
||||
|
||||
for _ in range(cfg.steps_per_epoch):
|
||||
x0, x1, x_seq, t_seq, dt_seq, cond = sample_batch(
|
||||
cfg, sphere_a, sphere_b0, sphere_b1, device
|
||||
)
|
||||
v_gt = x1 - x0
|
||||
|
||||
delta, dt, _ = model(x_seq, cond)
|
||||
|
||||
losses = compute_losses(
|
||||
delta=delta,
|
||||
dt=dt,
|
||||
x_seq=x_seq,
|
||||
x0=x0,
|
||||
v_gt=v_gt,
|
||||
t_seq=t_seq,
|
||||
dt_seq=dt_seq,
|
||||
cfg=cfg,
|
||||
)
|
||||
|
||||
loss = torch.tensor(0.0, device=device)
|
||||
if cfg.use_flow_loss and "flow" in losses:
|
||||
loss = loss + cfg.lambda_flow * losses["flow"]
|
||||
if cfg.use_pos_loss and "pos" in losses:
|
||||
loss = loss + cfg.lambda_pos * losses["pos"]
|
||||
if cfg.use_dt_loss and "dt" in losses:
|
||||
loss = loss + cfg.lambda_dt * losses["dt"]
|
||||
if loss.item() == 0.0:
|
||||
raise RuntimeError("No loss enabled: enable at least one of flow/pos/dt.")
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if global_step % 10 == 0:
|
||||
dt_min = float(dt.min().item())
|
||||
dt_max = float(dt.max().item())
|
||||
dt_mean = float(dt.mean().item())
|
||||
dt_gt_min = float(dt_seq.min().item())
|
||||
dt_gt_max = float(dt_seq.max().item())
|
||||
dt_gt_mean = float(dt_seq.mean().item())
|
||||
eps = 1e-6
|
||||
clamp_min_ratio = float((dt <= cfg.dt_min + eps).float().mean().item())
|
||||
clamp_max_ratio = float((dt >= cfg.dt_max - eps).float().mean().item())
|
||||
clamp_any_ratio = float(
|
||||
((dt <= cfg.dt_min + eps) | (dt >= cfg.dt_max - eps)).float().mean().item()
|
||||
)
|
||||
logger.log(
|
||||
{
|
||||
"loss/total": float(loss.item()),
|
||||
"loss/flow": float(losses.get("flow", torch.tensor(0.0)).item()),
|
||||
"loss/pos": float(losses.get("pos", torch.tensor(0.0)).item()),
|
||||
"loss/dt": float(losses.get("dt", torch.tensor(0.0)).item()),
|
||||
"dt/pred_mean": dt_mean,
|
||||
"dt/pred_min": dt_min,
|
||||
"dt/pred_max": dt_max,
|
||||
"dt/gt_mean": dt_gt_mean,
|
||||
"dt/gt_min": dt_gt_min,
|
||||
"dt/gt_max": dt_gt_max,
|
||||
"dt/clamp_min_ratio": clamp_min_ratio,
|
||||
"dt/clamp_max_ratio": clamp_max_ratio,
|
||||
"dt/clamp_any_ratio": clamp_any_ratio,
|
||||
},
|
||||
step=global_step,
|
||||
)
|
||||
|
||||
if cfg.val_every > 0 and global_step > 0 and global_step % cfg.val_every == 0:
|
||||
validate(
|
||||
model,
|
||||
cfg,
|
||||
sphere_a,
|
||||
sphere_b0,
|
||||
sphere_b1,
|
||||
device,
|
||||
logger,
|
||||
global_step,
|
||||
)
|
||||
dt_hist_path = Path(cfg.output_dir) / f"dt_hist_step_{global_step:06d}.png"
|
||||
plot_dt_hist(
|
||||
dt,
|
||||
dt_seq,
|
||||
dt_hist_path,
|
||||
title=f"dt Distribution (step {global_step})",
|
||||
)
|
||||
logger.log_image(
|
||||
"train/dt_hist",
|
||||
dt_hist_path,
|
||||
caption=f"step {global_step}",
|
||||
step=global_step,
|
||||
)
|
||||
global_step += 1
|
||||
|
||||
logger.finish()
|
||||
return model, sphere_a, sphere_b0, sphere_b1
|
||||
def save_image_grid(
|
||||
images: Tensor, save_path: Path, nrow: int, title: str | None = None
|
||||
) -> None:
|
||||
images = images.clamp(-1.0, 1.0)
|
||||
images = (images + 1.0) / 2.0
|
||||
grid = make_grid(images, nrow=nrow)
|
||||
if grid.ndim == 3 and grid.shape[2] == 1:
|
||||
grid = np.repeat(grid, 3, axis=2)
|
||||
plt.imsave(save_path, grid)
|
||||
if title is not None:
|
||||
fig, ax = plt.subplots(figsize=(4, 3))
|
||||
ax.imshow(grid)
|
||||
ax.set_title(title)
|
||||
ax.axis("off")
|
||||
fig.tight_layout()
|
||||
fig.savefig(save_path, dpi=160)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def rollout_trajectory(
|
||||
@@ -597,104 +441,173 @@ def rollout_trajectory(
|
||||
return torch.stack(traj, dim=1)
|
||||
|
||||
|
||||
def sphere_wireframe(
|
||||
center: Tensor, radius: float, u_steps: int = 24, v_steps: int = 12
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
center_np = center.detach().cpu().numpy()
|
||||
u = np.linspace(0, 2 * np.pi, u_steps)
|
||||
v = np.linspace(0, np.pi, v_steps)
|
||||
x = center_np[0] + radius * np.outer(np.cos(u), np.sin(v))
|
||||
y = center_np[1] + radius * np.outer(np.sin(u), np.sin(v))
|
||||
z = center_np[2] + radius * np.outer(np.ones_like(u), np.cos(v))
|
||||
return x, y, z
|
||||
|
||||
|
||||
def plot_trajectories_cond(
|
||||
traj: Tensor,
|
||||
cond: Tensor,
|
||||
sphere_a: tuple[Tensor, Tensor],
|
||||
sphere_b0: tuple[Tensor, Tensor],
|
||||
sphere_b1: tuple[Tensor, Tensor],
|
||||
save_path: Path,
|
||||
title: str = "AS-Mamba Trajectories",
|
||||
def log_class_samples(
|
||||
model: ASMamba,
|
||||
cfg: TrainConfig,
|
||||
device: torch.device,
|
||||
logger: SwanLogger,
|
||||
step: int,
|
||||
) -> None:
|
||||
traj = traj.detach().cpu()
|
||||
if traj.dim() == 2:
|
||||
traj = traj.unsqueeze(0)
|
||||
traj_np = traj.numpy()
|
||||
cond_np = cond.detach().cpu().numpy()
|
||||
if cfg.val_samples_per_class <= 0:
|
||||
return
|
||||
model.eval()
|
||||
max_steps = cfg.seq_len if cfg.val_max_steps <= 0 else cfg.val_max_steps
|
||||
input_dim = cfg.channels * cfg.image_size * cfg.image_size
|
||||
|
||||
fig = plt.figure(figsize=(7, 6))
|
||||
ax = fig.add_subplot(111, projection="3d")
|
||||
|
||||
for i in range(traj_np.shape[0]):
|
||||
color = "tab:green" if cond_np[i] == 0 else "tab:orange"
|
||||
ax.plot(
|
||||
traj_np[i, :, 0],
|
||||
traj_np[i, :, 1],
|
||||
traj_np[i, :, 2],
|
||||
color=color,
|
||||
alpha=0.6,
|
||||
for cls in range(cfg.num_classes):
|
||||
cond = torch.full(
|
||||
(cfg.val_samples_per_class,), cls, device=device, dtype=torch.long
|
||||
)
|
||||
|
||||
starts = traj_np[:, 0, :]
|
||||
ends = traj_np[:, -1, :]
|
||||
ax.scatter(starts[:, 0], starts[:, 1], starts[:, 2], color="blue", s=20, label="Start")
|
||||
ax.scatter(ends[:, 0], ends[:, 1], ends[:, 2], color="red", s=20, label="End")
|
||||
|
||||
center_a, radius_a = sphere_a
|
||||
center_b0, radius_b0 = sphere_b0
|
||||
center_b1, radius_b1 = sphere_b1
|
||||
x_a, y_a, z_a = sphere_wireframe(center_a, float(radius_a.item()))
|
||||
x_b0, y_b0, z_b0 = sphere_wireframe(center_b0, float(radius_b0.item()))
|
||||
x_b1, y_b1, z_b1 = sphere_wireframe(center_b1, float(radius_b1.item()))
|
||||
ax.plot_wireframe(x_a, y_a, z_a, color="blue", alpha=0.15, linewidth=0.5)
|
||||
ax.plot_wireframe(x_b0, y_b0, z_b0, color="tab:green", alpha=0.2, linewidth=0.5)
|
||||
ax.plot_wireframe(x_b1, y_b1, z_b1, color="tab:orange", alpha=0.2, linewidth=0.5)
|
||||
|
||||
ax.set_title(title)
|
||||
ax.set_xlabel("X")
|
||||
ax.set_ylabel("Y")
|
||||
ax.set_zlabel("Z")
|
||||
ax.legend(loc="best")
|
||||
fig.tight_layout()
|
||||
fig.savefig(save_path, dpi=160)
|
||||
plt.close(fig)
|
||||
x0 = torch.randn(cfg.val_samples_per_class, input_dim, device=device)
|
||||
traj = rollout_trajectory(model, x0, cond, max_steps=max_steps)
|
||||
x_final = traj[:, -1, :].view(
|
||||
cfg.val_samples_per_class, cfg.channels, cfg.image_size, cfg.image_size
|
||||
)
|
||||
save_path = Path(cfg.output_dir) / f"val_class_{cls}_step_{step:06d}.png"
|
||||
save_image_grid(x_final, save_path, nrow=cfg.val_grid_rows)
|
||||
logger.log_image(
|
||||
f"val/class_{cls}",
|
||||
save_path,
|
||||
caption=f"class {cls} step {step}",
|
||||
step=step,
|
||||
)
|
||||
model.train()
|
||||
|
||||
|
||||
def plot_dt_hist(
|
||||
dt_pred: Tensor,
|
||||
dt_gt: Tensor,
|
||||
save_path: Path,
|
||||
title: str = "dt Distribution",
|
||||
) -> None:
|
||||
dt_pred_np = dt_pred.detach().cpu().numpy().reshape(-1)
|
||||
dt_gt_np = dt_gt.detach().cpu().numpy().reshape(-1)
|
||||
def train(cfg: TrainConfig) -> ASMamba:
|
||||
validate_time_config(cfg)
|
||||
use_ddp, rank, world_size, device = setup_distributed(cfg)
|
||||
set_seed(cfg.seed + rank)
|
||||
output_dir = Path(cfg.output_dir)
|
||||
if rank == 0:
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(6, 4))
|
||||
ax.hist(dt_gt_np, bins=30, alpha=0.6, label="dt_gt", color="steelblue")
|
||||
ax.hist(dt_pred_np, bins=30, alpha=0.6, label="dt_pred", color="orange")
|
||||
ax.set_title(title)
|
||||
ax.set_xlabel("dt")
|
||||
ax.set_ylabel("count")
|
||||
ax.legend(loc="best")
|
||||
fig.tight_layout()
|
||||
fig.savefig(save_path, dpi=160)
|
||||
plt.close(fig)
|
||||
model = ASMamba(cfg).to(device)
|
||||
if use_ddp:
|
||||
model = nn.parallel.DistributedDataParallel(model, device_ids=[device.index])
|
||||
optimizer = torch.optim.AdamW(
|
||||
model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay
|
||||
)
|
||||
logger = SwanLogger(cfg, enabled=(rank == 0))
|
||||
|
||||
loader, sampler = build_dataloader(cfg, distributed=use_ddp)
|
||||
loader_iter = infinite_loader(loader)
|
||||
|
||||
global_step = 0
|
||||
for _ in range(cfg.epochs):
|
||||
if sampler is not None:
|
||||
sampler.set_epoch(global_step)
|
||||
model.train()
|
||||
for _ in range(cfg.steps_per_epoch):
|
||||
batch = next(loader_iter)
|
||||
x1 = batch["pixel_values"].to(device)
|
||||
cond = batch["labels"].to(device)
|
||||
b = x1.shape[0]
|
||||
x1 = x1.view(b, -1)
|
||||
x0 = torch.randn_like(x1)
|
||||
v_gt = x1 - x0
|
||||
dt_seq = sample_time_sequence(cfg, b, device)
|
||||
t_seq = torch.cumsum(dt_seq, dim=-1)
|
||||
t_seq = torch.cat([torch.zeros(b, 1, device=device), t_seq[:, :-1]], dim=-1)
|
||||
x_seq = x0[:, None, :] + t_seq[:, :, None] * v_gt[:, None, :]
|
||||
|
||||
delta, dt, _ = model(x_seq, cond)
|
||||
|
||||
losses = compute_losses(
|
||||
delta=delta,
|
||||
dt=dt,
|
||||
x_seq=x_seq,
|
||||
x0=x0,
|
||||
v_gt=v_gt,
|
||||
t_seq=t_seq,
|
||||
dt_seq=dt_seq,
|
||||
cfg=cfg,
|
||||
)
|
||||
|
||||
loss = torch.tensor(0.0, device=device)
|
||||
if cfg.use_flow_loss and "flow" in losses:
|
||||
loss = loss + cfg.lambda_flow * losses["flow"]
|
||||
if cfg.use_pos_loss and "pos" in losses:
|
||||
loss = loss + cfg.lambda_pos * losses["pos"]
|
||||
if cfg.use_dt_loss and "dt" in losses:
|
||||
loss = loss + cfg.lambda_dt * losses["dt"]
|
||||
if loss.item() == 0.0:
|
||||
raise RuntimeError(
|
||||
"No loss enabled: enable at least one of flow/pos/dt."
|
||||
)
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if global_step % 10 == 0:
|
||||
dt_min = float(dt.min().item())
|
||||
dt_max = float(dt.max().item())
|
||||
dt_mean = float(dt.mean().item())
|
||||
dt_gt_min = float(dt_seq.min().item())
|
||||
dt_gt_max = float(dt_seq.max().item())
|
||||
dt_gt_mean = float(dt_seq.mean().item())
|
||||
eps = 1e-6
|
||||
clamp_min_ratio = float((dt <= cfg.dt_min + eps).float().mean().item())
|
||||
clamp_max_ratio = float((dt >= cfg.dt_max - eps).float().mean().item())
|
||||
clamp_any_ratio = float(
|
||||
((dt <= cfg.dt_min + eps) | (dt >= cfg.dt_max - eps))
|
||||
.float()
|
||||
.mean()
|
||||
.item()
|
||||
)
|
||||
logger.log(
|
||||
{
|
||||
"loss/total": float(loss.item()),
|
||||
"loss/flow": float(
|
||||
losses.get("flow", torch.tensor(0.0)).item()
|
||||
),
|
||||
"loss/pos": float(losses.get("pos", torch.tensor(0.0)).item()),
|
||||
"loss/dt": float(losses.get("dt", torch.tensor(0.0)).item()),
|
||||
"dt/pred_mean": dt_mean,
|
||||
"dt/pred_min": dt_min,
|
||||
"dt/pred_max": dt_max,
|
||||
"dt/gt_mean": dt_gt_mean,
|
||||
"dt/gt_min": dt_gt_min,
|
||||
"dt/gt_max": dt_gt_max,
|
||||
"dt/clamp_min_ratio": clamp_min_ratio,
|
||||
"dt/clamp_max_ratio": clamp_max_ratio,
|
||||
"dt/clamp_any_ratio": clamp_any_ratio,
|
||||
},
|
||||
step=global_step,
|
||||
)
|
||||
|
||||
if (
|
||||
cfg.val_every > 0
|
||||
and global_step > 0
|
||||
and global_step % cfg.val_every == 0
|
||||
and rank == 0
|
||||
):
|
||||
log_class_samples(unwrap_model(model), cfg, device, logger, global_step)
|
||||
dt_hist_path = (
|
||||
Path(cfg.output_dir) / f"dt_hist_step_{global_step:06d}.png"
|
||||
)
|
||||
plot_dt_hist(
|
||||
dt,
|
||||
dt_seq,
|
||||
dt_hist_path,
|
||||
title=f"dt Distribution (step {global_step})",
|
||||
)
|
||||
logger.log_image(
|
||||
"train/dt_hist",
|
||||
dt_hist_path,
|
||||
caption=f"step {global_step}",
|
||||
step=global_step,
|
||||
)
|
||||
|
||||
global_step += 1
|
||||
|
||||
logger.finish()
|
||||
if use_ddp:
|
||||
torch.distributed.destroy_process_group()
|
||||
return unwrap_model(model)
|
||||
|
||||
|
||||
def run_training_and_plot(cfg: TrainConfig) -> Path:
|
||||
model, sphere_a, sphere_b0, sphere_b1 = train(cfg)
|
||||
device = next(model.parameters()).device
|
||||
|
||||
plot_samples = max(1, cfg.val_plot_samples)
|
||||
x0 = sample_points_in_sphere(
|
||||
sphere_a[0], float(sphere_a[1].item()), plot_samples, device
|
||||
)
|
||||
max_steps = cfg.seq_len if cfg.val_max_steps <= 0 else cfg.val_max_steps
|
||||
cond = torch.randint(0, 2, (plot_samples,), device=device, dtype=torch.long)
|
||||
traj = rollout_trajectory(model, x0, cond, max_steps=max_steps)
|
||||
output_dir = Path(cfg.output_dir)
|
||||
save_path = output_dir / "as_mamba_trajectory.png"
|
||||
plot_trajectories_cond(traj, cond, sphere_a, sphere_b0, sphere_b1, save_path)
|
||||
return save_path
|
||||
train(cfg)
|
||||
return Path(cfg.output_dir)
|
||||
|
||||
Reference in New Issue
Block a user