""" xPatch_SparseChannel model adapted for Time-Series-Library-main Supports both long-term forecasting and classification tasks """ import torch import torch.nn as nn from layers.DECOMP import DECOMP from layers.SeasonPatch import SeasonPatch from layers.RevIN import RevIN class Model(nn.Module): """ xPatch SparseChannel Model """ def __init__(self, configs): super(Model, self).__init__() # Model configuration self.task_name = configs.task_name self.seq_len = configs.seq_len self.pred_len = configs.pred_len self.enc_in = configs.enc_in # Patch parameters self.patch_len = getattr(configs, 'patch_len', 16) self.stride = getattr(configs, 'stride', 8) # Normalization self.revin = getattr(configs, 'revin', True) if self.revin: self.revin_layer = RevIN(self.enc_in, affine=True, subtract_last=False) # Decomposition using original DECOMP with EMA/DEMA ma_type = getattr(configs, 'ma_type', 'ema') alpha = getattr(configs, 'alpha', torch.tensor(0.1)) beta = getattr(configs, 'beta', torch.tensor(0.1)) self.decomp = DECOMP(ma_type, alpha, beta) # Season network (PatchTST/Mamba2 + Graph Mixer) # 透传新版 SeasonPatch 的参数(其中 GraphMixer 替换为非归一化 Hard-Concrete 门控) self.season_net = SeasonPatch( c_in=self.enc_in, seq_len=self.seq_len, pred_len=self.pred_len, patch_len=self.patch_len, stride=self.stride, # 编码器类型:'Transformer' or 'Mamba2' encoder_type=getattr(configs, 'season_encoder', 'Transformer'), # Patch相关 d_model=getattr(configs, 'd_model', 128), n_layers=getattr(configs, 'e_layers', 3), n_heads=getattr(configs, 'n_heads', 16), # GraphMixer相关(非归一化) k_graph=getattr(configs, 'k_graph', 8), # -> max_degree thr_graph=getattr(configs, 'thr_graph', 0.5), thr_graph_min=getattr(configs, 'thr_graph_min', None), thr_graph_max=getattr(configs, 'thr_graph_max', None), thr_graph_steps=getattr(configs, 'thr_graph_steps', 0), thr_graph_schedule=getattr(configs, 'thr_graph_schedule', 'linear'), symmetric_graph=getattr(configs, 'symmetric_graph', True), degree_rescale=getattr(configs, 'degree_rescale', 'count-sqrt'), # 'none' | 'count' | 'count-sqrt' | 'sum' gate_temperature=getattr(configs, 'gate_temperature', 2.0/3.0), tau_attn=getattr(configs, 'tau_attn', 1.0), l0_lambda=getattr(configs, 'season_l0_lambda', 0.0), # Mamba2相关 d_state=getattr(configs, 'd_state', 64), d_conv=getattr(configs, 'd_conv', 4), expand=getattr(configs, 'expand', 2), headdim=getattr(configs, 'headdim', 64), ) # Trend network (MLP) self.fc5 = nn.Linear(self.seq_len, self.pred_len * 4) self.avgpool1 = nn.AvgPool1d(kernel_size=2) self.ln1 = nn.LayerNorm(self.pred_len * 2) self.fc6 = nn.Linear(self.pred_len * 2, self.pred_len) self.avgpool2 = nn.AvgPool1d(kernel_size=2) self.ln2 = nn.LayerNorm(self.pred_len // 2) self.fc7 = nn.Linear(self.pred_len // 2, self.pred_len) # Task-specific heads if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': self.fc_final = nn.Linear(self.pred_len * 2, self.pred_len) elif self.task_name == 'classification': self.season_attention = nn.Sequential( nn.Linear(self.pred_len, 64), nn.Tanh(), nn.Linear(64, 1) ) self.trend_attention = nn.Sequential( nn.Linear(self.pred_len, 64), nn.Tanh(), nn.Linear(64, 1) ) self.classifier = nn.Sequential( nn.Linear(self.enc_in * 2, 128), nn.ReLU(), nn.Dropout(getattr(configs, 'dropout', 0.1)), nn.Linear(128, configs.num_class) ) def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): """Long-term forecasting""" # Normalization if self.revin: x_enc = self.revin_layer(x_enc, 'norm') # Decomposition seasonal_init, trend_init = self.decomp(x_enc) # Season stream y_season = self.season_net(seasonal_init) # [B, C, pred_len] # Trend stream B, L, C = trend_init.shape trend = trend_init.permute(0, 2, 1).reshape(B * C, L) # [B*C, L] trend = self.fc5(trend) trend = self.avgpool1(trend) trend = self.ln1(trend) trend = self.fc6(trend) trend = self.avgpool2(trend) trend = self.ln2(trend) trend = self.fc7(trend) # [B*C, pred_len] y_trend = trend.view(B, C, -1) # [B, C, pred_len] # Combine streams y = torch.cat([y_season, y_trend], dim=-1) # [B, C, 2*pred_len] y = self.fc_final(y) # [B, C, pred_len] y = y.permute(0, 2, 1) # [B, pred_len, C] # Denormalization if self.revin: y = self.revin_layer(y, 'denorm') return y def classification(self, x_enc, x_mark_enc): """Classification task""" # Decomposition(分类任务通常可不做 RevIN,如需可自行打开) seasonal_init, trend_init = self.decomp(x_enc) # Season stream y_season = self.season_net(seasonal_init) # [B, C, pred_len] # Trend stream B, L, C = trend_init.shape trend = trend_init.permute(0, 2, 1).reshape(B * C, L) # [B*C, L] trend = self.fc5(trend) trend = self.avgpool1(trend) trend = self.ln1(trend) trend = self.fc6(trend) trend = self.avgpool2(trend) trend = self.ln2(trend) trend = self.fc7(trend) # [B*C, pred_len] y_trend = trend.view(B, C, -1) # [B, C, pred_len] # Attention-based pooling for classification season_attn_weights = torch.softmax(y_season, dim=-1) season_pooled = (y_season * season_attn_weights).sum(dim=-1) # [B, C] trend_attn_weights = torch.softmax(y_trend, dim=-1) trend_pooled = (y_trend * trend_attn_weights).sum(dim=-1) # [B, C] # Combine features features = torch.cat([season_pooled, trend_pooled], dim=-1) # [B, 2*C] # Classification logits = self.classifier(features) # [B, num_classes] return logits def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): """Forward pass dispatching to task-specific methods""" if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) return dec_out[:, -self.pred_len:, :] # [B, L, D] elif self.task_name == 'classification': dec_out = self.classification(x_enc, x_mark_enc) return dec_out # [B, N] else: raise ValueError(f'Task {self.task_name} not supported by xPatch_SparseChannel') def reg_loss(self): """ L0 正则项(仅在 Transformer 路径启用 GraphMixer 时非零)。 训练时:total_loss = main_loss + model.reg_loss() """ if hasattr(self, "season_net") and hasattr(self.season_net, "reg_loss"): return self.season_net.reg_loss() return torch.tensor(0.0, device=next(self.parameters()).device)