refactor(graphmixer): enhance channel graph attention with ST-Gumbel

This commit is contained in:
game-loader
2025-09-06 00:06:26 +08:00
parent ef307a57e9
commit 9f7fb24beb
3 changed files with 165 additions and 111 deletions

View File

@ -6,18 +6,19 @@ import math
class HierarchicalGraphMixer(nn.Module):
"""
分层图混合器,同时考虑宏观通道关系和微观 Patch 级别注意力。
输入 z : 形状为 [B, C, N, D] 的张量
输出 z_out : 形状同输入
输入 z : [B, C, N, D]
输出 z_out : 形状
"""
def __init__(self, n_channel: int, dim: int, k: int = 5, tau: float = 0.2):
def __init__(self, n_channel: int, dim: int, k: int = 5, tau_fw: float = 0.3, tau_bw: float = 3.0):
super().__init__()
self.k = k
self.tau = tau
self.tau_fw = tau_fw # 前向温度(小)
self.tau_bw = tau_bw # 反向温度(大)
# Level 1: Channel Graph
# Level 1: Channel Graph (logits)
self.A = nn.Parameter(torch.zeros(n_channel, n_channel))
self.se = nn.Sequential(
nn.Linear(dim, dim // 4, bias=False), nn.ReLU(),
nn.Linear(dim, dim // 4, bias=False), nn.SiLU(),
nn.Linear(dim // 4, 1, bias=False), nn.Sigmoid()
)
@ -28,56 +29,96 @@ class HierarchicalGraphMixer(nn.Module):
self.out_proj = nn.Linear(dim, dim)
self.norm = nn.LayerNorm(dim)
def _row_sparse(self, logits: torch.Tensor) -> torch.Tensor:
"""Gumbel-Softmax based sparse attention"""
g = -torch.empty_like(logits).exponential_().log()
y = (logits + g) / self.tau
probs = F.softmax(y, dim=-1)
@torch.no_grad()
def _mask_self_logits_(self, logits: torch.Tensor):
"""把对角线置为 -inf确保不选到自己"""
C = logits.size(0)
eye = torch.eye(C, device=logits.device, dtype=torch.bool)
logits.masked_fill_(eye, float("-inf"))
# Ensure k doesn't exceed the dimension size
k_actual = min(self.k, probs.size(-1))
def _gumbel_topk_select(self, logits: torch.Tensor):
"""
返回:
- idx: [C, k_actual] 每行 top-k 的通道索引(不含自身)
- w_st: [C, k_actual] 选中边的权重(前向=用 tau_fw 的概率;反向梯度=来自 tau_bw 的概率)
"""
C = logits.size(0)
k_actual = min(self.k, C - 1)
if k_actual <= 0:
return torch.zeros_like(probs)
idx = torch.empty((C, 0), dtype=torch.long, device=logits.device)
w_st = torch.empty((C, 0), dtype=logits.dtype, device=logits.device)
return idx, w_st
topk_val, _ = torch.topk(probs, k_actual, dim=-1)
thr = topk_val[..., -1].unsqueeze(-1)
sparse = torch.where(probs >= thr, probs, torch.zeros_like(probs))
return sparse.detach() + probs - probs.detach()
# 共享一份 Gumbel 噪声,分别用不同温度构造前向/反向的分布
g = -torch.empty_like(logits).exponential_().log()
y_fw = (logits + g) / self.tau_fw
y_bw = (logits + g) / self.tau_bw
# 排除自身
y_fw = y_fw.clone()
y_bw = y_bw.clone()
self._mask_self_logits_(y_fw)
self._mask_self_logits_(y_bw)
# 选择前向 top-k严格选择
topk_val, idx = torch.topk(y_fw, k_actual, dim=-1) # [C, k]
# 计算前向/反向的软概率,并仅收集被选中的 k 个
p_fw = F.softmax(y_fw, dim=-1) # [C, C]
p_bw = F.softmax(y_bw, dim=-1) # [C, C]
w_fw = torch.gather(p_fw, -1, idx) # [C, k]
w_bw = torch.gather(p_bw, -1, idx) # [C, k]
# 在被选集合内进行归一化,稳定训练
eps = 1e-9
w_fw = w_fw / (w_fw.sum(-1, keepdim=True) + eps)
w_bw = w_bw / (w_bw.sum(-1, keepdim=True) + eps)
# Straight-Through前向用 w_fw反向梯度用 w_bw
w_st = w_fw.detach() + w_bw - w_bw.detach() # [C, k]
return idx, w_st
def forward(self, z):
# z 的形状: [B, C, N, D]
# z: [B, C, N, D]
B, C, N, D = z.shape
# --- Level 1: 计算宏观权重 ---
A_sparse = self._row_sparse(self.A) # 通道连接稀疏图 A_sparse: [C, C]
# --- Level 1: 选每个通道的 top-k 相关通道不含自身并得到ST权重 ---
idx, w_st = self._gumbel_topk_select(self.A) # idx:[C,k], w_st:[C,k]
# --- Level 2: 跨通道 Patch 交互 ---
# --- Level 2: 仅对被选中的通道做跨通道 Patch 交互 ---
out_z = torch.zeros_like(z)
for i in range(C): # 遍历每个目标通道 i
for i in range(C):
target_z = z[:, i, :, :] # [B, N, D]
# 准备聚合来自其他通道的 patch 级别上下文
aggregated_context = torch.zeros_like(target_z)
# 如果该通道没有可选邻居,直接残差
if idx.size(1) == 0:
out_z[:, i, :, :] = self.norm(target_z)
continue
for j in range(C): # 遍历每个源通道 j
if A_sparse[i, j] != 0:
source_z = z[:, j, :, :] # [B, N, D]
sel_idx = idx[i] # [k]
sel_w = w_st[i] # [k]
k_i = sel_idx.numel()
# --- 执行交叉注意力 ---
Q = self.q_proj(target_z) # Query 来自目标通道 i
K = self.k_proj(source_z) # Key 来自源通道 j
V = self.v_proj(source_z) # Value 来自源通道 j
# 源通道块: [B, k, N, D]
source_z = z[:, sel_idx, :, :]
attn_scores = torch.bmm(Q, K.transpose(1, 2)) / math.sqrt(D)
attn_probs = F.softmax(attn_scores, dim=-1) # [B, N, N]
# 线性投影
Q = self.q_proj(target_z) # [B, N, D]
K = self.k_proj(source_z.reshape(B * k_i, N, D)).reshape(B, k_i, N, D)
V = self.v_proj(source_z.reshape(B * k_i, N, D)).reshape(B, k_i, N, D)
context = torch.bmm(attn_probs, V) # [B, N, D], 从 j 聚合到 i 的上下文
# 跨注意力(一次性对 k 个源通道)
# attn_scores: [B, k, N, N]
attn_scores = torch.einsum('bnd,bkmd->bknm', Q, K) / math.sqrt(D)
attn_probs = F.softmax(attn_scores, dim=-1) # [B, k, N, N]
context = torch.einsum('bknm,bkmd->bknd', attn_probs, V) # [B, k, N, D]
# 加权上下文
weighted_context = A_sparse[i, j] * context
aggregated_context = aggregated_context + weighted_context
# 用 ST 的通道权重聚合(前向=小温度的权重,反向梯度=大温度)
w = sel_w.view(1, k_i, 1, 1) # [1, k, 1, 1]
aggregated_context = (context * w).sum(dim=1) # [B, N, D]
# 将聚合后的上下文通过输出层,并与原始目标表示相加(残差连接)
# 输出与残差
out_z[:, i, :, :] = self.norm(target_z + self.out_proj(aggregated_context))
return out_z

View File

@ -41,7 +41,11 @@ class SeasonPatch(nn.Module):
self.mixer = HierarchicalGraphMixer(c_in, dim=d_model, k=k_graph)
# Prediction head
self.head = nn.Linear(patch_num * d_model, pred_len)
self.head = nn.Sequential(
nn.Linear(patch_num * d_model, patch_num * d_model),
nn.SiLU(), # 非线性激活SiLU/Swish
nn.Linear(patch_num * d_model, pred_len)
)
def forward(self, x):
# x: [B, L, C]

View File

@ -2,6 +2,71 @@
model_name=xPatch_SparseChannel
# Weather dataset
for pred_len in 96 192 336 720
do
python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/weather/ \
--data_path weather.csv \
--model_id weather_$pred_len'_'$pred_len \
--model $model_name \
--data custom \
--features M \
--seq_len 96 \
--label_len 48 \
--pred_len $pred_len \
--e_layers 2 \
--d_layers 1 \
--enc_in 21 \
--c_out 21 \
--d_model 128 \
--lradj 'sigmoid' \
--train_epochs 20 \
--d_ff 256 \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--k_graph 8 \
--dropout 0.1 \
--revin 1 \
--des 'Exp' \
--itr 1
done
# Exchange dataset
for pred_len in 96 192 336 720
do
python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/exchange_rate/ \
--data_path exchange_rate.csv \
--model_id Exchange_$pred_len'_'$pred_len \
--model $model_name \
--data custom \
--features M \
--seq_len 96 \
--label_len 48 \
--pred_len $pred_len \
--e_layers 2 \
--d_layers 1 \
--enc_in 8 \
--c_out 8 \
--d_model 128 \
--lradj 'sigmoid' \
--d_ff 256 \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--k_graph 8 \
--dropout 0.1 \
--revin 1 \
--des 'Exp' \
--itr 1
done
# ETTm1 dataset
for pred_len in 96 192 336 720
do
@ -22,11 +87,12 @@ python -u run.py \
--enc_in 7 \
--c_out 7 \
--d_model 128 \
--lradj 'sigmoid' \
--d_ff 256 \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--k_graph 7 \
--k_graph 5 \
--dropout 0.1 \
--revin 1 \
--des 'Exp' \
@ -53,6 +119,7 @@ python -u run.py \
--enc_in 7 \
--c_out 7 \
--d_model 128 \
--lradj 'sigmoid' \
--d_ff 256 \
--n_heads 16 \
--patch_len 16 \
@ -84,6 +151,7 @@ python -u run.py \
--enc_in 7 \
--c_out 7 \
--d_model 128 \
--lradj 'sigmoid' \
--d_ff 256 \
--n_heads 16 \
--patch_len 16 \
@ -115,6 +183,7 @@ python -u run.py \
--enc_in 7 \
--c_out 7 \
--d_model 128 \
--lradj 'sigmoid' \
--d_ff 256 \
--n_heads 16 \
--patch_len 16 \
@ -126,36 +195,6 @@ python -u run.py \
--itr 1
done
# Weather dataset
for pred_len in 96 192 336 720
do
python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/weather/ \
--data_path weather.csv \
--model_id weather_$pred_len'_'$pred_len \
--model $model_name \
--data custom \
--features M \
--seq_len 96 \
--label_len 48 \
--pred_len $pred_len \
--e_layers 2 \
--d_layers 1 \
--enc_in 21 \
--c_out 21 \
--d_model 128 \
--d_ff 256 \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--k_graph 8 \
--dropout 0.1 \
--revin 1 \
--des 'Exp' \
--itr 1
done
# ECL dataset
for pred_len in 96 192 336 720
@ -219,33 +258,3 @@ python -u run.py \
--itr 1
done
# Exchange dataset
for pred_len in 96 192 336 720
do
python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/exchange_rate/ \
--data_path exchange_rate.csv \
--model_id Exchange_$pred_len'_'$pred_len \
--model $model_name \
--data custom \
--features M \
--seq_len 96 \
--label_len 48 \
--pred_len $pred_len \
--e_layers 2 \
--d_layers 1 \
--enc_in 8 \
--c_out 8 \
--d_model 128 \
--d_ff 256 \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--k_graph 8 \
--dropout 0.1 \
--revin 1 \
--des 'Exp' \
--itr 1
done