import torch import torch.nn as nn class RevIN(nn.Module): def __init__(self, num_features: int, eps=1e-5, affine=True): """ RevIN通过对输入层数据进行标准化并进行线性变换,并在输出层进行参数相同的反标准化得到最终输出。可作为可插入模块用于多种时序神经网络中。 RevIN standardizes input layer data and performs linear transformation, then applies denormalization with the same parameters at the output layer to obtain the final output. It can serve as a plug-in module for various time series neural networks. :param num_features: the number of features or channels :param eps: a value added for numerical stability :param affine: if True, RevIN has learnable affine parameters """ super(RevIN, self).__init__() self.num_features = num_features self.eps = eps self.affine = affine if self.affine: self._init_params() def forward(self, x, mode:str): if mode == 'norm': self._get_statistics(x) x = self._normalize(x) elif mode == 'denorm': x = self._denormalize(x) else: raise NotImplementedError return x def _init_params(self): # initialize RevIN params: (C,) self.affine_weight = nn.Parameter(torch.ones(self.num_features)) self.affine_bias = nn.Parameter(torch.zeros(self.num_features)) def _get_statistics(self, x): dim2reduce = tuple(range(1, x.ndim-1)) self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach() self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach() def _normalize(self, x): x = x - self.mean x = x / self.stdev if self.affine: x = x * self.affine_weight x = x + self.affine_bias return x def _denormalize(self, x): if self.affine: x = x - self.affine_bias x = x / (self.affine_weight + self.eps*self.eps) x = x * self.stdev x = x + self.mean return x