import torch import torch.nn as nn import torch.nn.functional as F import math class HierarchicalGraphMixer(nn.Module): """ 分层图混合器,同时考虑宏观通道关系和微观 Patch 级别注意力。 输入 z : [B, C, N, D] 输出 z_out : 同形状 """ def __init__(self, n_channel: int, dim: int, k: int = 5, tau_fw: float = 0.3, tau_bw: float = 3.0): super().__init__() self.k = k self.tau_fw = tau_fw # 前向温度(小) self.tau_bw = tau_bw # 反向温度(大) # Level 1: Channel Graph (logits) self.A = nn.Parameter(torch.zeros(n_channel, n_channel)) self.se = nn.Sequential( nn.Linear(dim, dim // 4, bias=False), nn.SiLU(), nn.Linear(dim // 4, 1, bias=False), nn.Sigmoid() ) # Level 2: Patch Cross-Attention self.q_proj = nn.Linear(dim, dim) self.k_proj = nn.Linear(dim, dim) self.v_proj = nn.Linear(dim, dim) self.out_proj = nn.Linear(dim, dim) self.norm = nn.LayerNorm(dim) @torch.no_grad() def _mask_self_logits_(self, logits: torch.Tensor): """把对角线置为 -inf,确保不选到自己""" C = logits.size(0) eye = torch.eye(C, device=logits.device, dtype=torch.bool) logits.masked_fill_(eye, float("-inf")) def _gumbel_topk_select(self, logits: torch.Tensor): """ 返回: - idx: [C, k_actual] 每行 top-k 的通道索引(不含自身) - w_st: [C, k_actual] 选中边的权重(前向=用 tau_fw 的概率;反向梯度=来自 tau_bw 的概率) """ C = logits.size(0) k_actual = min(self.k, C - 1) if k_actual <= 0: idx = torch.empty((C, 0), dtype=torch.long, device=logits.device) w_st = torch.empty((C, 0), dtype=logits.dtype, device=logits.device) return idx, w_st # 共享一份 Gumbel 噪声,分别用不同温度构造前向/反向的分布 g = -torch.empty_like(logits).exponential_().log() y_fw = (logits + g) / self.tau_fw y_bw = (logits + g) / self.tau_bw # 排除自身 y_fw = y_fw.clone() y_bw = y_bw.clone() self._mask_self_logits_(y_fw) self._mask_self_logits_(y_bw) # 选择前向 top-k(严格选择) topk_val, idx = torch.topk(y_fw, k_actual, dim=-1) # [C, k] # 计算前向/反向的软概率,并仅收集被选中的 k 个 p_fw = F.softmax(y_fw, dim=-1) # [C, C] p_bw = F.softmax(y_bw, dim=-1) # [C, C] w_fw = torch.gather(p_fw, -1, idx) # [C, k] w_bw = torch.gather(p_bw, -1, idx) # [C, k] # 在被选集合内进行归一化,稳定训练 eps = 1e-9 w_fw = w_fw / (w_fw.sum(-1, keepdim=True) + eps) w_bw = w_bw / (w_bw.sum(-1, keepdim=True) + eps) # Straight-Through:前向用 w_fw,反向梯度用 w_bw w_st = w_fw.detach() + w_bw - w_bw.detach() # [C, k] return idx, w_st def forward(self, z): # z: [B, C, N, D] B, C, N, D = z.shape # --- Level 1: 选每个通道的 top-k 相关通道(不含自身),并得到ST权重 --- idx, w_st = self._gumbel_topk_select(self.A) # idx:[C,k], w_st:[C,k] # --- Level 2: 仅对被选中的通道做跨通道 Patch 交互 --- out_z = torch.zeros_like(z) for i in range(C): target_z = z[:, i, :, :] # [B, N, D] # 如果该通道没有可选邻居,直接残差 if idx.size(1) == 0: out_z[:, i, :, :] = self.norm(target_z) continue sel_idx = idx[i] # [k] sel_w = w_st[i] # [k] k_i = sel_idx.numel() # 源通道块: [B, k, N, D] source_z = z[:, sel_idx, :, :] # 线性投影 Q = self.q_proj(target_z) # [B, N, D] K = self.k_proj(source_z.reshape(B * k_i, N, D)).reshape(B, k_i, N, D) V = self.v_proj(source_z.reshape(B * k_i, N, D)).reshape(B, k_i, N, D) # 跨注意力(一次性对 k 个源通道) # attn_scores: [B, k, N, N] attn_scores = torch.einsum('bnd,bkmd->bknm', Q, K) / math.sqrt(D) attn_probs = F.softmax(attn_scores, dim=-1) # [B, k, N, N] context = torch.einsum('bknm,bkmd->bknd', attn_probs, V) # [B, k, N, D] # 用 ST 的通道权重聚合(前向=小温度的权重,反向梯度=大温度) w = sel_w.view(1, k_i, 1, 1) # [1, k, 1, 1] aggregated_context = (context * w).sum(dim=1) # [B, N, D] # 输出与残差 out_z[:, i, :, :] = self.norm(target_z + self.out_proj(aggregated_context)) return out_z