Add initial project structure including core Mamba2 logic, entry point, and uv-based dependency management.
247 lines
7.9 KiB
Python
247 lines
7.9 KiB
Python
"""
|
|
mamba2-minimal
|
|
==============
|
|
|
|
Minimal Mamba-2 implementation for sequence modeling.
|
|
Reference: https://arxiv.org/abs/2405.21060
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import math
|
|
from dataclasses import dataclass
|
|
from typing import NamedTuple, TypeAlias
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from einops import rearrange, repeat
|
|
from torch import Tensor, nn
|
|
|
|
Device: TypeAlias = str | torch.device | None
|
|
|
|
|
|
@dataclass
|
|
class Mamba2Config:
|
|
d_model: int
|
|
n_layer: int = 4
|
|
d_state: int = 64
|
|
d_conv: int = 4
|
|
expand: int = 2
|
|
headdim: int = 32
|
|
chunk_size: int = 1
|
|
|
|
def __post_init__(self) -> None:
|
|
self.d_inner = self.expand * self.d_model
|
|
if self.d_inner % self.headdim != 0:
|
|
raise ValueError("d_inner must be divisible by headdim")
|
|
self.nheads = self.d_inner // self.headdim
|
|
|
|
|
|
class InferenceCache(NamedTuple):
|
|
conv_state: Tensor
|
|
ssm_state: Tensor
|
|
|
|
@staticmethod
|
|
def alloc(batch_size: int, args: Mamba2Config, device: Device = None) -> "InferenceCache":
|
|
return InferenceCache(
|
|
torch.zeros(
|
|
batch_size, args.d_inner + 2 * args.d_state, args.d_conv, device=device
|
|
),
|
|
torch.zeros(
|
|
batch_size, args.nheads, args.headdim, args.d_state, device=device
|
|
),
|
|
)
|
|
|
|
|
|
class Mamba2(nn.Module):
|
|
def __init__(self, args: Mamba2Config, device: Device = None) -> None:
|
|
super().__init__()
|
|
self.args = args
|
|
self.device = device
|
|
|
|
d_in_proj = 2 * args.d_inner + 2 * args.d_state + args.nheads
|
|
self.in_proj = nn.Linear(args.d_model, d_in_proj, bias=False, device=device)
|
|
|
|
conv_dim = args.d_inner + 2 * args.d_state
|
|
self.conv1d = nn.Conv1d(
|
|
in_channels=conv_dim,
|
|
out_channels=conv_dim,
|
|
kernel_size=args.d_conv,
|
|
groups=conv_dim,
|
|
padding=args.d_conv - 1,
|
|
device=device,
|
|
)
|
|
|
|
self.dt_bias = nn.Parameter(torch.empty(args.nheads, device=device))
|
|
self.A_log = nn.Parameter(torch.empty(args.nheads, device=device))
|
|
self.D = nn.Parameter(torch.empty(args.nheads, device=device))
|
|
self.norm = RMSNorm(args.d_inner, device=device)
|
|
self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=False, device=device)
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self) -> None:
|
|
dt_min, dt_max = 1e-3, 1e-1
|
|
device = self.dt_bias.device
|
|
dtype = self.dt_bias.dtype
|
|
dt = torch.exp(
|
|
torch.rand(self.args.nheads, device=device, dtype=dtype)
|
|
* (math.log(dt_max) - math.log(dt_min))
|
|
+ math.log(dt_min)
|
|
)
|
|
with torch.no_grad():
|
|
self.dt_bias.copy_(torch.log(torch.expm1(dt)))
|
|
self.A_log.copy_(
|
|
torch.log(
|
|
torch.arange(
|
|
1, self.args.nheads + 1, device=device, dtype=self.A_log.dtype
|
|
)
|
|
)
|
|
)
|
|
self.D.fill_(1.0)
|
|
|
|
def forward(self, u: Tensor, h: InferenceCache | None = None):
|
|
if h is not None:
|
|
return self.step(u, h)
|
|
|
|
A = -torch.exp(self.A_log)
|
|
zxbcdt = self.in_proj(u)
|
|
z, xBC, dt = torch.split(
|
|
zxbcdt,
|
|
[
|
|
self.args.d_inner,
|
|
self.args.d_inner + 2 * self.args.d_state,
|
|
self.args.nheads,
|
|
],
|
|
dim=-1,
|
|
)
|
|
dt = F.softplus(dt + self.dt_bias)
|
|
|
|
xBC_t = rearrange(xBC, "b l d -> b d l")
|
|
if u.shape[1] >= self.args.d_conv:
|
|
conv_state = xBC_t[:, :, -self.args.d_conv :]
|
|
else:
|
|
conv_state = F.pad(xBC_t, (self.args.d_conv - u.shape[1], 0))
|
|
|
|
xBC = silu(
|
|
self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, : u.shape[1], :]
|
|
)
|
|
x, B, C = torch.split(
|
|
xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1
|
|
)
|
|
x = rearrange(x, "b l (h p) -> b l h p", p=self.args.headdim)
|
|
y, ssm_state = ssd(
|
|
x * dt.unsqueeze(-1),
|
|
A * dt,
|
|
rearrange(B, "b l n -> b l 1 n"),
|
|
rearrange(C, "b l n -> b l 1 n"),
|
|
self.args.chunk_size,
|
|
device=x.device,
|
|
)
|
|
y = y + x * self.D.unsqueeze(-1)
|
|
y = rearrange(y, "b l h p -> b l (h p)")
|
|
y = self.norm(y, z)
|
|
y = self.out_proj(y)
|
|
|
|
h = InferenceCache(conv_state, ssm_state)
|
|
return y, h
|
|
|
|
def step(self, u: Tensor, h: InferenceCache):
|
|
assert u.shape[1] == 1, "Only one token can be decoded per inference step"
|
|
|
|
zxbcdt = self.in_proj(u.squeeze(1))
|
|
z, xBC, dt = torch.split(
|
|
zxbcdt,
|
|
[
|
|
self.args.d_inner,
|
|
self.args.d_inner + 2 * self.args.d_state,
|
|
self.args.nheads,
|
|
],
|
|
dim=-1,
|
|
)
|
|
|
|
h.conv_state.copy_(torch.roll(h.conv_state, shifts=-1, dims=-1))
|
|
h.conv_state[:, :, -1] = xBC
|
|
xBC = torch.sum(
|
|
h.conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
|
|
)
|
|
xBC += self.conv1d.bias
|
|
xBC = silu(xBC)
|
|
|
|
x, B, C = torch.split(
|
|
xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1
|
|
)
|
|
A = -torch.exp(self.A_log)
|
|
|
|
dt = F.softplus(dt + self.dt_bias)
|
|
dA = torch.exp(dt * A)
|
|
x = rearrange(x, "b (h p) -> b h p", p=self.args.headdim)
|
|
dBx = torch.einsum("bh, bn, bhp -> bhpn", dt, B, x)
|
|
h.ssm_state.copy_(h.ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
|
|
y = torch.einsum("bhpn, bn -> bhp", h.ssm_state, C)
|
|
y = y + rearrange(self.D, "h -> h 1") * x
|
|
y = rearrange(y, "b h p -> b (h p)")
|
|
y = self.norm(y, z)
|
|
y = self.out_proj(y)
|
|
|
|
return y.unsqueeze(1), h
|
|
|
|
|
|
def segsum(x: Tensor, device: Device = None) -> Tensor:
|
|
if device is None:
|
|
device = x.device
|
|
T = x.size(-1)
|
|
x = repeat(x, "... d -> ... d e", e=T)
|
|
mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=-1)
|
|
x = x.masked_fill(~mask, 0)
|
|
x_segsum = torch.cumsum(x, dim=-2)
|
|
mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=0)
|
|
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
|
|
return x_segsum
|
|
|
|
|
|
def ssd(x, A, B, C, chunk_size, initial_states=None, device: Device = None):
|
|
assert x.shape[1] % chunk_size == 0
|
|
|
|
x, A, B, C = [
|
|
rearrange(m, "b (c l) ... -> b c l ...", l=chunk_size) for m in (x, A, B, C)
|
|
]
|
|
|
|
A = rearrange(A, "b c l h -> b h c l")
|
|
A_cumsum = torch.cumsum(A, dim=-1)
|
|
|
|
L = torch.exp(segsum(A, device=device))
|
|
Y_diag = torch.einsum("bclhn, bcshn, bhcls, bcshp -> bclhp", C, B, L, x)
|
|
|
|
decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum)
|
|
states = torch.einsum("bclhn, bhcl, bclhp -> bchpn", B, decay_states, x)
|
|
|
|
if initial_states is None:
|
|
initial_states = torch.zeros_like(states[:, :1])
|
|
states = torch.cat([initial_states, states], dim=1)
|
|
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)), device=device))
|
|
new_states = torch.einsum("bhzc, bchpn -> bzhpn", decay_chunk, states)
|
|
states, final_state = new_states[:, :-1], new_states[:, -1]
|
|
|
|
state_decay_out = torch.exp(A_cumsum)
|
|
Y_off = torch.einsum("bclhn, bchpn, bhcl -> bclhp", C, states, state_decay_out)
|
|
|
|
Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p")
|
|
|
|
return Y, final_state
|
|
|
|
|
|
class RMSNorm(nn.Module):
|
|
def __init__(self, d: int, eps: float = 1e-5, device: Device = None) -> None:
|
|
super().__init__()
|
|
self.eps = eps
|
|
self.weight = nn.Parameter(torch.ones(d, device=device))
|
|
|
|
def forward(self, x: Tensor, z: Tensor | None = None) -> Tensor:
|
|
if z is not None:
|
|
x = x * silu(z)
|
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
|
|
|
|
|
|
def silu(x: Tensor) -> Tensor:
|
|
return x * torch.sigmoid(x)
|