import torch import torch.nn as nn import torch.nn.functional as F from typing import List class Encoder_ori(nn.Module): def __init__(self, attn_layers, conv_layers=None, norm_layer=None, one_output=False): super(Encoder_ori, self).__init__() self.attn_layers = nn.ModuleList(attn_layers) self.norm = norm_layer self.one_output = one_output def forward(self, x, attn_mask=None, tau=None, delta=None): # x [B, nvars, D] attns = [] X0 = None # to make Pycharm happy layer_len = len(self.attn_layers) for i, attn_layer in enumerate(self.attn_layers): x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) attns.append(attn) if not self.training and layer_len > 1: if i == 0: X0 = x if isinstance(x, tuple) or isinstance(x, List): x = x[0] if self.norm is not None: x = self.norm(x) if self.one_output: return x else: return x, attns class LinearEncoder(nn.Module): def __init__(self, d_model, d_ff=None, CovMat=None, dropout=0.1, activation="relu", token_num=None, **kwargs): super(LinearEncoder, self).__init__() d_ff = d_ff or 4 * d_model self.d_model = d_model self.d_ff = d_ff self.CovMat = CovMat.unsqueeze(0) if CovMat is not None else None self.token_num = token_num self.norm1 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) # attention --> linear self.v_proj = nn.Linear(d_model, d_model) self.out_proj = nn.Linear(d_model, d_model) init_weight_mat = torch.eye(self.token_num) * 1.0 + torch.randn(self.token_num, self.token_num) * 1.0 self.weight_mat = nn.Parameter(init_weight_mat[None, :, :]) # self.bias = nn.Parameter(torch.zeros(1, 1, self.d_model)) self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) self.activation = F.relu if activation == "relu" else F.gelu self.norm2 = nn.LayerNorm(d_model) def forward(self, x, **kwargs): # x.shape: b, l, d_model values = self.v_proj(x) if self.CovMat is not None: A = F.softmax(self.CovMat, dim=-1) + F.softplus(self.weight_mat) else: A = F.softplus(self.weight_mat) A = F.normalize(A, p=1, dim=-1) A = self.dropout(A) new_x = A @ values # + self.bias x = x + self.dropout(self.out_proj(new_x)) x = self.norm1(x) y = self.dropout(self.activation(self.conv1(x.transpose(-1, 1)))) y = self.dropout(self.conv2(y).transpose(-1, 1)) output = self.norm2(x + y) return output, None