diff --git a/models/xPatch_SparseChannel.py b/models/xPatch_SparseChannel.py index 88c3561..c809073 100644 --- a/models/xPatch_SparseChannel.py +++ b/models/xPatch_SparseChannel.py @@ -92,7 +92,7 @@ class Model(nn.Module): seasonal_init, trend_init = self.decomp(x_enc) # Season stream - y_season = self.season_net(seasonal_init, encoder=self.season_encoder) # [B, C, pred_len] + y_season = self.season_net(seasonal_init) # [B, C, pred_len] # Trend stream B, L, C = trend_init.shape @@ -127,7 +127,7 @@ class Model(nn.Module): seasonal_init, trend_init = self.decomp(x_enc) # Season stream - y_season = self.season_net(seasonal_init, encoder=self.season_encoder) # [B, C, pred_len] + y_season = self.season_net(seasonal_init) # [B, C, pred_len] # print("shape:", trend_init.shape) # Trend stream