Compare commits

..

13 Commits

Author SHA1 Message Date
gameloader
50b1f76263 feat(models/denoiser): Add Mamba2 model implementation
Implement a minimal, single-file implementation of the Mamba-2 model in PyTorch.
Based on the paper "Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality".
2026-01-20 14:44:03 +08:00
gameloader
b9c3720eba feat(callbacks): add Weights & Biases image logging to SaveImagesHook\n\n- Introduce optional wandb_max_save_num parameter in __init__\n- Add helper methods _get_wandb_logger and _log_wandb_samples\n- Log sampled images as wandb.Image artifacts during validation epoch end (global zero only)\n- Handle proper dtype conversion and clipping for wandb compatibility\n- Respect both general max_save_num and wandb-specific limit\n\nThis enables visualization of validation samples directly in the W&B dashboard without manual intervention. 2026-01-19 16:03:54 +08:00
game-loader
6bf32b08fd feat(dataset): add HuggingFace image dataset 2026-01-18 17:10:01 +08:00
game-loader
cf45afe325 feat(data): add support for Hugging Face datasets 2026-01-16 14:25:32 +08:00
wang shuai
1d1b4d2913 Update README.md 2025-08-22 10:49:18 +08:00
wangshuai6
ae7df43ecc torch.compile 2025-07-03 20:23:59 +08:00
wang shuai
2c16b8f423 Update improved_dit.py 2025-07-03 20:21:04 +08:00
wangshuai6
8038c16bee disperse loss 2025-06-14 12:14:01 +08:00
wangshuai6
598f7b40a2 disperse loss 2025-06-13 10:41:31 +08:00
wangshuai6
3093a65151 disperse loss 2025-06-12 22:18:15 +08:00
wangshuai6
910b764275 pixelddt-t2i update 2025-06-05 09:56:03 +08:00
wangshuai6
26c8e54edb pixelddt-t2i update 2025-06-05 09:55:43 +08:00
wangshuai6
99d92c94e7 fix bugs(admas timedeltas) 2025-05-20 12:30:40 +08:00
15 changed files with 841 additions and 17 deletions

4
.gitignore vendored Normal file
View File

@@ -0,0 +1,4 @@
__pycache__/
*.py[cod]
.venv/
.idea/

View File

@@ -13,6 +13,7 @@ We decouple diffusion transformer into encoder-decoder design, and surprisingly
* We achieves **1.26 FID** on ImageNet256x256 Benchmark with DDT-XL/2(22en6de).
* We achieves **1.28 FID** on ImageNet512x512 Benchmark with DDT-XL/2(22en6de).
* As a byproduct, our DDT can reuse encoder among adjacent steps to accelerate inference.
## Visualizations
![](./figs/teaser.png)
## Checkpoints

View File

