657 lines
21 KiB
Python
657 lines
21 KiB
Python
from __future__ import annotations
|
|
|
|
import math
|
|
import os
|
|
from dataclasses import asdict, dataclass
|
|
from pathlib import Path
|
|
from typing import Iterator, Optional
|
|
|
|
import lpips
|
|
|
|
os.environ.setdefault("MPLCONFIGDIR", "/tmp/mamba_diffusion_mplconfig")
|
|
import matplotlib
|
|
|
|
matplotlib.use("Agg")
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from datasets import load_dataset
|
|
from matplotlib import pyplot as plt
|
|
from torch import Tensor, nn
|
|
from torch.func import jvp
|
|
from torch.utils.data import DataLoader
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
|
|
from mamba2_minimal import InferenceCache, Mamba2, Mamba2Config, RMSNorm
|
|
|
|
|
|
FIXED_VAL_SAMPLING_STEPS = 5
|
|
FIXED_VAL_TIME_GRID = (1.0, 0.8, 0.6, 0.4, 0.2, 0.0)
|
|
|
|
|
|
@dataclass
|
|
class TrainConfig:
|
|
seed: int = 42
|
|
device: str = "cuda"
|
|
epochs: int = 50
|
|
steps_per_epoch: int = 200
|
|
batch_size: int = 128
|
|
seq_len: int = 5
|
|
lr: float = 2e-4
|
|
weight_decay: float = 1e-2
|
|
max_grad_norm: float = 1.0
|
|
lambda_flow: float = 1.0
|
|
lambda_perceptual: float = 2.0
|
|
num_classes: int = 10
|
|
image_size: int = 28
|
|
channels: int = 1
|
|
num_workers: int = 8
|
|
dataset_name: str = "ylecun/mnist"
|
|
dataset_split: str = "train"
|
|
d_model: int = 0
|
|
n_layer: int = 6
|
|
d_state: int = 64
|
|
d_conv: int = 4
|
|
expand: int = 2
|
|
headdim: int = 32
|
|
chunk_size: int = 1
|
|
use_residual: bool = False
|
|
output_dir: str = "outputs"
|
|
project: str = "as-mamba-mnist"
|
|
run_name: str = "mnist-meanflow"
|
|
val_every: int = 200
|
|
val_samples_per_class: int = 8
|
|
val_grid_rows: int = 4
|
|
val_sampling_steps: int = FIXED_VAL_SAMPLING_STEPS
|
|
time_grid_size: int = 256
|
|
use_ddp: bool = False
|
|
|
|
|
|
class AdaLNZero(nn.Module):
|
|
def __init__(self, d_model: int) -> None:
|
|
super().__init__()
|
|
self.norm = RMSNorm(d_model)
|
|
self.mod = nn.Linear(d_model, 2 * d_model)
|
|
nn.init.zeros_(self.mod.weight)
|
|
nn.init.zeros_(self.mod.bias)
|
|
|
|
def forward(self, x: Tensor, cond: Tensor) -> Tensor:
|
|
x = self.norm(x)
|
|
params = self.mod(cond).unsqueeze(1)
|
|
scale, shift = params.chunk(2, dim=-1)
|
|
return x * (1 + scale) + shift
|
|
|
|
|
|
class Mamba2Backbone(nn.Module):
|
|
def __init__(self, args: Mamba2Config, use_residual: bool = True) -> None:
|
|
super().__init__()
|
|
self.args = args
|
|
self.use_residual = use_residual
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
nn.ModuleDict(
|
|
dict(
|
|
mixer=Mamba2(args),
|
|
adaln=AdaLNZero(args.d_model),
|
|
)
|
|
)
|
|
for _ in range(args.n_layer)
|
|
]
|
|
)
|
|
self.norm_f = RMSNorm(args.d_model)
|
|
|
|
def forward(
|
|
self, x: Tensor, cond: Tensor, h: Optional[list[InferenceCache]] = None
|
|
) -> tuple[Tensor, list[InferenceCache]]:
|
|
if h is None:
|
|
h = [None for _ in range(self.args.n_layer)]
|
|
|
|
for i, layer in enumerate(self.layers):
|
|
x_mod = layer["adaln"](x, cond)
|
|
y, h[i] = layer["mixer"](x_mod, h[i])
|
|
x = x + y if self.use_residual else y
|
|
|
|
x = self.norm_f(x)
|
|
return x, h
|
|
|
|
|
|
def sinusoidal_embedding(t: Tensor, dim: int) -> Tensor:
|
|
if dim <= 0:
|
|
raise ValueError("sinusoidal embedding dim must be > 0")
|
|
half = dim // 2
|
|
if half == 0:
|
|
return torch.zeros(*t.shape, dim, device=t.device, dtype=t.dtype)
|
|
denom = max(half - 1, 1)
|
|
freqs = torch.exp(
|
|
-math.log(10000.0)
|
|
* torch.arange(half, device=t.device, dtype=torch.float32)
|
|
/ denom
|
|
)
|
|
args = t.to(torch.float32).unsqueeze(-1) * freqs
|
|
emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
|
|
if dim % 2 == 1:
|
|
pad = torch.zeros(*emb.shape[:-1], 1, device=emb.device, dtype=emb.dtype)
|
|
emb = torch.cat([emb, pad], dim=-1)
|
|
return emb.to(t.dtype)
|
|
|
|
|
|
def safe_time_divisor(t: Tensor) -> Tensor:
|
|
eps = torch.finfo(t.dtype).eps
|
|
return torch.where(t > 0, t, torch.full_like(t, eps))
|
|
|
|
|
|
class ASMamba(nn.Module):
|
|
def __init__(self, cfg: TrainConfig) -> None:
|
|
super().__init__()
|
|
self.cfg = cfg
|
|
input_dim = cfg.channels * cfg.image_size * cfg.image_size
|
|
if cfg.d_model == 0:
|
|
cfg.d_model = input_dim
|
|
if cfg.d_model != input_dim:
|
|
raise ValueError(
|
|
f"d_model must equal flattened image dim ({input_dim}) when input_proj is disabled."
|
|
)
|
|
|
|
args = Mamba2Config(
|
|
d_model=cfg.d_model,
|
|
n_layer=cfg.n_layer,
|
|
d_state=cfg.d_state,
|
|
d_conv=cfg.d_conv,
|
|
expand=cfg.expand,
|
|
headdim=cfg.headdim,
|
|
chunk_size=cfg.chunk_size,
|
|
)
|
|
self.backbone = Mamba2Backbone(args, use_residual=cfg.use_residual)
|
|
self.cond_emb = nn.Embedding(cfg.num_classes, cfg.d_model)
|
|
self.clean_head = nn.Linear(cfg.d_model, input_dim)
|
|
|
|
def forward(
|
|
self,
|
|
z_t: Tensor,
|
|
r: Tensor,
|
|
t: Tensor,
|
|
cond: Tensor,
|
|
h: Optional[list[InferenceCache]] = None,
|
|
) -> tuple[Tensor, list[InferenceCache]]:
|
|
if r.dim() == 1:
|
|
r = r.unsqueeze(1)
|
|
elif r.dim() == 3 and r.shape[-1] == 1:
|
|
r = r.squeeze(-1)
|
|
if t.dim() == 1:
|
|
t = t.unsqueeze(1)
|
|
elif t.dim() == 3 and t.shape[-1] == 1:
|
|
t = t.squeeze(-1)
|
|
|
|
r = r.to(dtype=z_t.dtype)
|
|
t = t.to(dtype=z_t.dtype)
|
|
z_t = z_t + sinusoidal_embedding(r, z_t.shape[-1]) + sinusoidal_embedding(
|
|
t, z_t.shape[-1]
|
|
)
|
|
cond_vec = self.cond_emb(cond)
|
|
feats, h = self.backbone(z_t, cond_vec, h)
|
|
x_pred = torch.tanh(self.clean_head(feats))
|
|
return x_pred, h
|
|
|
|
def step(
|
|
self, z_t: Tensor, r: Tensor, t: Tensor, cond: Tensor, h: list[InferenceCache]
|
|
) -> tuple[Tensor, list[InferenceCache]]:
|
|
x_pred, h = self.forward(
|
|
z_t.unsqueeze(1), r.unsqueeze(1), t.unsqueeze(1), cond, h
|
|
)
|
|
return x_pred[:, 0, :], h
|
|
|
|
def init_cache(self, batch_size: int, device: torch.device) -> list[InferenceCache]:
|
|
return [
|
|
InferenceCache.alloc(batch_size, self.backbone.args, device=device)
|
|
for _ in range(self.backbone.args.n_layer)
|
|
]
|
|
|
|
|
|
class SwanLogger:
|
|
def __init__(self, cfg: TrainConfig, enabled: bool = True) -> None:
|
|
self.enabled = enabled
|
|
self._swan = None
|
|
self._run = None
|
|
if not self.enabled:
|
|
return
|
|
try:
|
|
import swanlab # type: ignore
|
|
|
|
self._swan = swanlab
|
|
self._run = self._swan.init(
|
|
project=cfg.project,
|
|
experiment_name=cfg.run_name,
|
|
config=asdict(cfg),
|
|
)
|
|
self.enabled = True
|
|
except Exception:
|
|
self.enabled = False
|
|
|
|
def log(self, metrics: dict, step: int | None = None) -> None:
|
|
if not self.enabled:
|
|
return
|
|
target = self._run if self._run is not None else self._swan
|
|
if step is None:
|
|
target.log(metrics)
|
|
return
|
|
try:
|
|
target.log(metrics, step=step)
|
|
except TypeError:
|
|
payload = dict(metrics)
|
|
payload["step"] = step
|
|
target.log(payload)
|
|
|
|
def log_image(
|
|
self,
|
|
key: str,
|
|
image_path: Path,
|
|
caption: str | None = None,
|
|
step: int | None = None,
|
|
) -> None:
|
|
if not self.enabled:
|
|
return
|
|
image = self._swan.Image(str(image_path), caption=caption)
|
|
self.log({key: image}, step=step)
|
|
|
|
def finish(self) -> None:
|
|
if not self.enabled:
|
|
return
|
|
finish = None
|
|
if self._run is not None:
|
|
finish = getattr(self._run, "finish", None)
|
|
if finish is None:
|
|
finish = getattr(self._swan, "finish", None)
|
|
if callable(finish):
|
|
finish()
|
|
|
|
|
|
class LPIPSPerceptualLoss(nn.Module):
|
|
def __init__(self, cfg: TrainConfig) -> None:
|
|
super().__init__()
|
|
torch_home = Path(cfg.output_dir) / ".torch"
|
|
torch_home.mkdir(parents=True, exist_ok=True)
|
|
os.environ["TORCH_HOME"] = str(torch_home)
|
|
self.channels = cfg.channels
|
|
self.loss_fn = lpips.LPIPS(net="vgg", verbose=False)
|
|
self.loss_fn.eval()
|
|
for param in self.loss_fn.parameters():
|
|
param.requires_grad_(False)
|
|
|
|
def _prepare_images(self, images: Tensor) -> Tensor:
|
|
if images.shape[1] == 1:
|
|
return images.repeat(1, 3, 1, 1)
|
|
if images.shape[1] != 3:
|
|
raise ValueError(
|
|
"LPIPS perceptual loss expects 1-channel or 3-channel images. "
|
|
f"Got {images.shape[1]} channels."
|
|
)
|
|
return images
|
|
|
|
def forward(self, pred: Tensor, target: Tensor) -> Tensor:
|
|
pred_rgb = self._prepare_images(pred)
|
|
target_rgb = self._prepare_images(target)
|
|
return self.loss_fn(pred_rgb, target_rgb).mean()
|
|
|
|
|
|
def set_seed(seed: int) -> None:
|
|
torch.manual_seed(seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
|
|
def setup_distributed(cfg: TrainConfig) -> tuple[bool, int, int, torch.device]:
|
|
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
|
rank = int(os.environ.get("RANK", "0"))
|
|
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
|
use_ddp = cfg.use_ddp and world_size > 1
|
|
if use_ddp:
|
|
torch.distributed.init_process_group(backend="nccl", init_method="env://")
|
|
torch.cuda.set_device(local_rank)
|
|
device = torch.device("cuda", local_rank)
|
|
else:
|
|
device = torch.device(cfg.device if torch.cuda.is_available() else "cpu")
|
|
return use_ddp, rank, world_size, device
|
|
|
|
|
|
def unwrap_model(model: nn.Module) -> nn.Module:
|
|
return model.module if hasattr(model, "module") else model
|
|
|
|
|
|
def validate_config(cfg: TrainConfig) -> None:
|
|
if cfg.seq_len != 5:
|
|
raise ValueError(
|
|
f"seq_len must be 5 for the required 5-block training setup (got {cfg.seq_len})."
|
|
)
|
|
if cfg.time_grid_size < 2:
|
|
raise ValueError("time_grid_size must be >= 2.")
|
|
if cfg.lambda_perceptual < 0:
|
|
raise ValueError("lambda_perceptual must be >= 0.")
|
|
if cfg.max_grad_norm <= 0:
|
|
raise ValueError("max_grad_norm must be > 0.")
|
|
if cfg.val_sampling_steps != FIXED_VAL_SAMPLING_STEPS:
|
|
raise ValueError(
|
|
f"val_sampling_steps is fixed to {FIXED_VAL_SAMPLING_STEPS} for validation sampling."
|
|
)
|
|
|
|
|
|
def sample_block_times(
|
|
cfg: TrainConfig, batch_size: int, device: torch.device, dtype: torch.dtype
|
|
) -> tuple[Tensor, Tensor]:
|
|
num_internal = cfg.seq_len - 1
|
|
normal = torch.randn(batch_size, num_internal, device=device, dtype=dtype)
|
|
logit_normal = torch.sigmoid(normal * math.sqrt(0.8))
|
|
uniform = torch.rand(batch_size, num_internal, device=device, dtype=dtype)
|
|
use_uniform = torch.rand(batch_size, num_internal, device=device) < 0.1
|
|
cuts = torch.where(use_uniform, uniform, logit_normal)
|
|
cuts, _ = torch.sort(cuts, dim=-1)
|
|
boundaries = torch.cat(
|
|
[
|
|
torch.zeros(batch_size, 1, device=device, dtype=dtype),
|
|
cuts,
|
|
torch.ones(batch_size, 1, device=device, dtype=dtype),
|
|
],
|
|
dim=-1,
|
|
)
|
|
return boundaries[:, :-1], boundaries[:, 1:]
|
|
|
|
|
|
def build_dataloader(
|
|
cfg: TrainConfig, distributed: bool = False
|
|
) -> tuple[DataLoader, Optional[DistributedSampler]]:
|
|
ds = load_dataset(cfg.dataset_name, split=cfg.dataset_split)
|
|
|
|
def transform(example):
|
|
image = example.get("img", example.get("image"))
|
|
label = example.get("label", example.get("labels"))
|
|
if isinstance(image, list):
|
|
arr = np.stack([np.array(im, dtype=np.float32) for im in image], axis=0)
|
|
arr = arr / 127.5 - 1.0
|
|
if arr.ndim == 3:
|
|
tensor = torch.from_numpy(arr).unsqueeze(1)
|
|
else:
|
|
tensor = torch.from_numpy(arr).permute(0, 3, 1, 2)
|
|
labels = torch.tensor(label, dtype=torch.long)
|
|
return {"pixel_values": tensor, "labels": labels}
|
|
arr = np.array(image, dtype=np.float32) / 127.5 - 1.0
|
|
if arr.ndim == 2:
|
|
tensor = torch.from_numpy(arr).unsqueeze(0)
|
|
else:
|
|
tensor = torch.from_numpy(arr).permute(2, 0, 1)
|
|
return {"pixel_values": tensor, "labels": torch.tensor(label, dtype=torch.long)}
|
|
|
|
ds = ds.with_transform(transform)
|
|
sampler = DistributedSampler(ds, shuffle=True) if distributed else None
|
|
loader = DataLoader(
|
|
ds,
|
|
batch_size=cfg.batch_size,
|
|
shuffle=(sampler is None),
|
|
sampler=sampler,
|
|
num_workers=cfg.num_workers,
|
|
drop_last=True,
|
|
pin_memory=torch.cuda.is_available(),
|
|
)
|
|
return loader, sampler
|
|
|
|
|
|
def infinite_loader(loader: DataLoader) -> Iterator[dict]:
|
|
while True:
|
|
for batch in loader:
|
|
yield batch
|
|
|
|
|
|
def build_noisy_sequence(
|
|
x0: Tensor,
|
|
eps: Tensor,
|
|
t_seq: Tensor,
|
|
) -> tuple[Tensor, Tensor]:
|
|
z_t = (1.0 - t_seq.unsqueeze(-1)) * x0[:, None, :] + t_seq.unsqueeze(-1) * eps[:, None, :]
|
|
v_gt = eps - x0
|
|
return z_t, v_gt
|
|
|
|
|
|
def compute_losses(
|
|
model: nn.Module,
|
|
perceptual_loss_fn: LPIPSPerceptualLoss,
|
|
x0: Tensor,
|
|
z_t: Tensor,
|
|
v_gt: Tensor,
|
|
r_seq: Tensor,
|
|
t_seq: Tensor,
|
|
cond: Tensor,
|
|
cfg: TrainConfig,
|
|
) -> tuple[dict[str, Tensor], Tensor]:
|
|
seq_len = z_t.shape[1]
|
|
safe_t = safe_time_divisor(t_seq).unsqueeze(-1)
|
|
|
|
x_pred, _ = model(z_t, r_seq, t_seq, cond)
|
|
u = (z_t - x_pred) / safe_t
|
|
|
|
x_pred_inst, _ = model(z_t, t_seq, t_seq, cond)
|
|
v_inst = ((z_t - x_pred_inst) / safe_t).detach()
|
|
|
|
def u_fn(z_in: Tensor, r_in: Tensor, t_in: Tensor) -> Tensor:
|
|
x_pred_local, _ = model(z_in, r_in, t_in, cond)
|
|
return (z_in - x_pred_local) / safe_time_divisor(t_in).unsqueeze(-1)
|
|
|
|
_, dudt = jvp(
|
|
u_fn,
|
|
(z_t, r_seq, t_seq),
|
|
(v_inst, torch.zeros_like(r_seq), torch.ones_like(t_seq)),
|
|
)
|
|
corrected_velocity = u + (t_seq - r_seq).unsqueeze(-1) * dudt.detach()
|
|
target_velocity = v_gt[:, None, :].expand(-1, seq_len, -1)
|
|
|
|
pred_images = x_pred.reshape(
|
|
x0.shape[0] * seq_len, cfg.channels, cfg.image_size, cfg.image_size
|
|
)
|
|
target_images = (
|
|
x0.reshape(x0.shape[0], cfg.channels, cfg.image_size, cfg.image_size)
|
|
.unsqueeze(1)
|
|
.expand(-1, seq_len, -1, -1, -1)
|
|
.reshape(x0.shape[0] * seq_len, cfg.channels, cfg.image_size, cfg.image_size)
|
|
)
|
|
|
|
losses = {
|
|
"flow": F.mse_loss(corrected_velocity, target_velocity),
|
|
"perceptual": perceptual_loss_fn(pred_images, target_images),
|
|
}
|
|
losses["total"] = cfg.lambda_flow * losses["flow"] + cfg.lambda_perceptual * losses[
|
|
"perceptual"
|
|
]
|
|
return losses, x_pred
|
|
|
|
|
|
def make_grid(images: Tensor, nrow: int) -> np.ndarray:
|
|
images = images.detach().cpu().numpy()
|
|
b, c, h, w = images.shape
|
|
nrow = max(1, min(nrow, b))
|
|
ncol = math.ceil(b / nrow)
|
|
grid = np.zeros((c, ncol * h, nrow * w), dtype=np.float32)
|
|
for idx in range(b):
|
|
r = idx // nrow
|
|
cidx = idx % nrow
|
|
grid[:, r * h : (r + 1) * h, cidx * w : (cidx + 1) * w] = images[idx]
|
|
return np.transpose(grid, (1, 2, 0))
|
|
|
|
|
|
def save_image_grid(
|
|
images: Tensor, save_path: Path, nrow: int, title: str | None = None
|
|
) -> None:
|
|
images = images.clamp(-1.0, 1.0)
|
|
images = (images + 1.0) / 2.0
|
|
grid = make_grid(images, nrow=nrow)
|
|
if grid.ndim == 3 and grid.shape[2] == 1:
|
|
grid = np.repeat(grid, 3, axis=2)
|
|
plt.imsave(save_path, grid)
|
|
if title is not None:
|
|
fig, ax = plt.subplots(figsize=(4, 3))
|
|
ax.imshow(grid)
|
|
ax.set_title(title)
|
|
ax.axis("off")
|
|
fig.tight_layout()
|
|
fig.savefig(save_path, dpi=160)
|
|
plt.close(fig)
|
|
|
|
|
|
def sample_class_images(
|
|
model: ASMamba,
|
|
cfg: TrainConfig,
|
|
device: torch.device,
|
|
cond: Tensor,
|
|
) -> Tensor:
|
|
model.eval()
|
|
input_dim = cfg.channels * cfg.image_size * cfg.image_size
|
|
z_t = torch.randn(cond.shape[0], input_dim, device=device)
|
|
time_grid = torch.tensor(FIXED_VAL_TIME_GRID, device=device)
|
|
|
|
with torch.no_grad():
|
|
for step_idx in range(FIXED_VAL_SAMPLING_STEPS):
|
|
t_cur = torch.full(
|
|
(cond.shape[0],),
|
|
float(time_grid[step_idx].item()),
|
|
device=device,
|
|
)
|
|
t_next = time_grid[step_idx + 1]
|
|
x_pred, _ = model(
|
|
z_t.unsqueeze(1),
|
|
t_cur.unsqueeze(1),
|
|
t_cur.unsqueeze(1),
|
|
cond,
|
|
)
|
|
x_pred = x_pred[:, 0, :]
|
|
u_inst = (z_t - x_pred) / safe_time_divisor(t_cur).unsqueeze(-1)
|
|
z_t = z_t + (t_next - t_cur).unsqueeze(-1) * u_inst
|
|
|
|
return z_t.view(cond.shape[0], cfg.channels, cfg.image_size, cfg.image_size)
|
|
|
|
|
|
def log_class_samples(
|
|
model: ASMamba,
|
|
cfg: TrainConfig,
|
|
device: torch.device,
|
|
logger: SwanLogger,
|
|
step: int,
|
|
) -> None:
|
|
if cfg.val_samples_per_class <= 0:
|
|
return
|
|
training_mode = model.training
|
|
model.eval()
|
|
for cls in range(cfg.num_classes):
|
|
cond = torch.full(
|
|
(cfg.val_samples_per_class,), cls, device=device, dtype=torch.long
|
|
)
|
|
x_final = sample_class_images(model, cfg, device, cond)
|
|
save_path = Path(cfg.output_dir) / f"val_class_{cls}_step_{step:06d}.png"
|
|
save_image_grid(x_final, save_path, nrow=cfg.val_grid_rows)
|
|
logger.log_image(
|
|
f"val/class_{cls}",
|
|
save_path,
|
|
caption=f"class {cls} step {step}",
|
|
step=step,
|
|
)
|
|
if training_mode:
|
|
model.train()
|
|
|
|
|
|
def build_perceptual_loss(
|
|
cfg: TrainConfig, device: torch.device, rank: int, use_ddp: bool
|
|
) -> LPIPSPerceptualLoss:
|
|
if use_ddp and rank != 0:
|
|
torch.distributed.barrier()
|
|
perceptual_loss_fn = LPIPSPerceptualLoss(cfg).to(device)
|
|
if use_ddp and rank == 0:
|
|
torch.distributed.barrier()
|
|
return perceptual_loss_fn
|
|
|
|
|
|
def train(cfg: TrainConfig) -> ASMamba:
|
|
validate_config(cfg)
|
|
use_ddp, rank, world_size, device = setup_distributed(cfg)
|
|
del world_size
|
|
set_seed(cfg.seed + rank)
|
|
output_dir = Path(cfg.output_dir)
|
|
if rank == 0:
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
model = ASMamba(cfg).to(device)
|
|
if use_ddp:
|
|
model = nn.parallel.DistributedDataParallel(model, device_ids=[device.index])
|
|
optimizer = torch.optim.AdamW(
|
|
model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay
|
|
)
|
|
perceptual_loss_fn = build_perceptual_loss(cfg, device, rank, use_ddp)
|
|
logger = SwanLogger(cfg, enabled=(rank == 0))
|
|
|
|
loader, sampler = build_dataloader(cfg, distributed=use_ddp)
|
|
loader_iter = infinite_loader(loader)
|
|
|
|
global_step = 0
|
|
for epoch_idx in range(cfg.epochs):
|
|
if sampler is not None:
|
|
sampler.set_epoch(epoch_idx)
|
|
model.train()
|
|
for _ in range(cfg.steps_per_epoch):
|
|
batch = next(loader_iter)
|
|
x0 = batch["pixel_values"].to(device)
|
|
cond = batch["labels"].to(device)
|
|
b = x0.shape[0]
|
|
x0 = x0.view(b, -1)
|
|
eps = torch.randn_like(x0)
|
|
|
|
r_seq, t_seq = sample_block_times(cfg, b, device, x0.dtype)
|
|
z_t, v_gt = build_noisy_sequence(x0, eps, t_seq)
|
|
|
|
losses, _ = compute_losses(
|
|
model=model,
|
|
perceptual_loss_fn=perceptual_loss_fn,
|
|
x0=x0,
|
|
z_t=z_t,
|
|
v_gt=v_gt,
|
|
r_seq=r_seq,
|
|
t_seq=t_seq,
|
|
cond=cond,
|
|
cfg=cfg,
|
|
)
|
|
|
|
optimizer.zero_grad(set_to_none=True)
|
|
losses["total"].backward()
|
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
|
model.parameters(), max_norm=cfg.max_grad_norm
|
|
)
|
|
optimizer.step()
|
|
|
|
if global_step % 10 == 0 and rank == 0:
|
|
logger.log(
|
|
{
|
|
"loss/total": float(losses["total"].item()),
|
|
"loss/flow": float(losses["flow"].item()),
|
|
"loss/perceptual": float(losses["perceptual"].item()),
|
|
"grad/total_norm": float(grad_norm.item()),
|
|
"time/r_mean": float(r_seq.mean().item()),
|
|
"time/t_mean": float(t_seq.mean().item()),
|
|
"time/zero_block_frac": float((t_seq == r_seq).float().mean().item()),
|
|
},
|
|
step=global_step,
|
|
)
|
|
|
|
if (
|
|
cfg.val_every > 0
|
|
and global_step > 0
|
|
and global_step % cfg.val_every == 0
|
|
and rank == 0
|
|
):
|
|
log_class_samples(unwrap_model(model), cfg, device, logger, global_step)
|
|
|
|
global_step += 1
|
|
|
|
logger.finish()
|
|
if use_ddp:
|
|
torch.distributed.destroy_process_group()
|
|
return unwrap_model(model)
|
|
|
|
|
|
def run_training_and_plot(cfg: TrainConfig) -> Path:
|
|
train(cfg)
|
|
return Path(cfg.output_dir)
|