Files
TSlib/layers/DECOMP.py
2025-08-28 10:17:59 +00:00

22 lines
626 B
Python

import torch
from torch import nn
from layers.EMA import EMA
from layers.DEMA import DEMA
class DECOMP(nn.Module):
"""
Series decomposition block using EMA/DEMA
"""
def __init__(self, ma_type, alpha, beta):
super(DECOMP, self).__init__()
if ma_type == 'ema':
self.ma = EMA(alpha)
elif ma_type == 'dema':
self.ma = DEMA(alpha, beta)
else:
raise ValueError(f"Unsupported ma_type: {ma_type}. Use 'ema' or 'dema'")
def forward(self, x):
moving_average = self.ma(x)
res = x - moving_average
return res, moving_average