64 lines
2.1 KiB
Python
64 lines
2.1 KiB
Python
import torch
|
|
from torch import nn
|
|
from .patchtst_ci import SeasonPatch # <<< 新导入
|
|
|
|
class Network(nn.Module):
|
|
"""
|
|
trend : 原 MLP 线性流 (完全保留)
|
|
season : SeasonPatch (PatchTST + Mixer)
|
|
"""
|
|
def __init__(self, seq_len, pred_len,
|
|
patch_len, stride, padding_patch,
|
|
c_in):
|
|
super().__init__()
|
|
|
|
# -------- 季节性流 ---------------
|
|
self.season_net = SeasonPatch(c_in=c_in,
|
|
seq_len=seq_len,
|
|
pred_len=pred_len,
|
|
patch_len=patch_len,
|
|
stride=stride,
|
|
k_graph=5,
|
|
d_model=128)
|
|
|
|
# --------- 线性趋势流 (原代码保持不变) ----------
|
|
self.pred_len = pred_len
|
|
self.fc5 = nn.Linear(seq_len, pred_len * 4)
|
|
self.avgpool1 = nn.AvgPool1d(kernel_size=2)
|
|
self.ln1 = nn.LayerNorm(pred_len * 2)
|
|
|
|
self.fc6 = nn.Linear(pred_len * 2, pred_len)
|
|
self.avgpool2 = nn.AvgPool1d(kernel_size=2)
|
|
self.ln2 = nn.LayerNorm(pred_len // 2)
|
|
|
|
self.fc7 = nn.Linear(pred_len // 2, pred_len)
|
|
|
|
# 流结果拼接
|
|
self.fc_final = nn.Linear(pred_len * 2, pred_len)
|
|
|
|
# ---------------- forward --------------------
|
|
def forward(self, s, t):
|
|
# 输入形状: [B,L,C]
|
|
B,L,C = s.shape
|
|
|
|
# ---------- Seasonality ------------
|
|
y_season = self.season_net(s) # [B,C,T]
|
|
|
|
# ---------- Trend (原 MLP) ----------
|
|
t = t.permute(0,2,1).reshape(B*C, L) # [B*C,L]
|
|
t = self.fc5(t)
|
|
t = self.avgpool1(t)
|
|
t = self.ln1(t)
|
|
t = self.fc6(t)
|
|
t = self.avgpool2(t)
|
|
t = self.ln2(t)
|
|
t = self.fc7(t) # [B*C,T]
|
|
y_trend = t.view(B, C, -1) # [B,C,T]
|
|
|
|
# --------- 拼接 & 输出 --------------
|
|
y = torch.cat([y_season, y_trend], dim=-1) # [B,C,2T]
|
|
y = self.fc_final(y) # [B,C,T]
|
|
y = y.permute(0,2,1) # [B,T,C]
|
|
return y
|
|
|