Add configurable dt sampling and loss toggles
This commit is contained in:
164
as_mamba.py
164
as_mamba.py
@@ -25,15 +25,18 @@ class TrainConfig:
|
|||||||
batch_size: int = 128
|
batch_size: int = 128
|
||||||
steps_per_epoch: int = 50
|
steps_per_epoch: int = 50
|
||||||
epochs: int = 60
|
epochs: int = 60
|
||||||
warmup_epochs: int = 15
|
|
||||||
seq_len: int = 20
|
seq_len: int = 20
|
||||||
lr: float = 1e-3
|
lr: float = 1e-3
|
||||||
weight_decay: float = 1e-2
|
weight_decay: float = 1e-2
|
||||||
dt_min: float = 1e-3
|
dt_min: float = 1e-3
|
||||||
dt_max: float = 0.06
|
dt_max: float = 0.06
|
||||||
|
dt_alpha: float = 8.0
|
||||||
lambda_flow: float = 1.0
|
lambda_flow: float = 1.0
|
||||||
lambda_pos: float = 1.0
|
lambda_pos: float = 1.0
|
||||||
lambda_nfe: float = 0.05
|
lambda_dt: float = 0.05
|
||||||
|
use_flow_loss: bool = True
|
||||||
|
use_pos_loss: bool = False
|
||||||
|
use_dt_loss: bool = True
|
||||||
radius_min: float = 0.6
|
radius_min: float = 0.6
|
||||||
radius_max: float = 1.4
|
radius_max: float = 1.4
|
||||||
center_min: float = -6.0
|
center_min: float = -6.0
|
||||||
@@ -53,7 +56,7 @@ class TrainConfig:
|
|||||||
val_every: int = 200
|
val_every: int = 200
|
||||||
val_samples: int = 256
|
val_samples: int = 256
|
||||||
val_plot_samples: int = 16
|
val_plot_samples: int = 16
|
||||||
val_max_steps: int = 100
|
val_max_steps: int = 0
|
||||||
|
|
||||||
|
|
||||||
class Mamba2Backbone(nn.Module):
|
class Mamba2Backbone(nn.Module):
|
||||||
@@ -221,21 +224,41 @@ def sample_sphere_params(cfg: TrainConfig, device: torch.device) -> tuple[Tensor
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def sample_time_sequence(cfg: TrainConfig, batch_size: int, device: torch.device) -> Tensor:
|
||||||
|
alpha = float(cfg.dt_alpha)
|
||||||
|
if alpha <= 0:
|
||||||
|
raise ValueError("dt_alpha must be > 0")
|
||||||
|
dist = torch.distributions.Gamma(alpha, 1.0)
|
||||||
|
raw = dist.sample((batch_size, cfg.seq_len)).to(device)
|
||||||
|
dt_seq = raw / raw.sum(dim=-1, keepdim=True)
|
||||||
|
base = 1.0 / cfg.seq_len
|
||||||
|
max_dt = float(cfg.dt_max)
|
||||||
|
if max_dt <= base:
|
||||||
|
return torch.full_like(dt_seq, base)
|
||||||
|
max_current = dt_seq.max(dim=-1, keepdim=True).values
|
||||||
|
if (max_current > max_dt).any():
|
||||||
|
gamma = (max_dt - base) / (max_current - base)
|
||||||
|
gamma = gamma.clamp(0.0, 1.0)
|
||||||
|
dt_seq = gamma * dt_seq + (1.0 - gamma) * base
|
||||||
|
return dt_seq
|
||||||
|
|
||||||
|
|
||||||
def sample_batch(
|
def sample_batch(
|
||||||
cfg: TrainConfig,
|
cfg: TrainConfig,
|
||||||
sphere_a: tuple[Tensor, Tensor],
|
sphere_a: tuple[Tensor, Tensor],
|
||||||
sphere_b: tuple[Tensor, Tensor],
|
sphere_b: tuple[Tensor, Tensor],
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
|
) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
|
||||||
center_a, radius_a = sphere_a
|
center_a, radius_a = sphere_a
|
||||||
center_b, radius_b = sphere_b
|
center_b, radius_b = sphere_b
|
||||||
x0 = sample_points_in_sphere(center_a, float(radius_a.item()), cfg.batch_size, device)
|
x0 = sample_points_in_sphere(center_a, float(radius_a.item()), cfg.batch_size, device)
|
||||||
x1 = sample_points_in_sphere(center_b, float(radius_b.item()), cfg.batch_size, device)
|
x1 = sample_points_in_sphere(center_b, float(radius_b.item()), cfg.batch_size, device)
|
||||||
v_gt = x1 - x0
|
v_gt = x1 - x0
|
||||||
dt_fixed = 1.0 / cfg.seq_len
|
dt_seq = sample_time_sequence(cfg, cfg.batch_size, device)
|
||||||
t_seq = torch.arange(cfg.seq_len, device=device) * dt_fixed
|
t_seq = torch.cumsum(dt_seq, dim=-1)
|
||||||
x_seq = x0[:, None, :] + t_seq[None, :, None] * v_gt[:, None, :]
|
t_seq = torch.cat([torch.zeros(cfg.batch_size, 1, device=device), t_seq[:, :-1]], dim=-1)
|
||||||
return x0, x1, x_seq, t_seq
|
x_seq = x0[:, None, :] + t_seq[:, :, None] * v_gt[:, None, :]
|
||||||
|
return x0, x1, x_seq, t_seq, dt_seq
|
||||||
|
|
||||||
|
|
||||||
def compute_losses(
|
def compute_losses(
|
||||||
@@ -245,19 +268,26 @@ def compute_losses(
|
|||||||
x0: Tensor,
|
x0: Tensor,
|
||||||
v_gt: Tensor,
|
v_gt: Tensor,
|
||||||
t_seq: Tensor,
|
t_seq: Tensor,
|
||||||
|
dt_seq: Tensor,
|
||||||
cfg: TrainConfig,
|
cfg: TrainConfig,
|
||||||
) -> tuple[Tensor, Tensor, Tensor]:
|
) -> dict[str, Tensor]:
|
||||||
target_disp = v_gt[:, None, :] * dt.unsqueeze(-1)
|
losses: dict[str, Tensor] = {}
|
||||||
flow_loss = F.mse_loss(delta, target_disp)
|
|
||||||
|
|
||||||
t_next = t_seq[None, :, None] + dt.unsqueeze(-1)
|
if cfg.use_flow_loss:
|
||||||
t_next = torch.clamp(t_next, 0.0, 1.0)
|
target_disp = v_gt[:, None, :] * dt.unsqueeze(-1)
|
||||||
x_target = x0[:, None, :] + t_next * v_gt[:, None, :]
|
losses["flow"] = F.mse_loss(delta, target_disp)
|
||||||
x_next_pred = x_seq + delta
|
|
||||||
pos_loss = F.mse_loss(x_next_pred, x_target)
|
|
||||||
|
|
||||||
nfe_loss = (-torch.log(dt)).mean()
|
if cfg.use_pos_loss:
|
||||||
return flow_loss, pos_loss, nfe_loss
|
t_next = t_seq + dt
|
||||||
|
t_next = torch.clamp(t_next, 0.0, 1.0)
|
||||||
|
x_target = x0[:, None, :] + t_next.unsqueeze(-1) * v_gt[:, None, :]
|
||||||
|
x_next_pred = x_seq + delta
|
||||||
|
losses["pos"] = F.mse_loss(x_next_pred, x_target)
|
||||||
|
|
||||||
|
if cfg.use_dt_loss:
|
||||||
|
losses["dt"] = F.mse_loss(dt, dt_seq)
|
||||||
|
|
||||||
|
return losses
|
||||||
|
|
||||||
|
|
||||||
def validate(
|
def validate(
|
||||||
@@ -271,12 +301,13 @@ def validate(
|
|||||||
) -> None:
|
) -> None:
|
||||||
model.eval()
|
model.eval()
|
||||||
center_b, radius_b = sphere_b
|
center_b, radius_b = sphere_b
|
||||||
|
max_steps = cfg.seq_len if cfg.val_max_steps <= 0 else cfg.val_max_steps
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
x0 = sample_points_in_sphere(
|
x0 = sample_points_in_sphere(
|
||||||
sphere_a[0], float(sphere_a[1].item()), cfg.val_samples, device
|
sphere_a[0], float(sphere_a[1].item()), cfg.val_samples, device
|
||||||
)
|
)
|
||||||
traj = rollout_trajectory(model, x0, max_steps=cfg.val_max_steps)
|
traj = rollout_trajectory(model, x0, max_steps=max_steps)
|
||||||
|
|
||||||
x_final = traj[:, -1, :]
|
x_final = traj[:, -1, :]
|
||||||
center_b_cpu = center_b.detach().cpu()
|
center_b_cpu = center_b.detach().cpu()
|
||||||
@@ -291,6 +322,7 @@ def validate(
|
|||||||
"val/final_dist_mean": float(dist.mean().item()),
|
"val/final_dist_mean": float(dist.mean().item()),
|
||||||
"val/final_dist_min": float(dist.min().item()),
|
"val/final_dist_min": float(dist.min().item()),
|
||||||
"val/final_dist_max": float(dist.max().item()),
|
"val/final_dist_max": float(dist.max().item()),
|
||||||
|
"val/max_steps": float(max_steps),
|
||||||
},
|
},
|
||||||
step=step,
|
step=step,
|
||||||
)
|
)
|
||||||
@@ -351,54 +383,86 @@ def train(cfg: TrainConfig) -> tuple[ASMamba, tuple[Tensor, Tensor], tuple[Tenso
|
|||||||
|
|
||||||
global_step = 0
|
global_step = 0
|
||||||
for epoch in range(cfg.epochs):
|
for epoch in range(cfg.epochs):
|
||||||
warmup = epoch < cfg.warmup_epochs
|
|
||||||
model.train()
|
model.train()
|
||||||
for p in model.dt_head.parameters():
|
|
||||||
p.requires_grad = not warmup
|
|
||||||
|
|
||||||
for _ in range(cfg.steps_per_epoch):
|
for _ in range(cfg.steps_per_epoch):
|
||||||
x0, x1, x_seq, t_seq = sample_batch(cfg, sphere_a, sphere_b, device)
|
x0, x1, x_seq, t_seq, dt_seq = sample_batch(cfg, sphere_a, sphere_b, device)
|
||||||
v_gt = x1 - x0
|
v_gt = x1 - x0
|
||||||
|
|
||||||
delta, dt, _ = model(x_seq)
|
delta, dt, _ = model(x_seq)
|
||||||
if warmup:
|
|
||||||
dt = torch.full_like(dt, 1.0 / cfg.seq_len)
|
|
||||||
|
|
||||||
flow_loss, pos_loss, nfe_loss = compute_losses(
|
losses = compute_losses(
|
||||||
delta=delta,
|
delta=delta,
|
||||||
dt=dt,
|
dt=dt,
|
||||||
x_seq=x_seq,
|
x_seq=x_seq,
|
||||||
x0=x0,
|
x0=x0,
|
||||||
v_gt=v_gt,
|
v_gt=v_gt,
|
||||||
t_seq=t_seq,
|
t_seq=t_seq,
|
||||||
|
dt_seq=dt_seq,
|
||||||
cfg=cfg,
|
cfg=cfg,
|
||||||
)
|
)
|
||||||
|
|
||||||
loss = cfg.lambda_flow * flow_loss + cfg.lambda_pos * pos_loss
|
loss = torch.tensor(0.0, device=device)
|
||||||
if not warmup:
|
if cfg.use_flow_loss and "flow" in losses:
|
||||||
loss = loss + cfg.lambda_nfe * nfe_loss
|
loss = loss + cfg.lambda_flow * losses["flow"]
|
||||||
|
if cfg.use_pos_loss and "pos" in losses:
|
||||||
|
loss = loss + cfg.lambda_pos * losses["pos"]
|
||||||
|
if cfg.use_dt_loss and "dt" in losses:
|
||||||
|
loss = loss + cfg.lambda_dt * losses["dt"]
|
||||||
|
if loss.item() == 0.0:
|
||||||
|
raise RuntimeError("No loss enabled: enable at least one of flow/pos/dt.")
|
||||||
|
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
if global_step % 10 == 0:
|
if global_step % 10 == 0:
|
||||||
|
dt_min = float(dt.min().item())
|
||||||
|
dt_max = float(dt.max().item())
|
||||||
|
dt_mean = float(dt.mean().item())
|
||||||
|
dt_gt_min = float(dt_seq.min().item())
|
||||||
|
dt_gt_max = float(dt_seq.max().item())
|
||||||
|
dt_gt_mean = float(dt_seq.mean().item())
|
||||||
|
eps = 1e-6
|
||||||
|
clamp_min_ratio = float((dt <= cfg.dt_min + eps).float().mean().item())
|
||||||
|
clamp_max_ratio = float((dt >= cfg.dt_max - eps).float().mean().item())
|
||||||
|
clamp_any_ratio = float(
|
||||||
|
((dt <= cfg.dt_min + eps) | (dt >= cfg.dt_max - eps)).float().mean().item()
|
||||||
|
)
|
||||||
logger.log(
|
logger.log(
|
||||||
{
|
{
|
||||||
"loss/total": float(loss.item()),
|
"loss/total": float(loss.item()),
|
||||||
"loss/flow": float(flow_loss.item()),
|
"loss/flow": float(losses.get("flow", torch.tensor(0.0)).item()),
|
||||||
"loss/pos": float(pos_loss.item()),
|
"loss/pos": float(losses.get("pos", torch.tensor(0.0)).item()),
|
||||||
"loss/nfe": float(nfe_loss.item()),
|
"loss/dt": float(losses.get("dt", torch.tensor(0.0)).item()),
|
||||||
"dt/mean": float(dt.mean().item()),
|
"dt/pred_mean": dt_mean,
|
||||||
"dt/min": float(dt.min().item()),
|
"dt/pred_min": dt_min,
|
||||||
"dt/max": float(dt.max().item()),
|
"dt/pred_max": dt_max,
|
||||||
"stage": 0 if warmup else 1,
|
"dt/gt_mean": dt_gt_mean,
|
||||||
|
"dt/gt_min": dt_gt_min,
|
||||||
|
"dt/gt_max": dt_gt_max,
|
||||||
|
"dt/clamp_min_ratio": clamp_min_ratio,
|
||||||
|
"dt/clamp_max_ratio": clamp_max_ratio,
|
||||||
|
"dt/clamp_any_ratio": clamp_any_ratio,
|
||||||
},
|
},
|
||||||
step=global_step,
|
step=global_step,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.val_every > 0 and global_step > 0 and global_step % cfg.val_every == 0:
|
if cfg.val_every > 0 and global_step > 0 and global_step % cfg.val_every == 0:
|
||||||
validate(model, cfg, sphere_a, sphere_b, device, logger, global_step)
|
validate(model, cfg, sphere_a, sphere_b, device, logger, global_step)
|
||||||
|
dt_hist_path = Path(cfg.output_dir) / f"dt_hist_step_{global_step:06d}.png"
|
||||||
|
plot_dt_hist(
|
||||||
|
dt,
|
||||||
|
dt_seq,
|
||||||
|
dt_hist_path,
|
||||||
|
title=f"dt Distribution (step {global_step})",
|
||||||
|
)
|
||||||
|
logger.log_image(
|
||||||
|
"train/dt_hist",
|
||||||
|
dt_hist_path,
|
||||||
|
caption=f"step {global_step}",
|
||||||
|
step=global_step,
|
||||||
|
)
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
logger.finish()
|
logger.finish()
|
||||||
@@ -408,7 +472,7 @@ def train(cfg: TrainConfig) -> tuple[ASMamba, tuple[Tensor, Tensor], tuple[Tenso
|
|||||||
def rollout_trajectory(
|
def rollout_trajectory(
|
||||||
model: ASMamba,
|
model: ASMamba,
|
||||||
x0: Tensor,
|
x0: Tensor,
|
||||||
max_steps: int = 100,
|
max_steps: int,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
device = x0.device
|
device = x0.device
|
||||||
model.eval()
|
model.eval()
|
||||||
@@ -427,11 +491,9 @@ def rollout_trajectory(
|
|||||||
scale = (remaining / dt).unsqueeze(-1)
|
scale = (remaining / dt).unsqueeze(-1)
|
||||||
delta = torch.where(overshoot.unsqueeze(-1), delta * scale, delta)
|
delta = torch.where(overshoot.unsqueeze(-1), delta * scale, delta)
|
||||||
dt = torch.where(overshoot, remaining, dt)
|
dt = torch.where(overshoot, remaining, dt)
|
||||||
|
|
||||||
x = x + delta
|
x = x + delta
|
||||||
total_time = total_time + dt
|
total_time = total_time + dt
|
||||||
traj.append(x.detach().cpu())
|
traj.append(x.detach().cpu())
|
||||||
|
|
||||||
if torch.all(total_time >= 1.0 - 1e-6):
|
if torch.all(total_time >= 1.0 - 1e-6):
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -496,6 +558,27 @@ def plot_trajectories(
|
|||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
def plot_dt_hist(
|
||||||
|
dt_pred: Tensor,
|
||||||
|
dt_gt: Tensor,
|
||||||
|
save_path: Path,
|
||||||
|
title: str = "dt Distribution",
|
||||||
|
) -> None:
|
||||||
|
dt_pred_np = dt_pred.detach().cpu().numpy().reshape(-1)
|
||||||
|
dt_gt_np = dt_gt.detach().cpu().numpy().reshape(-1)
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(6, 4))
|
||||||
|
ax.hist(dt_gt_np, bins=30, alpha=0.6, label="dt_gt", color="steelblue")
|
||||||
|
ax.hist(dt_pred_np, bins=30, alpha=0.6, label="dt_pred", color="orange")
|
||||||
|
ax.set_title(title)
|
||||||
|
ax.set_xlabel("dt")
|
||||||
|
ax.set_ylabel("count")
|
||||||
|
ax.legend(loc="best")
|
||||||
|
fig.tight_layout()
|
||||||
|
fig.savefig(save_path, dpi=160)
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
def run_training_and_plot(cfg: TrainConfig) -> Path:
|
def run_training_and_plot(cfg: TrainConfig) -> Path:
|
||||||
model, sphere_a, sphere_b = train(cfg)
|
model, sphere_a, sphere_b = train(cfg)
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
@@ -504,7 +587,8 @@ def run_training_and_plot(cfg: TrainConfig) -> Path:
|
|||||||
x0 = sample_points_in_sphere(
|
x0 = sample_points_in_sphere(
|
||||||
sphere_a[0], float(sphere_a[1].item()), plot_samples, device
|
sphere_a[0], float(sphere_a[1].item()), plot_samples, device
|
||||||
)
|
)
|
||||||
traj = rollout_trajectory(model, x0, max_steps=cfg.val_max_steps)
|
max_steps = cfg.seq_len if cfg.val_max_steps <= 0 else cfg.val_max_steps
|
||||||
|
traj = rollout_trajectory(model, x0, max_steps=max_steps)
|
||||||
output_dir = Path(cfg.output_dir)
|
output_dir = Path(cfg.output_dir)
|
||||||
save_path = output_dir / "as_mamba_trajectory.png"
|
save_path = output_dir / "as_mamba_trajectory.png"
|
||||||
plot_trajectories(traj, sphere_a, sphere_b, save_path)
|
plot_trajectories(traj, sphere_a, sphere_b, save_path)
|
||||||
|
|||||||
22
main.py
22
main.py
@@ -6,7 +6,6 @@ from as_mamba import TrainConfig, run_training_and_plot
|
|||||||
def build_parser() -> argparse.ArgumentParser:
|
def build_parser() -> argparse.ArgumentParser:
|
||||||
parser = argparse.ArgumentParser(description="Train AS-Mamba on sphere-to-sphere flow.")
|
parser = argparse.ArgumentParser(description="Train AS-Mamba on sphere-to-sphere flow.")
|
||||||
parser.add_argument("--epochs", type=int, default=None)
|
parser.add_argument("--epochs", type=int, default=None)
|
||||||
parser.add_argument("--warmup-epochs", type=int, default=None)
|
|
||||||
parser.add_argument("--batch-size", type=int, default=None)
|
parser.add_argument("--batch-size", type=int, default=None)
|
||||||
parser.add_argument("--steps-per-epoch", type=int, default=None)
|
parser.add_argument("--steps-per-epoch", type=int, default=None)
|
||||||
parser.add_argument("--seq-len", type=int, default=None)
|
parser.add_argument("--seq-len", type=int, default=None)
|
||||||
@@ -15,6 +14,27 @@ def build_parser() -> argparse.ArgumentParser:
|
|||||||
parser.add_argument("--output-dir", type=str, default=None)
|
parser.add_argument("--output-dir", type=str, default=None)
|
||||||
parser.add_argument("--project", type=str, default=None)
|
parser.add_argument("--project", type=str, default=None)
|
||||||
parser.add_argument("--run-name", type=str, default=None)
|
parser.add_argument("--run-name", type=str, default=None)
|
||||||
|
parser.add_argument("--dt-alpha", type=float, default=None)
|
||||||
|
parser.add_argument("--dt-min", type=float, default=None)
|
||||||
|
parser.add_argument("--dt-max", type=float, default=None)
|
||||||
|
parser.add_argument("--lambda-flow", type=float, default=None)
|
||||||
|
parser.add_argument("--lambda-pos", type=float, default=None)
|
||||||
|
parser.add_argument("--lambda-dt", type=float, default=None)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-flow-loss",
|
||||||
|
action=argparse.BooleanOptionalAction,
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-pos-loss",
|
||||||
|
action=argparse.BooleanOptionalAction,
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-dt-loss",
|
||||||
|
action=argparse.BooleanOptionalAction,
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
parser.add_argument("--val-every", type=int, default=None)
|
parser.add_argument("--val-every", type=int, default=None)
|
||||||
parser.add_argument("--val-samples", type=int, default=None)
|
parser.add_argument("--val-samples", type=int, default=None)
|
||||||
parser.add_argument("--val-plot-samples", type=int, default=None)
|
parser.add_argument("--val-plot-samples", type=int, default=None)
|
||||||
|
|||||||
62
train_as_mamba.sh
Executable file
62
train_as_mamba.sh
Executable file
@@ -0,0 +1,62 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
DEVICE="cuda"
|
||||||
|
EPOCHS=80
|
||||||
|
STEPS_PER_EPOCH=100
|
||||||
|
BATCH_SIZE=256
|
||||||
|
SEQ_LEN=20
|
||||||
|
LR=1e-3
|
||||||
|
DT_MIN=1e-3
|
||||||
|
DT_MAX=0.10
|
||||||
|
DT_ALPHA=6.0
|
||||||
|
LAMBDA_FLOW=1.0
|
||||||
|
LAMBDA_POS=0.0
|
||||||
|
LAMBDA_DT=0.5
|
||||||
|
USE_FLOW_LOSS=true
|
||||||
|
USE_POS_LOSS=false
|
||||||
|
USE_DT_LOSS=true
|
||||||
|
VAL_EVERY=200
|
||||||
|
VAL_SAMPLES=512
|
||||||
|
VAL_PLOT_SAMPLES=16
|
||||||
|
VAL_MAX_STEPS=100
|
||||||
|
CENTER_MIN=-8
|
||||||
|
CENTER_MAX=8
|
||||||
|
CENTER_DISTANCE_MIN=8
|
||||||
|
PROJECT="as-mamba"
|
||||||
|
RUN_NAME="sphere-to-sphere-dt"
|
||||||
|
OUTPUT_DIR="outputs"
|
||||||
|
|
||||||
|
USE_FLOW_FLAG="--use-flow-loss"
|
||||||
|
if [ "${USE_FLOW_LOSS}" = "false" ]; then USE_FLOW_FLAG="--no-use-flow-loss"; fi
|
||||||
|
USE_POS_FLAG="--use-pos-loss"
|
||||||
|
if [ "${USE_POS_LOSS}" = "false" ]; then USE_POS_FLAG="--no-use-pos-loss"; fi
|
||||||
|
USE_DT_FLAG="--use-dt-loss"
|
||||||
|
if [ "${USE_DT_LOSS}" = "false" ]; then USE_DT_FLAG="--no-use-dt-loss"; fi
|
||||||
|
|
||||||
|
uv run python main.py \
|
||||||
|
--device "${DEVICE}" \
|
||||||
|
--epochs "${EPOCHS}" \
|
||||||
|
--steps-per-epoch "${STEPS_PER_EPOCH}" \
|
||||||
|
--batch-size "${BATCH_SIZE}" \
|
||||||
|
--seq-len "${SEQ_LEN}" \
|
||||||
|
--lr "${LR}" \
|
||||||
|
--dt-min "${DT_MIN}" \
|
||||||
|
--dt-max "${DT_MAX}" \
|
||||||
|
--dt-alpha "${DT_ALPHA}" \
|
||||||
|
--lambda-flow "${LAMBDA_FLOW}" \
|
||||||
|
--lambda-pos "${LAMBDA_POS}" \
|
||||||
|
--lambda-dt "${LAMBDA_DT}" \
|
||||||
|
${USE_FLOW_FLAG} \
|
||||||
|
${USE_POS_FLAG} \
|
||||||
|
${USE_DT_FLAG} \
|
||||||
|
--val-every "${VAL_EVERY}" \
|
||||||
|
--val-samples "${VAL_SAMPLES}" \
|
||||||
|
--val-plot-samples "${VAL_PLOT_SAMPLES}" \
|
||||||
|
--val-max-steps "${VAL_MAX_STEPS}" \
|
||||||
|
--center-min "${CENTER_MIN}" \
|
||||||
|
--center-max "${CENTER_MAX}" \
|
||||||
|
--center-distance-min "${CENTER_DISTANCE_MIN}" \
|
||||||
|
--project "${PROJECT}" \
|
||||||
|
--run-name "${RUN_NAME}" \
|
||||||
|
--output-dir "${OUTPUT_DIR}"
|
||||||
Reference in New Issue
Block a user