feat(graph-mixer): implement L0 sparsity with Hard-Concrete gate for channel selection

This commit is contained in:
gameloader
2025-09-11 16:50:58 +08:00
parent 5fc0da4239
commit 204d17086a
4 changed files with 268 additions and 124 deletions

View File

@ -27,7 +27,15 @@ class SeasonPatch(nn.Module):
d_state: int = 64,
d_conv: int = 4,
expand: int = 2,
headdim: int = 64):
headdim: int = 64,
# Mixergraph 可选超参数
thr_graph: float = 0.5,
symmetric_graph: bool = True,
degree_rescale: str = "count-sqrt", # "none" | "count" | "count-sqrt" | "sum"
gate_temperature: float = 2./3.,
tau_attn: float = 1.0,
l0_lambda: float = 1e-4):
super().__init__()
# Store patch parameters
@ -46,7 +54,17 @@ class SeasonPatch(nn.Module):
c_in=c_in, patch_num=patch_num, patch_len=patch_len,
d_model=d_model, n_layers=n_layers, n_heads=n_heads
)
self.mixer = HierarchicalGraphMixer(c_in, dim=d_model, k=k_graph)
# 集成新 HierarchicalGraphMixer非归一化
self.mixer = HierarchicalGraphMixer(
n_channel=c_in,
dim=d_model,
max_degree=k_graph,
thr=thr_graph,
temperature=gate_temperature,
tau_attn=tau_attn,
symmetric=symmetric_graph,
degree_rescale=degree_rescale
)
# Prediction headTransformer 路径用到,输入维度为 patch_num * d_model
self.head = nn.Sequential(
nn.Linear(patch_num * d_model, patch_num * d_model),
@ -97,3 +115,11 @@ class SeasonPatch(nn.Module):
y_pred = self.head(z_last) # y_pred: [B, C, pred_len]
return y_pred # [B, C, pred_len]
def reg_loss(self):
"""
可选:把 L0 正则暴露出去训练时加到总loss。
"""
if self.encoder_type == "Transformer" and hasattr(self, "mixer"):
return self.mixer.l0_loss(self.l0_lambda)
return torch.tensor(0.0, device=self.head[0].weight.device)