Files
TSlib/layers/GraphMixer.py

232 lines
9.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class HardConcreteGate(nn.Module):
"""
Hard-Concrete gate for L0-style sparsity (Louizos et al., 2017).
Produces z in [0,1] without row-wise normalization.
"""
def __init__(self, shape, temperature=2./3., gamma=-0.1, zeta=1.1, init_log_alpha=-2.0):
super().__init__()
self.log_alpha = nn.Parameter(torch.full(shape, init_log_alpha))
self.temperature = temperature
self.gamma = gamma
self.zeta = zeta
def sample(self, training=True):
if training:
u = torch.rand_like(self.log_alpha)
s = torch.sigmoid((self.log_alpha + torch.log(u) - torch.log(1 - u)) / self.temperature)
else:
# deterministic mean gate at eval
s = torch.sigmoid(self.log_alpha)
s_bar = s * (self.zeta - self.gamma) + self.gamma
z = torch.clamp(s_bar, 0., 1.)
return z
def expected_l0(self):
"""
E[1_{z>0}] closed-form for hard-concrete.
Useful for L0 penalty: lambda * expected_l0.sum()
"""
# s > t0 => z > 0, where t0 = -gamma / (zeta - gamma)
t0 = -self.gamma / (self.zeta - self.gamma)
# logit(t0)
logit_t0 = math.log(t0) - math.log(1 - t0)
# P(x > logit_t0) with x ~ Logistic(loc=log_alpha, scale=temperature)
p_open = torch.sigmoid((self.log_alpha - logit_t0) / self.temperature)
return p_open
class HierarchicalGraphMixer(nn.Module):
"""
使用 Hard-Concrete 边门控的分层图混合器:
- Level 1: 非归一化、可阈值、可为空的通道图
- Level 2: 仅在被选中的边上做 Patch 级别交叉注意力
输入: z [B, C, N, D]
输出: z_out 同形状
"""
def __init__(
self,
n_channel: int,
dim: int,
max_degree: int = None, # 可选:限制每行最多边数
thr: float = 0.5, # 保留边阈值,例如 0.5/0.7
thr_min: float = None, # 动态阈值起点,不传则用 thr
thr_max: float = None, # 动态阈值终点,不传则用 thr
thr_steps: int = 0, # 从 thr_min -> thr_max 的步数,>0 时启用动态调度
thr_schedule: str = "linear", # "linear" | "cosine" | "exp"
temperature: float = 2./3.,
tau_attn: float = 1.0, # Patch attention 温度(可选)
symmetric: bool = True, # 是否对称化通道图
degree_rescale: str = "none", # "none" | "count" | "count-sqrt" | "sum"
init_log_alpha: float = -2.0
):
super().__init__()
self.C = n_channel
self.dim = dim
self.max_degree = max_degree
self.thr = thr
self.tau_attn = tau_attn
self.symmetric = symmetric
self.degree_rescale = degree_rescale
self.thr_min = thr if (thr_min is None) else float(thr_min)
self.thr_max = thr if (thr_max is None) else float(thr_max)
self.thr_steps = int(thr_steps) if thr_steps is not None else 0
self.thr_schedule = thr_schedule
self._use_dynamic_thr = (self.thr_steps > 0) and (abs(self.thr_max - self.thr_min) > 1e-12)
# 用 buffer 记录已步进次数(不保存到权重里)
self.register_buffer("_thr_step", torch.zeros((), dtype=torch.long), persistent=False)
# Level 1: 非归一化门控
self.gate = HardConcreteGate(
shape=(n_channel, n_channel),
temperature=temperature,
init_log_alpha=init_log_alpha
)
# 可选 SE你原来的 se 可以用来生成样本相关的通道优先级,但这里先保留接口)
self.se = nn.Sequential(
nn.Linear(dim, dim // 4, bias=False), nn.SiLU(),
nn.Linear(dim // 4, 1, bias=False), nn.Sigmoid()
)
# Level 2: Patch Cross-Attention
self.q_proj = nn.Linear(dim, dim)
self.k_proj = nn.Linear(dim, dim)
self.v_proj = nn.Linear(dim, dim)
self.out_proj = nn.Linear(dim, dim)
self.norm = nn.LayerNorm(dim)
def _compute_thr_by_progress(self, progress: float) -> float:
# progress in [0,1]
progress = max(0.0, min(1.0, float(progress)))
if self.thr_schedule == "linear":
g = progress
elif self.thr_schedule == "cosine":
# 慢起步,后期加速
import math
g = 0.5 - 0.5 * math.cos(math.pi * progress)
elif self.thr_schedule == "exp":
# 更快从 thr_min 过渡到 thr_max指数式
import math
k = 5.0
g = (math.exp(k * progress) - 1.0) / (math.exp(k) - 1.0)
else:
g = progress
return self.thr_min + (self.thr_max - self.thr_min) * g
def _maybe_update_thr(self):
if self.training and self._use_dynamic_thr:
step = int(self._thr_step.item())
progress = step / float(self.thr_steps)
self.thr = float(self._compute_thr_by_progress(progress))
self._thr_step += 1
def _build_sparse_neighbors(self, z_gate):
"""
基于 z_gate 构造每行的邻接列表按阈值与可选top-k
返回:
- idx_list: 长度C的list每项是LongTensor[idx_j]
- w_list: 长度C的list每项是FloatTensor[w_j](非归一化)
"""
C = z_gate.size(0)
# 去对角
z_gate = z_gate.clone()
z_gate.fill_diagonal_(0.0)
if self.symmetric:
z_gate = 0.5 * (z_gate + z_gate.t())
z_gate.fill_diagonal_(0.0)
idx_list, w_list = [], []
for i in range(C):
row = z_gate[i] # [C]
# 阈值筛选
mask = row > self.thr
if mask.any():
vals = row[mask]
idxs = torch.nonzero(mask, as_tuple=False).squeeze(-1)
# 可选最多度数限制
if (self.max_degree is not None) and (idxs.numel() > self.max_degree):
topk = torch.topk(vals, k=self.max_degree, dim=0)
vals = topk.values
idxs = idxs[topk.indices]
else:
idxs = torch.empty((0,), dtype=torch.long, device=row.device)
vals = torch.empty((0,), dtype=row.dtype, device=row.device)
idx_list.append(idxs)
w_list.append(vals)
return idx_list, w_list
def _degree_rescale(self, ctx, w_sel):
"""
非归一化聚合的稳定性处理。可选对聚合值做degree归一化以稳定数值。
ctx: [B, k, N, D]
w_sel: [k]
"""
if self.degree_rescale == "none":
return (ctx * w_sel.view(1, -1, 1, 1)).sum(dim=1)
elif self.degree_rescale == "count":
k = max(1, w_sel.numel())
return (ctx * w_sel.view(1, -1, 1, 1)).sum(dim=1) / float(k)
elif self.degree_rescale == "count-sqrt":
k = max(1, w_sel.numel())
return (ctx * w_sel.view(1, -1, 1, 1)).sum(dim=1) / math.sqrt(k)
elif self.degree_rescale == "sum":
s = float(w_sel.sum().clamp(min=1e-6))
return (ctx * w_sel.view(1, -1, 1, 1)).sum(dim=1) / s
else:
return (ctx * w_sel.view(1, -1, 1, 1)).sum(dim=1)
def l0_loss(self, lam: float = 1e-4):
"""
期望L0正则鼓励稀疏邻接可调强度
"""
return lam * self.gate.expected_l0().sum()
def forward(self, z):
self._maybe_update_thr()
# z: [B, C, N, D]
B, C, N, D = z.shape
assert C == self.C and D == self.dim
# Level 1: 采样非归一化门 z_gate ∈ [0,1]
z_gate = self.gate.sample(training=self.training) # [C, C]
# 构建稀疏邻居(阈值 + 可选 top-k
idx_list, w_list = self._build_sparse_neighbors(z_gate)
# Level 2: 仅对被保留的边做跨通道 Patch 交互
out_z = torch.zeros_like(z)
for i in range(C):
target_z = z[:, i, :, :] # [B, N, D]
idx = idx_list[i]
if idx.numel() == 0:
# 空邻域:允许“没有相关通道”,仅残差/归一化
out_z[:, i, :, :] = self.norm(target_z)
continue
w_sel = w_list[i] # [k], 非归一化权重,范围[0,1]
k_i = idx.numel()
source_z = z[:, idx, :, :] # [B, k, N, D]
Q = self.q_proj(target_z) # [B, N, D]
K = self.k_proj(source_z.reshape(B * k_i, N, D)).reshape(B, k_i, N, D)
V = self.v_proj(source_z.reshape(B * k_i, N, D)).reshape(B, k_i, N, D)
# 跨通道 patch 注意力
attn_scores = torch.einsum('bnd,bkmd->bknm', Q, K) / math.sqrt(D)
if self.tau_attn != 1.0:
attn_scores = attn_scores / self.tau_attn
attn_probs = F.softmax(attn_scores, dim=-1) # [B, k, N, N]
context = torch.einsum('bknm,bkmd->bknd', attn_probs, V) # [B, k, N, D]
# 非归一化通道权重聚合 + 可选度归一化(仅数值稳定,不改变“非归一化”的语义)
aggregated_context = self._degree_rescale(context, w_sel) # [B, N, D]
out_z[:, i, :, :] = self.norm(target_z + self.out_proj(aggregated_context))
return out_z