4 Commits

Author SHA1 Message Date
Logic
5897a0afd1 feat: stabilize meanflow training and time sampling 2026-03-11 22:54:48 +08:00
Logic
9b2968997c Implement Mamba MeanFlow x-prediction training 2026-03-11 16:33:40 +08:00
gameloader
01fc1e4eab refactor: simplify delta-only flow training
Remove learned dt prediction and auxiliary losses.

Add repository contributor guidelines.
2026-03-10 18:23:17 +08:00
gameloader
913740266b fix: remove dt clamping and use raw softplus for step size 2026-01-22 14:41:02 +08:00
6 changed files with 408 additions and 246 deletions

27
AGENTS.md Normal file
View File

@@ -0,0 +1,27 @@
# Repository Guidelines
## Project Structure & Module Organization
This repository is a flat Python training project, not a packaged library. `main.py` is the CLI entry point. `as_mamba.py` holds `TrainConfig`, dataset loading, the training loop, validation image generation, and SwanLab logging. `mamba2_minimal.py` contains the standalone Mamba-2 building blocks. `train_as_mamba.sh` is the multi-GPU launcher. Runtime artifacts land in `outputs/` and `swanlog/`; treat both as generated data, not source.
## Build, Test, and Development Commands
Use `uv` with the committed lockfile.
- `uv sync` installs the Python 3.12 environment from `pyproject.toml` and `uv.lock`.
- `uv run python main.py --help` lists all training flags and is the fastest CLI sanity check.
- `uv run python main.py --device cpu --epochs 1 --steps-per-epoch 1 --batch-size 8 --num-workers 0 --val-every 0` runs a minimal local smoke test.
- `bash train_as_mamba.sh` launches the default distributed training job with `torchrun`; adjust `--nproc_per_node` and shell variables before using it on a new machine.
- `uv run python -m compileall main.py as_mamba.py mamba2_minimal.py` performs a lightweight syntax check when no test suite is available.
## Coding Style & Naming Conventions
Follow the existing Python style: 4-space indentation, type hints on public functions, `dataclass`-based config, and small helper functions over deeply nested logic. Use `snake_case` for functions, variables, and CLI flags; use `PascalCase` for classes such as `ASMamba` and `TrainConfig`. Keep new modules top-level unless the project is restructured. There is no configured formatter or linter yet, so match surrounding code closely and keep imports grouped and readable.
## Testing Guidelines
There is no dedicated `tests/` directory or pytest setup yet. For changes, require at minimum:
- `uv run python -m compileall main.py as_mamba.py mamba2_minimal.py`
- one short training smoke run with reduced epochs and workers
If you add reusable logic, prefer extracting pure functions so a future pytest suite can cover them easily. Name any new test files `test_<module>.py`.
## Commit & Pull Request Guidelines
Recent history favors short imperative commit subjects, usually Conventional Commit style: `feat: ...`, `fix: ...`, and scoped variants like `feat(mamba): ...`. Keep commits focused on one training or infrastructure change. Pull requests should describe the config or behavior change, list the verification commands you ran, and include sample images or metric notes when outputs in `outputs/` materially change. Do not commit `.venv/`, `__pycache__/`, or large generated artifacts.

View File

