feat(mamba): add Mamba2 encoder option to SeasonPatch

This commit is contained in:
gameloader
2025-09-10 15:51:28 +08:00
parent 96c40c6ab6
commit 908d3a7080
3 changed files with 138 additions and 37 deletions

62
layers/MambaSeries.py Normal file
View File

@ -0,0 +1,62 @@
import torch
import torch.nn as nn
from mamba_ssm import Mamba2
class Mamba2Encoder(nn.Module):
"""
使用 Mamba2 对 patch 维度进行序列建模:
输入: [bs, nvars, patch_num, patch_len]
映射: patch_len -> d_model
建模: 在 patch_num 维度上用 Mamba2
输出: [bs, nvars, d_model] (仅返回 Mamba 输出的最后一个时间步)
"""
def __init__(
self,
c_in,
patch_num,
patch_len,
d_model=128,
# Mamba2 超参
d_state=64,
d_conv=4,
expand=2,
headdim=64,
):
super().__init__()
self.patch_num = patch_num
self.patch_len = patch_len
self.d_model = d_model
# 将 patch_len 投影到 d_model
self.W_P = nn.Linear(patch_len, d_model) # 映射 patch_len -> d_model
# 直接使用 Mamba2 对序列 (patch_num) 建模
self.mamba = Mamba2(
d_model=d_model,
d_state=d_state,
d_conv=d_conv,
expand=expand,
headdim=headdim,
)
def forward(self, x):
# x: [bs, nvars, patch_num, patch_len]
bs, n_vars, patch_num, patch_len = x.shape # bs, n_vars, patch_num, patch_len
# 1) 线性映射: patch_len -> d_model
x = self.W_P(x) # x: [bs, nvars, patch_num, d_model]
# 2) 合并 batch 与通道维度,作为 Mamba 的 batch
u = x.reshape(bs * n_vars, patch_num, self.d_model) # u: [bs*nvars, patch_num, d_model]
# 3) Mamba2 建模(在 patch_num 维度上)
y = self.mamba(u) # y: [bs*nvars, patch_num, d_model]
# 4) 仅取最后一个时间步
y_last = y[:, -1, :] # y_last: [bs*nvars, d_model]
# 5) 还原回 (bs, nvars, d_model)
y_last = y_last.view(bs, n_vars, self.d_model) # y_last: [bs, nvars, d_model]
return y_last # [bs, nvars, d_model]