Files
tsmodel/models/xPatch_SparseChannel/patchtst_ci.py

67 lines
2.8 KiB
Python

"""
SeasonPatch = PatchTST (CI) + ChannelGraphMixer + 线性预测头
"""
import torch
import torch.nn as nn
from layers.PatchTST_layers import positional_encoding # 已存在
from layers.mixer import HierarchicalGraphMixer # 刚才创建
# ------ PatchTST CI 编码器(与官方实现等价, 但去掉 head 便于插入 Mixer) ------
class _Encoder(nn.Module):
def __init__(self, c_in, seq_len, patch_len, stride,
d_model=128, n_layers=3, n_heads=16, dropout=0.):
super().__init__()
self.patch_len = patch_len
self.stride = stride
self.patch_num = (seq_len - patch_len)//stride + 1
self.proj = nn.Linear(patch_len, d_model)
self.pos = positional_encoding('zeros', True, self.patch_num, d_model)
self.drop = nn.Dropout(dropout)
from layers.PatchTST_backbone import TSTEncoder # 与官方同名
self.encoder = TSTEncoder(self.patch_num, d_model, n_heads,
d_ff=d_model*2, dropout=dropout,
n_layers=n_layers, norm='LayerNorm')
def forward(self, x): # [B,C,L]
x = x.unfold(-1, self.patch_len, self.stride) # [B,C,patch_num,patch_len]
B,C,N,P = x.shape
z = self.proj(x) # [B,C,N,d]
z = z.contiguous().view(B*C, N, -1) # [B*C,N,d]
z = self.drop(z + self.pos)
z = self.encoder(z) # [B*C,N,d]
return z.view(B,C,N,-1) # [B,C,N,d]
# ------------------------- SeasonPatch -----------------------------
class SeasonPatch(nn.Module):
def __init__(self,
c_in: int,
seq_len: int,
pred_len: int,
patch_len: int,
stride: int,
k_graph: int = 5,
d_model: int = 128,
revin: bool = True):
super().__init__()
self.encoder = _Encoder(c_in, seq_len, patch_len, stride,
d_model=d_model)
self.mixer = HierarchicalGraphMixer(c_in, dim=d_model, k=k_graph)
# Calculate actual number of patches
patch_num = (seq_len - patch_len) // stride + 1
self.head = nn.Linear(patch_num * d_model, pred_len)
def forward(self, x): # x [B,L,C]
x = x.permute(0,2,1) # → [B,C,L]
z = self.encoder(x) # [B,C,N,d]
B,C,N,D = z.shape
# 通道独立 -> 稀疏跨通道注入
z_mix = self.mixer(z).view(B,C,N*D) # [B,C,N,d]
y_pred = self.head(z_mix) # [B,C,T]
return y_pred