feat(model): introduce dynamic training flag for model forward pass

This commit is contained in:
gameloader
2025-09-13 00:04:01 +08:00
parent 172328a4e6
commit 93f14077da
4 changed files with 40 additions and 22 deletions

View File

@ -116,8 +116,8 @@ class HierarchicalGraphMixer(nn.Module):
else:
g = progress
return self.thr_min + (self.thr_max - self.thr_min) * g
def _maybe_update_thr(self):
if self.training and self._use_dynamic_thr:
def _maybe_update_thr(self, training):
if training and self._use_dynamic_thr:
step = int(self._thr_step.item())
progress = step / float(self.thr_steps)
self.thr = float(self._compute_thr_by_progress(progress))
@ -185,14 +185,14 @@ class HierarchicalGraphMixer(nn.Module):
"""
return lam * self.gate.expected_l0().sum()
def forward(self, z):
self._maybe_update_thr()
def forward(self, z, is_training):
self._maybe_update_thr(training=is_training)
# z: [B, C, N, D]
B, C, N, D = z.shape
assert C == self.C and D == self.dim
# Level 1: 采样非归一化门 z_gate ∈ [0,1]
z_gate = self.gate.sample(training=self.training) # [C, C]
z_gate = self.gate.sample(training=is_training) # [C, C]
# 构建稀疏邻居(阈值 + 可选 top-k
idx_list, w_list = self._build_sparse_neighbors(z_gate)