""" SeasonPatch = PatchTST (CI) + ChannelGraphMixer + Linear prediction head 支持两种编码器: - Transformer 编码器路径:PatchTST + GraphMixer + Head - Mamba2 编码器路径:Mamba2Encoder(不使用mixer),直接用最后得到的 d_model 走 Head """ import torch import torch.nn as nn from layers.TSTEncoder import TSTiEncoder from layers.GraphMixer import HierarchicalGraphMixer from layers.MambaSeries import Mamba2Encoder class SeasonPatch(nn.Module): def __init__(self, c_in: int, seq_len: int, pred_len: int, patch_len: int, stride: int, k_graph: int = 8, encoder_type: str = "Transformer", d_model: int = 128, n_layers: int = 3, n_heads: int = 16, # Mamba2 相关可选超参 d_state: int = 64, d_conv: int = 4, expand: int = 2, headdim: int = 64, # Mixergraph 可选超参数 thr_graph: float = 0.5, thr_graph_min: float = None, thr_graph_max: float = None, thr_graph_steps: int = 0, thr_graph_schedule: str = "linear", symmetric_graph: bool = True, degree_rescale: str = "count-sqrt", # "none" | "count" | "count-sqrt" | "sum" gate_temperature: float = 2./3., tau_attn: float = 1.0, l0_lambda: float = 1e-4): super().__init__() # ===== 新增:保存 l0_lambda,防止 reg_loss 访问报错 ===== self.l0_lambda = l0_lambda # Store patch parameters self.patch_len = patch_len # patch 长度 self.stride = stride # patch 步幅 # Calculate patch number patch_num = (seq_len - patch_len) // stride + 1 # patch_num: int self.encoder_type = encoder_type # 只初始化需要的encoder if encoder_type == "Transformer": # Transformer (PatchTST) 编码器(channel independent) self.encoder = TSTiEncoder( c_in=c_in, patch_num=patch_num, patch_len=patch_len, d_model=d_model, n_layers=n_layers, n_heads=n_heads ) # 集成新 HierarchicalGraphMixer(非归一化) self.mixer = HierarchicalGraphMixer( n_channel=c_in, dim=d_model, max_degree=k_graph, thr=thr_graph, thr_min=thr_graph_min, thr_max=thr_graph_max, thr_steps=thr_graph_steps, thr_schedule=thr_graph_schedule, temperature=gate_temperature, tau_attn=tau_attn, symmetric=symmetric_graph, degree_rescale=degree_rescale ) # Prediction head(Transformer 路径用到,输入维度为 patch_num * d_model) self.head = nn.Sequential( nn.Linear(patch_num * d_model, patch_num * d_model), nn.SiLU(), nn.Linear(patch_num * d_model, pred_len) ) elif encoder_type == "Mamba2": # Mamba2 编码器(channel independent),返回 [B, C, d_model] self.encoder = Mamba2Encoder( c_in=c_in, patch_num=patch_num, patch_len=patch_len, d_model=d_model, d_state=d_state, d_conv=d_conv, expand=expand, headdim=headdim, n_layers=n_layers ) # Prediction head(Mamba2 路径用到,输入维度为 d_model) self.head = nn.Sequential( nn.Linear(d_model, d_model), nn.SiLU(), nn.Linear(d_model, pred_len) ) else: raise ValueError(f"Unsupported encoder_type: {encoder_type}") def forward(self, x, training): # x: [B, L, C] x = x.permute(0, 2, 1) # x: [B, C, L] # Patch the input x_patch = x.unfold(-1, self.patch_len, self.stride) # x_patch: [B, C, patch_num, patch_len] if self.encoder_type == "Transformer": # Encode patches (PatchTST) z = self.encoder(x_patch) # z: [B, C, d_model, patch_num] # [B, C, d_model, patch_num] -> [B, C, patch_num, d_model] B, C, D, N = z.shape # B: batch, C: channels, D: d_model, N: patch_num z = z.permute(0, 1, 3, 2) # z: [B, C, patch_num, d_model] # Cross-channel mixing z_mix = self.mixer(z, training) # z_mix: [B, C, patch_num, d_model] # Flatten and predict z_mix = z_mix.view(B, C, N * D) # z_mix: [B, C, patch_num * d_model] y_pred = self.head(z_mix) # y_pred: [B, C, pred_len] elif self.encoder_type == "Mamba2": # 使用 Mamba2 编码器(不使用 mixer) z_last = self.encoder(x_patch) # z_last: [B, C, d_model](仅最后一个时间步) y_pred = self.head(z_last) # y_pred: [B, C, pred_len] return y_pred # [B, C, pred_len] def reg_loss(self): """ 可选:把 L0 正则暴露出去,训练时加到总loss。 """ if self.encoder_type == "Transformer" and hasattr(self, "mixer"): return self.mixer.l0_loss(self.l0_lambda) return torch.tensor(0.0, device=self.head[0].weight.device)