""" mamba2-minimal ============== Minimal Mamba-2 implementation for sequence modeling. Reference: https://arxiv.org/abs/2405.21060 """ from __future__ import annotations import math from dataclasses import dataclass from typing import NamedTuple, TypeAlias import torch import torch.nn.functional as F from einops import rearrange, repeat from torch import Tensor, nn Device: TypeAlias = str | torch.device | None @dataclass class Mamba2Config: d_model: int n_layer: int = 4 d_state: int = 64 d_conv: int = 4 expand: int = 2 headdim: int = 32 chunk_size: int = 1 def __post_init__(self) -> None: self.d_inner = self.expand * self.d_model if self.d_inner % self.headdim != 0: raise ValueError("d_inner must be divisible by headdim") self.nheads = self.d_inner // self.headdim class InferenceCache(NamedTuple): conv_state: Tensor ssm_state: Tensor @staticmethod def alloc(batch_size: int, args: Mamba2Config, device: Device = None) -> "InferenceCache": return InferenceCache( torch.zeros( batch_size, args.d_inner + 2 * args.d_state, args.d_conv, device=device ), torch.zeros( batch_size, args.nheads, args.headdim, args.d_state, device=device ), ) class Mamba2(nn.Module): def __init__(self, args: Mamba2Config, device: Device = None) -> None: super().__init__() self.args = args self.device = device d_in_proj = 2 * args.d_inner + 2 * args.d_state + args.nheads self.in_proj = nn.Linear(args.d_model, d_in_proj, bias=False, device=device) conv_dim = args.d_inner + 2 * args.d_state self.conv1d = nn.Conv1d( in_channels=conv_dim, out_channels=conv_dim, kernel_size=args.d_conv, groups=conv_dim, padding=args.d_conv - 1, device=device, ) self.dt_bias = nn.Parameter(torch.empty(args.nheads, device=device)) self.A_log = nn.Parameter(torch.empty(args.nheads, device=device)) self.D = nn.Parameter(torch.empty(args.nheads, device=device)) self.norm = RMSNorm(args.d_inner, device=device) self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=False, device=device) self.reset_parameters() def reset_parameters(self) -> None: dt_min, dt_max = 1e-3, 1e-1 device = self.dt_bias.device dtype = self.dt_bias.dtype dt = torch.exp( torch.rand(self.args.nheads, device=device, dtype=dtype) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) ) with torch.no_grad(): self.dt_bias.copy_(torch.log(torch.expm1(dt))) self.A_log.copy_( torch.log( torch.arange( 1, self.args.nheads + 1, device=device, dtype=self.A_log.dtype ) ) ) self.D.fill_(1.0) def forward(self, u: Tensor, h: InferenceCache | None = None): if h is not None: return self.step(u, h) A = -torch.exp(self.A_log) zxbcdt = self.in_proj(u) z, xBC, dt = torch.split( zxbcdt, [ self.args.d_inner, self.args.d_inner + 2 * self.args.d_state, self.args.nheads, ], dim=-1, ) dt = F.softplus(dt + self.dt_bias) xBC_t = rearrange(xBC, "b l d -> b d l") if u.shape[1] >= self.args.d_conv: conv_state = xBC_t[:, :, -self.args.d_conv :] else: conv_state = F.pad(xBC_t, (self.args.d_conv - u.shape[1], 0)) xBC = silu( self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, : u.shape[1], :] ) x, B, C = torch.split( xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1 ) x = rearrange(x, "b l (h p) -> b l h p", p=self.args.headdim) y, ssm_state = ssd( x * dt.unsqueeze(-1), A * dt, rearrange(B, "b l n -> b l 1 n"), rearrange(C, "b l n -> b l 1 n"), self.args.chunk_size, device=x.device, ) y = y + x * self.D.unsqueeze(-1) y = rearrange(y, "b l h p -> b l (h p)") y = self.norm(y, z) y = self.out_proj(y) h = InferenceCache(conv_state, ssm_state) return y, h def step(self, u: Tensor, h: InferenceCache): assert u.shape[1] == 1, "Only one token can be decoded per inference step" zxbcdt = self.in_proj(u.squeeze(1)) z, xBC, dt = torch.split( zxbcdt, [ self.args.d_inner, self.args.d_inner + 2 * self.args.d_state, self.args.nheads, ], dim=-1, ) h.conv_state.copy_(torch.roll(h.conv_state, shifts=-1, dims=-1)) h.conv_state[:, :, -1] = xBC xBC = torch.sum( h.conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1 ) xBC += self.conv1d.bias xBC = silu(xBC) x, B, C = torch.split( xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1 ) A = -torch.exp(self.A_log) dt = F.softplus(dt + self.dt_bias) dA = torch.exp(dt * A) x = rearrange(x, "b (h p) -> b h p", p=self.args.headdim) dBx = torch.einsum("bh, bn, bhp -> bhpn", dt, B, x) h.ssm_state.copy_(h.ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx) y = torch.einsum("bhpn, bn -> bhp", h.ssm_state, C) y = y + rearrange(self.D, "h -> h 1") * x y = rearrange(y, "b h p -> b (h p)") y = self.norm(y, z) y = self.out_proj(y) return y.unsqueeze(1), h def segsum(x: Tensor, device: Device = None) -> Tensor: if device is None: device = x.device T = x.size(-1) x = repeat(x, "... d -> ... d e", e=T) mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=-1) x = x.masked_fill(~mask, 0) x_segsum = torch.cumsum(x, dim=-2) mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=0) x_segsum = x_segsum.masked_fill(~mask, -torch.inf) return x_segsum def ssd(x, A, B, C, chunk_size, initial_states=None, device: Device = None): assert x.shape[1] % chunk_size == 0 x, A, B, C = [ rearrange(m, "b (c l) ... -> b c l ...", l=chunk_size) for m in (x, A, B, C) ] A = rearrange(A, "b c l h -> b h c l") A_cumsum = torch.cumsum(A, dim=-1) L = torch.exp(segsum(A, device=device)) Y_diag = torch.einsum("bclhn, bcshn, bhcls, bcshp -> bclhp", C, B, L, x) decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum) states = torch.einsum("bclhn, bhcl, bclhp -> bchpn", B, decay_states, x) if initial_states is None: initial_states = torch.zeros_like(states[:, :1]) states = torch.cat([initial_states, states], dim=1) decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)), device=device)) new_states = torch.einsum("bhzc, bchpn -> bzhpn", decay_chunk, states) states, final_state = new_states[:, :-1], new_states[:, -1] state_decay_out = torch.exp(A_cumsum) Y_off = torch.einsum("bclhn, bchpn, bhcl -> bclhp", C, states, state_decay_out) Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p") return Y, final_state class RMSNorm(nn.Module): def __init__(self, d: int, eps: float = 1e-5, device: Device = None) -> None: super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(d, device=device)) def forward(self, x: Tensor, z: Tensor | None = None) -> Tensor: if z is not None: x = x * silu(z) return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight def silu(x: Tensor) -> Tensor: return x * torch.sigmoid(x)