Files
tsmodel/layers/cross_channel_attn.py

133 lines
5.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
from torch import nn
import math
class CrossChannelAttention(nn.Module):
"""
对每个通道 i
Query: 预测区间长度 pred_len 的可学习向量 (不依赖历史值)
Key/Value: 其它通道所有时间点的标量 -> 线性投影
输出: [B, pred_len, C]
复杂度: O(C * pred_len * (C-1) * L),当 C 很小(如 <= 64)通常可接受
"""
def __init__(self, seq_len, pred_len, c_in,
d_model=64, n_heads=4,
dropout=0.1, use_layernorm=True):
super().__init__()
assert d_model % n_heads == 0, "d_model 必须能整除 n_heads"
self.seq_len = seq_len
self.pred_len = pred_len
self.c_in = c_in
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
# 可学习的预测步 Query Embeddings: [pred_len, d_model]
self.query_embed = nn.Parameter(torch.randn(pred_len, d_model))
# 标量值 -> d_model 投影 (共享给 Key/Value)
self.key_proj = nn.Linear(1, d_model)
self.value_proj = nn.Linear(1, d_model)
# 输出压缩成标量
self.out_proj = nn.Linear(d_model, 1)
self.attn_dropout = nn.Dropout(dropout)
self.proj_dropout = nn.Dropout(dropout)
self.use_ln = use_layernorm
if use_layernorm:
self.ln_q = nn.LayerNorm(d_model)
self.ln_kv = nn.LayerNorm(d_model)
# 可选的时间 + 通道位置编码(简单可学习向量)
self.time_pos = nn.Parameter(torch.zeros(seq_len, d_model))
self.channel_pos = nn.Parameter(torch.zeros(c_in, d_model))
nn.init.normal_(self.time_pos, std=0.02)
nn.init.normal_(self.channel_pos, std=0.02)
nn.init.normal_(self.query_embed, std=0.02)
def split_heads(self, x):
# x: [B, T, d_model] -> [B, n_heads, T, head_dim]
B, T, D = x.shape
return x.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
def merge_heads(self, x):
# x: [B, n_heads, T, head_dim] -> [B, T, d_model]
B, H, T, Hd = x.shape
return x.transpose(1, 2).contiguous().view(B, T, H * Hd)
def forward(self, x):
"""
x: [B, L, C]
返回: cross_out [B, pred_len, C]
"""
B, L, C = x.shape
assert L == self.seq_len and C == self.c_in
# 准备 K,V: 对每个通道的时间序列做投影
# 先变成 [B, C, L, 1] -> Linear -> [B, C, L, d]
xc = x.permute(0, 2, 1).unsqueeze(-1) # [B, C, L, 1]
K = self.key_proj(xc) # [B, C, L, d_model]
V = self.value_proj(xc) # [B, C, L, d_model]
# 加位置编码(通道 + 时间)
# broadcast: time_pos [L, d_model] -> [1,1,L,d]; channel_pos [C,d_model]->[1,C,1,d]
K = K + self.time_pos.unsqueeze(0).unsqueeze(0) + self.channel_pos.unsqueeze(0).unsqueeze(2)
V = V + self.time_pos.unsqueeze(0).unsqueeze(0) + self.channel_pos.unsqueeze(0).unsqueeze(2)
if self.use_ln:
K = self.ln_kv(K)
V = self.ln_kv(V)
cross_outputs = []
# 预备 Query所有通道共享 query 基形,再可选加通道偏移)
base_q = self.query_embed # [pred_len, d_model]
for ci in range(C):
# 构造其它通道索引
if C == 1:
# 单通道退化: 直接输出零或复制自身
zero_out = x[:, -self.pred_len:, ci:ci+1]
cross_outputs.append(zero_out)
continue
other_idx = [j for j in range(C) if j != ci]
K_i = K[:, other_idx, :, :] # [B, C-1, L, d_model]
V_i = V[:, other_idx, :, :] # [B, C-1, L, d_model]
# 拉平成 token 维度
K_i = K_i.reshape(B, (C-1)*L, self.d_model) # [B, (C-1)*L, d]
V_i = V_i.reshape(B, (C-1)*L, self.d_model)
# Query: 复制到 batch并加通道偏移可选
Q_i = base_q.unsqueeze(0).expand(B, self.pred_len, self.d_model) # [B, pred_len, d_model]
Q_i = Q_i + self.channel_pos[ci].unsqueeze(0).unsqueeze(0)
if self.use_ln:
Q_i = self.ln_q(Q_i)
# 分头
Qh = self.split_heads(Q_i) # [B, H, pred_len, head_dim]
Kh = self.split_heads(K_i) # [B, H, (C-1)*L, head_dim]
Vh = self.split_heads(V_i) # [B, H, (C-1)*L, head_dim]
# Attention: [B, H, pred_len, (C-1)*L]
scores = torch.matmul(Qh, Kh.transpose(-2, -1)) / math.sqrt(self.head_dim)
attn = torch.softmax(scores, dim=-1)
attn = self.attn_dropout(attn)
# 上下文
ctx = torch.matmul(attn, Vh) # [B, H, pred_len, head_dim]
ctx = self.merge_heads(ctx) # [B, pred_len, d_model]
ctx = self.proj_dropout(ctx)
# 输出压缩到标量: [B, pred_len, 1]
out_ci = self.out_proj(ctx)
cross_outputs.append(out_ci) # list of [B, pred_len, 1]
cross_out = torch.cat(cross_outputs, dim=-1) # [B, pred_len, C]
return cross_out