Files
TSlib/layers/SeasonPatch.py

100 lines
3.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
SeasonPatch = PatchTST (CI) + ChannelGraphMixer + Linear prediction head
支持两种编码器:
- Transformer 编码器路径PatchTST + GraphMixer + Head
- Mamba2 编码器路径Mamba2Encoder不使用mixer直接用最后得到的 d_model 走 Head
"""
import torch
import torch.nn as nn
from layers.TSTEncoder import TSTiEncoder
from layers.GraphMixer import HierarchicalGraphMixer
from layers.MambaSeries import Mamba2Encoder
class SeasonPatch(nn.Module):
def __init__(self,
c_in: int,
seq_len: int,
pred_len: int,
patch_len: int,
stride: int,
k_graph: int = 8,
encoder_type: str = "Transformer",
d_model: int = 128,
n_layers: int = 3,
n_heads: int = 16,
# Mamba2 相关可选超参
d_state: int = 64,
d_conv: int = 4,
expand: int = 2,
headdim: int = 64):
super().__init__()
# Store patch parameters
self.patch_len = patch_len # patch 长度
self.stride = stride # patch 步幅
# Calculate patch number
patch_num = (seq_len - patch_len) // stride + 1 # patch_num: int
self.encoder_type = encoder_type
# 只初始化需要的encoder
if encoder_type == "Transformer":
# Transformer (PatchTST) 编码器channel independent
self.encoder = TSTiEncoder(
c_in=c_in, patch_num=patch_num, patch_len=patch_len,
d_model=d_model, n_layers=n_layers, n_heads=n_heads
)
self.mixer = HierarchicalGraphMixer(c_in, dim=d_model, k=k_graph)
# Prediction headTransformer 路径用到,输入维度为 patch_num * d_model
self.head = nn.Sequential(
nn.Linear(patch_num * d_model, patch_num * d_model),
nn.SiLU(),
nn.Linear(patch_num * d_model, pred_len)
)
elif encoder_type == "Mamba2":
# Mamba2 编码器channel independent返回 [B, C, d_model]
self.encoder = Mamba2Encoder(
c_in=c_in, patch_num=patch_num, patch_len=patch_len,
d_model=d_model, d_state=d_state, d_conv=d_conv,
expand=expand, headdim=headdim, n_layers=n_layers
)
# Prediction headMamba2 路径用到,输入维度为 d_model
self.head = nn.Sequential(
nn.Linear(d_model, d_model),
nn.SiLU(),
nn.Linear(d_model, pred_len)
)
else:
raise ValueError(f"Unsupported encoder_type: {encoder_type}")
def forward(self, x ):
# x: [B, L, C]
x = x.permute(0, 2, 1) # x: [B, C, L]
# Patch the input
x_patch = x.unfold(-1, self.patch_len, self.stride) # x_patch: [B, C, patch_num, patch_len]
if self.encoder_type == "Transformer":
# Encode patches (PatchTST)
z = self.encoder(x_patch) # z: [B, C, d_model, patch_num]
# [B, C, d_model, patch_num] -> [B, C, patch_num, d_model]
B, C, D, N = z.shape # B: batch, C: channels, D: d_model, N: patch_num
z = z.permute(0, 1, 3, 2) # z: [B, C, patch_num, d_model]
# Cross-channel mixing
z_mix = self.mixer(z) # z_mix: [B, C, patch_num, d_model]
# Flatten and predict
z_mix = z_mix.view(B, C, N * D) # z_mix: [B, C, patch_num * d_model]
y_pred = self.head(z_mix) # y_pred: [B, C, pred_len]
elif self.encoder_type == "Mamba2":
# 使用 Mamba2 编码器(不使用 mixer
z_last = self.encoder(x_patch) # z_last: [B, C, d_model](仅最后一个时间步)
y_pred = self.head(z_last) # y_pred: [B, C, pred_len]
return y_pred # [B, C, pred_len]