feat(mamba): add Mamba2 encoder option to SeasonPatch

This commit is contained in:
gameloader
2025-09-10 15:51:28 +08:00
parent 96c40c6ab6
commit 908d3a7080
3 changed files with 138 additions and 37 deletions

62
layers/MambaSeries.py Normal file
View File

@ -0,0 +1,62 @@
import torch
import torch.nn as nn
from mamba_ssm import Mamba2
class Mamba2Encoder(nn.Module):
"""
使用 Mamba2 对 patch 维度进行序列建模:
输入: [bs, nvars, patch_num, patch_len]
映射: patch_len -> d_model
建模: 在 patch_num 维度上用 Mamba2
输出: [bs, nvars, d_model] (仅返回 Mamba 输出的最后一个时间步)
"""
def __init__(
self,
c_in,
patch_num,
patch_len,
d_model=128,
# Mamba2 超参
d_state=64,
d_conv=4,
expand=2,
headdim=64,
):
super().__init__()
self.patch_num = patch_num
self.patch_len = patch_len
self.d_model = d_model
# 将 patch_len 投影到 d_model
self.W_P = nn.Linear(patch_len, d_model) # 映射 patch_len -> d_model
# 直接使用 Mamba2 对序列 (patch_num) 建模
self.mamba = Mamba2(
d_model=d_model,
d_state=d_state,
d_conv=d_conv,
expand=expand,
headdim=headdim,
)
def forward(self, x):
# x: [bs, nvars, patch_num, patch_len]
bs, n_vars, patch_num, patch_len = x.shape # bs, n_vars, patch_num, patch_len
# 1) 线性映射: patch_len -> d_model
x = self.W_P(x) # x: [bs, nvars, patch_num, d_model]
# 2) 合并 batch 与通道维度,作为 Mamba 的 batch
u = x.reshape(bs * n_vars, patch_num, self.d_model) # u: [bs*nvars, patch_num, d_model]
# 3) Mamba2 建模(在 patch_num 维度上)
y = self.mamba(u) # y: [bs*nvars, patch_num, d_model]
# 4) 仅取最后一个时间步
y_last = y[:, -1, :] # y_last: [bs*nvars, d_model]
# 5) 还原回 (bs, nvars, d_model)
y_last = y_last.view(bs, n_vars, self.d_model) # y_last: [bs, nvars, d_model]
return y_last # [bs, nvars, d_model]

View File

