596 lines
20 KiB
Python
596 lines
20 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import asdict, dataclass
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import matplotlib
|
|
|
|
matplotlib.use("Agg")
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from matplotlib import pyplot as plt
|
|
from mpl_toolkits.mplot3d import Axes3D # noqa: F401
|
|
from torch import Tensor, nn
|
|
|
|
from mamba2_minimal import InferenceCache, Mamba2, Mamba2Config, RMSNorm
|
|
|
|
|
|
@dataclass
|
|
class TrainConfig:
|
|
seed: int = 42
|
|
device: str = "cuda"
|
|
batch_size: int = 128
|
|
steps_per_epoch: int = 50
|
|
epochs: int = 60
|
|
seq_len: int = 20
|
|
lr: float = 1e-3
|
|
weight_decay: float = 1e-2
|
|
dt_min: float = 1e-3
|
|
dt_max: float = 0.06
|
|
dt_alpha: float = 8.0
|
|
lambda_flow: float = 1.0
|
|
lambda_pos: float = 1.0
|
|
lambda_dt: float = 0.05
|
|
use_flow_loss: bool = True
|
|
use_pos_loss: bool = False
|
|
use_dt_loss: bool = True
|
|
radius_min: float = 0.6
|
|
radius_max: float = 1.4
|
|
center_min: float = -6.0
|
|
center_max: float = 6.0
|
|
center_distance_min: float = 6.0
|
|
d_model: int = 128
|
|
n_layer: int = 4
|
|
d_state: int = 64
|
|
d_conv: int = 4
|
|
expand: int = 2
|
|
headdim: int = 32
|
|
chunk_size: int = 1
|
|
use_residual: bool = False
|
|
output_dir: str = "outputs"
|
|
project: str = "as-mamba"
|
|
run_name: str = "sphere-to-sphere"
|
|
val_every: int = 200
|
|
val_samples: int = 256
|
|
val_plot_samples: int = 16
|
|
val_max_steps: int = 0
|
|
|
|
|
|
class Mamba2Backbone(nn.Module):
|
|
def __init__(self, args: Mamba2Config, use_residual: bool = True) -> None:
|
|
super().__init__()
|
|
self.args = args
|
|
self.use_residual = use_residual
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
nn.ModuleDict(
|
|
dict(
|
|
mixer=Mamba2(args),
|
|
norm=RMSNorm(args.d_model),
|
|
)
|
|
)
|
|
for _ in range(args.n_layer)
|
|
]
|
|
)
|
|
self.norm_f = RMSNorm(args.d_model)
|
|
|
|
def forward(
|
|
self, x: Tensor, h: Optional[list[InferenceCache]] = None
|
|
) -> tuple[Tensor, list[InferenceCache]]:
|
|
if h is None:
|
|
h = [None for _ in range(self.args.n_layer)]
|
|
|
|
for i, layer in enumerate(self.layers):
|
|
y, h[i] = layer["mixer"](layer["norm"](x), h[i])
|
|
x = x + y if self.use_residual else y
|
|
|
|
x = self.norm_f(x)
|
|
return x, h
|
|
|
|
|
|
class ASMamba(nn.Module):
|
|
def __init__(self, cfg: TrainConfig) -> None:
|
|
super().__init__()
|
|
self.cfg = cfg
|
|
self.dt_min = float(cfg.dt_min)
|
|
self.dt_max = float(cfg.dt_max)
|
|
|
|
args = Mamba2Config(
|
|
d_model=cfg.d_model,
|
|
n_layer=cfg.n_layer,
|
|
d_state=cfg.d_state,
|
|
d_conv=cfg.d_conv,
|
|
expand=cfg.expand,
|
|
headdim=cfg.headdim,
|
|
chunk_size=cfg.chunk_size,
|
|
)
|
|
self.backbone = Mamba2Backbone(args, use_residual=cfg.use_residual)
|
|
self.input_proj = nn.Linear(3, cfg.d_model)
|
|
self.delta_head = nn.Linear(cfg.d_model, 3)
|
|
self.dt_head = nn.Sequential(
|
|
nn.Linear(cfg.d_model, cfg.d_model),
|
|
nn.SiLU(),
|
|
nn.Linear(cfg.d_model, 1),
|
|
)
|
|
|
|
def forward(
|
|
self, x: Tensor, h: Optional[list[InferenceCache]] = None
|
|
) -> tuple[Tensor, Tensor, list[InferenceCache]]:
|
|
x_proj = self.input_proj(x)
|
|
feats, h = self.backbone(x_proj, h)
|
|
delta = self.delta_head(feats)
|
|
dt_raw = self.dt_head(feats).squeeze(-1)
|
|
dt = torch.clamp(F.softplus(dt_raw), min=self.dt_min, max=self.dt_max)
|
|
return delta, dt, h
|
|
|
|
def step(
|
|
self, x: Tensor, h: list[InferenceCache]
|
|
) -> tuple[Tensor, Tensor, list[InferenceCache]]:
|
|
delta, dt, h = self.forward(x.unsqueeze(1), h)
|
|
return delta[:, 0, :], dt[:, 0], h
|
|
|
|
def init_cache(self, batch_size: int, device: torch.device) -> list[InferenceCache]:
|
|
return [
|
|
InferenceCache.alloc(batch_size, self.backbone.args, device=device)
|
|
for _ in range(self.backbone.args.n_layer)
|
|
]
|
|
|
|
|
|
class SwanLogger:
|
|
def __init__(self, cfg: TrainConfig) -> None:
|
|
self.enabled = False
|
|
self._swan = None
|
|
self._run = None
|
|
try:
|
|
import swanlab # type: ignore
|
|
|
|
self._swan = swanlab
|
|
self._run = self._swan.init(
|
|
project=cfg.project,
|
|
experiment_name=cfg.run_name,
|
|
config=asdict(cfg),
|
|
)
|
|
self.enabled = True
|
|
except Exception:
|
|
self.enabled = False
|
|
|
|
def log(self, metrics: dict, step: int | None = None) -> None:
|
|
if not self.enabled:
|
|
return
|
|
target = self._run if self._run is not None else self._swan
|
|
if step is None:
|
|
target.log(metrics)
|
|
return
|
|
try:
|
|
target.log(metrics, step=step)
|
|
except TypeError:
|
|
payload = dict(metrics)
|
|
payload["step"] = step
|
|
target.log(payload)
|
|
|
|
def log_image(
|
|
self, key: str, image_path: Path, caption: str | None = None, step: int | None = None
|
|
) -> None:
|
|
if not self.enabled:
|
|
return
|
|
image = self._swan.Image(str(image_path), caption=caption)
|
|
self.log({key: image}, step=step)
|
|
|
|
def finish(self) -> None:
|
|
if not self.enabled:
|
|
return
|
|
finish = None
|
|
if self._run is not None:
|
|
finish = getattr(self._run, "finish", None)
|
|
if finish is None:
|
|
finish = getattr(self._swan, "finish", None)
|
|
if callable(finish):
|
|
finish()
|
|
|
|
|
|
def set_seed(seed: int) -> None:
|
|
torch.manual_seed(seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
|
|
def sample_points_in_sphere(
|
|
center: Tensor, radius: float, batch_size: int, device: torch.device
|
|
) -> Tensor:
|
|
direction = torch.randn(batch_size, 3, device=device)
|
|
direction = direction / (direction.norm(dim=-1, keepdim=True) + 1e-8)
|
|
u = torch.rand(batch_size, 1, device=device)
|
|
r = radius * torch.pow(u, 1.0 / 3.0)
|
|
return center + direction * r
|
|
|
|
|
|
def sample_sphere_params(cfg: TrainConfig, device: torch.device) -> tuple[Tensor, Tensor]:
|
|
center_a = torch.empty(3, device=device).uniform_(cfg.center_min, cfg.center_max)
|
|
center_b = torch.empty(3, device=device).uniform_(cfg.center_min, cfg.center_max)
|
|
for _ in range(128):
|
|
if torch.norm(center_a - center_b) >= cfg.center_distance_min:
|
|
break
|
|
center_b = torch.empty(3, device=device).uniform_(cfg.center_min, cfg.center_max)
|
|
if torch.norm(center_a - center_b) < 1e-3:
|
|
center_b = center_b + torch.tensor([cfg.center_distance_min, 0.0, 0.0], device=device)
|
|
radius_a = float(torch.empty(1).uniform_(cfg.radius_min, cfg.radius_max).item())
|
|
radius_b = float(torch.empty(1).uniform_(cfg.radius_min, cfg.radius_max).item())
|
|
return (center_a, torch.tensor(radius_a, device=device)), (
|
|
center_b,
|
|
torch.tensor(radius_b, device=device),
|
|
)
|
|
|
|
|
|
def sample_time_sequence(cfg: TrainConfig, batch_size: int, device: torch.device) -> Tensor:
|
|
alpha = float(cfg.dt_alpha)
|
|
if alpha <= 0:
|
|
raise ValueError("dt_alpha must be > 0")
|
|
dist = torch.distributions.Gamma(alpha, 1.0)
|
|
raw = dist.sample((batch_size, cfg.seq_len)).to(device)
|
|
dt_seq = raw / raw.sum(dim=-1, keepdim=True)
|
|
base = 1.0 / cfg.seq_len
|
|
max_dt = float(cfg.dt_max)
|
|
if max_dt <= base:
|
|
return torch.full_like(dt_seq, base)
|
|
max_current = dt_seq.max(dim=-1, keepdim=True).values
|
|
if (max_current > max_dt).any():
|
|
gamma = (max_dt - base) / (max_current - base)
|
|
gamma = gamma.clamp(0.0, 1.0)
|
|
dt_seq = gamma * dt_seq + (1.0 - gamma) * base
|
|
return dt_seq
|
|
|
|
|
|
def sample_batch(
|
|
cfg: TrainConfig,
|
|
sphere_a: tuple[Tensor, Tensor],
|
|
sphere_b: tuple[Tensor, Tensor],
|
|
device: torch.device,
|
|
) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
|
|
center_a, radius_a = sphere_a
|
|
center_b, radius_b = sphere_b
|
|
x0 = sample_points_in_sphere(center_a, float(radius_a.item()), cfg.batch_size, device)
|
|
x1 = sample_points_in_sphere(center_b, float(radius_b.item()), cfg.batch_size, device)
|
|
v_gt = x1 - x0
|
|
dt_seq = sample_time_sequence(cfg, cfg.batch_size, device)
|
|
t_seq = torch.cumsum(dt_seq, dim=-1)
|
|
t_seq = torch.cat([torch.zeros(cfg.batch_size, 1, device=device), t_seq[:, :-1]], dim=-1)
|
|
x_seq = x0[:, None, :] + t_seq[:, :, None] * v_gt[:, None, :]
|
|
return x0, x1, x_seq, t_seq, dt_seq
|
|
|
|
|
|
def compute_losses(
|
|
delta: Tensor,
|
|
dt: Tensor,
|
|
x_seq: Tensor,
|
|
x0: Tensor,
|
|
v_gt: Tensor,
|
|
t_seq: Tensor,
|
|
dt_seq: Tensor,
|
|
cfg: TrainConfig,
|
|
) -> dict[str, Tensor]:
|
|
losses: dict[str, Tensor] = {}
|
|
|
|
if cfg.use_flow_loss:
|
|
target_disp = v_gt[:, None, :] * dt.unsqueeze(-1)
|
|
losses["flow"] = F.mse_loss(delta, target_disp)
|
|
|
|
if cfg.use_pos_loss:
|
|
t_next = t_seq + dt
|
|
t_next = torch.clamp(t_next, 0.0, 1.0)
|
|
x_target = x0[:, None, :] + t_next.unsqueeze(-1) * v_gt[:, None, :]
|
|
x_next_pred = x_seq + delta
|
|
losses["pos"] = F.mse_loss(x_next_pred, x_target)
|
|
|
|
if cfg.use_dt_loss:
|
|
losses["dt"] = F.mse_loss(dt, dt_seq)
|
|
|
|
return losses
|
|
|
|
|
|
def validate(
|
|
model: ASMamba,
|
|
cfg: TrainConfig,
|
|
sphere_a: tuple[Tensor, Tensor],
|
|
sphere_b: tuple[Tensor, Tensor],
|
|
device: torch.device,
|
|
logger: SwanLogger,
|
|
step: int,
|
|
) -> None:
|
|
model.eval()
|
|
center_b, radius_b = sphere_b
|
|
max_steps = cfg.seq_len if cfg.val_max_steps <= 0 else cfg.val_max_steps
|
|
|
|
with torch.no_grad():
|
|
x0 = sample_points_in_sphere(
|
|
sphere_a[0], float(sphere_a[1].item()), cfg.val_samples, device
|
|
)
|
|
traj = rollout_trajectory(model, x0, max_steps=max_steps)
|
|
|
|
x_final = traj[:, -1, :]
|
|
center_b_cpu = center_b.detach().cpu()
|
|
radius_b_cpu = radius_b.detach().cpu()
|
|
dist = torch.linalg.norm(x_final - center_b_cpu, dim=-1)
|
|
inside = dist <= radius_b_cpu
|
|
|
|
logger.log(
|
|
{
|
|
"val/inside_ratio": float(inside.float().mean().item()),
|
|
"val/inside_count": float(inside.float().sum().item()),
|
|
"val/final_dist_mean": float(dist.mean().item()),
|
|
"val/final_dist_min": float(dist.min().item()),
|
|
"val/final_dist_max": float(dist.max().item()),
|
|
"val/max_steps": float(max_steps),
|
|
},
|
|
step=step,
|
|
)
|
|
|
|
if cfg.val_plot_samples > 0:
|
|
count = min(cfg.val_plot_samples, traj.shape[0])
|
|
if count == 0:
|
|
model.train()
|
|
return
|
|
indices = torch.linspace(0, traj.shape[0] - 1, steps=count).long()
|
|
traj_plot = traj[indices]
|
|
save_path = Path(cfg.output_dir) / f"val_traj_step_{step:06d}.png"
|
|
plot_trajectories(
|
|
traj_plot,
|
|
sphere_a,
|
|
sphere_b,
|
|
save_path,
|
|
title=f"Validation Trajectories (step {step})",
|
|
)
|
|
ratio = float(inside.float().mean().item())
|
|
logger.log_image(
|
|
"val/trajectory",
|
|
save_path,
|
|
caption=f"step {step} | inside_ratio={ratio:.3f}",
|
|
step=step,
|
|
)
|
|
|
|
model.train()
|
|
|
|
|
|
def train(cfg: TrainConfig) -> tuple[ASMamba, tuple[Tensor, Tensor], tuple[Tensor, Tensor]]:
|
|
device = torch.device(cfg.device if torch.cuda.is_available() else "cpu")
|
|
set_seed(cfg.seed)
|
|
output_dir = Path(cfg.output_dir)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
model = ASMamba(cfg).to(device)
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
|
|
logger = SwanLogger(cfg)
|
|
|
|
sphere_a, sphere_b = sample_sphere_params(cfg, device)
|
|
center_a, radius_a = sphere_a
|
|
center_b, radius_b = sphere_b
|
|
center_dist = torch.norm(center_a - center_b).item()
|
|
logger.log(
|
|
{
|
|
"sphere_a/radius": float(radius_a.item()),
|
|
"sphere_b/radius": float(radius_b.item()),
|
|
"sphere_a/center_x": float(center_a[0].item()),
|
|
"sphere_a/center_y": float(center_a[1].item()),
|
|
"sphere_a/center_z": float(center_a[2].item()),
|
|
"sphere_b/center_x": float(center_b[0].item()),
|
|
"sphere_b/center_y": float(center_b[1].item()),
|
|
"sphere_b/center_z": float(center_b[2].item()),
|
|
"sphere/center_dist": float(center_dist),
|
|
}
|
|
)
|
|
|
|
global_step = 0
|
|
for epoch in range(cfg.epochs):
|
|
model.train()
|
|
|
|
for _ in range(cfg.steps_per_epoch):
|
|
x0, x1, x_seq, t_seq, dt_seq = sample_batch(cfg, sphere_a, sphere_b, device)
|
|
v_gt = x1 - x0
|
|
|
|
delta, dt, _ = model(x_seq)
|
|
|
|
losses = compute_losses(
|
|
delta=delta,
|
|
dt=dt,
|
|
x_seq=x_seq,
|
|
x0=x0,
|
|
v_gt=v_gt,
|
|
t_seq=t_seq,
|
|
dt_seq=dt_seq,
|
|
cfg=cfg,
|
|
)
|
|
|
|
loss = torch.tensor(0.0, device=device)
|
|
if cfg.use_flow_loss and "flow" in losses:
|
|
loss = loss + cfg.lambda_flow * losses["flow"]
|
|
if cfg.use_pos_loss and "pos" in losses:
|
|
loss = loss + cfg.lambda_pos * losses["pos"]
|
|
if cfg.use_dt_loss and "dt" in losses:
|
|
loss = loss + cfg.lambda_dt * losses["dt"]
|
|
if loss.item() == 0.0:
|
|
raise RuntimeError("No loss enabled: enable at least one of flow/pos/dt.")
|
|
|
|
optimizer.zero_grad(set_to_none=True)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
if global_step % 10 == 0:
|
|
dt_min = float(dt.min().item())
|
|
dt_max = float(dt.max().item())
|
|
dt_mean = float(dt.mean().item())
|
|
dt_gt_min = float(dt_seq.min().item())
|
|
dt_gt_max = float(dt_seq.max().item())
|
|
dt_gt_mean = float(dt_seq.mean().item())
|
|
eps = 1e-6
|
|
clamp_min_ratio = float((dt <= cfg.dt_min + eps).float().mean().item())
|
|
clamp_max_ratio = float((dt >= cfg.dt_max - eps).float().mean().item())
|
|
clamp_any_ratio = float(
|
|
((dt <= cfg.dt_min + eps) | (dt >= cfg.dt_max - eps)).float().mean().item()
|
|
)
|
|
logger.log(
|
|
{
|
|
"loss/total": float(loss.item()),
|
|
"loss/flow": float(losses.get("flow", torch.tensor(0.0)).item()),
|
|
"loss/pos": float(losses.get("pos", torch.tensor(0.0)).item()),
|
|
"loss/dt": float(losses.get("dt", torch.tensor(0.0)).item()),
|
|
"dt/pred_mean": dt_mean,
|
|
"dt/pred_min": dt_min,
|
|
"dt/pred_max": dt_max,
|
|
"dt/gt_mean": dt_gt_mean,
|
|
"dt/gt_min": dt_gt_min,
|
|
"dt/gt_max": dt_gt_max,
|
|
"dt/clamp_min_ratio": clamp_min_ratio,
|
|
"dt/clamp_max_ratio": clamp_max_ratio,
|
|
"dt/clamp_any_ratio": clamp_any_ratio,
|
|
},
|
|
step=global_step,
|
|
)
|
|
|
|
if cfg.val_every > 0 and global_step > 0 and global_step % cfg.val_every == 0:
|
|
validate(model, cfg, sphere_a, sphere_b, device, logger, global_step)
|
|
dt_hist_path = Path(cfg.output_dir) / f"dt_hist_step_{global_step:06d}.png"
|
|
plot_dt_hist(
|
|
dt,
|
|
dt_seq,
|
|
dt_hist_path,
|
|
title=f"dt Distribution (step {global_step})",
|
|
)
|
|
logger.log_image(
|
|
"train/dt_hist",
|
|
dt_hist_path,
|
|
caption=f"step {global_step}",
|
|
step=global_step,
|
|
)
|
|
global_step += 1
|
|
|
|
logger.finish()
|
|
return model, sphere_a, sphere_b
|
|
|
|
|
|
def rollout_trajectory(
|
|
model: ASMamba,
|
|
x0: Tensor,
|
|
max_steps: int,
|
|
) -> Tensor:
|
|
device = x0.device
|
|
model.eval()
|
|
h = model.init_cache(batch_size=x0.shape[0], device=device)
|
|
x = x0
|
|
total_time = torch.zeros(x0.shape[0], device=device)
|
|
traj = [x0.detach().cpu()]
|
|
|
|
with torch.no_grad():
|
|
for _ in range(max_steps):
|
|
delta, dt, h = model.step(x, h)
|
|
dt = torch.clamp(dt, min=model.dt_min, max=model.dt_max)
|
|
remaining = 1.0 - total_time
|
|
overshoot = dt > remaining
|
|
if overshoot.any():
|
|
scale = (remaining / dt).unsqueeze(-1)
|
|
delta = torch.where(overshoot.unsqueeze(-1), delta * scale, delta)
|
|
dt = torch.where(overshoot, remaining, dt)
|
|
x = x + delta
|
|
total_time = total_time + dt
|
|
traj.append(x.detach().cpu())
|
|
if torch.all(total_time >= 1.0 - 1e-6):
|
|
break
|
|
|
|
return torch.stack(traj, dim=1)
|
|
|
|
|
|
def sphere_wireframe(
|
|
center: Tensor, radius: float, u_steps: int = 24, v_steps: int = 12
|
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
center_np = center.detach().cpu().numpy()
|
|
u = np.linspace(0, 2 * np.pi, u_steps)
|
|
v = np.linspace(0, np.pi, v_steps)
|
|
x = center_np[0] + radius * np.outer(np.cos(u), np.sin(v))
|
|
y = center_np[1] + radius * np.outer(np.sin(u), np.sin(v))
|
|
z = center_np[2] + radius * np.outer(np.ones_like(u), np.cos(v))
|
|
return x, y, z
|
|
|
|
|
|
def plot_trajectories(
|
|
traj: Tensor,
|
|
sphere_a: tuple[Tensor, Tensor],
|
|
sphere_b: tuple[Tensor, Tensor],
|
|
save_path: Path,
|
|
title: str = "AS-Mamba Trajectories",
|
|
) -> None:
|
|
traj = traj.detach().cpu()
|
|
if traj.dim() == 2:
|
|
traj = traj.unsqueeze(0)
|
|
traj_np = traj.numpy()
|
|
|
|
fig = plt.figure(figsize=(7, 6))
|
|
ax = fig.add_subplot(111, projection="3d")
|
|
|
|
for i in range(traj_np.shape[0]):
|
|
ax.plot(
|
|
traj_np[i, :, 0],
|
|
traj_np[i, :, 1],
|
|
traj_np[i, :, 2],
|
|
color="green",
|
|
alpha=0.6,
|
|
)
|
|
|
|
starts = traj_np[:, 0, :]
|
|
ends = traj_np[:, -1, :]
|
|
ax.scatter(starts[:, 0], starts[:, 1], starts[:, 2], color="blue", s=20, label="Start")
|
|
ax.scatter(ends[:, 0], ends[:, 1], ends[:, 2], color="red", s=20, label="End")
|
|
|
|
center_a, radius_a = sphere_a
|
|
center_b, radius_b = sphere_b
|
|
x_a, y_a, z_a = sphere_wireframe(center_a, float(radius_a.item()))
|
|
x_b, y_b, z_b = sphere_wireframe(center_b, float(radius_b.item()))
|
|
ax.plot_wireframe(x_a, y_a, z_a, color="blue", alpha=0.15, linewidth=0.5)
|
|
ax.plot_wireframe(x_b, y_b, z_b, color="red", alpha=0.15, linewidth=0.5)
|
|
|
|
ax.set_title(title)
|
|
ax.set_xlabel("X")
|
|
ax.set_ylabel("Y")
|
|
ax.set_zlabel("Z")
|
|
ax.legend(loc="best")
|
|
fig.tight_layout()
|
|
fig.savefig(save_path, dpi=160)
|
|
plt.close(fig)
|
|
|
|
|
|
def plot_dt_hist(
|
|
dt_pred: Tensor,
|
|
dt_gt: Tensor,
|
|
save_path: Path,
|
|
title: str = "dt Distribution",
|
|
) -> None:
|
|
dt_pred_np = dt_pred.detach().cpu().numpy().reshape(-1)
|
|
dt_gt_np = dt_gt.detach().cpu().numpy().reshape(-1)
|
|
|
|
fig, ax = plt.subplots(figsize=(6, 4))
|
|
ax.hist(dt_gt_np, bins=30, alpha=0.6, label="dt_gt", color="steelblue")
|
|
ax.hist(dt_pred_np, bins=30, alpha=0.6, label="dt_pred", color="orange")
|
|
ax.set_title(title)
|
|
ax.set_xlabel("dt")
|
|
ax.set_ylabel("count")
|
|
ax.legend(loc="best")
|
|
fig.tight_layout()
|
|
fig.savefig(save_path, dpi=160)
|
|
plt.close(fig)
|
|
|
|
|
|
def run_training_and_plot(cfg: TrainConfig) -> Path:
|
|
model, sphere_a, sphere_b = train(cfg)
|
|
device = next(model.parameters()).device
|
|
|
|
plot_samples = max(1, cfg.val_plot_samples)
|
|
x0 = sample_points_in_sphere(
|
|
sphere_a[0], float(sphere_a[1].item()), plot_samples, device
|
|
)
|
|
max_steps = cfg.seq_len if cfg.val_max_steps <= 0 else cfg.val_max_steps
|
|
traj = rollout_trajectory(model, x0, max_steps=max_steps)
|
|
output_dir = Path(cfg.output_dir)
|
|
save_path = output_dir / "as_mamba_trajectory.png"
|
|
plot_trajectories(traj, sphere_a, sphere_b, save_path)
|
|
return save_path
|