Files
TSlib/layers/SeasonPatch.py

72 lines
2.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
SeasonPatch = PatchTST (CI) + ChannelGraphMixer + Linear prediction head
Adapted for Time-Series-Library-main style
"""
import torch
import torch.nn as nn
from layers.TSTEncoder import TSTiEncoder
from layers.GraphMixer import HierarchicalGraphMixer
class SeasonPatch(nn.Module):
def __init__(self,
c_in: int,
seq_len: int,
pred_len: int,
patch_len: int,
stride: int,
k_graph: int = 8,
d_model: int = 128,
n_layers: int = 3,
n_heads: int = 16):
super().__init__()
# Store patch parameters
self.patch_len = patch_len
self.stride = stride
# Calculate patch number
patch_num = (seq_len - patch_len) // stride + 1
# PatchTST encoder (channel independent)
self.encoder = TSTiEncoder(
c_in=c_in,
patch_num=patch_num,
patch_len=patch_len,
d_model=d_model,
n_layers=n_layers,
n_heads=n_heads
)
# Cross-channel mixer
self.mixer = HierarchicalGraphMixer(c_in, dim=d_model, k=k_graph)
# Prediction head
self.head = nn.Sequential(
nn.Linear(patch_num * d_model, patch_num * d_model),
nn.SiLU(), # 非线性激活SiLU/Swish
nn.Linear(patch_num * d_model, pred_len)
)
def forward(self, x):
# x: [B, L, C]
x = x.permute(0, 2, 1) # → [B, C, L]
# Patch the input
x_patch = x.unfold(-1, self.patch_len, self.stride) # [B, C, patch_num, patch_len]
# Encode patches
z = self.encoder(x_patch) # [B, C, d_model, patch_num]
# z: [B, C, d_model, patch_num] → [B, C, patch_num, d_model]
B, C, D, N = z.shape
z = z.permute(0, 1, 3, 2) # [B, C, patch_num, d_model]
# Cross-channel mixing
z_mix = self.mixer(z) # [B, C, patch_num, d_model]
# Flatten and predict
z_mix = z_mix.view(B, C, N * D) # [B, C, patch_num * d_model]
y_pred = self.head(z_mix) # [B, C, pred_len]
return y_pred