commit dc8c9f1f09162368445f749600468792e73882e9 Author: gameloader Date: Tue Jul 1 21:00:23 2025 +0800 feat(core): add initial TSModel package with OLinear and RevIN diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..6b4981b --- /dev/null +++ b/__init__.py @@ -0,0 +1,10 @@ +""" +TSModel - Time Series Modeling Package +""" + +__version__ = "0.1.0" + +# 导入主要模块,方便外部使用 +from .models import OLinear, RevIN + +__all__ = ['OLinear', 'RevIN'] diff --git a/models/OLinear/Trans_EncDec.py b/models/OLinear/Trans_EncDec.py new file mode 100644 index 0000000..c29d0f3 --- /dev/null +++ b/models/OLinear/Trans_EncDec.py @@ -0,0 +1,87 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import List + + +class Encoder_ori(nn.Module): + def __init__(self, attn_layers, conv_layers=None, norm_layer=None, one_output=False): + super(Encoder_ori, self).__init__() + self.attn_layers = nn.ModuleList(attn_layers) + self.norm = norm_layer + self.one_output = one_output + + def forward(self, x, attn_mask=None, tau=None, delta=None): + # x [B, nvars, D] + attns = [] + X0 = None # to make Pycharm happy + layer_len = len(self.attn_layers) + for i, attn_layer in enumerate(self.attn_layers): + x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) + attns.append(attn) + + if not self.training and layer_len > 1: + if i == 0: + X0 = x + + if isinstance(x, tuple) or isinstance(x, List): + x = x[0] + + if self.norm is not None: + x = self.norm(x) + + if self.one_output: + return x + else: + return x, attns + + +class LinearEncoder(nn.Module): + def __init__(self, d_model, d_ff=None, CovMat=None, dropout=0.1, activation="relu", token_num=None, **kwargs): + super(LinearEncoder, self).__init__() + + d_ff = d_ff or 4 * d_model + self.d_model = d_model + self.d_ff = d_ff + self.CovMat = CovMat.unsqueeze(0) if CovMat is not None else None + self.token_num = token_num + + self.norm1 = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + + # attention --> linear + self.v_proj = nn.Linear(d_model, d_model) + self.out_proj = nn.Linear(d_model, d_model) + + init_weight_mat = torch.eye(self.token_num) * 1.0 + torch.randn(self.token_num, self.token_num) * 1.0 + self.weight_mat = nn.Parameter(init_weight_mat[None, :, :]) + + # self.bias = nn.Parameter(torch.zeros(1, 1, self.d_model)) + + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) + self.activation = F.relu if activation == "relu" else F.gelu + self.norm2 = nn.LayerNorm(d_model) + + def forward(self, x, **kwargs): + # x.shape: b, l, d_model + values = self.v_proj(x) + + if self.CovMat is not None: + A = F.softmax(self.CovMat, dim=-1) + F.softplus(self.weight_mat) + else: + A = F.softplus(self.weight_mat) + + A = F.normalize(A, p=1, dim=-1) + A = self.dropout(A) + + new_x = A @ values # + self.bias + + x = x + self.dropout(self.out_proj(new_x)) + x = self.norm1(x) + + y = self.dropout(self.activation(self.conv1(x.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + output = self.norm2(x + y) + + return output, None diff --git a/models/OLinear/__init__.py b/models/OLinear/__init__.py new file mode 100644 index 0000000..882a35a --- /dev/null +++ b/models/OLinear/__init__.py @@ -0,0 +1,8 @@ +""" +OLinear model implementation +""" + +from .model import * # 导入model.py中的类和函数 + +# 如果有特定的类名,可以明确指定 +__all__ = ['OLinear'] # 根据实际的类名调整 diff --git a/models/OLinear/model.py b/models/OLinear/model.py new file mode 100644 index 0000000..1974e97 --- /dev/null +++ b/models/OLinear/model.py @@ -0,0 +1,150 @@ +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from RevIN import RevIN +from Trans_EncDec import Encoder_ori, LinearEncoder + + + +class OLinear(nn.Module): + def __init__(self, configs): + super(OLinear, self).__init__() + self.pred_len = configs.pred_len + self.enc_in = configs.enc_in # channels + self.seq_len = configs.seq_len + self.hidden_size = self.d_model = configs.d_model # hidden_size + self.d_ff = configs.d_ff # d_ff + + self.Q_chan_indep = configs.Q_chan_indep + + q_mat_dir = configs.Q_MAT_file if self.Q_chan_indep else configs.q_mat_file + if not os.path.isfile(q_mat_dir): + q_mat_dir = os.path.join(configs.root_path, q_mat_dir) + assert os.path.isfile(q_mat_dir) + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.Q_mat = torch.from_numpy(np.load(q_mat_dir)).to(torch.float32).to(device) + + assert (self.Q_mat.ndim == 3 if self.Q_chan_indep else self.Q_mat.ndim == 2) + assert (self.Q_mat.shape[0] == self.enc_in if self.Q_chan_indep else self.Q_mat.shape[0] == self.seq_len) + + q_out_mat_dir = configs.Q_OUT_MAT_file if self.Q_chan_indep else configs.q_out_mat_file + if not os.path.isfile(q_out_mat_dir): + q_out_mat_dir = os.path.join(configs.root_path, q_out_mat_dir) + assert os.path.isfile(q_out_mat_dir) + self.Q_out_mat = torch.from_numpy(np.load(q_out_mat_dir)).to(torch.float32).to(device) + + assert (self.Q_out_mat.ndim == 3 if self.Q_chan_indep else self.Q_out_mat.ndim == 2) + assert (self.Q_out_mat.shape[0] == self.enc_in if self.Q_chan_indep else + self.Q_out_mat.shape[0] == self.pred_len) + + self.patch_len = configs.temp_patch_len + self.stride = configs.temp_stride + + # self.channel_independence = configs.channel_independence + self.embed_size = configs.embed_size # embed_size + self.embeddings = nn.Parameter(torch.randn(1, self.embed_size)) + + self.fc = nn.Sequential( + nn.Linear(self.pred_len * self.embed_size, self.d_ff), + nn.GELU(), + nn.Linear(self.d_ff, self.pred_len) + ) + + # for final input and output + self.revin_layer = RevIN(self.enc_in, affine=True) + self.dropout = nn.Dropout(configs.dropout) + + # ############# transformer related ######### + self.encoder = Encoder_ori( + [ + LinearEncoder( + d_model=configs.d_model, d_ff=configs.d_ff, CovMat=None, + dropout=configs.dropout, activation=configs.activation, token_num=self.enc_in, + ) for _ in range(configs.e_layers) + ], + norm_layer=nn.LayerNorm(configs.d_model), + one_output=True, + CKA_flag=configs.CKA_flag + ) + self.ortho_trans = nn.Sequential( + nn.Linear(self.seq_len * self.embed_size, self.d_model), + self.encoder, + nn.Linear(self.d_model, self.pred_len * self.embed_size) + ) + + # learnable delta + self.delta1 = nn.Parameter(torch.zeros(1, self.enc_in, 1, self.seq_len)) + self.delta2 = nn.Parameter(torch.zeros(1, self.enc_in, 1, self.pred_len)) + + # dimension extension + def tokenEmb(self, x, embeddings): + if self.embed_size <= 1: + return x.transpose(-1, -2).unsqueeze(-1) + # x: [B, T, N] --> [B, N, T] + x = x.transpose(-1, -2) + x = x.unsqueeze(-1) + # B*N*T*1 x 1*D = B*N*T*D + return x * embeddings + + def Fre_Trans(self, x): + # [B, N, T, D] + B, N, T, D = x.shape + assert T == self.seq_len + # [B, N, D, T] + x = x.transpose(-1, -2) + + # orthogonal transformation + # [B, N, D, T] + if self.Q_chan_indep: + x_trans = torch.einsum('bndt,ntv->bndv', x, self.Q_mat.transpose(-1, -2)) + else: + x_trans = torch.einsum('bndt,tv->bndv', x, self.Q_mat.transpose(-1, -2)) + self.delta1 + # added on 25/1/30 + # x_trans = F.gelu(x_trans) + # [B, N, D, T] + assert x_trans.shape[-1] == self.seq_len + + # ########## transformer #### + x_trans = self.ortho_trans(x_trans.flatten(-2)).reshape(B, N, D, self.pred_len) + + # [B, N, D, tau]; orthogonal transformation + if self.Q_chan_indep: + x = torch.einsum('bndt,ntv->bndv', x_trans, self.Q_out_mat) + else: + x = torch.einsum('bndt,tv->bndv', x_trans, self.Q_out_mat) + self.delta2 + # added on 25/1/30 + # x = F.gelu(x) + + # [B, N, tau, D] + x = x.transpose(-1, -2) + return x + + def forward(self, x, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None): + # x: [Batch, Input length, Channel] + B, T, N = x.shape + + # revin norm + x = self.revin_layer(x, mode='norm') + x_ori = x + + # ########### frequency (high-level) part ########## + # input fre fine-tuning + # [B, T, N] + # embedding x: [B, N, T, D] + x = self.tokenEmb(x_ori, self.embeddings) + # [B, N, tau, D] + x = self.Fre_Trans(x) + + # linear + # [B, N, tau*D] --> [B, N, dim] --> [B, N, tau] --> [B, tau, N] + out = self.fc(x.flatten(-2)).transpose(-1, -2) + + # dropout + out = self.dropout(out) + + # revin denorm + out = self.revin_layer(out, mode='denorm') + + return out diff --git a/models/RevIN/__init__.py b/models/RevIN/__init__.py new file mode 100644 index 0000000..4a10a05 --- /dev/null +++ b/models/RevIN/__init__.py @@ -0,0 +1 @@ +from model import * diff --git a/models/RevIN/model.py b/models/RevIN/model.py new file mode 100644 index 0000000..4de724c --- /dev/null +++ b/models/RevIN/model.py @@ -0,0 +1,55 @@ +import torch +import torch.nn as nn + + +class RevIN(nn.Module): + def __init__(self, num_features: int, eps=1e-5, affine=True): + """ + RevIN通过对输入层数据进行标准化并进行线性变换,并在输出层进行参数相同的反标准化得到最终输出。可作为可插入模块用于多种时序神经网络中。 + RevIN standardizes input layer data and performs linear transformation, then applies denormalization with the same parameters at the output layer to obtain the final output. It can serve as a plug-in module for various time series neural networks. + + :param num_features: the number of features or channels + :param eps: a value added for numerical stability + :param affine: if True, RevIN has learnable affine parameters + """ + super(RevIN, self).__init__() + self.num_features = num_features + self.eps = eps + self.affine = affine + if self.affine: + self._init_params() + + def forward(self, x, mode:str): + if mode == 'norm': + self._get_statistics(x) + x = self._normalize(x) + elif mode == 'denorm': + x = self._denormalize(x) + else: raise NotImplementedError + return x + + def _init_params(self): + # initialize RevIN params: (C,) + self.affine_weight = nn.Parameter(torch.ones(self.num_features)) + self.affine_bias = nn.Parameter(torch.zeros(self.num_features)) + + def _get_statistics(self, x): + dim2reduce = tuple(range(1, x.ndim-1)) + self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach() + self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach() + + def _normalize(self, x): + x = x - self.mean + x = x / self.stdev + if self.affine: + x = x * self.affine_weight + x = x + self.affine_bias + return x + + def _denormalize(self, x): + if self.affine: + x = x - self.affine_bias + x = x / (self.affine_weight + self.eps*self.eps) + x = x * self.stdev + x = x + self.mean + return x diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..08b9749 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,8 @@ +""" +Models module for time series forecasting +""" + +from . import OLinear +from . import RevIN + +__all__ = ['OLinear', 'RevIN'] diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..938223c --- /dev/null +++ b/setup.py @@ -0,0 +1,16 @@ +from setuptools import setup, find_packages + +setup( + name="tsmodel", + version="0.1.0", + packages=find_packages(), + install_requires=[ + # 添加你的依赖包 + "numpy", + "torch", + # 其他依赖... + ], + author="Gameloader", + description="Time Series Modeling Package", + python_requires=">=3.10", +)