Files
TSlib/models/xPatch_SparseChannel.py

191 lines
7.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
xPatch_SparseChannel model adapted for Time-Series-Library-main
Supports both long-term forecasting and classification tasks
"""
import torch
import torch.nn as nn
from layers.DECOMP import DECOMP
from layers.SeasonPatch import SeasonPatch
from layers.RevIN import RevIN
class Model(nn.Module):
"""
xPatch SparseChannel Model
"""
def __init__(self, configs):
super(Model, self).__init__()
# Model configuration
self.task_name = configs.task_name
self.seq_len = configs.seq_len
self.pred_len = configs.pred_len
self.enc_in = configs.enc_in
# Patch parameters
self.patch_len = getattr(configs, 'patch_len', 16)
self.stride = getattr(configs, 'stride', 8)
# Normalization
self.revin = getattr(configs, 'revin', True)
if self.revin:
self.revin_layer = RevIN(self.enc_in, affine=True, subtract_last=False)
# Decomposition using original DECOMP with EMA/DEMA
ma_type = getattr(configs, 'ma_type', 'ema')
alpha = getattr(configs, 'alpha', torch.tensor(0.1))
beta = getattr(configs, 'beta', torch.tensor(0.1))
self.decomp = DECOMP(ma_type, alpha, beta)
# Season network (PatchTST/Mamba2 + Graph Mixer)
# 透传新版 SeasonPatch 的参数(其中 GraphMixer 替换为非归一化 Hard-Concrete 门控)
self.season_net = SeasonPatch(
c_in=self.enc_in,
seq_len=self.seq_len,
pred_len=self.pred_len,
patch_len=self.patch_len,
stride=self.stride,
# 编码器类型:'Transformer' or 'Mamba2'
encoder_type=getattr(configs, 'season_encoder', 'Transformer'),
# Patch相关
d_model=getattr(configs, 'd_model', 128),
n_layers=getattr(configs, 'e_layers', 3),
n_heads=getattr(configs, 'n_heads', 16),
# GraphMixer相关非归一化
k_graph=getattr(configs, 'k_graph', 8), # -> max_degree
thr_graph=getattr(configs, 'thr_graph', 0.5),
thr_graph_min=getattr(configs, 'thr_graph_min', None),
thr_graph_max=getattr(configs, 'thr_graph_max', None),
thr_graph_steps=getattr(configs, 'thr_graph_steps', 0),
thr_graph_schedule=getattr(configs, 'thr_graph_schedule', 'linear'),
symmetric_graph=getattr(configs, 'symmetric_graph', True),
degree_rescale=getattr(configs, 'degree_rescale', 'count-sqrt'), # 'none' | 'count' | 'count-sqrt' | 'sum'
gate_temperature=getattr(configs, 'gate_temperature', 2.0/3.0),
tau_attn=getattr(configs, 'tau_attn', 1.0),
l0_lambda=getattr(configs, 'season_l0_lambda', 0.0),
# Mamba2相关
d_state=getattr(configs, 'd_state', 64),
d_conv=getattr(configs, 'd_conv', 4),
expand=getattr(configs, 'expand', 2),
headdim=getattr(configs, 'headdim', 64),
)
# Trend network (MLP)
self.fc5 = nn.Linear(self.seq_len, self.pred_len * 4)
self.avgpool1 = nn.AvgPool1d(kernel_size=2)
self.ln1 = nn.LayerNorm(self.pred_len * 2)
self.fc6 = nn.Linear(self.pred_len * 2, self.pred_len)
self.avgpool2 = nn.AvgPool1d(kernel_size=2)
self.ln2 = nn.LayerNorm(self.pred_len // 2)
self.fc7 = nn.Linear(self.pred_len // 2, self.pred_len)
# Task-specific heads
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
self.fc_final = nn.Linear(self.pred_len * 2, self.pred_len)
elif self.task_name == 'classification':
self.season_attention = nn.Sequential(
nn.Linear(self.pred_len, 64),
nn.Tanh(),
nn.Linear(64, 1)
)
self.trend_attention = nn.Sequential(
nn.Linear(self.pred_len, 64),
nn.Tanh(),
nn.Linear(64, 1)
)
self.classifier = nn.Sequential(
nn.Linear(self.enc_in * 2, 128),
nn.ReLU(),
nn.Dropout(getattr(configs, 'dropout', 0.1)),
nn.Linear(128, configs.num_class)
)
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec, training):
"""Long-term forecasting"""
# Normalization
if self.revin:
x_enc = self.revin_layer(x_enc, 'norm')
# Decomposition
seasonal_init, trend_init = self.decomp(x_enc)
# Season stream
y_season = self.season_net(seasonal_init, training) # [B, C, pred_len]
# Trend stream
B, L, C = trend_init.shape
trend = trend_init.permute(0, 2, 1).reshape(B * C, L) # [B*C, L]
trend = self.fc5(trend)
trend = self.avgpool1(trend)
trend = self.ln1(trend)
trend = self.fc6(trend)
trend = self.avgpool2(trend)
trend = self.ln2(trend)
trend = self.fc7(trend) # [B*C, pred_len]
y_trend = trend.view(B, C, -1) # [B, C, pred_len]
# Combine streams
y = torch.cat([y_season, y_trend], dim=-1) # [B, C, 2*pred_len]
y = self.fc_final(y) # [B, C, pred_len]
y = y.permute(0, 2, 1) # [B, pred_len, C]
# Denormalization
if self.revin:
y = self.revin_layer(y, 'denorm')
return y
def classification(self, x_enc, x_mark_enc):
"""Classification task"""
# Decomposition分类任务通常可不做 RevIN如需可自行打开
seasonal_init, trend_init = self.decomp(x_enc)
# Season stream
y_season = self.season_net(seasonal_init) # [B, C, pred_len]
# Trend stream
B, L, C = trend_init.shape
trend = trend_init.permute(0, 2, 1).reshape(B * C, L) # [B*C, L]
trend = self.fc5(trend)
trend = self.avgpool1(trend)
trend = self.ln1(trend)
trend = self.fc6(trend)
trend = self.avgpool2(trend)
trend = self.ln2(trend)
trend = self.fc7(trend) # [B*C, pred_len]
y_trend = trend.view(B, C, -1) # [B, C, pred_len]
# Attention-based pooling for classification
season_attn_weights = torch.softmax(y_season, dim=-1)
season_pooled = (y_season * season_attn_weights).sum(dim=-1) # [B, C]
trend_attn_weights = torch.softmax(y_trend, dim=-1)
trend_pooled = (y_trend * trend_attn_weights).sum(dim=-1) # [B, C]
# Combine features
features = torch.cat([season_pooled, trend_pooled], dim=-1) # [B, 2*C]
# Classification
logits = self.classifier(features) # [B, num_classes]
return logits
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None, training=True):
"""Forward pass dispatching to task-specific methods"""
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec, training)
return dec_out[:, -self.pred_len:, :] # [B, L, D]
elif self.task_name == 'classification':
dec_out = self.classification(x_enc, x_mark_enc)
return dec_out # [B, N]
else:
raise ValueError(f'Task {self.task_name} not supported by xPatch_SparseChannel')
def reg_loss(self):
"""
L0 正则项(仅在 Transformer 路径启用 GraphMixer 时非零)。
训练时total_loss = main_loss + model.reg_loss()
"""
if hasattr(self, "season_net") and hasattr(self.season_net, "reg_loss"):
return self.season_net.reg_loss()
return torch.tensor(0.0, device=next(self.parameters()).device)