feat(core): add initial TSModel package with OLinear and RevIN

This commit is contained in:
gameloader 2025-07-01 21:00:23 +08:00
commit dc8c9f1f09
8 changed files with 335 additions and 0 deletions

10
__init__.py Normal file
View File

@ -0,0 +1,10 @@
"""
TSModel - Time Series Modeling Package
"""
__version__ = "0.1.0"
# 导入主要模块,方便外部使用
from .models import OLinear, RevIN
__all__ = ['OLinear', 'RevIN']

View File

@ -0,0 +1,87 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List
class Encoder_ori(nn.Module):
def __init__(self, attn_layers, conv_layers=None, norm_layer=None, one_output=False):
super(Encoder_ori, self).__init__()
self.attn_layers = nn.ModuleList(attn_layers)
self.norm = norm_layer
self.one_output = one_output
def forward(self, x, attn_mask=None, tau=None, delta=None):
# x [B, nvars, D]
attns = []
X0 = None # to make Pycharm happy
layer_len = len(self.attn_layers)
for i, attn_layer in enumerate(self.attn_layers):
x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
attns.append(attn)
if not self.training and layer_len > 1:
if i == 0:
X0 = x
if isinstance(x, tuple) or isinstance(x, List):
x = x[0]
if self.norm is not None:
x = self.norm(x)
if self.one_output:
return x
else:
return x, attns
class LinearEncoder(nn.Module):
def __init__(self, d_model, d_ff=None, CovMat=None, dropout=0.1, activation="relu", token_num=None, **kwargs):
super(LinearEncoder, self).__init__()
d_ff = d_ff or 4 * d_model
self.d_model = d_model
self.d_ff = d_ff
self.CovMat = CovMat.unsqueeze(0) if CovMat is not None else None
self.token_num = token_num
self.norm1 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
# attention --> linear
self.v_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
init_weight_mat = torch.eye(self.token_num) * 1.0 + torch.randn(self.token_num, self.token_num) * 1.0
self.weight_mat = nn.Parameter(init_weight_mat[None, :, :])
# self.bias = nn.Parameter(torch.zeros(1, 1, self.d_model))
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
self.activation = F.relu if activation == "relu" else F.gelu
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x, **kwargs):
# x.shape: b, l, d_model
values = self.v_proj(x)
if self.CovMat is not None:
A = F.softmax(self.CovMat, dim=-1) + F.softplus(self.weight_mat)
else:
A = F.softplus(self.weight_mat)
A = F.normalize(A, p=1, dim=-1)
A = self.dropout(A)
new_x = A @ values # + self.bias
x = x + self.dropout(self.out_proj(new_x))
x = self.norm1(x)
y = self.dropout(self.activation(self.conv1(x.transpose(-1, 1))))
y = self.dropout(self.conv2(y).transpose(-1, 1))
output = self.norm2(x + y)
return output, None

View File

@ -0,0 +1,8 @@
"""
OLinear model implementation
"""
from .model import * # 导入model.py中的类和函数
# 如果有特定的类名,可以明确指定
__all__ = ['OLinear'] # 根据实际的类名调整

150
models/OLinear/model.py Normal file
View File

