Files
TSlib/layers/DynamicChunking.py

441 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import repeat, rearrange
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
@dataclass
class RoutingModuleOutput:
# 路由模块的前向输出:
# - boundary_prob: 每个位置为“分隔点/非分隔点”的二分类概率,形状(..., 2)
# - boundary_mask: 基于最大概率得到的硬选择True=分隔点),形状与输入序列相同的前两维
# - selected_probs: 对应硬选择类别的概率,形状(..., 1)
boundary_prob: torch.Tensor
boundary_mask: torch.Tensor
selected_probs: torch.Tensor
@dataclass
class RoutingModuleState:
"""
路由模块的推理状态(用于增量/流式step
包含:
- has_seen_tokens: (batch_size,) 是否已经见过任意token用于首token强制边界
- last_hidden_state: (batch_size, d_model) 上一次的隐藏状态用于与当前token做相邻相似度
"""
has_seen_tokens: torch.Tensor # (batch_size,)
last_hidden_state: torch.Tensor # (batch_size, d_model)
@dataclass
class DeChunkState:
"""
DeChunk 的推理状态EMA的记忆值
包含:
- last_value: (batch_size, d_model) EMA反聚合的上一时刻值
"""
last_value: torch.Tensor # (batch_size, d_model)
def get_seq_idx(cu_seqlens, device=None):
seq_idx = torch.zeros(cu_seqlens[-1], dtype=torch.long, device=device)
seq_idx[cu_seqlens[:-1]] = 1
seq_idx = (torch.cumsum(seq_idx, dim=0) - 1).unsqueeze(0).int()
return seq_idx
class RoutingModule(nn.Module):
"""
路由模块:
用相邻token的余弦相似度构造“成为分隔点”的概率
p_t = clamp((1 - cos(h_{t-1}, h_t)) / 2, 0, 1)
并强制首位置为边界(概率=1
支持:
- 常规batch掩码 mask 模式
- packed序列 cu_seqlens 模式(把多序列打包成单条序列的拼接)
- 流式推理 step()(维护状态)
"""
def __init__(self, d_model, device=None, dtype=None):
self.d_model = d_model
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
# 相邻相似度计算前的线性投影(初始化为恒等)
self.q_proj_layer = nn.Linear(d_model, d_model, bias=False, **factory_kwargs)
self.k_proj_layer = nn.Linear(d_model, d_model, bias=False, **factory_kwargs)
with torch.no_grad():
self.q_proj_layer.weight.copy_(torch.eye(d_model))
self.k_proj_layer.weight.copy_(torch.eye(d_model))
# 防止外部权重再初始化
self.q_proj_layer.weight._no_reinit = True
self.k_proj_layer.weight._no_reinit = True
def allocate_inference_cache(self, batch_size, max_seqlen, device, dtype=None):
# 分配推理cache用于step
return RoutingModuleState(
has_seen_tokens=torch.zeros(batch_size, device=device, dtype=torch.bool),
last_hidden_state=torch.zeros(
batch_size, self.d_model, device=device, dtype=dtype
),
)
def forward(self, hidden_states, cu_seqlens=None, mask=None, inference_params=None):
"""
hidden_states:
- 若 cu_seqlens is None: (B, L, D)
- 若 cu_seqlens 非 None: 期望 packed 模式 (T, D),这里会临时扩维成 (1, T, D)
cu_seqlens: packed模式下每条序列的前缀和下标形如 [0, len1, len1+len2, ...]
mask: (B, L) boolTrue=有效非packed
inference_params: RoutingModuleState用于prefill时的校验与状态维护
"""
assert (mask is not None) or (
cu_seqlens is not None
), "Either mask or cu_seqlens must be provided"
if inference_params is not None:
# prefill阶段必须提供mask且不允许之前已经见过token
assert (
mask is not None
), "Mask must be provided if inference_params is provided"
assert (
~inference_params.has_seen_tokens
).all(), "Cannot have seen tokens when inference_params is not provided"
if cu_seqlens is not None:
# packed 模式:把 (T, D) 临时变为 (1, T, D)
hidden_states = hidden_states.unsqueeze(0)
# 计算相邻余弦相似度 cos(h_{t-1}, h_t)
cos_sim = torch.einsum(
"b l d, b l d -> b l",
F.normalize(self.q_proj_layer(hidden_states[:, :-1]), dim=-1),
F.normalize(self.k_proj_layer(hidden_states[:, 1:]), dim=-1),
)
# p = ((1 - cos) / 2) ∈ [0,1]
boundary_prob = torch.clamp(((1 - cos_sim) / 2), min=0.0, max=1.0)
# 强制首位置为边界:首位概率=1补充后长度和输入序列长度想等
PAD_PROB = 1.0
boundary_prob = F.pad(boundary_prob, (1, 0), "constant", PAD_PROB)
if cu_seqlens is not None:
# packed 模式下,每段序列的第一个位置是 cu_seqlens[:-1]
boundary_prob = boundary_prob.squeeze(0)
boundary_prob[cu_seqlens[:-1]] = PAD_PROB
# 组装为二分类概率 [非边界, 边界]
boundary_prob = torch.stack(((1 - boundary_prob), boundary_prob), dim=-1)
# 取最大概率类别
selected_idx = torch.argmax(boundary_prob, dim=-1)
# 硬边界掩码
boundary_mask = selected_idx == 1 # 形状与 hidden_states 的前两维一致
if mask is not None:
# 不允许选择到无效token
boundary_mask = boundary_mask & mask
if inference_params is not None:
# 维护路由状态是否见过token、最后一个有效token的隐藏状态
has_mask = mask.any(dim=-1)
inference_params.has_seen_tokens.copy_(
has_mask | inference_params.has_seen_tokens
)
last_mask = torch.clamp(mask.sum(dim=-1) - 1, min=0)
inference_params.last_hidden_state.copy_(
torch.where(
has_mask,
hidden_states[
torch.arange(
hidden_states.shape[0], device=hidden_states.device
),
last_mask,
],
inference_params.last_hidden_state,
)
)
# 取硬选择对应的概率(便于可视化/正则)
selected_probs = boundary_prob.gather(
dim=-1, index=selected_idx.unsqueeze(-1)
) # (..., 1)
return RoutingModuleOutput(
boundary_prob=boundary_prob, # (..., 2)
boundary_mask=boundary_mask, # (...)
selected_probs=selected_probs, # (..., 1)
)
def step(self, hidden_states, inference_params):
"""
流式单步:
hidden_states: (B, 1, D)
使用上一步缓存的 last_hidden_state 与当前token计算相邻相似度得到当前步的边界概率
"""
# (B, D)
hidden_states = hidden_states.squeeze(1)
cos_sim = torch.einsum(
"b d, b d -> b",
F.normalize(self.q_proj_layer(inference_params.last_hidden_state), dim=-1),
F.normalize(self.k_proj_layer(hidden_states), dim=-1),
)
boundary_prob = torch.clamp(((1 - cos_sim) / 2), min=0.0, max=1.0)
# 更新最后隐藏状态
inference_params.last_hidden_state.copy_(hidden_states)
# 首个token前强制边界
boundary_prob = torch.where(
inference_params.has_seen_tokens,
boundary_prob,
torch.ones_like(boundary_prob),
)
boundary_prob = torch.stack(((1 - boundary_prob), boundary_prob), dim=-1)
# 标记为已见token
inference_params.has_seen_tokens.copy_(
torch.ones_like(inference_params.has_seen_tokens)
)
return RoutingModuleOutput(
boundary_prob=boundary_prob, # (B, 2)
boundary_mask=boundary_prob[..., 1] > 0.5, # (B,)
selected_probs=boundary_prob.max(dim=-1).values.unsqueeze(-1), # (B, 1)
)
class ChunkLayer(nn.Module):
"""
Chunk层根据 boundary_mask 将被选中的“边界token”抽取出来形成下一层序列。
支持两种模式:
- packedcu_seqlens 非 None直接在拼接后的序列上索引
- 非packedmask 非 None通过排序 trick 把True位置排到前面并生成 next_mask
返回:
- next_hidden_states: 选中的token序列packed: shape=(#selected, D)非packed: (B, M, D)
- next_cu_seqlens: packed模式下新序列的cu_seqlens否则None
- next_max_seqlen: packed模式下选中的最大长度非packed模式返回None
- next_mask: 非packed模式下的右侧pad掩码packed模式下None
"""
def forward(self, hidden_states, boundary_mask, cu_seqlens=None, mask=None):
assert (mask is not None) or (
cu_seqlens is not None
), "Either mask or cu_seqlens must be provided"
if cu_seqlens is not None:
# packed直接选择True的行得到拼接后的 selected
next_hidden_states = hidden_states[boundary_mask]
# 新的cu_seqlens = 对每段最后一个位置(=cu_seqlens[1:]-1)累计True的计数再前置0
next_cu_seqlens = F.pad(
boundary_mask.cumsum(dim=0)[cu_seqlens[1:] - 1], (1, 0)
)
# 新序列的最大段长(仅用于内核/优化)
next_max_seqlen = int((next_cu_seqlens[1:] - next_cu_seqlens[:-1]).max())
next_mask = None
else:
# 非packed对每个batch内把True位置排到前面False放到靠后
next_cu_seqlens = None
num_tokens = boundary_mask.sum(dim=-1) # 每个样本被选中的数量
next_max_seqlen = int(num_tokens.max())
device = hidden_states.device
L = hidden_states.shape[1]
# trick用 (~boundary_mask)*L 把False加大从而 argsort 后 True 的下标排在前面
token_idx = (
torch.arange(L, device=device)[None, :] + (~boundary_mask).long() * L
)
seq_sorted_indices = torch.argsort(token_idx, dim=1)
# 收集前 next_max_seqlen 个不足的样本右侧pad
next_hidden_states = torch.gather(
hidden_states,
dim=1,
index=seq_sorted_indices[:, :next_max_seqlen, None].expand(
-1, -1, hidden_states.shape[-1]
),
)
# 下游的有效mask右侧pad无效
next_mask = (
torch.arange(next_max_seqlen, device=device)[None, :]
< num_tokens[:, None]
)
# 非packed模式下不再需要 max_seqlen返回None
next_max_seqlen = None
return next_hidden_states, next_cu_seqlens, next_max_seqlen, next_mask
def step(self, hidden_states, boundary_mask):
# 流式step仅返回当前步被选中的token用于下一层
return hidden_states[boundary_mask]
class DeChunkLayer(nn.Module):
"""
DeChunk层把“被选中的边界token序列”反聚合EMA回原始等长序列。
实现上复用 Mamba2 的 Triton 扫描核 mamba_chunk_scan_combined
- 将 d_model 切分为 nheads * headdim
- 使用参数 A=-1, b=p, c=1 的一阶状态空间/EMA形式进行前向扫描
- 最终把扫描输出根据分段索引映射回原位置plug back
支持:
- packed 模式cu_seqlens
- 非packedbatch+右侧pad
- 流式 stepEMA递推
"""
def __init__(
self,
d_model,
dtype=torch.bfloat16,
block_size=256,
headdim=32,
):
super().__init__()
self.d_model = d_model
# 仅为内核要求:使用 bfloat16块大小与头维拆分
self.dtype = dtype
self.block_size = block_size
self.headdim = headdim
assert d_model % self.headdim == 0
self.nheads = d_model // self.headdim
def allocate_inference_cache(self, batch_size, max_seqlen, device, dtype=None):
# 分配EMA的last_value缓存
return DeChunkState(
last_value=torch.zeros(
batch_size, self.d_model, device=device, dtype=dtype
),
)
def forward(
self,
hidden_states, # 被选中的token序列packed: (M, D)非packed: (B, M, D)
boundary_mask, # 原序列上的边界掩码((T,) 或 (B, L)
boundary_prob, # 原序列上的二分类概率((..., 2)
cu_seqlens=None,
inference_params=None,
mask=None,
):
"""
核心思路:
1) 从 boundary_prob 得到 p = P(boundary) ∈ (1e-4, 1-1e-4)
2) 构造 dt = log(1 / (1-p)),并对输入做缩放 x = h / dt
3) 用 mamba_chunk_scan_combined 扫描A=-1, b=p, c=1对 (B, M, H, P) 进行块扫描
4) 将结果根据 cumulative boundary index 回填到原序列位置
"""
if inference_params is not None:
# prefill时必须有mask且首token必须是边界保证EMA初始化
assert (
mask is not None
), "Mask must be provided if inference_params is provided"
assert boundary_mask[
:, 0
].all(), "First token must be a boundary if running prefill"
# 取边界概率的“边界类”概率 p并限制在(1e-4, 1-1e-4)内,避免数值不稳
p = torch.clamp(boundary_prob[..., -1].float(), min=1e-4, max=1 - (1e-4))
if cu_seqlens is not None:
# packed从原序列p中取出被选中的位置对应的概率形状(B=1, M)
p = p[boundary_mask].unsqueeze(0)
# 为triton核准备packed序列的索引映射
seq_idx = get_seq_idx(cu_seqlens, device=hidden_states.device)
else:
B, L = boundary_mask.shape
seq_idx = None
# 与ChunkLayer一致的排序 trick得到选中的顺序True在前
token_idx = (
torch.arange(L, device=hidden_states.device)[None, :]
+ (~boundary_mask).long() * L
)
seq_sorted_indices = torch.argsort(token_idx, dim=1)
# 取出与 hidden_states 对应长度的 p(B, M)
p = torch.gather(
p, dim=1, index=seq_sorted_indices[:, : hidden_states.shape[1]]
) # (B, M)
original_dtype = hidden_states.dtype
# 构造 EMA 扫描所需变量
dt = torch.log(1 / (1 - p)).to(self.dtype) # (B, M)
x = (hidden_states / dt[..., None]).to(self.dtype) # (B, M, D) / (B, M, 1)
# A, b, c 分别对应一阶状态空间/EMA的参数
A = -torch.ones(
(self.nheads,), device=hidden_states.device, dtype=torch.float32
)
b = p.to(self.dtype)
c = torch.ones_like(b)
# 调用triton核进行块扫描
out = mamba_chunk_scan_combined(
rearrange(x, "b l (h p) -> b l h p", p=self.headdim), # (B, M, H, P)
repeat(dt, "b l -> b l h", h=self.nheads), # (B, M, H)
A, # (H,)
rearrange(b, "b l -> b l 1 1"), # (B, M, 1, 1)
rearrange(c, "b l -> b l 1 1"), # (B, M, 1, 1)
chunk_size=self.block_size,
seq_idx=seq_idx, # packed时提供
)
out = rearrange(out, "b l h p -> b l (h p)") # (B, M, D)
# 将扫描结果回填plug back到原序列位置
if cu_seqlens is not None:
out = out.squeeze(0) # (M, D)
plug_back_idx = boundary_mask.cumsum(dim=0) - 1 # (T,)
out = torch.gather(
out, dim=0, index=plug_back_idx.unsqueeze(-1).expand(-1, self.d_model)
) # (T, D)
else:
plug_back_idx = torch.cumsum(boundary_mask, dim=1) - 1 # (B, L)
out = torch.gather(
out,
dim=1,
index=plug_back_idx.unsqueeze(-1).expand(-1, -1, self.d_model),
) # (B, L, D)
# 更新流式缓存
if inference_params is not None:
inference_params.last_value.copy_(out[:, -1])
return out.to(original_dtype)
def step(self, hidden_states, boundary_mask, boundary_prob, inference_params):
"""
流式单步 EMA 反聚合:
hidden_states: (B', 1, D),其中 B' = 当前步被选中的数量boundary_mask.sum()
boundary_mask: (B,) 当前batch哪些位置被选中为边界
boundary_prob: (B, 2) 当前batch各位置的边界概率
输出:(B, 1, D),对应对所有位置做了一步 EMA 更新后的值
"""
B = boundary_mask.shape[0]
D = hidden_states.shape[-1]
# 构造当前步每个位置的 p未被选中的位置 p=0
p = torch.zeros(B, device=hidden_states.device, dtype=hidden_states.dtype)
p[boundary_mask] = boundary_prob[boundary_mask, -1].clamp(
min=1e-4, max=1 - (1e-4)
)
# 构造当前被选中的隐藏状态未选中为0
current_hidden_states = torch.zeros(
B, D, device=hidden_states.device, dtype=hidden_states.dtype
)
current_hidden_states[boundary_mask] = hidden_states.squeeze(1)
# EMAresult = p * x + (1 - p) * last
result = p * current_hidden_states + (1 - p) * inference_params.last_value
inference_params.last_value.copy_(result)
return result.unsqueeze(1)