@@ -0,0 +1,116 @@
seed_everything: true
tags:
exp: &exp ddt_butterflies_b2_256
torch_hub_dir: null
huggingface_cache_dir: null
trainer:
default_root_dir: workdirs
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
precision: bf16-mixed
logger:
class_path: lightning.pytorch.loggers.WandbLogger
init_args:
project: ddt_butterflies
name: *exp
save_dir: workdirs
num_sanity_val_steps: 0
max_steps: 200000
val_check_interval: 2000
check_val_every_n_epoch: null
log_every_n_steps: 50
deterministic: null
inference_mode: true
use_distributed_sampler: false
callbacks:
- class_path: src.callbacks.model_checkpoint.CheckpointHook
init_args:
every_n_train_steps: 10000
save_top_k: -1
save_last: true
- class_path: src.callbacks.save_images.SaveImagesHook
init_args:
save_dir: val
max_save_num: 64
model:
vae:
class_path: src.models.vae.LatentVAE
init_args:
precompute: false
weight_path: stabilityai/sd-vae-ft-ema
denoiser:
class_path: src.models.denoiser.decoupled_improved_dit.DDT
init_args:
in_channels: 4
patch_size: 2
num_groups: 12
hidden_size: &hidden_dim 768
num_blocks: 12
num_encoder_blocks: 8
num_classes: 1
conditioner:
class_path: src.models.conditioner.LabelConditioner
init_args:
null_class: 1
diffusion_trainer:
class_path: src.diffusion.stateful_flow_matching.training_repa.REPATrainer
init_args:
lognorm_t: true
encoder_weight_path: dinov2_vitb14
align_layer: 4
proj_denoiser_dim: *hidden_dim
proj_hidden_dim: *hidden_dim
proj_encoder_dim: 768
scheduler: &scheduler src.diffusion.stateful_flow_matching.scheduling.LinearScheduler
diffusion_sampler:
class_path: src.diffusion.stateful_flow_matching.sampling.EulerSampler
init_args:
num_steps: 250
guidance: 1.0
timeshift: 1.0
state_refresh_rate: 1
guidance_interval_min: 0.3
guidance_interval_max: 1.0
scheduler: *scheduler
w_scheduler: src.diffusion.stateful_flow_matching.scheduling.LinearScheduler
guidance_fn: src.diffusion.base.guidance.simple_guidance_fn
last_step: 0.04
step_fn: src.diffusion.stateful_flow_matching.sampling.ode_step_fn
ema_tracker:
class_path: src.callbacks.simple_ema.SimpleEMA
init_args:
decay: 0.9999
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 1e-3
betas:
- 0.9
- 0.95
weight_decay: 0.0
data:
train_dataset: hf_image
train_root: ./data/butterflies
test_nature_root: null
test_gen_root: null
train_image_size: 256
train_batch_size: 128
train_num_workers: 8
train_prefetch_factor: 2
train_hf_name: huggan/smithsonian_butterflies_subset
train_hf_split: train
train_hf_image_column: image
train_hf_label_column: null
train_hf_cache_dir: null
eval_max_num_instances: 256
pred_batch_size: 32
pred_num_workers: 4
pred_seeds: null
pred_selected_classes: null
num_classes: 1
latent_shape:
- 4
- 32
- 32

View File

@@ -6,4 +6,5 @@ jsonargparse[signatures]>=4.27.7
torchvision
timm
accelerate
gradio
gradio
datasets

View File

@@ -1,5 +1,6 @@
import lightning.pytorch as pl
from lightning.pytorch import Callback
from lightning.pytorch.loggers import WandbLogger
import os.path
@@ -10,15 +11,20 @@ from concurrent.futures import ThreadPoolExecutor
from lightning.pytorch.utilities.types import STEP_OUTPUT
from lightning_utilities.core.rank_zero import rank_zero_info
try:
import wandb
except ImportError: # pragma: no cover - optional dependency
wandb = None
def process_fn(image, path):
Image.fromarray(image).save(path)
class SaveImagesHook(Callback):
def __init__(self, save_dir="val", max_save_num=100, compressed=True):
def __init__(self, save_dir="val", max_save_num=100, compressed=True, wandb_max_save_num=None):
self.save_dir = save_dir
self.max_save_num = max_save_num
self.compressed = compressed
self.wandb_max_save_num = wandb_max_save_num
def save_start(self, target_dir):
self.target_dir = target_dir
@@ -68,6 +74,50 @@ class SaveImagesHook(Callback):
self._have_saved_num = 0
self.executor_pool = None
def _get_wandb_logger(self, trainer: "pl.Trainer"):
if isinstance(trainer.logger, WandbLogger):
return trainer.logger
if getattr(trainer, "loggers", None):
for logger in trainer.loggers:
if isinstance(logger, WandbLogger):
return logger
return None
def _log_wandb_samples(self, trainer: "pl.Trainer"):
if not trainer.is_global_zero:
return
if wandb is None:
return
if not self.samples:
return
if self.max_save_num == 0:
return
wandb_logger = self._get_wandb_logger(trainer)
if wandb_logger is None:
return
max_num = self.wandb_max_save_num if self.wandb_max_save_num is not None else self.max_save_num
if max_num <= 0:
return
images = []
remaining = max_num
for batch in self.samples:
if remaining <= 0:
break
take = min(batch.shape[0], remaining)
chunk = batch[:take]
if chunk.dtype != numpy.uint8:
chunk = numpy.clip(chunk, 0, 255).astype(numpy.uint8)
images.extend([wandb.Image(img) for img in chunk])
remaining -= take
if images:
wandb_logger.experiment.log(
{f"{self.save_dir}/samples": images},
step=trainer.global_step,
)
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
target_dir = os.path.join(trainer.default_root_dir, self.save_dir, f"iter_{trainer.global_step}")
self.save_start(target_dir)
@@ -84,6 +134,7 @@ class SaveImagesHook(Callback):
return self.process_batch(trainer, pl_module, outputs, batch)
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._log_wandb_samples(trainer)
self.save_end()
def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
@@ -102,4 +153,4 @@ class SaveImagesHook(Callback):
return self.process_batch(trainer, pl_module, samples, batch)
def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.save_end()
self.save_end()

