feat(mamba): add Mamba2 encoder option to SeasonPatch

This commit is contained in:
gameloader
2025-09-10 15:51:28 +08:00
parent 96c40c6ab6
commit 908d3a7080
3 changed files with 138 additions and 37 deletions

View File

@ -1,11 +1,15 @@
"""
SeasonPatch = PatchTST (CI) + ChannelGraphMixer + Linear prediction head
Adapted for Time-Series-Library-main style
支持两种编码器:
- Transformer 编码器路径PatchTST + GraphMixer + Head
- Mamba2 编码器路径Mamba2Encoder不使用mixer直接用最后得到的 d_model 走 Head
"""
import torch
import torch.nn as nn
from layers.TSTEncoder import TSTiEncoder
from layers.GraphMixer import HierarchicalGraphMixer
from layers.MambaSeries import Mamba2Encoder
class SeasonPatch(nn.Module):
def __init__(self,
@ -17,17 +21,22 @@ class SeasonPatch(nn.Module):
k_graph: int = 8,
d_model: int = 128,
n_layers: int = 3,
n_heads: int = 16):
n_heads: int = 16,
# Mamba2 相关可选超参
d_state: int = 64,
d_conv: int = 4,
expand: int = 2,
headdim: int = 64):
super().__init__()
# Store patch parameters
self.patch_len = patch_len
self.stride = stride
self.patch_len = patch_len # patch 长度
self.stride = stride # patch 步幅
# Calculate patch number
patch_num = (seq_len - patch_len) // stride + 1
# PatchTST encoder (channel independent)
patch_num = (seq_len - patch_len) // stride + 1 # patch_num: int
# Transformer (PatchTST) 编码器(channel independent
self.encoder = TSTiEncoder(
c_in=c_in,
patch_num=patch_num,
@ -36,36 +45,64 @@ class SeasonPatch(nn.Module):
n_layers=n_layers,
n_heads=n_heads
)
# Cross-channel mixer
# Mamba2 编码器channel independent返回 [B, C, d_model]
self.mamba_encoder = Mamba2Encoder(
c_in=c_in,
patch_num=patch_num,
patch_len=patch_len,
d_model=d_model,
d_state=d_state,
d_conv=d_conv,
expand=expand,
headdim=headdim,
)
# Cross-channel mixer仅 Transformer 路径使用)
self.mixer = HierarchicalGraphMixer(c_in, dim=d_model, k=k_graph)
# Prediction head
self.head = nn.Sequential(
# Prediction headTransformer 路径用到,输入维度为 patch_num * d_model
self.head_tr = nn.Sequential(
nn.Linear(patch_num * d_model, patch_num * d_model),
nn.SiLU(), # 非线性激活SiLU/Swish
nn.SiLU(),
nn.Linear(patch_num * d_model, pred_len)
)
def forward(self, x):
# x: [B, L, C]
x = x.permute(0, 2, 1) # → [B, C, L]
# Patch the input
x_patch = x.unfold(-1, self.patch_len, self.stride) # [B, C, patch_num, patch_len]
# Encode patches
z = self.encoder(x_patch) # [B, C, d_model, patch_num]
# z: [B, C, d_model, patch_num] → [B, C, patch_num, d_model]
B, C, D, N = z.shape
z = z.permute(0, 1, 3, 2) # [B, C, patch_num, d_model]
# Cross-channel mixing
z_mix = self.mixer(z) # [B, C, patch_num, d_model]
# Flatten and predict
z_mix = z_mix.view(B, C, N * D) # [B, C, patch_num * d_model]
y_pred = self.head(z_mix) # [B, C, pred_len]
# Prediction headMamba2 路径用到,输入维度为 d_model
self.head_mamba = nn.Sequential(
nn.Linear(d_model, d_model),
nn.SiLU(),
nn.Linear(d_model, pred_len)
)
return y_pred
def forward(self, x, encoder="Transformer"):
# x: [B, L, C]
x = x.permute(0, 2, 1) # x: [B, C, L]
# Patch the input
x_patch = x.unfold(-1, self.patch_len, self.stride) # x_patch: [B, C, patch_num, patch_len]
if encoder == "Transformer":
# Encode patches (PatchTST)
z = self.encoder(x_patch) # z: [B, C, d_model, patch_num]
# [B, C, d_model, patch_num] -> [B, C, patch_num, d_model]
B, C, D, N = z.shape # B: batch, C: channels, D: d_model, N: patch_num
z = z.permute(0, 1, 3, 2) # z: [B, C, patch_num, d_model]
# Cross-channel mixing
z_mix = self.mixer(z) # z_mix: [B, C, patch_num, d_model]
# Flatten and predict
z_mix = z_mix.view(B, C, N * D) # z_mix: [B, C, patch_num * d_model]
y_pred = self.head_tr(z_mix) # y_pred: [B, C, pred_len]
elif encoder == "Mamba2":
# 使用 Mamba2 编码器(不使用 mixer
z_last = self.mamba_encoder(x_patch) # z_last: [B, C, d_model](仅最后一个时间步)
y_pred = self.head_mamba(z_last) # y_pred: [B, C, pred_len]
else:
raise ValueError(f"Unsupported encoder type: {encoder}. Use 'Transformer' or 'Mamba2'.")
return y_pred # [B, C, pred_len]