23 lines
761 B
Python
23 lines
761 B
Python
import torch
|
|
from torch import nn
|
|
|
|
class EMA(nn.Module):
|
|
"""
|
|
Exponential Moving Average (EMA) block to highlight the trend of time series
|
|
"""
|
|
def __init__(self, alpha):
|
|
super(EMA, self).__init__()
|
|
self.alpha = alpha
|
|
|
|
def forward(self, x):
|
|
# x: [Batch, Input, Channel]
|
|
_, t, _ = x.shape
|
|
powers = torch.flip(torch.arange(t, dtype=torch.double), dims=(0,))
|
|
weights = torch.pow((1 - self.alpha), powers).to(x.device)
|
|
divisor = weights.clone()
|
|
weights[1:] = weights[1:] * self.alpha
|
|
weights = weights.reshape(1, t, 1)
|
|
divisor = divisor.reshape(1, t, 1)
|
|
x = torch.cumsum(x * weights, dim=1)
|
|
x = torch.div(x, divisor)
|
|
return x.to(torch.float32) |