refactor(SeasonPatch): unify encoder and head initialization
This commit is contained in:
@ -19,6 +19,7 @@ class SeasonPatch(nn.Module):
|
||||
patch_len: int,
|
||||
stride: int,
|
||||
k_graph: int = 8,
|
||||
encoder_type: str = "Transformer",
|
||||
d_model: int = 128,
|
||||
n_layers: int = 3,
|
||||
n_heads: int = 16,
|
||||
@ -36,53 +37,46 @@ class SeasonPatch(nn.Module):
|
||||
# Calculate patch number
|
||||
patch_num = (seq_len - patch_len) // stride + 1 # patch_num: int
|
||||
|
||||
self.encoder_type = encoder_type
|
||||
# 只初始化需要的encoder
|
||||
|
||||
if encoder_type == "Transformer":
|
||||
# Transformer (PatchTST) 编码器(channel independent)
|
||||
self.encoder = TSTiEncoder(
|
||||
c_in=c_in,
|
||||
patch_num=patch_num,
|
||||
patch_len=patch_len,
|
||||
d_model=d_model,
|
||||
n_layers=n_layers,
|
||||
n_heads=n_heads
|
||||
c_in=c_in, patch_num=patch_num, patch_len=patch_len,
|
||||
d_model=d_model, n_layers=n_layers, n_heads=n_heads
|
||||
)
|
||||
|
||||
# Mamba2 编码器(channel independent),返回 [B, C, d_model]
|
||||
self.mamba_encoder = Mamba2Encoder(
|
||||
c_in=c_in,
|
||||
patch_num=patch_num,
|
||||
patch_len=patch_len,
|
||||
d_model=d_model,
|
||||
d_state=d_state,
|
||||
d_conv=d_conv,
|
||||
expand=expand,
|
||||
headdim=headdim,
|
||||
)
|
||||
|
||||
# Cross-channel mixer(仅 Transformer 路径使用)
|
||||
self.mixer = HierarchicalGraphMixer(c_in, dim=d_model, k=k_graph)
|
||||
|
||||
# Prediction head(Transformer 路径用到,输入维度为 patch_num * d_model)
|
||||
self.head_tr = nn.Sequential(
|
||||
self.head = nn.Sequential(
|
||||
nn.Linear(patch_num * d_model, patch_num * d_model),
|
||||
nn.SiLU(),
|
||||
nn.Linear(patch_num * d_model, pred_len)
|
||||
)
|
||||
|
||||
elif encoder_type == "Mamba2":
|
||||
# Mamba2 编码器(channel independent),返回 [B, C, d_model]
|
||||
self.encoder = Mamba2Encoder(
|
||||
c_in=c_in, patch_num=patch_num, patch_len=patch_len,
|
||||
d_model=d_model, d_state=d_state, d_conv=d_conv,
|
||||
expand=expand, headdim=headdim
|
||||
)
|
||||
# Prediction head(Mamba2 路径用到,输入维度为 d_model)
|
||||
self.head_mamba = nn.Sequential(
|
||||
self.head = nn.Sequential(
|
||||
nn.Linear(d_model, d_model),
|
||||
nn.SiLU(),
|
||||
nn.Linear(d_model, pred_len)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported encoder_type: {encoder_type}")
|
||||
|
||||
def forward(self, x, encoder="Transformer"):
|
||||
def forward(self, x ):
|
||||
# x: [B, L, C]
|
||||
x = x.permute(0, 2, 1) # x: [B, C, L]
|
||||
|
||||
# Patch the input
|
||||
x_patch = x.unfold(-1, self.patch_len, self.stride) # x_patch: [B, C, patch_num, patch_len]
|
||||
|
||||
if encoder == "Transformer":
|
||||
if self.encoder_type == "Transformer":
|
||||
# Encode patches (PatchTST)
|
||||
z = self.encoder(x_patch) # z: [B, C, d_model, patch_num]
|
||||
|
||||
@ -95,14 +89,11 @@ class SeasonPatch(nn.Module):
|
||||
|
||||
# Flatten and predict
|
||||
z_mix = z_mix.view(B, C, N * D) # z_mix: [B, C, patch_num * d_model]
|
||||
y_pred = self.head_tr(z_mix) # y_pred: [B, C, pred_len]
|
||||
y_pred = self.head(z_mix) # y_pred: [B, C, pred_len]
|
||||
|
||||
elif encoder == "Mamba2":
|
||||
elif self.encoder_type == "Mamba2":
|
||||
# 使用 Mamba2 编码器(不使用 mixer)
|
||||
z_last = self.mamba_encoder(x_patch) # z_last: [B, C, d_model](仅最后一个时间步)
|
||||
y_pred = self.head_mamba(z_last) # y_pred: [B, C, pred_len]
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported encoder type: {encoder}. Use 'Transformer' or 'Mamba2'.")
|
||||
z_last = self.encoder(x_patch) # z_last: [B, C, d_model](仅最后一个时间步)
|
||||
y_pred = self.head(z_last) # y_pred: [B, C, pred_len]
|
||||
|
||||
return y_pred # [B, C, pred_len]
|
||||
|
@ -37,8 +37,6 @@ class Model(nn.Module):
|
||||
beta = getattr(configs, 'beta', torch.tensor(0.1))
|
||||
self.decomp = DECOMP(ma_type, alpha, beta)
|
||||
|
||||
# 读取选择的编码器类型('Transformer' 或 'Mamba2')
|
||||
self.season_encoder = getattr(configs, 'season_encoder', 'Transformer')
|
||||
# Season network (PatchTST + Graph Mixer)
|
||||
self.season_net = SeasonPatch(
|
||||
c_in=self.enc_in,
|
||||
@ -49,7 +47,9 @@ class Model(nn.Module):
|
||||
k_graph=getattr(configs, 'k_graph', 8),
|
||||
d_model=getattr(configs, 'd_model', 128),
|
||||
n_layers=getattr(configs, 'e_layers', 3),
|
||||
n_heads=getattr(configs, 'n_heads', 16)
|
||||
n_heads=getattr(configs, 'n_heads', 16),
|
||||
# 读取选择的编码器类型('Transformer' 或 'Mamba2')
|
||||
encoder_type = getattr(configs, 'season_encoder', 'Transformer')
|
||||
)
|
||||
|
||||
# Trend network (MLP)
|
||||
|
Reference in New Issue
Block a user