refactor(SeasonPatch): unify encoder and head initialization
This commit is contained in:
@ -37,8 +37,6 @@ class Model(nn.Module):
|
||||
beta = getattr(configs, 'beta', torch.tensor(0.1))
|
||||
self.decomp = DECOMP(ma_type, alpha, beta)
|
||||
|
||||
# 读取选择的编码器类型('Transformer' 或 'Mamba2')
|
||||
self.season_encoder = getattr(configs, 'season_encoder', 'Transformer')
|
||||
# Season network (PatchTST + Graph Mixer)
|
||||
self.season_net = SeasonPatch(
|
||||
c_in=self.enc_in,
|
||||
@ -49,7 +47,9 @@ class Model(nn.Module):
|
||||
k_graph=getattr(configs, 'k_graph', 8),
|
||||
d_model=getattr(configs, 'd_model', 128),
|
||||
n_layers=getattr(configs, 'e_layers', 3),
|
||||
n_heads=getattr(configs, 'n_heads', 16)
|
||||
n_heads=getattr(configs, 'n_heads', 16),
|
||||
# 读取选择的编码器类型('Transformer' 或 'Mamba2')
|
||||
encoder_type = getattr(configs, 'season_encoder', 'Transformer')
|
||||
)
|
||||
|
||||
# Trend network (MLP)
|
||||
|
Reference in New Issue
Block a user