Compare commits

...

3 Commits

5 changed files with 564 additions and 3 deletions

4
.gitignore vendored Normal file
View File

@@ -0,0 +1,4 @@
__pycache__/
*.py[cod]
.venv/
.idea/

View File

@@ -1,5 +1,6 @@
import lightning.pytorch as pl import lightning.pytorch as pl
from lightning.pytorch import Callback from lightning.pytorch import Callback
from lightning.pytorch.loggers import WandbLogger
import os.path import os.path
@@ -10,15 +11,20 @@ from concurrent.futures import ThreadPoolExecutor
from lightning.pytorch.utilities.types import STEP_OUTPUT from lightning.pytorch.utilities.types import STEP_OUTPUT
from lightning_utilities.core.rank_zero import rank_zero_info from lightning_utilities.core.rank_zero import rank_zero_info
try:
import wandb
except ImportError: # pragma: no cover - optional dependency
wandb = None
def process_fn(image, path): def process_fn(image, path):
Image.fromarray(image).save(path) Image.fromarray(image).save(path)
class SaveImagesHook(Callback): class SaveImagesHook(Callback):
def __init__(self, save_dir="val", max_save_num=100, compressed=True): def __init__(self, save_dir="val", max_save_num=100, compressed=True, wandb_max_save_num=None):
self.save_dir = save_dir self.save_dir = save_dir
self.max_save_num = max_save_num self.max_save_num = max_save_num
self.compressed = compressed self.compressed = compressed
self.wandb_max_save_num = wandb_max_save_num
def save_start(self, target_dir): def save_start(self, target_dir):
self.target_dir = target_dir self.target_dir = target_dir
@@ -68,6 +74,50 @@ class SaveImagesHook(Callback):
self._have_saved_num = 0 self._have_saved_num = 0
self.executor_pool = None self.executor_pool = None
def _get_wandb_logger(self, trainer: "pl.Trainer"):
if isinstance(trainer.logger, WandbLogger):
return trainer.logger
if getattr(trainer, "loggers", None):
for logger in trainer.loggers:
if isinstance(logger, WandbLogger):
return logger
return None
def _log_wandb_samples(self, trainer: "pl.Trainer"):
if not trainer.is_global_zero:
return
if wandb is None:
return
if not self.samples:
return
if self.max_save_num == 0:
return
wandb_logger = self._get_wandb_logger(trainer)
if wandb_logger is None:
return
max_num = self.wandb_max_save_num if self.wandb_max_save_num is not None else self.max_save_num
if max_num <= 0:
return
images = []
remaining = max_num
for batch in self.samples:
if remaining <= 0:
break
take = min(batch.shape[0], remaining)
chunk = batch[:take]
if chunk.dtype != numpy.uint8:
chunk = numpy.clip(chunk, 0, 255).astype(numpy.uint8)
images.extend([wandb.Image(img) for img in chunk])
remaining -= take
if images:
wandb_logger.experiment.log(
{f"{self.save_dir}/samples": images},
step=trainer.global_step,
)
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
target_dir = os.path.join(trainer.default_root_dir, self.save_dir, f"iter_{trainer.global_step}") target_dir = os.path.join(trainer.default_root_dir, self.save_dir, f"iter_{trainer.global_step}")
self.save_start(target_dir) self.save_start(target_dir)
@@ -84,6 +134,7 @@ class SaveImagesHook(Callback):
return self.process_batch(trainer, pl_module, outputs, batch) return self.process_batch(trainer, pl_module, outputs, batch)
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._log_wandb_samples(trainer)
self.save_end() self.save_end()
def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:

View File

