first commit
This commit is contained in:
22
layers/DECOMP.py
Normal file
22
layers/DECOMP.py
Normal file
@ -0,0 +1,22 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from layers.EMA import EMA
|
||||
from layers.DEMA import DEMA
|
||||
|
||||
class DECOMP(nn.Module):
|
||||
"""
|
||||
Series decomposition block using EMA/DEMA
|
||||
"""
|
||||
def __init__(self, ma_type, alpha, beta):
|
||||
super(DECOMP, self).__init__()
|
||||
if ma_type == 'ema':
|
||||
self.ma = EMA(alpha)
|
||||
elif ma_type == 'dema':
|
||||
self.ma = DEMA(alpha, beta)
|
||||
else:
|
||||
raise ValueError(f"Unsupported ma_type: {ma_type}. Use 'ema' or 'dema'")
|
||||
|
||||
def forward(self, x):
|
||||
moving_average = self.ma(x)
|
||||
res = x - moving_average
|
||||
return res, moving_average
|
Reference in New Issue
Block a user