feat(core): add initial TSModel package with OLinear and RevIN
This commit is contained in:
10
__init__.py
Normal file
10
__init__.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
"""
|
||||||
|
TSModel - Time Series Modeling Package
|
||||||
|
"""
|
||||||
|
|
||||||
|
__version__ = "0.1.0"
|
||||||
|
|
||||||
|
# 导入主要模块,方便外部使用
|
||||||
|
from .models import OLinear, RevIN
|
||||||
|
|
||||||
|
__all__ = ['OLinear', 'RevIN']
|
87
models/OLinear/Trans_EncDec.py
Normal file
87
models/OLinear/Trans_EncDec.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder_ori(nn.Module):
|
||||||
|
def __init__(self, attn_layers, conv_layers=None, norm_layer=None, one_output=False):
|
||||||
|
super(Encoder_ori, self).__init__()
|
||||||
|
self.attn_layers = nn.ModuleList(attn_layers)
|
||||||
|
self.norm = norm_layer
|
||||||
|
self.one_output = one_output
|
||||||
|
|
||||||
|
def forward(self, x, attn_mask=None, tau=None, delta=None):
|
||||||
|
# x [B, nvars, D]
|
||||||
|
attns = []
|
||||||
|
X0 = None # to make Pycharm happy
|
||||||
|
layer_len = len(self.attn_layers)
|
||||||
|
for i, attn_layer in enumerate(self.attn_layers):
|
||||||
|
x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
|
||||||
|
attns.append(attn)
|
||||||
|
|
||||||
|
if not self.training and layer_len > 1:
|
||||||
|
if i == 0:
|
||||||
|
X0 = x
|
||||||
|
|
||||||
|
if isinstance(x, tuple) or isinstance(x, List):
|
||||||
|
x = x[0]
|
||||||
|
|
||||||
|
if self.norm is not None:
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
if self.one_output:
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
return x, attns
|
||||||
|
|
||||||
|
|
||||||
|
class LinearEncoder(nn.Module):
|
||||||
|
def __init__(self, d_model, d_ff=None, CovMat=None, dropout=0.1, activation="relu", token_num=None, **kwargs):
|
||||||
|
super(LinearEncoder, self).__init__()
|
||||||
|
|
||||||
|
d_ff = d_ff or 4 * d_model
|
||||||
|
self.d_model = d_model
|
||||||
|
self.d_ff = d_ff
|
||||||
|
self.CovMat = CovMat.unsqueeze(0) if CovMat is not None else None
|
||||||
|
self.token_num = token_num
|
||||||
|
|
||||||
|
self.norm1 = nn.LayerNorm(d_model)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
# attention --> linear
|
||||||
|
self.v_proj = nn.Linear(d_model, d_model)
|
||||||
|
self.out_proj = nn.Linear(d_model, d_model)
|
||||||
|
|
||||||
|
init_weight_mat = torch.eye(self.token_num) * 1.0 + torch.randn(self.token_num, self.token_num) * 1.0
|
||||||
|
self.weight_mat = nn.Parameter(init_weight_mat[None, :, :])
|
||||||
|
|
||||||
|
# self.bias = nn.Parameter(torch.zeros(1, 1, self.d_model))
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
|
||||||
|
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
|
||||||
|
self.activation = F.relu if activation == "relu" else F.gelu
|
||||||
|
self.norm2 = nn.LayerNorm(d_model)
|
||||||
|
|
||||||
|
def forward(self, x, **kwargs):
|
||||||
|
# x.shape: b, l, d_model
|
||||||
|
values = self.v_proj(x)
|
||||||
|
|
||||||
|
if self.CovMat is not None:
|
||||||
|
A = F.softmax(self.CovMat, dim=-1) + F.softplus(self.weight_mat)
|
||||||
|
else:
|
||||||
|
A = F.softplus(self.weight_mat)
|
||||||
|
|
||||||
|
A = F.normalize(A, p=1, dim=-1)
|
||||||
|
A = self.dropout(A)
|
||||||
|
|
||||||
|
new_x = A @ values # + self.bias
|
||||||
|
|
||||||
|
x = x + self.dropout(self.out_proj(new_x))
|
||||||
|
x = self.norm1(x)
|
||||||
|
|
||||||
|
y = self.dropout(self.activation(self.conv1(x.transpose(-1, 1))))
|
||||||
|
y = self.dropout(self.conv2(y).transpose(-1, 1))
|
||||||
|
output = self.norm2(x + y)
|
||||||
|
|
||||||
|
return output, None
|
8
models/OLinear/__init__.py
Normal file
8
models/OLinear/__init__.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
"""
|
||||||
|
OLinear model implementation
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .model import * # 导入model.py中的类和函数
|
||||||
|
|
||||||
|
# 如果有特定的类名,可以明确指定
|
||||||
|
__all__ = ['OLinear'] # 根据实际的类名调整
|
150
models/OLinear/model.py
Normal file
150
models/OLinear/model.py
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
from RevIN import RevIN
|
||||||
|
from Trans_EncDec import Encoder_ori, LinearEncoder
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class OLinear(nn.Module):
|
||||||
|
def __init__(self, configs):
|
||||||
|
super(OLinear, self).__init__()
|
||||||
|
self.pred_len = configs.pred_len
|
||||||
|
self.enc_in = configs.enc_in # channels
|
||||||
|
self.seq_len = configs.seq_len
|
||||||
|
self.hidden_size = self.d_model = configs.d_model # hidden_size
|
||||||
|
self.d_ff = configs.d_ff # d_ff
|
||||||
|
|
||||||
|
self.Q_chan_indep = configs.Q_chan_indep
|
||||||
|
|
||||||
|
q_mat_dir = configs.Q_MAT_file if self.Q_chan_indep else configs.q_mat_file
|
||||||
|
if not os.path.isfile(q_mat_dir):
|
||||||
|
q_mat_dir = os.path.join(configs.root_path, q_mat_dir)
|
||||||
|
assert os.path.isfile(q_mat_dir)
|
||||||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
self.Q_mat = torch.from_numpy(np.load(q_mat_dir)).to(torch.float32).to(device)
|
||||||
|
|
||||||
|
assert (self.Q_mat.ndim == 3 if self.Q_chan_indep else self.Q_mat.ndim == 2)
|
||||||
|
assert (self.Q_mat.shape[0] == self.enc_in if self.Q_chan_indep else self.Q_mat.shape[0] == self.seq_len)
|
||||||
|
|
||||||
|
q_out_mat_dir = configs.Q_OUT_MAT_file if self.Q_chan_indep else configs.q_out_mat_file
|
||||||
|
if not os.path.isfile(q_out_mat_dir):
|
||||||
|
q_out_mat_dir = os.path.join(configs.root_path, q_out_mat_dir)
|
||||||
|
assert os.path.isfile(q_out_mat_dir)
|
||||||
|
self.Q_out_mat = torch.from_numpy(np.load(q_out_mat_dir)).to(torch.float32).to(device)
|
||||||
|
|
||||||
|
assert (self.Q_out_mat.ndim == 3 if self.Q_chan_indep else self.Q_out_mat.ndim == 2)
|
||||||
|
assert (self.Q_out_mat.shape[0] == self.enc_in if self.Q_chan_indep else
|
||||||
|
self.Q_out_mat.shape[0] == self.pred_len)
|
||||||
|
|
||||||
|
self.patch_len = configs.temp_patch_len
|
||||||
|
self.stride = configs.temp_stride
|
||||||
|
|
||||||
|
# self.channel_independence = configs.channel_independence
|
||||||
|
self.embed_size = configs.embed_size # embed_size
|
||||||
|
self.embeddings = nn.Parameter(torch.randn(1, self.embed_size))
|
||||||
|
|
||||||
|
self.fc = nn.Sequential(
|
||||||
|
nn.Linear(self.pred_len * self.embed_size, self.d_ff),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(self.d_ff, self.pred_len)
|
||||||
|
)
|
||||||
|
|
||||||
|
# for final input and output
|
||||||
|
self.revin_layer = RevIN(self.enc_in, affine=True)
|
||||||
|
self.dropout = nn.Dropout(configs.dropout)
|
||||||
|
|
||||||
|
# ############# transformer related #########
|
||||||
|
self.encoder = Encoder_ori(
|
||||||
|
[
|
||||||
|
LinearEncoder(
|
||||||
|
d_model=configs.d_model, d_ff=configs.d_ff, CovMat=None,
|
||||||
|
dropout=configs.dropout, activation=configs.activation, token_num=self.enc_in,
|
||||||
|
) for _ in range(configs.e_layers)
|
||||||
|
],
|
||||||
|
norm_layer=nn.LayerNorm(configs.d_model),
|
||||||
|
one_output=True,
|
||||||
|
CKA_flag=configs.CKA_flag
|
||||||
|
)
|
||||||
|
self.ortho_trans = nn.Sequential(
|
||||||
|
nn.Linear(self.seq_len * self.embed_size, self.d_model),
|
||||||
|
self.encoder,
|
||||||
|
nn.Linear(self.d_model, self.pred_len * self.embed_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
# learnable delta
|
||||||
|
self.delta1 = nn.Parameter(torch.zeros(1, self.enc_in, 1, self.seq_len))
|
||||||
|
self.delta2 = nn.Parameter(torch.zeros(1, self.enc_in, 1, self.pred_len))
|
||||||
|
|
||||||
|
# dimension extension
|
||||||
|
def tokenEmb(self, x, embeddings):
|
||||||
|
if self.embed_size <= 1:
|
||||||
|
return x.transpose(-1, -2).unsqueeze(-1)
|
||||||
|
# x: [B, T, N] --> [B, N, T]
|
||||||
|
x = x.transpose(-1, -2)
|
||||||
|
x = x.unsqueeze(-1)
|
||||||
|
# B*N*T*1 x 1*D = B*N*T*D
|
||||||
|
return x * embeddings
|
||||||
|
|
||||||
|
def Fre_Trans(self, x):
|
||||||
|
# [B, N, T, D]
|
||||||
|
B, N, T, D = x.shape
|
||||||
|
assert T == self.seq_len
|
||||||
|
# [B, N, D, T]
|
||||||
|
x = x.transpose(-1, -2)
|
||||||
|
|
||||||
|
# orthogonal transformation
|
||||||
|
# [B, N, D, T]
|
||||||
|
if self.Q_chan_indep:
|
||||||
|
x_trans = torch.einsum('bndt,ntv->bndv', x, self.Q_mat.transpose(-1, -2))
|
||||||
|
else:
|
||||||
|
x_trans = torch.einsum('bndt,tv->bndv', x, self.Q_mat.transpose(-1, -2)) + self.delta1
|
||||||
|
# added on 25/1/30
|
||||||
|
# x_trans = F.gelu(x_trans)
|
||||||
|
# [B, N, D, T]
|
||||||
|
assert x_trans.shape[-1] == self.seq_len
|
||||||
|
|
||||||
|
# ########## transformer ####
|
||||||
|
x_trans = self.ortho_trans(x_trans.flatten(-2)).reshape(B, N, D, self.pred_len)
|
||||||
|
|
||||||
|
# [B, N, D, tau]; orthogonal transformation
|
||||||
|
if self.Q_chan_indep:
|
||||||
|
x = torch.einsum('bndt,ntv->bndv', x_trans, self.Q_out_mat)
|
||||||
|
else:
|
||||||
|
x = torch.einsum('bndt,tv->bndv', x_trans, self.Q_out_mat) + self.delta2
|
||||||
|
# added on 25/1/30
|
||||||
|
# x = F.gelu(x)
|
||||||
|
|
||||||
|
# [B, N, tau, D]
|
||||||
|
x = x.transpose(-1, -2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None):
|
||||||
|
# x: [Batch, Input length, Channel]
|
||||||
|
B, T, N = x.shape
|
||||||
|
|
||||||
|
# revin norm
|
||||||
|
x = self.revin_layer(x, mode='norm')
|
||||||
|
x_ori = x
|
||||||
|
|
||||||
|
# ########### frequency (high-level) part ##########
|
||||||
|
# input fre fine-tuning
|
||||||
|
# [B, T, N]
|
||||||
|
# embedding x: [B, N, T, D]
|
||||||
|
x = self.tokenEmb(x_ori, self.embeddings)
|
||||||
|
# [B, N, tau, D]
|
||||||
|
x = self.Fre_Trans(x)
|
||||||
|
|
||||||
|
# linear
|
||||||
|
# [B, N, tau*D] --> [B, N, dim] --> [B, N, tau] --> [B, tau, N]
|
||||||
|
out = self.fc(x.flatten(-2)).transpose(-1, -2)
|
||||||
|
|
||||||
|
# dropout
|
||||||
|
out = self.dropout(out)
|
||||||
|
|
||||||
|
# revin denorm
|
||||||
|
out = self.revin_layer(out, mode='denorm')
|
||||||
|
|
||||||
|
return out
|
1
models/RevIN/__init__.py
Normal file
1
models/RevIN/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from model import *
|
55
models/RevIN/model.py
Normal file
55
models/RevIN/model.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class RevIN(nn.Module):
|
||||||
|
def __init__(self, num_features: int, eps=1e-5, affine=True):
|
||||||
|
"""
|
||||||
|
RevIN通过对输入层数据进行标准化并进行线性变换,并在输出层进行参数相同的反标准化得到最终输出。可作为可插入模块用于多种时序神经网络中。
|
||||||
|
RevIN standardizes input layer data and performs linear transformation, then applies denormalization with the same parameters at the output layer to obtain the final output. It can serve as a plug-in module for various time series neural networks.
|
||||||
|
|
||||||
|
:param num_features: the number of features or channels
|
||||||
|
:param eps: a value added for numerical stability
|
||||||
|
:param affine: if True, RevIN has learnable affine parameters
|
||||||
|
"""
|
||||||
|
super(RevIN, self).__init__()
|
||||||
|
self.num_features = num_features
|
||||||
|
self.eps = eps
|
||||||
|
self.affine = affine
|
||||||
|
if self.affine:
|
||||||
|
self._init_params()
|
||||||
|
|
||||||
|
def forward(self, x, mode:str):
|
||||||
|
if mode == 'norm':
|
||||||
|
self._get_statistics(x)
|
||||||
|
x = self._normalize(x)
|
||||||
|
elif mode == 'denorm':
|
||||||
|
x = self._denormalize(x)
|
||||||
|
else: raise NotImplementedError
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _init_params(self):
|
||||||
|
# initialize RevIN params: (C,)
|
||||||
|
self.affine_weight = nn.Parameter(torch.ones(self.num_features))
|
||||||
|
self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
|
||||||
|
|
||||||
|
def _get_statistics(self, x):
|
||||||
|
dim2reduce = tuple(range(1, x.ndim-1))
|
||||||
|
self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
|
||||||
|
self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()
|
||||||
|
|
||||||
|
def _normalize(self, x):
|
||||||
|
x = x - self.mean
|
||||||
|
x = x / self.stdev
|
||||||
|
if self.affine:
|
||||||
|
x = x * self.affine_weight
|
||||||
|
x = x + self.affine_bias
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _denormalize(self, x):
|
||||||
|
if self.affine:
|
||||||
|
x = x - self.affine_bias
|
||||||
|
x = x / (self.affine_weight + self.eps*self.eps)
|
||||||
|
x = x * self.stdev
|
||||||
|
x = x + self.mean
|
||||||
|
return x
|
8
models/__init__.py
Normal file
8
models/__init__.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
"""
|
||||||
|
Models module for time series forecasting
|
||||||
|
"""
|
||||||
|
|
||||||
|
from . import OLinear
|
||||||
|
from . import RevIN
|
||||||
|
|
||||||
|
__all__ = ['OLinear', 'RevIN']
|
16
setup.py
Normal file
16
setup.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name="tsmodel",
|
||||||
|
version="0.1.0",
|
||||||
|
packages=find_packages(),
|
||||||
|
install_requires=[
|
||||||
|
# 添加你的依赖包
|
||||||
|
"numpy",
|
||||||
|
"torch",
|
||||||
|
# 其他依赖...
|
||||||
|
],
|
||||||
|
author="Gameloader",
|
||||||
|
description="Time Series Modeling Package",
|
||||||
|
python_requires=">=3.10",
|
||||||
|
)
|
Reference in New Issue
Block a user