Files
tsmodel/models/xPatch_SparseChannel/xPatch_SparseChannel.py

64 lines
2.1 KiB
Python

import torch
from torch import nn
from .patchtst_ci import SeasonPatch # <<< 新导入
class Network(nn.Module):
"""
trend : 原 MLP 线性流 (完全保留)
season : SeasonPatch (PatchTST + Mixer)
"""
def __init__(self, seq_len, pred_len,
patch_len, stride, padding_patch,
c_in):
super().__init__()
# -------- 季节性流 ---------------
self.season_net = SeasonPatch(c_in=c_in,
seq_len=seq_len,
pred_len=pred_len,
patch_len=patch_len,
stride=stride,
k_graph=5,
d_model=128)
# --------- 线性趋势流 (原代码保持不变) ----------
self.pred_len = pred_len
self.fc5 = nn.Linear(seq_len, pred_len * 4)
self.avgpool1 = nn.AvgPool1d(kernel_size=2)
self.ln1 = nn.LayerNorm(pred_len * 2)
self.fc6 = nn.Linear(pred_len * 2, pred_len)
self.avgpool2 = nn.AvgPool1d(kernel_size=2)
self.ln2 = nn.LayerNorm(pred_len // 2)
self.fc7 = nn.Linear(pred_len // 2, pred_len)
# 流结果拼接
self.fc_final = nn.Linear(pred_len * 2, pred_len)
# ---------------- forward --------------------
def forward(self, s, t):
# 输入形状: [B,L,C]
B,L,C = s.shape
# ---------- Seasonality ------------
y_season = self.season_net(s) # [B,C,T]
# ---------- Trend (原 MLP) ----------
t = t.permute(0,2,1).reshape(B*C, L) # [B*C,L]
t = self.fc5(t)
t = self.avgpool1(t)
t = self.ln1(t)
t = self.fc6(t)
t = self.avgpool2(t)
t = self.ln2(t)
t = self.fc7(t) # [B*C,T]
y_trend = t.view(B, C, -1) # [B,C,T]
# --------- 拼接 & 输出 --------------
y = torch.cat([y_season, y_trend], dim=-1) # [B,C,2T]
y = self.fc_final(y) # [B,C,T]
y = y.permute(0,2,1) # [B,T,C]
return y