import torch from torch import nn import torch.nn.functional as F from layers.SelfAttention_Family import FullAttention, AttentionLayer # 需要 Mamba2 作为外层编码器 from mamba_ssm.modules.mamba2 import Mamba2 # -------------------- Routing(余弦路由,和论文一致) -------------------- class RoutingModule(nn.Module): def __init__(self, d_model): super().__init__() self.q_proj = nn.Linear(d_model, d_model, bias=False) self.k_proj = nn.Linear(d_model, d_model, bias=False) with torch.no_grad(): nn.init.eye_(self.q_proj.weight) nn.init.eye_(self.k_proj.weight) self.q_proj.weight._no_reinit = True self.k_proj.weight._no_reinit = True def forward(self, x, mask=None): """ x: (B, L, D) mask: (B, L) bool, True=有效 返回: boundary_prob: (B, L, 2) boundary_mask: (B, L) bool selected_probs: (B, L, 1) """ B, L, D = x.shape q = F.normalize(self.q_proj(x[:, :-1]), dim=-1) # (B, L-1, D) k = F.normalize(self.k_proj(x[:, 1:]), dim=-1) # (B, L-1, D) cos_sim = (q * k).sum(dim=-1) # (B, L-1) p = torch.clamp((1 - cos_sim) / 2, 0.0, 1.0) # (B, L-1) p = F.pad(p, (1, 0), value=1.0) # 强制首位是边界 if mask is not None: p = p * mask.float() p[:, 0] = torch.where(mask[:, 0], torch.ones_like(p[:, 0]), p[:, 0]) boundary_prob = torch.stack([1 - p, p], dim=-1) # (B, L, 2) selected_idx = boundary_prob.argmax(dim=-1) boundary_mask = (selected_idx == 1) if mask is not None: boundary_mask = boundary_mask & mask selected_probs = boundary_prob.gather(-1, selected_idx.unsqueeze(-1)) # (B, L, 1) return boundary_prob, boundary_mask, selected_probs # -------------------- 选择并右侧零pad(不丢弃、不重复填充) -------------------- def select_and_right_pad(x, boundary_mask): """ 内存优化版本:减少临时tensor创建 x: (B, L, D), boundary_mask: (B, L) bool 返回: x_pad: (B, T_max, D) key_padding_mask: (B, T_max) bool, True=有效 lengths: (B,) """ B, L, D = x.shape device = x.device lengths = boundary_mask.sum(dim=1) # (B,) T_max = int(lengths.max().item()) if lengths.max() > 0 else 1 x_pad = x.new_zeros(B, T_max, D) key_padding_mask = torch.zeros(B, T_max, dtype=torch.bool, device=device) # 预创建默认索引tensor避免重复创建 default_idx = torch.tensor([0], device=device) for b in range(B): mask_b = boundary_mask[b] if mask_b.any(): idx = mask_b.nonzero(as_tuple=True)[0] # 更高效的nonzero t = idx.numel() x_pad[b, :t] = x[b, idx] key_padding_mask[b, :t] = True else: # 使用预创建的tensor x_pad[b, 0] = x[b, default_idx] key_padding_mask[b, 0] = True return x_pad, key_padding_mask, lengths # -------------------- Mamba2 堆叠(外层编码器) -------------------- class Mamba2Encoder(nn.Module): def __init__(self, d_model, depth=4, dropout=0.0): super().__init__() self.layers = nn.ModuleList([Mamba2(d_model=d_model) for _ in range(depth)]) self.norm = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) def forward(self, x): for layer in self.layers: x = layer(x) x = self.norm(x) x = self.dropout(x) return x # -------------------- 两层Encoder + DC(变长,不丢信息;以比率约束压缩) -------------------- class DCEmbedding2StageVarLen(nn.Module): """ - Stage 0: (B*nvars, L, 1) -> Linear(D0) -> Mamba2(D0) -> Routing -> 选择 -> 扩宽到 D1 - Stage 1: (B*nvars, L0_sel, D1) -> Mamba2(D1) -> Routing -> 选择 输出: enc_out: (B*nvars, T_max, D1) key_padding_mask: (B*nvars, T_max) n_vars: int aux: dict(含两层ratio loss与边界信息) """ def __init__(self, d_model_out, d_model_stage0, depth_enc0=4, depth_enc1=4, dropout=0.0, target_ratio0=0.25, target_ratio1=0.5): super().__init__() assert d_model_out >= d_model_stage0, "要求 D0 <= D1" self.d0 = d_model_stage0 self.d1 = d_model_out # 标量 -> D0 self.input_proj = nn.Linear(1, self.d0) # Stage 0 self.enc0 = Mamba2Encoder(self.d0, depth=depth_enc0, dropout=dropout) self.router0 = RoutingModule(self.d0) delta = self.d1 - self.d0 self.pad_vec = nn.Parameter(torch.zeros(delta)) if delta > 0 else None self.target_ratio0 = target_ratio0 # Stage 1 self.enc1 = Mamba2Encoder(self.d1, depth=depth_enc1, dropout=dropout) self.router1 = RoutingModule(self.d1) self.target_ratio1 = target_ratio1 def _expand_width(self, x): if self.pad_vec is None: return x B, L, _ = x.shape return torch.cat([x, self.pad_vec.view(1, 1, -1).expand(B, L, -1)], dim=-1) @staticmethod def _ratio_loss(boundary_mask: torch.Tensor, boundary_prob: torch.Tensor, target_ratio: float) -> torch.Tensor: eps = 1e-6 F_act = boundary_mask.float().mean(dim=1) # (B,) G_prob = boundary_prob[..., 1].mean(dim=1) # (B,) N = 1.0 / max(target_ratio, eps) loss = N / (N - 1.0 + eps) * (((N - 1.0) * F_act * G_prob) + (1.0 - F_act) * (1.0 - G_prob)) return loss.mean() def forward(self, x): """ x: (B, nvars, L) 内存优化版本:及时删除中间tensor """ B, nvars, L = x.shape x = x.reshape(B * nvars, L, 1) x = self.input_proj(x) # (B*nvars, L, D0) # Stage 0 h0 = self.enc0(x) # (B*nvars, L, D0) p0, bm0, _ = self.router0(h0) h0_sel, mask0, len0 = select_and_right_pad(h0, bm0) # (B*nvars, L0_max, D0) # 及时删除不需要的tensor del h0 # h0_sel = self._expand_width(h0_sel) # (B*nvars, L0_max, D1) # Stage 1 #h1 = self.enc1(h0_sel) # (B*nvars, L0_max, D1) #p1, bm1, _ = self.router1(h1) #bm1 = bm1 & mask0 #h1_sel, mask1, len1 = select_and_right_pad(h1, bm1) # (B*nvars, L1_max, D1) # 及时删除中间tensor #del h1, h0_sel # 计算ratio loss时使用detach避免保存计算图 ratio_loss0 = self._ratio_loss(bm0, p0, target_ratio=self.target_ratio0) # ratio_loss1 = self._ratio_loss(bm1, p1, target_ratio=self.target_ratio1) # 简化aux字典,只保存必要信息 aux = { "stage0": {"boundary_mask": bm0.detach(), "boundary_prob": p0.detach(), "lengths": len0.detach()}, # "stage1": {"boundary_mask": bm1.detach(), "boundary_prob": p1.detach(), "lengths": len1.detach()}, "ratio_loss0": ratio_loss0, # "ratio_loss1": ratio_loss1, } return h0_sel, mask0, nvars, aux # -------------------- Encoder/EncoderLayer(带 key_padding_mask 透传) -------------------- class EncoderLayerWithMask(nn.Module): """ 与原EncoderLayer结构一致,但 forward 增加 key_padding_mask,并传入 AttentionLayer。 FFN 用简单的 MLP(与常规Transformer一致)。 """ def __init__(self, attention: AttentionLayer, d_model, d_ff, dropout=0.1, activation="gelu"): super().__init__() self.attention = attention self.dropout = nn.Dropout(dropout) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) if activation == "relu": act = nn.ReLU() elif activation == "gelu": act = nn.GELU() else: raise ValueError(f"Unsupported activation: {activation}") self.ffn = nn.Sequential( nn.Linear(d_model, d_ff), act, nn.Dropout(dropout), nn.Linear(d_ff, d_model), ) def forward(self, x, attn_mask=None, tau=None, delta=None, key_padding_mask=None): # Multi-head attention with key padding mask attn_out, attn = self.attention( x, x, x, attn_mask, tau=tau, delta=delta, key_padding_mask=key_padding_mask ) x = x + self.dropout(attn_out) x = self.norm1(x) # FFN y = self.ffn(x) x = x + self.dropout(y) x = self.norm2(x) return x, attn class EncoderWithMask(nn.Module): """ 与原Encoder类似,但 forward 支持 key_padding_mask,并传递给每一层的注意力。 """ def __init__(self, attn_layers, norm_layer=None): super().__init__() self.attn_layers = nn.ModuleList(attn_layers) self.norm = norm_layer def forward(self, x, attn_mask=None, key_padding_mask=None): attns = [] for attn_layer in self.attn_layers: x, attn = attn_layer(x, attn_mask=attn_mask, key_padding_mask=key_padding_mask) attns.append(attn) if self.norm is not None: x = self.norm(x) return x, attns # -------------------- 门控注意力聚合 + 任务头(不依赖token数;保留信息) -------------------- def masked_softmax(logits: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor: """ logits: (..., T) mask: (..., T) bool, True=有效 """ neg_inf = torch.finfo(logits.dtype).min logits = logits.masked_fill(~mask, neg_inf) return torch.softmax(logits, dim=dim) class GatedAttnAggregator(nn.Module): """ 门控注意力聚合器(可学习查询 + mask softmax + 值端门控) 输入: x: (B*, T, D), mask: (B*, T) bool 输出: slots: (B*, R, D) 其中 R 为聚合插槽数(可配置) """ def __init__(self, d_model: int, num_slots: int = 4, d_att: int = None, dropout: float = 0.1): super().__init__() self.d_model = d_model self.R = num_slots self.d_att = d_att or d_model # 可学习查询(R个) self.query = nn.Parameter(torch.randn(self.R, self.d_att) / (self.d_att ** 0.5)) # 线性投影 self.key_proj = nn.Linear(d_model, self.d_att) self.val_proj = nn.Linear(d_model, d_model) # 值端门控(逐token标量门) self.gate = nn.Sequential( nn.Linear(d_model, d_model // 2), nn.GELU(), nn.Linear(d_model // 2, 1), nn.Sigmoid() ) self.dropout = nn.Dropout(dropout) def forward(self, x_bt_t_d: torch.Tensor, mask_bt: torch.Tensor) -> torch.Tensor: """ x_bt_t_d: (B*, T, D) mask_bt: (B*, T) bool return: slots (B*, R, D) """ BStar, T, D = x_bt_t_d.shape K = self.key_proj(x_bt_t_d) # (B*, T, d_att) V = self.val_proj(x_bt_t_d) # (B*, T, D) g = self.gate(x_bt_t_d) # (B*, T, 1) Vg = V * g # 门控后的值 Q = self.query.unsqueeze(0).expand(BStar, -1, -1) # (B*, R, d_att) logits = torch.matmul(Q, K.transpose(1, 2)) / (self.d_att ** 0.5) # (B*, R, T) attn_mask = mask_bt.unsqueeze(1) # (B*, 1, T) attn = masked_softmax(logits, attn_mask, dim=-1) attn = self.dropout(attn) slots = torch.matmul(attn, Vg) # (B*, R, D) return slots class AttnPoolHeadForecast(nn.Module): """ 预测任务头:门控注意力聚合到 R 个slots,再映射到 target_window(pred_len) 输出:(B, pred_len, nvars) """ def __init__(self, d_model: int, target_window: int, num_slots: int = 4, dropout: float = 0.1): super().__init__() self.agg = GatedAttnAggregator(d_model, num_slots=num_slots, dropout=dropout) self.proj = nn.Sequential( nn.LayerNorm(num_slots * d_model), nn.Linear(num_slots * d_model, target_window), ) self.dropout = nn.Dropout(dropout) def forward(self, enc_out_bt_t_d: torch.Tensor, key_padding_mask_bt: torch.Tensor, n_vars: int, B: int): slots = self.agg(enc_out_bt_t_d, key_padding_mask_bt) # (B*, R, D) slots = slots.reshape(B, n_vars, -1) # (B, nvars, R*D) out = self.proj(self.dropout(slots)) # (B, nvars, pred_len) return out.permute(0, 2, 1) # (B, pred_len, nvars) class AttnPoolHeadSeq(nn.Module): """ 序列重建头:门控注意力聚合后映射到 seq_len 输出:(B, seq_len, nvars) """ def __init__(self, d_model: int, target_window: int, num_slots: int = 4, dropout: float = 0.1): super().__init__() self.agg = GatedAttnAggregator(d_model, num_slots=num_slots, dropout=dropout) self.proj = nn.Sequential( nn.LayerNorm(num_slots * d_model), nn.Linear(num_slots * d_model, target_window), ) self.dropout = nn.Dropout(dropout) def forward(self, enc_out_bt_t_d: torch.Tensor, key_padding_mask_bt: torch.Tensor, n_vars: int, B: int): slots = self.agg(enc_out_bt_t_d, key_padding_mask_bt) # (B*, R, D) slots = slots.reshape(B, n_vars, -1) # (B, nvars, R*D) out = self.proj(self.dropout(slots)) # (B, nvars, seq_len) return out.permute(0, 2, 1) # (B, seq_len, nvars) class AttnPoolHeadCls(nn.Module): """ 分类头:每变量先门控注意力聚合到 R 个slots,拼接所有变量后线性分类。 输出:(B, num_class) """ def __init__(self, d_model: int, n_vars: int, num_class: int, num_slots: int = 4, dropout: float = 0.1): super().__init__() self.agg = GatedAttnAggregator(d_model, num_slots=num_slots, dropout=dropout) self.dropout = nn.Dropout(dropout) self.proj = nn.Sequential( nn.LayerNorm(n_vars * num_slots * d_model), nn.Linear(n_vars * num_slots * d_model, num_class), ) self.n_vars = n_vars self.num_slots = num_slots self.d_model = d_model def forward(self, enc_out_bt_t_d: torch.Tensor, key_padding_mask_bt: torch.Tensor, n_vars: int, B: int): slots = self.agg(enc_out_bt_t_d, key_padding_mask_bt) # (B*, R, D) slots = slots.reshape(B, n_vars, self.num_slots * self.d_model) # (B, nvars, R*D) flat = self.dropout(slots.reshape(B, -1)) # (B, nvars*R*D) return self.proj(flat) # -------------------- 主模型:两层DC(比率控制) + 带mask的Encoder + 门控聚合头 -------------------- class Transpose(nn.Module): def __init__(self, *dims, contiguous=False): super().__init__() self.dims, self.contiguous = dims, contiguous def forward(self, x): return x.transpose(*self.dims).contiguous() if self.contiguous else x.transpose(*self.dims) class Model(nn.Module): """ PatchTST with DC and masked attention + gated heads: - 用两层 Mamba2 编码器 + 动态分块 替代 PatchEmbedding - DC 使用 ratio loss(target_ratio0/1)控制压缩强度;随层级加深,序列变短,d_model 变大(D0->D1) - 注意力传入 key_padding_mask 屏蔽pad - 头部使用门控注意力聚合(不依赖token数,信息保留更充分) """ def __init__( self, configs, d_model_stage0=None, # D0,默认= d_model // 2 depth_enc0=1, depth_enc1=1, target_ratio0=0.25, # 约等于 1/N0 target_ratio1=0.5, # 约等于 1/N1 agg_slots=4, # 门控聚合的slot数 ): super().__init__() self.task_name = configs.task_name self.seq_len = configs.seq_len self.pred_len = configs.pred_len self.enc_in = configs.enc_in # DC 嵌入 D1 = configs.d_model D0 = d_model_stage0 if d_model_stage0 is not None else max(16, D1 // 2) assert D1 >= D0, "要求 D0 <= D1" self.dc_embedding = DCEmbedding2StageVarLen( d_model_out=D1, d_model_stage0=D0, depth_enc0=depth_enc0, depth_enc1=depth_enc1, dropout=configs.dropout, target_ratio0=target_ratio0, target_ratio1=target_ratio1, ) # 带mask的Encoder attn_layers = [ EncoderLayerWithMask( AttentionLayer( FullAttention(False, configs.factor, attention_dropout=configs.dropout, output_attention=False), D1, configs.n_heads ), d_model=D1, d_ff=configs.d_ff, dropout=configs.dropout, activation=configs.activation ) for _ in range(configs.e_layers) ] self.encoder = EncoderWithMask( attn_layers, norm_layer=nn.Sequential(Transpose(1, 2), nn.BatchNorm1d(D1), Transpose(1, 2)) ) # 门控聚合头(与token数无关) if self.task_name in ('long_term_forecast', 'short_term_forecast'): self.head = AttnPoolHeadForecast(D1, self.pred_len, num_slots=agg_slots, dropout=configs.dropout) elif self.task_name in ('imputation', 'anomaly_detection'): self.head = AttnPoolHeadSeq(D1, self.seq_len, num_slots=agg_slots, dropout=configs.dropout) elif self.task_name == 'classification': self.head_cls = AttnPoolHeadCls(D1, n_vars=self.enc_in, num_class=configs.num_class, num_slots=agg_slots, dropout=configs.dropout) # --------- 归一化/反归一化 --------- def _pre_norm(self, x): means = x.mean(1, keepdim=True).detach() x = x - means stdev = torch.sqrt(torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5) x = x / stdev return x, means, stdev def _denorm(self, y, means, stdev, length): return y * (stdev[:, 0, :].unsqueeze(1).repeat(1, length, 1)) + \ (means[:, 0, :].unsqueeze(1).repeat(1, length, 1)) # --------- DC + Transformer Encoder(携带 key_padding_mask) ---------- def _embed_and_encode(self, x_enc): """ x_enc: (B, L, C) 返回: enc_out: (B*nvars, T_max, D1) n_vars: int key_padding_mask: (B*nvars, T_max) aux: dict """ B, L, C = x_enc.shape x_vars = x_enc.permute(0, 2, 1) # (B, nvars, L) enc_out, key_padding_mask, n_vars, aux = self.dc_embedding(x_vars) enc_out, _ = self.encoder(enc_out, attn_mask=None, key_padding_mask=key_padding_mask) return enc_out, n_vars, key_padding_mask, B, aux # --------- 各任务前向 --------- def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): x_enc, means, stdev = self._pre_norm(x_enc) enc_out, n_vars, key_padding_mask, B, aux = self._embed_and_encode(x_enc) dec_out = self.head(enc_out, key_padding_mask, n_vars, B) # (B, pred_len, nvars) dec_out = self._denorm(dec_out, means, stdev, self.pred_len) return dec_out, aux def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask): means = torch.sum(x_enc, dim=1) / torch.sum(mask == 1, dim=1) means = means.unsqueeze(1).detach() x = x_enc - means x = x.masked_fill(mask == 0, 0) stdev = torch.sqrt(torch.sum(x * x, dim=1) / torch.sum(mask == 1, dim=1) + 1e-5) stdev = stdev.unsqueeze(1).detach() x = x / stdev enc_out, n_vars, key_padding_mask, B, aux = self._embed_and_encode(x) dec_out = self.head(enc_out, key_padding_mask, n_vars, B) # (B, seq_len, nvars) dec_out = self._denorm(dec_out, means, stdev, self.seq_len) return dec_out, aux def anomaly_detection(self, x_enc): x_enc, means, stdev = self._pre_norm(x_enc) enc_out, n_vars, key_padding_mask, B, aux = self._embed_and_encode(x_enc) dec_out = self.head(enc_out, key_padding_mask, n_vars, B) # (B, seq_len, nvars) dec_out = self._denorm(dec_out, means, stdev, self.seq_len) return dec_out, aux def classification(self, x_enc, x_mark_enc): x_enc, _, _ = self._pre_norm(x_enc) enc_out, n_vars, key_padding_mask, B, aux = self._embed_and_encode(x_enc) logits = self.head_cls(enc_out, key_padding_mask, n_vars, B) # (B, num_class) return logits, aux def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): if self.task_name in ('long_term_forecast', 'short_term_forecast'): dec_out, aux = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) return dec_out[:, -self.pred_len:, :], aux # [B, L, D], aux含ratio losses if self.task_name == 'imputation': dec_out, aux = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask) return dec_out, aux if self.task_name == 'anomaly_detection': dec_out, aux = self.anomaly_detection(x_enc) return dec_out, aux if self.task_name == 'classification': logits, aux = self.classification(x_enc, x_mark_enc) return logits, aux return None, None