169 lines
6.4 KiB
Python
169 lines
6.4 KiB
Python
"""
|
||
xPatch_SparseChannel model adapted for Time-Series-Library-main
|
||
Supports both long-term forecasting and classification tasks
|
||
"""
|
||
import torch
|
||
import torch.nn as nn
|
||
from layers.DECOMP import DECOMP
|
||
from layers.SeasonPatch import SeasonPatch
|
||
from layers.RevIN import RevIN
|
||
|
||
class Model(nn.Module):
|
||
"""
|
||
xPatch SparseChannel Model
|
||
"""
|
||
|
||
def __init__(self, configs):
|
||
super(Model, self).__init__()
|
||
|
||
# Model configuration
|
||
self.task_name = configs.task_name
|
||
self.seq_len = configs.seq_len
|
||
self.pred_len = configs.pred_len
|
||
self.enc_in = configs.enc_in
|
||
|
||
# Model parameters
|
||
self.patch_len = getattr(configs, 'patch_len', 16)
|
||
self.stride = getattr(configs, 'stride', 8)
|
||
|
||
# Normalization
|
||
self.revin = getattr(configs, 'revin', True)
|
||
if self.revin:
|
||
self.revin_layer = RevIN(self.enc_in, affine=True, subtract_last=False)
|
||
|
||
# Decomposition using original DECOMP with EMA/DEMA
|
||
ma_type = getattr(configs, 'ma_type', 'ema')
|
||
alpha = getattr(configs, 'alpha', torch.tensor(0.1))
|
||
beta = getattr(configs, 'beta', torch.tensor(0.1))
|
||
self.decomp = DECOMP(ma_type, alpha, beta)
|
||
|
||
# Season network (PatchTST + Graph Mixer)
|
||
self.season_net = SeasonPatch(
|
||
c_in=self.enc_in,
|
||
seq_len=self.seq_len,
|
||
pred_len=self.pred_len,
|
||
patch_len=self.patch_len,
|
||
stride=self.stride,
|
||
k_graph=getattr(configs, 'k_graph', 8),
|
||
d_model=getattr(configs, 'd_model', 128),
|
||
n_layers=getattr(configs, 'e_layers', 3),
|
||
n_heads=getattr(configs, 'n_heads', 16),
|
||
# 读取选择的编码器类型('Transformer' 或 'Mamba2')
|
||
encoder_type = getattr(configs, 'season_encoder', 'Transformer')
|
||
)
|
||
|
||
# Trend network (MLP)
|
||
self.fc5 = nn.Linear(self.seq_len, self.pred_len * 4)
|
||
self.avgpool1 = nn.AvgPool1d(kernel_size=2)
|
||
self.ln1 = nn.LayerNorm(self.pred_len * 2)
|
||
self.fc6 = nn.Linear(self.pred_len * 2, self.pred_len)
|
||
self.avgpool2 = nn.AvgPool1d(kernel_size=2)
|
||
self.ln2 = nn.LayerNorm(self.pred_len // 2)
|
||
self.fc7 = nn.Linear(self.pred_len // 2, self.pred_len)
|
||
|
||
# Task-specific heads
|
||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||
self.fc_final = nn.Linear(self.pred_len * 2, self.pred_len)
|
||
elif self.task_name == 'classification':
|
||
self.season_attention = nn.Sequential(
|
||
nn.Linear(self.pred_len, 64),
|
||
nn.Tanh(),
|
||
nn.Linear(64, 1)
|
||
)
|
||
self.trend_attention = nn.Sequential(
|
||
nn.Linear(self.pred_len, 64),
|
||
nn.Tanh(),
|
||
nn.Linear(64, 1)
|
||
)
|
||
self.classifier = nn.Sequential(
|
||
nn.Linear(self.enc_in * 2, 128),
|
||
nn.ReLU(),
|
||
nn.Dropout(getattr(configs, 'dropout', 0.1)),
|
||
nn.Linear(128, configs.num_class)
|
||
)
|
||
|
||
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
||
"""Long-term forecasting"""
|
||
# Normalization
|
||
if self.revin:
|
||
x_enc = self.revin_layer(x_enc, 'norm')
|
||
|
||
# Decomposition
|
||
seasonal_init, trend_init = self.decomp(x_enc)
|
||
|
||
# Season stream
|
||
y_season = self.season_net(seasonal_init, encoder=self.season_encoder) # [B, C, pred_len]
|
||
|
||
# Trend stream
|
||
B, L, C = trend_init.shape
|
||
trend = trend_init.permute(0, 2, 1).reshape(B * C, L) # [B*C, L]
|
||
trend = self.fc5(trend)
|
||
trend = self.avgpool1(trend)
|
||
trend = self.ln1(trend)
|
||
trend = self.fc6(trend)
|
||
trend = self.avgpool2(trend)
|
||
trend = self.ln2(trend)
|
||
trend = self.fc7(trend) # [B*C, pred_len]
|
||
y_trend = trend.view(B, C, -1) # [B, C, pred_len]
|
||
|
||
# Combine streams
|
||
y = torch.cat([y_season, y_trend], dim=-1) # [B, C, 2*pred_len]
|
||
y = self.fc_final(y) # [B, C, pred_len]
|
||
y = y.permute(0, 2, 1) # [B, pred_len, C]
|
||
|
||
# Denormalization
|
||
if self.revin:
|
||
y = self.revin_layer(y, 'denorm')
|
||
|
||
return y
|
||
|
||
def classification(self, x_enc, x_mark_enc):
|
||
"""Classification task"""
|
||
# Normalization
|
||
#if self.revin:
|
||
# x_enc = self.revin_layer(x_enc, 'norm')
|
||
|
||
# Decomposition
|
||
seasonal_init, trend_init = self.decomp(x_enc)
|
||
|
||
# Season stream
|
||
y_season = self.season_net(seasonal_init, encoder=self.season_encoder) # [B, C, pred_len]
|
||
|
||
# print("shape:", trend_init.shape)
|
||
# Trend stream
|
||
B, L, C = trend_init.shape
|
||
trend = trend_init.permute(0, 2, 1).reshape(B * C, L) # [B*C, L]
|
||
trend = self.fc5(trend)
|
||
trend = self.avgpool1(trend)
|
||
trend = self.ln1(trend)
|
||
trend = self.fc6(trend)
|
||
trend = self.avgpool2(trend)
|
||
trend = self.ln2(trend)
|
||
trend = self.fc7(trend) # [B*C, pred_len]
|
||
y_trend = trend.view(B, C, -1) # [B, C, pred_len]
|
||
|
||
# Attention-based pooling for classification
|
||
season_attn_weights = torch.softmax(y_season, dim=-1)
|
||
season_pooled = (y_season * season_attn_weights).sum(dim=-1) # [B, C]
|
||
|
||
trend_attn_weights = torch.softmax(y_trend, dim=-1) # 时间维
|
||
trend_pooled = (y_trend * trend_attn_weights).sum(dim=-1) # [B, C]
|
||
|
||
# Combine features
|
||
features = torch.cat([season_pooled, trend_pooled], dim=-1) # [B, 2*C]
|
||
|
||
# Classification
|
||
logits = self.classifier(features) # [B, num_classes]
|
||
return logits
|
||
|
||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
||
"""Forward pass dispatching to task-specific methods"""
|
||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
||
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
||
elif self.task_name == 'classification':
|
||
dec_out = self.classification(x_enc, x_mark_enc)
|
||
return dec_out # [B, N]
|
||
else:
|
||
raise ValueError(f'Task {self.task_name} not supported by xPatch_SparseChannel')
|