Files
tsmodel/models/OLinear/Trans_EncDec.py

88 lines
2.8 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List
class Encoder_ori(nn.Module):
def __init__(self, attn_layers, conv_layers=None, norm_layer=None, one_output=False):
super(Encoder_ori, self).__init__()
self.attn_layers = nn.ModuleList(attn_layers)
self.norm = norm_layer
self.one_output = one_output
def forward(self, x, attn_mask=None, tau=None, delta=None):
# x [B, nvars, D]
attns = []
X0 = None # to make Pycharm happy
layer_len = len(self.attn_layers)
for i, attn_layer in enumerate(self.attn_layers):
x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
attns.append(attn)
if not self.training and layer_len > 1:
if i == 0:
X0 = x
if isinstance(x, tuple) or isinstance(x, List):
x = x[0]
if self.norm is not None:
x = self.norm(x)
if self.one_output:
return x
else:
return x, attns
class LinearEncoder(nn.Module):
def __init__(self, d_model, d_ff=None, CovMat=None, dropout=0.1, activation="relu", token_num=None, **kwargs):
super(LinearEncoder, self).__init__()
d_ff = d_ff or 4 * d_model
self.d_model = d_model
self.d_ff = d_ff
self.CovMat = CovMat.unsqueeze(0) if CovMat is not None else None
self.token_num = token_num
self.norm1 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
# attention --> linear
self.v_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
init_weight_mat = torch.eye(self.token_num) * 1.0 + torch.randn(self.token_num, self.token_num) * 1.0
self.weight_mat = nn.Parameter(init_weight_mat[None, :, :])
# self.bias = nn.Parameter(torch.zeros(1, 1, self.d_model))
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
self.activation = F.relu if activation == "relu" else F.gelu
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x, **kwargs):
# x.shape: b, l, d_model
values = self.v_proj(x)
if self.CovMat is not None:
A = F.softmax(self.CovMat, dim=-1) + F.softplus(self.weight_mat)
else:
A = F.softplus(self.weight_mat)
A = F.normalize(A, p=1, dim=-1)
A = self.dropout(A)
new_x = A @ values # + self.bias
x = x + self.dropout(self.out_proj(new_x))
x = self.norm1(x)
y = self.dropout(self.activation(self.conv1(x.transpose(-1, 1))))
y = self.dropout(self.conv2(y).transpose(-1, 1))
output = self.norm2(x + y)
return output, None