feat: add mamba and dynamic chunking related code and test code
This commit is contained in:
339
models/DC_hnet.py
Normal file
339
models/DC_hnet.py
Normal file
@ -0,0 +1,339 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Literal, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# 来自你的代码库(可直接使用)
|
||||
from hnet.modules.dc import RoutingModule, ChunkLayer
|
||||
from hnet.modules.isotropic import Isotropic
|
||||
from hnet.models.config_hnet import HNetConfig, SSMConfig, AttnConfig
|
||||
|
||||
# -------------------- 辅助 --------------------
|
||||
def create_isotropic_encoder(d_model, arch="m", height=4, device=None, dtype=None):
|
||||
"""创建简化的Isotropic编码器"""
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
|
||||
# 创建HNetConfig,确保list字段有足够的元素
|
||||
config = HNetConfig(
|
||||
arch_layout=[f"{arch}{height}"],
|
||||
d_model=[d_model],
|
||||
d_intermediate=[d_model * 2],
|
||||
ssm_cfg=SSMConfig(
|
||||
d_conv=4,
|
||||
expand=2,
|
||||
d_state=128,
|
||||
chunk_size=256
|
||||
),
|
||||
attn_cfg=AttnConfig(
|
||||
num_heads=[8], # 确保有至少一个元素
|
||||
rotary_emb_dim=[0], # 确保有至少一个元素
|
||||
window_size=[-1] # 确保有至少一个元素
|
||||
)
|
||||
)
|
||||
|
||||
return Isotropic(
|
||||
config=config,
|
||||
pos_idx=0,
|
||||
stage_idx=0,
|
||||
**factory_kwargs
|
||||
)
|
||||
|
||||
def ratio_loss(boundary_mask: torch.Tensor, boundary_prob: torch.Tensor, target_N: int) -> torch.Tensor:
|
||||
F_act = boundary_mask.float().mean(dim=1) # (B,)
|
||||
G_prob = boundary_prob[..., 1].mean(dim=1) # (B,)
|
||||
N = float(target_N)
|
||||
loss = N / (N - 1.0) * (((N - 1.0) * F_act) + (1.0 - F_act) * (1.0 - G_prob))
|
||||
return loss.mean()
|
||||
|
||||
def masked_mean(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
mask_f = mask.float().unsqueeze(-1) # (B, L, 1)
|
||||
s = (x * mask_f).sum(dim=1) # (B, D)
|
||||
denom = mask_f.sum(dim=1).clamp_min(1.0)
|
||||
return s / denom
|
||||
|
||||
# -------------------- 多层Encoder(金字塔):每层Mamba2 + 路由下采样,只有最终有主网络 --------------------
|
||||
class PyramidEncoders_NoDechunk(nn.Module):
|
||||
"""
|
||||
层级结构(仅编码器逐层压缩;主网络只在最终一层):
|
||||
输入 x0: (B, L0, 1)
|
||||
- 线性升维 -> D0
|
||||
For s = 0..S-1:
|
||||
Es(Mamba2, D_s) -> h_s (B, L_s, D_s)
|
||||
路由 + 下采样 -> x_{s+1} (B, L_{s+1}, D_s), mask_{s+1}
|
||||
维度扩展 D_s -> D_{s+1}(拼接共享向量)
|
||||
最终 x_S: (B, L_S, D_S) 送入单一主网络 M (Transformer/Mamba)
|
||||
跨尺度融合(不去分块):融合 E^0 的 pooled_enc0 与 主网络 pooled_main
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
d_models: List[int], # [D0, D1, ..., D_S] 单调非降
|
||||
encoder_cfg_per_stage: List[dict], # S个编码器配置(必须 arch='m'/'M')
|
||||
main_cfg: dict, # 单一主网络配置(在最压缩序列上工作)
|
||||
fusion_dropout: float = 0.1,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
super().__init__()
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
|
||||
assert len(d_models) >= 1
|
||||
S = len(d_models) - 1
|
||||
assert S == len(encoder_cfg_per_stage), "stage数等于encoder配置数"
|
||||
for i in range(S):
|
||||
assert d_models[i+1] >= d_models[i], "需满足 D_s <= D_{s+1}(宽度单调增加)"
|
||||
assert encoder_cfg_per_stage[i].get("arch", "m") in ("m", "M"), "Encoder必须为Mamba2"
|
||||
|
||||
self.S = S
|
||||
self.d_models = d_models
|
||||
|
||||
# 输入升维到 D0
|
||||
self.input_proj = nn.Linear(1, d_models[0], **factory_kwargs)
|
||||
|
||||
# 每层编码器 + 路由 + 下采样 + 扩宽参数
|
||||
self.encoders = nn.ModuleList()
|
||||
self.routers = nn.ModuleList()
|
||||
self.chunks = nn.ModuleList()
|
||||
self.pad_vectors = nn.ParameterList()
|
||||
for s in range(S):
|
||||
self.encoders.append(
|
||||
create_isotropic_encoder(
|
||||
d_model=d_models[s],
|
||||
**{k: v for k, v in encoder_cfg_per_stage[s].items() if k != "d_model"},
|
||||
**factory_kwargs
|
||||
)
|
||||
)
|
||||
self.routers.append(RoutingModule(d_models[s], **factory_kwargs))
|
||||
self.chunks.append(ChunkLayer())
|
||||
delta = d_models[s+1] - d_models[s]
|
||||
self.pad_vectors.append(nn.Parameter(torch.zeros(delta, **factory_kwargs)) if delta > 0 else nn.Parameter(torch.empty(0, **factory_kwargs)))
|
||||
|
||||
# 最终唯一的主网络:在 D_S & L_S 上运行
|
||||
self.main_network = create_isotropic_encoder(
|
||||
d_model=d_models[-1],
|
||||
**{k: v for k, v in main_cfg.items() if k != "d_model"},
|
||||
**factory_kwargs
|
||||
)
|
||||
|
||||
# 跨尺度融合:将 pooled_enc0(D0) 投到 D_S 并与 pooled_main(D_S) 融合 -> D_S
|
||||
self.proj_enc0_to_DS = nn.Linear(d_models[0], d_models[-1], **factory_kwargs)
|
||||
self.fusion_head = nn.Sequential(
|
||||
nn.Linear(d_models[-1] + d_models[-1], d_models[-1], **factory_kwargs),
|
||||
nn.GELU(),
|
||||
nn.Dropout(fusion_dropout),
|
||||
nn.Linear(d_models[-1], d_models[-1], **factory_kwargs),
|
||||
)
|
||||
|
||||
def _expand_width(self, x: torch.Tensor, pad_vec: nn.Parameter) -> torch.Tensor:
|
||||
if pad_vec.numel() == 0:
|
||||
return x
|
||||
early = x.shape[:-1]
|
||||
return torch.cat([x, pad_vec.expand(*early, -1)], dim=-1)
|
||||
|
||||
def forward(self, x_scalar: torch.Tensor, mask: Optional[torch.Tensor] = None, return_seq: bool = False):
|
||||
"""
|
||||
x_scalar: (B, L) 或 (B, L, 1)
|
||||
mask: (B, L) bool
|
||||
返回:
|
||||
fused_vec: (B, D_S)
|
||||
debug: 可选
|
||||
aux: 包含各层路由信息(供ratio loss)
|
||||
"""
|
||||
if x_scalar.dim() == 2:
|
||||
x_scalar = x_scalar.unsqueeze(-1) # (B, L, 1)
|
||||
B, L, _ = x_scalar.shape
|
||||
device = x_scalar.device
|
||||
if mask is None:
|
||||
mask = torch.ones(B, L, dtype=torch.bool, device=device)
|
||||
|
||||
# 初始升维到 D0
|
||||
x = self.input_proj(x_scalar) # (B, L0, D0)
|
||||
cur_mask = mask
|
||||
|
||||
pooled_enc0 = None
|
||||
aux_per_stage = []
|
||||
seq_debug = [] if return_seq else None
|
||||
|
||||
# 逐层:Encoder(Mamba2)->Routing->Chunk->Expand D
|
||||
for s in range(self.S):
|
||||
d_in = self.d_models[s]
|
||||
# 细粒度编码(未压缩序列)
|
||||
h_enc = self.encoders[s](x, mask=cur_mask) # (B, L_s, D_s)
|
||||
|
||||
if s == 0:
|
||||
pooled_enc0 = masked_mean(h_enc, cur_mask) # (B, D0)
|
||||
|
||||
# 路由 + 下采样(得到更短序列)
|
||||
bpred = self.routers[s](h_enc, mask=cur_mask)
|
||||
x_next, _, _, mask_next = self.chunks[s](h_enc, bpred.boundary_mask, mask=cur_mask) # (B, L_{s+1}, D_s)
|
||||
|
||||
# 扩展宽度 D_s -> D_{s+1}
|
||||
x_next = self._expand_width(x_next, self.pad_vectors[s]) # (B, L_{s+1}, D_{s+1})
|
||||
|
||||
# 推进到下一层
|
||||
x, cur_mask = x_next, mask_next
|
||||
|
||||
aux_per_stage.append({
|
||||
"boundary_mask": bpred.boundary_mask,
|
||||
"boundary_prob": bpred.boundary_prob,
|
||||
"selected_probs": bpred.selected_probs,
|
||||
})
|
||||
if return_seq:
|
||||
seq_debug.append({"stage": s, "seq": x, "mask": cur_mask})
|
||||
|
||||
# 现在 x: (B, L_S, D_S), cur_mask: (B, L_S)
|
||||
# 最终单一主网络在最压缩序列上
|
||||
h_main = self.main_network(x, mask=cur_mask) # (B, L_S, D_S)
|
||||
|
||||
# 主网络池化
|
||||
if cur_mask is None:
|
||||
pooled_main = h_main.mean(dim=1) # (B, D_S)
|
||||
else:
|
||||
pooled_main = (h_main * cur_mask.float().unsqueeze(-1)).sum(dim=1) / \
|
||||
cur_mask.float().sum(dim=1, keepdim=True).clamp_min(1.0)
|
||||
|
||||
# 跨尺度融合:E^0 全局池化 与 主网络池化
|
||||
pooled_enc0_in_DS = self.proj_enc0_to_DS(pooled_enc0) # (B, D_S)
|
||||
fused = torch.cat([pooled_enc0_in_DS, pooled_main], dim=-1) # (B, 2*D_S)
|
||||
fused = self.fusion_head(fused) # (B, D_S)
|
||||
|
||||
aux = {"per_stage": aux_per_stage}
|
||||
if return_seq:
|
||||
return fused, {"stages": seq_debug, "main_seq": h_main, "main_mask": cur_mask}, aux
|
||||
else:
|
||||
return fused, None, aux
|
||||
|
||||
# -------------------- 顶层:多通道融合 + 分类头(仅一个主网络) --------------------
|
||||
@dataclass
|
||||
class HierEncodersSingleMainConfig:
|
||||
num_channels: int
|
||||
d_models: List[int] # [D0, D1, ..., D_S] 单调非降
|
||||
num_classes: int
|
||||
encoder_cfg_per_stage: List[dict] # S个编码器配置(均为Mamba2, height≈4)
|
||||
main_cfg: dict # 单一主网络配置(Transformer或Mamba2),d_model自动用D_S
|
||||
target_compression_N_per_stage: List[int]
|
||||
share_channel: bool = True
|
||||
fusion_across_channels: Literal["mean", "concat"] = "mean"
|
||||
dropout: float = 0.1
|
||||
|
||||
class HierEncodersSingleMainClassifier(nn.Module):
|
||||
def __init__(self, cfg: HierEncodersSingleMainConfig, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
factory_kwargs = {"dtype": dtype, "device": device}
|
||||
|
||||
S = len(cfg.d_models) - 1
|
||||
assert S == len(cfg.encoder_cfg_per_stage) == len(cfg.target_compression_N_per_stage), "stage数不一致"
|
||||
|
||||
if cfg.share_channel:
|
||||
self.channel_encoder = PyramidEncoders_NoDechunk(
|
||||
d_models=cfg.d_models,
|
||||
encoder_cfg_per_stage=cfg.encoder_cfg_per_stage,
|
||||
main_cfg=cfg.main_cfg,
|
||||
**factory_kwargs,
|
||||
)
|
||||
else:
|
||||
self.channel_encoder = nn.ModuleList([
|
||||
PyramidEncoders_NoDechunk(
|
||||
d_models=cfg.d_models,
|
||||
encoder_cfg_per_stage=cfg.encoder_cfg_per_stage,
|
||||
main_cfg=cfg.main_cfg,
|
||||
**factory_kwargs,
|
||||
)
|
||||
for _ in range(cfg.num_channels)
|
||||
])
|
||||
|
||||
fusion_dim = (cfg.num_channels * cfg.d_models[-1]) if cfg.fusion_across_channels == "concat" \
|
||||
else cfg.d_models[-1]
|
||||
self.dropout = nn.Dropout(cfg.dropout)
|
||||
self.head = nn.Linear(fusion_dim, cfg.num_classes, **factory_kwargs)
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, return_seq: bool = False):
|
||||
"""
|
||||
x: (B, L, N) 多通道输入
|
||||
mask: (B, L) 时序mask
|
||||
"""
|
||||
B, L, N = x.shape
|
||||
assert N == self.cfg.num_channels
|
||||
|
||||
channel_vecs: List[torch.Tensor] = []
|
||||
ratio_losses = []
|
||||
seq_dbg_all = [] if return_seq else None
|
||||
|
||||
for c in range(N):
|
||||
x_c = x[..., c] # (B, L)
|
||||
if self.cfg.share_channel:
|
||||
vec, seq_dbg, aux = self.channel_encoder(x_c, mask=mask, return_seq=return_seq)
|
||||
else:
|
||||
vec, seq_dbg, aux = self.channel_encoder[c](x_c, mask=mask, return_seq=return_seq)
|
||||
|
||||
# ratio loss 累加(每个encoder stage一项)
|
||||
total_rl = 0.0
|
||||
for s, aux_s in enumerate(aux["per_stage"]):
|
||||
rl = ratio_loss(aux_s["boundary_mask"], aux_s["boundary_prob"], self.cfg.target_compression_N_per_stage[s])
|
||||
total_rl = total_rl + rl
|
||||
ratio_losses.append(total_rl)
|
||||
|
||||
channel_vecs.append(vec)
|
||||
if return_seq:
|
||||
seq_dbg_all.append(seq_dbg)
|
||||
|
||||
if self.cfg.fusion_across_channels == "concat":
|
||||
fused = torch.cat(channel_vecs, dim=-1) # (B, N*D_S)
|
||||
else:
|
||||
fused = torch.stack(channel_vecs, dim=1).mean(dim=1) # (B, D_S)
|
||||
|
||||
fused = self.dropout(fused)
|
||||
logits = self.head(fused)
|
||||
|
||||
aux_all = {"ratio_loss": torch.stack(ratio_losses).mean()}
|
||||
if return_seq:
|
||||
return logits, seq_dbg_all, aux_all
|
||||
else:
|
||||
return logits, None, aux_all
|
||||
|
||||
# -------------------- 使用示例 --------------------
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
符合要求:
|
||||
- 多层仅增加编码器数量(每层Mamba2 + 动态分块),主网络只有最终一个
|
||||
- 序列长度逐层缩短(由DC决定),通道维度 d_model 单调增大(SpaceByte式共享向量拼接)
|
||||
- 不使用去分块(dechunk);跨尺度融合用 E^0 的全局池化 + 最终主网络池化
|
||||
"""
|
||||
B, L, N = 8, 1024, 6
|
||||
num_classes = 7
|
||||
d_models = [128, 256, 512] # D0 <= D1 <= D2
|
||||
|
||||
encoder_cfg_per_stage = [
|
||||
dict(arch="m", height=4, ssm_cfg=dict(), attn_cfg=dict()), # stage 0 encoder (Mamba2)
|
||||
dict(arch="m", height=4, ssm_cfg=dict(), attn_cfg=dict()), # stage 1 encoder (Mamba2)
|
||||
]
|
||||
main_cfg = dict(
|
||||
arch="T", height=12, ssm_cfg=dict(), attn_cfg=dict(num_heads=8) # 最终主网络(较重)
|
||||
)
|
||||
target_compression_N_per_stage = [4, 4]
|
||||
|
||||
cfg = HierEncodersSingleMainConfig(
|
||||
num_channels=N,
|
||||
d_models=d_models,
|
||||
num_classes=num_classes,
|
||||
encoder_cfg_per_stage=encoder_cfg_per_stage,
|
||||
main_cfg=main_cfg,
|
||||
target_compression_N_per_stage=target_compression_N_per_stage,
|
||||
share_channel=True,
|
||||
fusion_across_channels="mean",
|
||||
dropout=0.1,
|
||||
)
|
||||
|
||||
model = HierEncodersSingleMainClassifier(cfg).cuda().train()
|
||||
x = torch.randn(B, L, N, device="cuda")
|
||||
mask = torch.ones(B, L, dtype=torch.bool, device="cuda")
|
||||
|
||||
logits, _, aux = model(x, mask=mask, return_seq=False)
|
||||
y = torch.randint(0, num_classes, (B,), device="cuda")
|
||||
cls_loss = F.cross_entropy(logits, y)
|
||||
ratio_reg = 0.03 * aux["ratio_loss"]
|
||||
loss = cls_loss + ratio_reg
|
||||
loss.backward()
|
||||
print("logits:", logits.shape, "loss:", float(loss))
|
Reference in New Issue
Block a user