@@ -0,0 +1,43 @@
from typing import Optional
import torch
from torch.utils.data import Dataset
from torchvision.transforms import Normalize
from torchvision.transforms.functional import to_tensor
from PIL import Image
from datasets import load_dataset
from src.data.dataset.metric_dataset import CenterCrop
class HuggingFaceImageDataset(Dataset):
def __init__(
self,
name: str,
split: str = "train",
image_column: str = "image",
label_column: Optional[str] = None,
resolution: int = 256,
cache_dir: Optional[str] = None,
) -> None:
self.dataset = load_dataset(name, split=split, cache_dir=cache_dir)
self.image_column = image_column
self.label_column = label_column
self.transform = CenterCrop(resolution)
self.normalize = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
def __len__(self) -> int:
return len(self.dataset)
def __getitem__(self, idx: int):
item = self.dataset[idx]
image = item[self.image_column]
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
image = image.convert("RGB")
image = self.transform(image)
raw_image = to_tensor(image)
normalized_image = self.normalize(raw_image)
label = 0 if self.label_column is None else int(item[self.label_column])
return raw_image, normalized_image, label

View File

@@ -0,0 +1,463 @@
"""
mamba2-minimal
==============
A minimal, single-file implementation of the Mamba-2 model in PyTorch.
> **Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality**
> Authors: Tri Dao, Albert Gu
> Paper: https://arxiv.org/abs/2405.21060
"""
import json
import math
from dataclasses import dataclass
from typing import Iterable, NamedTuple, TypeAlias, cast
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import LongTensor, Tensor, nn
Device: TypeAlias = str | torch.device | None
@dataclass
class Mamba2Config:
d_model: int # model dimension (D)
n_layer: int = 24 # number of Mamba-2 layers in the language model
d_state: int = 128 # state dimension (N)
d_conv: int = 4 # convolution kernel size
expand: int = 2 # expansion factor (E)
headdim: int = 64 # head dimension (P)
chunk_size: int = 64 # matrix partition size (Q)
vocab_size: int = 50277
pad_vocab_size_multiple: int = 16
def __post_init__(self):
self.d_inner = self.expand * self.d_model
assert self.d_inner % self.headdim == 0
self.nheads = self.d_inner // self.headdim
if self.vocab_size % self.pad_vocab_size_multiple != 0:
self.vocab_size += (
self.pad_vocab_size_multiple
- self.vocab_size % self.pad_vocab_size_multiple
)
class InferenceCache(NamedTuple):
conv_state: Tensor # (batch, d_inner + 2 * d_state, d_conv)
ssm_state: Tensor # (batch, nheads, headdim, d_state)
@staticmethod
def alloc(batch_size: int, args: Mamba2Config, device: Device = None):
return InferenceCache(
torch.zeros(
batch_size, args.d_inner + 2 * args.d_state, args.d_conv, device=device
),
torch.zeros(
batch_size, args.nheads, args.headdim, args.d_state, device=device
),
)
class Mamba2LMHeadModel(nn.Module):
def __init__(self, args: Mamba2Config, device: Device = None):
super().__init__()
self.args = args
self.device = device
self.backbone = nn.ModuleDict(
dict(
embedding=nn.Embedding(args.vocab_size, args.d_model, device=device),
layers=nn.ModuleList(
[
nn.ModuleDict(
dict(
mixer=Mamba2(args, device=device),
norm=RMSNorm(args.d_model, device=device),
)
)
for _ in range(args.n_layer)
]
),
norm_f=RMSNorm(args.d_model, device=device),
)
)
self.lm_head = nn.Linear(
args.d_model, args.vocab_size, bias=False, device=device
)
self.lm_head.weight = self.backbone.embedding.weight
@staticmethod
def from_pretrained(huggingface_model_id: str, device: Device = None):
from transformers.utils import CONFIG_NAME, WEIGHTS_NAME
from transformers.utils.hub import cached_file
config_path = cached_file(huggingface_model_id, CONFIG_NAME)
assert config_path, "Failed to get huggingface config file"
state_dict_path = cached_file(huggingface_model_id, WEIGHTS_NAME)
assert state_dict_path, "Failed to get huggingface state dict file"
config = json.load(open(config_path))
args = Mamba2Config(
d_model=config["d_model"],
n_layer=config["n_layer"],
vocab_size=config["vocab_size"],
pad_vocab_size_multiple=config["pad_vocab_size_multiple"],
)
map_location = "cpu" if device is None else device
state_dict = torch.load(
state_dict_path, weights_only=True, map_location=map_location, mmap=True
)
model = Mamba2LMHeadModel(args, device=device)
model.load_state_dict(state_dict)
model.eval()
return model
def forward(
self, input_ids: LongTensor, h: list[InferenceCache] | list[None] | None = None
) -> tuple[LongTensor, list[InferenceCache]]:
"""
Arguments
input_ids: (batch, seqlen) tokens from `EleutherAI/gpt-neox-20b` tokenizer
h: hidden states for inference step. If present the constant-time
(wrt sequence length) inference path will be taken, input_ids
should have shape (batch, 1) containing the next batch of prompt
token.
Return (logits, h)
logits: (batch, seqlen, vocab_size)
h: updated inference cache after processing `input_ids`
"""
seqlen = input_ids.shape[1]
if h is None:
h = [None for _ in range(self.args.n_layer)]
x = self.backbone.embedding(input_ids)
for i, layer in enumerate(self.backbone.layers):
y, h[i] = layer.mixer(layer.norm(x), h[i])
x = y + x
x = self.backbone.norm_f(x)
logits = self.lm_head(x)
return logits[:, :seqlen], cast(list[InferenceCache], h)
def generate(
self,
input_ids: LongTensor,
max_new_length: int = 20,
temperature: float = 1.0,
top_k: int = 50,
top_p: float = 1.0,
eos_token_id: int = 0,
) -> Iterable[tuple[int, list[InferenceCache]]]:
prefix, tokens = input_ids[:-1], input_ids[-1:].unsqueeze(0)
cache_device = input_ids.device
# Process prompt
# The input sequence to forward (non-inference path) must have length multiple that of chunk_size.
# We split out excess tokens so that n_chunked tokens can be processed by one forward call and
# process the rest in multiple inference steps.
n_chunked = (prefix.shape[0] // self.args.chunk_size) * self.args.chunk_size
if n_chunked > 0:
_, h = self(prefix[:n_chunked].unsqueeze(0), None)
else:
h = [
InferenceCache.alloc(1, self.args, device=cache_device)
for _ in range(self.args.n_layer)
]
for i in range(n_chunked, prefix.shape[0]):
_, h = self(prefix[i : i + 1].unsqueeze(0), h)
# Generate
for _ in range(max_new_length):
with torch.no_grad():
out, h = self(tokens, h)
logits = out[0, -1]
if temperature != 1.0:
logits = logits / temperature
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, k=top_k)[0][-1]
logits[indices_to_remove] = -torch.inf
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cum_probs > 0.5
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
sorted_indices_to_remove[0] = False
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = -torch.inf
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
if next_token.item() == eos_token_id:
return
tokens = next_token.unsqueeze(0)
yield cast(int, next_token.item()), h
class Mamba2(nn.Module):
def __init__(self, args: Mamba2Config, device: Device = None):
super().__init__()
self.args = args
self.device = device
# Order: (z, x, B, C, dt)
d_in_proj = 2 * args.d_inner + 2 * args.d_state + args.nheads
self.in_proj = nn.Linear(args.d_model, d_in_proj, bias=False, device=device)
conv_dim = args.d_inner + 2 * args.d_state
self.conv1d = nn.Conv1d(
in_channels=conv_dim,
out_channels=conv_dim,
kernel_size=args.d_conv,
groups=conv_dim,
padding=args.d_conv - 1,
device=device,
)
self.dt_bias = nn.Parameter(torch.empty(args.nheads, device=device))
self.A_log = nn.Parameter(torch.empty(args.nheads, device=device))
self.D = nn.Parameter(torch.empty(args.nheads, device=device))
self.norm = RMSNorm(args.d_inner, device=device)
self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=False, device=device)
self.reset_parameters()
def reset_parameters(self):
# Mamba-style parameter init for stable SSM dynamics.
dt_min, dt_max = 1e-3, 1e-1
device = self.dt_bias.device
dtype = self.dt_bias.dtype
dt = torch.exp(
torch.rand(self.args.nheads, device=device, dtype=dtype)
* (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
)
with torch.no_grad():
self.dt_bias.copy_(torch.log(torch.expm1(dt)))
self.A_log.copy_(
torch.log(
torch.arange(
1, self.args.nheads + 1, device=device, dtype=self.A_log.dtype
)
)
)
self.D.fill_(1.0)
def forward(self, u: Tensor, h: InferenceCache | None = None):
"""
Arguments
u: (batch, seqlen, d_model) input. seqlen should be a multiple of chunk_size.
h: hidden states for inference step. Initialized to 0s if not present.
Return (y, h)
y: (batch, seqlen, d_model) output
h: updated inference cache after processing `u`
"""
if h:
return self.step(u, h)
A = -torch.exp(self.A_log) # (nheads,)
zxbcdt = self.in_proj(u) # (batch, seqlen, d_in_proj)
z, xBC, dt = torch.split(
zxbcdt,
[
self.args.d_inner,
self.args.d_inner + 2 * self.args.d_state,
self.args.nheads,
],
dim=-1,
)
dt = F.softplus(dt + self.dt_bias) # (batch, seqlen, nheads)
# Pad or truncate xBC seqlen to d_conv
conv_state = F.pad(
rearrange(xBC, "b l d -> b d l"), (self.args.d_conv - u.shape[1], 0)
)
xBC = silu(
self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, : u.shape[1], :]
) # (batch, seqlen, d_inner + 2 * d_state))
x, B, C = torch.split(
xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1
)
x = rearrange(x, "b l (h p) -> b l h p", p=self.args.headdim)
y, ssm_state = ssd(
x * dt.unsqueeze(-1),
A * dt,
rearrange(B, "b l n -> b l 1 n"),
rearrange(C, "b l n -> b l 1 n"),
self.args.chunk_size,
device=x.device,
)
y = y + x * self.D.unsqueeze(-1)
y = rearrange(y, "b l h p -> b l (h p)")
y = self.norm(y, z)
y = self.out_proj(y)
h = InferenceCache(conv_state, ssm_state)
return y, h
def step(self, u: Tensor, h: InferenceCache) -> tuple[Tensor, InferenceCache]:
"""Take a single inference step for the current input and hidden state
Unlike attention-based models, RNN-based models (eg Mamba) does not need
to look back at all the past tokens to generate a new token. Instead a
hidden state (initialized to 0s initially) is updated for each input and
passed to the next inference step. This means that the total inference
time is linear with respect to the sequence length instead of quadratic
in attention's case.
Arguments
u: (batch, 1, d_model)
h: initial/running hidden state
Return (y, h)
y: (batch, 1, d_model)
h: updated hidden state
"""
assert u.shape[1] == 1, "Only one token can be decoded per inference step"
zxbcdt = self.in_proj(u.squeeze(1)) # (batch, d_in_proj)
z, xBC, dt = torch.split(
zxbcdt,
[
self.args.d_inner,
self.args.d_inner + 2 * self.args.d_state,
self.args.nheads,
],
dim=-1,
)
# Advance convolution input
h.conv_state.copy_(torch.roll(h.conv_state, shifts=-1, dims=-1))
h.conv_state[:, :, -1] = xBC
# Convolution step
xBC = torch.sum(
h.conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
)
xBC += self.conv1d.bias
xBC = silu(xBC)
x, B, C = torch.split(
xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1
)
A = -torch.exp(self.A_log) # (nheads,)
# SSM step
dt = F.softplus(dt + self.dt_bias) # (batch, nheads)
dA = torch.exp(dt * A) # (batch, nheads)
x = rearrange(x, "b (h p) -> b h p", p=self.args.headdim)
dBx = torch.einsum("bh, bn, bhp -> bhpn", dt, B, x)
h.ssm_state.copy_(h.ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
y = torch.einsum("bhpn, bn -> bhp", h.ssm_state, C)
y = y + rearrange(self.D, "h -> h 1") * x
y = rearrange(y, "b h p -> b (h p)")
y = self.norm(y, z)
y = self.out_proj(y)
return y.unsqueeze(1), h
def segsum(x: Tensor, device: Device = None) -> Tensor:
"""Stable segment sum calculation.
`exp(segsum(A))` produces a 1-semiseparable matrix, which is equivalent to a scalar SSM.
Source: https://github.com/state-spaces/mamba/blob/219f03c840d5a44e7d42e4e728134834fddccf45/mamba_ssm/modules/ssd_minimal.py#L23-L32
"""
if device is None:
device = x.device
T = x.size(-1)
x = repeat(x, "... d -> ... d e", e=T)
mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=-1)
x = x.masked_fill(~mask, 0)
x_segsum = torch.cumsum(x, dim=-2)
mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=0)
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
return x_segsum
def ssd(x, A, B, C, chunk_size, initial_states=None, device: Device = None):
"""Structed State Space Duality (SSD) - the core of Mamba-2
This is almost the exact same minimal SSD code from the blog post.
Arguments
x: (batch, seqlen, n_heads, d_head)
A: (batch, seqlen, n_heads)
B: (batch, seqlen, n_heads, d_state)
C: (batch, seqlen, n_heads, d_state)
Return
y: (batch, seqlen, n_heads, d_head)
Source
1. https://tridao.me/blog/2024/mamba2-part3-algorithm/
2. https://github.com/state-spaces/mamba/blob/219f03c840d5a44e7d42e4e728134834fddccf45/mamba_ssm/modules/ssd_minimal.py#L34-L78
"""
assert x.shape[1] % chunk_size == 0
# Rearrange into chunks
# Step 1, 2 and 4 of SSD can be computed in parallel for each chunk across devices (sequence parallel)
# This is not implemented and left as an exercise for the reader 😜
x, A, B, C = [
rearrange(m, "b (c l) ... -> b c l ...", l=chunk_size) for m in (x, A, B, C)
]
A = rearrange(A, "b c l h -> b h c l")
A_cumsum = torch.cumsum(A, dim=-1)
# 1. Compute the output for each intra-chunk (diagonal blocks)
L = torch.exp(segsum(A, device=device))
Y_diag = torch.einsum("bclhn, bcshn, bhcls, bcshp -> bclhp", C, B, L, x)
# 2. Compute the state for each intra-chunk
# (right term of low-rank factorization of off-diagonal blocks; B terms)
decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum)
states = torch.einsum("bclhn, bhcl, bclhp -> bchpn", B, decay_states, x)
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
# (middle term of factorization of off-diag blocks; A terms)
if initial_states is None:
initial_states = torch.zeros_like(states[:, :1])
states = torch.cat([initial_states, states], dim=1)
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)), device=device))
new_states = torch.einsum("bhzc, bchpn -> bzhpn", decay_chunk, states)
states, final_state = new_states[:, :-1], new_states[:, -1]
# 4. Compute state -> output conversion per chunk
# (left term of low-rank factorization of off-diagonal blocks; C terms)
state_decay_out = torch.exp(A_cumsum)
Y_off = torch.einsum("bclhn, bchpn, bhcl -> bclhp", C, states, state_decay_out)
# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p")
return Y, final_state
class RMSNorm(nn.Module):
def __init__(self, d: int, eps: float = 1e-5, device: Device = None):
"""Gated Root Mean Square Layer Normalization
Paper: https://arxiv.org/abs/1910.07467
"""
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(d, device=device))
def forward(self, x, z=None):
if z is not None:
x = x * silu(z)
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
def silu(x):
"""Applies the Sigmoid Linear Unit (SiLU), element-wise.
Define this manually since torch's version doesn't seem to work on MPS.
"""
return x * F.sigmoid(x)