refactor(SeasonPatch): unify encoder and head initialization

This commit is contained in:
gameloader
2025-09-10 16:00:52 +08:00
parent 908d3a7080
commit 1044c60fe7
2 changed files with 39 additions and 48 deletions

View File

@ -19,6 +19,7 @@ class SeasonPatch(nn.Module):
patch_len: int, patch_len: int,
stride: int, stride: int,
k_graph: int = 8, k_graph: int = 8,
encoder_type: str = "Transformer",
d_model: int = 128, d_model: int = 128,
n_layers: int = 3, n_layers: int = 3,
n_heads: int = 16, n_heads: int = 16,
@ -36,53 +37,46 @@ class SeasonPatch(nn.Module):
# Calculate patch number # Calculate patch number
patch_num = (seq_len - patch_len) // stride + 1 # patch_num: int patch_num = (seq_len - patch_len) // stride + 1 # patch_num: int
# Transformer (PatchTST) 编码器channel independent self.encoder_type = encoder_type
self.encoder = TSTiEncoder( # 只初始化需要的encoder
c_in=c_in,
patch_num=patch_num,
patch_len=patch_len,
d_model=d_model,
n_layers=n_layers,
n_heads=n_heads
)
if encoder_type == "Transformer":
# Transformer (PatchTST) 编码器channel independent
self.encoder = TSTiEncoder(
c_in=c_in, patch_num=patch_num, patch_len=patch_len,
d_model=d_model, n_layers=n_layers, n_heads=n_heads
)
self.mixer = HierarchicalGraphMixer(c_in, dim=d_model, k=k_graph)
# Prediction headTransformer 路径用到,输入维度为 patch_num * d_model
self.head = nn.Sequential(
nn.Linear(patch_num * d_model, patch_num * d_model),
nn.SiLU(),
nn.Linear(patch_num * d_model, pred_len)
)
elif encoder_type == "Mamba2":
# Mamba2 编码器channel independent返回 [B, C, d_model] # Mamba2 编码器channel independent返回 [B, C, d_model]
self.mamba_encoder = Mamba2Encoder( self.encoder = Mamba2Encoder(
c_in=c_in, c_in=c_in, patch_num=patch_num, patch_len=patch_len,
patch_num=patch_num, d_model=d_model, d_state=d_state, d_conv=d_conv,
patch_len=patch_len, expand=expand, headdim=headdim
d_model=d_model, )
d_state=d_state, # Prediction headMamba2 路径用到,输入维度为 d_model
d_conv=d_conv, self.head = nn.Sequential(
expand=expand, nn.Linear(d_model, d_model),
headdim=headdim, nn.SiLU(),
) nn.Linear(d_model, pred_len)
)
else:
raise ValueError(f"Unsupported encoder_type: {encoder_type}")
# Cross-channel mixer仅 Transformer 路径使用) def forward(self, x ):
self.mixer = HierarchicalGraphMixer(c_in, dim=d_model, k=k_graph)
# Prediction headTransformer 路径用到,输入维度为 patch_num * d_model
self.head_tr = nn.Sequential(
nn.Linear(patch_num * d_model, patch_num * d_model),
nn.SiLU(),
nn.Linear(patch_num * d_model, pred_len)
)
# Prediction headMamba2 路径用到,输入维度为 d_model
self.head_mamba = nn.Sequential(
nn.Linear(d_model, d_model),
nn.SiLU(),
nn.Linear(d_model, pred_len)
)
def forward(self, x, encoder="Transformer"):
# x: [B, L, C] # x: [B, L, C]
x = x.permute(0, 2, 1) # x: [B, C, L] x = x.permute(0, 2, 1) # x: [B, C, L]
# Patch the input # Patch the input
x_patch = x.unfold(-1, self.patch_len, self.stride) # x_patch: [B, C, patch_num, patch_len] x_patch = x.unfold(-1, self.patch_len, self.stride) # x_patch: [B, C, patch_num, patch_len]
if encoder == "Transformer": if self.encoder_type == "Transformer":
# Encode patches (PatchTST) # Encode patches (PatchTST)
z = self.encoder(x_patch) # z: [B, C, d_model, patch_num] z = self.encoder(x_patch) # z: [B, C, d_model, patch_num]
@ -95,14 +89,11 @@ class SeasonPatch(nn.Module):
# Flatten and predict # Flatten and predict
z_mix = z_mix.view(B, C, N * D) # z_mix: [B, C, patch_num * d_model] z_mix = z_mix.view(B, C, N * D) # z_mix: [B, C, patch_num * d_model]
y_pred = self.head_tr(z_mix) # y_pred: [B, C, pred_len] y_pred = self.head(z_mix) # y_pred: [B, C, pred_len]
elif encoder == "Mamba2": elif self.encoder_type == "Mamba2":
# 使用 Mamba2 编码器(不使用 mixer # 使用 Mamba2 编码器(不使用 mixer
z_last = self.mamba_encoder(x_patch) # z_last: [B, C, d_model](仅最后一个时间步) z_last = self.encoder(x_patch) # z_last: [B, C, d_model](仅最后一个时间步)
y_pred = self.head_mamba(z_last) # y_pred: [B, C, pred_len] y_pred = self.head(z_last) # y_pred: [B, C, pred_len]
else:
raise ValueError(f"Unsupported encoder type: {encoder}. Use 'Transformer' or 'Mamba2'.")
return y_pred # [B, C, pred_len] return y_pred # [B, C, pred_len]

View File

@ -37,8 +37,6 @@ class Model(nn.Module):
beta = getattr(configs, 'beta', torch.tensor(0.1)) beta = getattr(configs, 'beta', torch.tensor(0.1))
self.decomp = DECOMP(ma_type, alpha, beta) self.decomp = DECOMP(ma_type, alpha, beta)
# 读取选择的编码器类型('Transformer' 或 'Mamba2'
self.season_encoder = getattr(configs, 'season_encoder', 'Transformer')
# Season network (PatchTST + Graph Mixer) # Season network (PatchTST + Graph Mixer)
self.season_net = SeasonPatch( self.season_net = SeasonPatch(
c_in=self.enc_in, c_in=self.enc_in,
@ -49,7 +47,9 @@ class Model(nn.Module):
k_graph=getattr(configs, 'k_graph', 8), k_graph=getattr(configs, 'k_graph', 8),
d_model=getattr(configs, 'd_model', 128), d_model=getattr(configs, 'd_model', 128),
n_layers=getattr(configs, 'e_layers', 3), n_layers=getattr(configs, 'e_layers', 3),
n_heads=getattr(configs, 'n_heads', 16) n_heads=getattr(configs, 'n_heads', 16),
# 读取选择的编码器类型('Transformer' 或 'Mamba2'
encoder_type = getattr(configs, 'season_encoder', 'Transformer')
) )
# Trend network (MLP) # Trend network (MLP)