feat(model): introduce dynamic training flag for model forward pass
This commit is contained in:
@ -11,6 +11,7 @@ import warnings
|
||||
import numpy as np
|
||||
from utils.dtw_metric import dtw, accelerated_dtw
|
||||
from utils.augmentation import run_augmentation, run_augmentation_single
|
||||
import inspect
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
@ -18,9 +19,18 @@ warnings.filterwarnings('ignore')
|
||||
class Exp_Long_Term_Forecast(Exp_Basic):
|
||||
def __init__(self, args):
|
||||
super(Exp_Long_Term_Forecast, self).__init__(args)
|
||||
self._model_supports_training_flag = False
|
||||
|
||||
def _build_model(self):
|
||||
model = self.model_dict[self.args.model].Model(self.args).float()
|
||||
# 如果模型被 DataParallel 包装,我们需要检查原始模型
|
||||
model_to_inspect = model
|
||||
# inspect.signature() 可以获取函数或方法的参数信息
|
||||
forward_signature = inspect.signature(model_to_inspect.forward)
|
||||
# 检查'training'是否在参数列表中
|
||||
if 'training' in forward_signature.parameters:
|
||||
self._model_supports_training_flag = True
|
||||
print("Model supports 'training' flag.")
|
||||
|
||||
if self.args.use_multi_gpu and self.args.use_gpu:
|
||||
model = nn.DataParallel(model, device_ids=self.args.device_ids)
|
||||
@ -63,12 +73,16 @@ class Exp_Long_Term_Forecast(Exp_Basic):
|
||||
# decoder input
|
||||
dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
|
||||
dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
|
||||
# encoder - decoder
|
||||
# --- 修改模型调用部分 ---
|
||||
model_args = (batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
||||
model_kwargs = {}
|
||||
if self._model_supports_training_flag:
|
||||
model_kwargs['training'] = False # 验证阶段为 False
|
||||
if self.args.use_amp:
|
||||
with torch.cuda.amp.autocast():
|
||||
outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
||||
outputs = self.model(*model_args, **model_kwargs)
|
||||
else:
|
||||
outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
||||
outputs = self.model(*model_args, **model_kwargs)
|
||||
f_dim = -1 if self.args.features == 'MS' else 0
|
||||
outputs = outputs[:, -self.args.pred_len:, f_dim:]
|
||||
batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
|
||||
@ -130,19 +144,20 @@ class Exp_Long_Term_Forecast(Exp_Basic):
|
||||
dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
|
||||
dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
|
||||
|
||||
# encoder - decoder
|
||||
model_args = (batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
||||
model_kwargs = {}
|
||||
if self._model_supports_training_flag:
|
||||
model_kwargs['training'] = True # 训练阶段为 True
|
||||
if self.args.use_amp:
|
||||
with torch.cuda.amp.autocast():
|
||||
outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
||||
|
||||
outputs = self.model(*model_args, **model_kwargs)
|
||||
f_dim = -1 if self.args.features == 'MS' else 0
|
||||
outputs = outputs[:, -self.args.pred_len:, f_dim:]
|
||||
batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
|
||||
loss = criterion(outputs, batch_y)
|
||||
train_loss.append(loss.item())
|
||||
else:
|
||||
outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
||||
|
||||
outputs = self.model(*model_args, **model_kwargs)
|
||||
f_dim = -1 if self.args.features == 'MS' else 0
|
||||
outputs = outputs[:, -self.args.pred_len:, f_dim:]
|
||||
batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
|
||||
@ -208,12 +223,15 @@ class Exp_Long_Term_Forecast(Exp_Basic):
|
||||
# decoder input
|
||||
dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
|
||||
dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
|
||||
# encoder - decoder
|
||||
model_args = (batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
||||
model_kwargs = {}
|
||||
if self._model_supports_training_flag:
|
||||
model_kwargs['training'] = False # 测试阶段为 False
|
||||
if self.args.use_amp:
|
||||
with torch.cuda.amp.autocast():
|
||||
outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
||||
outputs = self.model(*model_args, **model_kwargs)
|
||||
else:
|
||||
outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
||||
outputs = self.model(*model_args, **model_kwargs)
|
||||
|
||||
f_dim = -1 if self.args.features == 'MS' else 0
|
||||
outputs = outputs[:, -self.args.pred_len:, :]
|
||||
|
Reference in New Issue
Block a user