feat: add DiffusionTimeSeries and iTransformer models, introduce xPatch_SparseChannel

This commit is contained in:
game-loader
2025-08-26 20:53:35 +08:00
parent 44bd5c8f29
commit c3713f5c0b
11 changed files with 1528 additions and 41 deletions

View File

@ -0,0 +1,3 @@
from .xPatch import Model
__all__ = ['Model']

View File

@ -0,0 +1,66 @@
"""
SeasonPatch = PatchTST (CI) + ChannelGraphMixer + 线性预测头
"""
import torch
import torch.nn as nn
from layers.PatchTST_layers import positional_encoding # 已存在
from layers.mixer import HierarchicalGraphMixer # 刚才创建
# ------ PatchTST CI 编码器(与官方实现等价, 但去掉 head 便于插入 Mixer) ------
class _Encoder(nn.Module):
def __init__(self, c_in, seq_len, patch_len, stride,
d_model=128, n_layers=3, n_heads=16, dropout=0.):
super().__init__()
self.patch_len = patch_len
self.stride = stride
self.patch_num = (seq_len - patch_len)//stride + 1
self.proj = nn.Linear(patch_len, d_model)
self.pos = positional_encoding('zeros', True, self.patch_num, d_model)
self.drop = nn.Dropout(dropout)
from layers.PatchTST_backbone import TSTEncoder # 与官方同名
self.encoder = TSTEncoder(self.patch_num, d_model, n_heads,
d_ff=d_model*2, dropout=dropout,
n_layers=n_layers, norm='LayerNorm')
def forward(self, x): # [B,C,L]
x = x.unfold(-1, self.patch_len, self.stride) # [B,C,patch_num,patch_len]
B,C,N,P = x.shape
z = self.proj(x) # [B,C,N,d]
z = z.contiguous().view(B*C, N, -1) # [B*C,N,d]
z = self.drop(z + self.pos)
z = self.encoder(z) # [B*C,N,d]
return z.view(B,C,N,-1) # [B,C,N,d]
# ------------------------- SeasonPatch -----------------------------
class SeasonPatch(nn.Module):
def __init__(self,
c_in: int,
seq_len: int,
pred_len: int,
patch_len: int,
stride: int,
k_graph: int = 5,
d_model: int = 128,
revin: bool = True):
super().__init__()
self.encoder = _Encoder(c_in, seq_len, patch_len, stride,
d_model=d_model)
self.mixer = HierarchicalGraphMixer(c_in, dim=d_model, k=k_graph)
# Calculate actual number of patches
patch_num = (seq_len - patch_len) // stride + 1
self.head = nn.Linear(patch_num * d_model, pred_len)
def forward(self, x): # x [B,L,C]
x = x.permute(0,2,1) # → [B,C,L]
z = self.encoder(x) # [B,C,N,d]
B,C,N,D = z.shape
# 通道独立 -> 稀疏跨通道注入
z_mix = self.mixer(z).view(B,C,N*D) # [B,C,N,d]
y_pred = self.head(z_mix) # [B,C,T]
return y_pred

View File

@ -0,0 +1,56 @@
import torch
import torch.nn as nn
import math
from layers.decomp import DECOMP
from .xPatch_SparseChannel import Network
from layers.revin import RevIN
class Model(nn.Module):
def __init__(self, configs):
super(Model, self).__init__()
# Parameters
seq_len = configs.seq_len # lookback window L
pred_len = configs.pred_len # prediction length (96, 192, 336, 720)
c_in = configs.enc_in # input channels
# Patching
patch_len = configs.patch_len
stride = configs.stride
padding_patch = configs.padding_patch
# Normalization
self.revin = configs.revin
self.revin_layer = RevIN(c_in,affine=True,subtract_last=False)
# Moving Average
self.ma_type = configs.ma_type
alpha = configs.alpha # smoothing factor for EMA (Exponential Moving Average)
beta = configs.beta # smoothing factor for DEMA (Double Exponential Moving Average)
self.decomp = DECOMP(self.ma_type, alpha, beta)
self.net = Network(seq_len, pred_len, patch_len, stride, padding_patch,c_in)
# self.net_mlp = NetworkMLP(seq_len, pred_len) # For ablation study with MLP-only stream
# self.net_cnn = NetworkCNN(seq_len, pred_len, patch_len, stride, padding_patch) # For ablation study with CNN-only stream
def forward(self, x):
# x: [Batch, Input, Channel]
# Normalization
if self.revin:
x = self.revin_layer(x, 'norm')
if self.ma_type == 'reg': # If no decomposition, directly pass the input to the network
x = self.net(x, x)
# x = self.net_mlp(x) # For ablation study with MLP-only stream
# x = self.net_cnn(x) # For ablation study with CNN-only stream
else:
seasonal_init, trend_init = self.decomp(x)
x = self.net(seasonal_init, trend_init)
# Denormalization
if self.revin:
x = self.revin_layer(x, 'denorm')
return x

View File

@ -0,0 +1,63 @@
import torch
from torch import nn
from .patchtst_ci import SeasonPatch # <<< 新导入
class Network(nn.Module):
"""
trend : 原 MLP 线性流 (完全保留)
season : SeasonPatch (PatchTST + Mixer)
"""
def __init__(self, seq_len, pred_len,
patch_len, stride, padding_patch,
c_in):
super().__init__()
# -------- 季节性流 ---------------
self.season_net = SeasonPatch(c_in=c_in,
seq_len=seq_len,
pred_len=pred_len,
patch_len=patch_len,
stride=stride,
k_graph=5,
d_model=128)
# --------- 线性趋势流 (原代码保持不变) ----------
self.pred_len = pred_len
self.fc5 = nn.Linear(seq_len, pred_len * 4)
self.avgpool1 = nn.AvgPool1d(kernel_size=2)
self.ln1 = nn.LayerNorm(pred_len * 2)
self.fc6 = nn.Linear(pred_len * 2, pred_len)
self.avgpool2 = nn.AvgPool1d(kernel_size=2)
self.ln2 = nn.LayerNorm(pred_len // 2)
self.fc7 = nn.Linear(pred_len // 2, pred_len)
# 流结果拼接
self.fc_final = nn.Linear(pred_len * 2, pred_len)
# ---------------- forward --------------------
def forward(self, s, t):
# 输入形状: [B,L,C]
B,L,C = s.shape
# ---------- Seasonality ------------
y_season = self.season_net(s) # [B,C,T]
# ---------- Trend (原 MLP) ----------
t = t.permute(0,2,1).reshape(B*C, L) # [B*C,L]
t = self.fc5(t)
t = self.avgpool1(t)
t = self.ln1(t)
t = self.fc6(t)
t = self.avgpool2(t)
t = self.ln2(t)
t = self.fc7(t) # [B*C,T]
y_trend = t.view(B, C, -1) # [B,C,T]
# --------- 拼接 & 输出 --------------
y = torch.cat([y_season, y_trend], dim=-1) # [B,C,2T]
y = self.fc_final(y) # [B,C,T]
y = y.permute(0,2,1) # [B,T,C]
return y