diff --git a/layers/MambaSeries.py b/layers/MambaSeries.py index 6fd05c1..6cef33a 100644 --- a/layers/MambaSeries.py +++ b/layers/MambaSeries.py @@ -8,7 +8,7 @@ class Mamba2Encoder(nn.Module): 使用 Mamba2 对 patch 维度进行序列建模: 输入: [bs, nvars, patch_num, patch_len] 映射: patch_len -> d_model - 建模: 在 patch_num 维度上用 Mamba2 + 建模: 在 patch_num 维度上用 Mamba2(可堆叠多层) 输出: [bs, nvars, d_model] (仅返回 Mamba 输出的最后一个时间步) """ def __init__( @@ -22,23 +22,29 @@ class Mamba2Encoder(nn.Module): d_conv=4, expand=2, headdim=64, + # 堆叠层数 + n_layers=1, ): super().__init__() self.patch_num = patch_num self.patch_len = patch_len self.d_model = d_model + self.n_layers = n_layers # 将 patch_len 投影到 d_model self.W_P = nn.Linear(patch_len, d_model) # 映射 patch_len -> d_model - # 直接使用 Mamba2 对序列 (patch_num) 建模 - self.mamba = Mamba2( - d_model=d_model, - d_state=d_state, - d_conv=d_conv, - expand=expand, - headdim=headdim, - ) + # 堆叠 n_layers 层 Mamba2 + self.mambas = nn.ModuleList([ + Mamba2( + d_model=d_model, + d_state=d_state, + d_conv=d_conv, + expand=expand, + headdim=headdim, + ) + for _ in range(n_layers) + ]) def forward(self, x): # x: [bs, nvars, patch_num, patch_len] @@ -50,11 +56,12 @@ class Mamba2Encoder(nn.Module): # 2) 合并 batch 与通道维度,作为 Mamba 的 batch u = x.reshape(bs * n_vars, patch_num, self.d_model) # u: [bs*nvars, patch_num, d_model] - # 3) Mamba2 建模(在 patch_num 维度上) - y = self.mamba(u) # y: [bs*nvars, patch_num, d_model] + # 3) 通过 n_layers 层 Mamba2 进行建模(在 patch_num 维度上) + for m in self.mambas: + u = m(u) # 形状保持 [bs*nvars, patch_num, d_model] # 4) 仅取最后一个时间步 - y_last = y[:, -1, :] # y_last: [bs*nvars, d_model] + y_last = u[:, -1, :] # y_last: [bs*nvars, d_model] # 5) 还原回 (bs, nvars, d_model) y_last = y_last.view(bs, n_vars, self.d_model) # y_last: [bs, nvars, d_model]