first timesnet try
This commit is contained in:
203
models/TimeMixer++/Autoformer_EncDec.py
Normal file
203
models/TimeMixer++/Autoformer_EncDec.py
Normal file
@ -0,0 +1,203 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class my_Layernorm(nn.Module):
|
||||
"""
|
||||
Special designed layernorm for the seasonal part
|
||||
"""
|
||||
|
||||
def __init__(self, channels):
|
||||
super(my_Layernorm, self).__init__()
|
||||
self.layernorm = nn.LayerNorm(channels)
|
||||
|
||||
def forward(self, x):
|
||||
x_hat = self.layernorm(x)
|
||||
bias = torch.mean(x_hat, dim=1).unsqueeze(1).repeat(1, x.shape[1], 1)
|
||||
return x_hat - bias
|
||||
|
||||
|
||||
class moving_avg(nn.Module):
|
||||
"""
|
||||
Moving average block to highlight the trend of time series
|
||||
"""
|
||||
|
||||
def __init__(self, kernel_size, stride):
|
||||
super(moving_avg, self).__init__()
|
||||
self.kernel_size = kernel_size
|
||||
self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
# padding on the both ends of time series
|
||||
front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
|
||||
end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
|
||||
x = torch.cat([front, x, end], dim=1)
|
||||
x = self.avg(x.permute(0, 2, 1))
|
||||
x = x.permute(0, 2, 1)
|
||||
return x
|
||||
|
||||
|
||||
class series_decomp(nn.Module):
|
||||
"""
|
||||
Series decomposition block
|
||||
"""
|
||||
|
||||
def __init__(self, kernel_size):
|
||||
super(series_decomp, self).__init__()
|
||||
self.moving_avg = moving_avg(kernel_size, stride=1)
|
||||
|
||||
def forward(self, x):
|
||||
moving_mean = self.moving_avg(x)
|
||||
res = x - moving_mean
|
||||
return res, moving_mean
|
||||
|
||||
|
||||
class series_decomp_multi(nn.Module):
|
||||
"""
|
||||
Multiple Series decomposition block from FEDformer
|
||||
"""
|
||||
|
||||
def __init__(self, kernel_size):
|
||||
super(series_decomp_multi, self).__init__()
|
||||
self.kernel_size = kernel_size
|
||||
self.series_decomp = [series_decomp(kernel) for kernel in kernel_size]
|
||||
|
||||
def forward(self, x):
|
||||
moving_mean = []
|
||||
res = []
|
||||
for func in self.series_decomp:
|
||||
sea, moving_avg = func(x)
|
||||
moving_mean.append(moving_avg)
|
||||
res.append(sea)
|
||||
|
||||
sea = sum(res) / len(res)
|
||||
moving_mean = sum(moving_mean) / len(moving_mean)
|
||||
return sea, moving_mean
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
"""
|
||||
Autoformer encoder layer with the progressive decomposition architecture
|
||||
"""
|
||||
|
||||
def __init__(self, attention, d_model, d_ff=None, moving_avg=25, dropout=0.1, activation="relu"):
|
||||
super(EncoderLayer, self).__init__()
|
||||
d_ff = d_ff or 4 * d_model
|
||||
self.attention = attention
|
||||
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False)
|
||||
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False)
|
||||
self.decomp1 = series_decomp(moving_avg)
|
||||
self.decomp2 = series_decomp(moving_avg)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.activation = F.relu if activation == "relu" else F.gelu
|
||||
|
||||
def forward(self, x, attn_mask=None):
|
||||
new_x, attn = self.attention(
|
||||
x, x, x,
|
||||
attn_mask=attn_mask
|
||||
)
|
||||
x = x + self.dropout(new_x)
|
||||
x, _ = self.decomp1(x)
|
||||
y = x
|
||||
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
|
||||
y = self.dropout(self.conv2(y).transpose(-1, 1))
|
||||
res, _ = self.decomp2(x + y)
|
||||
return res, attn
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
"""
|
||||
Autoformer encoder
|
||||
"""
|
||||
|
||||
def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
|
||||
super(Encoder, self).__init__()
|
||||
self.attn_layers = nn.ModuleList(attn_layers)
|
||||
self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
|
||||
self.norm = norm_layer
|
||||
|
||||
def forward(self, x, attn_mask=None):
|
||||
attns = []
|
||||
if self.conv_layers is not None:
|
||||
for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers):
|
||||
x, attn = attn_layer(x, attn_mask=attn_mask)
|
||||
x = conv_layer(x)
|
||||
attns.append(attn)
|
||||
x, attn = self.attn_layers[-1](x)
|
||||
attns.append(attn)
|
||||
else:
|
||||
for attn_layer in self.attn_layers:
|
||||
x, attn = attn_layer(x, attn_mask=attn_mask)
|
||||
attns.append(attn)
|
||||
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
|
||||
return x, attns
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
"""
|
||||
Autoformer decoder layer with the progressive decomposition architecture
|
||||
"""
|
||||
|
||||
def __init__(self, self_attention, cross_attention, d_model, c_out, d_ff=None,
|
||||
moving_avg=25, dropout=0.1, activation="relu"):
|
||||
super(DecoderLayer, self).__init__()
|
||||
d_ff = d_ff or 4 * d_model
|
||||
self.self_attention = self_attention
|
||||
self.cross_attention = cross_attention
|
||||
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False)
|
||||
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False)
|
||||
self.decomp1 = series_decomp(moving_avg)
|
||||
self.decomp2 = series_decomp(moving_avg)
|
||||
self.decomp3 = series_decomp(moving_avg)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.projection = nn.Conv1d(in_channels=d_model, out_channels=c_out, kernel_size=3, stride=1, padding=1,
|
||||
padding_mode='circular', bias=False)
|
||||
self.activation = F.relu if activation == "relu" else F.gelu
|
||||
|
||||
def forward(self, x, cross, x_mask=None, cross_mask=None):
|
||||
x = x + self.dropout(self.self_attention(
|
||||
x, x, x,
|
||||
attn_mask=x_mask
|
||||
)[0])
|
||||
x, trend1 = self.decomp1(x)
|
||||
x = x + self.dropout(self.cross_attention(
|
||||
x, cross, cross,
|
||||
attn_mask=cross_mask
|
||||
)[0])
|
||||
x, trend2 = self.decomp2(x)
|
||||
y = x
|
||||
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
|
||||
y = self.dropout(self.conv2(y).transpose(-1, 1))
|
||||
x, trend3 = self.decomp3(x + y)
|
||||
|
||||
residual_trend = trend1 + trend2 + trend3
|
||||
residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose(1, 2)
|
||||
return x, residual_trend
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""
|
||||
Autoformer encoder
|
||||
"""
|
||||
|
||||
def __init__(self, layers, norm_layer=None, projection=None):
|
||||
super(Decoder, self).__init__()
|
||||
self.layers = nn.ModuleList(layers)
|
||||
self.norm = norm_layer
|
||||
self.projection = projection
|
||||
|
||||
def forward(self, x, cross, x_mask=None, cross_mask=None, trend=None):
|
||||
for layer in self.layers:
|
||||
x, residual_trend = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask)
|
||||
trend = trend + residual_trend
|
||||
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
|
||||
if self.projection is not None:
|
||||
x = self.projection(x)
|
||||
return x, trend
|
234
models/TimeMixer++/Embed.py
Normal file
234
models/TimeMixer++/Embed.py
Normal file
@ -0,0 +1,234 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import weight_norm
|
||||
import math
|
||||
|
||||
|
||||
class PositionalEmbedding(nn.Module):
|
||||
def __init__(self, d_model, max_len=5000):
|
||||
super(PositionalEmbedding, self).__init__()
|
||||
# Compute the positional encodings once in log space.
|
||||
pe = torch.zeros(max_len, d_model).float()
|
||||
pe.require_grad = False
|
||||
|
||||
position = torch.arange(0, max_len).float().unsqueeze(1)
|
||||
div_term = (torch.arange(0, d_model, 2).float()
|
||||
* -(math.log(10000.0) / d_model)).exp()
|
||||
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
|
||||
pe = pe.unsqueeze(0)
|
||||
self.register_buffer('pe', pe)
|
||||
|
||||
def forward(self, x):
|
||||
return self.pe[:, :x.size(1)]
|
||||
|
||||
|
||||
class TokenEmbedding(nn.Module):
|
||||
def __init__(self, c_in, d_model):
|
||||
super(TokenEmbedding, self).__init__()
|
||||
padding = 1 if torch.__version__ >= '1.5.0' else 2
|
||||
self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
|
||||
kernel_size=3, padding=padding, padding_mode='circular', bias=False)
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_in', nonlinearity='leaky_relu')
|
||||
|
||||
def forward(self, x):
|
||||
x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class FixedEmbedding(nn.Module):
|
||||
def __init__(self, c_in, d_model):
|
||||
super(FixedEmbedding, self).__init__()
|
||||
|
||||
w = torch.zeros(c_in, d_model).float()
|
||||
w.require_grad = False
|
||||
|
||||
position = torch.arange(0, c_in).float().unsqueeze(1)
|
||||
div_term = (torch.arange(0, d_model, 2).float()
|
||||
* -(math.log(10000.0) / d_model)).exp()
|
||||
|
||||
w[:, 0::2] = torch.sin(position * div_term)
|
||||
w[:, 1::2] = torch.cos(position * div_term)
|
||||
|
||||
self.emb = nn.Embedding(c_in, d_model)
|
||||
self.emb.weight = nn.Parameter(w, requires_grad=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.emb(x).detach()
|
||||
|
||||
|
||||
class TemporalEmbedding(nn.Module):
|
||||
def __init__(self, d_model, embed_type='fixed', freq='h'):
|
||||
super(TemporalEmbedding, self).__init__()
|
||||
|
||||
minute_size = 4
|
||||
hour_size = 24
|
||||
weekday_size = 7
|
||||
day_size = 32
|
||||
month_size = 13
|
||||
|
||||
Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding
|
||||
if freq == 't':
|
||||
self.minute_embed = Embed(minute_size, d_model)
|
||||
self.hour_embed = Embed(hour_size, d_model)
|
||||
self.weekday_embed = Embed(weekday_size, d_model)
|
||||
self.day_embed = Embed(day_size, d_model)
|
||||
self.month_embed = Embed(month_size, d_model)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.long()
|
||||
minute_x = self.minute_embed(x[:, :, 4]) if hasattr(
|
||||
self, 'minute_embed') else 0.
|
||||
hour_x = self.hour_embed(x[:, :, 3])
|
||||
weekday_x = self.weekday_embed(x[:, :, 2])
|
||||
day_x = self.day_embed(x[:, :, 1])
|
||||
month_x = self.month_embed(x[:, :, 0])
|
||||
|
||||
return hour_x + weekday_x + day_x + month_x + minute_x
|
||||
|
||||
|
||||
class TimeFeatureEmbedding(nn.Module):
|
||||
def __init__(self, d_model, embed_type='timeF', freq='h'):
|
||||
super(TimeFeatureEmbedding, self).__init__()
|
||||
|
||||
freq_map = {'h': 4, 't': 5, 's': 6, 'ms': 7,
|
||||
'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3}
|
||||
d_inp = freq_map[freq]
|
||||
self.embed = nn.Linear(d_inp, d_model, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.embed(x)
|
||||
|
||||
|
||||
class DataEmbedding(nn.Module):
|
||||
def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
|
||||
super(DataEmbedding, self).__init__()
|
||||
self.c_in = c_in
|
||||
self.d_model = d_model
|
||||
self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
|
||||
self.position_embedding = PositionalEmbedding(d_model=d_model)
|
||||
self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type,
|
||||
freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(
|
||||
d_model=d_model, embed_type=embed_type, freq=freq)
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
def forward(self, x, x_mark):
|
||||
_, _, N = x.size()
|
||||
if N == self.c_in:
|
||||
if x_mark is None:
|
||||
x = self.value_embedding(x) + self.position_embedding(x)
|
||||
else:
|
||||
x = self.value_embedding(
|
||||
x) + self.temporal_embedding(x_mark) + self.position_embedding(x)
|
||||
elif N == self.d_model:
|
||||
if x_mark is None:
|
||||
x = x + self.position_embedding(x)
|
||||
else:
|
||||
x = x + self.temporal_embedding(x_mark) + self.position_embedding(x)
|
||||
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class DataEmbedding_ms(nn.Module):
|
||||
def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
|
||||
super(DataEmbedding_ms, self).__init__()
|
||||
|
||||
self.value_embedding = TokenEmbedding(c_in=1, d_model=d_model)
|
||||
self.position_embedding = PositionalEmbedding(d_model=d_model)
|
||||
self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type,
|
||||
freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(
|
||||
d_model=d_model, embed_type=embed_type, freq=freq)
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
def forward(self, x, x_mark):
|
||||
B, T, N = x.shape
|
||||
x1 = self.value_embedding(x.reshape(0, 2, 1).reshape(B * N, T).unsqueeze(-1)).reshape(B, N, T, -1).permute(0, 2,
|
||||
1, 3)
|
||||
if x_mark is None:
|
||||
x = x1
|
||||
else:
|
||||
x = x1 + self.temporal_embedding(x_mark)
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class DataEmbedding_wo_pos(nn.Module):
|
||||
def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
|
||||
super(DataEmbedding_wo_pos, self).__init__()
|
||||
|
||||
self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
|
||||
self.position_embedding = PositionalEmbedding(d_model=d_model)
|
||||
self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type,
|
||||
freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(
|
||||
d_model=d_model, embed_type=embed_type, freq=freq)
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
def forward(self, x, x_mark):
|
||||
if x is None and x_mark is not None:
|
||||
return self.temporal_embedding(x_mark)
|
||||
if x_mark is None:
|
||||
x = self.value_embedding(x)
|
||||
else:
|
||||
x = self.value_embedding(x) + self.temporal_embedding(x_mark)
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class PatchEmbedding_crossformer(nn.Module):
|
||||
def __init__(self, d_model, patch_len, stride, padding, dropout):
|
||||
super(PatchEmbedding_crossformer, self).__init__()
|
||||
# Patching
|
||||
self.patch_len = patch_len
|
||||
self.stride = stride
|
||||
self.padding_patch_layer = nn.ReplicationPad1d((0, padding))
|
||||
|
||||
# Backbone, Input encoding: projection of feature vectors onto a d-dim vector space
|
||||
self.value_embedding = nn.Linear(patch_len, d_model, bias=False)
|
||||
|
||||
# Positional embedding
|
||||
self.position_embedding = PositionalEmbedding(d_model)
|
||||
|
||||
# Residual dropout
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
# do patching
|
||||
n_vars = x.shape[1]
|
||||
x = self.padding_patch_layer(x)
|
||||
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
|
||||
x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
|
||||
# Input encoding
|
||||
x = self.value_embedding(x) + self.position_embedding(x)
|
||||
return self.dropout(x), n_vars
|
||||
|
||||
|
||||
class PatchEmbedding(nn.Module):
|
||||
def __init__(self, d_model, patch_len, stride, dropout):
|
||||
super(PatchEmbedding, self).__init__()
|
||||
# Patching
|
||||
self.patch_len = patch_len
|
||||
self.stride = stride
|
||||
self.padding_patch_layer = nn.ReplicationPad1d((0, stride))
|
||||
|
||||
# Backbone, Input encoding: projection of feature vectors onto a d-dim vector space
|
||||
self.value_embedding = TokenEmbedding(patch_len, d_model)
|
||||
|
||||
# Positional embedding
|
||||
self.position_embedding = PositionalEmbedding(d_model)
|
||||
|
||||
# Residual dropout
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
# do patching
|
||||
n_vars = x.shape[1]
|
||||
x = self.padding_patch_layer(x)
|
||||
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
|
||||
x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
|
||||
# Input encoding
|
||||
x = self.value_embedding(x) + self.position_embedding(x)
|
||||
return self.dropout(x), n_vars
|
67
models/TimeMixer++/StandardNorm.py
Normal file
67
models/TimeMixer++/StandardNorm.py
Normal file
@ -0,0 +1,67 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class Normalize(nn.Module):
|
||||
def __init__(self, num_features: int, eps=1e-5, affine=False, subtract_last=False, non_norm=False):
|
||||
"""
|
||||
:param num_features: the number of features or channels
|
||||
:param eps: a value added for numerical stability
|
||||
:param affine: if True, RevIN has learnable affine parameters
|
||||
"""
|
||||
super(Normalize, self).__init__()
|
||||
self.num_features = num_features
|
||||
self.eps = eps
|
||||
self.affine = affine
|
||||
self.subtract_last = subtract_last
|
||||
self.non_norm = non_norm
|
||||
if self.affine:
|
||||
self._init_params()
|
||||
|
||||
def forward(self, x, mode: str):
|
||||
if mode == 'norm':
|
||||
self._get_statistics(x)
|
||||
x = self._normalize(x)
|
||||
elif mode == 'denorm':
|
||||
x = self._denormalize(x)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return x
|
||||
|
||||
def _init_params(self):
|
||||
# initialize RevIN params: (C,)
|
||||
self.affine_weight = nn.Parameter(torch.ones(self.num_features))
|
||||
self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
|
||||
|
||||
def _get_statistics(self, x):
|
||||
dim2reduce = tuple(range(1, x.ndim - 1))
|
||||
if self.subtract_last:
|
||||
self.last = x[:, -1, :].unsqueeze(1)
|
||||
else:
|
||||
self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
|
||||
self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()
|
||||
|
||||
def _normalize(self, x):
|
||||
if self.non_norm:
|
||||
return x
|
||||
if self.subtract_last:
|
||||
x = x - self.last
|
||||
else:
|
||||
x = x - self.mean
|
||||
x = x / self.stdev
|
||||
if self.affine:
|
||||
x = x * self.affine_weight
|
||||
x = x + self.affine_bias
|
||||
return x
|
||||
|
||||
def _denormalize(self, x):
|
||||
if self.non_norm:
|
||||
return x
|
||||
if self.affine:
|
||||
x = x - self.affine_bias
|
||||
x = x / (self.affine_weight + self.eps * self.eps)
|
||||
x = x * self.stdev
|
||||
if self.subtract_last:
|
||||
x = x + self.last
|
||||
else:
|
||||
x = x + self.mean
|
||||
return x
|
527
models/TimeMixer++/TimeMixer.py
Normal file
527
models/TimeMixer++/TimeMixer.py
Normal file
@ -0,0 +1,527 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from layers.Autoformer_EncDec import series_decomp
|
||||
from layers.Embed import DataEmbedding_wo_pos
|
||||
from layers.StandardNorm import Normalize
|
||||
|
||||
class DFT_series_decomp(nn.Module):
|
||||
"""
|
||||
Series decomposition block
|
||||
"""
|
||||
|
||||
def __init__(self, top_k=5):
|
||||
super(DFT_series_decomp, self).__init__()
|
||||
self.top_k = top_k
|
||||
|
||||
def forward(self, x):
|
||||
xf = torch.fft.rfft(x)
|
||||
freq = abs(xf)
|
||||
freq[0] = 0
|
||||
top_k_freq, top_list = torch.topk(freq, self.top_k)
|
||||
xf[freq <= top_k_freq.min()] = 0
|
||||
x_season = torch.fft.irfft(xf)
|
||||
x_trend = x - x_season
|
||||
return x_season, x_trend
|
||||
|
||||
|
||||
class MultiScaleSeasonMixing(nn.Module):
|
||||
"""
|
||||
Bottom-up mixing season pattern
|
||||
"""
|
||||
|
||||
def __init__(self, configs):
|
||||
super(MultiScaleSeasonMixing, self).__init__()
|
||||
|
||||
self.down_sampling_layers = torch.nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
torch.nn.Linear(
|
||||
configs.seq_len // (configs.down_sampling_window ** i),
|
||||
configs.seq_len // (configs.down_sampling_window ** (i + 1)),
|
||||
),
|
||||
nn.GELU(),
|
||||
torch.nn.Linear(
|
||||
configs.seq_len // (configs.down_sampling_window ** (i + 1)),
|
||||
configs.seq_len // (configs.down_sampling_window ** (i + 1)),
|
||||
),
|
||||
|
||||
)
|
||||
for i in range(configs.down_sampling_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, season_list):
|
||||
|
||||
# mixing high->low
|
||||
out_high = season_list[0]
|
||||
out_low = season_list[1]
|
||||
out_season_list = [out_high.permute(0, 2, 1)]
|
||||
|
||||
for i in range(len(season_list) - 1):
|
||||
out_low_res = self.down_sampling_layers[i](out_high)
|
||||
out_low = out_low + out_low_res
|
||||
out_high = out_low
|
||||
if i + 2 <= len(season_list) - 1:
|
||||
out_low = season_list[i + 2]
|
||||
out_season_list.append(out_high.permute(0, 2, 1))
|
||||
|
||||
return out_season_list
|
||||
|
||||
|
||||
class MultiScaleTrendMixing(nn.Module):
|
||||
"""
|
||||
Top-down mixing trend pattern
|
||||
"""
|
||||
|
||||
def __init__(self, configs):
|
||||
super(MultiScaleTrendMixing, self).__init__()
|
||||
|
||||
self.up_sampling_layers = torch.nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
torch.nn.Linear(
|
||||
configs.seq_len // (configs.down_sampling_window ** (i + 1)),
|
||||
configs.seq_len // (configs.down_sampling_window ** i),
|
||||
),
|
||||
nn.GELU(),
|
||||
torch.nn.Linear(
|
||||
configs.seq_len // (configs.down_sampling_window ** i),
|
||||
configs.seq_len // (configs.down_sampling_window ** i),
|
||||
),
|
||||
)
|
||||
for i in reversed(range(configs.down_sampling_layers))
|
||||
])
|
||||
|
||||
def forward(self, trend_list):
|
||||
|
||||
# mixing low->high
|
||||
trend_list_reverse = trend_list.copy()
|
||||
trend_list_reverse.reverse()
|
||||
out_low = trend_list_reverse[0]
|
||||
out_high = trend_list_reverse[1]
|
||||
out_trend_list = [out_low.permute(0, 2, 1)]
|
||||
|
||||
for i in range(len(trend_list_reverse) - 1):
|
||||
out_high_res = self.up_sampling_layers[i](out_low)
|
||||
out_high = out_high + out_high_res
|
||||
out_low = out_high
|
||||
if i + 2 <= len(trend_list_reverse) - 1:
|
||||
out_high = trend_list_reverse[i + 2]
|
||||
out_trend_list.append(out_low.permute(0, 2, 1))
|
||||
|
||||
out_trend_list.reverse()
|
||||
return out_trend_list
|
||||
|
||||
|
||||
class PastDecomposableMixing(nn.Module):
|
||||
def __init__(self, configs):
|
||||
super(PastDecomposableMixing, self).__init__()
|
||||
self.seq_len = configs.seq_len
|
||||
self.pred_len = configs.pred_len
|
||||
self.down_sampling_window = configs.down_sampling_window
|
||||
|
||||
self.layer_norm = nn.LayerNorm(configs.d_model)
|
||||
self.dropout = nn.Dropout(configs.dropout)
|
||||
self.channel_independence = configs.channel_independence
|
||||
|
||||
if configs.decomp_method == 'moving_avg':
|
||||
self.decompsition = series_decomp(configs.moving_avg)
|
||||
elif configs.decomp_method == "dft_decomp":
|
||||
self.decompsition = DFT_series_decomp(configs.top_k)
|
||||
else:
|
||||
raise ValueError('decompsition is error')
|
||||
|
||||
if configs.channel_independence == 0:
|
||||
self.cross_layer = nn.Sequential(
|
||||
nn.Linear(in_features=configs.d_model, out_features=configs.d_ff),
|
||||
nn.GELU(),
|
||||
nn.Linear(in_features=configs.d_ff, out_features=configs.d_model),
|
||||
)
|
||||
|
||||
# Mixing season
|
||||
self.mixing_multi_scale_season = MultiScaleSeasonMixing(configs)
|
||||
|
||||
# Mxing trend
|
||||
self.mixing_multi_scale_trend = MultiScaleTrendMixing(configs)
|
||||
|
||||
self.out_cross_layer = nn.Sequential(
|
||||
nn.Linear(in_features=configs.d_model, out_features=configs.d_ff),
|
||||
nn.GELU(),
|
||||
nn.Linear(in_features=configs.d_ff, out_features=configs.d_model),
|
||||
)
|
||||
|
||||
def forward(self, x_list):
|
||||
length_list = []
|
||||
for x in x_list:
|
||||
_, T, _ = x.size()
|
||||
length_list.append(T)
|
||||
|
||||
# Decompose to obtain the season and trend
|
||||
season_list = []
|
||||
trend_list = []
|
||||
for x in x_list:
|
||||
season, trend = self.decompsition(x)
|
||||
if self.channel_independence == 0:
|
||||
season = self.cross_layer(season)
|
||||
trend = self.cross_layer(trend)
|
||||
season_list.append(season.permute(0, 2, 1))
|
||||
trend_list.append(trend.permute(0, 2, 1))
|
||||
|
||||
# bottom-up season mixing
|
||||
out_season_list = self.mixing_multi_scale_season(season_list)
|
||||
# top-down trend mixing
|
||||
out_trend_list = self.mixing_multi_scale_trend(trend_list)
|
||||
|
||||
out_list = []
|
||||
for ori, out_season, out_trend, length in zip(x_list, out_season_list, out_trend_list,
|
||||
length_list):
|
||||
out = out_season + out_trend
|
||||
if self.channel_independence:
|
||||
out = ori + self.out_cross_layer(out)
|
||||
out_list.append(out[:, :length, :])
|
||||
return out_list
|
||||
|
||||
|
||||
class TimeMixer(nn.Module):
|
||||
|
||||
def __init__(self, configs):
|
||||
super(TimeMixer, self).__init__()
|
||||
self.configs = configs
|
||||
self.task_name = configs.task_name
|
||||
self.seq_len = configs.seq_len
|
||||
self.label_len = configs.label_len
|
||||
self.pred_len = configs.pred_len
|
||||
self.down_sampling_window = configs.down_sampling_window
|
||||
self.channel_independence = configs.channel_independence
|
||||
self.pdm_blocks = nn.ModuleList([PastDecomposableMixing(configs)
|
||||
for _ in range(configs.e_layers)])
|
||||
|
||||
self.preprocess = series_decomp(configs.moving_avg)
|
||||
self.enc_in = configs.enc_in
|
||||
self.use_future_temporal_feature = configs.use_future_temporal_feature
|
||||
|
||||
if self.channel_independence == 1:
|
||||
self.enc_embedding = DataEmbedding_wo_pos(1, configs.d_model, configs.embed, configs.freq,
|
||||
configs.dropout)
|
||||
else:
|
||||
self.enc_embedding = DataEmbedding_wo_pos(configs.enc_in, configs.d_model, configs.embed, configs.freq,
|
||||
configs.dropout)
|
||||
|
||||
self.layer = configs.e_layers
|
||||
|
||||
self.normalize_layers = torch.nn.ModuleList(
|
||||
[
|
||||
Normalize(self.configs.enc_in, affine=True, non_norm=True if configs.use_norm == 0 else False)
|
||||
for i in range(configs.down_sampling_layers + 1)
|
||||
]
|
||||
)
|
||||
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
self.predict_layers = torch.nn.ModuleList(
|
||||
[
|
||||
torch.nn.Linear(
|
||||
configs.seq_len // (configs.down_sampling_window ** i),
|
||||
configs.pred_len,
|
||||
)
|
||||
for i in range(configs.down_sampling_layers + 1)
|
||||
]
|
||||
)
|
||||
|
||||
if self.channel_independence == 1:
|
||||
self.projection_layer = nn.Linear(
|
||||
configs.d_model, 1, bias=True)
|
||||
else:
|
||||
self.projection_layer = nn.Linear(
|
||||
configs.d_model, configs.c_out, bias=True)
|
||||
|
||||
self.out_res_layers = torch.nn.ModuleList([
|
||||
torch.nn.Linear(
|
||||
configs.seq_len // (configs.down_sampling_window ** i),
|
||||
configs.seq_len // (configs.down_sampling_window ** i),
|
||||
)
|
||||
for i in range(configs.down_sampling_layers + 1)
|
||||
])
|
||||
|
||||
self.regression_layers = torch.nn.ModuleList(
|
||||
[
|
||||
torch.nn.Linear(
|
||||
configs.seq_len // (configs.down_sampling_window ** i),
|
||||
configs.pred_len,
|
||||
)
|
||||
for i in range(configs.down_sampling_layers + 1)
|
||||
]
|
||||
)
|
||||
if self.task_name == 'imputation' or self.task_name == 'anomaly_detection':
|
||||
if self.channel_independence == 1:
|
||||
self.projection_layer = nn.Linear(
|
||||
configs.d_model, 1, bias=True)
|
||||
else:
|
||||
self.projection_layer = nn.Linear(
|
||||
configs.d_model, configs.c_out, bias=True)
|
||||
if self.task_name == 'classification':
|
||||
self.act = F.gelu
|
||||
self.dropout = nn.Dropout(configs.dropout)
|
||||
self.projection = nn.Linear(
|
||||
configs.d_model * configs.seq_len, configs.num_class)
|
||||
|
||||
def out_projection(self, dec_out, i, out_res):
|
||||
dec_out = self.projection_layer(dec_out)
|
||||
out_res = out_res.permute(0, 2, 1)
|
||||
out_res = self.out_res_layers[i](out_res)
|
||||
out_res = self.regression_layers[i](out_res).permute(0, 2, 1)
|
||||
dec_out = dec_out + out_res
|
||||
return dec_out
|
||||
|
||||
def pre_enc(self, x_list):
|
||||
if self.channel_independence == 1:
|
||||
return (x_list, None)
|
||||
else:
|
||||
out1_list = []
|
||||
out2_list = []
|
||||
for x in x_list:
|
||||
x_1, x_2 = self.preprocess(x)
|
||||
out1_list.append(x_1)
|
||||
out2_list.append(x_2)
|
||||
return (out1_list, out2_list)
|
||||
|
||||
def __multi_scale_process_inputs(self, x_enc, x_mark_enc):
|
||||
if self.configs.down_sampling_method == 'max':
|
||||
down_pool = torch.nn.MaxPool1d(self.configs.down_sampling_window, return_indices=False)
|
||||
elif self.configs.down_sampling_method == 'avg':
|
||||
down_pool = torch.nn.AvgPool1d(self.configs.down_sampling_window)
|
||||
elif self.configs.down_sampling_method == 'conv':
|
||||
padding = 1 if torch.__version__ >= '1.5.0' else 2
|
||||
down_pool = nn.Conv1d(in_channels=self.configs.enc_in, out_channels=self.configs.enc_in,
|
||||
kernel_size=3, padding=padding,
|
||||
stride=self.configs.down_sampling_window,
|
||||
padding_mode='circular',
|
||||
bias=False)
|
||||
else:
|
||||
return x_enc, x_mark_enc
|
||||
# B,T,C -> B,C,T
|
||||
x_enc = x_enc.permute(0, 2, 1)
|
||||
|
||||
x_enc_ori = x_enc
|
||||
x_mark_enc_mark_ori = x_mark_enc
|
||||
|
||||
x_enc_sampling_list = []
|
||||
x_mark_sampling_list = []
|
||||
x_enc_sampling_list.append(x_enc.permute(0, 2, 1))
|
||||
x_mark_sampling_list.append(x_mark_enc)
|
||||
|
||||
for i in range(self.configs.down_sampling_layers):
|
||||
x_enc_sampling = down_pool(x_enc_ori)
|
||||
|
||||
x_enc_sampling_list.append(x_enc_sampling.permute(0, 2, 1))
|
||||
x_enc_ori = x_enc_sampling
|
||||
|
||||
if x_mark_enc_mark_ori is not None:
|
||||
x_mark_sampling_list.append(x_mark_enc_mark_ori[:, ::self.configs.down_sampling_window, :])
|
||||
x_mark_enc_mark_ori = x_mark_enc_mark_ori[:, ::self.configs.down_sampling_window, :]
|
||||
|
||||
x_enc = x_enc_sampling_list
|
||||
if x_mark_enc_mark_ori is not None:
|
||||
x_mark_enc = x_mark_sampling_list
|
||||
else:
|
||||
x_mark_enc = x_mark_enc
|
||||
|
||||
return x_enc, x_mark_enc
|
||||
|
||||
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
||||
|
||||
if self.use_future_temporal_feature:
|
||||
if self.channel_independence == 1:
|
||||
B, T, N = x_enc.size()
|
||||
x_mark_dec = x_mark_dec.repeat(N, 1, 1)
|
||||
self.x_mark_dec = self.enc_embedding(None, x_mark_dec)
|
||||
else:
|
||||
self.x_mark_dec = self.enc_embedding(None, x_mark_dec)
|
||||
|
||||
x_enc, x_mark_enc = self.__multi_scale_process_inputs(x_enc, x_mark_enc)
|
||||
|
||||
x_list = []
|
||||
x_mark_list = []
|
||||
if x_mark_enc is not None:
|
||||
for i, x, x_mark in zip(range(len(x_enc)), x_enc, x_mark_enc):
|
||||
B, T, N = x.size()
|
||||
x = self.normalize_layers[i](x, 'norm')
|
||||
if self.channel_independence == 1:
|
||||
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
|
||||
x_mark = x_mark.repeat(N, 1, 1)
|
||||
x_list.append(x)
|
||||
x_mark_list.append(x_mark)
|
||||
else:
|
||||
for i, x in zip(range(len(x_enc)), x_enc, ):
|
||||
B, T, N = x.size()
|
||||
x = self.normalize_layers[i](x, 'norm')
|
||||
if self.channel_independence == 1:
|
||||
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
|
||||
x_list.append(x)
|
||||
|
||||
# embedding
|
||||
enc_out_list = []
|
||||
x_list = self.pre_enc(x_list)
|
||||
if x_mark_enc is not None:
|
||||
for i, x, x_mark in zip(range(len(x_list[0])), x_list[0], x_mark_list):
|
||||
enc_out = self.enc_embedding(x, x_mark) # [B,T,C]
|
||||
enc_out_list.append(enc_out)
|
||||
else:
|
||||
for i, x in zip(range(len(x_list[0])), x_list[0]):
|
||||
enc_out = self.enc_embedding(x, None) # [B,T,C]
|
||||
enc_out_list.append(enc_out)
|
||||
|
||||
# Past Decomposable Mixing as encoder for past
|
||||
for i in range(self.layer):
|
||||
enc_out_list = self.pdm_blocks[i](enc_out_list)
|
||||
|
||||
# Future Multipredictor Mixing as decoder for future
|
||||
dec_out_list = self.future_multi_mixing(B, enc_out_list, x_list)
|
||||
|
||||
dec_out = torch.stack(dec_out_list, dim=-1).sum(-1)
|
||||
dec_out = self.normalize_layers[0](dec_out, 'denorm')
|
||||
return dec_out
|
||||
|
||||
def future_multi_mixing(self, B, enc_out_list, x_list):
|
||||
dec_out_list = []
|
||||
if self.channel_independence == 1:
|
||||
x_list = x_list[0]
|
||||
for i, enc_out in zip(range(len(x_list)), enc_out_list):
|
||||
dec_out = self.predict_layers[i](enc_out.permute(0, 2, 1)).permute(
|
||||
0, 2, 1) # align temporal dimension
|
||||
if self.use_future_temporal_feature:
|
||||
dec_out = dec_out + self.x_mark_dec
|
||||
dec_out = self.projection_layer(dec_out)
|
||||
else:
|
||||
dec_out = self.projection_layer(dec_out)
|
||||
dec_out = dec_out.reshape(B, self.configs.c_out, self.pred_len).permute(0, 2, 1).contiguous()
|
||||
dec_out_list.append(dec_out)
|
||||
|
||||
else:
|
||||
for i, enc_out, out_res in zip(range(len(x_list[0])), enc_out_list, x_list[1]):
|
||||
dec_out = self.predict_layers[i](enc_out.permute(0, 2, 1)).permute(
|
||||
0, 2, 1) # align temporal dimension
|
||||
dec_out = self.out_projection(dec_out, i, out_res)
|
||||
dec_out_list.append(dec_out)
|
||||
|
||||
return dec_out_list
|
||||
|
||||
def classification(self, x_enc, x_mark_enc):
|
||||
x_enc, _ = self.__multi_scale_process_inputs(x_enc, None)
|
||||
x_list = x_enc
|
||||
|
||||
# embedding
|
||||
enc_out_list = []
|
||||
for x in x_list:
|
||||
enc_out = self.enc_embedding(x, None) # [B,T,C]
|
||||
enc_out_list.append(enc_out)
|
||||
|
||||
# MultiScale-CrissCrossAttention as encoder for past
|
||||
for i in range(self.layer):
|
||||
enc_out_list = self.pdm_blocks[i](enc_out_list)
|
||||
|
||||
enc_out = enc_out_list[0]
|
||||
# Output
|
||||
# the output transformer encoder/decoder embeddings don't include non-linearity
|
||||
output = self.act(enc_out)
|
||||
output = self.dropout(output)
|
||||
# zero-out padding embeddings
|
||||
output = output * x_mark_enc.unsqueeze(-1)
|
||||
# (batch_size, seq_length * d_model)
|
||||
output = output.reshape(output.shape[0], -1)
|
||||
output = self.projection(output) # (batch_size, num_classes)
|
||||
return output
|
||||
|
||||
def anomaly_detection(self, x_enc):
|
||||
B, T, N = x_enc.size()
|
||||
x_enc, _ = self.__multi_scale_process_inputs(x_enc, None)
|
||||
|
||||
x_list = []
|
||||
|
||||
for i, x in zip(range(len(x_enc)), x_enc, ):
|
||||
B, T, N = x.size()
|
||||
x = self.normalize_layers[i](x, 'norm')
|
||||
if self.channel_independence == 1:
|
||||
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
|
||||
x_list.append(x)
|
||||
|
||||
# embedding
|
||||
enc_out_list = []
|
||||
for x in x_list:
|
||||
enc_out = self.enc_embedding(x, None) # [B,T,C]
|
||||
enc_out_list.append(enc_out)
|
||||
|
||||
# MultiScale-CrissCrossAttention as encoder for past
|
||||
for i in range(self.layer):
|
||||
enc_out_list = self.pdm_blocks[i](enc_out_list)
|
||||
|
||||
dec_out = self.projection_layer(enc_out_list[0])
|
||||
dec_out = dec_out.reshape(B, self.configs.c_out, -1).permute(0, 2, 1).contiguous()
|
||||
|
||||
dec_out = self.normalize_layers[0](dec_out, 'denorm')
|
||||
return dec_out
|
||||
|
||||
def imputation(self, x_enc, x_mark_enc, mask):
|
||||
means = torch.sum(x_enc, dim=1) / torch.sum(mask == 1, dim=1)
|
||||
means = means.unsqueeze(1).detach()
|
||||
x_enc = x_enc - means
|
||||
x_enc = x_enc.masked_fill(mask == 0, 0)
|
||||
stdev = torch.sqrt(torch.sum(x_enc * x_enc, dim=1) /
|
||||
torch.sum(mask == 1, dim=1) + 1e-5)
|
||||
stdev = stdev.unsqueeze(1).detach()
|
||||
x_enc /= stdev
|
||||
|
||||
B, T, N = x_enc.size()
|
||||
x_enc, x_mark_enc = self.__multi_scale_process_inputs(x_enc, x_mark_enc)
|
||||
|
||||
x_list = []
|
||||
x_mark_list = []
|
||||
if x_mark_enc is not None:
|
||||
for i, x, x_mark in zip(range(len(x_enc)), x_enc, x_mark_enc):
|
||||
B, T, N = x.size()
|
||||
if self.channel_independence == 1:
|
||||
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
|
||||
x_list.append(x)
|
||||
x_mark = x_mark.repeat(N, 1, 1)
|
||||
x_mark_list.append(x_mark)
|
||||
else:
|
||||
for i, x in zip(range(len(x_enc)), x_enc, ):
|
||||
B, T, N = x.size()
|
||||
if self.channel_independence == 1:
|
||||
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
|
||||
x_list.append(x)
|
||||
|
||||
# embedding
|
||||
enc_out_list = []
|
||||
for x in x_list:
|
||||
enc_out = self.enc_embedding(x, None) # [B,T,C]
|
||||
enc_out_list.append(enc_out)
|
||||
|
||||
# MultiScale-CrissCrossAttention as encoder for past
|
||||
for i in range(self.layer):
|
||||
enc_out_list = self.pdm_blocks[i](enc_out_list)
|
||||
|
||||
dec_out = self.projection_layer(enc_out_list[0])
|
||||
dec_out = dec_out.reshape(B, self.configs.c_out, -1).permute(0, 2, 1).contiguous()
|
||||
|
||||
dec_out = dec_out * \
|
||||
(stdev[:, 0, :].unsqueeze(1).repeat(1, self.seq_len, 1))
|
||||
dec_out = dec_out + \
|
||||
(means[:, 0, :].unsqueeze(1).repeat(1, self.seq_len, 1))
|
||||
return dec_out
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
||||
return dec_out
|
||||
if self.task_name == 'imputation':
|
||||
dec_out = self.imputation(x_enc, x_mark_enc, mask)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'anomaly_detection':
|
||||
dec_out = self.anomaly_detection(x_enc)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'classification':
|
||||
dec_out = self.classification(x_enc, x_mark_enc)
|
||||
return dec_out # [B, N]
|
||||
else:
|
||||
raise ValueError('Other tasks implemented yet')
|
0
models/TimeMixer++/__init__.py
Normal file
0
models/TimeMixer++/__init__.py
Normal file
216
models/TimesNet/TimesNet.py
Normal file
216
models/TimesNet/TimesNet.py
Normal file
@ -0,0 +1,216 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.fft
|
||||
from layers.Embed import DataEmbedding
|
||||
from layers.Conv_Blocks import Inception_Block_V1
|
||||
|
||||
|
||||
def FFT_for_Period(x, k=2):
|
||||
# [B, T, C]
|
||||
xf = torch.fft.rfft(x, dim=1)
|
||||
# find period by amplitudes
|
||||
frequency_list = abs(xf).mean(0).mean(-1)
|
||||
frequency_list[0] = 0
|
||||
_, top_list = torch.topk(frequency_list, k)
|
||||
top_list = top_list.detach().cpu().numpy()
|
||||
period = x.shape[1] // top_list
|
||||
return period, abs(xf).mean(-1)[:, top_list]
|
||||
|
||||
|
||||
class TimesBlock(nn.Module):
|
||||
def __init__(self, configs):
|
||||
super(TimesBlock, self).__init__()
|
||||
self.seq_len = configs.seq_len
|
||||
self.pred_len = configs.pred_len
|
||||
self.k = configs.top_k
|
||||
# parameter-efficient design
|
||||
self.conv = nn.Sequential(
|
||||
Inception_Block_V1(configs.d_model, configs.d_ff,
|
||||
num_kernels=configs.num_kernels),
|
||||
nn.GELU(),
|
||||
Inception_Block_V1(configs.d_ff, configs.d_model,
|
||||
num_kernels=configs.num_kernels)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
B, T, N = x.size()
|
||||
period_list, period_weight = FFT_for_Period(x, self.k)
|
||||
|
||||
res = []
|
||||
for i in range(self.k):
|
||||
period = period_list[i]
|
||||
# padding
|
||||
if (self.seq_len + self.pred_len) % period != 0:
|
||||
length = (
|
||||
((self.seq_len + self.pred_len) // period) + 1) * period
|
||||
padding = torch.zeros([x.shape[0], (length - (self.seq_len + self.pred_len)), x.shape[2]]).to(x.device)
|
||||
out = torch.cat([x, padding], dim=1)
|
||||
else:
|
||||
length = (self.seq_len + self.pred_len)
|
||||
out = x
|
||||
# reshape
|
||||
out = out.reshape(B, length // period, period,
|
||||
N).permute(0, 3, 1, 2).contiguous()
|
||||
# 2D conv: from 1d Variation to 2d Variation
|
||||
out = self.conv(out)
|
||||
# reshape back
|
||||
out = out.permute(0, 2, 3, 1).reshape(B, -1, N)
|
||||
res.append(out[:, :(self.seq_len + self.pred_len), :])
|
||||
res = torch.stack(res, dim=-1)
|
||||
# adaptive aggregation
|
||||
period_weight = F.softmax(period_weight, dim=1)
|
||||
period_weight = period_weight.unsqueeze(
|
||||
1).unsqueeze(1).repeat(1, T, N, 1)
|
||||
res = torch.sum(res * period_weight, -1)
|
||||
# residual connection
|
||||
res = res + x
|
||||
return res
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
"""
|
||||
Paper link: https://openreview.net/pdf?id=ju_Uqw384Oq
|
||||
"""
|
||||
|
||||
def __init__(self, configs):
|
||||
super(Model, self).__init__()
|
||||
self.configs = configs
|
||||
self.task_name = configs.task_name
|
||||
self.seq_len = configs.seq_len
|
||||
self.label_len = configs.label_len
|
||||
self.pred_len = configs.pred_len
|
||||
self.model = nn.ModuleList([TimesBlock(configs)
|
||||
for _ in range(configs.e_layers)])
|
||||
self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,
|
||||
configs.dropout)
|
||||
self.layer = configs.e_layers
|
||||
self.layer_norm = nn.LayerNorm(configs.d_model)
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
self.predict_linear = nn.Linear(
|
||||
self.seq_len, self.pred_len + self.seq_len)
|
||||
self.projection = nn.Linear(
|
||||
configs.d_model, configs.c_out, bias=True)
|
||||
if self.task_name == 'imputation' or self.task_name == 'anomaly_detection':
|
||||
self.projection = nn.Linear(
|
||||
configs.d_model, configs.c_out, bias=True)
|
||||
if self.task_name == 'classification':
|
||||
self.act = F.gelu
|
||||
self.dropout = nn.Dropout(configs.dropout)
|
||||
self.projection = nn.Linear(
|
||||
configs.d_model * configs.seq_len, configs.num_class)
|
||||
|
||||
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
||||
# Normalization from Non-stationary Transformer
|
||||
means = x_enc.mean(1, keepdim=True).detach()
|
||||
x_enc = x_enc.sub(means)
|
||||
stdev = torch.sqrt(
|
||||
torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
||||
x_enc = x_enc.div(stdev)
|
||||
|
||||
# embedding
|
||||
enc_out = self.enc_embedding(x_enc, x_mark_enc) # [B,T,C]
|
||||
enc_out = self.predict_linear(enc_out.permute(0, 2, 1)).permute(
|
||||
0, 2, 1) # align temporal dimension
|
||||
|
||||
# TimesNet
|
||||
for i in range(self.layer):
|
||||
enc_out = self.layer_norm(self.model[i](enc_out))
|
||||
# project back
|
||||
dec_out = self.projection(enc_out)
|
||||
|
||||
# De-Normalization from Non-stationary Transformer
|
||||
dec_out = dec_out.mul(
|
||||
(stdev[:, 0, :].unsqueeze(1).repeat(
|
||||
1, self.pred_len + self.seq_len, 1)))
|
||||
dec_out = dec_out.add(
|
||||
(means[:, 0, :].unsqueeze(1).repeat(
|
||||
1, self.pred_len + self.seq_len, 1)))
|
||||
return dec_out
|
||||
|
||||
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
|
||||
# Normalization from Non-stationary Transformer
|
||||
means = torch.sum(x_enc, dim=1) / torch.sum(mask == 1, dim=1)
|
||||
means = means.unsqueeze(1).detach()
|
||||
x_enc = x_enc.sub(means)
|
||||
x_enc = x_enc.masked_fill(mask == 0, 0)
|
||||
stdev = torch.sqrt(torch.sum(x_enc * x_enc, dim=1) /
|
||||
torch.sum(mask == 1, dim=1) + 1e-5)
|
||||
stdev = stdev.unsqueeze(1).detach()
|
||||
x_enc = x_enc.div(stdev)
|
||||
|
||||
# embedding
|
||||
enc_out = self.enc_embedding(x_enc, x_mark_enc) # [B,T,C]
|
||||
# TimesNet
|
||||
for i in range(self.layer):
|
||||
enc_out = self.layer_norm(self.model[i](enc_out))
|
||||
# project back
|
||||
dec_out = self.projection(enc_out)
|
||||
|
||||
# De-Normalization from Non-stationary Transformer
|
||||
dec_out = dec_out.mul(
|
||||
(stdev[:, 0, :].unsqueeze(1).repeat(
|
||||
1, self.pred_len + self.seq_len, 1)))
|
||||
dec_out = dec_out.add(
|
||||
(means[:, 0, :].unsqueeze(1).repeat(
|
||||
1, self.pred_len + self.seq_len, 1)))
|
||||
return dec_out
|
||||
|
||||
def anomaly_detection(self, x_enc):
|
||||
# Normalization from Non-stationary Transformer
|
||||
means = x_enc.mean(1, keepdim=True).detach()
|
||||
x_enc = x_enc.sub(means)
|
||||
stdev = torch.sqrt(
|
||||
torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
||||
x_enc = x_enc.div(stdev)
|
||||
|
||||
# embedding
|
||||
enc_out = self.enc_embedding(x_enc, None) # [B,T,C]
|
||||
# TimesNet
|
||||
for i in range(self.layer):
|
||||
enc_out = self.layer_norm(self.model[i](enc_out))
|
||||
# project back
|
||||
dec_out = self.projection(enc_out)
|
||||
|
||||
# De-Normalization from Non-stationary Transformer
|
||||
dec_out = dec_out.mul(
|
||||
(stdev[:, 0, :].unsqueeze(1).repeat(
|
||||
1, self.pred_len + self.seq_len, 1)))
|
||||
dec_out = dec_out.add(
|
||||
(means[:, 0, :].unsqueeze(1).repeat(
|
||||
1, self.pred_len + self.seq_len, 1)))
|
||||
return dec_out
|
||||
|
||||
def classification(self, x_enc, x_mark_enc):
|
||||
# embedding
|
||||
enc_out = self.enc_embedding(x_enc, None) # [B,T,C]
|
||||
# TimesNet
|
||||
for i in range(self.layer):
|
||||
enc_out = self.layer_norm(self.model[i](enc_out))
|
||||
|
||||
# Output
|
||||
# the output transformer encoder/decoder embeddings don't include non-linearity
|
||||
output = self.act(enc_out)
|
||||
output = self.dropout(output)
|
||||
# zero-out padding embeddings
|
||||
output = output * x_mark_enc.unsqueeze(-1)
|
||||
# (batch_size, seq_length * d_model)
|
||||
output = output.reshape(output.shape[0], -1)
|
||||
output = self.projection(output) # (batch_size, num_classes)
|
||||
return output
|
||||
|
||||
def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None):
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
||||
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
||||
if self.task_name == 'imputation':
|
||||
dec_out = self.imputation(
|
||||
x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'anomaly_detection':
|
||||
dec_out = self.anomaly_detection(x_enc)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'classification':
|
||||
dec_out = self.classification(x_enc, x_mark_enc)
|
||||
return dec_out # [B, N]
|
||||
return None
|
0
models/TimesNet/__init__.py
Normal file
0
models/TimesNet/__init__.py
Normal file
Reference in New Issue
Block a user