from dataclasses import dataclass from typing import Optional, Literal, List import torch import torch.nn as nn import torch.nn.functional as F # 来自你的代码库(可直接使用) from hnet.modules.dc import RoutingModule, ChunkLayer from hnet.modules.isotropic import Isotropic from hnet.models.config_hnet import HNetConfig, SSMConfig, AttnConfig # -------------------- 辅助 -------------------- def create_isotropic_encoder(d_model, arch="m", height=4, device=None, dtype=None): """创建简化的Isotropic编码器""" factory_kwargs = {"device": device, "dtype": dtype} # 创建HNetConfig,确保list字段有足够的元素 config = HNetConfig( arch_layout=[f"{arch}{height}"], d_model=[d_model], d_intermediate=[d_model * 2], ssm_cfg=SSMConfig( d_conv=4, expand=2, d_state=128, chunk_size=256 ), attn_cfg=AttnConfig( num_heads=[8], # 确保有至少一个元素 rotary_emb_dim=[0], # 确保有至少一个元素 window_size=[-1] # 确保有至少一个元素 ) ) return Isotropic( config=config, pos_idx=0, stage_idx=0, **factory_kwargs ) def ratio_loss(boundary_mask: torch.Tensor, boundary_prob: torch.Tensor, target_N: int) -> torch.Tensor: F_act = boundary_mask.float().mean(dim=1) # (B,) G_prob = boundary_prob[..., 1].mean(dim=1) # (B,) N = float(target_N) loss = N / (N - 1.0) * (((N - 1.0) * F_act) + (1.0 - F_act) * (1.0 - G_prob)) return loss.mean() def masked_mean(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: mask_f = mask.float().unsqueeze(-1) # (B, L, 1) s = (x * mask_f).sum(dim=1) # (B, D) denom = mask_f.sum(dim=1).clamp_min(1.0) return s / denom # -------------------- 多层Encoder(金字塔):每层Mamba2 + 路由下采样,只有最终有主网络 -------------------- class PyramidEncoders_NoDechunk(nn.Module): """ 层级结构(仅编码器逐层压缩;主网络只在最终一层): 输入 x0: (B, L0, 1) - 线性升维 -> D0 For s = 0..S-1: Es(Mamba2, D_s) -> h_s (B, L_s, D_s) 路由 + 下采样 -> x_{s+1} (B, L_{s+1}, D_s), mask_{s+1} 维度扩展 D_s -> D_{s+1}(拼接共享向量) 最终 x_S: (B, L_S, D_S) 送入单一主网络 M (Transformer/Mamba) 跨尺度融合(不去分块):融合 E^0 的 pooled_enc0 与 主网络 pooled_main """ def __init__( self, d_models: List[int], # [D0, D1, ..., D_S] 单调非降 encoder_cfg_per_stage: List[dict], # S个编码器配置(必须 arch='m'/'M') main_cfg: dict, # 单一主网络配置(在最压缩序列上工作) fusion_dropout: float = 0.1, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, ): super().__init__() factory_kwargs = {"device": device, "dtype": dtype} assert len(d_models) >= 1 S = len(d_models) - 1 assert S == len(encoder_cfg_per_stage), "stage数等于encoder配置数" for i in range(S): assert d_models[i+1] >= d_models[i], "需满足 D_s <= D_{s+1}(宽度单调增加)" assert encoder_cfg_per_stage[i].get("arch", "m") in ("m", "M"), "Encoder必须为Mamba2" self.S = S self.d_models = d_models # 输入升维到 D0 self.input_proj = nn.Linear(1, d_models[0], **factory_kwargs) # 每层编码器 + 路由 + 下采样 + 扩宽参数 self.encoders = nn.ModuleList() self.routers = nn.ModuleList() self.chunks = nn.ModuleList() self.pad_vectors = nn.ParameterList() for s in range(S): self.encoders.append( create_isotropic_encoder( d_model=d_models[s], **{k: v for k, v in encoder_cfg_per_stage[s].items() if k != "d_model"}, **factory_kwargs ) ) self.routers.append(RoutingModule(d_models[s], **factory_kwargs)) self.chunks.append(ChunkLayer()) delta = d_models[s+1] - d_models[s] self.pad_vectors.append(nn.Parameter(torch.zeros(delta, **factory_kwargs)) if delta > 0 else nn.Parameter(torch.empty(0, **factory_kwargs))) # 最终唯一的主网络:在 D_S & L_S 上运行 self.main_network = create_isotropic_encoder( d_model=d_models[-1], **{k: v for k, v in main_cfg.items() if k != "d_model"}, **factory_kwargs ) # 跨尺度融合:将 pooled_enc0(D0) 投到 D_S 并与 pooled_main(D_S) 融合 -> D_S self.proj_enc0_to_DS = nn.Linear(d_models[0], d_models[-1], **factory_kwargs) self.fusion_head = nn.Sequential( nn.Linear(d_models[-1] + d_models[-1], d_models[-1], **factory_kwargs), nn.GELU(), nn.Dropout(fusion_dropout), nn.Linear(d_models[-1], d_models[-1], **factory_kwargs), ) def _expand_width(self, x: torch.Tensor, pad_vec: nn.Parameter) -> torch.Tensor: if pad_vec.numel() == 0: return x early = x.shape[:-1] return torch.cat([x, pad_vec.expand(*early, -1)], dim=-1) def forward(self, x_scalar: torch.Tensor, mask: Optional[torch.Tensor] = None, return_seq: bool = False): """ x_scalar: (B, L) 或 (B, L, 1) mask: (B, L) bool 返回: fused_vec: (B, D_S) debug: 可选 aux: 包含各层路由信息(供ratio loss) """ if x_scalar.dim() == 2: x_scalar = x_scalar.unsqueeze(-1) # (B, L, 1) B, L, _ = x_scalar.shape device = x_scalar.device if mask is None: mask = torch.ones(B, L, dtype=torch.bool, device=device) # 初始升维到 D0 x = self.input_proj(x_scalar) # (B, L0, D0) cur_mask = mask pooled_enc0 = None aux_per_stage = [] seq_debug = [] if return_seq else None # 逐层:Encoder(Mamba2)->Routing->Chunk->Expand D for s in range(self.S): d_in = self.d_models[s] # 细粒度编码(未压缩序列) h_enc = self.encoders[s](x, mask=cur_mask) # (B, L_s, D_s) if s == 0: pooled_enc0 = masked_mean(h_enc, cur_mask) # (B, D0) # 路由 + 下采样(得到更短序列) bpred = self.routers[s](h_enc, mask=cur_mask) x_next, _, _, mask_next = self.chunks[s](h_enc, bpred.boundary_mask, mask=cur_mask) # (B, L_{s+1}, D_s) # 扩展宽度 D_s -> D_{s+1} x_next = self._expand_width(x_next, self.pad_vectors[s]) # (B, L_{s+1}, D_{s+1}) # 推进到下一层 x, cur_mask = x_next, mask_next aux_per_stage.append({ "boundary_mask": bpred.boundary_mask, "boundary_prob": bpred.boundary_prob, "selected_probs": bpred.selected_probs, }) if return_seq: seq_debug.append({"stage": s, "seq": x, "mask": cur_mask}) # 现在 x: (B, L_S, D_S), cur_mask: (B, L_S) # 最终单一主网络在最压缩序列上 h_main = self.main_network(x, mask=cur_mask) # (B, L_S, D_S) # 主网络池化 if cur_mask is None: pooled_main = h_main.mean(dim=1) # (B, D_S) else: pooled_main = (h_main * cur_mask.float().unsqueeze(-1)).sum(dim=1) / \ cur_mask.float().sum(dim=1, keepdim=True).clamp_min(1.0) # 跨尺度融合:E^0 全局池化 与 主网络池化 pooled_enc0_in_DS = self.proj_enc0_to_DS(pooled_enc0) # (B, D_S) fused = torch.cat([pooled_enc0_in_DS, pooled_main], dim=-1) # (B, 2*D_S) fused = self.fusion_head(fused) # (B, D_S) aux = {"per_stage": aux_per_stage} if return_seq: return fused, {"stages": seq_debug, "main_seq": h_main, "main_mask": cur_mask}, aux else: return fused, None, aux # -------------------- 顶层:多通道融合 + 分类头(仅一个主网络) -------------------- @dataclass class HierEncodersSingleMainConfig: num_channels: int d_models: List[int] # [D0, D1, ..., D_S] 单调非降 num_classes: int encoder_cfg_per_stage: List[dict] # S个编码器配置(均为Mamba2, height≈4) main_cfg: dict # 单一主网络配置(Transformer或Mamba2),d_model自动用D_S target_compression_N_per_stage: List[int] share_channel: bool = True fusion_across_channels: Literal["mean", "concat"] = "mean" dropout: float = 0.1 class HierEncodersSingleMainClassifier(nn.Module): def __init__(self, cfg: HierEncodersSingleMainConfig, dtype=None, device=None): super().__init__() self.cfg = cfg factory_kwargs = {"dtype": dtype, "device": device} S = len(cfg.d_models) - 1 assert S == len(cfg.encoder_cfg_per_stage) == len(cfg.target_compression_N_per_stage), "stage数不一致" if cfg.share_channel: self.channel_encoder = PyramidEncoders_NoDechunk( d_models=cfg.d_models, encoder_cfg_per_stage=cfg.encoder_cfg_per_stage, main_cfg=cfg.main_cfg, **factory_kwargs, ) else: self.channel_encoder = nn.ModuleList([ PyramidEncoders_NoDechunk( d_models=cfg.d_models, encoder_cfg_per_stage=cfg.encoder_cfg_per_stage, main_cfg=cfg.main_cfg, **factory_kwargs, ) for _ in range(cfg.num_channels) ]) fusion_dim = (cfg.num_channels * cfg.d_models[-1]) if cfg.fusion_across_channels == "concat" \ else cfg.d_models[-1] self.dropout = nn.Dropout(cfg.dropout) self.head = nn.Linear(fusion_dim, cfg.num_classes, **factory_kwargs) def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, return_seq: bool = False): """ x: (B, L, N) 多通道输入 mask: (B, L) 时序mask """ B, L, N = x.shape assert N == self.cfg.num_channels channel_vecs: List[torch.Tensor] = [] ratio_losses = [] seq_dbg_all = [] if return_seq else None for c in range(N): x_c = x[..., c] # (B, L) if self.cfg.share_channel: vec, seq_dbg, aux = self.channel_encoder(x_c, mask=mask, return_seq=return_seq) else: vec, seq_dbg, aux = self.channel_encoder[c](x_c, mask=mask, return_seq=return_seq) # ratio loss 累加(每个encoder stage一项) total_rl = 0.0 for s, aux_s in enumerate(aux["per_stage"]): rl = ratio_loss(aux_s["boundary_mask"], aux_s["boundary_prob"], self.cfg.target_compression_N_per_stage[s]) total_rl = total_rl + rl ratio_losses.append(total_rl) channel_vecs.append(vec) if return_seq: seq_dbg_all.append(seq_dbg) if self.cfg.fusion_across_channels == "concat": fused = torch.cat(channel_vecs, dim=-1) # (B, N*D_S) else: fused = torch.stack(channel_vecs, dim=1).mean(dim=1) # (B, D_S) fused = self.dropout(fused) logits = self.head(fused) aux_all = {"ratio_loss": torch.stack(ratio_losses).mean()} if return_seq: return logits, seq_dbg_all, aux_all else: return logits, None, aux_all # -------------------- 使用示例 -------------------- if __name__ == "__main__": """ 符合要求: - 多层仅增加编码器数量(每层Mamba2 + 动态分块),主网络只有最终一个 - 序列长度逐层缩短(由DC决定),通道维度 d_model 单调增大(SpaceByte式共享向量拼接) - 不使用去分块(dechunk);跨尺度融合用 E^0 的全局池化 + 最终主网络池化 """ B, L, N = 8, 1024, 6 num_classes = 7 d_models = [128, 256, 512] # D0 <= D1 <= D2 encoder_cfg_per_stage = [ dict(arch="m", height=4, ssm_cfg=dict(), attn_cfg=dict()), # stage 0 encoder (Mamba2) dict(arch="m", height=4, ssm_cfg=dict(), attn_cfg=dict()), # stage 1 encoder (Mamba2) ] main_cfg = dict( arch="T", height=12, ssm_cfg=dict(), attn_cfg=dict(num_heads=8) # 最终主网络(较重) ) target_compression_N_per_stage = [4, 4] cfg = HierEncodersSingleMainConfig( num_channels=N, d_models=d_models, num_classes=num_classes, encoder_cfg_per_stage=encoder_cfg_per_stage, main_cfg=main_cfg, target_compression_N_per_stage=target_compression_N_per_stage, share_channel=True, fusion_across_channels="mean", dropout=0.1, ) model = HierEncodersSingleMainClassifier(cfg).cuda().train() x = torch.randn(B, L, N, device="cuda") mask = torch.ones(B, L, dtype=torch.bool, device="cuda") logits, _, aux = model(x, mask=mask, return_seq=False) y = torch.randint(0, num_classes, (B,), device="cuda") cls_loss = F.cross_entropy(logits, y) ratio_reg = 0.03 * aux["ratio_loss"] loss = cls_loss + ratio_reg loss.backward() print("logits:", logits.shape, "loss:", float(loss))