feat: add conditional AdaLNZero and two-target spheres sampling

This commit is contained in:
gameloader
2026-01-21 15:41:40 +08:00
parent cac3236f9d
commit c15115edc4

View File

@@ -59,6 +59,21 @@ class TrainConfig:
val_max_steps: int = 0 val_max_steps: int = 0
class AdaLNZero(nn.Module):
def __init__(self, d_model: int) -> None:
super().__init__()
self.norm = RMSNorm(d_model)
self.mod = nn.Linear(d_model, 2 * d_model)
nn.init.zeros_(self.mod.weight)
nn.init.zeros_(self.mod.bias)
def forward(self, x: Tensor, cond: Tensor) -> Tensor:
x = self.norm(x)
params = self.mod(cond).unsqueeze(1)
scale, shift = params.chunk(2, dim=-1)
return x * (1 + scale) + shift
class Mamba2Backbone(nn.Module): class Mamba2Backbone(nn.Module):
def __init__(self, args: Mamba2Config, use_residual: bool = True) -> None: def __init__(self, args: Mamba2Config, use_residual: bool = True) -> None:
super().__init__() super().__init__()
@@ -69,7 +84,7 @@ class Mamba2Backbone(nn.Module):
nn.ModuleDict( nn.ModuleDict(
dict( dict(
mixer=Mamba2(args), mixer=Mamba2(args),
norm=RMSNorm(args.d_model), adaln=AdaLNZero(args.d_model),
) )
) )
for _ in range(args.n_layer) for _ in range(args.n_layer)
@@ -78,13 +93,19 @@ class Mamba2Backbone(nn.Module):
self.norm_f = RMSNorm(args.d_model) self.norm_f = RMSNorm(args.d_model)
def forward( def forward(
self, x: Tensor, h: Optional[list[InferenceCache]] = None self,
x: Tensor,
cond: Optional[Tensor] = None,
h: Optional[list[InferenceCache]] = None,
) -> tuple[Tensor, list[InferenceCache]]: ) -> tuple[Tensor, list[InferenceCache]]:
if h is None: if h is None:
h = [None for _ in range(self.args.n_layer)] h = [None for _ in range(self.args.n_layer)]
if cond is None:
cond = torch.zeros(x.shape[0], x.shape[-1], device=x.device, dtype=x.dtype)
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
y, h[i] = layer["mixer"](layer["norm"](x), h[i]) x_mod = layer["adaln"](x, cond)
y, h[i] = layer["mixer"](x_mod, h[i])
x = x + y if self.use_residual else y x = x + y if self.use_residual else y
x = self.norm_f(x) x = self.norm_f(x)
@@ -109,6 +130,7 @@ class ASMamba(nn.Module):
) )
self.backbone = Mamba2Backbone(args, use_residual=cfg.use_residual) self.backbone = Mamba2Backbone(args, use_residual=cfg.use_residual)
self.input_proj = nn.Linear(3, cfg.d_model) self.input_proj = nn.Linear(3, cfg.d_model)
self.cond_emb = nn.Embedding(2, cfg.d_model)
self.delta_head = nn.Linear(cfg.d_model, 3) self.delta_head = nn.Linear(cfg.d_model, 3)
self.dt_head = nn.Sequential( self.dt_head = nn.Sequential(
nn.Linear(cfg.d_model, cfg.d_model), nn.Linear(cfg.d_model, cfg.d_model),
@@ -117,19 +139,20 @@ class ASMamba(nn.Module):
) )
def forward( def forward(
self, x: Tensor, h: Optional[list[InferenceCache]] = None self, x: Tensor, cond: Tensor, h: Optional[list[InferenceCache]] = None
) -> tuple[Tensor, Tensor, list[InferenceCache]]: ) -> tuple[Tensor, Tensor, list[InferenceCache]]:
x_proj = self.input_proj(x) x_proj = self.input_proj(x)
feats, h = self.backbone(x_proj, h) cond_vec = self.cond_emb(cond)
feats, h = self.backbone(x_proj, cond_vec, h)
delta = self.delta_head(feats) delta = self.delta_head(feats)
dt_raw = self.dt_head(feats).squeeze(-1) dt_raw = self.dt_head(feats).squeeze(-1)
dt = torch.clamp(F.softplus(dt_raw), min=self.dt_min, max=self.dt_max) dt = torch.clamp(F.softplus(dt_raw), min=self.dt_min, max=self.dt_max)
return delta, dt, h return delta, dt, h
def step( def step(
self, x: Tensor, h: list[InferenceCache] self, x: Tensor, cond: Tensor, h: list[InferenceCache]
) -> tuple[Tensor, Tensor, list[InferenceCache]]: ) -> tuple[Tensor, Tensor, list[InferenceCache]]:
delta, dt, h = self.forward(x.unsqueeze(1), h) delta, dt, h = self.forward(x.unsqueeze(1), cond, h)
return delta[:, 0, :], dt[:, 0], h return delta[:, 0, :], dt[:, 0], h
def init_cache(self, batch_size: int, device: torch.device) -> list[InferenceCache]: def init_cache(self, batch_size: int, device: torch.device) -> list[InferenceCache]:
@@ -207,21 +230,42 @@ def sample_points_in_sphere(
return center + direction * r return center + direction * r
def sample_sphere_params(cfg: TrainConfig, device: torch.device) -> tuple[Tensor, Tensor]: def sample_center(cfg: TrainConfig, device: torch.device) -> Tensor:
center_a = torch.empty(3, device=device).uniform_(cfg.center_min, cfg.center_max) return torch.empty(3, device=device).uniform_(cfg.center_min, cfg.center_max)
center_b = torch.empty(3, device=device).uniform_(cfg.center_min, cfg.center_max)
for _ in range(128):
if torch.norm(center_a - center_b) >= cfg.center_distance_min: def sample_center_far(
break cfg: TrainConfig, device: torch.device, refs: list[Tensor]
center_b = torch.empty(3, device=device).uniform_(cfg.center_min, cfg.center_max) ) -> Tensor:
if torch.norm(center_a - center_b) < 1e-3: center = sample_center(cfg, device)
center_b = center_b + torch.tensor([cfg.center_distance_min, 0.0, 0.0], device=device) for _ in range(256):
if all(torch.norm(center - ref) >= cfg.center_distance_min for ref in refs):
return center
center = sample_center(cfg, device)
return center
def sample_spheres_params(
cfg: TrainConfig, device: torch.device
) -> tuple[tuple[Tensor, Tensor], tuple[Tensor, Tensor], tuple[Tensor, Tensor]]:
center_a = sample_center(cfg, device)
center_b0 = sample_center_far(cfg, device, [center_a])
center_b1 = sample_center_far(cfg, device, [center_a, center_b0])
if torch.norm(center_a - center_b0) < 1e-3:
center_b0 = center_b0 + torch.tensor(
[cfg.center_distance_min, 0.0, 0.0], device=device
)
if torch.norm(center_a - center_b1) < 1e-3:
center_b1 = center_b1 + torch.tensor(
[-cfg.center_distance_min, 0.0, 0.0], device=device
)
radius_a = float(torch.empty(1).uniform_(cfg.radius_min, cfg.radius_max).item()) radius_a = float(torch.empty(1).uniform_(cfg.radius_min, cfg.radius_max).item())
radius_b = float(torch.empty(1).uniform_(cfg.radius_min, cfg.radius_max).item()) radius_b0 = float(torch.empty(1).uniform_(cfg.radius_min, cfg.radius_max).item())
return (center_a, torch.tensor(radius_a, device=device)), ( radius_b1 = float(torch.empty(1).uniform_(cfg.radius_min, cfg.radius_max).item())
center_b, sphere_a = (center_a, torch.tensor(radius_a, device=device))
torch.tensor(radius_b, device=device), sphere_b0 = (center_b0, torch.tensor(radius_b0, device=device))
) sphere_b1 = (center_b1, torch.tensor(radius_b1, device=device))
return sphere_a, sphere_b0, sphere_b1
def sample_time_sequence(cfg: TrainConfig, batch_size: int, device: torch.device) -> Tensor: def sample_time_sequence(cfg: TrainConfig, batch_size: int, device: torch.device) -> Tensor:
@@ -246,19 +290,26 @@ def sample_time_sequence(cfg: TrainConfig, batch_size: int, device: torch.device
def sample_batch( def sample_batch(
cfg: TrainConfig, cfg: TrainConfig,
sphere_a: tuple[Tensor, Tensor], sphere_a: tuple[Tensor, Tensor],
sphere_b: tuple[Tensor, Tensor], sphere_b0: tuple[Tensor, Tensor],
sphere_b1: tuple[Tensor, Tensor],
device: torch.device, device: torch.device,
) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
center_a, radius_a = sphere_a center_a, radius_a = sphere_a
center_b, radius_b = sphere_b
x0 = sample_points_in_sphere(center_a, float(radius_a.item()), cfg.batch_size, device) x0 = sample_points_in_sphere(center_a, float(radius_a.item()), cfg.batch_size, device)
x1 = sample_points_in_sphere(center_b, float(radius_b.item()), cfg.batch_size, device) cond = torch.randint(0, 2, (cfg.batch_size,), device=device, dtype=torch.long)
x1_0 = sample_points_in_sphere(
sphere_b0[0], float(sphere_b0[1].item()), cfg.batch_size, device
)
x1_1 = sample_points_in_sphere(
sphere_b1[0], float(sphere_b1[1].item()), cfg.batch_size, device
)
x1 = torch.where(cond[:, None] == 0, x1_0, x1_1)
v_gt = x1 - x0 v_gt = x1 - x0
dt_seq = sample_time_sequence(cfg, cfg.batch_size, device) dt_seq = sample_time_sequence(cfg, cfg.batch_size, device)
t_seq = torch.cumsum(dt_seq, dim=-1) t_seq = torch.cumsum(dt_seq, dim=-1)
t_seq = torch.cat([torch.zeros(cfg.batch_size, 1, device=device), t_seq[:, :-1]], dim=-1) t_seq = torch.cat([torch.zeros(cfg.batch_size, 1, device=device), t_seq[:, :-1]], dim=-1)
x_seq = x0[:, None, :] + t_seq[:, :, None] * v_gt[:, None, :] x_seq = x0[:, None, :] + t_seq[:, :, None] * v_gt[:, None, :]
return x0, x1, x_seq, t_seq, dt_seq return x0, x1, x_seq, t_seq, dt_seq, cond
def compute_losses( def compute_losses(
@@ -294,30 +345,50 @@ def validate(
model: ASMamba, model: ASMamba,
cfg: TrainConfig, cfg: TrainConfig,
sphere_a: tuple[Tensor, Tensor], sphere_a: tuple[Tensor, Tensor],
sphere_b: tuple[Tensor, Tensor], sphere_b0: tuple[Tensor, Tensor],
sphere_b1: tuple[Tensor, Tensor],
device: torch.device, device: torch.device,
logger: SwanLogger, logger: SwanLogger,
step: int, step: int,
) -> None: ) -> None:
model.eval() model.eval()
center_b, radius_b = sphere_b center_b0, radius_b0 = sphere_b0
center_b1, radius_b1 = sphere_b1
max_steps = cfg.seq_len if cfg.val_max_steps <= 0 else cfg.val_max_steps max_steps = cfg.seq_len if cfg.val_max_steps <= 0 else cfg.val_max_steps
with torch.no_grad(): with torch.no_grad():
x0 = sample_points_in_sphere( x0 = sample_points_in_sphere(
sphere_a[0], float(sphere_a[1].item()), cfg.val_samples, device sphere_a[0], float(sphere_a[1].item()), cfg.val_samples, device
) )
traj = rollout_trajectory(model, x0, max_steps=max_steps) cond = torch.randint(0, 2, (cfg.val_samples,), device=device, dtype=torch.long)
traj = rollout_trajectory(model, x0, cond, max_steps=max_steps)
x_final = traj[:, -1, :] x_final = traj[:, -1, :]
center_b_cpu = center_b.detach().cpu() center_b0_cpu = center_b0.detach().cpu()
radius_b_cpu = radius_b.detach().cpu() center_b1_cpu = center_b1.detach().cpu()
dist = torch.linalg.norm(x_final - center_b_cpu, dim=-1) radius_b0_cpu = radius_b0.detach().cpu()
inside = dist <= radius_b_cpu radius_b1_cpu = radius_b1.detach().cpu()
cond_cpu = cond.detach().cpu()
target_center = torch.where(
cond_cpu[:, None] == 0, center_b0_cpu.unsqueeze(0), center_b1_cpu.unsqueeze(0)
)
target_radius = torch.where(cond_cpu == 0, radius_b0_cpu, radius_b1_cpu)
dist = torch.linalg.norm(x_final - target_center, dim=-1)
inside = dist <= target_radius
mask0 = cond_cpu == 0
mask1 = cond_cpu == 1
inside0 = inside[mask0]
inside1 = inside[mask1]
ratio0 = float(inside0.float().mean().item()) if inside0.numel() > 0 else 0.0
ratio1 = float(inside1.float().mean().item()) if inside1.numel() > 0 else 0.0
logger.log( logger.log(
{ {
"val/inside_ratio": float(inside.float().mean().item()), "val/inside_ratio": float(inside.float().mean().item()),
"val/inside_ratio_c0": ratio0,
"val/inside_ratio_c1": ratio1,
"val/cond0_count": float(mask0.float().sum().item()),
"val/cond1_count": float(mask1.float().sum().item()),
"val/inside_count": float(inside.float().sum().item()), "val/inside_count": float(inside.float().sum().item()),
"val/final_dist_mean": float(dist.mean().item()), "val/final_dist_mean": float(dist.mean().item()),
"val/final_dist_min": float(dist.min().item()), "val/final_dist_min": float(dist.min().item()),
@@ -334,11 +405,14 @@ def validate(
return return
indices = torch.linspace(0, traj.shape[0] - 1, steps=count).long() indices = torch.linspace(0, traj.shape[0] - 1, steps=count).long()
traj_plot = traj[indices] traj_plot = traj[indices]
cond_plot = cond_cpu[indices]
save_path = Path(cfg.output_dir) / f"val_traj_step_{step:06d}.png" save_path = Path(cfg.output_dir) / f"val_traj_step_{step:06d}.png"
plot_trajectories( plot_trajectories_cond(
traj_plot, traj_plot,
cond_plot,
sphere_a, sphere_a,
sphere_b, sphere_b0,
sphere_b1,
save_path, save_path,
title=f"Validation Trajectories (step {step})", title=f"Validation Trajectories (step {step})",
) )
@@ -353,7 +427,9 @@ def validate(
model.train() model.train()
def train(cfg: TrainConfig) -> tuple[ASMamba, tuple[Tensor, Tensor], tuple[Tensor, Tensor]]: def train(
cfg: TrainConfig,
) -> tuple[ASMamba, tuple[Tensor, Tensor], tuple[Tensor, Tensor], tuple[Tensor, Tensor]]:
device = torch.device(cfg.device if torch.cuda.is_available() else "cpu") device = torch.device(cfg.device if torch.cuda.is_available() else "cpu")
set_seed(cfg.seed) set_seed(cfg.seed)
output_dir = Path(cfg.output_dir) output_dir = Path(cfg.output_dir)
@@ -363,21 +439,30 @@ def train(cfg: TrainConfig) -> tuple[ASMamba, tuple[Tensor, Tensor], tuple[Tenso
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
logger = SwanLogger(cfg) logger = SwanLogger(cfg)
sphere_a, sphere_b = sample_sphere_params(cfg, device) sphere_a, sphere_b0, sphere_b1 = sample_spheres_params(cfg, device)
center_a, radius_a = sphere_a center_a, radius_a = sphere_a
center_b, radius_b = sphere_b center_b0, radius_b0 = sphere_b0
center_dist = torch.norm(center_a - center_b).item() center_b1, radius_b1 = sphere_b1
dist_a_b0 = torch.norm(center_a - center_b0).item()
dist_a_b1 = torch.norm(center_a - center_b1).item()
dist_b0_b1 = torch.norm(center_b0 - center_b1).item()
logger.log( logger.log(
{ {
"sphere_a/radius": float(radius_a.item()), "sphere_a/radius": float(radius_a.item()),
"sphere_b/radius": float(radius_b.item()),
"sphere_a/center_x": float(center_a[0].item()), "sphere_a/center_x": float(center_a[0].item()),
"sphere_a/center_y": float(center_a[1].item()), "sphere_a/center_y": float(center_a[1].item()),
"sphere_a/center_z": float(center_a[2].item()), "sphere_a/center_z": float(center_a[2].item()),
"sphere_b/center_x": float(center_b[0].item()), "sphere_b0/radius": float(radius_b0.item()),
"sphere_b/center_y": float(center_b[1].item()), "sphere_b0/center_x": float(center_b0[0].item()),
"sphere_b/center_z": float(center_b[2].item()), "sphere_b0/center_y": float(center_b0[1].item()),
"sphere/center_dist": float(center_dist), "sphere_b0/center_z": float(center_b0[2].item()),
"sphere_b1/radius": float(radius_b1.item()),
"sphere_b1/center_x": float(center_b1[0].item()),
"sphere_b1/center_y": float(center_b1[1].item()),
"sphere_b1/center_z": float(center_b1[2].item()),
"sphere/dist_a_b0": float(dist_a_b0),
"sphere/dist_a_b1": float(dist_a_b1),
"sphere/dist_b0_b1": float(dist_b0_b1),
} }
) )
@@ -386,10 +471,12 @@ def train(cfg: TrainConfig) -> tuple[ASMamba, tuple[Tensor, Tensor], tuple[Tenso
model.train() model.train()
for _ in range(cfg.steps_per_epoch): for _ in range(cfg.steps_per_epoch):
x0, x1, x_seq, t_seq, dt_seq = sample_batch(cfg, sphere_a, sphere_b, device) x0, x1, x_seq, t_seq, dt_seq, cond = sample_batch(
cfg, sphere_a, sphere_b0, sphere_b1, device
)
v_gt = x1 - x0 v_gt = x1 - x0
delta, dt, _ = model(x_seq) delta, dt, _ = model(x_seq, cond)
losses = compute_losses( losses = compute_losses(
delta=delta, delta=delta,
@@ -449,7 +536,16 @@ def train(cfg: TrainConfig) -> tuple[ASMamba, tuple[Tensor, Tensor], tuple[Tenso
) )
if cfg.val_every > 0 and global_step > 0 and global_step % cfg.val_every == 0: if cfg.val_every > 0 and global_step > 0 and global_step % cfg.val_every == 0:
validate(model, cfg, sphere_a, sphere_b, device, logger, global_step) validate(
model,
cfg,
sphere_a,
sphere_b0,
sphere_b1,
device,
logger,
global_step,
)
dt_hist_path = Path(cfg.output_dir) / f"dt_hist_step_{global_step:06d}.png" dt_hist_path = Path(cfg.output_dir) / f"dt_hist_step_{global_step:06d}.png"
plot_dt_hist( plot_dt_hist(
dt, dt,
@@ -466,12 +562,13 @@ def train(cfg: TrainConfig) -> tuple[ASMamba, tuple[Tensor, Tensor], tuple[Tenso
global_step += 1 global_step += 1
logger.finish() logger.finish()
return model, sphere_a, sphere_b return model, sphere_a, sphere_b0, sphere_b1
def rollout_trajectory( def rollout_trajectory(
model: ASMamba, model: ASMamba,
x0: Tensor, x0: Tensor,
cond: Tensor,
max_steps: int, max_steps: int,
) -> Tensor: ) -> Tensor:
device = x0.device device = x0.device
@@ -483,7 +580,7 @@ def rollout_trajectory(
with torch.no_grad(): with torch.no_grad():
for _ in range(max_steps): for _ in range(max_steps):
delta, dt, h = model.step(x, h) delta, dt, h = model.step(x, cond, h)
dt = torch.clamp(dt, min=model.dt_min, max=model.dt_max) dt = torch.clamp(dt, min=model.dt_min, max=model.dt_max)
remaining = 1.0 - total_time remaining = 1.0 - total_time
overshoot = dt > remaining overshoot = dt > remaining
@@ -512,10 +609,12 @@ def sphere_wireframe(
return x, y, z return x, y, z
def plot_trajectories( def plot_trajectories_cond(
traj: Tensor, traj: Tensor,
cond: Tensor,
sphere_a: tuple[Tensor, Tensor], sphere_a: tuple[Tensor, Tensor],
sphere_b: tuple[Tensor, Tensor], sphere_b0: tuple[Tensor, Tensor],
sphere_b1: tuple[Tensor, Tensor],
save_path: Path, save_path: Path,
title: str = "AS-Mamba Trajectories", title: str = "AS-Mamba Trajectories",
) -> None: ) -> None:
@@ -523,16 +622,18 @@ def plot_trajectories(
if traj.dim() == 2: if traj.dim() == 2:
traj = traj.unsqueeze(0) traj = traj.unsqueeze(0)
traj_np = traj.numpy() traj_np = traj.numpy()
cond_np = cond.detach().cpu().numpy()
fig = plt.figure(figsize=(7, 6)) fig = plt.figure(figsize=(7, 6))
ax = fig.add_subplot(111, projection="3d") ax = fig.add_subplot(111, projection="3d")
for i in range(traj_np.shape[0]): for i in range(traj_np.shape[0]):
color = "tab:green" if cond_np[i] == 0 else "tab:orange"
ax.plot( ax.plot(
traj_np[i, :, 0], traj_np[i, :, 0],
traj_np[i, :, 1], traj_np[i, :, 1],
traj_np[i, :, 2], traj_np[i, :, 2],
color="green", color=color,
alpha=0.6, alpha=0.6,
) )
@@ -542,11 +643,14 @@ def plot_trajectories(
ax.scatter(ends[:, 0], ends[:, 1], ends[:, 2], color="red", s=20, label="End") ax.scatter(ends[:, 0], ends[:, 1], ends[:, 2], color="red", s=20, label="End")
center_a, radius_a = sphere_a center_a, radius_a = sphere_a
center_b, radius_b = sphere_b center_b0, radius_b0 = sphere_b0
center_b1, radius_b1 = sphere_b1
x_a, y_a, z_a = sphere_wireframe(center_a, float(radius_a.item())) x_a, y_a, z_a = sphere_wireframe(center_a, float(radius_a.item()))
x_b, y_b, z_b = sphere_wireframe(center_b, float(radius_b.item())) x_b0, y_b0, z_b0 = sphere_wireframe(center_b0, float(radius_b0.item()))
x_b1, y_b1, z_b1 = sphere_wireframe(center_b1, float(radius_b1.item()))
ax.plot_wireframe(x_a, y_a, z_a, color="blue", alpha=0.15, linewidth=0.5) ax.plot_wireframe(x_a, y_a, z_a, color="blue", alpha=0.15, linewidth=0.5)
ax.plot_wireframe(x_b, y_b, z_b, color="red", alpha=0.15, linewidth=0.5) ax.plot_wireframe(x_b0, y_b0, z_b0, color="tab:green", alpha=0.2, linewidth=0.5)
ax.plot_wireframe(x_b1, y_b1, z_b1, color="tab:orange", alpha=0.2, linewidth=0.5)
ax.set_title(title) ax.set_title(title)
ax.set_xlabel("X") ax.set_xlabel("X")
@@ -580,7 +684,7 @@ def plot_dt_hist(
def run_training_and_plot(cfg: TrainConfig) -> Path: def run_training_and_plot(cfg: TrainConfig) -> Path:
model, sphere_a, sphere_b = train(cfg) model, sphere_a, sphere_b0, sphere_b1 = train(cfg)
device = next(model.parameters()).device device = next(model.parameters()).device
plot_samples = max(1, cfg.val_plot_samples) plot_samples = max(1, cfg.val_plot_samples)
@@ -588,8 +692,9 @@ def run_training_and_plot(cfg: TrainConfig) -> Path:
sphere_a[0], float(sphere_a[1].item()), plot_samples, device sphere_a[0], float(sphere_a[1].item()), plot_samples, device
) )
max_steps = cfg.seq_len if cfg.val_max_steps <= 0 else cfg.val_max_steps max_steps = cfg.seq_len if cfg.val_max_steps <= 0 else cfg.val_max_steps
traj = rollout_trajectory(model, x0, max_steps=max_steps) cond = torch.randint(0, 2, (plot_samples,), device=device, dtype=torch.long)
traj = rollout_trajectory(model, x0, cond, max_steps=max_steps)
output_dir = Path(cfg.output_dir) output_dir = Path(cfg.output_dir)
save_path = output_dir / "as_mamba_trajectory.png" save_path = output_dir / "as_mamba_trajectory.png"
plot_trajectories(traj, sphere_a, sphere_b, save_path) plot_trajectories_cond(traj, cond, sphere_a, sphere_b0, sphere_b1, save_path)
return save_path return save_path