From 172328a4e61b8d9da536af47959a08b2c97d28a4 Mon Sep 17 00:00:00 2001 From: gameloader Date: Fri, 12 Sep 2025 17:02:42 +0800 Subject: [PATCH] feat: implement dynamic threshold scheduling for GraphMixer --- layers/GraphMixer.py | 36 +++++++ layers/SeasonPatch.py | 11 +++ models/xPatch_SparseChannel.py | 4 + .../xPatch_SparseChannel_all-Copy1.sh | 95 +++++++++++-------- 4 files changed, 105 insertions(+), 41 deletions(-) diff --git a/layers/GraphMixer.py b/layers/GraphMixer.py index bb826da..277b989 100644 --- a/layers/GraphMixer.py +++ b/layers/GraphMixer.py @@ -53,6 +53,10 @@ class HierarchicalGraphMixer(nn.Module): dim: int, max_degree: int = None, # 可选:限制每行最多边数 thr: float = 0.5, # 保留边阈值,例如 0.5/0.7 + thr_min: float = None, # 动态阈值起点,不传则用 thr + thr_max: float = None, # 动态阈值终点,不传则用 thr + thr_steps: int = 0, # 从 thr_min -> thr_max 的步数,>0 时启用动态调度 + thr_schedule: str = "linear", # "linear" | "cosine" | "exp" temperature: float = 2./3., tau_attn: float = 1.0, # Patch attention 温度(可选) symmetric: bool = True, # 是否对称化通道图 @@ -67,6 +71,13 @@ class HierarchicalGraphMixer(nn.Module): self.tau_attn = tau_attn self.symmetric = symmetric self.degree_rescale = degree_rescale + self.thr_min = thr if (thr_min is None) else float(thr_min) + self.thr_max = thr if (thr_max is None) else float(thr_max) + self.thr_steps = int(thr_steps) if thr_steps is not None else 0 + self.thr_schedule = thr_schedule + self._use_dynamic_thr = (self.thr_steps > 0) and (abs(self.thr_max - self.thr_min) > 1e-12) + # 用 buffer 记录已步进次数(不保存到权重里) + self.register_buffer("_thr_step", torch.zeros((), dtype=torch.long), persistent=False) # Level 1: 非归一化门控 self.gate = HardConcreteGate( @@ -88,6 +99,30 @@ class HierarchicalGraphMixer(nn.Module): self.out_proj = nn.Linear(dim, dim) self.norm = nn.LayerNorm(dim) + def _compute_thr_by_progress(self, progress: float) -> float: + # progress in [0,1] + progress = max(0.0, min(1.0, float(progress))) + if self.thr_schedule == "linear": + g = progress + elif self.thr_schedule == "cosine": + # 慢起步,后期加速 + import math + g = 0.5 - 0.5 * math.cos(math.pi * progress) + elif self.thr_schedule == "exp": + # 更快从 thr_min 过渡到 thr_max(指数式) + import math + k = 5.0 + g = (math.exp(k * progress) - 1.0) / (math.exp(k) - 1.0) + else: + g = progress + return self.thr_min + (self.thr_max - self.thr_min) * g + def _maybe_update_thr(self): + if self.training and self._use_dynamic_thr: + step = int(self._thr_step.item()) + progress = step / float(self.thr_steps) + self.thr = float(self._compute_thr_by_progress(progress)) + self._thr_step += 1 + def _build_sparse_neighbors(self, z_gate): """ 基于 z_gate 构造每行的邻接列表(按阈值与可选top-k)。 @@ -151,6 +186,7 @@ class HierarchicalGraphMixer(nn.Module): return lam * self.gate.expected_l0().sum() def forward(self, z): + self._maybe_update_thr() # z: [B, C, N, D] B, C, N, D = z.shape assert C == self.C and D == self.dim diff --git a/layers/SeasonPatch.py b/layers/SeasonPatch.py index f25e445..a508d68 100644 --- a/layers/SeasonPatch.py +++ b/layers/SeasonPatch.py @@ -30,6 +30,10 @@ class SeasonPatch(nn.Module): headdim: int = 64, # Mixergraph 可选超参数 thr_graph: float = 0.5, + thr_graph_min: float = None, + thr_graph_max: float = None, + thr_graph_steps: int = 0, + thr_graph_schedule: str = "linear", symmetric_graph: bool = True, degree_rescale: str = "count-sqrt", # "none" | "count" | "count-sqrt" | "sum" gate_temperature: float = 2./3., @@ -38,6 +42,9 @@ class SeasonPatch(nn.Module): super().__init__() + # ===== 新增:保存 l0_lambda,防止 reg_loss 访问报错 ===== + self.l0_lambda = l0_lambda + # Store patch parameters self.patch_len = patch_len # patch 长度 self.stride = stride # patch 步幅 @@ -60,6 +67,10 @@ class SeasonPatch(nn.Module): dim=d_model, max_degree=k_graph, thr=thr_graph, + thr_min=thr_graph_min, + thr_max=thr_graph_max, + thr_steps=thr_graph_steps, + thr_schedule=thr_graph_schedule, temperature=gate_temperature, tau_attn=tau_attn, symmetric=symmetric_graph, diff --git a/models/xPatch_SparseChannel.py b/models/xPatch_SparseChannel.py index e04319e..2b5f009 100644 --- a/models/xPatch_SparseChannel.py +++ b/models/xPatch_SparseChannel.py @@ -54,6 +54,10 @@ class Model(nn.Module): # GraphMixer相关(非归一化) k_graph=getattr(configs, 'k_graph', 8), # -> max_degree thr_graph=getattr(configs, 'thr_graph', 0.5), + thr_graph_min=getattr(configs, 'thr_graph_min', None), + thr_graph_max=getattr(configs, 'thr_graph_max', None), + thr_graph_steps=getattr(configs, 'thr_graph_steps', 0), + thr_graph_schedule=getattr(configs, 'thr_graph_schedule', 'linear'), symmetric_graph=getattr(configs, 'symmetric_graph', True), degree_rescale=getattr(configs, 'degree_rescale', 'count-sqrt'), # 'none' | 'count' | 'count-sqrt' | 'sum' gate_temperature=getattr(configs, 'gate_temperature', 2.0/3.0), diff --git a/scripts/long_term_forecast/xPatch_SparseChannel_all-Copy1.sh b/scripts/long_term_forecast/xPatch_SparseChannel_all-Copy1.sh index aa07a76..ed212b9 100644 --- a/scripts/long_term_forecast/xPatch_SparseChannel_all-Copy1.sh +++ b/scripts/long_term_forecast/xPatch_SparseChannel_all-Copy1.sh @@ -2,45 +2,6 @@ model_name=xPatch_SparseChannel -# ETTm1 dataset -for pred_len in 96 192 336 720 -do -python -u run.py \ - --task_name long_term_forecast \ - --is_training 1 \ - --root_path ./dataset/ETT-small/ \ - --data_path ETTm1.csv \ - --model_id ETTm1_$pred_len'_'$pred_len \ - --model $model_name \ - --data ETTm1 \ - --features M \ - --seq_len 96 \ - --label_len 48 \ - --pred_len $pred_len \ - --e_layers 2 \ - --d_layers 1 \ - --enc_in 7 \ - --c_out 7 \ - --d_model 128 \ - --lradj 'sigmoid' \ - --d_ff 256 \ - --n_heads 16 \ - --patch_len 16 \ - --stride 8 \ - --k_graph 5 \ - --dropout 0.1 \ - --revin 1 \ - --des 'Exp' \ - --itr 1 \ - --season_encoder 'Transformer' \ - --thr_graph 0.6 \ - --symmetric_graph 1 \ - --degree_rescale 'none' \ - --gate_temperature 0.6667 \ - --tau_attn 1.0 \ - --season_l0_lambda 0.0000 -done - # Weather dataset for pred_len in 96 192 336 720 do @@ -78,7 +39,11 @@ python -u run.py \ --degree_rescale 'none' \ --gate_temperature 0.6667 \ --tau_attn 1.0 \ - --season_l0_lambda 0.0000 + --season_l0_lambda 0.0000 \ + --thr_graph_min 0.1 \ + --thr_graph_max 0.6 \ + --thr_graph_steps 1000 \ + --thr_graph_schedule 'cosine' done # Exchange dataset @@ -117,9 +82,57 @@ python -u run.py \ --degree_rescale 'none' \ --gate_temperature 0.6667 \ --tau_attn 1.0 \ - --season_l0_lambda 0.0000 + --season_l0_lambda 0.0000 \ + --thr_graph_min 0.1 \ + --thr_graph_max 0.6 \ + --thr_graph_steps 1000 \ + --thr_graph_schedule 'cosine' done +# ETTm1 dataset +for pred_len in 96 192 336 720 +do +python -u run.py \ + --task_name long_term_forecast \ + --is_training 1 \ + --root_path ./dataset/ETT-small/ \ + --data_path ETTm1.csv \ + --model_id ETTm1_$pred_len'_'$pred_len \ + --model $model_name \ + --data ETTm1 \ + --features M \ + --seq_len 96 \ + --label_len 48 \ + --pred_len $pred_len \ + --e_layers 2 \ + --d_layers 1 \ + --enc_in 7 \ + --c_out 7 \ + --d_model 128 \ + --lradj 'sigmoid' \ + --d_ff 256 \ + --n_heads 16 \ + --patch_len 16 \ + --stride 8 \ + --k_graph 5 \ + --dropout 0.1 \ + --revin 1 \ + --des 'Exp' \ + --itr 1 \ + --season_encoder 'Transformer' \ + --thr_graph 0.6 \ + --symmetric_graph 1 \ + --degree_rescale 'none' \ + --gate_temperature 0.6667 \ + --tau_attn 1.0 \ + --season_l0_lambda 0.0000 \ + --thr_graph_min 0.1 \ + --thr_graph_max 0.6 \ + --thr_graph_steps 1000 \ + --thr_graph_schedule 'cosine' +done + + # ETTm2 dataset for pred_len in 96 192 336 720