View File

@@ -0,0 +1,43 @@
from typing import Optional
import torch
from torch.utils.data import Dataset
from torchvision.transforms import Normalize
from torchvision.transforms.functional import to_tensor
from PIL import Image
from datasets import load_dataset
from src.data.dataset.metric_dataset import CenterCrop
class HuggingFaceImageDataset(Dataset):
def __init__(
self,
name: str,
split: str = "train",
image_column: str = "image",
label_column: Optional[str] = None,
resolution: int = 256,
cache_dir: Optional[str] = None,
) -> None:
self.dataset = load_dataset(name, split=split, cache_dir=cache_dir)
self.image_column = image_column
self.label_column = label_column
self.transform = CenterCrop(resolution)
self.normalize = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
def __len__(self) -> int:
return len(self.dataset)
def __getitem__(self, idx: int):
item = self.dataset[idx]
image = item[self.image_column]
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
image = image.convert("RGB")
image = self.transform(image)
raw_image = to_tensor(image)
normalized_image = self.normalize(raw_image)
label = 0 if self.label_column is None else int(item[self.label_column])
return raw_image, normalized_image, label

View File

@@ -41,23 +41,25 @@ class AdamLMSampler(BaseSampler):
self,
order: int = 2,
timeshift: float = 1.0,
guidance_interval_min: float = 0.0,
guidance_interval_max: float = 1.0,
lms_transform_fn: Callable = nop,
w_scheduler: BaseScheduler = None,
step_fn: Callable = ode_step_fn,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.step_fn = step_fn
self.w_scheduler = w_scheduler
assert self.scheduler is not None
assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
assert self.step_fn in [ode_step_fn, ]
self.order = order
self.lms_transform_fn = lms_transform_fn
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
self.guidance_interval_min = guidance_interval_min
self.guidance_interval_max = guidance_interval_max
self.timesteps = shift_respace_fn(timesteps, timeshift)
self.timedeltas = self.timesteps[1:] - self.timesteps[:-1]
self._reparameterize_coeffs()
@@ -93,7 +95,11 @@ class AdamLMSampler(BaseSampler):
cfg_x = torch.cat([x, x], dim=0)
cfg_t = t_cur.repeat(2)
out = net(cfg_x, cfg_t, cfg_condition)
out = self.guidance_fn(out, self.guidance)
if t_cur[0] > self.guidance_interval_min and t_cur[0] < self.guidance_interval_max:
guidance = self.guidance
out = self.guidance_fn(out, guidance)
else:
out = self.guidance_fn(out, 1.0)
pred_trajectory.append(out)
out = torch.zeros_like(out)
order = len(self.solver_coeffs[i])

View File

@@ -0,0 +1,104 @@
import torch
import copy
import timm
from torch.nn import Parameter
from src.utils.no_grad import no_grad
from typing import Callable, Iterator, Tuple
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.transforms import Normalize
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
def time_shift_fn(t, timeshift=1.0):
return t/(t+(1-t)*timeshift)
class DisperseTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
feat_loss_weight: float=0.5,
lognorm_t=False,
timeshift=1.0,
align_layer=8,
temperature=1.0,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.timeshift = timeshift
self.loss_weight_fn = loss_weight_fn
self.feat_loss_weight = feat_loss_weight
self.align_layer = align_layer
self.temperature = temperature
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size, c, height, width = x.shape
if self.lognorm_t:
base_t = torch.randn((batch_size), device=x.device, dtype=torch.float32).sigmoid()
else:
base_t = torch.rand((batch_size), device=x.device, dtype=torch.float32)
t = time_shift_fn(base_t, self.timeshift).to(x.dtype)
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
src_feature = []
def forward_hook(net, input, output):
feature = output
if isinstance(feature, tuple):
feature = feature[0] # mmdit
src_feature.append(feature)
if getattr(net, "encoder", None) is not None:
handle = net.encoder.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
else:
handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
out = net(x_t, t, y)
handle.remove()
disperse_loss = 0.0
world_size = torch.distributed.get_world_size()
for sf in src_feature:
gathered_sf = [torch.zeros_like(sf) for _ in range(world_size)]
torch.distributed.all_gather(gathered_sf, sf)
gathered_sf = torch.cat(gathered_sf, dim=0)
sf = gathered_sf.view(batch_size * world_size, -1)
sf = sf.view(batch_size * world_size, -1)
# normalize sf
sf = sf / torch.norm(sf, dim=1, keepdim=True)
distance = torch.nn.functional.pdist(sf, p=2) ** 2
sf_disperse_distance = torch.exp(-distance / self.temperature) + 1e-5
disperse_loss += sf_disperse_distance.mean().log()
weight = self.loss_weight_fn(alpha, sigma)
fm_loss = weight*(out - v_t)**2
out = dict(
fm_loss=fm_loss.mean(),
disperse_loss=disperse_loss.mean(),
loss=fm_loss.mean() + self.feat_loss_weight*disperse_loss.mean(),
)
return out

