feat(model): add initial PatchTST model architecture and utilities

This commit is contained in:
game-loader
2025-08-28 13:23:06 +08:00
parent 4129832f98
commit 59b23d4637
6 changed files with 1142 additions and 0 deletions

View File

@ -0,0 +1,132 @@
import torch
from torch import nn
import math
class CrossChannelAttention(nn.Module):
"""
对每个通道 i
Query: 预测区间长度 pred_len 的可学习向量 (不依赖历史值)
Key/Value: 其它通道所有时间点的标量 -> 线性投影
输出: [B, pred_len, C]
复杂度: O(C * pred_len * (C-1) * L),当 C 很小(如 <= 64)通常可接受
"""
def __init__(self, seq_len, pred_len, c_in,
d_model=64, n_heads=4,
dropout=0.1, use_layernorm=True):
super().__init__()
assert d_model % n_heads == 0, "d_model 必须能整除 n_heads"
self.seq_len = seq_len
self.pred_len = pred_len
self.c_in = c_in
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
# 可学习的预测步 Query Embeddings: [pred_len, d_model]
self.query_embed = nn.Parameter(torch.randn(pred_len, d_model))
# 标量值 -> d_model 投影 (共享给 Key/Value)
self.key_proj = nn.Linear(1, d_model)
self.value_proj = nn.Linear(1, d_model)
# 输出压缩成标量
self.out_proj = nn.Linear(d_model, 1)
self.attn_dropout = nn.Dropout(dropout)
self.proj_dropout = nn.Dropout(dropout)
self.use_ln = use_layernorm
if use_layernorm:
self.ln_q = nn.LayerNorm(d_model)
self.ln_kv = nn.LayerNorm(d_model)
# 可选的时间 + 通道位置编码(简单可学习向量)
self.time_pos = nn.Parameter(torch.zeros(seq_len, d_model))
self.channel_pos = nn.Parameter(torch.zeros(c_in, d_model))
nn.init.normal_(self.time_pos, std=0.02)
nn.init.normal_(self.channel_pos, std=0.02)
nn.init.normal_(self.query_embed, std=0.02)
def split_heads(self, x):
# x: [B, T, d_model] -> [B, n_heads, T, head_dim]
B, T, D = x.shape
return x.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
def merge_heads(self, x):
# x: [B, n_heads, T, head_dim] -> [B, T, d_model]
B, H, T, Hd = x.shape
return x.transpose(1, 2).contiguous().view(B, T, H * Hd)
def forward(self, x):
"""
x: [B, L, C]
返回: cross_out [B, pred_len, C]
"""
B, L, C = x.shape
assert L == self.seq_len and C == self.c_in
# 准备 K,V: 对每个通道的时间序列做投影
# 先变成 [B, C, L, 1] -> Linear -> [B, C, L, d]
xc = x.permute(0, 2, 1).unsqueeze(-1) # [B, C, L, 1]
K = self.key_proj(xc) # [B, C, L, d_model]
V = self.value_proj(xc) # [B, C, L, d_model]
# 加位置编码(通道 + 时间)
# broadcast: time_pos [L, d_model] -> [1,1,L,d]; channel_pos [C,d_model]->[1,C,1,d]
K = K + self.time_pos.unsqueeze(0).unsqueeze(0) + self.channel_pos.unsqueeze(0).unsqueeze(2)
V = V + self.time_pos.unsqueeze(0).unsqueeze(0) + self.channel_pos.unsqueeze(0).unsqueeze(2)
if self.use_ln:
K = self.ln_kv(K)
V = self.ln_kv(V)
cross_outputs = []
# 预备 Query所有通道共享 query 基形,再可选加通道偏移)
base_q = self.query_embed # [pred_len, d_model]
for ci in range(C):
# 构造其它通道索引
if C == 1:
# 单通道退化: 直接输出零或复制自身
zero_out = x[:, -self.pred_len:, ci:ci+1]
cross_outputs.append(zero_out)
continue
other_idx = [j for j in range(C) if j != ci]
K_i = K[:, other_idx, :, :] # [B, C-1, L, d_model]
V_i = V[:, other_idx, :, :] # [B, C-1, L, d_model]
# 拉平成 token 维度
K_i = K_i.reshape(B, (C-1)*L, self.d_model) # [B, (C-1)*L, d]
V_i = V_i.reshape(B, (C-1)*L, self.d_model)
# Query: 复制到 batch并加通道偏移可选
Q_i = base_q.unsqueeze(0).expand(B, self.pred_len, self.d_model) # [B, pred_len, d_model]
Q_i = Q_i + self.channel_pos[ci].unsqueeze(0).unsqueeze(0)
if self.use_ln:
Q_i = self.ln_q(Q_i)
# 分头
Qh = self.split_heads(Q_i) # [B, H, pred_len, head_dim]
Kh = self.split_heads(K_i) # [B, H, (C-1)*L, head_dim]
Vh = self.split_heads(V_i) # [B, H, (C-1)*L, head_dim]
# Attention: [B, H, pred_len, (C-1)*L]
scores = torch.matmul(Qh, Kh.transpose(-2, -1)) / math.sqrt(self.head_dim)
attn = torch.softmax(scores, dim=-1)
attn = self.attn_dropout(attn)
# 上下文
ctx = torch.matmul(attn, Vh) # [B, H, pred_len, head_dim]
ctx = self.merge_heads(ctx) # [B, pred_len, d_model]
ctx = self.proj_dropout(ctx)
# 输出压缩到标量: [B, pred_len, 1]
out_ci = self.out_proj(ctx)
cross_outputs.append(out_ci) # list of [B, pred_len, 1]
cross_out = torch.cat(cross_outputs, dim=-1) # [B, pred_len, C]
return cross_out