From 1044c60fe7a03f0e353f45fe2dbe1ad92ed8b2bb Mon Sep 17 00:00:00 2001 From: gameloader Date: Wed, 10 Sep 2025 16:00:52 +0800 Subject: [PATCH] refactor(SeasonPatch): unify encoder and head initialization --- layers/SeasonPatch.py | 81 +++++++++++++++------------------- models/xPatch_SparseChannel.py | 6 +-- 2 files changed, 39 insertions(+), 48 deletions(-) diff --git a/layers/SeasonPatch.py b/layers/SeasonPatch.py index d0ce2e1..5244377 100644 --- a/layers/SeasonPatch.py +++ b/layers/SeasonPatch.py @@ -19,6 +19,7 @@ class SeasonPatch(nn.Module): patch_len: int, stride: int, k_graph: int = 8, + encoder_type: str = "Transformer", d_model: int = 128, n_layers: int = 3, n_heads: int = 16, @@ -36,53 +37,46 @@ class SeasonPatch(nn.Module): # Calculate patch number patch_num = (seq_len - patch_len) // stride + 1 # patch_num: int - # Transformer (PatchTST) 编码器(channel independent) - self.encoder = TSTiEncoder( - c_in=c_in, - patch_num=patch_num, - patch_len=patch_len, - d_model=d_model, - n_layers=n_layers, - n_heads=n_heads - ) + self.encoder_type = encoder_type + # 只初始化需要的encoder + if encoder_type == "Transformer": + # Transformer (PatchTST) 编码器(channel independent) + self.encoder = TSTiEncoder( + c_in=c_in, patch_num=patch_num, patch_len=patch_len, + d_model=d_model, n_layers=n_layers, n_heads=n_heads + ) + self.mixer = HierarchicalGraphMixer(c_in, dim=d_model, k=k_graph) + # Prediction head(Transformer 路径用到,输入维度为 patch_num * d_model) + self.head = nn.Sequential( + nn.Linear(patch_num * d_model, patch_num * d_model), + nn.SiLU(), + nn.Linear(patch_num * d_model, pred_len) + ) + elif encoder_type == "Mamba2": # Mamba2 编码器(channel independent),返回 [B, C, d_model] - self.mamba_encoder = Mamba2Encoder( - c_in=c_in, - patch_num=patch_num, - patch_len=patch_len, - d_model=d_model, - d_state=d_state, - d_conv=d_conv, - expand=expand, - headdim=headdim, - ) + self.encoder = Mamba2Encoder( + c_in=c_in, patch_num=patch_num, patch_len=patch_len, + d_model=d_model, d_state=d_state, d_conv=d_conv, + expand=expand, headdim=headdim + ) + # Prediction head(Mamba2 路径用到,输入维度为 d_model) + self.head = nn.Sequential( + nn.Linear(d_model, d_model), + nn.SiLU(), + nn.Linear(d_model, pred_len) + ) + else: + raise ValueError(f"Unsupported encoder_type: {encoder_type}") - # Cross-channel mixer(仅 Transformer 路径使用) - self.mixer = HierarchicalGraphMixer(c_in, dim=d_model, k=k_graph) - - # Prediction head(Transformer 路径用到,输入维度为 patch_num * d_model) - self.head_tr = nn.Sequential( - nn.Linear(patch_num * d_model, patch_num * d_model), - nn.SiLU(), - nn.Linear(patch_num * d_model, pred_len) - ) - - # Prediction head(Mamba2 路径用到,输入维度为 d_model) - self.head_mamba = nn.Sequential( - nn.Linear(d_model, d_model), - nn.SiLU(), - nn.Linear(d_model, pred_len) - ) - - def forward(self, x, encoder="Transformer"): + def forward(self, x ): # x: [B, L, C] x = x.permute(0, 2, 1) # x: [B, C, L] # Patch the input x_patch = x.unfold(-1, self.patch_len, self.stride) # x_patch: [B, C, patch_num, patch_len] - if encoder == "Transformer": + if self.encoder_type == "Transformer": # Encode patches (PatchTST) z = self.encoder(x_patch) # z: [B, C, d_model, patch_num] @@ -95,14 +89,11 @@ class SeasonPatch(nn.Module): # Flatten and predict z_mix = z_mix.view(B, C, N * D) # z_mix: [B, C, patch_num * d_model] - y_pred = self.head_tr(z_mix) # y_pred: [B, C, pred_len] + y_pred = self.head(z_mix) # y_pred: [B, C, pred_len] - elif encoder == "Mamba2": + elif self.encoder_type == "Mamba2": # 使用 Mamba2 编码器(不使用 mixer) - z_last = self.mamba_encoder(x_patch) # z_last: [B, C, d_model](仅最后一个时间步) - y_pred = self.head_mamba(z_last) # y_pred: [B, C, pred_len] - - else: - raise ValueError(f"Unsupported encoder type: {encoder}. Use 'Transformer' or 'Mamba2'.") + z_last = self.encoder(x_patch) # z_last: [B, C, d_model](仅最后一个时间步) + y_pred = self.head(z_last) # y_pred: [B, C, pred_len] return y_pred # [B, C, pred_len] diff --git a/models/xPatch_SparseChannel.py b/models/xPatch_SparseChannel.py index af6e599..88c3561 100644 --- a/models/xPatch_SparseChannel.py +++ b/models/xPatch_SparseChannel.py @@ -37,8 +37,6 @@ class Model(nn.Module): beta = getattr(configs, 'beta', torch.tensor(0.1)) self.decomp = DECOMP(ma_type, alpha, beta) - # 读取选择的编码器类型('Transformer' 或 'Mamba2') - self.season_encoder = getattr(configs, 'season_encoder', 'Transformer') # Season network (PatchTST + Graph Mixer) self.season_net = SeasonPatch( c_in=self.enc_in, @@ -49,7 +47,9 @@ class Model(nn.Module): k_graph=getattr(configs, 'k_graph', 8), d_model=getattr(configs, 'd_model', 128), n_layers=getattr(configs, 'e_layers', 3), - n_heads=getattr(configs, 'n_heads', 16) + n_heads=getattr(configs, 'n_heads', 16), + # 读取选择的编码器类型('Transformer' 或 'Mamba2') + encoder_type = getattr(configs, 'season_encoder', 'Transformer') ) # Trend network (MLP)