Compare commits
13 Commits
3693640ca3
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
50b1f76263 | ||
|
|
b9c3720eba | ||
|
|
6bf32b08fd | ||
|
|
cf45afe325 | ||
|
|
1d1b4d2913 | ||
|
|
ae7df43ecc | ||
|
|
2c16b8f423 | ||
|
|
8038c16bee | ||
|
|
598f7b40a2 | ||
|
|
3093a65151 | ||
|
|
910b764275 | ||
|
|
26c8e54edb | ||
|
|
99d92c94e7 |
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
.venv/
|
||||||
|
.idea/
|
||||||
@@ -13,6 +13,7 @@ We decouple diffusion transformer into encoder-decoder design, and surprisingly
|
|||||||
* We achieves **1.26 FID** on ImageNet256x256 Benchmark with DDT-XL/2(22en6de).
|
* We achieves **1.26 FID** on ImageNet256x256 Benchmark with DDT-XL/2(22en6de).
|
||||||
* We achieves **1.28 FID** on ImageNet512x512 Benchmark with DDT-XL/2(22en6de).
|
* We achieves **1.28 FID** on ImageNet512x512 Benchmark with DDT-XL/2(22en6de).
|
||||||
* As a byproduct, our DDT can reuse encoder among adjacent steps to accelerate inference.
|
* As a byproduct, our DDT can reuse encoder among adjacent steps to accelerate inference.
|
||||||
|
|
||||||
## Visualizations
|
## Visualizations
|
||||||

|

