feat(model): introduce dynamic training flag for model forward pass
This commit is contained in:
@ -100,7 +100,7 @@ class Model(nn.Module):
|
||||
nn.Linear(128, configs.num_class)
|
||||
)
|
||||
|
||||
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
||||
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec, training):
|
||||
"""Long-term forecasting"""
|
||||
# Normalization
|
||||
if self.revin:
|
||||
@ -110,7 +110,7 @@ class Model(nn.Module):
|
||||
seasonal_init, trend_init = self.decomp(x_enc)
|
||||
|
||||
# Season stream
|
||||
y_season = self.season_net(seasonal_init) # [B, C, pred_len]
|
||||
y_season = self.season_net(seasonal_init, training) # [B, C, pred_len]
|
||||
|
||||
# Trend stream
|
||||
B, L, C = trend_init.shape
|
||||
@ -169,10 +169,10 @@ class Model(nn.Module):
|
||||
logits = self.classifier(features) # [B, num_classes]
|
||||
return logits
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None, training=True):
|
||||
"""Forward pass dispatching to task-specific methods"""
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
||||
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec, training)
|
||||
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
||||
elif self.task_name == 'classification':
|
||||
dec_out = self.classification(x_enc, x_mark_enc)
|
||||
|
Reference in New Issue
Block a user