feat(mambaseries): allow stacking multiple Mamba2 layers
This commit is contained in:
@ -8,7 +8,7 @@ class Mamba2Encoder(nn.Module):
|
|||||||
使用 Mamba2 对 patch 维度进行序列建模:
|
使用 Mamba2 对 patch 维度进行序列建模:
|
||||||
输入: [bs, nvars, patch_num, patch_len]
|
输入: [bs, nvars, patch_num, patch_len]
|
||||||
映射: patch_len -> d_model
|
映射: patch_len -> d_model
|
||||||
建模: 在 patch_num 维度上用 Mamba2
|
建模: 在 patch_num 维度上用 Mamba2(可堆叠多层)
|
||||||
输出: [bs, nvars, d_model] (仅返回 Mamba 输出的最后一个时间步)
|
输出: [bs, nvars, d_model] (仅返回 Mamba 输出的最后一个时间步)
|
||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -22,23 +22,29 @@ class Mamba2Encoder(nn.Module):
|
|||||||
d_conv=4,
|
d_conv=4,
|
||||||
expand=2,
|
expand=2,
|
||||||
headdim=64,
|
headdim=64,
|
||||||
|
# 堆叠层数
|
||||||
|
n_layers=1,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.patch_num = patch_num
|
self.patch_num = patch_num
|
||||||
self.patch_len = patch_len
|
self.patch_len = patch_len
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
|
self.n_layers = n_layers
|
||||||
|
|
||||||
# 将 patch_len 投影到 d_model
|
# 将 patch_len 投影到 d_model
|
||||||
self.W_P = nn.Linear(patch_len, d_model) # 映射 patch_len -> d_model
|
self.W_P = nn.Linear(patch_len, d_model) # 映射 patch_len -> d_model
|
||||||
|
|
||||||
# 直接使用 Mamba2 对序列 (patch_num) 建模
|
# 堆叠 n_layers 层 Mamba2
|
||||||
self.mamba = Mamba2(
|
self.mambas = nn.ModuleList([
|
||||||
d_model=d_model,
|
Mamba2(
|
||||||
d_state=d_state,
|
d_model=d_model,
|
||||||
d_conv=d_conv,
|
d_state=d_state,
|
||||||
expand=expand,
|
d_conv=d_conv,
|
||||||
headdim=headdim,
|
expand=expand,
|
||||||
)
|
headdim=headdim,
|
||||||
|
)
|
||||||
|
for _ in range(n_layers)
|
||||||
|
])
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# x: [bs, nvars, patch_num, patch_len]
|
# x: [bs, nvars, patch_num, patch_len]
|
||||||
@ -50,11 +56,12 @@ class Mamba2Encoder(nn.Module):
|
|||||||
# 2) 合并 batch 与通道维度,作为 Mamba 的 batch
|
# 2) 合并 batch 与通道维度,作为 Mamba 的 batch
|
||||||
u = x.reshape(bs * n_vars, patch_num, self.d_model) # u: [bs*nvars, patch_num, d_model]
|
u = x.reshape(bs * n_vars, patch_num, self.d_model) # u: [bs*nvars, patch_num, d_model]
|
||||||
|
|
||||||
# 3) Mamba2 建模(在 patch_num 维度上)
|
# 3) 通过 n_layers 层 Mamba2 进行建模(在 patch_num 维度上)
|
||||||
y = self.mamba(u) # y: [bs*nvars, patch_num, d_model]
|
for m in self.mambas:
|
||||||
|
u = m(u) # 形状保持 [bs*nvars, patch_num, d_model]
|
||||||
|
|
||||||
# 4) 仅取最后一个时间步
|
# 4) 仅取最后一个时间步
|
||||||
y_last = y[:, -1, :] # y_last: [bs*nvars, d_model]
|
y_last = u[:, -1, :] # y_last: [bs*nvars, d_model]
|
||||||
|
|
||||||
# 5) 还原回 (bs, nvars, d_model)
|
# 5) 还原回 (bs, nvars, d_model)
|
||||||
y_last = y_last.view(bs, n_vars, self.d_model) # y_last: [bs, nvars, d_model]
|
y_last = y_last.view(bs, n_vars, self.d_model) # y_last: [bs, nvars, d_model]
|
||||||
|
Reference in New Issue
Block a user