refactor(graphmixer): enhance channel graph attention with ST-Gumbel

This commit is contained in:
game-loader
2025-09-06 00:06:26 +08:00
parent ef307a57e9
commit 9f7fb24beb
3 changed files with 165 additions and 111 deletions

View File

@ -41,7 +41,11 @@ class SeasonPatch(nn.Module):
self.mixer = HierarchicalGraphMixer(c_in, dim=d_model, k=k_graph)
# Prediction head
self.head = nn.Linear(patch_num * d_model, pred_len)
self.head = nn.Sequential(
nn.Linear(patch_num * d_model, patch_num * d_model),
nn.SiLU(), # 非线性激活SiLU/Swish
nn.Linear(patch_num * d_model, pred_len)
)
def forward(self, x):
# x: [B, L, C]
@ -64,4 +68,4 @@ class SeasonPatch(nn.Module):
z_mix = z_mix.view(B, C, N * D) # [B, C, patch_num * d_model]
y_pred = self.head(z_mix) # [B, C, pred_len]
return y_pred
return y_pred