feat(mamba): add Mamba2 encoder option to SeasonPatch

This commit is contained in:
gameloader
2025-09-10 15:51:28 +08:00
parent 96c40c6ab6
commit 908d3a7080
3 changed files with 138 additions and 37 deletions

62
layers/MambaSeries.py Normal file
View File

@ -0,0 +1,62 @@
import torch
import torch.nn as nn
from mamba_ssm import Mamba2
class Mamba2Encoder(nn.Module):
"""
使用 Mamba2 对 patch 维度进行序列建模:
输入: [bs, nvars, patch_num, patch_len]
映射: patch_len -> d_model
建模: 在 patch_num 维度上用 Mamba2
输出: [bs, nvars, d_model] (仅返回 Mamba 输出的最后一个时间步)
"""
def __init__(
self,
c_in,
patch_num,
patch_len,
d_model=128,
# Mamba2 超参
d_state=64,
d_conv=4,
expand=2,
headdim=64,
):
super().__init__()
self.patch_num = patch_num
self.patch_len = patch_len
self.d_model = d_model
# 将 patch_len 投影到 d_model
self.W_P = nn.Linear(patch_len, d_model) # 映射 patch_len -> d_model
# 直接使用 Mamba2 对序列 (patch_num) 建模
self.mamba = Mamba2(
d_model=d_model,
d_state=d_state,
d_conv=d_conv,
expand=expand,
headdim=headdim,
)
def forward(self, x):
# x: [bs, nvars, patch_num, patch_len]
bs, n_vars, patch_num, patch_len = x.shape # bs, n_vars, patch_num, patch_len
# 1) 线性映射: patch_len -> d_model
x = self.W_P(x) # x: [bs, nvars, patch_num, d_model]
# 2) 合并 batch 与通道维度,作为 Mamba 的 batch
u = x.reshape(bs * n_vars, patch_num, self.d_model) # u: [bs*nvars, patch_num, d_model]
# 3) Mamba2 建模(在 patch_num 维度上)
y = self.mamba(u) # y: [bs*nvars, patch_num, d_model]
# 4) 仅取最后一个时间步
y_last = y[:, -1, :] # y_last: [bs*nvars, d_model]
# 5) 还原回 (bs, nvars, d_model)
y_last = y_last.view(bs, n_vars, self.d_model) # y_last: [bs, nvars, d_model]
return y_last # [bs, nvars, d_model]

View File

@ -1,11 +1,15 @@
"""
SeasonPatch = PatchTST (CI) + ChannelGraphMixer + Linear prediction head
Adapted for Time-Series-Library-main style
支持两种编码器:
- Transformer 编码器路径PatchTST + GraphMixer + Head
- Mamba2 编码器路径Mamba2Encoder不使用mixer直接用最后得到的 d_model 走 Head
"""
import torch
import torch.nn as nn
from layers.TSTEncoder import TSTiEncoder
from layers.GraphMixer import HierarchicalGraphMixer
from layers.MambaSeries import Mamba2Encoder
class SeasonPatch(nn.Module):
def __init__(self,
@ -17,17 +21,22 @@ class SeasonPatch(nn.Module):
k_graph: int = 8,
d_model: int = 128,
n_layers: int = 3,
n_heads: int = 16):
n_heads: int = 16,
# Mamba2 相关可选超参
d_state: int = 64,
d_conv: int = 4,
expand: int = 2,
headdim: int = 64):
super().__init__()
# Store patch parameters
self.patch_len = patch_len
self.stride = stride
self.patch_len = patch_len # patch 长度
self.stride = stride # patch 步幅
# Calculate patch number
patch_num = (seq_len - patch_len) // stride + 1
# PatchTST encoder (channel independent)
patch_num = (seq_len - patch_len) // stride + 1 # patch_num: int
# Transformer (PatchTST) 编码器(channel independent
self.encoder = TSTiEncoder(
c_in=c_in,
patch_num=patch_num,
@ -36,36 +45,64 @@ class SeasonPatch(nn.Module):
n_layers=n_layers,
n_heads=n_heads
)
# Cross-channel mixer
# Mamba2 编码器channel independent返回 [B, C, d_model]
self.mamba_encoder = Mamba2Encoder(
c_in=c_in,
patch_num=patch_num,
patch_len=patch_len,
d_model=d_model,
d_state=d_state,
d_conv=d_conv,
expand=expand,
headdim=headdim,
)
# Cross-channel mixer仅 Transformer 路径使用)
self.mixer = HierarchicalGraphMixer(c_in, dim=d_model, k=k_graph)
# Prediction head
self.head = nn.Sequential(
# Prediction headTransformer 路径用到,输入维度为 patch_num * d_model
self.head_tr = nn.Sequential(
nn.Linear(patch_num * d_model, patch_num * d_model),
nn.SiLU(), # 非线性激活SiLU/Swish
nn.SiLU(),
nn.Linear(patch_num * d_model, pred_len)
)
def forward(self, x):
# x: [B, L, C]
x = x.permute(0, 2, 1) # → [B, C, L]
# Patch the input
x_patch = x.unfold(-1, self.patch_len, self.stride) # [B, C, patch_num, patch_len]
# Encode patches
z = self.encoder(x_patch) # [B, C, d_model, patch_num]
# z: [B, C, d_model, patch_num] → [B, C, patch_num, d_model]
B, C, D, N = z.shape
z = z.permute(0, 1, 3, 2) # [B, C, patch_num, d_model]
# Cross-channel mixing
z_mix = self.mixer(z) # [B, C, patch_num, d_model]
# Flatten and predict
z_mix = z_mix.view(B, C, N * D) # [B, C, patch_num * d_model]
y_pred = self.head(z_mix) # [B, C, pred_len]
# Prediction headMamba2 路径用到,输入维度为 d_model
self.head_mamba = nn.Sequential(
nn.Linear(d_model, d_model),
nn.SiLU(),
nn.Linear(d_model, pred_len)
)
return y_pred
def forward(self, x, encoder="Transformer"):
# x: [B, L, C]
x = x.permute(0, 2, 1) # x: [B, C, L]
# Patch the input
x_patch = x.unfold(-1, self.patch_len, self.stride) # x_patch: [B, C, patch_num, patch_len]
if encoder == "Transformer":
# Encode patches (PatchTST)
z = self.encoder(x_patch) # z: [B, C, d_model, patch_num]
# [B, C, d_model, patch_num] -> [B, C, patch_num, d_model]
B, C, D, N = z.shape # B: batch, C: channels, D: d_model, N: patch_num
z = z.permute(0, 1, 3, 2) # z: [B, C, patch_num, d_model]
# Cross-channel mixing
z_mix = self.mixer(z) # z_mix: [B, C, patch_num, d_model]
# Flatten and predict
z_mix = z_mix.view(B, C, N * D) # z_mix: [B, C, patch_num * d_model]
y_pred = self.head_tr(z_mix) # y_pred: [B, C, pred_len]
elif encoder == "Mamba2":
# 使用 Mamba2 编码器(不使用 mixer
z_last = self.mamba_encoder(x_patch) # z_last: [B, C, d_model](仅最后一个时间步)
y_pred = self.head_mamba(z_last) # y_pred: [B, C, pred_len]
else:
raise ValueError(f"Unsupported encoder type: {encoder}. Use 'Transformer' or 'Mamba2'.")
return y_pred # [B, C, pred_len]