Compare commits
4 Commits
dynamic_dt
...
mamba_mean
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5897a0afd1 | ||
|
|
9b2968997c | ||
|
|
01fc1e4eab | ||
|
|
913740266b |
27
AGENTS.md
Normal file
27
AGENTS.md
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
# Repository Guidelines
|
||||||
|
|
||||||
|
## Project Structure & Module Organization
|
||||||
|
This repository is a flat Python training project, not a packaged library. `main.py` is the CLI entry point. `as_mamba.py` holds `TrainConfig`, dataset loading, the training loop, validation image generation, and SwanLab logging. `mamba2_minimal.py` contains the standalone Mamba-2 building blocks. `train_as_mamba.sh` is the multi-GPU launcher. Runtime artifacts land in `outputs/` and `swanlog/`; treat both as generated data, not source.
|
||||||
|
|
||||||
|
## Build, Test, and Development Commands
|
||||||
|
Use `uv` with the committed lockfile.
|
||||||
|
|
||||||
|
- `uv sync` installs the Python 3.12 environment from `pyproject.toml` and `uv.lock`.
|
||||||
|
- `uv run python main.py --help` lists all training flags and is the fastest CLI sanity check.
|
||||||
|
- `uv run python main.py --device cpu --epochs 1 --steps-per-epoch 1 --batch-size 8 --num-workers 0 --val-every 0` runs a minimal local smoke test.
|
||||||
|
- `bash train_as_mamba.sh` launches the default distributed training job with `torchrun`; adjust `--nproc_per_node` and shell variables before using it on a new machine.
|
||||||
|
- `uv run python -m compileall main.py as_mamba.py mamba2_minimal.py` performs a lightweight syntax check when no test suite is available.
|
||||||
|
|
||||||
|
## Coding Style & Naming Conventions
|
||||||
|
Follow the existing Python style: 4-space indentation, type hints on public functions, `dataclass`-based config, and small helper functions over deeply nested logic. Use `snake_case` for functions, variables, and CLI flags; use `PascalCase` for classes such as `ASMamba` and `TrainConfig`. Keep new modules top-level unless the project is restructured. There is no configured formatter or linter yet, so match surrounding code closely and keep imports grouped and readable.
|
||||||
|
|
||||||
|
## Testing Guidelines
|
||||||
|
There is no dedicated `tests/` directory or pytest setup yet. For changes, require at minimum:
|
||||||
|
|
||||||
|
- `uv run python -m compileall main.py as_mamba.py mamba2_minimal.py`
|
||||||
|
- one short training smoke run with reduced epochs and workers
|
||||||
|
|
||||||
|
If you add reusable logic, prefer extracting pure functions so a future pytest suite can cover them easily. Name any new test files `test_<module>.py`.
|
||||||
|
|
||||||
|
## Commit & Pull Request Guidelines
|
||||||
|
Recent history favors short imperative commit subjects, usually Conventional Commit style: `feat: ...`, `fix: ...`, and scoped variants like `feat(mamba): ...`. Keep commits focused on one training or infrastructure change. Pull requests should describe the config or behavior change, list the verification commands you ran, and include sample images or metric notes when outputs in `outputs/` materially change. Do not commit `.venv/`, `__pycache__/`, or large generated artifacts.
|
||||||
451
as_mamba.py
451
as_mamba.py
@@ -1,11 +1,14 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import asdict, dataclass
|
|
||||||
import os
|
import os
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterator, Optional
|
from typing import Iterator, Optional
|
||||||
|
|
||||||
|
import lpips
|
||||||
|
|
||||||
|
os.environ.setdefault("MPLCONFIGDIR", "/tmp/mamba_diffusion_mplconfig")
|
||||||
import matplotlib
|
import matplotlib
|
||||||
|
|
||||||
matplotlib.use("Agg")
|
matplotlib.use("Agg")
|
||||||
@@ -16,12 +19,17 @@ import torch.nn.functional as F
|
|||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
from torch.func import jvp
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
from mamba2_minimal import InferenceCache, Mamba2, Mamba2Config, RMSNorm
|
from mamba2_minimal import InferenceCache, Mamba2, Mamba2Config, RMSNorm
|
||||||
|
|
||||||
|
|
||||||
|
FIXED_VAL_SAMPLING_STEPS = 5
|
||||||
|
FIXED_VAL_TIME_GRID = (1.0, 0.8, 0.6, 0.4, 0.2, 0.0)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainConfig:
|
class TrainConfig:
|
||||||
seed: int = 42
|
seed: int = 42
|
||||||
@@ -29,18 +37,12 @@ class TrainConfig:
|
|||||||
epochs: int = 50
|
epochs: int = 50
|
||||||
steps_per_epoch: int = 200
|
steps_per_epoch: int = 200
|
||||||
batch_size: int = 128
|
batch_size: int = 128
|
||||||
seq_len: int = 20
|
seq_len: int = 5
|
||||||
lr: float = 2e-4
|
lr: float = 2e-4
|
||||||
weight_decay: float = 1e-2
|
weight_decay: float = 1e-2
|
||||||
dt_min: float = 1e-3
|
max_grad_norm: float = 1.0
|
||||||
dt_max: float = 0.06
|
|
||||||
dt_alpha: float = 8.0
|
|
||||||
lambda_flow: float = 1.0
|
lambda_flow: float = 1.0
|
||||||
lambda_pos: float = 1.0
|
lambda_perceptual: float = 2.0
|
||||||
lambda_dt: float = 1.0
|
|
||||||
use_flow_loss: bool = True
|
|
||||||
use_pos_loss: bool = False
|
|
||||||
use_dt_loss: bool = True
|
|
||||||
num_classes: int = 10
|
num_classes: int = 10
|
||||||
image_size: int = 28
|
image_size: int = 28
|
||||||
channels: int = 1
|
channels: int = 1
|
||||||
@@ -57,11 +59,12 @@ class TrainConfig:
|
|||||||
use_residual: bool = False
|
use_residual: bool = False
|
||||||
output_dir: str = "outputs"
|
output_dir: str = "outputs"
|
||||||
project: str = "as-mamba-mnist"
|
project: str = "as-mamba-mnist"
|
||||||
run_name: str = "mnist-flow"
|
run_name: str = "mnist-meanflow"
|
||||||
val_every: int = 200
|
val_every: int = 200
|
||||||
val_samples_per_class: int = 8
|
val_samples_per_class: int = 8
|
||||||
val_grid_rows: int = 4
|
val_grid_rows: int = 4
|
||||||
val_max_steps: int = 0
|
val_sampling_steps: int = FIXED_VAL_SAMPLING_STEPS
|
||||||
|
time_grid_size: int = 256
|
||||||
use_ddp: bool = False
|
use_ddp: bool = False
|
||||||
|
|
||||||
|
|
||||||
@@ -113,12 +116,35 @@ class Mamba2Backbone(nn.Module):
|
|||||||
return x, h
|
return x, h
|
||||||
|
|
||||||
|
|
||||||
|
def sinusoidal_embedding(t: Tensor, dim: int) -> Tensor:
|
||||||
|
if dim <= 0:
|
||||||
|
raise ValueError("sinusoidal embedding dim must be > 0")
|
||||||
|
half = dim // 2
|
||||||
|
if half == 0:
|
||||||
|
return torch.zeros(*t.shape, dim, device=t.device, dtype=t.dtype)
|
||||||
|
denom = max(half - 1, 1)
|
||||||
|
freqs = torch.exp(
|
||||||
|
-math.log(10000.0)
|
||||||
|
* torch.arange(half, device=t.device, dtype=torch.float32)
|
||||||
|
/ denom
|
||||||
|
)
|
||||||
|
args = t.to(torch.float32).unsqueeze(-1) * freqs
|
||||||
|
emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
|
||||||
|
if dim % 2 == 1:
|
||||||
|
pad = torch.zeros(*emb.shape[:-1], 1, device=emb.device, dtype=emb.dtype)
|
||||||
|
emb = torch.cat([emb, pad], dim=-1)
|
||||||
|
return emb.to(t.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def safe_time_divisor(t: Tensor) -> Tensor:
|
||||||
|
eps = torch.finfo(t.dtype).eps
|
||||||
|
return torch.where(t > 0, t, torch.full_like(t, eps))
|
||||||
|
|
||||||
|
|
||||||
class ASMamba(nn.Module):
|
class ASMamba(nn.Module):
|
||||||
def __init__(self, cfg: TrainConfig) -> None:
|
def __init__(self, cfg: TrainConfig) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.dt_min = float(cfg.dt_min)
|
|
||||||
self.dt_max = float(cfg.dt_max)
|
|
||||||
input_dim = cfg.channels * cfg.image_size * cfg.image_size
|
input_dim = cfg.channels * cfg.image_size * cfg.image_size
|
||||||
if cfg.d_model == 0:
|
if cfg.d_model == 0:
|
||||||
cfg.d_model = input_dim
|
cfg.d_model = input_dim
|
||||||
@@ -138,28 +164,42 @@ class ASMamba(nn.Module):
|
|||||||
)
|
)
|
||||||
self.backbone = Mamba2Backbone(args, use_residual=cfg.use_residual)
|
self.backbone = Mamba2Backbone(args, use_residual=cfg.use_residual)
|
||||||
self.cond_emb = nn.Embedding(cfg.num_classes, cfg.d_model)
|
self.cond_emb = nn.Embedding(cfg.num_classes, cfg.d_model)
|
||||||
self.delta_head = nn.Linear(cfg.d_model, input_dim)
|
self.clean_head = nn.Linear(cfg.d_model, input_dim)
|
||||||
self.dt_head = nn.Sequential(
|
|
||||||
nn.Linear(cfg.d_model, cfg.d_model),
|
|
||||||
nn.SiLU(),
|
|
||||||
nn.Linear(cfg.d_model, 1),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: Tensor, cond: Tensor, h: Optional[list[InferenceCache]] = None
|
self,
|
||||||
) -> tuple[Tensor, Tensor, list[InferenceCache]]:
|
z_t: Tensor,
|
||||||
|
r: Tensor,
|
||||||
|
t: Tensor,
|
||||||
|
cond: Tensor,
|
||||||
|
h: Optional[list[InferenceCache]] = None,
|
||||||
|
) -> tuple[Tensor, list[InferenceCache]]:
|
||||||
|
if r.dim() == 1:
|
||||||
|
r = r.unsqueeze(1)
|
||||||
|
elif r.dim() == 3 and r.shape[-1] == 1:
|
||||||
|
r = r.squeeze(-1)
|
||||||
|
if t.dim() == 1:
|
||||||
|
t = t.unsqueeze(1)
|
||||||
|
elif t.dim() == 3 and t.shape[-1] == 1:
|
||||||
|
t = t.squeeze(-1)
|
||||||
|
|
||||||
|
r = r.to(dtype=z_t.dtype)
|
||||||
|
t = t.to(dtype=z_t.dtype)
|
||||||
|
z_t = z_t + sinusoidal_embedding(r, z_t.shape[-1]) + sinusoidal_embedding(
|
||||||
|
t, z_t.shape[-1]
|
||||||
|
)
|
||||||
cond_vec = self.cond_emb(cond)
|
cond_vec = self.cond_emb(cond)
|
||||||
feats, h = self.backbone(x, cond_vec, h)
|
feats, h = self.backbone(z_t, cond_vec, h)
|
||||||
delta = self.delta_head(feats)
|
x_pred = torch.tanh(self.clean_head(feats))
|
||||||
dt_raw = self.dt_head(feats).squeeze(-1)
|
return x_pred, h
|
||||||
dt = torch.clamp(F.softplus(dt_raw), min=self.dt_min, max=self.dt_max)
|
|
||||||
return delta, dt, h
|
|
||||||
|
|
||||||
def step(
|
def step(
|
||||||
self, x: Tensor, cond: Tensor, h: list[InferenceCache]
|
self, z_t: Tensor, r: Tensor, t: Tensor, cond: Tensor, h: list[InferenceCache]
|
||||||
) -> tuple[Tensor, Tensor, list[InferenceCache]]:
|
) -> tuple[Tensor, list[InferenceCache]]:
|
||||||
delta, dt, h = self.forward(x.unsqueeze(1), cond, h)
|
x_pred, h = self.forward(
|
||||||
return delta[:, 0, :], dt[:, 0], h
|
z_t.unsqueeze(1), r.unsqueeze(1), t.unsqueeze(1), cond, h
|
||||||
|
)
|
||||||
|
return x_pred[:, 0, :], h
|
||||||
|
|
||||||
def init_cache(self, batch_size: int, device: torch.device) -> list[InferenceCache]:
|
def init_cache(self, batch_size: int, device: torch.device) -> list[InferenceCache]:
|
||||||
return [
|
return [
|
||||||
@@ -226,6 +266,34 @@ class SwanLogger:
|
|||||||
finish()
|
finish()
|
||||||
|
|
||||||
|
|
||||||
|
class LPIPSPerceptualLoss(nn.Module):
|
||||||
|
def __init__(self, cfg: TrainConfig) -> None:
|
||||||
|
super().__init__()
|
||||||
|
torch_home = Path(cfg.output_dir) / ".torch"
|
||||||
|
torch_home.mkdir(parents=True, exist_ok=True)
|
||||||
|
os.environ["TORCH_HOME"] = str(torch_home)
|
||||||
|
self.channels = cfg.channels
|
||||||
|
self.loss_fn = lpips.LPIPS(net="vgg", verbose=False)
|
||||||
|
self.loss_fn.eval()
|
||||||
|
for param in self.loss_fn.parameters():
|
||||||
|
param.requires_grad_(False)
|
||||||
|
|
||||||
|
def _prepare_images(self, images: Tensor) -> Tensor:
|
||||||
|
if images.shape[1] == 1:
|
||||||
|
return images.repeat(1, 3, 1, 1)
|
||||||
|
if images.shape[1] != 3:
|
||||||
|
raise ValueError(
|
||||||
|
"LPIPS perceptual loss expects 1-channel or 3-channel images. "
|
||||||
|
f"Got {images.shape[1]} channels."
|
||||||
|
)
|
||||||
|
return images
|
||||||
|
|
||||||
|
def forward(self, pred: Tensor, target: Tensor) -> Tensor:
|
||||||
|
pred_rgb = self._prepare_images(pred)
|
||||||
|
target_rgb = self._prepare_images(target)
|
||||||
|
return self.loss_fn(pred_rgb, target_rgb).mean()
|
||||||
|
|
||||||
|
|
||||||
def set_seed(seed: int) -> None:
|
def set_seed(seed: int) -> None:
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@@ -250,40 +318,42 @@ def unwrap_model(model: nn.Module) -> nn.Module:
|
|||||||
return model.module if hasattr(model, "module") else model
|
return model.module if hasattr(model, "module") else model
|
||||||
|
|
||||||
|
|
||||||
def validate_time_config(cfg: TrainConfig) -> None:
|
def validate_config(cfg: TrainConfig) -> None:
|
||||||
if cfg.seq_len <= 0:
|
if cfg.seq_len != 5:
|
||||||
raise ValueError("seq_len must be > 0")
|
|
||||||
base = 1.0 / cfg.seq_len
|
|
||||||
if cfg.dt_max <= base:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"dt_max must be > 1/seq_len to allow non-uniform dt_seq. "
|
f"seq_len must be 5 for the required 5-block training setup (got {cfg.seq_len})."
|
||||||
f"Got dt_max={cfg.dt_max}, seq_len={cfg.seq_len}, 1/seq_len={base}."
|
|
||||||
)
|
)
|
||||||
if cfg.dt_min >= cfg.dt_max:
|
if cfg.time_grid_size < 2:
|
||||||
|
raise ValueError("time_grid_size must be >= 2.")
|
||||||
|
if cfg.lambda_perceptual < 0:
|
||||||
|
raise ValueError("lambda_perceptual must be >= 0.")
|
||||||
|
if cfg.max_grad_norm <= 0:
|
||||||
|
raise ValueError("max_grad_norm must be > 0.")
|
||||||
|
if cfg.val_sampling_steps != FIXED_VAL_SAMPLING_STEPS:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"dt_min must be < dt_max (got dt_min={cfg.dt_min}, dt_max={cfg.dt_max})."
|
f"val_sampling_steps is fixed to {FIXED_VAL_SAMPLING_STEPS} for validation sampling."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def sample_time_sequence(
|
def sample_block_times(
|
||||||
cfg: TrainConfig, batch_size: int, device: torch.device
|
cfg: TrainConfig, batch_size: int, device: torch.device, dtype: torch.dtype
|
||||||
) -> Tensor:
|
) -> tuple[Tensor, Tensor]:
|
||||||
alpha = float(cfg.dt_alpha)
|
num_internal = cfg.seq_len - 1
|
||||||
if alpha <= 0:
|
normal = torch.randn(batch_size, num_internal, device=device, dtype=dtype)
|
||||||
raise ValueError("dt_alpha must be > 0")
|
logit_normal = torch.sigmoid(normal * math.sqrt(0.8))
|
||||||
dist = torch.distributions.Gamma(alpha, 1.0)
|
uniform = torch.rand(batch_size, num_internal, device=device, dtype=dtype)
|
||||||
raw = dist.sample((batch_size, cfg.seq_len)).to(device)
|
use_uniform = torch.rand(batch_size, num_internal, device=device) < 0.1
|
||||||
dt_seq = raw / raw.sum(dim=-1, keepdim=True)
|
cuts = torch.where(use_uniform, uniform, logit_normal)
|
||||||
base = 1.0 / cfg.seq_len
|
cuts, _ = torch.sort(cuts, dim=-1)
|
||||||
max_dt = float(cfg.dt_max)
|
boundaries = torch.cat(
|
||||||
if max_dt <= base:
|
[
|
||||||
return torch.full_like(dt_seq, base)
|
torch.zeros(batch_size, 1, device=device, dtype=dtype),
|
||||||
max_current = dt_seq.max(dim=-1, keepdim=True).values
|
cuts,
|
||||||
if (max_current > max_dt).any():
|
torch.ones(batch_size, 1, device=device, dtype=dtype),
|
||||||
gamma = (max_dt - base) / (max_current - base)
|
],
|
||||||
gamma = gamma.clamp(0.0, 1.0)
|
dim=-1,
|
||||||
dt_seq = gamma * dt_seq + (1.0 - gamma) * base
|
)
|
||||||
return dt_seq
|
return boundaries[:, :-1], boundaries[:, 1:]
|
||||||
|
|
||||||
|
|
||||||
def build_dataloader(
|
def build_dataloader(
|
||||||
@@ -330,51 +400,66 @@ def infinite_loader(loader: DataLoader) -> Iterator[dict]:
|
|||||||
yield batch
|
yield batch
|
||||||
|
|
||||||
|
|
||||||
def compute_losses(
|
def build_noisy_sequence(
|
||||||
delta: Tensor,
|
|
||||||
dt: Tensor,
|
|
||||||
x_seq: Tensor,
|
|
||||||
x0: Tensor,
|
x0: Tensor,
|
||||||
v_gt: Tensor,
|
eps: Tensor,
|
||||||
t_seq: Tensor,
|
t_seq: Tensor,
|
||||||
dt_seq: Tensor,
|
) -> tuple[Tensor, Tensor]:
|
||||||
|
z_t = (1.0 - t_seq.unsqueeze(-1)) * x0[:, None, :] + t_seq.unsqueeze(-1) * eps[:, None, :]
|
||||||
|
v_gt = eps - x0
|
||||||
|
return z_t, v_gt
|
||||||
|
|
||||||
|
|
||||||
|
def compute_losses(
|
||||||
|
model: nn.Module,
|
||||||
|
perceptual_loss_fn: LPIPSPerceptualLoss,
|
||||||
|
x0: Tensor,
|
||||||
|
z_t: Tensor,
|
||||||
|
v_gt: Tensor,
|
||||||
|
r_seq: Tensor,
|
||||||
|
t_seq: Tensor,
|
||||||
|
cond: Tensor,
|
||||||
cfg: TrainConfig,
|
cfg: TrainConfig,
|
||||||
) -> dict[str, Tensor]:
|
) -> tuple[dict[str, Tensor], Tensor]:
|
||||||
losses: dict[str, Tensor] = {}
|
seq_len = z_t.shape[1]
|
||||||
|
safe_t = safe_time_divisor(t_seq).unsqueeze(-1)
|
||||||
|
|
||||||
if cfg.use_flow_loss:
|
x_pred, _ = model(z_t, r_seq, t_seq, cond)
|
||||||
target_disp = v_gt[:, None, :] * dt.unsqueeze(-1)
|
u = (z_t - x_pred) / safe_t
|
||||||
losses["flow"] = F.mse_loss(delta, target_disp)
|
|
||||||
|
|
||||||
if cfg.use_pos_loss:
|
x_pred_inst, _ = model(z_t, t_seq, t_seq, cond)
|
||||||
t_next = t_seq + dt
|
v_inst = ((z_t - x_pred_inst) / safe_t).detach()
|
||||||
t_next = torch.clamp(t_next, 0.0, 1.0)
|
|
||||||
x_target = x0[:, None, :] + t_next.unsqueeze(-1) * v_gt[:, None, :]
|
|
||||||
x_next_pred = x_seq + delta
|
|
||||||
losses["pos"] = F.mse_loss(x_next_pred, x_target)
|
|
||||||
|
|
||||||
if cfg.use_dt_loss:
|
def u_fn(z_in: Tensor, r_in: Tensor, t_in: Tensor) -> Tensor:
|
||||||
losses["dt"] = F.mse_loss(dt, dt_seq)
|
x_pred_local, _ = model(z_in, r_in, t_in, cond)
|
||||||
|
return (z_in - x_pred_local) / safe_time_divisor(t_in).unsqueeze(-1)
|
||||||
|
|
||||||
return losses
|
_, dudt = jvp(
|
||||||
|
u_fn,
|
||||||
|
(z_t, r_seq, t_seq),
|
||||||
|
(v_inst, torch.zeros_like(r_seq), torch.ones_like(t_seq)),
|
||||||
|
)
|
||||||
|
corrected_velocity = u + (t_seq - r_seq).unsqueeze(-1) * dudt.detach()
|
||||||
|
target_velocity = v_gt[:, None, :].expand(-1, seq_len, -1)
|
||||||
|
|
||||||
|
pred_images = x_pred.reshape(
|
||||||
|
x0.shape[0] * seq_len, cfg.channels, cfg.image_size, cfg.image_size
|
||||||
|
)
|
||||||
|
target_images = (
|
||||||
|
x0.reshape(x0.shape[0], cfg.channels, cfg.image_size, cfg.image_size)
|
||||||
|
.unsqueeze(1)
|
||||||
|
.expand(-1, seq_len, -1, -1, -1)
|
||||||
|
.reshape(x0.shape[0] * seq_len, cfg.channels, cfg.image_size, cfg.image_size)
|
||||||
|
)
|
||||||
|
|
||||||
def plot_dt_hist(
|
losses = {
|
||||||
dt_pred: Tensor, dt_gt: Tensor, save_path: Path, title: str = "dt Distribution"
|
"flow": F.mse_loss(corrected_velocity, target_velocity),
|
||||||
) -> None:
|
"perceptual": perceptual_loss_fn(pred_images, target_images),
|
||||||
dt_pred_np = dt_pred.detach().cpu().numpy().reshape(-1)
|
}
|
||||||
dt_gt_np = dt_gt.detach().cpu().numpy().reshape(-1)
|
losses["total"] = cfg.lambda_flow * losses["flow"] + cfg.lambda_perceptual * losses[
|
||||||
|
"perceptual"
|
||||||
fig, ax = plt.subplots(figsize=(6, 4))
|
]
|
||||||
ax.hist(dt_gt_np, bins=30, alpha=0.6, label="dt_gt", color="steelblue")
|
return losses, x_pred
|
||||||
ax.hist(dt_pred_np, bins=30, alpha=0.6, label="dt_pred", color="orange")
|
|
||||||
ax.set_title(title)
|
|
||||||
ax.set_xlabel("dt")
|
|
||||||
ax.set_ylabel("count")
|
|
||||||
ax.legend(loc="best")
|
|
||||||
fig.tight_layout()
|
|
||||||
fig.savefig(save_path, dpi=160)
|
|
||||||
plt.close(fig)
|
|
||||||
|
|
||||||
|
|
||||||
def make_grid(images: Tensor, nrow: int) -> np.ndarray:
|
def make_grid(images: Tensor, nrow: int) -> np.ndarray:
|
||||||
@@ -409,36 +494,36 @@ def save_image_grid(
|
|||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
def rollout_trajectory(
|
def sample_class_images(
|
||||||
model: ASMamba,
|
model: ASMamba,
|
||||||
x0: Tensor,
|
cfg: TrainConfig,
|
||||||
|
device: torch.device,
|
||||||
cond: Tensor,
|
cond: Tensor,
|
||||||
max_steps: int,
|
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
device = x0.device
|
|
||||||
model.eval()
|
model.eval()
|
||||||
h = model.init_cache(batch_size=x0.shape[0], device=device)
|
input_dim = cfg.channels * cfg.image_size * cfg.image_size
|
||||||
x = x0
|
z_t = torch.randn(cond.shape[0], input_dim, device=device)
|
||||||
total_time = torch.zeros(x0.shape[0], device=device)
|
time_grid = torch.tensor(FIXED_VAL_TIME_GRID, device=device)
|
||||||
traj = [x0.detach().cpu()]
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for _ in range(max_steps):
|
for step_idx in range(FIXED_VAL_SAMPLING_STEPS):
|
||||||
delta, dt, h = model.step(x, cond, h)
|
t_cur = torch.full(
|
||||||
dt = torch.clamp(dt, min=model.dt_min, max=model.dt_max)
|
(cond.shape[0],),
|
||||||
remaining = 1.0 - total_time
|
float(time_grid[step_idx].item()),
|
||||||
overshoot = dt > remaining
|
device=device,
|
||||||
if overshoot.any():
|
)
|
||||||
scale = (remaining / dt).unsqueeze(-1)
|
t_next = time_grid[step_idx + 1]
|
||||||
delta = torch.where(overshoot.unsqueeze(-1), delta * scale, delta)
|
x_pred, _ = model(
|
||||||
dt = torch.where(overshoot, remaining, dt)
|
z_t.unsqueeze(1),
|
||||||
x = x + delta
|
t_cur.unsqueeze(1),
|
||||||
total_time = total_time + dt
|
t_cur.unsqueeze(1),
|
||||||
traj.append(x.detach().cpu())
|
cond,
|
||||||
if torch.all(total_time >= 1.0 - 1e-6):
|
)
|
||||||
break
|
x_pred = x_pred[:, 0, :]
|
||||||
|
u_inst = (z_t - x_pred) / safe_time_divisor(t_cur).unsqueeze(-1)
|
||||||
|
z_t = z_t + (t_next - t_cur).unsqueeze(-1) * u_inst
|
||||||
|
|
||||||
return torch.stack(traj, dim=1)
|
return z_t.view(cond.shape[0], cfg.channels, cfg.image_size, cfg.image_size)
|
||||||
|
|
||||||
|
|
||||||
def log_class_samples(
|
def log_class_samples(
|
||||||
@@ -450,19 +535,13 @@ def log_class_samples(
|
|||||||
) -> None:
|
) -> None:
|
||||||
if cfg.val_samples_per_class <= 0:
|
if cfg.val_samples_per_class <= 0:
|
||||||
return
|
return
|
||||||
|
training_mode = model.training
|
||||||
model.eval()
|
model.eval()
|
||||||
max_steps = cfg.seq_len if cfg.val_max_steps <= 0 else cfg.val_max_steps
|
|
||||||
input_dim = cfg.channels * cfg.image_size * cfg.image_size
|
|
||||||
|
|
||||||
for cls in range(cfg.num_classes):
|
for cls in range(cfg.num_classes):
|
||||||
cond = torch.full(
|
cond = torch.full(
|
||||||
(cfg.val_samples_per_class,), cls, device=device, dtype=torch.long
|
(cfg.val_samples_per_class,), cls, device=device, dtype=torch.long
|
||||||
)
|
)
|
||||||
x0 = torch.randn(cfg.val_samples_per_class, input_dim, device=device)
|
x_final = sample_class_images(model, cfg, device, cond)
|
||||||
traj = rollout_trajectory(model, x0, cond, max_steps=max_steps)
|
|
||||||
x_final = traj[:, -1, :].view(
|
|
||||||
cfg.val_samples_per_class, cfg.channels, cfg.image_size, cfg.image_size
|
|
||||||
)
|
|
||||||
save_path = Path(cfg.output_dir) / f"val_class_{cls}_step_{step:06d}.png"
|
save_path = Path(cfg.output_dir) / f"val_class_{cls}_step_{step:06d}.png"
|
||||||
save_image_grid(x_final, save_path, nrow=cfg.val_grid_rows)
|
save_image_grid(x_final, save_path, nrow=cfg.val_grid_rows)
|
||||||
logger.log_image(
|
logger.log_image(
|
||||||
@@ -471,12 +550,25 @@ def log_class_samples(
|
|||||||
caption=f"class {cls} step {step}",
|
caption=f"class {cls} step {step}",
|
||||||
step=step,
|
step=step,
|
||||||
)
|
)
|
||||||
model.train()
|
if training_mode:
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
|
||||||
|
def build_perceptual_loss(
|
||||||
|
cfg: TrainConfig, device: torch.device, rank: int, use_ddp: bool
|
||||||
|
) -> LPIPSPerceptualLoss:
|
||||||
|
if use_ddp and rank != 0:
|
||||||
|
torch.distributed.barrier()
|
||||||
|
perceptual_loss_fn = LPIPSPerceptualLoss(cfg).to(device)
|
||||||
|
if use_ddp and rank == 0:
|
||||||
|
torch.distributed.barrier()
|
||||||
|
return perceptual_loss_fn
|
||||||
|
|
||||||
|
|
||||||
def train(cfg: TrainConfig) -> ASMamba:
|
def train(cfg: TrainConfig) -> ASMamba:
|
||||||
validate_time_config(cfg)
|
validate_config(cfg)
|
||||||
use_ddp, rank, world_size, device = setup_distributed(cfg)
|
use_ddp, rank, world_size, device = setup_distributed(cfg)
|
||||||
|
del world_size
|
||||||
set_seed(cfg.seed + rank)
|
set_seed(cfg.seed + rank)
|
||||||
output_dir = Path(cfg.output_dir)
|
output_dir = Path(cfg.output_dir)
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
@@ -488,91 +580,57 @@ def train(cfg: TrainConfig) -> ASMamba:
|
|||||||
optimizer = torch.optim.AdamW(
|
optimizer = torch.optim.AdamW(
|
||||||
model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay
|
model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay
|
||||||
)
|
)
|
||||||
|
perceptual_loss_fn = build_perceptual_loss(cfg, device, rank, use_ddp)
|
||||||
logger = SwanLogger(cfg, enabled=(rank == 0))
|
logger = SwanLogger(cfg, enabled=(rank == 0))
|
||||||
|
|
||||||
loader, sampler = build_dataloader(cfg, distributed=use_ddp)
|
loader, sampler = build_dataloader(cfg, distributed=use_ddp)
|
||||||
loader_iter = infinite_loader(loader)
|
loader_iter = infinite_loader(loader)
|
||||||
|
|
||||||
global_step = 0
|
global_step = 0
|
||||||
for _ in range(cfg.epochs):
|
for epoch_idx in range(cfg.epochs):
|
||||||
if sampler is not None:
|
if sampler is not None:
|
||||||
sampler.set_epoch(global_step)
|
sampler.set_epoch(epoch_idx)
|
||||||
model.train()
|
model.train()
|
||||||
for _ in range(cfg.steps_per_epoch):
|
for _ in range(cfg.steps_per_epoch):
|
||||||
batch = next(loader_iter)
|
batch = next(loader_iter)
|
||||||
x1 = batch["pixel_values"].to(device)
|
x0 = batch["pixel_values"].to(device)
|
||||||
cond = batch["labels"].to(device)
|
cond = batch["labels"].to(device)
|
||||||
b = x1.shape[0]
|
b = x0.shape[0]
|
||||||
x1 = x1.view(b, -1)
|
x0 = x0.view(b, -1)
|
||||||
x0 = torch.randn_like(x1)
|
eps = torch.randn_like(x0)
|
||||||
v_gt = x1 - x0
|
|
||||||
dt_seq = sample_time_sequence(cfg, b, device)
|
|
||||||
t_seq = torch.cumsum(dt_seq, dim=-1)
|
|
||||||
t_seq = torch.cat([torch.zeros(b, 1, device=device), t_seq[:, :-1]], dim=-1)
|
|
||||||
x_seq = x0[:, None, :] + t_seq[:, :, None] * v_gt[:, None, :]
|
|
||||||
|
|
||||||
delta, dt, _ = model(x_seq, cond)
|
r_seq, t_seq = sample_block_times(cfg, b, device, x0.dtype)
|
||||||
|
z_t, v_gt = build_noisy_sequence(x0, eps, t_seq)
|
||||||
|
|
||||||
losses = compute_losses(
|
losses, _ = compute_losses(
|
||||||
delta=delta,
|
model=model,
|
||||||
dt=dt,
|
perceptual_loss_fn=perceptual_loss_fn,
|
||||||
x_seq=x_seq,
|
|
||||||
x0=x0,
|
x0=x0,
|
||||||
|
z_t=z_t,
|
||||||
v_gt=v_gt,
|
v_gt=v_gt,
|
||||||
|
r_seq=r_seq,
|
||||||
t_seq=t_seq,
|
t_seq=t_seq,
|
||||||
dt_seq=dt_seq,
|
cond=cond,
|
||||||
cfg=cfg,
|
cfg=cfg,
|
||||||
)
|
)
|
||||||
|
|
||||||
loss = torch.tensor(0.0, device=device)
|
|
||||||
if cfg.use_flow_loss and "flow" in losses:
|
|
||||||
loss = loss + cfg.lambda_flow * losses["flow"]
|
|
||||||
if cfg.use_pos_loss and "pos" in losses:
|
|
||||||
loss = loss + cfg.lambda_pos * losses["pos"]
|
|
||||||
if cfg.use_dt_loss and "dt" in losses:
|
|
||||||
loss = loss + cfg.lambda_dt * losses["dt"]
|
|
||||||
if loss.item() == 0.0:
|
|
||||||
raise RuntimeError(
|
|
||||||
"No loss enabled: enable at least one of flow/pos/dt."
|
|
||||||
)
|
|
||||||
|
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
loss.backward()
|
losses["total"].backward()
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
|
model.parameters(), max_norm=cfg.max_grad_norm
|
||||||
|
)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
if global_step % 10 == 0:
|
if global_step % 10 == 0 and rank == 0:
|
||||||
dt_min = float(dt.min().item())
|
|
||||||
dt_max = float(dt.max().item())
|
|
||||||
dt_mean = float(dt.mean().item())
|
|
||||||
dt_gt_min = float(dt_seq.min().item())
|
|
||||||
dt_gt_max = float(dt_seq.max().item())
|
|
||||||
dt_gt_mean = float(dt_seq.mean().item())
|
|
||||||
eps = 1e-6
|
|
||||||
clamp_min_ratio = float((dt <= cfg.dt_min + eps).float().mean().item())
|
|
||||||
clamp_max_ratio = float((dt >= cfg.dt_max - eps).float().mean().item())
|
|
||||||
clamp_any_ratio = float(
|
|
||||||
((dt <= cfg.dt_min + eps) | (dt >= cfg.dt_max - eps))
|
|
||||||
.float()
|
|
||||||
.mean()
|
|
||||||
.item()
|
|
||||||
)
|
|
||||||
logger.log(
|
logger.log(
|
||||||
{
|
{
|
||||||
"loss/total": float(loss.item()),
|
"loss/total": float(losses["total"].item()),
|
||||||
"loss/flow": float(
|
"loss/flow": float(losses["flow"].item()),
|
||||||
losses.get("flow", torch.tensor(0.0)).item()
|
"loss/perceptual": float(losses["perceptual"].item()),
|
||||||
),
|
"grad/total_norm": float(grad_norm.item()),
|
||||||
"loss/pos": float(losses.get("pos", torch.tensor(0.0)).item()),
|
"time/r_mean": float(r_seq.mean().item()),
|
||||||
"loss/dt": float(losses.get("dt", torch.tensor(0.0)).item()),
|
"time/t_mean": float(t_seq.mean().item()),
|
||||||
"dt/pred_mean": dt_mean,
|
"time/zero_block_frac": float((t_seq == r_seq).float().mean().item()),
|
||||||
"dt/pred_min": dt_min,
|
|
||||||
"dt/pred_max": dt_max,
|
|
||||||
"dt/gt_mean": dt_gt_mean,
|
|
||||||
"dt/gt_min": dt_gt_min,
|
|
||||||
"dt/gt_max": dt_gt_max,
|
|
||||||
"dt/clamp_min_ratio": clamp_min_ratio,
|
|
||||||
"dt/clamp_max_ratio": clamp_max_ratio,
|
|
||||||
"dt/clamp_any_ratio": clamp_any_ratio,
|
|
||||||
},
|
},
|
||||||
step=global_step,
|
step=global_step,
|
||||||
)
|
)
|
||||||
@@ -584,21 +642,6 @@ def train(cfg: TrainConfig) -> ASMamba:
|
|||||||
and rank == 0
|
and rank == 0
|
||||||
):
|
):
|
||||||
log_class_samples(unwrap_model(model), cfg, device, logger, global_step)
|
log_class_samples(unwrap_model(model), cfg, device, logger, global_step)
|
||||||
dt_hist_path = (
|
|
||||||
Path(cfg.output_dir) / f"dt_hist_step_{global_step:06d}.png"
|
|
||||||
)
|
|
||||||
plot_dt_hist(
|
|
||||||
dt,
|
|
||||||
dt_seq,
|
|
||||||
dt_hist_path,
|
|
||||||
title=f"dt Distribution (step {global_step})",
|
|
||||||
)
|
|
||||||
logger.log_image(
|
|
||||||
"train/dt_hist",
|
|
||||||
dt_hist_path,
|
|
||||||
caption=f"step {global_step}",
|
|
||||||
step=global_step,
|
|
||||||
)
|
|
||||||
|
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
|
|||||||
15
main.py
15
main.py
@@ -4,26 +4,20 @@ from as_mamba import TrainConfig, run_training_and_plot
|
|||||||
|
|
||||||
|
|
||||||
def build_parser() -> argparse.ArgumentParser:
|
def build_parser() -> argparse.ArgumentParser:
|
||||||
parser = argparse.ArgumentParser(description="Train AS-Mamba on MNIST flow matching.")
|
parser = argparse.ArgumentParser(description="Train AS-Mamba on MNIST MeanFlow x-prediction.")
|
||||||
parser.add_argument("--epochs", type=int, default=None)
|
parser.add_argument("--epochs", type=int, default=None)
|
||||||
parser.add_argument("--steps-per-epoch", type=int, default=None)
|
parser.add_argument("--steps-per-epoch", type=int, default=None)
|
||||||
parser.add_argument("--batch-size", type=int, default=None)
|
parser.add_argument("--batch-size", type=int, default=None)
|
||||||
parser.add_argument("--seq-len", type=int, default=None)
|
parser.add_argument("--seq-len", type=int, default=None)
|
||||||
parser.add_argument("--lr", type=float, default=None)
|
parser.add_argument("--lr", type=float, default=None)
|
||||||
parser.add_argument("--weight-decay", type=float, default=None)
|
parser.add_argument("--weight-decay", type=float, default=None)
|
||||||
|
parser.add_argument("--max-grad-norm", type=float, default=None)
|
||||||
parser.add_argument("--device", type=str, default=None)
|
parser.add_argument("--device", type=str, default=None)
|
||||||
parser.add_argument("--output-dir", type=str, default=None)
|
parser.add_argument("--output-dir", type=str, default=None)
|
||||||
parser.add_argument("--project", type=str, default=None)
|
parser.add_argument("--project", type=str, default=None)
|
||||||
parser.add_argument("--run-name", type=str, default=None)
|
parser.add_argument("--run-name", type=str, default=None)
|
||||||
parser.add_argument("--dt-alpha", type=float, default=None)
|
|
||||||
parser.add_argument("--dt-min", type=float, default=None)
|
|
||||||
parser.add_argument("--dt-max", type=float, default=None)
|
|
||||||
parser.add_argument("--lambda-flow", type=float, default=None)
|
parser.add_argument("--lambda-flow", type=float, default=None)
|
||||||
parser.add_argument("--lambda-pos", type=float, default=None)
|
parser.add_argument("--lambda-perceptual", type=float, default=None)
|
||||||
parser.add_argument("--lambda-dt", type=float, default=None)
|
|
||||||
parser.add_argument("--use-flow-loss", action=argparse.BooleanOptionalAction, default=None)
|
|
||||||
parser.add_argument("--use-pos-loss", action=argparse.BooleanOptionalAction, default=None)
|
|
||||||
parser.add_argument("--use-dt-loss", action=argparse.BooleanOptionalAction, default=None)
|
|
||||||
parser.add_argument("--num-classes", type=int, default=None)
|
parser.add_argument("--num-classes", type=int, default=None)
|
||||||
parser.add_argument("--image-size", type=int, default=None)
|
parser.add_argument("--image-size", type=int, default=None)
|
||||||
parser.add_argument("--channels", type=int, default=None)
|
parser.add_argument("--channels", type=int, default=None)
|
||||||
@@ -41,7 +35,8 @@ def build_parser() -> argparse.ArgumentParser:
|
|||||||
parser.add_argument("--val-every", type=int, default=None)
|
parser.add_argument("--val-every", type=int, default=None)
|
||||||
parser.add_argument("--val-samples-per-class", type=int, default=None)
|
parser.add_argument("--val-samples-per-class", type=int, default=None)
|
||||||
parser.add_argument("--val-grid-rows", type=int, default=None)
|
parser.add_argument("--val-grid-rows", type=int, default=None)
|
||||||
parser.add_argument("--val-max-steps", type=int, default=None)
|
parser.add_argument("--val-sampling-steps", type=int, default=None)
|
||||||
|
parser.add_argument("--time-grid-size", type=int, default=None)
|
||||||
parser.add_argument("--use-ddp", action=argparse.BooleanOptionalAction, default=None)
|
parser.add_argument("--use-ddp", action=argparse.BooleanOptionalAction, default=None)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|||||||
@@ -7,8 +7,10 @@ requires-python = ">=3.12"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"datasets>=2.19.0",
|
"datasets>=2.19.0",
|
||||||
"einops>=0.7.0",
|
"einops>=0.7.0",
|
||||||
|
"lpips>=0.1.4",
|
||||||
"matplotlib>=3.8.0",
|
"matplotlib>=3.8.0",
|
||||||
"numpy>=1.26.0",
|
"numpy>=1.26.0",
|
||||||
"swanlab>=0.5.0",
|
"swanlab>=0.5.0",
|
||||||
"torch>=2.2.0",
|
"torch>=2.2.0",
|
||||||
|
"torchvision>=0.24.1",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -4,48 +4,36 @@ set -euo pipefail
|
|||||||
DEVICE="cuda"
|
DEVICE="cuda"
|
||||||
EPOCHS=2000
|
EPOCHS=2000
|
||||||
STEPS_PER_EPOCH=200
|
STEPS_PER_EPOCH=200
|
||||||
BATCH_SIZE=256
|
BATCH_SIZE=512
|
||||||
SEQ_LEN=100
|
SEQ_LEN=5
|
||||||
LR=2e-3
|
LR=1e-3
|
||||||
WEIGHT_DECAY=1e-2
|
WEIGHT_DECAY=1e-2
|
||||||
DT_MIN=5e-4
|
|
||||||
DT_MAX=0.06
|
|
||||||
DT_ALPHA=9.0
|
|
||||||
LAMBDA_FLOW=1.0
|
LAMBDA_FLOW=1.0
|
||||||
LAMBDA_POS=1.0
|
LAMBDA_PERCEPTUAL=0.4
|
||||||
LAMBDA_DT=1.0
|
|
||||||
USE_FLOW_LOSS=true
|
|
||||||
USE_POS_LOSS=false
|
|
||||||
USE_DT_LOSS=true
|
|
||||||
NUM_CLASSES=10
|
NUM_CLASSES=10
|
||||||
IMAGE_SIZE=28
|
IMAGE_SIZE=28
|
||||||
CHANNELS=1
|
CHANNELS=1
|
||||||
NUM_WORKERS=16
|
NUM_WORKERS=32
|
||||||
DATASET_NAME="ylecun/mnist"
|
DATASET_NAME="ylecun/mnist"
|
||||||
DATASET_SPLIT="train"
|
DATASET_SPLIT="train"
|
||||||
D_MODEL=784
|
D_MODEL=784
|
||||||
N_LAYER=6
|
N_LAYER=8
|
||||||
D_STATE=32
|
D_STATE=32
|
||||||
D_CONV=4
|
D_CONV=4
|
||||||
EXPAND=2
|
EXPAND=2
|
||||||
HEADDIM=32
|
HEADDIM=32
|
||||||
CHUNK_SIZE=20
|
CHUNK_SIZE=1
|
||||||
USE_RESIDUAL=false
|
USE_RESIDUAL=true
|
||||||
USE_DDP=true
|
USE_DDP=true
|
||||||
VAL_EVERY=1000
|
VAL_EVERY=1000
|
||||||
VAL_SAMPLES_PER_CLASS=8
|
VAL_SAMPLES_PER_CLASS=8
|
||||||
VAL_GRID_ROWS=4
|
VAL_GRID_ROWS=4
|
||||||
VAL_MAX_STEPS=0
|
VAL_SAMPLING_STEPS=5
|
||||||
|
TIME_GRID_SIZE=256
|
||||||
PROJECT="as-mamba-mnist"
|
PROJECT="as-mamba-mnist"
|
||||||
RUN_NAME="mnist-flow"
|
RUN_NAME="mnist-meanflow-xpred"
|
||||||
OUTPUT_DIR="outputs"
|
OUTPUT_DIR="outputs"
|
||||||
|
|
||||||
USE_FLOW_FLAG="--use-flow-loss"
|
|
||||||
if [ "${USE_FLOW_LOSS}" = "false" ]; then USE_FLOW_FLAG="--no-use-flow-loss"; fi
|
|
||||||
USE_POS_FLAG="--use-pos-loss"
|
|
||||||
if [ "${USE_POS_LOSS}" = "false" ]; then USE_POS_FLAG="--no-use-pos-loss"; fi
|
|
||||||
USE_DT_FLAG="--use-dt-loss"
|
|
||||||
if [ "${USE_DT_LOSS}" = "false" ]; then USE_DT_FLAG="--no-use-dt-loss"; fi
|
|
||||||
USE_RESIDUAL_FLAG="--use-residual"
|
USE_RESIDUAL_FLAG="--use-residual"
|
||||||
if [ "${USE_RESIDUAL}" = "false" ]; then USE_RESIDUAL_FLAG="--no-use-residual"; fi
|
if [ "${USE_RESIDUAL}" = "false" ]; then USE_RESIDUAL_FLAG="--no-use-residual"; fi
|
||||||
USE_DDP_FLAG="--use-ddp"
|
USE_DDP_FLAG="--use-ddp"
|
||||||
@@ -59,15 +47,8 @@ uv run torchrun --nproc_per_node=2 main.py \
|
|||||||
--seq-len "${SEQ_LEN}" \
|
--seq-len "${SEQ_LEN}" \
|
||||||
--lr "${LR}" \
|
--lr "${LR}" \
|
||||||
--weight-decay "${WEIGHT_DECAY}" \
|
--weight-decay "${WEIGHT_DECAY}" \
|
||||||
--dt-min "${DT_MIN}" \
|
|
||||||
--dt-max "${DT_MAX}" \
|
|
||||||
--dt-alpha "${DT_ALPHA}" \
|
|
||||||
--lambda-flow "${LAMBDA_FLOW}" \
|
--lambda-flow "${LAMBDA_FLOW}" \
|
||||||
--lambda-pos "${LAMBDA_POS}" \
|
--lambda-perceptual "${LAMBDA_PERCEPTUAL}" \
|
||||||
--lambda-dt "${LAMBDA_DT}" \
|
|
||||||
${USE_FLOW_FLAG} \
|
|
||||||
${USE_POS_FLAG} \
|
|
||||||
${USE_DT_FLAG} \
|
|
||||||
--num-classes "${NUM_CLASSES}" \
|
--num-classes "${NUM_CLASSES}" \
|
||||||
--image-size "${IMAGE_SIZE}" \
|
--image-size "${IMAGE_SIZE}" \
|
||||||
--channels "${CHANNELS}" \
|
--channels "${CHANNELS}" \
|
||||||
@@ -86,7 +67,8 @@ uv run torchrun --nproc_per_node=2 main.py \
|
|||||||
--val-every "${VAL_EVERY}" \
|
--val-every "${VAL_EVERY}" \
|
||||||
--val-samples-per-class "${VAL_SAMPLES_PER_CLASS}" \
|
--val-samples-per-class "${VAL_SAMPLES_PER_CLASS}" \
|
||||||
--val-grid-rows "${VAL_GRID_ROWS}" \
|
--val-grid-rows "${VAL_GRID_ROWS}" \
|
||||||
--val-max-steps "${VAL_MAX_STEPS}" \
|
--val-sampling-steps "${VAL_SAMPLING_STEPS}" \
|
||||||
|
--time-grid-size "${TIME_GRID_SIZE}" \
|
||||||
--project "${PROJECT}" \
|
--project "${PROJECT}" \
|
||||||
--run-name "${RUN_NAME}" \
|
--run-name "${RUN_NAME}" \
|
||||||
--output-dir "${OUTPUT_DIR}"
|
--output-dir "${OUTPUT_DIR}"
|
||||||
|
|||||||
113
uv.lock
generated
113
uv.lock
generated
@@ -682,6 +682,22 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/80/be/3578e8afd18c88cdf9cb4cffde75a96d2be38c5a903f1ed0ceec061bd09e/kiwisolver-1.4.9-cp314-cp314t-win_arm64.whl", hash = "sha256:4a48a2ce79d65d363597ef7b567ce3d14d68783d2b2263d98db3d9477805ba32", size = 70260, upload-time = "2025-08-10T21:27:36.606Z" },
|
{ url = "https://files.pythonhosted.org/packages/80/be/3578e8afd18c88cdf9cb4cffde75a96d2be38c5a903f1ed0ceec061bd09e/kiwisolver-1.4.9-cp314-cp314t-win_arm64.whl", hash = "sha256:4a48a2ce79d65d363597ef7b567ce3d14d68783d2b2263d98db3d9477805ba32", size = 70260, upload-time = "2025-08-10T21:27:36.606Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "lpips"
|
||||||
|
version = "0.1.4"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "numpy" },
|
||||||
|
{ name = "scipy" },
|
||||||
|
{ name = "torch" },
|
||||||
|
{ name = "torchvision" },
|
||||||
|
{ name = "tqdm" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/e8/2d/4b8148d32f5bd461eb7d5daa54fcc998f86eaa709a57f4ef6aa4c62f024f/lpips-0.1.4.tar.gz", hash = "sha256:3846331df6c69688aec3d300a5eeef6c529435bc8460bd58201c3d62e56188fa", size = 18029, upload-time = "2021-08-25T22:10:32.803Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/9b/13/1df50c7925d9d2746702719f40e864f51ed66f307b20ad32392f1ad2bb87/lpips-0.1.4-py3-none-any.whl", hash = "sha256:fd537af5828b69d2e6ffc0a397bd506dbc28ca183543617690844c08e102ec5e", size = 53763, upload-time = "2021-08-25T22:10:31.257Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "markdown-it-py"
|
name = "markdown-it-py"
|
||||||
version = "4.0.0"
|
version = "4.0.0"
|
||||||
@@ -1687,6 +1703,67 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/19/71/39c7c0d87f8d4e6c020a393182060eaefeeae6c01dab6a84ec346f2567df/rich-13.9.4-py3-none-any.whl", hash = "sha256:6049d5e6ec054bf2779ab3358186963bac2ea89175919d699e378b99738c2a90", size = 242424, upload-time = "2024-11-01T16:43:55.817Z" },
|
{ url = "https://files.pythonhosted.org/packages/19/71/39c7c0d87f8d4e6c020a393182060eaefeeae6c01dab6a84ec346f2567df/rich-13.9.4-py3-none-any.whl", hash = "sha256:6049d5e6ec054bf2779ab3358186963bac2ea89175919d699e378b99738c2a90", size = 242424, upload-time = "2024-11-01T16:43:55.817Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "scipy"
|
||||||
|
version = "1.17.1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "numpy" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/7a/97/5a3609c4f8d58b039179648e62dd220f89864f56f7357f5d4f45c29eb2cc/scipy-1.17.1.tar.gz", hash = "sha256:95d8e012d8cb8816c226aef832200b1d45109ed4464303e997c5b13122b297c0", size = 30573822, upload-time = "2026-02-23T00:26:24.851Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/35/48/b992b488d6f299dbe3f11a20b24d3dda3d46f1a635ede1c46b5b17a7b163/scipy-1.17.1-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:35c3a56d2ef83efc372eaec584314bd0ef2e2f0d2adb21c55e6ad5b344c0dcb8", size = 31610954, upload-time = "2026-02-23T00:17:49.855Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/b2/02/cf107b01494c19dc100f1d0b7ac3cc08666e96ba2d64db7626066cee895e/scipy-1.17.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:fcb310ddb270a06114bb64bbe53c94926b943f5b7f0842194d585c65eb4edd76", size = 28172662, upload-time = "2026-02-23T00:18:01.64Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/cf/a9/599c28631bad314d219cf9ffd40e985b24d603fc8a2f4ccc5ae8419a535b/scipy-1.17.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:cc90d2e9c7e5c7f1a482c9875007c095c3194b1cfedca3c2f3291cdc2bc7c086", size = 20344366, upload-time = "2026-02-23T00:18:12.015Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/35/f5/906eda513271c8deb5af284e5ef0206d17a96239af79f9fa0aebfe0e36b4/scipy-1.17.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:c80be5ede8f3f8eded4eff73cc99a25c388ce98e555b17d31da05287015ffa5b", size = 22704017, upload-time = "2026-02-23T00:18:21.502Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/da/34/16f10e3042d2f1d6b66e0428308ab52224b6a23049cb2f5c1756f713815f/scipy-1.17.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e19ebea31758fac5893a2ac360fedd00116cbb7628e650842a6691ba7ca28a21", size = 32927842, upload-time = "2026-02-23T00:18:35.367Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/01/8e/1e35281b8ab6d5d72ebe9911edcdffa3f36b04ed9d51dec6dd140396e220/scipy-1.17.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:02ae3b274fde71c5e92ac4d54bc06c42d80e399fec704383dcd99b301df37458", size = 35235890, upload-time = "2026-02-23T00:18:49.188Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/c5/5c/9d7f4c88bea6e0d5a4f1bc0506a53a00e9fcb198de372bfe4d3652cef482/scipy-1.17.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8a604bae87c6195d8b1045eddece0514d041604b14f2727bbc2b3020172045eb", size = 35003557, upload-time = "2026-02-23T00:18:54.74Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/65/94/7698add8f276dbab7a9de9fb6b0e02fc13ee61d51c7c3f85ac28b65e1239/scipy-1.17.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:f590cd684941912d10becc07325a3eeb77886fe981415660d9265c4c418d0bea", size = 37625856, upload-time = "2026-02-23T00:19:00.307Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/a2/84/dc08d77fbf3d87d3ee27f6a0c6dcce1de5829a64f2eae85a0ecc1f0daa73/scipy-1.17.1-cp312-cp312-win_amd64.whl", hash = "sha256:41b71f4a3a4cab9d366cd9065b288efc4d4f3c0b37a91a8e0947fb5bd7f31d87", size = 36549682, upload-time = "2026-02-23T00:19:07.67Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/bc/98/fe9ae9ffb3b54b62559f52dedaebe204b408db8109a8c66fdd04869e6424/scipy-1.17.1-cp312-cp312-win_arm64.whl", hash = "sha256:f4115102802df98b2b0db3cce5cb9b92572633a1197c77b7553e5203f284a5b3", size = 24547340, upload-time = "2026-02-23T00:19:12.024Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/76/27/07ee1b57b65e92645f219b37148a7e7928b82e2b5dbeccecb4dff7c64f0b/scipy-1.17.1-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:5e3c5c011904115f88a39308379c17f91546f77c1667cea98739fe0fccea804c", size = 31590199, upload-time = "2026-02-23T00:19:17.192Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ec/ae/db19f8ab842e9b724bf5dbb7db29302a91f1e55bc4d04b1025d6d605a2c5/scipy-1.17.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:6fac755ca3d2c3edcb22f479fceaa241704111414831ddd3bc6056e18516892f", size = 28154001, upload-time = "2026-02-23T00:19:22.241Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/5b/58/3ce96251560107b381cbd6e8413c483bbb1228a6b919fa8652b0d4090e7f/scipy-1.17.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:7ff200bf9d24f2e4d5dc6ee8c3ac64d739d3a89e2326ba68aaf6c4a2b838fd7d", size = 20325719, upload-time = "2026-02-23T00:19:26.329Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/b2/83/15087d945e0e4d48ce2377498abf5ad171ae013232ae31d06f336e64c999/scipy-1.17.1-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:4b400bdc6f79fa02a4d86640310dde87a21fba0c979efff5248908c6f15fad1b", size = 22683595, upload-time = "2026-02-23T00:19:30.304Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/b4/e0/e58fbde4a1a594c8be8114eb4aac1a55bcd6587047efc18a61eb1f5c0d30/scipy-1.17.1-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2b64ca7d4aee0102a97f3ba22124052b4bd2152522355073580bf4845e2550b6", size = 32896429, upload-time = "2026-02-23T00:19:35.536Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/f5/5f/f17563f28ff03c7b6799c50d01d5d856a1d55f2676f537ca8d28c7f627cd/scipy-1.17.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:581b2264fc0aa555f3f435a5944da7504ea3a065d7029ad60e7c3d1ae09c5464", size = 35203952, upload-time = "2026-02-23T00:19:42.259Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/8d/a5/9afd17de24f657fdfe4df9a3f1ea049b39aef7c06000c13db1530d81ccca/scipy-1.17.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:beeda3d4ae615106d7094f7e7cef6218392e4465cc95d25f900bebabfded0950", size = 34979063, upload-time = "2026-02-23T00:19:47.547Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/8b/13/88b1d2384b424bf7c924f2038c1c409f8d88bb2a8d49d097861dd64a57b2/scipy-1.17.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6609bc224e9568f65064cfa72edc0f24ee6655b47575954ec6339534b2798369", size = 37598449, upload-time = "2026-02-23T00:19:53.238Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/35/e5/d6d0e51fc888f692a35134336866341c08655d92614f492c6860dc45bb2c/scipy-1.17.1-cp313-cp313-win_amd64.whl", hash = "sha256:37425bc9175607b0268f493d79a292c39f9d001a357bebb6b88fdfaff13f6448", size = 36510943, upload-time = "2026-02-23T00:20:50.89Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/2a/fd/3be73c564e2a01e690e19cc618811540ba5354c67c8680dce3281123fb79/scipy-1.17.1-cp313-cp313-win_arm64.whl", hash = "sha256:5cf36e801231b6a2059bf354720274b7558746f3b1a4efb43fcf557ccd484a87", size = 24545621, upload-time = "2026-02-23T00:20:55.871Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/6f/6b/17787db8b8114933a66f9dcc479a8272e4b4da75fe03b0c282f7b0ade8cd/scipy-1.17.1-cp313-cp313t-macosx_10_14_x86_64.whl", hash = "sha256:d59c30000a16d8edc7e64152e30220bfbd724c9bbb08368c054e24c651314f0a", size = 31936708, upload-time = "2026-02-23T00:19:58.694Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/38/2e/524405c2b6392765ab1e2b722a41d5da33dc5c7b7278184a8ad29b6cb206/scipy-1.17.1-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:010f4333c96c9bb1a4516269e33cb5917b08ef2166d5556ca2fd9f082a9e6ea0", size = 28570135, upload-time = "2026-02-23T00:20:03.934Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/fd/c3/5bd7199f4ea8556c0c8e39f04ccb014ac37d1468e6cfa6a95c6b3562b76e/scipy-1.17.1-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:2ceb2d3e01c5f1d83c4189737a42d9cb2fc38a6eeed225e7515eef71ad301dce", size = 20741977, upload-time = "2026-02-23T00:20:07.935Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/d9/b8/8ccd9b766ad14c78386599708eb745f6b44f08400a5fd0ade7cf89b6fc93/scipy-1.17.1-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:844e165636711ef41f80b4103ed234181646b98a53c8f05da12ca5ca289134f6", size = 23029601, upload-time = "2026-02-23T00:20:12.161Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/6d/a0/3cb6f4d2fb3e17428ad2880333cac878909ad1a89f678527b5328b93c1d4/scipy-1.17.1-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:158dd96d2207e21c966063e1635b1063cd7787b627b6f07305315dd73d9c679e", size = 33019667, upload-time = "2026-02-23T00:20:17.208Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/f3/c3/2d834a5ac7bf3a0c806ad1508efc02dda3c8c61472a56132d7894c312dea/scipy-1.17.1-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:74cbb80d93260fe2ffa334efa24cb8f2f0f622a9b9febf8b483c0b865bfb3475", size = 35264159, upload-time = "2026-02-23T00:20:23.087Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/4d/77/d3ed4becfdbd217c52062fafe35a72388d1bd82c2d0ba5ca19d6fcc93e11/scipy-1.17.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:dbc12c9f3d185f5c737d801da555fb74b3dcfa1a50b66a1a93e09190f41fab50", size = 35102771, upload-time = "2026-02-23T00:20:28.636Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/bd/12/d19da97efde68ca1ee5538bb261d5d2c062f0c055575128f11a2730e3ac1/scipy-1.17.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:94055a11dfebe37c656e70317e1996dc197e1a15bbcc351bcdd4610e128fe1ca", size = 37665910, upload-time = "2026-02-23T00:20:34.743Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/06/1c/1172a88d507a4baaf72c5a09bb6c018fe2ae0ab622e5830b703a46cc9e44/scipy-1.17.1-cp313-cp313t-win_amd64.whl", hash = "sha256:e30bdeaa5deed6bc27b4cc490823cd0347d7dae09119b8803ae576ea0ce52e4c", size = 36562980, upload-time = "2026-02-23T00:20:40.575Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/70/b0/eb757336e5a76dfa7911f63252e3b7d1de00935d7705cf772db5b45ec238/scipy-1.17.1-cp313-cp313t-win_arm64.whl", hash = "sha256:a720477885a9d2411f94a93d16f9d89bad0f28ca23c3f8daa521e2dcc3f44d49", size = 24856543, upload-time = "2026-02-23T00:20:45.313Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/cf/83/333afb452af6f0fd70414dc04f898647ee1423979ce02efa75c3b0f2c28e/scipy-1.17.1-cp314-cp314-macosx_10_14_x86_64.whl", hash = "sha256:a48a72c77a310327f6a3a920092fa2b8fd03d7deaa60f093038f22d98e096717", size = 31584510, upload-time = "2026-02-23T00:21:01.015Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ed/a6/d05a85fd51daeb2e4ea71d102f15b34fedca8e931af02594193ae4fd25f7/scipy-1.17.1-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:45abad819184f07240d8a696117a7aacd39787af9e0b719d00285549ed19a1e9", size = 28170131, upload-time = "2026-02-23T00:21:05.888Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/db/7b/8624a203326675d7746a254083a187398090a179335b2e4a20e2ddc46e83/scipy-1.17.1-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:3fd1fcdab3ea951b610dc4cef356d416d5802991e7e32b5254828d342f7b7e0b", size = 20342032, upload-time = "2026-02-23T00:21:09.904Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/c9/35/2c342897c00775d688d8ff3987aced3426858fd89d5a0e26e020b660b301/scipy-1.17.1-cp314-cp314-macosx_14_0_x86_64.whl", hash = "sha256:7bdf2da170b67fdf10bca777614b1c7d96ae3ca5794fd9587dce41eb2966e866", size = 22678766, upload-time = "2026-02-23T00:21:14.313Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ef/f2/7cdb8eb308a1a6ae1e19f945913c82c23c0c442a462a46480ce487fdc0ac/scipy-1.17.1-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:adb2642e060a6549c343603a3851ba76ef0b74cc8c079a9a58121c7ec9fe2350", size = 32957007, upload-time = "2026-02-23T00:21:19.663Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/0b/2e/7eea398450457ecb54e18e9d10110993fa65561c4f3add5e8eccd2b9cd41/scipy-1.17.1-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:eee2cfda04c00a857206a4330f0c5e3e56535494e30ca445eb19ec624ae75118", size = 35221333, upload-time = "2026-02-23T00:21:25.278Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/d9/77/5b8509d03b77f093a0d52e606d3c4f79e8b06d1d38c441dacb1e26cacf46/scipy-1.17.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:d2650c1fb97e184d12d8ba010493ee7b322864f7d3d00d3f9bb97d9c21de4068", size = 35042066, upload-time = "2026-02-23T00:21:31.358Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/f9/df/18f80fb99df40b4070328d5ae5c596f2f00fffb50167e31439e932f29e7d/scipy-1.17.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:08b900519463543aa604a06bec02461558a6e1cef8fdbb8098f77a48a83c8118", size = 37612763, upload-time = "2026-02-23T00:21:37.247Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/4b/39/f0e8ea762a764a9dc52aa7dabcfad51a354819de1f0d4652b6a1122424d6/scipy-1.17.1-cp314-cp314-win_amd64.whl", hash = "sha256:3877ac408e14da24a6196de0ddcace62092bfc12a83823e92e49e40747e52c19", size = 37290984, upload-time = "2026-02-23T00:22:35.023Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/7c/56/fe201e3b0f93d1a8bcf75d3379affd228a63d7e2d80ab45467a74b494947/scipy-1.17.1-cp314-cp314-win_arm64.whl", hash = "sha256:f8885db0bc2bffa59d5c1b72fad7a6a92d3e80e7257f967dd81abb553a90d293", size = 25192877, upload-time = "2026-02-23T00:22:39.798Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/96/ad/f8c414e121f82e02d76f310f16db9899c4fcde36710329502a6b2a3c0392/scipy-1.17.1-cp314-cp314t-macosx_10_14_x86_64.whl", hash = "sha256:1cc682cea2ae55524432f3cdff9e9a3be743d52a7443d0cba9017c23c87ae2f6", size = 31949750, upload-time = "2026-02-23T00:21:42.289Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/7c/b0/c741e8865d61b67c81e255f4f0a832846c064e426636cd7de84e74d209be/scipy-1.17.1-cp314-cp314t-macosx_12_0_arm64.whl", hash = "sha256:2040ad4d1795a0ae89bfc7e8429677f365d45aa9fd5e4587cf1ea737f927b4a1", size = 28585858, upload-time = "2026-02-23T00:21:47.706Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ed/1b/3985219c6177866628fa7c2595bfd23f193ceebbe472c98a08824b9466ff/scipy-1.17.1-cp314-cp314t-macosx_14_0_arm64.whl", hash = "sha256:131f5aaea57602008f9822e2115029b55d4b5f7c070287699fe45c661d051e39", size = 20757723, upload-time = "2026-02-23T00:21:52.039Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/c0/19/2a04aa25050d656d6f7b9e7b685cc83d6957fb101665bfd9369ca6534563/scipy-1.17.1-cp314-cp314t-macosx_14_0_x86_64.whl", hash = "sha256:9cdc1a2fcfd5c52cfb3045feb399f7b3ce822abdde3a193a6b9a60b3cb5854ca", size = 23043098, upload-time = "2026-02-23T00:21:56.185Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/86/f1/3383beb9b5d0dbddd030335bf8a8b32d4317185efe495374f134d8be6cce/scipy-1.17.1-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6e3dcd57ab780c741fde8dc68619de988b966db759a3c3152e8e9142c26295ad", size = 33030397, upload-time = "2026-02-23T00:22:01.404Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/41/68/8f21e8a65a5a03f25a79165ec9d2b28c00e66dc80546cf5eb803aeeff35b/scipy-1.17.1-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a9956e4d4f4a301ebf6cde39850333a6b6110799d470dbbb1e25326ac447f52a", size = 35281163, upload-time = "2026-02-23T00:22:07.024Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/84/8d/c8a5e19479554007a5632ed7529e665c315ae7492b4f946b0deb39870e39/scipy-1.17.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:a4328d245944d09fd639771de275701ccadf5f781ba0ff092ad141e017eccda4", size = 35116291, upload-time = "2026-02-23T00:22:12.585Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/52/52/e57eceff0e342a1f50e274264ed47497b59e6a4e3118808ee58ddda7b74a/scipy-1.17.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:a77cbd07b940d326d39a1d1b37817e2ee4d79cb30e7338f3d0cddffae70fcaa2", size = 37682317, upload-time = "2026-02-23T00:22:18.513Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/11/2f/b29eafe4a3fbc3d6de9662b36e028d5f039e72d345e05c250e121a230dd4/scipy-1.17.1-cp314-cp314t-win_amd64.whl", hash = "sha256:eb092099205ef62cd1782b006658db09e2fed75bffcae7cc0d44052d8aa0f484", size = 37345327, upload-time = "2026-02-23T00:22:24.442Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/07/39/338d9219c4e87f3e708f18857ecd24d22a0c3094752393319553096b98af/scipy-1.17.1-cp314-cp314t-win_arm64.whl", hash = "sha256:200e1050faffacc162be6a486a984a0497866ec54149a01270adc8a59b7c7d21", size = 25489165, upload-time = "2026-02-23T00:22:29.563Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "setuptools"
|
name = "setuptools"
|
||||||
version = "80.9.0"
|
version = "80.9.0"
|
||||||
@@ -1792,20 +1869,24 @@ source = { virtual = "." }
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "datasets" },
|
{ name = "datasets" },
|
||||||
{ name = "einops" },
|
{ name = "einops" },
|
||||||
|
{ name = "lpips" },
|
||||||
{ name = "matplotlib" },
|
{ name = "matplotlib" },
|
||||||
{ name = "numpy" },
|
{ name = "numpy" },
|
||||||
{ name = "swanlab" },
|
{ name = "swanlab" },
|
||||||
{ name = "torch" },
|
{ name = "torch" },
|
||||||
|
{ name = "torchvision" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.metadata]
|
[package.metadata]
|
||||||
requires-dist = [
|
requires-dist = [
|
||||||
{ name = "datasets", specifier = ">=2.19.0" },
|
{ name = "datasets", specifier = ">=2.19.0" },
|
||||||
{ name = "einops", specifier = ">=0.7.0" },
|
{ name = "einops", specifier = ">=0.7.0" },
|
||||||
|
{ name = "lpips", specifier = ">=0.1.4" },
|
||||||
{ name = "matplotlib", specifier = ">=3.8.0" },
|
{ name = "matplotlib", specifier = ">=3.8.0" },
|
||||||
{ name = "numpy", specifier = ">=1.26.0" },
|
{ name = "numpy", specifier = ">=1.26.0" },
|
||||||
{ name = "swanlab", specifier = ">=0.5.0" },
|
{ name = "swanlab", specifier = ">=0.5.0" },
|
||||||
{ name = "torch", specifier = ">=2.2.0" },
|
{ name = "torch", specifier = ">=2.2.0" },
|
||||||
|
{ name = "torchvision", specifier = ">=0.24.1" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1860,6 +1941,38 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/db/2b/f7818f6ec88758dfd21da46b6cd46af9d1b3433e53ddbb19ad1e0da17f9b/torch-2.9.1-cp314-cp314t-win_amd64.whl", hash = "sha256:c88d3299ddeb2b35dcc31753305612db485ab6f1823e37fb29451c8b2732b87e", size = 111163659, upload-time = "2025-11-12T15:23:20.009Z" },
|
{ url = "https://files.pythonhosted.org/packages/db/2b/f7818f6ec88758dfd21da46b6cd46af9d1b3433e53ddbb19ad1e0da17f9b/torch-2.9.1-cp314-cp314t-win_amd64.whl", hash = "sha256:c88d3299ddeb2b35dcc31753305612db485ab6f1823e37fb29451c8b2732b87e", size = 111163659, upload-time = "2025-11-12T15:23:20.009Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "torchvision"
|
||||||
|
version = "0.24.1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "numpy" },
|
||||||
|
{ name = "pillow" },
|
||||||
|
{ name = "torch" },
|
||||||
|
]
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/f0/af/18e2c6b9538a045f60718a0c5a058908ccb24f88fde8e6f0fc12d5ff7bd3/torchvision-0.24.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e48bf6a8ec95872eb45763f06499f87bd2fb246b9b96cb00aae260fda2f96193", size = 1891433, upload-time = "2025-11-12T15:25:03.232Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/9d/43/600e5cfb0643d10d633124f5982d7abc2170dfd7ce985584ff16edab3e76/torchvision-0.24.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:7fb7590c737ebe3e1c077ad60c0e5e2e56bb26e7bccc3b9d04dbfc34fd09f050", size = 2386737, upload-time = "2025-11-12T15:25:08.288Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/93/b1/db2941526ecddd84884132e2742a55c9311296a6a38627f9e2627f5ac889/torchvision-0.24.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:66a98471fc18cad9064123106d810a75f57f0838eee20edc56233fd8484b0cc7", size = 8049868, upload-time = "2025-11-12T15:25:13.058Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/69/98/16e583f59f86cd59949f59d52bfa8fc286f86341a229a9d15cbe7a694f0c/torchvision-0.24.1-cp312-cp312-win_amd64.whl", hash = "sha256:4aa6cb806eb8541e92c9b313e96192c6b826e9eb0042720e2fa250d021079952", size = 4302006, upload-time = "2025-11-12T15:25:16.184Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/e4/97/ab40550f482577f2788304c27220e8ba02c63313bd74cf2f8920526aac20/torchvision-0.24.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:8a6696db7fb71eadb2c6a48602106e136c785642e598eb1533e0b27744f2cce6", size = 1891435, upload-time = "2025-11-12T15:25:28.642Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/30/65/ac0a3f9be6abdbe4e1d82c915d7e20de97e7fd0e9a277970508b015309f3/torchvision-0.24.1-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:db2125c46f9cb25dc740be831ce3ce99303cfe60439249a41b04fd9f373be671", size = 2338718, upload-time = "2025-11-12T15:25:26.19Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/10/b5/5bba24ff9d325181508501ed7f0c3de8ed3dd2edca0784d48b144b6c5252/torchvision-0.24.1-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:f035f0cacd1f44a8ff6cb7ca3627d84c54d685055961d73a1a9fb9827a5414c8", size = 8049661, upload-time = "2025-11-12T15:25:22.558Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/5c/ec/54a96ae9ab6a0dd66d4bba27771f892e36478a9c3489fa56e51c70abcc4d/torchvision-0.24.1-cp313-cp313-win_amd64.whl", hash = "sha256:16274823b93048e0a29d83415166a2e9e0bf4e1b432668357b657612a4802864", size = 4319808, upload-time = "2025-11-12T15:25:17.318Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/d5/f3/a90a389a7e547f3eb8821b13f96ea7c0563cdefbbbb60a10e08dda9720ff/torchvision-0.24.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:e3f96208b4bef54cd60e415545f5200346a65024e04f29a26cd0006dbf9e8e66", size = 2005342, upload-time = "2025-11-12T15:25:11.871Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/a9/fe/ff27d2ed1b524078164bea1062f23d2618a5fc3208e247d6153c18c91a76/torchvision-0.24.1-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:f231f6a4f2aa6522713326d0d2563538fa72d613741ae364f9913027fa52ea35", size = 2341708, upload-time = "2025-11-12T15:25:25.08Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/b1/b9/d6c903495cbdfd2533b3ef6f7b5643ff589ea062f8feb5c206ee79b9d9e5/torchvision-0.24.1-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:1540a9e7f8cf55fe17554482f5a125a7e426347b71de07327d5de6bfd8d17caa", size = 8177239, upload-time = "2025-11-12T15:25:18.554Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/4f/2b/ba02e4261369c3798310483028495cf507e6cb3f394f42e4796981ecf3a7/torchvision-0.24.1-cp313-cp313t-win_amd64.whl", hash = "sha256:d83e16d70ea85d2f196d678bfb702c36be7a655b003abed84e465988b6128938", size = 4251604, upload-time = "2025-11-12T15:25:34.069Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/42/84/577b2cef8f32094add5f52887867da4c2a3e6b4261538447e9b48eb25812/torchvision-0.24.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:cccf4b4fec7fdfcd3431b9ea75d1588c0a8596d0333245dafebee0462abe3388", size = 2005319, upload-time = "2025-11-12T15:25:23.827Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/5f/34/ecb786bffe0159a3b49941a61caaae089853132f3cd1e8f555e3621f7e6f/torchvision-0.24.1-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:1b495edd3a8f9911292424117544f0b4ab780452e998649425d1f4b2bed6695f", size = 2338844, upload-time = "2025-11-12T15:25:32.625Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/51/99/a84623786a6969504c87f2dc3892200f586ee13503f519d282faab0bb4f0/torchvision-0.24.1-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:ab211e1807dc3e53acf8f6638df9a7444c80c0ad050466e8d652b3e83776987b", size = 8175144, upload-time = "2025-11-12T15:25:31.355Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/6d/ba/8fae3525b233e109317ce6a9c1de922ab2881737b029a7e88021f81e068f/torchvision-0.24.1-cp314-cp314-win_amd64.whl", hash = "sha256:18f9cb60e64b37b551cd605a3d62c15730c086362b40682d23e24b616a697d41", size = 4234459, upload-time = "2025-11-12T15:25:19.859Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/50/33/481602c1c72d0485d4b3a6b48c9534b71c2957c9d83bf860eb837bf5a620/torchvision-0.24.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:ec9d7379c519428395e4ffda4dbb99ec56be64b0a75b95989e00f9ec7ae0b2d7", size = 2005336, upload-time = "2025-11-12T15:25:27.225Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/d0/7f/372de60bf3dd8f5593bd0d03f4aecf0d1fd58f5bc6943618d9d913f5e6d5/torchvision-0.24.1-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:af9201184c2712d808bd4eb656899011afdfce1e83721c7cb08000034df353fe", size = 2341704, upload-time = "2025-11-12T15:25:29.857Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/36/9b/0f3b9ff3d0225ee2324ec663de0e7fb3eb855615ca958ac1875f22f1f8e5/torchvision-0.24.1-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:9ef95d819fd6df81bc7cc97b8f21a15d2c0d3ac5dbfaab5cbc2d2ce57114b19e", size = 8177422, upload-time = "2025-11-12T15:25:37.357Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/d6/ab/e2bcc7c2f13d882a58f8b30ff86f794210b075736587ea50f8c545834f8a/torchvision-0.24.1-cp314-cp314t-win_amd64.whl", hash = "sha256:480b271d6edff83ac2e8d69bbb4cf2073f93366516a50d48f140ccfceedb002e", size = 4335190, upload-time = "2025-11-12T15:25:35.745Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tqdm"
|
name = "tqdm"
|
||||||
version = "4.67.1"
|
version = "4.67.1"
|
||||||
|
|||||||
Reference in New Issue
Block a user