feat: implement dynamic threshold scheduling for GraphMixer
This commit is contained in:
@ -53,6 +53,10 @@ class HierarchicalGraphMixer(nn.Module):
|
||||
dim: int,
|
||||
max_degree: int = None, # 可选:限制每行最多边数
|
||||
thr: float = 0.5, # 保留边阈值,例如 0.5/0.7
|
||||
thr_min: float = None, # 动态阈值起点,不传则用 thr
|
||||
thr_max: float = None, # 动态阈值终点,不传则用 thr
|
||||
thr_steps: int = 0, # 从 thr_min -> thr_max 的步数,>0 时启用动态调度
|
||||
thr_schedule: str = "linear", # "linear" | "cosine" | "exp"
|
||||
temperature: float = 2./3.,
|
||||
tau_attn: float = 1.0, # Patch attention 温度(可选)
|
||||
symmetric: bool = True, # 是否对称化通道图
|
||||
@ -67,6 +71,13 @@ class HierarchicalGraphMixer(nn.Module):
|
||||
self.tau_attn = tau_attn
|
||||
self.symmetric = symmetric
|
||||
self.degree_rescale = degree_rescale
|
||||
self.thr_min = thr if (thr_min is None) else float(thr_min)
|
||||
self.thr_max = thr if (thr_max is None) else float(thr_max)
|
||||
self.thr_steps = int(thr_steps) if thr_steps is not None else 0
|
||||
self.thr_schedule = thr_schedule
|
||||
self._use_dynamic_thr = (self.thr_steps > 0) and (abs(self.thr_max - self.thr_min) > 1e-12)
|
||||
# 用 buffer 记录已步进次数(不保存到权重里)
|
||||
self.register_buffer("_thr_step", torch.zeros((), dtype=torch.long), persistent=False)
|
||||
|
||||
# Level 1: 非归一化门控
|
||||
self.gate = HardConcreteGate(
|
||||
@ -88,6 +99,30 @@ class HierarchicalGraphMixer(nn.Module):
|
||||
self.out_proj = nn.Linear(dim, dim)
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
def _compute_thr_by_progress(self, progress: float) -> float:
|
||||
# progress in [0,1]
|
||||
progress = max(0.0, min(1.0, float(progress)))
|
||||
if self.thr_schedule == "linear":
|
||||
g = progress
|
||||
elif self.thr_schedule == "cosine":
|
||||
# 慢起步,后期加速
|
||||
import math
|
||||
g = 0.5 - 0.5 * math.cos(math.pi * progress)
|
||||
elif self.thr_schedule == "exp":
|
||||
# 更快从 thr_min 过渡到 thr_max(指数式)
|
||||
import math
|
||||
k = 5.0
|
||||
g = (math.exp(k * progress) - 1.0) / (math.exp(k) - 1.0)
|
||||
else:
|
||||
g = progress
|
||||
return self.thr_min + (self.thr_max - self.thr_min) * g
|
||||
def _maybe_update_thr(self):
|
||||
if self.training and self._use_dynamic_thr:
|
||||
step = int(self._thr_step.item())
|
||||
progress = step / float(self.thr_steps)
|
||||
self.thr = float(self._compute_thr_by_progress(progress))
|
||||
self._thr_step += 1
|
||||
|
||||
def _build_sparse_neighbors(self, z_gate):
|
||||
"""
|
||||
基于 z_gate 构造每行的邻接列表(按阈值与可选top-k)。
|
||||
@ -151,6 +186,7 @@ class HierarchicalGraphMixer(nn.Module):
|
||||
return lam * self.gate.expected_l0().sum()
|
||||
|
||||
def forward(self, z):
|
||||
self._maybe_update_thr()
|
||||
# z: [B, C, N, D]
|
||||
B, C, N, D = z.shape
|
||||
assert C == self.C and D == self.dim
|
||||
|
Reference in New Issue
Block a user