refactor: simplify delta-only flow training

Remove learned dt prediction and auxiliary losses.

Add repository contributor guidelines.
This commit is contained in:
gameloader
2026-03-10 18:23:17 +08:00
parent 913740266b
commit 01fc1e4eab
4 changed files with 94 additions and 170 deletions

View File

@@ -36,11 +36,6 @@ class TrainConfig:
dt_max: float = 0.06
dt_alpha: float = 8.0
lambda_flow: float = 1.0
lambda_pos: float = 1.0
lambda_dt: float = 1.0
use_flow_loss: bool = True
use_pos_loss: bool = False
use_dt_loss: bool = True
num_classes: int = 10
image_size: int = 28
channels: int = 1
@@ -61,7 +56,6 @@ class TrainConfig:
val_every: int = 200
val_samples_per_class: int = 8
val_grid_rows: int = 4
val_max_steps: int = 0
use_ddp: bool = False
@@ -113,12 +107,30 @@ class Mamba2Backbone(nn.Module):
return x, h
def sinusoidal_embedding(t: Tensor, dim: int) -> Tensor:
if dim <= 0:
raise ValueError("sinusoidal embedding dim must be > 0")
half = dim // 2
if half == 0:
return torch.zeros(*t.shape, dim, device=t.device, dtype=t.dtype)
denom = max(half - 1, 1)
freqs = torch.exp(
-math.log(10000.0)
* torch.arange(half, device=t.device, dtype=torch.float32)
/ denom
)
args = t.to(torch.float32).unsqueeze(-1) * freqs
emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
if dim % 2 == 1:
pad = torch.zeros(*emb.shape[:-1], 1, device=emb.device, dtype=emb.dtype)
emb = torch.cat([emb, pad], dim=-1)
return emb.to(t.dtype)
class ASMamba(nn.Module):
def __init__(self, cfg: TrainConfig) -> None:
super().__init__()
self.cfg = cfg
self.dt_min = float(cfg.dt_min)
self.dt_max = float(cfg.dt_max)
input_dim = cfg.channels * cfg.image_size * cfg.image_size
if cfg.d_model == 0:
cfg.d_model = input_dim
@@ -139,27 +151,31 @@ class ASMamba(nn.Module):
self.backbone = Mamba2Backbone(args, use_residual=cfg.use_residual)
self.cond_emb = nn.Embedding(cfg.num_classes, cfg.d_model)
self.delta_head = nn.Linear(cfg.d_model, input_dim)
self.dt_head = nn.Sequential(
nn.Linear(cfg.d_model, cfg.d_model),
nn.SiLU(),
nn.Linear(cfg.d_model, 1),
)
def forward(
self, x: Tensor, cond: Tensor, h: Optional[list[InferenceCache]] = None
) -> tuple[Tensor, Tensor, list[InferenceCache]]:
self,
x: Tensor,
dt: Tensor,
cond: Tensor,
h: Optional[list[InferenceCache]] = None,
) -> tuple[Tensor, list[InferenceCache]]:
if dt.dim() == 1:
dt = dt.unsqueeze(1)
elif dt.dim() == 3 and dt.shape[-1] == 1:
dt = dt.squeeze(-1)
dt = dt.to(dtype=x.dtype)
dt_emb = sinusoidal_embedding(dt, x.shape[-1])
x = x + dt_emb
cond_vec = self.cond_emb(cond)
feats, h = self.backbone(x, cond_vec, h)
delta = self.delta_head(feats)
dt_raw = self.dt_head(feats).squeeze(-1)
dt = F.softplus(dt_raw)
return delta, dt, h
return delta, h
def step(
self, x: Tensor, cond: Tensor, h: list[InferenceCache]
) -> tuple[Tensor, Tensor, list[InferenceCache]]:
delta, dt, h = self.forward(x.unsqueeze(1), cond, h)
return delta[:, 0, :], dt[:, 0], h
self, x: Tensor, dt: Tensor, cond: Tensor, h: list[InferenceCache]
) -> tuple[Tensor, list[InferenceCache]]:
delta, h = self.forward(x.unsqueeze(1), dt.unsqueeze(1), cond, h)
return delta[:, 0, :], h
def init_cache(self, batch_size: int, device: torch.device) -> list[InferenceCache]:
return [
@@ -332,51 +348,15 @@ def infinite_loader(loader: DataLoader) -> Iterator[dict]:
def compute_losses(
delta: Tensor,
dt: Tensor,
x_seq: Tensor,
x0: Tensor,
v_gt: Tensor,
t_seq: Tensor,
dt_seq: Tensor,
cfg: TrainConfig,
) -> dict[str, Tensor]:
losses: dict[str, Tensor] = {}
if cfg.use_flow_loss:
target_disp = v_gt[:, None, :] * dt.unsqueeze(-1)
losses["flow"] = F.mse_loss(delta, target_disp)
if cfg.use_pos_loss:
t_next = t_seq + dt
t_next = torch.clamp(t_next, 0.0, 1.0)
x_target = x0[:, None, :] + t_next.unsqueeze(-1) * v_gt[:, None, :]
x_next_pred = x_seq + delta
losses["pos"] = F.mse_loss(x_next_pred, x_target)
if cfg.use_dt_loss:
losses["dt"] = F.mse_loss(dt, dt_seq)
target_disp = v_gt[:, None, :] * dt_seq.unsqueeze(-1)
losses["flow"] = F.mse_loss(delta, target_disp)
return losses
def plot_dt_hist(
dt_pred: Tensor, dt_gt: Tensor, save_path: Path, title: str = "dt Distribution"
) -> None:
dt_pred_np = dt_pred.detach().cpu().numpy().reshape(-1)
dt_gt_np = dt_gt.detach().cpu().numpy().reshape(-1)
fig, ax = plt.subplots(figsize=(6, 4))
ax.hist(dt_gt_np, bins=30, alpha=0.6, label="dt_gt", color="steelblue")
ax.hist(dt_pred_np, bins=30, alpha=0.6, label="dt_pred", color="orange")
ax.set_title(title)
ax.set_xlabel("dt")
ax.set_ylabel("count")
ax.legend(loc="best")
fig.tight_layout()
fig.savefig(save_path, dpi=160)
plt.close(fig)
def make_grid(images: Tensor, nrow: int) -> np.ndarray:
images = images.detach().cpu().numpy()
b, c, h, w = images.shape
@@ -413,30 +393,24 @@ def rollout_trajectory(
model: ASMamba,
x0: Tensor,
cond: Tensor,
max_steps: int,
dt_seq: Tensor,
) -> Tensor:
device = x0.device
model.eval()
h = model.init_cache(batch_size=x0.shape[0], device=device)
x = x0
total_time = torch.zeros(x0.shape[0], device=device)
if dt_seq.dim() == 1:
dt_seq = dt_seq.unsqueeze(0).expand(x0.shape[0], -1)
elif dt_seq.shape[0] == 1 and x0.shape[0] > 1:
dt_seq = dt_seq.expand(x0.shape[0], -1)
traj = [x0.detach().cpu()]
with torch.no_grad():
for _ in range(max_steps):
delta, dt, h = model.step(x, cond, h)
dt = torch.clamp(dt, min=model.dt_min, max=model.dt_max)
remaining = 1.0 - total_time
overshoot = dt > remaining
if overshoot.any():
scale = (remaining / dt).unsqueeze(-1)
delta = torch.where(overshoot.unsqueeze(-1), delta * scale, delta)
dt = torch.where(overshoot, remaining, dt)
for step_idx in range(dt_seq.shape[1]):
dt = dt_seq[:, step_idx]
delta, h = model.step(x, dt, cond, h)
x = x + delta
total_time = total_time + dt
traj.append(x.detach().cpu())
if torch.all(total_time >= 1.0 - 1e-6):
break
return torch.stack(traj, dim=1)
@@ -451,15 +425,20 @@ def log_class_samples(
if cfg.val_samples_per_class <= 0:
return
model.eval()
max_steps = cfg.seq_len if cfg.val_max_steps <= 0 else cfg.val_max_steps
max_steps = cfg.seq_len
input_dim = cfg.channels * cfg.image_size * cfg.image_size
dt_seq = torch.full(
(cfg.val_samples_per_class, max_steps),
1.0 / max_steps,
device=device,
)
for cls in range(cfg.num_classes):
cond = torch.full(
(cfg.val_samples_per_class,), cls, device=device, dtype=torch.long
)
x0 = torch.randn(cfg.val_samples_per_class, input_dim, device=device)
traj = rollout_trajectory(model, x0, cond, max_steps=max_steps)
traj = rollout_trajectory(model, x0, cond, dt_seq=dt_seq)
x_final = traj[:, -1, :].view(
cfg.val_samples_per_class, cfg.channels, cfg.image_size, cfg.image_size
)
@@ -511,68 +490,25 @@ def train(cfg: TrainConfig) -> ASMamba:
t_seq = torch.cat([torch.zeros(b, 1, device=device), t_seq[:, :-1]], dim=-1)
x_seq = x0[:, None, :] + t_seq[:, :, None] * v_gt[:, None, :]
delta, dt, _ = model(x_seq, cond)
delta, _ = model(x_seq, dt_seq, cond)
losses = compute_losses(
delta=delta,
dt=dt,
x_seq=x_seq,
x0=x0,
v_gt=v_gt,
t_seq=t_seq,
dt_seq=dt_seq,
cfg=cfg,
)
loss = torch.tensor(0.0, device=device)
if cfg.use_flow_loss and "flow" in losses:
loss = loss + cfg.lambda_flow * losses["flow"]
if cfg.use_pos_loss and "pos" in losses:
loss = loss + cfg.lambda_pos * losses["pos"]
if cfg.use_dt_loss and "dt" in losses:
loss = loss + cfg.lambda_dt * losses["dt"]
if loss.item() == 0.0:
raise RuntimeError(
"No loss enabled: enable at least one of flow/pos/dt."
)
loss = cfg.lambda_flow * losses["flow"]
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
if global_step % 10 == 0:
dt_min = float(dt.min().item())
dt_max = float(dt.max().item())
dt_mean = float(dt.mean().item())
dt_gt_min = float(dt_seq.min().item())
dt_gt_max = float(dt_seq.max().item())
dt_gt_mean = float(dt_seq.mean().item())
eps = 1e-6
clamp_min_ratio = float((dt <= cfg.dt_min + eps).float().mean().item())
clamp_max_ratio = float((dt >= cfg.dt_max - eps).float().mean().item())
clamp_any_ratio = float(
((dt <= cfg.dt_min + eps) | (dt >= cfg.dt_max - eps))
.float()
.mean()
.item()
)
logger.log(
{
"loss/total": float(loss.item()),
"loss/flow": float(
losses.get("flow", torch.tensor(0.0)).item()
),
"loss/pos": float(losses.get("pos", torch.tensor(0.0)).item()),
"loss/dt": float(losses.get("dt", torch.tensor(0.0)).item()),
"dt/pred_mean": dt_mean,
"dt/pred_min": dt_min,
"dt/pred_max": dt_max,
"dt/gt_mean": dt_gt_mean,
"dt/gt_min": dt_gt_min,
"dt/gt_max": dt_gt_max,
"dt/clamp_min_ratio": clamp_min_ratio,
"dt/clamp_max_ratio": clamp_max_ratio,
"dt/clamp_any_ratio": clamp_any_ratio,
"loss/flow": float(losses["flow"].item()),
},
step=global_step,
)
@@ -584,21 +520,6 @@ def train(cfg: TrainConfig) -> ASMamba:
and rank == 0
):
log_class_samples(unwrap_model(model), cfg, device, logger, global_step)
dt_hist_path = (
Path(cfg.output_dir) / f"dt_hist_step_{global_step:06d}.png"
)
plot_dt_hist(
dt,
dt_seq,
dt_hist_path,
title=f"dt Distribution (step {global_step})",
)
logger.log_image(
"train/dt_hist",
dt_hist_path,
caption=f"step {global_step}",
step=global_step,
)
global_step += 1