166 lines
6.1 KiB
Python
166 lines
6.1 KiB
Python
"""
|
|
xPatch_SparseChannel model adapted for Time-Series-Library-main
|
|
Supports both long-term forecasting and classification tasks
|
|
"""
|
|
import torch
|
|
import torch.nn as nn
|
|
from layers.DECOMP import DECOMP
|
|
from layers.SeasonPatch import SeasonPatch
|
|
from layers.RevIN import RevIN
|
|
|
|
class Model(nn.Module):
|
|
"""
|
|
xPatch SparseChannel Model
|
|
"""
|
|
|
|
def __init__(self, configs):
|
|
super(Model, self).__init__()
|
|
|
|
# Model configuration
|
|
self.task_name = configs.task_name
|
|
self.seq_len = configs.seq_len
|
|
self.pred_len = configs.pred_len
|
|
self.enc_in = configs.enc_in
|
|
|
|
# Model parameters
|
|
self.patch_len = getattr(configs, 'patch_len', 16)
|
|
self.stride = getattr(configs, 'stride', 8)
|
|
|
|
# Normalization
|
|
self.revin = getattr(configs, 'revin', True)
|
|
if self.revin:
|
|
self.revin_layer = RevIN(self.enc_in, affine=True, subtract_last=False)
|
|
|
|
# Decomposition using original DECOMP with EMA/DEMA
|
|
ma_type = getattr(configs, 'ma_type', 'ema')
|
|
alpha = getattr(configs, 'alpha', torch.tensor(0.1))
|
|
beta = getattr(configs, 'beta', torch.tensor(0.1))
|
|
self.decomp = DECOMP(ma_type, alpha, beta)
|
|
|
|
# Season network (PatchTST + Graph Mixer)
|
|
self.season_net = SeasonPatch(
|
|
c_in=self.enc_in,
|
|
seq_len=self.seq_len,
|
|
pred_len=self.pred_len,
|
|
patch_len=self.patch_len,
|
|
stride=self.stride,
|
|
k_graph=getattr(configs, 'k_graph', 8),
|
|
d_model=getattr(configs, 'd_model', 128),
|
|
n_layers=getattr(configs, 'e_layers', 3),
|
|
n_heads=getattr(configs, 'n_heads', 16)
|
|
)
|
|
|
|
# Trend network (MLP)
|
|
self.fc5 = nn.Linear(self.seq_len, self.pred_len * 4)
|
|
self.avgpool1 = nn.AvgPool1d(kernel_size=2)
|
|
self.ln1 = nn.LayerNorm(self.pred_len * 2)
|
|
self.fc6 = nn.Linear(self.pred_len * 2, self.pred_len)
|
|
self.avgpool2 = nn.AvgPool1d(kernel_size=2)
|
|
self.ln2 = nn.LayerNorm(self.pred_len // 2)
|
|
self.fc7 = nn.Linear(self.pred_len // 2, self.pred_len)
|
|
|
|
# Task-specific heads
|
|
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
|
self.fc_final = nn.Linear(self.pred_len * 2, self.pred_len)
|
|
elif self.task_name == 'classification':
|
|
self.season_attention = nn.Sequential(
|
|
nn.Linear(self.pred_len, 64),
|
|
nn.Tanh(),
|
|
nn.Linear(64, 1)
|
|
)
|
|
self.trend_attention = nn.Sequential(
|
|
nn.Linear(self.pred_len, 64),
|
|
nn.Tanh(),
|
|
nn.Linear(64, 1)
|
|
)
|
|
self.classifier = nn.Sequential(
|
|
nn.Linear(self.enc_in * 2, 128),
|
|
nn.ReLU(),
|
|
nn.Dropout(getattr(configs, 'dropout', 0.1)),
|
|
nn.Linear(128, configs.num_class)
|
|
)
|
|
|
|
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
|
"""Long-term forecasting"""
|
|
# Normalization
|
|
if self.revin:
|
|
x_enc = self.revin_layer(x_enc, 'norm')
|
|
|
|
# Decomposition
|
|
seasonal_init, trend_init = self.decomp(x_enc)
|
|
|
|
# Season stream
|
|
y_season = self.season_net(seasonal_init) # [B, C, pred_len]
|
|
|
|
# Trend stream
|
|
B, L, C = trend_init.shape
|
|
trend = trend_init.permute(0, 2, 1).reshape(B * C, L) # [B*C, L]
|
|
trend = self.fc5(trend)
|
|
trend = self.avgpool1(trend)
|
|
trend = self.ln1(trend)
|
|
trend = self.fc6(trend)
|
|
trend = self.avgpool2(trend)
|
|
trend = self.ln2(trend)
|
|
trend = self.fc7(trend) # [B*C, pred_len]
|
|
y_trend = trend.view(B, C, -1) # [B, C, pred_len]
|
|
|
|
# Combine streams
|
|
y = torch.cat([y_season, y_trend], dim=-1) # [B, C, 2*pred_len]
|
|
y = self.fc_final(y) # [B, C, pred_len]
|
|
y = y.permute(0, 2, 1) # [B, pred_len, C]
|
|
|
|
# Denormalization
|
|
if self.revin:
|
|
y = self.revin_layer(y, 'denorm')
|
|
|
|
return y
|
|
|
|
def classification(self, x_enc, x_mark_enc):
|
|
"""Classification task"""
|
|
# Normalization
|
|
#if self.revin:
|
|
# x_enc = self.revin_layer(x_enc, 'norm')
|
|
|
|
# Decomposition
|
|
seasonal_init, trend_init = self.decomp(x_enc)
|
|
|
|
# Season stream
|
|
y_season = self.season_net(seasonal_init) # [B, C, pred_len]
|
|
|
|
# print("shape:", trend_init.shape)
|
|
# Trend stream
|
|
B, L, C = trend_init.shape
|
|
trend = trend_init.permute(0, 2, 1).reshape(B * C, L) # [B*C, L]
|
|
trend = self.fc5(trend)
|
|
trend = self.avgpool1(trend)
|
|
trend = self.ln1(trend)
|
|
trend = self.fc6(trend)
|
|
trend = self.avgpool2(trend)
|
|
trend = self.ln2(trend)
|
|
trend = self.fc7(trend) # [B*C, pred_len]
|
|
y_trend = trend.view(B, C, -1) # [B, C, pred_len]
|
|
|
|
# Attention-based pooling for classification
|
|
season_attn_weights = torch.softmax(y_season, dim=-1)
|
|
season_pooled = (y_season * season_attn_weights).sum(dim=-1) # [B, C]
|
|
|
|
trend_attn_weights = torch.softmax(y_trend, dim=-1) # 时间维
|
|
trend_pooled = (y_trend * trend_attn_weights).sum(dim=-1) # [B, C]
|
|
|
|
# Combine features
|
|
features = torch.cat([season_pooled, trend_pooled], dim=-1) # [B, 2*C]
|
|
|
|
# Classification
|
|
logits = self.classifier(features) # [B, num_classes]
|
|
return logits
|
|
|
|
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
|
"""Forward pass dispatching to task-specific methods"""
|
|
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
|
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
|
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
|
elif self.task_name == 'classification':
|
|
dec_out = self.classification(x_enc, x_mark_enc)
|
|
return dec_out # [B, N]
|
|
else:
|
|
raise ValueError(f'Task {self.task_name} not supported by xPatch_SparseChannel') |