59 lines
2.1 KiB
Python
59 lines
2.1 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import math
|
|
|
|
from layers.decomp import DECOMP
|
|
from .network import Network
|
|
# from layers.network_mlp import NetworkMLP # For ablation study with MLP-only stream
|
|
# from layers.network_cnn import NetworkCNN # For ablation study with CNN-only stream
|
|
from layers.revin import RevIN
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self, configs):
|
|
super(Model, self).__init__()
|
|
|
|
# Parameters
|
|
seq_len = configs.seq_len # lookback window L
|
|
pred_len = configs.pred_len # prediction length (96, 192, 336, 720)
|
|
c_in = configs.enc_in # input channels
|
|
|
|
# Patching
|
|
patch_len = configs.patch_len
|
|
stride = configs.stride
|
|
padding_patch = configs.padding_patch
|
|
|
|
# Normalization
|
|
self.revin = configs.revin
|
|
self.revin_layer = RevIN(c_in,affine=True,subtract_last=False)
|
|
|
|
# Moving Average
|
|
self.ma_type = configs.ma_type
|
|
alpha = configs.alpha # smoothing factor for EMA (Exponential Moving Average)
|
|
beta = configs.beta # smoothing factor for DEMA (Double Exponential Moving Average)
|
|
|
|
self.decomp = DECOMP(self.ma_type, alpha, beta)
|
|
self.net = Network(seq_len, pred_len, patch_len, stride, padding_patch)
|
|
# self.net_mlp = NetworkMLP(seq_len, pred_len) # For ablation study with MLP-only stream
|
|
# self.net_cnn = NetworkCNN(seq_len, pred_len, patch_len, stride, padding_patch) # For ablation study with CNN-only stream
|
|
|
|
def forward(self, x):
|
|
# x: [Batch, Input, Channel]
|
|
|
|
# Normalization
|
|
if self.revin:
|
|
x = self.revin_layer(x, 'norm')
|
|
|
|
if self.ma_type == 'reg': # If no decomposition, directly pass the input to the network
|
|
x = self.net(x, x)
|
|
# x = self.net_mlp(x) # For ablation study with MLP-only stream
|
|
# x = self.net_cnn(x) # For ablation study with CNN-only stream
|
|
else:
|
|
seasonal_init, trend_init = self.decomp(x)
|
|
x = self.net(seasonal_init, trend_init)
|
|
|
|
# Denormalization
|
|
if self.revin:
|
|
x = self.revin_layer(x, 'denorm')
|
|
|
|
return x
|