View File

@@ -41,22 +41,24 @@ class AdamLMSampler(BaseSampler):
self,
order: int = 2,
timeshift: float = 1.0,
guidance_interval_min: float = 0.0,
guidance_interval_max: float = 1.0,
state_refresh_rate: int = 1,
lms_transform_fn: Callable = nop,
w_scheduler: BaseScheduler = None,
step_fn: Callable = ode_step_fn,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.step_fn = step_fn
self.w_scheduler = w_scheduler
self.state_refresh_rate = state_refresh_rate
assert self.scheduler is not None
assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
assert self.step_fn in [ode_step_fn, ]
self.order = order
self.lms_transform_fn = lms_transform_fn
self.guidance_interval_min = guidance_interval_min
self.guidance_interval_max = guidance_interval_max
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
@@ -98,7 +100,11 @@ class AdamLMSampler(BaseSampler):
if i % self.state_refresh_rate == 0:
state = None
out, state = net(cfg_x, cfg_t, cfg_condition, state)
out = self.guidance_fn(out, self.guidance)
if t_cur[0] > self.guidance_interval_min and t_cur[0] < self.guidance_interval_max:
guidance = self.guidance
out = self.guidance_fn(out, guidance)
else:
out = self.guidance_fn(out, 1.0)
pred_trajectory.append(out)
out = torch.zeros_like(out)
order = len(self.solver_coeffs[i])

View File

@@ -29,6 +29,11 @@ class DataModule(pl.LightningDataModule):
var_transform_engine: VARTransformEngine = None,
train_prefetch_factor=2,
train_dataset: str = None,
train_hf_name: str = None,
train_hf_split: str = "train",
train_hf_image_column: str = "image",
train_hf_label_column: str = None,
train_hf_cache_dir: str = None,
eval_batch_size=32,
eval_num_workers=4,
eval_max_num_instances=50000,
@@ -45,6 +50,11 @@ class DataModule(pl.LightningDataModule):
self.train_root = train_root
self.train_image_size = train_image_size
self.train_dataset = train_dataset
self.train_hf_name = train_hf_name
self.train_hf_split = train_hf_split
self.train_hf_image_column = train_hf_image_column
self.train_hf_label_column = train_hf_label_column
self.train_hf_cache_dir = train_hf_cache_dir
# stupid data_convert override, just to make nebular happy
self.train_batch_size = train_batch_size
self.train_num_workers = train_num_workers
@@ -101,6 +111,18 @@ class DataModule(pl.LightningDataModule):
self.train_dataset = PixImageNet512(
root=self.train_root,
)
elif self.train_dataset == "hf_image":
from src.data.dataset.hf_image import HuggingFaceImageDataset
if self.train_hf_name is None:
raise ValueError("train_hf_name must be set when train_dataset=hf_image")
self.train_dataset = HuggingFaceImageDataset(
name=self.train_hf_name,
split=self.train_hf_split,
image_column=self.train_hf_image_column,
label_column=self.train_hf_label_column,
resolution=self.train_image_size,
cache_dir=self.train_hf_cache_dir,
)
else:
raise NotImplementedError("no such dataset")

View File

