441 lines
18 KiB
Python
441 lines
18 KiB
Python
from dataclasses import dataclass
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
|
||
from einops import repeat, rearrange
|
||
|
||
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
|
||
|
||
|
||
|
||
|
||
|
||
@dataclass
|
||
class RoutingModuleOutput:
|
||
# 路由模块的前向输出:
|
||
# - boundary_prob: 每个位置为“分隔点/非分隔点”的二分类概率,形状(..., 2)
|
||
# - boundary_mask: 基于最大概率得到的硬选择(True=分隔点),形状与输入序列相同的前两维
|
||
# - selected_probs: 对应硬选择类别的概率,形状(..., 1)
|
||
boundary_prob: torch.Tensor
|
||
boundary_mask: torch.Tensor
|
||
selected_probs: torch.Tensor
|
||
|
||
|
||
@dataclass
|
||
class RoutingModuleState:
|
||
"""
|
||
路由模块的推理状态(用于增量/流式step)
|
||
|
||
包含:
|
||
- has_seen_tokens: (batch_size,) 是否已经见过任意token(用于首token强制边界)
|
||
- last_hidden_state: (batch_size, d_model) 上一次的隐藏状态(用于与当前token做相邻相似度)
|
||
"""
|
||
|
||
has_seen_tokens: torch.Tensor # (batch_size,)
|
||
last_hidden_state: torch.Tensor # (batch_size, d_model)
|
||
|
||
|
||
@dataclass
|
||
class DeChunkState:
|
||
"""
|
||
DeChunk 的推理状态(EMA的记忆值)
|
||
|
||
包含:
|
||
- last_value: (batch_size, d_model) EMA反聚合的上一时刻值
|
||
"""
|
||
|
||
last_value: torch.Tensor # (batch_size, d_model)
|
||
|
||
|
||
def get_seq_idx(cu_seqlens, device=None):
|
||
seq_idx = torch.zeros(cu_seqlens[-1], dtype=torch.long, device=device)
|
||
seq_idx[cu_seqlens[:-1]] = 1
|
||
seq_idx = (torch.cumsum(seq_idx, dim=0) - 1).unsqueeze(0).int()
|
||
|
||
return seq_idx
|
||
|
||
class RoutingModule(nn.Module):
|
||
"""
|
||
路由模块:
|
||
用相邻token的余弦相似度构造“成为分隔点”的概率:
|
||
p_t = clamp((1 - cos(h_{t-1}, h_t)) / 2, 0, 1)
|
||
并强制首位置为边界(概率=1)。
|
||
支持:
|
||
- 常规batch掩码 mask 模式
|
||
- packed序列 cu_seqlens 模式(把多序列打包成单条序列的拼接)
|
||
- 流式推理 step()(维护状态)
|
||
"""
|
||
|
||
def __init__(self, d_model, device=None, dtype=None):
|
||
self.d_model = d_model
|
||
factory_kwargs = {"device": device, "dtype": dtype}
|
||
super().__init__()
|
||
# 相邻相似度计算前的线性投影(初始化为恒等)
|
||
self.q_proj_layer = nn.Linear(d_model, d_model, bias=False, **factory_kwargs)
|
||
self.k_proj_layer = nn.Linear(d_model, d_model, bias=False, **factory_kwargs)
|
||
with torch.no_grad():
|
||
self.q_proj_layer.weight.copy_(torch.eye(d_model))
|
||
self.k_proj_layer.weight.copy_(torch.eye(d_model))
|
||
# 防止外部权重再初始化
|
||
self.q_proj_layer.weight._no_reinit = True
|
||
self.k_proj_layer.weight._no_reinit = True
|
||
|
||
def allocate_inference_cache(self, batch_size, max_seqlen, device, dtype=None):
|
||
# 分配推理cache(用于step)
|
||
return RoutingModuleState(
|
||
has_seen_tokens=torch.zeros(batch_size, device=device, dtype=torch.bool),
|
||
last_hidden_state=torch.zeros(
|
||
batch_size, self.d_model, device=device, dtype=dtype
|
||
),
|
||
)
|
||
|
||
def forward(self, hidden_states, cu_seqlens=None, mask=None, inference_params=None):
|
||
"""
|
||
hidden_states:
|
||
- 若 cu_seqlens is None: (B, L, D)
|
||
- 若 cu_seqlens 非 None: 期望 packed 模式 (T, D),这里会临时扩维成 (1, T, D)
|
||
cu_seqlens: packed模式下每条序列的前缀和下标,形如 [0, len1, len1+len2, ...]
|
||
mask: (B, L) bool,True=有效(非packed)
|
||
inference_params: RoutingModuleState,用于prefill时的校验与状态维护
|
||
"""
|
||
assert (mask is not None) or (
|
||
cu_seqlens is not None
|
||
), "Either mask or cu_seqlens must be provided"
|
||
|
||
if inference_params is not None:
|
||
# prefill阶段必须提供mask,且不允许之前已经见过token
|
||
assert (
|
||
mask is not None
|
||
), "Mask must be provided if inference_params is provided"
|
||
assert (
|
||
~inference_params.has_seen_tokens
|
||
).all(), "Cannot have seen tokens when inference_params is not provided"
|
||
|
||
if cu_seqlens is not None:
|
||
# packed 模式:把 (T, D) 临时变为 (1, T, D)
|
||
hidden_states = hidden_states.unsqueeze(0)
|
||
|
||
# 计算相邻余弦相似度 cos(h_{t-1}, h_t)
|
||
cos_sim = torch.einsum(
|
||
"b l d, b l d -> b l",
|
||
F.normalize(self.q_proj_layer(hidden_states[:, :-1]), dim=-1),
|
||
F.normalize(self.k_proj_layer(hidden_states[:, 1:]), dim=-1),
|
||
)
|
||
# p = ((1 - cos) / 2) ∈ [0,1]
|
||
boundary_prob = torch.clamp(((1 - cos_sim) / 2), min=0.0, max=1.0)
|
||
|
||
# 强制首位置为边界:首位概率=1,补充后长度和输入序列长度想等
|
||
PAD_PROB = 1.0
|
||
boundary_prob = F.pad(boundary_prob, (1, 0), "constant", PAD_PROB)
|
||
|
||
if cu_seqlens is not None:
|
||
# packed 模式下,每段序列的第一个位置是 cu_seqlens[:-1]
|
||
boundary_prob = boundary_prob.squeeze(0)
|
||
boundary_prob[cu_seqlens[:-1]] = PAD_PROB
|
||
|
||
# 组装为二分类概率 [非边界, 边界]
|
||
boundary_prob = torch.stack(((1 - boundary_prob), boundary_prob), dim=-1)
|
||
|
||
# 取最大概率类别
|
||
selected_idx = torch.argmax(boundary_prob, dim=-1)
|
||
|
||
# 硬边界掩码
|
||
boundary_mask = selected_idx == 1 # 形状与 hidden_states 的前两维一致
|
||
if mask is not None:
|
||
# 不允许选择到无效token
|
||
boundary_mask = boundary_mask & mask
|
||
|
||
if inference_params is not None:
|
||
# 维护路由状态:是否见过token、最后一个有效token的隐藏状态
|
||
has_mask = mask.any(dim=-1)
|
||
inference_params.has_seen_tokens.copy_(
|
||
has_mask | inference_params.has_seen_tokens
|
||
)
|
||
last_mask = torch.clamp(mask.sum(dim=-1) - 1, min=0)
|
||
inference_params.last_hidden_state.copy_(
|
||
torch.where(
|
||
has_mask,
|
||
hidden_states[
|
||
torch.arange(
|
||
hidden_states.shape[0], device=hidden_states.device
|
||
),
|
||
last_mask,
|
||
],
|
||
inference_params.last_hidden_state,
|
||
)
|
||
)
|
||
|
||
# 取硬选择对应的概率(便于可视化/正则)
|
||
selected_probs = boundary_prob.gather(
|
||
dim=-1, index=selected_idx.unsqueeze(-1)
|
||
) # (..., 1)
|
||
|
||
return RoutingModuleOutput(
|
||
boundary_prob=boundary_prob, # (..., 2)
|
||
boundary_mask=boundary_mask, # (...)
|
||
selected_probs=selected_probs, # (..., 1)
|
||
)
|
||
|
||
def step(self, hidden_states, inference_params):
|
||
"""
|
||
流式单步:
|
||
hidden_states: (B, 1, D)
|
||
使用上一步缓存的 last_hidden_state 与当前token计算相邻相似度,得到当前步的边界概率
|
||
"""
|
||
# (B, D)
|
||
hidden_states = hidden_states.squeeze(1)
|
||
cos_sim = torch.einsum(
|
||
"b d, b d -> b",
|
||
F.normalize(self.q_proj_layer(inference_params.last_hidden_state), dim=-1),
|
||
F.normalize(self.k_proj_layer(hidden_states), dim=-1),
|
||
)
|
||
boundary_prob = torch.clamp(((1 - cos_sim) / 2), min=0.0, max=1.0)
|
||
# 更新最后隐藏状态
|
||
inference_params.last_hidden_state.copy_(hidden_states)
|
||
# 首个token前,强制边界
|
||
boundary_prob = torch.where(
|
||
inference_params.has_seen_tokens,
|
||
boundary_prob,
|
||
torch.ones_like(boundary_prob),
|
||
)
|
||
boundary_prob = torch.stack(((1 - boundary_prob), boundary_prob), dim=-1)
|
||
|
||
# 标记为已见token
|
||
inference_params.has_seen_tokens.copy_(
|
||
torch.ones_like(inference_params.has_seen_tokens)
|
||
)
|
||
return RoutingModuleOutput(
|
||
boundary_prob=boundary_prob, # (B, 2)
|
||
boundary_mask=boundary_prob[..., 1] > 0.5, # (B,)
|
||
selected_probs=boundary_prob.max(dim=-1).values.unsqueeze(-1), # (B, 1)
|
||
)
|
||
|
||
|
||
class ChunkLayer(nn.Module):
|
||
"""
|
||
Chunk层:根据 boundary_mask 将被选中的“边界token”抽取出来,形成下一层序列。
|
||
支持两种模式:
|
||
- packed(cu_seqlens 非 None):直接在拼接后的序列上索引
|
||
- 非packed(mask 非 None):通过排序 trick 把True位置排到前面,并生成 next_mask
|
||
返回:
|
||
- next_hidden_states: 选中的token序列(packed: shape=(#selected, D);非packed: (B, M, D))
|
||
- next_cu_seqlens: packed模式下新序列的cu_seqlens;否则None
|
||
- next_max_seqlen: packed模式下选中的最大长度;非packed模式返回None
|
||
- next_mask: 非packed模式下的右侧pad掩码;packed模式下None
|
||
"""
|
||
|
||
def forward(self, hidden_states, boundary_mask, cu_seqlens=None, mask=None):
|
||
assert (mask is not None) or (
|
||
cu_seqlens is not None
|
||
), "Either mask or cu_seqlens must be provided"
|
||
|
||
if cu_seqlens is not None:
|
||
# packed:直接选择True的行,得到拼接后的 selected
|
||
next_hidden_states = hidden_states[boundary_mask]
|
||
# 新的cu_seqlens = 对每段最后一个位置(=cu_seqlens[1:]-1)累计True的计数,再前置0
|
||
next_cu_seqlens = F.pad(
|
||
boundary_mask.cumsum(dim=0)[cu_seqlens[1:] - 1], (1, 0)
|
||
)
|
||
# 新序列的最大段长(仅用于内核/优化)
|
||
next_max_seqlen = int((next_cu_seqlens[1:] - next_cu_seqlens[:-1]).max())
|
||
next_mask = None
|
||
else:
|
||
# 非packed:对每个batch内,把True位置排到前面(False放到靠后)
|
||
next_cu_seqlens = None
|
||
num_tokens = boundary_mask.sum(dim=-1) # 每个样本被选中的数量
|
||
next_max_seqlen = int(num_tokens.max())
|
||
|
||
device = hidden_states.device
|
||
L = hidden_states.shape[1]
|
||
# trick:用 (~boundary_mask)*L 把False加大,从而 argsort 后 True 的下标排在前面
|
||
token_idx = (
|
||
torch.arange(L, device=device)[None, :] + (~boundary_mask).long() * L
|
||
)
|
||
seq_sorted_indices = torch.argsort(token_idx, dim=1)
|
||
|
||
# 收集前 next_max_seqlen 个(不足的样本右侧pad)
|
||
next_hidden_states = torch.gather(
|
||
hidden_states,
|
||
dim=1,
|
||
index=seq_sorted_indices[:, :next_max_seqlen, None].expand(
|
||
-1, -1, hidden_states.shape[-1]
|
||
),
|
||
)
|
||
|
||
# 下游的有效mask(右侧pad无效)
|
||
next_mask = (
|
||
torch.arange(next_max_seqlen, device=device)[None, :]
|
||
< num_tokens[:, None]
|
||
)
|
||
# 非packed模式下,不再需要 max_seqlen(返回None)
|
||
next_max_seqlen = None
|
||
|
||
return next_hidden_states, next_cu_seqlens, next_max_seqlen, next_mask
|
||
|
||
def step(self, hidden_states, boundary_mask):
|
||
# 流式step:仅返回当前步被选中的token(用于下一层)
|
||
return hidden_states[boundary_mask]
|
||
|
||
|
||
class DeChunkLayer(nn.Module):
|
||
"""
|
||
DeChunk层:把“被选中的边界token序列”反聚合(EMA)回原始等长序列。
|
||
实现上复用 Mamba2 的 Triton 扫描核 mamba_chunk_scan_combined:
|
||
- 将 d_model 切分为 nheads * headdim
|
||
- 使用参数 A=-1, b=p, c=1 的一阶状态空间/EMA形式进行前向扫描
|
||
- 最终把扫描输出根据分段索引映射回原位置(plug back)
|
||
支持:
|
||
- packed 模式(cu_seqlens)
|
||
- 非packed(batch+右侧pad)
|
||
- 流式 step(EMA递推)
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
d_model,
|
||
dtype=torch.bfloat16,
|
||
block_size=256,
|
||
headdim=32,
|
||
):
|
||
super().__init__()
|
||
self.d_model = d_model
|
||
|
||
# 仅为内核要求:使用 bfloat16,块大小与头维拆分
|
||
self.dtype = dtype
|
||
self.block_size = block_size
|
||
self.headdim = headdim
|
||
assert d_model % self.headdim == 0
|
||
self.nheads = d_model // self.headdim
|
||
|
||
def allocate_inference_cache(self, batch_size, max_seqlen, device, dtype=None):
|
||
# 分配EMA的last_value缓存
|
||
return DeChunkState(
|
||
last_value=torch.zeros(
|
||
batch_size, self.d_model, device=device, dtype=dtype
|
||
),
|
||
)
|
||
|
||
def forward(
|
||
self,
|
||
hidden_states, # 被选中的token序列(packed: (M, D);非packed: (B, M, D))
|
||
boundary_mask, # 原序列上的边界掩码((T,) 或 (B, L))
|
||
boundary_prob, # 原序列上的二分类概率((..., 2))
|
||
cu_seqlens=None,
|
||
inference_params=None,
|
||
mask=None,
|
||
):
|
||
"""
|
||
核心思路:
|
||
1) 从 boundary_prob 得到 p = P(boundary) ∈ (1e-4, 1-1e-4)
|
||
2) 构造 dt = log(1 / (1-p)),并对输入做缩放 x = h / dt
|
||
3) 用 mamba_chunk_scan_combined 扫描:A=-1, b=p, c=1,对 (B, M, H, P) 进行块扫描
|
||
4) 将结果根据 cumulative boundary index 回填到原序列位置
|
||
"""
|
||
if inference_params is not None:
|
||
# prefill时必须有mask,且首token必须是边界(保证EMA初始化)
|
||
assert (
|
||
mask is not None
|
||
), "Mask must be provided if inference_params is provided"
|
||
assert boundary_mask[
|
||
:, 0
|
||
].all(), "First token must be a boundary if running prefill"
|
||
|
||
# 取边界概率的“边界类”概率 p,并限制在(1e-4, 1-1e-4)内,避免数值不稳
|
||
p = torch.clamp(boundary_prob[..., -1].float(), min=1e-4, max=1 - (1e-4))
|
||
|
||
if cu_seqlens is not None:
|
||
# packed:从原序列p中取出被选中的位置对应的概率,形状(B=1, M)
|
||
p = p[boundary_mask].unsqueeze(0)
|
||
# 为triton核准备packed序列的索引映射
|
||
seq_idx = get_seq_idx(cu_seqlens, device=hidden_states.device)
|
||
else:
|
||
B, L = boundary_mask.shape
|
||
seq_idx = None
|
||
# 与ChunkLayer一致的排序 trick,得到选中的顺序(True在前)
|
||
token_idx = (
|
||
torch.arange(L, device=hidden_states.device)[None, :]
|
||
+ (~boundary_mask).long() * L
|
||
)
|
||
seq_sorted_indices = torch.argsort(token_idx, dim=1)
|
||
|
||
# 取出与 hidden_states 对应长度的 p((B, M))
|
||
p = torch.gather(
|
||
p, dim=1, index=seq_sorted_indices[:, : hidden_states.shape[1]]
|
||
) # (B, M)
|
||
|
||
original_dtype = hidden_states.dtype
|
||
# 构造 EMA 扫描所需变量
|
||
dt = torch.log(1 / (1 - p)).to(self.dtype) # (B, M)
|
||
x = (hidden_states / dt[..., None]).to(self.dtype) # (B, M, D) / (B, M, 1)
|
||
|
||
# A, b, c 分别对应一阶状态空间/EMA的参数
|
||
A = -torch.ones(
|
||
(self.nheads,), device=hidden_states.device, dtype=torch.float32
|
||
)
|
||
b = p.to(self.dtype)
|
||
c = torch.ones_like(b)
|
||
|
||
# 调用triton核进行块扫描
|
||
out = mamba_chunk_scan_combined(
|
||
rearrange(x, "b l (h p) -> b l h p", p=self.headdim), # (B, M, H, P)
|
||
repeat(dt, "b l -> b l h", h=self.nheads), # (B, M, H)
|
||
A, # (H,)
|
||
rearrange(b, "b l -> b l 1 1"), # (B, M, 1, 1)
|
||
rearrange(c, "b l -> b l 1 1"), # (B, M, 1, 1)
|
||
chunk_size=self.block_size,
|
||
seq_idx=seq_idx, # packed时提供
|
||
)
|
||
out = rearrange(out, "b l h p -> b l (h p)") # (B, M, D)
|
||
|
||
# 将扫描结果回填(plug back)到原序列位置
|
||
if cu_seqlens is not None:
|
||
out = out.squeeze(0) # (M, D)
|
||
plug_back_idx = boundary_mask.cumsum(dim=0) - 1 # (T,)
|
||
out = torch.gather(
|
||
out, dim=0, index=plug_back_idx.unsqueeze(-1).expand(-1, self.d_model)
|
||
) # (T, D)
|
||
else:
|
||
plug_back_idx = torch.cumsum(boundary_mask, dim=1) - 1 # (B, L)
|
||
out = torch.gather(
|
||
out,
|
||
dim=1,
|
||
index=plug_back_idx.unsqueeze(-1).expand(-1, -1, self.d_model),
|
||
) # (B, L, D)
|
||
|
||
# 更新流式缓存
|
||
if inference_params is not None:
|
||
inference_params.last_value.copy_(out[:, -1])
|
||
|
||
return out.to(original_dtype)
|
||
|
||
def step(self, hidden_states, boundary_mask, boundary_prob, inference_params):
|
||
"""
|
||
流式单步 EMA 反聚合:
|
||
hidden_states: (B', 1, D),其中 B' = 当前步被选中的数量(boundary_mask.sum())
|
||
boundary_mask: (B,) 当前batch哪些位置被选中为边界
|
||
boundary_prob: (B, 2) 当前batch各位置的边界概率
|
||
输出:(B, 1, D),对应对所有位置做了一步 EMA 更新后的值
|
||
"""
|
||
B = boundary_mask.shape[0]
|
||
D = hidden_states.shape[-1]
|
||
|
||
# 构造当前步每个位置的 p(未被选中的位置 p=0)
|
||
p = torch.zeros(B, device=hidden_states.device, dtype=hidden_states.dtype)
|
||
p[boundary_mask] = boundary_prob[boundary_mask, -1].clamp(
|
||
min=1e-4, max=1 - (1e-4)
|
||
)
|
||
|
||
# 构造当前被选中的隐藏状态(未选中为0)
|
||
current_hidden_states = torch.zeros(
|
||
B, D, device=hidden_states.device, dtype=hidden_states.dtype
|
||
)
|
||
current_hidden_states[boundary_mask] = hidden_states.squeeze(1)
|
||
|
||
# EMA:result = p * x + (1 - p) * last
|
||
result = p * current_hidden_states + (1 - p) * inference_params.last_value
|
||
inference_params.last_value.copy_(result)
|
||
|
||
return result.unsqueeze(1)
|