feat: implement dynamic threshold scheduling for GraphMixer
This commit is contained in:
@ -30,6 +30,10 @@ class SeasonPatch(nn.Module):
|
||||
headdim: int = 64,
|
||||
# Mixergraph 可选超参数
|
||||
thr_graph: float = 0.5,
|
||||
thr_graph_min: float = None,
|
||||
thr_graph_max: float = None,
|
||||
thr_graph_steps: int = 0,
|
||||
thr_graph_schedule: str = "linear",
|
||||
symmetric_graph: bool = True,
|
||||
degree_rescale: str = "count-sqrt", # "none" | "count" | "count-sqrt" | "sum"
|
||||
gate_temperature: float = 2./3.,
|
||||
@ -38,6 +42,9 @@ class SeasonPatch(nn.Module):
|
||||
|
||||
super().__init__()
|
||||
|
||||
# ===== 新增:保存 l0_lambda,防止 reg_loss 访问报错 =====
|
||||
self.l0_lambda = l0_lambda
|
||||
|
||||
# Store patch parameters
|
||||
self.patch_len = patch_len # patch 长度
|
||||
self.stride = stride # patch 步幅
|
||||
@ -60,6 +67,10 @@ class SeasonPatch(nn.Module):
|
||||
dim=d_model,
|
||||
max_degree=k_graph,
|
||||
thr=thr_graph,
|
||||
thr_min=thr_graph_min,
|
||||
thr_max=thr_graph_max,
|
||||
thr_steps=thr_graph_steps,
|
||||
thr_schedule=thr_graph_schedule,
|
||||
temperature=gate_temperature,
|
||||
tau_attn=tau_attn,
|
||||
symmetric=symmetric_graph,
|
||||
|
Reference in New Issue
Block a user