@ -1,11 +1,15 @@
""" """
SeasonPatch = PatchTST (CI) + ChannelGraphMixer + Linear prediction head SeasonPatch = PatchTST (CI) + ChannelGraphMixer + Linear prediction head
Adapted for Time-Series-Library-main style 支持两种编码器:
- Transformer 编码器路径PatchTST + GraphMixer + Head
- Mamba2 编码器路径Mamba2Encoder不使用mixer直接用最后得到的 d_model 走 Head
""" """
import torch import torch
import torch.nn as nn import torch.nn as nn
from layers.TSTEncoder import TSTiEncoder from layers.TSTEncoder import TSTiEncoder
from layers.GraphMixer import HierarchicalGraphMixer from layers.GraphMixer import HierarchicalGraphMixer
from layers.MambaSeries import Mamba2Encoder
class SeasonPatch(nn.Module): class SeasonPatch(nn.Module):
def __init__(self, def __init__(self,
@ -17,17 +21,22 @@ class SeasonPatch(nn.Module):
k_graph: int = 8, k_graph: int = 8,
d_model: int = 128, d_model: int = 128,
n_layers: int = 3, n_layers: int = 3,
n_heads: int = 16): n_heads: int = 16,
# Mamba2 相关可选超参
d_state: int = 64,
d_conv: int = 4,
expand: int = 2,
headdim: int = 64):
super().__init__() super().__init__()
# Store patch parameters # Store patch parameters
self.patch_len = patch_len self.patch_len = patch_len # patch 长度
self.stride = stride self.stride = stride # patch 步幅
# Calculate patch number # Calculate patch number
patch_num = (seq_len - patch_len) // stride + 1 patch_num = (seq_len - patch_len) // stride + 1 # patch_num: int
# PatchTST encoder (channel independent) # Transformer (PatchTST) 编码器(channel independent
self.encoder = TSTiEncoder( self.encoder = TSTiEncoder(
c_in=c_in, c_in=c_in,
patch_num=patch_num, patch_num=patch_num,
@ -36,36 +45,64 @@ class SeasonPatch(nn.Module):
n_layers=n_layers, n_layers=n_layers,
n_heads=n_heads n_heads=n_heads
) )
# Cross-channel mixer # Mamba2 编码器channel independent返回 [B, C, d_model]
self.mamba_encoder = Mamba2Encoder(
c_in=c_in,
patch_num=patch_num,
patch_len=patch_len,
d_model=d_model,
d_state=d_state,
d_conv=d_conv,
expand=expand,
headdim=headdim,
)
# Cross-channel mixer仅 Transformer 路径使用)
self.mixer = HierarchicalGraphMixer(c_in, dim=d_model, k=k_graph) self.mixer = HierarchicalGraphMixer(c_in, dim=d_model, k=k_graph)
# Prediction head # Prediction headTransformer 路径用到,输入维度为 patch_num * d_model
self.head = nn.Sequential( self.head_tr = nn.Sequential(
nn.Linear(patch_num * d_model, patch_num * d_model), nn.Linear(patch_num * d_model, patch_num * d_model),
nn.SiLU(), # 非线性激活SiLU/Swish nn.SiLU(),
nn.Linear(patch_num * d_model, pred_len) nn.Linear(patch_num * d_model, pred_len)
) )
def forward(self, x): # Prediction headMamba2 路径用到,输入维度为 d_model
# x: [B, L, C] self.head_mamba = nn.Sequential(
x = x.permute(0, 2, 1) # → [B, C, L] nn.Linear(d_model, d_model),
nn.SiLU(),
# Patch the input nn.Linear(d_model, pred_len)
x_patch = x.unfold(-1, self.patch_len, self.stride) # [B, C, patch_num, patch_len] )
# Encode patches
z = self.encoder(x_patch) # [B, C, d_model, patch_num]
# z: [B, C, d_model, patch_num] → [B, C, patch_num, d_model]
B, C, D, N = z.shape
z = z.permute(0, 1, 3, 2) # [B, C, patch_num, d_model]
# Cross-channel mixing
z_mix = self.mixer(z) # [B, C, patch_num, d_model]
# Flatten and predict
z_mix = z_mix.view(B, C, N * D) # [B, C, patch_num * d_model]
y_pred = self.head(z_mix) # [B, C, pred_len]
return y_pred def forward(self, x, encoder="Transformer"):
# x: [B, L, C]
x = x.permute(0, 2, 1) # x: [B, C, L]
# Patch the input
x_patch = x.unfold(-1, self.patch_len, self.stride) # x_patch: [B, C, patch_num, patch_len]
if encoder == "Transformer":
# Encode patches (PatchTST)
z = self.encoder(x_patch) # z: [B, C, d_model, patch_num]
# [B, C, d_model, patch_num] -> [B, C, patch_num, d_model]
B, C, D, N = z.shape # B: batch, C: channels, D: d_model, N: patch_num
z = z.permute(0, 1, 3, 2) # z: [B, C, patch_num, d_model]
# Cross-channel mixing
z_mix = self.mixer(z) # z_mix: [B, C, patch_num, d_model]
# Flatten and predict
z_mix = z_mix.view(B, C, N * D) # z_mix: [B, C, patch_num * d_model]
y_pred = self.head_tr(z_mix) # y_pred: [B, C, pred_len]
elif encoder == "Mamba2":
# 使用 Mamba2 编码器(不使用 mixer
z_last = self.mamba_encoder(x_patch) # z_last: [B, C, d_model](仅最后一个时间步)
y_pred = self.head_mamba(z_last) # y_pred: [B, C, pred_len]
else:
raise ValueError(f"Unsupported encoder type: {encoder}. Use 'Transformer' or 'Mamba2'.")
return y_pred # [B, C, pred_len]

View File

@ -37,6 +37,8 @@ class Model(nn.Module):
beta = getattr(configs, 'beta', torch.tensor(0.1)) beta = getattr(configs, 'beta', torch.tensor(0.1))
self.decomp = DECOMP(ma_type, alpha, beta) self.decomp = DECOMP(ma_type, alpha, beta)
# 读取选择的编码器类型('Transformer' 或 'Mamba2'
self.season_encoder = getattr(configs, 'season_encoder', 'Transformer')
# Season network (PatchTST + Graph Mixer) # Season network (PatchTST + Graph Mixer)
self.season_net = SeasonPatch( self.season_net = SeasonPatch(
c_in=self.enc_in, c_in=self.enc_in,
@ -90,7 +92,7 @@ class Model(nn.Module):
seasonal_init, trend_init = self.decomp(x_enc) seasonal_init, trend_init = self.decomp(x_enc)
# Season stream # Season stream
y_season = self.season_net(seasonal_init) # [B, C, pred_len] y_season = self.season_net(seasonal_init, encoder=self.season_encoder) # [B, C, pred_len]
# Trend stream # Trend stream
B, L, C = trend_init.shape B, L, C = trend_init.shape
@ -125,7 +127,7 @@ class Model(nn.Module):
seasonal_init, trend_init = self.decomp(x_enc) seasonal_init, trend_init = self.decomp(x_enc)
# Season stream # Season stream
y_season = self.season_net(seasonal_init) # [B, C, pred_len] y_season = self.season_net(seasonal_init, encoder=self.season_encoder) # [B, C, pred_len]
# print("shape:", trend_init.shape) # print("shape:", trend_init.shape)
# Trend stream # Trend stream
@ -163,4 +165,4 @@ class Model(nn.Module):
dec_out = self.classification(x_enc, x_mark_enc) dec_out = self.classification(x_enc, x_mark_enc)
return dec_out # [B, N] return dec_out # [B, N]
else: else:
raise ValueError(f'Task {self.task_name} not supported by xPatch_SparseChannel') raise ValueError(f'Task {self.task_name} not supported by xPatch_SparseChannel')