Files
TSlib/layers/MambaSeries.py

60 lines
1.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
from mamba_ssm import Mamba2
class Mamba2Encoder(nn.Module):
"""
使用 Mamba2 对 patch 维度进行序列建模:
输入: [bs, nvars, patch_num, patch_len]
映射: patch_len -> d_model
建模: 在 patch_num 维度上用 Mamba2可堆叠多层
输出: [bs, nvars, d_model] (仅返回 Mamba 输出的最后一个时间步)
"""
def __init__(
self,
c_in,
patch_num,
patch_len,
d_model=128,
# Mamba2 超参
d_state=64,
d_conv=4,
expand=2,
headdim=64,
# 堆叠层数
n_layers=2,
):
super().__init__()
self.patch_num = patch_num
self.patch_len = patch_len
self.d_model = d_model
self.n_layers = n_layers
# 将 patch_len 投影到 d_model
self.W_P = nn.Linear(patch_len, d_model) # 映射 patch_len -> d_model
# 堆叠 n_layers 层 Mamba2
self.mambas = nn.ModuleList([
Mamba2(
d_model=d_model,
d_state=d_state,
d_conv=d_conv,
expand=expand,
headdim=headdim,
)
for _ in range(n_layers)
])
def forward(self, x):
# x: [bs, nvars, patch_num, patch_len]
bs, n_vars, patch_num, patch_len = x.shape # bs, n_vars, patch_num, patch_len
# 1) 线性映射: patch_len -> d_model
x = self.W_P(x) # x: [bs, nvars, patch_num, d_model]
# 2) 合并 batch 与通道维度,作为 Mamba 的 batch
u = x.reshape(bs * n_vars, patch_num, self.d_model) # u: [bs*nvars, patch_num, d_model]
# 3) 通过 n_layers 层 Mamba2 进行建模(在 patch_num 维度上)