diff --git a/layers/MambaSeries.py b/layers/MambaSeries.py index 6cef33a..2c5fd26 100644 --- a/layers/MambaSeries.py +++ b/layers/MambaSeries.py @@ -23,7 +23,7 @@ class Mamba2Encoder(nn.Module): expand=2, headdim=64, # 堆叠层数 - n_layers=1, + n_layers=2, ): super().__init__() self.patch_num = patch_num @@ -57,13 +57,3 @@ class Mamba2Encoder(nn.Module): u = x.reshape(bs * n_vars, patch_num, self.d_model) # u: [bs*nvars, patch_num, d_model] # 3) 通过 n_layers 层 Mamba2 进行建模(在 patch_num 维度上) - for m in self.mambas: - u = m(u) # 形状保持 [bs*nvars, patch_num, d_model] - - # 4) 仅取最后一个时间步 - y_last = u[:, -1, :] # y_last: [bs*nvars, d_model] - - # 5) 还原回 (bs, nvars, d_model) - y_last = y_last.view(bs, n_vars, self.d_model) # y_last: [bs, nvars, d_model] - - return y_last # [bs, nvars, d_model] diff --git a/layers/SeasonPatch.py b/layers/SeasonPatch.py index 5244377..1491fb5 100644 --- a/layers/SeasonPatch.py +++ b/layers/SeasonPatch.py @@ -58,7 +58,7 @@ class SeasonPatch(nn.Module): self.encoder = Mamba2Encoder( c_in=c_in, patch_num=patch_num, patch_len=patch_len, d_model=d_model, d_state=d_state, d_conv=d_conv, - expand=expand, headdim=headdim + expand=expand, headdim=headdim, n_layers=n_layers ) # Prediction head(Mamba2 路径用到,输入维度为 d_model) self.head = nn.Sequential(