""" SeasonPatch = PatchTST (CI) + ChannelGraphMixer + Linear prediction head Adapted for Time-Series-Library-main style """ import torch import torch.nn as nn from layers.TSTEncoder import TSTiEncoder from layers.GraphMixer import HierarchicalGraphMixer class SeasonPatch(nn.Module): def __init__(self, c_in: int, seq_len: int, pred_len: int, patch_len: int, stride: int, k_graph: int = 8, d_model: int = 128, n_layers: int = 3, n_heads: int = 16): super().__init__() # Store patch parameters self.patch_len = patch_len self.stride = stride # Calculate patch number patch_num = (seq_len - patch_len) // stride + 1 # PatchTST encoder (channel independent) self.encoder = TSTiEncoder( c_in=c_in, patch_num=patch_num, patch_len=patch_len, d_model=d_model, n_layers=n_layers, n_heads=n_heads ) # Cross-channel mixer self.mixer = HierarchicalGraphMixer(c_in, dim=d_model, k=k_graph) # Prediction head self.head = nn.Sequential( nn.Linear(patch_num * d_model, patch_num * d_model), nn.SiLU(), # 非线性激活(SiLU/Swish) nn.Linear(patch_num * d_model, pred_len) ) def forward(self, x): # x: [B, L, C] x = x.permute(0, 2, 1) # → [B, C, L] # Patch the input x_patch = x.unfold(-1, self.patch_len, self.stride) # [B, C, patch_num, patch_len] # Encode patches z = self.encoder(x_patch) # [B, C, d_model, patch_num] # z: [B, C, d_model, patch_num] → [B, C, patch_num, d_model] B, C, D, N = z.shape z = z.permute(0, 1, 3, 2) # [B, C, patch_num, d_model] # Cross-channel mixing z_mix = self.mixer(z) # [B, C, patch_num, d_model] # Flatten and predict z_mix = z_mix.view(B, C, N * D) # [B, C, patch_num * d_model] y_pred = self.head(z_mix) # [B, C, pred_len] return y_pred