feat: add DiffusionTimeSeries and iTransformer models, introduce xPatch_SparseChannel
This commit is contained in:
3
models/xPatch_SparseChannel/__init__.py
Normal file
3
models/xPatch_SparseChannel/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .xPatch import Model
|
||||
|
||||
__all__ = ['Model']
|
||||
66
models/xPatch_SparseChannel/patchtst_ci.py
Normal file
66
models/xPatch_SparseChannel/patchtst_ci.py
Normal file
@ -0,0 +1,66 @@
|
||||
"""
|
||||
SeasonPatch = PatchTST (CI) + ChannelGraphMixer + 线性预测头
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from layers.PatchTST_layers import positional_encoding # 已存在
|
||||
from layers.mixer import HierarchicalGraphMixer # 刚才创建
|
||||
|
||||
# ------ PatchTST CI 编码器(与官方实现等价, 但去掉 head 便于插入 Mixer) ------
|
||||
class _Encoder(nn.Module):
|
||||
def __init__(self, c_in, seq_len, patch_len, stride,
|
||||
d_model=128, n_layers=3, n_heads=16, dropout=0.):
|
||||
super().__init__()
|
||||
self.patch_len = patch_len
|
||||
self.stride = stride
|
||||
self.patch_num = (seq_len - patch_len)//stride + 1
|
||||
|
||||
self.proj = nn.Linear(patch_len, d_model)
|
||||
self.pos = positional_encoding('zeros', True, self.patch_num, d_model)
|
||||
self.drop = nn.Dropout(dropout)
|
||||
|
||||
from layers.PatchTST_backbone import TSTEncoder # 与官方同名
|
||||
self.encoder = TSTEncoder(self.patch_num, d_model, n_heads,
|
||||
d_ff=d_model*2, dropout=dropout,
|
||||
n_layers=n_layers, norm='LayerNorm')
|
||||
|
||||
def forward(self, x): # [B,C,L]
|
||||
x = x.unfold(-1, self.patch_len, self.stride) # [B,C,patch_num,patch_len]
|
||||
B,C,N,P = x.shape
|
||||
z = self.proj(x) # [B,C,N,d]
|
||||
z = z.contiguous().view(B*C, N, -1) # [B*C,N,d]
|
||||
z = self.drop(z + self.pos)
|
||||
z = self.encoder(z) # [B*C,N,d]
|
||||
return z.view(B,C,N,-1) # [B,C,N,d]
|
||||
|
||||
# ------------------------- SeasonPatch -----------------------------
|
||||
class SeasonPatch(nn.Module):
|
||||
def __init__(self,
|
||||
c_in: int,
|
||||
seq_len: int,
|
||||
pred_len: int,
|
||||
patch_len: int,
|
||||
stride: int,
|
||||
k_graph: int = 5,
|
||||
d_model: int = 128,
|
||||
revin: bool = True):
|
||||
super().__init__()
|
||||
|
||||
self.encoder = _Encoder(c_in, seq_len, patch_len, stride,
|
||||
d_model=d_model)
|
||||
self.mixer = HierarchicalGraphMixer(c_in, dim=d_model, k=k_graph)
|
||||
|
||||
# Calculate actual number of patches
|
||||
patch_num = (seq_len - patch_len) // stride + 1
|
||||
self.head = nn.Linear(patch_num * d_model, pred_len)
|
||||
|
||||
def forward(self, x): # x [B,L,C]
|
||||
x = x.permute(0,2,1) # → [B,C,L]
|
||||
z = self.encoder(x) # [B,C,N,d]
|
||||
B,C,N,D = z.shape
|
||||
# 通道独立 -> 稀疏跨通道注入
|
||||
z_mix = self.mixer(z).view(B,C,N*D) # [B,C,N,d]
|
||||
y_pred = self.head(z_mix) # [B,C,T]
|
||||
|
||||
return y_pred
|
||||
|
||||
56
models/xPatch_SparseChannel/xPatch.py
Normal file
56
models/xPatch_SparseChannel/xPatch.py
Normal file
@ -0,0 +1,56 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
from layers.decomp import DECOMP
|
||||
from .xPatch_SparseChannel import Network
|
||||
from layers.revin import RevIN
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, configs):
|
||||
super(Model, self).__init__()
|
||||
|
||||
# Parameters
|
||||
seq_len = configs.seq_len # lookback window L
|
||||
pred_len = configs.pred_len # prediction length (96, 192, 336, 720)
|
||||
c_in = configs.enc_in # input channels
|
||||
|
||||
# Patching
|
||||
patch_len = configs.patch_len
|
||||
stride = configs.stride
|
||||
padding_patch = configs.padding_patch
|
||||
|
||||
# Normalization
|
||||
self.revin = configs.revin
|
||||
self.revin_layer = RevIN(c_in,affine=True,subtract_last=False)
|
||||
|
||||
# Moving Average
|
||||
self.ma_type = configs.ma_type
|
||||
alpha = configs.alpha # smoothing factor for EMA (Exponential Moving Average)
|
||||
beta = configs.beta # smoothing factor for DEMA (Double Exponential Moving Average)
|
||||
|
||||
self.decomp = DECOMP(self.ma_type, alpha, beta)
|
||||
self.net = Network(seq_len, pred_len, patch_len, stride, padding_patch,c_in)
|
||||
# self.net_mlp = NetworkMLP(seq_len, pred_len) # For ablation study with MLP-only stream
|
||||
# self.net_cnn = NetworkCNN(seq_len, pred_len, patch_len, stride, padding_patch) # For ablation study with CNN-only stream
|
||||
|
||||
def forward(self, x):
|
||||
# x: [Batch, Input, Channel]
|
||||
|
||||
# Normalization
|
||||
if self.revin:
|
||||
x = self.revin_layer(x, 'norm')
|
||||
|
||||
if self.ma_type == 'reg': # If no decomposition, directly pass the input to the network
|
||||
x = self.net(x, x)
|
||||
# x = self.net_mlp(x) # For ablation study with MLP-only stream
|
||||
# x = self.net_cnn(x) # For ablation study with CNN-only stream
|
||||
else:
|
||||
seasonal_init, trend_init = self.decomp(x)
|
||||
x = self.net(seasonal_init, trend_init)
|
||||
|
||||
# Denormalization
|
||||
if self.revin:
|
||||
x = self.revin_layer(x, 'denorm')
|
||||
|
||||
return x
|
||||
63
models/xPatch_SparseChannel/xPatch_SparseChannel.py
Normal file
63
models/xPatch_SparseChannel/xPatch_SparseChannel.py
Normal file
@ -0,0 +1,63 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from .patchtst_ci import SeasonPatch # <<< 新导入
|
||||
|
||||
class Network(nn.Module):
|
||||
"""
|
||||
trend : 原 MLP 线性流 (完全保留)
|
||||
season : SeasonPatch (PatchTST + Mixer)
|
||||
"""
|
||||
def __init__(self, seq_len, pred_len,
|
||||
patch_len, stride, padding_patch,
|
||||
c_in):
|
||||
super().__init__()
|
||||
|
||||
# -------- 季节性流 ---------------
|
||||
self.season_net = SeasonPatch(c_in=c_in,
|
||||
seq_len=seq_len,
|
||||
pred_len=pred_len,
|
||||
patch_len=patch_len,
|
||||
stride=stride,
|
||||
k_graph=5,
|
||||
d_model=128)
|
||||
|
||||
# --------- 线性趋势流 (原代码保持不变) ----------
|
||||
self.pred_len = pred_len
|
||||
self.fc5 = nn.Linear(seq_len, pred_len * 4)
|
||||
self.avgpool1 = nn.AvgPool1d(kernel_size=2)
|
||||
self.ln1 = nn.LayerNorm(pred_len * 2)
|
||||
|
||||
self.fc6 = nn.Linear(pred_len * 2, pred_len)
|
||||
self.avgpool2 = nn.AvgPool1d(kernel_size=2)
|
||||
self.ln2 = nn.LayerNorm(pred_len // 2)
|
||||
|
||||
self.fc7 = nn.Linear(pred_len // 2, pred_len)
|
||||
|
||||
# 流结果拼接
|
||||
self.fc_final = nn.Linear(pred_len * 2, pred_len)
|
||||
|
||||
# ---------------- forward --------------------
|
||||
def forward(self, s, t):
|
||||
# 输入形状: [B,L,C]
|
||||
B,L,C = s.shape
|
||||
|
||||
# ---------- Seasonality ------------
|
||||
y_season = self.season_net(s) # [B,C,T]
|
||||
|
||||
# ---------- Trend (原 MLP) ----------
|
||||
t = t.permute(0,2,1).reshape(B*C, L) # [B*C,L]
|
||||
t = self.fc5(t)
|
||||
t = self.avgpool1(t)
|
||||
t = self.ln1(t)
|
||||
t = self.fc6(t)
|
||||
t = self.avgpool2(t)
|
||||
t = self.ln2(t)
|
||||
t = self.fc7(t) # [B*C,T]
|
||||
y_trend = t.view(B, C, -1) # [B,C,T]
|
||||
|
||||
# --------- 拼接 & 输出 --------------
|
||||
y = torch.cat([y_season, y_trend], dim=-1) # [B,C,2T]
|
||||
y = self.fc_final(y) # [B,C,T]
|
||||
y = y.permute(0,2,1) # [B,T,C]
|
||||
return y
|
||||
|
||||
Reference in New Issue
Block a user