import torch from torch import nn import math class CrossChannelAttention(nn.Module): """ 对每个通道 i: Query: 预测区间长度 pred_len 的可学习向量 (不依赖历史值) Key/Value: 其它通道所有时间点的标量 -> 线性投影 输出: [B, pred_len, C] 复杂度: O(C * pred_len * (C-1) * L),当 C 很小(如 <= 64)通常可接受 """ def __init__(self, seq_len, pred_len, c_in, d_model=64, n_heads=4, dropout=0.1, use_layernorm=True): super().__init__() assert d_model % n_heads == 0, "d_model 必须能整除 n_heads" self.seq_len = seq_len self.pred_len = pred_len self.c_in = c_in self.d_model = d_model self.n_heads = n_heads self.head_dim = d_model // n_heads # 可学习的预测步 Query Embeddings: [pred_len, d_model] self.query_embed = nn.Parameter(torch.randn(pred_len, d_model)) # 标量值 -> d_model 投影 (共享给 Key/Value) self.key_proj = nn.Linear(1, d_model) self.value_proj = nn.Linear(1, d_model) # 输出压缩成标量 self.out_proj = nn.Linear(d_model, 1) self.attn_dropout = nn.Dropout(dropout) self.proj_dropout = nn.Dropout(dropout) self.use_ln = use_layernorm if use_layernorm: self.ln_q = nn.LayerNorm(d_model) self.ln_kv = nn.LayerNorm(d_model) # 可选的时间 + 通道位置编码(简单可学习向量) self.time_pos = nn.Parameter(torch.zeros(seq_len, d_model)) self.channel_pos = nn.Parameter(torch.zeros(c_in, d_model)) nn.init.normal_(self.time_pos, std=0.02) nn.init.normal_(self.channel_pos, std=0.02) nn.init.normal_(self.query_embed, std=0.02) def split_heads(self, x): # x: [B, T, d_model] -> [B, n_heads, T, head_dim] B, T, D = x.shape return x.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) def merge_heads(self, x): # x: [B, n_heads, T, head_dim] -> [B, T, d_model] B, H, T, Hd = x.shape return x.transpose(1, 2).contiguous().view(B, T, H * Hd) def forward(self, x): """ x: [B, L, C] 返回: cross_out [B, pred_len, C] """ B, L, C = x.shape assert L == self.seq_len and C == self.c_in # 准备 K,V: 对每个通道的时间序列做投影 # 先变成 [B, C, L, 1] -> Linear -> [B, C, L, d] xc = x.permute(0, 2, 1).unsqueeze(-1) # [B, C, L, 1] K = self.key_proj(xc) # [B, C, L, d_model] V = self.value_proj(xc) # [B, C, L, d_model] # 加位置编码(通道 + 时间) # broadcast: time_pos [L, d_model] -> [1,1,L,d]; channel_pos [C,d_model]->[1,C,1,d] K = K + self.time_pos.unsqueeze(0).unsqueeze(0) + self.channel_pos.unsqueeze(0).unsqueeze(2) V = V + self.time_pos.unsqueeze(0).unsqueeze(0) + self.channel_pos.unsqueeze(0).unsqueeze(2) if self.use_ln: K = self.ln_kv(K) V = self.ln_kv(V) cross_outputs = [] # 预备 Query(所有通道共享 query 基形,再可选加通道偏移) base_q = self.query_embed # [pred_len, d_model] for ci in range(C): # 构造其它通道索引 if C == 1: # 单通道退化: 直接输出零或复制自身 zero_out = x[:, -self.pred_len:, ci:ci+1] cross_outputs.append(zero_out) continue other_idx = [j for j in range(C) if j != ci] K_i = K[:, other_idx, :, :] # [B, C-1, L, d_model] V_i = V[:, other_idx, :, :] # [B, C-1, L, d_model] # 拉平成 token 维度 K_i = K_i.reshape(B, (C-1)*L, self.d_model) # [B, (C-1)*L, d] V_i = V_i.reshape(B, (C-1)*L, self.d_model) # Query: 复制到 batch,并加通道偏移(可选) Q_i = base_q.unsqueeze(0).expand(B, self.pred_len, self.d_model) # [B, pred_len, d_model] Q_i = Q_i + self.channel_pos[ci].unsqueeze(0).unsqueeze(0) if self.use_ln: Q_i = self.ln_q(Q_i) # 分头 Qh = self.split_heads(Q_i) # [B, H, pred_len, head_dim] Kh = self.split_heads(K_i) # [B, H, (C-1)*L, head_dim] Vh = self.split_heads(V_i) # [B, H, (C-1)*L, head_dim] # Attention: [B, H, pred_len, (C-1)*L] scores = torch.matmul(Qh, Kh.transpose(-2, -1)) / math.sqrt(self.head_dim) attn = torch.softmax(scores, dim=-1) attn = self.attn_dropout(attn) # 上下文 ctx = torch.matmul(attn, Vh) # [B, H, pred_len, head_dim] ctx = self.merge_heads(ctx) # [B, pred_len, d_model] ctx = self.proj_dropout(ctx) # 输出压缩到标量: [B, pred_len, 1] out_ci = self.out_proj(ctx) cross_outputs.append(out_ci) # list of [B, pred_len, 1] cross_out = torch.cat(cross_outputs, dim=-1) # [B, pred_len, C] return cross_out