import torch import torch.nn as nn import torch.nn.functional as F import math class ChannelGraphMixer(nn.Module): """ 在 PatchTST 的通道独立输出上做一次可学习的稀疏跨通道交互. 输入 z_list : 长度=M 的 list, 每个元素形状 [B, D] (单通道表示) 输出 list, 形状同输入 """ def __init__(self, n_channel: int, dim: int, k: int = 5, tau: float = 0.2): super().__init__() self.k = k self.tau = tau self.dim = dim self.A = nn.Parameter(torch.zeros(n_channel, n_channel)) # 可学习邻接 self.mix = nn.Linear(dim, dim, bias=False) # 通道映射 # 通道注意力过滤 CAF self.se = nn.Sequential( nn.Linear(dim, dim // 4, bias=False), nn.ReLU(), nn.Linear(dim // 4, 1, bias=False), nn.Sigmoid() ) # -------- util: 生成 row-wise top-k 稀疏邻接 ------------- def _row_sparse(self, logits: torch.Tensor) -> torch.Tensor: """ logits: [M,M]. 返回一个稀疏矩阵, 每行只保留 top-k, 其余为 0 """ # Straight-through Gumbel–Softmax g = -torch.empty_like(logits).exponential_().log() y = (logits + g) / self.tau probs = F.softmax(y, dim=-1) # 可导 topk_val, _ = torch.topk(probs, self.k, dim=-1) thr = topk_val[..., -1].unsqueeze(-1) # 每行阈值 sparse = torch.where(probs >= thr, probs, torch.zeros_like(probs)) return sparse.detach() + probs - probs.detach() # ST-estimator # -------------------------------------------------------- def forward(self, z_list): M = len(z_list) B = z_list[0].shape[0] Z = torch.stack(z_list, dim=1) # [B,M,D] alpha = self.se(Z).squeeze(-1) # [B,M] 通道重要性 A_sparse = self._row_sparse(self.A) # [M,M] out = [] for i in range(M): # 汇聚来自其余通道的表示 agg = 0 for j in range(M): if A_sparse[i, j] != 0: agg = agg + alpha[:, j:j+1] * A_sparse[i, j] * self.mix(z_list[j]) out.append(z_list[i] + agg) # 残差 return out class HierarchicalGraphMixer(nn.Module): """ 分层图混合器,同时考虑宏观通道关系和微观 Patch 级别注意力。 输入 z : 形状为 [B, C, N, D] 的张量 输出 z_out : 形状同输入 """ def __init__(self, n_channel: int, dim: int, k: int = 5, tau: float = 0.2): super().__init__() self.k = k self.tau = tau # Level 1: Channel Graph self.A = nn.Parameter(torch.zeros(n_channel, n_channel)) self.se = nn.Sequential( nn.Linear(dim, dim // 4, bias=False), nn.ReLU(), nn.Linear(dim // 4, 1, bias=False), nn.Sigmoid() ) # Level 2: Patch Cross-Attention self.q_proj = nn.Linear(dim, dim) self.k_proj = nn.Linear(dim, dim) self.v_proj = nn.Linear(dim, dim) self.out_proj = nn.Linear(dim, dim) self.norm = nn.LayerNorm(dim) def _row_sparse(self, logits: torch.Tensor) -> torch.Tensor: # (Gumbel-Softmax 函数,无需改变) g = -torch.empty_like(logits).exponential_().log() y = (logits + g) / self.tau probs = F.softmax(y, dim=-1) topk_val, _ = torch.topk(probs, self.k, dim=-1) thr = topk_val[..., -1].unsqueeze(-1) sparse = torch.where(probs >= thr, probs, torch.zeros_like(probs)) return sparse.detach() + probs - probs.detach() def forward(self, z): # z 的形状: [B, C, N, D] B, C, N, D = z.shape # --- Level 1: 计算宏观权重 --- A_sparse = self._row_sparse(self.A) # 通道连接稀疏图 A_sparse: [C, C] # --- Level 2: 跨通道 Patch 交互 --- out_z = torch.zeros_like(z) for i in range(C): # 遍历每个目标通道 i target_z = z[:, i, :, :] # [B, N, D] # 准备聚合来自其他通道的 patch 级别上下文 aggregated_context = torch.zeros_like(target_z) for j in range(C): # 遍历每个源通道 j if A_sparse[i, j] != 0: source_z = z[:, j, :, :] # [B, N, D] # --- 执行交叉注意力 --- Q = self.q_proj(target_z) # Query 来自目标通道 i K = self.k_proj(source_z) # Key 来自源通道 j V = self.v_proj(source_z) # Value 来自源通道 j attn_scores = torch.bmm(Q, K.transpose(1, 2)) / math.sqrt(D) attn_probs = F.softmax(attn_scores, dim=-1) # [B, N, N] context = torch.bmm(attn_probs, V) # [B, N, D], 从 j 聚合到 i 的上下文 # 宏观权重2 (通道连接强度): A_sparse[i, j] -> 标量 # 加权上下文 weighted_context = A_sparse[i, j] * context aggregated_context = aggregated_context + weighted_context # 将聚合后的上下文通过输出层,并与原始目标表示相加(残差连接) # LayerNorm 增加稳定性 out_z[:, i, :, :] = self.norm(target_z + self.out_proj(aggregated_context)) return out_z