import math import torch import torch.nn as nn import torch.nn.functional as F class HardConcreteGate(nn.Module): """ Hard-Concrete gate for L0-style sparsity (Louizos et al., 2017). Produces z in [0,1] without row-wise normalization. """ def __init__(self, shape, temperature=2./3., gamma=-0.1, zeta=1.1, init_log_alpha=-2.0): super().__init__() self.log_alpha = nn.Parameter(torch.full(shape, init_log_alpha)) self.temperature = temperature self.gamma = gamma self.zeta = zeta def sample(self, training=True): if training: u = torch.rand_like(self.log_alpha) s = torch.sigmoid((self.log_alpha + torch.log(u) - torch.log(1 - u)) / self.temperature) else: # deterministic mean gate at eval s = torch.sigmoid(self.log_alpha) s_bar = s * (self.zeta - self.gamma) + self.gamma z = torch.clamp(s_bar, 0., 1.) return z def expected_l0(self): """ E[1_{z>0}] closed-form for hard-concrete. Useful for L0 penalty: lambda * expected_l0.sum() """ # s > t0 => z > 0, where t0 = -gamma / (zeta - gamma) t0 = -self.gamma / (self.zeta - self.gamma) # logit(t0) logit_t0 = math.log(t0) - math.log(1 - t0) # P(x > logit_t0) with x ~ Logistic(loc=log_alpha, scale=temperature) p_open = torch.sigmoid((self.log_alpha - logit_t0) / self.temperature) return p_open class HierarchicalGraphMixer(nn.Module): """ 使用 Hard-Concrete 边门控的分层图混合器: - Level 1: 非归一化、可阈值、可为空的通道图 - Level 2: 仅在被选中的边上做 Patch 级别交叉注意力 输入: z [B, C, N, D] 输出: z_out 同形状 """ def __init__( self, n_channel: int, dim: int, max_degree: int = None, # 可选:限制每行最多边数 thr: float = 0.5, # 保留边阈值,例如 0.5/0.7 thr_min: float = None, # 动态阈值起点,不传则用 thr thr_max: float = None, # 动态阈值终点,不传则用 thr thr_steps: int = 0, # 从 thr_min -> thr_max 的步数,>0 时启用动态调度 thr_schedule: str = "linear", # "linear" | "cosine" | "exp" temperature: float = 2./3., tau_attn: float = 1.0, # Patch attention 温度(可选) symmetric: bool = True, # 是否对称化通道图 degree_rescale: str = "none", # "none" | "count" | "count-sqrt" | "sum" init_log_alpha: float = -2.0 ): super().__init__() self.C = n_channel self.dim = dim self.max_degree = max_degree self.thr = thr self.tau_attn = tau_attn self.symmetric = symmetric self.degree_rescale = degree_rescale self.thr_min = thr if (thr_min is None) else float(thr_min) self.thr_max = thr if (thr_max is None) else float(thr_max) self.thr_steps = int(thr_steps) if thr_steps is not None else 0 self.thr_schedule = thr_schedule self._use_dynamic_thr = (self.thr_steps > 0) and (abs(self.thr_max - self.thr_min) > 1e-12) # 用 buffer 记录已步进次数(不保存到权重里) self.register_buffer("_thr_step", torch.zeros((), dtype=torch.long), persistent=False) # Level 1: 非归一化门控 self.gate = HardConcreteGate( shape=(n_channel, n_channel), temperature=temperature, init_log_alpha=init_log_alpha ) # 可选 SE(你原来的 se 可以用来生成样本相关的通道优先级,但这里先保留接口) self.se = nn.Sequential( nn.Linear(dim, dim // 4, bias=False), nn.SiLU(), nn.Linear(dim // 4, 1, bias=False), nn.Sigmoid() ) # Level 2: Patch Cross-Attention self.q_proj = nn.Linear(dim, dim) self.k_proj = nn.Linear(dim, dim) self.v_proj = nn.Linear(dim, dim) self.out_proj = nn.Linear(dim, dim) self.norm = nn.LayerNorm(dim) def _compute_thr_by_progress(self, progress: float) -> float: # progress in [0,1] progress = max(0.0, min(1.0, float(progress))) if self.thr_schedule == "linear": g = progress elif self.thr_schedule == "cosine": # 慢起步,后期加速 import math g = 0.5 - 0.5 * math.cos(math.pi * progress) elif self.thr_schedule == "exp": # 更快从 thr_min 过渡到 thr_max(指数式) import math k = 5.0 g = (math.exp(k * progress) - 1.0) / (math.exp(k) - 1.0) else: g = progress return self.thr_min + (self.thr_max - self.thr_min) * g def _maybe_update_thr(self, training): if training and self._use_dynamic_thr: step = int(self._thr_step.item()) progress = step / float(self.thr_steps) self.thr = float(self._compute_thr_by_progress(progress)) self._thr_step += 1 def _build_sparse_neighbors(self, z_gate): """ 基于 z_gate 构造每行的邻接列表(按阈值与可选top-k)。 返回: - idx_list: 长度C的list,每项是LongTensor[idx_j] - w_list: 长度C的list,每项是FloatTensor[w_j](非归一化) """ C = z_gate.size(0) # 去对角 z_gate = z_gate.clone() z_gate.fill_diagonal_(0.0) if self.symmetric: z_gate = 0.5 * (z_gate + z_gate.t()) z_gate.fill_diagonal_(0.0) idx_list, w_list = [], [] for i in range(C): row = z_gate[i] # [C] # 阈值筛选 mask = row > self.thr if mask.any(): vals = row[mask] idxs = torch.nonzero(mask, as_tuple=False).squeeze(-1) # 可选最多度数限制 if (self.max_degree is not None) and (idxs.numel() > self.max_degree): topk = torch.topk(vals, k=self.max_degree, dim=0) vals = topk.values idxs = idxs[topk.indices] else: idxs = torch.empty((0,), dtype=torch.long, device=row.device) vals = torch.empty((0,), dtype=row.dtype, device=row.device) idx_list.append(idxs) w_list.append(vals) return idx_list, w_list def _degree_rescale(self, ctx, w_sel): """ 非归一化聚合的稳定性处理。可选对聚合值做degree归一化以稳定数值。 ctx: [B, k, N, D] w_sel: [k] """ if self.degree_rescale == "none": return (ctx * w_sel.view(1, -1, 1, 1)).sum(dim=1) elif self.degree_rescale == "count": k = max(1, w_sel.numel()) return (ctx * w_sel.view(1, -1, 1, 1)).sum(dim=1) / float(k) elif self.degree_rescale == "count-sqrt": k = max(1, w_sel.numel()) return (ctx * w_sel.view(1, -1, 1, 1)).sum(dim=1) / math.sqrt(k) elif self.degree_rescale == "sum": s = float(w_sel.sum().clamp(min=1e-6)) return (ctx * w_sel.view(1, -1, 1, 1)).sum(dim=1) / s else: return (ctx * w_sel.view(1, -1, 1, 1)).sum(dim=1) def l0_loss(self, lam: float = 1e-4): """ 期望L0正则:鼓励稀疏邻接(可调强度)。 """ return lam * self.gate.expected_l0().sum() def forward(self, z, is_training): self._maybe_update_thr(training=is_training) # z: [B, C, N, D] B, C, N, D = z.shape assert C == self.C and D == self.dim # Level 1: 采样非归一化门 z_gate ∈ [0,1] z_gate = self.gate.sample(training=is_training) # [C, C] # 构建稀疏邻居(阈值 + 可选 top-k) idx_list, w_list = self._build_sparse_neighbors(z_gate) # Level 2: 仅对被保留的边做跨通道 Patch 交互 out_z = torch.zeros_like(z) for i in range(C): target_z = z[:, i, :, :] # [B, N, D] idx = idx_list[i] if idx.numel() == 0: # 空邻域:允许“没有相关通道”,仅残差/归一化 out_z[:, i, :, :] = self.norm(target_z) continue w_sel = w_list[i] # [k], 非归一化权重,范围[0,1] k_i = idx.numel() source_z = z[:, idx, :, :] # [B, k, N, D] Q = self.q_proj(target_z) # [B, N, D] K = self.k_proj(source_z.reshape(B * k_i, N, D)).reshape(B, k_i, N, D) V = self.v_proj(source_z.reshape(B * k_i, N, D)).reshape(B, k_i, N, D) # 跨通道 patch 注意力 attn_scores = torch.einsum('bnd,bkmd->bknm', Q, K) / math.sqrt(D) if self.tau_attn != 1.0: attn_scores = attn_scores / self.tau_attn attn_probs = F.softmax(attn_scores, dim=-1) # [B, k, N, N] context = torch.einsum('bknm,bkmd->bknd', attn_probs, V) # [B, k, N, D] # 非归一化通道权重聚合 + 可选度归一化(仅数值稳定,不改变“非归一化”的语义) aggregated_context = self._degree_rescale(context, w_sel) # [B, N, D] out_z[:, i, :, :] = self.norm(target_z + self.out_proj(aggregated_context)) return out_z