22 lines
514 B
Python
22 lines
514 B
Python
import torch
|
|
from torch import nn
|
|
|
|
from layers.ema import EMA
|
|
from layers.dema import DEMA
|
|
|
|
class DECOMP(nn.Module):
|
|
"""
|
|
Series decomposition block
|
|
"""
|
|
def __init__(self, ma_type, alpha, beta):
|
|
super(DECOMP, self).__init__()
|
|
if ma_type == 'ema':
|
|
self.ma = EMA(alpha)
|
|
elif ma_type == 'dema':
|
|
self.ma = DEMA(alpha, beta)
|
|
|
|
def forward(self, x):
|
|
moving_average = self.ma(x)
|
|
res = x - moving_average
|
|
return res, moving_average
|