60 lines
1.7 KiB
Python
60 lines
1.7 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
from mamba_ssm import Mamba2
|
||
|
||
|
||
class Mamba2Encoder(nn.Module):
|
||
"""
|
||
使用 Mamba2 对 patch 维度进行序列建模:
|
||
输入: [bs, nvars, patch_num, patch_len]
|
||
映射: patch_len -> d_model
|
||
建模: 在 patch_num 维度上用 Mamba2(可堆叠多层)
|
||
输出: [bs, nvars, d_model] (仅返回 Mamba 输出的最后一个时间步)
|
||
"""
|
||
def __init__(
|
||
self,
|
||
c_in,
|
||
patch_num,
|
||
patch_len,
|
||
d_model=128,
|
||
# Mamba2 超参
|
||
d_state=64,
|
||
d_conv=4,
|
||
expand=2,
|
||
headdim=64,
|
||
# 堆叠层数
|
||
n_layers=2,
|
||
):
|
||
super().__init__()
|
||
self.patch_num = patch_num
|
||
self.patch_len = patch_len
|
||
self.d_model = d_model
|
||
self.n_layers = n_layers
|
||
|
||
# 将 patch_len 投影到 d_model
|
||
self.W_P = nn.Linear(patch_len, d_model) # 映射 patch_len -> d_model
|
||
|
||
# 堆叠 n_layers 层 Mamba2
|
||
self.mambas = nn.ModuleList([
|
||
Mamba2(
|
||
d_model=d_model,
|
||
d_state=d_state,
|
||
d_conv=d_conv,
|
||
expand=expand,
|
||
headdim=headdim,
|
||
)
|
||
for _ in range(n_layers)
|
||
])
|
||
|
||
def forward(self, x):
|
||
# x: [bs, nvars, patch_num, patch_len]
|
||
bs, n_vars, patch_num, patch_len = x.shape # bs, n_vars, patch_num, patch_len
|
||
|
||
# 1) 线性映射: patch_len -> d_model
|
||
x = self.W_P(x) # x: [bs, nvars, patch_num, d_model]
|
||
|
||
# 2) 合并 batch 与通道维度,作为 Mamba 的 batch
|
||
u = x.reshape(bs * n_vars, patch_num, self.d_model) # u: [bs*nvars, patch_num, d_model]
|
||
|
||
# 3) 通过 n_layers 层 Mamba2 进行建模(在 patch_num 维度上)
|