@ -0,0 +1,150 @@
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from RevIN import RevIN
from Trans_EncDec import Encoder_ori, LinearEncoder
class OLinear(nn.Module):
def __init__(self, configs):
super(OLinear, self).__init__()
self.pred_len = configs.pred_len
self.enc_in = configs.enc_in # channels
self.seq_len = configs.seq_len
self.hidden_size = self.d_model = configs.d_model # hidden_size
self.d_ff = configs.d_ff # d_ff
self.Q_chan_indep = configs.Q_chan_indep
q_mat_dir = configs.Q_MAT_file if self.Q_chan_indep else configs.q_mat_file
if not os.path.isfile(q_mat_dir):
q_mat_dir = os.path.join(configs.root_path, q_mat_dir)
assert os.path.isfile(q_mat_dir)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.Q_mat = torch.from_numpy(np.load(q_mat_dir)).to(torch.float32).to(device)
assert (self.Q_mat.ndim == 3 if self.Q_chan_indep else self.Q_mat.ndim == 2)
assert (self.Q_mat.shape[0] == self.enc_in if self.Q_chan_indep else self.Q_mat.shape[0] == self.seq_len)
q_out_mat_dir = configs.Q_OUT_MAT_file if self.Q_chan_indep else configs.q_out_mat_file
if not os.path.isfile(q_out_mat_dir):
q_out_mat_dir = os.path.join(configs.root_path, q_out_mat_dir)
assert os.path.isfile(q_out_mat_dir)
self.Q_out_mat = torch.from_numpy(np.load(q_out_mat_dir)).to(torch.float32).to(device)
assert (self.Q_out_mat.ndim == 3 if self.Q_chan_indep else self.Q_out_mat.ndim == 2)
assert (self.Q_out_mat.shape[0] == self.enc_in if self.Q_chan_indep else
self.Q_out_mat.shape[0] == self.pred_len)
self.patch_len = configs.temp_patch_len
self.stride = configs.temp_stride
# self.channel_independence = configs.channel_independence
self.embed_size = configs.embed_size # embed_size
self.embeddings = nn.Parameter(torch.randn(1, self.embed_size))
self.fc = nn.Sequential(
nn.Linear(self.pred_len * self.embed_size, self.d_ff),
nn.GELU(),
nn.Linear(self.d_ff, self.pred_len)
)
# for final input and output
self.revin_layer = RevIN(self.enc_in, affine=True)
self.dropout = nn.Dropout(configs.dropout)
# ############# transformer related #########
self.encoder = Encoder_ori(
[
LinearEncoder(
d_model=configs.d_model, d_ff=configs.d_ff, CovMat=None,
dropout=configs.dropout, activation=configs.activation, token_num=self.enc_in,
) for _ in range(configs.e_layers)
],
norm_layer=nn.LayerNorm(configs.d_model),
one_output=True,
CKA_flag=configs.CKA_flag
)
self.ortho_trans = nn.Sequential(
nn.Linear(self.seq_len * self.embed_size, self.d_model),
self.encoder,
nn.Linear(self.d_model, self.pred_len * self.embed_size)
)
# learnable delta
self.delta1 = nn.Parameter(torch.zeros(1, self.enc_in, 1, self.seq_len))
self.delta2 = nn.Parameter(torch.zeros(1, self.enc_in, 1, self.pred_len))
# dimension extension
def tokenEmb(self, x, embeddings):
if self.embed_size <= 1:
return x.transpose(-1, -2).unsqueeze(-1)
# x: [B, T, N] --> [B, N, T]
x = x.transpose(-1, -2)
x = x.unsqueeze(-1)
# B*N*T*1 x 1*D = B*N*T*D
return x * embeddings
def Fre_Trans(self, x):
# [B, N, T, D]
B, N, T, D = x.shape
assert T == self.seq_len
# [B, N, D, T]
x = x.transpose(-1, -2)
# orthogonal transformation
# [B, N, D, T]
if self.Q_chan_indep:
x_trans = torch.einsum('bndt,ntv->bndv', x, self.Q_mat.transpose(-1, -2))
else:
x_trans = torch.einsum('bndt,tv->bndv', x, self.Q_mat.transpose(-1, -2)) + self.delta1
# added on 25/1/30
# x_trans = F.gelu(x_trans)
# [B, N, D, T]
assert x_trans.shape[-1] == self.seq_len
# ########## transformer ####
x_trans = self.ortho_trans(x_trans.flatten(-2)).reshape(B, N, D, self.pred_len)
# [B, N, D, tau]; orthogonal transformation
if self.Q_chan_indep:
x = torch.einsum('bndt,ntv->bndv', x_trans, self.Q_out_mat)
else:
x = torch.einsum('bndt,tv->bndv', x_trans, self.Q_out_mat) + self.delta2
# added on 25/1/30
# x = F.gelu(x)
# [B, N, tau, D]
x = x.transpose(-1, -2)
return x
def forward(self, x, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None):
# x: [Batch, Input length, Channel]
B, T, N = x.shape
# revin norm
x = self.revin_layer(x, mode='norm')
x_ori = x
# ########### frequency (high-level) part ##########
# input fre fine-tuning
# [B, T, N]
# embedding x: [B, N, T, D]
x = self.tokenEmb(x_ori, self.embeddings)
# [B, N, tau, D]
x = self.Fre_Trans(x)
# linear
# [B, N, tau*D] --> [B, N, dim] --> [B, N, tau] --> [B, tau, N]
out = self.fc(x.flatten(-2)).transpose(-1, -2)
# dropout
out = self.dropout(out)
# revin denorm
out = self.revin_layer(out, mode='denorm')
return out

1
models/RevIN/__init__.py Normal file
View File

@ -0,0 +1 @@
from model import *

55
models/RevIN/model.py Normal file
View File

@ -0,0 +1,55 @@
import torch
import torch.nn as nn
class RevIN(nn.Module):
def __init__(self, num_features: int, eps=1e-5, affine=True):
"""
RevIN通过对输入层数据进行标准化并进行线性变换并在输出层进行参数相同的反标准化得到最终输出可作为可插入模块用于多种时序神经网络中
RevIN standardizes input layer data and performs linear transformation, then applies denormalization with the same parameters at the output layer to obtain the final output. It can serve as a plug-in module for various time series neural networks.
:param num_features: the number of features or channels
:param eps: a value added for numerical stability
:param affine: if True, RevIN has learnable affine parameters
"""
super(RevIN, self).__init__()
self.num_features = num_features
self.eps = eps
self.affine = affine
if self.affine:
self._init_params()
def forward(self, x, mode:str):
if mode == 'norm':
self._get_statistics(x)
x = self._normalize(x)
elif mode == 'denorm':
x = self._denormalize(x)
else: raise NotImplementedError
return x
def _init_params(self):
# initialize RevIN params: (C,)
self.affine_weight = nn.Parameter(torch.ones(self.num_features))
self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
def _get_statistics(self, x):
dim2reduce = tuple(range(1, x.ndim-1))
self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()
def _normalize(self, x):
x = x - self.mean
x = x / self.stdev
if self.affine:
x = x * self.affine_weight
x = x + self.affine_bias
return x
def _denormalize(self, x):
if self.affine:
x = x - self.affine_bias
x = x / (self.affine_weight + self.eps*self.eps)
x = x * self.stdev
x = x + self.mean
return x

8
models/__init__.py Normal file
View File

@ -0,0 +1,8 @@
"""
Models module for time series forecasting
"""
from . import OLinear
from . import RevIN
__all__ = ['OLinear', 'RevIN']

16
setup.py Normal file
View File

@ -0,0 +1,16 @@
from setuptools import setup, find_packages
setup(
name="tsmodel",
version="0.1.0",
packages=find_packages(),
install_requires=[
# 添加你的依赖包
"numpy",
"torch",
# 其他依赖...
],
author="Gameloader",
description="Time Series Modeling Package",
python_requires=">=3.10",
)