refactor(mamba): adjust Mamba2Encoder layer configuration
This commit is contained in:
@ -58,7 +58,7 @@ class SeasonPatch(nn.Module):
|
||||
self.encoder = Mamba2Encoder(
|
||||
c_in=c_in, patch_num=patch_num, patch_len=patch_len,
|
||||
d_model=d_model, d_state=d_state, d_conv=d_conv,
|
||||
expand=expand, headdim=headdim
|
||||
expand=expand, headdim=headdim, n_layers=n_layers
|
||||
)
|
||||
# Prediction head(Mamba2 路径用到,输入维度为 d_model)
|
||||
self.head = nn.Sequential(
|
||||
|
Reference in New Issue
Block a user