83 lines
3.3 KiB
Python
83 lines
3.3 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import math
|
|
|
|
class HierarchicalGraphMixer(nn.Module):
|
|
"""
|
|
分层图混合器,同时考虑宏观通道关系和微观 Patch 级别注意力。
|
|
输入 z : 形状为 [B, C, N, D] 的张量
|
|
输出 z_out : 形状同输入
|
|
"""
|
|
def __init__(self, n_channel: int, dim: int, k: int = 5, tau: float = 0.2):
|
|
super().__init__()
|
|
self.k = k
|
|
self.tau = tau
|
|
|
|
# Level 1: Channel Graph
|
|
self.A = nn.Parameter(torch.zeros(n_channel, n_channel))
|
|
self.se = nn.Sequential(
|
|
nn.Linear(dim, dim // 4, bias=False), nn.ReLU(),
|
|
nn.Linear(dim // 4, 1, bias=False), nn.Sigmoid()
|
|
)
|
|
|
|
# Level 2: Patch Cross-Attention
|
|
self.q_proj = nn.Linear(dim, dim)
|
|
self.k_proj = nn.Linear(dim, dim)
|
|
self.v_proj = nn.Linear(dim, dim)
|
|
self.out_proj = nn.Linear(dim, dim)
|
|
self.norm = nn.LayerNorm(dim)
|
|
|
|
def _row_sparse(self, logits: torch.Tensor) -> torch.Tensor:
|
|
"""Gumbel-Softmax based sparse attention"""
|
|
g = -torch.empty_like(logits).exponential_().log()
|
|
y = (logits + g) / self.tau
|
|
probs = F.softmax(y, dim=-1)
|
|
|
|
# Ensure k doesn't exceed the dimension size
|
|
k_actual = min(self.k, probs.size(-1))
|
|
if k_actual <= 0:
|
|
return torch.zeros_like(probs)
|
|
|
|
topk_val, _ = torch.topk(probs, k_actual, dim=-1)
|
|
thr = topk_val[..., -1].unsqueeze(-1)
|
|
sparse = torch.where(probs >= thr, probs, torch.zeros_like(probs))
|
|
return sparse.detach() + probs - probs.detach()
|
|
|
|
def forward(self, z):
|
|
# z 的形状: [B, C, N, D]
|
|
B, C, N, D = z.shape
|
|
|
|
# --- Level 1: 计算宏观权重 ---
|
|
A_sparse = self._row_sparse(self.A) # 通道连接稀疏图 A_sparse: [C, C]
|
|
|
|
# --- Level 2: 跨通道 Patch 交互 ---
|
|
out_z = torch.zeros_like(z)
|
|
for i in range(C): # 遍历每个目标通道 i
|
|
target_z = z[:, i, :, :] # [B, N, D]
|
|
|
|
# 准备聚合来自其他通道的 patch 级别上下文
|
|
aggregated_context = torch.zeros_like(target_z)
|
|
|
|
for j in range(C): # 遍历每个源通道 j
|
|
if A_sparse[i, j] != 0:
|
|
source_z = z[:, j, :, :] # [B, N, D]
|
|
|
|
# --- 执行交叉注意力 ---
|
|
Q = self.q_proj(target_z) # Query 来自目标通道 i
|
|
K = self.k_proj(source_z) # Key 来自源通道 j
|
|
V = self.v_proj(source_z) # Value 来自源通道 j
|
|
|
|
attn_scores = torch.bmm(Q, K.transpose(1, 2)) / math.sqrt(D)
|
|
attn_probs = F.softmax(attn_scores, dim=-1) # [B, N, N]
|
|
|
|
context = torch.bmm(attn_probs, V) # [B, N, D], 从 j 聚合到 i 的上下文
|
|
|
|
# 加权上下文
|
|
weighted_context = A_sparse[i, j] * context
|
|
aggregated_context = aggregated_context + weighted_context
|
|
|
|
# 将聚合后的上下文通过输出层,并与原始目标表示相加(残差连接)
|
|
out_z[:, i, :, :] = self.norm(target_z + self.out_proj(aggregated_context))
|
|
|
|
return out_z |