feat(mamba): add Mamba2 encoder option to SeasonPatch
This commit is contained in:
62
layers/MambaSeries.py
Normal file
62
layers/MambaSeries.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from mamba_ssm import Mamba2
|
||||||
|
|
||||||
|
|
||||||
|
class Mamba2Encoder(nn.Module):
|
||||||
|
"""
|
||||||
|
使用 Mamba2 对 patch 维度进行序列建模:
|
||||||
|
输入: [bs, nvars, patch_num, patch_len]
|
||||||
|
映射: patch_len -> d_model
|
||||||
|
建模: 在 patch_num 维度上用 Mamba2
|
||||||
|
输出: [bs, nvars, d_model] (仅返回 Mamba 输出的最后一个时间步)
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
c_in,
|
||||||
|
patch_num,
|
||||||
|
patch_len,
|
||||||
|
d_model=128,
|
||||||
|
# Mamba2 超参
|
||||||
|
d_state=64,
|
||||||
|
d_conv=4,
|
||||||
|
expand=2,
|
||||||
|
headdim=64,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.patch_num = patch_num
|
||||||
|
self.patch_len = patch_len
|
||||||
|
self.d_model = d_model
|
||||||
|
|
||||||
|
# 将 patch_len 投影到 d_model
|
||||||
|
self.W_P = nn.Linear(patch_len, d_model) # 映射 patch_len -> d_model
|
||||||
|
|
||||||
|
# 直接使用 Mamba2 对序列 (patch_num) 建模
|
||||||
|
self.mamba = Mamba2(
|
||||||
|
d_model=d_model,
|
||||||
|
d_state=d_state,
|
||||||
|
d_conv=d_conv,
|
||||||
|
expand=expand,
|
||||||
|
headdim=headdim,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x: [bs, nvars, patch_num, patch_len]
|
||||||
|
bs, n_vars, patch_num, patch_len = x.shape # bs, n_vars, patch_num, patch_len
|
||||||
|
|
||||||
|
# 1) 线性映射: patch_len -> d_model
|
||||||
|
x = self.W_P(x) # x: [bs, nvars, patch_num, d_model]
|
||||||
|
|
||||||
|
# 2) 合并 batch 与通道维度,作为 Mamba 的 batch
|
||||||
|
u = x.reshape(bs * n_vars, patch_num, self.d_model) # u: [bs*nvars, patch_num, d_model]
|
||||||
|
|
||||||
|
# 3) Mamba2 建模(在 patch_num 维度上)
|
||||||
|
y = self.mamba(u) # y: [bs*nvars, patch_num, d_model]
|
||||||
|
|
||||||
|
# 4) 仅取最后一个时间步
|
||||||
|
y_last = y[:, -1, :] # y_last: [bs*nvars, d_model]
|
||||||
|
|
||||||
|
# 5) 还原回 (bs, nvars, d_model)
|
||||||
|
y_last = y_last.view(bs, n_vars, self.d_model) # y_last: [bs, nvars, d_model]
|
||||||
|
|
||||||
|
return y_last # [bs, nvars, d_model]
|
@ -1,11 +1,15 @@
|
|||||||
"""
|
"""
|
||||||
SeasonPatch = PatchTST (CI) + ChannelGraphMixer + Linear prediction head
|
SeasonPatch = PatchTST (CI) + ChannelGraphMixer + Linear prediction head
|
||||||
Adapted for Time-Series-Library-main style
|
支持两种编码器:
|
||||||
|
- Transformer 编码器路径:PatchTST + GraphMixer + Head
|
||||||
|
- Mamba2 编码器路径:Mamba2Encoder(不使用mixer),直接用最后得到的 d_model 走 Head
|
||||||
"""
|
"""
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from layers.TSTEncoder import TSTiEncoder
|
from layers.TSTEncoder import TSTiEncoder
|
||||||
from layers.GraphMixer import HierarchicalGraphMixer
|
from layers.GraphMixer import HierarchicalGraphMixer
|
||||||
|
from layers.MambaSeries import Mamba2Encoder
|
||||||
|
|
||||||
|
|
||||||
class SeasonPatch(nn.Module):
|
class SeasonPatch(nn.Module):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -17,17 +21,22 @@ class SeasonPatch(nn.Module):
|
|||||||
k_graph: int = 8,
|
k_graph: int = 8,
|
||||||
d_model: int = 128,
|
d_model: int = 128,
|
||||||
n_layers: int = 3,
|
n_layers: int = 3,
|
||||||
n_heads: int = 16):
|
n_heads: int = 16,
|
||||||
|
# Mamba2 相关可选超参
|
||||||
|
d_state: int = 64,
|
||||||
|
d_conv: int = 4,
|
||||||
|
expand: int = 2,
|
||||||
|
headdim: int = 64):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Store patch parameters
|
# Store patch parameters
|
||||||
self.patch_len = patch_len
|
self.patch_len = patch_len # patch 长度
|
||||||
self.stride = stride
|
self.stride = stride # patch 步幅
|
||||||
|
|
||||||
# Calculate patch number
|
# Calculate patch number
|
||||||
patch_num = (seq_len - patch_len) // stride + 1
|
patch_num = (seq_len - patch_len) // stride + 1 # patch_num: int
|
||||||
|
|
||||||
# PatchTST encoder (channel independent)
|
# Transformer (PatchTST) 编码器(channel independent)
|
||||||
self.encoder = TSTiEncoder(
|
self.encoder = TSTiEncoder(
|
||||||
c_in=c_in,
|
c_in=c_in,
|
||||||
patch_num=patch_num,
|
patch_num=patch_num,
|
||||||
@ -36,36 +45,64 @@ class SeasonPatch(nn.Module):
|
|||||||
n_layers=n_layers,
|
n_layers=n_layers,
|
||||||
n_heads=n_heads
|
n_heads=n_heads
|
||||||
)
|
)
|
||||||
|
|
||||||
# Cross-channel mixer
|
# Mamba2 编码器(channel independent),返回 [B, C, d_model]
|
||||||
|
self.mamba_encoder = Mamba2Encoder(
|
||||||
|
c_in=c_in,
|
||||||
|
patch_num=patch_num,
|
||||||
|
patch_len=patch_len,
|
||||||
|
d_model=d_model,
|
||||||
|
d_state=d_state,
|
||||||
|
d_conv=d_conv,
|
||||||
|
expand=expand,
|
||||||
|
headdim=headdim,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cross-channel mixer(仅 Transformer 路径使用)
|
||||||
self.mixer = HierarchicalGraphMixer(c_in, dim=d_model, k=k_graph)
|
self.mixer = HierarchicalGraphMixer(c_in, dim=d_model, k=k_graph)
|
||||||
|
|
||||||
# Prediction head
|
# Prediction head(Transformer 路径用到,输入维度为 patch_num * d_model)
|
||||||
self.head = nn.Sequential(
|
self.head_tr = nn.Sequential(
|
||||||
nn.Linear(patch_num * d_model, patch_num * d_model),
|
nn.Linear(patch_num * d_model, patch_num * d_model),
|
||||||
nn.SiLU(), # 非线性激活(SiLU/Swish)
|
nn.SiLU(),
|
||||||
nn.Linear(patch_num * d_model, pred_len)
|
nn.Linear(patch_num * d_model, pred_len)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
# Prediction head(Mamba2 路径用到,输入维度为 d_model)
|
||||||
# x: [B, L, C]
|
self.head_mamba = nn.Sequential(
|
||||||
x = x.permute(0, 2, 1) # → [B, C, L]
|
nn.Linear(d_model, d_model),
|
||||||
|
nn.SiLU(),
|
||||||
# Patch the input
|
nn.Linear(d_model, pred_len)
|
||||||
x_patch = x.unfold(-1, self.patch_len, self.stride) # [B, C, patch_num, patch_len]
|
)
|
||||||
|
|
||||||
# Encode patches
|
|
||||||
z = self.encoder(x_patch) # [B, C, d_model, patch_num]
|
|
||||||
|
|
||||||
# z: [B, C, d_model, patch_num] → [B, C, patch_num, d_model]
|
|
||||||
B, C, D, N = z.shape
|
|
||||||
z = z.permute(0, 1, 3, 2) # [B, C, patch_num, d_model]
|
|
||||||
|
|
||||||
# Cross-channel mixing
|
|
||||||
z_mix = self.mixer(z) # [B, C, patch_num, d_model]
|
|
||||||
|
|
||||||
# Flatten and predict
|
|
||||||
z_mix = z_mix.view(B, C, N * D) # [B, C, patch_num * d_model]
|
|
||||||
y_pred = self.head(z_mix) # [B, C, pred_len]
|
|
||||||
|
|
||||||
return y_pred
|
def forward(self, x, encoder="Transformer"):
|
||||||
|
# x: [B, L, C]
|
||||||
|
x = x.permute(0, 2, 1) # x: [B, C, L]
|
||||||
|
|
||||||
|
# Patch the input
|
||||||
|
x_patch = x.unfold(-1, self.patch_len, self.stride) # x_patch: [B, C, patch_num, patch_len]
|
||||||
|
|
||||||
|
if encoder == "Transformer":
|
||||||
|
# Encode patches (PatchTST)
|
||||||
|
z = self.encoder(x_patch) # z: [B, C, d_model, patch_num]
|
||||||
|
|
||||||
|
# [B, C, d_model, patch_num] -> [B, C, patch_num, d_model]
|
||||||
|
B, C, D, N = z.shape # B: batch, C: channels, D: d_model, N: patch_num
|
||||||
|
z = z.permute(0, 1, 3, 2) # z: [B, C, patch_num, d_model]
|
||||||
|
|
||||||
|
# Cross-channel mixing
|
||||||
|
z_mix = self.mixer(z) # z_mix: [B, C, patch_num, d_model]
|
||||||
|
|
||||||
|
# Flatten and predict
|
||||||
|
z_mix = z_mix.view(B, C, N * D) # z_mix: [B, C, patch_num * d_model]
|
||||||
|
y_pred = self.head_tr(z_mix) # y_pred: [B, C, pred_len]
|
||||||
|
|
||||||
|
elif encoder == "Mamba2":
|
||||||
|
# 使用 Mamba2 编码器(不使用 mixer)
|
||||||
|
z_last = self.mamba_encoder(x_patch) # z_last: [B, C, d_model](仅最后一个时间步)
|
||||||
|
y_pred = self.head_mamba(z_last) # y_pred: [B, C, pred_len]
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported encoder type: {encoder}. Use 'Transformer' or 'Mamba2'.")
|
||||||
|
|
||||||
|
return y_pred # [B, C, pred_len]
|
||||||
|
@ -37,6 +37,8 @@ class Model(nn.Module):
|
|||||||
beta = getattr(configs, 'beta', torch.tensor(0.1))
|
beta = getattr(configs, 'beta', torch.tensor(0.1))
|
||||||
self.decomp = DECOMP(ma_type, alpha, beta)
|
self.decomp = DECOMP(ma_type, alpha, beta)
|
||||||
|
|
||||||
|
# 读取选择的编码器类型('Transformer' 或 'Mamba2')
|
||||||
|
self.season_encoder = getattr(configs, 'season_encoder', 'Transformer')
|
||||||
# Season network (PatchTST + Graph Mixer)
|
# Season network (PatchTST + Graph Mixer)
|
||||||
self.season_net = SeasonPatch(
|
self.season_net = SeasonPatch(
|
||||||
c_in=self.enc_in,
|
c_in=self.enc_in,
|
||||||
@ -90,7 +92,7 @@ class Model(nn.Module):
|
|||||||
seasonal_init, trend_init = self.decomp(x_enc)
|
seasonal_init, trend_init = self.decomp(x_enc)
|
||||||
|
|
||||||
# Season stream
|
# Season stream
|
||||||
y_season = self.season_net(seasonal_init) # [B, C, pred_len]
|
y_season = self.season_net(seasonal_init, encoder=self.season_encoder) # [B, C, pred_len]
|
||||||
|
|
||||||
# Trend stream
|
# Trend stream
|
||||||
B, L, C = trend_init.shape
|
B, L, C = trend_init.shape
|
||||||
@ -125,7 +127,7 @@ class Model(nn.Module):
|
|||||||
seasonal_init, trend_init = self.decomp(x_enc)
|
seasonal_init, trend_init = self.decomp(x_enc)
|
||||||
|
|
||||||
# Season stream
|
# Season stream
|
||||||
y_season = self.season_net(seasonal_init) # [B, C, pred_len]
|
y_season = self.season_net(seasonal_init, encoder=self.season_encoder) # [B, C, pred_len]
|
||||||
|
|
||||||
# print("shape:", trend_init.shape)
|
# print("shape:", trend_init.shape)
|
||||||
# Trend stream
|
# Trend stream
|
||||||
@ -163,4 +165,4 @@ class Model(nn.Module):
|
|||||||
dec_out = self.classification(x_enc, x_mark_enc)
|
dec_out = self.classification(x_enc, x_mark_enc)
|
||||||
return dec_out # [B, N]
|
return dec_out # [B, N]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Task {self.task_name} not supported by xPatch_SparseChannel')
|
raise ValueError(f'Task {self.task_name} not supported by xPatch_SparseChannel')
|
||||||
|
Reference in New Issue
Block a user