import torch import torch.nn as nn import torch.nn.functional as F import math class HierarchicalGraphMixer(nn.Module): """ 分层图混合器,同时考虑宏观通道关系和微观 Patch 级别注意力。 输入 z : 形状为 [B, C, N, D] 的张量 输出 z_out : 形状同输入 """ def __init__(self, n_channel: int, dim: int, k: int = 5, tau: float = 0.2): super().__init__() self.k = k self.tau = tau # Level 1: Channel Graph self.A = nn.Parameter(torch.zeros(n_channel, n_channel)) self.se = nn.Sequential( nn.Linear(dim, dim // 4, bias=False), nn.ReLU(), nn.Linear(dim // 4, 1, bias=False), nn.Sigmoid() ) # Level 2: Patch Cross-Attention self.q_proj = nn.Linear(dim, dim) self.k_proj = nn.Linear(dim, dim) self.v_proj = nn.Linear(dim, dim) self.out_proj = nn.Linear(dim, dim) self.norm = nn.LayerNorm(dim) def _row_sparse(self, logits: torch.Tensor) -> torch.Tensor: """Gumbel-Softmax based sparse attention""" g = -torch.empty_like(logits).exponential_().log() y = (logits + g) / self.tau probs = F.softmax(y, dim=-1) # Ensure k doesn't exceed the dimension size k_actual = min(self.k, probs.size(-1)) if k_actual <= 0: return torch.zeros_like(probs) topk_val, _ = torch.topk(probs, k_actual, dim=-1) thr = topk_val[..., -1].unsqueeze(-1) sparse = torch.where(probs >= thr, probs, torch.zeros_like(probs)) return sparse.detach() + probs - probs.detach() def forward(self, z): # z 的形状: [B, C, N, D] B, C, N, D = z.shape # --- Level 1: 计算宏观权重 --- A_sparse = self._row_sparse(self.A) # 通道连接稀疏图 A_sparse: [C, C] # --- Level 2: 跨通道 Patch 交互 --- out_z = torch.zeros_like(z) for i in range(C): # 遍历每个目标通道 i target_z = z[:, i, :, :] # [B, N, D] # 准备聚合来自其他通道的 patch 级别上下文 aggregated_context = torch.zeros_like(target_z) for j in range(C): # 遍历每个源通道 j if A_sparse[i, j] != 0: source_z = z[:, j, :, :] # [B, N, D] # --- 执行交叉注意力 --- Q = self.q_proj(target_z) # Query 来自目标通道 i K = self.k_proj(source_z) # Key 来自源通道 j V = self.v_proj(source_z) # Value 来自源通道 j attn_scores = torch.bmm(Q, K.transpose(1, 2)) / math.sqrt(D) attn_probs = F.softmax(attn_scores, dim=-1) # [B, N, N] context = torch.bmm(attn_probs, V) # [B, N, D], 从 j 聚合到 i 的上下文 # 加权上下文 weighted_context = A_sparse[i, j] * context aggregated_context = aggregated_context + weighted_context # 将聚合后的上下文通过输出层,并与原始目标表示相加(残差连接) out_z[:, i, :, :] = self.norm(target_z + self.out_proj(aggregated_context)) return out_z