@@ -1,11 +1,14 @@
from __future__ import annotations from __future__ import annotations
import math import math
from dataclasses import asdict, dataclass
import os import os
from dataclasses import asdict, dataclass
from pathlib import Path from pathlib import Path
from typing import Iterator, Optional from typing import Iterator, Optional
import lpips
os.environ.setdefault("MPLCONFIGDIR", "/tmp/mamba_diffusion_mplconfig")
import matplotlib import matplotlib
matplotlib.use("Agg") matplotlib.use("Agg")
@@ -16,12 +19,17 @@ import torch.nn.functional as F
from datasets import load_dataset from datasets import load_dataset
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from torch import Tensor, nn from torch import Tensor, nn
from torch.func import jvp
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from mamba2_minimal import InferenceCache, Mamba2, Mamba2Config, RMSNorm from mamba2_minimal import InferenceCache, Mamba2, Mamba2Config, RMSNorm
FIXED_VAL_SAMPLING_STEPS = 5
FIXED_VAL_TIME_GRID = (1.0, 0.8, 0.6, 0.4, 0.2, 0.0)
@dataclass @dataclass
class TrainConfig: class TrainConfig:
seed: int = 42 seed: int = 42
@@ -29,18 +37,12 @@ class TrainConfig:
epochs: int = 50 epochs: int = 50
steps_per_epoch: int = 200 steps_per_epoch: int = 200
batch_size: int = 128 batch_size: int = 128
seq_len: int = 20 seq_len: int = 5
lr: float = 2e-4 lr: float = 2e-4
weight_decay: float = 1e-2 weight_decay: float = 1e-2
dt_min: float = 1e-3 max_grad_norm: float = 1.0
dt_max: float = 0.06
dt_alpha: float = 8.0
lambda_flow: float = 1.0 lambda_flow: float = 1.0
lambda_pos: float = 1.0 lambda_perceptual: float = 2.0
lambda_dt: float = 1.0
use_flow_loss: bool = True
use_pos_loss: bool = False
use_dt_loss: bool = True
num_classes: int = 10 num_classes: int = 10
image_size: int = 28 image_size: int = 28
channels: int = 1 channels: int = 1
@@ -57,11 +59,12 @@ class TrainConfig:
use_residual: bool = False use_residual: bool = False
output_dir: str = "outputs" output_dir: str = "outputs"
project: str = "as-mamba-mnist" project: str = "as-mamba-mnist"
run_name: str = "mnist-flow" run_name: str = "mnist-meanflow"
val_every: int = 200 val_every: int = 200
val_samples_per_class: int = 8 val_samples_per_class: int = 8
val_grid_rows: int = 4 val_grid_rows: int = 4
val_max_steps: int = 0 val_sampling_steps: int = FIXED_VAL_SAMPLING_STEPS
time_grid_size: int = 256
use_ddp: bool = False use_ddp: bool = False
@@ -113,12 +116,35 @@ class Mamba2Backbone(nn.Module):
return x, h return x, h
def sinusoidal_embedding(t: Tensor, dim: int) -> Tensor:
if dim <= 0:
raise ValueError("sinusoidal embedding dim must be > 0")
half = dim // 2
if half == 0:
return torch.zeros(*t.shape, dim, device=t.device, dtype=t.dtype)
denom = max(half - 1, 1)
freqs = torch.exp(
-math.log(10000.0)
* torch.arange(half, device=t.device, dtype=torch.float32)
/ denom
)
args = t.to(torch.float32).unsqueeze(-1) * freqs
emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
if dim % 2 == 1:
pad = torch.zeros(*emb.shape[:-1], 1, device=emb.device, dtype=emb.dtype)
emb = torch.cat([emb, pad], dim=-1)
return emb.to(t.dtype)
def safe_time_divisor(t: Tensor) -> Tensor:
eps = torch.finfo(t.dtype).eps
return torch.where(t > 0, t, torch.full_like(t, eps))
class ASMamba(nn.Module): class ASMamba(nn.Module):
def __init__(self, cfg: TrainConfig) -> None: def __init__(self, cfg: TrainConfig) -> None:
super().__init__() super().__init__()
self.cfg = cfg self.cfg = cfg
self.dt_min = float(cfg.dt_min)
self.dt_max = float(cfg.dt_max)
input_dim = cfg.channels * cfg.image_size * cfg.image_size input_dim = cfg.channels * cfg.image_size * cfg.image_size
if cfg.d_model == 0: if cfg.d_model == 0:
cfg.d_model = input_dim cfg.d_model = input_dim
@@ -138,28 +164,42 @@ class ASMamba(nn.Module):
) )
self.backbone = Mamba2Backbone(args, use_residual=cfg.use_residual) self.backbone = Mamba2Backbone(args, use_residual=cfg.use_residual)
self.cond_emb = nn.Embedding(cfg.num_classes, cfg.d_model) self.cond_emb = nn.Embedding(cfg.num_classes, cfg.d_model)
self.delta_head = nn.Linear(cfg.d_model, input_dim) self.clean_head = nn.Linear(cfg.d_model, input_dim)
self.dt_head = nn.Sequential(
nn.Linear(cfg.d_model, cfg.d_model),
nn.SiLU(),
nn.Linear(cfg.d_model, 1),
)
def forward( def forward(
self, x: Tensor, cond: Tensor, h: Optional[list[InferenceCache]] = None self,
) -> tuple[Tensor, Tensor, list[InferenceCache]]: z_t: Tensor,
r: Tensor,
t: Tensor,
cond: Tensor,
h: Optional[list[InferenceCache]] = None,
) -> tuple[Tensor, list[InferenceCache]]:
if r.dim() == 1:
r = r.unsqueeze(1)
elif r.dim() == 3 and r.shape[-1] == 1:
r = r.squeeze(-1)
if t.dim() == 1:
t = t.unsqueeze(1)
elif t.dim() == 3 and t.shape[-1] == 1:
t = t.squeeze(-1)
r = r.to(dtype=z_t.dtype)
t = t.to(dtype=z_t.dtype)
z_t = z_t + sinusoidal_embedding(r, z_t.shape[-1]) + sinusoidal_embedding(
t, z_t.shape[-1]
)
cond_vec = self.cond_emb(cond) cond_vec = self.cond_emb(cond)
feats, h = self.backbone(x, cond_vec, h) feats, h = self.backbone(z_t, cond_vec, h)
delta = self.delta_head(feats) x_pred = torch.tanh(self.clean_head(feats))
dt_raw = self.dt_head(feats).squeeze(-1) return x_pred, h
dt = torch.clamp(F.softplus(dt_raw), min=self.dt_min, max=self.dt_max)
return delta, dt, h
def step( def step(
self, x: Tensor, cond: Tensor, h: list[InferenceCache] self, z_t: Tensor, r: Tensor, t: Tensor, cond: Tensor, h: list[InferenceCache]
) -> tuple[Tensor, Tensor, list[InferenceCache]]: ) -> tuple[Tensor, list[InferenceCache]]:
delta, dt, h = self.forward(x.unsqueeze(1), cond, h) x_pred, h = self.forward(
return delta[:, 0, :], dt[:, 0], h z_t.unsqueeze(1), r.unsqueeze(1), t.unsqueeze(1), cond, h
)
return x_pred[:, 0, :], h
def init_cache(self, batch_size: int, device: torch.device) -> list[InferenceCache]: def init_cache(self, batch_size: int, device: torch.device) -> list[InferenceCache]:
return [ return [
@@ -226,6 +266,34 @@ class SwanLogger:
finish() finish()
class LPIPSPerceptualLoss(nn.Module):
def __init__(self, cfg: TrainConfig) -> None:
super().__init__()
torch_home = Path(cfg.output_dir) / ".torch"
torch_home.mkdir(parents=True, exist_ok=True)
os.environ["TORCH_HOME"] = str(torch_home)
self.channels = cfg.channels
self.loss_fn = lpips.LPIPS(net="vgg", verbose=False)
self.loss_fn.eval()
for param in self.loss_fn.parameters():
param.requires_grad_(False)
def _prepare_images(self, images: Tensor) -> Tensor:
if images.shape[1] == 1:
return images.repeat(1, 3, 1, 1)
if images.shape[1] != 3:
raise ValueError(
"LPIPS perceptual loss expects 1-channel or 3-channel images. "
f"Got {images.shape[1]} channels."
)
return images
def forward(self, pred: Tensor, target: Tensor) -> Tensor:
pred_rgb = self._prepare_images(pred)
target_rgb = self._prepare_images(target)
return self.loss_fn(pred_rgb, target_rgb).mean()
def set_seed(seed: int) -> None: def set_seed(seed: int) -> None:
torch.manual_seed(seed) torch.manual_seed(seed)
if torch.cuda.is_available(): if torch.cuda.is_available():
@@ -250,40 +318,42 @@ def unwrap_model(model: nn.Module) -> nn.Module:
return model.module if hasattr(model, "module") else model return model.module if hasattr(model, "module") else model
def validate_time_config(cfg: TrainConfig) -> None: def validate_config(cfg: TrainConfig) -> None:
if cfg.seq_len <= 0: if cfg.seq_len != 5:
raise ValueError("seq_len must be > 0")
base = 1.0 / cfg.seq_len
if cfg.dt_max <= base:
raise ValueError( raise ValueError(
"dt_max must be > 1/seq_len to allow non-uniform dt_seq. " f"seq_len must be 5 for the required 5-block training setup (got {cfg.seq_len})."
f"Got dt_max={cfg.dt_max}, seq_len={cfg.seq_len}, 1/seq_len={base}."
) )
if cfg.dt_min >= cfg.dt_max: if cfg.time_grid_size < 2:
raise ValueError("time_grid_size must be >= 2.")
if cfg.lambda_perceptual < 0:
raise ValueError("lambda_perceptual must be >= 0.")
if cfg.max_grad_norm <= 0:
raise ValueError("max_grad_norm must be > 0.")
if cfg.val_sampling_steps != FIXED_VAL_SAMPLING_STEPS:
raise ValueError( raise ValueError(
f"dt_min must be < dt_max (got dt_min={cfg.dt_min}, dt_max={cfg.dt_max})." f"val_sampling_steps is fixed to {FIXED_VAL_SAMPLING_STEPS} for validation sampling."
) )
def sample_time_sequence( def sample_block_times(
cfg: TrainConfig, batch_size: int, device: torch.device cfg: TrainConfig, batch_size: int, device: torch.device, dtype: torch.dtype
) -> Tensor: ) -> tuple[Tensor, Tensor]:
alpha = float(cfg.dt_alpha) num_internal = cfg.seq_len - 1
if alpha <= 0: normal = torch.randn(batch_size, num_internal, device=device, dtype=dtype)
raise ValueError("dt_alpha must be > 0") logit_normal = torch.sigmoid(normal * math.sqrt(0.8))
dist = torch.distributions.Gamma(alpha, 1.0) uniform = torch.rand(batch_size, num_internal, device=device, dtype=dtype)
raw = dist.sample((batch_size, cfg.seq_len)).to(device) use_uniform = torch.rand(batch_size, num_internal, device=device) < 0.1
dt_seq = raw / raw.sum(dim=-1, keepdim=True) cuts = torch.where(use_uniform, uniform, logit_normal)
base = 1.0 / cfg.seq_len cuts, _ = torch.sort(cuts, dim=-1)
max_dt = float(cfg.dt_max) boundaries = torch.cat(
if max_dt <= base: [
return torch.full_like(dt_seq, base) torch.zeros(batch_size, 1, device=device, dtype=dtype),
max_current = dt_seq.max(dim=-1, keepdim=True).values cuts,
if (max_current > max_dt).any(): torch.ones(batch_size, 1, device=device, dtype=dtype),
gamma = (max_dt - base) / (max_current - base) ],
gamma = gamma.clamp(0.0, 1.0) dim=-1,
dt_seq = gamma * dt_seq + (1.0 - gamma) * base )
return dt_seq return boundaries[:, :-1], boundaries[:, 1:]
def build_dataloader( def build_dataloader(
@@ -330,51 +400,66 @@ def infinite_loader(loader: DataLoader) -> Iterator[dict]:
yield batch yield batch
def compute_losses( def build_noisy_sequence(
delta: Tensor,
dt: Tensor,
x_seq: Tensor,
x0: Tensor, x0: Tensor,
v_gt: Tensor, eps: Tensor,
t_seq: Tensor, t_seq: Tensor,
dt_seq: Tensor, ) -> tuple[Tensor, Tensor]:
z_t = (1.0 - t_seq.unsqueeze(-1)) * x0[:, None, :] + t_seq.unsqueeze(-1) * eps[:, None, :]
v_gt = eps - x0
return z_t, v_gt
def compute_losses(
model: nn.Module,
perceptual_loss_fn: LPIPSPerceptualLoss,
x0: Tensor,
z_t: Tensor,
v_gt: Tensor,
r_seq: Tensor,
t_seq: Tensor,
cond: Tensor,
cfg: TrainConfig, cfg: TrainConfig,
) -> dict[str, Tensor]: ) -> tuple[dict[str, Tensor], Tensor]:
losses: dict[str, Tensor] = {} seq_len = z_t.shape[1]
safe_t = safe_time_divisor(t_seq).unsqueeze(-1)
if cfg.use_flow_loss: x_pred, _ = model(z_t, r_seq, t_seq, cond)
target_disp = v_gt[:, None, :] * dt.unsqueeze(-1) u = (z_t - x_pred) / safe_t
losses["flow"] = F.mse_loss(delta, target_disp)
if cfg.use_pos_loss: x_pred_inst, _ = model(z_t, t_seq, t_seq, cond)
t_next = t_seq + dt v_inst = ((z_t - x_pred_inst) / safe_t).detach()
t_next = torch.clamp(t_next, 0.0, 1.0)
x_target = x0[:, None, :] + t_next.unsqueeze(-1) * v_gt[:, None, :]
x_next_pred = x_seq + delta
losses["pos"] = F.mse_loss(x_next_pred, x_target)
if cfg.use_dt_loss: def u_fn(z_in: Tensor, r_in: Tensor, t_in: Tensor) -> Tensor:
losses["dt"] = F.mse_loss(dt, dt_seq) x_pred_local, _ = model(z_in, r_in, t_in, cond)
return (z_in - x_pred_local) / safe_time_divisor(t_in).unsqueeze(-1)
return losses _, dudt = jvp(
u_fn,
(z_t, r_seq, t_seq),
(v_inst, torch.zeros_like(r_seq), torch.ones_like(t_seq)),
)
corrected_velocity = u + (t_seq - r_seq).unsqueeze(-1) * dudt.detach()
target_velocity = v_gt[:, None, :].expand(-1, seq_len, -1)
pred_images = x_pred.reshape(
x0.shape[0] * seq_len, cfg.channels, cfg.image_size, cfg.image_size
)
target_images = (
x0.reshape(x0.shape[0], cfg.channels, cfg.image_size, cfg.image_size)
.unsqueeze(1)
.expand(-1, seq_len, -1, -1, -1)
.reshape(x0.shape[0] * seq_len, cfg.channels, cfg.image_size, cfg.image_size)
)
def plot_dt_hist( losses = {
dt_pred: Tensor, dt_gt: Tensor, save_path: Path, title: str = "dt Distribution" "flow": F.mse_loss(corrected_velocity, target_velocity),
) -> None: "perceptual": perceptual_loss_fn(pred_images, target_images),
dt_pred_np = dt_pred.detach().cpu().numpy().reshape(-1) }
dt_gt_np = dt_gt.detach().cpu().numpy().reshape(-1) losses["total"] = cfg.lambda_flow * losses["flow"] + cfg.lambda_perceptual * losses[
"perceptual"
fig, ax = plt.subplots(figsize=(6, 4)) ]
ax.hist(dt_gt_np, bins=30, alpha=0.6, label="dt_gt", color="steelblue") return losses, x_pred
ax.hist(dt_pred_np, bins=30, alpha=0.6, label="dt_pred", color="orange")
ax.set_title(title)
ax.set_xlabel("dt")
ax.set_ylabel("count")
ax.legend(loc="best")
fig.tight_layout()
fig.savefig(save_path, dpi=160)
plt.close(fig)
def make_grid(images: Tensor, nrow: int) -> np.ndarray: def make_grid(images: Tensor, nrow: int) -> np.ndarray:
@@ -409,36 +494,36 @@ def save_image_grid(
plt.close(fig) plt.close(fig)
def rollout_trajectory( def sample_class_images(
model: ASMamba, model: ASMamba,
x0: Tensor, cfg: TrainConfig,
device: torch.device,
cond: Tensor, cond: Tensor,
max_steps: int,
) -> Tensor: ) -> Tensor:
device = x0.device
model.eval() model.eval()
h = model.init_cache(batch_size=x0.shape[0], device=device) input_dim = cfg.channels * cfg.image_size * cfg.image_size
x = x0 z_t = torch.randn(cond.shape[0], input_dim, device=device)
total_time = torch.zeros(x0.shape[0], device=device) time_grid = torch.tensor(FIXED_VAL_TIME_GRID, device=device)
traj = [x0.detach().cpu()]
with torch.no_grad(): with torch.no_grad():
for _ in range(max_steps): for step_idx in range(FIXED_VAL_SAMPLING_STEPS):
delta, dt, h = model.step(x, cond, h) t_cur = torch.full(
dt = torch.clamp(dt, min=model.dt_min, max=model.dt_max) (cond.shape[0],),
remaining = 1.0 - total_time float(time_grid[step_idx].item()),
overshoot = dt > remaining device=device,
if overshoot.any(): )
scale = (remaining / dt).unsqueeze(-1) t_next = time_grid[step_idx + 1]
delta = torch.where(overshoot.unsqueeze(-1), delta * scale, delta) x_pred, _ = model(
dt = torch.where(overshoot, remaining, dt) z_t.unsqueeze(1),
x = x + delta t_cur.unsqueeze(1),
total_time = total_time + dt t_cur.unsqueeze(1),
traj.append(x.detach().cpu()) cond,
if torch.all(total_time >= 1.0 - 1e-6): )
break x_pred = x_pred[:, 0, :]
u_inst = (z_t - x_pred) / safe_time_divisor(t_cur).unsqueeze(-1)
z_t = z_t + (t_next - t_cur).unsqueeze(-1) * u_inst
return torch.stack(traj, dim=1) return z_t.view(cond.shape[0], cfg.channels, cfg.image_size, cfg.image_size)
def log_class_samples( def log_class_samples(
@@ -450,19 +535,13 @@ def log_class_samples(
) -> None: ) -> None:
if cfg.val_samples_per_class <= 0: if cfg.val_samples_per_class <= 0:
return return
training_mode = model.training
model.eval() model.eval()
max_steps = cfg.seq_len if cfg.val_max_steps <= 0 else cfg.val_max_steps
input_dim = cfg.channels * cfg.image_size * cfg.image_size
for cls in range(cfg.num_classes): for cls in range(cfg.num_classes):
cond = torch.full( cond = torch.full(
(cfg.val_samples_per_class,), cls, device=device, dtype=torch.long (cfg.val_samples_per_class,), cls, device=device, dtype=torch.long
) )
x0 = torch.randn(cfg.val_samples_per_class, input_dim, device=device) x_final = sample_class_images(model, cfg, device, cond)
traj = rollout_trajectory(model, x0, cond, max_steps=max_steps)
x_final = traj[:, -1, :].view(
cfg.val_samples_per_class, cfg.channels, cfg.image_size, cfg.image_size
)
save_path = Path(cfg.output_dir) / f"val_class_{cls}_step_{step:06d}.png" save_path = Path(cfg.output_dir) / f"val_class_{cls}_step_{step:06d}.png"
save_image_grid(x_final, save_path, nrow=cfg.val_grid_rows) save_image_grid(x_final, save_path, nrow=cfg.val_grid_rows)
logger.log_image( logger.log_image(
@@ -471,12 +550,25 @@ def log_class_samples(
caption=f"class {cls} step {step}", caption=f"class {cls} step {step}",
step=step, step=step,
) )
if training_mode:
model.train() model.train()
def build_perceptual_loss(
cfg: TrainConfig, device: torch.device, rank: int, use_ddp: bool
) -> LPIPSPerceptualLoss:
if use_ddp and rank != 0:
torch.distributed.barrier()
perceptual_loss_fn = LPIPSPerceptualLoss(cfg).to(device)
if use_ddp and rank == 0:
torch.distributed.barrier()
return perceptual_loss_fn
def train(cfg: TrainConfig) -> ASMamba: def train(cfg: TrainConfig) -> ASMamba:
validate_time_config(cfg) validate_config(cfg)
use_ddp, rank, world_size, device = setup_distributed(cfg) use_ddp, rank, world_size, device = setup_distributed(cfg)
del world_size
set_seed(cfg.seed + rank) set_seed(cfg.seed + rank)
output_dir = Path(cfg.output_dir) output_dir = Path(cfg.output_dir)
if rank == 0: if rank == 0:
@@ -488,91 +580,57 @@ def train(cfg: TrainConfig) -> ASMamba:
optimizer = torch.optim.AdamW( optimizer = torch.optim.AdamW(
model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay
) )
perceptual_loss_fn = build_perceptual_loss(cfg, device, rank, use_ddp)
logger = SwanLogger(cfg, enabled=(rank == 0)) logger = SwanLogger(cfg, enabled=(rank == 0))
loader, sampler = build_dataloader(cfg, distributed=use_ddp) loader, sampler = build_dataloader(cfg, distributed=use_ddp)
loader_iter = infinite_loader(loader) loader_iter = infinite_loader(loader)
global_step = 0 global_step = 0
for _ in range(cfg.epochs): for epoch_idx in range(cfg.epochs):
if sampler is not None: if sampler is not None:
sampler.set_epoch(global_step) sampler.set_epoch(epoch_idx)
model.train() model.train()
for _ in range(cfg.steps_per_epoch): for _ in range(cfg.steps_per_epoch):
batch = next(loader_iter) batch = next(loader_iter)
x1 = batch["pixel_values"].to(device) x0 = batch["pixel_values"].to(device)
cond = batch["labels"].to(device) cond = batch["labels"].to(device)
b = x1.shape[0] b = x0.shape[0]
x1 = x1.view(b, -1) x0 = x0.view(b, -1)
x0 = torch.randn_like(x1) eps = torch.randn_like(x0)
v_gt = x1 - x0
dt_seq = sample_time_sequence(cfg, b, device)
t_seq = torch.cumsum(dt_seq, dim=-1)
t_seq = torch.cat([torch.zeros(b, 1, device=device), t_seq[:, :-1]], dim=-1)
x_seq = x0[:, None, :] + t_seq[:, :, None] * v_gt[:, None, :]
delta, dt, _ = model(x_seq, cond) r_seq, t_seq = sample_block_times(cfg, b, device, x0.dtype)
z_t, v_gt = build_noisy_sequence(x0, eps, t_seq)
losses = compute_losses( losses, _ = compute_losses(
delta=delta, model=model,
dt=dt, perceptual_loss_fn=perceptual_loss_fn,
x_seq=x_seq,
x0=x0, x0=x0,
z_t=z_t,
v_gt=v_gt, v_gt=v_gt,
r_seq=r_seq,
t_seq=t_seq, t_seq=t_seq,
dt_seq=dt_seq, cond=cond,
cfg=cfg, cfg=cfg,
) )
loss = torch.tensor(0.0, device=device)
if cfg.use_flow_loss and "flow" in losses:
loss = loss + cfg.lambda_flow * losses["flow"]
if cfg.use_pos_loss and "pos" in losses:
loss = loss + cfg.lambda_pos * losses["pos"]
if cfg.use_dt_loss and "dt" in losses:
loss = loss + cfg.lambda_dt * losses["dt"]
if loss.item() == 0.0:
raise RuntimeError(
"No loss enabled: enable at least one of flow/pos/dt."
)
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)
loss.backward() losses["total"].backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), max_norm=cfg.max_grad_norm
)
optimizer.step() optimizer.step()
if global_step % 10 == 0: if global_step % 10 == 0 and rank == 0:
dt_min = float(dt.min().item())
dt_max = float(dt.max().item())
dt_mean = float(dt.mean().item())
dt_gt_min = float(dt_seq.min().item())
dt_gt_max = float(dt_seq.max().item())
dt_gt_mean = float(dt_seq.mean().item())
eps = 1e-6
clamp_min_ratio = float((dt <= cfg.dt_min + eps).float().mean().item())
clamp_max_ratio = float((dt >= cfg.dt_max - eps).float().mean().item())
clamp_any_ratio = float(
((dt <= cfg.dt_min + eps) | (dt >= cfg.dt_max - eps))
.float()
.mean()
.item()
)
logger.log( logger.log(
{ {
"loss/total": float(loss.item()), "loss/total": float(losses["total"].item()),
"loss/flow": float( "loss/flow": float(losses["flow"].item()),
losses.get("flow", torch.tensor(0.0)).item() "loss/perceptual": float(losses["perceptual"].item()),
), "grad/total_norm": float(grad_norm.item()),
"loss/pos": float(losses.get("pos", torch.tensor(0.0)).item()), "time/r_mean": float(r_seq.mean().item()),
"loss/dt": float(losses.get("dt", torch.tensor(0.0)).item()), "time/t_mean": float(t_seq.mean().item()),
"dt/pred_mean": dt_mean, "time/zero_block_frac": float((t_seq == r_seq).float().mean().item()),
"dt/pred_min": dt_min,
"dt/pred_max": dt_max,
"dt/gt_mean": dt_gt_mean,
"dt/gt_min": dt_gt_min,
"dt/gt_max": dt_gt_max,
"dt/clamp_min_ratio": clamp_min_ratio,
"dt/clamp_max_ratio": clamp_max_ratio,
"dt/clamp_any_ratio": clamp_any_ratio,
}, },
step=global_step, step=global_step,
) )
@@ -584,21 +642,6 @@ def train(cfg: TrainConfig) -> ASMamba:
and rank == 0 and rank == 0
): ):
log_class_samples(unwrap_model(model), cfg, device, logger, global_step) log_class_samples(unwrap_model(model), cfg, device, logger, global_step)
dt_hist_path = (
Path(cfg.output_dir) / f"dt_hist_step_{global_step:06d}.png"
)
plot_dt_hist(
dt,
dt_seq,
dt_hist_path,
title=f"dt Distribution (step {global_step})",
)
logger.log_image(
"train/dt_hist",
dt_hist_path,
caption=f"step {global_step}",
step=global_step,
)
global_step += 1 global_step += 1

15
main.py
View File

@@ -4,26 +4,20 @@ from as_mamba import TrainConfig, run_training_and_plot
def build_parser() -> argparse.ArgumentParser: def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Train AS-Mamba on MNIST flow matching.") parser = argparse.ArgumentParser(description="Train AS-Mamba on MNIST MeanFlow x-prediction.")
parser.add_argument("--epochs", type=int, default=None) parser.add_argument("--epochs", type=int, default=None)
parser.add_argument("--steps-per-epoch", type=int, default=None) parser.add_argument("--steps-per-epoch", type=int, default=None)
parser.add_argument("--batch-size", type=int, default=None) parser.add_argument("--batch-size", type=int, default=None)
parser.add_argument("--seq-len", type=int, default=None) parser.add_argument("--seq-len", type=int, default=None)
parser.add_argument("--lr", type=float, default=None) parser.add_argument("--lr", type=float, default=None)
parser.add_argument("--weight-decay", type=float, default=None) parser.add_argument("--weight-decay", type=float, default=None)
parser.add_argument("--max-grad-norm", type=float, default=None)
parser.add_argument("--device", type=str, default=None) parser.add_argument("--device", type=str, default=None)
parser.add_argument("--output-dir", type=str, default=None) parser.add_argument("--output-dir", type=str, default=None)
parser.add_argument("--project", type=str, default=None) parser.add_argument("--project", type=str, default=None)
parser.add_argument("--run-name", type=str, default=None) parser.add_argument("--run-name", type=str, default=None)
parser.add_argument("--dt-alpha", type=float, default=None)
parser.add_argument("--dt-min", type=float, default=None)
parser.add_argument("--dt-max", type=float, default=None)
parser.add_argument("--lambda-flow", type=float, default=None) parser.add_argument("--lambda-flow", type=float, default=None)
parser.add_argument("--lambda-pos", type=float, default=None) parser.add_argument("--lambda-perceptual", type=float, default=None)
parser.add_argument("--lambda-dt", type=float, default=None)
parser.add_argument("--use-flow-loss", action=argparse.BooleanOptionalAction, default=None)
parser.add_argument("--use-pos-loss", action=argparse.BooleanOptionalAction, default=None)
parser.add_argument("--use-dt-loss", action=argparse.BooleanOptionalAction, default=None)
parser.add_argument("--num-classes", type=int, default=None) parser.add_argument("--num-classes", type=int, default=None)
parser.add_argument("--image-size", type=int, default=None) parser.add_argument("--image-size", type=int, default=None)
parser.add_argument("--channels", type=int, default=None) parser.add_argument("--channels", type=int, default=None)
@@ -41,7 +35,8 @@ def build_parser() -> argparse.ArgumentParser:
parser.add_argument("--val-every", type=int, default=None) parser.add_argument("--val-every", type=int, default=None)
parser.add_argument("--val-samples-per-class", type=int, default=None) parser.add_argument("--val-samples-per-class", type=int, default=None)
parser.add_argument("--val-grid-rows", type=int, default=None) parser.add_argument("--val-grid-rows", type=int, default=None)
parser.add_argument("--val-max-steps", type=int, default=None) parser.add_argument("--val-sampling-steps", type=int, default=None)
parser.add_argument("--time-grid-size", type=int, default=None)
parser.add_argument("--use-ddp", action=argparse.BooleanOptionalAction, default=None) parser.add_argument("--use-ddp", action=argparse.BooleanOptionalAction, default=None)
return parser return parser

View File

@@ -7,8 +7,10 @@ requires-python = ">=3.12"
dependencies = [ dependencies = [
"datasets>=2.19.0", "datasets>=2.19.0",
"einops>=0.7.0", "einops>=0.7.0",
"lpips>=0.1.4",
"matplotlib>=3.8.0", "matplotlib>=3.8.0",
"numpy>=1.26.0", "numpy>=1.26.0",
"swanlab>=0.5.0", "swanlab>=0.5.0",
"torch>=2.2.0", "torch>=2.2.0",
"torchvision>=0.24.1",
] ]

View File

@@ -4,48 +4,36 @@ set -euo pipefail
DEVICE="cuda" DEVICE="cuda"
EPOCHS=2000 EPOCHS=2000
STEPS_PER_EPOCH=200 STEPS_PER_EPOCH=200
BATCH_SIZE=256 BATCH_SIZE=512
SEQ_LEN=100 SEQ_LEN=5
LR=2e-3 LR=1e-3
WEIGHT_DECAY=1e-2 WEIGHT_DECAY=1e-2
DT_MIN=5e-4
DT_MAX=0.06
DT_ALPHA=9.0
LAMBDA_FLOW=1.0 LAMBDA_FLOW=1.0
LAMBDA_POS=1.0 LAMBDA_PERCEPTUAL=0.4
LAMBDA_DT=1.0
USE_FLOW_LOSS=true
USE_POS_LOSS=false
USE_DT_LOSS=true
NUM_CLASSES=10 NUM_CLASSES=10
IMAGE_SIZE=28 IMAGE_SIZE=28
CHANNELS=1 CHANNELS=1
NUM_WORKERS=16 NUM_WORKERS=32
DATASET_NAME="ylecun/mnist" DATASET_NAME="ylecun/mnist"
DATASET_SPLIT="train" DATASET_SPLIT="train"
D_MODEL=784 D_MODEL=784
N_LAYER=6 N_LAYER=8
D_STATE=32 D_STATE=32
D_CONV=4 D_CONV=4
EXPAND=2 EXPAND=2
HEADDIM=32 HEADDIM=32
CHUNK_SIZE=20 CHUNK_SIZE=1
USE_RESIDUAL=false USE_RESIDUAL=true
USE_DDP=true USE_DDP=true
VAL_EVERY=1000 VAL_EVERY=1000
VAL_SAMPLES_PER_CLASS=8 VAL_SAMPLES_PER_CLASS=8
VAL_GRID_ROWS=4 VAL_GRID_ROWS=4
VAL_MAX_STEPS=0 VAL_SAMPLING_STEPS=5
TIME_GRID_SIZE=256
PROJECT="as-mamba-mnist" PROJECT="as-mamba-mnist"
RUN_NAME="mnist-flow" RUN_NAME="mnist-meanflow-xpred"
OUTPUT_DIR="outputs" OUTPUT_DIR="outputs"
USE_FLOW_FLAG="--use-flow-loss"
if [ "${USE_FLOW_LOSS}" = "false" ]; then USE_FLOW_FLAG="--no-use-flow-loss"; fi
USE_POS_FLAG="--use-pos-loss"
if [ "${USE_POS_LOSS}" = "false" ]; then USE_POS_FLAG="--no-use-pos-loss"; fi
USE_DT_FLAG="--use-dt-loss"
if [ "${USE_DT_LOSS}" = "false" ]; then USE_DT_FLAG="--no-use-dt-loss"; fi
USE_RESIDUAL_FLAG="--use-residual" USE_RESIDUAL_FLAG="--use-residual"
if [ "${USE_RESIDUAL}" = "false" ]; then USE_RESIDUAL_FLAG="--no-use-residual"; fi if [ "${USE_RESIDUAL}" = "false" ]; then USE_RESIDUAL_FLAG="--no-use-residual"; fi
USE_DDP_FLAG="--use-ddp" USE_DDP_FLAG="--use-ddp"
@@ -59,15 +47,8 @@ uv run torchrun --nproc_per_node=2 main.py \
--seq-len "${SEQ_LEN}" \ --seq-len "${SEQ_LEN}" \
--lr "${LR}" \ --lr "${LR}" \
--weight-decay "${WEIGHT_DECAY}" \ --weight-decay "${WEIGHT_DECAY}" \
--dt-min "${DT_MIN}" \
--dt-max "${DT_MAX}" \
--dt-alpha "${DT_ALPHA}" \
--lambda-flow "${LAMBDA_FLOW}" \ --lambda-flow "${LAMBDA_FLOW}" \
--lambda-pos "${LAMBDA_POS}" \ --lambda-perceptual "${LAMBDA_PERCEPTUAL}" \
--lambda-dt "${LAMBDA_DT}" \
${USE_FLOW_FLAG} \
${USE_POS_FLAG} \
${USE_DT_FLAG} \
--num-classes "${NUM_CLASSES}" \ --num-classes "${NUM_CLASSES}" \
--image-size "${IMAGE_SIZE}" \ --image-size "${IMAGE_SIZE}" \
--channels "${CHANNELS}" \ --channels "${CHANNELS}" \
@@ -86,7 +67,8 @@ uv run torchrun --nproc_per_node=2 main.py \
--val-every "${VAL_EVERY}" \ --val-every "${VAL_EVERY}" \
--val-samples-per-class "${VAL_SAMPLES_PER_CLASS}" \ --val-samples-per-class "${VAL_SAMPLES_PER_CLASS}" \
--val-grid-rows "${VAL_GRID_ROWS}" \ --val-grid-rows "${VAL_GRID_ROWS}" \
--val-max-steps "${VAL_MAX_STEPS}" \ --val-sampling-steps "${VAL_SAMPLING_STEPS}" \
--time-grid-size "${TIME_GRID_SIZE}" \
--project "${PROJECT}" \ --project "${PROJECT}" \
--run-name "${RUN_NAME}" \ --run-name "${RUN_NAME}" \
--output-dir "${OUTPUT_DIR}" --output-dir "${OUTPUT_DIR}"

113
uv.lock generated
View File

@@ -682,6 +682,22 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/80/be/3578e8afd18c88cdf9cb4cffde75a96d2be38c5a903f1ed0ceec061bd09e/kiwisolver-1.4.9-cp314-cp314t-win_arm64.whl", hash = "sha256:4a48a2ce79d65d363597ef7b567ce3d14d68783d2b2263d98db3d9477805ba32", size = 70260, upload-time = "2025-08-10T21:27:36.606Z" }, { url = "https://files.pythonhosted.org/packages/80/be/3578e8afd18c88cdf9cb4cffde75a96d2be38c5a903f1ed0ceec061bd09e/kiwisolver-1.4.9-cp314-cp314t-win_arm64.whl", hash = "sha256:4a48a2ce79d65d363597ef7b567ce3d14d68783d2b2263d98db3d9477805ba32", size = 70260, upload-time = "2025-08-10T21:27:36.606Z" },
] ]
[[package]]
name = "lpips"
version = "0.1.4"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "numpy" },
{ name = "scipy" },
{ name = "torch" },
{ name = "torchvision" },
{ name = "tqdm" },
]
sdist = { url = "https://files.pythonhosted.org/packages/e8/2d/4b8148d32f5bd461eb7d5daa54fcc998f86eaa709a57f4ef6aa4c62f024f/lpips-0.1.4.tar.gz", hash = "sha256:3846331df6c69688aec3d300a5eeef6c529435bc8460bd58201c3d62e56188fa", size = 18029, upload-time = "2021-08-25T22:10:32.803Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/9b/13/1df50c7925d9d2746702719f40e864f51ed66f307b20ad32392f1ad2bb87/lpips-0.1.4-py3-none-any.whl", hash = "sha256:fd537af5828b69d2e6ffc0a397bd506dbc28ca183543617690844c08e102ec5e", size = 53763, upload-time = "2021-08-25T22:10:31.257Z" },
]
[[package]] [[package]]
name = "markdown-it-py" name = "markdown-it-py"
version = "4.0.0" version = "4.0.0"
@@ -1687,6 +1703,67 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/19/71/39c7c0d87f8d4e6c020a393182060eaefeeae6c01dab6a84ec346f2567df/rich-13.9.4-py3-none-any.whl", hash = "sha256:6049d5e6ec054bf2779ab3358186963bac2ea89175919d699e378b99738c2a90", size = 242424, upload-time = "2024-11-01T16:43:55.817Z" }, { url = "https://files.pythonhosted.org/packages/19/71/39c7c0d87f8d4e6c020a393182060eaefeeae6c01dab6a84ec346f2567df/rich-13.9.4-py3-none-any.whl", hash = "sha256:6049d5e6ec054bf2779ab3358186963bac2ea89175919d699e378b99738c2a90", size = 242424, upload-time = "2024-11-01T16:43:55.817Z" },
] ]
[[package]]
name = "scipy"
version = "1.17.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "numpy" },
]
sdist = { url = "https://files.pythonhosted.org/packages/7a/97/5a3609c4f8d58b039179648e62dd220f89864f56f7357f5d4f45c29eb2cc/scipy-1.17.1.tar.gz", hash = "sha256:95d8e012d8cb8816c226aef832200b1d45109ed4464303e997c5b13122b297c0", size = 30573822, upload-time = "2026-02-23T00:26:24.851Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/35/48/b992b488d6f299dbe3f11a20b24d3dda3d46f1a635ede1c46b5b17a7b163/scipy-1.17.1-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:35c3a56d2ef83efc372eaec584314bd0ef2e2f0d2adb21c55e6ad5b344c0dcb8", size = 31610954, upload-time = "2026-02-23T00:17:49.855Z" },
{ url = "https://files.pythonhosted.org/packages/b2/02/cf107b01494c19dc100f1d0b7ac3cc08666e96ba2d64db7626066cee895e/scipy-1.17.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:fcb310ddb270a06114bb64bbe53c94926b943f5b7f0842194d585c65eb4edd76", size = 28172662, upload-time = "2026-02-23T00:18:01.64Z" },
{ url = "https://files.pythonhosted.org/packages/cf/a9/599c28631bad314d219cf9ffd40e985b24d603fc8a2f4ccc5ae8419a535b/scipy-1.17.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:cc90d2e9c7e5c7f1a482c9875007c095c3194b1cfedca3c2f3291cdc2bc7c086", size = 20344366, upload-time = "2026-02-23T00:18:12.015Z" },
{ url = "https://files.pythonhosted.org/packages/35/f5/906eda513271c8deb5af284e5ef0206d17a96239af79f9fa0aebfe0e36b4/scipy-1.17.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:c80be5ede8f3f8eded4eff73cc99a25c388ce98e555b17d31da05287015ffa5b", size = 22704017, upload-time = "2026-02-23T00:18:21.502Z" },
{ url = "https://files.pythonhosted.org/packages/da/34/16f10e3042d2f1d6b66e0428308ab52224b6a23049cb2f5c1756f713815f/scipy-1.17.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e19ebea31758fac5893a2ac360fedd00116cbb7628e650842a6691ba7ca28a21", size = 32927842, upload-time = "2026-02-23T00:18:35.367Z" },
{ url = "https://files.pythonhosted.org/packages/01/8e/1e35281b8ab6d5d72ebe9911edcdffa3f36b04ed9d51dec6dd140396e220/scipy-1.17.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:02ae3b274fde71c5e92ac4d54bc06c42d80e399fec704383dcd99b301df37458", size = 35235890, upload-time = "2026-02-23T00:18:49.188Z" },
{ url = "https://files.pythonhosted.org/packages/c5/5c/9d7f4c88bea6e0d5a4f1bc0506a53a00e9fcb198de372bfe4d3652cef482/scipy-1.17.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8a604bae87c6195d8b1045eddece0514d041604b14f2727bbc2b3020172045eb", size = 35003557, upload-time = "2026-02-23T00:18:54.74Z" },
{ url = "https://files.pythonhosted.org/packages/65/94/7698add8f276dbab7a9de9fb6b0e02fc13ee61d51c7c3f85ac28b65e1239/scipy-1.17.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:f590cd684941912d10becc07325a3eeb77886fe981415660d9265c4c418d0bea", size = 37625856, upload-time = "2026-02-23T00:19:00.307Z" },
{ url = "https://files.pythonhosted.org/packages/a2/84/dc08d77fbf3d87d3ee27f6a0c6dcce1de5829a64f2eae85a0ecc1f0daa73/scipy-1.17.1-cp312-cp312-win_amd64.whl", hash = "sha256:41b71f4a3a4cab9d366cd9065b288efc4d4f3c0b37a91a8e0947fb5bd7f31d87", size = 36549682, upload-time = "2026-02-23T00:19:07.67Z" },
{ url = "https://files.pythonhosted.org/packages/bc/98/fe9ae9ffb3b54b62559f52dedaebe204b408db8109a8c66fdd04869e6424/scipy-1.17.1-cp312-cp312-win_arm64.whl", hash = "sha256:f4115102802df98b2b0db3cce5cb9b92572633a1197c77b7553e5203f284a5b3", size = 24547340, upload-time = "2026-02-23T00:19:12.024Z" },
{ url = "https://files.pythonhosted.org/packages/76/27/07ee1b57b65e92645f219b37148a7e7928b82e2b5dbeccecb4dff7c64f0b/scipy-1.17.1-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:5e3c5c011904115f88a39308379c17f91546f77c1667cea98739fe0fccea804c", size = 31590199, upload-time = "2026-02-23T00:19:17.192Z" },
{ url = "https://files.pythonhosted.org/packages/ec/ae/db19f8ab842e9b724bf5dbb7db29302a91f1e55bc4d04b1025d6d605a2c5/scipy-1.17.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:6fac755ca3d2c3edcb22f479fceaa241704111414831ddd3bc6056e18516892f", size = 28154001, upload-time = "2026-02-23T00:19:22.241Z" },
{ url = "https://files.pythonhosted.org/packages/5b/58/3ce96251560107b381cbd6e8413c483bbb1228a6b919fa8652b0d4090e7f/scipy-1.17.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:7ff200bf9d24f2e4d5dc6ee8c3ac64d739d3a89e2326ba68aaf6c4a2b838fd7d", size = 20325719, upload-time = "2026-02-23T00:19:26.329Z" },
{ url = "https://files.pythonhosted.org/packages/b2/83/15087d945e0e4d48ce2377498abf5ad171ae013232ae31d06f336e64c999/scipy-1.17.1-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:4b400bdc6f79fa02a4d86640310dde87a21fba0c979efff5248908c6f15fad1b", size = 22683595, upload-time = "2026-02-23T00:19:30.304Z" },
{ url = "https://files.pythonhosted.org/packages/b4/e0/e58fbde4a1a594c8be8114eb4aac1a55bcd6587047efc18a61eb1f5c0d30/scipy-1.17.1-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2b64ca7d4aee0102a97f3ba22124052b4bd2152522355073580bf4845e2550b6", size = 32896429, upload-time = "2026-02-23T00:19:35.536Z" },
{ url = "https://files.pythonhosted.org/packages/f5/5f/f17563f28ff03c7b6799c50d01d5d856a1d55f2676f537ca8d28c7f627cd/scipy-1.17.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:581b2264fc0aa555f3f435a5944da7504ea3a065d7029ad60e7c3d1ae09c5464", size = 35203952, upload-time = "2026-02-23T00:19:42.259Z" },
{ url = "https://files.pythonhosted.org/packages/8d/a5/9afd17de24f657fdfe4df9a3f1ea049b39aef7c06000c13db1530d81ccca/scipy-1.17.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:beeda3d4ae615106d7094f7e7cef6218392e4465cc95d25f900bebabfded0950", size = 34979063, upload-time = "2026-02-23T00:19:47.547Z" },
{ url = "https://files.pythonhosted.org/packages/8b/13/88b1d2384b424bf7c924f2038c1c409f8d88bb2a8d49d097861dd64a57b2/scipy-1.17.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6609bc224e9568f65064cfa72edc0f24ee6655b47575954ec6339534b2798369", size = 37598449, upload-time = "2026-02-23T00:19:53.238Z" },
{ url = "https://files.pythonhosted.org/packages/35/e5/d6d0e51fc888f692a35134336866341c08655d92614f492c6860dc45bb2c/scipy-1.17.1-cp313-cp313-win_amd64.whl", hash = "sha256:37425bc9175607b0268f493d79a292c39f9d001a357bebb6b88fdfaff13f6448", size = 36510943, upload-time = "2026-02-23T00:20:50.89Z" },
{ url = "https://files.pythonhosted.org/packages/2a/fd/3be73c564e2a01e690e19cc618811540ba5354c67c8680dce3281123fb79/scipy-1.17.1-cp313-cp313-win_arm64.whl", hash = "sha256:5cf36e801231b6a2059bf354720274b7558746f3b1a4efb43fcf557ccd484a87", size = 24545621, upload-time = "2026-02-23T00:20:55.871Z" },
{ url = "https://files.pythonhosted.org/packages/6f/6b/17787db8b8114933a66f9dcc479a8272e4b4da75fe03b0c282f7b0ade8cd/scipy-1.17.1-cp313-cp313t-macosx_10_14_x86_64.whl", hash = "sha256:d59c30000a16d8edc7e64152e30220bfbd724c9bbb08368c054e24c651314f0a", size = 31936708, upload-time = "2026-02-23T00:19:58.694Z" },
{ url = "https://files.pythonhosted.org/packages/38/2e/524405c2b6392765ab1e2b722a41d5da33dc5c7b7278184a8ad29b6cb206/scipy-1.17.1-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:010f4333c96c9bb1a4516269e33cb5917b08ef2166d5556ca2fd9f082a9e6ea0", size = 28570135, upload-time = "2026-02-23T00:20:03.934Z" },
{ url = "https://files.pythonhosted.org/packages/fd/c3/5bd7199f4ea8556c0c8e39f04ccb014ac37d1468e6cfa6a95c6b3562b76e/scipy-1.17.1-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:2ceb2d3e01c5f1d83c4189737a42d9cb2fc38a6eeed225e7515eef71ad301dce", size = 20741977, upload-time = "2026-02-23T00:20:07.935Z" },
{ url = "https://files.pythonhosted.org/packages/d9/b8/8ccd9b766ad14c78386599708eb745f6b44f08400a5fd0ade7cf89b6fc93/scipy-1.17.1-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:844e165636711ef41f80b4103ed234181646b98a53c8f05da12ca5ca289134f6", size = 23029601, upload-time = "2026-02-23T00:20:12.161Z" },
{ url = "https://files.pythonhosted.org/packages/6d/a0/3cb6f4d2fb3e17428ad2880333cac878909ad1a89f678527b5328b93c1d4/scipy-1.17.1-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:158dd96d2207e21c966063e1635b1063cd7787b627b6f07305315dd73d9c679e", size = 33019667, upload-time = "2026-02-23T00:20:17.208Z" },
{ url = "https://files.pythonhosted.org/packages/f3/c3/2d834a5ac7bf3a0c806ad1508efc02dda3c8c61472a56132d7894c312dea/scipy-1.17.1-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:74cbb80d93260fe2ffa334efa24cb8f2f0f622a9b9febf8b483c0b865bfb3475", size = 35264159, upload-time = "2026-02-23T00:20:23.087Z" },
{ url = "https://files.pythonhosted.org/packages/4d/77/d3ed4becfdbd217c52062fafe35a72388d1bd82c2d0ba5ca19d6fcc93e11/scipy-1.17.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:dbc12c9f3d185f5c737d801da555fb74b3dcfa1a50b66a1a93e09190f41fab50", size = 35102771, upload-time = "2026-02-23T00:20:28.636Z" },
{ url = "https://files.pythonhosted.org/packages/bd/12/d19da97efde68ca1ee5538bb261d5d2c062f0c055575128f11a2730e3ac1/scipy-1.17.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:94055a11dfebe37c656e70317e1996dc197e1a15bbcc351bcdd4610e128fe1ca", size = 37665910, upload-time = "2026-02-23T00:20:34.743Z" },
{ url = "https://files.pythonhosted.org/packages/06/1c/1172a88d507a4baaf72c5a09bb6c018fe2ae0ab622e5830b703a46cc9e44/scipy-1.17.1-cp313-cp313t-win_amd64.whl", hash = "sha256:e30bdeaa5deed6bc27b4cc490823cd0347d7dae09119b8803ae576ea0ce52e4c", size = 36562980, upload-time = "2026-02-23T00:20:40.575Z" },
{ url = "https://files.pythonhosted.org/packages/70/b0/eb757336e5a76dfa7911f63252e3b7d1de00935d7705cf772db5b45ec238/scipy-1.17.1-cp313-cp313t-win_arm64.whl", hash = "sha256:a720477885a9d2411f94a93d16f9d89bad0f28ca23c3f8daa521e2dcc3f44d49", size = 24856543, upload-time = "2026-02-23T00:20:45.313Z" },
{ url = "https://files.pythonhosted.org/packages/cf/83/333afb452af6f0fd70414dc04f898647ee1423979ce02efa75c3b0f2c28e/scipy-1.17.1-cp314-cp314-macosx_10_14_x86_64.whl", hash = "sha256:a48a72c77a310327f6a3a920092fa2b8fd03d7deaa60f093038f22d98e096717", size = 31584510, upload-time = "2026-02-23T00:21:01.015Z" },
{ url = "https://files.pythonhosted.org/packages/ed/a6/d05a85fd51daeb2e4ea71d102f15b34fedca8e931af02594193ae4fd25f7/scipy-1.17.1-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:45abad819184f07240d8a696117a7aacd39787af9e0b719d00285549ed19a1e9", size = 28170131, upload-time = "2026-02-23T00:21:05.888Z" },
{ url = "https://files.pythonhosted.org/packages/db/7b/8624a203326675d7746a254083a187398090a179335b2e4a20e2ddc46e83/scipy-1.17.1-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:3fd1fcdab3ea951b610dc4cef356d416d5802991e7e32b5254828d342f7b7e0b", size = 20342032, upload-time = "2026-02-23T00:21:09.904Z" },
{ url = "https://files.pythonhosted.org/packages/c9/35/2c342897c00775d688d8ff3987aced3426858fd89d5a0e26e020b660b301/scipy-1.17.1-cp314-cp314-macosx_14_0_x86_64.whl", hash = "sha256:7bdf2da170b67fdf10bca777614b1c7d96ae3ca5794fd9587dce41eb2966e866", size = 22678766, upload-time = "2026-02-23T00:21:14.313Z" },
{ url = "https://files.pythonhosted.org/packages/ef/f2/7cdb8eb308a1a6ae1e19f945913c82c23c0c442a462a46480ce487fdc0ac/scipy-1.17.1-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:adb2642e060a6549c343603a3851ba76ef0b74cc8c079a9a58121c7ec9fe2350", size = 32957007, upload-time = "2026-02-23T00:21:19.663Z" },
{ url = "https://files.pythonhosted.org/packages/0b/2e/7eea398450457ecb54e18e9d10110993fa65561c4f3add5e8eccd2b9cd41/scipy-1.17.1-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:eee2cfda04c00a857206a4330f0c5e3e56535494e30ca445eb19ec624ae75118", size = 35221333, upload-time = "2026-02-23T00:21:25.278Z" },
{ url = "https://files.pythonhosted.org/packages/d9/77/5b8509d03b77f093a0d52e606d3c4f79e8b06d1d38c441dacb1e26cacf46/scipy-1.17.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:d2650c1fb97e184d12d8ba010493ee7b322864f7d3d00d3f9bb97d9c21de4068", size = 35042066, upload-time = "2026-02-23T00:21:31.358Z" },
{ url = "https://files.pythonhosted.org/packages/f9/df/18f80fb99df40b4070328d5ae5c596f2f00fffb50167e31439e932f29e7d/scipy-1.17.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:08b900519463543aa604a06bec02461558a6e1cef8fdbb8098f77a48a83c8118", size = 37612763, upload-time = "2026-02-23T00:21:37.247Z" },
{ url = "https://files.pythonhosted.org/packages/4b/39/f0e8ea762a764a9dc52aa7dabcfad51a354819de1f0d4652b6a1122424d6/scipy-1.17.1-cp314-cp314-win_amd64.whl", hash = "sha256:3877ac408e14da24a6196de0ddcace62092bfc12a83823e92e49e40747e52c19", size = 37290984, upload-time = "2026-02-23T00:22:35.023Z" },
{ url = "https://files.pythonhosted.org/packages/7c/56/fe201e3b0f93d1a8bcf75d3379affd228a63d7e2d80ab45467a74b494947/scipy-1.17.1-cp314-cp314-win_arm64.whl", hash = "sha256:f8885db0bc2bffa59d5c1b72fad7a6a92d3e80e7257f967dd81abb553a90d293", size = 25192877, upload-time = "2026-02-23T00:22:39.798Z" },
{ url = "https://files.pythonhosted.org/packages/96/ad/f8c414e121f82e02d76f310f16db9899c4fcde36710329502a6b2a3c0392/scipy-1.17.1-cp314-cp314t-macosx_10_14_x86_64.whl", hash = "sha256:1cc682cea2ae55524432f3cdff9e9a3be743d52a7443d0cba9017c23c87ae2f6", size = 31949750, upload-time = "2026-02-23T00:21:42.289Z" },
{ url = "https://files.pythonhosted.org/packages/7c/b0/c741e8865d61b67c81e255f4f0a832846c064e426636cd7de84e74d209be/scipy-1.17.1-cp314-cp314t-macosx_12_0_arm64.whl", hash = "sha256:2040ad4d1795a0ae89bfc7e8429677f365d45aa9fd5e4587cf1ea737f927b4a1", size = 28585858, upload-time = "2026-02-23T00:21:47.706Z" },
{ url = "https://files.pythonhosted.org/packages/ed/1b/3985219c6177866628fa7c2595bfd23f193ceebbe472c98a08824b9466ff/scipy-1.17.1-cp314-cp314t-macosx_14_0_arm64.whl", hash = "sha256:131f5aaea57602008f9822e2115029b55d4b5f7c070287699fe45c661d051e39", size = 20757723, upload-time = "2026-02-23T00:21:52.039Z" },
{ url = "https://files.pythonhosted.org/packages/c0/19/2a04aa25050d656d6f7b9e7b685cc83d6957fb101665bfd9369ca6534563/scipy-1.17.1-cp314-cp314t-macosx_14_0_x86_64.whl", hash = "sha256:9cdc1a2fcfd5c52cfb3045feb399f7b3ce822abdde3a193a6b9a60b3cb5854ca", size = 23043098, upload-time = "2026-02-23T00:21:56.185Z" },
{ url = "https://files.pythonhosted.org/packages/86/f1/3383beb9b5d0dbddd030335bf8a8b32d4317185efe495374f134d8be6cce/scipy-1.17.1-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6e3dcd57ab780c741fde8dc68619de988b966db759a3c3152e8e9142c26295ad", size = 33030397, upload-time = "2026-02-23T00:22:01.404Z" },
{ url = "https://files.pythonhosted.org/packages/41/68/8f21e8a65a5a03f25a79165ec9d2b28c00e66dc80546cf5eb803aeeff35b/scipy-1.17.1-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a9956e4d4f4a301ebf6cde39850333a6b6110799d470dbbb1e25326ac447f52a", size = 35281163, upload-time = "2026-02-23T00:22:07.024Z" },
{ url = "https://files.pythonhosted.org/packages/84/8d/c8a5e19479554007a5632ed7529e665c315ae7492b4f946b0deb39870e39/scipy-1.17.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:a4328d245944d09fd639771de275701ccadf5f781ba0ff092ad141e017eccda4", size = 35116291, upload-time = "2026-02-23T00:22:12.585Z" },
{ url = "https://files.pythonhosted.org/packages/52/52/e57eceff0e342a1f50e274264ed47497b59e6a4e3118808ee58ddda7b74a/scipy-1.17.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:a77cbd07b940d326d39a1d1b37817e2ee4d79cb30e7338f3d0cddffae70fcaa2", size = 37682317, upload-time = "2026-02-23T00:22:18.513Z" },
{ url = "https://files.pythonhosted.org/packages/11/2f/b29eafe4a3fbc3d6de9662b36e028d5f039e72d345e05c250e121a230dd4/scipy-1.17.1-cp314-cp314t-win_amd64.whl", hash = "sha256:eb092099205ef62cd1782b006658db09e2fed75bffcae7cc0d44052d8aa0f484", size = 37345327, upload-time = "2026-02-23T00:22:24.442Z" },
{ url = "https://files.pythonhosted.org/packages/07/39/338d9219c4e87f3e708f18857ecd24d22a0c3094752393319553096b98af/scipy-1.17.1-cp314-cp314t-win_arm64.whl", hash = "sha256:200e1050faffacc162be6a486a984a0497866ec54149a01270adc8a59b7c7d21", size = 25489165, upload-time = "2026-02-23T00:22:29.563Z" },
]
[[package]] [[package]]
name = "setuptools" name = "setuptools"
version = "80.9.0" version = "80.9.0"
@@ -1792,20 +1869,24 @@ source = { virtual = "." }
dependencies = [ dependencies = [
{ name = "datasets" }, { name = "datasets" },
{ name = "einops" }, { name = "einops" },
{ name = "lpips" },
{ name = "matplotlib" }, { name = "matplotlib" },
{ name = "numpy" }, { name = "numpy" },
{ name = "swanlab" }, { name = "swanlab" },
{ name = "torch" }, { name = "torch" },
{ name = "torchvision" },
] ]
[package.metadata] [package.metadata]
requires-dist = [ requires-dist = [
{ name = "datasets", specifier = ">=2.19.0" }, { name = "datasets", specifier = ">=2.19.0" },
{ name = "einops", specifier = ">=0.7.0" }, { name = "einops", specifier = ">=0.7.0" },
{ name = "lpips", specifier = ">=0.1.4" },
{ name = "matplotlib", specifier = ">=3.8.0" }, { name = "matplotlib", specifier = ">=3.8.0" },
{ name = "numpy", specifier = ">=1.26.0" }, { name = "numpy", specifier = ">=1.26.0" },
{ name = "swanlab", specifier = ">=0.5.0" }, { name = "swanlab", specifier = ">=0.5.0" },
{ name = "torch", specifier = ">=2.2.0" }, { name = "torch", specifier = ">=2.2.0" },
{ name = "torchvision", specifier = ">=0.24.1" },
] ]
[[package]] [[package]]
@@ -1860,6 +1941,38 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/db/2b/f7818f6ec88758dfd21da46b6cd46af9d1b3433e53ddbb19ad1e0da17f9b/torch-2.9.1-cp314-cp314t-win_amd64.whl", hash = "sha256:c88d3299ddeb2b35dcc31753305612db485ab6f1823e37fb29451c8b2732b87e", size = 111163659, upload-time = "2025-11-12T15:23:20.009Z" }, { url = "https://files.pythonhosted.org/packages/db/2b/f7818f6ec88758dfd21da46b6cd46af9d1b3433e53ddbb19ad1e0da17f9b/torch-2.9.1-cp314-cp314t-win_amd64.whl", hash = "sha256:c88d3299ddeb2b35dcc31753305612db485ab6f1823e37fb29451c8b2732b87e", size = 111163659, upload-time = "2025-11-12T15:23:20.009Z" },
] ]
[[package]]
name = "torchvision"
version = "0.24.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "numpy" },
{ name = "pillow" },
{ name = "torch" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/f0/af/18e2c6b9538a045f60718a0c5a058908ccb24f88fde8e6f0fc12d5ff7bd3/torchvision-0.24.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e48bf6a8ec95872eb45763f06499f87bd2fb246b9b96cb00aae260fda2f96193", size = 1891433, upload-time = "2025-11-12T15:25:03.232Z" },
{ url = "https://files.pythonhosted.org/packages/9d/43/600e5cfb0643d10d633124f5982d7abc2170dfd7ce985584ff16edab3e76/torchvision-0.24.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:7fb7590c737ebe3e1c077ad60c0e5e2e56bb26e7bccc3b9d04dbfc34fd09f050", size = 2386737, upload-time = "2025-11-12T15:25:08.288Z" },
{ url = "https://files.pythonhosted.org/packages/93/b1/db2941526ecddd84884132e2742a55c9311296a6a38627f9e2627f5ac889/torchvision-0.24.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:66a98471fc18cad9064123106d810a75f57f0838eee20edc56233fd8484b0cc7", size = 8049868, upload-time = "2025-11-12T15:25:13.058Z" },
{ url = "https://files.pythonhosted.org/packages/69/98/16e583f59f86cd59949f59d52bfa8fc286f86341a229a9d15cbe7a694f0c/torchvision-0.24.1-cp312-cp312-win_amd64.whl", hash = "sha256:4aa6cb806eb8541e92c9b313e96192c6b826e9eb0042720e2fa250d021079952", size = 4302006, upload-time = "2025-11-12T15:25:16.184Z" },
{ url = "https://files.pythonhosted.org/packages/e4/97/ab40550f482577f2788304c27220e8ba02c63313bd74cf2f8920526aac20/torchvision-0.24.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:8a6696db7fb71eadb2c6a48602106e136c785642e598eb1533e0b27744f2cce6", size = 1891435, upload-time = "2025-11-12T15:25:28.642Z" },
{ url = "https://files.pythonhosted.org/packages/30/65/ac0a3f9be6abdbe4e1d82c915d7e20de97e7fd0e9a277970508b015309f3/torchvision-0.24.1-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:db2125c46f9cb25dc740be831ce3ce99303cfe60439249a41b04fd9f373be671", size = 2338718, upload-time = "2025-11-12T15:25:26.19Z" },
{ url = "https://files.pythonhosted.org/packages/10/b5/5bba24ff9d325181508501ed7f0c3de8ed3dd2edca0784d48b144b6c5252/torchvision-0.24.1-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:f035f0cacd1f44a8ff6cb7ca3627d84c54d685055961d73a1a9fb9827a5414c8", size = 8049661, upload-time = "2025-11-12T15:25:22.558Z" },
{ url = "https://files.pythonhosted.org/packages/5c/ec/54a96ae9ab6a0dd66d4bba27771f892e36478a9c3489fa56e51c70abcc4d/torchvision-0.24.1-cp313-cp313-win_amd64.whl", hash = "sha256:16274823b93048e0a29d83415166a2e9e0bf4e1b432668357b657612a4802864", size = 4319808, upload-time = "2025-11-12T15:25:17.318Z" },
{ url = "https://files.pythonhosted.org/packages/d5/f3/a90a389a7e547f3eb8821b13f96ea7c0563cdefbbbb60a10e08dda9720ff/torchvision-0.24.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:e3f96208b4bef54cd60e415545f5200346a65024e04f29a26cd0006dbf9e8e66", size = 2005342, upload-time = "2025-11-12T15:25:11.871Z" },
{ url = "https://files.pythonhosted.org/packages/a9/fe/ff27d2ed1b524078164bea1062f23d2618a5fc3208e247d6153c18c91a76/torchvision-0.24.1-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:f231f6a4f2aa6522713326d0d2563538fa72d613741ae364f9913027fa52ea35", size = 2341708, upload-time = "2025-11-12T15:25:25.08Z" },
{ url = "https://files.pythonhosted.org/packages/b1/b9/d6c903495cbdfd2533b3ef6f7b5643ff589ea062f8feb5c206ee79b9d9e5/torchvision-0.24.1-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:1540a9e7f8cf55fe17554482f5a125a7e426347b71de07327d5de6bfd8d17caa", size = 8177239, upload-time = "2025-11-12T15:25:18.554Z" },
{ url = "https://files.pythonhosted.org/packages/4f/2b/ba02e4261369c3798310483028495cf507e6cb3f394f42e4796981ecf3a7/torchvision-0.24.1-cp313-cp313t-win_amd64.whl", hash = "sha256:d83e16d70ea85d2f196d678bfb702c36be7a655b003abed84e465988b6128938", size = 4251604, upload-time = "2025-11-12T15:25:34.069Z" },
{ url = "https://files.pythonhosted.org/packages/42/84/577b2cef8f32094add5f52887867da4c2a3e6b4261538447e9b48eb25812/torchvision-0.24.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:cccf4b4fec7fdfcd3431b9ea75d1588c0a8596d0333245dafebee0462abe3388", size = 2005319, upload-time = "2025-11-12T15:25:23.827Z" },
{ url = "https://files.pythonhosted.org/packages/5f/34/ecb786bffe0159a3b49941a61caaae089853132f3cd1e8f555e3621f7e6f/torchvision-0.24.1-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:1b495edd3a8f9911292424117544f0b4ab780452e998649425d1f4b2bed6695f", size = 2338844, upload-time = "2025-11-12T15:25:32.625Z" },
{ url = "https://files.pythonhosted.org/packages/51/99/a84623786a6969504c87f2dc3892200f586ee13503f519d282faab0bb4f0/torchvision-0.24.1-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:ab211e1807dc3e53acf8f6638df9a7444c80c0ad050466e8d652b3e83776987b", size = 8175144, upload-time = "2025-11-12T15:25:31.355Z" },
{ url = "https://files.pythonhosted.org/packages/6d/ba/8fae3525b233e109317ce6a9c1de922ab2881737b029a7e88021f81e068f/torchvision-0.24.1-cp314-cp314-win_amd64.whl", hash = "sha256:18f9cb60e64b37b551cd605a3d62c15730c086362b40682d23e24b616a697d41", size = 4234459, upload-time = "2025-11-12T15:25:19.859Z" },
{ url = "https://files.pythonhosted.org/packages/50/33/481602c1c72d0485d4b3a6b48c9534b71c2957c9d83bf860eb837bf5a620/torchvision-0.24.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:ec9d7379c519428395e4ffda4dbb99ec56be64b0a75b95989e00f9ec7ae0b2d7", size = 2005336, upload-time = "2025-11-12T15:25:27.225Z" },
{ url = "https://files.pythonhosted.org/packages/d0/7f/372de60bf3dd8f5593bd0d03f4aecf0d1fd58f5bc6943618d9d913f5e6d5/torchvision-0.24.1-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:af9201184c2712d808bd4eb656899011afdfce1e83721c7cb08000034df353fe", size = 2341704, upload-time = "2025-11-12T15:25:29.857Z" },
{ url = "https://files.pythonhosted.org/packages/36/9b/0f3b9ff3d0225ee2324ec663de0e7fb3eb855615ca958ac1875f22f1f8e5/torchvision-0.24.1-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:9ef95d819fd6df81bc7cc97b8f21a15d2c0d3ac5dbfaab5cbc2d2ce57114b19e", size = 8177422, upload-time = "2025-11-12T15:25:37.357Z" },
{ url = "https://files.pythonhosted.org/packages/d6/ab/e2bcc7c2f13d882a58f8b30ff86f794210b075736587ea50f8c545834f8a/torchvision-0.24.1-cp314-cp314t-win_amd64.whl", hash = "sha256:480b271d6edff83ac2e8d69bbb4cf2073f93366516a50d48f140ccfceedb002e", size = 4335190, upload-time = "2025-11-12T15:25:35.745Z" },
]
[[package]] [[package]]
name = "tqdm" name = "tqdm"
version = "4.67.1" version = "4.67.1"