67 lines
2.8 KiB
Python
67 lines
2.8 KiB
Python
"""
|
|
SeasonPatch = PatchTST (CI) + ChannelGraphMixer + 线性预测头
|
|
"""
|
|
import torch
|
|
import torch.nn as nn
|
|
from layers.PatchTST_layers import positional_encoding # 已存在
|
|
from layers.mixer import HierarchicalGraphMixer # 刚才创建
|
|
|
|
# ------ PatchTST CI 编码器(与官方实现等价, 但去掉 head 便于插入 Mixer) ------
|
|
class _Encoder(nn.Module):
|
|
def __init__(self, c_in, seq_len, patch_len, stride,
|
|
d_model=128, n_layers=3, n_heads=16, dropout=0.):
|
|
super().__init__()
|
|
self.patch_len = patch_len
|
|
self.stride = stride
|
|
self.patch_num = (seq_len - patch_len)//stride + 1
|
|
|
|
self.proj = nn.Linear(patch_len, d_model)
|
|
self.pos = positional_encoding('zeros', True, self.patch_num, d_model)
|
|
self.drop = nn.Dropout(dropout)
|
|
|
|
from layers.PatchTST_backbone import TSTEncoder # 与官方同名
|
|
self.encoder = TSTEncoder(self.patch_num, d_model, n_heads,
|
|
d_ff=d_model*2, dropout=dropout,
|
|
n_layers=n_layers, norm='LayerNorm')
|
|
|
|
def forward(self, x): # [B,C,L]
|
|
x = x.unfold(-1, self.patch_len, self.stride) # [B,C,patch_num,patch_len]
|
|
B,C,N,P = x.shape
|
|
z = self.proj(x) # [B,C,N,d]
|
|
z = z.contiguous().view(B*C, N, -1) # [B*C,N,d]
|
|
z = self.drop(z + self.pos)
|
|
z = self.encoder(z) # [B*C,N,d]
|
|
return z.view(B,C,N,-1) # [B,C,N,d]
|
|
|
|
# ------------------------- SeasonPatch -----------------------------
|
|
class SeasonPatch(nn.Module):
|
|
def __init__(self,
|
|
c_in: int,
|
|
seq_len: int,
|
|
pred_len: int,
|
|
patch_len: int,
|
|
stride: int,
|
|
k_graph: int = 5,
|
|
d_model: int = 128,
|
|
revin: bool = True):
|
|
super().__init__()
|
|
|
|
self.encoder = _Encoder(c_in, seq_len, patch_len, stride,
|
|
d_model=d_model)
|
|
self.mixer = HierarchicalGraphMixer(c_in, dim=d_model, k=k_graph)
|
|
|
|
# Calculate actual number of patches
|
|
patch_num = (seq_len - patch_len) // stride + 1
|
|
self.head = nn.Linear(patch_num * d_model, pred_len)
|
|
|
|
def forward(self, x): # x [B,L,C]
|
|
x = x.permute(0,2,1) # → [B,C,L]
|
|
z = self.encoder(x) # [B,C,N,d]
|
|
B,C,N,D = z.shape
|
|
# 通道独立 -> 稀疏跨通道注入
|
|
z_mix = self.mixer(z).view(B,C,N*D) # [B,C,N,d]
|
|
y_pred = self.head(z_mix) # [B,C,T]
|
|
|
|
return y_pred
|
|
|