feat(core): add initial TSModel package with OLinear and RevIN
This commit is contained in:
87
models/OLinear/Trans_EncDec.py
Normal file
87
models/OLinear/Trans_EncDec.py
Normal file
@ -0,0 +1,87 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import List
|
||||
|
||||
|
||||
class Encoder_ori(nn.Module):
|
||||
def __init__(self, attn_layers, conv_layers=None, norm_layer=None, one_output=False):
|
||||
super(Encoder_ori, self).__init__()
|
||||
self.attn_layers = nn.ModuleList(attn_layers)
|
||||
self.norm = norm_layer
|
||||
self.one_output = one_output
|
||||
|
||||
def forward(self, x, attn_mask=None, tau=None, delta=None):
|
||||
# x [B, nvars, D]
|
||||
attns = []
|
||||
X0 = None # to make Pycharm happy
|
||||
layer_len = len(self.attn_layers)
|
||||
for i, attn_layer in enumerate(self.attn_layers):
|
||||
x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
|
||||
attns.append(attn)
|
||||
|
||||
if not self.training and layer_len > 1:
|
||||
if i == 0:
|
||||
X0 = x
|
||||
|
||||
if isinstance(x, tuple) or isinstance(x, List):
|
||||
x = x[0]
|
||||
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
|
||||
if self.one_output:
|
||||
return x
|
||||
else:
|
||||
return x, attns
|
||||
|
||||
|
||||
class LinearEncoder(nn.Module):
|
||||
def __init__(self, d_model, d_ff=None, CovMat=None, dropout=0.1, activation="relu", token_num=None, **kwargs):
|
||||
super(LinearEncoder, self).__init__()
|
||||
|
||||
d_ff = d_ff or 4 * d_model
|
||||
self.d_model = d_model
|
||||
self.d_ff = d_ff
|
||||
self.CovMat = CovMat.unsqueeze(0) if CovMat is not None else None
|
||||
self.token_num = token_num
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
# attention --> linear
|
||||
self.v_proj = nn.Linear(d_model, d_model)
|
||||
self.out_proj = nn.Linear(d_model, d_model)
|
||||
|
||||
init_weight_mat = torch.eye(self.token_num) * 1.0 + torch.randn(self.token_num, self.token_num) * 1.0
|
||||
self.weight_mat = nn.Parameter(init_weight_mat[None, :, :])
|
||||
|
||||
# self.bias = nn.Parameter(torch.zeros(1, 1, self.d_model))
|
||||
|
||||
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
|
||||
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
|
||||
self.activation = F.relu if activation == "relu" else F.gelu
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
# x.shape: b, l, d_model
|
||||
values = self.v_proj(x)
|
||||
|
||||
if self.CovMat is not None:
|
||||
A = F.softmax(self.CovMat, dim=-1) + F.softplus(self.weight_mat)
|
||||
else:
|
||||
A = F.softplus(self.weight_mat)
|
||||
|
||||
A = F.normalize(A, p=1, dim=-1)
|
||||
A = self.dropout(A)
|
||||
|
||||
new_x = A @ values # + self.bias
|
||||
|
||||
x = x + self.dropout(self.out_proj(new_x))
|
||||
x = self.norm1(x)
|
||||
|
||||
y = self.dropout(self.activation(self.conv1(x.transpose(-1, 1))))
|
||||
y = self.dropout(self.conv2(y).transpose(-1, 1))
|
||||
output = self.norm2(x + y)
|
||||
|
||||
return output, None
|
Reference in New Issue
Block a user