feat(mamba): add Mamba2 implementation
Add initial project structure including core Mamba2 logic, entry point, and uv-based dependency management.
This commit is contained in:
10
.gitignore
vendored
Normal file
10
.gitignore
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
# Python-generated files
|
||||
__pycache__/
|
||||
*.py[oc]
|
||||
build/
|
||||
dist/
|
||||
wheels/
|
||||
*.egg-info
|
||||
|
||||
# Virtual environments
|
||||
.venv
|
||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
||||
3.12
|
||||
511
as_mamba.py
Normal file
511
as_mamba.py
Normal file
@@ -0,0 +1,511 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use("Agg")
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from matplotlib import pyplot as plt
|
||||
from mpl_toolkits.mplot3d import Axes3D # noqa: F401
|
||||
from torch import Tensor, nn
|
||||
|
||||
from mamba2_minimal import InferenceCache, Mamba2, Mamba2Config, RMSNorm
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainConfig:
|
||||
seed: int = 42
|
||||
device: str = "cuda"
|
||||
batch_size: int = 128
|
||||
steps_per_epoch: int = 50
|
||||
epochs: int = 60
|
||||
warmup_epochs: int = 15
|
||||
seq_len: int = 20
|
||||
lr: float = 1e-3
|
||||
weight_decay: float = 1e-2
|
||||
dt_min: float = 1e-3
|
||||
dt_max: float = 0.06
|
||||
lambda_flow: float = 1.0
|
||||
lambda_pos: float = 1.0
|
||||
lambda_nfe: float = 0.05
|
||||
radius_min: float = 0.6
|
||||
radius_max: float = 1.4
|
||||
center_min: float = -6.0
|
||||
center_max: float = 6.0
|
||||
center_distance_min: float = 6.0
|
||||
d_model: int = 128
|
||||
n_layer: int = 4
|
||||
d_state: int = 64
|
||||
d_conv: int = 4
|
||||
expand: int = 2
|
||||
headdim: int = 32
|
||||
chunk_size: int = 1
|
||||
use_residual: bool = False
|
||||
output_dir: str = "outputs"
|
||||
project: str = "as-mamba"
|
||||
run_name: str = "sphere-to-sphere"
|
||||
val_every: int = 200
|
||||
val_samples: int = 256
|
||||
val_plot_samples: int = 16
|
||||
val_max_steps: int = 100
|
||||
|
||||
|
||||
class Mamba2Backbone(nn.Module):
|
||||
def __init__(self, args: Mamba2Config, use_residual: bool = True) -> None:
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.use_residual = use_residual
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
nn.ModuleDict(
|
||||
dict(
|
||||
mixer=Mamba2(args),
|
||||
norm=RMSNorm(args.d_model),
|
||||
)
|
||||
)
|
||||
for _ in range(args.n_layer)
|
||||
]
|
||||
)
|
||||
self.norm_f = RMSNorm(args.d_model)
|
||||
|
||||
def forward(
|
||||
self, x: Tensor, h: Optional[list[InferenceCache]] = None
|
||||
) -> tuple[Tensor, list[InferenceCache]]:
|
||||
if h is None:
|
||||
h = [None for _ in range(self.args.n_layer)]
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
y, h[i] = layer["mixer"](layer["norm"](x), h[i])
|
||||
x = x + y if self.use_residual else y
|
||||
|
||||
x = self.norm_f(x)
|
||||
return x, h
|
||||
|
||||
|
||||
class ASMamba(nn.Module):
|
||||
def __init__(self, cfg: TrainConfig) -> None:
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
self.dt_min = float(cfg.dt_min)
|
||||
self.dt_max = float(cfg.dt_max)
|
||||
|
||||
args = Mamba2Config(
|
||||
d_model=cfg.d_model,
|
||||
n_layer=cfg.n_layer,
|
||||
d_state=cfg.d_state,
|
||||
d_conv=cfg.d_conv,
|
||||
expand=cfg.expand,
|
||||
headdim=cfg.headdim,
|
||||
chunk_size=cfg.chunk_size,
|
||||
)
|
||||
self.backbone = Mamba2Backbone(args, use_residual=cfg.use_residual)
|
||||
self.input_proj = nn.Linear(3, cfg.d_model)
|
||||
self.delta_head = nn.Linear(cfg.d_model, 3)
|
||||
self.dt_head = nn.Sequential(
|
||||
nn.Linear(cfg.d_model, cfg.d_model),
|
||||
nn.SiLU(),
|
||||
nn.Linear(cfg.d_model, 1),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, x: Tensor, h: Optional[list[InferenceCache]] = None
|
||||
) -> tuple[Tensor, Tensor, list[InferenceCache]]:
|
||||
x_proj = self.input_proj(x)
|
||||
feats, h = self.backbone(x_proj, h)
|
||||
delta = self.delta_head(feats)
|
||||
dt_raw = self.dt_head(feats).squeeze(-1)
|
||||
dt = torch.clamp(F.softplus(dt_raw), min=self.dt_min, max=self.dt_max)
|
||||
return delta, dt, h
|
||||
|
||||
def step(
|
||||
self, x: Tensor, h: list[InferenceCache]
|
||||
) -> tuple[Tensor, Tensor, list[InferenceCache]]:
|
||||
delta, dt, h = self.forward(x.unsqueeze(1), h)
|
||||
return delta[:, 0, :], dt[:, 0], h
|
||||
|
||||
def init_cache(self, batch_size: int, device: torch.device) -> list[InferenceCache]:
|
||||
return [
|
||||
InferenceCache.alloc(batch_size, self.backbone.args, device=device)
|
||||
for _ in range(self.backbone.args.n_layer)
|
||||
]
|
||||
|
||||
|
||||
class SwanLogger:
|
||||
def __init__(self, cfg: TrainConfig) -> None:
|
||||
self.enabled = False
|
||||
self._swan = None
|
||||
self._run = None
|
||||
try:
|
||||
import swanlab # type: ignore
|
||||
|
||||
self._swan = swanlab
|
||||
self._run = self._swan.init(
|
||||
project=cfg.project,
|
||||
experiment_name=cfg.run_name,
|
||||
config=asdict(cfg),
|
||||
)
|
||||
self.enabled = True
|
||||
except Exception:
|
||||
self.enabled = False
|
||||
|
||||
def log(self, metrics: dict, step: int | None = None) -> None:
|
||||
if not self.enabled:
|
||||
return
|
||||
target = self._run if self._run is not None else self._swan
|
||||
if step is None:
|
||||
target.log(metrics)
|
||||
return
|
||||
try:
|
||||
target.log(metrics, step=step)
|
||||
except TypeError:
|
||||
payload = dict(metrics)
|
||||
payload["step"] = step
|
||||
target.log(payload)
|
||||
|
||||
def log_image(
|
||||
self, key: str, image_path: Path, caption: str | None = None, step: int | None = None
|
||||
) -> None:
|
||||
if not self.enabled:
|
||||
return
|
||||
image = self._swan.Image(str(image_path), caption=caption)
|
||||
self.log({key: image}, step=step)
|
||||
|
||||
def finish(self) -> None:
|
||||
if not self.enabled:
|
||||
return
|
||||
finish = None
|
||||
if self._run is not None:
|
||||
finish = getattr(self._run, "finish", None)
|
||||
if finish is None:
|
||||
finish = getattr(self._swan, "finish", None)
|
||||
if callable(finish):
|
||||
finish()
|
||||
|
||||
|
||||
def set_seed(seed: int) -> None:
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def sample_points_in_sphere(
|
||||
center: Tensor, radius: float, batch_size: int, device: torch.device
|
||||
) -> Tensor:
|
||||
direction = torch.randn(batch_size, 3, device=device)
|
||||
direction = direction / (direction.norm(dim=-1, keepdim=True) + 1e-8)
|
||||
u = torch.rand(batch_size, 1, device=device)
|
||||
r = radius * torch.pow(u, 1.0 / 3.0)
|
||||
return center + direction * r
|
||||
|
||||
|
||||
def sample_sphere_params(cfg: TrainConfig, device: torch.device) -> tuple[Tensor, Tensor]:
|
||||
center_a = torch.empty(3, device=device).uniform_(cfg.center_min, cfg.center_max)
|
||||
center_b = torch.empty(3, device=device).uniform_(cfg.center_min, cfg.center_max)
|
||||
for _ in range(128):
|
||||
if torch.norm(center_a - center_b) >= cfg.center_distance_min:
|
||||
break
|
||||
center_b = torch.empty(3, device=device).uniform_(cfg.center_min, cfg.center_max)
|
||||
if torch.norm(center_a - center_b) < 1e-3:
|
||||
center_b = center_b + torch.tensor([cfg.center_distance_min, 0.0, 0.0], device=device)
|
||||
radius_a = float(torch.empty(1).uniform_(cfg.radius_min, cfg.radius_max).item())
|
||||
radius_b = float(torch.empty(1).uniform_(cfg.radius_min, cfg.radius_max).item())
|
||||
return (center_a, torch.tensor(radius_a, device=device)), (
|
||||
center_b,
|
||||
torch.tensor(radius_b, device=device),
|
||||
)
|
||||
|
||||
|
||||
def sample_batch(
|
||||
cfg: TrainConfig,
|
||||
sphere_a: tuple[Tensor, Tensor],
|
||||
sphere_b: tuple[Tensor, Tensor],
|
||||
device: torch.device,
|
||||
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
center_a, radius_a = sphere_a
|
||||
center_b, radius_b = sphere_b
|
||||
x0 = sample_points_in_sphere(center_a, float(radius_a.item()), cfg.batch_size, device)
|
||||
x1 = sample_points_in_sphere(center_b, float(radius_b.item()), cfg.batch_size, device)
|
||||
v_gt = x1 - x0
|
||||
dt_fixed = 1.0 / cfg.seq_len
|
||||
t_seq = torch.arange(cfg.seq_len, device=device) * dt_fixed
|
||||
x_seq = x0[:, None, :] + t_seq[None, :, None] * v_gt[:, None, :]
|
||||
return x0, x1, x_seq, t_seq
|
||||
|
||||
|
||||
def compute_losses(
|
||||
delta: Tensor,
|
||||
dt: Tensor,
|
||||
x_seq: Tensor,
|
||||
x0: Tensor,
|
||||
v_gt: Tensor,
|
||||
t_seq: Tensor,
|
||||
cfg: TrainConfig,
|
||||
) -> tuple[Tensor, Tensor, Tensor]:
|
||||
target_disp = v_gt[:, None, :] * dt.unsqueeze(-1)
|
||||
flow_loss = F.mse_loss(delta, target_disp)
|
||||
|
||||
t_next = t_seq[None, :, None] + dt.unsqueeze(-1)
|
||||
t_next = torch.clamp(t_next, 0.0, 1.0)
|
||||
x_target = x0[:, None, :] + t_next * v_gt[:, None, :]
|
||||
x_next_pred = x_seq + delta
|
||||
pos_loss = F.mse_loss(x_next_pred, x_target)
|
||||
|
||||
nfe_loss = (-torch.log(dt)).mean()
|
||||
return flow_loss, pos_loss, nfe_loss
|
||||
|
||||
|
||||
def validate(
|
||||
model: ASMamba,
|
||||
cfg: TrainConfig,
|
||||
sphere_a: tuple[Tensor, Tensor],
|
||||
sphere_b: tuple[Tensor, Tensor],
|
||||
device: torch.device,
|
||||
logger: SwanLogger,
|
||||
step: int,
|
||||
) -> None:
|
||||
model.eval()
|
||||
center_b, radius_b = sphere_b
|
||||
|
||||
with torch.no_grad():
|
||||
x0 = sample_points_in_sphere(
|
||||
sphere_a[0], float(sphere_a[1].item()), cfg.val_samples, device
|
||||
)
|
||||
traj = rollout_trajectory(model, x0, max_steps=cfg.val_max_steps)
|
||||
|
||||
x_final = traj[:, -1, :]
|
||||
center_b_cpu = center_b.detach().cpu()
|
||||
radius_b_cpu = radius_b.detach().cpu()
|
||||
dist = torch.linalg.norm(x_final - center_b_cpu, dim=-1)
|
||||
inside = dist <= radius_b_cpu
|
||||
|
||||
logger.log(
|
||||
{
|
||||
"val/inside_ratio": float(inside.float().mean().item()),
|
||||
"val/inside_count": float(inside.float().sum().item()),
|
||||
"val/final_dist_mean": float(dist.mean().item()),
|
||||
"val/final_dist_min": float(dist.min().item()),
|
||||
"val/final_dist_max": float(dist.max().item()),
|
||||
},
|
||||
step=step,
|
||||
)
|
||||
|
||||
if cfg.val_plot_samples > 0:
|
||||
count = min(cfg.val_plot_samples, traj.shape[0])
|
||||
if count == 0:
|
||||
model.train()
|
||||
return
|
||||
indices = torch.linspace(0, traj.shape[0] - 1, steps=count).long()
|
||||
traj_plot = traj[indices]
|
||||
save_path = Path(cfg.output_dir) / f"val_traj_step_{step:06d}.png"
|
||||
plot_trajectories(
|
||||
traj_plot,
|
||||
sphere_a,
|
||||
sphere_b,
|
||||
save_path,
|
||||
title=f"Validation Trajectories (step {step})",
|
||||
)
|
||||
ratio = float(inside.float().mean().item())
|
||||
logger.log_image(
|
||||
"val/trajectory",
|
||||
save_path,
|
||||
caption=f"step {step} | inside_ratio={ratio:.3f}",
|
||||
step=step,
|
||||
)
|
||||
|
||||
model.train()
|
||||
|
||||
|
||||
def train(cfg: TrainConfig) -> tuple[ASMamba, tuple[Tensor, Tensor], tuple[Tensor, Tensor]]:
|
||||
device = torch.device(cfg.device if torch.cuda.is_available() else "cpu")
|
||||
set_seed(cfg.seed)
|
||||
output_dir = Path(cfg.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
model = ASMamba(cfg).to(device)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
|
||||
logger = SwanLogger(cfg)
|
||||
|
||||
sphere_a, sphere_b = sample_sphere_params(cfg, device)
|
||||
center_a, radius_a = sphere_a
|
||||
center_b, radius_b = sphere_b
|
||||
center_dist = torch.norm(center_a - center_b).item()
|
||||
logger.log(
|
||||
{
|
||||
"sphere_a/radius": float(radius_a.item()),
|
||||
"sphere_b/radius": float(radius_b.item()),
|
||||
"sphere_a/center_x": float(center_a[0].item()),
|
||||
"sphere_a/center_y": float(center_a[1].item()),
|
||||
"sphere_a/center_z": float(center_a[2].item()),
|
||||
"sphere_b/center_x": float(center_b[0].item()),
|
||||
"sphere_b/center_y": float(center_b[1].item()),
|
||||
"sphere_b/center_z": float(center_b[2].item()),
|
||||
"sphere/center_dist": float(center_dist),
|
||||
}
|
||||
)
|
||||
|
||||
global_step = 0
|
||||
for epoch in range(cfg.epochs):
|
||||
warmup = epoch < cfg.warmup_epochs
|
||||
model.train()
|
||||
for p in model.dt_head.parameters():
|
||||
p.requires_grad = not warmup
|
||||
|
||||
for _ in range(cfg.steps_per_epoch):
|
||||
x0, x1, x_seq, t_seq = sample_batch(cfg, sphere_a, sphere_b, device)
|
||||
v_gt = x1 - x0
|
||||
|
||||
delta, dt, _ = model(x_seq)
|
||||
if warmup:
|
||||
dt = torch.full_like(dt, 1.0 / cfg.seq_len)
|
||||
|
||||
flow_loss, pos_loss, nfe_loss = compute_losses(
|
||||
delta=delta,
|
||||
dt=dt,
|
||||
x_seq=x_seq,
|
||||
x0=x0,
|
||||
v_gt=v_gt,
|
||||
t_seq=t_seq,
|
||||
cfg=cfg,
|
||||
)
|
||||
|
||||
loss = cfg.lambda_flow * flow_loss + cfg.lambda_pos * pos_loss
|
||||
if not warmup:
|
||||
loss = loss + cfg.lambda_nfe * nfe_loss
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if global_step % 10 == 0:
|
||||
logger.log(
|
||||
{
|
||||
"loss/total": float(loss.item()),
|
||||
"loss/flow": float(flow_loss.item()),
|
||||
"loss/pos": float(pos_loss.item()),
|
||||
"loss/nfe": float(nfe_loss.item()),
|
||||
"dt/mean": float(dt.mean().item()),
|
||||
"dt/min": float(dt.min().item()),
|
||||
"dt/max": float(dt.max().item()),
|
||||
"stage": 0 if warmup else 1,
|
||||
},
|
||||
step=global_step,
|
||||
)
|
||||
|
||||
if cfg.val_every > 0 and global_step > 0 and global_step % cfg.val_every == 0:
|
||||
validate(model, cfg, sphere_a, sphere_b, device, logger, global_step)
|
||||
global_step += 1
|
||||
|
||||
logger.finish()
|
||||
return model, sphere_a, sphere_b
|
||||
|
||||
|
||||
def rollout_trajectory(
|
||||
model: ASMamba,
|
||||
x0: Tensor,
|
||||
max_steps: int = 100,
|
||||
) -> Tensor:
|
||||
device = x0.device
|
||||
model.eval()
|
||||
h = model.init_cache(batch_size=x0.shape[0], device=device)
|
||||
x = x0
|
||||
total_time = torch.zeros(x0.shape[0], device=device)
|
||||
traj = [x0.detach().cpu()]
|
||||
|
||||
with torch.no_grad():
|
||||
for _ in range(max_steps):
|
||||
delta, dt, h = model.step(x, h)
|
||||
dt = torch.clamp(dt, min=model.dt_min, max=model.dt_max)
|
||||
remaining = 1.0 - total_time
|
||||
overshoot = dt > remaining
|
||||
if overshoot.any():
|
||||
scale = (remaining / dt).unsqueeze(-1)
|
||||
delta = torch.where(overshoot.unsqueeze(-1), delta * scale, delta)
|
||||
dt = torch.where(overshoot, remaining, dt)
|
||||
|
||||
x = x + delta
|
||||
total_time = total_time + dt
|
||||
traj.append(x.detach().cpu())
|
||||
|
||||
if torch.all(total_time >= 1.0 - 1e-6):
|
||||
break
|
||||
|
||||
return torch.stack(traj, dim=1)
|
||||
|
||||
|
||||
def sphere_wireframe(
|
||||
center: Tensor, radius: float, u_steps: int = 24, v_steps: int = 12
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
center_np = center.detach().cpu().numpy()
|
||||
u = np.linspace(0, 2 * np.pi, u_steps)
|
||||
v = np.linspace(0, np.pi, v_steps)
|
||||
x = center_np[0] + radius * np.outer(np.cos(u), np.sin(v))
|
||||
y = center_np[1] + radius * np.outer(np.sin(u), np.sin(v))
|
||||
z = center_np[2] + radius * np.outer(np.ones_like(u), np.cos(v))
|
||||
return x, y, z
|
||||
|
||||
|
||||
def plot_trajectories(
|
||||
traj: Tensor,
|
||||
sphere_a: tuple[Tensor, Tensor],
|
||||
sphere_b: tuple[Tensor, Tensor],
|
||||
save_path: Path,
|
||||
title: str = "AS-Mamba Trajectories",
|
||||
) -> None:
|
||||
traj = traj.detach().cpu()
|
||||
if traj.dim() == 2:
|
||||
traj = traj.unsqueeze(0)
|
||||
traj_np = traj.numpy()
|
||||
|
||||
fig = plt.figure(figsize=(7, 6))
|
||||
ax = fig.add_subplot(111, projection="3d")
|
||||
|
||||
for i in range(traj_np.shape[0]):
|
||||
ax.plot(
|
||||
traj_np[i, :, 0],
|
||||
traj_np[i, :, 1],
|
||||
traj_np[i, :, 2],
|
||||
color="green",
|
||||
alpha=0.6,
|
||||
)
|
||||
|
||||
starts = traj_np[:, 0, :]
|
||||
ends = traj_np[:, -1, :]
|
||||
ax.scatter(starts[:, 0], starts[:, 1], starts[:, 2], color="blue", s=20, label="Start")
|
||||
ax.scatter(ends[:, 0], ends[:, 1], ends[:, 2], color="red", s=20, label="End")
|
||||
|
||||
center_a, radius_a = sphere_a
|
||||
center_b, radius_b = sphere_b
|
||||
x_a, y_a, z_a = sphere_wireframe(center_a, float(radius_a.item()))
|
||||
x_b, y_b, z_b = sphere_wireframe(center_b, float(radius_b.item()))
|
||||
ax.plot_wireframe(x_a, y_a, z_a, color="blue", alpha=0.15, linewidth=0.5)
|
||||
ax.plot_wireframe(x_b, y_b, z_b, color="red", alpha=0.15, linewidth=0.5)
|
||||
|
||||
ax.set_title(title)
|
||||
ax.set_xlabel("X")
|
||||
ax.set_ylabel("Y")
|
||||
ax.set_zlabel("Z")
|
||||
ax.legend(loc="best")
|
||||
fig.tight_layout()
|
||||
fig.savefig(save_path, dpi=160)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def run_training_and_plot(cfg: TrainConfig) -> Path:
|
||||
model, sphere_a, sphere_b = train(cfg)
|
||||
device = next(model.parameters()).device
|
||||
|
||||
plot_samples = max(1, cfg.val_plot_samples)
|
||||
x0 = sample_points_in_sphere(
|
||||
sphere_a[0], float(sphere_a[1].item()), plot_samples, device
|
||||
)
|
||||
traj = rollout_trajectory(model, x0, max_steps=cfg.val_max_steps)
|
||||
output_dir = Path(cfg.output_dir)
|
||||
save_path = output_dir / "as_mamba_trajectory.png"
|
||||
plot_trajectories(traj, sphere_a, sphere_b, save_path)
|
||||
return save_path
|
||||
43
main.py
Normal file
43
main.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import argparse
|
||||
|
||||
from as_mamba import TrainConfig, run_training_and_plot
|
||||
|
||||
|
||||
def build_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description="Train AS-Mamba on sphere-to-sphere flow.")
|
||||
parser.add_argument("--epochs", type=int, default=None)
|
||||
parser.add_argument("--warmup-epochs", type=int, default=None)
|
||||
parser.add_argument("--batch-size", type=int, default=None)
|
||||
parser.add_argument("--steps-per-epoch", type=int, default=None)
|
||||
parser.add_argument("--seq-len", type=int, default=None)
|
||||
parser.add_argument("--lr", type=float, default=None)
|
||||
parser.add_argument("--device", type=str, default=None)
|
||||
parser.add_argument("--output-dir", type=str, default=None)
|
||||
parser.add_argument("--project", type=str, default=None)
|
||||
parser.add_argument("--run-name", type=str, default=None)
|
||||
parser.add_argument("--val-every", type=int, default=None)
|
||||
parser.add_argument("--val-samples", type=int, default=None)
|
||||
parser.add_argument("--val-plot-samples", type=int, default=None)
|
||||
parser.add_argument("--val-max-steps", type=int, default=None)
|
||||
parser.add_argument("--center-min", type=float, default=None)
|
||||
parser.add_argument("--center-max", type=float, default=None)
|
||||
parser.add_argument("--center-distance-min", type=float, default=None)
|
||||
parser.add_argument("--use-residual", action="store_true")
|
||||
return parser
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = build_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
cfg = TrainConfig()
|
||||
for key, value in vars(args).items():
|
||||
if value is not None:
|
||||
setattr(cfg, key.replace("-", "_"), value)
|
||||
|
||||
plot_path = run_training_and_plot(cfg)
|
||||
print(f"Saved trajectory plot to {plot_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
246
mamba2_minimal.py
Normal file
246
mamba2_minimal.py
Normal file
@@ -0,0 +1,246 @@
|
||||
"""
|
||||
mamba2-minimal
|
||||
==============
|
||||
|
||||
Minimal Mamba-2 implementation for sequence modeling.
|
||||
Reference: https://arxiv.org/abs/2405.21060
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import NamedTuple, TypeAlias
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from torch import Tensor, nn
|
||||
|
||||
Device: TypeAlias = str | torch.device | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Mamba2Config:
|
||||
d_model: int
|
||||
n_layer: int = 4
|
||||
d_state: int = 64
|
||||
d_conv: int = 4
|
||||
expand: int = 2
|
||||
headdim: int = 32
|
||||
chunk_size: int = 1
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.d_inner = self.expand * self.d_model
|
||||
if self.d_inner % self.headdim != 0:
|
||||
raise ValueError("d_inner must be divisible by headdim")
|
||||
self.nheads = self.d_inner // self.headdim
|
||||
|
||||
|
||||
class InferenceCache(NamedTuple):
|
||||
conv_state: Tensor
|
||||
ssm_state: Tensor
|
||||
|
||||
@staticmethod
|
||||
def alloc(batch_size: int, args: Mamba2Config, device: Device = None) -> "InferenceCache":
|
||||
return InferenceCache(
|
||||
torch.zeros(
|
||||
batch_size, args.d_inner + 2 * args.d_state, args.d_conv, device=device
|
||||
),
|
||||
torch.zeros(
|
||||
batch_size, args.nheads, args.headdim, args.d_state, device=device
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class Mamba2(nn.Module):
|
||||
def __init__(self, args: Mamba2Config, device: Device = None) -> None:
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.device = device
|
||||
|
||||
d_in_proj = 2 * args.d_inner + 2 * args.d_state + args.nheads
|
||||
self.in_proj = nn.Linear(args.d_model, d_in_proj, bias=False, device=device)
|
||||
|
||||
conv_dim = args.d_inner + 2 * args.d_state
|
||||
self.conv1d = nn.Conv1d(
|
||||
in_channels=conv_dim,
|
||||
out_channels=conv_dim,
|
||||
kernel_size=args.d_conv,
|
||||
groups=conv_dim,
|
||||
padding=args.d_conv - 1,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.dt_bias = nn.Parameter(torch.empty(args.nheads, device=device))
|
||||
self.A_log = nn.Parameter(torch.empty(args.nheads, device=device))
|
||||
self.D = nn.Parameter(torch.empty(args.nheads, device=device))
|
||||
self.norm = RMSNorm(args.d_inner, device=device)
|
||||
self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=False, device=device)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
dt_min, dt_max = 1e-3, 1e-1
|
||||
device = self.dt_bias.device
|
||||
dtype = self.dt_bias.dtype
|
||||
dt = torch.exp(
|
||||
torch.rand(self.args.nheads, device=device, dtype=dtype)
|
||||
* (math.log(dt_max) - math.log(dt_min))
|
||||
+ math.log(dt_min)
|
||||
)
|
||||
with torch.no_grad():
|
||||
self.dt_bias.copy_(torch.log(torch.expm1(dt)))
|
||||
self.A_log.copy_(
|
||||
torch.log(
|
||||
torch.arange(
|
||||
1, self.args.nheads + 1, device=device, dtype=self.A_log.dtype
|
||||
)
|
||||
)
|
||||
)
|
||||
self.D.fill_(1.0)
|
||||
|
||||
def forward(self, u: Tensor, h: InferenceCache | None = None):
|
||||
if h is not None:
|
||||
return self.step(u, h)
|
||||
|
||||
A = -torch.exp(self.A_log)
|
||||
zxbcdt = self.in_proj(u)
|
||||
z, xBC, dt = torch.split(
|
||||
zxbcdt,
|
||||
[
|
||||
self.args.d_inner,
|
||||
self.args.d_inner + 2 * self.args.d_state,
|
||||
self.args.nheads,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
dt = F.softplus(dt + self.dt_bias)
|
||||
|
||||
xBC_t = rearrange(xBC, "b l d -> b d l")
|
||||
if u.shape[1] >= self.args.d_conv:
|
||||
conv_state = xBC_t[:, :, -self.args.d_conv :]
|
||||
else:
|
||||
conv_state = F.pad(xBC_t, (self.args.d_conv - u.shape[1], 0))
|
||||
|
||||
xBC = silu(
|
||||
self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, : u.shape[1], :]
|
||||
)
|
||||
x, B, C = torch.split(
|
||||
xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1
|
||||
)
|
||||
x = rearrange(x, "b l (h p) -> b l h p", p=self.args.headdim)
|
||||
y, ssm_state = ssd(
|
||||
x * dt.unsqueeze(-1),
|
||||
A * dt,
|
||||
rearrange(B, "b l n -> b l 1 n"),
|
||||
rearrange(C, "b l n -> b l 1 n"),
|
||||
self.args.chunk_size,
|
||||
device=x.device,
|
||||
)
|
||||
y = y + x * self.D.unsqueeze(-1)
|
||||
y = rearrange(y, "b l h p -> b l (h p)")
|
||||
y = self.norm(y, z)
|
||||
y = self.out_proj(y)
|
||||
|
||||
h = InferenceCache(conv_state, ssm_state)
|
||||
return y, h
|
||||
|
||||
def step(self, u: Tensor, h: InferenceCache):
|
||||
assert u.shape[1] == 1, "Only one token can be decoded per inference step"
|
||||
|
||||
zxbcdt = self.in_proj(u.squeeze(1))
|
||||
z, xBC, dt = torch.split(
|
||||
zxbcdt,
|
||||
[
|
||||
self.args.d_inner,
|
||||
self.args.d_inner + 2 * self.args.d_state,
|
||||
self.args.nheads,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
h.conv_state.copy_(torch.roll(h.conv_state, shifts=-1, dims=-1))
|
||||
h.conv_state[:, :, -1] = xBC
|
||||
xBC = torch.sum(
|
||||
h.conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
|
||||
)
|
||||
xBC += self.conv1d.bias
|
||||
xBC = silu(xBC)
|
||||
|
||||
x, B, C = torch.split(
|
||||
xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1
|
||||
)
|
||||
A = -torch.exp(self.A_log)
|
||||
|
||||
dt = F.softplus(dt + self.dt_bias)
|
||||
dA = torch.exp(dt * A)
|
||||
x = rearrange(x, "b (h p) -> b h p", p=self.args.headdim)
|
||||
dBx = torch.einsum("bh, bn, bhp -> bhpn", dt, B, x)
|
||||
h.ssm_state.copy_(h.ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
|
||||
y = torch.einsum("bhpn, bn -> bhp", h.ssm_state, C)
|
||||
y = y + rearrange(self.D, "h -> h 1") * x
|
||||
y = rearrange(y, "b h p -> b (h p)")
|
||||
y = self.norm(y, z)
|
||||
y = self.out_proj(y)
|
||||
|
||||
return y.unsqueeze(1), h
|
||||
|
||||
|
||||
def segsum(x: Tensor, device: Device = None) -> Tensor:
|
||||
if device is None:
|
||||
device = x.device
|
||||
T = x.size(-1)
|
||||
x = repeat(x, "... d -> ... d e", e=T)
|
||||
mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=-1)
|
||||
x = x.masked_fill(~mask, 0)
|
||||
x_segsum = torch.cumsum(x, dim=-2)
|
||||
mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=0)
|
||||
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
|
||||
return x_segsum
|
||||
|
||||
|
||||
def ssd(x, A, B, C, chunk_size, initial_states=None, device: Device = None):
|
||||
assert x.shape[1] % chunk_size == 0
|
||||
|
||||
x, A, B, C = [
|
||||
rearrange(m, "b (c l) ... -> b c l ...", l=chunk_size) for m in (x, A, B, C)
|
||||
]
|
||||
|
||||
A = rearrange(A, "b c l h -> b h c l")
|
||||
A_cumsum = torch.cumsum(A, dim=-1)
|
||||
|
||||
L = torch.exp(segsum(A, device=device))
|
||||
Y_diag = torch.einsum("bclhn, bcshn, bhcls, bcshp -> bclhp", C, B, L, x)
|
||||
|
||||
decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum)
|
||||
states = torch.einsum("bclhn, bhcl, bclhp -> bchpn", B, decay_states, x)
|
||||
|
||||
if initial_states is None:
|
||||
initial_states = torch.zeros_like(states[:, :1])
|
||||
states = torch.cat([initial_states, states], dim=1)
|
||||
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)), device=device))
|
||||
new_states = torch.einsum("bhzc, bchpn -> bzhpn", decay_chunk, states)
|
||||
states, final_state = new_states[:, :-1], new_states[:, -1]
|
||||
|
||||
state_decay_out = torch.exp(A_cumsum)
|
||||
Y_off = torch.einsum("bclhn, bchpn, bhcl -> bclhp", C, states, state_decay_out)
|
||||
|
||||
Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p")
|
||||
|
||||
return Y, final_state
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, d: int, eps: float = 1e-5, device: Device = None) -> None:
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(d, device=device))
|
||||
|
||||
def forward(self, x: Tensor, z: Tensor | None = None) -> Tensor:
|
||||
if z is not None:
|
||||
x = x * silu(z)
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
|
||||
|
||||
|
||||
def silu(x: Tensor) -> Tensor:
|
||||
return x * torch.sigmoid(x)
|
||||
13
pyproject.toml
Normal file
13
pyproject.toml
Normal file
@@ -0,0 +1,13 @@
|
||||
[project]
|
||||
name = "test-diffusion"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"einops>=0.7.0",
|
||||
"matplotlib>=3.8.0",
|
||||
"numpy>=1.26.0",
|
||||
"swanlab>=0.5.0",
|
||||
"torch>=2.2.0",
|
||||
]
|
||||
Reference in New Issue
Block a user