Files
TSlib/models/xPatch_SparseChannel.py
2025-08-28 10:17:59 +00:00

166 lines
6.1 KiB
Python

"""
xPatch_SparseChannel model adapted for Time-Series-Library-main
Supports both long-term forecasting and classification tasks
"""
import torch
import torch.nn as nn
from layers.DECOMP import DECOMP
from layers.SeasonPatch import SeasonPatch
from layers.RevIN import RevIN
class Model(nn.Module):
"""
xPatch SparseChannel Model
"""
def __init__(self, configs):
super(Model, self).__init__()
# Model configuration
self.task_name = configs.task_name
self.seq_len = configs.seq_len
self.pred_len = configs.pred_len
self.enc_in = configs.enc_in
# Model parameters
self.patch_len = getattr(configs, 'patch_len', 16)
self.stride = getattr(configs, 'stride', 8)
# Normalization
self.revin = getattr(configs, 'revin', True)
if self.revin:
self.revin_layer = RevIN(self.enc_in, affine=True, subtract_last=False)
# Decomposition using original DECOMP with EMA/DEMA
ma_type = getattr(configs, 'ma_type', 'ema')
alpha = getattr(configs, 'alpha', torch.tensor(0.1))
beta = getattr(configs, 'beta', torch.tensor(0.1))
self.decomp = DECOMP(ma_type, alpha, beta)
# Season network (PatchTST + Graph Mixer)
self.season_net = SeasonPatch(
c_in=self.enc_in,
seq_len=self.seq_len,
pred_len=self.pred_len,
patch_len=self.patch_len,
stride=self.stride,
k_graph=getattr(configs, 'k_graph', 8),
d_model=getattr(configs, 'd_model', 128),
n_layers=getattr(configs, 'e_layers', 3),
n_heads=getattr(configs, 'n_heads', 16)
)
# Trend network (MLP)
self.fc5 = nn.Linear(self.seq_len, self.pred_len * 4)
self.avgpool1 = nn.AvgPool1d(kernel_size=2)
self.ln1 = nn.LayerNorm(self.pred_len * 2)
self.fc6 = nn.Linear(self.pred_len * 2, self.pred_len)
self.avgpool2 = nn.AvgPool1d(kernel_size=2)
self.ln2 = nn.LayerNorm(self.pred_len // 2)
self.fc7 = nn.Linear(self.pred_len // 2, self.pred_len)
# Task-specific heads
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
self.fc_final = nn.Linear(self.pred_len * 2, self.pred_len)
elif self.task_name == 'classification':
self.season_attention = nn.Sequential(
nn.Linear(self.pred_len, 64),
nn.Tanh(),
nn.Linear(64, 1)
)
self.trend_attention = nn.Sequential(
nn.Linear(self.pred_len, 64),
nn.Tanh(),
nn.Linear(64, 1)
)
self.classifier = nn.Sequential(
nn.Linear(self.enc_in * 2, 128),
nn.ReLU(),
nn.Dropout(getattr(configs, 'dropout', 0.1)),
nn.Linear(128, configs.num_class)
)
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
"""Long-term forecasting"""
# Normalization
if self.revin:
x_enc = self.revin_layer(x_enc, 'norm')
# Decomposition
seasonal_init, trend_init = self.decomp(x_enc)
# Season stream
y_season = self.season_net(seasonal_init) # [B, C, pred_len]
# Trend stream
B, L, C = trend_init.shape
trend = trend_init.permute(0, 2, 1).reshape(B * C, L) # [B*C, L]
trend = self.fc5(trend)
trend = self.avgpool1(trend)
trend = self.ln1(trend)
trend = self.fc6(trend)
trend = self.avgpool2(trend)
trend = self.ln2(trend)
trend = self.fc7(trend) # [B*C, pred_len]
y_trend = trend.view(B, C, -1) # [B, C, pred_len]
# Combine streams
y = torch.cat([y_season, y_trend], dim=-1) # [B, C, 2*pred_len]
y = self.fc_final(y) # [B, C, pred_len]
y = y.permute(0, 2, 1) # [B, pred_len, C]
# Denormalization
if self.revin:
y = self.revin_layer(y, 'denorm')
return y
def classification(self, x_enc, x_mark_enc):
"""Classification task"""
# Normalization
#if self.revin:
# x_enc = self.revin_layer(x_enc, 'norm')
# Decomposition
seasonal_init, trend_init = self.decomp(x_enc)
# Season stream
y_season = self.season_net(seasonal_init) # [B, C, pred_len]
# print("shape:", trend_init.shape)
# Trend stream
B, L, C = trend_init.shape
trend = trend_init.permute(0, 2, 1).reshape(B * C, L) # [B*C, L]
trend = self.fc5(trend)
trend = self.avgpool1(trend)
trend = self.ln1(trend)
trend = self.fc6(trend)
trend = self.avgpool2(trend)
trend = self.ln2(trend)
trend = self.fc7(trend) # [B*C, pred_len]
y_trend = trend.view(B, C, -1) # [B, C, pred_len]
# Attention-based pooling for classification
season_attn_weights = torch.softmax(y_season, dim=-1)
season_pooled = (y_season * season_attn_weights).sum(dim=-1) # [B, C]
trend_attn_weights = torch.softmax(y_trend, dim=-1) # 时间维
trend_pooled = (y_trend * trend_attn_weights).sum(dim=-1) # [B, C]
# Combine features
features = torch.cat([season_pooled, trend_pooled], dim=-1) # [B, 2*C]
# Classification
logits = self.classifier(features) # [B, num_classes]
return logits
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
"""Forward pass dispatching to task-specific methods"""
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
return dec_out[:, -self.pred_len:, :] # [B, L, D]
elif self.task_name == 'classification':
dec_out = self.classification(x_enc, x_mark_enc)
return dec_out # [B, N]
else:
raise ValueError(f'Task {self.task_name} not supported by xPatch_SparseChannel')