""" SeasonPatch = PatchTST (CI) + ChannelGraphMixer + 线性预测头 """ import torch import torch.nn as nn from layers.PatchTST_layers import positional_encoding # 已存在 from layers.mixer import HierarchicalGraphMixer # 刚才创建 # ------ PatchTST CI 编码器(与官方实现等价, 但去掉 head 便于插入 Mixer) ------ class _Encoder(nn.Module): def __init__(self, c_in, seq_len, patch_len, stride, d_model=128, n_layers=3, n_heads=16, dropout=0.): super().__init__() self.patch_len = patch_len self.stride = stride self.patch_num = (seq_len - patch_len)//stride + 1 self.proj = nn.Linear(patch_len, d_model) self.pos = positional_encoding('zeros', True, self.patch_num, d_model) self.drop = nn.Dropout(dropout) from layers.PatchTST_backbone import TSTEncoder # 与官方同名 self.encoder = TSTEncoder(self.patch_num, d_model, n_heads, d_ff=d_model*2, dropout=dropout, n_layers=n_layers, norm='LayerNorm') def forward(self, x): # [B,C,L] x = x.unfold(-1, self.patch_len, self.stride) # [B,C,patch_num,patch_len] B,C,N,P = x.shape z = self.proj(x) # [B,C,N,d] z = z.contiguous().view(B*C, N, -1) # [B*C,N,d] z = self.drop(z + self.pos) z = self.encoder(z) # [B*C,N,d] return z.view(B,C,N,-1) # [B,C,N,d] # ------------------------- SeasonPatch ----------------------------- class SeasonPatch(nn.Module): def __init__(self, c_in: int, seq_len: int, pred_len: int, patch_len: int, stride: int, k_graph: int = 5, d_model: int = 128, revin: bool = True): super().__init__() self.encoder = _Encoder(c_in, seq_len, patch_len, stride, d_model=d_model) self.mixer = HierarchicalGraphMixer(c_in, dim=d_model, k=k_graph) # Calculate actual number of patches patch_num = (seq_len - patch_len) // stride + 1 self.head = nn.Linear(patch_num * d_model, pred_len) def forward(self, x): # x [B,L,C] x = x.permute(0,2,1) # → [B,C,L] z = self.encoder(x) # [B,C,N,d] B,C,N,D = z.shape # 通道独立 -> 稀疏跨通道注入 z_mix = self.mixer(z).view(B,C,N*D) # [B,C,N,d] y_pred = self.head(z_mix) # [B,C,T] return y_pred