feat(model): introduce dynamic training flag for model forward pass
This commit is contained in:
@ -11,6 +11,7 @@ import warnings
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from utils.dtw_metric import dtw, accelerated_dtw
|
from utils.dtw_metric import dtw, accelerated_dtw
|
||||||
from utils.augmentation import run_augmentation, run_augmentation_single
|
from utils.augmentation import run_augmentation, run_augmentation_single
|
||||||
|
import inspect
|
||||||
|
|
||||||
warnings.filterwarnings('ignore')
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
@ -18,9 +19,18 @@ warnings.filterwarnings('ignore')
|
|||||||
class Exp_Long_Term_Forecast(Exp_Basic):
|
class Exp_Long_Term_Forecast(Exp_Basic):
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
super(Exp_Long_Term_Forecast, self).__init__(args)
|
super(Exp_Long_Term_Forecast, self).__init__(args)
|
||||||
|
self._model_supports_training_flag = False
|
||||||
|
|
||||||
def _build_model(self):
|
def _build_model(self):
|
||||||
model = self.model_dict[self.args.model].Model(self.args).float()
|
model = self.model_dict[self.args.model].Model(self.args).float()
|
||||||
|
# 如果模型被 DataParallel 包装,我们需要检查原始模型
|
||||||
|
model_to_inspect = model
|
||||||
|
# inspect.signature() 可以获取函数或方法的参数信息
|
||||||
|
forward_signature = inspect.signature(model_to_inspect.forward)
|
||||||
|
# 检查'training'是否在参数列表中
|
||||||
|
if 'training' in forward_signature.parameters:
|
||||||
|
self._model_supports_training_flag = True
|
||||||
|
print("Model supports 'training' flag.")
|
||||||
|
|
||||||
if self.args.use_multi_gpu and self.args.use_gpu:
|
if self.args.use_multi_gpu and self.args.use_gpu:
|
||||||
model = nn.DataParallel(model, device_ids=self.args.device_ids)
|
model = nn.DataParallel(model, device_ids=self.args.device_ids)
|
||||||
@ -63,12 +73,16 @@ class Exp_Long_Term_Forecast(Exp_Basic):
|
|||||||
# decoder input
|
# decoder input
|
||||||
dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
|
dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
|
||||||
dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
|
dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
|
||||||
# encoder - decoder
|
# --- 修改模型调用部分 ---
|
||||||
|
model_args = (batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
||||||
|
model_kwargs = {}
|
||||||
|
if self._model_supports_training_flag:
|
||||||
|
model_kwargs['training'] = False # 验证阶段为 False
|
||||||
if self.args.use_amp:
|
if self.args.use_amp:
|
||||||
with torch.cuda.amp.autocast():
|
with torch.cuda.amp.autocast():
|
||||||
outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
outputs = self.model(*model_args, **model_kwargs)
|
||||||
else:
|
else:
|
||||||
outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
outputs = self.model(*model_args, **model_kwargs)
|
||||||
f_dim = -1 if self.args.features == 'MS' else 0
|
f_dim = -1 if self.args.features == 'MS' else 0
|
||||||
outputs = outputs[:, -self.args.pred_len:, f_dim:]
|
outputs = outputs[:, -self.args.pred_len:, f_dim:]
|
||||||
batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
|
batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
|
||||||
@ -130,19 +144,20 @@ class Exp_Long_Term_Forecast(Exp_Basic):
|
|||||||
dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
|
dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
|
||||||
dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
|
dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
|
||||||
|
|
||||||
# encoder - decoder
|
model_args = (batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
||||||
|
model_kwargs = {}
|
||||||
|
if self._model_supports_training_flag:
|
||||||
|
model_kwargs['training'] = True # 训练阶段为 True
|
||||||
if self.args.use_amp:
|
if self.args.use_amp:
|
||||||
with torch.cuda.amp.autocast():
|
with torch.cuda.amp.autocast():
|
||||||
outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
outputs = self.model(*model_args, **model_kwargs)
|
||||||
|
|
||||||
f_dim = -1 if self.args.features == 'MS' else 0
|
f_dim = -1 if self.args.features == 'MS' else 0
|
||||||
outputs = outputs[:, -self.args.pred_len:, f_dim:]
|
outputs = outputs[:, -self.args.pred_len:, f_dim:]
|
||||||
batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
|
batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
|
||||||
loss = criterion(outputs, batch_y)
|
loss = criterion(outputs, batch_y)
|
||||||
train_loss.append(loss.item())
|
train_loss.append(loss.item())
|
||||||
else:
|
else:
|
||||||
outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
outputs = self.model(*model_args, **model_kwargs)
|
||||||
|
|
||||||
f_dim = -1 if self.args.features == 'MS' else 0
|
f_dim = -1 if self.args.features == 'MS' else 0
|
||||||
outputs = outputs[:, -self.args.pred_len:, f_dim:]
|
outputs = outputs[:, -self.args.pred_len:, f_dim:]
|
||||||
batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
|
batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
|
||||||
@ -208,12 +223,15 @@ class Exp_Long_Term_Forecast(Exp_Basic):
|
|||||||
# decoder input
|
# decoder input
|
||||||
dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
|
dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
|
||||||
dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
|
dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
|
||||||
# encoder - decoder
|
model_args = (batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
||||||
|
model_kwargs = {}
|
||||||
|
if self._model_supports_training_flag:
|
||||||
|
model_kwargs['training'] = False # 测试阶段为 False
|
||||||
if self.args.use_amp:
|
if self.args.use_amp:
|
||||||
with torch.cuda.amp.autocast():
|
with torch.cuda.amp.autocast():
|
||||||
outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
outputs = self.model(*model_args, **model_kwargs)
|
||||||
else:
|
else:
|
||||||
outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
outputs = self.model(*model_args, **model_kwargs)
|
||||||
|
|
||||||
f_dim = -1 if self.args.features == 'MS' else 0
|
f_dim = -1 if self.args.features == 'MS' else 0
|
||||||
outputs = outputs[:, -self.args.pred_len:, :]
|
outputs = outputs[:, -self.args.pred_len:, :]
|
||||||
|
@ -116,8 +116,8 @@ class HierarchicalGraphMixer(nn.Module):
|
|||||||
else:
|
else:
|
||||||
g = progress
|
g = progress
|
||||||
return self.thr_min + (self.thr_max - self.thr_min) * g
|
return self.thr_min + (self.thr_max - self.thr_min) * g
|
||||||
def _maybe_update_thr(self):
|
def _maybe_update_thr(self, training):
|
||||||
if self.training and self._use_dynamic_thr:
|
if training and self._use_dynamic_thr:
|
||||||
step = int(self._thr_step.item())
|
step = int(self._thr_step.item())
|
||||||
progress = step / float(self.thr_steps)
|
progress = step / float(self.thr_steps)
|
||||||
self.thr = float(self._compute_thr_by_progress(progress))
|
self.thr = float(self._compute_thr_by_progress(progress))
|
||||||
@ -185,14 +185,14 @@ class HierarchicalGraphMixer(nn.Module):
|
|||||||
"""
|
"""
|
||||||
return lam * self.gate.expected_l0().sum()
|
return lam * self.gate.expected_l0().sum()
|
||||||
|
|
||||||
def forward(self, z):
|
def forward(self, z, is_training):
|
||||||
self._maybe_update_thr()
|
self._maybe_update_thr(training=is_training)
|
||||||
# z: [B, C, N, D]
|
# z: [B, C, N, D]
|
||||||
B, C, N, D = z.shape
|
B, C, N, D = z.shape
|
||||||
assert C == self.C and D == self.dim
|
assert C == self.C and D == self.dim
|
||||||
|
|
||||||
# Level 1: 采样非归一化门 z_gate ∈ [0,1]
|
# Level 1: 采样非归一化门 z_gate ∈ [0,1]
|
||||||
z_gate = self.gate.sample(training=self.training) # [C, C]
|
z_gate = self.gate.sample(training=is_training) # [C, C]
|
||||||
|
|
||||||
# 构建稀疏邻居(阈值 + 可选 top-k)
|
# 构建稀疏邻居(阈值 + 可选 top-k)
|
||||||
idx_list, w_list = self._build_sparse_neighbors(z_gate)
|
idx_list, w_list = self._build_sparse_neighbors(z_gate)
|
||||||
|
@ -98,7 +98,7 @@ class SeasonPatch(nn.Module):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported encoder_type: {encoder_type}")
|
raise ValueError(f"Unsupported encoder_type: {encoder_type}")
|
||||||
|
|
||||||
def forward(self, x ):
|
def forward(self, x, training):
|
||||||
# x: [B, L, C]
|
# x: [B, L, C]
|
||||||
x = x.permute(0, 2, 1) # x: [B, C, L]
|
x = x.permute(0, 2, 1) # x: [B, C, L]
|
||||||
|
|
||||||
@ -114,7 +114,7 @@ class SeasonPatch(nn.Module):
|
|||||||
z = z.permute(0, 1, 3, 2) # z: [B, C, patch_num, d_model]
|
z = z.permute(0, 1, 3, 2) # z: [B, C, patch_num, d_model]
|
||||||
|
|
||||||
# Cross-channel mixing
|
# Cross-channel mixing
|
||||||
z_mix = self.mixer(z) # z_mix: [B, C, patch_num, d_model]
|
z_mix = self.mixer(z, training) # z_mix: [B, C, patch_num, d_model]
|
||||||
|
|
||||||
# Flatten and predict
|
# Flatten and predict
|
||||||
z_mix = z_mix.view(B, C, N * D) # z_mix: [B, C, patch_num * d_model]
|
z_mix = z_mix.view(B, C, N * D) # z_mix: [B, C, patch_num * d_model]
|
||||||
|
@ -100,7 +100,7 @@ class Model(nn.Module):
|
|||||||
nn.Linear(128, configs.num_class)
|
nn.Linear(128, configs.num_class)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec, training):
|
||||||
"""Long-term forecasting"""
|
"""Long-term forecasting"""
|
||||||
# Normalization
|
# Normalization
|
||||||
if self.revin:
|
if self.revin:
|
||||||
@ -110,7 +110,7 @@ class Model(nn.Module):
|
|||||||
seasonal_init, trend_init = self.decomp(x_enc)
|
seasonal_init, trend_init = self.decomp(x_enc)
|
||||||
|
|
||||||
# Season stream
|
# Season stream
|
||||||
y_season = self.season_net(seasonal_init) # [B, C, pred_len]
|
y_season = self.season_net(seasonal_init, training) # [B, C, pred_len]
|
||||||
|
|
||||||
# Trend stream
|
# Trend stream
|
||||||
B, L, C = trend_init.shape
|
B, L, C = trend_init.shape
|
||||||
@ -169,10 +169,10 @@ class Model(nn.Module):
|
|||||||
logits = self.classifier(features) # [B, num_classes]
|
logits = self.classifier(features) # [B, num_classes]
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None, training=True):
|
||||||
"""Forward pass dispatching to task-specific methods"""
|
"""Forward pass dispatching to task-specific methods"""
|
||||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||||
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec, training)
|
||||||
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
||||||
elif self.task_name == 'classification':
|
elif self.task_name == 'classification':
|
||||||
dec_out = self.classification(x_enc, x_mark_enc)
|
dec_out = self.classification(x_enc, x_mark_enc)
|
||||||
|
Reference in New Issue
Block a user