Files
TSlib/models/xPatch_SparseChannel.py

169 lines
6.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
xPatch_SparseChannel model adapted for Time-Series-Library-main
Supports both long-term forecasting and classification tasks
"""
import torch
import torch.nn as nn
from layers.DECOMP import DECOMP
from layers.SeasonPatch import SeasonPatch
from layers.RevIN import RevIN
class Model(nn.Module):
"""
xPatch SparseChannel Model
"""
def __init__(self, configs):
super(Model, self).__init__()
# Model configuration
self.task_name = configs.task_name
self.seq_len = configs.seq_len
self.pred_len = configs.pred_len
self.enc_in = configs.enc_in
# Model parameters
self.patch_len = getattr(configs, 'patch_len', 16)
self.stride = getattr(configs, 'stride', 8)
# Normalization
self.revin = getattr(configs, 'revin', True)
if self.revin:
self.revin_layer = RevIN(self.enc_in, affine=True, subtract_last=False)
# Decomposition using original DECOMP with EMA/DEMA
ma_type = getattr(configs, 'ma_type', 'ema')
alpha = getattr(configs, 'alpha', torch.tensor(0.1))
beta = getattr(configs, 'beta', torch.tensor(0.1))
self.decomp = DECOMP(ma_type, alpha, beta)
# Season network (PatchTST + Graph Mixer)
self.season_net = SeasonPatch(
c_in=self.enc_in,
seq_len=self.seq_len,
pred_len=self.pred_len,
patch_len=self.patch_len,
stride=self.stride,
k_graph=getattr(configs, 'k_graph', 8),
d_model=getattr(configs, 'd_model', 128),
n_layers=getattr(configs, 'e_layers', 3),
n_heads=getattr(configs, 'n_heads', 16),
# 读取选择的编码器类型('Transformer' 或 'Mamba2'
encoder_type = getattr(configs, 'season_encoder', 'Transformer')
)
# Trend network (MLP)
self.fc5 = nn.Linear(self.seq_len, self.pred_len * 4)
self.avgpool1 = nn.AvgPool1d(kernel_size=2)
self.ln1 = nn.LayerNorm(self.pred_len * 2)
self.fc6 = nn.Linear(self.pred_len * 2, self.pred_len)
self.avgpool2 = nn.AvgPool1d(kernel_size=2)
self.ln2 = nn.LayerNorm(self.pred_len // 2)
self.fc7 = nn.Linear(self.pred_len // 2, self.pred_len)
# Task-specific heads
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
self.fc_final = nn.Linear(self.pred_len * 2, self.pred_len)
elif self.task_name == 'classification':
self.season_attention = nn.Sequential(
nn.Linear(self.pred_len, 64),
nn.Tanh(),
nn.Linear(64, 1)
)
self.trend_attention = nn.Sequential(
nn.Linear(self.pred_len, 64),
nn.Tanh(),
nn.Linear(64, 1)
)
self.classifier = nn.Sequential(
nn.Linear(self.enc_in * 2, 128),
nn.ReLU(),
nn.Dropout(getattr(configs, 'dropout', 0.1)),
nn.Linear(128, configs.num_class)
)
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
"""Long-term forecasting"""
# Normalization
if self.revin:
x_enc = self.revin_layer(x_enc, 'norm')
# Decomposition
seasonal_init, trend_init = self.decomp(x_enc)
# Season stream
y_season = self.season_net(seasonal_init, encoder=self.season_encoder) # [B, C, pred_len]
# Trend stream
B, L, C = trend_init.shape
trend = trend_init.permute(0, 2, 1).reshape(B * C, L) # [B*C, L]
trend = self.fc5(trend)
trend = self.avgpool1(trend)
trend = self.ln1(trend)
trend = self.fc6(trend)
trend = self.avgpool2(trend)
trend = self.ln2(trend)
trend = self.fc7(trend) # [B*C, pred_len]
y_trend = trend.view(B, C, -1) # [B, C, pred_len]
# Combine streams
y = torch.cat([y_season, y_trend], dim=-1) # [B, C, 2*pred_len]
y = self.fc_final(y) # [B, C, pred_len]
y = y.permute(0, 2, 1) # [B, pred_len, C]
# Denormalization
if self.revin:
y = self.revin_layer(y, 'denorm')
return y
def classification(self, x_enc, x_mark_enc):
"""Classification task"""
# Normalization
#if self.revin:
# x_enc = self.revin_layer(x_enc, 'norm')
# Decomposition
seasonal_init, trend_init = self.decomp(x_enc)
# Season stream
y_season = self.season_net(seasonal_init, encoder=self.season_encoder) # [B, C, pred_len]
# print("shape:", trend_init.shape)
# Trend stream
B, L, C = trend_init.shape
trend = trend_init.permute(0, 2, 1).reshape(B * C, L) # [B*C, L]
trend = self.fc5(trend)
trend = self.avgpool1(trend)
trend = self.ln1(trend)
trend = self.fc6(trend)
trend = self.avgpool2(trend)
trend = self.ln2(trend)
trend = self.fc7(trend) # [B*C, pred_len]
y_trend = trend.view(B, C, -1) # [B, C, pred_len]
# Attention-based pooling for classification
season_attn_weights = torch.softmax(y_season, dim=-1)
season_pooled = (y_season * season_attn_weights).sum(dim=-1) # [B, C]
trend_attn_weights = torch.softmax(y_trend, dim=-1) # 时间维
trend_pooled = (y_trend * trend_attn_weights).sum(dim=-1) # [B, C]
# Combine features
features = torch.cat([season_pooled, trend_pooled], dim=-1) # [B, 2*C]
# Classification
logits = self.classifier(features) # [B, num_classes]
return logits
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
"""Forward pass dispatching to task-specific methods"""
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
return dec_out[:, -self.pred_len:, :] # [B, L, D]
elif self.task_name == 'classification':
dec_out = self.classification(x_enc, x_mark_enc)
return dec_out # [B, N]
else:
raise ValueError(f'Task {self.task_name} not supported by xPatch_SparseChannel')