feat: add mamba and dynamic chunking related code and test code
This commit is contained in:
528
models/DC_PatchTST.py
Normal file
528
models/DC_PatchTST.py
Normal file
@ -0,0 +1,528 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from layers.SelfAttention_Family import FullAttention, AttentionLayer
|
||||
|
||||
# 需要 Mamba2 作为外层编码器
|
||||
from mamba_ssm.modules.mamba2 import Mamba2
|
||||
|
||||
|
||||
|
||||
# -------------------- Routing(余弦路由,和论文一致) --------------------
|
||||
class RoutingModule(nn.Module):
|
||||
def __init__(self, d_model):
|
||||
super().__init__()
|
||||
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
||||
self.k_proj = nn.Linear(d_model, d_model, bias=False)
|
||||
with torch.no_grad():
|
||||
nn.init.eye_(self.q_proj.weight)
|
||||
nn.init.eye_(self.k_proj.weight)
|
||||
self.q_proj.weight._no_reinit = True
|
||||
self.k_proj.weight._no_reinit = True
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
"""
|
||||
x: (B, L, D)
|
||||
mask: (B, L) bool, True=有效
|
||||
返回:
|
||||
boundary_prob: (B, L, 2)
|
||||
boundary_mask: (B, L) bool
|
||||
selected_probs: (B, L, 1)
|
||||
"""
|
||||
B, L, D = x.shape
|
||||
q = F.normalize(self.q_proj(x[:, :-1]), dim=-1) # (B, L-1, D)
|
||||
k = F.normalize(self.k_proj(x[:, 1:]), dim=-1) # (B, L-1, D)
|
||||
cos_sim = (q * k).sum(dim=-1) # (B, L-1)
|
||||
p = torch.clamp((1 - cos_sim) / 2, 0.0, 1.0) # (B, L-1)
|
||||
p = F.pad(p, (1, 0), value=1.0) # 强制首位是边界
|
||||
|
||||
if mask is not None:
|
||||
p = p * mask.float()
|
||||
p[:, 0] = torch.where(mask[:, 0], torch.ones_like(p[:, 0]), p[:, 0])
|
||||
|
||||
boundary_prob = torch.stack([1 - p, p], dim=-1) # (B, L, 2)
|
||||
selected_idx = boundary_prob.argmax(dim=-1)
|
||||
boundary_mask = (selected_idx == 1)
|
||||
if mask is not None:
|
||||
boundary_mask = boundary_mask & mask
|
||||
selected_probs = boundary_prob.gather(-1, selected_idx.unsqueeze(-1)) # (B, L, 1)
|
||||
return boundary_prob, boundary_mask, selected_probs
|
||||
|
||||
|
||||
# -------------------- 选择并右侧零pad(不丢弃、不重复填充) --------------------
|
||||
def select_and_right_pad(x, boundary_mask):
|
||||
"""
|
||||
内存优化版本:减少临时tensor创建
|
||||
x: (B, L, D), boundary_mask: (B, L) bool
|
||||
返回:
|
||||
x_pad: (B, T_max, D)
|
||||
key_padding_mask: (B, T_max) bool, True=有效
|
||||
lengths: (B,)
|
||||
"""
|
||||
B, L, D = x.shape
|
||||
device = x.device
|
||||
lengths = boundary_mask.sum(dim=1) # (B,)
|
||||
T_max = int(lengths.max().item()) if lengths.max() > 0 else 1
|
||||
|
||||
x_pad = x.new_zeros(B, T_max, D)
|
||||
key_padding_mask = torch.zeros(B, T_max, dtype=torch.bool, device=device)
|
||||
|
||||
# 预创建默认索引tensor避免重复创建
|
||||
default_idx = torch.tensor([0], device=device)
|
||||
|
||||
for b in range(B):
|
||||
mask_b = boundary_mask[b]
|
||||
if mask_b.any():
|
||||
idx = mask_b.nonzero(as_tuple=True)[0] # 更高效的nonzero
|
||||
t = idx.numel()
|
||||
x_pad[b, :t] = x[b, idx]
|
||||
key_padding_mask[b, :t] = True
|
||||
else:
|
||||
# 使用预创建的tensor
|
||||
x_pad[b, 0] = x[b, default_idx]
|
||||
key_padding_mask[b, 0] = True
|
||||
|
||||
return x_pad, key_padding_mask, lengths
|
||||
|
||||
|
||||
# -------------------- Mamba2 堆叠(外层编码器) --------------------
|
||||
class Mamba2Encoder(nn.Module):
|
||||
def __init__(self, d_model, depth=4, dropout=0.0):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([Mamba2(d_model=d_model) for _ in range(depth)])
|
||||
self.norm = nn.LayerNorm(d_model)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
x = self.norm(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
# -------------------- 两层Encoder + DC(变长,不丢信息;以比率约束压缩) --------------------
|
||||
class DCEmbedding2StageVarLen(nn.Module):
|
||||
"""
|
||||
- Stage 0: (B*nvars, L, 1) -> Linear(D0) -> Mamba2(D0) -> Routing -> 选择 -> 扩宽到 D1
|
||||
- Stage 1: (B*nvars, L0_sel, D1) -> Mamba2(D1) -> Routing -> 选择
|
||||
输出:
|
||||
enc_out: (B*nvars, T_max, D1)
|
||||
key_padding_mask: (B*nvars, T_max)
|
||||
n_vars: int
|
||||
aux: dict(含两层ratio loss与边界信息)
|
||||
"""
|
||||
def __init__(self, d_model_out, d_model_stage0, depth_enc0=4, depth_enc1=4, dropout=0.0,
|
||||
target_ratio0=0.25, target_ratio1=0.5):
|
||||
super().__init__()
|
||||
assert d_model_out >= d_model_stage0, "要求 D0 <= D1"
|
||||
self.d0 = d_model_stage0
|
||||
self.d1 = d_model_out
|
||||
|
||||
# 标量 -> D0
|
||||
self.input_proj = nn.Linear(1, self.d0)
|
||||
|
||||
# Stage 0
|
||||
self.enc0 = Mamba2Encoder(self.d0, depth=depth_enc0, dropout=dropout)
|
||||
self.router0 = RoutingModule(self.d0)
|
||||
delta = self.d1 - self.d0
|
||||
self.pad_vec = nn.Parameter(torch.zeros(delta)) if delta > 0 else None
|
||||
self.target_ratio0 = target_ratio0
|
||||
|
||||
# Stage 1
|
||||
self.enc1 = Mamba2Encoder(self.d1, depth=depth_enc1, dropout=dropout)
|
||||
self.router1 = RoutingModule(self.d1)
|
||||
self.target_ratio1 = target_ratio1
|
||||
|
||||
def _expand_width(self, x):
|
||||
if self.pad_vec is None:
|
||||
return x
|
||||
B, L, _ = x.shape
|
||||
return torch.cat([x, self.pad_vec.view(1, 1, -1).expand(B, L, -1)], dim=-1)
|
||||
|
||||
@staticmethod
|
||||
def _ratio_loss(boundary_mask: torch.Tensor, boundary_prob: torch.Tensor, target_ratio: float) -> torch.Tensor:
|
||||
eps = 1e-6
|
||||
F_act = boundary_mask.float().mean(dim=1) # (B,)
|
||||
G_prob = boundary_prob[..., 1].mean(dim=1) # (B,)
|
||||
N = 1.0 / max(target_ratio, eps)
|
||||
loss = N / (N - 1.0 + eps) * (((N - 1.0) * F_act * G_prob) + (1.0 - F_act) * (1.0 - G_prob))
|
||||
return loss.mean()
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: (B, nvars, L)
|
||||
内存优化版本:及时删除中间tensor
|
||||
"""
|
||||
B, nvars, L = x.shape
|
||||
x = x.reshape(B * nvars, L, 1)
|
||||
x = self.input_proj(x) # (B*nvars, L, D0)
|
||||
|
||||
# Stage 0
|
||||
h0 = self.enc0(x) # (B*nvars, L, D0)
|
||||
p0, bm0, _ = self.router0(h0)
|
||||
h0_sel, mask0, len0 = select_and_right_pad(h0, bm0) # (B*nvars, L0_max, D0)
|
||||
|
||||
# 及时删除不需要的tensor
|
||||
del h0
|
||||
|
||||
# h0_sel = self._expand_width(h0_sel) # (B*nvars, L0_max, D1)
|
||||
|
||||
# Stage 1
|
||||
#h1 = self.enc1(h0_sel) # (B*nvars, L0_max, D1)
|
||||
#p1, bm1, _ = self.router1(h1)
|
||||
#bm1 = bm1 & mask0
|
||||
#h1_sel, mask1, len1 = select_and_right_pad(h1, bm1) # (B*nvars, L1_max, D1)
|
||||
|
||||
# 及时删除中间tensor
|
||||
#del h1, h0_sel
|
||||
|
||||
# 计算ratio loss时使用detach避免保存计算图
|
||||
ratio_loss0 = self._ratio_loss(bm0, p0, target_ratio=self.target_ratio0)
|
||||
# ratio_loss1 = self._ratio_loss(bm1, p1, target_ratio=self.target_ratio1)
|
||||
|
||||
# 简化aux字典,只保存必要信息
|
||||
aux = {
|
||||
"stage0": {"boundary_mask": bm0.detach(), "boundary_prob": p0.detach(), "lengths": len0.detach()},
|
||||
# "stage1": {"boundary_mask": bm1.detach(), "boundary_prob": p1.detach(), "lengths": len1.detach()},
|
||||
"ratio_loss0": ratio_loss0,
|
||||
# "ratio_loss1": ratio_loss1,
|
||||
}
|
||||
|
||||
return h0_sel, mask0, nvars, aux
|
||||
|
||||
|
||||
# -------------------- Encoder/EncoderLayer(带 key_padding_mask 透传) --------------------
|
||||
class EncoderLayerWithMask(nn.Module):
|
||||
"""
|
||||
与原EncoderLayer结构一致,但 forward 增加 key_padding_mask,并传入 AttentionLayer。
|
||||
FFN 用简单的 MLP(与常规Transformer一致)。
|
||||
"""
|
||||
def __init__(self, attention: AttentionLayer, d_model, d_ff, dropout=0.1, activation="gelu"):
|
||||
super().__init__()
|
||||
self.attention = attention
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
|
||||
if activation == "relu":
|
||||
act = nn.ReLU()
|
||||
elif activation == "gelu":
|
||||
act = nn.GELU()
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation: {activation}")
|
||||
|
||||
self.ffn = nn.Sequential(
|
||||
nn.Linear(d_model, d_ff),
|
||||
act,
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(d_ff, d_model),
|
||||
)
|
||||
|
||||
def forward(self, x, attn_mask=None, tau=None, delta=None, key_padding_mask=None):
|
||||
# Multi-head attention with key padding mask
|
||||
attn_out, attn = self.attention(
|
||||
x, x, x, attn_mask, tau=tau, delta=delta, key_padding_mask=key_padding_mask
|
||||
)
|
||||
x = x + self.dropout(attn_out)
|
||||
x = self.norm1(x)
|
||||
|
||||
# FFN
|
||||
y = self.ffn(x)
|
||||
x = x + self.dropout(y)
|
||||
x = self.norm2(x)
|
||||
return x, attn
|
||||
|
||||
|
||||
class EncoderWithMask(nn.Module):
|
||||
"""
|
||||
与原Encoder类似,但 forward 支持 key_padding_mask,并传递给每一层的注意力。
|
||||
"""
|
||||
def __init__(self, attn_layers, norm_layer=None):
|
||||
super().__init__()
|
||||
self.attn_layers = nn.ModuleList(attn_layers)
|
||||
self.norm = norm_layer
|
||||
|
||||
def forward(self, x, attn_mask=None, key_padding_mask=None):
|
||||
attns = []
|
||||
for attn_layer in self.attn_layers:
|
||||
x, attn = attn_layer(x, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
|
||||
attns.append(attn)
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
return x, attns
|
||||
|
||||
|
||||
# -------------------- 门控注意力聚合 + 任务头(不依赖token数;保留信息) --------------------
|
||||
def masked_softmax(logits: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
"""
|
||||
logits: (..., T)
|
||||
mask: (..., T) bool, True=有效
|
||||
"""
|
||||
neg_inf = torch.finfo(logits.dtype).min
|
||||
logits = logits.masked_fill(~mask, neg_inf)
|
||||
return torch.softmax(logits, dim=dim)
|
||||
|
||||
class GatedAttnAggregator(nn.Module):
|
||||
"""
|
||||
门控注意力聚合器(可学习查询 + mask softmax + 值端门控)
|
||||
输入: x: (B*, T, D), mask: (B*, T) bool
|
||||
输出: slots: (B*, R, D) 其中 R 为聚合插槽数(可配置)
|
||||
"""
|
||||
def __init__(self, d_model: int, num_slots: int = 4, d_att: int = None, dropout: float = 0.1):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.R = num_slots
|
||||
self.d_att = d_att or d_model
|
||||
|
||||
# 可学习查询(R个)
|
||||
self.query = nn.Parameter(torch.randn(self.R, self.d_att) / (self.d_att ** 0.5))
|
||||
|
||||
# 线性投影
|
||||
self.key_proj = nn.Linear(d_model, self.d_att)
|
||||
self.val_proj = nn.Linear(d_model, d_model)
|
||||
|
||||
# 值端门控(逐token标量门)
|
||||
self.gate = nn.Sequential(
|
||||
nn.Linear(d_model, d_model // 2),
|
||||
nn.GELU(),
|
||||
nn.Linear(d_model // 2, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x_bt_t_d: torch.Tensor, mask_bt: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
x_bt_t_d: (B*, T, D)
|
||||
mask_bt: (B*, T) bool
|
||||
return: slots (B*, R, D)
|
||||
"""
|
||||
BStar, T, D = x_bt_t_d.shape
|
||||
K = self.key_proj(x_bt_t_d) # (B*, T, d_att)
|
||||
V = self.val_proj(x_bt_t_d) # (B*, T, D)
|
||||
g = self.gate(x_bt_t_d) # (B*, T, 1)
|
||||
Vg = V * g # 门控后的值
|
||||
|
||||
Q = self.query.unsqueeze(0).expand(BStar, -1, -1) # (B*, R, d_att)
|
||||
|
||||
logits = torch.matmul(Q, K.transpose(1, 2)) / (self.d_att ** 0.5) # (B*, R, T)
|
||||
attn_mask = mask_bt.unsqueeze(1) # (B*, 1, T)
|
||||
attn = masked_softmax(logits, attn_mask, dim=-1)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
slots = torch.matmul(attn, Vg) # (B*, R, D)
|
||||
return slots
|
||||
|
||||
class AttnPoolHeadForecast(nn.Module):
|
||||
"""
|
||||
预测任务头:门控注意力聚合到 R 个slots,再映射到 target_window(pred_len)
|
||||
输出:(B, pred_len, nvars)
|
||||
"""
|
||||
def __init__(self, d_model: int, target_window: int, num_slots: int = 4, dropout: float = 0.1):
|
||||
super().__init__()
|
||||
self.agg = GatedAttnAggregator(d_model, num_slots=num_slots, dropout=dropout)
|
||||
self.proj = nn.Sequential(
|
||||
nn.LayerNorm(num_slots * d_model),
|
||||
nn.Linear(num_slots * d_model, target_window),
|
||||
)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, enc_out_bt_t_d: torch.Tensor, key_padding_mask_bt: torch.Tensor, n_vars: int, B: int):
|
||||
slots = self.agg(enc_out_bt_t_d, key_padding_mask_bt) # (B*, R, D)
|
||||
slots = slots.reshape(B, n_vars, -1) # (B, nvars, R*D)
|
||||
out = self.proj(self.dropout(slots)) # (B, nvars, pred_len)
|
||||
return out.permute(0, 2, 1) # (B, pred_len, nvars)
|
||||
|
||||
class AttnPoolHeadSeq(nn.Module):
|
||||
"""
|
||||
序列重建头:门控注意力聚合后映射到 seq_len
|
||||
输出:(B, seq_len, nvars)
|
||||
"""
|
||||
def __init__(self, d_model: int, target_window: int, num_slots: int = 4, dropout: float = 0.1):
|
||||
super().__init__()
|
||||
self.agg = GatedAttnAggregator(d_model, num_slots=num_slots, dropout=dropout)
|
||||
self.proj = nn.Sequential(
|
||||
nn.LayerNorm(num_slots * d_model),
|
||||
nn.Linear(num_slots * d_model, target_window),
|
||||
)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, enc_out_bt_t_d: torch.Tensor, key_padding_mask_bt: torch.Tensor, n_vars: int, B: int):
|
||||
slots = self.agg(enc_out_bt_t_d, key_padding_mask_bt) # (B*, R, D)
|
||||
slots = slots.reshape(B, n_vars, -1) # (B, nvars, R*D)
|
||||
out = self.proj(self.dropout(slots)) # (B, nvars, seq_len)
|
||||
return out.permute(0, 2, 1) # (B, seq_len, nvars)
|
||||
|
||||
class AttnPoolHeadCls(nn.Module):
|
||||
"""
|
||||
分类头:每变量先门控注意力聚合到 R 个slots,拼接所有变量后线性分类。
|
||||
输出:(B, num_class)
|
||||
"""
|
||||
def __init__(self, d_model: int, n_vars: int, num_class: int, num_slots: int = 4, dropout: float = 0.1):
|
||||
super().__init__()
|
||||
self.agg = GatedAttnAggregator(d_model, num_slots=num_slots, dropout=dropout)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.proj = nn.Sequential(
|
||||
nn.LayerNorm(n_vars * num_slots * d_model),
|
||||
nn.Linear(n_vars * num_slots * d_model, num_class),
|
||||
)
|
||||
self.n_vars = n_vars
|
||||
self.num_slots = num_slots
|
||||
self.d_model = d_model
|
||||
|
||||
def forward(self, enc_out_bt_t_d: torch.Tensor, key_padding_mask_bt: torch.Tensor, n_vars: int, B: int):
|
||||
slots = self.agg(enc_out_bt_t_d, key_padding_mask_bt) # (B*, R, D)
|
||||
slots = slots.reshape(B, n_vars, self.num_slots * self.d_model) # (B, nvars, R*D)
|
||||
flat = self.dropout(slots.reshape(B, -1)) # (B, nvars*R*D)
|
||||
return self.proj(flat)
|
||||
|
||||
|
||||
# -------------------- 主模型:两层DC(比率控制) + 带mask的Encoder + 门控聚合头 --------------------
|
||||
class Transpose(nn.Module):
|
||||
def __init__(self, *dims, contiguous=False):
|
||||
super().__init__()
|
||||
self.dims, self.contiguous = dims, contiguous
|
||||
def forward(self, x):
|
||||
return x.transpose(*self.dims).contiguous() if self.contiguous else x.transpose(*self.dims)
|
||||
|
||||
class Model(nn.Module):
|
||||
"""
|
||||
PatchTST with DC and masked attention + gated heads:
|
||||
- 用两层 Mamba2 编码器 + 动态分块 替代 PatchEmbedding
|
||||
- DC 使用 ratio loss(target_ratio0/1)控制压缩强度;随层级加深,序列变短,d_model 变大(D0->D1)
|
||||
- 注意力传入 key_padding_mask 屏蔽pad
|
||||
- 头部使用门控注意力聚合(不依赖token数,信息保留更充分)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, configs,
|
||||
d_model_stage0=None, # D0,默认= d_model // 2
|
||||
depth_enc0=1, depth_enc1=1,
|
||||
target_ratio0=0.25, # 约等于 1/N0
|
||||
target_ratio1=0.5, # 约等于 1/N1
|
||||
agg_slots=4, # 门控聚合的slot数
|
||||
):
|
||||
super().__init__()
|
||||
self.task_name = configs.task_name
|
||||
self.seq_len = configs.seq_len
|
||||
self.pred_len = configs.pred_len
|
||||
self.enc_in = configs.enc_in
|
||||
|
||||
# DC 嵌入
|
||||
D1 = configs.d_model
|
||||
D0 = d_model_stage0 if d_model_stage0 is not None else max(16, D1 // 2)
|
||||
assert D1 >= D0, "要求 D0 <= D1"
|
||||
self.dc_embedding = DCEmbedding2StageVarLen(
|
||||
d_model_out=D1,
|
||||
d_model_stage0=D0,
|
||||
depth_enc0=depth_enc0,
|
||||
depth_enc1=depth_enc1,
|
||||
dropout=configs.dropout,
|
||||
target_ratio0=target_ratio0,
|
||||
target_ratio1=target_ratio1,
|
||||
)
|
||||
|
||||
# 带mask的Encoder
|
||||
attn_layers = [
|
||||
EncoderLayerWithMask(
|
||||
AttentionLayer(
|
||||
FullAttention(False, configs.factor, attention_dropout=configs.dropout, output_attention=False),
|
||||
D1, configs.n_heads
|
||||
),
|
||||
d_model=D1,
|
||||
d_ff=configs.d_ff,
|
||||
dropout=configs.dropout,
|
||||
activation=configs.activation
|
||||
) for _ in range(configs.e_layers)
|
||||
]
|
||||
self.encoder = EncoderWithMask(
|
||||
attn_layers,
|
||||
norm_layer=nn.Sequential(Transpose(1, 2), nn.BatchNorm1d(D1), Transpose(1, 2))
|
||||
)
|
||||
|
||||
# 门控聚合头(与token数无关)
|
||||
if self.task_name in ('long_term_forecast', 'short_term_forecast'):
|
||||
self.head = AttnPoolHeadForecast(D1, self.pred_len, num_slots=agg_slots, dropout=configs.dropout)
|
||||
elif self.task_name in ('imputation', 'anomaly_detection'):
|
||||
self.head = AttnPoolHeadSeq(D1, self.seq_len, num_slots=agg_slots, dropout=configs.dropout)
|
||||
elif self.task_name == 'classification':
|
||||
self.head_cls = AttnPoolHeadCls(D1, n_vars=self.enc_in, num_class=configs.num_class, num_slots=agg_slots, dropout=configs.dropout)
|
||||
|
||||
# --------- 归一化/反归一化 ---------
|
||||
def _pre_norm(self, x):
|
||||
means = x.mean(1, keepdim=True).detach()
|
||||
x = x - means
|
||||
stdev = torch.sqrt(torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
||||
x = x / stdev
|
||||
return x, means, stdev
|
||||
|
||||
def _denorm(self, y, means, stdev, length):
|
||||
return y * (stdev[:, 0, :].unsqueeze(1).repeat(1, length, 1)) + \
|
||||
(means[:, 0, :].unsqueeze(1).repeat(1, length, 1))
|
||||
|
||||
# --------- DC + Transformer Encoder(携带 key_padding_mask) ----------
|
||||
def _embed_and_encode(self, x_enc):
|
||||
"""
|
||||
x_enc: (B, L, C)
|
||||
返回:
|
||||
enc_out: (B*nvars, T_max, D1)
|
||||
n_vars: int
|
||||
key_padding_mask: (B*nvars, T_max)
|
||||
aux: dict
|
||||
"""
|
||||
B, L, C = x_enc.shape
|
||||
x_vars = x_enc.permute(0, 2, 1) # (B, nvars, L)
|
||||
enc_out, key_padding_mask, n_vars, aux = self.dc_embedding(x_vars)
|
||||
enc_out, _ = self.encoder(enc_out, attn_mask=None, key_padding_mask=key_padding_mask)
|
||||
return enc_out, n_vars, key_padding_mask, B, aux
|
||||
|
||||
# --------- 各任务前向 ---------
|
||||
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
||||
x_enc, means, stdev = self._pre_norm(x_enc)
|
||||
enc_out, n_vars, key_padding_mask, B, aux = self._embed_and_encode(x_enc)
|
||||
dec_out = self.head(enc_out, key_padding_mask, n_vars, B) # (B, pred_len, nvars)
|
||||
dec_out = self._denorm(dec_out, means, stdev, self.pred_len)
|
||||
return dec_out, aux
|
||||
|
||||
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
|
||||
means = torch.sum(x_enc, dim=1) / torch.sum(mask == 1, dim=1)
|
||||
means = means.unsqueeze(1).detach()
|
||||
x = x_enc - means
|
||||
x = x.masked_fill(mask == 0, 0)
|
||||
stdev = torch.sqrt(torch.sum(x * x, dim=1) / torch.sum(mask == 1, dim=1) + 1e-5)
|
||||
stdev = stdev.unsqueeze(1).detach()
|
||||
x = x / stdev
|
||||
|
||||
enc_out, n_vars, key_padding_mask, B, aux = self._embed_and_encode(x)
|
||||
dec_out = self.head(enc_out, key_padding_mask, n_vars, B) # (B, seq_len, nvars)
|
||||
dec_out = self._denorm(dec_out, means, stdev, self.seq_len)
|
||||
return dec_out, aux
|
||||
|
||||
def anomaly_detection(self, x_enc):
|
||||
x_enc, means, stdev = self._pre_norm(x_enc)
|
||||
enc_out, n_vars, key_padding_mask, B, aux = self._embed_and_encode(x_enc)
|
||||
dec_out = self.head(enc_out, key_padding_mask, n_vars, B) # (B, seq_len, nvars)
|
||||
dec_out = self._denorm(dec_out, means, stdev, self.seq_len)
|
||||
return dec_out, aux
|
||||
|
||||
def classification(self, x_enc, x_mark_enc):
|
||||
x_enc, _, _ = self._pre_norm(x_enc)
|
||||
enc_out, n_vars, key_padding_mask, B, aux = self._embed_and_encode(x_enc)
|
||||
logits = self.head_cls(enc_out, key_padding_mask, n_vars, B) # (B, num_class)
|
||||
return logits, aux
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
||||
if self.task_name in ('long_term_forecast', 'short_term_forecast'):
|
||||
dec_out, aux = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
||||
return dec_out[:, -self.pred_len:, :], aux # [B, L, D], aux含ratio losses
|
||||
if self.task_name == 'imputation':
|
||||
dec_out, aux = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
|
||||
return dec_out, aux
|
||||
if self.task_name == 'anomaly_detection':
|
||||
dec_out, aux = self.anomaly_detection(x_enc)
|
||||
return dec_out, aux
|
||||
if self.task_name == 'classification':
|
||||
logits, aux = self.classification(x_enc, x_mark_enc)
|
||||
return logits, aux
|
||||
return None, None
|
339
models/DC_hnet.py
Normal file
339
models/DC_hnet.py
Normal file
@ -0,0 +1,339 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Literal, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# 来自你的代码库(可直接使用)
|
||||
from hnet.modules.dc import RoutingModule, ChunkLayer
|
||||
from hnet.modules.isotropic import Isotropic
|
||||
from hnet.models.config_hnet import HNetConfig, SSMConfig, AttnConfig
|
||||
|
||||
# -------------------- 辅助 --------------------
|
||||
def create_isotropic_encoder(d_model, arch="m", height=4, device=None, dtype=None):
|
||||
"""创建简化的Isotropic编码器"""
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
|
||||
# 创建HNetConfig,确保list字段有足够的元素
|
||||
config = HNetConfig(
|
||||
arch_layout=[f"{arch}{height}"],
|
||||
d_model=[d_model],
|
||||
d_intermediate=[d_model * 2],
|
||||
ssm_cfg=SSMConfig(
|
||||
d_conv=4,
|
||||
expand=2,
|
||||
d_state=128,
|
||||
chunk_size=256
|
||||
),
|
||||
attn_cfg=AttnConfig(
|
||||
num_heads=[8], # 确保有至少一个元素
|
||||
rotary_emb_dim=[0], # 确保有至少一个元素
|
||||
window_size=[-1] # 确保有至少一个元素
|
||||
)
|
||||
)
|
||||
|
||||
return Isotropic(
|
||||
config=config,
|
||||
pos_idx=0,
|
||||
stage_idx=0,
|
||||
**factory_kwargs
|
||||
)
|
||||
|
||||
def ratio_loss(boundary_mask: torch.Tensor, boundary_prob: torch.Tensor, target_N: int) -> torch.Tensor:
|
||||
F_act = boundary_mask.float().mean(dim=1) # (B,)
|
||||
G_prob = boundary_prob[..., 1].mean(dim=1) # (B,)
|
||||
N = float(target_N)
|
||||
loss = N / (N - 1.0) * (((N - 1.0) * F_act) + (1.0 - F_act) * (1.0 - G_prob))
|
||||
return loss.mean()
|
||||
|
||||
def masked_mean(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
mask_f = mask.float().unsqueeze(-1) # (B, L, 1)
|
||||
s = (x * mask_f).sum(dim=1) # (B, D)
|
||||
denom = mask_f.sum(dim=1).clamp_min(1.0)
|
||||
return s / denom
|
||||
|
||||
# -------------------- 多层Encoder(金字塔):每层Mamba2 + 路由下采样,只有最终有主网络 --------------------
|
||||
class PyramidEncoders_NoDechunk(nn.Module):
|
||||
"""
|
||||
层级结构(仅编码器逐层压缩;主网络只在最终一层):
|
||||
输入 x0: (B, L0, 1)
|
||||
- 线性升维 -> D0
|
||||
For s = 0..S-1:
|
||||
Es(Mamba2, D_s) -> h_s (B, L_s, D_s)
|
||||
路由 + 下采样 -> x_{s+1} (B, L_{s+1}, D_s), mask_{s+1}
|
||||
维度扩展 D_s -> D_{s+1}(拼接共享向量)
|
||||
最终 x_S: (B, L_S, D_S) 送入单一主网络 M (Transformer/Mamba)
|
||||
跨尺度融合(不去分块):融合 E^0 的 pooled_enc0 与 主网络 pooled_main
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
d_models: List[int], # [D0, D1, ..., D_S] 单调非降
|
||||
encoder_cfg_per_stage: List[dict], # S个编码器配置(必须 arch='m'/'M')
|
||||
main_cfg: dict, # 单一主网络配置(在最压缩序列上工作)
|
||||
fusion_dropout: float = 0.1,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
super().__init__()
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
|
||||
assert len(d_models) >= 1
|
||||
S = len(d_models) - 1
|
||||
assert S == len(encoder_cfg_per_stage), "stage数等于encoder配置数"
|
||||
for i in range(S):
|
||||
assert d_models[i+1] >= d_models[i], "需满足 D_s <= D_{s+1}(宽度单调增加)"
|
||||
assert encoder_cfg_per_stage[i].get("arch", "m") in ("m", "M"), "Encoder必须为Mamba2"
|
||||
|
||||
self.S = S
|
||||
self.d_models = d_models
|
||||
|
||||
# 输入升维到 D0
|
||||
self.input_proj = nn.Linear(1, d_models[0], **factory_kwargs)
|
||||
|
||||
# 每层编码器 + 路由 + 下采样 + 扩宽参数
|
||||
self.encoders = nn.ModuleList()
|
||||
self.routers = nn.ModuleList()
|
||||
self.chunks = nn.ModuleList()
|
||||
self.pad_vectors = nn.ParameterList()
|
||||
for s in range(S):
|
||||
self.encoders.append(
|
||||
create_isotropic_encoder(
|
||||
d_model=d_models[s],
|
||||
**{k: v for k, v in encoder_cfg_per_stage[s].items() if k != "d_model"},
|
||||
**factory_kwargs
|
||||
)
|
||||
)
|
||||
self.routers.append(RoutingModule(d_models[s], **factory_kwargs))
|
||||
self.chunks.append(ChunkLayer())
|
||||
delta = d_models[s+1] - d_models[s]
|
||||
self.pad_vectors.append(nn.Parameter(torch.zeros(delta, **factory_kwargs)) if delta > 0 else nn.Parameter(torch.empty(0, **factory_kwargs)))
|
||||
|
||||
# 最终唯一的主网络:在 D_S & L_S 上运行
|
||||
self.main_network = create_isotropic_encoder(
|
||||
d_model=d_models[-1],
|
||||
**{k: v for k, v in main_cfg.items() if k != "d_model"},
|
||||
**factory_kwargs
|
||||
)
|
||||
|
||||
# 跨尺度融合:将 pooled_enc0(D0) 投到 D_S 并与 pooled_main(D_S) 融合 -> D_S
|
||||
self.proj_enc0_to_DS = nn.Linear(d_models[0], d_models[-1], **factory_kwargs)
|
||||
self.fusion_head = nn.Sequential(
|
||||
nn.Linear(d_models[-1] + d_models[-1], d_models[-1], **factory_kwargs),
|
||||
nn.GELU(),
|
||||
nn.Dropout(fusion_dropout),
|
||||
nn.Linear(d_models[-1], d_models[-1], **factory_kwargs),
|
||||
)
|
||||
|
||||
def _expand_width(self, x: torch.Tensor, pad_vec: nn.Parameter) -> torch.Tensor:
|
||||
if pad_vec.numel() == 0:
|
||||
return x
|
||||
early = x.shape[:-1]
|
||||
return torch.cat([x, pad_vec.expand(*early, -1)], dim=-1)
|
||||
|
||||
def forward(self, x_scalar: torch.Tensor, mask: Optional[torch.Tensor] = None, return_seq: bool = False):
|
||||
"""
|
||||
x_scalar: (B, L) 或 (B, L, 1)
|
||||
mask: (B, L) bool
|
||||
返回:
|
||||
fused_vec: (B, D_S)
|
||||
debug: 可选
|
||||
aux: 包含各层路由信息(供ratio loss)
|
||||
"""
|
||||
if x_scalar.dim() == 2:
|
||||
x_scalar = x_scalar.unsqueeze(-1) # (B, L, 1)
|
||||
B, L, _ = x_scalar.shape
|
||||
device = x_scalar.device
|
||||
if mask is None:
|
||||
mask = torch.ones(B, L, dtype=torch.bool, device=device)
|
||||
|
||||
# 初始升维到 D0
|
||||
x = self.input_proj(x_scalar) # (B, L0, D0)
|
||||
cur_mask = mask
|
||||
|
||||
pooled_enc0 = None
|
||||
aux_per_stage = []
|
||||
seq_debug = [] if return_seq else None
|
||||
|
||||
# 逐层:Encoder(Mamba2)->Routing->Chunk->Expand D
|
||||
for s in range(self.S):
|
||||
d_in = self.d_models[s]
|
||||
# 细粒度编码(未压缩序列)
|
||||
h_enc = self.encoders[s](x, mask=cur_mask) # (B, L_s, D_s)
|
||||
|
||||
if s == 0:
|
||||
pooled_enc0 = masked_mean(h_enc, cur_mask) # (B, D0)
|
||||
|
||||
# 路由 + 下采样(得到更短序列)
|
||||
bpred = self.routers[s](h_enc, mask=cur_mask)
|
||||
x_next, _, _, mask_next = self.chunks[s](h_enc, bpred.boundary_mask, mask=cur_mask) # (B, L_{s+1}, D_s)
|
||||
|
||||
# 扩展宽度 D_s -> D_{s+1}
|
||||
x_next = self._expand_width(x_next, self.pad_vectors[s]) # (B, L_{s+1}, D_{s+1})
|
||||
|
||||
# 推进到下一层
|
||||
x, cur_mask = x_next, mask_next
|
||||
|
||||
aux_per_stage.append({
|
||||
"boundary_mask": bpred.boundary_mask,
|
||||
"boundary_prob": bpred.boundary_prob,
|
||||
"selected_probs": bpred.selected_probs,
|
||||
})
|
||||
if return_seq:
|
||||
seq_debug.append({"stage": s, "seq": x, "mask": cur_mask})
|
||||
|
||||
# 现在 x: (B, L_S, D_S), cur_mask: (B, L_S)
|
||||
# 最终单一主网络在最压缩序列上
|
||||
h_main = self.main_network(x, mask=cur_mask) # (B, L_S, D_S)
|
||||
|
||||
# 主网络池化
|
||||
if cur_mask is None:
|
||||
pooled_main = h_main.mean(dim=1) # (B, D_S)
|
||||
else:
|
||||
pooled_main = (h_main * cur_mask.float().unsqueeze(-1)).sum(dim=1) / \
|
||||
cur_mask.float().sum(dim=1, keepdim=True).clamp_min(1.0)
|
||||
|
||||
# 跨尺度融合:E^0 全局池化 与 主网络池化
|
||||
pooled_enc0_in_DS = self.proj_enc0_to_DS(pooled_enc0) # (B, D_S)
|
||||
fused = torch.cat([pooled_enc0_in_DS, pooled_main], dim=-1) # (B, 2*D_S)
|
||||
fused = self.fusion_head(fused) # (B, D_S)
|
||||
|
||||
aux = {"per_stage": aux_per_stage}
|
||||
if return_seq:
|
||||
return fused, {"stages": seq_debug, "main_seq": h_main, "main_mask": cur_mask}, aux
|
||||
else:
|
||||
return fused, None, aux
|
||||
|
||||
# -------------------- 顶层:多通道融合 + 分类头(仅一个主网络) --------------------
|
||||
@dataclass
|
||||
class HierEncodersSingleMainConfig:
|
||||
num_channels: int
|
||||
d_models: List[int] # [D0, D1, ..., D_S] 单调非降
|
||||
num_classes: int
|
||||
encoder_cfg_per_stage: List[dict] # S个编码器配置(均为Mamba2, height≈4)
|
||||
main_cfg: dict # 单一主网络配置(Transformer或Mamba2),d_model自动用D_S
|
||||
target_compression_N_per_stage: List[int]
|
||||
share_channel: bool = True
|
||||
fusion_across_channels: Literal["mean", "concat"] = "mean"
|
||||
dropout: float = 0.1
|
||||
|
||||
class HierEncodersSingleMainClassifier(nn.Module):
|
||||
def __init__(self, cfg: HierEncodersSingleMainConfig, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
factory_kwargs = {"dtype": dtype, "device": device}
|
||||
|
||||
S = len(cfg.d_models) - 1
|
||||
assert S == len(cfg.encoder_cfg_per_stage) == len(cfg.target_compression_N_per_stage), "stage数不一致"
|
||||
|
||||
if cfg.share_channel:
|
||||
self.channel_encoder = PyramidEncoders_NoDechunk(
|
||||
d_models=cfg.d_models,
|
||||
encoder_cfg_per_stage=cfg.encoder_cfg_per_stage,
|
||||
main_cfg=cfg.main_cfg,
|
||||
**factory_kwargs,
|
||||
)
|
||||
else:
|
||||
self.channel_encoder = nn.ModuleList([
|
||||
PyramidEncoders_NoDechunk(
|
||||
d_models=cfg.d_models,
|
||||
encoder_cfg_per_stage=cfg.encoder_cfg_per_stage,
|
||||
main_cfg=cfg.main_cfg,
|
||||
**factory_kwargs,
|
||||
)
|
||||
for _ in range(cfg.num_channels)
|
||||
])
|
||||
|
||||
fusion_dim = (cfg.num_channels * cfg.d_models[-1]) if cfg.fusion_across_channels == "concat" \
|
||||
else cfg.d_models[-1]
|
||||
self.dropout = nn.Dropout(cfg.dropout)
|
||||
self.head = nn.Linear(fusion_dim, cfg.num_classes, **factory_kwargs)
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, return_seq: bool = False):
|
||||
"""
|
||||
x: (B, L, N) 多通道输入
|
||||
mask: (B, L) 时序mask
|
||||
"""
|
||||
B, L, N = x.shape
|
||||
assert N == self.cfg.num_channels
|
||||
|
||||
channel_vecs: List[torch.Tensor] = []
|
||||
ratio_losses = []
|
||||
seq_dbg_all = [] if return_seq else None
|
||||
|
||||
for c in range(N):
|
||||
x_c = x[..., c] # (B, L)
|
||||
if self.cfg.share_channel:
|
||||
vec, seq_dbg, aux = self.channel_encoder(x_c, mask=mask, return_seq=return_seq)
|
||||
else:
|
||||
vec, seq_dbg, aux = self.channel_encoder[c](x_c, mask=mask, return_seq=return_seq)
|
||||
|
||||
# ratio loss 累加(每个encoder stage一项)
|
||||
total_rl = 0.0
|
||||
for s, aux_s in enumerate(aux["per_stage"]):
|
||||
rl = ratio_loss(aux_s["boundary_mask"], aux_s["boundary_prob"], self.cfg.target_compression_N_per_stage[s])
|
||||
total_rl = total_rl + rl
|
||||
ratio_losses.append(total_rl)
|
||||
|
||||
channel_vecs.append(vec)
|
||||
if return_seq:
|
||||
seq_dbg_all.append(seq_dbg)
|
||||
|
||||
if self.cfg.fusion_across_channels == "concat":
|
||||
fused = torch.cat(channel_vecs, dim=-1) # (B, N*D_S)
|
||||
else:
|
||||
fused = torch.stack(channel_vecs, dim=1).mean(dim=1) # (B, D_S)
|
||||
|
||||
fused = self.dropout(fused)
|
||||
logits = self.head(fused)
|
||||
|
||||
aux_all = {"ratio_loss": torch.stack(ratio_losses).mean()}
|
||||
if return_seq:
|
||||
return logits, seq_dbg_all, aux_all
|
||||
else:
|
||||
return logits, None, aux_all
|
||||
|
||||
# -------------------- 使用示例 --------------------
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
符合要求:
|
||||
- 多层仅增加编码器数量(每层Mamba2 + 动态分块),主网络只有最终一个
|
||||
- 序列长度逐层缩短(由DC决定),通道维度 d_model 单调增大(SpaceByte式共享向量拼接)
|
||||
- 不使用去分块(dechunk);跨尺度融合用 E^0 的全局池化 + 最终主网络池化
|
||||
"""
|
||||
B, L, N = 8, 1024, 6
|
||||
num_classes = 7
|
||||
d_models = [128, 256, 512] # D0 <= D1 <= D2
|
||||
|
||||
encoder_cfg_per_stage = [
|
||||
dict(arch="m", height=4, ssm_cfg=dict(), attn_cfg=dict()), # stage 0 encoder (Mamba2)
|
||||
dict(arch="m", height=4, ssm_cfg=dict(), attn_cfg=dict()), # stage 1 encoder (Mamba2)
|
||||
]
|
||||
main_cfg = dict(
|
||||
arch="T", height=12, ssm_cfg=dict(), attn_cfg=dict(num_heads=8) # 最终主网络(较重)
|
||||
)
|
||||
target_compression_N_per_stage = [4, 4]
|
||||
|
||||
cfg = HierEncodersSingleMainConfig(
|
||||
num_channels=N,
|
||||
d_models=d_models,
|
||||
num_classes=num_classes,
|
||||
encoder_cfg_per_stage=encoder_cfg_per_stage,
|
||||
main_cfg=main_cfg,
|
||||
target_compression_N_per_stage=target_compression_N_per_stage,
|
||||
share_channel=True,
|
||||
fusion_across_channels="mean",
|
||||
dropout=0.1,
|
||||
)
|
||||
|
||||
model = HierEncodersSingleMainClassifier(cfg).cuda().train()
|
||||
x = torch.randn(B, L, N, device="cuda")
|
||||
mask = torch.ones(B, L, dtype=torch.bool, device="cuda")
|
||||
|
||||
logits, _, aux = model(x, mask=mask, return_seq=False)
|
||||
y = torch.randint(0, num_classes, (B,), device="cuda")
|
||||
cls_loss = F.cross_entropy(logits, y)
|
||||
ratio_reg = 0.03 * aux["ratio_loss"]
|
||||
loss = cls_loss + ratio_reg
|
||||
loss.backward()
|
||||
print("logits:", logits.shape, "loss:", float(loss))
|
138
models/vanillaMamba-Copy1.py
Normal file
138
models/vanillaMamba-Copy1.py
Normal file
@ -0,0 +1,138 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mamba_ssm import Mamba2
|
||||
|
||||
|
||||
class ValueEmbedding(nn.Module):
|
||||
"""
|
||||
对每个时间步的单通道标量做线性投影到 d_model,并可选 Dropout。
|
||||
不包含 temporal embedding 和 positional embedding。
|
||||
"""
|
||||
def __init__(self, in_dim: int, d_model: int, dropout: float = 0.0, bias: bool = True):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(in_dim, d_model, bias=bias)
|
||||
self.dropout = nn.Dropout(dropout) if dropout and dropout > 0.0 else nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# x: [B, L, 1] -> [B, L, d_model]
|
||||
return self.dropout(self.proj(x))
|
||||
|
||||
|
||||
class ChannelMambaBlock(nn.Module):
|
||||
"""
|
||||
针对单个通道的两层 Mamba-2 处理块:
|
||||
- 输入: [B, L, 1],先做投影到 d_model
|
||||
- 两层 Mamba2,且在第一层输出和第二层输出均添加残差连接
|
||||
- 每层后接 LayerNorm
|
||||
- 输出: [B, L, d_model]
|
||||
"""
|
||||
def __init__(self, d_model: int, dropout: float, m2_kwargs: dict):
|
||||
super().__init__()
|
||||
self.embed = ValueEmbedding(in_dim=1, d_model=d_model, dropout=dropout, bias=True)
|
||||
|
||||
# 两层 Mamba-2
|
||||
self.mamba1 = Mamba2(d_model=d_model, **m2_kwargs)
|
||||
self.mamba2 = Mamba2(d_model=d_model, **m2_kwargs)
|
||||
|
||||
# 每层后接的归一化
|
||||
self.ln1 = nn.LayerNorm(d_model)
|
||||
self.ln2 = nn.LayerNorm(d_model)
|
||||
|
||||
def forward(self, x_ch: torch.Tensor) -> torch.Tensor:
|
||||
# x_ch: [B, L, 1]
|
||||
x = self.embed(x_ch) # [B, L, d_model]
|
||||
|
||||
# 第一层 + 残差
|
||||
y1 = self.mamba1(x) # [B, L, d_model]
|
||||
y1 = self.ln1(x + y1) # 残差1 + LN
|
||||
|
||||
# 第二层 + 残差
|
||||
y2 = self.mamba2(y1) # [B, L, d_model]
|
||||
y2 = self.ln2(y1 + y2) # 残差2 + LN
|
||||
|
||||
return y2 # [B, L, d_model]
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
"""
|
||||
按通道独立处理的 Mamba-2 分类模型:
|
||||
- 将输入的每个通道拆开,分别使用独立的两层 Mamba2(含两处残差)
|
||||
- 每个通道得到 [B, L, d_model] 输出
|
||||
- 取各通道最后时间步的表示拼接,接分类头
|
||||
输入:
|
||||
- x_enc: [B, L, D] 多变量时间序列
|
||||
输出:
|
||||
- logits: [B, num_class]
|
||||
"""
|
||||
def __init__(self, configs):
|
||||
super().__init__()
|
||||
self.task_name = getattr(configs, 'task_name', 'classification')
|
||||
assert self.task_name == 'classification', "当前模型仅实现 classification 任务"
|
||||
|
||||
# 基本配置
|
||||
self.enc_in = configs.enc_in # 通道数 D
|
||||
self.d_model = configs.d_model # 每通道的模型维度
|
||||
self.num_class = configs.num_class
|
||||
self.dropout = getattr(configs, 'dropout', 0.1)
|
||||
|
||||
# Mamba-2 超参数(按需从 configs 读取)
|
||||
# 注意:此处不再使用 e_layers 的堆叠,而是固定每通道两层以满足“在第一层和第二层输出处添加残差”的要求
|
||||
m2_kwargs = dict(
|
||||
d_state=getattr(configs, 'd_state', 64),
|
||||
d_conv=getattr(configs, 'd_conv', 4),
|
||||
expand=getattr(configs, 'expand', 2),
|
||||
headdim=getattr(configs, 'headdim', 64),
|
||||
d_ssm=getattr(configs, 'd_ssm', None),
|
||||
ngroups=getattr(configs, 'ngroups', 1),
|
||||
A_init_range=getattr(configs, 'A_init_range', (1, 16)),
|
||||
D_has_hdim=getattr(configs, 'D_has_hdim', False),
|
||||
rmsnorm=getattr(configs, 'rmsnorm', True),
|
||||
norm_before_gate=getattr(configs, 'norm_before_gate', False),
|
||||
dt_min=getattr(configs, 'dt_min', 0.001),
|
||||
dt_max=getattr(configs, 'dt_max', 0.1),
|
||||
dt_init_floor=getattr(configs, 'dt_init_floor', 1e-4),
|
||||
dt_limit=getattr(configs, 'dt_limit', (0.0, float("inf"))),
|
||||
bias=getattr(configs, 'bias', False),
|
||||
conv_bias=getattr(configs, 'conv_bias', True),
|
||||
chunk_size=getattr(configs, 'chunk_size', 256),
|
||||
use_mem_eff_path=getattr(configs, 'use_mem_eff_path', True),
|
||||
)
|
||||
|
||||
# 为每个通道构建独立的两层 Mamba2 处理块
|
||||
self.channel_blocks = nn.ModuleList([
|
||||
ChannelMambaBlock(d_model=self.d_model, dropout=self.dropout, m2_kwargs=m2_kwargs)
|
||||
for _ in range(self.enc_in)
|
||||
])
|
||||
|
||||
# 分类头:将各通道最后时间步的表示拼接后 -> GELU -> Dropout -> Linear
|
||||
self.act = nn.GELU()
|
||||
self.head = nn.Sequential(
|
||||
nn.Dropout(self.dropout),
|
||||
nn.Linear(self.d_model * self.enc_in, self.num_class)
|
||||
)
|
||||
|
||||
def classification(self, x_enc: torch.Tensor) -> torch.Tensor:
|
||||
# x_enc: [B, L, D]
|
||||
B, L, D = x_enc.shape
|
||||
assert D == self.enc_in, f"输入通道数 {D} 与 enc_in {self.enc_in} 不一致"
|
||||
|
||||
per_channel_last = []
|
||||
for c in range(D):
|
||||
# 取出单通道序列 [B, L] -> [B, L, 1]
|
||||
x_ch = x_enc[:, :, c].unsqueeze(-1)
|
||||
y_ch = self.channel_blocks[c](x_ch) # [B, L, d_model]
|
||||
per_channel_last.append(y_ch[:, -1, :]) # [B, d_model]
|
||||
|
||||
# 拼接各通道最后时刻的表示 -> [B, D * d_model]
|
||||
h_last = torch.cat(per_channel_last, dim=-1)
|
||||
|
||||
# 分类头
|
||||
h_last = self.act(h_last)
|
||||
logits = self.head(h_last) # [B, num_class]
|
||||
return logits
|
||||
|
||||
# 与 TimesNet 的 forward 签名保持一致;忽略 x_mark_enc / x_dec / x_mark_dec / mask
|
||||
def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None):
|
||||
return self.classification(x_enc)
|
203
models/vanillaMamba.py
Normal file
203
models/vanillaMamba.py
Normal file
@ -0,0 +1,203 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mamba_ssm import Mamba2
|
||||
|
||||
|
||||
class ValueEmbedding(nn.Module):
|
||||
"""
|
||||
对每个时间步的单通道标量做线性投影到 d_model,并可选 Dropout。
|
||||
不包含 temporal embedding 和 positional embedding。
|
||||
"""
|
||||
def __init__(self, in_dim: int, d_model: int, dropout: float = 0.0, bias: bool = True):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(in_dim, d_model, bias=bias)
|
||||
self.dropout = nn.Dropout(dropout) if dropout and dropout > 0.0 else nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# x: [B, L, 1] -> [B, L, d_model]
|
||||
return self.dropout(self.proj(x))
|
||||
|
||||
|
||||
class ChannelMambaBlock(nn.Module):
|
||||
"""
|
||||
针对单个通道的两层 Mamba-2 处理块:
|
||||
- 输入: [B, L, 1],先做投影到 d_model
|
||||
- 两层 Mamba2,且在第一层输出和第二层输出均添加残差连接
|
||||
- 每层后接 LayerNorm
|
||||
- 输出: [B, L, d_model]
|
||||
"""
|
||||
def __init__(self, d_model: int, dropout: float, m2_kwargs: dict):
|
||||
super().__init__()
|
||||
self.embed = ValueEmbedding(in_dim=1, d_model=d_model, dropout=dropout, bias=True)
|
||||
|
||||
# 两层 Mamba-2
|
||||
self.mamba1 = Mamba2(d_model=d_model, **m2_kwargs)
|
||||
self.mamba2 = Mamba2(d_model=d_model, **m2_kwargs)
|
||||
|
||||
# 每层后接的归一化
|
||||
self.ln1 = nn.LayerNorm(d_model)
|
||||
self.ln2 = nn.LayerNorm(d_model)
|
||||
|
||||
def forward(self, x_ch: torch.Tensor) -> torch.Tensor:
|
||||
# x_ch: [B, L, 1]
|
||||
x = self.embed(x_ch) # [B, L, d_model]
|
||||
|
||||
# 第一层 + 残差
|
||||
y1 = self.mamba1(x) # [B, L, d_model]
|
||||
y1 = self.ln1(x + y1) # 残差1 + LN
|
||||
|
||||
# 第二层 + 残差
|
||||
y2 = self.mamba2(y1) # [B, L, d_model]
|
||||
y2 = self.ln2(y1 + y2) # 残差2 + LN
|
||||
|
||||
return y2 # [B, L, d_model]
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
"""
|
||||
按通道独立处理的 Mamba-2 模型,支持:
|
||||
- 分类:各通道独立提取,取最后时刻拼接 -> 分类头
|
||||
- 长/短期预测:各通道独立提取,保留整段序列,经时间维线性映射到目标长度,再投影回标量并拼接
|
||||
注意:预测输出通道数与输入通道数严格相同(逐通道预测)。
|
||||
|
||||
输入:
|
||||
- x_enc: [B, L, D] 多变量时间序列
|
||||
- x_mark_enc, x_dec, x_mark_dec, mask: 兼容接口参数(本模型在分类/预测中未使用这些标注)
|
||||
|
||||
输出:
|
||||
- classification: logits [B, num_class]
|
||||
- forecast: [B, pred_len, D]
|
||||
"""
|
||||
def __init__(self, configs):
|
||||
super().__init__()
|
||||
# 任务类型
|
||||
self.task_name = getattr(configs, 'task_name', 'classification')
|
||||
assert self.task_name in ['classification', 'long_term_forecast', 'short_term_forecast'], \
|
||||
"只支持 classification / long_term_forecast / short_term_forecast"
|
||||
|
||||
# 基本配置
|
||||
self.enc_in = configs.enc_in # 通道数 D
|
||||
self.d_model = configs.d_model # 每通道的模型维度
|
||||
self.num_class = getattr(configs, 'num_class', None)
|
||||
self.dropout = getattr(configs, 'dropout', 0.1)
|
||||
|
||||
# 预测相关
|
||||
self.seq_len = getattr(configs, 'seq_len', None)
|
||||
self.pred_len = getattr(configs, 'pred_len', None)
|
||||
if self.task_name in ['long_term_forecast', 'short_term_forecast']:
|
||||
assert self.seq_len is not None and self.pred_len is not None, "预测任务需要 seq_len 与 pred_len"
|
||||
# 输出通道必须与输入通道一致
|
||||
self.c_out = getattr(configs, 'c_out', self.enc_in)
|
||||
assert self.c_out == self.enc_in, "预测任务要求输出通道 c_out 与输入通道 enc_in 一致"
|
||||
|
||||
# Mamba-2 超参数
|
||||
m2_kwargs = dict(
|
||||
d_state=getattr(configs, 'd_state', 64),
|
||||
d_conv=getattr(configs, 'd_conv', 4),
|
||||
expand=getattr(configs, 'expand', 2),
|
||||
headdim=getattr(configs, 'headdim', 64),
|
||||
d_ssm=getattr(configs, 'd_ssm', None),
|
||||
ngroups=getattr(configs, 'ngroups', 1),
|
||||
A_init_range=getattr(configs, 'A_init_range', (1, 16)),
|
||||
D_has_hdim=getattr(configs, 'D_has_hdim', False),
|
||||
rmsnorm=getattr(configs, 'rmsnorm', True),
|
||||
norm_before_gate=getattr(configs, 'norm_before_gate', False),
|
||||
dt_min=getattr(configs, 'dt_min', 0.001),
|
||||
dt_max=getattr(configs, 'dt_max', 0.1),
|
||||
dt_init_floor=getattr(configs, 'dt_init_floor', 1e-4),
|
||||
dt_limit=getattr(configs, 'dt_limit', (0.0, float("inf"))),
|
||||
bias=getattr(configs, 'bias', False),
|
||||
conv_bias=getattr(configs, 'conv_bias', True),
|
||||
chunk_size=getattr(configs, 'chunk_size', 256),
|
||||
use_mem_eff_path=getattr(configs, 'use_mem_eff_path', True),
|
||||
)
|
||||
|
||||
# 为每个通道构建独立的两层 Mamba2 处理块
|
||||
self.channel_blocks = nn.ModuleList([
|
||||
ChannelMambaBlock(d_model=self.d_model, dropout=self.dropout, m2_kwargs=m2_kwargs)
|
||||
for _ in range(self.enc_in)
|
||||
])
|
||||
|
||||
# 分类头:将各通道最后时间步的表示拼接后 -> GELU -> Dropout -> Linear
|
||||
if self.task_name == 'classification':
|
||||
assert self.num_class is not None, "classification 需要提供 num_class"
|
||||
self.act = nn.GELU()
|
||||
self.head = nn.Sequential(
|
||||
nn.Dropout(self.dropout),
|
||||
nn.Linear(self.d_model * self.enc_in, self.num_class)
|
||||
)
|
||||
|
||||
# 预测头:
|
||||
# - 先对时间维做线性映射: [B, L, d_model] -> [B, pred_len, d_model]
|
||||
# - 再将 d_model 投影为单通道标量: [B, pred_len, d_model] -> [B, pred_len, 1]
|
||||
if self.task_name in ['long_term_forecast', 'short_term_forecast']:
|
||||
self.predict_linear = nn.Linear(self.seq_len, self.pred_len)
|
||||
self.projection = nn.Linear(self.d_model, 1, bias=True)
|
||||
|
||||
def classification(self, x_enc: torch.Tensor) -> torch.Tensor:
|
||||
# x_enc: [B, L, D]
|
||||
B, L, D = x_enc.shape
|
||||
assert D == self.enc_in, f"输入通道数 {D} 与 enc_in {self.enc_in} 不一致"
|
||||
|
||||
per_channel_last = []
|
||||
for c in range(D):
|
||||
# 取出单通道序列 [B, L] -> [B, L, 1]
|
||||
x_ch = x_enc[:, :, c].unsqueeze(-1)
|
||||
y_ch = self.channel_blocks[c](x_ch) # [B, L, d_model]
|
||||
per_channel_last.append(y_ch[:, -1, :]) # [B, d_model]
|
||||
|
||||
# 拼接各通道最后时刻的表示 -> [B, D * d_model]
|
||||
h_last = torch.cat(per_channel_last, dim=-1)
|
||||
|
||||
# 分类头
|
||||
logits = self.head(self.act(h_last)) # [B, num_class]
|
||||
return logits
|
||||
|
||||
def forecast(self, x_enc: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
逐通道预测:
|
||||
- 归一化(时间维),按通道独立提取
|
||||
- 使用整段 Mamba 输出序列,经时间维线性映射到目标长度,再投影为标量
|
||||
- 反归一化
|
||||
返回:
|
||||
dec_out: [B, L+pred_len, D],在 forward 中会取最后 pred_len 段
|
||||
"""
|
||||
B, L, D = x_enc.shape
|
||||
assert L == self.seq_len, f"输入长度 {L} 与配置 seq_len {self.seq_len} 不一致"
|
||||
assert D == self.enc_in, f"输入通道数 {D} 与 enc_in {self.enc_in} 不一致"
|
||||
|
||||
# Normalization (per Non-stationary Transformer)
|
||||
means = x_enc.mean(1, keepdim=True).detach() # [B, 1, D]
|
||||
x = x_enc - means
|
||||
stdev = torch.sqrt(x.var(dim=1, keepdim=True, unbiased=False) + 1e-5) # [B, 1, D]
|
||||
x = x / stdev
|
||||
|
||||
per_channel_seq = []
|
||||
for c in range(D):
|
||||
x_ch = x[:, :, c].unsqueeze(-1) # [B, L, 1]
|
||||
h_ch = self.channel_blocks[c](x_ch) # [B, L, d_model]
|
||||
# 时间维映射到 L + pred_len
|
||||
h_ch = self.predict_linear(h_ch.permute(0, 2, 1)).permute(0, 2, 1) # [B, L+pred_len, d_model]
|
||||
# 投影回单通道
|
||||
y_ch = self.projection(h_ch) # [B, L+pred_len, 1]
|
||||
per_channel_seq.append(y_ch)
|
||||
|
||||
# 拼接通道
|
||||
dec_out = torch.cat(per_channel_seq, dim=-1) # [B, pred_len, D]
|
||||
|
||||
# De-normalization
|
||||
dec_out = dec_out * stdev[:, 0, :].unsqueeze(1) + means[:, 0, :].unsqueeze(1)
|
||||
|
||||
return dec_out
|
||||
|
||||
# 与 TimesNet 的 forward 签名保持一致;忽略 x_mark_enc / x_dec / x_mark_dec / mask
|
||||
def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None):
|
||||
if self.task_name in ['long_term_forecast', 'short_term_forecast']:
|
||||
dec_out = self.forecast(x_enc) # [B, L+pred_len, D]
|
||||
return dec_out[:, -self.pred_len:, :] # 仅返回预测部分 [B, pred_len, D]
|
||||
elif self.task_name == 'classification':
|
||||
return self.classification(x_enc)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported task: {self.task_name}")
|
Reference in New Issue
Block a user