refactor(mamba): add residual connection to Mamba2Encoder layers

This commit is contained in:
gameloader
2025-09-10 21:23:17 +08:00
parent 598fdaadbc
commit 5fc0da4239

View File

@ -8,7 +8,7 @@ class Mamba2Encoder(nn.Module):
使用 Mamba2 对 patch 维度进行序列建模: 使用 Mamba2 对 patch 维度进行序列建模:
输入: [bs, nvars, patch_num, patch_len] 输入: [bs, nvars, patch_num, patch_len]
映射: patch_len -> d_model 映射: patch_len -> d_model
建模: 在 patch_num 维度上用 Mamba2可堆叠多层 建模: 在 patch_num 维度上用 Mamba2可堆叠多层,层间加残差
输出: [bs, nvars, d_model] (仅返回 Mamba 输出的最后一个时间步) 输出: [bs, nvars, d_model] (仅返回 Mamba 输出的最后一个时间步)
""" """
def __init__( def __init__(
@ -56,9 +56,9 @@ class Mamba2Encoder(nn.Module):
# 2) 合并 batch 与通道维度,作为 Mamba 的 batch # 2) 合并 batch 与通道维度,作为 Mamba 的 batch
u = x.reshape(bs * n_vars, patch_num, self.d_model) # u: [bs*nvars, patch_num, d_model] u = x.reshape(bs * n_vars, patch_num, self.d_model) # u: [bs*nvars, patch_num, d_model]
# 3) 通过 n_layers 层 Mamba2 进行建模(在 patch_num 维度上) # 3) 通过 n_layers 层 Mamba2 进行建模(在 patch_num 维度上),并加残差连接
for m in self.mambas: for m in self.mambas:
u = m(u) # 形状保持 [bs*nvars, patch_num, d_model] u = u + m(u) # 残差连接,形状保持 [bs*nvars, patch_num, d_model]
# 4) 仅取最后一个时间步 # 4) 仅取最后一个时间步
y_last = u[:, -1, :] # y_last: [bs*nvars, d_model] y_last = u[:, -1, :] # y_last: [bs*nvars, d_model]