feat(model): introduce dynamic training flag for model forward pass
This commit is contained in:
@ -116,8 +116,8 @@ class HierarchicalGraphMixer(nn.Module):
|
||||
else:
|
||||
g = progress
|
||||
return self.thr_min + (self.thr_max - self.thr_min) * g
|
||||
def _maybe_update_thr(self):
|
||||
if self.training and self._use_dynamic_thr:
|
||||
def _maybe_update_thr(self, training):
|
||||
if training and self._use_dynamic_thr:
|
||||
step = int(self._thr_step.item())
|
||||
progress = step / float(self.thr_steps)
|
||||
self.thr = float(self._compute_thr_by_progress(progress))
|
||||
@ -185,14 +185,14 @@ class HierarchicalGraphMixer(nn.Module):
|
||||
"""
|
||||
return lam * self.gate.expected_l0().sum()
|
||||
|
||||
def forward(self, z):
|
||||
self._maybe_update_thr()
|
||||
def forward(self, z, is_training):
|
||||
self._maybe_update_thr(training=is_training)
|
||||
# z: [B, C, N, D]
|
||||
B, C, N, D = z.shape
|
||||
assert C == self.C and D == self.dim
|
||||
|
||||
# Level 1: 采样非归一化门 z_gate ∈ [0,1]
|
||||
z_gate = self.gate.sample(training=self.training) # [C, C]
|
||||
z_gate = self.gate.sample(training=is_training) # [C, C]
|
||||
|
||||
# 构建稀疏邻居(阈值 + 可选 top-k)
|
||||
idx_list, w_list = self._build_sparse_neighbors(z_gate)
|
||||
|
@ -98,7 +98,7 @@ class SeasonPatch(nn.Module):
|
||||
else:
|
||||
raise ValueError(f"Unsupported encoder_type: {encoder_type}")
|
||||
|
||||
def forward(self, x ):
|
||||
def forward(self, x, training):
|
||||
# x: [B, L, C]
|
||||
x = x.permute(0, 2, 1) # x: [B, C, L]
|
||||
|
||||
@ -114,7 +114,7 @@ class SeasonPatch(nn.Module):
|
||||
z = z.permute(0, 1, 3, 2) # z: [B, C, patch_num, d_model]
|
||||
|
||||
# Cross-channel mixing
|
||||
z_mix = self.mixer(z) # z_mix: [B, C, patch_num, d_model]
|
||||
z_mix = self.mixer(z, training) # z_mix: [B, C, patch_num, d_model]
|
||||
|
||||
# Flatten and predict
|
||||
z_mix = z_mix.view(B, C, N * D) # z_mix: [B, C, patch_num * d_model]
|
||||
|
Reference in New Issue
Block a user