@@ -34,6 +34,7 @@ class LightningModel(pl.LightningModule):
ema_tracker: Optional[EMACallable] = None,
optimizer: OptimizerCallable = None,
lr_scheduler: LRSchedulerCallable = None,
compile: bool = False
):
super().__init__()
self.vae = vae
@@ -48,6 +49,7 @@ class LightningModel(pl.LightningModule):
# self.model_loader = ModelLoader()
self._strict_loading = False
self._compile = compile
def configure_model(self) -> None:
self.trainer.strategy.barrier()
@@ -61,6 +63,11 @@ class LightningModel(pl.LightningModule):
no_grad(self.diffusion_sampler)
no_grad(self.ema_denoiser)
# add compile to speed up
if self._compile:
self.denoiser = torch.compile(self.denoiser)
self.ema_denoiser = torch.compile(self.ema_denoiser)
def configure_callbacks(self) -> Union[Sequence[Callback], Callback]:
ema_tracker = self.ema_tracker(self.denoiser, self.ema_denoiser)
return [ema_tracker]

View File

@@ -305,4 +305,4 @@ class DDT(nn.Module):
x = self.blocks[i](x, s, pos, None)
x = self.final_layer(x, s)
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
return x, s
return x, s

View File

@@ -251,11 +251,11 @@ class DiT(nn.Module):
self.initialize_weights()
self.precompute_pos = dict()
def fetch_pos(self, height, width, device, dtype):
def fetch_pos(self, height, width, device):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)].to(device, dtype)
return self.precompute_pos[(height, width)].to(device)
else:
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device, dtype)
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
self.precompute_pos[(height, width)] = pos
return pos
@@ -289,7 +289,7 @@ class DiT(nn.Module):
B, _, H, W = x.shape
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
x = self.x_embedder(x)
pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device, x.dtype)
pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device)
B, L, C = x.shape
t = self.t_embedder(t.view(-1)).view(B, -1, C)
y = self.y_embedder(y).view(B, 1, C)
@@ -298,4 +298,4 @@ class DiT(nn.Module):
x = block(x, condition, pos, masks[i])
x = self.final_layer(x, condition)
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
return x
return x

View File

