From c15115edc426f457b296fd14bd1bed9a0f036cc3 Mon Sep 17 00:00:00 2001 From: gameloader Date: Wed, 21 Jan 2026 15:41:40 +0800 Subject: [PATCH] feat: add conditional AdaLNZero and two-target spheres sampling --- as_mamba.py | 221 ++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 163 insertions(+), 58 deletions(-) diff --git a/as_mamba.py b/as_mamba.py index f3c5ba1..8a053fa 100644 --- a/as_mamba.py +++ b/as_mamba.py @@ -59,6 +59,21 @@ class TrainConfig: val_max_steps: int = 0 +class AdaLNZero(nn.Module): + def __init__(self, d_model: int) -> None: + super().__init__() + self.norm = RMSNorm(d_model) + self.mod = nn.Linear(d_model, 2 * d_model) + nn.init.zeros_(self.mod.weight) + nn.init.zeros_(self.mod.bias) + + def forward(self, x: Tensor, cond: Tensor) -> Tensor: + x = self.norm(x) + params = self.mod(cond).unsqueeze(1) + scale, shift = params.chunk(2, dim=-1) + return x * (1 + scale) + shift + + class Mamba2Backbone(nn.Module): def __init__(self, args: Mamba2Config, use_residual: bool = True) -> None: super().__init__() @@ -69,7 +84,7 @@ class Mamba2Backbone(nn.Module): nn.ModuleDict( dict( mixer=Mamba2(args), - norm=RMSNorm(args.d_model), + adaln=AdaLNZero(args.d_model), ) ) for _ in range(args.n_layer) @@ -78,13 +93,19 @@ class Mamba2Backbone(nn.Module): self.norm_f = RMSNorm(args.d_model) def forward( - self, x: Tensor, h: Optional[list[InferenceCache]] = None + self, + x: Tensor, + cond: Optional[Tensor] = None, + h: Optional[list[InferenceCache]] = None, ) -> tuple[Tensor, list[InferenceCache]]: if h is None: h = [None for _ in range(self.args.n_layer)] + if cond is None: + cond = torch.zeros(x.shape[0], x.shape[-1], device=x.device, dtype=x.dtype) for i, layer in enumerate(self.layers): - y, h[i] = layer["mixer"](layer["norm"](x), h[i]) + x_mod = layer["adaln"](x, cond) + y, h[i] = layer["mixer"](x_mod, h[i]) x = x + y if self.use_residual else y x = self.norm_f(x) @@ -109,6 +130,7 @@ class ASMamba(nn.Module): ) self.backbone = Mamba2Backbone(args, use_residual=cfg.use_residual) self.input_proj = nn.Linear(3, cfg.d_model) + self.cond_emb = nn.Embedding(2, cfg.d_model) self.delta_head = nn.Linear(cfg.d_model, 3) self.dt_head = nn.Sequential( nn.Linear(cfg.d_model, cfg.d_model), @@ -117,19 +139,20 @@ class ASMamba(nn.Module): ) def forward( - self, x: Tensor, h: Optional[list[InferenceCache]] = None + self, x: Tensor, cond: Tensor, h: Optional[list[InferenceCache]] = None ) -> tuple[Tensor, Tensor, list[InferenceCache]]: x_proj = self.input_proj(x) - feats, h = self.backbone(x_proj, h) + cond_vec = self.cond_emb(cond) + feats, h = self.backbone(x_proj, cond_vec, h) delta = self.delta_head(feats) dt_raw = self.dt_head(feats).squeeze(-1) dt = torch.clamp(F.softplus(dt_raw), min=self.dt_min, max=self.dt_max) return delta, dt, h def step( - self, x: Tensor, h: list[InferenceCache] + self, x: Tensor, cond: Tensor, h: list[InferenceCache] ) -> tuple[Tensor, Tensor, list[InferenceCache]]: - delta, dt, h = self.forward(x.unsqueeze(1), h) + delta, dt, h = self.forward(x.unsqueeze(1), cond, h) return delta[:, 0, :], dt[:, 0], h def init_cache(self, batch_size: int, device: torch.device) -> list[InferenceCache]: @@ -207,21 +230,42 @@ def sample_points_in_sphere( return center + direction * r -def sample_sphere_params(cfg: TrainConfig, device: torch.device) -> tuple[Tensor, Tensor]: - center_a = torch.empty(3, device=device).uniform_(cfg.center_min, cfg.center_max) - center_b = torch.empty(3, device=device).uniform_(cfg.center_min, cfg.center_max) - for _ in range(128): - if torch.norm(center_a - center_b) >= cfg.center_distance_min: - break - center_b = torch.empty(3, device=device).uniform_(cfg.center_min, cfg.center_max) - if torch.norm(center_a - center_b) < 1e-3: - center_b = center_b + torch.tensor([cfg.center_distance_min, 0.0, 0.0], device=device) +def sample_center(cfg: TrainConfig, device: torch.device) -> Tensor: + return torch.empty(3, device=device).uniform_(cfg.center_min, cfg.center_max) + + +def sample_center_far( + cfg: TrainConfig, device: torch.device, refs: list[Tensor] +) -> Tensor: + center = sample_center(cfg, device) + for _ in range(256): + if all(torch.norm(center - ref) >= cfg.center_distance_min for ref in refs): + return center + center = sample_center(cfg, device) + return center + + +def sample_spheres_params( + cfg: TrainConfig, device: torch.device +) -> tuple[tuple[Tensor, Tensor], tuple[Tensor, Tensor], tuple[Tensor, Tensor]]: + center_a = sample_center(cfg, device) + center_b0 = sample_center_far(cfg, device, [center_a]) + center_b1 = sample_center_far(cfg, device, [center_a, center_b0]) + if torch.norm(center_a - center_b0) < 1e-3: + center_b0 = center_b0 + torch.tensor( + [cfg.center_distance_min, 0.0, 0.0], device=device + ) + if torch.norm(center_a - center_b1) < 1e-3: + center_b1 = center_b1 + torch.tensor( + [-cfg.center_distance_min, 0.0, 0.0], device=device + ) radius_a = float(torch.empty(1).uniform_(cfg.radius_min, cfg.radius_max).item()) - radius_b = float(torch.empty(1).uniform_(cfg.radius_min, cfg.radius_max).item()) - return (center_a, torch.tensor(radius_a, device=device)), ( - center_b, - torch.tensor(radius_b, device=device), - ) + radius_b0 = float(torch.empty(1).uniform_(cfg.radius_min, cfg.radius_max).item()) + radius_b1 = float(torch.empty(1).uniform_(cfg.radius_min, cfg.radius_max).item()) + sphere_a = (center_a, torch.tensor(radius_a, device=device)) + sphere_b0 = (center_b0, torch.tensor(radius_b0, device=device)) + sphere_b1 = (center_b1, torch.tensor(radius_b1, device=device)) + return sphere_a, sphere_b0, sphere_b1 def sample_time_sequence(cfg: TrainConfig, batch_size: int, device: torch.device) -> Tensor: @@ -246,19 +290,26 @@ def sample_time_sequence(cfg: TrainConfig, batch_size: int, device: torch.device def sample_batch( cfg: TrainConfig, sphere_a: tuple[Tensor, Tensor], - sphere_b: tuple[Tensor, Tensor], + sphere_b0: tuple[Tensor, Tensor], + sphere_b1: tuple[Tensor, Tensor], device: torch.device, -) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: center_a, radius_a = sphere_a - center_b, radius_b = sphere_b x0 = sample_points_in_sphere(center_a, float(radius_a.item()), cfg.batch_size, device) - x1 = sample_points_in_sphere(center_b, float(radius_b.item()), cfg.batch_size, device) + cond = torch.randint(0, 2, (cfg.batch_size,), device=device, dtype=torch.long) + x1_0 = sample_points_in_sphere( + sphere_b0[0], float(sphere_b0[1].item()), cfg.batch_size, device + ) + x1_1 = sample_points_in_sphere( + sphere_b1[0], float(sphere_b1[1].item()), cfg.batch_size, device + ) + x1 = torch.where(cond[:, None] == 0, x1_0, x1_1) v_gt = x1 - x0 dt_seq = sample_time_sequence(cfg, cfg.batch_size, device) t_seq = torch.cumsum(dt_seq, dim=-1) t_seq = torch.cat([torch.zeros(cfg.batch_size, 1, device=device), t_seq[:, :-1]], dim=-1) x_seq = x0[:, None, :] + t_seq[:, :, None] * v_gt[:, None, :] - return x0, x1, x_seq, t_seq, dt_seq + return x0, x1, x_seq, t_seq, dt_seq, cond def compute_losses( @@ -294,30 +345,50 @@ def validate( model: ASMamba, cfg: TrainConfig, sphere_a: tuple[Tensor, Tensor], - sphere_b: tuple[Tensor, Tensor], + sphere_b0: tuple[Tensor, Tensor], + sphere_b1: tuple[Tensor, Tensor], device: torch.device, logger: SwanLogger, step: int, ) -> None: model.eval() - center_b, radius_b = sphere_b + center_b0, radius_b0 = sphere_b0 + center_b1, radius_b1 = sphere_b1 max_steps = cfg.seq_len if cfg.val_max_steps <= 0 else cfg.val_max_steps with torch.no_grad(): x0 = sample_points_in_sphere( sphere_a[0], float(sphere_a[1].item()), cfg.val_samples, device ) - traj = rollout_trajectory(model, x0, max_steps=max_steps) + cond = torch.randint(0, 2, (cfg.val_samples,), device=device, dtype=torch.long) + traj = rollout_trajectory(model, x0, cond, max_steps=max_steps) x_final = traj[:, -1, :] - center_b_cpu = center_b.detach().cpu() - radius_b_cpu = radius_b.detach().cpu() - dist = torch.linalg.norm(x_final - center_b_cpu, dim=-1) - inside = dist <= radius_b_cpu + center_b0_cpu = center_b0.detach().cpu() + center_b1_cpu = center_b1.detach().cpu() + radius_b0_cpu = radius_b0.detach().cpu() + radius_b1_cpu = radius_b1.detach().cpu() + cond_cpu = cond.detach().cpu() + target_center = torch.where( + cond_cpu[:, None] == 0, center_b0_cpu.unsqueeze(0), center_b1_cpu.unsqueeze(0) + ) + target_radius = torch.where(cond_cpu == 0, radius_b0_cpu, radius_b1_cpu) + dist = torch.linalg.norm(x_final - target_center, dim=-1) + inside = dist <= target_radius + mask0 = cond_cpu == 0 + mask1 = cond_cpu == 1 + inside0 = inside[mask0] + inside1 = inside[mask1] + ratio0 = float(inside0.float().mean().item()) if inside0.numel() > 0 else 0.0 + ratio1 = float(inside1.float().mean().item()) if inside1.numel() > 0 else 0.0 logger.log( { "val/inside_ratio": float(inside.float().mean().item()), + "val/inside_ratio_c0": ratio0, + "val/inside_ratio_c1": ratio1, + "val/cond0_count": float(mask0.float().sum().item()), + "val/cond1_count": float(mask1.float().sum().item()), "val/inside_count": float(inside.float().sum().item()), "val/final_dist_mean": float(dist.mean().item()), "val/final_dist_min": float(dist.min().item()), @@ -334,11 +405,14 @@ def validate( return indices = torch.linspace(0, traj.shape[0] - 1, steps=count).long() traj_plot = traj[indices] + cond_plot = cond_cpu[indices] save_path = Path(cfg.output_dir) / f"val_traj_step_{step:06d}.png" - plot_trajectories( + plot_trajectories_cond( traj_plot, + cond_plot, sphere_a, - sphere_b, + sphere_b0, + sphere_b1, save_path, title=f"Validation Trajectories (step {step})", ) @@ -353,7 +427,9 @@ def validate( model.train() -def train(cfg: TrainConfig) -> tuple[ASMamba, tuple[Tensor, Tensor], tuple[Tensor, Tensor]]: +def train( + cfg: TrainConfig, +) -> tuple[ASMamba, tuple[Tensor, Tensor], tuple[Tensor, Tensor], tuple[Tensor, Tensor]]: device = torch.device(cfg.device if torch.cuda.is_available() else "cpu") set_seed(cfg.seed) output_dir = Path(cfg.output_dir) @@ -363,21 +439,30 @@ def train(cfg: TrainConfig) -> tuple[ASMamba, tuple[Tensor, Tensor], tuple[Tenso optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) logger = SwanLogger(cfg) - sphere_a, sphere_b = sample_sphere_params(cfg, device) + sphere_a, sphere_b0, sphere_b1 = sample_spheres_params(cfg, device) center_a, radius_a = sphere_a - center_b, radius_b = sphere_b - center_dist = torch.norm(center_a - center_b).item() + center_b0, radius_b0 = sphere_b0 + center_b1, radius_b1 = sphere_b1 + dist_a_b0 = torch.norm(center_a - center_b0).item() + dist_a_b1 = torch.norm(center_a - center_b1).item() + dist_b0_b1 = torch.norm(center_b0 - center_b1).item() logger.log( { "sphere_a/radius": float(radius_a.item()), - "sphere_b/radius": float(radius_b.item()), "sphere_a/center_x": float(center_a[0].item()), "sphere_a/center_y": float(center_a[1].item()), "sphere_a/center_z": float(center_a[2].item()), - "sphere_b/center_x": float(center_b[0].item()), - "sphere_b/center_y": float(center_b[1].item()), - "sphere_b/center_z": float(center_b[2].item()), - "sphere/center_dist": float(center_dist), + "sphere_b0/radius": float(radius_b0.item()), + "sphere_b0/center_x": float(center_b0[0].item()), + "sphere_b0/center_y": float(center_b0[1].item()), + "sphere_b0/center_z": float(center_b0[2].item()), + "sphere_b1/radius": float(radius_b1.item()), + "sphere_b1/center_x": float(center_b1[0].item()), + "sphere_b1/center_y": float(center_b1[1].item()), + "sphere_b1/center_z": float(center_b1[2].item()), + "sphere/dist_a_b0": float(dist_a_b0), + "sphere/dist_a_b1": float(dist_a_b1), + "sphere/dist_b0_b1": float(dist_b0_b1), } ) @@ -386,10 +471,12 @@ def train(cfg: TrainConfig) -> tuple[ASMamba, tuple[Tensor, Tensor], tuple[Tenso model.train() for _ in range(cfg.steps_per_epoch): - x0, x1, x_seq, t_seq, dt_seq = sample_batch(cfg, sphere_a, sphere_b, device) + x0, x1, x_seq, t_seq, dt_seq, cond = sample_batch( + cfg, sphere_a, sphere_b0, sphere_b1, device + ) v_gt = x1 - x0 - delta, dt, _ = model(x_seq) + delta, dt, _ = model(x_seq, cond) losses = compute_losses( delta=delta, @@ -449,7 +536,16 @@ def train(cfg: TrainConfig) -> tuple[ASMamba, tuple[Tensor, Tensor], tuple[Tenso ) if cfg.val_every > 0 and global_step > 0 and global_step % cfg.val_every == 0: - validate(model, cfg, sphere_a, sphere_b, device, logger, global_step) + validate( + model, + cfg, + sphere_a, + sphere_b0, + sphere_b1, + device, + logger, + global_step, + ) dt_hist_path = Path(cfg.output_dir) / f"dt_hist_step_{global_step:06d}.png" plot_dt_hist( dt, @@ -466,12 +562,13 @@ def train(cfg: TrainConfig) -> tuple[ASMamba, tuple[Tensor, Tensor], tuple[Tenso global_step += 1 logger.finish() - return model, sphere_a, sphere_b + return model, sphere_a, sphere_b0, sphere_b1 def rollout_trajectory( model: ASMamba, x0: Tensor, + cond: Tensor, max_steps: int, ) -> Tensor: device = x0.device @@ -483,7 +580,7 @@ def rollout_trajectory( with torch.no_grad(): for _ in range(max_steps): - delta, dt, h = model.step(x, h) + delta, dt, h = model.step(x, cond, h) dt = torch.clamp(dt, min=model.dt_min, max=model.dt_max) remaining = 1.0 - total_time overshoot = dt > remaining @@ -512,10 +609,12 @@ def sphere_wireframe( return x, y, z -def plot_trajectories( +def plot_trajectories_cond( traj: Tensor, + cond: Tensor, sphere_a: tuple[Tensor, Tensor], - sphere_b: tuple[Tensor, Tensor], + sphere_b0: tuple[Tensor, Tensor], + sphere_b1: tuple[Tensor, Tensor], save_path: Path, title: str = "AS-Mamba Trajectories", ) -> None: @@ -523,16 +622,18 @@ def plot_trajectories( if traj.dim() == 2: traj = traj.unsqueeze(0) traj_np = traj.numpy() + cond_np = cond.detach().cpu().numpy() fig = plt.figure(figsize=(7, 6)) ax = fig.add_subplot(111, projection="3d") for i in range(traj_np.shape[0]): + color = "tab:green" if cond_np[i] == 0 else "tab:orange" ax.plot( traj_np[i, :, 0], traj_np[i, :, 1], traj_np[i, :, 2], - color="green", + color=color, alpha=0.6, ) @@ -542,11 +643,14 @@ def plot_trajectories( ax.scatter(ends[:, 0], ends[:, 1], ends[:, 2], color="red", s=20, label="End") center_a, radius_a = sphere_a - center_b, radius_b = sphere_b + center_b0, radius_b0 = sphere_b0 + center_b1, radius_b1 = sphere_b1 x_a, y_a, z_a = sphere_wireframe(center_a, float(radius_a.item())) - x_b, y_b, z_b = sphere_wireframe(center_b, float(radius_b.item())) + x_b0, y_b0, z_b0 = sphere_wireframe(center_b0, float(radius_b0.item())) + x_b1, y_b1, z_b1 = sphere_wireframe(center_b1, float(radius_b1.item())) ax.plot_wireframe(x_a, y_a, z_a, color="blue", alpha=0.15, linewidth=0.5) - ax.plot_wireframe(x_b, y_b, z_b, color="red", alpha=0.15, linewidth=0.5) + ax.plot_wireframe(x_b0, y_b0, z_b0, color="tab:green", alpha=0.2, linewidth=0.5) + ax.plot_wireframe(x_b1, y_b1, z_b1, color="tab:orange", alpha=0.2, linewidth=0.5) ax.set_title(title) ax.set_xlabel("X") @@ -580,7 +684,7 @@ def plot_dt_hist( def run_training_and_plot(cfg: TrainConfig) -> Path: - model, sphere_a, sphere_b = train(cfg) + model, sphere_a, sphere_b0, sphere_b1 = train(cfg) device = next(model.parameters()).device plot_samples = max(1, cfg.val_plot_samples) @@ -588,8 +692,9 @@ def run_training_and_plot(cfg: TrainConfig) -> Path: sphere_a[0], float(sphere_a[1].item()), plot_samples, device ) max_steps = cfg.seq_len if cfg.val_max_steps <= 0 else cfg.val_max_steps - traj = rollout_trajectory(model, x0, max_steps=max_steps) + cond = torch.randint(0, 2, (plot_samples,), device=device, dtype=torch.long) + traj = rollout_trajectory(model, x0, cond, max_steps=max_steps) output_dir = Path(cfg.output_dir) save_path = output_dir / "as_mamba_trajectory.png" - plot_trajectories(traj, sphere_a, sphere_b, save_path) + plot_trajectories_cond(traj, cond, sphere_a, sphere_b0, sphere_b1, save_path) return save_path