diff --git a/layers/MambaSeries.py b/layers/MambaSeries.py new file mode 100644 index 0000000..6fd05c1 --- /dev/null +++ b/layers/MambaSeries.py @@ -0,0 +1,62 @@ +import torch +import torch.nn as nn +from mamba_ssm import Mamba2 + + +class Mamba2Encoder(nn.Module): + """ + 使用 Mamba2 对 patch 维度进行序列建模: + 输入: [bs, nvars, patch_num, patch_len] + 映射: patch_len -> d_model + 建模: 在 patch_num 维度上用 Mamba2 + 输出: [bs, nvars, d_model] (仅返回 Mamba 输出的最后一个时间步) + """ + def __init__( + self, + c_in, + patch_num, + patch_len, + d_model=128, + # Mamba2 超参 + d_state=64, + d_conv=4, + expand=2, + headdim=64, + ): + super().__init__() + self.patch_num = patch_num + self.patch_len = patch_len + self.d_model = d_model + + # 将 patch_len 投影到 d_model + self.W_P = nn.Linear(patch_len, d_model) # 映射 patch_len -> d_model + + # 直接使用 Mamba2 对序列 (patch_num) 建模 + self.mamba = Mamba2( + d_model=d_model, + d_state=d_state, + d_conv=d_conv, + expand=expand, + headdim=headdim, + ) + + def forward(self, x): + # x: [bs, nvars, patch_num, patch_len] + bs, n_vars, patch_num, patch_len = x.shape # bs, n_vars, patch_num, patch_len + + # 1) 线性映射: patch_len -> d_model + x = self.W_P(x) # x: [bs, nvars, patch_num, d_model] + + # 2) 合并 batch 与通道维度,作为 Mamba 的 batch + u = x.reshape(bs * n_vars, patch_num, self.d_model) # u: [bs*nvars, patch_num, d_model] + + # 3) Mamba2 建模(在 patch_num 维度上) + y = self.mamba(u) # y: [bs*nvars, patch_num, d_model] + + # 4) 仅取最后一个时间步 + y_last = y[:, -1, :] # y_last: [bs*nvars, d_model] + + # 5) 还原回 (bs, nvars, d_model) + y_last = y_last.view(bs, n_vars, self.d_model) # y_last: [bs, nvars, d_model] + + return y_last # [bs, nvars, d_model] diff --git a/layers/SeasonPatch.py b/layers/SeasonPatch.py index 0d4d6a2..d0ce2e1 100644 --- a/layers/SeasonPatch.py +++ b/layers/SeasonPatch.py @@ -1,11 +1,15 @@ """ SeasonPatch = PatchTST (CI) + ChannelGraphMixer + Linear prediction head -Adapted for Time-Series-Library-main style +支持两种编码器: +- Transformer 编码器路径:PatchTST + GraphMixer + Head +- Mamba2 编码器路径:Mamba2Encoder(不使用mixer),直接用最后得到的 d_model 走 Head """ import torch import torch.nn as nn from layers.TSTEncoder import TSTiEncoder from layers.GraphMixer import HierarchicalGraphMixer +from layers.MambaSeries import Mamba2Encoder + class SeasonPatch(nn.Module): def __init__(self, @@ -17,17 +21,22 @@ class SeasonPatch(nn.Module): k_graph: int = 8, d_model: int = 128, n_layers: int = 3, - n_heads: int = 16): + n_heads: int = 16, + # Mamba2 相关可选超参 + d_state: int = 64, + d_conv: int = 4, + expand: int = 2, + headdim: int = 64): super().__init__() # Store patch parameters - self.patch_len = patch_len - self.stride = stride + self.patch_len = patch_len # patch 长度 + self.stride = stride # patch 步幅 # Calculate patch number - patch_num = (seq_len - patch_len) // stride + 1 - - # PatchTST encoder (channel independent) + patch_num = (seq_len - patch_len) // stride + 1 # patch_num: int + + # Transformer (PatchTST) 编码器(channel independent) self.encoder = TSTiEncoder( c_in=c_in, patch_num=patch_num, @@ -36,36 +45,64 @@ class SeasonPatch(nn.Module): n_layers=n_layers, n_heads=n_heads ) - - # Cross-channel mixer + + # Mamba2 编码器(channel independent),返回 [B, C, d_model] + self.mamba_encoder = Mamba2Encoder( + c_in=c_in, + patch_num=patch_num, + patch_len=patch_len, + d_model=d_model, + d_state=d_state, + d_conv=d_conv, + expand=expand, + headdim=headdim, + ) + + # Cross-channel mixer(仅 Transformer 路径使用) self.mixer = HierarchicalGraphMixer(c_in, dim=d_model, k=k_graph) - - # Prediction head - self.head = nn.Sequential( + + # Prediction head(Transformer 路径用到,输入维度为 patch_num * d_model) + self.head_tr = nn.Sequential( nn.Linear(patch_num * d_model, patch_num * d_model), - nn.SiLU(), # 非线性激活(SiLU/Swish) + nn.SiLU(), nn.Linear(patch_num * d_model, pred_len) ) - def forward(self, x): - # x: [B, L, C] - x = x.permute(0, 2, 1) # → [B, C, L] - - # Patch the input - x_patch = x.unfold(-1, self.patch_len, self.stride) # [B, C, patch_num, patch_len] - - # Encode patches - z = self.encoder(x_patch) # [B, C, d_model, patch_num] - - # z: [B, C, d_model, patch_num] → [B, C, patch_num, d_model] - B, C, D, N = z.shape - z = z.permute(0, 1, 3, 2) # [B, C, patch_num, d_model] - - # Cross-channel mixing - z_mix = self.mixer(z) # [B, C, patch_num, d_model] - - # Flatten and predict - z_mix = z_mix.view(B, C, N * D) # [B, C, patch_num * d_model] - y_pred = self.head(z_mix) # [B, C, pred_len] + # Prediction head(Mamba2 路径用到,输入维度为 d_model) + self.head_mamba = nn.Sequential( + nn.Linear(d_model, d_model), + nn.SiLU(), + nn.Linear(d_model, pred_len) + ) - return y_pred + def forward(self, x, encoder="Transformer"): + # x: [B, L, C] + x = x.permute(0, 2, 1) # x: [B, C, L] + + # Patch the input + x_patch = x.unfold(-1, self.patch_len, self.stride) # x_patch: [B, C, patch_num, patch_len] + + if encoder == "Transformer": + # Encode patches (PatchTST) + z = self.encoder(x_patch) # z: [B, C, d_model, patch_num] + + # [B, C, d_model, patch_num] -> [B, C, patch_num, d_model] + B, C, D, N = z.shape # B: batch, C: channels, D: d_model, N: patch_num + z = z.permute(0, 1, 3, 2) # z: [B, C, patch_num, d_model] + + # Cross-channel mixing + z_mix = self.mixer(z) # z_mix: [B, C, patch_num, d_model] + + # Flatten and predict + z_mix = z_mix.view(B, C, N * D) # z_mix: [B, C, patch_num * d_model] + y_pred = self.head_tr(z_mix) # y_pred: [B, C, pred_len] + + elif encoder == "Mamba2": + # 使用 Mamba2 编码器(不使用 mixer) + z_last = self.mamba_encoder(x_patch) # z_last: [B, C, d_model](仅最后一个时间步) + y_pred = self.head_mamba(z_last) # y_pred: [B, C, pred_len] + + else: + raise ValueError(f"Unsupported encoder type: {encoder}. Use 'Transformer' or 'Mamba2'.") + + return y_pred # [B, C, pred_len] diff --git a/models/xPatch_SparseChannel.py b/models/xPatch_SparseChannel.py index ea3862f..af6e599 100644 --- a/models/xPatch_SparseChannel.py +++ b/models/xPatch_SparseChannel.py @@ -37,6 +37,8 @@ class Model(nn.Module): beta = getattr(configs, 'beta', torch.tensor(0.1)) self.decomp = DECOMP(ma_type, alpha, beta) + # 读取选择的编码器类型('Transformer' 或 'Mamba2') + self.season_encoder = getattr(configs, 'season_encoder', 'Transformer') # Season network (PatchTST + Graph Mixer) self.season_net = SeasonPatch( c_in=self.enc_in, @@ -90,7 +92,7 @@ class Model(nn.Module): seasonal_init, trend_init = self.decomp(x_enc) # Season stream - y_season = self.season_net(seasonal_init) # [B, C, pred_len] + y_season = self.season_net(seasonal_init, encoder=self.season_encoder) # [B, C, pred_len] # Trend stream B, L, C = trend_init.shape @@ -125,7 +127,7 @@ class Model(nn.Module): seasonal_init, trend_init = self.decomp(x_enc) # Season stream - y_season = self.season_net(seasonal_init) # [B, C, pred_len] + y_season = self.season_net(seasonal_init, encoder=self.season_encoder) # [B, C, pred_len] # print("shape:", trend_init.shape) # Trend stream @@ -163,4 +165,4 @@ class Model(nn.Module): dec_out = self.classification(x_enc, x_mark_enc) return dec_out # [B, N] else: - raise ValueError(f'Task {self.task_name} not supported by xPatch_SparseChannel') \ No newline at end of file + raise ValueError(f'Task {self.task_name} not supported by xPatch_SparseChannel')