diff --git a/exp/exp_long_term_forecasting.py b/exp/exp_long_term_forecasting.py index 858b9bf..a76b99c 100644 --- a/exp/exp_long_term_forecasting.py +++ b/exp/exp_long_term_forecasting.py @@ -11,6 +11,7 @@ import warnings import numpy as np from utils.dtw_metric import dtw, accelerated_dtw from utils.augmentation import run_augmentation, run_augmentation_single +import inspect warnings.filterwarnings('ignore') @@ -18,9 +19,18 @@ warnings.filterwarnings('ignore') class Exp_Long_Term_Forecast(Exp_Basic): def __init__(self, args): super(Exp_Long_Term_Forecast, self).__init__(args) + self._model_supports_training_flag = False def _build_model(self): model = self.model_dict[self.args.model].Model(self.args).float() + # 如果模型被 DataParallel 包装,我们需要检查原始模型 + model_to_inspect = model + # inspect.signature() 可以获取函数或方法的参数信息 + forward_signature = inspect.signature(model_to_inspect.forward) + # 检查'training'是否在参数列表中 + if 'training' in forward_signature.parameters: + self._model_supports_training_flag = True + print("Model supports 'training' flag.") if self.args.use_multi_gpu and self.args.use_gpu: model = nn.DataParallel(model, device_ids=self.args.device_ids) @@ -63,12 +73,16 @@ class Exp_Long_Term_Forecast(Exp_Basic): # decoder input dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float() dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device) - # encoder - decoder + # --- 修改模型调用部分 --- + model_args = (batch_x, batch_x_mark, dec_inp, batch_y_mark) + model_kwargs = {} + if self._model_supports_training_flag: + model_kwargs['training'] = False # 验证阶段为 False if self.args.use_amp: with torch.cuda.amp.autocast(): - outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) + outputs = self.model(*model_args, **model_kwargs) else: - outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) + outputs = self.model(*model_args, **model_kwargs) f_dim = -1 if self.args.features == 'MS' else 0 outputs = outputs[:, -self.args.pred_len:, f_dim:] batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device) @@ -130,19 +144,20 @@ class Exp_Long_Term_Forecast(Exp_Basic): dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float() dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device) - # encoder - decoder + model_args = (batch_x, batch_x_mark, dec_inp, batch_y_mark) + model_kwargs = {} + if self._model_supports_training_flag: + model_kwargs['training'] = True # 训练阶段为 True if self.args.use_amp: with torch.cuda.amp.autocast(): - outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) - + outputs = self.model(*model_args, **model_kwargs) f_dim = -1 if self.args.features == 'MS' else 0 outputs = outputs[:, -self.args.pred_len:, f_dim:] batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device) loss = criterion(outputs, batch_y) train_loss.append(loss.item()) else: - outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) - + outputs = self.model(*model_args, **model_kwargs) f_dim = -1 if self.args.features == 'MS' else 0 outputs = outputs[:, -self.args.pred_len:, f_dim:] batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device) @@ -208,12 +223,15 @@ class Exp_Long_Term_Forecast(Exp_Basic): # decoder input dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float() dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device) - # encoder - decoder + model_args = (batch_x, batch_x_mark, dec_inp, batch_y_mark) + model_kwargs = {} + if self._model_supports_training_flag: + model_kwargs['training'] = False # 测试阶段为 False if self.args.use_amp: with torch.cuda.amp.autocast(): - outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) + outputs = self.model(*model_args, **model_kwargs) else: - outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) + outputs = self.model(*model_args, **model_kwargs) f_dim = -1 if self.args.features == 'MS' else 0 outputs = outputs[:, -self.args.pred_len:, :] diff --git a/layers/GraphMixer.py b/layers/GraphMixer.py index 277b989..04a07bb 100644 --- a/layers/GraphMixer.py +++ b/layers/GraphMixer.py @@ -116,8 +116,8 @@ class HierarchicalGraphMixer(nn.Module): else: g = progress return self.thr_min + (self.thr_max - self.thr_min) * g - def _maybe_update_thr(self): - if self.training and self._use_dynamic_thr: + def _maybe_update_thr(self, training): + if training and self._use_dynamic_thr: step = int(self._thr_step.item()) progress = step / float(self.thr_steps) self.thr = float(self._compute_thr_by_progress(progress)) @@ -185,14 +185,14 @@ class HierarchicalGraphMixer(nn.Module): """ return lam * self.gate.expected_l0().sum() - def forward(self, z): - self._maybe_update_thr() + def forward(self, z, is_training): + self._maybe_update_thr(training=is_training) # z: [B, C, N, D] B, C, N, D = z.shape assert C == self.C and D == self.dim # Level 1: 采样非归一化门 z_gate ∈ [0,1] - z_gate = self.gate.sample(training=self.training) # [C, C] + z_gate = self.gate.sample(training=is_training) # [C, C] # 构建稀疏邻居(阈值 + 可选 top-k) idx_list, w_list = self._build_sparse_neighbors(z_gate) diff --git a/layers/SeasonPatch.py b/layers/SeasonPatch.py index a508d68..a338cec 100644 --- a/layers/SeasonPatch.py +++ b/layers/SeasonPatch.py @@ -98,7 +98,7 @@ class SeasonPatch(nn.Module): else: raise ValueError(f"Unsupported encoder_type: {encoder_type}") - def forward(self, x ): + def forward(self, x, training): # x: [B, L, C] x = x.permute(0, 2, 1) # x: [B, C, L] @@ -114,7 +114,7 @@ class SeasonPatch(nn.Module): z = z.permute(0, 1, 3, 2) # z: [B, C, patch_num, d_model] # Cross-channel mixing - z_mix = self.mixer(z) # z_mix: [B, C, patch_num, d_model] + z_mix = self.mixer(z, training) # z_mix: [B, C, patch_num, d_model] # Flatten and predict z_mix = z_mix.view(B, C, N * D) # z_mix: [B, C, patch_num * d_model] diff --git a/models/xPatch_SparseChannel.py b/models/xPatch_SparseChannel.py index 2b5f009..bdaabda 100644 --- a/models/xPatch_SparseChannel.py +++ b/models/xPatch_SparseChannel.py @@ -100,7 +100,7 @@ class Model(nn.Module): nn.Linear(128, configs.num_class) ) - def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): + def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec, training): """Long-term forecasting""" # Normalization if self.revin: @@ -110,7 +110,7 @@ class Model(nn.Module): seasonal_init, trend_init = self.decomp(x_enc) # Season stream - y_season = self.season_net(seasonal_init) # [B, C, pred_len] + y_season = self.season_net(seasonal_init, training) # [B, C, pred_len] # Trend stream B, L, C = trend_init.shape @@ -169,10 +169,10 @@ class Model(nn.Module): logits = self.classifier(features) # [B, num_classes] return logits - def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): + def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None, training=True): """Forward pass dispatching to task-specific methods""" if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': - dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) + dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec, training) return dec_out[:, -self.pred_len:, :] # [B, L, D] elif self.task_name == 'classification': dec_out = self.classification(x_enc, x_mark_enc)