|
||||||
## Checkpoints
|
## Checkpoints
|
||||||
|
|||||||
116
configs/ddt_butterflies_b2_256.yaml
Normal file
116
configs/ddt_butterflies_b2_256.yaml
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
seed_everything: true
|
||||||
|
tags:
|
||||||
|
exp: &exp ddt_butterflies_b2_256
|
||||||
|
torch_hub_dir: null
|
||||||
|
huggingface_cache_dir: null
|
||||||
|
trainer:
|
||||||
|
default_root_dir: workdirs
|
||||||
|
accelerator: auto
|
||||||
|
strategy: auto
|
||||||
|
devices: auto
|
||||||
|
num_nodes: 1
|
||||||
|
precision: bf16-mixed
|
||||||
|
logger:
|
||||||
|
class_path: lightning.pytorch.loggers.WandbLogger
|
||||||
|
init_args:
|
||||||
|
project: ddt_butterflies
|
||||||
|
name: *exp
|
||||||
|
save_dir: workdirs
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
max_steps: 200000
|
||||||
|
val_check_interval: 2000
|
||||||
|
check_val_every_n_epoch: null
|
||||||
|
log_every_n_steps: 50
|
||||||
|
deterministic: null
|
||||||
|
inference_mode: true
|
||||||
|
use_distributed_sampler: false
|
||||||
|
callbacks:
|
||||||
|
- class_path: src.callbacks.model_checkpoint.CheckpointHook
|
||||||
|
init_args:
|
||||||
|
every_n_train_steps: 10000
|
||||||
|
save_top_k: -1
|
||||||
|
save_last: true
|
||||||
|
- class_path: src.callbacks.save_images.SaveImagesHook
|
||||||
|
init_args:
|
||||||
|
save_dir: val
|
||||||
|
max_save_num: 64
|
||||||
|
model:
|
||||||
|
vae:
|
||||||
|
class_path: src.models.vae.LatentVAE
|
||||||
|
init_args:
|
||||||
|
precompute: false
|
||||||
|
weight_path: stabilityai/sd-vae-ft-ema
|
||||||
|
denoiser:
|
||||||
|
class_path: src.models.denoiser.decoupled_improved_dit.DDT
|
||||||
|
init_args:
|
||||||
|
in_channels: 4
|
||||||
|
patch_size: 2
|
||||||
|
num_groups: 12
|
||||||
|
hidden_size: &hidden_dim 768
|
||||||
|
num_blocks: 12
|
||||||
|
num_encoder_blocks: 8
|
||||||
|
num_classes: 1
|
||||||
|
conditioner:
|
||||||
|
class_path: src.models.conditioner.LabelConditioner
|
||||||
|
init_args:
|
||||||
|
null_class: 1
|
||||||
|
diffusion_trainer:
|
||||||
|
class_path: src.diffusion.stateful_flow_matching.training_repa.REPATrainer
|
||||||
|
init_args:
|
||||||
|
lognorm_t: true
|
||||||
|
encoder_weight_path: dinov2_vitb14
|
||||||
|
align_layer: 4
|
||||||
|
proj_denoiser_dim: *hidden_dim
|
||||||
|
proj_hidden_dim: *hidden_dim
|
||||||
|
proj_encoder_dim: 768
|
||||||
|
scheduler: &scheduler src.diffusion.stateful_flow_matching.scheduling.LinearScheduler
|
||||||
|
diffusion_sampler:
|
||||||
|
class_path: src.diffusion.stateful_flow_matching.sampling.EulerSampler
|
||||||
|
init_args:
|
||||||
|
num_steps: 250
|
||||||
|
guidance: 1.0
|
||||||
|
timeshift: 1.0
|
||||||
|
state_refresh_rate: 1
|
||||||
|
guidance_interval_min: 0.3
|
||||||
|
guidance_interval_max: 1.0
|
||||||
|
scheduler: *scheduler
|
||||||
|
w_scheduler: src.diffusion.stateful_flow_matching.scheduling.LinearScheduler
|
||||||
|
guidance_fn: src.diffusion.base.guidance.simple_guidance_fn
|
||||||
|
last_step: 0.04
|
||||||
|
step_fn: src.diffusion.stateful_flow_matching.sampling.ode_step_fn
|
||||||
|
ema_tracker:
|
||||||
|
class_path: src.callbacks.simple_ema.SimpleEMA
|
||||||
|
init_args:
|
||||||
|
decay: 0.9999
|
||||||
|
optimizer:
|
||||||
|
class_path: torch.optim.AdamW
|
||||||
|
init_args:
|
||||||
|
lr: 1e-3
|
||||||
|
betas:
|
||||||
|
- 0.9
|
||||||
|
- 0.95
|
||||||
|
weight_decay: 0.0
|
||||||
|
data:
|
||||||
|
train_dataset: hf_image
|
||||||
|
train_root: ./data/butterflies
|
||||||
|
test_nature_root: null
|
||||||
|
test_gen_root: null
|
||||||
|
train_image_size: 256
|
||||||
|
train_batch_size: 128
|
||||||
|
train_num_workers: 8
|
||||||
|
train_prefetch_factor: 2
|
||||||
|
train_hf_name: huggan/smithsonian_butterflies_subset
|
||||||
|
train_hf_split: train
|
||||||
|
train_hf_image_column: image
|
||||||
|
train_hf_label_column: null
|
||||||
|
train_hf_cache_dir: null
|
||||||
|
eval_max_num_instances: 256
|
||||||
|
pred_batch_size: 32
|
||||||
|
pred_num_workers: 4
|
||||||
|
pred_seeds: null
|
||||||
|
pred_selected_classes: null
|
||||||
|
num_classes: 1
|
||||||
|
latent_shape:
|
||||||
|
- 4
|
||||||
|
- 32
|
||||||
|
- 32
|
||||||
@@ -6,4 +6,5 @@ jsonargparse[signatures]>=4.27.7
|
|||||||
torchvision
|
torchvision
|
||||||
timm
|
timm
|
||||||
accelerate
|
accelerate
|
||||||
gradio
|
gradio
|
||||||
|
datasets
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import lightning.pytorch as pl
|
import lightning.pytorch as pl
|
||||||
from lightning.pytorch import Callback
|
from lightning.pytorch import Callback
|
||||||
|
from lightning.pytorch.loggers import WandbLogger
|
||||||
|
|
||||||
|
|
||||||
import os.path
|
import os.path
|
||||||
@@ -10,15 +11,20 @@ from concurrent.futures import ThreadPoolExecutor
|
|||||||
|
|
||||||
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
||||||
from lightning_utilities.core.rank_zero import rank_zero_info
|
from lightning_utilities.core.rank_zero import rank_zero_info
|
||||||
|
try:
|
||||||
|
import wandb
|
||||||
|
except ImportError: # pragma: no cover - optional dependency
|
||||||
|
wandb = None
|
||||||
|
|
||||||
def process_fn(image, path):
|
def process_fn(image, path):
|
||||||
Image.fromarray(image).save(path)
|
Image.fromarray(image).save(path)
|
||||||
|
|
||||||
class SaveImagesHook(Callback):
|
class SaveImagesHook(Callback):
|
||||||
def __init__(self, save_dir="val", max_save_num=100, compressed=True):
|
def __init__(self, save_dir="val", max_save_num=100, compressed=True, wandb_max_save_num=None):
|
||||||
self.save_dir = save_dir
|
self.save_dir = save_dir
|
||||||
self.max_save_num = max_save_num
|
self.max_save_num = max_save_num
|
||||||
self.compressed = compressed
|
self.compressed = compressed
|
||||||
|
self.wandb_max_save_num = wandb_max_save_num
|
||||||
|
|
||||||
def save_start(self, target_dir):
|
def save_start(self, target_dir):
|
||||||
self.target_dir = target_dir
|
self.target_dir = target_dir
|
||||||
@@ -68,6 +74,50 @@ class SaveImagesHook(Callback):
|
|||||||
self._have_saved_num = 0
|
self._have_saved_num = 0
|
||||||
self.executor_pool = None
|
self.executor_pool = None
|
||||||
|
|
||||||
|
def _get_wandb_logger(self, trainer: "pl.Trainer"):
|
||||||
|
if isinstance(trainer.logger, WandbLogger):
|
||||||
|
return trainer.logger
|
||||||
|
if getattr(trainer, "loggers", None):
|
||||||
|
for logger in trainer.loggers:
|
||||||
|
if isinstance(logger, WandbLogger):
|
||||||
|
return logger
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _log_wandb_samples(self, trainer: "pl.Trainer"):
|
||||||
|
if not trainer.is_global_zero:
|
||||||
|
return
|
||||||
|
if wandb is None:
|
||||||
|
return
|
||||||
|
if not self.samples:
|
||||||
|
return
|
||||||
|
if self.max_save_num == 0:
|
||||||
|
return
|
||||||
|
wandb_logger = self._get_wandb_logger(trainer)
|
||||||
|
if wandb_logger is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
max_num = self.wandb_max_save_num if self.wandb_max_save_num is not None else self.max_save_num
|
||||||
|
if max_num <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
images = []
|
||||||
|
remaining = max_num
|
||||||
|
for batch in self.samples:
|
||||||
|
if remaining <= 0:
|
||||||
|
break
|
||||||
|
take = min(batch.shape[0], remaining)
|
||||||
|
chunk = batch[:take]
|
||||||
|
if chunk.dtype != numpy.uint8:
|
||||||
|
chunk = numpy.clip(chunk, 0, 255).astype(numpy.uint8)
|
||||||
|
images.extend([wandb.Image(img) for img in chunk])
|
||||||
|
remaining -= take
|
||||||
|
|
||||||
|
if images:
|
||||||
|
wandb_logger.experiment.log(
|
||||||
|
{f"{self.save_dir}/samples": images},
|
||||||
|
step=trainer.global_step,
|
||||||
|
)
|
||||||
|
|
||||||
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||||
target_dir = os.path.join(trainer.default_root_dir, self.save_dir, f"iter_{trainer.global_step}")
|
target_dir = os.path.join(trainer.default_root_dir, self.save_dir, f"iter_{trainer.global_step}")
|
||||||
self.save_start(target_dir)
|
self.save_start(target_dir)
|
||||||
@@ -84,6 +134,7 @@ class SaveImagesHook(Callback):
|
|||||||
return self.process_batch(trainer, pl_module, outputs, batch)
|
return self.process_batch(trainer, pl_module, outputs, batch)
|
||||||
|
|
||||||
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||||
|
self._log_wandb_samples(trainer)
|
||||||
self.save_end()
|
self.save_end()
|
||||||
|
|
||||||
def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||||
@@ -102,4 +153,4 @@ class SaveImagesHook(Callback):
|
|||||||
return self.process_batch(trainer, pl_module, samples, batch)
|
return self.process_batch(trainer, pl_module, samples, batch)
|
||||||
|
|
||||||
def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||||
self.save_end()
|
self.save_end()
|
||||||
|
|||||||
43
src/data/dataset/hf_image.py
Normal file
43
src/data/dataset/hf_image.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torchvision.transforms import Normalize
|
||||||
|
from torchvision.transforms.functional import to_tensor
|
||||||
|
from PIL import Image
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
from src.data.dataset.metric_dataset import CenterCrop
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingFaceImageDataset(Dataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
split: str = "train",
|
||||||
|
image_column: str = "image",
|
||||||
|
label_column: Optional[str] = None,
|
||||||
|
resolution: int = 256,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
self.dataset = load_dataset(name, split=split, cache_dir=cache_dir)
|
||||||
|
self.image_column = image_column
|
||||||
|
self.label_column = label_column
|
||||||
|
self.transform = CenterCrop(resolution)
|
||||||
|
self.normalize = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.dataset)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int):
|
||||||
|
item = self.dataset[idx]
|
||||||
|
image = item[self.image_column]
|
||||||
|
if not isinstance(image, Image.Image):
|
||||||
|
image = Image.fromarray(image)
|
||||||
|
image = image.convert("RGB")
|
||||||
|
image = self.transform(image)
|
||||||
|
|
||||||
|
raw_image = to_tensor(image)
|
||||||
|
normalized_image = self.normalize(raw_image)
|
||||||
|
label = 0 if self.label_column is None else int(item[self.label_column])
|
||||||
|
return raw_image, normalized_image, label
|
||||||
@@ -41,23 +41,25 @@ class AdamLMSampler(BaseSampler):
|
|||||||
self,
|
self,
|
||||||
order: int = 2,
|
order: int = 2,
|
||||||
timeshift: float = 1.0,
|
timeshift: float = 1.0,
|
||||||
|
guidance_interval_min: float = 0.0,
|
||||||
|
guidance_interval_max: float = 1.0,
|
||||||
lms_transform_fn: Callable = nop,
|
lms_transform_fn: Callable = nop,
|
||||||
w_scheduler: BaseScheduler = None,
|
|
||||||
step_fn: Callable = ode_step_fn,
|
step_fn: Callable = ode_step_fn,
|
||||||
*args,
|
*args,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.step_fn = step_fn
|
self.step_fn = step_fn
|
||||||
self.w_scheduler = w_scheduler
|
|
||||||
|
|
||||||
assert self.scheduler is not None
|
assert self.scheduler is not None
|
||||||
assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
|
assert self.step_fn in [ode_step_fn, ]
|
||||||
self.order = order
|
self.order = order
|
||||||
self.lms_transform_fn = lms_transform_fn
|
self.lms_transform_fn = lms_transform_fn
|
||||||
|
|
||||||
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
|
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
|
||||||
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
|
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
|
||||||
|
self.guidance_interval_min = guidance_interval_min
|
||||||
|
self.guidance_interval_max = guidance_interval_max
|
||||||
self.timesteps = shift_respace_fn(timesteps, timeshift)
|
self.timesteps = shift_respace_fn(timesteps, timeshift)
|
||||||
self.timedeltas = self.timesteps[1:] - self.timesteps[:-1]
|
self.timedeltas = self.timesteps[1:] - self.timesteps[:-1]
|
||||||
self._reparameterize_coeffs()
|
self._reparameterize_coeffs()
|
||||||
@@ -93,7 +95,11 @@ class AdamLMSampler(BaseSampler):
|
|||||||
cfg_x = torch.cat([x, x], dim=0)
|
cfg_x = torch.cat([x, x], dim=0)
|
||||||
cfg_t = t_cur.repeat(2)
|
cfg_t = t_cur.repeat(2)
|
||||||
out = net(cfg_x, cfg_t, cfg_condition)
|
out = net(cfg_x, cfg_t, cfg_condition)
|
||||||
out = self.guidance_fn(out, self.guidance)
|
if t_cur[0] > self.guidance_interval_min and t_cur[0] < self.guidance_interval_max:
|
||||||
|
guidance = self.guidance
|
||||||
|
out = self.guidance_fn(out, guidance)
|
||||||
|
else:
|
||||||
|
out = self.guidance_fn(out, 1.0)
|
||||||
pred_trajectory.append(out)
|
pred_trajectory.append(out)
|
||||||
out = torch.zeros_like(out)
|
out = torch.zeros_like(out)
|
||||||
order = len(self.solver_coeffs[i])
|
order = len(self.solver_coeffs[i])
|
||||||
|
|||||||
104
src/diffusion/flow_matching/training_disperse.py
Normal file
104
src/diffusion/flow_matching/training_disperse.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
import torch
|
||||||
|
import copy
|
||||||
|
import timm
|
||||||
|
from torch.nn import Parameter
|
||||||
|
|
||||||
|
from src.utils.no_grad import no_grad
|
||||||
|
from typing import Callable, Iterator, Tuple
|
||||||
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
from torchvision.transforms import Normalize
|
||||||
|
from src.diffusion.base.training import *
|
||||||
|
from src.diffusion.base.scheduling import BaseScheduler
|
||||||
|
|
||||||
|
def inverse_sigma(alpha, sigma):
|
||||||
|
return 1/sigma**2
|
||||||
|
def snr(alpha, sigma):
|
||||||
|
return alpha/sigma
|
||||||
|
def minsnr(alpha, sigma, threshold=5):
|
||||||
|
return torch.clip(alpha/sigma, min=threshold)
|
||||||
|
def maxsnr(alpha, sigma, threshold=5):
|
||||||
|
return torch.clip(alpha/sigma, max=threshold)
|
||||||
|
def constant(alpha, sigma):
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
def time_shift_fn(t, timeshift=1.0):
|
||||||
|
return t/(t+(1-t)*timeshift)
|
||||||
|
|
||||||
|
|
||||||
|
class DisperseTrainer(BaseTrainer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
scheduler: BaseScheduler,
|
||||||
|
loss_weight_fn:Callable=constant,
|
||||||
|
feat_loss_weight: float=0.5,
|
||||||
|
lognorm_t=False,
|
||||||
|
timeshift=1.0,
|
||||||
|
align_layer=8,
|
||||||
|
temperature=1.0,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.lognorm_t = lognorm_t
|
||||||
|
self.scheduler = scheduler
|
||||||
|
self.timeshift = timeshift
|
||||||
|
self.loss_weight_fn = loss_weight_fn
|
||||||
|
self.feat_loss_weight = feat_loss_weight
|
||||||
|
self.align_layer = align_layer
|
||||||
|
self.temperature = temperature
|
||||||
|
|
||||||
|
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||||
|
batch_size, c, height, width = x.shape
|
||||||
|
if self.lognorm_t:
|
||||||
|
base_t = torch.randn((batch_size), device=x.device, dtype=torch.float32).sigmoid()
|
||||||
|
else:
|
||||||
|
base_t = torch.rand((batch_size), device=x.device, dtype=torch.float32)
|
||||||
|
t = time_shift_fn(base_t, self.timeshift).to(x.dtype)
|
||||||
|
noise = torch.randn_like(x)
|
||||||
|
alpha = self.scheduler.alpha(t)
|
||||||
|
dalpha = self.scheduler.dalpha(t)
|
||||||
|
sigma = self.scheduler.sigma(t)
|
||||||
|
dsigma = self.scheduler.dsigma(t)
|
||||||
|
|
||||||
|
x_t = alpha * x + noise * sigma
|
||||||
|
v_t = dalpha * x + dsigma * noise
|
||||||
|
|
||||||
|
src_feature = []
|
||||||
|
def forward_hook(net, input, output):
|
||||||
|
feature = output
|
||||||
|
if isinstance(feature, tuple):
|
||||||
|
feature = feature[0] # mmdit
|
||||||
|
src_feature.append(feature)
|
||||||
|
|
||||||
|
if getattr(net, "encoder", None) is not None:
|
||||||
|
handle = net.encoder.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
|
||||||
|
else:
|
||||||
|
handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
|
||||||
|
|
||||||
|
out = net(x_t, t, y)
|
||||||
|
handle.remove()
|
||||||
|
disperse_loss = 0.0
|
||||||
|
world_size = torch.distributed.get_world_size()
|
||||||
|
for sf in src_feature:
|
||||||
|
gathered_sf = [torch.zeros_like(sf) for _ in range(world_size)]
|
||||||
|
torch.distributed.all_gather(gathered_sf, sf)
|
||||||
|
gathered_sf = torch.cat(gathered_sf, dim=0)
|
||||||
|
sf = gathered_sf.view(batch_size * world_size, -1)
|
||||||
|
sf = sf.view(batch_size * world_size, -1)
|
||||||
|
# normalize sf
|
||||||
|
sf = sf / torch.norm(sf, dim=1, keepdim=True)
|
||||||
|
distance = torch.nn.functional.pdist(sf, p=2) ** 2
|
||||||
|
sf_disperse_distance = torch.exp(-distance / self.temperature) + 1e-5
|
||||||
|
disperse_loss += sf_disperse_distance.mean().log()
|
||||||
|
|
||||||
|
|
||||||
|
weight = self.loss_weight_fn(alpha, sigma)
|
||||||
|
fm_loss = weight*(out - v_t)**2
|
||||||
|
|
||||||
|
out = dict(
|
||||||
|
fm_loss=fm_loss.mean(),
|
||||||
|
disperse_loss=disperse_loss.mean(),
|
||||||
|
loss=fm_loss.mean() + self.feat_loss_weight*disperse_loss.mean(),
|
||||||
|
)
|
||||||
|
return out
|
||||||
@@ -41,22 +41,24 @@ class AdamLMSampler(BaseSampler):
|
|||||||
self,
|
self,
|
||||||
order: int = 2,
|
order: int = 2,
|
||||||
timeshift: float = 1.0,
|
timeshift: float = 1.0,
|
||||||
|
guidance_interval_min: float = 0.0,
|
||||||
|
guidance_interval_max: float = 1.0,
|
||||||
state_refresh_rate: int = 1,
|
state_refresh_rate: int = 1,
|
||||||
lms_transform_fn: Callable = nop,
|
lms_transform_fn: Callable = nop,
|
||||||
w_scheduler: BaseScheduler = None,
|
|
||||||
step_fn: Callable = ode_step_fn,
|
step_fn: Callable = ode_step_fn,
|
||||||
*args,
|
*args,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.step_fn = step_fn
|
self.step_fn = step_fn
|
||||||
self.w_scheduler = w_scheduler
|
|
||||||
self.state_refresh_rate = state_refresh_rate
|
self.state_refresh_rate = state_refresh_rate
|
||||||
|
|
||||||
assert self.scheduler is not None
|
assert self.scheduler is not None
|
||||||
assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
|
assert self.step_fn in [ode_step_fn, ]
|
||||||
self.order = order
|
self.order = order
|
||||||
self.lms_transform_fn = lms_transform_fn
|
self.lms_transform_fn = lms_transform_fn
|
||||||
|
self.guidance_interval_min = guidance_interval_min
|
||||||
|
self.guidance_interval_max = guidance_interval_max
|
||||||
|
|
||||||
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
|
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
|
||||||
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
|
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
|
||||||
@@ -98,7 +100,11 @@ class AdamLMSampler(BaseSampler):
|
|||||||
if i % self.state_refresh_rate == 0:
|
if i % self.state_refresh_rate == 0:
|
||||||
state = None
|
state = None
|
||||||
out, state = net(cfg_x, cfg_t, cfg_condition, state)
|
out, state = net(cfg_x, cfg_t, cfg_condition, state)
|
||||||
out = self.guidance_fn(out, self.guidance)
|
if t_cur[0] > self.guidance_interval_min and t_cur[0] < self.guidance_interval_max:
|
||||||
|
guidance = self.guidance
|
||||||
|
out = self.guidance_fn(out, guidance)
|
||||||
|
else:
|
||||||
|
out = self.guidance_fn(out, 1.0)
|
||||||
pred_trajectory.append(out)
|
pred_trajectory.append(out)
|
||||||
out = torch.zeros_like(out)
|
out = torch.zeros_like(out)
|
||||||
order = len(self.solver_coeffs[i])
|
order = len(self.solver_coeffs[i])
|
||||||
|
|||||||
@@ -29,6 +29,11 @@ class DataModule(pl.LightningDataModule):
|
|||||||
var_transform_engine: VARTransformEngine = None,
|
var_transform_engine: VARTransformEngine = None,
|
||||||
train_prefetch_factor=2,
|
train_prefetch_factor=2,
|
||||||
train_dataset: str = None,
|
train_dataset: str = None,
|
||||||
|
train_hf_name: str = None,
|
||||||
|
train_hf_split: str = "train",
|
||||||
|
train_hf_image_column: str = "image",
|
||||||
|
train_hf_label_column: str = None,
|
||||||
|
train_hf_cache_dir: str = None,
|
||||||
eval_batch_size=32,
|
eval_batch_size=32,
|
||||||
eval_num_workers=4,
|
eval_num_workers=4,
|
||||||
eval_max_num_instances=50000,
|
eval_max_num_instances=50000,
|
||||||
@@ -45,6 +50,11 @@ class DataModule(pl.LightningDataModule):
|
|||||||
self.train_root = train_root
|
self.train_root = train_root
|
||||||
self.train_image_size = train_image_size
|
self.train_image_size = train_image_size
|
||||||
self.train_dataset = train_dataset
|
self.train_dataset = train_dataset
|
||||||
|
self.train_hf_name = train_hf_name
|
||||||
|
self.train_hf_split = train_hf_split
|
||||||
|
self.train_hf_image_column = train_hf_image_column
|
||||||
|
self.train_hf_label_column = train_hf_label_column
|
||||||
|
self.train_hf_cache_dir = train_hf_cache_dir
|
||||||
# stupid data_convert override, just to make nebular happy
|
# stupid data_convert override, just to make nebular happy
|
||||||
self.train_batch_size = train_batch_size
|
self.train_batch_size = train_batch_size
|
||||||
self.train_num_workers = train_num_workers
|
self.train_num_workers = train_num_workers
|
||||||
@@ -101,6 +111,18 @@ class DataModule(pl.LightningDataModule):
|
|||||||
self.train_dataset = PixImageNet512(
|
self.train_dataset = PixImageNet512(
|
||||||
root=self.train_root,
|
root=self.train_root,
|
||||||
)
|
)
|
||||||
|
elif self.train_dataset == "hf_image":
|
||||||
|
from src.data.dataset.hf_image import HuggingFaceImageDataset
|
||||||
|
if self.train_hf_name is None:
|
||||||
|
raise ValueError("train_hf_name must be set when train_dataset=hf_image")
|
||||||
|
self.train_dataset = HuggingFaceImageDataset(
|
||||||
|
name=self.train_hf_name,
|
||||||
|
split=self.train_hf_split,
|
||||||
|
image_column=self.train_hf_image_column,
|
||||||
|
label_column=self.train_hf_label_column,
|
||||||
|
resolution=self.train_image_size,
|
||||||
|
cache_dir=self.train_hf_cache_dir,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("no such dataset")
|
raise NotImplementedError("no such dataset")
|
||||||
|
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ class LightningModel(pl.LightningModule):
|
|||||||
ema_tracker: Optional[EMACallable] = None,
|
ema_tracker: Optional[EMACallable] = None,
|
||||||
optimizer: OptimizerCallable = None,
|
optimizer: OptimizerCallable = None,
|
||||||
lr_scheduler: LRSchedulerCallable = None,
|
lr_scheduler: LRSchedulerCallable = None,
|
||||||
|
compile: bool = False
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.vae = vae
|
self.vae = vae
|
||||||
@@ -48,6 +49,7 @@ class LightningModel(pl.LightningModule):
|
|||||||
# self.model_loader = ModelLoader()
|
# self.model_loader = ModelLoader()
|
||||||
|
|
||||||
self._strict_loading = False
|
self._strict_loading = False
|
||||||
|
self._compile = compile
|
||||||
|
|
||||||
def configure_model(self) -> None:
|
def configure_model(self) -> None:
|
||||||
self.trainer.strategy.barrier()
|
self.trainer.strategy.barrier()
|
||||||
@@ -61,6 +63,11 @@ class LightningModel(pl.LightningModule):
|
|||||||
no_grad(self.diffusion_sampler)
|
no_grad(self.diffusion_sampler)
|
||||||
no_grad(self.ema_denoiser)
|
no_grad(self.ema_denoiser)
|
||||||
|
|
||||||
|
# add compile to speed up
|
||||||
|
if self._compile:
|
||||||
|
self.denoiser = torch.compile(self.denoiser)
|
||||||
|
self.ema_denoiser = torch.compile(self.ema_denoiser)
|
||||||
|
|
||||||
def configure_callbacks(self) -> Union[Sequence[Callback], Callback]:
|
def configure_callbacks(self) -> Union[Sequence[Callback], Callback]:
|
||||||
ema_tracker = self.ema_tracker(self.denoiser, self.ema_denoiser)
|
ema_tracker = self.ema_tracker(self.denoiser, self.ema_denoiser)
|
||||||
return [ema_tracker]
|
return [ema_tracker]
|
||||||
|
|||||||
@@ -305,4 +305,4 @@ class DDT(nn.Module):
|
|||||||
x = self.blocks[i](x, s, pos, None)
|
x = self.blocks[i](x, s, pos, None)
|
||||||
x = self.final_layer(x, s)
|
x = self.final_layer(x, s)
|
||||||
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
|
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
|
||||||
return x, s
|
return x, s
|
||||||
|
|||||||
@@ -251,11 +251,11 @@ class DiT(nn.Module):
|
|||||||
self.initialize_weights()
|
self.initialize_weights()
|
||||||
self.precompute_pos = dict()
|
self.precompute_pos = dict()
|
||||||
|
|
||||||
def fetch_pos(self, height, width, device, dtype):
|
def fetch_pos(self, height, width, device):
|
||||||
if (height, width) in self.precompute_pos:
|
if (height, width) in self.precompute_pos:
|
||||||
return self.precompute_pos[(height, width)].to(device, dtype)
|
return self.precompute_pos[(height, width)].to(device)
|
||||||
else:
|
else:
|
||||||
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device, dtype)
|
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
|
||||||
self.precompute_pos[(height, width)] = pos
|
self.precompute_pos[(height, width)] = pos
|
||||||
return pos
|
return pos
|
||||||
|
|
||||||
@@ -289,7 +289,7 @@ class DiT(nn.Module):
|
|||||||
B, _, H, W = x.shape
|
B, _, H, W = x.shape
|
||||||
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
|
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
|
||||||
x = self.x_embedder(x)
|
x = self.x_embedder(x)
|
||||||
pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device, x.dtype)
|
pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device)
|
||||||
B, L, C = x.shape
|
B, L, C = x.shape
|
||||||
t = self.t_embedder(t.view(-1)).view(B, -1, C)
|
t = self.t_embedder(t.view(-1)).view(B, -1, C)
|
||||||
y = self.y_embedder(y).view(B, 1, C)
|
y = self.y_embedder(y).view(B, 1, C)
|
||||||
@@ -298,4 +298,4 @@ class DiT(nn.Module):
|
|||||||
x = block(x, condition, pos, masks[i])
|
x = block(x, condition, pos, masks[i])
|
||||||
x = self.final_layer(x, condition)
|
x = self.final_layer(x, condition)
|
||||||
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
|
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
|
||||||
return x
|
return x
|
||||||
|
|||||||
463
src/models/denoiser/mamba2.py
Normal file
463
src/models/denoiser/mamba2.py
Normal file
@@ -0,0 +1,463 @@
|
|||||||
|
"""
|
||||||
|
mamba2-minimal
|
||||||
|
==============
|
||||||
|
|
||||||
|
A minimal, single-file implementation of the Mamba-2 model in PyTorch.
|
||||||
|
|
||||||
|
> **Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality**
|
||||||
|
> Authors: Tri Dao, Albert Gu
|
||||||
|
> Paper: https://arxiv.org/abs/2405.21060
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Iterable, NamedTuple, TypeAlias, cast
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from torch import LongTensor, Tensor, nn
|
||||||
|
|
||||||
|
Device: TypeAlias = str | torch.device | None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Mamba2Config:
|
||||||
|
d_model: int # model dimension (D)
|
||||||
|
n_layer: int = 24 # number of Mamba-2 layers in the language model
|
||||||
|
d_state: int = 128 # state dimension (N)
|
||||||
|
d_conv: int = 4 # convolution kernel size
|
||||||
|
expand: int = 2 # expansion factor (E)
|
||||||
|
headdim: int = 64 # head dimension (P)
|
||||||
|
chunk_size: int = 64 # matrix partition size (Q)
|
||||||
|
vocab_size: int = 50277
|
||||||
|
pad_vocab_size_multiple: int = 16
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self.d_inner = self.expand * self.d_model
|
||||||
|
assert self.d_inner % self.headdim == 0
|
||||||
|
self.nheads = self.d_inner // self.headdim
|
||||||
|
if self.vocab_size % self.pad_vocab_size_multiple != 0:
|
||||||
|
self.vocab_size += (
|
||||||
|
self.pad_vocab_size_multiple
|
||||||
|
- self.vocab_size % self.pad_vocab_size_multiple
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceCache(NamedTuple):
|
||||||
|
conv_state: Tensor # (batch, d_inner + 2 * d_state, d_conv)
|
||||||
|
ssm_state: Tensor # (batch, nheads, headdim, d_state)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def alloc(batch_size: int, args: Mamba2Config, device: Device = None):
|
||||||
|
return InferenceCache(
|
||||||
|
torch.zeros(
|
||||||
|
batch_size, args.d_inner + 2 * args.d_state, args.d_conv, device=device
|
||||||
|
),
|
||||||
|
torch.zeros(
|
||||||
|
batch_size, args.nheads, args.headdim, args.d_state, device=device
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Mamba2LMHeadModel(nn.Module):
|
||||||
|
def __init__(self, args: Mamba2Config, device: Device = None):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
self.backbone = nn.ModuleDict(
|
||||||
|
dict(
|
||||||
|
embedding=nn.Embedding(args.vocab_size, args.d_model, device=device),
|
||||||
|
layers=nn.ModuleList(
|
||||||
|
[
|
||||||
|
nn.ModuleDict(
|
||||||
|
dict(
|
||||||
|
mixer=Mamba2(args, device=device),
|
||||||
|
norm=RMSNorm(args.d_model, device=device),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for _ in range(args.n_layer)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
norm_f=RMSNorm(args.d_model, device=device),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.lm_head = nn.Linear(
|
||||||
|
args.d_model, args.vocab_size, bias=False, device=device
|
||||||
|
)
|
||||||
|
self.lm_head.weight = self.backbone.embedding.weight
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_pretrained(huggingface_model_id: str, device: Device = None):
|
||||||
|
from transformers.utils import CONFIG_NAME, WEIGHTS_NAME
|
||||||
|
from transformers.utils.hub import cached_file
|
||||||
|
|
||||||
|
config_path = cached_file(huggingface_model_id, CONFIG_NAME)
|
||||||
|
assert config_path, "Failed to get huggingface config file"
|
||||||
|
state_dict_path = cached_file(huggingface_model_id, WEIGHTS_NAME)
|
||||||
|
assert state_dict_path, "Failed to get huggingface state dict file"
|
||||||
|
|
||||||
|
config = json.load(open(config_path))
|
||||||
|
args = Mamba2Config(
|
||||||
|
d_model=config["d_model"],
|
||||||
|
n_layer=config["n_layer"],
|
||||||
|
vocab_size=config["vocab_size"],
|
||||||
|
pad_vocab_size_multiple=config["pad_vocab_size_multiple"],
|
||||||
|
)
|
||||||
|
|
||||||
|
map_location = "cpu" if device is None else device
|
||||||
|
state_dict = torch.load(
|
||||||
|
state_dict_path, weights_only=True, map_location=map_location, mmap=True
|
||||||
|
)
|
||||||
|
model = Mamba2LMHeadModel(args, device=device)
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input_ids: LongTensor, h: list[InferenceCache] | list[None] | None = None
|
||||||
|
) -> tuple[LongTensor, list[InferenceCache]]:
|
||||||
|
"""
|
||||||
|
Arguments
|
||||||
|
input_ids: (batch, seqlen) tokens from `EleutherAI/gpt-neox-20b` tokenizer
|
||||||
|
h: hidden states for inference step. If present the constant-time
|
||||||
|
(wrt sequence length) inference path will be taken, input_ids
|
||||||
|
should have shape (batch, 1) containing the next batch of prompt
|
||||||
|
token.
|
||||||
|
|
||||||
|
Return (logits, h)
|
||||||
|
logits: (batch, seqlen, vocab_size)
|
||||||
|
h: updated inference cache after processing `input_ids`
|
||||||
|
"""
|
||||||
|
seqlen = input_ids.shape[1]
|
||||||
|
|
||||||
|
if h is None:
|
||||||
|
h = [None for _ in range(self.args.n_layer)]
|
||||||
|
|
||||||
|
x = self.backbone.embedding(input_ids)
|
||||||
|
for i, layer in enumerate(self.backbone.layers):
|
||||||
|
y, h[i] = layer.mixer(layer.norm(x), h[i])
|
||||||
|
x = y + x
|
||||||
|
|
||||||
|
x = self.backbone.norm_f(x)
|
||||||
|
logits = self.lm_head(x)
|
||||||
|
return logits[:, :seqlen], cast(list[InferenceCache], h)
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
input_ids: LongTensor,
|
||||||
|
max_new_length: int = 20,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
top_k: int = 50,
|
||||||
|
top_p: float = 1.0,
|
||||||
|
eos_token_id: int = 0,
|
||||||
|
) -> Iterable[tuple[int, list[InferenceCache]]]:
|
||||||
|
prefix, tokens = input_ids[:-1], input_ids[-1:].unsqueeze(0)
|
||||||
|
cache_device = input_ids.device
|
||||||
|
|
||||||
|
# Process prompt
|
||||||
|
# The input sequence to forward (non-inference path) must have length multiple that of chunk_size.
|
||||||
|
# We split out excess tokens so that n_chunked tokens can be processed by one forward call and
|
||||||
|
# process the rest in multiple inference steps.
|
||||||
|
n_chunked = (prefix.shape[0] // self.args.chunk_size) * self.args.chunk_size
|
||||||
|
if n_chunked > 0:
|
||||||
|
_, h = self(prefix[:n_chunked].unsqueeze(0), None)
|
||||||
|
else:
|
||||||
|
h = [
|
||||||
|
InferenceCache.alloc(1, self.args, device=cache_device)
|
||||||
|
for _ in range(self.args.n_layer)
|
||||||
|
]
|
||||||
|
for i in range(n_chunked, prefix.shape[0]):
|
||||||
|
_, h = self(prefix[i : i + 1].unsqueeze(0), h)
|
||||||
|
|
||||||
|
# Generate
|
||||||
|
for _ in range(max_new_length):
|
||||||
|
with torch.no_grad():
|
||||||
|
out, h = self(tokens, h)
|
||||||
|
logits = out[0, -1]
|
||||||
|
if temperature != 1.0:
|
||||||
|
logits = logits / temperature
|
||||||
|
if top_k > 0:
|
||||||
|
indices_to_remove = logits < torch.topk(logits, k=top_k)[0][-1]
|
||||||
|
logits[indices_to_remove] = -torch.inf
|
||||||
|
if top_p < 1.0:
|
||||||
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||||
|
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||||
|
sorted_indices_to_remove = cum_probs > 0.5
|
||||||
|
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
|
||||||
|
sorted_indices_to_remove[0] = False
|
||||||
|
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
||||||
|
logits[indices_to_remove] = -torch.inf
|
||||||
|
probs = F.softmax(logits, dim=-1)
|
||||||
|
next_token = torch.multinomial(probs, num_samples=1)
|
||||||
|
if next_token.item() == eos_token_id:
|
||||||
|
return
|
||||||
|
tokens = next_token.unsqueeze(0)
|
||||||
|
yield cast(int, next_token.item()), h
|
||||||
|
|
||||||
|
|
||||||
|
class Mamba2(nn.Module):
|
||||||
|
def __init__(self, args: Mamba2Config, device: Device = None):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
# Order: (z, x, B, C, dt)
|
||||||
|
d_in_proj = 2 * args.d_inner + 2 * args.d_state + args.nheads
|
||||||
|
self.in_proj = nn.Linear(args.d_model, d_in_proj, bias=False, device=device)
|
||||||
|
|
||||||
|
conv_dim = args.d_inner + 2 * args.d_state
|
||||||
|
self.conv1d = nn.Conv1d(
|
||||||
|
in_channels=conv_dim,
|
||||||
|
out_channels=conv_dim,
|
||||||
|
kernel_size=args.d_conv,
|
||||||
|
groups=conv_dim,
|
||||||
|
padding=args.d_conv - 1,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.dt_bias = nn.Parameter(torch.empty(args.nheads, device=device))
|
||||||
|
self.A_log = nn.Parameter(torch.empty(args.nheads, device=device))
|
||||||
|
self.D = nn.Parameter(torch.empty(args.nheads, device=device))
|
||||||
|
self.norm = RMSNorm(args.d_inner, device=device)
|
||||||
|
self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=False, device=device)
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
# Mamba-style parameter init for stable SSM dynamics.
|
||||||
|
dt_min, dt_max = 1e-3, 1e-1
|
||||||
|
device = self.dt_bias.device
|
||||||
|
dtype = self.dt_bias.dtype
|
||||||
|
dt = torch.exp(
|
||||||
|
torch.rand(self.args.nheads, device=device, dtype=dtype)
|
||||||
|
* (math.log(dt_max) - math.log(dt_min))
|
||||||
|
+ math.log(dt_min)
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
self.dt_bias.copy_(torch.log(torch.expm1(dt)))
|
||||||
|
self.A_log.copy_(
|
||||||
|
torch.log(
|
||||||
|
torch.arange(
|
||||||
|
1, self.args.nheads + 1, device=device, dtype=self.A_log.dtype
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.D.fill_(1.0)
|
||||||
|
|
||||||
|
def forward(self, u: Tensor, h: InferenceCache | None = None):
|
||||||
|
"""
|
||||||
|
Arguments
|
||||||
|
u: (batch, seqlen, d_model) input. seqlen should be a multiple of chunk_size.
|
||||||
|
h: hidden states for inference step. Initialized to 0s if not present.
|
||||||
|
|
||||||
|
Return (y, h)
|
||||||
|
y: (batch, seqlen, d_model) output
|
||||||
|
h: updated inference cache after processing `u`
|
||||||
|
"""
|
||||||
|
if h:
|
||||||
|
return self.step(u, h)
|
||||||
|
|
||||||
|
A = -torch.exp(self.A_log) # (nheads,)
|
||||||
|
zxbcdt = self.in_proj(u) # (batch, seqlen, d_in_proj)
|
||||||
|
z, xBC, dt = torch.split(
|
||||||
|
zxbcdt,
|
||||||
|
[
|
||||||
|
self.args.d_inner,
|
||||||
|
self.args.d_inner + 2 * self.args.d_state,
|
||||||
|
self.args.nheads,
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
dt = F.softplus(dt + self.dt_bias) # (batch, seqlen, nheads)
|
||||||
|
|
||||||
|
# Pad or truncate xBC seqlen to d_conv
|
||||||
|
conv_state = F.pad(
|
||||||
|
rearrange(xBC, "b l d -> b d l"), (self.args.d_conv - u.shape[1], 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
xBC = silu(
|
||||||
|
self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, : u.shape[1], :]
|
||||||
|
) # (batch, seqlen, d_inner + 2 * d_state))
|
||||||
|
x, B, C = torch.split(
|
||||||
|
xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1
|
||||||
|
)
|
||||||
|
x = rearrange(x, "b l (h p) -> b l h p", p=self.args.headdim)
|
||||||
|
y, ssm_state = ssd(
|
||||||
|
x * dt.unsqueeze(-1),
|
||||||
|
A * dt,
|
||||||
|
rearrange(B, "b l n -> b l 1 n"),
|
||||||
|
rearrange(C, "b l n -> b l 1 n"),
|
||||||
|
self.args.chunk_size,
|
||||||
|
device=x.device,
|
||||||
|
)
|
||||||
|
y = y + x * self.D.unsqueeze(-1)
|
||||||
|
y = rearrange(y, "b l h p -> b l (h p)")
|
||||||
|
y = self.norm(y, z)
|
||||||
|
y = self.out_proj(y)
|
||||||
|
|
||||||
|
h = InferenceCache(conv_state, ssm_state)
|
||||||
|
return y, h
|
||||||
|
|
||||||
|
def step(self, u: Tensor, h: InferenceCache) -> tuple[Tensor, InferenceCache]:
|
||||||
|
"""Take a single inference step for the current input and hidden state
|
||||||
|
|
||||||
|
Unlike attention-based models, RNN-based models (eg Mamba) does not need
|
||||||
|
to look back at all the past tokens to generate a new token. Instead a
|
||||||
|
hidden state (initialized to 0s initially) is updated for each input and
|
||||||
|
passed to the next inference step. This means that the total inference
|
||||||
|
time is linear with respect to the sequence length instead of quadratic
|
||||||
|
in attention's case.
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
u: (batch, 1, d_model)
|
||||||
|
h: initial/running hidden state
|
||||||
|
|
||||||
|
Return (y, h)
|
||||||
|
y: (batch, 1, d_model)
|
||||||
|
h: updated hidden state
|
||||||
|
"""
|
||||||
|
assert u.shape[1] == 1, "Only one token can be decoded per inference step"
|
||||||
|
|
||||||
|
zxbcdt = self.in_proj(u.squeeze(1)) # (batch, d_in_proj)
|
||||||
|
z, xBC, dt = torch.split(
|
||||||
|
zxbcdt,
|
||||||
|
[
|
||||||
|
self.args.d_inner,
|
||||||
|
self.args.d_inner + 2 * self.args.d_state,
|
||||||
|
self.args.nheads,
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Advance convolution input
|
||||||
|
h.conv_state.copy_(torch.roll(h.conv_state, shifts=-1, dims=-1))
|
||||||
|
h.conv_state[:, :, -1] = xBC
|
||||||
|
# Convolution step
|
||||||
|
xBC = torch.sum(
|
||||||
|
h.conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
|
||||||
|
)
|
||||||
|
xBC += self.conv1d.bias
|
||||||
|
xBC = silu(xBC)
|
||||||
|
|
||||||
|
x, B, C = torch.split(
|
||||||
|
xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1
|
||||||
|
)
|
||||||
|
A = -torch.exp(self.A_log) # (nheads,)
|
||||||
|
|
||||||
|
# SSM step
|
||||||
|
dt = F.softplus(dt + self.dt_bias) # (batch, nheads)
|
||||||
|
dA = torch.exp(dt * A) # (batch, nheads)
|
||||||
|
x = rearrange(x, "b (h p) -> b h p", p=self.args.headdim)
|
||||||
|
dBx = torch.einsum("bh, bn, bhp -> bhpn", dt, B, x)
|
||||||
|
h.ssm_state.copy_(h.ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
|
||||||
|
y = torch.einsum("bhpn, bn -> bhp", h.ssm_state, C)
|
||||||
|
y = y + rearrange(self.D, "h -> h 1") * x
|
||||||
|
y = rearrange(y, "b h p -> b (h p)")
|
||||||
|
y = self.norm(y, z)
|
||||||
|
y = self.out_proj(y)
|
||||||
|
|
||||||
|
return y.unsqueeze(1), h
|
||||||
|
|
||||||
|
|
||||||
|
def segsum(x: Tensor, device: Device = None) -> Tensor:
|
||||||
|
"""Stable segment sum calculation.
|
||||||
|
|
||||||
|
`exp(segsum(A))` produces a 1-semiseparable matrix, which is equivalent to a scalar SSM.
|
||||||
|
|
||||||
|
Source: https://github.com/state-spaces/mamba/blob/219f03c840d5a44e7d42e4e728134834fddccf45/mamba_ssm/modules/ssd_minimal.py#L23-L32
|
||||||
|
"""
|
||||||
|
if device is None:
|
||||||
|
device = x.device
|
||||||
|
T = x.size(-1)
|
||||||
|
x = repeat(x, "... d -> ... d e", e=T)
|
||||||
|
mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=-1)
|
||||||
|
x = x.masked_fill(~mask, 0)
|
||||||
|
x_segsum = torch.cumsum(x, dim=-2)
|
||||||
|
mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=0)
|
||||||
|
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
|
||||||
|
return x_segsum
|
||||||
|
|
||||||
|
|
||||||
|
def ssd(x, A, B, C, chunk_size, initial_states=None, device: Device = None):
|
||||||
|
"""Structed State Space Duality (SSD) - the core of Mamba-2
|
||||||
|
|
||||||
|
This is almost the exact same minimal SSD code from the blog post.
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
x: (batch, seqlen, n_heads, d_head)
|
||||||
|
A: (batch, seqlen, n_heads)
|
||||||
|
B: (batch, seqlen, n_heads, d_state)
|
||||||
|
C: (batch, seqlen, n_heads, d_state)
|
||||||
|
|
||||||
|
Return
|
||||||
|
y: (batch, seqlen, n_heads, d_head)
|
||||||
|
|
||||||
|
Source
|
||||||
|
1. https://tridao.me/blog/2024/mamba2-part3-algorithm/
|
||||||
|
2. https://github.com/state-spaces/mamba/blob/219f03c840d5a44e7d42e4e728134834fddccf45/mamba_ssm/modules/ssd_minimal.py#L34-L78
|
||||||
|
"""
|
||||||
|
assert x.shape[1] % chunk_size == 0
|
||||||
|
|
||||||
|
# Rearrange into chunks
|
||||||
|
# Step 1, 2 and 4 of SSD can be computed in parallel for each chunk across devices (sequence parallel)
|
||||||
|
# This is not implemented and left as an exercise for the reader 😜
|
||||||
|
x, A, B, C = [
|
||||||
|
rearrange(m, "b (c l) ... -> b c l ...", l=chunk_size) for m in (x, A, B, C)
|
||||||
|
]
|
||||||
|
|
||||||
|
A = rearrange(A, "b c l h -> b h c l")
|
||||||
|
A_cumsum = torch.cumsum(A, dim=-1)
|
||||||
|
|
||||||
|
# 1. Compute the output for each intra-chunk (diagonal blocks)
|
||||||
|
L = torch.exp(segsum(A, device=device))
|
||||||
|
Y_diag = torch.einsum("bclhn, bcshn, bhcls, bcshp -> bclhp", C, B, L, x)
|
||||||
|
|
||||||
|
# 2. Compute the state for each intra-chunk
|
||||||
|
# (right term of low-rank factorization of off-diagonal blocks; B terms)
|
||||||
|
decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum)
|
||||||
|
states = torch.einsum("bclhn, bhcl, bclhp -> bchpn", B, decay_states, x)
|
||||||
|
|
||||||
|
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
|
||||||
|
# (middle term of factorization of off-diag blocks; A terms)
|
||||||
|
if initial_states is None:
|
||||||
|
initial_states = torch.zeros_like(states[:, :1])
|
||||||
|
states = torch.cat([initial_states, states], dim=1)
|
||||||
|
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)), device=device))
|
||||||
|
new_states = torch.einsum("bhzc, bchpn -> bzhpn", decay_chunk, states)
|
||||||
|
states, final_state = new_states[:, :-1], new_states[:, -1]
|
||||||
|
|
||||||
|
# 4. Compute state -> output conversion per chunk
|
||||||
|
# (left term of low-rank factorization of off-diagonal blocks; C terms)
|
||||||
|
state_decay_out = torch.exp(A_cumsum)
|
||||||
|
Y_off = torch.einsum("bclhn, bchpn, bhcl -> bclhp", C, states, state_decay_out)
|
||||||
|
|
||||||
|
# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
|
||||||
|
Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p")
|
||||||
|
|
||||||
|
return Y, final_state
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
def __init__(self, d: int, eps: float = 1e-5, device: Device = None):
|
||||||
|
"""Gated Root Mean Square Layer Normalization
|
||||||
|
|
||||||
|
Paper: https://arxiv.org/abs/1910.07467
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.weight = nn.Parameter(torch.ones(d, device=device))
|
||||||
|
|
||||||
|
def forward(self, x, z=None):
|
||||||
|
if z is not None:
|
||||||
|
x = x * silu(z)
|
||||||
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
|
||||||
|
|
||||||
|
|
||||||
|
def silu(x):
|
||||||
|
"""Applies the Sigmoid Linear Unit (SiLU), element-wise.
|
||||||
|
|
||||||
|
Define this manually since torch's version doesn't seem to work on MPS.
|
||||||
|
"""
|
||||||
|
return x * F.sigmoid(x)
|
||||||
BIN
src/utils/__pycache__/patch_bugs.cpython-312.pyc
Normal file
BIN
src/utils/__pycache__/patch_bugs.cpython-312.pyc
Normal file
Binary file not shown.
Reference in New Issue
Block a user