614 lines
20 KiB
Python
614 lines
20 KiB
Python
from __future__ import annotations
|
|
|
|
import math
|
|
from dataclasses import asdict, dataclass
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Iterator, Optional
|
|
|
|
import matplotlib
|
|
|
|
matplotlib.use("Agg")
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from datasets import load_dataset
|
|
from matplotlib import pyplot as plt
|
|
from torch import Tensor, nn
|
|
from torch.utils.data import DataLoader
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
|
|
from mamba2_minimal import InferenceCache, Mamba2, Mamba2Config, RMSNorm
|
|
|
|
|
|
@dataclass
|
|
class TrainConfig:
|
|
seed: int = 42
|
|
device: str = "cuda"
|
|
epochs: int = 50
|
|
steps_per_epoch: int = 200
|
|
batch_size: int = 128
|
|
seq_len: int = 20
|
|
lr: float = 2e-4
|
|
weight_decay: float = 1e-2
|
|
dt_min: float = 1e-3
|
|
dt_max: float = 0.06
|
|
dt_alpha: float = 8.0
|
|
lambda_flow: float = 1.0
|
|
lambda_pos: float = 1.0
|
|
lambda_dt: float = 1.0
|
|
use_flow_loss: bool = True
|
|
use_pos_loss: bool = False
|
|
use_dt_loss: bool = True
|
|
num_classes: int = 10
|
|
image_size: int = 28
|
|
channels: int = 1
|
|
num_workers: int = 8
|
|
dataset_name: str = "ylecun/mnist"
|
|
dataset_split: str = "train"
|
|
d_model: int = 0
|
|
n_layer: int = 6
|
|
d_state: int = 64
|
|
d_conv: int = 4
|
|
expand: int = 2
|
|
headdim: int = 32
|
|
chunk_size: int = 1
|
|
use_residual: bool = False
|
|
output_dir: str = "outputs"
|
|
project: str = "as-mamba-mnist"
|
|
run_name: str = "mnist-flow"
|
|
val_every: int = 200
|
|
val_samples_per_class: int = 8
|
|
val_grid_rows: int = 4
|
|
val_max_steps: int = 0
|
|
use_ddp: bool = False
|
|
|
|
|
|
class AdaLNZero(nn.Module):
|
|
def __init__(self, d_model: int) -> None:
|
|
super().__init__()
|
|
self.norm = RMSNorm(d_model)
|
|
self.mod = nn.Linear(d_model, 2 * d_model)
|
|
nn.init.zeros_(self.mod.weight)
|
|
nn.init.zeros_(self.mod.bias)
|
|
|
|
def forward(self, x: Tensor, cond: Tensor) -> Tensor:
|
|
x = self.norm(x)
|
|
params = self.mod(cond).unsqueeze(1)
|
|
scale, shift = params.chunk(2, dim=-1)
|
|
return x * (1 + scale) + shift
|
|
|
|
|
|
class Mamba2Backbone(nn.Module):
|
|
def __init__(self, args: Mamba2Config, use_residual: bool = True) -> None:
|
|
super().__init__()
|
|
self.args = args
|
|
self.use_residual = use_residual
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
nn.ModuleDict(
|
|
dict(
|
|
mixer=Mamba2(args),
|
|
adaln=AdaLNZero(args.d_model),
|
|
)
|
|
)
|
|
for _ in range(args.n_layer)
|
|
]
|
|
)
|
|
self.norm_f = RMSNorm(args.d_model)
|
|
|
|
def forward(
|
|
self, x: Tensor, cond: Tensor, h: Optional[list[InferenceCache]] = None
|
|
) -> tuple[Tensor, list[InferenceCache]]:
|
|
if h is None:
|
|
h = [None for _ in range(self.args.n_layer)]
|
|
|
|
for i, layer in enumerate(self.layers):
|
|
x_mod = layer["adaln"](x, cond)
|
|
y, h[i] = layer["mixer"](x_mod, h[i])
|
|
x = x + y if self.use_residual else y
|
|
|
|
x = self.norm_f(x)
|
|
return x, h
|
|
|
|
|
|
class ASMamba(nn.Module):
|
|
def __init__(self, cfg: TrainConfig) -> None:
|
|
super().__init__()
|
|
self.cfg = cfg
|
|
self.dt_min = float(cfg.dt_min)
|
|
self.dt_max = float(cfg.dt_max)
|
|
input_dim = cfg.channels * cfg.image_size * cfg.image_size
|
|
if cfg.d_model == 0:
|
|
cfg.d_model = input_dim
|
|
if cfg.d_model != input_dim:
|
|
raise ValueError(
|
|
f"d_model must equal flattened image dim ({input_dim}) when input_proj is disabled."
|
|
)
|
|
|
|
args = Mamba2Config(
|
|
d_model=cfg.d_model,
|
|
n_layer=cfg.n_layer,
|
|
d_state=cfg.d_state,
|
|
d_conv=cfg.d_conv,
|
|
expand=cfg.expand,
|
|
headdim=cfg.headdim,
|
|
chunk_size=cfg.chunk_size,
|
|
)
|
|
self.backbone = Mamba2Backbone(args, use_residual=cfg.use_residual)
|
|
self.cond_emb = nn.Embedding(cfg.num_classes, cfg.d_model)
|
|
self.delta_head = nn.Linear(cfg.d_model, input_dim)
|
|
self.dt_head = nn.Sequential(
|
|
nn.Linear(cfg.d_model, cfg.d_model),
|
|
nn.SiLU(),
|
|
nn.Linear(cfg.d_model, 1),
|
|
)
|
|
|
|
def forward(
|
|
self, x: Tensor, cond: Tensor, h: Optional[list[InferenceCache]] = None
|
|
) -> tuple[Tensor, Tensor, list[InferenceCache]]:
|
|
cond_vec = self.cond_emb(cond)
|
|
feats, h = self.backbone(x, cond_vec, h)
|
|
delta = self.delta_head(feats)
|
|
dt_raw = self.dt_head(feats).squeeze(-1)
|
|
dt = torch.clamp(F.softplus(dt_raw), min=self.dt_min, max=self.dt_max)
|
|
return delta, dt, h
|
|
|
|
def step(
|
|
self, x: Tensor, cond: Tensor, h: list[InferenceCache]
|
|
) -> tuple[Tensor, Tensor, list[InferenceCache]]:
|
|
delta, dt, h = self.forward(x.unsqueeze(1), cond, h)
|
|
return delta[:, 0, :], dt[:, 0], h
|
|
|
|
def init_cache(self, batch_size: int, device: torch.device) -> list[InferenceCache]:
|
|
return [
|
|
InferenceCache.alloc(batch_size, self.backbone.args, device=device)
|
|
for _ in range(self.backbone.args.n_layer)
|
|
]
|
|
|
|
|
|
class SwanLogger:
|
|
def __init__(self, cfg: TrainConfig, enabled: bool = True) -> None:
|
|
self.enabled = enabled
|
|
self._swan = None
|
|
self._run = None
|
|
if not self.enabled:
|
|
return
|
|
try:
|
|
import swanlab # type: ignore
|
|
|
|
self._swan = swanlab
|
|
self._run = self._swan.init(
|
|
project=cfg.project,
|
|
experiment_name=cfg.run_name,
|
|
config=asdict(cfg),
|
|
)
|
|
self.enabled = True
|
|
except Exception:
|
|
self.enabled = False
|
|
|
|
def log(self, metrics: dict, step: int | None = None) -> None:
|
|
if not self.enabled:
|
|
return
|
|
target = self._run if self._run is not None else self._swan
|
|
if step is None:
|
|
target.log(metrics)
|
|
return
|
|
try:
|
|
target.log(metrics, step=step)
|
|
except TypeError:
|
|
payload = dict(metrics)
|
|
payload["step"] = step
|
|
target.log(payload)
|
|
|
|
def log_image(
|
|
self,
|
|
key: str,
|
|
image_path: Path,
|
|
caption: str | None = None,
|
|
step: int | None = None,
|
|
) -> None:
|
|
if not self.enabled:
|
|
return
|
|
image = self._swan.Image(str(image_path), caption=caption)
|
|
self.log({key: image}, step=step)
|
|
|
|
def finish(self) -> None:
|
|
if not self.enabled:
|
|
return
|
|
finish = None
|
|
if self._run is not None:
|
|
finish = getattr(self._run, "finish", None)
|
|
if finish is None:
|
|
finish = getattr(self._swan, "finish", None)
|
|
if callable(finish):
|
|
finish()
|
|
|
|
|
|
def set_seed(seed: int) -> None:
|
|
torch.manual_seed(seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
|
|
def setup_distributed(cfg: TrainConfig) -> tuple[bool, int, int, torch.device]:
|
|
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
|
rank = int(os.environ.get("RANK", "0"))
|
|
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
|
use_ddp = cfg.use_ddp and world_size > 1
|
|
if use_ddp:
|
|
torch.distributed.init_process_group(backend="nccl", init_method="env://")
|
|
torch.cuda.set_device(local_rank)
|
|
device = torch.device("cuda", local_rank)
|
|
else:
|
|
device = torch.device(cfg.device if torch.cuda.is_available() else "cpu")
|
|
return use_ddp, rank, world_size, device
|
|
|
|
|
|
def unwrap_model(model: nn.Module) -> nn.Module:
|
|
return model.module if hasattr(model, "module") else model
|
|
|
|
|
|
def validate_time_config(cfg: TrainConfig) -> None:
|
|
if cfg.seq_len <= 0:
|
|
raise ValueError("seq_len must be > 0")
|
|
base = 1.0 / cfg.seq_len
|
|
if cfg.dt_max <= base:
|
|
raise ValueError(
|
|
"dt_max must be > 1/seq_len to allow non-uniform dt_seq. "
|
|
f"Got dt_max={cfg.dt_max}, seq_len={cfg.seq_len}, 1/seq_len={base}."
|
|
)
|
|
if cfg.dt_min >= cfg.dt_max:
|
|
raise ValueError(
|
|
f"dt_min must be < dt_max (got dt_min={cfg.dt_min}, dt_max={cfg.dt_max})."
|
|
)
|
|
|
|
|
|
def sample_time_sequence(
|
|
cfg: TrainConfig, batch_size: int, device: torch.device
|
|
) -> Tensor:
|
|
alpha = float(cfg.dt_alpha)
|
|
if alpha <= 0:
|
|
raise ValueError("dt_alpha must be > 0")
|
|
dist = torch.distributions.Gamma(alpha, 1.0)
|
|
raw = dist.sample((batch_size, cfg.seq_len)).to(device)
|
|
dt_seq = raw / raw.sum(dim=-1, keepdim=True)
|
|
base = 1.0 / cfg.seq_len
|
|
max_dt = float(cfg.dt_max)
|
|
if max_dt <= base:
|
|
return torch.full_like(dt_seq, base)
|
|
max_current = dt_seq.max(dim=-1, keepdim=True).values
|
|
if (max_current > max_dt).any():
|
|
gamma = (max_dt - base) / (max_current - base)
|
|
gamma = gamma.clamp(0.0, 1.0)
|
|
dt_seq = gamma * dt_seq + (1.0 - gamma) * base
|
|
return dt_seq
|
|
|
|
|
|
def build_dataloader(
|
|
cfg: TrainConfig, distributed: bool = False
|
|
) -> tuple[DataLoader, Optional[DistributedSampler]]:
|
|
ds = load_dataset(cfg.dataset_name, split=cfg.dataset_split)
|
|
|
|
def transform(example):
|
|
image = example.get("img", example.get("image"))
|
|
label = example.get("label", example.get("labels"))
|
|
if isinstance(image, list):
|
|
arr = np.stack([np.array(im, dtype=np.float32) for im in image], axis=0)
|
|
arr = arr / 127.5 - 1.0
|
|
if arr.ndim == 3:
|
|
tensor = torch.from_numpy(arr).unsqueeze(1)
|
|
else:
|
|
tensor = torch.from_numpy(arr).permute(0, 3, 1, 2)
|
|
labels = torch.tensor(label, dtype=torch.long)
|
|
return {"pixel_values": tensor, "labels": labels}
|
|
arr = np.array(image, dtype=np.float32) / 127.5 - 1.0
|
|
if arr.ndim == 2:
|
|
tensor = torch.from_numpy(arr).unsqueeze(0)
|
|
else:
|
|
tensor = torch.from_numpy(arr).permute(2, 0, 1)
|
|
return {"pixel_values": tensor, "labels": torch.tensor(label, dtype=torch.long)}
|
|
|
|
ds = ds.with_transform(transform)
|
|
sampler = DistributedSampler(ds, shuffle=True) if distributed else None
|
|
loader = DataLoader(
|
|
ds,
|
|
batch_size=cfg.batch_size,
|
|
shuffle=(sampler is None),
|
|
sampler=sampler,
|
|
num_workers=cfg.num_workers,
|
|
drop_last=True,
|
|
pin_memory=torch.cuda.is_available(),
|
|
)
|
|
return loader, sampler
|
|
|
|
|
|
def infinite_loader(loader: DataLoader) -> Iterator[dict]:
|
|
while True:
|
|
for batch in loader:
|
|
yield batch
|
|
|
|
|
|
def compute_losses(
|
|
delta: Tensor,
|
|
dt: Tensor,
|
|
x_seq: Tensor,
|
|
x0: Tensor,
|
|
v_gt: Tensor,
|
|
t_seq: Tensor,
|
|
dt_seq: Tensor,
|
|
cfg: TrainConfig,
|
|
) -> dict[str, Tensor]:
|
|
losses: dict[str, Tensor] = {}
|
|
|
|
if cfg.use_flow_loss:
|
|
target_disp = v_gt[:, None, :] * dt.unsqueeze(-1)
|
|
losses["flow"] = F.mse_loss(delta, target_disp)
|
|
|
|
if cfg.use_pos_loss:
|
|
t_next = t_seq + dt
|
|
t_next = torch.clamp(t_next, 0.0, 1.0)
|
|
x_target = x0[:, None, :] + t_next.unsqueeze(-1) * v_gt[:, None, :]
|
|
x_next_pred = x_seq + delta
|
|
losses["pos"] = F.mse_loss(x_next_pred, x_target)
|
|
|
|
if cfg.use_dt_loss:
|
|
losses["dt"] = F.mse_loss(dt, dt_seq)
|
|
|
|
return losses
|
|
|
|
|
|
def plot_dt_hist(
|
|
dt_pred: Tensor, dt_gt: Tensor, save_path: Path, title: str = "dt Distribution"
|
|
) -> None:
|
|
dt_pred_np = dt_pred.detach().cpu().numpy().reshape(-1)
|
|
dt_gt_np = dt_gt.detach().cpu().numpy().reshape(-1)
|
|
|
|
fig, ax = plt.subplots(figsize=(6, 4))
|
|
ax.hist(dt_gt_np, bins=30, alpha=0.6, label="dt_gt", color="steelblue")
|
|
ax.hist(dt_pred_np, bins=30, alpha=0.6, label="dt_pred", color="orange")
|
|
ax.set_title(title)
|
|
ax.set_xlabel("dt")
|
|
ax.set_ylabel("count")
|
|
ax.legend(loc="best")
|
|
fig.tight_layout()
|
|
fig.savefig(save_path, dpi=160)
|
|
plt.close(fig)
|
|
|
|
|
|
def make_grid(images: Tensor, nrow: int) -> np.ndarray:
|
|
images = images.detach().cpu().numpy()
|
|
b, c, h, w = images.shape
|
|
nrow = max(1, min(nrow, b))
|
|
ncol = math.ceil(b / nrow)
|
|
grid = np.zeros((c, ncol * h, nrow * w), dtype=np.float32)
|
|
for idx in range(b):
|
|
r = idx // nrow
|
|
cidx = idx % nrow
|
|
grid[:, r * h : (r + 1) * h, cidx * w : (cidx + 1) * w] = images[idx]
|
|
return np.transpose(grid, (1, 2, 0))
|
|
|
|
|
|
def save_image_grid(
|
|
images: Tensor, save_path: Path, nrow: int, title: str | None = None
|
|
) -> None:
|
|
images = images.clamp(-1.0, 1.0)
|
|
images = (images + 1.0) / 2.0
|
|
grid = make_grid(images, nrow=nrow)
|
|
if grid.ndim == 3 and grid.shape[2] == 1:
|
|
grid = np.repeat(grid, 3, axis=2)
|
|
plt.imsave(save_path, grid)
|
|
if title is not None:
|
|
fig, ax = plt.subplots(figsize=(4, 3))
|
|
ax.imshow(grid)
|
|
ax.set_title(title)
|
|
ax.axis("off")
|
|
fig.tight_layout()
|
|
fig.savefig(save_path, dpi=160)
|
|
plt.close(fig)
|
|
|
|
|
|
def rollout_trajectory(
|
|
model: ASMamba,
|
|
x0: Tensor,
|
|
cond: Tensor,
|
|
max_steps: int,
|
|
) -> Tensor:
|
|
device = x0.device
|
|
model.eval()
|
|
h = model.init_cache(batch_size=x0.shape[0], device=device)
|
|
x = x0
|
|
total_time = torch.zeros(x0.shape[0], device=device)
|
|
traj = [x0.detach().cpu()]
|
|
|
|
with torch.no_grad():
|
|
for _ in range(max_steps):
|
|
delta, dt, h = model.step(x, cond, h)
|
|
dt = torch.clamp(dt, min=model.dt_min, max=model.dt_max)
|
|
remaining = 1.0 - total_time
|
|
overshoot = dt > remaining
|
|
if overshoot.any():
|
|
scale = (remaining / dt).unsqueeze(-1)
|
|
delta = torch.where(overshoot.unsqueeze(-1), delta * scale, delta)
|
|
dt = torch.where(overshoot, remaining, dt)
|
|
x = x + delta
|
|
total_time = total_time + dt
|
|
traj.append(x.detach().cpu())
|
|
if torch.all(total_time >= 1.0 - 1e-6):
|
|
break
|
|
|
|
return torch.stack(traj, dim=1)
|
|
|
|
|
|
def log_class_samples(
|
|
model: ASMamba,
|
|
cfg: TrainConfig,
|
|
device: torch.device,
|
|
logger: SwanLogger,
|
|
step: int,
|
|
) -> None:
|
|
if cfg.val_samples_per_class <= 0:
|
|
return
|
|
model.eval()
|
|
max_steps = cfg.seq_len if cfg.val_max_steps <= 0 else cfg.val_max_steps
|
|
input_dim = cfg.channels * cfg.image_size * cfg.image_size
|
|
|
|
for cls in range(cfg.num_classes):
|
|
cond = torch.full(
|
|
(cfg.val_samples_per_class,), cls, device=device, dtype=torch.long
|
|
)
|
|
x0 = torch.randn(cfg.val_samples_per_class, input_dim, device=device)
|
|
traj = rollout_trajectory(model, x0, cond, max_steps=max_steps)
|
|
x_final = traj[:, -1, :].view(
|
|
cfg.val_samples_per_class, cfg.channels, cfg.image_size, cfg.image_size
|
|
)
|
|
save_path = Path(cfg.output_dir) / f"val_class_{cls}_step_{step:06d}.png"
|
|
save_image_grid(x_final, save_path, nrow=cfg.val_grid_rows)
|
|
logger.log_image(
|
|
f"val/class_{cls}",
|
|
save_path,
|
|
caption=f"class {cls} step {step}",
|
|
step=step,
|
|
)
|
|
model.train()
|
|
|
|
|
|
def train(cfg: TrainConfig) -> ASMamba:
|
|
validate_time_config(cfg)
|
|
use_ddp, rank, world_size, device = setup_distributed(cfg)
|
|
set_seed(cfg.seed + rank)
|
|
output_dir = Path(cfg.output_dir)
|
|
if rank == 0:
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
model = ASMamba(cfg).to(device)
|
|
if use_ddp:
|
|
model = nn.parallel.DistributedDataParallel(model, device_ids=[device.index])
|
|
optimizer = torch.optim.AdamW(
|
|
model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay
|
|
)
|
|
logger = SwanLogger(cfg, enabled=(rank == 0))
|
|
|
|
loader, sampler = build_dataloader(cfg, distributed=use_ddp)
|
|
loader_iter = infinite_loader(loader)
|
|
|
|
global_step = 0
|
|
for _ in range(cfg.epochs):
|
|
if sampler is not None:
|
|
sampler.set_epoch(global_step)
|
|
model.train()
|
|
for _ in range(cfg.steps_per_epoch):
|
|
batch = next(loader_iter)
|
|
x1 = batch["pixel_values"].to(device)
|
|
cond = batch["labels"].to(device)
|
|
b = x1.shape[0]
|
|
x1 = x1.view(b, -1)
|
|
x0 = torch.randn_like(x1)
|
|
v_gt = x1 - x0
|
|
dt_seq = sample_time_sequence(cfg, b, device)
|
|
t_seq = torch.cumsum(dt_seq, dim=-1)
|
|
t_seq = torch.cat([torch.zeros(b, 1, device=device), t_seq[:, :-1]], dim=-1)
|
|
x_seq = x0[:, None, :] + t_seq[:, :, None] * v_gt[:, None, :]
|
|
|
|
delta, dt, _ = model(x_seq, cond)
|
|
|
|
losses = compute_losses(
|
|
delta=delta,
|
|
dt=dt,
|
|
x_seq=x_seq,
|
|
x0=x0,
|
|
v_gt=v_gt,
|
|
t_seq=t_seq,
|
|
dt_seq=dt_seq,
|
|
cfg=cfg,
|
|
)
|
|
|
|
loss = torch.tensor(0.0, device=device)
|
|
if cfg.use_flow_loss and "flow" in losses:
|
|
loss = loss + cfg.lambda_flow * losses["flow"]
|
|
if cfg.use_pos_loss and "pos" in losses:
|
|
loss = loss + cfg.lambda_pos * losses["pos"]
|
|
if cfg.use_dt_loss and "dt" in losses:
|
|
loss = loss + cfg.lambda_dt * losses["dt"]
|
|
if loss.item() == 0.0:
|
|
raise RuntimeError(
|
|
"No loss enabled: enable at least one of flow/pos/dt."
|
|
)
|
|
|
|
optimizer.zero_grad(set_to_none=True)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
if global_step % 10 == 0:
|
|
dt_min = float(dt.min().item())
|
|
dt_max = float(dt.max().item())
|
|
dt_mean = float(dt.mean().item())
|
|
dt_gt_min = float(dt_seq.min().item())
|
|
dt_gt_max = float(dt_seq.max().item())
|
|
dt_gt_mean = float(dt_seq.mean().item())
|
|
eps = 1e-6
|
|
clamp_min_ratio = float((dt <= cfg.dt_min + eps).float().mean().item())
|
|
clamp_max_ratio = float((dt >= cfg.dt_max - eps).float().mean().item())
|
|
clamp_any_ratio = float(
|
|
((dt <= cfg.dt_min + eps) | (dt >= cfg.dt_max - eps))
|
|
.float()
|
|
.mean()
|
|
.item()
|
|
)
|
|
logger.log(
|
|
{
|
|
"loss/total": float(loss.item()),
|
|
"loss/flow": float(
|
|
losses.get("flow", torch.tensor(0.0)).item()
|
|
),
|
|
"loss/pos": float(losses.get("pos", torch.tensor(0.0)).item()),
|
|
"loss/dt": float(losses.get("dt", torch.tensor(0.0)).item()),
|
|
"dt/pred_mean": dt_mean,
|
|
"dt/pred_min": dt_min,
|
|
"dt/pred_max": dt_max,
|
|
"dt/gt_mean": dt_gt_mean,
|
|
"dt/gt_min": dt_gt_min,
|
|
"dt/gt_max": dt_gt_max,
|
|
"dt/clamp_min_ratio": clamp_min_ratio,
|
|
"dt/clamp_max_ratio": clamp_max_ratio,
|
|
"dt/clamp_any_ratio": clamp_any_ratio,
|
|
},
|
|
step=global_step,
|
|
)
|
|
|
|
if (
|
|
cfg.val_every > 0
|
|
and global_step > 0
|
|
and global_step % cfg.val_every == 0
|
|
and rank == 0
|
|
):
|
|
log_class_samples(unwrap_model(model), cfg, device, logger, global_step)
|
|
dt_hist_path = (
|
|
Path(cfg.output_dir) / f"dt_hist_step_{global_step:06d}.png"
|
|
)
|
|
plot_dt_hist(
|
|
dt,
|
|
dt_seq,
|
|
dt_hist_path,
|
|
title=f"dt Distribution (step {global_step})",
|
|
)
|
|
logger.log_image(
|
|
"train/dt_hist",
|
|
dt_hist_path,
|
|
caption=f"step {global_step}",
|
|
step=global_step,
|
|
)
|
|
|
|
global_step += 1
|
|
|
|
logger.finish()
|
|
if use_ddp:
|
|
torch.distributed.destroy_process_group()
|
|
return unwrap_model(model)
|
|
|
|
|
|
def run_training_and_plot(cfg: TrainConfig) -> Path:
|
|
train(cfg)
|
|
return Path(cfg.output_dir)
|