314 lines
13 KiB
Python
314 lines
13 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import numpy as np
|
||
from math import sqrt
|
||
from utils.masking import TriangularCausalMask, ProbMask
|
||
from reformer_pytorch import LSHSelfAttention
|
||
from einops import rearrange, repeat
|
||
|
||
|
||
class DSAttention(nn.Module):
|
||
'''De-stationary Attention'''
|
||
|
||
def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
|
||
super(DSAttention, self).__init__()
|
||
self.scale = scale
|
||
self.mask_flag = mask_flag
|
||
self.output_attention = output_attention
|
||
self.dropout = nn.Dropout(attention_dropout)
|
||
|
||
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None, key_padding_mask=None):
|
||
"""
|
||
key_padding_mask: (B, S) bool, True=valid, False=pad(可选,忽略或由上层应用)
|
||
"""
|
||
B, L, H, E = queries.shape
|
||
_, S, _, D = values.shape
|
||
scale = self.scale or 1. / sqrt(E)
|
||
|
||
tau = 1.0 if tau is None else tau.unsqueeze(1).unsqueeze(1) # B x 1 x 1 x 1
|
||
delta = 0.0 if delta is None else delta.unsqueeze(1).unsqueeze(1) # B x 1 x 1 x S
|
||
|
||
scores = torch.einsum("blhe,bshe->bhls", queries, keys) * tau + delta # (B,H,L,S)
|
||
|
||
if self.mask_flag:
|
||
if attn_mask is None:
|
||
attn_mask = TriangularCausalMask(B, L, device=queries.device)
|
||
scores.masked_fill_(attn_mask.mask, -np.inf)
|
||
|
||
# 可选:基于key_padding_mask的无效键屏蔽(不改变原行为,默认None)
|
||
if key_padding_mask is not None:
|
||
# key_padding_mask: True 表示有效,False为padding
|
||
invalid_k = (~key_padding_mask).unsqueeze(1).unsqueeze(1) # (B,1,1,S)
|
||
scores = scores.masked_fill(invalid_k, -np.inf)
|
||
|
||
A = self.dropout(torch.softmax(scale * scores, dim=-1))
|
||
V = torch.einsum("bhls,bshd->blhd", A, values)
|
||
|
||
if self.output_attention:
|
||
return V.contiguous(), A
|
||
else:
|
||
return V.contiguous(), None
|
||
|
||
|
||
class FullAttention(nn.Module):
|
||
"""
|
||
修正点:
|
||
- 新增 key_padding_mask 支持,用于屏蔽批内右侧pad的键向量(与DC变长对齐)
|
||
- key_padding_mask 约定:shape=(B, S),True=有效,False=padding
|
||
- 其余行为与原实现保持一致
|
||
"""
|
||
def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
|
||
super(FullAttention, self).__init__()
|
||
self.scale = scale
|
||
self.mask_flag = mask_flag
|
||
self.output_attention = output_attention
|
||
self.dropout = nn.Dropout(attention_dropout)
|
||
|
||
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None, key_padding_mask=None):
|
||
"""
|
||
queries: (B, L, H, E)
|
||
keys: (B, S, H, E)
|
||
values: (B, S, H, D)
|
||
attn_mask: TriangularCausalMask 或 None
|
||
key_padding_mask: (B, S) bool,True=有效,False=padding(可选)
|
||
"""
|
||
B, L, H, E = queries.shape
|
||
_, S, _, D = values.shape
|
||
scale = self.scale or 1. / sqrt(E)
|
||
|
||
scores = torch.einsum("blhe,bshe->bhls", queries, keys) # (B,H,L,S)
|
||
|
||
if self.mask_flag:
|
||
if attn_mask is None:
|
||
attn_mask = TriangularCausalMask(B, L, device=queries.device)
|
||
scores.masked_fill_(attn_mask.mask, -np.inf)
|
||
|
||
# 基于key_padding_mask屏蔽无效键(padding位置不参与注意力)
|
||
if key_padding_mask is not None:
|
||
# key_padding_mask: True=有效,False=padding
|
||
invalid_k = (~key_padding_mask).unsqueeze(1).unsqueeze(1) # (B,1,1,S)
|
||
scores = scores.masked_fill(invalid_k, -np.inf)
|
||
|
||
A = self.dropout(torch.softmax(scale * scores, dim=-1)) # (B,H,L,S)
|
||
V = torch.einsum("bhls,bshd->blhd", A, values) # (B,L,H,D)
|
||
|
||
if self.output_attention:
|
||
return V.contiguous(), A
|
||
else:
|
||
return V.contiguous(), None
|
||
|
||
|
||
class ProbAttention(nn.Module):
|
||
def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
|
||
super(ProbAttention, self).__init__()
|
||
self.factor = factor
|
||
self.scale = scale
|
||
self.mask_flag = mask_flag
|
||
self.output_attention = output_attention
|
||
self.dropout = nn.Dropout(attention_dropout)
|
||
|
||
def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q)
|
||
# Q [B, H, L_q, D], K [B, H, L_k, D]
|
||
B, H, L_K, E = K.shape
|
||
_, _, L_Q, _ = Q.shape
|
||
|
||
K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
|
||
index_sample = torch.randint(L_K, (L_Q, sample_k), device=Q.device)
|
||
K_sample = K_expand[:, :, torch.arange(L_Q, device=Q.device).unsqueeze(1), index_sample, :]
|
||
Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze(-2) # (B,H,L_Q,sample_k)
|
||
|
||
M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K) # (B,H,L_Q)
|
||
M_top = M.topk(n_top, sorted=False)[1] # indices
|
||
|
||
Q_reduce = Q[torch.arange(B)[:, None, None],
|
||
torch.arange(H)[None, :, None],
|
||
M_top, :] # (B,H,n_top,D)
|
||
Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # (B,H,n_top,L_K)
|
||
return Q_K, M_top
|
||
|
||
def _get_initial_context(self, V, L_Q):
|
||
B, H, L_V, D = V.shape
|
||
if not self.mask_flag:
|
||
V_mean = V.mean(dim=-2) # (B,H,D)
|
||
context = V_mean.unsqueeze(-2).expand(B, H, L_Q, D).clone()
|
||
else:
|
||
assert L_Q == L_V
|
||
context = V.cumsum(dim=-2)
|
||
return context
|
||
|
||
def _update_context(self, context_in, V, scores, index, L_Q, attn_mask):
|
||
B, H, L_V, D = V.shape
|
||
if self.mask_flag:
|
||
attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
|
||
scores.masked_fill_(attn_mask.mask, -np.inf)
|
||
attn = torch.softmax(scores, dim=-1)
|
||
|
||
context_in[torch.arange(B)[:, None, None],
|
||
torch.arange(H)[None, :, None],
|
||
index, :] = torch.matmul(attn, V).type_as(context_in)
|
||
if self.output_attention:
|
||
attns = (torch.ones([B, H, L_V, L_V], device=attn.device, dtype=attn.dtype) / L_V)
|
||
attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = attn
|
||
return context_in, attns
|
||
else:
|
||
return context_in, None
|
||
|
||
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None, key_padding_mask=None):
|
||
"""
|
||
key_padding_mask 目前未集成到 ProbAttention(如需,可在scores处对无效键置 -inf)
|
||
"""
|
||
B, L_Q, H, D = queries.shape
|
||
_, L_K, _, _ = keys.shape
|
||
|
||
queries = queries.transpose(2, 1) # (B,H,L_Q,D)
|
||
keys = keys.transpose(2, 1) # (B,H,L_K,D)
|
||
values = values.transpose(2, 1) # (B,H,L_K,D)
|
||
|
||
U_part = self.factor * int(np.ceil(np.log(L_K)))
|
||
u = self.factor * int(np.ceil(np.log(L_Q)))
|
||
|
||
U_part = min(U_part, L_K)
|
||
u = min(u, L_Q)
|
||
|
||
scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u)
|
||
|
||
scale = self.scale or 1. / sqrt(D)
|
||
scores_top = scores_top * scale
|
||
|
||
context = self._get_initial_context(values, L_Q)
|
||
context, attn = self._update_context(context, values, scores_top, index, L_Q, attn_mask)
|
||
|
||
return context.contiguous(), attn
|
||
|
||
|
||
class AttentionLayer(nn.Module):
|
||
"""
|
||
修正点:
|
||
- forward 新增 key_padding_mask 参数,并向 inner_attention 透传
|
||
- 保持与旧调用兼容(不传时默认None)
|
||
"""
|
||
def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None):
|
||
super(AttentionLayer, self).__init__()
|
||
|
||
d_keys = d_keys or (d_model // n_heads)
|
||
d_values = d_values or (d_model // n_heads)
|
||
|
||
self.inner_attention = attention
|
||
self.query_projection = nn.Linear(d_model, d_keys * n_heads)
|
||
self.key_projection = nn.Linear(d_model, d_keys * n_heads)
|
||
self.value_projection = nn.Linear(d_model, d_values * n_heads)
|
||
self.out_projection = nn.Linear(d_values * n_heads, d_model)
|
||
self.n_heads = n_heads
|
||
|
||
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None, key_padding_mask=None):
|
||
"""
|
||
key_padding_mask: (B, S) bool, True=有效,False=padding
|
||
"""
|
||
B, L, _ = queries.shape
|
||
_, S, _ = keys.shape
|
||
H = self.n_heads
|
||
|
||
queries = self.query_projection(queries).view(B, L, H, -1)
|
||
keys = self.key_projection(keys).view(B, S, H, -1)
|
||
values = self.value_projection(values).view(B, S, H, -1)
|
||
|
||
out, attn = self.inner_attention(
|
||
queries,
|
||
keys,
|
||
values,
|
||
attn_mask,
|
||
tau=tau,
|
||
delta=delta,
|
||
key_padding_mask=key_padding_mask,
|
||
)
|
||
out = out.view(B, L, -1)
|
||
return self.out_projection(out), attn
|
||
|
||
|
||
class ReformerLayer(nn.Module):
|
||
def __init__(self, attention, d_model, n_heads, d_keys=None,
|
||
d_values=None, causal=False, bucket_size=4, n_hashes=4):
|
||
super().__init__()
|
||
self.bucket_size = bucket_size
|
||
self.attn = LSHSelfAttention(
|
||
dim=d_model,
|
||
heads=n_heads,
|
||
bucket_size=bucket_size,
|
||
n_hashes=n_hashes,
|
||
causal=causal
|
||
)
|
||
|
||
def fit_length(self, queries):
|
||
# inside reformer: assert N % (bucket_size * 2) == 0
|
||
B, N, C = queries.shape
|
||
if N % (self.bucket_size * 2) == 0:
|
||
return queries
|
||
else:
|
||
fill_len = (self.bucket_size * 2) - (N % (self.bucket_size * 2))
|
||
return torch.cat([queries, torch.zeros([B, fill_len, C]).to(queries.device)], dim=1)
|
||
|
||
def forward(self, queries, keys, values, attn_mask, tau, delta, key_padding_mask=None):
|
||
# queries=keys in Reformer
|
||
B, N, C = queries.shape
|
||
queries = self.attn(self.fit_length(queries))[:, :N, :]
|
||
return queries, None
|
||
|
||
|
||
class TwoStageAttentionLayer(nn.Module):
|
||
'''
|
||
The Two Stage Attention (TSA) Layer
|
||
input/output shape: [batch_size, Data_dim(D), Seg_num(L), d_model]
|
||
'''
|
||
|
||
def __init__(self, configs,
|
||
seg_num, factor, d_model, n_heads, d_ff=None, dropout=0.1):
|
||
super(TwoStageAttentionLayer, self).__init__()
|
||
d_ff = d_ff or 4 * d_model
|
||
self.time_attention = AttentionLayer(FullAttention(False, configs.factor, attention_dropout=configs.dropout,
|
||
output_attention=False), d_model, n_heads)
|
||
self.dim_sender = AttentionLayer(FullAttention(False, configs.factor, attention_dropout=configs.dropout,
|
||
output_attention=False), d_model, n_heads)
|
||
self.dim_receiver = AttentionLayer(FullAttention(False, configs.factor, attention_dropout=configs.dropout,
|
||
output_attention=False), d_model, n_heads)
|
||
self.router = nn.Parameter(torch.randn(seg_num, factor, d_model))
|
||
|
||
self.dropout = nn.Dropout(dropout)
|
||
|
||
self.norm1 = nn.LayerNorm(d_model)
|
||
self.norm2 = nn.LayerNorm(d_model)
|
||
self.norm3 = nn.LayerNorm(d_model)
|
||
self.norm4 = nn.LayerNorm(d_model)
|
||
|
||
self.MLP1 = nn.Sequential(nn.Linear(d_model, d_ff),
|
||
nn.GELU(),
|
||
nn.Linear(d_ff, d_model))
|
||
self.MLP2 = nn.Sequential(nn.Linear(d_model, d_ff),
|
||
nn.GELU(),
|
||
nn.Linear(d_ff, d_model))
|
||
|
||
def forward(self, x, attn_mask=None, tau=None, delta=None, key_padding_mask=None):
|
||
# Cross Time Stage: Directly apply MSA to each dimension
|
||
batch = x.shape[0]
|
||
time_in = rearrange(x, 'b ts_d seg_num d_model -> (b ts_d) seg_num d_model')
|
||
time_enc, attn = self.time_attention(
|
||
time_in, time_in, time_in, attn_mask=None, tau=None, delta=None, key_padding_mask=key_padding_mask
|
||
)
|
||
dim_in = time_in + self.dropout(time_enc)
|
||
dim_in = self.norm1(dim_in)
|
||
dim_in = dim_in + self.dropout(self.MLP1(dim_in))
|
||
dim_in = self.norm2(dim_in)
|
||
|
||
# Cross Dimension Stage
|
||
dim_send = rearrange(dim_in, '(b ts_d) seg_num d_model -> (b seg_num) ts_d d_model', b=batch)
|
||
batch_router = repeat(self.router, 'seg_num factor d_model -> (repeat seg_num) factor d_model', repeat=batch)
|
||
dim_buffer, _ = self.dim_sender(batch_router, dim_send, dim_send, attn_mask=None, tau=None, delta=None)
|
||
dim_receive, _ = self.dim_receiver(dim_send, dim_buffer, dim_buffer, attn_mask=None, tau=None, delta=None)
|
||
dim_enc = dim_send + self.dropout(dim_receive)
|
||
dim_enc = self.norm3(dim_enc)
|
||
dim_enc = dim_enc + self.dropout(self.MLP2(dim_enc))
|
||
dim_enc = self.norm4(dim_enc)
|
||
|
||
final_out = rearrange(dim_enc, '(b seg_num) ts_d d_model -> b ts_d seg_num d_model', b=batch)
|
||
|
||
return final_out
|