feat: add DiffusionTimeSeries and iTransformer models, introduce xPatch_SparseChannel
This commit is contained in:
63
models/xPatch_SparseChannel/xPatch_SparseChannel.py
Normal file
63
models/xPatch_SparseChannel/xPatch_SparseChannel.py
Normal file
@ -0,0 +1,63 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from .patchtst_ci import SeasonPatch # <<< 新导入
|
||||
|
||||
class Network(nn.Module):
|
||||
"""
|
||||
trend : 原 MLP 线性流 (完全保留)
|
||||
season : SeasonPatch (PatchTST + Mixer)
|
||||
"""
|
||||
def __init__(self, seq_len, pred_len,
|
||||
patch_len, stride, padding_patch,
|
||||
c_in):
|
||||
super().__init__()
|
||||
|
||||
# -------- 季节性流 ---------------
|
||||
self.season_net = SeasonPatch(c_in=c_in,
|
||||
seq_len=seq_len,
|
||||
pred_len=pred_len,
|
||||
patch_len=patch_len,
|
||||
stride=stride,
|
||||
k_graph=5,
|
||||
d_model=128)
|
||||
|
||||
# --------- 线性趋势流 (原代码保持不变) ----------
|
||||
self.pred_len = pred_len
|
||||
self.fc5 = nn.Linear(seq_len, pred_len * 4)
|
||||
self.avgpool1 = nn.AvgPool1d(kernel_size=2)
|
||||
self.ln1 = nn.LayerNorm(pred_len * 2)
|
||||
|
||||
self.fc6 = nn.Linear(pred_len * 2, pred_len)
|
||||
self.avgpool2 = nn.AvgPool1d(kernel_size=2)
|
||||
self.ln2 = nn.LayerNorm(pred_len // 2)
|
||||
|
||||
self.fc7 = nn.Linear(pred_len // 2, pred_len)
|
||||
|
||||
# 流结果拼接
|
||||
self.fc_final = nn.Linear(pred_len * 2, pred_len)
|
||||
|
||||
# ---------------- forward --------------------
|
||||
def forward(self, s, t):
|
||||
# 输入形状: [B,L,C]
|
||||
B,L,C = s.shape
|
||||
|
||||
# ---------- Seasonality ------------
|
||||
y_season = self.season_net(s) # [B,C,T]
|
||||
|
||||
# ---------- Trend (原 MLP) ----------
|
||||
t = t.permute(0,2,1).reshape(B*C, L) # [B*C,L]
|
||||
t = self.fc5(t)
|
||||
t = self.avgpool1(t)
|
||||
t = self.ln1(t)
|
||||
t = self.fc6(t)
|
||||
t = self.avgpool2(t)
|
||||
t = self.ln2(t)
|
||||
t = self.fc7(t) # [B*C,T]
|
||||
y_trend = t.view(B, C, -1) # [B,C,T]
|
||||
|
||||
# --------- 拼接 & 输出 --------------
|
||||
y = torch.cat([y_season, y_trend], dim=-1) # [B,C,2T]
|
||||
y = self.fc_final(y) # [B,C,T]
|
||||
y = y.permute(0,2,1) # [B,T,C]
|
||||
return y
|
||||
|
||||
Reference in New Issue
Block a user