first commit
This commit is contained in:
67
layers/SeasonPatch.py
Normal file
67
layers/SeasonPatch.py
Normal file
@ -0,0 +1,67 @@
|
||||
"""
|
||||
SeasonPatch = PatchTST (CI) + ChannelGraphMixer + Linear prediction head
|
||||
Adapted for Time-Series-Library-main style
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from layers.TSTEncoder import TSTiEncoder
|
||||
from layers.GraphMixer import HierarchicalGraphMixer
|
||||
|
||||
class SeasonPatch(nn.Module):
|
||||
def __init__(self,
|
||||
c_in: int,
|
||||
seq_len: int,
|
||||
pred_len: int,
|
||||
patch_len: int,
|
||||
stride: int,
|
||||
k_graph: int = 8,
|
||||
d_model: int = 128,
|
||||
n_layers: int = 3,
|
||||
n_heads: int = 16):
|
||||
super().__init__()
|
||||
|
||||
# Store patch parameters
|
||||
self.patch_len = patch_len
|
||||
self.stride = stride
|
||||
|
||||
# Calculate patch number
|
||||
patch_num = (seq_len - patch_len) // stride + 1
|
||||
|
||||
# PatchTST encoder (channel independent)
|
||||
self.encoder = TSTiEncoder(
|
||||
c_in=c_in,
|
||||
patch_num=patch_num,
|
||||
patch_len=patch_len,
|
||||
d_model=d_model,
|
||||
n_layers=n_layers,
|
||||
n_heads=n_heads
|
||||
)
|
||||
|
||||
# Cross-channel mixer
|
||||
self.mixer = HierarchicalGraphMixer(c_in, dim=d_model, k=k_graph)
|
||||
|
||||
# Prediction head
|
||||
self.head = nn.Linear(patch_num * d_model, pred_len)
|
||||
|
||||
def forward(self, x):
|
||||
# x: [B, L, C]
|
||||
x = x.permute(0, 2, 1) # → [B, C, L]
|
||||
|
||||
# Patch the input
|
||||
x_patch = x.unfold(-1, self.patch_len, self.stride) # [B, C, patch_num, patch_len]
|
||||
|
||||
# Encode patches
|
||||
z = self.encoder(x_patch) # [B, C, d_model, patch_num]
|
||||
|
||||
# z: [B, C, d_model, patch_num] → [B, C, patch_num, d_model]
|
||||
B, C, D, N = z.shape
|
||||
z = z.permute(0, 1, 3, 2) # [B, C, patch_num, d_model]
|
||||
|
||||
# Cross-channel mixing
|
||||
z_mix = self.mixer(z) # [B, C, patch_num, d_model]
|
||||
|
||||
# Flatten and predict
|
||||
z_mix = z_mix.view(B, C, N * D) # [B, C, patch_num * d_model]
|
||||
y_pred = self.head(z_mix) # [B, C, pred_len]
|
||||
|
||||
return y_pred
|
Reference in New Issue
Block a user