import torch import torch.nn as nn import torch.nn.functional as F import math class Splitting(nn.Module): def __init__(self): super(Splitting, self).__init__() def even(self, x): return x[:, ::2, :] def odd(self, x): return x[:, 1::2, :] def forward(self, x): # return the odd and even part return self.even(x), self.odd(x) class CausalConvBlock(nn.Module): def __init__(self, d_model, kernel_size=5, dropout=0.0): super(CausalConvBlock, self).__init__() module_list = [ nn.ReplicationPad1d((kernel_size - 1, kernel_size - 1)), nn.Conv1d(d_model, d_model, kernel_size=kernel_size), nn.LeakyReLU(negative_slope=0.01, inplace=True), nn.Dropout(dropout), nn.Conv1d(d_model, d_model, kernel_size=kernel_size), nn.Tanh() ] self.causal_conv = nn.Sequential(*module_list) def forward(self, x): return self.causal_conv(x) # return value is the same as input dimension class SCIBlock(nn.Module): def __init__(self, d_model, kernel_size=5, dropout=0.0): super(SCIBlock, self).__init__() self.splitting = Splitting() self.modules_even, self.modules_odd, self.interactor_even, self.interactor_odd = [CausalConvBlock(d_model) for _ in range(4)] def forward(self, x): x_even, x_odd = self.splitting(x) x_even = x_even.permute(0, 2, 1) x_odd = x_odd.permute(0, 2, 1) x_even_temp = x_even.mul(torch.exp(self.modules_even(x_odd))) x_odd_temp = x_odd.mul(torch.exp(self.modules_odd(x_even))) x_even_update = x_even_temp + self.interactor_even(x_odd_temp) x_odd_update = x_odd_temp - self.interactor_odd(x_even_temp) return x_even_update.permute(0, 2, 1), x_odd_update.permute(0, 2, 1) class SCINet(nn.Module): def __init__(self, d_model, current_level=3, kernel_size=5, dropout=0.0): super(SCINet, self).__init__() self.current_level = current_level self.working_block = SCIBlock(d_model, kernel_size, dropout) if current_level != 0: self.SCINet_Tree_odd = SCINet(d_model, current_level-1, kernel_size, dropout) self.SCINet_Tree_even = SCINet(d_model, current_level-1, kernel_size, dropout) def forward(self, x): odd_flag = False if x.shape[1] % 2 == 1: odd_flag = True x = torch.cat((x, x[:, -1:, :]), dim=1) x_even_update, x_odd_update = self.working_block(x) if odd_flag: x_odd_update = x_odd_update[:, :-1] if self.current_level == 0: return self.zip_up_the_pants(x_even_update, x_odd_update) else: return self.zip_up_the_pants(self.SCINet_Tree_even(x_even_update), self.SCINet_Tree_odd(x_odd_update)) def zip_up_the_pants(self, even, odd): even = even.permute(1, 0, 2) odd = odd.permute(1, 0, 2) even_len = even.shape[0] odd_len = odd.shape[0] min_len = min(even_len, odd_len) zipped_data = [] for i in range(min_len): zipped_data.append(even[i].unsqueeze(0)) zipped_data.append(odd[i].unsqueeze(0)) if even_len > odd_len: zipped_data.append(even[-1].unsqueeze(0)) return torch.cat(zipped_data,0).permute(1, 0, 2) class Model(nn.Module): def __init__(self, configs): super(Model, self).__init__() self.task_name = configs.task_name self.seq_len = configs.seq_len self.label_len = configs.label_len self.pred_len = configs.pred_len # You can set the number of SCINet stacks by argument "d_layers", but should choose 1 or 2. self.num_stacks = configs.d_layers if self.num_stacks == 1: self.sci_net_1 = SCINet(configs.enc_in, dropout=configs.dropout) self.projection_1 = nn.Conv1d(self.seq_len, self.seq_len + self.pred_len, kernel_size=1, stride=1, bias=False) else: self.sci_net_1, self.sci_net_2 = [SCINet(configs.enc_in, dropout=configs.dropout) for _ in range(2)] self.projection_1 = nn.Conv1d(self.seq_len, self.pred_len, kernel_size=1, stride=1, bias=False) self.projection_2 = nn.Conv1d(self.seq_len+self.pred_len, self.seq_len+self.pred_len, kernel_size = 1, bias = False) # For positional encoding self.pe_hidden_size = configs.enc_in if self.pe_hidden_size % 2 == 1: self.pe_hidden_size += 1 num_timescales = self.pe_hidden_size // 2 max_timescale = 10000.0 min_timescale = 1.0 log_timescale_increment = ( math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1, 1)) inv_timescales = min_timescale * torch.exp( torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment) self.register_buffer('inv_timescales', inv_timescales) def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) # [B,pred_len,C] dec_out = torch.cat([torch.zeros_like(x_enc), dec_out], dim=1) return dec_out # [B, T, D] return None def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): # Normalization from Non-stationary Transformer means = x_enc.mean(1, keepdim=True).detach() x_enc = x_enc - means stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) x_enc /= stdev # position-encoding pe = self.get_position_encoding(x_enc) if pe.shape[2] > x_enc.shape[2]: x_enc += pe[:, :, :-1] else: x_enc += self.get_position_encoding(x_enc) # SCINet dec_out = self.sci_net_1(x_enc) dec_out += x_enc dec_out = self.projection_1(dec_out) if self.num_stacks != 1: dec_out = torch.cat((x_enc, dec_out), dim=1) temp = dec_out dec_out = self.sci_net_2(dec_out) dec_out += temp dec_out = self.projection_2(dec_out) # De-Normalization from Non-stationary Transformer dec_out = dec_out * \ (stdev[:, 0, :].unsqueeze(1).repeat( 1, self.pred_len + self.seq_len, 1)) dec_out = dec_out + \ (means[:, 0, :].unsqueeze(1).repeat( 1, self.pred_len + self.seq_len, 1)) return dec_out def get_position_encoding(self, x): max_length = x.size()[1] position = torch.arange(max_length, dtype=torch.float32, device=x.device) # tensor([0., 1., 2., 3., 4.], device='cuda:0') scaled_time = position.unsqueeze(1) * self.inv_timescales.unsqueeze(0) # 5 256 signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) # [T, C] signal = F.pad(signal, (0, 0, 0, self.pe_hidden_size % 2)) signal = signal.view(1, max_length, self.pe_hidden_size) return signal