refactor: simplify delta-only flow training
Remove learned dt prediction and auxiliary losses. Add repository contributor guidelines.
This commit is contained in:
27
AGENTS.md
Normal file
27
AGENTS.md
Normal file
@@ -0,0 +1,27 @@
|
||||
# Repository Guidelines
|
||||
|
||||
## Project Structure & Module Organization
|
||||
This repository is a flat Python training project, not a packaged library. `main.py` is the CLI entry point. `as_mamba.py` holds `TrainConfig`, dataset loading, the training loop, validation image generation, and SwanLab logging. `mamba2_minimal.py` contains the standalone Mamba-2 building blocks. `train_as_mamba.sh` is the multi-GPU launcher. Runtime artifacts land in `outputs/` and `swanlog/`; treat both as generated data, not source.
|
||||
|
||||
## Build, Test, and Development Commands
|
||||
Use `uv` with the committed lockfile.
|
||||
|
||||
- `uv sync` installs the Python 3.12 environment from `pyproject.toml` and `uv.lock`.
|
||||
- `uv run python main.py --help` lists all training flags and is the fastest CLI sanity check.
|
||||
- `uv run python main.py --device cpu --epochs 1 --steps-per-epoch 1 --batch-size 8 --num-workers 0 --val-every 0` runs a minimal local smoke test.
|
||||
- `bash train_as_mamba.sh` launches the default distributed training job with `torchrun`; adjust `--nproc_per_node` and shell variables before using it on a new machine.
|
||||
- `uv run python -m compileall main.py as_mamba.py mamba2_minimal.py` performs a lightweight syntax check when no test suite is available.
|
||||
|
||||
## Coding Style & Naming Conventions
|
||||
Follow the existing Python style: 4-space indentation, type hints on public functions, `dataclass`-based config, and small helper functions over deeply nested logic. Use `snake_case` for functions, variables, and CLI flags; use `PascalCase` for classes such as `ASMamba` and `TrainConfig`. Keep new modules top-level unless the project is restructured. There is no configured formatter or linter yet, so match surrounding code closely and keep imports grouped and readable.
|
||||
|
||||
## Testing Guidelines
|
||||
There is no dedicated `tests/` directory or pytest setup yet. For changes, require at minimum:
|
||||
|
||||
- `uv run python -m compileall main.py as_mamba.py mamba2_minimal.py`
|
||||
- one short training smoke run with reduced epochs and workers
|
||||
|
||||
If you add reusable logic, prefer extracting pure functions so a future pytest suite can cover them easily. Name any new test files `test_<module>.py`.
|
||||
|
||||
## Commit & Pull Request Guidelines
|
||||
Recent history favors short imperative commit subjects, usually Conventional Commit style: `feat: ...`, `fix: ...`, and scoped variants like `feat(mamba): ...`. Keep commits focused on one training or infrastructure change. Pull requests should describe the config or behavior change, list the verification commands you ran, and include sample images or metric notes when outputs in `outputs/` materially change. Do not commit `.venv/`, `__pycache__/`, or large generated artifacts.
|
||||
193
as_mamba.py
193
as_mamba.py
@@ -36,11 +36,6 @@ class TrainConfig:
|
||||
dt_max: float = 0.06
|
||||
dt_alpha: float = 8.0
|
||||
lambda_flow: float = 1.0
|
||||
lambda_pos: float = 1.0
|
||||
lambda_dt: float = 1.0
|
||||
use_flow_loss: bool = True
|
||||
use_pos_loss: bool = False
|
||||
use_dt_loss: bool = True
|
||||
num_classes: int = 10
|
||||
image_size: int = 28
|
||||
channels: int = 1
|
||||
@@ -61,7 +56,6 @@ class TrainConfig:
|
||||
val_every: int = 200
|
||||
val_samples_per_class: int = 8
|
||||
val_grid_rows: int = 4
|
||||
val_max_steps: int = 0
|
||||
use_ddp: bool = False
|
||||
|
||||
|
||||
@@ -113,12 +107,30 @@ class Mamba2Backbone(nn.Module):
|
||||
return x, h
|
||||
|
||||
|
||||
def sinusoidal_embedding(t: Tensor, dim: int) -> Tensor:
|
||||
if dim <= 0:
|
||||
raise ValueError("sinusoidal embedding dim must be > 0")
|
||||
half = dim // 2
|
||||
if half == 0:
|
||||
return torch.zeros(*t.shape, dim, device=t.device, dtype=t.dtype)
|
||||
denom = max(half - 1, 1)
|
||||
freqs = torch.exp(
|
||||
-math.log(10000.0)
|
||||
* torch.arange(half, device=t.device, dtype=torch.float32)
|
||||
/ denom
|
||||
)
|
||||
args = t.to(torch.float32).unsqueeze(-1) * freqs
|
||||
emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
|
||||
if dim % 2 == 1:
|
||||
pad = torch.zeros(*emb.shape[:-1], 1, device=emb.device, dtype=emb.dtype)
|
||||
emb = torch.cat([emb, pad], dim=-1)
|
||||
return emb.to(t.dtype)
|
||||
|
||||
|
||||
class ASMamba(nn.Module):
|
||||
def __init__(self, cfg: TrainConfig) -> None:
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
self.dt_min = float(cfg.dt_min)
|
||||
self.dt_max = float(cfg.dt_max)
|
||||
input_dim = cfg.channels * cfg.image_size * cfg.image_size
|
||||
if cfg.d_model == 0:
|
||||
cfg.d_model = input_dim
|
||||
@@ -139,27 +151,31 @@ class ASMamba(nn.Module):
|
||||
self.backbone = Mamba2Backbone(args, use_residual=cfg.use_residual)
|
||||
self.cond_emb = nn.Embedding(cfg.num_classes, cfg.d_model)
|
||||
self.delta_head = nn.Linear(cfg.d_model, input_dim)
|
||||
self.dt_head = nn.Sequential(
|
||||
nn.Linear(cfg.d_model, cfg.d_model),
|
||||
nn.SiLU(),
|
||||
nn.Linear(cfg.d_model, 1),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, x: Tensor, cond: Tensor, h: Optional[list[InferenceCache]] = None
|
||||
) -> tuple[Tensor, Tensor, list[InferenceCache]]:
|
||||
self,
|
||||
x: Tensor,
|
||||
dt: Tensor,
|
||||
cond: Tensor,
|
||||
h: Optional[list[InferenceCache]] = None,
|
||||
) -> tuple[Tensor, list[InferenceCache]]:
|
||||
if dt.dim() == 1:
|
||||
dt = dt.unsqueeze(1)
|
||||
elif dt.dim() == 3 and dt.shape[-1] == 1:
|
||||
dt = dt.squeeze(-1)
|
||||
dt = dt.to(dtype=x.dtype)
|
||||
dt_emb = sinusoidal_embedding(dt, x.shape[-1])
|
||||
x = x + dt_emb
|
||||
cond_vec = self.cond_emb(cond)
|
||||
feats, h = self.backbone(x, cond_vec, h)
|
||||
delta = self.delta_head(feats)
|
||||
dt_raw = self.dt_head(feats).squeeze(-1)
|
||||
dt = F.softplus(dt_raw)
|
||||
return delta, dt, h
|
||||
return delta, h
|
||||
|
||||
def step(
|
||||
self, x: Tensor, cond: Tensor, h: list[InferenceCache]
|
||||
) -> tuple[Tensor, Tensor, list[InferenceCache]]:
|
||||
delta, dt, h = self.forward(x.unsqueeze(1), cond, h)
|
||||
return delta[:, 0, :], dt[:, 0], h
|
||||
self, x: Tensor, dt: Tensor, cond: Tensor, h: list[InferenceCache]
|
||||
) -> tuple[Tensor, list[InferenceCache]]:
|
||||
delta, h = self.forward(x.unsqueeze(1), dt.unsqueeze(1), cond, h)
|
||||
return delta[:, 0, :], h
|
||||
|
||||
def init_cache(self, batch_size: int, device: torch.device) -> list[InferenceCache]:
|
||||
return [
|
||||
@@ -332,51 +348,15 @@ def infinite_loader(loader: DataLoader) -> Iterator[dict]:
|
||||
|
||||
def compute_losses(
|
||||
delta: Tensor,
|
||||
dt: Tensor,
|
||||
x_seq: Tensor,
|
||||
x0: Tensor,
|
||||
v_gt: Tensor,
|
||||
t_seq: Tensor,
|
||||
dt_seq: Tensor,
|
||||
cfg: TrainConfig,
|
||||
) -> dict[str, Tensor]:
|
||||
losses: dict[str, Tensor] = {}
|
||||
|
||||
if cfg.use_flow_loss:
|
||||
target_disp = v_gt[:, None, :] * dt.unsqueeze(-1)
|
||||
target_disp = v_gt[:, None, :] * dt_seq.unsqueeze(-1)
|
||||
losses["flow"] = F.mse_loss(delta, target_disp)
|
||||
|
||||
if cfg.use_pos_loss:
|
||||
t_next = t_seq + dt
|
||||
t_next = torch.clamp(t_next, 0.0, 1.0)
|
||||
x_target = x0[:, None, :] + t_next.unsqueeze(-1) * v_gt[:, None, :]
|
||||
x_next_pred = x_seq + delta
|
||||
losses["pos"] = F.mse_loss(x_next_pred, x_target)
|
||||
|
||||
if cfg.use_dt_loss:
|
||||
losses["dt"] = F.mse_loss(dt, dt_seq)
|
||||
|
||||
return losses
|
||||
|
||||
|
||||
def plot_dt_hist(
|
||||
dt_pred: Tensor, dt_gt: Tensor, save_path: Path, title: str = "dt Distribution"
|
||||
) -> None:
|
||||
dt_pred_np = dt_pred.detach().cpu().numpy().reshape(-1)
|
||||
dt_gt_np = dt_gt.detach().cpu().numpy().reshape(-1)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(6, 4))
|
||||
ax.hist(dt_gt_np, bins=30, alpha=0.6, label="dt_gt", color="steelblue")
|
||||
ax.hist(dt_pred_np, bins=30, alpha=0.6, label="dt_pred", color="orange")
|
||||
ax.set_title(title)
|
||||
ax.set_xlabel("dt")
|
||||
ax.set_ylabel("count")
|
||||
ax.legend(loc="best")
|
||||
fig.tight_layout()
|
||||
fig.savefig(save_path, dpi=160)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def make_grid(images: Tensor, nrow: int) -> np.ndarray:
|
||||
images = images.detach().cpu().numpy()
|
||||
b, c, h, w = images.shape
|
||||
@@ -413,30 +393,24 @@ def rollout_trajectory(
|
||||
model: ASMamba,
|
||||
x0: Tensor,
|
||||
cond: Tensor,
|
||||
max_steps: int,
|
||||
dt_seq: Tensor,
|
||||
) -> Tensor:
|
||||
device = x0.device
|
||||
model.eval()
|
||||
h = model.init_cache(batch_size=x0.shape[0], device=device)
|
||||
x = x0
|
||||
total_time = torch.zeros(x0.shape[0], device=device)
|
||||
if dt_seq.dim() == 1:
|
||||
dt_seq = dt_seq.unsqueeze(0).expand(x0.shape[0], -1)
|
||||
elif dt_seq.shape[0] == 1 and x0.shape[0] > 1:
|
||||
dt_seq = dt_seq.expand(x0.shape[0], -1)
|
||||
traj = [x0.detach().cpu()]
|
||||
|
||||
with torch.no_grad():
|
||||
for _ in range(max_steps):
|
||||
delta, dt, h = model.step(x, cond, h)
|
||||
dt = torch.clamp(dt, min=model.dt_min, max=model.dt_max)
|
||||
remaining = 1.0 - total_time
|
||||
overshoot = dt > remaining
|
||||
if overshoot.any():
|
||||
scale = (remaining / dt).unsqueeze(-1)
|
||||
delta = torch.where(overshoot.unsqueeze(-1), delta * scale, delta)
|
||||
dt = torch.where(overshoot, remaining, dt)
|
||||
for step_idx in range(dt_seq.shape[1]):
|
||||
dt = dt_seq[:, step_idx]
|
||||
delta, h = model.step(x, dt, cond, h)
|
||||
x = x + delta
|
||||
total_time = total_time + dt
|
||||
traj.append(x.detach().cpu())
|
||||
if torch.all(total_time >= 1.0 - 1e-6):
|
||||
break
|
||||
|
||||
return torch.stack(traj, dim=1)
|
||||
|
||||
@@ -451,15 +425,20 @@ def log_class_samples(
|
||||
if cfg.val_samples_per_class <= 0:
|
||||
return
|
||||
model.eval()
|
||||
max_steps = cfg.seq_len if cfg.val_max_steps <= 0 else cfg.val_max_steps
|
||||
max_steps = cfg.seq_len
|
||||
input_dim = cfg.channels * cfg.image_size * cfg.image_size
|
||||
dt_seq = torch.full(
|
||||
(cfg.val_samples_per_class, max_steps),
|
||||
1.0 / max_steps,
|
||||
device=device,
|
||||
)
|
||||
|
||||
for cls in range(cfg.num_classes):
|
||||
cond = torch.full(
|
||||
(cfg.val_samples_per_class,), cls, device=device, dtype=torch.long
|
||||
)
|
||||
x0 = torch.randn(cfg.val_samples_per_class, input_dim, device=device)
|
||||
traj = rollout_trajectory(model, x0, cond, max_steps=max_steps)
|
||||
traj = rollout_trajectory(model, x0, cond, dt_seq=dt_seq)
|
||||
x_final = traj[:, -1, :].view(
|
||||
cfg.val_samples_per_class, cfg.channels, cfg.image_size, cfg.image_size
|
||||
)
|
||||
@@ -511,68 +490,25 @@ def train(cfg: TrainConfig) -> ASMamba:
|
||||
t_seq = torch.cat([torch.zeros(b, 1, device=device), t_seq[:, :-1]], dim=-1)
|
||||
x_seq = x0[:, None, :] + t_seq[:, :, None] * v_gt[:, None, :]
|
||||
|
||||
delta, dt, _ = model(x_seq, cond)
|
||||
delta, _ = model(x_seq, dt_seq, cond)
|
||||
|
||||
losses = compute_losses(
|
||||
delta=delta,
|
||||
dt=dt,
|
||||
x_seq=x_seq,
|
||||
x0=x0,
|
||||
v_gt=v_gt,
|
||||
t_seq=t_seq,
|
||||
dt_seq=dt_seq,
|
||||
cfg=cfg,
|
||||
)
|
||||
|
||||
loss = torch.tensor(0.0, device=device)
|
||||
if cfg.use_flow_loss and "flow" in losses:
|
||||
loss = loss + cfg.lambda_flow * losses["flow"]
|
||||
if cfg.use_pos_loss and "pos" in losses:
|
||||
loss = loss + cfg.lambda_pos * losses["pos"]
|
||||
if cfg.use_dt_loss and "dt" in losses:
|
||||
loss = loss + cfg.lambda_dt * losses["dt"]
|
||||
if loss.item() == 0.0:
|
||||
raise RuntimeError(
|
||||
"No loss enabled: enable at least one of flow/pos/dt."
|
||||
)
|
||||
loss = cfg.lambda_flow * losses["flow"]
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if global_step % 10 == 0:
|
||||
dt_min = float(dt.min().item())
|
||||
dt_max = float(dt.max().item())
|
||||
dt_mean = float(dt.mean().item())
|
||||
dt_gt_min = float(dt_seq.min().item())
|
||||
dt_gt_max = float(dt_seq.max().item())
|
||||
dt_gt_mean = float(dt_seq.mean().item())
|
||||
eps = 1e-6
|
||||
clamp_min_ratio = float((dt <= cfg.dt_min + eps).float().mean().item())
|
||||
clamp_max_ratio = float((dt >= cfg.dt_max - eps).float().mean().item())
|
||||
clamp_any_ratio = float(
|
||||
((dt <= cfg.dt_min + eps) | (dt >= cfg.dt_max - eps))
|
||||
.float()
|
||||
.mean()
|
||||
.item()
|
||||
)
|
||||
logger.log(
|
||||
{
|
||||
"loss/total": float(loss.item()),
|
||||
"loss/flow": float(
|
||||
losses.get("flow", torch.tensor(0.0)).item()
|
||||
),
|
||||
"loss/pos": float(losses.get("pos", torch.tensor(0.0)).item()),
|
||||
"loss/dt": float(losses.get("dt", torch.tensor(0.0)).item()),
|
||||
"dt/pred_mean": dt_mean,
|
||||
"dt/pred_min": dt_min,
|
||||
"dt/pred_max": dt_max,
|
||||
"dt/gt_mean": dt_gt_mean,
|
||||
"dt/gt_min": dt_gt_min,
|
||||
"dt/gt_max": dt_gt_max,
|
||||
"dt/clamp_min_ratio": clamp_min_ratio,
|
||||
"dt/clamp_max_ratio": clamp_max_ratio,
|
||||
"dt/clamp_any_ratio": clamp_any_ratio,
|
||||
"loss/flow": float(losses["flow"].item()),
|
||||
},
|
||||
step=global_step,
|
||||
)
|
||||
@@ -584,21 +520,6 @@ def train(cfg: TrainConfig) -> ASMamba:
|
||||
and rank == 0
|
||||
):
|
||||
log_class_samples(unwrap_model(model), cfg, device, logger, global_step)
|
||||
dt_hist_path = (
|
||||
Path(cfg.output_dir) / f"dt_hist_step_{global_step:06d}.png"
|
||||
)
|
||||
plot_dt_hist(
|
||||
dt,
|
||||
dt_seq,
|
||||
dt_hist_path,
|
||||
title=f"dt Distribution (step {global_step})",
|
||||
)
|
||||
logger.log_image(
|
||||
"train/dt_hist",
|
||||
dt_hist_path,
|
||||
caption=f"step {global_step}",
|
||||
step=global_step,
|
||||
)
|
||||
|
||||
global_step += 1
|
||||
|
||||
|
||||
6
main.py
6
main.py
@@ -19,11 +19,6 @@ def build_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument("--dt-min", type=float, default=None)
|
||||
parser.add_argument("--dt-max", type=float, default=None)
|
||||
parser.add_argument("--lambda-flow", type=float, default=None)
|
||||
parser.add_argument("--lambda-pos", type=float, default=None)
|
||||
parser.add_argument("--lambda-dt", type=float, default=None)
|
||||
parser.add_argument("--use-flow-loss", action=argparse.BooleanOptionalAction, default=None)
|
||||
parser.add_argument("--use-pos-loss", action=argparse.BooleanOptionalAction, default=None)
|
||||
parser.add_argument("--use-dt-loss", action=argparse.BooleanOptionalAction, default=None)
|
||||
parser.add_argument("--num-classes", type=int, default=None)
|
||||
parser.add_argument("--image-size", type=int, default=None)
|
||||
parser.add_argument("--channels", type=int, default=None)
|
||||
@@ -41,7 +36,6 @@ def build_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument("--val-every", type=int, default=None)
|
||||
parser.add_argument("--val-samples-per-class", type=int, default=None)
|
||||
parser.add_argument("--val-grid-rows", type=int, default=None)
|
||||
parser.add_argument("--val-max-steps", type=int, default=None)
|
||||
parser.add_argument("--use-ddp", action=argparse.BooleanOptionalAction, default=None)
|
||||
return parser
|
||||
|
||||
|
||||
@@ -4,48 +4,36 @@ set -euo pipefail
|
||||
DEVICE="cuda"
|
||||
EPOCHS=2000
|
||||
STEPS_PER_EPOCH=200
|
||||
BATCH_SIZE=256
|
||||
SEQ_LEN=100
|
||||
LR=2e-3
|
||||
BATCH_SIZE=512
|
||||
SEQ_LEN=1
|
||||
LR=1e-3
|
||||
WEIGHT_DECAY=1e-2
|
||||
DT_MIN=5e-4
|
||||
DT_MAX=0.06
|
||||
DT_MAX=1.1
|
||||
DT_ALPHA=9.0
|
||||
LAMBDA_FLOW=1.0
|
||||
LAMBDA_POS=1.0
|
||||
LAMBDA_DT=1.0
|
||||
USE_FLOW_LOSS=true
|
||||
USE_POS_LOSS=false
|
||||
USE_DT_LOSS=true
|
||||
NUM_CLASSES=10
|
||||
IMAGE_SIZE=28
|
||||
CHANNELS=1
|
||||
NUM_WORKERS=16
|
||||
NUM_WORKERS=32
|
||||
DATASET_NAME="ylecun/mnist"
|
||||
DATASET_SPLIT="train"
|
||||
D_MODEL=784
|
||||
N_LAYER=6
|
||||
N_LAYER=8
|
||||
D_STATE=32
|
||||
D_CONV=4
|
||||
EXPAND=2
|
||||
HEADDIM=32
|
||||
CHUNK_SIZE=20
|
||||
USE_RESIDUAL=false
|
||||
CHUNK_SIZE=1
|
||||
USE_RESIDUAL=true
|
||||
USE_DDP=true
|
||||
VAL_EVERY=1000
|
||||
VAL_SAMPLES_PER_CLASS=8
|
||||
VAL_GRID_ROWS=4
|
||||
VAL_MAX_STEPS=0
|
||||
PROJECT="as-mamba-mnist"
|
||||
RUN_NAME="mnist-flow"
|
||||
RUN_NAME="mnist-flow-res-5seq"
|
||||
OUTPUT_DIR="outputs"
|
||||
|
||||
USE_FLOW_FLAG="--use-flow-loss"
|
||||
if [ "${USE_FLOW_LOSS}" = "false" ]; then USE_FLOW_FLAG="--no-use-flow-loss"; fi
|
||||
USE_POS_FLAG="--use-pos-loss"
|
||||
if [ "${USE_POS_LOSS}" = "false" ]; then USE_POS_FLAG="--no-use-pos-loss"; fi
|
||||
USE_DT_FLAG="--use-dt-loss"
|
||||
if [ "${USE_DT_LOSS}" = "false" ]; then USE_DT_FLAG="--no-use-dt-loss"; fi
|
||||
USE_RESIDUAL_FLAG="--use-residual"
|
||||
if [ "${USE_RESIDUAL}" = "false" ]; then USE_RESIDUAL_FLAG="--no-use-residual"; fi
|
||||
USE_DDP_FLAG="--use-ddp"
|
||||
@@ -63,11 +51,6 @@ uv run torchrun --nproc_per_node=2 main.py \
|
||||
--dt-max "${DT_MAX}" \
|
||||
--dt-alpha "${DT_ALPHA}" \
|
||||
--lambda-flow "${LAMBDA_FLOW}" \
|
||||
--lambda-pos "${LAMBDA_POS}" \
|
||||
--lambda-dt "${LAMBDA_DT}" \
|
||||
${USE_FLOW_FLAG} \
|
||||
${USE_POS_FLAG} \
|
||||
${USE_DT_FLAG} \
|
||||
--num-classes "${NUM_CLASSES}" \
|
||||
--image-size "${IMAGE_SIZE}" \
|
||||
--channels "${CHANNELS}" \
|
||||
@@ -86,7 +69,6 @@ uv run torchrun --nproc_per_node=2 main.py \
|
||||
--val-every "${VAL_EVERY}" \
|
||||
--val-samples-per-class "${VAL_SAMPLES_PER_CLASS}" \
|
||||
--val-grid-rows "${VAL_GRID_ROWS}" \
|
||||
--val-max-steps "${VAL_MAX_STEPS}" \
|
||||
--project "${PROJECT}" \
|
||||
--run-name "${RUN_NAME}" \
|
||||
--output-dir "${OUTPUT_DIR}"
|
||||
|
||||
Reference in New Issue
Block a user