first commit
This commit is contained in:
163
layers/AutoCorrelation.py
Normal file
163
layers/AutoCorrelation.py
Normal file
@ -0,0 +1,163 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import math
|
||||
from math import sqrt
|
||||
import os
|
||||
|
||||
|
||||
class AutoCorrelation(nn.Module):
|
||||
"""
|
||||
AutoCorrelation Mechanism with the following two phases:
|
||||
(1) period-based dependencies discovery
|
||||
(2) time delay aggregation
|
||||
This block can replace the self-attention family mechanism seamlessly.
|
||||
"""
|
||||
|
||||
def __init__(self, mask_flag=True, factor=1, scale=None, attention_dropout=0.1, output_attention=False):
|
||||
super(AutoCorrelation, self).__init__()
|
||||
self.factor = factor
|
||||
self.scale = scale
|
||||
self.mask_flag = mask_flag
|
||||
self.output_attention = output_attention
|
||||
self.dropout = nn.Dropout(attention_dropout)
|
||||
|
||||
def time_delay_agg_training(self, values, corr):
|
||||
"""
|
||||
SpeedUp version of Autocorrelation (a batch-normalization style design)
|
||||
This is for the training phase.
|
||||
"""
|
||||
head = values.shape[1]
|
||||
channel = values.shape[2]
|
||||
length = values.shape[3]
|
||||
# find top k
|
||||
top_k = int(self.factor * math.log(length))
|
||||
mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
|
||||
index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1]
|
||||
weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1)
|
||||
# update corr
|
||||
tmp_corr = torch.softmax(weights, dim=-1)
|
||||
# aggregation
|
||||
tmp_values = values
|
||||
delays_agg = torch.zeros_like(values).float()
|
||||
for i in range(top_k):
|
||||
pattern = torch.roll(tmp_values, -int(index[i]), -1)
|
||||
delays_agg = delays_agg + pattern * \
|
||||
(tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))
|
||||
return delays_agg
|
||||
|
||||
def time_delay_agg_inference(self, values, corr):
|
||||
"""
|
||||
SpeedUp version of Autocorrelation (a batch-normalization style design)
|
||||
This is for the inference phase.
|
||||
"""
|
||||
batch = values.shape[0]
|
||||
head = values.shape[1]
|
||||
channel = values.shape[2]
|
||||
length = values.shape[3]
|
||||
# index init
|
||||
init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1).to(values.device)
|
||||
# find top k
|
||||
top_k = int(self.factor * math.log(length))
|
||||
mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
|
||||
weights, delay = torch.topk(mean_value, top_k, dim=-1)
|
||||
# update corr
|
||||
tmp_corr = torch.softmax(weights, dim=-1)
|
||||
# aggregation
|
||||
tmp_values = values.repeat(1, 1, 1, 2)
|
||||
delays_agg = torch.zeros_like(values).float()
|
||||
for i in range(top_k):
|
||||
tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)
|
||||
pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
|
||||
delays_agg = delays_agg + pattern * \
|
||||
(tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))
|
||||
return delays_agg
|
||||
|
||||
def time_delay_agg_full(self, values, corr):
|
||||
"""
|
||||
Standard version of Autocorrelation
|
||||
"""
|
||||
batch = values.shape[0]
|
||||
head = values.shape[1]
|
||||
channel = values.shape[2]
|
||||
length = values.shape[3]
|
||||
# index init
|
||||
init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1).to(values.device)
|
||||
# find top k
|
||||
top_k = int(self.factor * math.log(length))
|
||||
weights, delay = torch.topk(corr, top_k, dim=-1)
|
||||
# update corr
|
||||
tmp_corr = torch.softmax(weights, dim=-1)
|
||||
# aggregation
|
||||
tmp_values = values.repeat(1, 1, 1, 2)
|
||||
delays_agg = torch.zeros_like(values).float()
|
||||
for i in range(top_k):
|
||||
tmp_delay = init_index + delay[..., i].unsqueeze(-1)
|
||||
pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
|
||||
delays_agg = delays_agg + pattern * (tmp_corr[..., i].unsqueeze(-1))
|
||||
return delays_agg
|
||||
|
||||
def forward(self, queries, keys, values, attn_mask):
|
||||
B, L, H, E = queries.shape
|
||||
_, S, _, D = values.shape
|
||||
if L > S:
|
||||
zeros = torch.zeros_like(queries[:, :(L - S), :]).float()
|
||||
values = torch.cat([values, zeros], dim=1)
|
||||
keys = torch.cat([keys, zeros], dim=1)
|
||||
else:
|
||||
values = values[:, :L, :, :]
|
||||
keys = keys[:, :L, :, :]
|
||||
|
||||
# period-based dependencies
|
||||
q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1)
|
||||
k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1)
|
||||
res = q_fft * torch.conj(k_fft)
|
||||
corr = torch.fft.irfft(res, dim=-1)
|
||||
|
||||
# time delay agg
|
||||
if self.training:
|
||||
V = self.time_delay_agg_training(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)
|
||||
else:
|
||||
V = self.time_delay_agg_inference(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)
|
||||
|
||||
if self.output_attention:
|
||||
return (V.contiguous(), corr.permute(0, 3, 1, 2))
|
||||
else:
|
||||
return (V.contiguous(), None)
|
||||
|
||||
|
||||
class AutoCorrelationLayer(nn.Module):
|
||||
def __init__(self, correlation, d_model, n_heads, d_keys=None,
|
||||
d_values=None):
|
||||
super(AutoCorrelationLayer, self).__init__()
|
||||
|
||||
d_keys = d_keys or (d_model // n_heads)
|
||||
d_values = d_values or (d_model // n_heads)
|
||||
|
||||
self.inner_correlation = correlation
|
||||
self.query_projection = nn.Linear(d_model, d_keys * n_heads)
|
||||
self.key_projection = nn.Linear(d_model, d_keys * n_heads)
|
||||
self.value_projection = nn.Linear(d_model, d_values * n_heads)
|
||||
self.out_projection = nn.Linear(d_values * n_heads, d_model)
|
||||
self.n_heads = n_heads
|
||||
|
||||
def forward(self, queries, keys, values, attn_mask):
|
||||
B, L, _ = queries.shape
|
||||
_, S, _ = keys.shape
|
||||
H = self.n_heads
|
||||
|
||||
queries = self.query_projection(queries).view(B, L, H, -1)
|
||||
keys = self.key_projection(keys).view(B, S, H, -1)
|
||||
values = self.value_projection(values).view(B, S, H, -1)
|
||||
|
||||
out, attn = self.inner_correlation(
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
attn_mask
|
||||
)
|
||||
out = out.view(B, L, -1)
|
||||
|
||||
return self.out_projection(out), attn
|
Reference in New Issue
Block a user