refactor(graphmixer): enhance channel graph attention with ST-Gumbel
This commit is contained in:
@ -6,18 +6,19 @@ import math
|
|||||||
class HierarchicalGraphMixer(nn.Module):
|
class HierarchicalGraphMixer(nn.Module):
|
||||||
"""
|
"""
|
||||||
分层图混合器,同时考虑宏观通道关系和微观 Patch 级别注意力。
|
分层图混合器,同时考虑宏观通道关系和微观 Patch 级别注意力。
|
||||||
输入 z : 形状为 [B, C, N, D] 的张量
|
输入 z : [B, C, N, D]
|
||||||
输出 z_out : 形状同输入
|
输出 z_out : 同形状
|
||||||
"""
|
"""
|
||||||
def __init__(self, n_channel: int, dim: int, k: int = 5, tau: float = 0.2):
|
def __init__(self, n_channel: int, dim: int, k: int = 5, tau_fw: float = 0.3, tau_bw: float = 3.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.k = k
|
self.k = k
|
||||||
self.tau = tau
|
self.tau_fw = tau_fw # 前向温度(小)
|
||||||
|
self.tau_bw = tau_bw # 反向温度(大)
|
||||||
|
|
||||||
# Level 1: Channel Graph
|
# Level 1: Channel Graph (logits)
|
||||||
self.A = nn.Parameter(torch.zeros(n_channel, n_channel))
|
self.A = nn.Parameter(torch.zeros(n_channel, n_channel))
|
||||||
self.se = nn.Sequential(
|
self.se = nn.Sequential(
|
||||||
nn.Linear(dim, dim // 4, bias=False), nn.ReLU(),
|
nn.Linear(dim, dim // 4, bias=False), nn.SiLU(),
|
||||||
nn.Linear(dim // 4, 1, bias=False), nn.Sigmoid()
|
nn.Linear(dim // 4, 1, bias=False), nn.Sigmoid()
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -28,56 +29,96 @@ class HierarchicalGraphMixer(nn.Module):
|
|||||||
self.out_proj = nn.Linear(dim, dim)
|
self.out_proj = nn.Linear(dim, dim)
|
||||||
self.norm = nn.LayerNorm(dim)
|
self.norm = nn.LayerNorm(dim)
|
||||||
|
|
||||||
def _row_sparse(self, logits: torch.Tensor) -> torch.Tensor:
|
@torch.no_grad()
|
||||||
"""Gumbel-Softmax based sparse attention"""
|
def _mask_self_logits_(self, logits: torch.Tensor):
|
||||||
g = -torch.empty_like(logits).exponential_().log()
|
"""把对角线置为 -inf,确保不选到自己"""
|
||||||
y = (logits + g) / self.tau
|
C = logits.size(0)
|
||||||
probs = F.softmax(y, dim=-1)
|
eye = torch.eye(C, device=logits.device, dtype=torch.bool)
|
||||||
|
logits.masked_fill_(eye, float("-inf"))
|
||||||
# Ensure k doesn't exceed the dimension size
|
|
||||||
k_actual = min(self.k, probs.size(-1))
|
def _gumbel_topk_select(self, logits: torch.Tensor):
|
||||||
|
"""
|
||||||
|
返回:
|
||||||
|
- idx: [C, k_actual] 每行 top-k 的通道索引(不含自身)
|
||||||
|
- w_st: [C, k_actual] 选中边的权重(前向=用 tau_fw 的概率;反向梯度=来自 tau_bw 的概率)
|
||||||
|
"""
|
||||||
|
C = logits.size(0)
|
||||||
|
k_actual = min(self.k, C - 1)
|
||||||
if k_actual <= 0:
|
if k_actual <= 0:
|
||||||
return torch.zeros_like(probs)
|
idx = torch.empty((C, 0), dtype=torch.long, device=logits.device)
|
||||||
|
w_st = torch.empty((C, 0), dtype=logits.dtype, device=logits.device)
|
||||||
topk_val, _ = torch.topk(probs, k_actual, dim=-1)
|
return idx, w_st
|
||||||
thr = topk_val[..., -1].unsqueeze(-1)
|
|
||||||
sparse = torch.where(probs >= thr, probs, torch.zeros_like(probs))
|
# 共享一份 Gumbel 噪声,分别用不同温度构造前向/反向的分布
|
||||||
return sparse.detach() + probs - probs.detach()
|
g = -torch.empty_like(logits).exponential_().log()
|
||||||
|
y_fw = (logits + g) / self.tau_fw
|
||||||
|
y_bw = (logits + g) / self.tau_bw
|
||||||
|
|
||||||
|
# 排除自身
|
||||||
|
y_fw = y_fw.clone()
|
||||||
|
y_bw = y_bw.clone()
|
||||||
|
self._mask_self_logits_(y_fw)
|
||||||
|
self._mask_self_logits_(y_bw)
|
||||||
|
|
||||||
|
# 选择前向 top-k(严格选择)
|
||||||
|
topk_val, idx = torch.topk(y_fw, k_actual, dim=-1) # [C, k]
|
||||||
|
# 计算前向/反向的软概率,并仅收集被选中的 k 个
|
||||||
|
p_fw = F.softmax(y_fw, dim=-1) # [C, C]
|
||||||
|
p_bw = F.softmax(y_bw, dim=-1) # [C, C]
|
||||||
|
w_fw = torch.gather(p_fw, -1, idx) # [C, k]
|
||||||
|
w_bw = torch.gather(p_bw, -1, idx) # [C, k]
|
||||||
|
|
||||||
|
# 在被选集合内进行归一化,稳定训练
|
||||||
|
eps = 1e-9
|
||||||
|
w_fw = w_fw / (w_fw.sum(-1, keepdim=True) + eps)
|
||||||
|
w_bw = w_bw / (w_bw.sum(-1, keepdim=True) + eps)
|
||||||
|
|
||||||
|
# Straight-Through:前向用 w_fw,反向梯度用 w_bw
|
||||||
|
w_st = w_fw.detach() + w_bw - w_bw.detach() # [C, k]
|
||||||
|
return idx, w_st
|
||||||
|
|
||||||
def forward(self, z):
|
def forward(self, z):
|
||||||
# z 的形状: [B, C, N, D]
|
# z: [B, C, N, D]
|
||||||
B, C, N, D = z.shape
|
B, C, N, D = z.shape
|
||||||
|
|
||||||
# --- Level 1: 计算宏观权重 ---
|
# --- Level 1: 选每个通道的 top-k 相关通道(不含自身),并得到ST权重 ---
|
||||||
A_sparse = self._row_sparse(self.A) # 通道连接稀疏图 A_sparse: [C, C]
|
idx, w_st = self._gumbel_topk_select(self.A) # idx:[C,k], w_st:[C,k]
|
||||||
|
|
||||||
# --- Level 2: 跨通道 Patch 交互 ---
|
# --- Level 2: 仅对被选中的通道做跨通道 Patch 交互 ---
|
||||||
out_z = torch.zeros_like(z)
|
out_z = torch.zeros_like(z)
|
||||||
for i in range(C): # 遍历每个目标通道 i
|
|
||||||
|
for i in range(C):
|
||||||
target_z = z[:, i, :, :] # [B, N, D]
|
target_z = z[:, i, :, :] # [B, N, D]
|
||||||
|
|
||||||
# 准备聚合来自其他通道的 patch 级别上下文
|
|
||||||
aggregated_context = torch.zeros_like(target_z)
|
|
||||||
|
|
||||||
for j in range(C): # 遍历每个源通道 j
|
|
||||||
if A_sparse[i, j] != 0:
|
|
||||||
source_z = z[:, j, :, :] # [B, N, D]
|
|
||||||
|
|
||||||
# --- 执行交叉注意力 ---
|
# 如果该通道没有可选邻居,直接残差
|
||||||
Q = self.q_proj(target_z) # Query 来自目标通道 i
|
if idx.size(1) == 0:
|
||||||
K = self.k_proj(source_z) # Key 来自源通道 j
|
out_z[:, i, :, :] = self.norm(target_z)
|
||||||
V = self.v_proj(source_z) # Value 来自源通道 j
|
continue
|
||||||
|
|
||||||
attn_scores = torch.bmm(Q, K.transpose(1, 2)) / math.sqrt(D)
|
|
||||||
attn_probs = F.softmax(attn_scores, dim=-1) # [B, N, N]
|
|
||||||
|
|
||||||
context = torch.bmm(attn_probs, V) # [B, N, D], 从 j 聚合到 i 的上下文
|
|
||||||
|
|
||||||
# 加权上下文
|
|
||||||
weighted_context = A_sparse[i, j] * context
|
|
||||||
aggregated_context = aggregated_context + weighted_context
|
|
||||||
|
|
||||||
# 将聚合后的上下文通过输出层,并与原始目标表示相加(残差连接)
|
sel_idx = idx[i] # [k]
|
||||||
|
sel_w = w_st[i] # [k]
|
||||||
|
k_i = sel_idx.numel()
|
||||||
|
|
||||||
|
# 源通道块: [B, k, N, D]
|
||||||
|
source_z = z[:, sel_idx, :, :]
|
||||||
|
|
||||||
|
# 线性投影
|
||||||
|
Q = self.q_proj(target_z) # [B, N, D]
|
||||||
|
K = self.k_proj(source_z.reshape(B * k_i, N, D)).reshape(B, k_i, N, D)
|
||||||
|
V = self.v_proj(source_z.reshape(B * k_i, N, D)).reshape(B, k_i, N, D)
|
||||||
|
|
||||||
|
# 跨注意力(一次性对 k 个源通道)
|
||||||
|
# attn_scores: [B, k, N, N]
|
||||||
|
attn_scores = torch.einsum('bnd,bkmd->bknm', Q, K) / math.sqrt(D)
|
||||||
|
attn_probs = F.softmax(attn_scores, dim=-1) # [B, k, N, N]
|
||||||
|
context = torch.einsum('bknm,bkmd->bknd', attn_probs, V) # [B, k, N, D]
|
||||||
|
|
||||||
|
# 用 ST 的通道权重聚合(前向=小温度的权重,反向梯度=大温度)
|
||||||
|
w = sel_w.view(1, k_i, 1, 1) # [1, k, 1, 1]
|
||||||
|
aggregated_context = (context * w).sum(dim=1) # [B, N, D]
|
||||||
|
|
||||||
|
# 输出与残差
|
||||||
out_z[:, i, :, :] = self.norm(target_z + self.out_proj(aggregated_context))
|
out_z[:, i, :, :] = self.norm(target_z + self.out_proj(aggregated_context))
|
||||||
|
|
||||||
return out_z
|
return out_z
|
||||||
|
|
||||||
|
@ -41,7 +41,11 @@ class SeasonPatch(nn.Module):
|
|||||||
self.mixer = HierarchicalGraphMixer(c_in, dim=d_model, k=k_graph)
|
self.mixer = HierarchicalGraphMixer(c_in, dim=d_model, k=k_graph)
|
||||||
|
|
||||||
# Prediction head
|
# Prediction head
|
||||||
self.head = nn.Linear(patch_num * d_model, pred_len)
|
self.head = nn.Sequential(
|
||||||
|
nn.Linear(patch_num * d_model, patch_num * d_model),
|
||||||
|
nn.SiLU(), # 非线性激活(SiLU/Swish)
|
||||||
|
nn.Linear(patch_num * d_model, pred_len)
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# x: [B, L, C]
|
# x: [B, L, C]
|
||||||
@ -64,4 +68,4 @@ class SeasonPatch(nn.Module):
|
|||||||
z_mix = z_mix.view(B, C, N * D) # [B, C, patch_num * d_model]
|
z_mix = z_mix.view(B, C, N * D) # [B, C, patch_num * d_model]
|
||||||
y_pred = self.head(z_mix) # [B, C, pred_len]
|
y_pred = self.head(z_mix) # [B, C, pred_len]
|
||||||
|
|
||||||
return y_pred
|
return y_pred
|
||||||
|
@ -2,6 +2,71 @@
|
|||||||
|
|
||||||
model_name=xPatch_SparseChannel
|
model_name=xPatch_SparseChannel
|
||||||
|
|
||||||
|
# Weather dataset
|
||||||
|
for pred_len in 96 192 336 720
|
||||||
|
do
|
||||||
|
python -u run.py \
|
||||||
|
--task_name long_term_forecast \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/weather/ \
|
||||||
|
--data_path weather.csv \
|
||||||
|
--model_id weather_$pred_len'_'$pred_len \
|
||||||
|
--model $model_name \
|
||||||
|
--data custom \
|
||||||
|
--features M \
|
||||||
|
--seq_len 96 \
|
||||||
|
--label_len 48 \
|
||||||
|
--pred_len $pred_len \
|
||||||
|
--e_layers 2 \
|
||||||
|
--d_layers 1 \
|
||||||
|
--enc_in 21 \
|
||||||
|
--c_out 21 \
|
||||||
|
--d_model 128 \
|
||||||
|
--lradj 'sigmoid' \
|
||||||
|
--train_epochs 20 \
|
||||||
|
--d_ff 256 \
|
||||||
|
--n_heads 16 \
|
||||||
|
--patch_len 16 \
|
||||||
|
--stride 8 \
|
||||||
|
--k_graph 8 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--revin 1 \
|
||||||
|
--des 'Exp' \
|
||||||
|
--itr 1
|
||||||
|
done
|
||||||
|
|
||||||
|
# Exchange dataset
|
||||||
|
for pred_len in 96 192 336 720
|
||||||
|
do
|
||||||
|
python -u run.py \
|
||||||
|
--task_name long_term_forecast \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/exchange_rate/ \
|
||||||
|
--data_path exchange_rate.csv \
|
||||||
|
--model_id Exchange_$pred_len'_'$pred_len \
|
||||||
|
--model $model_name \
|
||||||
|
--data custom \
|
||||||
|
--features M \
|
||||||
|
--seq_len 96 \
|
||||||
|
--label_len 48 \
|
||||||
|
--pred_len $pred_len \
|
||||||
|
--e_layers 2 \
|
||||||
|
--d_layers 1 \
|
||||||
|
--enc_in 8 \
|
||||||
|
--c_out 8 \
|
||||||
|
--d_model 128 \
|
||||||
|
--lradj 'sigmoid' \
|
||||||
|
--d_ff 256 \
|
||||||
|
--n_heads 16 \
|
||||||
|
--patch_len 16 \
|
||||||
|
--stride 8 \
|
||||||
|
--k_graph 8 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--revin 1 \
|
||||||
|
--des 'Exp' \
|
||||||
|
--itr 1
|
||||||
|
done
|
||||||
|
|
||||||
# ETTm1 dataset
|
# ETTm1 dataset
|
||||||
for pred_len in 96 192 336 720
|
for pred_len in 96 192 336 720
|
||||||
do
|
do
|
||||||
@ -22,11 +87,12 @@ python -u run.py \
|
|||||||
--enc_in 7 \
|
--enc_in 7 \
|
||||||
--c_out 7 \
|
--c_out 7 \
|
||||||
--d_model 128 \
|
--d_model 128 \
|
||||||
|
--lradj 'sigmoid' \
|
||||||
--d_ff 256 \
|
--d_ff 256 \
|
||||||
--n_heads 16 \
|
--n_heads 16 \
|
||||||
--patch_len 16 \
|
--patch_len 16 \
|
||||||
--stride 8 \
|
--stride 8 \
|
||||||
--k_graph 7 \
|
--k_graph 5 \
|
||||||
--dropout 0.1 \
|
--dropout 0.1 \
|
||||||
--revin 1 \
|
--revin 1 \
|
||||||
--des 'Exp' \
|
--des 'Exp' \
|
||||||
@ -53,6 +119,7 @@ python -u run.py \
|
|||||||
--enc_in 7 \
|
--enc_in 7 \
|
||||||
--c_out 7 \
|
--c_out 7 \
|
||||||
--d_model 128 \
|
--d_model 128 \
|
||||||
|
--lradj 'sigmoid' \
|
||||||
--d_ff 256 \
|
--d_ff 256 \
|
||||||
--n_heads 16 \
|
--n_heads 16 \
|
||||||
--patch_len 16 \
|
--patch_len 16 \
|
||||||
@ -84,6 +151,7 @@ python -u run.py \
|
|||||||
--enc_in 7 \
|
--enc_in 7 \
|
||||||
--c_out 7 \
|
--c_out 7 \
|
||||||
--d_model 128 \
|
--d_model 128 \
|
||||||
|
--lradj 'sigmoid' \
|
||||||
--d_ff 256 \
|
--d_ff 256 \
|
||||||
--n_heads 16 \
|
--n_heads 16 \
|
||||||
--patch_len 16 \
|
--patch_len 16 \
|
||||||
@ -115,6 +183,7 @@ python -u run.py \
|
|||||||
--enc_in 7 \
|
--enc_in 7 \
|
||||||
--c_out 7 \
|
--c_out 7 \
|
||||||
--d_model 128 \
|
--d_model 128 \
|
||||||
|
--lradj 'sigmoid' \
|
||||||
--d_ff 256 \
|
--d_ff 256 \
|
||||||
--n_heads 16 \
|
--n_heads 16 \
|
||||||
--patch_len 16 \
|
--patch_len 16 \
|
||||||
@ -126,36 +195,6 @@ python -u run.py \
|
|||||||
--itr 1
|
--itr 1
|
||||||
done
|
done
|
||||||
|
|
||||||
# Weather dataset
|
|
||||||
for pred_len in 96 192 336 720
|
|
||||||
do
|
|
||||||
python -u run.py \
|
|
||||||
--task_name long_term_forecast \
|
|
||||||
--is_training 1 \
|
|
||||||
--root_path ./dataset/weather/ \
|
|
||||||
--data_path weather.csv \
|
|
||||||
--model_id weather_$pred_len'_'$pred_len \
|
|
||||||
--model $model_name \
|
|
||||||
--data custom \
|
|
||||||
--features M \
|
|
||||||
--seq_len 96 \
|
|
||||||
--label_len 48 \
|
|
||||||
--pred_len $pred_len \
|
|
||||||
--e_layers 2 \
|
|
||||||
--d_layers 1 \
|
|
||||||
--enc_in 21 \
|
|
||||||
--c_out 21 \
|
|
||||||
--d_model 128 \
|
|
||||||
--d_ff 256 \
|
|
||||||
--n_heads 16 \
|
|
||||||
--patch_len 16 \
|
|
||||||
--stride 8 \
|
|
||||||
--k_graph 8 \
|
|
||||||
--dropout 0.1 \
|
|
||||||
--revin 1 \
|
|
||||||
--des 'Exp' \
|
|
||||||
--itr 1
|
|
||||||
done
|
|
||||||
|
|
||||||
# ECL dataset
|
# ECL dataset
|
||||||
for pred_len in 96 192 336 720
|
for pred_len in 96 192 336 720
|
||||||
@ -219,33 +258,3 @@ python -u run.py \
|
|||||||
--itr 1
|
--itr 1
|
||||||
done
|
done
|
||||||
|
|
||||||
# Exchange dataset
|
|
||||||
for pred_len in 96 192 336 720
|
|
||||||
do
|
|
||||||
python -u run.py \
|
|
||||||
--task_name long_term_forecast \
|
|
||||||
--is_training 1 \
|
|
||||||
--root_path ./dataset/exchange_rate/ \
|
|
||||||
--data_path exchange_rate.csv \
|
|
||||||
--model_id Exchange_$pred_len'_'$pred_len \
|
|
||||||
--model $model_name \
|
|
||||||
--data custom \
|
|
||||||
--features M \
|
|
||||||
--seq_len 96 \
|
|
||||||
--label_len 48 \
|
|
||||||
--pred_len $pred_len \
|
|
||||||
--e_layers 2 \
|
|
||||||
--d_layers 1 \
|
|
||||||
--enc_in 8 \
|
|
||||||
--c_out 8 \
|
|
||||||
--d_model 128 \
|
|
||||||
--d_ff 256 \
|
|
||||||
--n_heads 16 \
|
|
||||||
--patch_len 16 \
|
|
||||||
--stride 8 \
|
|
||||||
--k_graph 8 \
|
|
||||||
--dropout 0.1 \
|
|
||||||
--revin 1 \
|
|
||||||
--des 'Exp' \
|
|
||||||
--itr 1
|
|
||||||
done
|
|
Reference in New Issue
Block a user