feat(core): add initial TSModel package with OLinear and RevIN
This commit is contained in:
150
models/OLinear/model.py
Normal file
150
models/OLinear/model.py
Normal file
@ -0,0 +1,150 @@
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from RevIN import RevIN
|
||||
from Trans_EncDec import Encoder_ori, LinearEncoder
|
||||
|
||||
|
||||
|
||||
class OLinear(nn.Module):
|
||||
def __init__(self, configs):
|
||||
super(OLinear, self).__init__()
|
||||
self.pred_len = configs.pred_len
|
||||
self.enc_in = configs.enc_in # channels
|
||||
self.seq_len = configs.seq_len
|
||||
self.hidden_size = self.d_model = configs.d_model # hidden_size
|
||||
self.d_ff = configs.d_ff # d_ff
|
||||
|
||||
self.Q_chan_indep = configs.Q_chan_indep
|
||||
|
||||
q_mat_dir = configs.Q_MAT_file if self.Q_chan_indep else configs.q_mat_file
|
||||
if not os.path.isfile(q_mat_dir):
|
||||
q_mat_dir = os.path.join(configs.root_path, q_mat_dir)
|
||||
assert os.path.isfile(q_mat_dir)
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
self.Q_mat = torch.from_numpy(np.load(q_mat_dir)).to(torch.float32).to(device)
|
||||
|
||||
assert (self.Q_mat.ndim == 3 if self.Q_chan_indep else self.Q_mat.ndim == 2)
|
||||
assert (self.Q_mat.shape[0] == self.enc_in if self.Q_chan_indep else self.Q_mat.shape[0] == self.seq_len)
|
||||
|
||||
q_out_mat_dir = configs.Q_OUT_MAT_file if self.Q_chan_indep else configs.q_out_mat_file
|
||||
if not os.path.isfile(q_out_mat_dir):
|
||||
q_out_mat_dir = os.path.join(configs.root_path, q_out_mat_dir)
|
||||
assert os.path.isfile(q_out_mat_dir)
|
||||
self.Q_out_mat = torch.from_numpy(np.load(q_out_mat_dir)).to(torch.float32).to(device)
|
||||
|
||||
assert (self.Q_out_mat.ndim == 3 if self.Q_chan_indep else self.Q_out_mat.ndim == 2)
|
||||
assert (self.Q_out_mat.shape[0] == self.enc_in if self.Q_chan_indep else
|
||||
self.Q_out_mat.shape[0] == self.pred_len)
|
||||
|
||||
self.patch_len = configs.temp_patch_len
|
||||
self.stride = configs.temp_stride
|
||||
|
||||
# self.channel_independence = configs.channel_independence
|
||||
self.embed_size = configs.embed_size # embed_size
|
||||
self.embeddings = nn.Parameter(torch.randn(1, self.embed_size))
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(self.pred_len * self.embed_size, self.d_ff),
|
||||
nn.GELU(),
|
||||
nn.Linear(self.d_ff, self.pred_len)
|
||||
)
|
||||
|
||||
# for final input and output
|
||||
self.revin_layer = RevIN(self.enc_in, affine=True)
|
||||
self.dropout = nn.Dropout(configs.dropout)
|
||||
|
||||
# ############# transformer related #########
|
||||
self.encoder = Encoder_ori(
|
||||
[
|
||||
LinearEncoder(
|
||||
d_model=configs.d_model, d_ff=configs.d_ff, CovMat=None,
|
||||
dropout=configs.dropout, activation=configs.activation, token_num=self.enc_in,
|
||||
) for _ in range(configs.e_layers)
|
||||
],
|
||||
norm_layer=nn.LayerNorm(configs.d_model),
|
||||
one_output=True,
|
||||
CKA_flag=configs.CKA_flag
|
||||
)
|
||||
self.ortho_trans = nn.Sequential(
|
||||
nn.Linear(self.seq_len * self.embed_size, self.d_model),
|
||||
self.encoder,
|
||||
nn.Linear(self.d_model, self.pred_len * self.embed_size)
|
||||
)
|
||||
|
||||
# learnable delta
|
||||
self.delta1 = nn.Parameter(torch.zeros(1, self.enc_in, 1, self.seq_len))
|
||||
self.delta2 = nn.Parameter(torch.zeros(1, self.enc_in, 1, self.pred_len))
|
||||
|
||||
# dimension extension
|
||||
def tokenEmb(self, x, embeddings):
|
||||
if self.embed_size <= 1:
|
||||
return x.transpose(-1, -2).unsqueeze(-1)
|
||||
# x: [B, T, N] --> [B, N, T]
|
||||
x = x.transpose(-1, -2)
|
||||
x = x.unsqueeze(-1)
|
||||
# B*N*T*1 x 1*D = B*N*T*D
|
||||
return x * embeddings
|
||||
|
||||
def Fre_Trans(self, x):
|
||||
# [B, N, T, D]
|
||||
B, N, T, D = x.shape
|
||||
assert T == self.seq_len
|
||||
# [B, N, D, T]
|
||||
x = x.transpose(-1, -2)
|
||||
|
||||
# orthogonal transformation
|
||||
# [B, N, D, T]
|
||||
if self.Q_chan_indep:
|
||||
x_trans = torch.einsum('bndt,ntv->bndv', x, self.Q_mat.transpose(-1, -2))
|
||||
else:
|
||||
x_trans = torch.einsum('bndt,tv->bndv', x, self.Q_mat.transpose(-1, -2)) + self.delta1
|
||||
# added on 25/1/30
|
||||
# x_trans = F.gelu(x_trans)
|
||||
# [B, N, D, T]
|
||||
assert x_trans.shape[-1] == self.seq_len
|
||||
|
||||
# ########## transformer ####
|
||||
x_trans = self.ortho_trans(x_trans.flatten(-2)).reshape(B, N, D, self.pred_len)
|
||||
|
||||
# [B, N, D, tau]; orthogonal transformation
|
||||
if self.Q_chan_indep:
|
||||
x = torch.einsum('bndt,ntv->bndv', x_trans, self.Q_out_mat)
|
||||
else:
|
||||
x = torch.einsum('bndt,tv->bndv', x_trans, self.Q_out_mat) + self.delta2
|
||||
# added on 25/1/30
|
||||
# x = F.gelu(x)
|
||||
|
||||
# [B, N, tau, D]
|
||||
x = x.transpose(-1, -2)
|
||||
return x
|
||||
|
||||
def forward(self, x, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None):
|
||||
# x: [Batch, Input length, Channel]
|
||||
B, T, N = x.shape
|
||||
|
||||
# revin norm
|
||||
x = self.revin_layer(x, mode='norm')
|
||||
x_ori = x
|
||||
|
||||
# ########### frequency (high-level) part ##########
|
||||
# input fre fine-tuning
|
||||
# [B, T, N]
|
||||
# embedding x: [B, N, T, D]
|
||||
x = self.tokenEmb(x_ori, self.embeddings)
|
||||
# [B, N, tau, D]
|
||||
x = self.Fre_Trans(x)
|
||||
|
||||
# linear
|
||||
# [B, N, tau*D] --> [B, N, dim] --> [B, N, tau] --> [B, tau, N]
|
||||
out = self.fc(x.flatten(-2)).transpose(-1, -2)
|
||||
|
||||
# dropout
|
||||
out = self.dropout(out)
|
||||
|
||||
# revin denorm
|
||||
out = self.revin_layer(out, mode='denorm')
|
||||
|
||||
return out
|
Reference in New Issue
Block a user