From 9f7fb24beb0fde638eccc565459c031476a42417 Mon Sep 17 00:00:00 2001 From: game-loader Date: Sat, 6 Sep 2025 00:06:26 +0800 Subject: [PATCH] refactor(graphmixer): enhance channel graph attention with ST-Gumbel --- layers/GraphMixer.py | 137 ++++++++++++------ layers/SeasonPatch.py | 8 +- .../xPatch_SparseChannel_all-Copy1.sh | 131 +++++++++-------- 3 files changed, 165 insertions(+), 111 deletions(-) diff --git a/layers/GraphMixer.py b/layers/GraphMixer.py index c900805..000b3fa 100644 --- a/layers/GraphMixer.py +++ b/layers/GraphMixer.py @@ -6,18 +6,19 @@ import math class HierarchicalGraphMixer(nn.Module): """ 分层图混合器,同时考虑宏观通道关系和微观 Patch 级别注意力。 - 输入 z : 形状为 [B, C, N, D] 的张量 - 输出 z_out : 形状同输入 + 输入 z : [B, C, N, D] + 输出 z_out : 同形状 """ - def __init__(self, n_channel: int, dim: int, k: int = 5, tau: float = 0.2): + def __init__(self, n_channel: int, dim: int, k: int = 5, tau_fw: float = 0.3, tau_bw: float = 3.0): super().__init__() self.k = k - self.tau = tau + self.tau_fw = tau_fw # 前向温度(小) + self.tau_bw = tau_bw # 反向温度(大) - # Level 1: Channel Graph + # Level 1: Channel Graph (logits) self.A = nn.Parameter(torch.zeros(n_channel, n_channel)) self.se = nn.Sequential( - nn.Linear(dim, dim // 4, bias=False), nn.ReLU(), + nn.Linear(dim, dim // 4, bias=False), nn.SiLU(), nn.Linear(dim // 4, 1, bias=False), nn.Sigmoid() ) @@ -28,56 +29,96 @@ class HierarchicalGraphMixer(nn.Module): self.out_proj = nn.Linear(dim, dim) self.norm = nn.LayerNorm(dim) - def _row_sparse(self, logits: torch.Tensor) -> torch.Tensor: - """Gumbel-Softmax based sparse attention""" - g = -torch.empty_like(logits).exponential_().log() - y = (logits + g) / self.tau - probs = F.softmax(y, dim=-1) - - # Ensure k doesn't exceed the dimension size - k_actual = min(self.k, probs.size(-1)) + @torch.no_grad() + def _mask_self_logits_(self, logits: torch.Tensor): + """把对角线置为 -inf,确保不选到自己""" + C = logits.size(0) + eye = torch.eye(C, device=logits.device, dtype=torch.bool) + logits.masked_fill_(eye, float("-inf")) + + def _gumbel_topk_select(self, logits: torch.Tensor): + """ + 返回: + - idx: [C, k_actual] 每行 top-k 的通道索引(不含自身) + - w_st: [C, k_actual] 选中边的权重(前向=用 tau_fw 的概率;反向梯度=来自 tau_bw 的概率) + """ + C = logits.size(0) + k_actual = min(self.k, C - 1) if k_actual <= 0: - return torch.zeros_like(probs) - - topk_val, _ = torch.topk(probs, k_actual, dim=-1) - thr = topk_val[..., -1].unsqueeze(-1) - sparse = torch.where(probs >= thr, probs, torch.zeros_like(probs)) - return sparse.detach() + probs - probs.detach() + idx = torch.empty((C, 0), dtype=torch.long, device=logits.device) + w_st = torch.empty((C, 0), dtype=logits.dtype, device=logits.device) + return idx, w_st + + # 共享一份 Gumbel 噪声,分别用不同温度构造前向/反向的分布 + g = -torch.empty_like(logits).exponential_().log() + y_fw = (logits + g) / self.tau_fw + y_bw = (logits + g) / self.tau_bw + + # 排除自身 + y_fw = y_fw.clone() + y_bw = y_bw.clone() + self._mask_self_logits_(y_fw) + self._mask_self_logits_(y_bw) + + # 选择前向 top-k(严格选择) + topk_val, idx = torch.topk(y_fw, k_actual, dim=-1) # [C, k] + # 计算前向/反向的软概率,并仅收集被选中的 k 个 + p_fw = F.softmax(y_fw, dim=-1) # [C, C] + p_bw = F.softmax(y_bw, dim=-1) # [C, C] + w_fw = torch.gather(p_fw, -1, idx) # [C, k] + w_bw = torch.gather(p_bw, -1, idx) # [C, k] + + # 在被选集合内进行归一化,稳定训练 + eps = 1e-9 + w_fw = w_fw / (w_fw.sum(-1, keepdim=True) + eps) + w_bw = w_bw / (w_bw.sum(-1, keepdim=True) + eps) + + # Straight-Through:前向用 w_fw,反向梯度用 w_bw + w_st = w_fw.detach() + w_bw - w_bw.detach() # [C, k] + return idx, w_st def forward(self, z): - # z 的形状: [B, C, N, D] + # z: [B, C, N, D] B, C, N, D = z.shape - # --- Level 1: 计算宏观权重 --- - A_sparse = self._row_sparse(self.A) # 通道连接稀疏图 A_sparse: [C, C] + # --- Level 1: 选每个通道的 top-k 相关通道(不含自身),并得到ST权重 --- + idx, w_st = self._gumbel_topk_select(self.A) # idx:[C,k], w_st:[C,k] - # --- Level 2: 跨通道 Patch 交互 --- + # --- Level 2: 仅对被选中的通道做跨通道 Patch 交互 --- out_z = torch.zeros_like(z) - for i in range(C): # 遍历每个目标通道 i + + for i in range(C): target_z = z[:, i, :, :] # [B, N, D] - - # 准备聚合来自其他通道的 patch 级别上下文 - aggregated_context = torch.zeros_like(target_z) - - for j in range(C): # 遍历每个源通道 j - if A_sparse[i, j] != 0: - source_z = z[:, j, :, :] # [B, N, D] - # --- 执行交叉注意力 --- - Q = self.q_proj(target_z) # Query 来自目标通道 i - K = self.k_proj(source_z) # Key 来自源通道 j - V = self.v_proj(source_z) # Value 来自源通道 j - - attn_scores = torch.bmm(Q, K.transpose(1, 2)) / math.sqrt(D) - attn_probs = F.softmax(attn_scores, dim=-1) # [B, N, N] - - context = torch.bmm(attn_probs, V) # [B, N, D], 从 j 聚合到 i 的上下文 - - # 加权上下文 - weighted_context = A_sparse[i, j] * context - aggregated_context = aggregated_context + weighted_context + # 如果该通道没有可选邻居,直接残差 + if idx.size(1) == 0: + out_z[:, i, :, :] = self.norm(target_z) + continue - # 将聚合后的上下文通过输出层,并与原始目标表示相加(残差连接) + sel_idx = idx[i] # [k] + sel_w = w_st[i] # [k] + k_i = sel_idx.numel() + + # 源通道块: [B, k, N, D] + source_z = z[:, sel_idx, :, :] + + # 线性投影 + Q = self.q_proj(target_z) # [B, N, D] + K = self.k_proj(source_z.reshape(B * k_i, N, D)).reshape(B, k_i, N, D) + V = self.v_proj(source_z.reshape(B * k_i, N, D)).reshape(B, k_i, N, D) + + # 跨注意力(一次性对 k 个源通道) + # attn_scores: [B, k, N, N] + attn_scores = torch.einsum('bnd,bkmd->bknm', Q, K) / math.sqrt(D) + attn_probs = F.softmax(attn_scores, dim=-1) # [B, k, N, N] + context = torch.einsum('bknm,bkmd->bknd', attn_probs, V) # [B, k, N, D] + + # 用 ST 的通道权重聚合(前向=小温度的权重,反向梯度=大温度) + w = sel_w.view(1, k_i, 1, 1) # [1, k, 1, 1] + aggregated_context = (context * w).sum(dim=1) # [B, N, D] + + # 输出与残差 out_z[:, i, :, :] = self.norm(target_z + self.out_proj(aggregated_context)) - - return out_z \ No newline at end of file + + return out_z + diff --git a/layers/SeasonPatch.py b/layers/SeasonPatch.py index 3dd349e..0d4d6a2 100644 --- a/layers/SeasonPatch.py +++ b/layers/SeasonPatch.py @@ -41,7 +41,11 @@ class SeasonPatch(nn.Module): self.mixer = HierarchicalGraphMixer(c_in, dim=d_model, k=k_graph) # Prediction head - self.head = nn.Linear(patch_num * d_model, pred_len) + self.head = nn.Sequential( + nn.Linear(patch_num * d_model, patch_num * d_model), + nn.SiLU(), # 非线性激活(SiLU/Swish) + nn.Linear(patch_num * d_model, pred_len) + ) def forward(self, x): # x: [B, L, C] @@ -64,4 +68,4 @@ class SeasonPatch(nn.Module): z_mix = z_mix.view(B, C, N * D) # [B, C, patch_num * d_model] y_pred = self.head(z_mix) # [B, C, pred_len] - return y_pred \ No newline at end of file + return y_pred diff --git a/scripts/long_term_forecast/xPatch_SparseChannel_all-Copy1.sh b/scripts/long_term_forecast/xPatch_SparseChannel_all-Copy1.sh index ae2c04d..b4deafc 100644 --- a/scripts/long_term_forecast/xPatch_SparseChannel_all-Copy1.sh +++ b/scripts/long_term_forecast/xPatch_SparseChannel_all-Copy1.sh @@ -2,6 +2,71 @@ model_name=xPatch_SparseChannel +# Weather dataset +for pred_len in 96 192 336 720 +do +python -u run.py \ + --task_name long_term_forecast \ + --is_training 1 \ + --root_path ./dataset/weather/ \ + --data_path weather.csv \ + --model_id weather_$pred_len'_'$pred_len \ + --model $model_name \ + --data custom \ + --features M \ + --seq_len 96 \ + --label_len 48 \ + --pred_len $pred_len \ + --e_layers 2 \ + --d_layers 1 \ + --enc_in 21 \ + --c_out 21 \ + --d_model 128 \ + --lradj 'sigmoid' \ + --train_epochs 20 \ + --d_ff 256 \ + --n_heads 16 \ + --patch_len 16 \ + --stride 8 \ + --k_graph 8 \ + --dropout 0.1 \ + --revin 1 \ + --des 'Exp' \ + --itr 1 +done + +# Exchange dataset +for pred_len in 96 192 336 720 +do +python -u run.py \ + --task_name long_term_forecast \ + --is_training 1 \ + --root_path ./dataset/exchange_rate/ \ + --data_path exchange_rate.csv \ + --model_id Exchange_$pred_len'_'$pred_len \ + --model $model_name \ + --data custom \ + --features M \ + --seq_len 96 \ + --label_len 48 \ + --pred_len $pred_len \ + --e_layers 2 \ + --d_layers 1 \ + --enc_in 8 \ + --c_out 8 \ + --d_model 128 \ + --lradj 'sigmoid' \ + --d_ff 256 \ + --n_heads 16 \ + --patch_len 16 \ + --stride 8 \ + --k_graph 8 \ + --dropout 0.1 \ + --revin 1 \ + --des 'Exp' \ + --itr 1 +done + # ETTm1 dataset for pred_len in 96 192 336 720 do @@ -22,11 +87,12 @@ python -u run.py \ --enc_in 7 \ --c_out 7 \ --d_model 128 \ + --lradj 'sigmoid' \ --d_ff 256 \ --n_heads 16 \ --patch_len 16 \ --stride 8 \ - --k_graph 7 \ + --k_graph 5 \ --dropout 0.1 \ --revin 1 \ --des 'Exp' \ @@ -53,6 +119,7 @@ python -u run.py \ --enc_in 7 \ --c_out 7 \ --d_model 128 \ + --lradj 'sigmoid' \ --d_ff 256 \ --n_heads 16 \ --patch_len 16 \ @@ -84,6 +151,7 @@ python -u run.py \ --enc_in 7 \ --c_out 7 \ --d_model 128 \ + --lradj 'sigmoid' \ --d_ff 256 \ --n_heads 16 \ --patch_len 16 \ @@ -115,6 +183,7 @@ python -u run.py \ --enc_in 7 \ --c_out 7 \ --d_model 128 \ + --lradj 'sigmoid' \ --d_ff 256 \ --n_heads 16 \ --patch_len 16 \ @@ -126,36 +195,6 @@ python -u run.py \ --itr 1 done -# Weather dataset -for pred_len in 96 192 336 720 -do -python -u run.py \ - --task_name long_term_forecast \ - --is_training 1 \ - --root_path ./dataset/weather/ \ - --data_path weather.csv \ - --model_id weather_$pred_len'_'$pred_len \ - --model $model_name \ - --data custom \ - --features M \ - --seq_len 96 \ - --label_len 48 \ - --pred_len $pred_len \ - --e_layers 2 \ - --d_layers 1 \ - --enc_in 21 \ - --c_out 21 \ - --d_model 128 \ - --d_ff 256 \ - --n_heads 16 \ - --patch_len 16 \ - --stride 8 \ - --k_graph 8 \ - --dropout 0.1 \ - --revin 1 \ - --des 'Exp' \ - --itr 1 -done # ECL dataset for pred_len in 96 192 336 720 @@ -219,33 +258,3 @@ python -u run.py \ --itr 1 done -# Exchange dataset -for pred_len in 96 192 336 720 -do -python -u run.py \ - --task_name long_term_forecast \ - --is_training 1 \ - --root_path ./dataset/exchange_rate/ \ - --data_path exchange_rate.csv \ - --model_id Exchange_$pred_len'_'$pred_len \ - --model $model_name \ - --data custom \ - --features M \ - --seq_len 96 \ - --label_len 48 \ - --pred_len $pred_len \ - --e_layers 2 \ - --d_layers 1 \ - --enc_in 8 \ - --c_out 8 \ - --d_model 128 \ - --d_ff 256 \ - --n_heads 16 \ - --patch_len 16 \ - --stride 8 \ - --k_graph 8 \ - --dropout 0.1 \ - --revin 1 \ - --des 'Exp' \ - --itr 1 -done \ No newline at end of file