Files
TSlib/layers/GraphMixer.py

125 lines
4.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class HierarchicalGraphMixer(nn.Module):
"""
分层图混合器,同时考虑宏观通道关系和微观 Patch 级别注意力。
输入 z : [B, C, N, D]
输出 z_out : 同形状
"""
def __init__(self, n_channel: int, dim: int, k: int = 5, tau_fw: float = 0.3, tau_bw: float = 3.0):
super().__init__()
self.k = k
self.tau_fw = tau_fw # 前向温度(小)
self.tau_bw = tau_bw # 反向温度(大)
# Level 1: Channel Graph (logits)
self.A = nn.Parameter(torch.zeros(n_channel, n_channel))
self.se = nn.Sequential(
nn.Linear(dim, dim // 4, bias=False), nn.SiLU(),
nn.Linear(dim // 4, 1, bias=False), nn.Sigmoid()
)
# Level 2: Patch Cross-Attention
self.q_proj = nn.Linear(dim, dim)
self.k_proj = nn.Linear(dim, dim)
self.v_proj = nn.Linear(dim, dim)
self.out_proj = nn.Linear(dim, dim)
self.norm = nn.LayerNorm(dim)
@torch.no_grad()
def _mask_self_logits_(self, logits: torch.Tensor):
"""把对角线置为 -inf确保不选到自己"""
C = logits.size(0)
eye = torch.eye(C, device=logits.device, dtype=torch.bool)
logits.masked_fill_(eye, float("-inf"))
def _gumbel_topk_select(self, logits: torch.Tensor):
"""
返回:
- idx: [C, k_actual] 每行 top-k 的通道索引(不含自身)
- w_st: [C, k_actual] 选中边的权重(前向=用 tau_fw 的概率;反向梯度=来自 tau_bw 的概率)
"""
C = logits.size(0)
k_actual = min(self.k, C - 1)
if k_actual <= 0:
idx = torch.empty((C, 0), dtype=torch.long, device=logits.device)
w_st = torch.empty((C, 0), dtype=logits.dtype, device=logits.device)
return idx, w_st
# 共享一份 Gumbel 噪声,分别用不同温度构造前向/反向的分布
g = -torch.empty_like(logits).exponential_().log()
y_fw = (logits + g) / self.tau_fw
y_bw = (logits + g) / self.tau_bw
# 排除自身
y_fw = y_fw.clone()
y_bw = y_bw.clone()
self._mask_self_logits_(y_fw)
self._mask_self_logits_(y_bw)
# 选择前向 top-k严格选择
topk_val, idx = torch.topk(y_fw, k_actual, dim=-1) # [C, k]
# 计算前向/反向的软概率,并仅收集被选中的 k 个
p_fw = F.softmax(y_fw, dim=-1) # [C, C]
p_bw = F.softmax(y_bw, dim=-1) # [C, C]
w_fw = torch.gather(p_fw, -1, idx) # [C, k]
w_bw = torch.gather(p_bw, -1, idx) # [C, k]
# 在被选集合内进行归一化,稳定训练
eps = 1e-9
w_fw = w_fw / (w_fw.sum(-1, keepdim=True) + eps)
w_bw = w_bw / (w_bw.sum(-1, keepdim=True) + eps)
# Straight-Through前向用 w_fw反向梯度用 w_bw
w_st = w_fw.detach() + w_bw - w_bw.detach() # [C, k]
return idx, w_st
def forward(self, z):
# z: [B, C, N, D]
B, C, N, D = z.shape
# --- Level 1: 选每个通道的 top-k 相关通道不含自身并得到ST权重 ---
idx, w_st = self._gumbel_topk_select(self.A) # idx:[C,k], w_st:[C,k]
# --- Level 2: 仅对被选中的通道做跨通道 Patch 交互 ---
out_z = torch.zeros_like(z)
for i in range(C):
target_z = z[:, i, :, :] # [B, N, D]
# 如果该通道没有可选邻居,直接残差
if idx.size(1) == 0:
out_z[:, i, :, :] = self.norm(target_z)
continue
sel_idx = idx[i] # [k]
sel_w = w_st[i] # [k]
k_i = sel_idx.numel()
# 源通道块: [B, k, N, D]
source_z = z[:, sel_idx, :, :]
# 线性投影
Q = self.q_proj(target_z) # [B, N, D]
K = self.k_proj(source_z.reshape(B * k_i, N, D)).reshape(B, k_i, N, D)
V = self.v_proj(source_z.reshape(B * k_i, N, D)).reshape(B, k_i, N, D)
# 跨注意力(一次性对 k 个源通道)
# attn_scores: [B, k, N, N]
attn_scores = torch.einsum('bnd,bkmd->bknm', Q, K) / math.sqrt(D)
attn_probs = F.softmax(attn_scores, dim=-1) # [B, k, N, N]
context = torch.einsum('bknm,bkmd->bknd', attn_probs, V) # [B, k, N, D]
# 用 ST 的通道权重聚合(前向=小温度的权重,反向梯度=大温度)
w = sel_w.view(1, k_i, 1, 1) # [1, k, 1, 1]
aggregated_context = (context * w).sum(dim=1) # [B, N, D]
# 输出与残差
out_z[:, i, :, :] = self.norm(target_z + self.out_proj(aggregated_context))
return out_z