# -*- coding: utf-8 -*- """ Created on Sun Jan 5 16:10:01 2025 @author: Murad SISLab, USF mmurad@usf.edu https://github.com/Secure-and-Intelligent-Systems-Lab/WPMixer """ import torch.nn as nn import torch from layers.DWT_Decomposition import Decomposition class TokenMixer(nn.Module): def __init__(self, input_seq=[], batch_size=[], channel=[], pred_seq=[], dropout=[], factor=[], d_model=[]): super(TokenMixer, self).__init__() self.input_seq = input_seq self.batch_size = batch_size self.channel = channel self.pred_seq = pred_seq self.dropout = dropout self.factor = factor self.d_model = d_model self.dropoutLayer = nn.Dropout(self.dropout) self.layers = nn.Sequential(nn.Linear(self.input_seq, self.pred_seq * self.factor), nn.GELU(), nn.Dropout(self.dropout), nn.Linear(self.pred_seq * self.factor, self.pred_seq) ) def forward(self, x): x = x.transpose(1, 2) x = self.layers(x) x = x.transpose(1, 2) return x class Mixer(nn.Module): def __init__(self, input_seq=[], out_seq=[], batch_size=[], channel=[], d_model=[], dropout=[], tfactor=[], dfactor=[]): super(Mixer, self).__init__() self.input_seq = input_seq self.pred_seq = out_seq self.batch_size = batch_size self.channel = channel self.d_model = d_model self.dropout = dropout self.tfactor = tfactor # expansion factor for patch mixer self.dfactor = dfactor # expansion factor for embedding mixer self.tMixer = TokenMixer(input_seq=self.input_seq, batch_size=self.batch_size, channel=self.channel, pred_seq=self.pred_seq, dropout=self.dropout, factor=self.tfactor, d_model=self.d_model) self.dropoutLayer = nn.Dropout(self.dropout) self.norm1 = nn.BatchNorm2d(self.channel) self.norm2 = nn.BatchNorm2d(self.channel) self.embeddingMixer = nn.Sequential(nn.Linear(self.d_model, self.d_model * self.dfactor), nn.GELU(), nn.Dropout(self.dropout), nn.Linear(self.d_model * self.dfactor, self.d_model)) def forward(self, x): ''' Parameters ---------- x : input: [Batch, Channel, Patch_number, d_model] Returns ------- x: output: [Batch, Channel, Patch_number, d_model] ''' x = self.norm1(x) x = x.permute(0, 3, 1, 2) x = self.dropoutLayer(self.tMixer(x)) x = x.permute(0, 2, 3, 1) x = self.norm2(x) x = x + self.dropoutLayer(self.embeddingMixer(x)) return x class ResolutionBranch(nn.Module): def __init__(self, input_seq=[], pred_seq=[], batch_size=[], channel=[], d_model=[], dropout=[], embedding_dropout=[], tfactor=[], dfactor=[], patch_len=[], patch_stride=[]): super(ResolutionBranch, self).__init__() self.input_seq = input_seq self.pred_seq = pred_seq self.batch_size = batch_size self.channel = channel self.d_model = d_model self.dropout = dropout self.embedding_dropout = embedding_dropout self.tfactor = tfactor self.dfactor = dfactor self.patch_len = patch_len self.patch_stride = patch_stride self.patch_num = int((self.input_seq - self.patch_len) / self.patch_stride + 2) self.patch_norm = nn.BatchNorm2d(self.channel) self.patch_embedding_layer = nn.Linear(self.patch_len, self.d_model) # shared among all channels self.mixer1 = Mixer(input_seq=self.patch_num, out_seq=self.patch_num, batch_size=self.batch_size, channel=self.channel, d_model=self.d_model, dropout=self.dropout, tfactor=self.tfactor, dfactor=self.dfactor) self.mixer2 = Mixer(input_seq=self.patch_num, out_seq=self.patch_num, batch_size=self.batch_size, channel=self.channel, d_model=self.d_model, dropout=self.dropout, tfactor=self.tfactor, dfactor=self.dfactor) self.norm = nn.BatchNorm2d(self.channel) self.dropoutLayer = nn.Dropout(self.embedding_dropout) self.head = nn.Sequential(nn.Flatten(start_dim=-2, end_dim=-1), nn.Linear(self.patch_num * self.d_model, self.pred_seq)) def forward(self, x): ''' Parameters ---------- x : input coefficient series: [Batch, channel, length_of_coefficient_series] Returns ------- out : predicted coefficient series: [Batch, channel, length_of_pred_coeff_series] ''' x_patch = self.do_patching(x) x_patch = self.patch_norm(x_patch) x_emb = self.dropoutLayer(self.patch_embedding_layer(x_patch)) out = self.mixer1(x_emb) res = out out = res + self.mixer2(out) out = self.norm(out) out = self.head(out) return out def do_patching(self, x): x_end = x[:, :, -1:] x_padding = x_end.repeat(1, 1, self.patch_stride) x_new = torch.cat((x, x_padding), dim=-1) x_patch = x_new.unfold(dimension=-1, size=self.patch_len, step=self.patch_stride) return x_patch class WPMixerCore(nn.Module): def __init__(self, input_length=[], pred_length=[], wavelet_name=[], level=[], batch_size=[], channel=[], d_model=[], dropout=[], embedding_dropout=[], tfactor=[], dfactor=[], device=[], patch_len=[], patch_stride=[], no_decomposition=[], use_amp=[]): super(WPMixerCore, self).__init__() self.input_length = input_length self.pred_length = pred_length self.wavelet_name = wavelet_name self.level = level self.batch_size = batch_size self.channel = channel self.d_model = d_model self.dropout = dropout self.embedding_dropout = embedding_dropout self.device = device self.no_decomposition = no_decomposition self.tfactor = tfactor self.dfactor = dfactor self.use_amp = use_amp self.Decomposition_model = Decomposition(input_length=self.input_length, pred_length=self.pred_length, wavelet_name=self.wavelet_name, level=self.level, batch_size=self.batch_size, channel=self.channel, d_model=self.d_model, tfactor=self.tfactor, dfactor=self.dfactor, device=self.device, no_decomposition=self.no_decomposition, use_amp=self.use_amp) self.input_w_dim = self.Decomposition_model.input_w_dim # list of the length of the input coefficient series self.pred_w_dim = self.Decomposition_model.pred_w_dim # list of the length of the predicted coefficient series self.patch_len = patch_len self.patch_stride = patch_stride # (m+1) number of resolutionBranch self.resolutionBranch = nn.ModuleList([ResolutionBranch(input_seq=self.input_w_dim[i], pred_seq=self.pred_w_dim[i], batch_size=self.batch_size, channel=self.channel, d_model=self.d_model, dropout=self.dropout, embedding_dropout=self.embedding_dropout, tfactor=self.tfactor, dfactor=self.dfactor, patch_len=self.patch_len, patch_stride=self.patch_stride) for i in range(len(self.input_w_dim))]) def forward(self, xL): ''' Parameters ---------- xL : Look back window: [Batch, look_back_length, channel] Returns ------- xT : Prediction time series: [Batch, prediction_length, output_channel] ''' x = xL.transpose(1, 2) # [batch, channel, look_back_length] # xA: approximation coefficient series, # xD: detail coefficient series # yA: predicted approximation coefficient series # yD: predicted detail coefficient series xA, xD = self.Decomposition_model.transform(x) yA = self.resolutionBranch[0](xA) yD = [] for i in range(len(xD)): yD_i = self.resolutionBranch[i + 1](xD[i]) yD.append(yD_i) y = self.Decomposition_model.inv_transform(yA, yD) y = y.transpose(1, 2) xT = y[:, -self.pred_length:, :] # decomposition output is always even, but pred length can be odd return xT class Model(nn.Module): def __init__(self, args, tfactor=5, dfactor=5, wavelet='db2', level=1, stride=8, no_decomposition=False): super(Model, self).__init__() self.args = args self.task_name = args.task_name self.wpmixerCore = WPMixerCore(input_length=self.args.seq_len, pred_length=self.args.pred_len, wavelet_name=wavelet, level=level, batch_size=self.args.batch_size, channel=self.args.c_out, d_model=self.args.d_model, dropout=self.args.dropout, embedding_dropout=self.args.dropout, tfactor=tfactor, dfactor=dfactor, device=self.args.device, patch_len=self.args.patch_len, patch_stride=stride, no_decomposition=no_decomposition, use_amp=self.args.use_amp) def forecast(self, x_enc, x_mark_enc, x_dec, batch_y_mark): # Normalization means = x_enc.mean(1, keepdim=True).detach() x_enc = x_enc - means stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) x_enc /= stdev pred = self.wpmixerCore(x_enc) pred = pred[:, :, -self.args.c_out:] # De-Normalization dec_out = pred * (stdev[:, 0].unsqueeze(1).repeat(1, self.args.pred_len, 1)) dec_out = dec_out + (means[:, 0].unsqueeze(1).repeat(1, self.args.pred_len, 1)) return dec_out def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) return dec_out # [B, L, D] if self.task_name == 'imputation': raise NotImplementedError("Task imputation for WPMixer is temporarily not supported") if self.task_name == 'anomaly_detection': raise NotImplementedError("Task anomaly_detection for WPMixer is temporarily not supported") if self.task_name == 'classification': raise NotImplementedError("Task classification for WPMixer is temporarily not supported") return None