first commit
This commit is contained in:
218
layers/Pyraformer_EncDec.py
Normal file
218
layers/Pyraformer_EncDec.py
Normal file
@ -0,0 +1,218 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.modules.linear import Linear
|
||||
from layers.SelfAttention_Family import AttentionLayer, FullAttention
|
||||
from layers.Embed import DataEmbedding
|
||||
import math
|
||||
|
||||
|
||||
def get_mask(input_size, window_size, inner_size):
|
||||
"""Get the attention mask of PAM-Naive"""
|
||||
# Get the size of all layers
|
||||
all_size = []
|
||||
all_size.append(input_size)
|
||||
for i in range(len(window_size)):
|
||||
layer_size = math.floor(all_size[i] / window_size[i])
|
||||
all_size.append(layer_size)
|
||||
|
||||
seq_length = sum(all_size)
|
||||
mask = torch.zeros(seq_length, seq_length)
|
||||
|
||||
# get intra-scale mask
|
||||
inner_window = inner_size // 2
|
||||
for layer_idx in range(len(all_size)):
|
||||
start = sum(all_size[:layer_idx])
|
||||
for i in range(start, start + all_size[layer_idx]):
|
||||
left_side = max(i - inner_window, start)
|
||||
right_side = min(i + inner_window + 1, start + all_size[layer_idx])
|
||||
mask[i, left_side:right_side] = 1
|
||||
|
||||
# get inter-scale mask
|
||||
for layer_idx in range(1, len(all_size)):
|
||||
start = sum(all_size[:layer_idx])
|
||||
for i in range(start, start + all_size[layer_idx]):
|
||||
left_side = (start - all_size[layer_idx - 1]) + \
|
||||
(i - start) * window_size[layer_idx - 1]
|
||||
if i == (start + all_size[layer_idx] - 1):
|
||||
right_side = start
|
||||
else:
|
||||
right_side = (
|
||||
start - all_size[layer_idx - 1]) + (i - start + 1) * window_size[layer_idx - 1]
|
||||
mask[i, left_side:right_side] = 1
|
||||
mask[left_side:right_side, i] = 1
|
||||
|
||||
mask = (1 - mask).bool()
|
||||
|
||||
return mask, all_size
|
||||
|
||||
|
||||
def refer_points(all_sizes, window_size):
|
||||
"""Gather features from PAM's pyramid sequences"""
|
||||
input_size = all_sizes[0]
|
||||
indexes = torch.zeros(input_size, len(all_sizes))
|
||||
|
||||
for i in range(input_size):
|
||||
indexes[i][0] = i
|
||||
former_index = i
|
||||
for j in range(1, len(all_sizes)):
|
||||
start = sum(all_sizes[:j])
|
||||
inner_layer_idx = former_index - (start - all_sizes[j - 1])
|
||||
former_index = start + \
|
||||
min(inner_layer_idx // window_size[j - 1], all_sizes[j] - 1)
|
||||
indexes[i][j] = former_index
|
||||
|
||||
indexes = indexes.unsqueeze(0).unsqueeze(3)
|
||||
|
||||
return indexes.long()
|
||||
|
||||
|
||||
class RegularMask():
|
||||
def __init__(self, mask):
|
||||
self._mask = mask.unsqueeze(1)
|
||||
|
||||
@property
|
||||
def mask(self):
|
||||
return self._mask
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
""" Compose with two layers """
|
||||
|
||||
def __init__(self, d_model, d_inner, n_head, dropout=0.1, normalize_before=True):
|
||||
super(EncoderLayer, self).__init__()
|
||||
|
||||
self.slf_attn = AttentionLayer(
|
||||
FullAttention(mask_flag=True, factor=0,
|
||||
attention_dropout=dropout, output_attention=False),
|
||||
d_model, n_head)
|
||||
self.pos_ffn = PositionwiseFeedForward(
|
||||
d_model, d_inner, dropout=dropout, normalize_before=normalize_before)
|
||||
|
||||
def forward(self, enc_input, slf_attn_mask=None):
|
||||
attn_mask = RegularMask(slf_attn_mask)
|
||||
enc_output, _ = self.slf_attn(
|
||||
enc_input, enc_input, enc_input, attn_mask=attn_mask)
|
||||
enc_output = self.pos_ffn(enc_output)
|
||||
return enc_output
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
""" A encoder model with self attention mechanism. """
|
||||
|
||||
def __init__(self, configs, window_size, inner_size):
|
||||
super().__init__()
|
||||
|
||||
d_bottleneck = configs.d_model//4
|
||||
|
||||
self.mask, self.all_size = get_mask(
|
||||
configs.seq_len, window_size, inner_size)
|
||||
self.indexes = refer_points(self.all_size, window_size)
|
||||
self.layers = nn.ModuleList([
|
||||
EncoderLayer(configs.d_model, configs.d_ff, configs.n_heads, dropout=configs.dropout,
|
||||
normalize_before=False) for _ in range(configs.e_layers)
|
||||
]) # naive pyramid attention
|
||||
|
||||
self.enc_embedding = DataEmbedding(
|
||||
configs.enc_in, configs.d_model, configs.dropout)
|
||||
self.conv_layers = Bottleneck_Construct(
|
||||
configs.d_model, window_size, d_bottleneck)
|
||||
|
||||
def forward(self, x_enc, x_mark_enc):
|
||||
seq_enc = self.enc_embedding(x_enc, x_mark_enc)
|
||||
|
||||
mask = self.mask.repeat(len(seq_enc), 1, 1).to(x_enc.device)
|
||||
seq_enc = self.conv_layers(seq_enc)
|
||||
|
||||
for i in range(len(self.layers)):
|
||||
seq_enc = self.layers[i](seq_enc, mask)
|
||||
|
||||
indexes = self.indexes.repeat(seq_enc.size(
|
||||
0), 1, 1, seq_enc.size(2)).to(seq_enc.device)
|
||||
indexes = indexes.view(seq_enc.size(0), -1, seq_enc.size(2))
|
||||
all_enc = torch.gather(seq_enc, 1, indexes)
|
||||
seq_enc = all_enc.view(seq_enc.size(0), self.all_size[0], -1)
|
||||
|
||||
return seq_enc
|
||||
|
||||
|
||||
class ConvLayer(nn.Module):
|
||||
def __init__(self, c_in, window_size):
|
||||
super(ConvLayer, self).__init__()
|
||||
self.downConv = nn.Conv1d(in_channels=c_in,
|
||||
out_channels=c_in,
|
||||
kernel_size=window_size,
|
||||
stride=window_size)
|
||||
self.norm = nn.BatchNorm1d(c_in)
|
||||
self.activation = nn.ELU()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.downConv(x)
|
||||
x = self.norm(x)
|
||||
x = self.activation(x)
|
||||
return x
|
||||
|
||||
|
||||
class Bottleneck_Construct(nn.Module):
|
||||
"""Bottleneck convolution CSCM"""
|
||||
|
||||
def __init__(self, d_model, window_size, d_inner):
|
||||
super(Bottleneck_Construct, self).__init__()
|
||||
if not isinstance(window_size, list):
|
||||
self.conv_layers = nn.ModuleList([
|
||||
ConvLayer(d_inner, window_size),
|
||||
ConvLayer(d_inner, window_size),
|
||||
ConvLayer(d_inner, window_size)
|
||||
])
|
||||
else:
|
||||
self.conv_layers = []
|
||||
for i in range(len(window_size)):
|
||||
self.conv_layers.append(ConvLayer(d_inner, window_size[i]))
|
||||
self.conv_layers = nn.ModuleList(self.conv_layers)
|
||||
self.up = Linear(d_inner, d_model)
|
||||
self.down = Linear(d_model, d_inner)
|
||||
self.norm = nn.LayerNorm(d_model)
|
||||
|
||||
def forward(self, enc_input):
|
||||
temp_input = self.down(enc_input).permute(0, 2, 1)
|
||||
all_inputs = []
|
||||
for i in range(len(self.conv_layers)):
|
||||
temp_input = self.conv_layers[i](temp_input)
|
||||
all_inputs.append(temp_input)
|
||||
|
||||
all_inputs = torch.cat(all_inputs, dim=2).transpose(1, 2)
|
||||
all_inputs = self.up(all_inputs)
|
||||
all_inputs = torch.cat([enc_input, all_inputs], dim=1)
|
||||
|
||||
all_inputs = self.norm(all_inputs)
|
||||
return all_inputs
|
||||
|
||||
|
||||
class PositionwiseFeedForward(nn.Module):
|
||||
""" Two-layer position-wise feed-forward neural network. """
|
||||
|
||||
def __init__(self, d_in, d_hid, dropout=0.1, normalize_before=True):
|
||||
super().__init__()
|
||||
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
self.w_1 = nn.Linear(d_in, d_hid)
|
||||
self.w_2 = nn.Linear(d_hid, d_in)
|
||||
|
||||
self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.layer_norm(x)
|
||||
|
||||
x = F.gelu(self.w_1(x))
|
||||
x = self.dropout(x)
|
||||
x = self.w_2(x)
|
||||
x = self.dropout(x)
|
||||
x = x + residual
|
||||
|
||||
if not self.normalize_before:
|
||||
x = self.layer_norm(x)
|
||||
return x
|
Reference in New Issue
Block a user