feat(mamba): extract last time step from Mamba2Encoder output

This commit is contained in:
gameloader
2025-09-10 21:03:35 +08:00
parent b139f711bc
commit 598fdaadbc

View File

@ -23,7 +23,7 @@ class Mamba2Encoder(nn.Module):
expand=2, expand=2,
headdim=64, headdim=64,
# 堆叠层数 # 堆叠层数
n_layers=2, n_layers=1,
): ):
super().__init__() super().__init__()
self.patch_num = patch_num self.patch_num = patch_num
@ -57,3 +57,13 @@ class Mamba2Encoder(nn.Module):
u = x.reshape(bs * n_vars, patch_num, self.d_model) # u: [bs*nvars, patch_num, d_model] u = x.reshape(bs * n_vars, patch_num, self.d_model) # u: [bs*nvars, patch_num, d_model]
# 3) 通过 n_layers 层 Mamba2 进行建模(在 patch_num 维度上) # 3) 通过 n_layers 层 Mamba2 进行建模(在 patch_num 维度上)
for m in self.mambas:
u = m(u) # 形状保持 [bs*nvars, patch_num, d_model]
# 4) 仅取最后一个时间步
y_last = u[:, -1, :] # y_last: [bs*nvars, d_model]
# 5) 还原回 (bs, nvars, d_model)
y_last = y_last.view(bs, n_vars, self.d_model) # y_last: [bs, nvars, d_model]
return y_last # [bs, nvars, d_model]