__all__ = ['Transpose', 'get_activation_fn', 'moving_avg', 'series_decomp', 'PositionalEncoding', 'SinCosPosEncoding', 'Coord2dPosEncoding', 'Coord1dPosEncoding', 'positional_encoding'] import torch from torch import nn import math class Transpose(nn.Module): def __init__(self, *dims, contiguous=False): super().__init__() self.dims, self.contiguous = dims, contiguous def forward(self, x): if self.contiguous: return x.transpose(*self.dims).contiguous() else: return x.transpose(*self.dims) def get_activation_fn(activation): if callable(activation): return activation() elif activation.lower() == "relu": return nn.ReLU() elif activation.lower() == "gelu": return nn.GELU() raise ValueError(f'{activation} is not available. You can use "relu", "gelu", or a callable') # decomposition class moving_avg(nn.Module): """ Moving average block to highlight the trend of time series """ def __init__(self, kernel_size, stride): super(moving_avg, self).__init__() self.kernel_size = kernel_size self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0) def forward(self, x): # padding on the both ends of time series front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1) end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1) x = torch.cat([front, x, end], dim=1) x = self.avg(x.permute(0, 2, 1)) x = x.permute(0, 2, 1) return x class series_decomp(nn.Module): """ Series decomposition block """ def __init__(self, kernel_size): super(series_decomp, self).__init__() self.moving_avg = moving_avg(kernel_size, stride=1) def forward(self, x): moving_mean = self.moving_avg(x) res = x - moving_mean return res, moving_mean # pos_encoding def PositionalEncoding(q_len, d_model, normalize=True): pe = torch.zeros(q_len, d_model) position = torch.arange(0, q_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) if normalize: pe = pe - pe.mean() pe = pe / (pe.std() * 10) return pe SinCosPosEncoding = PositionalEncoding def Coord2dPosEncoding(q_len, d_model, exponential=False, normalize=True, eps=1e-3, verbose=False): x = .5 if exponential else 1 i = 0 for i in range(100): cpe = 2 * (torch.linspace(0, 1, q_len).reshape(-1, 1) ** x) * (torch.linspace(0, 1, d_model).reshape(1, -1) ** x) - 1 # pv(f'{i:4.0f} {x:5.3f} {cpe.mean():+6.3f}', verbose) if abs(cpe.mean()) <= eps: break elif cpe.mean() > eps: x += .001 else: x -= .001 i += 1 if normalize: cpe = cpe - cpe.mean() cpe = cpe / (cpe.std() * 10) return cpe def Coord1dPosEncoding(q_len, exponential=False, normalize=True): cpe = (2 * (torch.linspace(0, 1, q_len).reshape(-1, 1)**(.5 if exponential else 1)) - 1) if normalize: cpe = cpe - cpe.mean() cpe = cpe / (cpe.std() * 10) return cpe def positional_encoding(pe, learn_pe, q_len, d_model): # Positional encoding if pe == None: W_pos = torch.empty((q_len, d_model)) # pe = None and learn_pe = False can be used to measure impact of pe nn.init.uniform_(W_pos, -0.02, 0.02) learn_pe = False elif pe == 'zero': W_pos = torch.empty((q_len, 1)) nn.init.uniform_(W_pos, -0.02, 0.02) elif pe == 'zeros': W_pos = torch.empty((q_len, d_model)) nn.init.uniform_(W_pos, -0.02, 0.02) elif pe == 'normal' or pe == 'gauss': W_pos = torch.zeros((q_len, 1)) torch.nn.init.normal_(W_pos, mean=0.0, std=0.1) elif pe == 'uniform': W_pos = torch.zeros((q_len, 1)) nn.init.uniform_(W_pos, a=0.0, b=0.1) elif pe == 'lin1d': W_pos = Coord1dPosEncoding(q_len, exponential=False, normalize=True) elif pe == 'exp1d': W_pos = Coord1dPosEncoding(q_len, exponential=True, normalize=True) elif pe == 'lin2d': W_pos = Coord2dPosEncoding(q_len, d_model, exponential=False, normalize=True) elif pe == 'exp2d': W_pos = Coord2dPosEncoding(q_len, d_model, exponential=True, normalize=True) elif pe == 'sincos': W_pos = PositionalEncoding(q_len, d_model, normalize=True) else: raise ValueError(f"{pe} is not a valid pe (positional encoder. Available types: 'gauss'=='normal', \ 'zeros', 'zero', uniform', 'lin1d', 'exp1d', 'lin2d', 'exp2d', 'sincos', None.)") return nn.Parameter(W_pos, requires_grad=learn_pe)