Files
TSlib/models/DC_hnet.py

340 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from dataclasses import dataclass
from typing import Optional, Literal, List
import torch
import torch.nn as nn
import torch.nn.functional as F
# 来自你的代码库(可直接使用)
from hnet.modules.dc import RoutingModule, ChunkLayer
from hnet.modules.isotropic import Isotropic
from hnet.models.config_hnet import HNetConfig, SSMConfig, AttnConfig
# -------------------- 辅助 --------------------
def create_isotropic_encoder(d_model, arch="m", height=4, device=None, dtype=None):
"""创建简化的Isotropic编码器"""
factory_kwargs = {"device": device, "dtype": dtype}
# 创建HNetConfig确保list字段有足够的元素
config = HNetConfig(
arch_layout=[f"{arch}{height}"],
d_model=[d_model],
d_intermediate=[d_model * 2],
ssm_cfg=SSMConfig(
d_conv=4,
expand=2,
d_state=128,
chunk_size=256
),
attn_cfg=AttnConfig(
num_heads=[8], # 确保有至少一个元素
rotary_emb_dim=[0], # 确保有至少一个元素
window_size=[-1] # 确保有至少一个元素
)
)
return Isotropic(
config=config,
pos_idx=0,
stage_idx=0,
**factory_kwargs
)
def ratio_loss(boundary_mask: torch.Tensor, boundary_prob: torch.Tensor, target_N: int) -> torch.Tensor:
F_act = boundary_mask.float().mean(dim=1) # (B,)
G_prob = boundary_prob[..., 1].mean(dim=1) # (B,)
N = float(target_N)
loss = N / (N - 1.0) * (((N - 1.0) * F_act) + (1.0 - F_act) * (1.0 - G_prob))
return loss.mean()
def masked_mean(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
mask_f = mask.float().unsqueeze(-1) # (B, L, 1)
s = (x * mask_f).sum(dim=1) # (B, D)
denom = mask_f.sum(dim=1).clamp_min(1.0)
return s / denom
# -------------------- 多层Encoder金字塔每层Mamba2 + 路由下采样,只有最终有主网络 --------------------
class PyramidEncoders_NoDechunk(nn.Module):
"""
层级结构(仅编码器逐层压缩;主网络只在最终一层):
输入 x0: (B, L0, 1)
- 线性升维 -> D0
For s = 0..S-1:
Es(Mamba2, D_s) -> h_s (B, L_s, D_s)
路由 + 下采样 -> x_{s+1} (B, L_{s+1}, D_s), mask_{s+1}
维度扩展 D_s -> D_{s+1}(拼接共享向量)
最终 x_S: (B, L_S, D_S) 送入单一主网络 M (Transformer/Mamba)
跨尺度融合(不去分块):融合 E^0 的 pooled_enc0 与 主网络 pooled_main
"""
def __init__(
self,
d_models: List[int], # [D0, D1, ..., D_S] 单调非降
encoder_cfg_per_stage: List[dict], # S个编码器配置必须 arch='m'/'M'
main_cfg: dict, # 单一主网络配置(在最压缩序列上工作)
fusion_dropout: float = 0.1,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
super().__init__()
factory_kwargs = {"device": device, "dtype": dtype}
assert len(d_models) >= 1
S = len(d_models) - 1
assert S == len(encoder_cfg_per_stage), "stage数等于encoder配置数"
for i in range(S):
assert d_models[i+1] >= d_models[i], "需满足 D_s <= D_{s+1}(宽度单调增加)"
assert encoder_cfg_per_stage[i].get("arch", "m") in ("m", "M"), "Encoder必须为Mamba2"
self.S = S
self.d_models = d_models
# 输入升维到 D0
self.input_proj = nn.Linear(1, d_models[0], **factory_kwargs)
# 每层编码器 + 路由 + 下采样 + 扩宽参数
self.encoders = nn.ModuleList()
self.routers = nn.ModuleList()
self.chunks = nn.ModuleList()
self.pad_vectors = nn.ParameterList()
for s in range(S):
self.encoders.append(
create_isotropic_encoder(
d_model=d_models[s],
**{k: v for k, v in encoder_cfg_per_stage[s].items() if k != "d_model"},
**factory_kwargs
)
)
self.routers.append(RoutingModule(d_models[s], **factory_kwargs))
self.chunks.append(ChunkLayer())
delta = d_models[s+1] - d_models[s]
self.pad_vectors.append(nn.Parameter(torch.zeros(delta, **factory_kwargs)) if delta > 0 else nn.Parameter(torch.empty(0, **factory_kwargs)))
# 最终唯一的主网络:在 D_S & L_S 上运行
self.main_network = create_isotropic_encoder(
d_model=d_models[-1],
**{k: v for k, v in main_cfg.items() if k != "d_model"},
**factory_kwargs
)
# 跨尺度融合:将 pooled_enc0(D0) 投到 D_S 并与 pooled_main(D_S) 融合 -> D_S
self.proj_enc0_to_DS = nn.Linear(d_models[0], d_models[-1], **factory_kwargs)
self.fusion_head = nn.Sequential(
nn.Linear(d_models[-1] + d_models[-1], d_models[-1], **factory_kwargs),
nn.GELU(),
nn.Dropout(fusion_dropout),
nn.Linear(d_models[-1], d_models[-1], **factory_kwargs),
)
def _expand_width(self, x: torch.Tensor, pad_vec: nn.Parameter) -> torch.Tensor:
if pad_vec.numel() == 0:
return x
early = x.shape[:-1]
return torch.cat([x, pad_vec.expand(*early, -1)], dim=-1)
def forward(self, x_scalar: torch.Tensor, mask: Optional[torch.Tensor] = None, return_seq: bool = False):
"""
x_scalar: (B, L) 或 (B, L, 1)
mask: (B, L) bool
返回:
fused_vec: (B, D_S)
debug: 可选
aux: 包含各层路由信息供ratio loss
"""
if x_scalar.dim() == 2:
x_scalar = x_scalar.unsqueeze(-1) # (B, L, 1)
B, L, _ = x_scalar.shape
device = x_scalar.device
if mask is None:
mask = torch.ones(B, L, dtype=torch.bool, device=device)
# 初始升维到 D0
x = self.input_proj(x_scalar) # (B, L0, D0)
cur_mask = mask
pooled_enc0 = None
aux_per_stage = []
seq_debug = [] if return_seq else None
# 逐层Encoder(Mamba2)->Routing->Chunk->Expand D
for s in range(self.S):
d_in = self.d_models[s]
# 细粒度编码(未压缩序列)
h_enc = self.encoders[s](x, mask=cur_mask) # (B, L_s, D_s)
if s == 0:
pooled_enc0 = masked_mean(h_enc, cur_mask) # (B, D0)
# 路由 + 下采样(得到更短序列)
bpred = self.routers[s](h_enc, mask=cur_mask)
x_next, _, _, mask_next = self.chunks[s](h_enc, bpred.boundary_mask, mask=cur_mask) # (B, L_{s+1}, D_s)
# 扩展宽度 D_s -> D_{s+1}
x_next = self._expand_width(x_next, self.pad_vectors[s]) # (B, L_{s+1}, D_{s+1})
# 推进到下一层
x, cur_mask = x_next, mask_next
aux_per_stage.append({
"boundary_mask": bpred.boundary_mask,
"boundary_prob": bpred.boundary_prob,
"selected_probs": bpred.selected_probs,
})
if return_seq:
seq_debug.append({"stage": s, "seq": x, "mask": cur_mask})
# 现在 x: (B, L_S, D_S), cur_mask: (B, L_S)
# 最终单一主网络在最压缩序列上
h_main = self.main_network(x, mask=cur_mask) # (B, L_S, D_S)
# 主网络池化
if cur_mask is None:
pooled_main = h_main.mean(dim=1) # (B, D_S)
else:
pooled_main = (h_main * cur_mask.float().unsqueeze(-1)).sum(dim=1) / \
cur_mask.float().sum(dim=1, keepdim=True).clamp_min(1.0)
# 跨尺度融合E^0 全局池化 与 主网络池化
pooled_enc0_in_DS = self.proj_enc0_to_DS(pooled_enc0) # (B, D_S)
fused = torch.cat([pooled_enc0_in_DS, pooled_main], dim=-1) # (B, 2*D_S)
fused = self.fusion_head(fused) # (B, D_S)
aux = {"per_stage": aux_per_stage}
if return_seq:
return fused, {"stages": seq_debug, "main_seq": h_main, "main_mask": cur_mask}, aux
else:
return fused, None, aux
# -------------------- 顶层:多通道融合 + 分类头(仅一个主网络) --------------------
@dataclass
class HierEncodersSingleMainConfig:
num_channels: int
d_models: List[int] # [D0, D1, ..., D_S] 单调非降
num_classes: int
encoder_cfg_per_stage: List[dict] # S个编码器配置均为Mamba2, height≈4
main_cfg: dict # 单一主网络配置Transformer或Mamba2d_model自动用D_S
target_compression_N_per_stage: List[int]
share_channel: bool = True
fusion_across_channels: Literal["mean", "concat"] = "mean"
dropout: float = 0.1
class HierEncodersSingleMainClassifier(nn.Module):
def __init__(self, cfg: HierEncodersSingleMainConfig, dtype=None, device=None):
super().__init__()
self.cfg = cfg
factory_kwargs = {"dtype": dtype, "device": device}
S = len(cfg.d_models) - 1
assert S == len(cfg.encoder_cfg_per_stage) == len(cfg.target_compression_N_per_stage), "stage数不一致"
if cfg.share_channel:
self.channel_encoder = PyramidEncoders_NoDechunk(
d_models=cfg.d_models,
encoder_cfg_per_stage=cfg.encoder_cfg_per_stage,
main_cfg=cfg.main_cfg,
**factory_kwargs,
)
else:
self.channel_encoder = nn.ModuleList([
PyramidEncoders_NoDechunk(
d_models=cfg.d_models,
encoder_cfg_per_stage=cfg.encoder_cfg_per_stage,
main_cfg=cfg.main_cfg,
**factory_kwargs,
)
for _ in range(cfg.num_channels)
])
fusion_dim = (cfg.num_channels * cfg.d_models[-1]) if cfg.fusion_across_channels == "concat" \
else cfg.d_models[-1]
self.dropout = nn.Dropout(cfg.dropout)
self.head = nn.Linear(fusion_dim, cfg.num_classes, **factory_kwargs)
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, return_seq: bool = False):
"""
x: (B, L, N) 多通道输入
mask: (B, L) 时序mask
"""
B, L, N = x.shape
assert N == self.cfg.num_channels
channel_vecs: List[torch.Tensor] = []
ratio_losses = []
seq_dbg_all = [] if return_seq else None
for c in range(N):
x_c = x[..., c] # (B, L)
if self.cfg.share_channel:
vec, seq_dbg, aux = self.channel_encoder(x_c, mask=mask, return_seq=return_seq)
else:
vec, seq_dbg, aux = self.channel_encoder[c](x_c, mask=mask, return_seq=return_seq)
# ratio loss 累加每个encoder stage一项
total_rl = 0.0
for s, aux_s in enumerate(aux["per_stage"]):
rl = ratio_loss(aux_s["boundary_mask"], aux_s["boundary_prob"], self.cfg.target_compression_N_per_stage[s])
total_rl = total_rl + rl
ratio_losses.append(total_rl)
channel_vecs.append(vec)
if return_seq:
seq_dbg_all.append(seq_dbg)
if self.cfg.fusion_across_channels == "concat":
fused = torch.cat(channel_vecs, dim=-1) # (B, N*D_S)
else:
fused = torch.stack(channel_vecs, dim=1).mean(dim=1) # (B, D_S)
fused = self.dropout(fused)
logits = self.head(fused)
aux_all = {"ratio_loss": torch.stack(ratio_losses).mean()}
if return_seq:
return logits, seq_dbg_all, aux_all
else:
return logits, None, aux_all
# -------------------- 使用示例 --------------------
if __name__ == "__main__":
"""
符合要求:
- 多层仅增加编码器数量每层Mamba2 + 动态分块),主网络只有最终一个
- 序列长度逐层缩短由DC决定通道维度 d_model 单调增大SpaceByte式共享向量拼接
- 不使用去分块dechunk跨尺度融合用 E^0 的全局池化 + 最终主网络池化
"""
B, L, N = 8, 1024, 6
num_classes = 7
d_models = [128, 256, 512] # D0 <= D1 <= D2
encoder_cfg_per_stage = [
dict(arch="m", height=4, ssm_cfg=dict(), attn_cfg=dict()), # stage 0 encoder (Mamba2)
dict(arch="m", height=4, ssm_cfg=dict(), attn_cfg=dict()), # stage 1 encoder (Mamba2)
]
main_cfg = dict(
arch="T", height=12, ssm_cfg=dict(), attn_cfg=dict(num_heads=8) # 最终主网络(较重)
)
target_compression_N_per_stage = [4, 4]
cfg = HierEncodersSingleMainConfig(
num_channels=N,
d_models=d_models,
num_classes=num_classes,
encoder_cfg_per_stage=encoder_cfg_per_stage,
main_cfg=main_cfg,
target_compression_N_per_stage=target_compression_N_per_stage,
share_channel=True,
fusion_across_channels="mean",
dropout=0.1,
)
model = HierEncodersSingleMainClassifier(cfg).cuda().train()
x = torch.randn(B, L, N, device="cuda")
mask = torch.ones(B, L, dtype=torch.bool, device="cuda")
logits, _, aux = model(x, mask=mask, return_seq=False)
y = torch.randint(0, num_classes, (B,), device="cuda")
cls_loss = F.cross_entropy(logits, y)
ratio_reg = 0.03 * aux["ratio_loss"]
loss = cls_loss + ratio_reg
loss.backward()
print("logits:", logits.shape, "loss:", float(loss))