refactor(graphmixer): enhance channel graph attention with ST-Gumbel
This commit is contained in:
@ -41,7 +41,11 @@ class SeasonPatch(nn.Module):
|
||||
self.mixer = HierarchicalGraphMixer(c_in, dim=d_model, k=k_graph)
|
||||
|
||||
# Prediction head
|
||||
self.head = nn.Linear(patch_num * d_model, pred_len)
|
||||
self.head = nn.Sequential(
|
||||
nn.Linear(patch_num * d_model, patch_num * d_model),
|
||||
nn.SiLU(), # 非线性激活(SiLU/Swish)
|
||||
nn.Linear(patch_num * d_model, pred_len)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# x: [B, L, C]
|
||||
@ -64,4 +68,4 @@ class SeasonPatch(nn.Module):
|
||||
z_mix = z_mix.view(B, C, N * D) # [B, C, patch_num * d_model]
|
||||
y_pred = self.head(z_mix) # [B, C, pred_len]
|
||||
|
||||
return y_pred
|
||||
return y_pred
|
||||
|
Reference in New Issue
Block a user