refactor(mamba): adjust Mamba2Encoder layer configuration

This commit is contained in:
gameloader
2025-09-10 21:00:31 +08:00
parent 9787badd25
commit b139f711bc
2 changed files with 2 additions and 12 deletions

View File

@ -58,7 +58,7 @@ class SeasonPatch(nn.Module):
self.encoder = Mamba2Encoder(
c_in=c_in, patch_num=patch_num, patch_len=patch_len,
d_model=d_model, d_state=d_state, d_conv=d_conv,
expand=expand, headdim=headdim
expand=expand, headdim=headdim, n_layers=n_layers
)
# Prediction headMamba2 路径用到,输入维度为 d_model
self.head = nn.Sequential(