feat(mamba): add Mamba2 encoder option to SeasonPatch

This commit is contained in:
gameloader
2025-09-10 15:51:28 +08:00
parent 96c40c6ab6
commit 908d3a7080
3 changed files with 138 additions and 37 deletions

View File

@ -37,6 +37,8 @@ class Model(nn.Module):
beta = getattr(configs, 'beta', torch.tensor(0.1))
self.decomp = DECOMP(ma_type, alpha, beta)
# 读取选择的编码器类型('Transformer' 或 'Mamba2'
self.season_encoder = getattr(configs, 'season_encoder', 'Transformer')
# Season network (PatchTST + Graph Mixer)
self.season_net = SeasonPatch(
c_in=self.enc_in,
@ -90,7 +92,7 @@ class Model(nn.Module):
seasonal_init, trend_init = self.decomp(x_enc)
# Season stream
y_season = self.season_net(seasonal_init) # [B, C, pred_len]
y_season = self.season_net(seasonal_init, encoder=self.season_encoder) # [B, C, pred_len]
# Trend stream
B, L, C = trend_init.shape
@ -125,7 +127,7 @@ class Model(nn.Module):
seasonal_init, trend_init = self.decomp(x_enc)
# Season stream
y_season = self.season_net(seasonal_init) # [B, C, pred_len]
y_season = self.season_net(seasonal_init, encoder=self.season_encoder) # [B, C, pred_len]
# print("shape:", trend_init.shape)
# Trend stream
@ -163,4 +165,4 @@ class Model(nn.Module):
dec_out = self.classification(x_enc, x_mark_enc)
return dec_out # [B, N]
else:
raise ValueError(f'Task {self.task_name} not supported by xPatch_SparseChannel')
raise ValueError(f'Task {self.task_name} not supported by xPatch_SparseChannel')