feat(core): add initial TSModel package with OLinear and RevIN
This commit is contained in:
55
models/RevIN/model.py
Normal file
55
models/RevIN/model.py
Normal file
@ -0,0 +1,55 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class RevIN(nn.Module):
|
||||
def __init__(self, num_features: int, eps=1e-5, affine=True):
|
||||
"""
|
||||
RevIN通过对输入层数据进行标准化并进行线性变换,并在输出层进行参数相同的反标准化得到最终输出。可作为可插入模块用于多种时序神经网络中。
|
||||
RevIN standardizes input layer data and performs linear transformation, then applies denormalization with the same parameters at the output layer to obtain the final output. It can serve as a plug-in module for various time series neural networks.
|
||||
|
||||
:param num_features: the number of features or channels
|
||||
:param eps: a value added for numerical stability
|
||||
:param affine: if True, RevIN has learnable affine parameters
|
||||
"""
|
||||
super(RevIN, self).__init__()
|
||||
self.num_features = num_features
|
||||
self.eps = eps
|
||||
self.affine = affine
|
||||
if self.affine:
|
||||
self._init_params()
|
||||
|
||||
def forward(self, x, mode:str):
|
||||
if mode == 'norm':
|
||||
self._get_statistics(x)
|
||||
x = self._normalize(x)
|
||||
elif mode == 'denorm':
|
||||
x = self._denormalize(x)
|
||||
else: raise NotImplementedError
|
||||
return x
|
||||
|
||||
def _init_params(self):
|
||||
# initialize RevIN params: (C,)
|
||||
self.affine_weight = nn.Parameter(torch.ones(self.num_features))
|
||||
self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
|
||||
|
||||
def _get_statistics(self, x):
|
||||
dim2reduce = tuple(range(1, x.ndim-1))
|
||||
self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
|
||||
self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()
|
||||
|
||||
def _normalize(self, x):
|
||||
x = x - self.mean
|
||||
x = x / self.stdev
|
||||
if self.affine:
|
||||
x = x * self.affine_weight
|
||||
x = x + self.affine_bias
|
||||
return x
|
||||
|
||||
def _denormalize(self, x):
|
||||
if self.affine:
|
||||
x = x - self.affine_bias
|
||||
x = x / (self.affine_weight + self.eps*self.eps)
|
||||
x = x * self.stdev
|
||||
x = x + self.mean
|
||||
return x
|
Reference in New Issue
Block a user