340 lines
13 KiB
Python
340 lines
13 KiB
Python
from dataclasses import dataclass
|
||
from typing import Optional, Literal, List
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
|
||
# 来自你的代码库(可直接使用)
|
||
from hnet.modules.dc import RoutingModule, ChunkLayer
|
||
from hnet.modules.isotropic import Isotropic
|
||
from hnet.models.config_hnet import HNetConfig, SSMConfig, AttnConfig
|
||
|
||
# -------------------- 辅助 --------------------
|
||
def create_isotropic_encoder(d_model, arch="m", height=4, device=None, dtype=None):
|
||
"""创建简化的Isotropic编码器"""
|
||
factory_kwargs = {"device": device, "dtype": dtype}
|
||
|
||
# 创建HNetConfig,确保list字段有足够的元素
|
||
config = HNetConfig(
|
||
arch_layout=[f"{arch}{height}"],
|
||
d_model=[d_model],
|
||
d_intermediate=[d_model * 2],
|
||
ssm_cfg=SSMConfig(
|
||
d_conv=4,
|
||
expand=2,
|
||
d_state=128,
|
||
chunk_size=256
|
||
),
|
||
attn_cfg=AttnConfig(
|
||
num_heads=[8], # 确保有至少一个元素
|
||
rotary_emb_dim=[0], # 确保有至少一个元素
|
||
window_size=[-1] # 确保有至少一个元素
|
||
)
|
||
)
|
||
|
||
return Isotropic(
|
||
config=config,
|
||
pos_idx=0,
|
||
stage_idx=0,
|
||
**factory_kwargs
|
||
)
|
||
|
||
def ratio_loss(boundary_mask: torch.Tensor, boundary_prob: torch.Tensor, target_N: int) -> torch.Tensor:
|
||
F_act = boundary_mask.float().mean(dim=1) # (B,)
|
||
G_prob = boundary_prob[..., 1].mean(dim=1) # (B,)
|
||
N = float(target_N)
|
||
loss = N / (N - 1.0) * (((N - 1.0) * F_act) + (1.0 - F_act) * (1.0 - G_prob))
|
||
return loss.mean()
|
||
|
||
def masked_mean(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||
mask_f = mask.float().unsqueeze(-1) # (B, L, 1)
|
||
s = (x * mask_f).sum(dim=1) # (B, D)
|
||
denom = mask_f.sum(dim=1).clamp_min(1.0)
|
||
return s / denom
|
||
|
||
# -------------------- 多层Encoder(金字塔):每层Mamba2 + 路由下采样,只有最终有主网络 --------------------
|
||
class PyramidEncoders_NoDechunk(nn.Module):
|
||
"""
|
||
层级结构(仅编码器逐层压缩;主网络只在最终一层):
|
||
输入 x0: (B, L0, 1)
|
||
- 线性升维 -> D0
|
||
For s = 0..S-1:
|
||
Es(Mamba2, D_s) -> h_s (B, L_s, D_s)
|
||
路由 + 下采样 -> x_{s+1} (B, L_{s+1}, D_s), mask_{s+1}
|
||
维度扩展 D_s -> D_{s+1}(拼接共享向量)
|
||
最终 x_S: (B, L_S, D_S) 送入单一主网络 M (Transformer/Mamba)
|
||
跨尺度融合(不去分块):融合 E^0 的 pooled_enc0 与 主网络 pooled_main
|
||
"""
|
||
def __init__(
|
||
self,
|
||
d_models: List[int], # [D0, D1, ..., D_S] 单调非降
|
||
encoder_cfg_per_stage: List[dict], # S个编码器配置(必须 arch='m'/'M')
|
||
main_cfg: dict, # 单一主网络配置(在最压缩序列上工作)
|
||
fusion_dropout: float = 0.1,
|
||
dtype: Optional[torch.dtype] = None,
|
||
device: Optional[torch.device] = None,
|
||
):
|
||
super().__init__()
|
||
factory_kwargs = {"device": device, "dtype": dtype}
|
||
|
||
assert len(d_models) >= 1
|
||
S = len(d_models) - 1
|
||
assert S == len(encoder_cfg_per_stage), "stage数等于encoder配置数"
|
||
for i in range(S):
|
||
assert d_models[i+1] >= d_models[i], "需满足 D_s <= D_{s+1}(宽度单调增加)"
|
||
assert encoder_cfg_per_stage[i].get("arch", "m") in ("m", "M"), "Encoder必须为Mamba2"
|
||
|
||
self.S = S
|
||
self.d_models = d_models
|
||
|
||
# 输入升维到 D0
|
||
self.input_proj = nn.Linear(1, d_models[0], **factory_kwargs)
|
||
|
||
# 每层编码器 + 路由 + 下采样 + 扩宽参数
|
||
self.encoders = nn.ModuleList()
|
||
self.routers = nn.ModuleList()
|
||
self.chunks = nn.ModuleList()
|
||
self.pad_vectors = nn.ParameterList()
|
||
for s in range(S):
|
||
self.encoders.append(
|
||
create_isotropic_encoder(
|
||
d_model=d_models[s],
|
||
**{k: v for k, v in encoder_cfg_per_stage[s].items() if k != "d_model"},
|
||
**factory_kwargs
|
||
)
|
||
)
|
||
self.routers.append(RoutingModule(d_models[s], **factory_kwargs))
|
||
self.chunks.append(ChunkLayer())
|
||
delta = d_models[s+1] - d_models[s]
|
||
self.pad_vectors.append(nn.Parameter(torch.zeros(delta, **factory_kwargs)) if delta > 0 else nn.Parameter(torch.empty(0, **factory_kwargs)))
|
||
|
||
# 最终唯一的主网络:在 D_S & L_S 上运行
|
||
self.main_network = create_isotropic_encoder(
|
||
d_model=d_models[-1],
|
||
**{k: v for k, v in main_cfg.items() if k != "d_model"},
|
||
**factory_kwargs
|
||
)
|
||
|
||
# 跨尺度融合:将 pooled_enc0(D0) 投到 D_S 并与 pooled_main(D_S) 融合 -> D_S
|
||
self.proj_enc0_to_DS = nn.Linear(d_models[0], d_models[-1], **factory_kwargs)
|
||
self.fusion_head = nn.Sequential(
|
||
nn.Linear(d_models[-1] + d_models[-1], d_models[-1], **factory_kwargs),
|
||
nn.GELU(),
|
||
nn.Dropout(fusion_dropout),
|
||
nn.Linear(d_models[-1], d_models[-1], **factory_kwargs),
|
||
)
|
||
|
||
def _expand_width(self, x: torch.Tensor, pad_vec: nn.Parameter) -> torch.Tensor:
|
||
if pad_vec.numel() == 0:
|
||
return x
|
||
early = x.shape[:-1]
|
||
return torch.cat([x, pad_vec.expand(*early, -1)], dim=-1)
|
||
|
||
def forward(self, x_scalar: torch.Tensor, mask: Optional[torch.Tensor] = None, return_seq: bool = False):
|
||
"""
|
||
x_scalar: (B, L) 或 (B, L, 1)
|
||
mask: (B, L) bool
|
||
返回:
|
||
fused_vec: (B, D_S)
|
||
debug: 可选
|
||
aux: 包含各层路由信息(供ratio loss)
|
||
"""
|
||
if x_scalar.dim() == 2:
|
||
x_scalar = x_scalar.unsqueeze(-1) # (B, L, 1)
|
||
B, L, _ = x_scalar.shape
|
||
device = x_scalar.device
|
||
if mask is None:
|
||
mask = torch.ones(B, L, dtype=torch.bool, device=device)
|
||
|
||
# 初始升维到 D0
|
||
x = self.input_proj(x_scalar) # (B, L0, D0)
|
||
cur_mask = mask
|
||
|
||
pooled_enc0 = None
|
||
aux_per_stage = []
|
||
seq_debug = [] if return_seq else None
|
||
|
||
# 逐层:Encoder(Mamba2)->Routing->Chunk->Expand D
|
||
for s in range(self.S):
|
||
d_in = self.d_models[s]
|
||
# 细粒度编码(未压缩序列)
|
||
h_enc = self.encoders[s](x, mask=cur_mask) # (B, L_s, D_s)
|
||
|
||
if s == 0:
|
||
pooled_enc0 = masked_mean(h_enc, cur_mask) # (B, D0)
|
||
|
||
# 路由 + 下采样(得到更短序列)
|
||
bpred = self.routers[s](h_enc, mask=cur_mask)
|
||
x_next, _, _, mask_next = self.chunks[s](h_enc, bpred.boundary_mask, mask=cur_mask) # (B, L_{s+1}, D_s)
|
||
|
||
# 扩展宽度 D_s -> D_{s+1}
|
||
x_next = self._expand_width(x_next, self.pad_vectors[s]) # (B, L_{s+1}, D_{s+1})
|
||
|
||
# 推进到下一层
|
||
x, cur_mask = x_next, mask_next
|
||
|
||
aux_per_stage.append({
|
||
"boundary_mask": bpred.boundary_mask,
|
||
"boundary_prob": bpred.boundary_prob,
|
||
"selected_probs": bpred.selected_probs,
|
||
})
|
||
if return_seq:
|
||
seq_debug.append({"stage": s, "seq": x, "mask": cur_mask})
|
||
|
||
# 现在 x: (B, L_S, D_S), cur_mask: (B, L_S)
|
||
# 最终单一主网络在最压缩序列上
|
||
h_main = self.main_network(x, mask=cur_mask) # (B, L_S, D_S)
|
||
|
||
# 主网络池化
|
||
if cur_mask is None:
|
||
pooled_main = h_main.mean(dim=1) # (B, D_S)
|
||
else:
|
||
pooled_main = (h_main * cur_mask.float().unsqueeze(-1)).sum(dim=1) / \
|
||
cur_mask.float().sum(dim=1, keepdim=True).clamp_min(1.0)
|
||
|
||
# 跨尺度融合:E^0 全局池化 与 主网络池化
|
||
pooled_enc0_in_DS = self.proj_enc0_to_DS(pooled_enc0) # (B, D_S)
|
||
fused = torch.cat([pooled_enc0_in_DS, pooled_main], dim=-1) # (B, 2*D_S)
|
||
fused = self.fusion_head(fused) # (B, D_S)
|
||
|
||
aux = {"per_stage": aux_per_stage}
|
||
if return_seq:
|
||
return fused, {"stages": seq_debug, "main_seq": h_main, "main_mask": cur_mask}, aux
|
||
else:
|
||
return fused, None, aux
|
||
|
||
# -------------------- 顶层:多通道融合 + 分类头(仅一个主网络) --------------------
|
||
@dataclass
|
||
class HierEncodersSingleMainConfig:
|
||
num_channels: int
|
||
d_models: List[int] # [D0, D1, ..., D_S] 单调非降
|
||
num_classes: int
|
||
encoder_cfg_per_stage: List[dict] # S个编码器配置(均为Mamba2, height≈4)
|
||
main_cfg: dict # 单一主网络配置(Transformer或Mamba2),d_model自动用D_S
|
||
target_compression_N_per_stage: List[int]
|
||
share_channel: bool = True
|
||
fusion_across_channels: Literal["mean", "concat"] = "mean"
|
||
dropout: float = 0.1
|
||
|
||
class HierEncodersSingleMainClassifier(nn.Module):
|
||
def __init__(self, cfg: HierEncodersSingleMainConfig, dtype=None, device=None):
|
||
super().__init__()
|
||
self.cfg = cfg
|
||
factory_kwargs = {"dtype": dtype, "device": device}
|
||
|
||
S = len(cfg.d_models) - 1
|
||
assert S == len(cfg.encoder_cfg_per_stage) == len(cfg.target_compression_N_per_stage), "stage数不一致"
|
||
|
||
if cfg.share_channel:
|
||
self.channel_encoder = PyramidEncoders_NoDechunk(
|
||
d_models=cfg.d_models,
|
||
encoder_cfg_per_stage=cfg.encoder_cfg_per_stage,
|
||
main_cfg=cfg.main_cfg,
|
||
**factory_kwargs,
|
||
)
|
||
else:
|
||
self.channel_encoder = nn.ModuleList([
|
||
PyramidEncoders_NoDechunk(
|
||
d_models=cfg.d_models,
|
||
encoder_cfg_per_stage=cfg.encoder_cfg_per_stage,
|
||
main_cfg=cfg.main_cfg,
|
||
**factory_kwargs,
|
||
)
|
||
for _ in range(cfg.num_channels)
|
||
])
|
||
|
||
fusion_dim = (cfg.num_channels * cfg.d_models[-1]) if cfg.fusion_across_channels == "concat" \
|
||
else cfg.d_models[-1]
|
||
self.dropout = nn.Dropout(cfg.dropout)
|
||
self.head = nn.Linear(fusion_dim, cfg.num_classes, **factory_kwargs)
|
||
|
||
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, return_seq: bool = False):
|
||
"""
|
||
x: (B, L, N) 多通道输入
|
||
mask: (B, L) 时序mask
|
||
"""
|
||
B, L, N = x.shape
|
||
assert N == self.cfg.num_channels
|
||
|
||
channel_vecs: List[torch.Tensor] = []
|
||
ratio_losses = []
|
||
seq_dbg_all = [] if return_seq else None
|
||
|
||
for c in range(N):
|
||
x_c = x[..., c] # (B, L)
|
||
if self.cfg.share_channel:
|
||
vec, seq_dbg, aux = self.channel_encoder(x_c, mask=mask, return_seq=return_seq)
|
||
else:
|
||
vec, seq_dbg, aux = self.channel_encoder[c](x_c, mask=mask, return_seq=return_seq)
|
||
|
||
# ratio loss 累加(每个encoder stage一项)
|
||
total_rl = 0.0
|
||
for s, aux_s in enumerate(aux["per_stage"]):
|
||
rl = ratio_loss(aux_s["boundary_mask"], aux_s["boundary_prob"], self.cfg.target_compression_N_per_stage[s])
|
||
total_rl = total_rl + rl
|
||
ratio_losses.append(total_rl)
|
||
|
||
channel_vecs.append(vec)
|
||
if return_seq:
|
||
seq_dbg_all.append(seq_dbg)
|
||
|
||
if self.cfg.fusion_across_channels == "concat":
|
||
fused = torch.cat(channel_vecs, dim=-1) # (B, N*D_S)
|
||
else:
|
||
fused = torch.stack(channel_vecs, dim=1).mean(dim=1) # (B, D_S)
|
||
|
||
fused = self.dropout(fused)
|
||
logits = self.head(fused)
|
||
|
||
aux_all = {"ratio_loss": torch.stack(ratio_losses).mean()}
|
||
if return_seq:
|
||
return logits, seq_dbg_all, aux_all
|
||
else:
|
||
return logits, None, aux_all
|
||
|
||
# -------------------- 使用示例 --------------------
|
||
if __name__ == "__main__":
|
||
"""
|
||
符合要求:
|
||
- 多层仅增加编码器数量(每层Mamba2 + 动态分块),主网络只有最终一个
|
||
- 序列长度逐层缩短(由DC决定),通道维度 d_model 单调增大(SpaceByte式共享向量拼接)
|
||
- 不使用去分块(dechunk);跨尺度融合用 E^0 的全局池化 + 最终主网络池化
|
||
"""
|
||
B, L, N = 8, 1024, 6
|
||
num_classes = 7
|
||
d_models = [128, 256, 512] # D0 <= D1 <= D2
|
||
|
||
encoder_cfg_per_stage = [
|
||
dict(arch="m", height=4, ssm_cfg=dict(), attn_cfg=dict()), # stage 0 encoder (Mamba2)
|
||
dict(arch="m", height=4, ssm_cfg=dict(), attn_cfg=dict()), # stage 1 encoder (Mamba2)
|
||
]
|
||
main_cfg = dict(
|
||
arch="T", height=12, ssm_cfg=dict(), attn_cfg=dict(num_heads=8) # 最终主网络(较重)
|
||
)
|
||
target_compression_N_per_stage = [4, 4]
|
||
|
||
cfg = HierEncodersSingleMainConfig(
|
||
num_channels=N,
|
||
d_models=d_models,
|
||
num_classes=num_classes,
|
||
encoder_cfg_per_stage=encoder_cfg_per_stage,
|
||
main_cfg=main_cfg,
|
||
target_compression_N_per_stage=target_compression_N_per_stage,
|
||
share_channel=True,
|
||
fusion_across_channels="mean",
|
||
dropout=0.1,
|
||
)
|
||
|
||
model = HierEncodersSingleMainClassifier(cfg).cuda().train()
|
||
x = torch.randn(B, L, N, device="cuda")
|
||
mask = torch.ones(B, L, dtype=torch.bool, device="cuda")
|
||
|
||
logits, _, aux = model(x, mask=mask, return_seq=False)
|
||
y = torch.randint(0, num_classes, (B,), device="cuda")
|
||
cls_loss = F.cross_entropy(logits, y)
|
||
ratio_reg = 0.03 * aux["ratio_loss"]
|
||
loss = cls_loss + ratio_reg
|
||
loss.backward()
|
||
print("logits:", logits.shape, "loss:", float(loss))
|