Files
mamba_diffusion/mamba2_minimal.py
gameloader c58a73ae26 feat(mamba): add Mamba2 implementation
Add initial project structure including core Mamba2 logic,
entry point, and uv-based dependency management.
2026-01-21 12:54:49 +08:00

247 lines
7.9 KiB
Python

"""
mamba2-minimal
==============
Minimal Mamba-2 implementation for sequence modeling.
Reference: https://arxiv.org/abs/2405.21060
"""
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import NamedTuple, TypeAlias
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import Tensor, nn
Device: TypeAlias = str | torch.device | None
@dataclass
class Mamba2Config:
d_model: int
n_layer: int = 4
d_state: int = 64
d_conv: int = 4
expand: int = 2
headdim: int = 32
chunk_size: int = 1
def __post_init__(self) -> None:
self.d_inner = self.expand * self.d_model
if self.d_inner % self.headdim != 0:
raise ValueError("d_inner must be divisible by headdim")
self.nheads = self.d_inner // self.headdim
class InferenceCache(NamedTuple):
conv_state: Tensor
ssm_state: Tensor
@staticmethod
def alloc(batch_size: int, args: Mamba2Config, device: Device = None) -> "InferenceCache":
return InferenceCache(
torch.zeros(
batch_size, args.d_inner + 2 * args.d_state, args.d_conv, device=device
),
torch.zeros(
batch_size, args.nheads, args.headdim, args.d_state, device=device
),
)
class Mamba2(nn.Module):
def __init__(self, args: Mamba2Config, device: Device = None) -> None:
super().__init__()
self.args = args
self.device = device
d_in_proj = 2 * args.d_inner + 2 * args.d_state + args.nheads
self.in_proj = nn.Linear(args.d_model, d_in_proj, bias=False, device=device)
conv_dim = args.d_inner + 2 * args.d_state
self.conv1d = nn.Conv1d(
in_channels=conv_dim,
out_channels=conv_dim,
kernel_size=args.d_conv,
groups=conv_dim,
padding=args.d_conv - 1,
device=device,
)
self.dt_bias = nn.Parameter(torch.empty(args.nheads, device=device))
self.A_log = nn.Parameter(torch.empty(args.nheads, device=device))
self.D = nn.Parameter(torch.empty(args.nheads, device=device))
self.norm = RMSNorm(args.d_inner, device=device)
self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=False, device=device)
self.reset_parameters()
def reset_parameters(self) -> None:
dt_min, dt_max = 1e-3, 1e-1
device = self.dt_bias.device
dtype = self.dt_bias.dtype
dt = torch.exp(
torch.rand(self.args.nheads, device=device, dtype=dtype)
* (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
)
with torch.no_grad():
self.dt_bias.copy_(torch.log(torch.expm1(dt)))
self.A_log.copy_(
torch.log(
torch.arange(
1, self.args.nheads + 1, device=device, dtype=self.A_log.dtype
)
)
)
self.D.fill_(1.0)
def forward(self, u: Tensor, h: InferenceCache | None = None):
if h is not None:
return self.step(u, h)
A = -torch.exp(self.A_log)
zxbcdt = self.in_proj(u)
z, xBC, dt = torch.split(
zxbcdt,
[
self.args.d_inner,
self.args.d_inner + 2 * self.args.d_state,
self.args.nheads,
],
dim=-1,
)
dt = F.softplus(dt + self.dt_bias)
xBC_t = rearrange(xBC, "b l d -> b d l")
if u.shape[1] >= self.args.d_conv:
conv_state = xBC_t[:, :, -self.args.d_conv :]
else:
conv_state = F.pad(xBC_t, (self.args.d_conv - u.shape[1], 0))
xBC = silu(
self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, : u.shape[1], :]
)
x, B, C = torch.split(
xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1
)
x = rearrange(x, "b l (h p) -> b l h p", p=self.args.headdim)
y, ssm_state = ssd(
x * dt.unsqueeze(-1),
A * dt,
rearrange(B, "b l n -> b l 1 n"),
rearrange(C, "b l n -> b l 1 n"),
self.args.chunk_size,
device=x.device,
)
y = y + x * self.D.unsqueeze(-1)
y = rearrange(y, "b l h p -> b l (h p)")
y = self.norm(y, z)
y = self.out_proj(y)
h = InferenceCache(conv_state, ssm_state)
return y, h
def step(self, u: Tensor, h: InferenceCache):
assert u.shape[1] == 1, "Only one token can be decoded per inference step"
zxbcdt = self.in_proj(u.squeeze(1))
z, xBC, dt = torch.split(
zxbcdt,
[
self.args.d_inner,
self.args.d_inner + 2 * self.args.d_state,
self.args.nheads,
],
dim=-1,
)
h.conv_state.copy_(torch.roll(h.conv_state, shifts=-1, dims=-1))
h.conv_state[:, :, -1] = xBC
xBC = torch.sum(
h.conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
)
xBC += self.conv1d.bias
xBC = silu(xBC)
x, B, C = torch.split(
xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1
)
A = -torch.exp(self.A_log)
dt = F.softplus(dt + self.dt_bias)
dA = torch.exp(dt * A)
x = rearrange(x, "b (h p) -> b h p", p=self.args.headdim)
dBx = torch.einsum("bh, bn, bhp -> bhpn", dt, B, x)
h.ssm_state.copy_(h.ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
y = torch.einsum("bhpn, bn -> bhp", h.ssm_state, C)
y = y + rearrange(self.D, "h -> h 1") * x
y = rearrange(y, "b h p -> b (h p)")
y = self.norm(y, z)
y = self.out_proj(y)
return y.unsqueeze(1), h
def segsum(x: Tensor, device: Device = None) -> Tensor:
if device is None:
device = x.device
T = x.size(-1)
x = repeat(x, "... d -> ... d e", e=T)
mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=-1)
x = x.masked_fill(~mask, 0)
x_segsum = torch.cumsum(x, dim=-2)
mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=0)
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
return x_segsum
def ssd(x, A, B, C, chunk_size, initial_states=None, device: Device = None):
assert x.shape[1] % chunk_size == 0
x, A, B, C = [
rearrange(m, "b (c l) ... -> b c l ...", l=chunk_size) for m in (x, A, B, C)
]
A = rearrange(A, "b c l h -> b h c l")
A_cumsum = torch.cumsum(A, dim=-1)
L = torch.exp(segsum(A, device=device))
Y_diag = torch.einsum("bclhn, bcshn, bhcls, bcshp -> bclhp", C, B, L, x)
decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum)
states = torch.einsum("bclhn, bhcl, bclhp -> bchpn", B, decay_states, x)
if initial_states is None:
initial_states = torch.zeros_like(states[:, :1])
states = torch.cat([initial_states, states], dim=1)
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)), device=device))
new_states = torch.einsum("bhzc, bchpn -> bzhpn", decay_chunk, states)
states, final_state = new_states[:, :-1], new_states[:, -1]
state_decay_out = torch.exp(A_cumsum)
Y_off = torch.einsum("bclhn, bchpn, bhcl -> bclhp", C, states, state_decay_out)
Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p")
return Y, final_state
class RMSNorm(nn.Module):
def __init__(self, d: int, eps: float = 1e-5, device: Device = None) -> None:
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(d, device=device))
def forward(self, x: Tensor, z: Tensor | None = None) -> Tensor:
if z is not None:
x = x * silu(z)
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
def silu(x: Tensor) -> Tensor:
return x * torch.sigmoid(x)