refactor(mamba): adjust Mamba2Encoder layer configuration
This commit is contained in:
@ -23,7 +23,7 @@ class Mamba2Encoder(nn.Module):
|
|||||||
expand=2,
|
expand=2,
|
||||||
headdim=64,
|
headdim=64,
|
||||||
# 堆叠层数
|
# 堆叠层数
|
||||||
n_layers=1,
|
n_layers=2,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.patch_num = patch_num
|
self.patch_num = patch_num
|
||||||
@ -57,13 +57,3 @@ class Mamba2Encoder(nn.Module):
|
|||||||
u = x.reshape(bs * n_vars, patch_num, self.d_model) # u: [bs*nvars, patch_num, d_model]
|
u = x.reshape(bs * n_vars, patch_num, self.d_model) # u: [bs*nvars, patch_num, d_model]
|
||||||
|
|
||||||
# 3) 通过 n_layers 层 Mamba2 进行建模(在 patch_num 维度上)
|
# 3) 通过 n_layers 层 Mamba2 进行建模(在 patch_num 维度上)
|
||||||
for m in self.mambas:
|
|
||||||
u = m(u) # 形状保持 [bs*nvars, patch_num, d_model]
|
|
||||||
|
|
||||||
# 4) 仅取最后一个时间步
|
|
||||||
y_last = u[:, -1, :] # y_last: [bs*nvars, d_model]
|
|
||||||
|
|
||||||
# 5) 还原回 (bs, nvars, d_model)
|
|
||||||
y_last = y_last.view(bs, n_vars, self.d_model) # y_last: [bs, nvars, d_model]
|
|
||||||
|
|
||||||
return y_last # [bs, nvars, d_model]
|
|
||||||
|
@ -58,7 +58,7 @@ class SeasonPatch(nn.Module):
|
|||||||
self.encoder = Mamba2Encoder(
|
self.encoder = Mamba2Encoder(
|
||||||
c_in=c_in, patch_num=patch_num, patch_len=patch_len,
|
c_in=c_in, patch_num=patch_num, patch_len=patch_len,
|
||||||
d_model=d_model, d_state=d_state, d_conv=d_conv,
|
d_model=d_model, d_state=d_state, d_conv=d_conv,
|
||||||
expand=expand, headdim=headdim
|
expand=expand, headdim=headdim, n_layers=n_layers
|
||||||
)
|
)
|
||||||
# Prediction head(Mamba2 路径用到,输入维度为 d_model)
|
# Prediction head(Mamba2 路径用到,输入维度为 d_model)
|
||||||
self.head = nn.Sequential(
|
self.head = nn.Sequential(
|
||||||
|
Reference in New Issue
Block a user