from dataclasses import dataclass import torch import torch.nn as nn import torch.nn.functional as F from einops import repeat, rearrange from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined @dataclass class RoutingModuleOutput: # 路由模块的前向输出: # - boundary_prob: 每个位置为“分隔点/非分隔点”的二分类概率,形状(..., 2) # - boundary_mask: 基于最大概率得到的硬选择(True=分隔点),形状与输入序列相同的前两维 # - selected_probs: 对应硬选择类别的概率,形状(..., 1) boundary_prob: torch.Tensor boundary_mask: torch.Tensor selected_probs: torch.Tensor @dataclass class RoutingModuleState: """ 路由模块的推理状态(用于增量/流式step) 包含: - has_seen_tokens: (batch_size,) 是否已经见过任意token(用于首token强制边界) - last_hidden_state: (batch_size, d_model) 上一次的隐藏状态(用于与当前token做相邻相似度) """ has_seen_tokens: torch.Tensor # (batch_size,) last_hidden_state: torch.Tensor # (batch_size, d_model) @dataclass class DeChunkState: """ DeChunk 的推理状态(EMA的记忆值) 包含: - last_value: (batch_size, d_model) EMA反聚合的上一时刻值 """ last_value: torch.Tensor # (batch_size, d_model) def get_seq_idx(cu_seqlens, device=None): seq_idx = torch.zeros(cu_seqlens[-1], dtype=torch.long, device=device) seq_idx[cu_seqlens[:-1]] = 1 seq_idx = (torch.cumsum(seq_idx, dim=0) - 1).unsqueeze(0).int() return seq_idx class RoutingModule(nn.Module): """ 路由模块: 用相邻token的余弦相似度构造“成为分隔点”的概率: p_t = clamp((1 - cos(h_{t-1}, h_t)) / 2, 0, 1) 并强制首位置为边界(概率=1)。 支持: - 常规batch掩码 mask 模式 - packed序列 cu_seqlens 模式(把多序列打包成单条序列的拼接) - 流式推理 step()(维护状态) """ def __init__(self, d_model, device=None, dtype=None): self.d_model = d_model factory_kwargs = {"device": device, "dtype": dtype} super().__init__() # 相邻相似度计算前的线性投影(初始化为恒等) self.q_proj_layer = nn.Linear(d_model, d_model, bias=False, **factory_kwargs) self.k_proj_layer = nn.Linear(d_model, d_model, bias=False, **factory_kwargs) with torch.no_grad(): self.q_proj_layer.weight.copy_(torch.eye(d_model)) self.k_proj_layer.weight.copy_(torch.eye(d_model)) # 防止外部权重再初始化 self.q_proj_layer.weight._no_reinit = True self.k_proj_layer.weight._no_reinit = True def allocate_inference_cache(self, batch_size, max_seqlen, device, dtype=None): # 分配推理cache(用于step) return RoutingModuleState( has_seen_tokens=torch.zeros(batch_size, device=device, dtype=torch.bool), last_hidden_state=torch.zeros( batch_size, self.d_model, device=device, dtype=dtype ), ) def forward(self, hidden_states, cu_seqlens=None, mask=None, inference_params=None): """ hidden_states: - 若 cu_seqlens is None: (B, L, D) - 若 cu_seqlens 非 None: 期望 packed 模式 (T, D),这里会临时扩维成 (1, T, D) cu_seqlens: packed模式下每条序列的前缀和下标,形如 [0, len1, len1+len2, ...] mask: (B, L) bool,True=有效(非packed) inference_params: RoutingModuleState,用于prefill时的校验与状态维护 """ assert (mask is not None) or ( cu_seqlens is not None ), "Either mask or cu_seqlens must be provided" if inference_params is not None: # prefill阶段必须提供mask,且不允许之前已经见过token assert ( mask is not None ), "Mask must be provided if inference_params is provided" assert ( ~inference_params.has_seen_tokens ).all(), "Cannot have seen tokens when inference_params is not provided" if cu_seqlens is not None: # packed 模式:把 (T, D) 临时变为 (1, T, D) hidden_states = hidden_states.unsqueeze(0) # 计算相邻余弦相似度 cos(h_{t-1}, h_t) cos_sim = torch.einsum( "b l d, b l d -> b l", F.normalize(self.q_proj_layer(hidden_states[:, :-1]), dim=-1), F.normalize(self.k_proj_layer(hidden_states[:, 1:]), dim=-1), ) # p = ((1 - cos) / 2) ∈ [0,1] boundary_prob = torch.clamp(((1 - cos_sim) / 2), min=0.0, max=1.0) # 强制首位置为边界:首位概率=1,补充后长度和输入序列长度想等 PAD_PROB = 1.0 boundary_prob = F.pad(boundary_prob, (1, 0), "constant", PAD_PROB) if cu_seqlens is not None: # packed 模式下,每段序列的第一个位置是 cu_seqlens[:-1] boundary_prob = boundary_prob.squeeze(0) boundary_prob[cu_seqlens[:-1]] = PAD_PROB # 组装为二分类概率 [非边界, 边界] boundary_prob = torch.stack(((1 - boundary_prob), boundary_prob), dim=-1) # 取最大概率类别 selected_idx = torch.argmax(boundary_prob, dim=-1) # 硬边界掩码 boundary_mask = selected_idx == 1 # 形状与 hidden_states 的前两维一致 if mask is not None: # 不允许选择到无效token boundary_mask = boundary_mask & mask if inference_params is not None: # 维护路由状态:是否见过token、最后一个有效token的隐藏状态 has_mask = mask.any(dim=-1) inference_params.has_seen_tokens.copy_( has_mask | inference_params.has_seen_tokens ) last_mask = torch.clamp(mask.sum(dim=-1) - 1, min=0) inference_params.last_hidden_state.copy_( torch.where( has_mask, hidden_states[ torch.arange( hidden_states.shape[0], device=hidden_states.device ), last_mask, ], inference_params.last_hidden_state, ) ) # 取硬选择对应的概率(便于可视化/正则) selected_probs = boundary_prob.gather( dim=-1, index=selected_idx.unsqueeze(-1) ) # (..., 1) return RoutingModuleOutput( boundary_prob=boundary_prob, # (..., 2) boundary_mask=boundary_mask, # (...) selected_probs=selected_probs, # (..., 1) ) def step(self, hidden_states, inference_params): """ 流式单步: hidden_states: (B, 1, D) 使用上一步缓存的 last_hidden_state 与当前token计算相邻相似度,得到当前步的边界概率 """ # (B, D) hidden_states = hidden_states.squeeze(1) cos_sim = torch.einsum( "b d, b d -> b", F.normalize(self.q_proj_layer(inference_params.last_hidden_state), dim=-1), F.normalize(self.k_proj_layer(hidden_states), dim=-1), ) boundary_prob = torch.clamp(((1 - cos_sim) / 2), min=0.0, max=1.0) # 更新最后隐藏状态 inference_params.last_hidden_state.copy_(hidden_states) # 首个token前,强制边界 boundary_prob = torch.where( inference_params.has_seen_tokens, boundary_prob, torch.ones_like(boundary_prob), ) boundary_prob = torch.stack(((1 - boundary_prob), boundary_prob), dim=-1) # 标记为已见token inference_params.has_seen_tokens.copy_( torch.ones_like(inference_params.has_seen_tokens) ) return RoutingModuleOutput( boundary_prob=boundary_prob, # (B, 2) boundary_mask=boundary_prob[..., 1] > 0.5, # (B,) selected_probs=boundary_prob.max(dim=-1).values.unsqueeze(-1), # (B, 1) ) class ChunkLayer(nn.Module): """ Chunk层:根据 boundary_mask 将被选中的“边界token”抽取出来,形成下一层序列。 支持两种模式: - packed(cu_seqlens 非 None):直接在拼接后的序列上索引 - 非packed(mask 非 None):通过排序 trick 把True位置排到前面,并生成 next_mask 返回: - next_hidden_states: 选中的token序列(packed: shape=(#selected, D);非packed: (B, M, D)) - next_cu_seqlens: packed模式下新序列的cu_seqlens;否则None - next_max_seqlen: packed模式下选中的最大长度;非packed模式返回None - next_mask: 非packed模式下的右侧pad掩码;packed模式下None """ def forward(self, hidden_states, boundary_mask, cu_seqlens=None, mask=None): assert (mask is not None) or ( cu_seqlens is not None ), "Either mask or cu_seqlens must be provided" if cu_seqlens is not None: # packed:直接选择True的行,得到拼接后的 selected next_hidden_states = hidden_states[boundary_mask] # 新的cu_seqlens = 对每段最后一个位置(=cu_seqlens[1:]-1)累计True的计数,再前置0 next_cu_seqlens = F.pad( boundary_mask.cumsum(dim=0)[cu_seqlens[1:] - 1], (1, 0) ) # 新序列的最大段长(仅用于内核/优化) next_max_seqlen = int((next_cu_seqlens[1:] - next_cu_seqlens[:-1]).max()) next_mask = None else: # 非packed:对每个batch内,把True位置排到前面(False放到靠后) next_cu_seqlens = None num_tokens = boundary_mask.sum(dim=-1) # 每个样本被选中的数量 next_max_seqlen = int(num_tokens.max()) device = hidden_states.device L = hidden_states.shape[1] # trick:用 (~boundary_mask)*L 把False加大,从而 argsort 后 True 的下标排在前面 token_idx = ( torch.arange(L, device=device)[None, :] + (~boundary_mask).long() * L ) seq_sorted_indices = torch.argsort(token_idx, dim=1) # 收集前 next_max_seqlen 个(不足的样本右侧pad) next_hidden_states = torch.gather( hidden_states, dim=1, index=seq_sorted_indices[:, :next_max_seqlen, None].expand( -1, -1, hidden_states.shape[-1] ), ) # 下游的有效mask(右侧pad无效) next_mask = ( torch.arange(next_max_seqlen, device=device)[None, :] < num_tokens[:, None] ) # 非packed模式下,不再需要 max_seqlen(返回None) next_max_seqlen = None return next_hidden_states, next_cu_seqlens, next_max_seqlen, next_mask def step(self, hidden_states, boundary_mask): # 流式step:仅返回当前步被选中的token(用于下一层) return hidden_states[boundary_mask] class DeChunkLayer(nn.Module): """ DeChunk层:把“被选中的边界token序列”反聚合(EMA)回原始等长序列。 实现上复用 Mamba2 的 Triton 扫描核 mamba_chunk_scan_combined: - 将 d_model 切分为 nheads * headdim - 使用参数 A=-1, b=p, c=1 的一阶状态空间/EMA形式进行前向扫描 - 最终把扫描输出根据分段索引映射回原位置(plug back) 支持: - packed 模式(cu_seqlens) - 非packed(batch+右侧pad) - 流式 step(EMA递推) """ def __init__( self, d_model, dtype=torch.bfloat16, block_size=256, headdim=32, ): super().__init__() self.d_model = d_model # 仅为内核要求:使用 bfloat16,块大小与头维拆分 self.dtype = dtype self.block_size = block_size self.headdim = headdim assert d_model % self.headdim == 0 self.nheads = d_model // self.headdim def allocate_inference_cache(self, batch_size, max_seqlen, device, dtype=None): # 分配EMA的last_value缓存 return DeChunkState( last_value=torch.zeros( batch_size, self.d_model, device=device, dtype=dtype ), ) def forward( self, hidden_states, # 被选中的token序列(packed: (M, D);非packed: (B, M, D)) boundary_mask, # 原序列上的边界掩码((T,) 或 (B, L)) boundary_prob, # 原序列上的二分类概率((..., 2)) cu_seqlens=None, inference_params=None, mask=None, ): """ 核心思路: 1) 从 boundary_prob 得到 p = P(boundary) ∈ (1e-4, 1-1e-4) 2) 构造 dt = log(1 / (1-p)),并对输入做缩放 x = h / dt 3) 用 mamba_chunk_scan_combined 扫描:A=-1, b=p, c=1,对 (B, M, H, P) 进行块扫描 4) 将结果根据 cumulative boundary index 回填到原序列位置 """ if inference_params is not None: # prefill时必须有mask,且首token必须是边界(保证EMA初始化) assert ( mask is not None ), "Mask must be provided if inference_params is provided" assert boundary_mask[ :, 0 ].all(), "First token must be a boundary if running prefill" # 取边界概率的“边界类”概率 p,并限制在(1e-4, 1-1e-4)内,避免数值不稳 p = torch.clamp(boundary_prob[..., -1].float(), min=1e-4, max=1 - (1e-4)) if cu_seqlens is not None: # packed:从原序列p中取出被选中的位置对应的概率,形状(B=1, M) p = p[boundary_mask].unsqueeze(0) # 为triton核准备packed序列的索引映射 seq_idx = get_seq_idx(cu_seqlens, device=hidden_states.device) else: B, L = boundary_mask.shape seq_idx = None # 与ChunkLayer一致的排序 trick,得到选中的顺序(True在前) token_idx = ( torch.arange(L, device=hidden_states.device)[None, :] + (~boundary_mask).long() * L ) seq_sorted_indices = torch.argsort(token_idx, dim=1) # 取出与 hidden_states 对应长度的 p((B, M)) p = torch.gather( p, dim=1, index=seq_sorted_indices[:, : hidden_states.shape[1]] ) # (B, M) original_dtype = hidden_states.dtype # 构造 EMA 扫描所需变量 dt = torch.log(1 / (1 - p)).to(self.dtype) # (B, M) x = (hidden_states / dt[..., None]).to(self.dtype) # (B, M, D) / (B, M, 1) # A, b, c 分别对应一阶状态空间/EMA的参数 A = -torch.ones( (self.nheads,), device=hidden_states.device, dtype=torch.float32 ) b = p.to(self.dtype) c = torch.ones_like(b) # 调用triton核进行块扫描 out = mamba_chunk_scan_combined( rearrange(x, "b l (h p) -> b l h p", p=self.headdim), # (B, M, H, P) repeat(dt, "b l -> b l h", h=self.nheads), # (B, M, H) A, # (H,) rearrange(b, "b l -> b l 1 1"), # (B, M, 1, 1) rearrange(c, "b l -> b l 1 1"), # (B, M, 1, 1) chunk_size=self.block_size, seq_idx=seq_idx, # packed时提供 ) out = rearrange(out, "b l h p -> b l (h p)") # (B, M, D) # 将扫描结果回填(plug back)到原序列位置 if cu_seqlens is not None: out = out.squeeze(0) # (M, D) plug_back_idx = boundary_mask.cumsum(dim=0) - 1 # (T,) out = torch.gather( out, dim=0, index=plug_back_idx.unsqueeze(-1).expand(-1, self.d_model) ) # (T, D) else: plug_back_idx = torch.cumsum(boundary_mask, dim=1) - 1 # (B, L) out = torch.gather( out, dim=1, index=plug_back_idx.unsqueeze(-1).expand(-1, -1, self.d_model), ) # (B, L, D) # 更新流式缓存 if inference_params is not None: inference_params.last_value.copy_(out[:, -1]) return out.to(original_dtype) def step(self, hidden_states, boundary_mask, boundary_prob, inference_params): """ 流式单步 EMA 反聚合: hidden_states: (B', 1, D),其中 B' = 当前步被选中的数量(boundary_mask.sum()) boundary_mask: (B,) 当前batch哪些位置被选中为边界 boundary_prob: (B, 2) 当前batch各位置的边界概率 输出:(B, 1, D),对应对所有位置做了一步 EMA 更新后的值 """ B = boundary_mask.shape[0] D = hidden_states.shape[-1] # 构造当前步每个位置的 p(未被选中的位置 p=0) p = torch.zeros(B, device=hidden_states.device, dtype=hidden_states.dtype) p[boundary_mask] = boundary_prob[boundary_mask, -1].clamp( min=1e-4, max=1 - (1e-4) ) # 构造当前被选中的隐藏状态(未选中为0) current_hidden_states = torch.zeros( B, D, device=hidden_states.device, dtype=hidden_states.dtype ) current_hidden_states[boundary_mask] = hidden_states.squeeze(1) # EMA:result = p * x + (1 - p) * last result = p * current_hidden_states + (1 - p) * inference_params.last_value inference_params.last_value.copy_(result) return result.unsqueeze(1)