feat: implement dynamic threshold scheduling for GraphMixer

This commit is contained in:
gameloader
2025-09-12 17:02:42 +08:00
parent 6a1f9d30f3
commit 172328a4e6
4 changed files with 105 additions and 41 deletions

View File

@ -30,6 +30,10 @@ class SeasonPatch(nn.Module):
headdim: int = 64,
# Mixergraph 可选超参数
thr_graph: float = 0.5,
thr_graph_min: float = None,
thr_graph_max: float = None,
thr_graph_steps: int = 0,
thr_graph_schedule: str = "linear",
symmetric_graph: bool = True,
degree_rescale: str = "count-sqrt", # "none" | "count" | "count-sqrt" | "sum"
gate_temperature: float = 2./3.,
@ -38,6 +42,9 @@ class SeasonPatch(nn.Module):
super().__init__()
# ===== 新增:保存 l0_lambda防止 reg_loss 访问报错 =====
self.l0_lambda = l0_lambda
# Store patch parameters
self.patch_len = patch_len # patch 长度
self.stride = stride # patch 步幅
@ -60,6 +67,10 @@ class SeasonPatch(nn.Module):
dim=d_model,
max_degree=k_graph,
thr=thr_graph,
thr_min=thr_graph_min,
thr_max=thr_graph_max,
thr_steps=thr_graph_steps,
thr_schedule=thr_graph_schedule,
temperature=gate_temperature,
tau_attn=tau_attn,
symmetric=symmetric_graph,