feat(graph-mixer): implement L0 sparsity with Hard-Concrete gate for channel selection

This commit is contained in:
gameloader
2025-09-11 16:50:58 +08:00
parent 5fc0da4239
commit 204d17086a
4 changed files with 268 additions and 124 deletions

View File

@ -1,27 +1,86 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class HardConcreteGate(nn.Module):
"""
Hard-Concrete gate for L0-style sparsity (Louizos et al., 2017).
Produces z in [0,1] without row-wise normalization.
"""
def __init__(self, shape, temperature=2./3., gamma=-0.1, zeta=1.1, init_log_alpha=-2.0):
super().__init__()
self.log_alpha = nn.Parameter(torch.full(shape, init_log_alpha))
self.temperature = temperature
self.gamma = gamma
self.zeta = zeta
def sample(self, training=True):
if training:
u = torch.rand_like(self.log_alpha)
s = torch.sigmoid((self.log_alpha + torch.log(u) - torch.log(1 - u)) / self.temperature)
else:
# deterministic mean gate at eval
s = torch.sigmoid(self.log_alpha)
s_bar = s * (self.zeta - self.gamma) + self.gamma
z = torch.clamp(s_bar, 0., 1.)
return z
def expected_l0(self):
"""
E[1_{z>0}] closed-form for hard-concrete.
Useful for L0 penalty: lambda * expected_l0.sum()
"""
# s > t0 => z > 0, where t0 = -gamma / (zeta - gamma)
t0 = -self.gamma / (self.zeta - self.gamma)
# logit(t0)
logit_t0 = math.log(t0) - math.log(1 - t0)
# P(x > logit_t0) with x ~ Logistic(loc=log_alpha, scale=temperature)
p_open = torch.sigmoid((self.log_alpha - logit_t0) / self.temperature)
return p_open
class HierarchicalGraphMixer(nn.Module):
"""
分层图混合器,同时考虑宏观通道关系和微观 Patch 级别注意力。
输入 z : [B, C, N, D]
输出 z_out : 同形状
使用 Hard-Concrete 边门控的分层图混合器:
- Level 1: 非归一化、可阈值、可为空的通道图
- Level 2: 仅在被选中的边上做 Patch 级别交叉注意力
输入: z [B, C, N, D]
输出: z_out 同形状
"""
def __init__(self, n_channel: int, dim: int, k: int = 5, tau_fw: float = 0.3, tau_bw: float = 3.0):
def __init__(
self,
n_channel: int,
dim: int,
max_degree: int = None, # 可选:限制每行最多边数
thr: float = 0.5, # 保留边阈值,例如 0.5/0.7
temperature: float = 2./3.,
tau_attn: float = 1.0, # Patch attention 温度(可选)
symmetric: bool = True, # 是否对称化通道图
degree_rescale: str = "none", # "none" | "count" | "count-sqrt" | "sum"
init_log_alpha: float = -2.0
):
super().__init__()
self.k = k
self.tau_fw = tau_fw # 前向温度(小)
self.tau_bw = tau_bw # 反向温度(大)
# Level 1: Channel Graph (logits)
self.A = nn.Parameter(torch.zeros(n_channel, n_channel))
self.C = n_channel
self.dim = dim
self.max_degree = max_degree
self.thr = thr
self.tau_attn = tau_attn
self.symmetric = symmetric
self.degree_rescale = degree_rescale
# Level 1: 非归一化门控
self.gate = HardConcreteGate(
shape=(n_channel, n_channel),
temperature=temperature,
init_log_alpha=init_log_alpha
)
# 可选 SE你原来的 se 可以用来生成样本相关的通道优先级,但这里先保留接口)
self.se = nn.Sequential(
nn.Linear(dim, dim // 4, bias=False), nn.SiLU(),
nn.Linear(dim // 4, 1, bias=False), nn.Sigmoid()
)
# Level 2: Patch Cross-Attention
self.q_proj = nn.Linear(dim, dim)
self.k_proj = nn.Linear(dim, dim)
@ -29,96 +88,108 @@ class HierarchicalGraphMixer(nn.Module):
self.out_proj = nn.Linear(dim, dim)
self.norm = nn.LayerNorm(dim)
@torch.no_grad()
def _mask_self_logits_(self, logits: torch.Tensor):
"""把对角线置为 -inf确保不选到自己"""
C = logits.size(0)
eye = torch.eye(C, device=logits.device, dtype=torch.bool)
logits.masked_fill_(eye, float("-inf"))
def _gumbel_topk_select(self, logits: torch.Tensor):
def _build_sparse_neighbors(self, z_gate):
"""
返回:
- idx: [C, k_actual] 每行 top-k 的通道索引(不含自身)
- w_st: [C, k_actual] 选中边的权重(前向=用 tau_fw 的概率;反向梯度=来自 tau_bw 的概率)
基于 z_gate 构造每行的邻接列表按阈值与可选top-k
返回:
- idx_list: 长度C的list每项是LongTensor[idx_j]
- w_list: 长度C的list每项是FloatTensor[w_j](非归一化)
"""
C = logits.size(0)
k_actual = min(self.k, C - 1)
if k_actual <= 0:
idx = torch.empty((C, 0), dtype=torch.long, device=logits.device)
w_st = torch.empty((C, 0), dtype=logits.dtype, device=logits.device)
return idx, w_st
C = z_gate.size(0)
# 去对角
z_gate = z_gate.clone()
z_gate.fill_diagonal_(0.0)
# 共享一份 Gumbel 噪声,分别用不同温度构造前向/反向的分布
g = -torch.empty_like(logits).exponential_().log()
y_fw = (logits + g) / self.tau_fw
y_bw = (logits + g) / self.tau_bw
if self.symmetric:
z_gate = 0.5 * (z_gate + z_gate.t())
z_gate.fill_diagonal_(0.0)
# 排除自身
y_fw = y_fw.clone()
y_bw = y_bw.clone()
self._mask_self_logits_(y_fw)
self._mask_self_logits_(y_bw)
idx_list, w_list = [], []
for i in range(C):
row = z_gate[i] # [C]
# 阈值筛选
mask = row > self.thr
if mask.any():
vals = row[mask]
idxs = torch.nonzero(mask, as_tuple=False).squeeze(-1)
# 可选最多度数限制
if (self.max_degree is not None) and (idxs.numel() > self.max_degree):
topk = torch.topk(vals, k=self.max_degree, dim=0)
vals = topk.values
idxs = idxs[topk.indices]
else:
idxs = torch.empty((0,), dtype=torch.long, device=row.device)
vals = torch.empty((0,), dtype=row.dtype, device=row.device)
idx_list.append(idxs)
w_list.append(vals)
return idx_list, w_list
# 选择前向 top-k严格选择
topk_val, idx = torch.topk(y_fw, k_actual, dim=-1) # [C, k]
# 计算前向/反向的软概率,并仅收集被选中的 k 个
p_fw = F.softmax(y_fw, dim=-1) # [C, C]
p_bw = F.softmax(y_bw, dim=-1) # [C, C]
w_fw = torch.gather(p_fw, -1, idx) # [C, k]
w_bw = torch.gather(p_bw, -1, idx) # [C, k]
def _degree_rescale(self, ctx, w_sel):
"""
非归一化聚合的稳定性处理。可选对聚合值做degree归一化以稳定数值。
ctx: [B, k, N, D]
w_sel: [k]
"""
if self.degree_rescale == "none":
return (ctx * w_sel.view(1, -1, 1, 1)).sum(dim=1)
elif self.degree_rescale == "count":
k = max(1, w_sel.numel())
return (ctx * w_sel.view(1, -1, 1, 1)).sum(dim=1) / float(k)
elif self.degree_rescale == "count-sqrt":
k = max(1, w_sel.numel())
return (ctx * w_sel.view(1, -1, 1, 1)).sum(dim=1) / math.sqrt(k)
elif self.degree_rescale == "sum":
s = float(w_sel.sum().clamp(min=1e-6))
return (ctx * w_sel.view(1, -1, 1, 1)).sum(dim=1) / s
else:
return (ctx * w_sel.view(1, -1, 1, 1)).sum(dim=1)
# 在被选集合内进行归一化,稳定训练
eps = 1e-9
w_fw = w_fw / (w_fw.sum(-1, keepdim=True) + eps)
w_bw = w_bw / (w_bw.sum(-1, keepdim=True) + eps)
# Straight-Through前向用 w_fw反向梯度用 w_bw
w_st = w_fw.detach() + w_bw - w_bw.detach() # [C, k]
return idx, w_st
def l0_loss(self, lam: float = 1e-4):
"""
期望L0正则鼓励稀疏邻接可调强度
"""
return lam * self.gate.expected_l0().sum()
def forward(self, z):
# z: [B, C, N, D]
B, C, N, D = z.shape
assert C == self.C and D == self.dim
# --- Level 1: 选每个通道的 top-k 相关通道不含自身并得到ST权重 ---
idx, w_st = self._gumbel_topk_select(self.A) # idx:[C,k], w_st:[C,k]
# Level 1: 采样非归一化门 z_gate ∈ [0,1]
z_gate = self.gate.sample(training=self.training) # [C, C]
# --- Level 2: 仅对被选中的通道做跨通道 Patch 交互 ---
# 构建稀疏邻居(阈值 + 可选 top-k
idx_list, w_list = self._build_sparse_neighbors(z_gate)
# Level 2: 仅对被保留的边做跨通道 Patch 交互
out_z = torch.zeros_like(z)
for i in range(C):
target_z = z[:, i, :, :] # [B, N, D]
# 如果该通道没有可选邻居,直接残差
if idx.size(1) == 0:
idx = idx_list[i]
if idx.numel() == 0:
# 空邻域:允许“没有相关通道”,仅残差/归一化
out_z[:, i, :, :] = self.norm(target_z)
continue
sel_idx = idx[i] # [k]
sel_w = w_st[i] # [k]
k_i = sel_idx.numel()
w_sel = w_list[i] # [k], 非归一化权重,范围[0,1]
k_i = idx.numel()
# 源通道块: [B, k, N, D]
source_z = z[:, sel_idx, :, :]
source_z = z[:, idx, :, :] # [B, k, N, D]
# 线性投影
Q = self.q_proj(target_z) # [B, N, D]
Q = self.q_proj(target_z) # [B, N, D]
K = self.k_proj(source_z.reshape(B * k_i, N, D)).reshape(B, k_i, N, D)
V = self.v_proj(source_z.reshape(B * k_i, N, D)).reshape(B, k_i, N, D)
# 跨注意力(一次性对 k 个源通道)
# attn_scores: [B, k, N, N]
# 跨通道 patch 注意力
attn_scores = torch.einsum('bnd,bkmd->bknm', Q, K) / math.sqrt(D)
attn_probs = F.softmax(attn_scores, dim=-1) # [B, k, N, N]
context = torch.einsum('bknm,bkmd->bknd', attn_probs, V) # [B, k, N, D]
if self.tau_attn != 1.0:
attn_scores = attn_scores / self.tau_attn
attn_probs = F.softmax(attn_scores, dim=-1) # [B, k, N, N]
context = torch.einsum('bknm,bkmd->bknd', attn_probs, V) # [B, k, N, D]
# 用 ST 的通道权重聚合(前向=小温度的权重,反向梯度=大温度
w = sel_w.view(1, k_i, 1, 1) # [1, k, 1, 1]
aggregated_context = (context * w).sum(dim=1) # [B, N, D]
# 输出与残差
# 非归一化通道权重聚合 + 可选度归一化(仅数值稳定,不改变“非归一化”的语义
aggregated_context = self._degree_rescale(context, w_sel) # [B, N, D]
out_z[:, i, :, :] = self.norm(target_z + self.out_proj(aggregated_context))
return out_z