feat(mamba): add Mamba2 encoder option to SeasonPatch
This commit is contained in:
@ -37,6 +37,8 @@ class Model(nn.Module):
|
||||
beta = getattr(configs, 'beta', torch.tensor(0.1))
|
||||
self.decomp = DECOMP(ma_type, alpha, beta)
|
||||
|
||||
# 读取选择的编码器类型('Transformer' 或 'Mamba2')
|
||||
self.season_encoder = getattr(configs, 'season_encoder', 'Transformer')
|
||||
# Season network (PatchTST + Graph Mixer)
|
||||
self.season_net = SeasonPatch(
|
||||
c_in=self.enc_in,
|
||||
@ -90,7 +92,7 @@ class Model(nn.Module):
|
||||
seasonal_init, trend_init = self.decomp(x_enc)
|
||||
|
||||
# Season stream
|
||||
y_season = self.season_net(seasonal_init) # [B, C, pred_len]
|
||||
y_season = self.season_net(seasonal_init, encoder=self.season_encoder) # [B, C, pred_len]
|
||||
|
||||
# Trend stream
|
||||
B, L, C = trend_init.shape
|
||||
@ -125,7 +127,7 @@ class Model(nn.Module):
|
||||
seasonal_init, trend_init = self.decomp(x_enc)
|
||||
|
||||
# Season stream
|
||||
y_season = self.season_net(seasonal_init) # [B, C, pred_len]
|
||||
y_season = self.season_net(seasonal_init, encoder=self.season_encoder) # [B, C, pred_len]
|
||||
|
||||
# print("shape:", trend_init.shape)
|
||||
# Trend stream
|
||||
@ -163,4 +165,4 @@ class Model(nn.Module):
|
||||
dec_out = self.classification(x_enc, x_mark_enc)
|
||||
return dec_out # [B, N]
|
||||
else:
|
||||
raise ValueError(f'Task {self.task_name} not supported by xPatch_SparseChannel')
|
||||
raise ValueError(f'Task {self.task_name} not supported by xPatch_SparseChannel')
|
||||
|
Reference in New Issue
Block a user