import torch from torch import nn from .patchtst_ci import SeasonPatch # <<< 新导入 class Network(nn.Module): """ trend : 原 MLP 线性流 (完全保留) season : SeasonPatch (PatchTST + Mixer) """ def __init__(self, seq_len, pred_len, patch_len, stride, padding_patch, c_in): super().__init__() # -------- 季节性流 --------------- self.season_net = SeasonPatch(c_in=c_in, seq_len=seq_len, pred_len=pred_len, patch_len=patch_len, stride=stride, k_graph=5, d_model=128) # --------- 线性趋势流 (原代码保持不变) ---------- self.pred_len = pred_len self.fc5 = nn.Linear(seq_len, pred_len * 4) self.avgpool1 = nn.AvgPool1d(kernel_size=2) self.ln1 = nn.LayerNorm(pred_len * 2) self.fc6 = nn.Linear(pred_len * 2, pred_len) self.avgpool2 = nn.AvgPool1d(kernel_size=2) self.ln2 = nn.LayerNorm(pred_len // 2) self.fc7 = nn.Linear(pred_len // 2, pred_len) # 流结果拼接 self.fc_final = nn.Linear(pred_len * 2, pred_len) # ---------------- forward -------------------- def forward(self, s, t): # 输入形状: [B,L,C] B,L,C = s.shape # ---------- Seasonality ------------ y_season = self.season_net(s) # [B,C,T] # ---------- Trend (原 MLP) ---------- t = t.permute(0,2,1).reshape(B*C, L) # [B*C,L] t = self.fc5(t) t = self.avgpool1(t) t = self.ln1(t) t = self.fc6(t) t = self.avgpool2(t) t = self.ln2(t) t = self.fc7(t) # [B*C,T] y_trend = t.view(B, C, -1) # [B,C,T] # --------- 拼接 & 输出 -------------- y = torch.cat([y_season, y_trend], dim=-1) # [B,C,2T] y = self.fc_final(y) # [B,C,T] y = y.permute(0,2,1) # [B,T,C] return y