import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.modules.linear import Linear from layers.SelfAttention_Family import AttentionLayer, FullAttention from layers.Embed import DataEmbedding import math def get_mask(input_size, window_size, inner_size): """Get the attention mask of PAM-Naive""" # Get the size of all layers all_size = [] all_size.append(input_size) for i in range(len(window_size)): layer_size = math.floor(all_size[i] / window_size[i]) all_size.append(layer_size) seq_length = sum(all_size) mask = torch.zeros(seq_length, seq_length) # get intra-scale mask inner_window = inner_size // 2 for layer_idx in range(len(all_size)): start = sum(all_size[:layer_idx]) for i in range(start, start + all_size[layer_idx]): left_side = max(i - inner_window, start) right_side = min(i + inner_window + 1, start + all_size[layer_idx]) mask[i, left_side:right_side] = 1 # get inter-scale mask for layer_idx in range(1, len(all_size)): start = sum(all_size[:layer_idx]) for i in range(start, start + all_size[layer_idx]): left_side = (start - all_size[layer_idx - 1]) + \ (i - start) * window_size[layer_idx - 1] if i == (start + all_size[layer_idx] - 1): right_side = start else: right_side = ( start - all_size[layer_idx - 1]) + (i - start + 1) * window_size[layer_idx - 1] mask[i, left_side:right_side] = 1 mask[left_side:right_side, i] = 1 mask = (1 - mask).bool() return mask, all_size def refer_points(all_sizes, window_size): """Gather features from PAM's pyramid sequences""" input_size = all_sizes[0] indexes = torch.zeros(input_size, len(all_sizes)) for i in range(input_size): indexes[i][0] = i former_index = i for j in range(1, len(all_sizes)): start = sum(all_sizes[:j]) inner_layer_idx = former_index - (start - all_sizes[j - 1]) former_index = start + \ min(inner_layer_idx // window_size[j - 1], all_sizes[j] - 1) indexes[i][j] = former_index indexes = indexes.unsqueeze(0).unsqueeze(3) return indexes.long() class RegularMask(): def __init__(self, mask): self._mask = mask.unsqueeze(1) @property def mask(self): return self._mask class EncoderLayer(nn.Module): """ Compose with two layers """ def __init__(self, d_model, d_inner, n_head, dropout=0.1, normalize_before=True): super(EncoderLayer, self).__init__() self.slf_attn = AttentionLayer( FullAttention(mask_flag=True, factor=0, attention_dropout=dropout, output_attention=False), d_model, n_head) self.pos_ffn = PositionwiseFeedForward( d_model, d_inner, dropout=dropout, normalize_before=normalize_before) def forward(self, enc_input, slf_attn_mask=None): attn_mask = RegularMask(slf_attn_mask) enc_output, _ = self.slf_attn( enc_input, enc_input, enc_input, attn_mask=attn_mask) enc_output = self.pos_ffn(enc_output) return enc_output class Encoder(nn.Module): """ A encoder model with self attention mechanism. """ def __init__(self, configs, window_size, inner_size): super().__init__() d_bottleneck = configs.d_model//4 self.mask, self.all_size = get_mask( configs.seq_len, window_size, inner_size) self.indexes = refer_points(self.all_size, window_size) self.layers = nn.ModuleList([ EncoderLayer(configs.d_model, configs.d_ff, configs.n_heads, dropout=configs.dropout, normalize_before=False) for _ in range(configs.e_layers) ]) # naive pyramid attention self.enc_embedding = DataEmbedding( configs.enc_in, configs.d_model, configs.dropout) self.conv_layers = Bottleneck_Construct( configs.d_model, window_size, d_bottleneck) def forward(self, x_enc, x_mark_enc): seq_enc = self.enc_embedding(x_enc, x_mark_enc) mask = self.mask.repeat(len(seq_enc), 1, 1).to(x_enc.device) seq_enc = self.conv_layers(seq_enc) for i in range(len(self.layers)): seq_enc = self.layers[i](seq_enc, mask) indexes = self.indexes.repeat(seq_enc.size( 0), 1, 1, seq_enc.size(2)).to(seq_enc.device) indexes = indexes.view(seq_enc.size(0), -1, seq_enc.size(2)) all_enc = torch.gather(seq_enc, 1, indexes) seq_enc = all_enc.view(seq_enc.size(0), self.all_size[0], -1) return seq_enc class ConvLayer(nn.Module): def __init__(self, c_in, window_size): super(ConvLayer, self).__init__() self.downConv = nn.Conv1d(in_channels=c_in, out_channels=c_in, kernel_size=window_size, stride=window_size) self.norm = nn.BatchNorm1d(c_in) self.activation = nn.ELU() def forward(self, x): x = self.downConv(x) x = self.norm(x) x = self.activation(x) return x class Bottleneck_Construct(nn.Module): """Bottleneck convolution CSCM""" def __init__(self, d_model, window_size, d_inner): super(Bottleneck_Construct, self).__init__() if not isinstance(window_size, list): self.conv_layers = nn.ModuleList([ ConvLayer(d_inner, window_size), ConvLayer(d_inner, window_size), ConvLayer(d_inner, window_size) ]) else: self.conv_layers = [] for i in range(len(window_size)): self.conv_layers.append(ConvLayer(d_inner, window_size[i])) self.conv_layers = nn.ModuleList(self.conv_layers) self.up = Linear(d_inner, d_model) self.down = Linear(d_model, d_inner) self.norm = nn.LayerNorm(d_model) def forward(self, enc_input): temp_input = self.down(enc_input).permute(0, 2, 1) all_inputs = [] for i in range(len(self.conv_layers)): temp_input = self.conv_layers[i](temp_input) all_inputs.append(temp_input) all_inputs = torch.cat(all_inputs, dim=2).transpose(1, 2) all_inputs = self.up(all_inputs) all_inputs = torch.cat([enc_input, all_inputs], dim=1) all_inputs = self.norm(all_inputs) return all_inputs class PositionwiseFeedForward(nn.Module): """ Two-layer position-wise feed-forward neural network. """ def __init__(self, d_in, d_hid, dropout=0.1, normalize_before=True): super().__init__() self.normalize_before = normalize_before self.w_1 = nn.Linear(d_in, d_hid) self.w_2 = nn.Linear(d_hid, d_in) self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) self.dropout = nn.Dropout(dropout) def forward(self, x): residual = x if self.normalize_before: x = self.layer_norm(x) x = F.gelu(self.w_1(x)) x = self.dropout(x) x = self.w_2(x) x = self.dropout(x) x = x + residual if not self.normalize_before: x = self.layer_norm(x) return x