feat: add DiffusionTimeSeries and iTransformer models, introduce xPatch_SparseChannel
This commit is contained in:
66
models/xPatch_SparseChannel/patchtst_ci.py
Normal file
66
models/xPatch_SparseChannel/patchtst_ci.py
Normal file
@ -0,0 +1,66 @@
|
||||
"""
|
||||
SeasonPatch = PatchTST (CI) + ChannelGraphMixer + 线性预测头
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from layers.PatchTST_layers import positional_encoding # 已存在
|
||||
from layers.mixer import HierarchicalGraphMixer # 刚才创建
|
||||
|
||||
# ------ PatchTST CI 编码器(与官方实现等价, 但去掉 head 便于插入 Mixer) ------
|
||||
class _Encoder(nn.Module):
|
||||
def __init__(self, c_in, seq_len, patch_len, stride,
|
||||
d_model=128, n_layers=3, n_heads=16, dropout=0.):
|
||||
super().__init__()
|
||||
self.patch_len = patch_len
|
||||
self.stride = stride
|
||||
self.patch_num = (seq_len - patch_len)//stride + 1
|
||||
|
||||
self.proj = nn.Linear(patch_len, d_model)
|
||||
self.pos = positional_encoding('zeros', True, self.patch_num, d_model)
|
||||
self.drop = nn.Dropout(dropout)
|
||||
|
||||
from layers.PatchTST_backbone import TSTEncoder # 与官方同名
|
||||
self.encoder = TSTEncoder(self.patch_num, d_model, n_heads,
|
||||
d_ff=d_model*2, dropout=dropout,
|
||||
n_layers=n_layers, norm='LayerNorm')
|
||||
|
||||
def forward(self, x): # [B,C,L]
|
||||
x = x.unfold(-1, self.patch_len, self.stride) # [B,C,patch_num,patch_len]
|
||||
B,C,N,P = x.shape
|
||||
z = self.proj(x) # [B,C,N,d]
|
||||
z = z.contiguous().view(B*C, N, -1) # [B*C,N,d]
|
||||
z = self.drop(z + self.pos)
|
||||
z = self.encoder(z) # [B*C,N,d]
|
||||
return z.view(B,C,N,-1) # [B,C,N,d]
|
||||
|
||||
# ------------------------- SeasonPatch -----------------------------
|
||||
class SeasonPatch(nn.Module):
|
||||
def __init__(self,
|
||||
c_in: int,
|
||||
seq_len: int,
|
||||
pred_len: int,
|
||||
patch_len: int,
|
||||
stride: int,
|
||||
k_graph: int = 5,
|
||||
d_model: int = 128,
|
||||
revin: bool = True):
|
||||
super().__init__()
|
||||
|
||||
self.encoder = _Encoder(c_in, seq_len, patch_len, stride,
|
||||
d_model=d_model)
|
||||
self.mixer = HierarchicalGraphMixer(c_in, dim=d_model, k=k_graph)
|
||||
|
||||
# Calculate actual number of patches
|
||||
patch_num = (seq_len - patch_len) // stride + 1
|
||||
self.head = nn.Linear(patch_num * d_model, pred_len)
|
||||
|
||||
def forward(self, x): # x [B,L,C]
|
||||
x = x.permute(0,2,1) # → [B,C,L]
|
||||
z = self.encoder(x) # [B,C,N,d]
|
||||
B,C,N,D = z.shape
|
||||
# 通道独立 -> 稀疏跨通道注入
|
||||
z_mix = self.mixer(z).view(B,C,N*D) # [B,C,N,d]
|
||||
y_pred = self.head(z_mix) # [B,C,T]
|
||||
|
||||
return y_pred
|
||||
|
||||
Reference in New Issue
Block a user