Files
TSlib/layers/SelfAttention_Family.py

314 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import numpy as np
from math import sqrt
from utils.masking import TriangularCausalMask, ProbMask
from reformer_pytorch import LSHSelfAttention
from einops import rearrange, repeat
class DSAttention(nn.Module):
'''De-stationary Attention'''
def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
super(DSAttention, self).__init__()
self.scale = scale
self.mask_flag = mask_flag
self.output_attention = output_attention
self.dropout = nn.Dropout(attention_dropout)
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None, key_padding_mask=None):
"""
key_padding_mask: (B, S) bool, True=valid, False=pad可选忽略或由上层应用
"""
B, L, H, E = queries.shape
_, S, _, D = values.shape
scale = self.scale or 1. / sqrt(E)
tau = 1.0 if tau is None else tau.unsqueeze(1).unsqueeze(1) # B x 1 x 1 x 1
delta = 0.0 if delta is None else delta.unsqueeze(1).unsqueeze(1) # B x 1 x 1 x S
scores = torch.einsum("blhe,bshe->bhls", queries, keys) * tau + delta # (B,H,L,S)
if self.mask_flag:
if attn_mask is None:
attn_mask = TriangularCausalMask(B, L, device=queries.device)
scores.masked_fill_(attn_mask.mask, -np.inf)
# 可选基于key_padding_mask的无效键屏蔽不改变原行为默认None
if key_padding_mask is not None:
# key_padding_mask: True 表示有效False为padding
invalid_k = (~key_padding_mask).unsqueeze(1).unsqueeze(1) # (B,1,1,S)
scores = scores.masked_fill(invalid_k, -np.inf)
A = self.dropout(torch.softmax(scale * scores, dim=-1))
V = torch.einsum("bhls,bshd->blhd", A, values)
if self.output_attention:
return V.contiguous(), A
else:
return V.contiguous(), None
class FullAttention(nn.Module):
"""
修正点:
- 新增 key_padding_mask 支持用于屏蔽批内右侧pad的键向量与DC变长对齐
- key_padding_mask 约定shape=(B, S)True=有效False=padding
- 其余行为与原实现保持一致
"""
def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
super(FullAttention, self).__init__()
self.scale = scale
self.mask_flag = mask_flag
self.output_attention = output_attention
self.dropout = nn.Dropout(attention_dropout)
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None, key_padding_mask=None):
"""
queries: (B, L, H, E)
keys: (B, S, H, E)
values: (B, S, H, D)
attn_mask: TriangularCausalMask 或 None
key_padding_mask: (B, S) boolTrue=有效False=padding可选
"""
B, L, H, E = queries.shape
_, S, _, D = values.shape
scale = self.scale or 1. / sqrt(E)
scores = torch.einsum("blhe,bshe->bhls", queries, keys) # (B,H,L,S)
if self.mask_flag:
if attn_mask is None:
attn_mask = TriangularCausalMask(B, L, device=queries.device)
scores.masked_fill_(attn_mask.mask, -np.inf)
# 基于key_padding_mask屏蔽无效键padding位置不参与注意力
if key_padding_mask is not None:
# key_padding_mask: True=有效False=padding
invalid_k = (~key_padding_mask).unsqueeze(1).unsqueeze(1) # (B,1,1,S)
scores = scores.masked_fill(invalid_k, -np.inf)
A = self.dropout(torch.softmax(scale * scores, dim=-1)) # (B,H,L,S)
V = torch.einsum("bhls,bshd->blhd", A, values) # (B,L,H,D)
if self.output_attention:
return V.contiguous(), A
else:
return V.contiguous(), None
class ProbAttention(nn.Module):
def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
super(ProbAttention, self).__init__()
self.factor = factor
self.scale = scale
self.mask_flag = mask_flag
self.output_attention = output_attention
self.dropout = nn.Dropout(attention_dropout)
def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q)
# Q [B, H, L_q, D], K [B, H, L_k, D]
B, H, L_K, E = K.shape
_, _, L_Q, _ = Q.shape
K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
index_sample = torch.randint(L_K, (L_Q, sample_k), device=Q.device)
K_sample = K_expand[:, :, torch.arange(L_Q, device=Q.device).unsqueeze(1), index_sample, :]
Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze(-2) # (B,H,L_Q,sample_k)
M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K) # (B,H,L_Q)
M_top = M.topk(n_top, sorted=False)[1] # indices
Q_reduce = Q[torch.arange(B)[:, None, None],
torch.arange(H)[None, :, None],
M_top, :] # (B,H,n_top,D)
Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # (B,H,n_top,L_K)
return Q_K, M_top
def _get_initial_context(self, V, L_Q):
B, H, L_V, D = V.shape
if not self.mask_flag:
V_mean = V.mean(dim=-2) # (B,H,D)
context = V_mean.unsqueeze(-2).expand(B, H, L_Q, D).clone()
else:
assert L_Q == L_V
context = V.cumsum(dim=-2)
return context
def _update_context(self, context_in, V, scores, index, L_Q, attn_mask):
B, H, L_V, D = V.shape
if self.mask_flag:
attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
scores.masked_fill_(attn_mask.mask, -np.inf)
attn = torch.softmax(scores, dim=-1)
context_in[torch.arange(B)[:, None, None],
torch.arange(H)[None, :, None],
index, :] = torch.matmul(attn, V).type_as(context_in)
if self.output_attention:
attns = (torch.ones([B, H, L_V, L_V], device=attn.device, dtype=attn.dtype) / L_V)
attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = attn
return context_in, attns
else:
return context_in, None
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None, key_padding_mask=None):
"""
key_padding_mask 目前未集成到 ProbAttention如需可在scores处对无效键置 -inf
"""
B, L_Q, H, D = queries.shape
_, L_K, _, _ = keys.shape
queries = queries.transpose(2, 1) # (B,H,L_Q,D)
keys = keys.transpose(2, 1) # (B,H,L_K,D)
values = values.transpose(2, 1) # (B,H,L_K,D)
U_part = self.factor * int(np.ceil(np.log(L_K)))
u = self.factor * int(np.ceil(np.log(L_Q)))
U_part = min(U_part, L_K)
u = min(u, L_Q)
scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u)
scale = self.scale or 1. / sqrt(D)
scores_top = scores_top * scale
context = self._get_initial_context(values, L_Q)
context, attn = self._update_context(context, values, scores_top, index, L_Q, attn_mask)
return context.contiguous(), attn
class AttentionLayer(nn.Module):
"""
修正点:
- forward 新增 key_padding_mask 参数,并向 inner_attention 透传
- 保持与旧调用兼容不传时默认None
"""
def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None):
super(AttentionLayer, self).__init__()
d_keys = d_keys or (d_model // n_heads)
d_values = d_values or (d_model // n_heads)
self.inner_attention = attention
self.query_projection = nn.Linear(d_model, d_keys * n_heads)
self.key_projection = nn.Linear(d_model, d_keys * n_heads)
self.value_projection = nn.Linear(d_model, d_values * n_heads)
self.out_projection = nn.Linear(d_values * n_heads, d_model)
self.n_heads = n_heads
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None, key_padding_mask=None):
"""
key_padding_mask: (B, S) bool, True=有效False=padding
"""
B, L, _ = queries.shape
_, S, _ = keys.shape
H = self.n_heads
queries = self.query_projection(queries).view(B, L, H, -1)
keys = self.key_projection(keys).view(B, S, H, -1)
values = self.value_projection(values).view(B, S, H, -1)
out, attn = self.inner_attention(
queries,
keys,
values,
attn_mask,
tau=tau,
delta=delta,
key_padding_mask=key_padding_mask,
)
out = out.view(B, L, -1)
return self.out_projection(out), attn
class ReformerLayer(nn.Module):
def __init__(self, attention, d_model, n_heads, d_keys=None,
d_values=None, causal=False, bucket_size=4, n_hashes=4):
super().__init__()
self.bucket_size = bucket_size
self.attn = LSHSelfAttention(
dim=d_model,
heads=n_heads,
bucket_size=bucket_size,
n_hashes=n_hashes,
causal=causal
)
def fit_length(self, queries):
# inside reformer: assert N % (bucket_size * 2) == 0
B, N, C = queries.shape
if N % (self.bucket_size * 2) == 0:
return queries
else:
fill_len = (self.bucket_size * 2) - (N % (self.bucket_size * 2))
return torch.cat([queries, torch.zeros([B, fill_len, C]).to(queries.device)], dim=1)
def forward(self, queries, keys, values, attn_mask, tau, delta, key_padding_mask=None):
# queries=keys in Reformer
B, N, C = queries.shape
queries = self.attn(self.fit_length(queries))[:, :N, :]
return queries, None
class TwoStageAttentionLayer(nn.Module):
'''
The Two Stage Attention (TSA) Layer
input/output shape: [batch_size, Data_dim(D), Seg_num(L), d_model]
'''
def __init__(self, configs,
seg_num, factor, d_model, n_heads, d_ff=None, dropout=0.1):
super(TwoStageAttentionLayer, self).__init__()
d_ff = d_ff or 4 * d_model
self.time_attention = AttentionLayer(FullAttention(False, configs.factor, attention_dropout=configs.dropout,
output_attention=False), d_model, n_heads)
self.dim_sender = AttentionLayer(FullAttention(False, configs.factor, attention_dropout=configs.dropout,
output_attention=False), d_model, n_heads)
self.dim_receiver = AttentionLayer(FullAttention(False, configs.factor, attention_dropout=configs.dropout,
output_attention=False), d_model, n_heads)
self.router = nn.Parameter(torch.randn(seg_num, factor, d_model))
self.dropout = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.norm4 = nn.LayerNorm(d_model)
self.MLP1 = nn.Sequential(nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Linear(d_ff, d_model))
self.MLP2 = nn.Sequential(nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Linear(d_ff, d_model))
def forward(self, x, attn_mask=None, tau=None, delta=None, key_padding_mask=None):
# Cross Time Stage: Directly apply MSA to each dimension
batch = x.shape[0]
time_in = rearrange(x, 'b ts_d seg_num d_model -> (b ts_d) seg_num d_model')
time_enc, attn = self.time_attention(
time_in, time_in, time_in, attn_mask=None, tau=None, delta=None, key_padding_mask=key_padding_mask
)
dim_in = time_in + self.dropout(time_enc)
dim_in = self.norm1(dim_in)
dim_in = dim_in + self.dropout(self.MLP1(dim_in))
dim_in = self.norm2(dim_in)
# Cross Dimension Stage
dim_send = rearrange(dim_in, '(b ts_d) seg_num d_model -> (b seg_num) ts_d d_model', b=batch)
batch_router = repeat(self.router, 'seg_num factor d_model -> (repeat seg_num) factor d_model', repeat=batch)
dim_buffer, _ = self.dim_sender(batch_router, dim_send, dim_send, attn_mask=None, tau=None, delta=None)
dim_receive, _ = self.dim_receiver(dim_send, dim_buffer, dim_buffer, attn_mask=None, tau=None, delta=None)
dim_enc = dim_send + self.dropout(dim_receive)
dim_enc = self.norm3(dim_enc)
dim_enc = dim_enc + self.dropout(self.MLP2(dim_enc))
dim_enc = self.norm4(dim_enc)
final_out = rearrange(dim_enc, '(b seg_num) ts_d d_model -> b ts_d seg_num d_model', b=batch)
return final_out