refactor: simplify delta-only flow training
Remove learned dt prediction and auxiliary losses. Add repository contributor guidelines.
This commit is contained in:
195
as_mamba.py
195
as_mamba.py
@@ -36,11 +36,6 @@ class TrainConfig:
|
||||
dt_max: float = 0.06
|
||||
dt_alpha: float = 8.0
|
||||
lambda_flow: float = 1.0
|
||||
lambda_pos: float = 1.0
|
||||
lambda_dt: float = 1.0
|
||||
use_flow_loss: bool = True
|
||||
use_pos_loss: bool = False
|
||||
use_dt_loss: bool = True
|
||||
num_classes: int = 10
|
||||
image_size: int = 28
|
||||
channels: int = 1
|
||||
@@ -61,7 +56,6 @@ class TrainConfig:
|
||||
val_every: int = 200
|
||||
val_samples_per_class: int = 8
|
||||
val_grid_rows: int = 4
|
||||
val_max_steps: int = 0
|
||||
use_ddp: bool = False
|
||||
|
||||
|
||||
@@ -113,12 +107,30 @@ class Mamba2Backbone(nn.Module):
|
||||
return x, h
|
||||
|
||||
|
||||
def sinusoidal_embedding(t: Tensor, dim: int) -> Tensor:
|
||||
if dim <= 0:
|
||||
raise ValueError("sinusoidal embedding dim must be > 0")
|
||||
half = dim // 2
|
||||
if half == 0:
|
||||
return torch.zeros(*t.shape, dim, device=t.device, dtype=t.dtype)
|
||||
denom = max(half - 1, 1)
|
||||
freqs = torch.exp(
|
||||
-math.log(10000.0)
|
||||
* torch.arange(half, device=t.device, dtype=torch.float32)
|
||||
/ denom
|
||||
)
|
||||
args = t.to(torch.float32).unsqueeze(-1) * freqs
|
||||
emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
|
||||
if dim % 2 == 1:
|
||||
pad = torch.zeros(*emb.shape[:-1], 1, device=emb.device, dtype=emb.dtype)
|
||||
emb = torch.cat([emb, pad], dim=-1)
|
||||
return emb.to(t.dtype)
|
||||
|
||||
|
||||
class ASMamba(nn.Module):
|
||||
def __init__(self, cfg: TrainConfig) -> None:
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
self.dt_min = float(cfg.dt_min)
|
||||
self.dt_max = float(cfg.dt_max)
|
||||
input_dim = cfg.channels * cfg.image_size * cfg.image_size
|
||||
if cfg.d_model == 0:
|
||||
cfg.d_model = input_dim
|
||||
@@ -139,27 +151,31 @@ class ASMamba(nn.Module):
|
||||
self.backbone = Mamba2Backbone(args, use_residual=cfg.use_residual)
|
||||
self.cond_emb = nn.Embedding(cfg.num_classes, cfg.d_model)
|
||||
self.delta_head = nn.Linear(cfg.d_model, input_dim)
|
||||
self.dt_head = nn.Sequential(
|
||||
nn.Linear(cfg.d_model, cfg.d_model),
|
||||
nn.SiLU(),
|
||||
nn.Linear(cfg.d_model, 1),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, x: Tensor, cond: Tensor, h: Optional[list[InferenceCache]] = None
|
||||
) -> tuple[Tensor, Tensor, list[InferenceCache]]:
|
||||
self,
|
||||
x: Tensor,
|
||||
dt: Tensor,
|
||||
cond: Tensor,
|
||||
h: Optional[list[InferenceCache]] = None,
|
||||
) -> tuple[Tensor, list[InferenceCache]]:
|
||||
if dt.dim() == 1:
|
||||
dt = dt.unsqueeze(1)
|
||||
elif dt.dim() == 3 and dt.shape[-1] == 1:
|
||||
dt = dt.squeeze(-1)
|
||||
dt = dt.to(dtype=x.dtype)
|
||||
dt_emb = sinusoidal_embedding(dt, x.shape[-1])
|
||||
x = x + dt_emb
|
||||
cond_vec = self.cond_emb(cond)
|
||||
feats, h = self.backbone(x, cond_vec, h)
|
||||
delta = self.delta_head(feats)
|
||||
dt_raw = self.dt_head(feats).squeeze(-1)
|
||||
dt = F.softplus(dt_raw)
|
||||
return delta, dt, h
|
||||
return delta, h
|
||||
|
||||
def step(
|
||||
self, x: Tensor, cond: Tensor, h: list[InferenceCache]
|
||||
) -> tuple[Tensor, Tensor, list[InferenceCache]]:
|
||||
delta, dt, h = self.forward(x.unsqueeze(1), cond, h)
|
||||
return delta[:, 0, :], dt[:, 0], h
|
||||
self, x: Tensor, dt: Tensor, cond: Tensor, h: list[InferenceCache]
|
||||
) -> tuple[Tensor, list[InferenceCache]]:
|
||||
delta, h = self.forward(x.unsqueeze(1), dt.unsqueeze(1), cond, h)
|
||||
return delta[:, 0, :], h
|
||||
|
||||
def init_cache(self, batch_size: int, device: torch.device) -> list[InferenceCache]:
|
||||
return [
|
||||
@@ -332,51 +348,15 @@ def infinite_loader(loader: DataLoader) -> Iterator[dict]:
|
||||
|
||||
def compute_losses(
|
||||
delta: Tensor,
|
||||
dt: Tensor,
|
||||
x_seq: Tensor,
|
||||
x0: Tensor,
|
||||
v_gt: Tensor,
|
||||
t_seq: Tensor,
|
||||
dt_seq: Tensor,
|
||||
cfg: TrainConfig,
|
||||
) -> dict[str, Tensor]:
|
||||
losses: dict[str, Tensor] = {}
|
||||
|
||||
if cfg.use_flow_loss:
|
||||
target_disp = v_gt[:, None, :] * dt.unsqueeze(-1)
|
||||
losses["flow"] = F.mse_loss(delta, target_disp)
|
||||
|
||||
if cfg.use_pos_loss:
|
||||
t_next = t_seq + dt
|
||||
t_next = torch.clamp(t_next, 0.0, 1.0)
|
||||
x_target = x0[:, None, :] + t_next.unsqueeze(-1) * v_gt[:, None, :]
|
||||
x_next_pred = x_seq + delta
|
||||
losses["pos"] = F.mse_loss(x_next_pred, x_target)
|
||||
|
||||
if cfg.use_dt_loss:
|
||||
losses["dt"] = F.mse_loss(dt, dt_seq)
|
||||
|
||||
target_disp = v_gt[:, None, :] * dt_seq.unsqueeze(-1)
|
||||
losses["flow"] = F.mse_loss(delta, target_disp)
|
||||
return losses
|
||||
|
||||
|
||||
def plot_dt_hist(
|
||||
dt_pred: Tensor, dt_gt: Tensor, save_path: Path, title: str = "dt Distribution"
|
||||
) -> None:
|
||||
dt_pred_np = dt_pred.detach().cpu().numpy().reshape(-1)
|
||||
dt_gt_np = dt_gt.detach().cpu().numpy().reshape(-1)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(6, 4))
|
||||
ax.hist(dt_gt_np, bins=30, alpha=0.6, label="dt_gt", color="steelblue")
|
||||
ax.hist(dt_pred_np, bins=30, alpha=0.6, label="dt_pred", color="orange")
|
||||
ax.set_title(title)
|
||||
ax.set_xlabel("dt")
|
||||
ax.set_ylabel("count")
|
||||
ax.legend(loc="best")
|
||||
fig.tight_layout()
|
||||
fig.savefig(save_path, dpi=160)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def make_grid(images: Tensor, nrow: int) -> np.ndarray:
|
||||
images = images.detach().cpu().numpy()
|
||||
b, c, h, w = images.shape
|
||||
@@ -413,30 +393,24 @@ def rollout_trajectory(
|
||||
model: ASMamba,
|
||||
x0: Tensor,
|
||||
cond: Tensor,
|
||||
max_steps: int,
|
||||
dt_seq: Tensor,
|
||||
) -> Tensor:
|
||||
device = x0.device
|
||||
model.eval()
|
||||
h = model.init_cache(batch_size=x0.shape[0], device=device)
|
||||
x = x0
|
||||
total_time = torch.zeros(x0.shape[0], device=device)
|
||||
if dt_seq.dim() == 1:
|
||||
dt_seq = dt_seq.unsqueeze(0).expand(x0.shape[0], -1)
|
||||
elif dt_seq.shape[0] == 1 and x0.shape[0] > 1:
|
||||
dt_seq = dt_seq.expand(x0.shape[0], -1)
|
||||
traj = [x0.detach().cpu()]
|
||||
|
||||
with torch.no_grad():
|
||||
for _ in range(max_steps):
|
||||
delta, dt, h = model.step(x, cond, h)
|
||||
dt = torch.clamp(dt, min=model.dt_min, max=model.dt_max)
|
||||
remaining = 1.0 - total_time
|
||||
overshoot = dt > remaining
|
||||
if overshoot.any():
|
||||
scale = (remaining / dt).unsqueeze(-1)
|
||||
delta = torch.where(overshoot.unsqueeze(-1), delta * scale, delta)
|
||||
dt = torch.where(overshoot, remaining, dt)
|
||||
for step_idx in range(dt_seq.shape[1]):
|
||||
dt = dt_seq[:, step_idx]
|
||||
delta, h = model.step(x, dt, cond, h)
|
||||
x = x + delta
|
||||
total_time = total_time + dt
|
||||
traj.append(x.detach().cpu())
|
||||
if torch.all(total_time >= 1.0 - 1e-6):
|
||||
break
|
||||
|
||||
return torch.stack(traj, dim=1)
|
||||
|
||||
@@ -451,15 +425,20 @@ def log_class_samples(
|
||||
if cfg.val_samples_per_class <= 0:
|
||||
return
|
||||
model.eval()
|
||||
max_steps = cfg.seq_len if cfg.val_max_steps <= 0 else cfg.val_max_steps
|
||||
max_steps = cfg.seq_len
|
||||
input_dim = cfg.channels * cfg.image_size * cfg.image_size
|
||||
dt_seq = torch.full(
|
||||
(cfg.val_samples_per_class, max_steps),
|
||||
1.0 / max_steps,
|
||||
device=device,
|
||||
)
|
||||
|
||||
for cls in range(cfg.num_classes):
|
||||
cond = torch.full(
|
||||
(cfg.val_samples_per_class,), cls, device=device, dtype=torch.long
|
||||
)
|
||||
x0 = torch.randn(cfg.val_samples_per_class, input_dim, device=device)
|
||||
traj = rollout_trajectory(model, x0, cond, max_steps=max_steps)
|
||||
traj = rollout_trajectory(model, x0, cond, dt_seq=dt_seq)
|
||||
x_final = traj[:, -1, :].view(
|
||||
cfg.val_samples_per_class, cfg.channels, cfg.image_size, cfg.image_size
|
||||
)
|
||||
@@ -511,68 +490,25 @@ def train(cfg: TrainConfig) -> ASMamba:
|
||||
t_seq = torch.cat([torch.zeros(b, 1, device=device), t_seq[:, :-1]], dim=-1)
|
||||
x_seq = x0[:, None, :] + t_seq[:, :, None] * v_gt[:, None, :]
|
||||
|
||||
delta, dt, _ = model(x_seq, cond)
|
||||
delta, _ = model(x_seq, dt_seq, cond)
|
||||
|
||||
losses = compute_losses(
|
||||
delta=delta,
|
||||
dt=dt,
|
||||
x_seq=x_seq,
|
||||
x0=x0,
|
||||
v_gt=v_gt,
|
||||
t_seq=t_seq,
|
||||
dt_seq=dt_seq,
|
||||
cfg=cfg,
|
||||
)
|
||||
|
||||
loss = torch.tensor(0.0, device=device)
|
||||
if cfg.use_flow_loss and "flow" in losses:
|
||||
loss = loss + cfg.lambda_flow * losses["flow"]
|
||||
if cfg.use_pos_loss and "pos" in losses:
|
||||
loss = loss + cfg.lambda_pos * losses["pos"]
|
||||
if cfg.use_dt_loss and "dt" in losses:
|
||||
loss = loss + cfg.lambda_dt * losses["dt"]
|
||||
if loss.item() == 0.0:
|
||||
raise RuntimeError(
|
||||
"No loss enabled: enable at least one of flow/pos/dt."
|
||||
)
|
||||
loss = cfg.lambda_flow * losses["flow"]
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if global_step % 10 == 0:
|
||||
dt_min = float(dt.min().item())
|
||||
dt_max = float(dt.max().item())
|
||||
dt_mean = float(dt.mean().item())
|
||||
dt_gt_min = float(dt_seq.min().item())
|
||||
dt_gt_max = float(dt_seq.max().item())
|
||||
dt_gt_mean = float(dt_seq.mean().item())
|
||||
eps = 1e-6
|
||||
clamp_min_ratio = float((dt <= cfg.dt_min + eps).float().mean().item())
|
||||
clamp_max_ratio = float((dt >= cfg.dt_max - eps).float().mean().item())
|
||||
clamp_any_ratio = float(
|
||||
((dt <= cfg.dt_min + eps) | (dt >= cfg.dt_max - eps))
|
||||
.float()
|
||||
.mean()
|
||||
.item()
|
||||
)
|
||||
logger.log(
|
||||
{
|
||||
"loss/total": float(loss.item()),
|
||||
"loss/flow": float(
|
||||
losses.get("flow", torch.tensor(0.0)).item()
|
||||
),
|
||||
"loss/pos": float(losses.get("pos", torch.tensor(0.0)).item()),
|
||||
"loss/dt": float(losses.get("dt", torch.tensor(0.0)).item()),
|
||||
"dt/pred_mean": dt_mean,
|
||||
"dt/pred_min": dt_min,
|
||||
"dt/pred_max": dt_max,
|
||||
"dt/gt_mean": dt_gt_mean,
|
||||
"dt/gt_min": dt_gt_min,
|
||||
"dt/gt_max": dt_gt_max,
|
||||
"dt/clamp_min_ratio": clamp_min_ratio,
|
||||
"dt/clamp_max_ratio": clamp_max_ratio,
|
||||
"dt/clamp_any_ratio": clamp_any_ratio,
|
||||
"loss/flow": float(losses["flow"].item()),
|
||||
},
|
||||
step=global_step,
|
||||
)
|
||||
@@ -584,21 +520,6 @@ def train(cfg: TrainConfig) -> ASMamba:
|
||||
and rank == 0
|
||||
):
|
||||
log_class_samples(unwrap_model(model), cfg, device, logger, global_step)
|
||||
dt_hist_path = (
|
||||
Path(cfg.output_dir) / f"dt_hist_step_{global_step:06d}.png"
|
||||
)
|
||||
plot_dt_hist(
|
||||
dt,
|
||||
dt_seq,
|
||||
dt_hist_path,
|
||||
title=f"dt Distribution (step {global_step})",
|
||||
)
|
||||
logger.log_image(
|
||||
"train/dt_hist",
|
||||
dt_hist_path,
|
||||
caption=f"step {global_step}",
|
||||
step=global_step,
|
||||
)
|
||||
|
||||
global_step += 1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user