feat(model): add initial PatchTST model architecture and utilities
This commit is contained in:
132
layers/cross_channel_attn.py
Normal file
132
layers/cross_channel_attn.py
Normal file
@ -0,0 +1,132 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import math
|
||||
|
||||
class CrossChannelAttention(nn.Module):
|
||||
"""
|
||||
对每个通道 i:
|
||||
Query: 预测区间长度 pred_len 的可学习向量 (不依赖历史值)
|
||||
Key/Value: 其它通道所有时间点的标量 -> 线性投影
|
||||
输出: [B, pred_len, C]
|
||||
复杂度: O(C * pred_len * (C-1) * L),当 C 很小(如 <= 64)通常可接受
|
||||
"""
|
||||
def __init__(self, seq_len, pred_len, c_in,
|
||||
d_model=64, n_heads=4,
|
||||
dropout=0.1, use_layernorm=True):
|
||||
super().__init__()
|
||||
assert d_model % n_heads == 0, "d_model 必须能整除 n_heads"
|
||||
self.seq_len = seq_len
|
||||
self.pred_len = pred_len
|
||||
self.c_in = c_in
|
||||
self.d_model = d_model
|
||||
self.n_heads = n_heads
|
||||
self.head_dim = d_model // n_heads
|
||||
|
||||
# 可学习的预测步 Query Embeddings: [pred_len, d_model]
|
||||
self.query_embed = nn.Parameter(torch.randn(pred_len, d_model))
|
||||
|
||||
# 标量值 -> d_model 投影 (共享给 Key/Value)
|
||||
self.key_proj = nn.Linear(1, d_model)
|
||||
self.value_proj = nn.Linear(1, d_model)
|
||||
|
||||
# 输出压缩成标量
|
||||
self.out_proj = nn.Linear(d_model, 1)
|
||||
|
||||
self.attn_dropout = nn.Dropout(dropout)
|
||||
self.proj_dropout = nn.Dropout(dropout)
|
||||
|
||||
self.use_ln = use_layernorm
|
||||
if use_layernorm:
|
||||
self.ln_q = nn.LayerNorm(d_model)
|
||||
self.ln_kv = nn.LayerNorm(d_model)
|
||||
|
||||
# 可选的时间 + 通道位置编码(简单可学习向量)
|
||||
self.time_pos = nn.Parameter(torch.zeros(seq_len, d_model))
|
||||
self.channel_pos = nn.Parameter(torch.zeros(c_in, d_model))
|
||||
nn.init.normal_(self.time_pos, std=0.02)
|
||||
nn.init.normal_(self.channel_pos, std=0.02)
|
||||
nn.init.normal_(self.query_embed, std=0.02)
|
||||
|
||||
def split_heads(self, x):
|
||||
# x: [B, T, d_model] -> [B, n_heads, T, head_dim]
|
||||
B, T, D = x.shape
|
||||
return x.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
def merge_heads(self, x):
|
||||
# x: [B, n_heads, T, head_dim] -> [B, T, d_model]
|
||||
B, H, T, Hd = x.shape
|
||||
return x.transpose(1, 2).contiguous().view(B, T, H * Hd)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: [B, L, C]
|
||||
返回: cross_out [B, pred_len, C]
|
||||
"""
|
||||
B, L, C = x.shape
|
||||
assert L == self.seq_len and C == self.c_in
|
||||
|
||||
# 准备 K,V: 对每个通道的时间序列做投影
|
||||
# 先变成 [B, C, L, 1] -> Linear -> [B, C, L, d]
|
||||
xc = x.permute(0, 2, 1).unsqueeze(-1) # [B, C, L, 1]
|
||||
K = self.key_proj(xc) # [B, C, L, d_model]
|
||||
V = self.value_proj(xc) # [B, C, L, d_model]
|
||||
|
||||
# 加位置编码(通道 + 时间)
|
||||
# broadcast: time_pos [L, d_model] -> [1,1,L,d]; channel_pos [C,d_model]->[1,C,1,d]
|
||||
K = K + self.time_pos.unsqueeze(0).unsqueeze(0) + self.channel_pos.unsqueeze(0).unsqueeze(2)
|
||||
V = V + self.time_pos.unsqueeze(0).unsqueeze(0) + self.channel_pos.unsqueeze(0).unsqueeze(2)
|
||||
|
||||
if self.use_ln:
|
||||
K = self.ln_kv(K)
|
||||
V = self.ln_kv(V)
|
||||
|
||||
cross_outputs = []
|
||||
|
||||
# 预备 Query(所有通道共享 query 基形,再可选加通道偏移)
|
||||
base_q = self.query_embed # [pred_len, d_model]
|
||||
|
||||
for ci in range(C):
|
||||
# 构造其它通道索引
|
||||
if C == 1:
|
||||
# 单通道退化: 直接输出零或复制自身
|
||||
zero_out = x[:, -self.pred_len:, ci:ci+1]
|
||||
cross_outputs.append(zero_out)
|
||||
continue
|
||||
other_idx = [j for j in range(C) if j != ci]
|
||||
|
||||
K_i = K[:, other_idx, :, :] # [B, C-1, L, d_model]
|
||||
V_i = V[:, other_idx, :, :] # [B, C-1, L, d_model]
|
||||
|
||||
# 拉平成 token 维度
|
||||
K_i = K_i.reshape(B, (C-1)*L, self.d_model) # [B, (C-1)*L, d]
|
||||
V_i = V_i.reshape(B, (C-1)*L, self.d_model)
|
||||
|
||||
# Query: 复制到 batch,并加通道偏移(可选)
|
||||
Q_i = base_q.unsqueeze(0).expand(B, self.pred_len, self.d_model) # [B, pred_len, d_model]
|
||||
Q_i = Q_i + self.channel_pos[ci].unsqueeze(0).unsqueeze(0)
|
||||
|
||||
if self.use_ln:
|
||||
Q_i = self.ln_q(Q_i)
|
||||
|
||||
# 分头
|
||||
Qh = self.split_heads(Q_i) # [B, H, pred_len, head_dim]
|
||||
Kh = self.split_heads(K_i) # [B, H, (C-1)*L, head_dim]
|
||||
Vh = self.split_heads(V_i) # [B, H, (C-1)*L, head_dim]
|
||||
|
||||
# Attention: [B, H, pred_len, (C-1)*L]
|
||||
scores = torch.matmul(Qh, Kh.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
attn = torch.softmax(scores, dim=-1)
|
||||
attn = self.attn_dropout(attn)
|
||||
|
||||
# 上下文
|
||||
ctx = torch.matmul(attn, Vh) # [B, H, pred_len, head_dim]
|
||||
ctx = self.merge_heads(ctx) # [B, pred_len, d_model]
|
||||
ctx = self.proj_dropout(ctx)
|
||||
|
||||
# 输出压缩到标量: [B, pred_len, 1]
|
||||
out_ci = self.out_proj(ctx)
|
||||
cross_outputs.append(out_ci) # list of [B, pred_len, 1]
|
||||
|
||||
cross_out = torch.cat(cross_outputs, dim=-1) # [B, pred_len, C]
|
||||
return cross_out
|
||||
|
Reference in New Issue
Block a user