feat(model): add initial PatchTST model architecture and utilities

This commit is contained in:
game-loader
2025-08-28 13:23:06 +08:00
parent 4129832f98
commit 59b23d4637
6 changed files with 1142 additions and 0 deletions

136
layers/mixer.py Normal file
View File

@ -0,0 +1,136 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class ChannelGraphMixer(nn.Module):
"""
在 PatchTST 的通道独立输出上做一次可学习的稀疏跨通道交互.
输入 z_list : 长度=M 的 list, 每个元素形状 [B, D] (单通道表示)
输出 list, 形状同输入
"""
def __init__(self, n_channel: int, dim: int, k: int = 5, tau: float = 0.2):
super().__init__()
self.k = k
self.tau = tau
self.dim = dim
self.A = nn.Parameter(torch.zeros(n_channel, n_channel)) # 可学习邻接
self.mix = nn.Linear(dim, dim, bias=False) # 通道映射
# 通道注意力过滤 CAF
self.se = nn.Sequential(
nn.Linear(dim, dim // 4, bias=False),
nn.ReLU(),
nn.Linear(dim // 4, 1, bias=False),
nn.Sigmoid()
)
# -------- util: 生成 row-wise top-k 稀疏邻接 -------------
def _row_sparse(self, logits: torch.Tensor) -> torch.Tensor:
"""
logits: [M,M]. 返回一个稀疏矩阵, 每行只保留 top-k, 其余为 0
"""
# Straight-through GumbelSoftmax
g = -torch.empty_like(logits).exponential_().log()
y = (logits + g) / self.tau
probs = F.softmax(y, dim=-1) # 可导
topk_val, _ = torch.topk(probs, self.k, dim=-1)
thr = topk_val[..., -1].unsqueeze(-1) # 每行阈值
sparse = torch.where(probs >= thr, probs, torch.zeros_like(probs))
return sparse.detach() + probs - probs.detach() # ST-estimator
# --------------------------------------------------------
def forward(self, z_list):
M = len(z_list)
B = z_list[0].shape[0]
Z = torch.stack(z_list, dim=1) # [B,M,D]
alpha = self.se(Z).squeeze(-1) # [B,M] 通道重要性
A_sparse = self._row_sparse(self.A) # [M,M]
out = []
for i in range(M):
# 汇聚来自其余通道的表示
agg = 0
for j in range(M):
if A_sparse[i, j] != 0:
agg = agg + alpha[:, j:j+1] * A_sparse[i, j] * self.mix(z_list[j])
out.append(z_list[i] + agg) # 残差
return out
class HierarchicalGraphMixer(nn.Module):
"""
分层图混合器,同时考虑宏观通道关系和微观 Patch 级别注意力。
输入 z : 形状为 [B, C, N, D] 的张量
输出 z_out : 形状同输入
"""
def __init__(self, n_channel: int, dim: int, k: int = 5, tau: float = 0.2):
super().__init__()
self.k = k
self.tau = tau
# Level 1: Channel Graph
self.A = nn.Parameter(torch.zeros(n_channel, n_channel))
self.se = nn.Sequential(
nn.Linear(dim, dim // 4, bias=False), nn.ReLU(),
nn.Linear(dim // 4, 1, bias=False), nn.Sigmoid()
)
# Level 2: Patch Cross-Attention
self.q_proj = nn.Linear(dim, dim)
self.k_proj = nn.Linear(dim, dim)
self.v_proj = nn.Linear(dim, dim)
self.out_proj = nn.Linear(dim, dim)
self.norm = nn.LayerNorm(dim)
def _row_sparse(self, logits: torch.Tensor) -> torch.Tensor:
# (Gumbel-Softmax 函数,无需改变)
g = -torch.empty_like(logits).exponential_().log()
y = (logits + g) / self.tau
probs = F.softmax(y, dim=-1)
topk_val, _ = torch.topk(probs, self.k, dim=-1)
thr = topk_val[..., -1].unsqueeze(-1)
sparse = torch.where(probs >= thr, probs, torch.zeros_like(probs))
return sparse.detach() + probs - probs.detach()
def forward(self, z):
# z 的形状: [B, C, N, D]
B, C, N, D = z.shape
# --- Level 1: 计算宏观权重 ---
A_sparse = self._row_sparse(self.A) # 通道连接稀疏图 A_sparse: [C, C]
# --- Level 2: 跨通道 Patch 交互 ---
out_z = torch.zeros_like(z)
for i in range(C): # 遍历每个目标通道 i
target_z = z[:, i, :, :] # [B, N, D]
# 准备聚合来自其他通道的 patch 级别上下文
aggregated_context = torch.zeros_like(target_z)
for j in range(C): # 遍历每个源通道 j
if A_sparse[i, j] != 0:
source_z = z[:, j, :, :] # [B, N, D]
# --- 执行交叉注意力 ---
Q = self.q_proj(target_z) # Query 来自目标通道 i
K = self.k_proj(source_z) # Key 来自源通道 j
V = self.v_proj(source_z) # Value 来自源通道 j
attn_scores = torch.bmm(Q, K.transpose(1, 2)) / math.sqrt(D)
attn_probs = F.softmax(attn_scores, dim=-1) # [B, N, N]
context = torch.bmm(attn_probs, V) # [B, N, D], 从 j 聚合到 i 的上下文
# 宏观权重2 (通道连接强度): A_sparse[i, j] -> 标量
# 加权上下文
weighted_context = A_sparse[i, j] * context
aggregated_context = aggregated_context + weighted_context
# 将聚合后的上下文通过输出层,并与原始目标表示相加(残差连接)
# LayerNorm 增加稳定性
out_z[:, i, :, :] = self.norm(target_z + self.out_proj(aggregated_context))
return out_z