refactor(SeasonPatch): unify encoder and head initialization

This commit is contained in:
gameloader
2025-09-10 16:00:52 +08:00
parent 908d3a7080
commit 1044c60fe7
2 changed files with 39 additions and 48 deletions

View File

@ -37,8 +37,6 @@ class Model(nn.Module):
beta = getattr(configs, 'beta', torch.tensor(0.1))
self.decomp = DECOMP(ma_type, alpha, beta)
# 读取选择的编码器类型('Transformer' 或 'Mamba2'
self.season_encoder = getattr(configs, 'season_encoder', 'Transformer')
# Season network (PatchTST + Graph Mixer)
self.season_net = SeasonPatch(
c_in=self.enc_in,
@ -49,7 +47,9 @@ class Model(nn.Module):
k_graph=getattr(configs, 'k_graph', 8),
d_model=getattr(configs, 'd_model', 128),
n_layers=getattr(configs, 'e_layers', 3),
n_heads=getattr(configs, 'n_heads', 16)
n_heads=getattr(configs, 'n_heads', 16),
# 读取选择的编码器类型('Transformer' 或 'Mamba2'
encoder_type = getattr(configs, 'season_encoder', 'Transformer')
)
# Trend network (MLP)