From 204d17086ad4efb859e26f7d0a96ead2b0e60d1b Mon Sep 17 00:00:00 2001 From: gameloader Date: Thu, 11 Sep 2025 16:50:58 +0800 Subject: [PATCH] feat(graph-mixer): implement L0 sparsity with Hard-Concrete gate for channel selection --- layers/GraphMixer.py | 223 ++++++++++++------ layers/SeasonPatch.py | 30 ++- models/xPatch_SparseChannel.py | 42 +++- .../xPatch_SparseChannel_all-Copy1.sh | 97 +++++--- 4 files changed, 268 insertions(+), 124 deletions(-) diff --git a/layers/GraphMixer.py b/layers/GraphMixer.py index 000b3fa..bb826da 100644 --- a/layers/GraphMixer.py +++ b/layers/GraphMixer.py @@ -1,27 +1,86 @@ +import math import torch import torch.nn as nn import torch.nn.functional as F -import math + +class HardConcreteGate(nn.Module): + """ + Hard-Concrete gate for L0-style sparsity (Louizos et al., 2017). + Produces z in [0,1] without row-wise normalization. + """ + def __init__(self, shape, temperature=2./3., gamma=-0.1, zeta=1.1, init_log_alpha=-2.0): + super().__init__() + self.log_alpha = nn.Parameter(torch.full(shape, init_log_alpha)) + self.temperature = temperature + self.gamma = gamma + self.zeta = zeta + + def sample(self, training=True): + if training: + u = torch.rand_like(self.log_alpha) + s = torch.sigmoid((self.log_alpha + torch.log(u) - torch.log(1 - u)) / self.temperature) + else: + # deterministic mean gate at eval + s = torch.sigmoid(self.log_alpha) + s_bar = s * (self.zeta - self.gamma) + self.gamma + z = torch.clamp(s_bar, 0., 1.) + return z + + def expected_l0(self): + """ + E[1_{z>0}] closed-form for hard-concrete. + Useful for L0 penalty: lambda * expected_l0.sum() + """ + # s > t0 => z > 0, where t0 = -gamma / (zeta - gamma) + t0 = -self.gamma / (self.zeta - self.gamma) + # logit(t0) + logit_t0 = math.log(t0) - math.log(1 - t0) + # P(x > logit_t0) with x ~ Logistic(loc=log_alpha, scale=temperature) + p_open = torch.sigmoid((self.log_alpha - logit_t0) / self.temperature) + return p_open class HierarchicalGraphMixer(nn.Module): """ - 分层图混合器,同时考虑宏观通道关系和微观 Patch 级别注意力。 - 输入 z : [B, C, N, D] - 输出 z_out : 同形状 + 使用 Hard-Concrete 边门控的分层图混合器: + - Level 1: 非归一化、可阈值、可为空的通道图 + - Level 2: 仅在被选中的边上做 Patch 级别交叉注意力 + 输入: z [B, C, N, D] + 输出: z_out 同形状 """ - def __init__(self, n_channel: int, dim: int, k: int = 5, tau_fw: float = 0.3, tau_bw: float = 3.0): + def __init__( + self, + n_channel: int, + dim: int, + max_degree: int = None, # 可选:限制每行最多边数 + thr: float = 0.5, # 保留边阈值,例如 0.5/0.7 + temperature: float = 2./3., + tau_attn: float = 1.0, # Patch attention 温度(可选) + symmetric: bool = True, # 是否对称化通道图 + degree_rescale: str = "none", # "none" | "count" | "count-sqrt" | "sum" + init_log_alpha: float = -2.0 + ): super().__init__() - self.k = k - self.tau_fw = tau_fw # 前向温度(小) - self.tau_bw = tau_bw # 反向温度(大) - - # Level 1: Channel Graph (logits) - self.A = nn.Parameter(torch.zeros(n_channel, n_channel)) + self.C = n_channel + self.dim = dim + self.max_degree = max_degree + self.thr = thr + self.tau_attn = tau_attn + self.symmetric = symmetric + self.degree_rescale = degree_rescale + + # Level 1: 非归一化门控 + self.gate = HardConcreteGate( + shape=(n_channel, n_channel), + temperature=temperature, + init_log_alpha=init_log_alpha + ) + + # 可选 SE(你原来的 se 可以用来生成样本相关的通道优先级,但这里先保留接口) self.se = nn.Sequential( nn.Linear(dim, dim // 4, bias=False), nn.SiLU(), nn.Linear(dim // 4, 1, bias=False), nn.Sigmoid() ) - + # Level 2: Patch Cross-Attention self.q_proj = nn.Linear(dim, dim) self.k_proj = nn.Linear(dim, dim) @@ -29,96 +88,108 @@ class HierarchicalGraphMixer(nn.Module): self.out_proj = nn.Linear(dim, dim) self.norm = nn.LayerNorm(dim) - @torch.no_grad() - def _mask_self_logits_(self, logits: torch.Tensor): - """把对角线置为 -inf,确保不选到自己""" - C = logits.size(0) - eye = torch.eye(C, device=logits.device, dtype=torch.bool) - logits.masked_fill_(eye, float("-inf")) - - def _gumbel_topk_select(self, logits: torch.Tensor): + def _build_sparse_neighbors(self, z_gate): """ - 返回: - - idx: [C, k_actual] 每行 top-k 的通道索引(不含自身) - - w_st: [C, k_actual] 选中边的权重(前向=用 tau_fw 的概率;反向梯度=来自 tau_bw 的概率) + 基于 z_gate 构造每行的邻接列表(按阈值与可选top-k)。 + 返回: + - idx_list: 长度C的list,每项是LongTensor[idx_j] + - w_list: 长度C的list,每项是FloatTensor[w_j](非归一化) """ - C = logits.size(0) - k_actual = min(self.k, C - 1) - if k_actual <= 0: - idx = torch.empty((C, 0), dtype=torch.long, device=logits.device) - w_st = torch.empty((C, 0), dtype=logits.dtype, device=logits.device) - return idx, w_st + C = z_gate.size(0) + # 去对角 + z_gate = z_gate.clone() + z_gate.fill_diagonal_(0.0) - # 共享一份 Gumbel 噪声,分别用不同温度构造前向/反向的分布 - g = -torch.empty_like(logits).exponential_().log() - y_fw = (logits + g) / self.tau_fw - y_bw = (logits + g) / self.tau_bw + if self.symmetric: + z_gate = 0.5 * (z_gate + z_gate.t()) + z_gate.fill_diagonal_(0.0) - # 排除自身 - y_fw = y_fw.clone() - y_bw = y_bw.clone() - self._mask_self_logits_(y_fw) - self._mask_self_logits_(y_bw) + idx_list, w_list = [], [] + for i in range(C): + row = z_gate[i] # [C] + # 阈值筛选 + mask = row > self.thr + if mask.any(): + vals = row[mask] + idxs = torch.nonzero(mask, as_tuple=False).squeeze(-1) + # 可选最多度数限制 + if (self.max_degree is not None) and (idxs.numel() > self.max_degree): + topk = torch.topk(vals, k=self.max_degree, dim=0) + vals = topk.values + idxs = idxs[topk.indices] + else: + idxs = torch.empty((0,), dtype=torch.long, device=row.device) + vals = torch.empty((0,), dtype=row.dtype, device=row.device) + idx_list.append(idxs) + w_list.append(vals) + return idx_list, w_list - # 选择前向 top-k(严格选择) - topk_val, idx = torch.topk(y_fw, k_actual, dim=-1) # [C, k] - # 计算前向/反向的软概率,并仅收集被选中的 k 个 - p_fw = F.softmax(y_fw, dim=-1) # [C, C] - p_bw = F.softmax(y_bw, dim=-1) # [C, C] - w_fw = torch.gather(p_fw, -1, idx) # [C, k] - w_bw = torch.gather(p_bw, -1, idx) # [C, k] + def _degree_rescale(self, ctx, w_sel): + """ + 非归一化聚合的稳定性处理。可选对聚合值做degree归一化以稳定数值。 + ctx: [B, k, N, D] + w_sel: [k] + """ + if self.degree_rescale == "none": + return (ctx * w_sel.view(1, -1, 1, 1)).sum(dim=1) + elif self.degree_rescale == "count": + k = max(1, w_sel.numel()) + return (ctx * w_sel.view(1, -1, 1, 1)).sum(dim=1) / float(k) + elif self.degree_rescale == "count-sqrt": + k = max(1, w_sel.numel()) + return (ctx * w_sel.view(1, -1, 1, 1)).sum(dim=1) / math.sqrt(k) + elif self.degree_rescale == "sum": + s = float(w_sel.sum().clamp(min=1e-6)) + return (ctx * w_sel.view(1, -1, 1, 1)).sum(dim=1) / s + else: + return (ctx * w_sel.view(1, -1, 1, 1)).sum(dim=1) - # 在被选集合内进行归一化,稳定训练 - eps = 1e-9 - w_fw = w_fw / (w_fw.sum(-1, keepdim=True) + eps) - w_bw = w_bw / (w_bw.sum(-1, keepdim=True) + eps) - - # Straight-Through:前向用 w_fw,反向梯度用 w_bw - w_st = w_fw.detach() + w_bw - w_bw.detach() # [C, k] - return idx, w_st + def l0_loss(self, lam: float = 1e-4): + """ + 期望L0正则:鼓励稀疏邻接(可调强度)。 + """ + return lam * self.gate.expected_l0().sum() def forward(self, z): # z: [B, C, N, D] B, C, N, D = z.shape + assert C == self.C and D == self.dim - # --- Level 1: 选每个通道的 top-k 相关通道(不含自身),并得到ST权重 --- - idx, w_st = self._gumbel_topk_select(self.A) # idx:[C,k], w_st:[C,k] + # Level 1: 采样非归一化门 z_gate ∈ [0,1] + z_gate = self.gate.sample(training=self.training) # [C, C] - # --- Level 2: 仅对被选中的通道做跨通道 Patch 交互 --- + # 构建稀疏邻居(阈值 + 可选 top-k) + idx_list, w_list = self._build_sparse_neighbors(z_gate) + + # Level 2: 仅对被保留的边做跨通道 Patch 交互 out_z = torch.zeros_like(z) for i in range(C): target_z = z[:, i, :, :] # [B, N, D] - - # 如果该通道没有可选邻居,直接残差 - if idx.size(1) == 0: + idx = idx_list[i] + if idx.numel() == 0: + # 空邻域:允许“没有相关通道”,仅残差/归一化 out_z[:, i, :, :] = self.norm(target_z) continue - sel_idx = idx[i] # [k] - sel_w = w_st[i] # [k] - k_i = sel_idx.numel() + w_sel = w_list[i] # [k], 非归一化权重,范围[0,1] + k_i = idx.numel() - # 源通道块: [B, k, N, D] - source_z = z[:, sel_idx, :, :] + source_z = z[:, idx, :, :] # [B, k, N, D] - # 线性投影 - Q = self.q_proj(target_z) # [B, N, D] + Q = self.q_proj(target_z) # [B, N, D] K = self.k_proj(source_z.reshape(B * k_i, N, D)).reshape(B, k_i, N, D) V = self.v_proj(source_z.reshape(B * k_i, N, D)).reshape(B, k_i, N, D) - # 跨注意力(一次性对 k 个源通道) - # attn_scores: [B, k, N, N] + # 跨通道 patch 注意力 attn_scores = torch.einsum('bnd,bkmd->bknm', Q, K) / math.sqrt(D) - attn_probs = F.softmax(attn_scores, dim=-1) # [B, k, N, N] - context = torch.einsum('bknm,bkmd->bknd', attn_probs, V) # [B, k, N, D] + if self.tau_attn != 1.0: + attn_scores = attn_scores / self.tau_attn + attn_probs = F.softmax(attn_scores, dim=-1) # [B, k, N, N] + context = torch.einsum('bknm,bkmd->bknd', attn_probs, V) # [B, k, N, D] - # 用 ST 的通道权重聚合(前向=小温度的权重,反向梯度=大温度) - w = sel_w.view(1, k_i, 1, 1) # [1, k, 1, 1] - aggregated_context = (context * w).sum(dim=1) # [B, N, D] - - # 输出与残差 + # 非归一化通道权重聚合 + 可选度归一化(仅数值稳定,不改变“非归一化”的语义) + aggregated_context = self._degree_rescale(context, w_sel) # [B, N, D] out_z[:, i, :, :] = self.norm(target_z + self.out_proj(aggregated_context)) return out_z - diff --git a/layers/SeasonPatch.py b/layers/SeasonPatch.py index 1491fb5..f25e445 100644 --- a/layers/SeasonPatch.py +++ b/layers/SeasonPatch.py @@ -27,7 +27,15 @@ class SeasonPatch(nn.Module): d_state: int = 64, d_conv: int = 4, expand: int = 2, - headdim: int = 64): + headdim: int = 64, + # Mixergraph 可选超参数 + thr_graph: float = 0.5, + symmetric_graph: bool = True, + degree_rescale: str = "count-sqrt", # "none" | "count" | "count-sqrt" | "sum" + gate_temperature: float = 2./3., + tau_attn: float = 1.0, + l0_lambda: float = 1e-4): + super().__init__() # Store patch parameters @@ -46,7 +54,17 @@ class SeasonPatch(nn.Module): c_in=c_in, patch_num=patch_num, patch_len=patch_len, d_model=d_model, n_layers=n_layers, n_heads=n_heads ) - self.mixer = HierarchicalGraphMixer(c_in, dim=d_model, k=k_graph) + # 集成新 HierarchicalGraphMixer(非归一化) + self.mixer = HierarchicalGraphMixer( + n_channel=c_in, + dim=d_model, + max_degree=k_graph, + thr=thr_graph, + temperature=gate_temperature, + tau_attn=tau_attn, + symmetric=symmetric_graph, + degree_rescale=degree_rescale + ) # Prediction head(Transformer 路径用到,输入维度为 patch_num * d_model) self.head = nn.Sequential( nn.Linear(patch_num * d_model, patch_num * d_model), @@ -97,3 +115,11 @@ class SeasonPatch(nn.Module): y_pred = self.head(z_last) # y_pred: [B, C, pred_len] return y_pred # [B, C, pred_len] + + def reg_loss(self): + """ + 可选:把 L0 正则暴露出去,训练时加到总loss。 + """ + if self.encoder_type == "Transformer" and hasattr(self, "mixer"): + return self.mixer.l0_loss(self.l0_lambda) + return torch.tensor(0.0, device=self.head[0].weight.device) diff --git a/models/xPatch_SparseChannel.py b/models/xPatch_SparseChannel.py index c809073..e04319e 100644 --- a/models/xPatch_SparseChannel.py +++ b/models/xPatch_SparseChannel.py @@ -22,7 +22,7 @@ class Model(nn.Module): self.pred_len = configs.pred_len self.enc_in = configs.enc_in - # Model parameters + # Patch parameters self.patch_len = getattr(configs, 'patch_len', 16) self.stride = getattr(configs, 'stride', 8) @@ -37,19 +37,33 @@ class Model(nn.Module): beta = getattr(configs, 'beta', torch.tensor(0.1)) self.decomp = DECOMP(ma_type, alpha, beta) - # Season network (PatchTST + Graph Mixer) + # Season network (PatchTST/Mamba2 + Graph Mixer) + # 透传新版 SeasonPatch 的参数(其中 GraphMixer 替换为非归一化 Hard-Concrete 门控) self.season_net = SeasonPatch( c_in=self.enc_in, seq_len=self.seq_len, pred_len=self.pred_len, patch_len=self.patch_len, stride=self.stride, - k_graph=getattr(configs, 'k_graph', 8), + # 编码器类型:'Transformer' or 'Mamba2' + encoder_type=getattr(configs, 'season_encoder', 'Transformer'), + # Patch相关 d_model=getattr(configs, 'd_model', 128), n_layers=getattr(configs, 'e_layers', 3), n_heads=getattr(configs, 'n_heads', 16), - # 读取选择的编码器类型('Transformer' 或 'Mamba2') - encoder_type = getattr(configs, 'season_encoder', 'Transformer') + # GraphMixer相关(非归一化) + k_graph=getattr(configs, 'k_graph', 8), # -> max_degree + thr_graph=getattr(configs, 'thr_graph', 0.5), + symmetric_graph=getattr(configs, 'symmetric_graph', True), + degree_rescale=getattr(configs, 'degree_rescale', 'count-sqrt'), # 'none' | 'count' | 'count-sqrt' | 'sum' + gate_temperature=getattr(configs, 'gate_temperature', 2.0/3.0), + tau_attn=getattr(configs, 'tau_attn', 1.0), + l0_lambda=getattr(configs, 'season_l0_lambda', 0.0), + # Mamba2相关 + d_state=getattr(configs, 'd_state', 64), + d_conv=getattr(configs, 'd_conv', 4), + expand=getattr(configs, 'expand', 2), + headdim=getattr(configs, 'headdim', 64), ) # Trend network (MLP) @@ -119,17 +133,12 @@ class Model(nn.Module): def classification(self, x_enc, x_mark_enc): """Classification task""" - # Normalization - #if self.revin: - # x_enc = self.revin_layer(x_enc, 'norm') - - # Decomposition + # Decomposition(分类任务通常可不做 RevIN,如需可自行打开) seasonal_init, trend_init = self.decomp(x_enc) # Season stream y_season = self.season_net(seasonal_init) # [B, C, pred_len] - # print("shape:", trend_init.shape) # Trend stream B, L, C = trend_init.shape trend = trend_init.permute(0, 2, 1).reshape(B * C, L) # [B*C, L] @@ -146,7 +155,7 @@ class Model(nn.Module): season_attn_weights = torch.softmax(y_season, dim=-1) season_pooled = (y_season * season_attn_weights).sum(dim=-1) # [B, C] - trend_attn_weights = torch.softmax(y_trend, dim=-1) # 时间维 + trend_attn_weights = torch.softmax(y_trend, dim=-1) trend_pooled = (y_trend * trend_attn_weights).sum(dim=-1) # [B, C] # Combine features @@ -166,3 +175,12 @@ class Model(nn.Module): return dec_out # [B, N] else: raise ValueError(f'Task {self.task_name} not supported by xPatch_SparseChannel') + + def reg_loss(self): + """ + L0 正则项(仅在 Transformer 路径启用 GraphMixer 时非零)。 + 训练时:total_loss = main_loss + model.reg_loss() + """ + if hasattr(self, "season_net") and hasattr(self.season_net, "reg_loss"): + return self.season_net.reg_loss() + return torch.tensor(0.0, device=next(self.parameters()).device) diff --git a/scripts/long_term_forecast/xPatch_SparseChannel_all-Copy1.sh b/scripts/long_term_forecast/xPatch_SparseChannel_all-Copy1.sh index b4deafc..40b9723 100644 --- a/scripts/long_term_forecast/xPatch_SparseChannel_all-Copy1.sh +++ b/scripts/long_term_forecast/xPatch_SparseChannel_all-Copy1.sh @@ -2,6 +2,45 @@ model_name=xPatch_SparseChannel +# ETTm1 dataset +for pred_len in 96 192 336 720 +do +python -u run.py \ + --task_name long_term_forecast \ + --is_training 1 \ + --root_path ./dataset/ETT-small/ \ + --data_path ETTm1.csv \ + --model_id ETTm1_$pred_len'_'$pred_len \ + --model $model_name \ + --data ETTm1 \ + --features M \ + --seq_len 96 \ + --label_len 48 \ + --pred_len $pred_len \ + --e_layers 2 \ + --d_layers 1 \ + --enc_in 7 \ + --c_out 7 \ + --d_model 128 \ + --lradj 'sigmoid' \ + --d_ff 256 \ + --n_heads 16 \ + --patch_len 16 \ + --stride 8 \ + --k_graph 5 \ + --dropout 0.1 \ + --revin 1 \ + --des 'Exp' \ + --itr 1 \ + --season_encoder 'Transformer' \ + --thr_graph 0.6 \ + --symmetric_graph 1 \ + --degree_rescale 'none' \ + --gate_temperature 0.6667 \ + --tau_attn 1.0 \ + --season_l0_lambda 0.0000 +done + # Weather dataset for pred_len in 96 192 336 720 do @@ -32,7 +71,14 @@ python -u run.py \ --dropout 0.1 \ --revin 1 \ --des 'Exp' \ - --itr 1 + --itr 1 \ + --season_encoder 'Transformer' \ + --thr_graph 0.6 \ + --symmetric_graph 1 \ + --degree_rescale 'none' \ + --gate_temperature 0.6667 \ + --tau_attn 1.0 \ + --season_l0_lambda 0.0000 done # Exchange dataset @@ -64,40 +110,16 @@ python -u run.py \ --dropout 0.1 \ --revin 1 \ --des 'Exp' \ - --itr 1 + --itr 1 \ + --season_encoder 'Transformer' \ + --thr_graph 0.6 \ + --symmetric_graph 1 \ + --degree_rescale 'none' \ + --gate_temperature 0.6667 \ + --tau_attn 1.0 \ + --season_l0_lambda 0.0000 done -# ETTm1 dataset -for pred_len in 96 192 336 720 -do -python -u run.py \ - --task_name long_term_forecast \ - --is_training 1 \ - --root_path ./dataset/ETT-small/ \ - --data_path ETTm1.csv \ - --model_id ETTm1_$pred_len'_'$pred_len \ - --model $model_name \ - --data ETTm1 \ - --features M \ - --seq_len 96 \ - --label_len 48 \ - --pred_len $pred_len \ - --e_layers 2 \ - --d_layers 1 \ - --enc_in 7 \ - --c_out 7 \ - --d_model 128 \ - --lradj 'sigmoid' \ - --d_ff 256 \ - --n_heads 16 \ - --patch_len 16 \ - --stride 8 \ - --k_graph 5 \ - --dropout 0.1 \ - --revin 1 \ - --des 'Exp' \ - --itr 1 -done # ETTm2 dataset for pred_len in 96 192 336 720 @@ -128,7 +150,14 @@ python -u run.py \ --dropout 0.1 \ --revin 1 \ --des 'Exp' \ - --itr 1 + --itr 1 \ + --season_encoder 'Transformer' \ + --thr_graph 0.6 \ + --symmetric_graph 1 \ + --degree_rescale 'none' \ + --gate_temperature 0.6667 \ + --tau_attn 1.0 \ + --season_l0_lambda 0.0000 done # ETTh1 dataset