@@ -0,0 +1,463 @@
"""
mamba2-minimal
==============
A minimal, single-file implementation of the Mamba-2 model in PyTorch.
> **Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality**
> Authors: Tri Dao, Albert Gu
> Paper: https://arxiv.org/abs/2405.21060
"""
import json
import math
from dataclasses import dataclass
from typing import Iterable, NamedTuple, TypeAlias, cast
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import LongTensor, Tensor, nn
Device: TypeAlias = str | torch.device | None
@dataclass
class Mamba2Config:
d_model: int # model dimension (D)
n_layer: int = 24 # number of Mamba-2 layers in the language model
d_state: int = 128 # state dimension (N)
d_conv: int = 4 # convolution kernel size
expand: int = 2 # expansion factor (E)
headdim: int = 64 # head dimension (P)
chunk_size: int = 64 # matrix partition size (Q)
vocab_size: int = 50277
pad_vocab_size_multiple: int = 16
def __post_init__(self):
self.d_inner = self.expand * self.d_model
assert self.d_inner % self.headdim == 0
self.nheads = self.d_inner // self.headdim
if self.vocab_size % self.pad_vocab_size_multiple != 0:
self.vocab_size += (
self.pad_vocab_size_multiple
- self.vocab_size % self.pad_vocab_size_multiple
)
class InferenceCache(NamedTuple):
conv_state: Tensor # (batch, d_inner + 2 * d_state, d_conv)
ssm_state: Tensor # (batch, nheads, headdim, d_state)
@staticmethod
def alloc(batch_size: int, args: Mamba2Config, device: Device = None):
return InferenceCache(
torch.zeros(
batch_size, args.d_inner + 2 * args.d_state, args.d_conv, device=device
),
torch.zeros(
batch_size, args.nheads, args.headdim, args.d_state, device=device
),
)
class Mamba2LMHeadModel(nn.Module):
def __init__(self, args: Mamba2Config, device: Device = None):
super().__init__()
self.args = args
self.device = device
self.backbone = nn.ModuleDict(
dict(
embedding=nn.Embedding(args.vocab_size, args.d_model, device=device),
layers=nn.ModuleList(
[
nn.ModuleDict(
dict(
mixer=Mamba2(args, device=device),
norm=RMSNorm(args.d_model, device=device),
)
)
for _ in range(args.n_layer)
]
),
norm_f=RMSNorm(args.d_model, device=device),
)
)
self.lm_head = nn.Linear(
args.d_model, args.vocab_size, bias=False, device=device
)
self.lm_head.weight = self.backbone.embedding.weight
@staticmethod
def from_pretrained(huggingface_model_id: str, device: Device = None):
from transformers.utils import CONFIG_NAME, WEIGHTS_NAME
from transformers.utils.hub import cached_file
config_path = cached_file(huggingface_model_id, CONFIG_NAME)
assert config_path, "Failed to get huggingface config file"
state_dict_path = cached_file(huggingface_model_id, WEIGHTS_NAME)
assert state_dict_path, "Failed to get huggingface state dict file"
config = json.load(open(config_path))
args = Mamba2Config(
d_model=config["d_model"],
n_layer=config["n_layer"],
vocab_size=config["vocab_size"],
pad_vocab_size_multiple=config["pad_vocab_size_multiple"],
)
map_location = "cpu" if device is None else device
state_dict = torch.load(
state_dict_path, weights_only=True, map_location=map_location, mmap=True
)
model = Mamba2LMHeadModel(args, device=device)
model.load_state_dict(state_dict)
model.eval()
return model
def forward(
self, input_ids: LongTensor, h: list[InferenceCache] | list[None] | None = None
) -> tuple[LongTensor, list[InferenceCache]]:
"""
Arguments
input_ids: (batch, seqlen) tokens from `EleutherAI/gpt-neox-20b` tokenizer
h: hidden states for inference step. If present the constant-time
(wrt sequence length) inference path will be taken, input_ids
should have shape (batch, 1) containing the next batch of prompt
token.
Return (logits, h)
logits: (batch, seqlen, vocab_size)
h: updated inference cache after processing `input_ids`
"""
seqlen = input_ids.shape[1]
if h is None:
h = [None for _ in range(self.args.n_layer)]
x = self.backbone.embedding(input_ids)
for i, layer in enumerate(self.backbone.layers):
y, h[i] = layer.mixer(layer.norm(x), h[i])
x = y + x
x = self.backbone.norm_f(x)
logits = self.lm_head(x)
return logits[:, :seqlen], cast(list[InferenceCache], h)
def generate(
self,
input_ids: LongTensor,
max_new_length: int = 20,
temperature: float = 1.0,
top_k: int = 50,
top_p: float = 1.0,
eos_token_id: int = 0,
) -> Iterable[tuple[int, list[InferenceCache]]]:
prefix, tokens = input_ids[:-1], input_ids[-1:].unsqueeze(0)
cache_device = input_ids.device
# Process prompt
# The input sequence to forward (non-inference path) must have length multiple that of chunk_size.
# We split out excess tokens so that n_chunked tokens can be processed by one forward call and
# process the rest in multiple inference steps.
n_chunked = (prefix.shape[0] // self.args.chunk_size) * self.args.chunk_size
if n_chunked > 0:
_, h = self(prefix[:n_chunked].unsqueeze(0), None)
else:
h = [
InferenceCache.alloc(1, self.args, device=cache_device)
for _ in range(self.args.n_layer)
]
for i in range(n_chunked, prefix.shape[0]):
_, h = self(prefix[i : i + 1].unsqueeze(0), h)
# Generate
for _ in range(max_new_length):
with torch.no_grad():
out, h = self(tokens, h)
logits = out[0, -1]
if temperature != 1.0:
logits = logits / temperature
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, k=top_k)[0][-1]
logits[indices_to_remove] = -torch.inf
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cum_probs > 0.5
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
sorted_indices_to_remove[0] = False
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = -torch.inf
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
if next_token.item() == eos_token_id:
return
tokens = next_token.unsqueeze(0)
yield cast(int, next_token.item()), h
class Mamba2(nn.Module):
def __init__(self, args: Mamba2Config, device: Device = None):
super().__init__()
self.args = args
self.device = device
# Order: (z, x, B, C, dt)
d_in_proj = 2 * args.d_inner + 2 * args.d_state + args.nheads
self.in_proj = nn.Linear(args.d_model, d_in_proj, bias=False, device=device)
conv_dim = args.d_inner + 2 * args.d_state
self.conv1d = nn.Conv1d(
in_channels=conv_dim,
out_channels=conv_dim,
kernel_size=args.d_conv,
groups=conv_dim,
padding=args.d_conv - 1,
device=device,
)
self.dt_bias = nn.Parameter(torch.empty(args.nheads, device=device))
self.A_log = nn.Parameter(torch.empty(args.nheads, device=device))
self.D = nn.Parameter(torch.empty(args.nheads, device=device))
self.norm = RMSNorm(args.d_inner, device=device)
self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=False, device=device)
self.reset_parameters()
def reset_parameters(self):
# Mamba-style parameter init for stable SSM dynamics.
dt_min, dt_max = 1e-3, 1e-1
device = self.dt_bias.device
dtype = self.dt_bias.dtype
dt = torch.exp(
torch.rand(self.args.nheads, device=device, dtype=dtype)
* (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
)
with torch.no_grad():
self.dt_bias.copy_(torch.log(torch.expm1(dt)))
self.A_log.copy_(
torch.log(
torch.arange(
1, self.args.nheads + 1, device=device, dtype=self.A_log.dtype
)
)
)
self.D.fill_(1.0)
def forward(self, u: Tensor, h: InferenceCache | None = None):
"""
Arguments
u: (batch, seqlen, d_model) input. seqlen should be a multiple of chunk_size.
h: hidden states for inference step. Initialized to 0s if not present.
Return (y, h)
y: (batch, seqlen, d_model) output
h: updated inference cache after processing `u`
"""
if h:
return self.step(u, h)
A = -torch.exp(self.A_log) # (nheads,)
zxbcdt = self.in_proj(u) # (batch, seqlen, d_in_proj)
z, xBC, dt = torch.split(
zxbcdt,
[
self.args.d_inner,
self.args.d_inner + 2 * self.args.d_state,
self.args.nheads,
],
dim=-1,
)
dt = F.softplus(dt + self.dt_bias) # (batch, seqlen, nheads)
# Pad or truncate xBC seqlen to d_conv
conv_state = F.pad(
rearrange(xBC, "b l d -> b d l"), (self.args.d_conv - u.shape[1], 0)
)
xBC = silu(
self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, : u.shape[1], :]
) # (batch, seqlen, d_inner + 2 * d_state))
x, B, C = torch.split(
xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1
)
x = rearrange(x, "b l (h p) -> b l h p", p=self.args.headdim)
y, ssm_state = ssd(
x * dt.unsqueeze(-1),
A * dt,
rearrange(B, "b l n -> b l 1 n"),
rearrange(C, "b l n -> b l 1 n"),
self.args.chunk_size,
device=x.device,
)
y = y + x * self.D.unsqueeze(-1)
y = rearrange(y, "b l h p -> b l (h p)")
y = self.norm(y, z)
y = self.out_proj(y)
h = InferenceCache(conv_state, ssm_state)
return y, h
def step(self, u: Tensor, h: InferenceCache) -> tuple[Tensor, InferenceCache]:
"""Take a single inference step for the current input and hidden state
Unlike attention-based models, RNN-based models (eg Mamba) does not need
to look back at all the past tokens to generate a new token. Instead a
hidden state (initialized to 0s initially) is updated for each input and
passed to the next inference step. This means that the total inference
time is linear with respect to the sequence length instead of quadratic
in attention's case.
Arguments
u: (batch, 1, d_model)
h: initial/running hidden state
Return (y, h)
y: (batch, 1, d_model)
h: updated hidden state
"""
assert u.shape[1] == 1, "Only one token can be decoded per inference step"
zxbcdt = self.in_proj(u.squeeze(1)) # (batch, d_in_proj)
z, xBC, dt = torch.split(
zxbcdt,
[
self.args.d_inner,
self.args.d_inner + 2 * self.args.d_state,
self.args.nheads,
],
dim=-1,
)
# Advance convolution input
h.conv_state.copy_(torch.roll(h.conv_state, shifts=-1, dims=-1))
h.conv_state[:, :, -1] = xBC
# Convolution step
xBC = torch.sum(
h.conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
)
xBC += self.conv1d.bias
xBC = silu(xBC)
x, B, C = torch.split(
xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1
)
A = -torch.exp(self.A_log) # (nheads,)
# SSM step
dt = F.softplus(dt + self.dt_bias) # (batch, nheads)
dA = torch.exp(dt * A) # (batch, nheads)
x = rearrange(x, "b (h p) -> b h p", p=self.args.headdim)
dBx = torch.einsum("bh, bn, bhp -> bhpn", dt, B, x)
h.ssm_state.copy_(h.ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
y = torch.einsum("bhpn, bn -> bhp", h.ssm_state, C)
y = y + rearrange(self.D, "h -> h 1") * x
y = rearrange(y, "b h p -> b (h p)")
y = self.norm(y, z)
y = self.out_proj(y)
return y.unsqueeze(1), h
def segsum(x: Tensor, device: Device = None) -> Tensor:
"""Stable segment sum calculation.
`exp(segsum(A))` produces a 1-semiseparable matrix, which is equivalent to a scalar SSM.
Source: https://github.com/state-spaces/mamba/blob/219f03c840d5a44e7d42e4e728134834fddccf45/mamba_ssm/modules/ssd_minimal.py#L23-L32
"""
if device is None:
device = x.device
T = x.size(-1)
x = repeat(x, "... d -> ... d e", e=T)
mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=-1)
x = x.masked_fill(~mask, 0)
x_segsum = torch.cumsum(x, dim=-2)
mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=0)
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
return x_segsum
def ssd(x, A, B, C, chunk_size, initial_states=None, device: Device = None):
"""Structed State Space Duality (SSD) - the core of Mamba-2
This is almost the exact same minimal SSD code from the blog post.
Arguments
x: (batch, seqlen, n_heads, d_head)
A: (batch, seqlen, n_heads)
B: (batch, seqlen, n_heads, d_state)
C: (batch, seqlen, n_heads, d_state)
Return
y: (batch, seqlen, n_heads, d_head)
Source
1. https://tridao.me/blog/2024/mamba2-part3-algorithm/
2. https://github.com/state-spaces/mamba/blob/219f03c840d5a44e7d42e4e728134834fddccf45/mamba_ssm/modules/ssd_minimal.py#L34-L78
"""
assert x.shape[1] % chunk_size == 0
# Rearrange into chunks
# Step 1, 2 and 4 of SSD can be computed in parallel for each chunk across devices (sequence parallel)
# This is not implemented and left as an exercise for the reader 😜
x, A, B, C = [
rearrange(m, "b (c l) ... -> b c l ...", l=chunk_size) for m in (x, A, B, C)
]
A = rearrange(A, "b c l h -> b h c l")
A_cumsum = torch.cumsum(A, dim=-1)
# 1. Compute the output for each intra-chunk (diagonal blocks)
L = torch.exp(segsum(A, device=device))
Y_diag = torch.einsum("bclhn, bcshn, bhcls, bcshp -> bclhp", C, B, L, x)
# 2. Compute the state for each intra-chunk
# (right term of low-rank factorization of off-diagonal blocks; B terms)
decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum)
states = torch.einsum("bclhn, bhcl, bclhp -> bchpn", B, decay_states, x)
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
# (middle term of factorization of off-diag blocks; A terms)
if initial_states is None:
initial_states = torch.zeros_like(states[:, :1])
states = torch.cat([initial_states, states], dim=1)
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)), device=device))
new_states = torch.einsum("bhzc, bchpn -> bzhpn", decay_chunk, states)
states, final_state = new_states[:, :-1], new_states[:, -1]
# 4. Compute state -> output conversion per chunk
# (left term of low-rank factorization of off-diagonal blocks; C terms)
state_decay_out = torch.exp(A_cumsum)
Y_off = torch.einsum("bclhn, bchpn, bhcl -> bclhp", C, states, state_decay_out)
# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p")
return Y, final_state
class RMSNorm(nn.Module):
def __init__(self, d: int, eps: float = 1e-5, device: Device = None):
"""Gated Root Mean Square Layer Normalization
Paper: https://arxiv.org/abs/1910.07467
"""
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(d, device=device))
def forward(self, x, z=None):
if z is not None:
x = x * silu(z)
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
def silu(x):
"""Applies the Sigmoid Linear Unit (SiLU), element-wise.
Define this manually since torch's version doesn't seem to work on MPS.
"""
return x * F.sigmoid(x)

Binary file not shown.