feat(model): introduce dynamic training flag for model forward pass

This commit is contained in:
gameloader
2025-09-13 00:04:01 +08:00
parent 172328a4e6
commit 93f14077da
4 changed files with 40 additions and 22 deletions

View File

@ -100,7 +100,7 @@ class Model(nn.Module):
nn.Linear(128, configs.num_class)
)
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec, training):
"""Long-term forecasting"""
# Normalization
if self.revin:
@ -110,7 +110,7 @@ class Model(nn.Module):
seasonal_init, trend_init = self.decomp(x_enc)
# Season stream
y_season = self.season_net(seasonal_init) # [B, C, pred_len]
y_season = self.season_net(seasonal_init, training) # [B, C, pred_len]
# Trend stream
B, L, C = trend_init.shape
@ -169,10 +169,10 @@ class Model(nn.Module):
logits = self.classifier(features) # [B, num_classes]
return logits
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None, training=True):
"""Forward pass dispatching to task-specific methods"""
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec, training)
return dec_out[:, -self.pred_len:, :] # [B, L, D]
elif self.task_name == 'classification':
dec_out = self.classification(x_enc, x_mark_enc)