Files
tsmodel/layers/mixer.py

137 lines
5.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class ChannelGraphMixer(nn.Module):
"""
在 PatchTST 的通道独立输出上做一次可学习的稀疏跨通道交互.
输入 z_list : 长度=M 的 list, 每个元素形状 [B, D] (单通道表示)
输出 list, 形状同输入
"""
def __init__(self, n_channel: int, dim: int, k: int = 5, tau: float = 0.2):
super().__init__()
self.k = k
self.tau = tau
self.dim = dim
self.A = nn.Parameter(torch.zeros(n_channel, n_channel)) # 可学习邻接
self.mix = nn.Linear(dim, dim, bias=False) # 通道映射
# 通道注意力过滤 CAF
self.se = nn.Sequential(
nn.Linear(dim, dim // 4, bias=False),
nn.ReLU(),
nn.Linear(dim // 4, 1, bias=False),
nn.Sigmoid()
)
# -------- util: 生成 row-wise top-k 稀疏邻接 -------------
def _row_sparse(self, logits: torch.Tensor) -> torch.Tensor:
"""
logits: [M,M]. 返回一个稀疏矩阵, 每行只保留 top-k, 其余为 0
"""
# Straight-through GumbelSoftmax
g = -torch.empty_like(logits).exponential_().log()
y = (logits + g) / self.tau
probs = F.softmax(y, dim=-1) # 可导
topk_val, _ = torch.topk(probs, self.k, dim=-1)
thr = topk_val[..., -1].unsqueeze(-1) # 每行阈值
sparse = torch.where(probs >= thr, probs, torch.zeros_like(probs))
return sparse.detach() + probs - probs.detach() # ST-estimator
# --------------------------------------------------------
def forward(self, z_list):
M = len(z_list)
B = z_list[0].shape[0]
Z = torch.stack(z_list, dim=1) # [B,M,D]
alpha = self.se(Z).squeeze(-1) # [B,M] 通道重要性
A_sparse = self._row_sparse(self.A) # [M,M]
out = []
for i in range(M):
# 汇聚来自其余通道的表示
agg = 0
for j in range(M):
if A_sparse[i, j] != 0:
agg = agg + alpha[:, j:j+1] * A_sparse[i, j] * self.mix(z_list[j])
out.append(z_list[i] + agg) # 残差
return out
class HierarchicalGraphMixer(nn.Module):
"""
分层图混合器,同时考虑宏观通道关系和微观 Patch 级别注意力。
输入 z : 形状为 [B, C, N, D] 的张量
输出 z_out : 形状同输入
"""
def __init__(self, n_channel: int, dim: int, k: int = 5, tau: float = 0.2):
super().__init__()
self.k = k
self.tau = tau
# Level 1: Channel Graph
self.A = nn.Parameter(torch.zeros(n_channel, n_channel))
self.se = nn.Sequential(
nn.Linear(dim, dim // 4, bias=False), nn.ReLU(),
nn.Linear(dim // 4, 1, bias=False), nn.Sigmoid()
)
# Level 2: Patch Cross-Attention
self.q_proj = nn.Linear(dim, dim)
self.k_proj = nn.Linear(dim, dim)
self.v_proj = nn.Linear(dim, dim)
self.out_proj = nn.Linear(dim, dim)
self.norm = nn.LayerNorm(dim)
def _row_sparse(self, logits: torch.Tensor) -> torch.Tensor:
# (Gumbel-Softmax 函数,无需改变)
g = -torch.empty_like(logits).exponential_().log()
y = (logits + g) / self.tau
probs = F.softmax(y, dim=-1)
topk_val, _ = torch.topk(probs, self.k, dim=-1)
thr = topk_val[..., -1].unsqueeze(-1)
sparse = torch.where(probs >= thr, probs, torch.zeros_like(probs))
return sparse.detach() + probs - probs.detach()
def forward(self, z):
# z 的形状: [B, C, N, D]
B, C, N, D = z.shape
# --- Level 1: 计算宏观权重 ---
A_sparse = self._row_sparse(self.A) # 通道连接稀疏图 A_sparse: [C, C]
# --- Level 2: 跨通道 Patch 交互 ---
out_z = torch.zeros_like(z)
for i in range(C): # 遍历每个目标通道 i
target_z = z[:, i, :, :] # [B, N, D]
# 准备聚合来自其他通道的 patch 级别上下文
aggregated_context = torch.zeros_like(target_z)
for j in range(C): # 遍历每个源通道 j
if A_sparse[i, j] != 0:
source_z = z[:, j, :, :] # [B, N, D]
# --- 执行交叉注意力 ---
Q = self.q_proj(target_z) # Query 来自目标通道 i
K = self.k_proj(source_z) # Key 来自源通道 j
V = self.v_proj(source_z) # Value 来自源通道 j
attn_scores = torch.bmm(Q, K.transpose(1, 2)) / math.sqrt(D)
attn_probs = F.softmax(attn_scores, dim=-1) # [B, N, N]
context = torch.bmm(attn_probs, V) # [B, N, D], 从 j 聚合到 i 的上下文
# 宏观权重2 (通道连接强度): A_sparse[i, j] -> 标量
# 加权上下文
weighted_context = A_sparse[i, j] * context
aggregated_context = aggregated_context + weighted_context
# 将聚合后的上下文通过输出层,并与原始目标表示相加(残差连接)
# LayerNorm 增加稳定性
out_z[:, i, :, :] = self.norm(target_z + self.out_proj(aggregated_context))
return out_z