137 lines
5.4 KiB
Python
137 lines
5.4 KiB
Python
"""
|
||
SeasonPatch = PatchTST (CI) + ChannelGraphMixer + Linear prediction head
|
||
支持两种编码器:
|
||
- Transformer 编码器路径:PatchTST + GraphMixer + Head
|
||
- Mamba2 编码器路径:Mamba2Encoder(不使用mixer),直接用最后得到的 d_model 走 Head
|
||
"""
|
||
import torch
|
||
import torch.nn as nn
|
||
from layers.TSTEncoder import TSTiEncoder
|
||
from layers.GraphMixer import HierarchicalGraphMixer
|
||
from layers.MambaSeries import Mamba2Encoder
|
||
|
||
|
||
class SeasonPatch(nn.Module):
|
||
def __init__(self,
|
||
c_in: int,
|
||
seq_len: int,
|
||
pred_len: int,
|
||
patch_len: int,
|
||
stride: int,
|
||
k_graph: int = 8,
|
||
encoder_type: str = "Transformer",
|
||
d_model: int = 128,
|
||
n_layers: int = 3,
|
||
n_heads: int = 16,
|
||
# Mamba2 相关可选超参
|
||
d_state: int = 64,
|
||
d_conv: int = 4,
|
||
expand: int = 2,
|
||
headdim: int = 64,
|
||
# Mixergraph 可选超参数
|
||
thr_graph: float = 0.5,
|
||
thr_graph_min: float = None,
|
||
thr_graph_max: float = None,
|
||
thr_graph_steps: int = 0,
|
||
thr_graph_schedule: str = "linear",
|
||
symmetric_graph: bool = True,
|
||
degree_rescale: str = "count-sqrt", # "none" | "count" | "count-sqrt" | "sum"
|
||
gate_temperature: float = 2./3.,
|
||
tau_attn: float = 1.0,
|
||
l0_lambda: float = 1e-4):
|
||
|
||
super().__init__()
|
||
|
||
# ===== 新增:保存 l0_lambda,防止 reg_loss 访问报错 =====
|
||
self.l0_lambda = l0_lambda
|
||
|
||
# Store patch parameters
|
||
self.patch_len = patch_len # patch 长度
|
||
self.stride = stride # patch 步幅
|
||
|
||
# Calculate patch number
|
||
patch_num = (seq_len - patch_len) // stride + 1 # patch_num: int
|
||
|
||
self.encoder_type = encoder_type
|
||
# 只初始化需要的encoder
|
||
|
||
if encoder_type == "Transformer":
|
||
# Transformer (PatchTST) 编码器(channel independent)
|
||
self.encoder = TSTiEncoder(
|
||
c_in=c_in, patch_num=patch_num, patch_len=patch_len,
|
||
d_model=d_model, n_layers=n_layers, n_heads=n_heads
|
||
)
|
||
# 集成新 HierarchicalGraphMixer(非归一化)
|
||
self.mixer = HierarchicalGraphMixer(
|
||
n_channel=c_in,
|
||
dim=d_model,
|
||
max_degree=k_graph,
|
||
thr=thr_graph,
|
||
thr_min=thr_graph_min,
|
||
thr_max=thr_graph_max,
|
||
thr_steps=thr_graph_steps,
|
||
thr_schedule=thr_graph_schedule,
|
||
temperature=gate_temperature,
|
||
tau_attn=tau_attn,
|
||
symmetric=symmetric_graph,
|
||
degree_rescale=degree_rescale
|
||
)
|
||
# Prediction head(Transformer 路径用到,输入维度为 patch_num * d_model)
|
||
self.head = nn.Sequential(
|
||
nn.Linear(patch_num * d_model, patch_num * d_model),
|
||
nn.SiLU(),
|
||
nn.Linear(patch_num * d_model, pred_len)
|
||
)
|
||
elif encoder_type == "Mamba2":
|
||
# Mamba2 编码器(channel independent),返回 [B, C, d_model]
|
||
self.encoder = Mamba2Encoder(
|
||
c_in=c_in, patch_num=patch_num, patch_len=patch_len,
|
||
d_model=d_model, d_state=d_state, d_conv=d_conv,
|
||
expand=expand, headdim=headdim, n_layers=n_layers
|
||
)
|
||
# Prediction head(Mamba2 路径用到,输入维度为 d_model)
|
||
self.head = nn.Sequential(
|
||
nn.Linear(d_model, d_model),
|
||
nn.SiLU(),
|
||
nn.Linear(d_model, pred_len)
|
||
)
|
||
else:
|
||
raise ValueError(f"Unsupported encoder_type: {encoder_type}")
|
||
|
||
def forward(self, x ):
|
||
# x: [B, L, C]
|
||
x = x.permute(0, 2, 1) # x: [B, C, L]
|
||
|
||
# Patch the input
|
||
x_patch = x.unfold(-1, self.patch_len, self.stride) # x_patch: [B, C, patch_num, patch_len]
|
||
|
||
if self.encoder_type == "Transformer":
|
||
# Encode patches (PatchTST)
|
||
z = self.encoder(x_patch) # z: [B, C, d_model, patch_num]
|
||
|
||
# [B, C, d_model, patch_num] -> [B, C, patch_num, d_model]
|
||
B, C, D, N = z.shape # B: batch, C: channels, D: d_model, N: patch_num
|
||
z = z.permute(0, 1, 3, 2) # z: [B, C, patch_num, d_model]
|
||
|
||
# Cross-channel mixing
|
||
z_mix = self.mixer(z) # z_mix: [B, C, patch_num, d_model]
|
||
|
||
# Flatten and predict
|
||
z_mix = z_mix.view(B, C, N * D) # z_mix: [B, C, patch_num * d_model]
|
||
y_pred = self.head(z_mix) # y_pred: [B, C, pred_len]
|
||
|
||
elif self.encoder_type == "Mamba2":
|
||
# 使用 Mamba2 编码器(不使用 mixer)
|
||
z_last = self.encoder(x_patch) # z_last: [B, C, d_model](仅最后一个时间步)
|
||
y_pred = self.head(z_last) # y_pred: [B, C, pred_len]
|
||
|
||
return y_pred # [B, C, pred_len]
|
||
|
||
def reg_loss(self):
|
||
"""
|
||
可选:把 L0 正则暴露出去,训练时加到总loss。
|
||
"""
|
||
if self.encoder_type == "Transformer" and hasattr(self, "mixer"):
|
||
return self.mixer.l0_loss(self.l0_lambda)
|
||
return torch.tensor(0.0, device=self.head[0].weight.device)
|