""" xPatch_SparseChannel model adapted for Time-Series-Library-main Supports both long-term forecasting and classification tasks """ import torch import torch.nn as nn from layers.DECOMP import DECOMP from layers.SeasonPatch import SeasonPatch from layers.RevIN import RevIN class Model(nn.Module): """ xPatch SparseChannel Model """ def __init__(self, configs): super(Model, self).__init__() # Model configuration self.task_name = configs.task_name self.seq_len = configs.seq_len self.pred_len = configs.pred_len self.enc_in = configs.enc_in # Model parameters self.patch_len = getattr(configs, 'patch_len', 16) self.stride = getattr(configs, 'stride', 8) # Normalization self.revin = getattr(configs, 'revin', True) if self.revin: self.revin_layer = RevIN(self.enc_in, affine=True, subtract_last=False) # Decomposition using original DECOMP with EMA/DEMA ma_type = getattr(configs, 'ma_type', 'ema') alpha = getattr(configs, 'alpha', torch.tensor(0.1)) beta = getattr(configs, 'beta', torch.tensor(0.1)) self.decomp = DECOMP(ma_type, alpha, beta) # Season network (PatchTST + Graph Mixer) self.season_net = SeasonPatch( c_in=self.enc_in, seq_len=self.seq_len, pred_len=self.pred_len, patch_len=self.patch_len, stride=self.stride, k_graph=getattr(configs, 'k_graph', 8), d_model=getattr(configs, 'd_model', 128), n_layers=getattr(configs, 'e_layers', 3), n_heads=getattr(configs, 'n_heads', 16), # 读取选择的编码器类型('Transformer' 或 'Mamba2') encoder_type = getattr(configs, 'season_encoder', 'Transformer') ) # Trend network (MLP) self.fc5 = nn.Linear(self.seq_len, self.pred_len * 4) self.avgpool1 = nn.AvgPool1d(kernel_size=2) self.ln1 = nn.LayerNorm(self.pred_len * 2) self.fc6 = nn.Linear(self.pred_len * 2, self.pred_len) self.avgpool2 = nn.AvgPool1d(kernel_size=2) self.ln2 = nn.LayerNorm(self.pred_len // 2) self.fc7 = nn.Linear(self.pred_len // 2, self.pred_len) # Task-specific heads if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': self.fc_final = nn.Linear(self.pred_len * 2, self.pred_len) elif self.task_name == 'classification': self.season_attention = nn.Sequential( nn.Linear(self.pred_len, 64), nn.Tanh(), nn.Linear(64, 1) ) self.trend_attention = nn.Sequential( nn.Linear(self.pred_len, 64), nn.Tanh(), nn.Linear(64, 1) ) self.classifier = nn.Sequential( nn.Linear(self.enc_in * 2, 128), nn.ReLU(), nn.Dropout(getattr(configs, 'dropout', 0.1)), nn.Linear(128, configs.num_class) ) def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): """Long-term forecasting""" # Normalization if self.revin: x_enc = self.revin_layer(x_enc, 'norm') # Decomposition seasonal_init, trend_init = self.decomp(x_enc) # Season stream y_season = self.season_net(seasonal_init) # [B, C, pred_len] # Trend stream B, L, C = trend_init.shape trend = trend_init.permute(0, 2, 1).reshape(B * C, L) # [B*C, L] trend = self.fc5(trend) trend = self.avgpool1(trend) trend = self.ln1(trend) trend = self.fc6(trend) trend = self.avgpool2(trend) trend = self.ln2(trend) trend = self.fc7(trend) # [B*C, pred_len] y_trend = trend.view(B, C, -1) # [B, C, pred_len] # Combine streams y = torch.cat([y_season, y_trend], dim=-1) # [B, C, 2*pred_len] y = self.fc_final(y) # [B, C, pred_len] y = y.permute(0, 2, 1) # [B, pred_len, C] # Denormalization if self.revin: y = self.revin_layer(y, 'denorm') return y def classification(self, x_enc, x_mark_enc): """Classification task""" # Normalization #if self.revin: # x_enc = self.revin_layer(x_enc, 'norm') # Decomposition seasonal_init, trend_init = self.decomp(x_enc) # Season stream y_season = self.season_net(seasonal_init) # [B, C, pred_len] # print("shape:", trend_init.shape) # Trend stream B, L, C = trend_init.shape trend = trend_init.permute(0, 2, 1).reshape(B * C, L) # [B*C, L] trend = self.fc5(trend) trend = self.avgpool1(trend) trend = self.ln1(trend) trend = self.fc6(trend) trend = self.avgpool2(trend) trend = self.ln2(trend) trend = self.fc7(trend) # [B*C, pred_len] y_trend = trend.view(B, C, -1) # [B, C, pred_len] # Attention-based pooling for classification season_attn_weights = torch.softmax(y_season, dim=-1) season_pooled = (y_season * season_attn_weights).sum(dim=-1) # [B, C] trend_attn_weights = torch.softmax(y_trend, dim=-1) # 时间维 trend_pooled = (y_trend * trend_attn_weights).sum(dim=-1) # [B, C] # Combine features features = torch.cat([season_pooled, trend_pooled], dim=-1) # [B, 2*C] # Classification logits = self.classifier(features) # [B, num_classes] return logits def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): """Forward pass dispatching to task-specific methods""" if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) return dec_out[:, -self.pred_len:, :] # [B, L, D] elif self.task_name == 'classification': dec_out = self.classification(x_enc, x_mark_enc) return dec_out # [B, N] else: raise ValueError(f'Task {self.task_name} not supported by xPatch_SparseChannel')