feat: add mamba and dynamic chunking related code and test code

This commit is contained in:
gameloader
2025-09-04 01:32:13 +00:00
parent 12cb7652cf
commit ef307a57e9
21 changed files with 4550 additions and 86 deletions

440
layers/DynamicChunking.py Normal file
View File

@ -0,0 +1,440 @@
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import repeat, rearrange
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
@dataclass
class RoutingModuleOutput:
# 路由模块的前向输出:
# - boundary_prob: 每个位置为“分隔点/非分隔点”的二分类概率,形状(..., 2)
# - boundary_mask: 基于最大概率得到的硬选择True=分隔点),形状与输入序列相同的前两维
# - selected_probs: 对应硬选择类别的概率,形状(..., 1)
boundary_prob: torch.Tensor
boundary_mask: torch.Tensor
selected_probs: torch.Tensor
@dataclass
class RoutingModuleState:
"""
路由模块的推理状态(用于增量/流式step
包含:
- has_seen_tokens: (batch_size,) 是否已经见过任意token用于首token强制边界
- last_hidden_state: (batch_size, d_model) 上一次的隐藏状态用于与当前token做相邻相似度
"""
has_seen_tokens: torch.Tensor # (batch_size,)
last_hidden_state: torch.Tensor # (batch_size, d_model)
@dataclass
class DeChunkState:
"""
DeChunk 的推理状态EMA的记忆值
包含:
- last_value: (batch_size, d_model) EMA反聚合的上一时刻值
"""
last_value: torch.Tensor # (batch_size, d_model)
def get_seq_idx(cu_seqlens, device=None):
seq_idx = torch.zeros(cu_seqlens[-1], dtype=torch.long, device=device)
seq_idx[cu_seqlens[:-1]] = 1
seq_idx = (torch.cumsum(seq_idx, dim=0) - 1).unsqueeze(0).int()
return seq_idx
class RoutingModule(nn.Module):
"""
路由模块:
用相邻token的余弦相似度构造“成为分隔点”的概率
p_t = clamp((1 - cos(h_{t-1}, h_t)) / 2, 0, 1)
并强制首位置为边界(概率=1
支持:
- 常规batch掩码 mask 模式
- packed序列 cu_seqlens 模式(把多序列打包成单条序列的拼接)
- 流式推理 step()(维护状态)
"""
def __init__(self, d_model, device=None, dtype=None):
self.d_model = d_model
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
# 相邻相似度计算前的线性投影(初始化为恒等)
self.q_proj_layer = nn.Linear(d_model, d_model, bias=False, **factory_kwargs)
self.k_proj_layer = nn.Linear(d_model, d_model, bias=False, **factory_kwargs)
with torch.no_grad():
self.q_proj_layer.weight.copy_(torch.eye(d_model))
self.k_proj_layer.weight.copy_(torch.eye(d_model))
# 防止外部权重再初始化
self.q_proj_layer.weight._no_reinit = True
self.k_proj_layer.weight._no_reinit = True
def allocate_inference_cache(self, batch_size, max_seqlen, device, dtype=None):
# 分配推理cache用于step
return RoutingModuleState(
has_seen_tokens=torch.zeros(batch_size, device=device, dtype=torch.bool),
last_hidden_state=torch.zeros(
batch_size, self.d_model, device=device, dtype=dtype
),
)
def forward(self, hidden_states, cu_seqlens=None, mask=None, inference_params=None):
"""
hidden_states:
- 若 cu_seqlens is None: (B, L, D)
- 若 cu_seqlens 非 None: 期望 packed 模式 (T, D),这里会临时扩维成 (1, T, D)
cu_seqlens: packed模式下每条序列的前缀和下标形如 [0, len1, len1+len2, ...]
mask: (B, L) boolTrue=有效非packed
inference_params: RoutingModuleState用于prefill时的校验与状态维护
"""
assert (mask is not None) or (
cu_seqlens is not None
), "Either mask or cu_seqlens must be provided"
if inference_params is not None:
# prefill阶段必须提供mask且不允许之前已经见过token
assert (
mask is not None
), "Mask must be provided if inference_params is provided"
assert (
~inference_params.has_seen_tokens
).all(), "Cannot have seen tokens when inference_params is not provided"
if cu_seqlens is not None:
# packed 模式:把 (T, D) 临时变为 (1, T, D)
hidden_states = hidden_states.unsqueeze(0)
# 计算相邻余弦相似度 cos(h_{t-1}, h_t)
cos_sim = torch.einsum(
"b l d, b l d -> b l",
F.normalize(self.q_proj_layer(hidden_states[:, :-1]), dim=-1),
F.normalize(self.k_proj_layer(hidden_states[:, 1:]), dim=-1),
)
# p = ((1 - cos) / 2) ∈ [0,1]
boundary_prob = torch.clamp(((1 - cos_sim) / 2), min=0.0, max=1.0)
# 强制首位置为边界:首位概率=1补充后长度和输入序列长度想等
PAD_PROB = 1.0
boundary_prob = F.pad(boundary_prob, (1, 0), "constant", PAD_PROB)
if cu_seqlens is not None:
# packed 模式下,每段序列的第一个位置是 cu_seqlens[:-1]
boundary_prob = boundary_prob.squeeze(0)
boundary_prob[cu_seqlens[:-1]] = PAD_PROB
# 组装为二分类概率 [非边界, 边界]
boundary_prob = torch.stack(((1 - boundary_prob), boundary_prob), dim=-1)
# 取最大概率类别
selected_idx = torch.argmax(boundary_prob, dim=-1)
# 硬边界掩码
boundary_mask = selected_idx == 1 # 形状与 hidden_states 的前两维一致
if mask is not None:
# 不允许选择到无效token
boundary_mask = boundary_mask & mask
if inference_params is not None:
# 维护路由状态是否见过token、最后一个有效token的隐藏状态
has_mask = mask.any(dim=-1)
inference_params.has_seen_tokens.copy_(
has_mask | inference_params.has_seen_tokens
)
last_mask = torch.clamp(mask.sum(dim=-1) - 1, min=0)
inference_params.last_hidden_state.copy_(
torch.where(
has_mask,
hidden_states[
torch.arange(
hidden_states.shape[0], device=hidden_states.device
),
last_mask,
],
inference_params.last_hidden_state,
)
)
# 取硬选择对应的概率(便于可视化/正则)
selected_probs = boundary_prob.gather(
dim=-1, index=selected_idx.unsqueeze(-1)
) # (..., 1)
return RoutingModuleOutput(
boundary_prob=boundary_prob, # (..., 2)
boundary_mask=boundary_mask, # (...)
selected_probs=selected_probs, # (..., 1)
)
def step(self, hidden_states, inference_params):
"""
流式单步:
hidden_states: (B, 1, D)
使用上一步缓存的 last_hidden_state 与当前token计算相邻相似度得到当前步的边界概率
"""
# (B, D)
hidden_states = hidden_states.squeeze(1)
cos_sim = torch.einsum(
"b d, b d -> b",
F.normalize(self.q_proj_layer(inference_params.last_hidden_state), dim=-1),
F.normalize(self.k_proj_layer(hidden_states), dim=-1),
)
boundary_prob = torch.clamp(((1 - cos_sim) / 2), min=0.0, max=1.0)
# 更新最后隐藏状态
inference_params.last_hidden_state.copy_(hidden_states)
# 首个token前强制边界
boundary_prob = torch.where(
inference_params.has_seen_tokens,
boundary_prob,
torch.ones_like(boundary_prob),
)
boundary_prob = torch.stack(((1 - boundary_prob), boundary_prob), dim=-1)
# 标记为已见token
inference_params.has_seen_tokens.copy_(
torch.ones_like(inference_params.has_seen_tokens)
)
return RoutingModuleOutput(
boundary_prob=boundary_prob, # (B, 2)
boundary_mask=boundary_prob[..., 1] > 0.5, # (B,)
selected_probs=boundary_prob.max(dim=-1).values.unsqueeze(-1), # (B, 1)
)
class ChunkLayer(nn.Module):
"""
Chunk层根据 boundary_mask 将被选中的“边界token”抽取出来形成下一层序列。
支持两种模式:
- packedcu_seqlens 非 None直接在拼接后的序列上索引
- 非packedmask 非 None通过排序 trick 把True位置排到前面并生成 next_mask
返回:
- next_hidden_states: 选中的token序列packed: shape=(#selected, D)非packed: (B, M, D)
- next_cu_seqlens: packed模式下新序列的cu_seqlens否则None
- next_max_seqlen: packed模式下选中的最大长度非packed模式返回None
- next_mask: 非packed模式下的右侧pad掩码packed模式下None
"""
def forward(self, hidden_states, boundary_mask, cu_seqlens=None, mask=None):
assert (mask is not None) or (
cu_seqlens is not None
), "Either mask or cu_seqlens must be provided"
if cu_seqlens is not None:
# packed直接选择True的行得到拼接后的 selected
next_hidden_states = hidden_states[boundary_mask]
# 新的cu_seqlens = 对每段最后一个位置(=cu_seqlens[1:]-1)累计True的计数再前置0
next_cu_seqlens = F.pad(
boundary_mask.cumsum(dim=0)[cu_seqlens[1:] - 1], (1, 0)
)
# 新序列的最大段长(仅用于内核/优化)
next_max_seqlen = int((next_cu_seqlens[1:] - next_cu_seqlens[:-1]).max())
next_mask = None
else:
# 非packed对每个batch内把True位置排到前面False放到靠后
next_cu_seqlens = None
num_tokens = boundary_mask.sum(dim=-1) # 每个样本被选中的数量
next_max_seqlen = int(num_tokens.max())
device = hidden_states.device
L = hidden_states.shape[1]
# trick用 (~boundary_mask)*L 把False加大从而 argsort 后 True 的下标排在前面
token_idx = (
torch.arange(L, device=device)[None, :] + (~boundary_mask).long() * L
)
seq_sorted_indices = torch.argsort(token_idx, dim=1)
# 收集前 next_max_seqlen 个不足的样本右侧pad
next_hidden_states = torch.gather(
hidden_states,
dim=1,
index=seq_sorted_indices[:, :next_max_seqlen, None].expand(
-1, -1, hidden_states.shape[-1]
),
)
# 下游的有效mask右侧pad无效
next_mask = (
torch.arange(next_max_seqlen, device=device)[None, :]
< num_tokens[:, None]
)
# 非packed模式下不再需要 max_seqlen返回None
next_max_seqlen = None
return next_hidden_states, next_cu_seqlens, next_max_seqlen, next_mask
def step(self, hidden_states, boundary_mask):
# 流式step仅返回当前步被选中的token用于下一层
return hidden_states[boundary_mask]
class DeChunkLayer(nn.Module):
"""
DeChunk层把“被选中的边界token序列”反聚合EMA回原始等长序列。
实现上复用 Mamba2 的 Triton 扫描核 mamba_chunk_scan_combined
- 将 d_model 切分为 nheads * headdim
- 使用参数 A=-1, b=p, c=1 的一阶状态空间/EMA形式进行前向扫描
- 最终把扫描输出根据分段索引映射回原位置plug back
支持:
- packed 模式cu_seqlens
- 非packedbatch+右侧pad
- 流式 stepEMA递推
"""
def __init__(
self,
d_model,
dtype=torch.bfloat16,
block_size=256,
headdim=32,
):
super().__init__()
self.d_model = d_model
# 仅为内核要求:使用 bfloat16块大小与头维拆分
self.dtype = dtype
self.block_size = block_size
self.headdim = headdim
assert d_model % self.headdim == 0
self.nheads = d_model // self.headdim
def allocate_inference_cache(self, batch_size, max_seqlen, device, dtype=None):
# 分配EMA的last_value缓存
return DeChunkState(
last_value=torch.zeros(
batch_size, self.d_model, device=device, dtype=dtype
),
)
def forward(
self,
hidden_states, # 被选中的token序列packed: (M, D)非packed: (B, M, D)
boundary_mask, # 原序列上的边界掩码((T,) 或 (B, L)
boundary_prob, # 原序列上的二分类概率((..., 2)
cu_seqlens=None,
inference_params=None,
mask=None,
):
"""
核心思路:
1) 从 boundary_prob 得到 p = P(boundary) ∈ (1e-4, 1-1e-4)
2) 构造 dt = log(1 / (1-p)),并对输入做缩放 x = h / dt
3) 用 mamba_chunk_scan_combined 扫描A=-1, b=p, c=1对 (B, M, H, P) 进行块扫描
4) 将结果根据 cumulative boundary index 回填到原序列位置
"""
if inference_params is not None:
# prefill时必须有mask且首token必须是边界保证EMA初始化
assert (
mask is not None
), "Mask must be provided if inference_params is provided"
assert boundary_mask[
:, 0
].all(), "First token must be a boundary if running prefill"
# 取边界概率的“边界类”概率 p并限制在(1e-4, 1-1e-4)内,避免数值不稳
p = torch.clamp(boundary_prob[..., -1].float(), min=1e-4, max=1 - (1e-4))
if cu_seqlens is not None:
# packed从原序列p中取出被选中的位置对应的概率形状(B=1, M)
p = p[boundary_mask].unsqueeze(0)
# 为triton核准备packed序列的索引映射
seq_idx = get_seq_idx(cu_seqlens, device=hidden_states.device)
else:
B, L = boundary_mask.shape
seq_idx = None
# 与ChunkLayer一致的排序 trick得到选中的顺序True在前
token_idx = (
torch.arange(L, device=hidden_states.device)[None, :]
+ (~boundary_mask).long() * L
)
seq_sorted_indices = torch.argsort(token_idx, dim=1)
# 取出与 hidden_states 对应长度的 p(B, M)
p = torch.gather(
p, dim=1, index=seq_sorted_indices[:, : hidden_states.shape[1]]
) # (B, M)
original_dtype = hidden_states.dtype
# 构造 EMA 扫描所需变量
dt = torch.log(1 / (1 - p)).to(self.dtype) # (B, M)
x = (hidden_states / dt[..., None]).to(self.dtype) # (B, M, D) / (B, M, 1)
# A, b, c 分别对应一阶状态空间/EMA的参数
A = -torch.ones(
(self.nheads,), device=hidden_states.device, dtype=torch.float32
)
b = p.to(self.dtype)
c = torch.ones_like(b)
# 调用triton核进行块扫描
out = mamba_chunk_scan_combined(
rearrange(x, "b l (h p) -> b l h p", p=self.headdim), # (B, M, H, P)
repeat(dt, "b l -> b l h", h=self.nheads), # (B, M, H)
A, # (H,)
rearrange(b, "b l -> b l 1 1"), # (B, M, 1, 1)
rearrange(c, "b l -> b l 1 1"), # (B, M, 1, 1)
chunk_size=self.block_size,
seq_idx=seq_idx, # packed时提供
)
out = rearrange(out, "b l h p -> b l (h p)") # (B, M, D)
# 将扫描结果回填plug back到原序列位置
if cu_seqlens is not None:
out = out.squeeze(0) # (M, D)
plug_back_idx = boundary_mask.cumsum(dim=0) - 1 # (T,)
out = torch.gather(
out, dim=0, index=plug_back_idx.unsqueeze(-1).expand(-1, self.d_model)
) # (T, D)
else:
plug_back_idx = torch.cumsum(boundary_mask, dim=1) - 1 # (B, L)
out = torch.gather(
out,
dim=1,
index=plug_back_idx.unsqueeze(-1).expand(-1, -1, self.d_model),
) # (B, L, D)
# 更新流式缓存
if inference_params is not None:
inference_params.last_value.copy_(out[:, -1])
return out.to(original_dtype)
def step(self, hidden_states, boundary_mask, boundary_prob, inference_params):
"""
流式单步 EMA 反聚合:
hidden_states: (B', 1, D),其中 B' = 当前步被选中的数量boundary_mask.sum()
boundary_mask: (B,) 当前batch哪些位置被选中为边界
boundary_prob: (B, 2) 当前batch各位置的边界概率
输出:(B, 1, D),对应对所有位置做了一步 EMA 更新后的值
"""
B = boundary_mask.shape[0]
D = hidden_states.shape[-1]
# 构造当前步每个位置的 p未被选中的位置 p=0
p = torch.zeros(B, device=hidden_states.device, dtype=hidden_states.dtype)
p[boundary_mask] = boundary_prob[boundary_mask, -1].clamp(
min=1e-4, max=1 - (1e-4)
)
# 构造当前被选中的隐藏状态未选中为0
current_hidden_states = torch.zeros(
B, D, device=hidden_states.device, dtype=hidden_states.dtype
)
current_hidden_states[boundary_mask] = hidden_states.squeeze(1)
# EMAresult = p * x + (1 - p) * last
result = p * current_hidden_states + (1 - p) * inference_params.last_value
inference_params.last_value.copy_(result)
return result.unsqueeze(1)