refactor(as_mamba): Remove dt prediction and use fixed dt
Removes the `dt_head` network and associated configuration parameters (dt_min, dt_max, lambda_nfe, warmup_epochs). Replaces predicted time steps with a fixed value derived from sequence length. Eliminates the warmup phase and NFE loss calculation.
This commit is contained in:
80
as_mamba.py
80
as_mamba.py
@@ -25,15 +25,11 @@ class TrainConfig:
|
|||||||
batch_size: int = 128
|
batch_size: int = 128
|
||||||
steps_per_epoch: int = 50
|
steps_per_epoch: int = 50
|
||||||
epochs: int = 60
|
epochs: int = 60
|
||||||
warmup_epochs: int = 15
|
|
||||||
seq_len: int = 20
|
seq_len: int = 20
|
||||||
lr: float = 1e-3
|
lr: float = 1e-3
|
||||||
weight_decay: float = 1e-2
|
weight_decay: float = 1e-2
|
||||||
dt_min: float = 1e-3
|
|
||||||
dt_max: float = 0.06
|
|
||||||
lambda_flow: float = 1.0
|
lambda_flow: float = 1.0
|
||||||
lambda_pos: float = 1.0
|
lambda_pos: float = 1.0
|
||||||
lambda_nfe: float = 0.05
|
|
||||||
radius_min: float = 0.6
|
radius_min: float = 0.6
|
||||||
radius_max: float = 1.4
|
radius_max: float = 1.4
|
||||||
center_min: float = -6.0
|
center_min: float = -6.0
|
||||||
@@ -53,7 +49,7 @@ class TrainConfig:
|
|||||||
val_every: int = 200
|
val_every: int = 200
|
||||||
val_samples: int = 256
|
val_samples: int = 256
|
||||||
val_plot_samples: int = 16
|
val_plot_samples: int = 16
|
||||||
val_max_steps: int = 100
|
val_max_steps: int = 0
|
||||||
|
|
||||||
|
|
||||||
class Mamba2Backbone(nn.Module):
|
class Mamba2Backbone(nn.Module):
|
||||||
@@ -92,8 +88,6 @@ class ASMamba(nn.Module):
|
|||||||
def __init__(self, cfg: TrainConfig) -> None:
|
def __init__(self, cfg: TrainConfig) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.dt_min = float(cfg.dt_min)
|
|
||||||
self.dt_max = float(cfg.dt_max)
|
|
||||||
|
|
||||||
args = Mamba2Config(
|
args = Mamba2Config(
|
||||||
d_model=cfg.d_model,
|
d_model=cfg.d_model,
|
||||||
@@ -107,27 +101,20 @@ class ASMamba(nn.Module):
|
|||||||
self.backbone = Mamba2Backbone(args, use_residual=cfg.use_residual)
|
self.backbone = Mamba2Backbone(args, use_residual=cfg.use_residual)
|
||||||
self.input_proj = nn.Linear(3, cfg.d_model)
|
self.input_proj = nn.Linear(3, cfg.d_model)
|
||||||
self.delta_head = nn.Linear(cfg.d_model, 3)
|
self.delta_head = nn.Linear(cfg.d_model, 3)
|
||||||
self.dt_head = nn.Sequential(
|
|
||||||
nn.Linear(cfg.d_model, cfg.d_model),
|
|
||||||
nn.SiLU(),
|
|
||||||
nn.Linear(cfg.d_model, 1),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: Tensor, h: Optional[list[InferenceCache]] = None
|
self, x: Tensor, h: Optional[list[InferenceCache]] = None
|
||||||
) -> tuple[Tensor, Tensor, list[InferenceCache]]:
|
) -> tuple[Tensor, list[InferenceCache]]:
|
||||||
x_proj = self.input_proj(x)
|
x_proj = self.input_proj(x)
|
||||||
feats, h = self.backbone(x_proj, h)
|
feats, h = self.backbone(x_proj, h)
|
||||||
delta = self.delta_head(feats)
|
delta = self.delta_head(feats)
|
||||||
dt_raw = self.dt_head(feats).squeeze(-1)
|
return delta, h
|
||||||
dt = torch.clamp(F.softplus(dt_raw), min=self.dt_min, max=self.dt_max)
|
|
||||||
return delta, dt, h
|
|
||||||
|
|
||||||
def step(
|
def step(
|
||||||
self, x: Tensor, h: list[InferenceCache]
|
self, x: Tensor, h: list[InferenceCache]
|
||||||
) -> tuple[Tensor, Tensor, list[InferenceCache]]:
|
) -> tuple[Tensor, list[InferenceCache]]:
|
||||||
delta, dt, h = self.forward(x.unsqueeze(1), h)
|
delta, h = self.forward(x.unsqueeze(1), h)
|
||||||
return delta[:, 0, :], dt[:, 0], h
|
return delta[:, 0, :], h
|
||||||
|
|
||||||
def init_cache(self, batch_size: int, device: torch.device) -> list[InferenceCache]:
|
def init_cache(self, batch_size: int, device: torch.device) -> list[InferenceCache]:
|
||||||
return [
|
return [
|
||||||
@@ -240,24 +227,23 @@ def sample_batch(
|
|||||||
|
|
||||||
def compute_losses(
|
def compute_losses(
|
||||||
delta: Tensor,
|
delta: Tensor,
|
||||||
dt: Tensor,
|
|
||||||
x_seq: Tensor,
|
x_seq: Tensor,
|
||||||
x0: Tensor,
|
x0: Tensor,
|
||||||
v_gt: Tensor,
|
v_gt: Tensor,
|
||||||
t_seq: Tensor,
|
t_seq: Tensor,
|
||||||
cfg: TrainConfig,
|
cfg: TrainConfig,
|
||||||
) -> tuple[Tensor, Tensor, Tensor]:
|
) -> tuple[Tensor, Tensor]:
|
||||||
target_disp = v_gt[:, None, :] * dt.unsqueeze(-1)
|
dt_fixed = 1.0 / cfg.seq_len
|
||||||
|
target_disp = v_gt[:, None, :] * dt_fixed
|
||||||
flow_loss = F.mse_loss(delta, target_disp)
|
flow_loss = F.mse_loss(delta, target_disp)
|
||||||
|
|
||||||
t_next = t_seq[None, :, None] + dt.unsqueeze(-1)
|
t_next = t_seq[None, :, None] + dt_fixed
|
||||||
t_next = torch.clamp(t_next, 0.0, 1.0)
|
t_next = torch.clamp(t_next, 0.0, 1.0)
|
||||||
x_target = x0[:, None, :] + t_next * v_gt[:, None, :]
|
x_target = x0[:, None, :] + t_next * v_gt[:, None, :]
|
||||||
x_next_pred = x_seq + delta
|
x_next_pred = x_seq + delta
|
||||||
pos_loss = F.mse_loss(x_next_pred, x_target)
|
pos_loss = F.mse_loss(x_next_pred, x_target)
|
||||||
|
|
||||||
nfe_loss = (-torch.log(dt)).mean()
|
return flow_loss, pos_loss
|
||||||
return flow_loss, pos_loss, nfe_loss
|
|
||||||
|
|
||||||
|
|
||||||
def validate(
|
def validate(
|
||||||
@@ -271,12 +257,13 @@ def validate(
|
|||||||
) -> None:
|
) -> None:
|
||||||
model.eval()
|
model.eval()
|
||||||
center_b, radius_b = sphere_b
|
center_b, radius_b = sphere_b
|
||||||
|
steps = cfg.seq_len if cfg.val_max_steps <= 0 else cfg.val_max_steps
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
x0 = sample_points_in_sphere(
|
x0 = sample_points_in_sphere(
|
||||||
sphere_a[0], float(sphere_a[1].item()), cfg.val_samples, device
|
sphere_a[0], float(sphere_a[1].item()), cfg.val_samples, device
|
||||||
)
|
)
|
||||||
traj = rollout_trajectory(model, x0, max_steps=cfg.val_max_steps)
|
traj = rollout_trajectory(model, x0, steps=steps)
|
||||||
|
|
||||||
x_final = traj[:, -1, :]
|
x_final = traj[:, -1, :]
|
||||||
center_b_cpu = center_b.detach().cpu()
|
center_b_cpu = center_b.detach().cpu()
|
||||||
@@ -291,6 +278,7 @@ def validate(
|
|||||||
"val/final_dist_mean": float(dist.mean().item()),
|
"val/final_dist_mean": float(dist.mean().item()),
|
||||||
"val/final_dist_min": float(dist.min().item()),
|
"val/final_dist_min": float(dist.min().item()),
|
||||||
"val/final_dist_max": float(dist.max().item()),
|
"val/final_dist_max": float(dist.max().item()),
|
||||||
|
"val/steps": float(steps),
|
||||||
},
|
},
|
||||||
step=step,
|
step=step,
|
||||||
)
|
)
|
||||||
@@ -351,22 +339,15 @@ def train(cfg: TrainConfig) -> tuple[ASMamba, tuple[Tensor, Tensor], tuple[Tenso
|
|||||||
|
|
||||||
global_step = 0
|
global_step = 0
|
||||||
for epoch in range(cfg.epochs):
|
for epoch in range(cfg.epochs):
|
||||||
warmup = epoch < cfg.warmup_epochs
|
|
||||||
model.train()
|
model.train()
|
||||||
for p in model.dt_head.parameters():
|
|
||||||
p.requires_grad = not warmup
|
|
||||||
|
|
||||||
for _ in range(cfg.steps_per_epoch):
|
for _ in range(cfg.steps_per_epoch):
|
||||||
x0, x1, x_seq, t_seq = sample_batch(cfg, sphere_a, sphere_b, device)
|
x0, x1, x_seq, t_seq = sample_batch(cfg, sphere_a, sphere_b, device)
|
||||||
v_gt = x1 - x0
|
v_gt = x1 - x0
|
||||||
|
|
||||||
delta, dt, _ = model(x_seq)
|
delta, _ = model(x_seq)
|
||||||
if warmup:
|
flow_loss, pos_loss = compute_losses(
|
||||||
dt = torch.full_like(dt, 1.0 / cfg.seq_len)
|
|
||||||
|
|
||||||
flow_loss, pos_loss, nfe_loss = compute_losses(
|
|
||||||
delta=delta,
|
delta=delta,
|
||||||
dt=dt,
|
|
||||||
x_seq=x_seq,
|
x_seq=x_seq,
|
||||||
x0=x0,
|
x0=x0,
|
||||||
v_gt=v_gt,
|
v_gt=v_gt,
|
||||||
@@ -375,8 +356,6 @@ def train(cfg: TrainConfig) -> tuple[ASMamba, tuple[Tensor, Tensor], tuple[Tenso
|
|||||||
)
|
)
|
||||||
|
|
||||||
loss = cfg.lambda_flow * flow_loss + cfg.lambda_pos * pos_loss
|
loss = cfg.lambda_flow * flow_loss + cfg.lambda_pos * pos_loss
|
||||||
if not warmup:
|
|
||||||
loss = loss + cfg.lambda_nfe * nfe_loss
|
|
||||||
|
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
@@ -388,11 +367,6 @@ def train(cfg: TrainConfig) -> tuple[ASMamba, tuple[Tensor, Tensor], tuple[Tenso
|
|||||||
"loss/total": float(loss.item()),
|
"loss/total": float(loss.item()),
|
||||||
"loss/flow": float(flow_loss.item()),
|
"loss/flow": float(flow_loss.item()),
|
||||||
"loss/pos": float(pos_loss.item()),
|
"loss/pos": float(pos_loss.item()),
|
||||||
"loss/nfe": float(nfe_loss.item()),
|
|
||||||
"dt/mean": float(dt.mean().item()),
|
|
||||||
"dt/min": float(dt.min().item()),
|
|
||||||
"dt/max": float(dt.max().item()),
|
|
||||||
"stage": 0 if warmup else 1,
|
|
||||||
},
|
},
|
||||||
step=global_step,
|
step=global_step,
|
||||||
)
|
)
|
||||||
@@ -408,33 +382,20 @@ def train(cfg: TrainConfig) -> tuple[ASMamba, tuple[Tensor, Tensor], tuple[Tenso
|
|||||||
def rollout_trajectory(
|
def rollout_trajectory(
|
||||||
model: ASMamba,
|
model: ASMamba,
|
||||||
x0: Tensor,
|
x0: Tensor,
|
||||||
max_steps: int = 100,
|
steps: int,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
device = x0.device
|
device = x0.device
|
||||||
model.eval()
|
model.eval()
|
||||||
h = model.init_cache(batch_size=x0.shape[0], device=device)
|
h = model.init_cache(batch_size=x0.shape[0], device=device)
|
||||||
x = x0
|
x = x0
|
||||||
total_time = torch.zeros(x0.shape[0], device=device)
|
|
||||||
traj = [x0.detach().cpu()]
|
traj = [x0.detach().cpu()]
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for _ in range(max_steps):
|
for _ in range(steps):
|
||||||
delta, dt, h = model.step(x, h)
|
delta, h = model.step(x, h)
|
||||||
dt = torch.clamp(dt, min=model.dt_min, max=model.dt_max)
|
|
||||||
remaining = 1.0 - total_time
|
|
||||||
overshoot = dt > remaining
|
|
||||||
if overshoot.any():
|
|
||||||
scale = (remaining / dt).unsqueeze(-1)
|
|
||||||
delta = torch.where(overshoot.unsqueeze(-1), delta * scale, delta)
|
|
||||||
dt = torch.where(overshoot, remaining, dt)
|
|
||||||
|
|
||||||
x = x + delta
|
x = x + delta
|
||||||
total_time = total_time + dt
|
|
||||||
traj.append(x.detach().cpu())
|
traj.append(x.detach().cpu())
|
||||||
|
|
||||||
if torch.all(total_time >= 1.0 - 1e-6):
|
|
||||||
break
|
|
||||||
|
|
||||||
return torch.stack(traj, dim=1)
|
return torch.stack(traj, dim=1)
|
||||||
|
|
||||||
|
|
||||||
@@ -504,7 +465,8 @@ def run_training_and_plot(cfg: TrainConfig) -> Path:
|
|||||||
x0 = sample_points_in_sphere(
|
x0 = sample_points_in_sphere(
|
||||||
sphere_a[0], float(sphere_a[1].item()), plot_samples, device
|
sphere_a[0], float(sphere_a[1].item()), plot_samples, device
|
||||||
)
|
)
|
||||||
traj = rollout_trajectory(model, x0, max_steps=cfg.val_max_steps)
|
steps = cfg.seq_len if cfg.val_max_steps <= 0 else cfg.val_max_steps
|
||||||
|
traj = rollout_trajectory(model, x0, steps=steps)
|
||||||
output_dir = Path(cfg.output_dir)
|
output_dir = Path(cfg.output_dir)
|
||||||
save_path = output_dir / "as_mamba_trajectory.png"
|
save_path = output_dir / "as_mamba_trajectory.png"
|
||||||
plot_trajectories(traj, sphere_a, sphere_b, save_path)
|
plot_trajectories(traj, sphere_a, sphere_b, save_path)
|
||||||
|
|||||||
1
main.py
1
main.py
@@ -6,7 +6,6 @@ from as_mamba import TrainConfig, run_training_and_plot
|
|||||||
def build_parser() -> argparse.ArgumentParser:
|
def build_parser() -> argparse.ArgumentParser:
|
||||||
parser = argparse.ArgumentParser(description="Train AS-Mamba on sphere-to-sphere flow.")
|
parser = argparse.ArgumentParser(description="Train AS-Mamba on sphere-to-sphere flow.")
|
||||||
parser.add_argument("--epochs", type=int, default=None)
|
parser.add_argument("--epochs", type=int, default=None)
|
||||||
parser.add_argument("--warmup-epochs", type=int, default=None)
|
|
||||||
parser.add_argument("--batch-size", type=int, default=None)
|
parser.add_argument("--batch-size", type=int, default=None)
|
||||||
parser.add_argument("--steps-per-epoch", type=int, default=None)
|
parser.add_argument("--steps-per-epoch", type=int, default=None)
|
||||||
parser.add_argument("--seq-len", type=int, default=None)
|
parser.add_argument("--seq-len", type=int, default=None)
|
||||||
|
|||||||
Reference in New Issue
Block a user