feat: implement dynamic threshold scheduling for GraphMixer
This commit is contained in:
@ -53,6 +53,10 @@ class HierarchicalGraphMixer(nn.Module):
|
|||||||
dim: int,
|
dim: int,
|
||||||
max_degree: int = None, # 可选:限制每行最多边数
|
max_degree: int = None, # 可选:限制每行最多边数
|
||||||
thr: float = 0.5, # 保留边阈值,例如 0.5/0.7
|
thr: float = 0.5, # 保留边阈值,例如 0.5/0.7
|
||||||
|
thr_min: float = None, # 动态阈值起点,不传则用 thr
|
||||||
|
thr_max: float = None, # 动态阈值终点,不传则用 thr
|
||||||
|
thr_steps: int = 0, # 从 thr_min -> thr_max 的步数,>0 时启用动态调度
|
||||||
|
thr_schedule: str = "linear", # "linear" | "cosine" | "exp"
|
||||||
temperature: float = 2./3.,
|
temperature: float = 2./3.,
|
||||||
tau_attn: float = 1.0, # Patch attention 温度(可选)
|
tau_attn: float = 1.0, # Patch attention 温度(可选)
|
||||||
symmetric: bool = True, # 是否对称化通道图
|
symmetric: bool = True, # 是否对称化通道图
|
||||||
@ -67,6 +71,13 @@ class HierarchicalGraphMixer(nn.Module):
|
|||||||
self.tau_attn = tau_attn
|
self.tau_attn = tau_attn
|
||||||
self.symmetric = symmetric
|
self.symmetric = symmetric
|
||||||
self.degree_rescale = degree_rescale
|
self.degree_rescale = degree_rescale
|
||||||
|
self.thr_min = thr if (thr_min is None) else float(thr_min)
|
||||||
|
self.thr_max = thr if (thr_max is None) else float(thr_max)
|
||||||
|
self.thr_steps = int(thr_steps) if thr_steps is not None else 0
|
||||||
|
self.thr_schedule = thr_schedule
|
||||||
|
self._use_dynamic_thr = (self.thr_steps > 0) and (abs(self.thr_max - self.thr_min) > 1e-12)
|
||||||
|
# 用 buffer 记录已步进次数(不保存到权重里)
|
||||||
|
self.register_buffer("_thr_step", torch.zeros((), dtype=torch.long), persistent=False)
|
||||||
|
|
||||||
# Level 1: 非归一化门控
|
# Level 1: 非归一化门控
|
||||||
self.gate = HardConcreteGate(
|
self.gate = HardConcreteGate(
|
||||||
@ -88,6 +99,30 @@ class HierarchicalGraphMixer(nn.Module):
|
|||||||
self.out_proj = nn.Linear(dim, dim)
|
self.out_proj = nn.Linear(dim, dim)
|
||||||
self.norm = nn.LayerNorm(dim)
|
self.norm = nn.LayerNorm(dim)
|
||||||
|
|
||||||
|
def _compute_thr_by_progress(self, progress: float) -> float:
|
||||||
|
# progress in [0,1]
|
||||||
|
progress = max(0.0, min(1.0, float(progress)))
|
||||||
|
if self.thr_schedule == "linear":
|
||||||
|
g = progress
|
||||||
|
elif self.thr_schedule == "cosine":
|
||||||
|
# 慢起步,后期加速
|
||||||
|
import math
|
||||||
|
g = 0.5 - 0.5 * math.cos(math.pi * progress)
|
||||||
|
elif self.thr_schedule == "exp":
|
||||||
|
# 更快从 thr_min 过渡到 thr_max(指数式)
|
||||||
|
import math
|
||||||
|
k = 5.0
|
||||||
|
g = (math.exp(k * progress) - 1.0) / (math.exp(k) - 1.0)
|
||||||
|
else:
|
||||||
|
g = progress
|
||||||
|
return self.thr_min + (self.thr_max - self.thr_min) * g
|
||||||
|
def _maybe_update_thr(self):
|
||||||
|
if self.training and self._use_dynamic_thr:
|
||||||
|
step = int(self._thr_step.item())
|
||||||
|
progress = step / float(self.thr_steps)
|
||||||
|
self.thr = float(self._compute_thr_by_progress(progress))
|
||||||
|
self._thr_step += 1
|
||||||
|
|
||||||
def _build_sparse_neighbors(self, z_gate):
|
def _build_sparse_neighbors(self, z_gate):
|
||||||
"""
|
"""
|
||||||
基于 z_gate 构造每行的邻接列表(按阈值与可选top-k)。
|
基于 z_gate 构造每行的邻接列表(按阈值与可选top-k)。
|
||||||
@ -151,6 +186,7 @@ class HierarchicalGraphMixer(nn.Module):
|
|||||||
return lam * self.gate.expected_l0().sum()
|
return lam * self.gate.expected_l0().sum()
|
||||||
|
|
||||||
def forward(self, z):
|
def forward(self, z):
|
||||||
|
self._maybe_update_thr()
|
||||||
# z: [B, C, N, D]
|
# z: [B, C, N, D]
|
||||||
B, C, N, D = z.shape
|
B, C, N, D = z.shape
|
||||||
assert C == self.C and D == self.dim
|
assert C == self.C and D == self.dim
|
||||||
|
@ -30,6 +30,10 @@ class SeasonPatch(nn.Module):
|
|||||||
headdim: int = 64,
|
headdim: int = 64,
|
||||||
# Mixergraph 可选超参数
|
# Mixergraph 可选超参数
|
||||||
thr_graph: float = 0.5,
|
thr_graph: float = 0.5,
|
||||||
|
thr_graph_min: float = None,
|
||||||
|
thr_graph_max: float = None,
|
||||||
|
thr_graph_steps: int = 0,
|
||||||
|
thr_graph_schedule: str = "linear",
|
||||||
symmetric_graph: bool = True,
|
symmetric_graph: bool = True,
|
||||||
degree_rescale: str = "count-sqrt", # "none" | "count" | "count-sqrt" | "sum"
|
degree_rescale: str = "count-sqrt", # "none" | "count" | "count-sqrt" | "sum"
|
||||||
gate_temperature: float = 2./3.,
|
gate_temperature: float = 2./3.,
|
||||||
@ -38,6 +42,9 @@ class SeasonPatch(nn.Module):
|
|||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
# ===== 新增:保存 l0_lambda,防止 reg_loss 访问报错 =====
|
||||||
|
self.l0_lambda = l0_lambda
|
||||||
|
|
||||||
# Store patch parameters
|
# Store patch parameters
|
||||||
self.patch_len = patch_len # patch 长度
|
self.patch_len = patch_len # patch 长度
|
||||||
self.stride = stride # patch 步幅
|
self.stride = stride # patch 步幅
|
||||||
@ -60,6 +67,10 @@ class SeasonPatch(nn.Module):
|
|||||||
dim=d_model,
|
dim=d_model,
|
||||||
max_degree=k_graph,
|
max_degree=k_graph,
|
||||||
thr=thr_graph,
|
thr=thr_graph,
|
||||||
|
thr_min=thr_graph_min,
|
||||||
|
thr_max=thr_graph_max,
|
||||||
|
thr_steps=thr_graph_steps,
|
||||||
|
thr_schedule=thr_graph_schedule,
|
||||||
temperature=gate_temperature,
|
temperature=gate_temperature,
|
||||||
tau_attn=tau_attn,
|
tau_attn=tau_attn,
|
||||||
symmetric=symmetric_graph,
|
symmetric=symmetric_graph,
|
||||||
|
@ -54,6 +54,10 @@ class Model(nn.Module):
|
|||||||
# GraphMixer相关(非归一化)
|
# GraphMixer相关(非归一化)
|
||||||
k_graph=getattr(configs, 'k_graph', 8), # -> max_degree
|
k_graph=getattr(configs, 'k_graph', 8), # -> max_degree
|
||||||
thr_graph=getattr(configs, 'thr_graph', 0.5),
|
thr_graph=getattr(configs, 'thr_graph', 0.5),
|
||||||
|
thr_graph_min=getattr(configs, 'thr_graph_min', None),
|
||||||
|
thr_graph_max=getattr(configs, 'thr_graph_max', None),
|
||||||
|
thr_graph_steps=getattr(configs, 'thr_graph_steps', 0),
|
||||||
|
thr_graph_schedule=getattr(configs, 'thr_graph_schedule', 'linear'),
|
||||||
symmetric_graph=getattr(configs, 'symmetric_graph', True),
|
symmetric_graph=getattr(configs, 'symmetric_graph', True),
|
||||||
degree_rescale=getattr(configs, 'degree_rescale', 'count-sqrt'), # 'none' | 'count' | 'count-sqrt' | 'sum'
|
degree_rescale=getattr(configs, 'degree_rescale', 'count-sqrt'), # 'none' | 'count' | 'count-sqrt' | 'sum'
|
||||||
gate_temperature=getattr(configs, 'gate_temperature', 2.0/3.0),
|
gate_temperature=getattr(configs, 'gate_temperature', 2.0/3.0),
|
||||||
|
@ -2,45 +2,6 @@
|
|||||||
|
|
||||||
model_name=xPatch_SparseChannel
|
model_name=xPatch_SparseChannel
|
||||||
|
|
||||||
# ETTm1 dataset
|
|
||||||
for pred_len in 96 192 336 720
|
|
||||||
do
|
|
||||||
python -u run.py \
|
|
||||||
--task_name long_term_forecast \
|
|
||||||
--is_training 1 \
|
|
||||||
--root_path ./dataset/ETT-small/ \
|
|
||||||
--data_path ETTm1.csv \
|
|
||||||
--model_id ETTm1_$pred_len'_'$pred_len \
|
|
||||||
--model $model_name \
|
|
||||||
--data ETTm1 \
|
|
||||||
--features M \
|
|
||||||
--seq_len 96 \
|
|
||||||
--label_len 48 \
|
|
||||||
--pred_len $pred_len \
|
|
||||||
--e_layers 2 \
|
|
||||||
--d_layers 1 \
|
|
||||||
--enc_in 7 \
|
|
||||||
--c_out 7 \
|
|
||||||
--d_model 128 \
|
|
||||||
--lradj 'sigmoid' \
|
|
||||||
--d_ff 256 \
|
|
||||||
--n_heads 16 \
|
|
||||||
--patch_len 16 \
|
|
||||||
--stride 8 \
|
|
||||||
--k_graph 5 \
|
|
||||||
--dropout 0.1 \
|
|
||||||
--revin 1 \
|
|
||||||
--des 'Exp' \
|
|
||||||
--itr 1 \
|
|
||||||
--season_encoder 'Transformer' \
|
|
||||||
--thr_graph 0.6 \
|
|
||||||
--symmetric_graph 1 \
|
|
||||||
--degree_rescale 'none' \
|
|
||||||
--gate_temperature 0.6667 \
|
|
||||||
--tau_attn 1.0 \
|
|
||||||
--season_l0_lambda 0.0000
|
|
||||||
done
|
|
||||||
|
|
||||||
# Weather dataset
|
# Weather dataset
|
||||||
for pred_len in 96 192 336 720
|
for pred_len in 96 192 336 720
|
||||||
do
|
do
|
||||||
@ -78,7 +39,11 @@ python -u run.py \
|
|||||||
--degree_rescale 'none' \
|
--degree_rescale 'none' \
|
||||||
--gate_temperature 0.6667 \
|
--gate_temperature 0.6667 \
|
||||||
--tau_attn 1.0 \
|
--tau_attn 1.0 \
|
||||||
--season_l0_lambda 0.0000
|
--season_l0_lambda 0.0000 \
|
||||||
|
--thr_graph_min 0.1 \
|
||||||
|
--thr_graph_max 0.6 \
|
||||||
|
--thr_graph_steps 1000 \
|
||||||
|
--thr_graph_schedule 'cosine'
|
||||||
done
|
done
|
||||||
|
|
||||||
# Exchange dataset
|
# Exchange dataset
|
||||||
@ -117,9 +82,57 @@ python -u run.py \
|
|||||||
--degree_rescale 'none' \
|
--degree_rescale 'none' \
|
||||||
--gate_temperature 0.6667 \
|
--gate_temperature 0.6667 \
|
||||||
--tau_attn 1.0 \
|
--tau_attn 1.0 \
|
||||||
--season_l0_lambda 0.0000
|
--season_l0_lambda 0.0000 \
|
||||||
|
--thr_graph_min 0.1 \
|
||||||
|
--thr_graph_max 0.6 \
|
||||||
|
--thr_graph_steps 1000 \
|
||||||
|
--thr_graph_schedule 'cosine'
|
||||||
done
|
done
|
||||||
|
|
||||||
|
# ETTm1 dataset
|
||||||
|
for pred_len in 96 192 336 720
|
||||||
|
do
|
||||||
|
python -u run.py \
|
||||||
|
--task_name long_term_forecast \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/ETT-small/ \
|
||||||
|
--data_path ETTm1.csv \
|
||||||
|
--model_id ETTm1_$pred_len'_'$pred_len \
|
||||||
|
--model $model_name \
|
||||||
|
--data ETTm1 \
|
||||||
|
--features M \
|
||||||
|
--seq_len 96 \
|
||||||
|
--label_len 48 \
|
||||||
|
--pred_len $pred_len \
|
||||||
|
--e_layers 2 \
|
||||||
|
--d_layers 1 \
|
||||||
|
--enc_in 7 \
|
||||||
|
--c_out 7 \
|
||||||
|
--d_model 128 \
|
||||||
|
--lradj 'sigmoid' \
|
||||||
|
--d_ff 256 \
|
||||||
|
--n_heads 16 \
|
||||||
|
--patch_len 16 \
|
||||||
|
--stride 8 \
|
||||||
|
--k_graph 5 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--revin 1 \
|
||||||
|
--des 'Exp' \
|
||||||
|
--itr 1 \
|
||||||
|
--season_encoder 'Transformer' \
|
||||||
|
--thr_graph 0.6 \
|
||||||
|
--symmetric_graph 1 \
|
||||||
|
--degree_rescale 'none' \
|
||||||
|
--gate_temperature 0.6667 \
|
||||||
|
--tau_attn 1.0 \
|
||||||
|
--season_l0_lambda 0.0000 \
|
||||||
|
--thr_graph_min 0.1 \
|
||||||
|
--thr_graph_max 0.6 \
|
||||||
|
--thr_graph_steps 1000 \
|
||||||
|
--thr_graph_schedule 'cosine'
|
||||||
|
done
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ETTm2 dataset
|
# ETTm2 dataset
|
||||||
for pred_len in 96 192 336 720
|
for pred_len in 96 192 336 720
|
||||||
|
Reference in New Issue
Block a user