first commit
This commit is contained in:
157
models/Autoformer.py
Normal file
157
models/Autoformer.py
Normal file
@ -0,0 +1,157 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from layers.Embed import DataEmbedding, DataEmbedding_wo_pos
|
||||
from layers.AutoCorrelation import AutoCorrelation, AutoCorrelationLayer
|
||||
from layers.Autoformer_EncDec import Encoder, Decoder, EncoderLayer, DecoderLayer, my_Layernorm, series_decomp
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
"""
|
||||
Autoformer is the first method to achieve the series-wise connection,
|
||||
with inherent O(LlogL) complexity
|
||||
Paper link: https://openreview.net/pdf?id=I55UqU-M11y
|
||||
"""
|
||||
|
||||
def __init__(self, configs):
|
||||
super(Model, self).__init__()
|
||||
self.task_name = configs.task_name
|
||||
self.seq_len = configs.seq_len
|
||||
self.label_len = configs.label_len
|
||||
self.pred_len = configs.pred_len
|
||||
|
||||
# Decomp
|
||||
kernel_size = configs.moving_avg
|
||||
self.decomp = series_decomp(kernel_size)
|
||||
|
||||
# Embedding
|
||||
self.enc_embedding = DataEmbedding_wo_pos(configs.enc_in, configs.d_model, configs.embed, configs.freq,
|
||||
configs.dropout)
|
||||
# Encoder
|
||||
self.encoder = Encoder(
|
||||
[
|
||||
EncoderLayer(
|
||||
AutoCorrelationLayer(
|
||||
AutoCorrelation(False, configs.factor, attention_dropout=configs.dropout,
|
||||
output_attention=False),
|
||||
configs.d_model, configs.n_heads),
|
||||
configs.d_model,
|
||||
configs.d_ff,
|
||||
moving_avg=configs.moving_avg,
|
||||
dropout=configs.dropout,
|
||||
activation=configs.activation
|
||||
) for l in range(configs.e_layers)
|
||||
],
|
||||
norm_layer=my_Layernorm(configs.d_model)
|
||||
)
|
||||
# Decoder
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
self.dec_embedding = DataEmbedding_wo_pos(configs.dec_in, configs.d_model, configs.embed, configs.freq,
|
||||
configs.dropout)
|
||||
self.decoder = Decoder(
|
||||
[
|
||||
DecoderLayer(
|
||||
AutoCorrelationLayer(
|
||||
AutoCorrelation(True, configs.factor, attention_dropout=configs.dropout,
|
||||
output_attention=False),
|
||||
configs.d_model, configs.n_heads),
|
||||
AutoCorrelationLayer(
|
||||
AutoCorrelation(False, configs.factor, attention_dropout=configs.dropout,
|
||||
output_attention=False),
|
||||
configs.d_model, configs.n_heads),
|
||||
configs.d_model,
|
||||
configs.c_out,
|
||||
configs.d_ff,
|
||||
moving_avg=configs.moving_avg,
|
||||
dropout=configs.dropout,
|
||||
activation=configs.activation,
|
||||
)
|
||||
for l in range(configs.d_layers)
|
||||
],
|
||||
norm_layer=my_Layernorm(configs.d_model),
|
||||
projection=nn.Linear(configs.d_model, configs.c_out, bias=True)
|
||||
)
|
||||
if self.task_name == 'imputation':
|
||||
self.projection = nn.Linear(
|
||||
configs.d_model, configs.c_out, bias=True)
|
||||
if self.task_name == 'anomaly_detection':
|
||||
self.projection = nn.Linear(
|
||||
configs.d_model, configs.c_out, bias=True)
|
||||
if self.task_name == 'classification':
|
||||
self.act = F.gelu
|
||||
self.dropout = nn.Dropout(configs.dropout)
|
||||
self.projection = nn.Linear(
|
||||
configs.d_model * configs.seq_len, configs.num_class)
|
||||
|
||||
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
||||
# decomp init
|
||||
mean = torch.mean(x_enc, dim=1).unsqueeze(
|
||||
1).repeat(1, self.pred_len, 1)
|
||||
zeros = torch.zeros([x_dec.shape[0], self.pred_len,
|
||||
x_dec.shape[2]], device=x_enc.device)
|
||||
seasonal_init, trend_init = self.decomp(x_enc)
|
||||
# decoder input
|
||||
trend_init = torch.cat(
|
||||
[trend_init[:, -self.label_len:, :], mean], dim=1)
|
||||
seasonal_init = torch.cat(
|
||||
[seasonal_init[:, -self.label_len:, :], zeros], dim=1)
|
||||
# enc
|
||||
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
||||
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
||||
# dec
|
||||
dec_out = self.dec_embedding(seasonal_init, x_mark_dec)
|
||||
seasonal_part, trend_part = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None,
|
||||
trend=trend_init)
|
||||
# final
|
||||
dec_out = trend_part + seasonal_part
|
||||
return dec_out
|
||||
|
||||
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
|
||||
# enc
|
||||
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
||||
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
||||
# final
|
||||
dec_out = self.projection(enc_out)
|
||||
return dec_out
|
||||
|
||||
def anomaly_detection(self, x_enc):
|
||||
# enc
|
||||
enc_out = self.enc_embedding(x_enc, None)
|
||||
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
||||
# final
|
||||
dec_out = self.projection(enc_out)
|
||||
return dec_out
|
||||
|
||||
def classification(self, x_enc, x_mark_enc):
|
||||
# enc
|
||||
enc_out = self.enc_embedding(x_enc, None)
|
||||
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
||||
|
||||
# Output
|
||||
# the output transformer encoder/decoder embeddings don't include non-linearity
|
||||
output = self.act(enc_out)
|
||||
output = self.dropout(output)
|
||||
# zero-out padding embeddings
|
||||
output = output * x_mark_enc.unsqueeze(-1)
|
||||
# (batch_size, seq_length * d_model)
|
||||
output = output.reshape(output.shape[0], -1)
|
||||
output = self.projection(output) # (batch_size, num_classes)
|
||||
return output
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
||||
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
||||
if self.task_name == 'imputation':
|
||||
dec_out = self.imputation(
|
||||
x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'anomaly_detection':
|
||||
dec_out = self.anomaly_detection(x_enc)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'classification':
|
||||
dec_out = self.classification(x_enc, x_mark_enc)
|
||||
return dec_out # [B, N]
|
||||
return None
|
145
models/Crossformer.py
Normal file
145
models/Crossformer.py
Normal file
@ -0,0 +1,145 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from layers.Crossformer_EncDec import scale_block, Encoder, Decoder, DecoderLayer
|
||||
from layers.Embed import PatchEmbedding
|
||||
from layers.SelfAttention_Family import AttentionLayer, FullAttention, TwoStageAttentionLayer
|
||||
from models.PatchTST import FlattenHead
|
||||
|
||||
|
||||
from math import ceil
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
"""
|
||||
Paper link: https://openreview.net/pdf?id=vSVLM2j9eie
|
||||
"""
|
||||
def __init__(self, configs):
|
||||
super(Model, self).__init__()
|
||||
self.enc_in = configs.enc_in
|
||||
self.seq_len = configs.seq_len
|
||||
self.pred_len = configs.pred_len
|
||||
self.seg_len = 12
|
||||
self.win_size = 2
|
||||
self.task_name = configs.task_name
|
||||
|
||||
# The padding operation to handle invisible sgemnet length
|
||||
self.pad_in_len = ceil(1.0 * configs.seq_len / self.seg_len) * self.seg_len
|
||||
self.pad_out_len = ceil(1.0 * configs.pred_len / self.seg_len) * self.seg_len
|
||||
self.in_seg_num = self.pad_in_len // self.seg_len
|
||||
self.out_seg_num = ceil(self.in_seg_num / (self.win_size ** (configs.e_layers - 1)))
|
||||
self.head_nf = configs.d_model * self.out_seg_num
|
||||
|
||||
# Embedding
|
||||
self.enc_value_embedding = PatchEmbedding(configs.d_model, self.seg_len, self.seg_len, self.pad_in_len - configs.seq_len, 0)
|
||||
self.enc_pos_embedding = nn.Parameter(
|
||||
torch.randn(1, configs.enc_in, self.in_seg_num, configs.d_model))
|
||||
self.pre_norm = nn.LayerNorm(configs.d_model)
|
||||
|
||||
# Encoder
|
||||
self.encoder = Encoder(
|
||||
[
|
||||
scale_block(configs, 1 if l == 0 else self.win_size, configs.d_model, configs.n_heads, configs.d_ff,
|
||||
1, configs.dropout,
|
||||
self.in_seg_num if l == 0 else ceil(self.in_seg_num / self.win_size ** l), configs.factor
|
||||
) for l in range(configs.e_layers)
|
||||
]
|
||||
)
|
||||
# Decoder
|
||||
self.dec_pos_embedding = nn.Parameter(
|
||||
torch.randn(1, configs.enc_in, (self.pad_out_len // self.seg_len), configs.d_model))
|
||||
|
||||
self.decoder = Decoder(
|
||||
[
|
||||
DecoderLayer(
|
||||
TwoStageAttentionLayer(configs, (self.pad_out_len // self.seg_len), configs.factor, configs.d_model, configs.n_heads,
|
||||
configs.d_ff, configs.dropout),
|
||||
AttentionLayer(
|
||||
FullAttention(False, configs.factor, attention_dropout=configs.dropout,
|
||||
output_attention=False),
|
||||
configs.d_model, configs.n_heads),
|
||||
self.seg_len,
|
||||
configs.d_model,
|
||||
configs.d_ff,
|
||||
dropout=configs.dropout,
|
||||
# activation=configs.activation,
|
||||
)
|
||||
for l in range(configs.e_layers + 1)
|
||||
],
|
||||
)
|
||||
if self.task_name == 'imputation' or self.task_name == 'anomaly_detection':
|
||||
self.head = FlattenHead(configs.enc_in, self.head_nf, configs.seq_len,
|
||||
head_dropout=configs.dropout)
|
||||
elif self.task_name == 'classification':
|
||||
self.flatten = nn.Flatten(start_dim=-2)
|
||||
self.dropout = nn.Dropout(configs.dropout)
|
||||
self.projection = nn.Linear(
|
||||
self.head_nf * configs.enc_in, configs.num_class)
|
||||
|
||||
|
||||
|
||||
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
||||
# embedding
|
||||
x_enc, n_vars = self.enc_value_embedding(x_enc.permute(0, 2, 1))
|
||||
x_enc = rearrange(x_enc, '(b d) seg_num d_model -> b d seg_num d_model', d = n_vars)
|
||||
x_enc += self.enc_pos_embedding
|
||||
x_enc = self.pre_norm(x_enc)
|
||||
enc_out, attns = self.encoder(x_enc)
|
||||
|
||||
dec_in = repeat(self.dec_pos_embedding, 'b ts_d l d -> (repeat b) ts_d l d', repeat=x_enc.shape[0])
|
||||
dec_out = self.decoder(dec_in, enc_out)
|
||||
return dec_out
|
||||
|
||||
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
|
||||
# embedding
|
||||
x_enc, n_vars = self.enc_value_embedding(x_enc.permute(0, 2, 1))
|
||||
x_enc = rearrange(x_enc, '(b d) seg_num d_model -> b d seg_num d_model', d=n_vars)
|
||||
x_enc += self.enc_pos_embedding
|
||||
x_enc = self.pre_norm(x_enc)
|
||||
enc_out, attns = self.encoder(x_enc)
|
||||
|
||||
dec_out = self.head(enc_out[-1].permute(0, 1, 3, 2)).permute(0, 2, 1)
|
||||
|
||||
return dec_out
|
||||
|
||||
def anomaly_detection(self, x_enc):
|
||||
# embedding
|
||||
x_enc, n_vars = self.enc_value_embedding(x_enc.permute(0, 2, 1))
|
||||
x_enc = rearrange(x_enc, '(b d) seg_num d_model -> b d seg_num d_model', d=n_vars)
|
||||
x_enc += self.enc_pos_embedding
|
||||
x_enc = self.pre_norm(x_enc)
|
||||
enc_out, attns = self.encoder(x_enc)
|
||||
|
||||
dec_out = self.head(enc_out[-1].permute(0, 1, 3, 2)).permute(0, 2, 1)
|
||||
return dec_out
|
||||
|
||||
def classification(self, x_enc, x_mark_enc):
|
||||
# embedding
|
||||
x_enc, n_vars = self.enc_value_embedding(x_enc.permute(0, 2, 1))
|
||||
|
||||
x_enc = rearrange(x_enc, '(b d) seg_num d_model -> b d seg_num d_model', d=n_vars)
|
||||
x_enc += self.enc_pos_embedding
|
||||
x_enc = self.pre_norm(x_enc)
|
||||
enc_out, attns = self.encoder(x_enc)
|
||||
# Output from Non-stationary Transformer
|
||||
output = self.flatten(enc_out[-1].permute(0, 1, 3, 2))
|
||||
output = self.dropout(output)
|
||||
output = output.reshape(output.shape[0], -1)
|
||||
output = self.projection(output)
|
||||
return output
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
||||
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
||||
if self.task_name == 'imputation':
|
||||
dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'anomaly_detection':
|
||||
dec_out = self.anomaly_detection(x_enc)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'classification':
|
||||
dec_out = self.classification(x_enc, x_mark_enc)
|
||||
return dec_out # [B, N]
|
||||
return None
|
110
models/DLinear.py
Normal file
110
models/DLinear.py
Normal file
@ -0,0 +1,110 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from layers.Autoformer_EncDec import series_decomp
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
"""
|
||||
Paper link: https://arxiv.org/pdf/2205.13504.pdf
|
||||
"""
|
||||
|
||||
def __init__(self, configs, individual=False):
|
||||
"""
|
||||
individual: Bool, whether shared model among different variates.
|
||||
"""
|
||||
super(Model, self).__init__()
|
||||
self.task_name = configs.task_name
|
||||
self.seq_len = configs.seq_len
|
||||
if self.task_name == 'classification' or self.task_name == 'anomaly_detection' or self.task_name == 'imputation':
|
||||
self.pred_len = configs.seq_len
|
||||
else:
|
||||
self.pred_len = configs.pred_len
|
||||
# Series decomposition block from Autoformer
|
||||
self.decompsition = series_decomp(configs.moving_avg)
|
||||
self.individual = individual
|
||||
self.channels = configs.enc_in
|
||||
|
||||
if self.individual:
|
||||
self.Linear_Seasonal = nn.ModuleList()
|
||||
self.Linear_Trend = nn.ModuleList()
|
||||
|
||||
for i in range(self.channels):
|
||||
self.Linear_Seasonal.append(
|
||||
nn.Linear(self.seq_len, self.pred_len))
|
||||
self.Linear_Trend.append(
|
||||
nn.Linear(self.seq_len, self.pred_len))
|
||||
|
||||
self.Linear_Seasonal[i].weight = nn.Parameter(
|
||||
(1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len]))
|
||||
self.Linear_Trend[i].weight = nn.Parameter(
|
||||
(1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len]))
|
||||
else:
|
||||
self.Linear_Seasonal = nn.Linear(self.seq_len, self.pred_len)
|
||||
self.Linear_Trend = nn.Linear(self.seq_len, self.pred_len)
|
||||
|
||||
self.Linear_Seasonal.weight = nn.Parameter(
|
||||
(1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len]))
|
||||
self.Linear_Trend.weight = nn.Parameter(
|
||||
(1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len]))
|
||||
|
||||
if self.task_name == 'classification':
|
||||
self.projection = nn.Linear(
|
||||
configs.enc_in * configs.seq_len, configs.num_class)
|
||||
|
||||
def encoder(self, x):
|
||||
seasonal_init, trend_init = self.decompsition(x)
|
||||
seasonal_init, trend_init = seasonal_init.permute(
|
||||
0, 2, 1), trend_init.permute(0, 2, 1)
|
||||
if self.individual:
|
||||
seasonal_output = torch.zeros([seasonal_init.size(0), seasonal_init.size(1), self.pred_len],
|
||||
dtype=seasonal_init.dtype).to(seasonal_init.device)
|
||||
trend_output = torch.zeros([trend_init.size(0), trend_init.size(1), self.pred_len],
|
||||
dtype=trend_init.dtype).to(trend_init.device)
|
||||
for i in range(self.channels):
|
||||
seasonal_output[:, i, :] = self.Linear_Seasonal[i](
|
||||
seasonal_init[:, i, :])
|
||||
trend_output[:, i, :] = self.Linear_Trend[i](
|
||||
trend_init[:, i, :])
|
||||
else:
|
||||
seasonal_output = self.Linear_Seasonal(seasonal_init)
|
||||
trend_output = self.Linear_Trend(trend_init)
|
||||
x = seasonal_output + trend_output
|
||||
return x.permute(0, 2, 1)
|
||||
|
||||
def forecast(self, x_enc):
|
||||
# Encoder
|
||||
return self.encoder(x_enc)
|
||||
|
||||
def imputation(self, x_enc):
|
||||
# Encoder
|
||||
return self.encoder(x_enc)
|
||||
|
||||
def anomaly_detection(self, x_enc):
|
||||
# Encoder
|
||||
return self.encoder(x_enc)
|
||||
|
||||
def classification(self, x_enc):
|
||||
# Encoder
|
||||
enc_out = self.encoder(x_enc)
|
||||
# Output
|
||||
# (batch_size, seq_length * d_model)
|
||||
output = enc_out.reshape(enc_out.shape[0], -1)
|
||||
# (batch_size, num_classes)
|
||||
output = self.projection(output)
|
||||
return output
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
dec_out = self.forecast(x_enc)
|
||||
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
||||
if self.task_name == 'imputation':
|
||||
dec_out = self.imputation(x_enc)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'anomaly_detection':
|
||||
dec_out = self.anomaly_detection(x_enc)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'classification':
|
||||
dec_out = self.classification(x_enc)
|
||||
return dec_out # [B, N]
|
||||
return None
|
110
models/ETSformer.py
Normal file
110
models/ETSformer.py
Normal file
@ -0,0 +1,110 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from layers.Embed import DataEmbedding
|
||||
from layers.ETSformer_EncDec import EncoderLayer, Encoder, DecoderLayer, Decoder, Transform
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
"""
|
||||
Paper link: https://arxiv.org/abs/2202.01381
|
||||
"""
|
||||
|
||||
def __init__(self, configs):
|
||||
super(Model, self).__init__()
|
||||
self.task_name = configs.task_name
|
||||
self.seq_len = configs.seq_len
|
||||
self.label_len = configs.label_len
|
||||
if self.task_name == 'classification' or self.task_name == 'anomaly_detection' or self.task_name == 'imputation':
|
||||
self.pred_len = configs.seq_len
|
||||
else:
|
||||
self.pred_len = configs.pred_len
|
||||
|
||||
assert configs.e_layers == configs.d_layers, "Encoder and decoder layers must be equal"
|
||||
|
||||
# Embedding
|
||||
self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,
|
||||
configs.dropout)
|
||||
|
||||
# Encoder
|
||||
self.encoder = Encoder(
|
||||
[
|
||||
EncoderLayer(
|
||||
configs.d_model, configs.n_heads, configs.enc_in, configs.seq_len, self.pred_len, configs.top_k,
|
||||
dim_feedforward=configs.d_ff,
|
||||
dropout=configs.dropout,
|
||||
activation=configs.activation,
|
||||
) for _ in range(configs.e_layers)
|
||||
]
|
||||
)
|
||||
# Decoder
|
||||
self.decoder = Decoder(
|
||||
[
|
||||
DecoderLayer(
|
||||
configs.d_model, configs.n_heads, configs.c_out, self.pred_len,
|
||||
dropout=configs.dropout,
|
||||
) for _ in range(configs.d_layers)
|
||||
],
|
||||
)
|
||||
self.transform = Transform(sigma=0.2)
|
||||
|
||||
if self.task_name == 'classification':
|
||||
self.act = torch.nn.functional.gelu
|
||||
self.dropout = nn.Dropout(configs.dropout)
|
||||
self.projection = nn.Linear(configs.d_model * configs.seq_len, configs.num_class)
|
||||
|
||||
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
||||
with torch.no_grad():
|
||||
if self.training:
|
||||
x_enc = self.transform.transform(x_enc)
|
||||
res = self.enc_embedding(x_enc, x_mark_enc)
|
||||
level, growths, seasons = self.encoder(res, x_enc, attn_mask=None)
|
||||
|
||||
growth, season = self.decoder(growths, seasons)
|
||||
preds = level[:, -1:] + growth + season
|
||||
return preds
|
||||
|
||||
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
|
||||
res = self.enc_embedding(x_enc, x_mark_enc)
|
||||
level, growths, seasons = self.encoder(res, x_enc, attn_mask=None)
|
||||
growth, season = self.decoder(growths, seasons)
|
||||
preds = level[:, -1:] + growth + season
|
||||
return preds
|
||||
|
||||
def anomaly_detection(self, x_enc):
|
||||
res = self.enc_embedding(x_enc, None)
|
||||
level, growths, seasons = self.encoder(res, x_enc, attn_mask=None)
|
||||
growth, season = self.decoder(growths, seasons)
|
||||
preds = level[:, -1:] + growth + season
|
||||
return preds
|
||||
|
||||
def classification(self, x_enc, x_mark_enc):
|
||||
res = self.enc_embedding(x_enc, None)
|
||||
_, growths, seasons = self.encoder(res, x_enc, attn_mask=None)
|
||||
|
||||
growths = torch.sum(torch.stack(growths, 0), 0)[:, :self.seq_len, :]
|
||||
seasons = torch.sum(torch.stack(seasons, 0), 0)[:, :self.seq_len, :]
|
||||
|
||||
enc_out = growths + seasons
|
||||
output = self.act(enc_out) # the output transformer encoder/decoder embeddings don't include non-linearity
|
||||
output = self.dropout(output)
|
||||
|
||||
# Output
|
||||
output = output * x_mark_enc.unsqueeze(-1) # zero-out padding embeddings
|
||||
output = output.reshape(output.shape[0], -1) # (batch_size, seq_length * d_model)
|
||||
output = self.projection(output) # (batch_size, num_classes)
|
||||
return output
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
||||
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
||||
if self.task_name == 'imputation':
|
||||
dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'anomaly_detection':
|
||||
dec_out = self.anomaly_detection(x_enc)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'classification':
|
||||
dec_out = self.classification(x_enc, x_mark_enc)
|
||||
return dec_out # [B, N]
|
||||
return None
|
178
models/FEDformer.py
Normal file
178
models/FEDformer.py
Normal file
@ -0,0 +1,178 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from layers.Embed import DataEmbedding
|
||||
from layers.AutoCorrelation import AutoCorrelationLayer
|
||||
from layers.FourierCorrelation import FourierBlock, FourierCrossAttention
|
||||
from layers.MultiWaveletCorrelation import MultiWaveletCross, MultiWaveletTransform
|
||||
from layers.Autoformer_EncDec import Encoder, Decoder, EncoderLayer, DecoderLayer, my_Layernorm, series_decomp
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
"""
|
||||
FEDformer performs the attention mechanism on frequency domain and achieved O(N) complexity
|
||||
Paper link: https://proceedings.mlr.press/v162/zhou22g.html
|
||||
"""
|
||||
|
||||
def __init__(self, configs, version='fourier', mode_select='random', modes=32):
|
||||
"""
|
||||
version: str, for FEDformer, there are two versions to choose, options: [Fourier, Wavelets].
|
||||
mode_select: str, for FEDformer, there are two mode selection method, options: [random, low].
|
||||
modes: int, modes to be selected.
|
||||
"""
|
||||
super(Model, self).__init__()
|
||||
self.task_name = configs.task_name
|
||||
self.seq_len = configs.seq_len
|
||||
self.label_len = configs.label_len
|
||||
self.pred_len = configs.pred_len
|
||||
|
||||
self.version = version
|
||||
self.mode_select = mode_select
|
||||
self.modes = modes
|
||||
|
||||
# Decomp
|
||||
self.decomp = series_decomp(configs.moving_avg)
|
||||
self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,
|
||||
configs.dropout)
|
||||
self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq,
|
||||
configs.dropout)
|
||||
|
||||
if self.version == 'Wavelets':
|
||||
encoder_self_att = MultiWaveletTransform(ich=configs.d_model, L=1, base='legendre')
|
||||
decoder_self_att = MultiWaveletTransform(ich=configs.d_model, L=1, base='legendre')
|
||||
decoder_cross_att = MultiWaveletCross(in_channels=configs.d_model,
|
||||
out_channels=configs.d_model,
|
||||
seq_len_q=self.seq_len // 2 + self.pred_len,
|
||||
seq_len_kv=self.seq_len,
|
||||
modes=self.modes,
|
||||
ich=configs.d_model,
|
||||
base='legendre',
|
||||
activation='tanh')
|
||||
else:
|
||||
encoder_self_att = FourierBlock(in_channels=configs.d_model,
|
||||
out_channels=configs.d_model,
|
||||
n_heads=configs.n_heads,
|
||||
seq_len=self.seq_len,
|
||||
modes=self.modes,
|
||||
mode_select_method=self.mode_select)
|
||||
decoder_self_att = FourierBlock(in_channels=configs.d_model,
|
||||
out_channels=configs.d_model,
|
||||
n_heads=configs.n_heads,
|
||||
seq_len=self.seq_len // 2 + self.pred_len,
|
||||
modes=self.modes,
|
||||
mode_select_method=self.mode_select)
|
||||
decoder_cross_att = FourierCrossAttention(in_channels=configs.d_model,
|
||||
out_channels=configs.d_model,
|
||||
seq_len_q=self.seq_len // 2 + self.pred_len,
|
||||
seq_len_kv=self.seq_len,
|
||||
modes=self.modes,
|
||||
mode_select_method=self.mode_select,
|
||||
num_heads=configs.n_heads)
|
||||
# Encoder
|
||||
self.encoder = Encoder(
|
||||
[
|
||||
EncoderLayer(
|
||||
AutoCorrelationLayer(
|
||||
encoder_self_att, # instead of multi-head attention in transformer
|
||||
configs.d_model, configs.n_heads),
|
||||
configs.d_model,
|
||||
configs.d_ff,
|
||||
moving_avg=configs.moving_avg,
|
||||
dropout=configs.dropout,
|
||||
activation=configs.activation
|
||||
) for l in range(configs.e_layers)
|
||||
],
|
||||
norm_layer=my_Layernorm(configs.d_model)
|
||||
)
|
||||
# Decoder
|
||||
self.decoder = Decoder(
|
||||
[
|
||||
DecoderLayer(
|
||||
AutoCorrelationLayer(
|
||||
decoder_self_att,
|
||||
configs.d_model, configs.n_heads),
|
||||
AutoCorrelationLayer(
|
||||
decoder_cross_att,
|
||||
configs.d_model, configs.n_heads),
|
||||
configs.d_model,
|
||||
configs.c_out,
|
||||
configs.d_ff,
|
||||
moving_avg=configs.moving_avg,
|
||||
dropout=configs.dropout,
|
||||
activation=configs.activation,
|
||||
)
|
||||
for l in range(configs.d_layers)
|
||||
],
|
||||
norm_layer=my_Layernorm(configs.d_model),
|
||||
projection=nn.Linear(configs.d_model, configs.c_out, bias=True)
|
||||
)
|
||||
|
||||
if self.task_name == 'imputation':
|
||||
self.projection = nn.Linear(configs.d_model, configs.c_out, bias=True)
|
||||
if self.task_name == 'anomaly_detection':
|
||||
self.projection = nn.Linear(configs.d_model, configs.c_out, bias=True)
|
||||
if self.task_name == 'classification':
|
||||
self.act = F.gelu
|
||||
self.dropout = nn.Dropout(configs.dropout)
|
||||
self.projection = nn.Linear(configs.d_model * configs.seq_len, configs.num_class)
|
||||
|
||||
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
||||
# decomp init
|
||||
mean = torch.mean(x_enc, dim=1).unsqueeze(1).repeat(1, self.pred_len, 1)
|
||||
seasonal_init, trend_init = self.decomp(x_enc) # x - moving_avg, moving_avg
|
||||
# decoder input
|
||||
trend_init = torch.cat([trend_init[:, -self.label_len:, :], mean], dim=1)
|
||||
seasonal_init = F.pad(seasonal_init[:, -self.label_len:, :], (0, 0, 0, self.pred_len))
|
||||
# enc
|
||||
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
||||
dec_out = self.dec_embedding(seasonal_init, x_mark_dec)
|
||||
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
||||
# dec
|
||||
seasonal_part, trend_part = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None, trend=trend_init)
|
||||
# final
|
||||
dec_out = trend_part + seasonal_part
|
||||
return dec_out
|
||||
|
||||
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
|
||||
# enc
|
||||
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
||||
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
||||
# final
|
||||
dec_out = self.projection(enc_out)
|
||||
return dec_out
|
||||
|
||||
def anomaly_detection(self, x_enc):
|
||||
# enc
|
||||
enc_out = self.enc_embedding(x_enc, None)
|
||||
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
||||
# final
|
||||
dec_out = self.projection(enc_out)
|
||||
return dec_out
|
||||
|
||||
def classification(self, x_enc, x_mark_enc):
|
||||
# enc
|
||||
enc_out = self.enc_embedding(x_enc, None)
|
||||
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
||||
|
||||
# Output
|
||||
output = self.act(enc_out)
|
||||
output = self.dropout(output)
|
||||
output = output * x_mark_enc.unsqueeze(-1)
|
||||
output = output.reshape(output.shape[0], -1)
|
||||
output = self.projection(output)
|
||||
return output
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
||||
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
||||
if self.task_name == 'imputation':
|
||||
dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'anomaly_detection':
|
||||
dec_out = self.anomaly_detection(x_enc)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'classification':
|
||||
dec_out = self.classification(x_enc, x_mark_enc)
|
||||
return dec_out # [B, N]
|
||||
return None
|
268
models/FiLM.py
Normal file
268
models/FiLM.py
Normal file
@ -0,0 +1,268 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from scipy import signal
|
||||
from scipy import special as ss
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
def transition(N):
|
||||
Q = np.arange(N, dtype=np.float64)
|
||||
R = (2 * Q + 1)[:, None] # / theta
|
||||
j, i = np.meshgrid(Q, Q)
|
||||
A = np.where(i < j, -1, (-1.) ** (i - j + 1)) * R
|
||||
B = (-1.) ** Q[:, None] * R
|
||||
return A, B
|
||||
|
||||
|
||||
class HiPPO_LegT(nn.Module):
|
||||
def __init__(self, N, dt=1.0, discretization='bilinear'):
|
||||
"""
|
||||
N: the order of the HiPPO projection
|
||||
dt: discretization step size - should be roughly inverse to the length of the sequence
|
||||
"""
|
||||
super(HiPPO_LegT, self).__init__()
|
||||
self.N = N
|
||||
A, B = transition(N)
|
||||
C = np.ones((1, N))
|
||||
D = np.zeros((1,))
|
||||
A, B, _, _, _ = signal.cont2discrete((A, B, C, D), dt=dt, method=discretization)
|
||||
|
||||
B = B.squeeze(-1)
|
||||
|
||||
self.register_buffer('A', torch.Tensor(A).to(device))
|
||||
self.register_buffer('B', torch.Tensor(B).to(device))
|
||||
vals = np.arange(0.0, 1.0, dt)
|
||||
self.register_buffer('eval_matrix', torch.Tensor(
|
||||
ss.eval_legendre(np.arange(N)[:, None], 1 - 2 * vals).T).to(device))
|
||||
|
||||
def forward(self, inputs):
|
||||
"""
|
||||
inputs : (length, ...)
|
||||
output : (length, ..., N) where N is the order of the HiPPO projection
|
||||
"""
|
||||
c = torch.zeros(inputs.shape[:-1] + tuple([self.N])).to(device)
|
||||
cs = []
|
||||
for f in inputs.permute([-1, 0, 1]):
|
||||
f = f.unsqueeze(-1)
|
||||
new = f @ self.B.unsqueeze(0)
|
||||
c = F.linear(c, self.A) + new
|
||||
cs.append(c)
|
||||
return torch.stack(cs, dim=0)
|
||||
|
||||
def reconstruct(self, c):
|
||||
return (self.eval_matrix @ c.unsqueeze(-1)).squeeze(-1)
|
||||
|
||||
|
||||
class SpectralConv1d(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, seq_len, ratio=0.5):
|
||||
"""
|
||||
1D Fourier layer. It does FFT, linear transform, and Inverse FFT.
|
||||
"""
|
||||
super(SpectralConv1d, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.ratio = ratio
|
||||
self.modes = min(32, seq_len // 2)
|
||||
self.index = list(range(0, self.modes))
|
||||
|
||||
self.scale = (1 / (in_channels * out_channels))
|
||||
self.weights_real = nn.Parameter(
|
||||
self.scale * torch.rand(in_channels, out_channels, len(self.index), dtype=torch.float))
|
||||
self.weights_imag = nn.Parameter(
|
||||
self.scale * torch.rand(in_channels, out_channels, len(self.index), dtype=torch.float))
|
||||
|
||||
def compl_mul1d(self, order, x, weights_real, weights_imag):
|
||||
return torch.complex(torch.einsum(order, x.real, weights_real) - torch.einsum(order, x.imag, weights_imag),
|
||||
torch.einsum(order, x.real, weights_imag) + torch.einsum(order, x.imag, weights_real))
|
||||
|
||||
def forward(self, x):
|
||||
B, H, E, N = x.shape
|
||||
x_ft = torch.fft.rfft(x)
|
||||
out_ft = torch.zeros(B, H, self.out_channels, x.size(-1) // 2 + 1, device=x.device, dtype=torch.cfloat)
|
||||
a = x_ft[:, :, :, :self.modes]
|
||||
out_ft[:, :, :, :self.modes] = self.compl_mul1d("bjix,iox->bjox", a, self.weights_real, self.weights_imag)
|
||||
x = torch.fft.irfft(out_ft, n=x.size(-1))
|
||||
return x
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
"""
|
||||
Paper link: https://arxiv.org/abs/2205.08897
|
||||
"""
|
||||
def __init__(self, configs):
|
||||
super(Model, self).__init__()
|
||||
self.task_name = configs.task_name
|
||||
self.configs = configs
|
||||
self.seq_len = configs.seq_len
|
||||
self.label_len = configs.label_len
|
||||
self.pred_len = configs.seq_len if configs.pred_len == 0 else configs.pred_len
|
||||
|
||||
self.seq_len_all = self.seq_len + self.label_len
|
||||
|
||||
self.layers = configs.e_layers
|
||||
self.enc_in = configs.enc_in
|
||||
self.e_layers = configs.e_layers
|
||||
# b, s, f means b, f
|
||||
self.affine_weight = nn.Parameter(torch.ones(1, 1, configs.enc_in))
|
||||
self.affine_bias = nn.Parameter(torch.zeros(1, 1, configs.enc_in))
|
||||
|
||||
self.multiscale = [1, 2, 4]
|
||||
self.window_size = [256]
|
||||
configs.ratio = 0.5
|
||||
self.legts = nn.ModuleList(
|
||||
[HiPPO_LegT(N=n, dt=1. / self.pred_len / i) for n in self.window_size for i in self.multiscale])
|
||||
self.spec_conv_1 = nn.ModuleList([SpectralConv1d(in_channels=n, out_channels=n,
|
||||
seq_len=min(self.pred_len, self.seq_len),
|
||||
ratio=configs.ratio) for n in
|
||||
self.window_size for _ in range(len(self.multiscale))])
|
||||
self.mlp = nn.Linear(len(self.multiscale) * len(self.window_size), 1)
|
||||
|
||||
if self.task_name == 'imputation' or self.task_name == 'anomaly_detection':
|
||||
self.projection = nn.Linear(
|
||||
configs.d_model, configs.c_out, bias=True)
|
||||
if self.task_name == 'classification':
|
||||
self.act = F.gelu
|
||||
self.dropout = nn.Dropout(configs.dropout)
|
||||
self.projection = nn.Linear(
|
||||
configs.enc_in * configs.seq_len, configs.num_class)
|
||||
|
||||
def forecast(self, x_enc, x_mark_enc, x_dec_true, x_mark_dec):
|
||||
# Normalization from Non-stationary Transformer
|
||||
means = x_enc.mean(1, keepdim=True).detach()
|
||||
x_enc = x_enc - means
|
||||
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
|
||||
x_enc /= stdev
|
||||
|
||||
x_enc = x_enc * self.affine_weight + self.affine_bias
|
||||
x_decs = []
|
||||
jump_dist = 0
|
||||
for i in range(0, len(self.multiscale) * len(self.window_size)):
|
||||
x_in_len = self.multiscale[i % len(self.multiscale)] * self.pred_len
|
||||
x_in = x_enc[:, -x_in_len:]
|
||||
legt = self.legts[i]
|
||||
x_in_c = legt(x_in.transpose(1, 2)).permute([1, 2, 3, 0])[:, :, :, jump_dist:]
|
||||
out1 = self.spec_conv_1[i](x_in_c)
|
||||
if self.seq_len >= self.pred_len:
|
||||
x_dec_c = out1.transpose(2, 3)[:, :, self.pred_len - 1 - jump_dist, :]
|
||||
else:
|
||||
x_dec_c = out1.transpose(2, 3)[:, :, -1, :]
|
||||
x_dec = x_dec_c @ legt.eval_matrix[-self.pred_len:, :].T
|
||||
x_decs.append(x_dec)
|
||||
x_dec = torch.stack(x_decs, dim=-1)
|
||||
x_dec = self.mlp(x_dec).squeeze(-1).permute(0, 2, 1)
|
||||
|
||||
# De-Normalization from Non-stationary Transformer
|
||||
x_dec = x_dec - self.affine_bias
|
||||
x_dec = x_dec / (self.affine_weight + 1e-10)
|
||||
x_dec = x_dec * stdev
|
||||
x_dec = x_dec + means
|
||||
return x_dec
|
||||
|
||||
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
|
||||
# Normalization from Non-stationary Transformer
|
||||
means = x_enc.mean(1, keepdim=True).detach()
|
||||
x_enc = x_enc - means
|
||||
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
|
||||
x_enc /= stdev
|
||||
|
||||
x_enc = x_enc * self.affine_weight + self.affine_bias
|
||||
x_decs = []
|
||||
jump_dist = 0
|
||||
for i in range(0, len(self.multiscale) * len(self.window_size)):
|
||||
x_in_len = self.multiscale[i % len(self.multiscale)] * self.pred_len
|
||||
x_in = x_enc[:, -x_in_len:]
|
||||
legt = self.legts[i]
|
||||
x_in_c = legt(x_in.transpose(1, 2)).permute([1, 2, 3, 0])[:, :, :, jump_dist:]
|
||||
out1 = self.spec_conv_1[i](x_in_c)
|
||||
if self.seq_len >= self.pred_len:
|
||||
x_dec_c = out1.transpose(2, 3)[:, :, self.pred_len - 1 - jump_dist, :]
|
||||
else:
|
||||
x_dec_c = out1.transpose(2, 3)[:, :, -1, :]
|
||||
x_dec = x_dec_c @ legt.eval_matrix[-self.pred_len:, :].T
|
||||
x_decs.append(x_dec)
|
||||
x_dec = torch.stack(x_decs, dim=-1)
|
||||
x_dec = self.mlp(x_dec).squeeze(-1).permute(0, 2, 1)
|
||||
|
||||
# De-Normalization from Non-stationary Transformer
|
||||
x_dec = x_dec - self.affine_bias
|
||||
x_dec = x_dec / (self.affine_weight + 1e-10)
|
||||
x_dec = x_dec * stdev
|
||||
x_dec = x_dec + means
|
||||
return x_dec
|
||||
|
||||
def anomaly_detection(self, x_enc):
|
||||
# Normalization from Non-stationary Transformer
|
||||
means = x_enc.mean(1, keepdim=True).detach()
|
||||
x_enc = x_enc - means
|
||||
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
|
||||
x_enc /= stdev
|
||||
|
||||
x_enc = x_enc * self.affine_weight + self.affine_bias
|
||||
x_decs = []
|
||||
jump_dist = 0
|
||||
for i in range(0, len(self.multiscale) * len(self.window_size)):
|
||||
x_in_len = self.multiscale[i % len(self.multiscale)] * self.pred_len
|
||||
x_in = x_enc[:, -x_in_len:]
|
||||
legt = self.legts[i]
|
||||
x_in_c = legt(x_in.transpose(1, 2)).permute([1, 2, 3, 0])[:, :, :, jump_dist:]
|
||||
out1 = self.spec_conv_1[i](x_in_c)
|
||||
if self.seq_len >= self.pred_len:
|
||||
x_dec_c = out1.transpose(2, 3)[:, :, self.pred_len - 1 - jump_dist, :]
|
||||
else:
|
||||
x_dec_c = out1.transpose(2, 3)[:, :, -1, :]
|
||||
x_dec = x_dec_c @ legt.eval_matrix[-self.pred_len:, :].T
|
||||
x_decs.append(x_dec)
|
||||
x_dec = torch.stack(x_decs, dim=-1)
|
||||
x_dec = self.mlp(x_dec).squeeze(-1).permute(0, 2, 1)
|
||||
|
||||
# De-Normalization from Non-stationary Transformer
|
||||
x_dec = x_dec - self.affine_bias
|
||||
x_dec = x_dec / (self.affine_weight + 1e-10)
|
||||
x_dec = x_dec * stdev
|
||||
x_dec = x_dec + means
|
||||
return x_dec
|
||||
|
||||
def classification(self, x_enc, x_mark_enc):
|
||||
x_enc = x_enc * self.affine_weight + self.affine_bias
|
||||
x_decs = []
|
||||
jump_dist = 0
|
||||
for i in range(0, len(self.multiscale) * len(self.window_size)):
|
||||
x_in_len = self.multiscale[i % len(self.multiscale)] * self.pred_len
|
||||
x_in = x_enc[:, -x_in_len:]
|
||||
legt = self.legts[i]
|
||||
x_in_c = legt(x_in.transpose(1, 2)).permute([1, 2, 3, 0])[:, :, :, jump_dist:]
|
||||
out1 = self.spec_conv_1[i](x_in_c)
|
||||
if self.seq_len >= self.pred_len:
|
||||
x_dec_c = out1.transpose(2, 3)[:, :, self.pred_len - 1 - jump_dist, :]
|
||||
else:
|
||||
x_dec_c = out1.transpose(2, 3)[:, :, -1, :]
|
||||
x_dec = x_dec_c @ legt.eval_matrix[-self.pred_len:, :].T
|
||||
x_decs.append(x_dec)
|
||||
x_dec = torch.stack(x_decs, dim=-1)
|
||||
x_dec = self.mlp(x_dec).squeeze(-1).permute(0, 2, 1)
|
||||
|
||||
# Output from Non-stationary Transformer
|
||||
output = self.act(x_dec)
|
||||
output = self.dropout(output)
|
||||
output = output * x_mark_enc.unsqueeze(-1)
|
||||
output = output.reshape(output.shape[0], -1)
|
||||
output = self.projection(output)
|
||||
return output
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
||||
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
||||
if self.task_name == 'imputation':
|
||||
dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'anomaly_detection':
|
||||
dec_out = self.anomaly_detection(x_enc)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'classification':
|
||||
dec_out = self.classification(x_enc, x_mark_enc)
|
||||
return dec_out # [B, N]
|
||||
return None
|
118
models/FreTS.py
Normal file
118
models/FreTS.py
Normal file
@ -0,0 +1,118 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
"""
|
||||
Paper link: https://arxiv.org/pdf/2311.06184.pdf
|
||||
"""
|
||||
|
||||
def __init__(self, configs):
|
||||
super(Model, self).__init__()
|
||||
self.task_name = configs.task_name
|
||||
if self.task_name == 'classification' or self.task_name == 'anomaly_detection' or self.task_name == 'imputation':
|
||||
self.pred_len = configs.seq_len
|
||||
else:
|
||||
self.pred_len = configs.pred_len
|
||||
self.embed_size = 128 # embed_size
|
||||
self.hidden_size = 256 # hidden_size
|
||||
self.pred_len = configs.pred_len
|
||||
self.feature_size = configs.enc_in # channels
|
||||
self.seq_len = configs.seq_len
|
||||
self.channel_independence = configs.channel_independence
|
||||
self.sparsity_threshold = 0.01
|
||||
self.scale = 0.02
|
||||
self.embeddings = nn.Parameter(torch.randn(1, self.embed_size))
|
||||
self.r1 = nn.Parameter(self.scale * torch.randn(self.embed_size, self.embed_size))
|
||||
self.i1 = nn.Parameter(self.scale * torch.randn(self.embed_size, self.embed_size))
|
||||
self.rb1 = nn.Parameter(self.scale * torch.randn(self.embed_size))
|
||||
self.ib1 = nn.Parameter(self.scale * torch.randn(self.embed_size))
|
||||
self.r2 = nn.Parameter(self.scale * torch.randn(self.embed_size, self.embed_size))
|
||||
self.i2 = nn.Parameter(self.scale * torch.randn(self.embed_size, self.embed_size))
|
||||
self.rb2 = nn.Parameter(self.scale * torch.randn(self.embed_size))
|
||||
self.ib2 = nn.Parameter(self.scale * torch.randn(self.embed_size))
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(self.seq_len * self.embed_size, self.hidden_size),
|
||||
nn.LeakyReLU(),
|
||||
nn.Linear(self.hidden_size, self.pred_len)
|
||||
)
|
||||
|
||||
# dimension extension
|
||||
def tokenEmb(self, x):
|
||||
# x: [Batch, Input length, Channel]
|
||||
x = x.permute(0, 2, 1)
|
||||
x = x.unsqueeze(3)
|
||||
# N*T*1 x 1*D = N*T*D
|
||||
y = self.embeddings
|
||||
return x * y
|
||||
|
||||
# frequency temporal learner
|
||||
def MLP_temporal(self, x, B, N, L):
|
||||
# [B, N, T, D]
|
||||
x = torch.fft.rfft(x, dim=2, norm='ortho') # FFT on L dimension
|
||||
y = self.FreMLP(B, N, L, x, self.r2, self.i2, self.rb2, self.ib2)
|
||||
x = torch.fft.irfft(y, n=self.seq_len, dim=2, norm="ortho")
|
||||
return x
|
||||
|
||||
# frequency channel learner
|
||||
def MLP_channel(self, x, B, N, L):
|
||||
# [B, N, T, D]
|
||||
x = x.permute(0, 2, 1, 3)
|
||||
# [B, T, N, D]
|
||||
x = torch.fft.rfft(x, dim=2, norm='ortho') # FFT on N dimension
|
||||
y = self.FreMLP(B, L, N, x, self.r1, self.i1, self.rb1, self.ib1)
|
||||
x = torch.fft.irfft(y, n=self.feature_size, dim=2, norm="ortho")
|
||||
x = x.permute(0, 2, 1, 3)
|
||||
# [B, N, T, D]
|
||||
return x
|
||||
|
||||
# frequency-domain MLPs
|
||||
# dimension: FFT along the dimension, r: the real part of weights, i: the imaginary part of weights
|
||||
# rb: the real part of bias, ib: the imaginary part of bias
|
||||
def FreMLP(self, B, nd, dimension, x, r, i, rb, ib):
|
||||
o1_real = torch.zeros([B, nd, dimension // 2 + 1, self.embed_size],
|
||||
device=x.device)
|
||||
o1_imag = torch.zeros([B, nd, dimension // 2 + 1, self.embed_size],
|
||||
device=x.device)
|
||||
|
||||
o1_real = F.relu(
|
||||
torch.einsum('bijd,dd->bijd', x.real, r) - \
|
||||
torch.einsum('bijd,dd->bijd', x.imag, i) + \
|
||||
rb
|
||||
)
|
||||
|
||||
o1_imag = F.relu(
|
||||
torch.einsum('bijd,dd->bijd', x.imag, r) + \
|
||||
torch.einsum('bijd,dd->bijd', x.real, i) + \
|
||||
ib
|
||||
)
|
||||
|
||||
y = torch.stack([o1_real, o1_imag], dim=-1)
|
||||
y = F.softshrink(y, lambd=self.sparsity_threshold)
|
||||
y = torch.view_as_complex(y)
|
||||
return y
|
||||
|
||||
def forecast(self, x_enc):
|
||||
# x: [Batch, Input length, Channel]
|
||||
B, T, N = x_enc.shape
|
||||
# embedding x: [B, N, T, D]
|
||||
x = self.tokenEmb(x_enc)
|
||||
bias = x
|
||||
# [B, N, T, D]
|
||||
if self.channel_independence == '0':
|
||||
x = self.MLP_channel(x, B, N, T)
|
||||
# [B, N, T, D]
|
||||
x = self.MLP_temporal(x, B, N, T)
|
||||
x = x + bias
|
||||
x = self.fc(x.reshape(B, N, -1)).permute(0, 2, 1)
|
||||
return x
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
dec_out = self.forecast(x_enc)
|
||||
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
||||
else:
|
||||
raise ValueError('Only forecast tasks implemented yet')
|
147
models/Informer.py
Normal file
147
models/Informer.py
Normal file
@ -0,0 +1,147 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from layers.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer, ConvLayer
|
||||
from layers.SelfAttention_Family import ProbAttention, AttentionLayer
|
||||
from layers.Embed import DataEmbedding
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
"""
|
||||
Informer with Propspare attention in O(LlogL) complexity
|
||||
Paper link: https://ojs.aaai.org/index.php/AAAI/article/view/17325/17132
|
||||
"""
|
||||
|
||||
def __init__(self, configs):
|
||||
super(Model, self).__init__()
|
||||
self.task_name = configs.task_name
|
||||
self.pred_len = configs.pred_len
|
||||
self.label_len = configs.label_len
|
||||
|
||||
# Embedding
|
||||
self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,
|
||||
configs.dropout)
|
||||
self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq,
|
||||
configs.dropout)
|
||||
|
||||
# Encoder
|
||||
self.encoder = Encoder(
|
||||
[
|
||||
EncoderLayer(
|
||||
AttentionLayer(
|
||||
ProbAttention(False, configs.factor, attention_dropout=configs.dropout,
|
||||
output_attention=False),
|
||||
configs.d_model, configs.n_heads),
|
||||
configs.d_model,
|
||||
configs.d_ff,
|
||||
dropout=configs.dropout,
|
||||
activation=configs.activation
|
||||
) for l in range(configs.e_layers)
|
||||
],
|
||||
[
|
||||
ConvLayer(
|
||||
configs.d_model
|
||||
) for l in range(configs.e_layers - 1)
|
||||
] if configs.distil and ('forecast' in configs.task_name) else None,
|
||||
norm_layer=torch.nn.LayerNorm(configs.d_model)
|
||||
)
|
||||
# Decoder
|
||||
self.decoder = Decoder(
|
||||
[
|
||||
DecoderLayer(
|
||||
AttentionLayer(
|
||||
ProbAttention(True, configs.factor, attention_dropout=configs.dropout, output_attention=False),
|
||||
configs.d_model, configs.n_heads),
|
||||
AttentionLayer(
|
||||
ProbAttention(False, configs.factor, attention_dropout=configs.dropout, output_attention=False),
|
||||
configs.d_model, configs.n_heads),
|
||||
configs.d_model,
|
||||
configs.d_ff,
|
||||
dropout=configs.dropout,
|
||||
activation=configs.activation,
|
||||
)
|
||||
for l in range(configs.d_layers)
|
||||
],
|
||||
norm_layer=torch.nn.LayerNorm(configs.d_model),
|
||||
projection=nn.Linear(configs.d_model, configs.c_out, bias=True)
|
||||
)
|
||||
if self.task_name == 'imputation':
|
||||
self.projection = nn.Linear(configs.d_model, configs.c_out, bias=True)
|
||||
if self.task_name == 'anomaly_detection':
|
||||
self.projection = nn.Linear(configs.d_model, configs.c_out, bias=True)
|
||||
if self.task_name == 'classification':
|
||||
self.act = F.gelu
|
||||
self.dropout = nn.Dropout(configs.dropout)
|
||||
self.projection = nn.Linear(configs.d_model * configs.seq_len, configs.num_class)
|
||||
|
||||
def long_forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
||||
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
||||
dec_out = self.dec_embedding(x_dec, x_mark_dec)
|
||||
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
||||
|
||||
dec_out = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None)
|
||||
|
||||
return dec_out # [B, L, D]
|
||||
|
||||
def short_forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
||||
# Normalization
|
||||
mean_enc = x_enc.mean(1, keepdim=True).detach() # B x 1 x E
|
||||
x_enc = x_enc - mean_enc
|
||||
std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach() # B x 1 x E
|
||||
x_enc = x_enc / std_enc
|
||||
|
||||
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
||||
dec_out = self.dec_embedding(x_dec, x_mark_dec)
|
||||
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
||||
|
||||
dec_out = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None)
|
||||
|
||||
dec_out = dec_out * std_enc + mean_enc
|
||||
return dec_out # [B, L, D]
|
||||
|
||||
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
|
||||
# enc
|
||||
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
||||
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
||||
# final
|
||||
dec_out = self.projection(enc_out)
|
||||
return dec_out
|
||||
|
||||
def anomaly_detection(self, x_enc):
|
||||
# enc
|
||||
enc_out = self.enc_embedding(x_enc, None)
|
||||
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
||||
# final
|
||||
dec_out = self.projection(enc_out)
|
||||
return dec_out
|
||||
|
||||
def classification(self, x_enc, x_mark_enc):
|
||||
# enc
|
||||
enc_out = self.enc_embedding(x_enc, None)
|
||||
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
||||
|
||||
# Output
|
||||
output = self.act(enc_out) # the output transformer encoder/decoder embeddings don't include non-linearity
|
||||
output = self.dropout(output)
|
||||
output = output * x_mark_enc.unsqueeze(-1) # zero-out padding embeddings
|
||||
output = output.reshape(output.shape[0], -1) # (batch_size, seq_length * d_model)
|
||||
output = self.projection(output) # (batch_size, num_classes)
|
||||
return output
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
||||
if self.task_name == 'long_term_forecast':
|
||||
dec_out = self.long_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
||||
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
||||
if self.task_name == 'short_term_forecast':
|
||||
dec_out = self.short_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
||||
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
||||
if self.task_name == 'imputation':
|
||||
dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'anomaly_detection':
|
||||
dec_out = self.anomaly_detection(x_enc)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'classification':
|
||||
dec_out = self.classification(x_enc, x_mark_enc)
|
||||
return dec_out # [B, N]
|
||||
return None
|
337
models/Koopa.py
Normal file
337
models/Koopa.py
Normal file
@ -0,0 +1,337 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from data_provider.data_factory import data_provider
|
||||
|
||||
|
||||
|
||||
class FourierFilter(nn.Module):
|
||||
"""
|
||||
Fourier Filter: to time-variant and time-invariant term
|
||||
"""
|
||||
def __init__(self, mask_spectrum):
|
||||
super(FourierFilter, self).__init__()
|
||||
self.mask_spectrum = mask_spectrum
|
||||
|
||||
def forward(self, x):
|
||||
xf = torch.fft.rfft(x, dim=1)
|
||||
mask = torch.ones_like(xf)
|
||||
mask[:, self.mask_spectrum, :] = 0
|
||||
x_var = torch.fft.irfft(xf*mask, dim=1)
|
||||
x_inv = x - x_var
|
||||
|
||||
return x_var, x_inv
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
'''
|
||||
Multilayer perceptron to encode/decode high dimension representation of sequential data
|
||||
'''
|
||||
def __init__(self,
|
||||
f_in,
|
||||
f_out,
|
||||
hidden_dim=128,
|
||||
hidden_layers=2,
|
||||
dropout=0.05,
|
||||
activation='tanh'):
|
||||
super(MLP, self).__init__()
|
||||
self.f_in = f_in
|
||||
self.f_out = f_out
|
||||
self.hidden_dim = hidden_dim
|
||||
self.hidden_layers = hidden_layers
|
||||
self.dropout = dropout
|
||||
if activation == 'relu':
|
||||
self.activation = nn.ReLU()
|
||||
elif activation == 'tanh':
|
||||
self.activation = nn.Tanh()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
layers = [nn.Linear(self.f_in, self.hidden_dim),
|
||||
self.activation, nn.Dropout(self.dropout)]
|
||||
for i in range(self.hidden_layers-2):
|
||||
layers += [nn.Linear(self.hidden_dim, self.hidden_dim),
|
||||
self.activation, nn.Dropout(dropout)]
|
||||
|
||||
layers += [nn.Linear(hidden_dim, f_out)]
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
# x: B x S x f_in
|
||||
# y: B x S x f_out
|
||||
y = self.layers(x)
|
||||
return y
|
||||
|
||||
|
||||
class KPLayer(nn.Module):
|
||||
"""
|
||||
A demonstration of finding one step transition of linear system by DMD iteratively
|
||||
"""
|
||||
def __init__(self):
|
||||
super(KPLayer, self).__init__()
|
||||
|
||||
self.K = None # B E E
|
||||
|
||||
def one_step_forward(self, z, return_rec=False, return_K=False):
|
||||
B, input_len, E = z.shape
|
||||
assert input_len > 1, 'snapshots number should be larger than 1'
|
||||
x, y = z[:, :-1], z[:, 1:]
|
||||
|
||||
# solve linear system
|
||||
self.K = torch.linalg.lstsq(x, y).solution # B E E
|
||||
if torch.isnan(self.K).any():
|
||||
print('Encounter K with nan, replace K by identity matrix')
|
||||
self.K = torch.eye(self.K.shape[1]).to(self.K.device).unsqueeze(0).repeat(B, 1, 1)
|
||||
|
||||
z_pred = torch.bmm(z[:, -1:], self.K)
|
||||
if return_rec:
|
||||
z_rec = torch.cat((z[:, :1], torch.bmm(x, self.K)), dim=1)
|
||||
return z_rec, z_pred
|
||||
|
||||
return z_pred
|
||||
|
||||
def forward(self, z, pred_len=1):
|
||||
assert pred_len >= 1, 'prediction length should not be less than 1'
|
||||
z_rec, z_pred= self.one_step_forward(z, return_rec=True)
|
||||
z_preds = [z_pred]
|
||||
for i in range(1, pred_len):
|
||||
z_pred = torch.bmm(z_pred, self.K)
|
||||
z_preds.append(z_pred)
|
||||
z_preds = torch.cat(z_preds, dim=1)
|
||||
return z_rec, z_preds
|
||||
|
||||
|
||||
class KPLayerApprox(nn.Module):
|
||||
"""
|
||||
Find koopman transition of linear system by DMD with multistep K approximation
|
||||
"""
|
||||
def __init__(self):
|
||||
super(KPLayerApprox, self).__init__()
|
||||
|
||||
self.K = None # B E E
|
||||
self.K_step = None # B E E
|
||||
|
||||
def forward(self, z, pred_len=1):
|
||||
# z: B L E, koopman invariance space representation
|
||||
# z_rec: B L E, reconstructed representation
|
||||
# z_pred: B S E, forecasting representation
|
||||
B, input_len, E = z.shape
|
||||
assert input_len > 1, 'snapshots number should be larger than 1'
|
||||
x, y = z[:, :-1], z[:, 1:]
|
||||
|
||||
# solve linear system
|
||||
self.K = torch.linalg.lstsq(x, y).solution # B E E
|
||||
|
||||
if torch.isnan(self.K).any():
|
||||
print('Encounter K with nan, replace K by identity matrix')
|
||||
self.K = torch.eye(self.K.shape[1]).to(self.K.device).unsqueeze(0).repeat(B, 1, 1)
|
||||
|
||||
z_rec = torch.cat((z[:, :1], torch.bmm(x, self.K)), dim=1) # B L E
|
||||
|
||||
if pred_len <= input_len:
|
||||
self.K_step = torch.linalg.matrix_power(self.K, pred_len)
|
||||
if torch.isnan(self.K_step).any():
|
||||
print('Encounter multistep K with nan, replace it by identity matrix')
|
||||
self.K_step = torch.eye(self.K_step.shape[1]).to(self.K_step.device).unsqueeze(0).repeat(B, 1, 1)
|
||||
z_pred = torch.bmm(z[:, -pred_len:, :], self.K_step)
|
||||
else:
|
||||
self.K_step = torch.linalg.matrix_power(self.K, input_len)
|
||||
if torch.isnan(self.K_step).any():
|
||||
print('Encounter multistep K with nan, replace it by identity matrix')
|
||||
self.K_step = torch.eye(self.K_step.shape[1]).to(self.K_step.device).unsqueeze(0).repeat(B, 1, 1)
|
||||
temp_z_pred, all_pred = z, []
|
||||
for _ in range(math.ceil(pred_len / input_len)):
|
||||
temp_z_pred = torch.bmm(temp_z_pred, self.K_step)
|
||||
all_pred.append(temp_z_pred)
|
||||
z_pred = torch.cat(all_pred, dim=1)[:, :pred_len, :]
|
||||
|
||||
return z_rec, z_pred
|
||||
|
||||
|
||||
class TimeVarKP(nn.Module):
|
||||
"""
|
||||
Koopman Predictor with DMD (analysitical solution of Koopman operator)
|
||||
Utilize local variations within individual sliding window to predict the future of time-variant term
|
||||
"""
|
||||
def __init__(self,
|
||||
enc_in=8,
|
||||
input_len=96,
|
||||
pred_len=96,
|
||||
seg_len=24,
|
||||
dynamic_dim=128,
|
||||
encoder=None,
|
||||
decoder=None,
|
||||
multistep=False,
|
||||
):
|
||||
super(TimeVarKP, self).__init__()
|
||||
self.input_len = input_len
|
||||
self.pred_len = pred_len
|
||||
self.enc_in = enc_in
|
||||
self.seg_len = seg_len
|
||||
self.dynamic_dim = dynamic_dim
|
||||
self.multistep = multistep
|
||||
self.encoder, self.decoder = encoder, decoder
|
||||
self.freq = math.ceil(self.input_len / self.seg_len) # segment number of input
|
||||
self.step = math.ceil(self.pred_len / self.seg_len) # segment number of output
|
||||
self.padding_len = self.seg_len * self.freq - self.input_len
|
||||
# Approximate mulitstep K by KPLayerApprox when pred_len is large
|
||||
self.dynamics = KPLayerApprox() if self.multistep else KPLayer()
|
||||
|
||||
def forward(self, x):
|
||||
# x: B L C
|
||||
B, L, C = x.shape
|
||||
|
||||
res = torch.cat((x[:, L-self.padding_len:, :], x) ,dim=1)
|
||||
|
||||
res = res.chunk(self.freq, dim=1) # F x B P C, P means seg_len
|
||||
res = torch.stack(res, dim=1).reshape(B, self.freq, -1) # B F PC
|
||||
|
||||
res = self.encoder(res) # B F H
|
||||
x_rec, x_pred = self.dynamics(res, self.step) # B F H, B S H
|
||||
|
||||
x_rec = self.decoder(x_rec) # B F PC
|
||||
x_rec = x_rec.reshape(B, self.freq, self.seg_len, self.enc_in)
|
||||
x_rec = x_rec.reshape(B, -1, self.enc_in)[:, :self.input_len, :] # B L C
|
||||
|
||||
x_pred = self.decoder(x_pred) # B S PC
|
||||
x_pred = x_pred.reshape(B, self.step, self.seg_len, self.enc_in)
|
||||
x_pred = x_pred.reshape(B, -1, self.enc_in)[:, :self.pred_len, :] # B S C
|
||||
|
||||
return x_rec, x_pred
|
||||
|
||||
|
||||
class TimeInvKP(nn.Module):
|
||||
"""
|
||||
Koopman Predictor with learnable Koopman operator
|
||||
Utilize lookback and forecast window snapshots to predict the future of time-invariant term
|
||||
"""
|
||||
def __init__(self,
|
||||
input_len=96,
|
||||
pred_len=96,
|
||||
dynamic_dim=128,
|
||||
encoder=None,
|
||||
decoder=None):
|
||||
super(TimeInvKP, self).__init__()
|
||||
self.dynamic_dim = dynamic_dim
|
||||
self.input_len = input_len
|
||||
self.pred_len = pred_len
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
|
||||
K_init = torch.randn(self.dynamic_dim, self.dynamic_dim)
|
||||
U, _, V = torch.svd(K_init) # stable initialization
|
||||
self.K = nn.Linear(self.dynamic_dim, self.dynamic_dim, bias=False)
|
||||
self.K.weight.data = torch.mm(U, V.t())
|
||||
|
||||
def forward(self, x):
|
||||
# x: B L C
|
||||
res = x.transpose(1, 2) # B C L
|
||||
res = self.encoder(res) # B C H
|
||||
res = self.K(res) # B C H
|
||||
res = self.decoder(res) # B C S
|
||||
res = res.transpose(1, 2) # B S C
|
||||
|
||||
return res
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
'''
|
||||
Paper link: https://arxiv.org/pdf/2305.18803.pdf
|
||||
'''
|
||||
def __init__(self, configs, dynamic_dim=128, hidden_dim=64, hidden_layers=2, num_blocks=3, multistep=False):
|
||||
"""
|
||||
mask_spectrum: list, shared frequency spectrums
|
||||
seg_len: int, segment length of time series
|
||||
dynamic_dim: int, latent dimension of koopman embedding
|
||||
hidden_dim: int, hidden dimension of en/decoder
|
||||
hidden_layers: int, number of hidden layers of en/decoder
|
||||
num_blocks: int, number of Koopa blocks
|
||||
multistep: bool, whether to use approximation for multistep K
|
||||
alpha: float, spectrum filter ratio
|
||||
"""
|
||||
super(Model, self).__init__()
|
||||
self.task_name = configs.task_name
|
||||
self.enc_in = configs.enc_in
|
||||
self.input_len = configs.seq_len
|
||||
self.pred_len = configs.pred_len
|
||||
|
||||
self.seg_len = self.pred_len
|
||||
self.num_blocks = num_blocks
|
||||
self.dynamic_dim = dynamic_dim
|
||||
self.hidden_dim = hidden_dim
|
||||
self.hidden_layers = hidden_layers
|
||||
self.multistep = multistep
|
||||
self.alpha = 0.2
|
||||
self.mask_spectrum = self._get_mask_spectrum(configs)
|
||||
|
||||
self.disentanglement = FourierFilter(self.mask_spectrum)
|
||||
|
||||
# shared encoder/decoder to make koopman embedding consistent
|
||||
self.time_inv_encoder = MLP(f_in=self.input_len, f_out=self.dynamic_dim, activation='relu',
|
||||
hidden_dim=self.hidden_dim, hidden_layers=self.hidden_layers)
|
||||
self.time_inv_decoder = MLP(f_in=self.dynamic_dim, f_out=self.pred_len, activation='relu',
|
||||
hidden_dim=self.hidden_dim, hidden_layers=self.hidden_layers)
|
||||
self.time_inv_kps = self.time_var_kps = nn.ModuleList([
|
||||
TimeInvKP(input_len=self.input_len,
|
||||
pred_len=self.pred_len,
|
||||
dynamic_dim=self.dynamic_dim,
|
||||
encoder=self.time_inv_encoder,
|
||||
decoder=self.time_inv_decoder)
|
||||
for _ in range(self.num_blocks)])
|
||||
|
||||
# shared encoder/decoder to make koopman embedding consistent
|
||||
self.time_var_encoder = MLP(f_in=self.seg_len*self.enc_in, f_out=self.dynamic_dim, activation='tanh',
|
||||
hidden_dim=self.hidden_dim, hidden_layers=self.hidden_layers)
|
||||
self.time_var_decoder = MLP(f_in=self.dynamic_dim, f_out=self.seg_len*self.enc_in, activation='tanh',
|
||||
hidden_dim=self.hidden_dim, hidden_layers=self.hidden_layers)
|
||||
self.time_var_kps = nn.ModuleList([
|
||||
TimeVarKP(enc_in=configs.enc_in,
|
||||
input_len=self.input_len,
|
||||
pred_len=self.pred_len,
|
||||
seg_len=self.seg_len,
|
||||
dynamic_dim=self.dynamic_dim,
|
||||
encoder=self.time_var_encoder,
|
||||
decoder=self.time_var_decoder,
|
||||
multistep=self.multistep)
|
||||
for _ in range(self.num_blocks)])
|
||||
|
||||
def _get_mask_spectrum(self, configs):
|
||||
"""
|
||||
get shared frequency spectrums
|
||||
"""
|
||||
train_data, train_loader = data_provider(configs, 'train')
|
||||
amps = 0.0
|
||||
for data in train_loader:
|
||||
lookback_window = data[0]
|
||||
amps += abs(torch.fft.rfft(lookback_window, dim=1)).mean(dim=0).mean(dim=1)
|
||||
mask_spectrum = amps.topk(int(amps.shape[0]*self.alpha)).indices
|
||||
return mask_spectrum # as the spectrums of time-invariant component
|
||||
|
||||
def forecast(self, x_enc):
|
||||
# Series Stationarization adopted from NSformer
|
||||
mean_enc = x_enc.mean(1, keepdim=True).detach() # B x 1 x E
|
||||
x_enc = x_enc - mean_enc
|
||||
std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
|
||||
x_enc = x_enc / std_enc
|
||||
|
||||
# Koopman Forecasting
|
||||
residual, forecast = x_enc, None
|
||||
for i in range(self.num_blocks):
|
||||
time_var_input, time_inv_input = self.disentanglement(residual)
|
||||
time_inv_output = self.time_inv_kps[i](time_inv_input)
|
||||
time_var_backcast, time_var_output = self.time_var_kps[i](time_var_input)
|
||||
residual = residual - time_var_backcast
|
||||
if forecast is None:
|
||||
forecast = (time_inv_output + time_var_output)
|
||||
else:
|
||||
forecast += (time_inv_output + time_var_output)
|
||||
|
||||
# Series Stationarization adopted from NSformer
|
||||
res = forecast * std_enc + mean_enc
|
||||
|
||||
return res
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
||||
if self.task_name == 'long_term_forecast':
|
||||
dec_out = self.forecast(x_enc)
|
||||
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
165
models/LightTS.py
Normal file
165
models/LightTS.py
Normal file
@ -0,0 +1,165 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class IEBlock(nn.Module):
|
||||
def __init__(self, input_dim, hid_dim, output_dim, num_node):
|
||||
super(IEBlock, self).__init__()
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.hid_dim = hid_dim
|
||||
self.output_dim = output_dim
|
||||
self.num_node = num_node
|
||||
|
||||
self._build()
|
||||
|
||||
def _build(self):
|
||||
self.spatial_proj = nn.Sequential(
|
||||
nn.Linear(self.input_dim, self.hid_dim),
|
||||
nn.LeakyReLU(),
|
||||
nn.Linear(self.hid_dim, self.hid_dim // 4)
|
||||
)
|
||||
|
||||
self.channel_proj = nn.Linear(self.num_node, self.num_node)
|
||||
torch.nn.init.eye_(self.channel_proj.weight)
|
||||
|
||||
self.output_proj = nn.Linear(self.hid_dim // 4, self.output_dim)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.spatial_proj(x.permute(0, 2, 1))
|
||||
x = x.permute(0, 2, 1) + self.channel_proj(x.permute(0, 2, 1))
|
||||
x = self.output_proj(x.permute(0, 2, 1))
|
||||
|
||||
x = x.permute(0, 2, 1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
"""
|
||||
Paper link: https://arxiv.org/abs/2207.01186
|
||||
"""
|
||||
|
||||
def __init__(self, configs, chunk_size=24):
|
||||
"""
|
||||
chunk_size: int, reshape T into [num_chunks, chunk_size]
|
||||
"""
|
||||
super(Model, self).__init__()
|
||||
self.task_name = configs.task_name
|
||||
self.seq_len = configs.seq_len
|
||||
if self.task_name == 'classification' or self.task_name == 'anomaly_detection' or self.task_name == 'imputation':
|
||||
self.pred_len = configs.seq_len
|
||||
else:
|
||||
self.pred_len = configs.pred_len
|
||||
|
||||
if configs.task_name == 'long_term_forecast' or configs.task_name == 'short_term_forecast':
|
||||
self.chunk_size = min(configs.pred_len, configs.seq_len, chunk_size)
|
||||
else:
|
||||
self.chunk_size = min(configs.seq_len, chunk_size)
|
||||
# assert (self.seq_len % self.chunk_size == 0)
|
||||
if self.seq_len % self.chunk_size != 0:
|
||||
self.seq_len += (self.chunk_size - self.seq_len % self.chunk_size) # padding in order to ensure complete division
|
||||
self.num_chunks = self.seq_len // self.chunk_size
|
||||
|
||||
self.d_model = configs.d_model
|
||||
self.enc_in = configs.enc_in
|
||||
self.dropout = configs.dropout
|
||||
if self.task_name == 'classification':
|
||||
self.act = F.gelu
|
||||
self.dropout = nn.Dropout(configs.dropout)
|
||||
self.projection = nn.Linear(configs.enc_in * configs.seq_len, configs.num_class)
|
||||
self._build()
|
||||
|
||||
def _build(self):
|
||||
self.layer_1 = IEBlock(
|
||||
input_dim=self.chunk_size,
|
||||
hid_dim=self.d_model // 4,
|
||||
output_dim=self.d_model // 4,
|
||||
num_node=self.num_chunks
|
||||
)
|
||||
|
||||
self.chunk_proj_1 = nn.Linear(self.num_chunks, 1)
|
||||
|
||||
self.layer_2 = IEBlock(
|
||||
input_dim=self.chunk_size,
|
||||
hid_dim=self.d_model // 4,
|
||||
output_dim=self.d_model // 4,
|
||||
num_node=self.num_chunks
|
||||
)
|
||||
|
||||
self.chunk_proj_2 = nn.Linear(self.num_chunks, 1)
|
||||
|
||||
self.layer_3 = IEBlock(
|
||||
input_dim=self.d_model // 2,
|
||||
hid_dim=self.d_model // 2,
|
||||
output_dim=self.pred_len,
|
||||
num_node=self.enc_in
|
||||
)
|
||||
|
||||
self.ar = nn.Linear(self.seq_len, self.pred_len)
|
||||
|
||||
def encoder(self, x):
|
||||
B, T, N = x.size()
|
||||
|
||||
# padding
|
||||
x = torch.cat([x, torch.zeros((B, self.seq_len - T, N)).to(x.device)], dim=1)
|
||||
|
||||
highway = self.ar(x.permute(0, 2, 1))
|
||||
highway = highway.permute(0, 2, 1)
|
||||
|
||||
# continuous sampling
|
||||
x1 = x.reshape(B, self.num_chunks, self.chunk_size, N)
|
||||
x1 = x1.permute(0, 3, 2, 1)
|
||||
x1 = x1.reshape(-1, self.chunk_size, self.num_chunks)
|
||||
x1 = self.layer_1(x1)
|
||||
x1 = self.chunk_proj_1(x1).squeeze(dim=-1)
|
||||
|
||||
# interval sampling
|
||||
x2 = x.reshape(B, self.chunk_size, self.num_chunks, N)
|
||||
x2 = x2.permute(0, 3, 1, 2)
|
||||
x2 = x2.reshape(-1, self.chunk_size, self.num_chunks)
|
||||
x2 = self.layer_2(x2)
|
||||
x2 = self.chunk_proj_2(x2).squeeze(dim=-1)
|
||||
|
||||
x3 = torch.cat([x1, x2], dim=-1)
|
||||
|
||||
x3 = x3.reshape(B, N, -1)
|
||||
x3 = x3.permute(0, 2, 1)
|
||||
|
||||
out = self.layer_3(x3)
|
||||
|
||||
out = out + highway
|
||||
return out
|
||||
|
||||
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
||||
return self.encoder(x_enc)
|
||||
|
||||
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
|
||||
return self.encoder(x_enc)
|
||||
|
||||
def anomaly_detection(self, x_enc):
|
||||
return self.encoder(x_enc)
|
||||
|
||||
def classification(self, x_enc, x_mark_enc):
|
||||
enc_out = self.encoder(x_enc)
|
||||
|
||||
# Output
|
||||
output = enc_out.reshape(enc_out.shape[0], -1) # (batch_size, seq_length * d_model)
|
||||
output = self.projection(output) # (batch_size, num_classes)
|
||||
return output
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
||||
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
||||
if self.task_name == 'imputation':
|
||||
dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'anomaly_detection':
|
||||
dec_out = self.anomaly_detection(x_enc)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'classification':
|
||||
dec_out = self.classification(x_enc, x_mark_enc)
|
||||
return dec_out # [B, N]
|
||||
return None
|
222
models/MICN.py
Normal file
222
models/MICN.py
Normal file
@ -0,0 +1,222 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from layers.Embed import DataEmbedding
|
||||
from layers.Autoformer_EncDec import series_decomp, series_decomp_multi
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class MIC(nn.Module):
|
||||
"""
|
||||
MIC layer to extract local and global features
|
||||
"""
|
||||
|
||||
def __init__(self, feature_size=512, n_heads=8, dropout=0.05, decomp_kernel=[32], conv_kernel=[24],
|
||||
isometric_kernel=[18, 6], device='cuda'):
|
||||
super(MIC, self).__init__()
|
||||
self.conv_kernel = conv_kernel
|
||||
self.device = device
|
||||
|
||||
# isometric convolution
|
||||
self.isometric_conv = nn.ModuleList([nn.Conv1d(in_channels=feature_size, out_channels=feature_size,
|
||||
kernel_size=i, padding=0, stride=1)
|
||||
for i in isometric_kernel])
|
||||
|
||||
# downsampling convolution: padding=i//2, stride=i
|
||||
self.conv = nn.ModuleList([nn.Conv1d(in_channels=feature_size, out_channels=feature_size,
|
||||
kernel_size=i, padding=i // 2, stride=i)
|
||||
for i in conv_kernel])
|
||||
|
||||
# upsampling convolution
|
||||
self.conv_trans = nn.ModuleList([nn.ConvTranspose1d(in_channels=feature_size, out_channels=feature_size,
|
||||
kernel_size=i, padding=0, stride=i)
|
||||
for i in conv_kernel])
|
||||
|
||||
self.decomp = nn.ModuleList([series_decomp(k) for k in decomp_kernel])
|
||||
self.merge = torch.nn.Conv2d(in_channels=feature_size, out_channels=feature_size,
|
||||
kernel_size=(len(self.conv_kernel), 1))
|
||||
|
||||
# feedforward network
|
||||
self.conv1 = nn.Conv1d(in_channels=feature_size, out_channels=feature_size * 4, kernel_size=1)
|
||||
self.conv2 = nn.Conv1d(in_channels=feature_size * 4, out_channels=feature_size, kernel_size=1)
|
||||
self.norm1 = nn.LayerNorm(feature_size)
|
||||
self.norm2 = nn.LayerNorm(feature_size)
|
||||
|
||||
self.norm = torch.nn.LayerNorm(feature_size)
|
||||
self.act = torch.nn.Tanh()
|
||||
self.drop = torch.nn.Dropout(0.05)
|
||||
|
||||
def conv_trans_conv(self, input, conv1d, conv1d_trans, isometric):
|
||||
batch, seq_len, channel = input.shape
|
||||
x = input.permute(0, 2, 1)
|
||||
|
||||
# downsampling convolution
|
||||
x1 = self.drop(self.act(conv1d(x)))
|
||||
x = x1
|
||||
|
||||
# isometric convolution
|
||||
zeros = torch.zeros((x.shape[0], x.shape[1], x.shape[2] - 1), device=self.device)
|
||||
x = torch.cat((zeros, x), dim=-1)
|
||||
x = self.drop(self.act(isometric(x)))
|
||||
x = self.norm((x + x1).permute(0, 2, 1)).permute(0, 2, 1)
|
||||
|
||||
# upsampling convolution
|
||||
x = self.drop(self.act(conv1d_trans(x)))
|
||||
x = x[:, :, :seq_len] # truncate
|
||||
|
||||
x = self.norm(x.permute(0, 2, 1) + input)
|
||||
return x
|
||||
|
||||
def forward(self, src):
|
||||
self.device = src.device
|
||||
# multi-scale
|
||||
multi = []
|
||||
for i in range(len(self.conv_kernel)):
|
||||
src_out, trend1 = self.decomp[i](src)
|
||||
src_out = self.conv_trans_conv(src_out, self.conv[i], self.conv_trans[i], self.isometric_conv[i])
|
||||
multi.append(src_out)
|
||||
|
||||
# merge
|
||||
mg = torch.tensor([], device=self.device)
|
||||
for i in range(len(self.conv_kernel)):
|
||||
mg = torch.cat((mg, multi[i].unsqueeze(1).to(self.device)), dim=1)
|
||||
mg = self.merge(mg.permute(0, 3, 1, 2)).squeeze(-2).permute(0, 2, 1)
|
||||
|
||||
y = self.norm1(mg)
|
||||
y = self.conv2(self.conv1(y.transpose(-1, 1))).transpose(-1, 1)
|
||||
|
||||
return self.norm2(mg + y)
|
||||
|
||||
|
||||
class SeasonalPrediction(nn.Module):
|
||||
def __init__(self, embedding_size=512, n_heads=8, dropout=0.05, d_layers=1, decomp_kernel=[32], c_out=1,
|
||||
conv_kernel=[2, 4], isometric_kernel=[18, 6], device='cuda'):
|
||||
super(SeasonalPrediction, self).__init__()
|
||||
|
||||
self.mic = nn.ModuleList([MIC(feature_size=embedding_size, n_heads=n_heads,
|
||||
decomp_kernel=decomp_kernel, conv_kernel=conv_kernel,
|
||||
isometric_kernel=isometric_kernel, device=device)
|
||||
for i in range(d_layers)])
|
||||
|
||||
self.projection = nn.Linear(embedding_size, c_out)
|
||||
|
||||
def forward(self, dec):
|
||||
for mic_layer in self.mic:
|
||||
dec = mic_layer(dec)
|
||||
return self.projection(dec)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
"""
|
||||
Paper link: https://openreview.net/pdf?id=zt53IDUR1U
|
||||
"""
|
||||
def __init__(self, configs, conv_kernel=[12, 16]):
|
||||
"""
|
||||
conv_kernel: downsampling and upsampling convolution kernel_size
|
||||
"""
|
||||
super(Model, self).__init__()
|
||||
|
||||
decomp_kernel = [] # kernel of decomposition operation
|
||||
isometric_kernel = [] # kernel of isometric convolution
|
||||
for ii in conv_kernel:
|
||||
if ii % 2 == 0: # the kernel of decomposition operation must be odd
|
||||
decomp_kernel.append(ii + 1)
|
||||
isometric_kernel.append((configs.seq_len + configs.pred_len + ii) // ii)
|
||||
else:
|
||||
decomp_kernel.append(ii)
|
||||
isometric_kernel.append((configs.seq_len + configs.pred_len + ii - 1) // ii)
|
||||
|
||||
self.task_name = configs.task_name
|
||||
self.pred_len = configs.pred_len
|
||||
self.seq_len = configs.seq_len
|
||||
|
||||
# Multiple Series decomposition block from FEDformer
|
||||
self.decomp_multi = series_decomp_multi(decomp_kernel)
|
||||
|
||||
# embedding
|
||||
self.dec_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,
|
||||
configs.dropout)
|
||||
|
||||
self.conv_trans = SeasonalPrediction(embedding_size=configs.d_model, n_heads=configs.n_heads,
|
||||
dropout=configs.dropout,
|
||||
d_layers=configs.d_layers, decomp_kernel=decomp_kernel,
|
||||
c_out=configs.c_out, conv_kernel=conv_kernel,
|
||||
isometric_kernel=isometric_kernel, device=torch.device('cuda:0'))
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
# refer to DLinear
|
||||
self.regression = nn.Linear(configs.seq_len, configs.pred_len)
|
||||
self.regression.weight = nn.Parameter(
|
||||
(1 / configs.pred_len) * torch.ones([configs.pred_len, configs.seq_len]),
|
||||
requires_grad=True)
|
||||
if self.task_name == 'imputation':
|
||||
self.projection = nn.Linear(configs.d_model, configs.c_out, bias=True)
|
||||
if self.task_name == 'anomaly_detection':
|
||||
self.projection = nn.Linear(configs.d_model, configs.c_out, bias=True)
|
||||
if self.task_name == 'classification':
|
||||
self.act = F.gelu
|
||||
self.dropout = nn.Dropout(configs.dropout)
|
||||
self.projection = nn.Linear(configs.c_out * configs.seq_len, configs.num_class)
|
||||
|
||||
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
||||
# Multi-scale Hybrid Decomposition
|
||||
seasonal_init_enc, trend = self.decomp_multi(x_enc)
|
||||
trend = self.regression(trend.permute(0, 2, 1)).permute(0, 2, 1)
|
||||
|
||||
# embedding
|
||||
zeros = torch.zeros([x_dec.shape[0], self.pred_len, x_dec.shape[2]], device=x_enc.device)
|
||||
seasonal_init_dec = torch.cat([seasonal_init_enc[:, -self.seq_len:, :], zeros], dim=1)
|
||||
dec_out = self.dec_embedding(seasonal_init_dec, x_mark_dec)
|
||||
dec_out = self.conv_trans(dec_out)
|
||||
dec_out = dec_out[:, -self.pred_len:, :] + trend[:, -self.pred_len:, :]
|
||||
return dec_out
|
||||
|
||||
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
|
||||
# Multi-scale Hybrid Decomposition
|
||||
seasonal_init_enc, trend = self.decomp_multi(x_enc)
|
||||
|
||||
# embedding
|
||||
dec_out = self.dec_embedding(seasonal_init_enc, x_mark_dec)
|
||||
dec_out = self.conv_trans(dec_out)
|
||||
dec_out = dec_out + trend
|
||||
return dec_out
|
||||
|
||||
def anomaly_detection(self, x_enc):
|
||||
# Multi-scale Hybrid Decomposition
|
||||
seasonal_init_enc, trend = self.decomp_multi(x_enc)
|
||||
|
||||
# embedding
|
||||
dec_out = self.dec_embedding(seasonal_init_enc, None)
|
||||
dec_out = self.conv_trans(dec_out)
|
||||
dec_out = dec_out + trend
|
||||
return dec_out
|
||||
|
||||
def classification(self, x_enc, x_mark_enc):
|
||||
# Multi-scale Hybrid Decomposition
|
||||
seasonal_init_enc, trend = self.decomp_multi(x_enc)
|
||||
# embedding
|
||||
dec_out = self.dec_embedding(seasonal_init_enc, None)
|
||||
dec_out = self.conv_trans(dec_out)
|
||||
dec_out = dec_out + trend
|
||||
|
||||
# Output from Non-stationary Transformer
|
||||
output = self.act(dec_out) # the output transformer encoder/decoder embeddings don't include non-linearity
|
||||
output = self.dropout(output)
|
||||
output = output * x_mark_enc.unsqueeze(-1) # zero-out padding embeddings
|
||||
output = output.reshape(output.shape[0], -1) # (batch_size, seq_length * d_model)
|
||||
output = self.projection(output) # (batch_size, num_classes)
|
||||
return output
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
||||
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
||||
if self.task_name == 'imputation':
|
||||
dec_out = self.imputation(
|
||||
x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'anomaly_detection':
|
||||
dec_out = self.anomaly_detection(x_enc)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'classification':
|
||||
dec_out = self.classification(x_enc, x_mark_enc)
|
||||
return dec_out # [B, N]
|
||||
return None
|
50
models/Mamba.py
Normal file
50
models/Mamba.py
Normal file
@ -0,0 +1,50 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mamba_ssm import Mamba
|
||||
|
||||
from layers.Embed import DataEmbedding
|
||||
|
||||
class Model(nn.Module):
|
||||
|
||||
def __init__(self, configs):
|
||||
super(Model, self).__init__()
|
||||
self.task_name = configs.task_name
|
||||
self.pred_len = configs.pred_len
|
||||
|
||||
self.d_inner = configs.d_model * configs.expand
|
||||
self.dt_rank = math.ceil(configs.d_model / 16) # TODO implement "auto"
|
||||
|
||||
self.embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, configs.dropout)
|
||||
|
||||
self.mamba = Mamba(
|
||||
d_model = configs.d_model,
|
||||
d_state = configs.d_ff,
|
||||
d_conv = configs.d_conv,
|
||||
expand = configs.expand,
|
||||
)
|
||||
|
||||
self.out_layer = nn.Linear(configs.d_model, configs.c_out, bias=False)
|
||||
|
||||
def forecast(self, x_enc, x_mark_enc):
|
||||
mean_enc = x_enc.mean(1, keepdim=True).detach()
|
||||
x_enc = x_enc - mean_enc
|
||||
std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
|
||||
x_enc = x_enc / std_enc
|
||||
|
||||
x = self.embedding(x_enc, x_mark_enc)
|
||||
x = self.mamba(x)
|
||||
x_out = self.out_layer(x)
|
||||
|
||||
x_out = x_out * std_enc + mean_enc
|
||||
return x_out
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
||||
if self.task_name in ['short_term_forecast', 'long_term_forecast']:
|
||||
x_out = self.forecast(x_enc, x_mark_enc)
|
||||
return x_out[:, -self.pred_len:, :]
|
||||
|
||||
# other tasks not implemented
|
162
models/MambaSimple.py
Normal file
162
models/MambaSimple.py
Normal file
@ -0,0 +1,162 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat, einsum
|
||||
|
||||
from layers.Embed import DataEmbedding
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
"""
|
||||
Mamba, linear-time sequence modeling with selective state spaces O(L)
|
||||
Paper link: https://arxiv.org/abs/2312.00752
|
||||
Implementation refernce: https://github.com/johnma2006/mamba-minimal/
|
||||
"""
|
||||
|
||||
def __init__(self, configs):
|
||||
super(Model, self).__init__()
|
||||
self.task_name = configs.task_name
|
||||
self.pred_len = configs.pred_len
|
||||
|
||||
self.d_inner = configs.d_model * configs.expand
|
||||
self.dt_rank = math.ceil(configs.d_model / 16)
|
||||
|
||||
self.embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, configs.dropout)
|
||||
|
||||
self.layers = nn.ModuleList([ResidualBlock(configs, self.d_inner, self.dt_rank) for _ in range(configs.e_layers)])
|
||||
self.norm = RMSNorm(configs.d_model)
|
||||
|
||||
self.out_layer = nn.Linear(configs.d_model, configs.c_out, bias=False)
|
||||
|
||||
def forecast(self, x_enc, x_mark_enc):
|
||||
mean_enc = x_enc.mean(1, keepdim=True).detach()
|
||||
x_enc = x_enc - mean_enc
|
||||
std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
|
||||
x_enc = x_enc / std_enc
|
||||
|
||||
x = self.embedding(x_enc, x_mark_enc)
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
|
||||
x = self.norm(x)
|
||||
x_out = self.out_layer(x)
|
||||
|
||||
x_out = x_out * std_enc + mean_enc
|
||||
return x_out
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
||||
if self.task_name in ['short_term_forecast', 'long_term_forecast']:
|
||||
x_out = self.forecast(x_enc, x_mark_enc)
|
||||
return x_out[:, -self.pred_len:, :]
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, configs, d_inner, dt_rank):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
self.mixer = MambaBlock(configs, d_inner, dt_rank)
|
||||
self.norm = RMSNorm(configs.d_model)
|
||||
|
||||
def forward(self, x):
|
||||
output = self.mixer(self.norm(x)) + x
|
||||
return output
|
||||
|
||||
class MambaBlock(nn.Module):
|
||||
def __init__(self, configs, d_inner, dt_rank):
|
||||
super(MambaBlock, self).__init__()
|
||||
self.d_inner = d_inner
|
||||
self.dt_rank = dt_rank
|
||||
|
||||
self.in_proj = nn.Linear(configs.d_model, self.d_inner * 2, bias=False)
|
||||
|
||||
self.conv1d = nn.Conv1d(
|
||||
in_channels = self.d_inner,
|
||||
out_channels = self.d_inner,
|
||||
bias = True,
|
||||
kernel_size = configs.d_conv,
|
||||
padding = configs.d_conv - 1,
|
||||
groups = self.d_inner
|
||||
)
|
||||
|
||||
# takes in x and outputs the input-specific delta, B, C
|
||||
self.x_proj = nn.Linear(self.d_inner, self.dt_rank + configs.d_ff * 2, bias=False)
|
||||
|
||||
# projects delta
|
||||
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
|
||||
|
||||
A = repeat(torch.arange(1, configs.d_ff + 1), "n -> d n", d=self.d_inner).float()
|
||||
self.A_log = nn.Parameter(torch.log(A))
|
||||
self.D = nn.Parameter(torch.ones(self.d_inner))
|
||||
|
||||
self.out_proj = nn.Linear(self.d_inner, configs.d_model, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Figure 3 in Section 3.4 in the paper
|
||||
"""
|
||||
(b, l, d) = x.shape
|
||||
|
||||
x_and_res = self.in_proj(x) # [B, L, 2 * d_inner]
|
||||
(x, res) = x_and_res.split(split_size=[self.d_inner, self.d_inner], dim=-1)
|
||||
|
||||
x = rearrange(x, "b l d -> b d l")
|
||||
x = self.conv1d(x)[:, :, :l]
|
||||
x = rearrange(x, "b d l -> b l d")
|
||||
|
||||
x = F.silu(x)
|
||||
|
||||
y = self.ssm(x)
|
||||
y = y * F.silu(res)
|
||||
|
||||
output = self.out_proj(y)
|
||||
return output
|
||||
|
||||
|
||||
def ssm(self, x):
|
||||
"""
|
||||
Algorithm 2 in Section 3.2 in the paper
|
||||
"""
|
||||
|
||||
(d_in, n) = self.A_log.shape
|
||||
|
||||
A = -torch.exp(self.A_log.float()) # [d_in, n]
|
||||
D = self.D.float() # [d_in]
|
||||
|
||||
x_dbl = self.x_proj(x) # [B, L, d_rank + 2 * d_ff]
|
||||
(delta, B, C) = x_dbl.split(split_size=[self.dt_rank, n, n], dim=-1) # delta: [B, L, d_rank]; B, C: [B, L, n]
|
||||
delta = F.softplus(self.dt_proj(delta)) # [B, L, d_in]
|
||||
y = self.selective_scan(x, delta, A, B, C, D)
|
||||
|
||||
return y
|
||||
|
||||
def selective_scan(self, u, delta, A, B, C, D):
|
||||
(b, l, d_in) = u.shape
|
||||
n = A.shape[1]
|
||||
|
||||
deltaA = torch.exp(einsum(delta, A, "b l d, d n -> b l d n")) # A is discretized using zero-order hold (ZOH) discretization
|
||||
deltaB_u = einsum(delta, B, u, "b l d, b l n, b l d -> b l d n") # B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors: "A is the more important term and the performance doesn't change much with the simplification on B"
|
||||
|
||||
# selective scan, sequential instead of parallel
|
||||
x = torch.zeros((b, d_in, n), device=deltaA.device)
|
||||
ys = []
|
||||
for i in range(l):
|
||||
x = deltaA[:, i] * x + deltaB_u[:, i]
|
||||
y = einsum(x, C[:, i, :], "b d n, b n -> b d")
|
||||
ys.append(y)
|
||||
|
||||
y = torch.stack(ys, dim=1) # [B, L, d_in]
|
||||
y = y + u * D
|
||||
|
||||
return y
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, d_model, eps=1e-5):
|
||||
super(RMSNorm, self).__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(d_model))
|
||||
|
||||
def forward(self, x):
|
||||
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
|
||||
return output
|
365
models/MultiPatchFormer.py
Normal file
365
models/MultiPatchFormer.py
Normal file
@ -0,0 +1,365 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
from einops import rearrange
|
||||
|
||||
from layers.SelfAttention_Family import AttentionLayer, FullAttention
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, d_model: int, d_hidden: int = 512):
|
||||
super(FeedForward, self).__init__()
|
||||
|
||||
self.linear_1 = torch.nn.Linear(d_model, d_hidden)
|
||||
self.linear_2 = torch.nn.Linear(d_hidden, d_model)
|
||||
self.activation = torch.nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear_1(x)
|
||||
x = self.activation(x)
|
||||
x = self.linear_2(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
mha: AttentionLayer,
|
||||
d_hidden: int,
|
||||
dropout: float = 0,
|
||||
channel_wise=False,
|
||||
):
|
||||
super(Encoder, self).__init__()
|
||||
|
||||
self.channel_wise = channel_wise
|
||||
if self.channel_wise:
|
||||
self.conv = torch.nn.Conv1d(
|
||||
in_channels=d_model,
|
||||
out_channels=d_model,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
padding_mode="reflect",
|
||||
)
|
||||
self.MHA = mha
|
||||
self.feedforward = FeedForward(d_model=d_model, d_hidden=d_hidden)
|
||||
self.dropout = torch.nn.Dropout(p=dropout)
|
||||
self.layerNormal_1 = torch.nn.LayerNorm(d_model)
|
||||
self.layerNormal_2 = torch.nn.LayerNorm(d_model)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
q = residual
|
||||
if self.channel_wise:
|
||||
x_r = self.conv(x.permute(0, 2, 1)).transpose(1, 2)
|
||||
k = x_r
|
||||
v = x_r
|
||||
else:
|
||||
k = residual
|
||||
v = residual
|
||||
x, score = self.MHA(q, k, v, attn_mask=None)
|
||||
x = self.dropout(x)
|
||||
x = self.layerNormal_1(x + residual)
|
||||
|
||||
residual = x
|
||||
x = self.feedforward(residual)
|
||||
x = self.dropout(x)
|
||||
x = self.layerNormal_2(x + residual)
|
||||
|
||||
return x, score
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, configs):
|
||||
super(Model, self).__init__()
|
||||
self.task_name = configs.task_name
|
||||
self.seq_len = configs.seq_len
|
||||
self.pred_len = configs.pred_len
|
||||
self.d_channel = configs.enc_in
|
||||
self.N = configs.e_layers
|
||||
# Embedding
|
||||
self.d_model = configs.d_model
|
||||
self.d_hidden = configs.d_ff
|
||||
self.n_heads = configs.n_heads
|
||||
self.mask = True
|
||||
self.dropout = configs.dropout
|
||||
|
||||
self.stride1 = 8
|
||||
self.patch_len1 = 8
|
||||
self.stride2 = 8
|
||||
self.patch_len2 = 16
|
||||
self.stride3 = 7
|
||||
self.patch_len3 = 24
|
||||
self.stride4 = 6
|
||||
self.patch_len4 = 32
|
||||
self.patch_num1 = int((self.seq_len - self.patch_len2) // self.stride2) + 2
|
||||
self.padding_patch_layer1 = nn.ReplicationPad1d((0, self.stride1))
|
||||
self.padding_patch_layer2 = nn.ReplicationPad1d((0, self.stride2))
|
||||
self.padding_patch_layer3 = nn.ReplicationPad1d((0, self.stride3))
|
||||
self.padding_patch_layer4 = nn.ReplicationPad1d((0, self.stride4))
|
||||
|
||||
self.shared_MHA = nn.ModuleList(
|
||||
[
|
||||
AttentionLayer(
|
||||
FullAttention(mask_flag=self.mask),
|
||||
d_model=self.d_model,
|
||||
n_heads=self.n_heads,
|
||||
)
|
||||
for _ in range(self.N)
|
||||
]
|
||||
)
|
||||
|
||||
self.shared_MHA_ch = nn.ModuleList(
|
||||
[
|
||||
AttentionLayer(
|
||||
FullAttention(mask_flag=self.mask),
|
||||
d_model=self.d_model,
|
||||
n_heads=self.n_heads,
|
||||
)
|
||||
for _ in range(self.N)
|
||||
]
|
||||
)
|
||||
|
||||
self.encoder_list = nn.ModuleList(
|
||||
[
|
||||
Encoder(
|
||||
d_model=self.d_model,
|
||||
mha=self.shared_MHA[ll],
|
||||
d_hidden=self.d_hidden,
|
||||
dropout=self.dropout,
|
||||
channel_wise=False,
|
||||
)
|
||||
for ll in range(self.N)
|
||||
]
|
||||
)
|
||||
|
||||
self.encoder_list_ch = nn.ModuleList(
|
||||
[
|
||||
Encoder(
|
||||
d_model=self.d_model,
|
||||
mha=self.shared_MHA_ch[0],
|
||||
d_hidden=self.d_hidden,
|
||||
dropout=self.dropout,
|
||||
channel_wise=True,
|
||||
)
|
||||
for ll in range(self.N)
|
||||
]
|
||||
)
|
||||
|
||||
pe = torch.zeros(self.patch_num1, self.d_model)
|
||||
for pos in range(self.patch_num1):
|
||||
for i in range(0, self.d_model, 2):
|
||||
wavelength = 10000 ** ((2 * i) / self.d_model)
|
||||
pe[pos, i] = math.sin(pos / wavelength)
|
||||
pe[pos, i + 1] = math.cos(pos / wavelength)
|
||||
pe = pe.unsqueeze(0) # add a batch dimention to your pe matrix
|
||||
self.register_buffer("pe", pe)
|
||||
|
||||
self.embedding_channel = nn.Conv1d(
|
||||
in_channels=self.d_model * self.patch_num1,
|
||||
out_channels=self.d_model,
|
||||
kernel_size=1,
|
||||
)
|
||||
|
||||
self.embedding_patch_1 = torch.nn.Conv1d(
|
||||
in_channels=1,
|
||||
out_channels=self.d_model // 4,
|
||||
kernel_size=self.patch_len1,
|
||||
stride=self.stride1,
|
||||
)
|
||||
self.embedding_patch_2 = torch.nn.Conv1d(
|
||||
in_channels=1,
|
||||
out_channels=self.d_model // 4,
|
||||
kernel_size=self.patch_len2,
|
||||
stride=self.stride2,
|
||||
)
|
||||
self.embedding_patch_3 = torch.nn.Conv1d(
|
||||
in_channels=1,
|
||||
out_channels=self.d_model // 4,
|
||||
kernel_size=self.patch_len3,
|
||||
stride=self.stride3,
|
||||
)
|
||||
self.embedding_patch_4 = torch.nn.Conv1d(
|
||||
in_channels=1,
|
||||
out_channels=self.d_model // 4,
|
||||
kernel_size=self.patch_len4,
|
||||
stride=self.stride4,
|
||||
)
|
||||
|
||||
self.out_linear_1 = torch.nn.Linear(self.d_model, self.pred_len // 8)
|
||||
self.out_linear_2 = torch.nn.Linear(
|
||||
self.d_model + self.pred_len // 8, self.pred_len // 8
|
||||
)
|
||||
self.out_linear_3 = torch.nn.Linear(
|
||||
self.d_model + 2 * self.pred_len // 8, self.pred_len // 8
|
||||
)
|
||||
self.out_linear_4 = torch.nn.Linear(
|
||||
self.d_model + 3 * self.pred_len // 8, self.pred_len // 8
|
||||
)
|
||||
self.out_linear_5 = torch.nn.Linear(
|
||||
self.d_model + self.pred_len // 2, self.pred_len // 8
|
||||
)
|
||||
self.out_linear_6 = torch.nn.Linear(
|
||||
self.d_model + 5 * self.pred_len // 8, self.pred_len // 8
|
||||
)
|
||||
self.out_linear_7 = torch.nn.Linear(
|
||||
self.d_model + 6 * self.pred_len // 8, self.pred_len // 8
|
||||
)
|
||||
self.out_linear_8 = torch.nn.Linear(
|
||||
self.d_model + 7 * self.pred_len // 8,
|
||||
self.pred_len - 7 * (self.pred_len // 8),
|
||||
)
|
||||
|
||||
self.remap = torch.nn.Linear(self.d_model, self.seq_len)
|
||||
|
||||
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
||||
# Normalization
|
||||
means = x_enc.mean(1, keepdim=True).detach()
|
||||
x_enc = x_enc - means
|
||||
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
||||
x_enc /= stdev
|
||||
|
||||
# Multi-scale embedding
|
||||
x_i = x_enc.permute(0, 2, 1)
|
||||
|
||||
x_i_p1 = x_i
|
||||
x_i_p2 = self.padding_patch_layer2(x_i)
|
||||
x_i_p3 = self.padding_patch_layer3(x_i)
|
||||
x_i_p4 = self.padding_patch_layer4(x_i)
|
||||
encoding_patch1 = self.embedding_patch_1(
|
||||
rearrange(x_i_p1, "b c l -> (b c) l").unsqueeze(-1).permute(0, 2, 1)
|
||||
).permute(0, 2, 1)
|
||||
encoding_patch2 = self.embedding_patch_2(
|
||||
rearrange(x_i_p2, "b c l -> (b c) l").unsqueeze(-1).permute(0, 2, 1)
|
||||
).permute(0, 2, 1)
|
||||
encoding_patch3 = self.embedding_patch_3(
|
||||
rearrange(x_i_p3, "b c l -> (b c) l").unsqueeze(-1).permute(0, 2, 1)
|
||||
).permute(0, 2, 1)
|
||||
encoding_patch4 = self.embedding_patch_4(
|
||||
rearrange(x_i_p4, "b c l -> (b c) l").unsqueeze(-1).permute(0, 2, 1)
|
||||
).permute(0, 2, 1)
|
||||
|
||||
encoding_patch = (
|
||||
torch.cat(
|
||||
(encoding_patch1, encoding_patch2, encoding_patch3, encoding_patch4),
|
||||
dim=-1,
|
||||
)
|
||||
+ self.pe
|
||||
)
|
||||
# Temporal encoding
|
||||
for i in range(self.N):
|
||||
encoding_patch = self.encoder_list[i](encoding_patch)[0]
|
||||
|
||||
# Channel-wise encoding
|
||||
x_patch_c = rearrange(
|
||||
encoding_patch, "(b c) p d -> b c (p d)", b=x_enc.shape[0], c=self.d_channel
|
||||
)
|
||||
x_ch = self.embedding_channel(x_patch_c.permute(0, 2, 1)).transpose(
|
||||
1, 2
|
||||
) # [b c d]
|
||||
|
||||
encoding_1_ch = self.encoder_list_ch[0](x_ch)[0]
|
||||
|
||||
# Semi Auto-regressive
|
||||
forecast_ch1 = self.out_linear_1(encoding_1_ch)
|
||||
forecast_ch2 = self.out_linear_2(
|
||||
torch.cat((encoding_1_ch, forecast_ch1), dim=-1)
|
||||
)
|
||||
forecast_ch3 = self.out_linear_3(
|
||||
torch.cat((encoding_1_ch, forecast_ch1, forecast_ch2), dim=-1)
|
||||
)
|
||||
forecast_ch4 = self.out_linear_4(
|
||||
torch.cat((encoding_1_ch, forecast_ch1, forecast_ch2, forecast_ch3), dim=-1)
|
||||
)
|
||||
forecast_ch5 = self.out_linear_5(
|
||||
torch.cat(
|
||||
(encoding_1_ch, forecast_ch1, forecast_ch2, forecast_ch3, forecast_ch4),
|
||||
dim=-1,
|
||||
)
|
||||
)
|
||||
forecast_ch6 = self.out_linear_6(
|
||||
torch.cat(
|
||||
(
|
||||
encoding_1_ch,
|
||||
forecast_ch1,
|
||||
forecast_ch2,
|
||||
forecast_ch3,
|
||||
forecast_ch4,
|
||||
forecast_ch5,
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
)
|
||||
forecast_ch7 = self.out_linear_7(
|
||||
torch.cat(
|
||||
(
|
||||
encoding_1_ch,
|
||||
forecast_ch1,
|
||||
forecast_ch2,
|
||||
forecast_ch3,
|
||||
forecast_ch4,
|
||||
forecast_ch5,
|
||||
forecast_ch6,
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
)
|
||||
forecast_ch8 = self.out_linear_8(
|
||||
torch.cat(
|
||||
(
|
||||
encoding_1_ch,
|
||||
forecast_ch1,
|
||||
forecast_ch2,
|
||||
forecast_ch3,
|
||||
forecast_ch4,
|
||||
forecast_ch5,
|
||||
forecast_ch6,
|
||||
forecast_ch7,
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
)
|
||||
|
||||
final_forecast = torch.cat(
|
||||
(
|
||||
forecast_ch1,
|
||||
forecast_ch2,
|
||||
forecast_ch3,
|
||||
forecast_ch4,
|
||||
forecast_ch5,
|
||||
forecast_ch6,
|
||||
forecast_ch7,
|
||||
forecast_ch8,
|
||||
),
|
||||
dim=-1,
|
||||
).permute(0, 2, 1)
|
||||
|
||||
# De-Normalization
|
||||
dec_out = final_forecast * (
|
||||
stdev[:, 0].unsqueeze(1).repeat(1, self.pred_len, 1)
|
||||
)
|
||||
dec_out = dec_out + (means[:, 0].unsqueeze(1).repeat(1, self.pred_len, 1))
|
||||
return dec_out
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
||||
if (
|
||||
self.task_name == "long_term_forecast"
|
||||
or self.task_name == "short_term_forecast"
|
||||
):
|
||||
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
||||
return dec_out[:, -self.pred_len :, :] # [B, L, D]
|
||||
if self.task_name == "imputation":
|
||||
raise NotImplementedError(
|
||||
"Task imputation for WPMixer is temporarily not supported"
|
||||
)
|
||||
if self.task_name == "anomaly_detection":
|
||||
raise NotImplementedError(
|
||||
"Task anomaly_detection for WPMixer is temporarily not supported"
|
||||
)
|
||||
if self.task_name == "classification":
|
||||
raise NotImplementedError(
|
||||
"Task classification for WPMixer is temporarily not supported"
|
||||
)
|
||||
return None
|
230
models/Nonstationary_Transformer.py
Normal file
230
models/Nonstationary_Transformer.py
Normal file
@ -0,0 +1,230 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from layers.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer
|
||||
from layers.SelfAttention_Family import DSAttention, AttentionLayer
|
||||
from layers.Embed import DataEmbedding
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Projector(nn.Module):
|
||||
'''
|
||||
MLP to learn the De-stationary factors
|
||||
Paper link: https://openreview.net/pdf?id=ucNDIDRNjjv
|
||||
'''
|
||||
|
||||
def __init__(self, enc_in, seq_len, hidden_dims, hidden_layers, output_dim, kernel_size=3):
|
||||
super(Projector, self).__init__()
|
||||
|
||||
padding = 1 if torch.__version__ >= '1.5.0' else 2
|
||||
self.series_conv = nn.Conv1d(in_channels=seq_len, out_channels=1, kernel_size=kernel_size, padding=padding,
|
||||
padding_mode='circular', bias=False)
|
||||
|
||||
layers = [nn.Linear(2 * enc_in, hidden_dims[0]), nn.ReLU()]
|
||||
for i in range(hidden_layers - 1):
|
||||
layers += [nn.Linear(hidden_dims[i], hidden_dims[i + 1]), nn.ReLU()]
|
||||
|
||||
layers += [nn.Linear(hidden_dims[-1], output_dim, bias=False)]
|
||||
self.backbone = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x, stats):
|
||||
# x: B x S x E
|
||||
# stats: B x 1 x E
|
||||
# y: B x O
|
||||
batch_size = x.shape[0]
|
||||
x = self.series_conv(x) # B x 1 x E
|
||||
x = torch.cat([x, stats], dim=1) # B x 2 x E
|
||||
x = x.view(batch_size, -1) # B x 2E
|
||||
y = self.backbone(x) # B x O
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
"""
|
||||
Paper link: https://openreview.net/pdf?id=ucNDIDRNjjv
|
||||
"""
|
||||
|
||||
def __init__(self, configs):
|
||||
super(Model, self).__init__()
|
||||
self.task_name = configs.task_name
|
||||
self.pred_len = configs.pred_len
|
||||
self.seq_len = configs.seq_len
|
||||
self.label_len = configs.label_len
|
||||
|
||||
# Embedding
|
||||
self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,
|
||||
configs.dropout)
|
||||
|
||||
# Encoder
|
||||
self.encoder = Encoder(
|
||||
[
|
||||
EncoderLayer(
|
||||
AttentionLayer(
|
||||
DSAttention(False, configs.factor, attention_dropout=configs.dropout,
|
||||
output_attention=False), configs.d_model, configs.n_heads),
|
||||
configs.d_model,
|
||||
configs.d_ff,
|
||||
dropout=configs.dropout,
|
||||
activation=configs.activation
|
||||
) for l in range(configs.e_layers)
|
||||
],
|
||||
norm_layer=torch.nn.LayerNorm(configs.d_model)
|
||||
)
|
||||
# Decoder
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq,
|
||||
configs.dropout)
|
||||
self.decoder = Decoder(
|
||||
[
|
||||
DecoderLayer(
|
||||
AttentionLayer(
|
||||
DSAttention(True, configs.factor, attention_dropout=configs.dropout,
|
||||
output_attention=False),
|
||||
configs.d_model, configs.n_heads),
|
||||
AttentionLayer(
|
||||
DSAttention(False, configs.factor, attention_dropout=configs.dropout,
|
||||
output_attention=False),
|
||||
configs.d_model, configs.n_heads),
|
||||
configs.d_model,
|
||||
configs.d_ff,
|
||||
dropout=configs.dropout,
|
||||
activation=configs.activation,
|
||||
)
|
||||
for l in range(configs.d_layers)
|
||||
],
|
||||
norm_layer=torch.nn.LayerNorm(configs.d_model),
|
||||
projection=nn.Linear(configs.d_model, configs.c_out, bias=True)
|
||||
)
|
||||
if self.task_name == 'imputation':
|
||||
self.projection = nn.Linear(configs.d_model, configs.c_out, bias=True)
|
||||
if self.task_name == 'anomaly_detection':
|
||||
self.projection = nn.Linear(configs.d_model, configs.c_out, bias=True)
|
||||
if self.task_name == 'classification':
|
||||
self.act = F.gelu
|
||||
self.dropout = nn.Dropout(configs.dropout)
|
||||
self.projection = nn.Linear(configs.d_model * configs.seq_len, configs.num_class)
|
||||
|
||||
self.tau_learner = Projector(enc_in=configs.enc_in, seq_len=configs.seq_len, hidden_dims=configs.p_hidden_dims,
|
||||
hidden_layers=configs.p_hidden_layers, output_dim=1)
|
||||
self.delta_learner = Projector(enc_in=configs.enc_in, seq_len=configs.seq_len,
|
||||
hidden_dims=configs.p_hidden_dims, hidden_layers=configs.p_hidden_layers,
|
||||
output_dim=configs.seq_len)
|
||||
|
||||
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
||||
x_raw = x_enc.clone().detach()
|
||||
|
||||
# Normalization
|
||||
mean_enc = x_enc.mean(1, keepdim=True).detach() # B x 1 x E
|
||||
x_enc = x_enc - mean_enc
|
||||
std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach() # B x 1 x E
|
||||
x_enc = x_enc / std_enc
|
||||
# B x S x E, B x 1 x E -> B x 1, positive scalar
|
||||
tau = self.tau_learner(x_raw, std_enc)
|
||||
threshold = 80.0
|
||||
tau_clamped = torch.clamp(tau, max=threshold) # avoid numerical overflow
|
||||
tau = tau_clamped.exp()
|
||||
# B x S x E, B x 1 x E -> B x S
|
||||
delta = self.delta_learner(x_raw, mean_enc)
|
||||
|
||||
x_dec_new = torch.cat([x_enc[:, -self.label_len:, :], torch.zeros_like(x_dec[:, -self.pred_len:, :])],
|
||||
dim=1).to(x_enc.device).clone()
|
||||
|
||||
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
||||
enc_out, attns = self.encoder(enc_out, attn_mask=None, tau=tau, delta=delta)
|
||||
|
||||
dec_out = self.dec_embedding(x_dec_new, x_mark_dec)
|
||||
dec_out = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None, tau=tau, delta=delta)
|
||||
dec_out = dec_out * std_enc + mean_enc
|
||||
return dec_out
|
||||
|
||||
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
|
||||
x_raw = x_enc.clone().detach()
|
||||
|
||||
# Normalization
|
||||
mean_enc = torch.sum(x_enc, dim=1) / torch.sum(mask == 1, dim=1)
|
||||
mean_enc = mean_enc.unsqueeze(1).detach()
|
||||
x_enc = x_enc - mean_enc
|
||||
x_enc = x_enc.masked_fill(mask == 0, 0)
|
||||
std_enc = torch.sqrt(torch.sum(x_enc * x_enc, dim=1) / torch.sum(mask == 1, dim=1) + 1e-5)
|
||||
std_enc = std_enc.unsqueeze(1).detach()
|
||||
x_enc /= std_enc
|
||||
# B x S x E, B x 1 x E -> B x 1, positive scalar
|
||||
tau = self.tau_learner(x_raw, std_enc)
|
||||
threshold = 80.0
|
||||
tau_clamped = torch.clamp(tau, max=threshold) # avoid numerical overflow
|
||||
tau = tau_clamped.exp()
|
||||
# B x S x E, B x 1 x E -> B x S
|
||||
delta = self.delta_learner(x_raw, mean_enc)
|
||||
|
||||
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
||||
enc_out, attns = self.encoder(enc_out, attn_mask=None, tau=tau, delta=delta)
|
||||
|
||||
dec_out = self.projection(enc_out)
|
||||
dec_out = dec_out * std_enc + mean_enc
|
||||
return dec_out
|
||||
|
||||
def anomaly_detection(self, x_enc):
|
||||
x_raw = x_enc.clone().detach()
|
||||
|
||||
# Normalization
|
||||
mean_enc = x_enc.mean(1, keepdim=True).detach() # B x 1 x E
|
||||
x_enc = x_enc - mean_enc
|
||||
std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach() # B x 1 x E
|
||||
x_enc = x_enc / std_enc
|
||||
# B x S x E, B x 1 x E -> B x 1, positive scalar
|
||||
tau = self.tau_learner(x_raw, std_enc)
|
||||
threshold = 80.0
|
||||
tau_clamped = torch.clamp(tau, max=threshold) # avoid numerical overflow
|
||||
tau = tau_clamped.exp()
|
||||
# B x S x E, B x 1 x E -> B x S
|
||||
delta = self.delta_learner(x_raw, mean_enc)
|
||||
# embedding
|
||||
enc_out = self.enc_embedding(x_enc, None)
|
||||
enc_out, attns = self.encoder(enc_out, attn_mask=None, tau=tau, delta=delta)
|
||||
|
||||
dec_out = self.projection(enc_out)
|
||||
dec_out = dec_out * std_enc + mean_enc
|
||||
return dec_out
|
||||
|
||||
def classification(self, x_enc, x_mark_enc):
|
||||
x_raw = x_enc.clone().detach()
|
||||
|
||||
# Normalization
|
||||
mean_enc = x_enc.mean(1, keepdim=True).detach() # B x 1 x E
|
||||
std_enc = torch.sqrt(
|
||||
torch.var(x_enc - mean_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach() # B x 1 x E
|
||||
# B x S x E, B x 1 x E -> B x 1, positive scalar
|
||||
tau = self.tau_learner(x_raw, std_enc)
|
||||
threshold = 80.0
|
||||
tau_clamped = torch.clamp(tau, max=threshold) # avoid numerical overflow
|
||||
tau = tau_clamped.exp()
|
||||
# B x S x E, B x 1 x E -> B x S
|
||||
delta = self.delta_learner(x_raw, mean_enc)
|
||||
# embedding
|
||||
enc_out = self.enc_embedding(x_enc, None)
|
||||
enc_out, attns = self.encoder(enc_out, attn_mask=None, tau=tau, delta=delta)
|
||||
|
||||
# Output
|
||||
output = self.act(enc_out) # the output transformer encoder/decoder embeddings don't include non-linearity
|
||||
output = self.dropout(output)
|
||||
output = output * x_mark_enc.unsqueeze(-1) # zero-out padding embeddings
|
||||
# (batch_size, seq_length * d_model)
|
||||
output = output.reshape(output.shape[0], -1)
|
||||
# (batch_size, num_classes)
|
||||
output = self.projection(output)
|
||||
return output
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
||||
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
||||
if self.task_name == 'imputation':
|
||||
dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'anomaly_detection':
|
||||
dec_out = self.anomaly_detection(x_enc)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'classification':
|
||||
dec_out = self.classification(x_enc, x_mark_enc)
|
||||
return dec_out # [B, L, D]
|
||||
return None
|
62
models/PAttn.py
Normal file
62
models/PAttn.py
Normal file
@ -0,0 +1,62 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from layers.Transformer_EncDec import Encoder, EncoderLayer
|
||||
from layers.SelfAttention_Family import FullAttention, AttentionLayer
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
"""
|
||||
Paper link: https://arxiv.org/abs/2406.16964
|
||||
"""
|
||||
def __init__(self, configs, patch_len=16, stride=8):
|
||||
super().__init__()
|
||||
self.seq_len = configs.seq_len
|
||||
self.pred_len = configs.pred_len
|
||||
self.patch_size = patch_len
|
||||
self.stride = stride
|
||||
|
||||
self.d_model = configs.d_model
|
||||
|
||||
self.patch_num = (configs.seq_len - self.patch_size) // self.stride + 2
|
||||
self.padding_patch_layer = nn.ReplicationPad1d((0, self.stride))
|
||||
self.in_layer = nn.Linear(self.patch_size, self.d_model)
|
||||
self.encoder = Encoder(
|
||||
[
|
||||
EncoderLayer(
|
||||
AttentionLayer(
|
||||
FullAttention(False, configs.factor, attention_dropout=configs.dropout,
|
||||
output_attention=False), configs.d_model, configs.n_heads),
|
||||
configs.d_model,
|
||||
configs.d_ff,
|
||||
dropout=configs.dropout,
|
||||
activation=configs.activation
|
||||
) for l in range(1)
|
||||
],
|
||||
norm_layer=nn.LayerNorm(configs.d_model)
|
||||
)
|
||||
self.out_layer = nn.Linear(self.d_model * self.patch_num, configs.pred_len)
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
||||
means = x_enc.mean(1, keepdim=True).detach()
|
||||
x_enc = x_enc - means
|
||||
stdev = torch.sqrt(
|
||||
torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
||||
x_enc /= stdev
|
||||
|
||||
B, _, C = x_enc.shape
|
||||
x_enc = x_enc.permute(0, 2, 1)
|
||||
x_enc = self.padding_patch_layer(x_enc)
|
||||
x_enc = x_enc.unfold(dimension=-1, size=self.patch_size, step=self.stride)
|
||||
enc_out = self.in_layer(x_enc)
|
||||
enc_out = rearrange(enc_out, 'b c m l -> (b c) m l')
|
||||
dec_out, _ = self.encoder(enc_out)
|
||||
dec_out = rearrange(dec_out, '(b c) m l -> b c (m l)' , b=B , c=C)
|
||||
dec_out = self.out_layer(dec_out)
|
||||
dec_out = dec_out.permute(0, 2, 1)
|
||||
|
||||
dec_out = dec_out * \
|
||||
(stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
|
||||
dec_out = dec_out + \
|
||||
(means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
|
||||
return dec_out
|
227
models/PatchTST.py
Normal file
227
models/PatchTST.py
Normal file
@ -0,0 +1,227 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from layers.Transformer_EncDec import Encoder, EncoderLayer
|
||||
from layers.SelfAttention_Family import FullAttention, AttentionLayer
|
||||
from layers.Embed import PatchEmbedding
|
||||
|
||||
class Transpose(nn.Module):
|
||||
def __init__(self, *dims, contiguous=False):
|
||||
super().__init__()
|
||||
self.dims, self.contiguous = dims, contiguous
|
||||
def forward(self, x):
|
||||
if self.contiguous: return x.transpose(*self.dims).contiguous()
|
||||
else: return x.transpose(*self.dims)
|
||||
|
||||
|
||||
class FlattenHead(nn.Module):
|
||||
def __init__(self, n_vars, nf, target_window, head_dropout=0):
|
||||
super().__init__()
|
||||
self.n_vars = n_vars
|
||||
self.flatten = nn.Flatten(start_dim=-2)
|
||||
self.linear = nn.Linear(nf, target_window)
|
||||
self.dropout = nn.Dropout(head_dropout)
|
||||
|
||||
def forward(self, x): # x: [bs x nvars x d_model x patch_num]
|
||||
x = self.flatten(x)
|
||||
x = self.linear(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
"""
|
||||
Paper link: https://arxiv.org/pdf/2211.14730.pdf
|
||||
"""
|
||||
|
||||
def __init__(self, configs, patch_len=16, stride=8):
|
||||
"""
|
||||
patch_len: int, patch len for patch_embedding
|
||||
stride: int, stride for patch_embedding
|
||||
"""
|
||||
super().__init__()
|
||||
self.task_name = configs.task_name
|
||||
self.seq_len = configs.seq_len
|
||||
self.pred_len = configs.pred_len
|
||||
padding = stride
|
||||
|
||||
# patching and embedding
|
||||
self.patch_embedding = PatchEmbedding(
|
||||
configs.d_model, patch_len, stride, padding, configs.dropout)
|
||||
|
||||
# Encoder
|
||||
self.encoder = Encoder(
|
||||
[
|
||||
EncoderLayer(
|
||||
AttentionLayer(
|
||||
FullAttention(False, configs.factor, attention_dropout=configs.dropout,
|
||||
output_attention=False), configs.d_model, configs.n_heads),
|
||||
configs.d_model,
|
||||
configs.d_ff,
|
||||
dropout=configs.dropout,
|
||||
activation=configs.activation
|
||||
) for l in range(configs.e_layers)
|
||||
],
|
||||
norm_layer=nn.Sequential(Transpose(1,2), nn.BatchNorm1d(configs.d_model), Transpose(1,2))
|
||||
)
|
||||
|
||||
# Prediction Head
|
||||
self.head_nf = configs.d_model * \
|
||||
int((configs.seq_len - patch_len) / stride + 2)
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
self.head = FlattenHead(configs.enc_in, self.head_nf, configs.pred_len,
|
||||
head_dropout=configs.dropout)
|
||||
elif self.task_name == 'imputation' or self.task_name == 'anomaly_detection':
|
||||
self.head = FlattenHead(configs.enc_in, self.head_nf, configs.seq_len,
|
||||
head_dropout=configs.dropout)
|
||||
elif self.task_name == 'classification':
|
||||
self.flatten = nn.Flatten(start_dim=-2)
|
||||
self.dropout = nn.Dropout(configs.dropout)
|
||||
self.projection = nn.Linear(
|
||||
self.head_nf * configs.enc_in, configs.num_class)
|
||||
|
||||
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
||||
# Normalization from Non-stationary Transformer
|
||||
means = x_enc.mean(1, keepdim=True).detach()
|
||||
x_enc = x_enc - means
|
||||
stdev = torch.sqrt(
|
||||
torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
||||
x_enc /= stdev
|
||||
|
||||
# do patching and embedding
|
||||
x_enc = x_enc.permute(0, 2, 1)
|
||||
# u: [bs * nvars x patch_num x d_model]
|
||||
enc_out, n_vars = self.patch_embedding(x_enc)
|
||||
|
||||
# Encoder
|
||||
# z: [bs * nvars x patch_num x d_model]
|
||||
enc_out, attns = self.encoder(enc_out)
|
||||
# z: [bs x nvars x patch_num x d_model]
|
||||
enc_out = torch.reshape(
|
||||
enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1]))
|
||||
# z: [bs x nvars x d_model x patch_num]
|
||||
enc_out = enc_out.permute(0, 1, 3, 2)
|
||||
|
||||
# Decoder
|
||||
dec_out = self.head(enc_out) # z: [bs x nvars x target_window]
|
||||
dec_out = dec_out.permute(0, 2, 1)
|
||||
|
||||
# De-Normalization from Non-stationary Transformer
|
||||
dec_out = dec_out * \
|
||||
(stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
|
||||
dec_out = dec_out + \
|
||||
(means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
|
||||
return dec_out
|
||||
|
||||
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
|
||||
# Normalization from Non-stationary Transformer
|
||||
means = torch.sum(x_enc, dim=1) / torch.sum(mask == 1, dim=1)
|
||||
means = means.unsqueeze(1).detach()
|
||||
x_enc = x_enc - means
|
||||
x_enc = x_enc.masked_fill(mask == 0, 0)
|
||||
stdev = torch.sqrt(torch.sum(x_enc * x_enc, dim=1) /
|
||||
torch.sum(mask == 1, dim=1) + 1e-5)
|
||||
stdev = stdev.unsqueeze(1).detach()
|
||||
x_enc /= stdev
|
||||
|
||||
# do patching and embedding
|
||||
x_enc = x_enc.permute(0, 2, 1)
|
||||
# u: [bs * nvars x patch_num x d_model]
|
||||
enc_out, n_vars = self.patch_embedding(x_enc)
|
||||
|
||||
# Encoder
|
||||
# z: [bs * nvars x patch_num x d_model]
|
||||
enc_out, attns = self.encoder(enc_out)
|
||||
# z: [bs x nvars x patch_num x d_model]
|
||||
enc_out = torch.reshape(
|
||||
enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1]))
|
||||
# z: [bs x nvars x d_model x patch_num]
|
||||
enc_out = enc_out.permute(0, 1, 3, 2)
|
||||
|
||||
# Decoder
|
||||
dec_out = self.head(enc_out) # z: [bs x nvars x target_window]
|
||||
dec_out = dec_out.permute(0, 2, 1)
|
||||
|
||||
# De-Normalization from Non-stationary Transformer
|
||||
dec_out = dec_out * \
|
||||
(stdev[:, 0, :].unsqueeze(1).repeat(1, self.seq_len, 1))
|
||||
dec_out = dec_out + \
|
||||
(means[:, 0, :].unsqueeze(1).repeat(1, self.seq_len, 1))
|
||||
return dec_out
|
||||
|
||||
def anomaly_detection(self, x_enc):
|
||||
# Normalization from Non-stationary Transformer
|
||||
means = x_enc.mean(1, keepdim=True).detach()
|
||||
x_enc = x_enc - means
|
||||
stdev = torch.sqrt(
|
||||
torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
||||
x_enc /= stdev
|
||||
|
||||
# do patching and embedding
|
||||
x_enc = x_enc.permute(0, 2, 1)
|
||||
# u: [bs * nvars x patch_num x d_model]
|
||||
enc_out, n_vars = self.patch_embedding(x_enc)
|
||||
|
||||
# Encoder
|
||||
# z: [bs * nvars x patch_num x d_model]
|
||||
enc_out, attns = self.encoder(enc_out)
|
||||
# z: [bs x nvars x patch_num x d_model]
|
||||
enc_out = torch.reshape(
|
||||
enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1]))
|
||||
# z: [bs x nvars x d_model x patch_num]
|
||||
enc_out = enc_out.permute(0, 1, 3, 2)
|
||||
|
||||
# Decoder
|
||||
dec_out = self.head(enc_out) # z: [bs x nvars x target_window]
|
||||
dec_out = dec_out.permute(0, 2, 1)
|
||||
|
||||
# De-Normalization from Non-stationary Transformer
|
||||
dec_out = dec_out * \
|
||||
(stdev[:, 0, :].unsqueeze(1).repeat(1, self.seq_len, 1))
|
||||
dec_out = dec_out + \
|
||||
(means[:, 0, :].unsqueeze(1).repeat(1, self.seq_len, 1))
|
||||
return dec_out
|
||||
|
||||
def classification(self, x_enc, x_mark_enc):
|
||||
# Normalization from Non-stationary Transformer
|
||||
means = x_enc.mean(1, keepdim=True).detach()
|
||||
x_enc = x_enc - means
|
||||
stdev = torch.sqrt(
|
||||
torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
||||
x_enc /= stdev
|
||||
|
||||
# do patching and embedding
|
||||
x_enc = x_enc.permute(0, 2, 1)
|
||||
# u: [bs * nvars x patch_num x d_model]
|
||||
enc_out, n_vars = self.patch_embedding(x_enc)
|
||||
|
||||
# Encoder
|
||||
# z: [bs * nvars x patch_num x d_model]
|
||||
enc_out, attns = self.encoder(enc_out)
|
||||
# z: [bs x nvars x patch_num x d_model]
|
||||
enc_out = torch.reshape(
|
||||
enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1]))
|
||||
# z: [bs x nvars x d_model x patch_num]
|
||||
enc_out = enc_out.permute(0, 1, 3, 2)
|
||||
|
||||
# Decoder
|
||||
output = self.flatten(enc_out)
|
||||
output = self.dropout(output)
|
||||
output = output.reshape(output.shape[0], -1)
|
||||
output = self.projection(output) # (batch_size, num_classes)
|
||||
return output
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
||||
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
||||
if self.task_name == 'imputation':
|
||||
dec_out = self.imputation(
|
||||
x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'anomaly_detection':
|
||||
dec_out = self.anomaly_detection(x_enc)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'classification':
|
||||
dec_out = self.classification(x_enc, x_mark_enc)
|
||||
return dec_out # [B, N]
|
||||
return None
|
101
models/Pyraformer.py
Normal file
101
models/Pyraformer.py
Normal file
@ -0,0 +1,101 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from layers.Pyraformer_EncDec import Encoder
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
"""
|
||||
Pyraformer: Pyramidal attention to reduce complexity
|
||||
Paper link: https://openreview.net/pdf?id=0EXmFzUn5I
|
||||
"""
|
||||
|
||||
def __init__(self, configs, window_size=[4,4], inner_size=5):
|
||||
"""
|
||||
window_size: list, the downsample window size in pyramidal attention.
|
||||
inner_size: int, the size of neighbour attention
|
||||
"""
|
||||
super().__init__()
|
||||
self.task_name = configs.task_name
|
||||
self.pred_len = configs.pred_len
|
||||
self.d_model = configs.d_model
|
||||
|
||||
if self.task_name == 'short_term_forecast':
|
||||
window_size = [2,2]
|
||||
self.encoder = Encoder(configs, window_size, inner_size)
|
||||
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
self.projection = nn.Linear(
|
||||
(len(window_size)+1)*self.d_model, self.pred_len * configs.enc_in)
|
||||
elif self.task_name == 'imputation' or self.task_name == 'anomaly_detection':
|
||||
self.projection = nn.Linear(
|
||||
(len(window_size)+1)*self.d_model, configs.enc_in, bias=True)
|
||||
elif self.task_name == 'classification':
|
||||
self.act = torch.nn.functional.gelu
|
||||
self.dropout = nn.Dropout(configs.dropout)
|
||||
self.projection = nn.Linear(
|
||||
(len(window_size)+1)*self.d_model * configs.seq_len, configs.num_class)
|
||||
|
||||
def long_forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
||||
enc_out = self.encoder(x_enc, x_mark_enc)[:, -1, :]
|
||||
dec_out = self.projection(enc_out).view(
|
||||
enc_out.size(0), self.pred_len, -1)
|
||||
return dec_out
|
||||
|
||||
def short_forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
||||
# Normalization
|
||||
mean_enc = x_enc.mean(1, keepdim=True).detach() # B x 1 x E
|
||||
x_enc = x_enc - mean_enc
|
||||
std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach() # B x 1 x E
|
||||
x_enc = x_enc / std_enc
|
||||
|
||||
enc_out = self.encoder(x_enc, x_mark_enc)[:, -1, :]
|
||||
dec_out = self.projection(enc_out).view(
|
||||
enc_out.size(0), self.pred_len, -1)
|
||||
|
||||
dec_out = dec_out * std_enc + mean_enc
|
||||
return dec_out
|
||||
|
||||
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
|
||||
enc_out = self.encoder(x_enc, x_mark_enc)
|
||||
dec_out = self.projection(enc_out)
|
||||
return dec_out
|
||||
|
||||
def anomaly_detection(self, x_enc, x_mark_enc):
|
||||
enc_out = self.encoder(x_enc, x_mark_enc)
|
||||
dec_out = self.projection(enc_out)
|
||||
return dec_out
|
||||
|
||||
def classification(self, x_enc, x_mark_enc):
|
||||
# enc
|
||||
enc_out = self.encoder(x_enc, x_mark_enc=None)
|
||||
|
||||
# Output
|
||||
# the output transformer encoder/decoder embeddings don't include non-linearity
|
||||
output = self.act(enc_out)
|
||||
output = self.dropout(output)
|
||||
# zero-out padding embeddings
|
||||
output = output * x_mark_enc.unsqueeze(-1)
|
||||
# (batch_size, seq_length * d_model)
|
||||
output = output.reshape(output.shape[0], -1)
|
||||
output = self.projection(output) # (batch_size, num_classes)
|
||||
|
||||
return output
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
||||
if self.task_name == 'long_term_forecast':
|
||||
dec_out = self.long_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
||||
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
||||
if self.task_name == 'short_term_forecast':
|
||||
dec_out = self.short_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
||||
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
||||
if self.task_name == 'imputation':
|
||||
dec_out = self.imputation(
|
||||
x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'anomaly_detection':
|
||||
dec_out = self.anomaly_detection(x_enc, x_mark_enc)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'classification':
|
||||
dec_out = self.classification(x_enc, x_mark_enc)
|
||||
return dec_out # [B, N]
|
||||
return None
|
132
models/Reformer.py
Normal file
132
models/Reformer.py
Normal file
@ -0,0 +1,132 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from layers.Transformer_EncDec import Encoder, EncoderLayer
|
||||
from layers.SelfAttention_Family import ReformerLayer
|
||||
from layers.Embed import DataEmbedding
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
"""
|
||||
Reformer with O(LlogL) complexity
|
||||
Paper link: https://openreview.net/forum?id=rkgNKkHtvB
|
||||
"""
|
||||
|
||||
def __init__(self, configs, bucket_size=4, n_hashes=4):
|
||||
"""
|
||||
bucket_size: int,
|
||||
n_hashes: int,
|
||||
"""
|
||||
super(Model, self).__init__()
|
||||
self.task_name = configs.task_name
|
||||
self.pred_len = configs.pred_len
|
||||
self.seq_len = configs.seq_len
|
||||
|
||||
self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,
|
||||
configs.dropout)
|
||||
# Encoder
|
||||
self.encoder = Encoder(
|
||||
[
|
||||
EncoderLayer(
|
||||
ReformerLayer(None, configs.d_model, configs.n_heads,
|
||||
bucket_size=bucket_size, n_hashes=n_hashes),
|
||||
configs.d_model,
|
||||
configs.d_ff,
|
||||
dropout=configs.dropout,
|
||||
activation=configs.activation
|
||||
) for l in range(configs.e_layers)
|
||||
],
|
||||
norm_layer=torch.nn.LayerNorm(configs.d_model)
|
||||
)
|
||||
|
||||
if self.task_name == 'classification':
|
||||
self.act = F.gelu
|
||||
self.dropout = nn.Dropout(configs.dropout)
|
||||
self.projection = nn.Linear(
|
||||
configs.d_model * configs.seq_len, configs.num_class)
|
||||
else:
|
||||
self.projection = nn.Linear(
|
||||
configs.d_model, configs.c_out, bias=True)
|
||||
|
||||
def long_forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
||||
# add placeholder
|
||||
x_enc = torch.cat([x_enc, x_dec[:, -self.pred_len:, :]], dim=1)
|
||||
if x_mark_enc is not None:
|
||||
x_mark_enc = torch.cat(
|
||||
[x_mark_enc, x_mark_dec[:, -self.pred_len:, :]], dim=1)
|
||||
|
||||
enc_out = self.enc_embedding(x_enc, x_mark_enc) # [B,T,C]
|
||||
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
||||
dec_out = self.projection(enc_out)
|
||||
|
||||
return dec_out # [B, L, D]
|
||||
|
||||
def short_forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
||||
# Normalization
|
||||
mean_enc = x_enc.mean(1, keepdim=True).detach() # B x 1 x E
|
||||
x_enc = x_enc - mean_enc
|
||||
std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach() # B x 1 x E
|
||||
x_enc = x_enc / std_enc
|
||||
|
||||
# add placeholder
|
||||
x_enc = torch.cat([x_enc, x_dec[:, -self.pred_len:, :]], dim=1)
|
||||
if x_mark_enc is not None:
|
||||
x_mark_enc = torch.cat(
|
||||
[x_mark_enc, x_mark_dec[:, -self.pred_len:, :]], dim=1)
|
||||
|
||||
enc_out = self.enc_embedding(x_enc, x_mark_enc) # [B,T,C]
|
||||
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
||||
dec_out = self.projection(enc_out)
|
||||
|
||||
dec_out = dec_out * std_enc + mean_enc
|
||||
return dec_out # [B, L, D]
|
||||
|
||||
def imputation(self, x_enc, x_mark_enc):
|
||||
enc_out = self.enc_embedding(x_enc, x_mark_enc) # [B,T,C]
|
||||
|
||||
enc_out, attns = self.encoder(enc_out)
|
||||
enc_out = self.projection(enc_out)
|
||||
|
||||
return enc_out # [B, L, D]
|
||||
|
||||
def anomaly_detection(self, x_enc):
|
||||
enc_out = self.enc_embedding(x_enc, None) # [B,T,C]
|
||||
|
||||
enc_out, attns = self.encoder(enc_out)
|
||||
enc_out = self.projection(enc_out)
|
||||
|
||||
return enc_out # [B, L, D]
|
||||
|
||||
def classification(self, x_enc, x_mark_enc):
|
||||
# enc
|
||||
enc_out = self.enc_embedding(x_enc, None)
|
||||
enc_out, attns = self.encoder(enc_out)
|
||||
|
||||
# Output
|
||||
# the output transformer encoder/decoder embeddings don't include non-linearity
|
||||
output = self.act(enc_out)
|
||||
output = self.dropout(output)
|
||||
# zero-out padding embeddings
|
||||
output = output * x_mark_enc.unsqueeze(-1)
|
||||
# (batch_size, seq_length * d_model)
|
||||
output = output.reshape(output.shape[0], -1)
|
||||
output = self.projection(output) # (batch_size, num_classes)
|
||||
return output
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
||||
if self.task_name == 'long_term_forecast':
|
||||
dec_out = self.long_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
||||
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
||||
if self.task_name == 'short_term_forecast':
|
||||
dec_out = self.short_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
||||
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
||||
if self.task_name == 'imputation':
|
||||
dec_out = self.imputation(x_enc, x_mark_enc)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'anomaly_detection':
|
||||
dec_out = self.anomaly_detection(x_enc)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'classification':
|
||||
dec_out = self.classification(x_enc, x_mark_enc)
|
||||
return dec_out # [B, N]
|
||||
return None
|
188
models/SCINet.py
Normal file
188
models/SCINet.py
Normal file
@ -0,0 +1,188 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
class Splitting(nn.Module):
|
||||
def __init__(self):
|
||||
super(Splitting, self).__init__()
|
||||
|
||||
def even(self, x):
|
||||
return x[:, ::2, :]
|
||||
|
||||
def odd(self, x):
|
||||
return x[:, 1::2, :]
|
||||
|
||||
def forward(self, x):
|
||||
# return the odd and even part
|
||||
return self.even(x), self.odd(x)
|
||||
|
||||
|
||||
class CausalConvBlock(nn.Module):
|
||||
def __init__(self, d_model, kernel_size=5, dropout=0.0):
|
||||
super(CausalConvBlock, self).__init__()
|
||||
module_list = [
|
||||
nn.ReplicationPad1d((kernel_size - 1, kernel_size - 1)),
|
||||
|
||||
nn.Conv1d(d_model, d_model,
|
||||
kernel_size=kernel_size),
|
||||
nn.LeakyReLU(negative_slope=0.01, inplace=True),
|
||||
|
||||
nn.Dropout(dropout),
|
||||
nn.Conv1d(d_model, d_model,
|
||||
kernel_size=kernel_size),
|
||||
nn.Tanh()
|
||||
]
|
||||
self.causal_conv = nn.Sequential(*module_list)
|
||||
|
||||
def forward(self, x):
|
||||
return self.causal_conv(x) # return value is the same as input dimension
|
||||
|
||||
|
||||
class SCIBlock(nn.Module):
|
||||
def __init__(self, d_model, kernel_size=5, dropout=0.0):
|
||||
super(SCIBlock, self).__init__()
|
||||
self.splitting = Splitting()
|
||||
self.modules_even, self.modules_odd, self.interactor_even, self.interactor_odd = [CausalConvBlock(d_model) for _ in range(4)]
|
||||
|
||||
def forward(self, x):
|
||||
x_even, x_odd = self.splitting(x)
|
||||
x_even = x_even.permute(0, 2, 1)
|
||||
x_odd = x_odd.permute(0, 2, 1)
|
||||
|
||||
x_even_temp = x_even.mul(torch.exp(self.modules_even(x_odd)))
|
||||
x_odd_temp = x_odd.mul(torch.exp(self.modules_odd(x_even)))
|
||||
|
||||
x_even_update = x_even_temp + self.interactor_even(x_odd_temp)
|
||||
x_odd_update = x_odd_temp - self.interactor_odd(x_even_temp)
|
||||
|
||||
return x_even_update.permute(0, 2, 1), x_odd_update.permute(0, 2, 1)
|
||||
|
||||
|
||||
class SCINet(nn.Module):
|
||||
def __init__(self, d_model, current_level=3, kernel_size=5, dropout=0.0):
|
||||
super(SCINet, self).__init__()
|
||||
self.current_level = current_level
|
||||
self.working_block = SCIBlock(d_model, kernel_size, dropout)
|
||||
|
||||
if current_level != 0:
|
||||
self.SCINet_Tree_odd = SCINet(d_model, current_level-1, kernel_size, dropout)
|
||||
self.SCINet_Tree_even = SCINet(d_model, current_level-1, kernel_size, dropout)
|
||||
|
||||
def forward(self, x):
|
||||
odd_flag = False
|
||||
if x.shape[1] % 2 == 1:
|
||||
odd_flag = True
|
||||
x = torch.cat((x, x[:, -1:, :]), dim=1)
|
||||
x_even_update, x_odd_update = self.working_block(x)
|
||||
if odd_flag:
|
||||
x_odd_update = x_odd_update[:, :-1]
|
||||
|
||||
if self.current_level == 0:
|
||||
return self.zip_up_the_pants(x_even_update, x_odd_update)
|
||||
else:
|
||||
return self.zip_up_the_pants(self.SCINet_Tree_even(x_even_update), self.SCINet_Tree_odd(x_odd_update))
|
||||
|
||||
def zip_up_the_pants(self, even, odd):
|
||||
even = even.permute(1, 0, 2)
|
||||
odd = odd.permute(1, 0, 2)
|
||||
even_len = even.shape[0]
|
||||
odd_len = odd.shape[0]
|
||||
min_len = min(even_len, odd_len)
|
||||
|
||||
zipped_data = []
|
||||
for i in range(min_len):
|
||||
zipped_data.append(even[i].unsqueeze(0))
|
||||
zipped_data.append(odd[i].unsqueeze(0))
|
||||
if even_len > odd_len:
|
||||
zipped_data.append(even[-1].unsqueeze(0))
|
||||
return torch.cat(zipped_data,0).permute(1, 0, 2)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, configs):
|
||||
super(Model, self).__init__()
|
||||
self.task_name = configs.task_name
|
||||
self.seq_len = configs.seq_len
|
||||
self.label_len = configs.label_len
|
||||
self.pred_len = configs.pred_len
|
||||
|
||||
# You can set the number of SCINet stacks by argument "d_layers", but should choose 1 or 2.
|
||||
self.num_stacks = configs.d_layers
|
||||
if self.num_stacks == 1:
|
||||
self.sci_net_1 = SCINet(configs.enc_in, dropout=configs.dropout)
|
||||
self.projection_1 = nn.Conv1d(self.seq_len, self.seq_len + self.pred_len, kernel_size=1, stride=1, bias=False)
|
||||
else:
|
||||
self.sci_net_1, self.sci_net_2 = [SCINet(configs.enc_in, dropout=configs.dropout) for _ in range(2)]
|
||||
self.projection_1 = nn.Conv1d(self.seq_len, self.pred_len, kernel_size=1, stride=1, bias=False)
|
||||
self.projection_2 = nn.Conv1d(self.seq_len+self.pred_len, self.seq_len+self.pred_len,
|
||||
kernel_size = 1, bias = False)
|
||||
|
||||
# For positional encoding
|
||||
self.pe_hidden_size = configs.enc_in
|
||||
if self.pe_hidden_size % 2 == 1:
|
||||
self.pe_hidden_size += 1
|
||||
|
||||
num_timescales = self.pe_hidden_size // 2
|
||||
max_timescale = 10000.0
|
||||
min_timescale = 1.0
|
||||
|
||||
log_timescale_increment = (
|
||||
math.log(float(max_timescale) / float(min_timescale)) /
|
||||
max(num_timescales - 1, 1))
|
||||
inv_timescales = min_timescale * torch.exp(
|
||||
torch.arange(num_timescales, dtype=torch.float32) *
|
||||
-log_timescale_increment)
|
||||
self.register_buffer('inv_timescales', inv_timescales)
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) # [B,pred_len,C]
|
||||
dec_out = torch.cat([torch.zeros_like(x_enc), dec_out], dim=1)
|
||||
return dec_out # [B, T, D]
|
||||
return None
|
||||
|
||||
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
||||
# Normalization from Non-stationary Transformer
|
||||
means = x_enc.mean(1, keepdim=True).detach()
|
||||
x_enc = x_enc - means
|
||||
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
||||
x_enc /= stdev
|
||||
|
||||
# position-encoding
|
||||
pe = self.get_position_encoding(x_enc)
|
||||
if pe.shape[2] > x_enc.shape[2]:
|
||||
x_enc += pe[:, :, :-1]
|
||||
else:
|
||||
x_enc += self.get_position_encoding(x_enc)
|
||||
|
||||
# SCINet
|
||||
dec_out = self.sci_net_1(x_enc)
|
||||
dec_out += x_enc
|
||||
dec_out = self.projection_1(dec_out)
|
||||
if self.num_stacks != 1:
|
||||
dec_out = torch.cat((x_enc, dec_out), dim=1)
|
||||
temp = dec_out
|
||||
dec_out = self.sci_net_2(dec_out)
|
||||
dec_out += temp
|
||||
dec_out = self.projection_2(dec_out)
|
||||
|
||||
# De-Normalization from Non-stationary Transformer
|
||||
dec_out = dec_out * \
|
||||
(stdev[:, 0, :].unsqueeze(1).repeat(
|
||||
1, self.pred_len + self.seq_len, 1))
|
||||
dec_out = dec_out + \
|
||||
(means[:, 0, :].unsqueeze(1).repeat(
|
||||
1, self.pred_len + self.seq_len, 1))
|
||||
return dec_out
|
||||
|
||||
def get_position_encoding(self, x):
|
||||
max_length = x.size()[1]
|
||||
position = torch.arange(max_length, dtype=torch.float32,
|
||||
device=x.device) # tensor([0., 1., 2., 3., 4.], device='cuda:0')
|
||||
scaled_time = position.unsqueeze(1) * self.inv_timescales.unsqueeze(0) # 5 256
|
||||
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) # [T, C]
|
||||
signal = F.pad(signal, (0, 0, 0, self.pe_hidden_size % 2))
|
||||
signal = signal.view(1, max_length, self.pe_hidden_size)
|
||||
|
||||
return signal
|
119
models/SegRNN.py
Normal file
119
models/SegRNN.py
Normal file
@ -0,0 +1,119 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from layers.Autoformer_EncDec import series_decomp
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
"""
|
||||
Paper link: https://arxiv.org/abs/2308.11200.pdf
|
||||
"""
|
||||
|
||||
def __init__(self, configs):
|
||||
super(Model, self).__init__()
|
||||
|
||||
# get parameters
|
||||
self.seq_len = configs.seq_len
|
||||
self.enc_in = configs.enc_in
|
||||
self.d_model = configs.d_model
|
||||
self.dropout = configs.dropout
|
||||
|
||||
self.task_name = configs.task_name
|
||||
if self.task_name == 'classification' or self.task_name == 'anomaly_detection' or self.task_name == 'imputation':
|
||||
self.pred_len = configs.seq_len
|
||||
else:
|
||||
self.pred_len = configs.pred_len
|
||||
|
||||
self.seg_len = configs.seg_len
|
||||
self.seg_num_x = self.seq_len // self.seg_len
|
||||
self.seg_num_y = self.pred_len // self.seg_len
|
||||
|
||||
# building model
|
||||
self.valueEmbedding = nn.Sequential(
|
||||
nn.Linear(self.seg_len, self.d_model),
|
||||
nn.ReLU()
|
||||
)
|
||||
self.rnn = nn.GRU(input_size=self.d_model, hidden_size=self.d_model, num_layers=1, bias=True,
|
||||
batch_first=True, bidirectional=False)
|
||||
self.pos_emb = nn.Parameter(torch.randn(self.seg_num_y, self.d_model // 2))
|
||||
self.channel_emb = nn.Parameter(torch.randn(self.enc_in, self.d_model // 2))
|
||||
|
||||
self.predict = nn.Sequential(
|
||||
nn.Dropout(self.dropout),
|
||||
nn.Linear(self.d_model, self.seg_len)
|
||||
)
|
||||
|
||||
if self.task_name == 'classification':
|
||||
self.act = F.gelu
|
||||
self.dropout = nn.Dropout(configs.dropout)
|
||||
self.projection = nn.Linear(
|
||||
configs.enc_in * configs.seq_len, configs.num_class)
|
||||
|
||||
def encoder(self, x):
|
||||
# b:batch_size c:channel_size s:seq_len s:seq_len
|
||||
# d:d_model w:seg_len n:seg_num_x m:seg_num_y
|
||||
batch_size = x.size(0)
|
||||
|
||||
# normalization and permute b,s,c -> b,c,s
|
||||
seq_last = x[:, -1:, :].detach()
|
||||
x = (x - seq_last).permute(0, 2, 1) # b,c,s
|
||||
|
||||
# segment and embedding b,c,s -> bc,n,w -> bc,n,d
|
||||
x = self.valueEmbedding(x.reshape(-1, self.seg_num_x, self.seg_len))
|
||||
|
||||
# encoding
|
||||
_, hn = self.rnn(x) # bc,n,d 1,bc,d
|
||||
|
||||
# m,d//2 -> 1,m,d//2 -> c,m,d//2
|
||||
# c,d//2 -> c,1,d//2 -> c,m,d//2
|
||||
# c,m,d -> cm,1,d -> bcm, 1, d
|
||||
pos_emb = torch.cat([
|
||||
self.pos_emb.unsqueeze(0).repeat(self.enc_in, 1, 1),
|
||||
self.channel_emb.unsqueeze(1).repeat(1, self.seg_num_y, 1)
|
||||
], dim=-1).view(-1, 1, self.d_model).repeat(batch_size,1,1)
|
||||
|
||||
_, hy = self.rnn(pos_emb, hn.repeat(1, 1, self.seg_num_y).view(1, -1, self.d_model)) # bcm,1,d 1,bcm,d
|
||||
|
||||
# 1,bcm,d -> 1,bcm,w -> b,c,s
|
||||
y = self.predict(hy).view(-1, self.enc_in, self.pred_len)
|
||||
|
||||
# permute and denorm
|
||||
y = y.permute(0, 2, 1) + seq_last
|
||||
return y
|
||||
|
||||
def forecast(self, x_enc):
|
||||
# Encoder
|
||||
return self.encoder(x_enc)
|
||||
|
||||
def imputation(self, x_enc):
|
||||
# Encoder
|
||||
return self.encoder(x_enc)
|
||||
|
||||
def anomaly_detection(self, x_enc):
|
||||
# Encoder
|
||||
return self.encoder(x_enc)
|
||||
|
||||
def classification(self, x_enc):
|
||||
# Encoder
|
||||
enc_out = self.encoder(x_enc)
|
||||
# Output
|
||||
# (batch_size, seq_length * d_model)
|
||||
output = enc_out.reshape(enc_out.shape[0], -1)
|
||||
# (batch_size, num_classes)
|
||||
output = self.projection(output)
|
||||
return output
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
dec_out = self.forecast(x_enc)
|
||||
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
||||
if self.task_name == 'imputation':
|
||||
dec_out = self.imputation(x_enc)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'anomaly_detection':
|
||||
dec_out = self.anomaly_detection(x_enc)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'classification':
|
||||
dec_out = self.classification(x_enc)
|
||||
return dec_out # [B, N]
|
||||
return None
|
54
models/TSMixer.py
Normal file
54
models/TSMixer.py
Normal file
@ -0,0 +1,54 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, configs):
|
||||
super(ResBlock, self).__init__()
|
||||
|
||||
self.temporal = nn.Sequential(
|
||||
nn.Linear(configs.seq_len, configs.d_model),
|
||||
nn.ReLU(),
|
||||
nn.Linear(configs.d_model, configs.seq_len),
|
||||
nn.Dropout(configs.dropout)
|
||||
)
|
||||
|
||||
self.channel = nn.Sequential(
|
||||
nn.Linear(configs.enc_in, configs.d_model),
|
||||
nn.ReLU(),
|
||||
nn.Linear(configs.d_model, configs.enc_in),
|
||||
nn.Dropout(configs.dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# x: [B, L, D]
|
||||
x = x + self.temporal(x.transpose(1, 2)).transpose(1, 2)
|
||||
x = x + self.channel(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, configs):
|
||||
super(Model, self).__init__()
|
||||
self.task_name = configs.task_name
|
||||
self.layer = configs.e_layers
|
||||
self.model = nn.ModuleList([ResBlock(configs)
|
||||
for _ in range(configs.e_layers)])
|
||||
self.pred_len = configs.pred_len
|
||||
self.projection = nn.Linear(configs.seq_len, configs.pred_len)
|
||||
|
||||
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
||||
|
||||
# x: [B, L, D]
|
||||
for i in range(self.layer):
|
||||
x_enc = self.model[i](x_enc)
|
||||
enc_out = self.projection(x_enc.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
return enc_out
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
||||
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
||||
else:
|
||||
raise ValueError('Only forecast tasks implemented yet')
|
309
models/TemporalFusionTransformer.py
Normal file
309
models/TemporalFusionTransformer.py
Normal file
@ -0,0 +1,309 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from layers.Embed import DataEmbedding, TemporalEmbedding
|
||||
from torch import Tensor
|
||||
from typing import Optional
|
||||
from collections import namedtuple
|
||||
|
||||
# static: time-independent features
|
||||
# observed: time features of the past(e.g. predicted targets)
|
||||
# known: known information about the past and future(i.e. time stamp)
|
||||
TypePos = namedtuple('TypePos', ['static', 'observed'])
|
||||
|
||||
# When you want to use new dataset, please add the index of 'static, observed' columns here.
|
||||
# 'known' columns needn't be added, because 'known' inputs are automatically judged and provided by the program.
|
||||
datatype_dict = {'ETTh1': TypePos([], [x for x in range(7)]),
|
||||
'ETTm1': TypePos([], [x for x in range(7)])}
|
||||
|
||||
|
||||
def get_known_len(embed_type, freq):
|
||||
if embed_type != 'timeF':
|
||||
if freq == 't':
|
||||
return 5
|
||||
else:
|
||||
return 4
|
||||
else:
|
||||
freq_map = {'h': 4, 't': 5, 's': 6,
|
||||
'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3}
|
||||
return freq_map[freq]
|
||||
|
||||
|
||||
class TFTTemporalEmbedding(TemporalEmbedding):
|
||||
def __init__(self, d_model, embed_type='fixed', freq='h'):
|
||||
super(TFTTemporalEmbedding, self).__init__(d_model, embed_type, freq)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.long()
|
||||
minute_x = self.minute_embed(x[:, :, 4]) if hasattr(
|
||||
self, 'minute_embed') else 0.
|
||||
hour_x = self.hour_embed(x[:, :, 3])
|
||||
weekday_x = self.weekday_embed(x[:, :, 2])
|
||||
day_x = self.day_embed(x[:, :, 1])
|
||||
month_x = self.month_embed(x[:, :, 0])
|
||||
|
||||
embedding_x = torch.stack([month_x, day_x, weekday_x, hour_x, minute_x], dim=-2) if hasattr(
|
||||
self, 'minute_embed') else torch.stack([month_x, day_x, weekday_x, hour_x], dim=-2)
|
||||
return embedding_x
|
||||
|
||||
|
||||
class TFTTimeFeatureEmbedding(nn.Module):
|
||||
def __init__(self, d_model, embed_type='timeF', freq='h'):
|
||||
super(TFTTimeFeatureEmbedding, self).__init__()
|
||||
d_inp = get_known_len(embed_type, freq)
|
||||
self.embed = nn.ModuleList([nn.Linear(1, d_model, bias=False) for _ in range(d_inp)])
|
||||
|
||||
def forward(self, x):
|
||||
return torch.stack([embed(x[:,:,i].unsqueeze(-1)) for i, embed in enumerate(self.embed)], dim=-2)
|
||||
|
||||
|
||||
class TFTEmbedding(nn.Module):
|
||||
def __init__(self, configs):
|
||||
super(TFTEmbedding, self).__init__()
|
||||
self.pred_len = configs.pred_len
|
||||
self.static_pos = datatype_dict[configs.data].static
|
||||
self.observed_pos = datatype_dict[configs.data].observed
|
||||
self.static_len = len(self.static_pos)
|
||||
self.observed_len = len(self.observed_pos)
|
||||
|
||||
self.static_embedding = nn.ModuleList([DataEmbedding(1,configs.d_model,dropout=configs.dropout) for _ in range(self.static_len)]) \
|
||||
if self.static_len else None
|
||||
self.observed_embedding = nn.ModuleList([DataEmbedding(1,configs.d_model,dropout=configs.dropout) for _ in range(self.observed_len)])
|
||||
self.known_embedding = TFTTemporalEmbedding(configs.d_model, configs.embed, configs.freq) \
|
||||
if configs.embed != 'timeF' else TFTTimeFeatureEmbedding(configs.d_model, configs.embed, configs.freq)
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
||||
if self.static_len:
|
||||
# static_input: [B,C,d_model]
|
||||
static_input = torch.stack([embed(x_enc[:,:1,self.static_pos[i]].unsqueeze(-1), None).squeeze(1) for i, embed in enumerate(self.static_embedding)], dim=-2)
|
||||
else:
|
||||
static_input = None
|
||||
|
||||
# observed_input: [B,T,C,d_model]
|
||||
observed_input = torch.stack([embed(x_enc[:,:,self.observed_pos[i]].unsqueeze(-1), None) for i, embed in enumerate(self.observed_embedding)], dim=-2)
|
||||
|
||||
x_mark = torch.cat([x_mark_enc, x_mark_dec[:,-self.pred_len:,:]], dim=-2)
|
||||
# known_input: [B,T,C,d_model]
|
||||
known_input = self.known_embedding(x_mark)
|
||||
|
||||
return static_input, observed_input, known_input
|
||||
|
||||
|
||||
class GLU(nn.Module):
|
||||
def __init__(self, input_size, output_size):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(input_size, output_size)
|
||||
self.fc2 = nn.Linear(input_size, output_size)
|
||||
self.glu = nn.GLU()
|
||||
|
||||
def forward(self, x):
|
||||
a = self.fc1(x)
|
||||
b = self.fc2(x)
|
||||
return self.glu(torch.cat([a, b], dim=-1))
|
||||
|
||||
|
||||
class GateAddNorm(nn.Module):
|
||||
def __init__(self, input_size, output_size):
|
||||
super(GateAddNorm, self).__init__()
|
||||
self.glu = GLU(input_size, input_size)
|
||||
self.projection = nn.Linear(input_size, output_size) if input_size != output_size else nn.Identity()
|
||||
self.layer_norm = nn.LayerNorm(output_size)
|
||||
|
||||
def forward(self, x, skip_a):
|
||||
x = self.glu(x)
|
||||
x = x + skip_a
|
||||
return self.layer_norm(self.projection(x))
|
||||
|
||||
|
||||
class GRN(nn.Module):
|
||||
def __init__(self, input_size, output_size, hidden_size=None, context_size=None, dropout=0.0):
|
||||
super(GRN, self).__init__()
|
||||
hidden_size = input_size if hidden_size is None else hidden_size
|
||||
self.lin_a = nn.Linear(input_size, hidden_size)
|
||||
self.lin_c = nn.Linear(context_size, hidden_size) if context_size is not None else None
|
||||
self.lin_i = nn.Linear(hidden_size, hidden_size)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.project_a = nn.Linear(input_size, hidden_size) if hidden_size != input_size else nn.Identity()
|
||||
self.gate = GateAddNorm(hidden_size, output_size)
|
||||
|
||||
def forward(self, a: Tensor, c: Optional[Tensor] = None):
|
||||
# a: [B,T,d], c: [B,d]
|
||||
x = self.lin_a(a)
|
||||
if c is not None:
|
||||
x = x + self.lin_c(c).unsqueeze(1)
|
||||
x = F.elu(x)
|
||||
x = self.lin_i(x)
|
||||
x = self.dropout(x)
|
||||
return self.gate(x, self.project_a(a))
|
||||
|
||||
|
||||
class VariableSelectionNetwork(nn.Module):
|
||||
def __init__(self, d_model, variable_num, dropout=0.0):
|
||||
super(VariableSelectionNetwork, self).__init__()
|
||||
self.joint_grn = GRN(d_model * variable_num, variable_num, hidden_size=d_model, context_size=d_model, dropout=dropout)
|
||||
self.variable_grns = nn.ModuleList([GRN(d_model, d_model, dropout=dropout) for _ in range(variable_num)])
|
||||
|
||||
def forward(self, x: Tensor, context: Optional[Tensor] = None):
|
||||
# x: [B,T,C,d] or [B,C,d]
|
||||
# selection_weights: [B,T,C] or [B,C]
|
||||
# x_processed: [B,T,d,C] or [B,d,C]
|
||||
# selection_result: [B,T,d] or [B,d]
|
||||
x_flattened = torch.flatten(x, start_dim=-2)
|
||||
selection_weights = self.joint_grn(x_flattened, context)
|
||||
selection_weights = F.softmax(selection_weights, dim=-1)
|
||||
|
||||
x_processed = torch.stack([grn(x[...,i,:]) for i, grn in enumerate(self.variable_grns)], dim=-1)
|
||||
|
||||
selection_result = torch.matmul(x_processed, selection_weights.unsqueeze(-1)).squeeze(-1)
|
||||
return selection_result
|
||||
|
||||
|
||||
class StaticCovariateEncoder(nn.Module):
|
||||
def __init__(self, d_model, static_len, dropout=0.0):
|
||||
super(StaticCovariateEncoder, self).__init__()
|
||||
self.static_vsn = VariableSelectionNetwork(d_model, static_len) if static_len else None
|
||||
self.grns = nn.ModuleList([GRN(d_model, d_model, dropout=dropout) for _ in range(4)])
|
||||
|
||||
def forward(self, static_input):
|
||||
# static_input: [B,C,d]
|
||||
if static_input is not None:
|
||||
static_features = self.static_vsn(static_input)
|
||||
return [grn(static_features) for grn in self.grns]
|
||||
else:
|
||||
return [None] * 4
|
||||
|
||||
|
||||
class InterpretableMultiHeadAttention(nn.Module):
|
||||
def __init__(self, configs):
|
||||
super(InterpretableMultiHeadAttention, self).__init__()
|
||||
self.n_heads = configs.n_heads
|
||||
assert configs.d_model % configs.n_heads == 0
|
||||
self.d_head = configs.d_model // configs.n_heads
|
||||
self.qkv_linears = nn.Linear(configs.d_model, (2 * self.n_heads + 1) * self.d_head, bias=False)
|
||||
self.out_projection = nn.Linear(self.d_head, configs.d_model, bias=False)
|
||||
self.out_dropout = nn.Dropout(configs.dropout)
|
||||
self.scale = self.d_head ** -0.5
|
||||
example_len = configs.seq_len + configs.pred_len
|
||||
self.register_buffer("mask", torch.triu(torch.full((example_len, example_len), float('-inf')), 1))
|
||||
|
||||
def forward(self, x):
|
||||
# Q,K,V are all from x
|
||||
B, T, d_model = x.shape
|
||||
qkv = self.qkv_linears(x)
|
||||
q, k, v = qkv.split((self.n_heads * self.d_head, self.n_heads * self.d_head, self.d_head), dim=-1)
|
||||
q = q.view(B, T, self.n_heads, self.d_head)
|
||||
k = k.view(B, T, self.n_heads, self.d_head)
|
||||
v = v.view(B, T, self.d_head)
|
||||
|
||||
attention_score = torch.matmul(q.permute((0, 2, 1, 3)), k.permute((0, 2, 3, 1))) # [B,n,T,T]
|
||||
attention_score.mul_(self.scale)
|
||||
attention_score = attention_score + self.mask
|
||||
attention_prob = F.softmax(attention_score, dim=3) # [B,n,T,T]
|
||||
|
||||
attention_out = torch.matmul(attention_prob, v.unsqueeze(1)) # [B,n,T,d]
|
||||
attention_out = torch.mean(attention_out, dim=1) # [B,T,d]
|
||||
out = self.out_projection(attention_out)
|
||||
out = self.out_dropout(out) # [B,T,d]
|
||||
return out
|
||||
|
||||
|
||||
class TemporalFusionDecoder(nn.Module):
|
||||
def __init__(self, configs):
|
||||
super(TemporalFusionDecoder, self).__init__()
|
||||
self.pred_len = configs.pred_len
|
||||
|
||||
self.history_encoder = nn.LSTM(configs.d_model, configs.d_model, batch_first=True)
|
||||
self.future_encoder = nn.LSTM(configs.d_model, configs.d_model, batch_first=True)
|
||||
self.gate_after_lstm = GateAddNorm(configs.d_model, configs.d_model)
|
||||
self.enrichment_grn = GRN(configs.d_model, configs.d_model, context_size=configs.d_model, dropout=configs.dropout)
|
||||
self.attention = InterpretableMultiHeadAttention(configs)
|
||||
self.gate_after_attention = GateAddNorm(configs.d_model, configs.d_model)
|
||||
self.position_wise_grn = GRN(configs.d_model, configs.d_model, dropout=configs.dropout)
|
||||
self.gate_final = GateAddNorm(configs.d_model, configs.d_model)
|
||||
self.out_projection = nn.Linear(configs.d_model, configs.c_out)
|
||||
|
||||
def forward(self, history_input, future_input, c_c, c_h, c_e):
|
||||
# history_input, future_input: [B,T,d]
|
||||
# c_c, c_h, c_e: [B,d]
|
||||
# LSTM
|
||||
c = (c_c.unsqueeze(0), c_h.unsqueeze(0)) if c_c is not None and c_h is not None else None
|
||||
historical_features, state = self.history_encoder(history_input, c)
|
||||
future_features, _ = self.future_encoder(future_input, state)
|
||||
|
||||
# Skip connection
|
||||
temporal_input = torch.cat([history_input, future_input], dim=1)
|
||||
temporal_features = torch.cat([historical_features, future_features], dim=1)
|
||||
temporal_features = self.gate_after_lstm(temporal_features, temporal_input) # [B,T,d]
|
||||
|
||||
# Static enrichment
|
||||
enriched_features = self.enrichment_grn(temporal_features, c_e) # [B,T,d]
|
||||
|
||||
# Temporal self-attention
|
||||
attention_out = self.attention(enriched_features) # [B,T,d]
|
||||
# Don't compute historical loss
|
||||
attention_out = self.gate_after_attention(attention_out[:,-self.pred_len:], enriched_features[:,-self.pred_len:])
|
||||
|
||||
# Position-wise feed-forward
|
||||
out = self.position_wise_grn(attention_out) # [B,T,d]
|
||||
|
||||
# Final skip connection
|
||||
out = self.gate_final(out, temporal_features[:,-self.pred_len:])
|
||||
return self.out_projection(out)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, configs):
|
||||
super(Model, self).__init__()
|
||||
self.configs = configs
|
||||
self.task_name = configs.task_name
|
||||
self.seq_len = configs.seq_len
|
||||
self.label_len = configs.label_len
|
||||
self.pred_len = configs.pred_len
|
||||
|
||||
# Number of variables
|
||||
self.static_len = len(datatype_dict[configs.data].static)
|
||||
self.observed_len = len(datatype_dict[configs.data].observed)
|
||||
self.known_len = get_known_len(configs.embed, configs.freq)
|
||||
|
||||
self.embedding = TFTEmbedding(configs)
|
||||
self.static_encoder = StaticCovariateEncoder(configs.d_model, self.static_len)
|
||||
self.history_vsn = VariableSelectionNetwork(configs.d_model, self.observed_len + self.known_len)
|
||||
self.future_vsn = VariableSelectionNetwork(configs.d_model, self.known_len)
|
||||
self.temporal_fusion_decoder = TemporalFusionDecoder(configs)
|
||||
|
||||
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
||||
# Normalization from Non-stationary Transformer
|
||||
means = x_enc.mean(1, keepdim=True).detach()
|
||||
x_enc = x_enc - means
|
||||
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
||||
x_enc /= stdev
|
||||
|
||||
# Data embedding
|
||||
# static_input: [B,C,d], observed_input:[B,T,C,d], known_input: [B,T,C,d]
|
||||
static_input, observed_input, known_input = self.embedding(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
||||
|
||||
# Static context
|
||||
# c_s,...,c_e: [B,d]
|
||||
c_s, c_c, c_h, c_e = self.static_encoder(static_input)
|
||||
|
||||
# Temporal input Selection
|
||||
history_input = torch.cat([observed_input, known_input[:,:self.seq_len]], dim=-2)
|
||||
future_input = known_input[:,self.seq_len:]
|
||||
history_input = self.history_vsn(history_input, c_s)
|
||||
future_input = self.future_vsn(future_input, c_s)
|
||||
|
||||
# TFT main procedure after variable selection
|
||||
# history_input: [B,T,d], future_input: [B,T,d]
|
||||
dec_out = self.temporal_fusion_decoder(history_input, future_input, c_c, c_h, c_e)
|
||||
|
||||
# De-Normalization from Non-stationary Transformer
|
||||
dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
|
||||
dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
|
||||
return dec_out
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) # [B,pred_len,C]
|
||||
dec_out = torch.cat([torch.zeros_like(x_enc), dec_out], dim=1)
|
||||
return dec_out # [B, T, D]
|
||||
return None
|
145
models/TiDE.py
Normal file
145
models/TiDE.py
Normal file
@ -0,0 +1,145 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
""" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
|
||||
|
||||
def __init__(self, ndim, bias):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(ndim))
|
||||
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
|
||||
|
||||
def forward(self, input):
|
||||
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
|
||||
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.1, bias=True):
|
||||
super().__init__()
|
||||
|
||||
self.fc1 = nn.Linear(input_dim, hidden_dim, bias=bias)
|
||||
self.fc2 = nn.Linear(hidden_dim, output_dim, bias=bias)
|
||||
self.fc3 = nn.Linear(input_dim, output_dim, bias=bias)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.relu = nn.ReLU()
|
||||
self.ln = LayerNorm(output_dim, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
out = self.fc1(x)
|
||||
out = self.relu(out)
|
||||
out = self.fc2(out)
|
||||
out = self.dropout(out)
|
||||
out = out + self.fc3(x)
|
||||
out = self.ln(out)
|
||||
return out
|
||||
|
||||
|
||||
#TiDE
|
||||
class Model(nn.Module):
|
||||
"""
|
||||
paper: https://arxiv.org/pdf/2304.08424.pdf
|
||||
"""
|
||||
def __init__(self, configs, bias=True, feature_encode_dim=2):
|
||||
super(Model, self).__init__()
|
||||
self.configs = configs
|
||||
self.task_name = configs.task_name
|
||||
self.seq_len = configs.seq_len #L
|
||||
self.label_len = configs.label_len
|
||||
self.pred_len = configs.pred_len #H
|
||||
self.hidden_dim=configs.d_model
|
||||
self.res_hidden=configs.d_model
|
||||
self.encoder_num=configs.e_layers
|
||||
self.decoder_num=configs.d_layers
|
||||
self.freq=configs.freq
|
||||
self.feature_encode_dim=feature_encode_dim
|
||||
self.decode_dim = configs.c_out
|
||||
self.temporalDecoderHidden=configs.d_ff
|
||||
dropout=configs.dropout
|
||||
|
||||
|
||||
freq_map = {'h': 4, 't': 5, 's': 6,
|
||||
'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3}
|
||||
|
||||
self.feature_dim=freq_map[self.freq]
|
||||
|
||||
|
||||
flatten_dim = self.seq_len + (self.seq_len + self.pred_len) * self.feature_encode_dim
|
||||
|
||||
self.feature_encoder = ResBlock(self.feature_dim, self.res_hidden, self.feature_encode_dim, dropout, bias)
|
||||
self.encoders = nn.Sequential(ResBlock(flatten_dim, self.res_hidden, self.hidden_dim, dropout, bias),*([ ResBlock(self.hidden_dim, self.res_hidden, self.hidden_dim, dropout, bias)]*(self.encoder_num-1)))
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
self.decoders = nn.Sequential(*([ ResBlock(self.hidden_dim, self.res_hidden, self.hidden_dim, dropout, bias)]*(self.decoder_num-1)),ResBlock(self.hidden_dim, self.res_hidden, self.decode_dim * self.pred_len, dropout, bias))
|
||||
self.temporalDecoder = ResBlock(self.decode_dim + self.feature_encode_dim, self.temporalDecoderHidden, 1, dropout, bias)
|
||||
self.residual_proj = nn.Linear(self.seq_len, self.pred_len, bias=bias)
|
||||
if self.task_name == 'imputation':
|
||||
self.decoders = nn.Sequential(*([ ResBlock(self.hidden_dim, self.res_hidden, self.hidden_dim, dropout, bias)]*(self.decoder_num-1)),ResBlock(self.hidden_dim, self.res_hidden, self.decode_dim * self.seq_len, dropout, bias))
|
||||
self.temporalDecoder = ResBlock(self.decode_dim + self.feature_encode_dim, self.temporalDecoderHidden, 1, dropout, bias)
|
||||
self.residual_proj = nn.Linear(self.seq_len, self.seq_len, bias=bias)
|
||||
if self.task_name == 'anomaly_detection':
|
||||
self.decoders = nn.Sequential(*([ ResBlock(self.hidden_dim, self.res_hidden, self.hidden_dim, dropout, bias)]*(self.decoder_num-1)),ResBlock(self.hidden_dim, self.res_hidden, self.decode_dim * self.seq_len, dropout, bias))
|
||||
self.temporalDecoder = ResBlock(self.decode_dim + self.feature_encode_dim, self.temporalDecoderHidden, 1, dropout, bias)
|
||||
self.residual_proj = nn.Linear(self.seq_len, self.seq_len, bias=bias)
|
||||
|
||||
|
||||
def forecast(self, x_enc, x_mark_enc, x_dec, batch_y_mark):
|
||||
# Normalization
|
||||
means = x_enc.mean(1, keepdim=True).detach()
|
||||
x_enc = x_enc - means
|
||||
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
||||
x_enc /= stdev
|
||||
|
||||
feature = self.feature_encoder(batch_y_mark)
|
||||
hidden = self.encoders(torch.cat([x_enc, feature.reshape(feature.shape[0], -1)], dim=-1))
|
||||
decoded = self.decoders(hidden).reshape(hidden.shape[0], self.pred_len, self.decode_dim)
|
||||
dec_out = self.temporalDecoder(torch.cat([feature[:,self.seq_len:], decoded], dim=-1)).squeeze(-1) + self.residual_proj(x_enc)
|
||||
|
||||
|
||||
# De-Normalization
|
||||
dec_out = dec_out * (stdev[:, 0].unsqueeze(1).repeat(1, self.pred_len))
|
||||
dec_out = dec_out + (means[:, 0].unsqueeze(1).repeat(1, self.pred_len))
|
||||
return dec_out
|
||||
|
||||
def imputation(self, x_enc, x_mark_enc, x_dec, batch_y_mark, mask):
|
||||
# Normalization
|
||||
means = x_enc.mean(1, keepdim=True).detach()
|
||||
x_enc = x_enc - means
|
||||
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
||||
x_enc /= stdev
|
||||
|
||||
feature = self.feature_encoder(x_mark_enc)
|
||||
hidden = self.encoders(torch.cat([x_enc, feature.reshape(feature.shape[0], -1)], dim=-1))
|
||||
decoded = self.decoders(hidden).reshape(hidden.shape[0], self.seq_len, self.decode_dim)
|
||||
dec_out = self.temporalDecoder(torch.cat([feature[:,:self.seq_len], decoded], dim=-1)).squeeze(-1) + self.residual_proj(x_enc)
|
||||
|
||||
# De-Normalization
|
||||
dec_out = dec_out * (stdev[:, 0].unsqueeze(1).repeat(1, self.seq_len))
|
||||
dec_out = dec_out + (means[:, 0].unsqueeze(1).repeat(1, self.seq_len))
|
||||
return dec_out
|
||||
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, batch_y_mark, mask=None):
|
||||
'''x_mark_enc is the exogenous dynamic feature described in the original paper'''
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
if batch_y_mark is None:
|
||||
batch_y_mark = torch.zeros((x_enc.shape[0], self.seq_len+self.pred_len, self.feature_dim)).to(x_enc.device).detach()
|
||||
else:
|
||||
batch_y_mark = torch.concat([x_mark_enc, batch_y_mark[:, -self.pred_len:, :]],dim=1)
|
||||
dec_out = torch.stack([self.forecast(x_enc[:, :, feature], x_mark_enc, x_dec, batch_y_mark) for feature in range(x_enc.shape[-1])],dim=-1)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'imputation':
|
||||
dec_out = torch.stack([self.imputation(x_enc[:, :, feature], x_mark_enc, x_dec, batch_y_mark, mask) for feature in range(x_enc.shape[-1])],dim=-1)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'anomaly_detection':
|
||||
raise NotImplementedError("Task anomaly_detection for Tide is temporarily not supported")
|
||||
if self.task_name == 'classification':
|
||||
raise NotImplementedError("Task classification for Tide is temporarily not supported")
|
||||
return None
|
||||
|
||||
|
||||
|
||||
|
||||
|
516
models/TimeMixer.py
Executable file
516
models/TimeMixer.py
Executable file
@ -0,0 +1,516 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from layers.Autoformer_EncDec import series_decomp
|
||||
from layers.Embed import DataEmbedding_wo_pos
|
||||
from layers.StandardNorm import Normalize
|
||||
|
||||
|
||||
class DFT_series_decomp(nn.Module):
|
||||
"""
|
||||
Series decomposition block
|
||||
"""
|
||||
|
||||
def __init__(self, top_k: int = 5):
|
||||
super(DFT_series_decomp, self).__init__()
|
||||
self.top_k = top_k
|
||||
|
||||
def forward(self, x):
|
||||
xf = torch.fft.rfft(x)
|
||||
freq = abs(xf)
|
||||
freq[0] = 0
|
||||
top_k_freq, top_list = torch.topk(freq, k=self.top_k)
|
||||
xf[freq <= top_k_freq.min()] = 0
|
||||
x_season = torch.fft.irfft(xf)
|
||||
x_trend = x - x_season
|
||||
return x_season, x_trend
|
||||
|
||||
|
||||
class MultiScaleSeasonMixing(nn.Module):
|
||||
"""
|
||||
Bottom-up mixing season pattern
|
||||
"""
|
||||
|
||||
def __init__(self, configs):
|
||||
super(MultiScaleSeasonMixing, self).__init__()
|
||||
|
||||
self.down_sampling_layers = torch.nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
torch.nn.Linear(
|
||||
configs.seq_len // (configs.down_sampling_window ** i),
|
||||
configs.seq_len // (configs.down_sampling_window ** (i + 1)),
|
||||
),
|
||||
nn.GELU(),
|
||||
torch.nn.Linear(
|
||||
configs.seq_len // (configs.down_sampling_window ** (i + 1)),
|
||||
configs.seq_len // (configs.down_sampling_window ** (i + 1)),
|
||||
),
|
||||
|
||||
)
|
||||
for i in range(configs.down_sampling_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, season_list):
|
||||
|
||||
# mixing high->low
|
||||
out_high = season_list[0]
|
||||
out_low = season_list[1]
|
||||
out_season_list = [out_high.permute(0, 2, 1)]
|
||||
|
||||
for i in range(len(season_list) - 1):
|
||||
out_low_res = self.down_sampling_layers[i](out_high)
|
||||
out_low = out_low + out_low_res
|
||||
out_high = out_low
|
||||
if i + 2 <= len(season_list) - 1:
|
||||
out_low = season_list[i + 2]
|
||||
out_season_list.append(out_high.permute(0, 2, 1))
|
||||
|
||||
return out_season_list
|
||||
|
||||
|
||||
class MultiScaleTrendMixing(nn.Module):
|
||||
"""
|
||||
Top-down mixing trend pattern
|
||||
"""
|
||||
|
||||
def __init__(self, configs):
|
||||
super(MultiScaleTrendMixing, self).__init__()
|
||||
|
||||
self.up_sampling_layers = torch.nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
torch.nn.Linear(
|
||||
configs.seq_len // (configs.down_sampling_window ** (i + 1)),
|
||||
configs.seq_len // (configs.down_sampling_window ** i),
|
||||
),
|
||||
nn.GELU(),
|
||||
torch.nn.Linear(
|
||||
configs.seq_len // (configs.down_sampling_window ** i),
|
||||
configs.seq_len // (configs.down_sampling_window ** i),
|
||||
),
|
||||
)
|
||||
for i in reversed(range(configs.down_sampling_layers))
|
||||
])
|
||||
|
||||
def forward(self, trend_list):
|
||||
|
||||
# mixing low->high
|
||||
trend_list_reverse = trend_list.copy()
|
||||
trend_list_reverse.reverse()
|
||||
out_low = trend_list_reverse[0]
|
||||
out_high = trend_list_reverse[1]
|
||||
out_trend_list = [out_low.permute(0, 2, 1)]
|
||||
|
||||
for i in range(len(trend_list_reverse) - 1):
|
||||
out_high_res = self.up_sampling_layers[i](out_low)
|
||||
out_high = out_high + out_high_res
|
||||
out_low = out_high
|
||||
if i + 2 <= len(trend_list_reverse) - 1:
|
||||
out_high = trend_list_reverse[i + 2]
|
||||
out_trend_list.append(out_low.permute(0, 2, 1))
|
||||
|
||||
out_trend_list.reverse()
|
||||
return out_trend_list
|
||||
|
||||
|
||||
class PastDecomposableMixing(nn.Module):
|
||||
def __init__(self, configs):
|
||||
super(PastDecomposableMixing, self).__init__()
|
||||
self.seq_len = configs.seq_len
|
||||
self.pred_len = configs.pred_len
|
||||
self.down_sampling_window = configs.down_sampling_window
|
||||
|
||||
self.layer_norm = nn.LayerNorm(configs.d_model)
|
||||
self.dropout = nn.Dropout(configs.dropout)
|
||||
self.channel_independence = configs.channel_independence
|
||||
|
||||
if configs.decomp_method == 'moving_avg':
|
||||
self.decompsition = series_decomp(configs.moving_avg)
|
||||
elif configs.decomp_method == "dft_decomp":
|
||||
self.decompsition = DFT_series_decomp(configs.top_k)
|
||||
else:
|
||||
raise ValueError('decompsition is error')
|
||||
|
||||
if not configs.channel_independence:
|
||||
self.cross_layer = nn.Sequential(
|
||||
nn.Linear(in_features=configs.d_model, out_features=configs.d_ff),
|
||||
nn.GELU(),
|
||||
nn.Linear(in_features=configs.d_ff, out_features=configs.d_model),
|
||||
)
|
||||
|
||||
# Mixing season
|
||||
self.mixing_multi_scale_season = MultiScaleSeasonMixing(configs)
|
||||
|
||||
# Mxing trend
|
||||
self.mixing_multi_scale_trend = MultiScaleTrendMixing(configs)
|
||||
|
||||
self.out_cross_layer = nn.Sequential(
|
||||
nn.Linear(in_features=configs.d_model, out_features=configs.d_ff),
|
||||
nn.GELU(),
|
||||
nn.Linear(in_features=configs.d_ff, out_features=configs.d_model),
|
||||
)
|
||||
|
||||
def forward(self, x_list):
|
||||
length_list = []
|
||||
for x in x_list:
|
||||
_, T, _ = x.size()
|
||||
length_list.append(T)
|
||||
|
||||
# Decompose to obtain the season and trend
|
||||
season_list = []
|
||||
trend_list = []
|
||||
for x in x_list:
|
||||
season, trend = self.decompsition(x)
|
||||
if not self.channel_independence:
|
||||
season = self.cross_layer(season)
|
||||
trend = self.cross_layer(trend)
|
||||
season_list.append(season.permute(0, 2, 1))
|
||||
trend_list.append(trend.permute(0, 2, 1))
|
||||
|
||||
# bottom-up season mixing
|
||||
out_season_list = self.mixing_multi_scale_season(season_list)
|
||||
# top-down trend mixing
|
||||
out_trend_list = self.mixing_multi_scale_trend(trend_list)
|
||||
|
||||
out_list = []
|
||||
for ori, out_season, out_trend, length in zip(x_list, out_season_list, out_trend_list,
|
||||
length_list):
|
||||
out = out_season + out_trend
|
||||
if self.channel_independence:
|
||||
out = ori + self.out_cross_layer(out)
|
||||
out_list.append(out[:, :length, :])
|
||||
return out_list
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
|
||||
def __init__(self, configs):
|
||||
super(Model, self).__init__()
|
||||
self.configs = configs
|
||||
self.task_name = configs.task_name
|
||||
self.seq_len = configs.seq_len
|
||||
self.label_len = configs.label_len
|
||||
self.pred_len = configs.pred_len
|
||||
self.down_sampling_window = configs.down_sampling_window
|
||||
self.channel_independence = configs.channel_independence
|
||||
self.pdm_blocks = nn.ModuleList([PastDecomposableMixing(configs)
|
||||
for _ in range(configs.e_layers)])
|
||||
|
||||
self.preprocess = series_decomp(configs.moving_avg)
|
||||
self.enc_in = configs.enc_in
|
||||
|
||||
if self.channel_independence:
|
||||
self.enc_embedding = DataEmbedding_wo_pos(1, configs.d_model, configs.embed, configs.freq,
|
||||
configs.dropout)
|
||||
else:
|
||||
self.enc_embedding = DataEmbedding_wo_pos(configs.enc_in, configs.d_model, configs.embed, configs.freq,
|
||||
configs.dropout)
|
||||
|
||||
self.layer = configs.e_layers
|
||||
|
||||
self.normalize_layers = torch.nn.ModuleList(
|
||||
[
|
||||
Normalize(self.configs.enc_in, affine=True, non_norm=True if configs.use_norm == 0 else False)
|
||||
for i in range(configs.down_sampling_layers + 1)
|
||||
]
|
||||
)
|
||||
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
self.predict_layers = torch.nn.ModuleList(
|
||||
[
|
||||
torch.nn.Linear(
|
||||
configs.seq_len // (configs.down_sampling_window ** i),
|
||||
configs.pred_len,
|
||||
)
|
||||
for i in range(configs.down_sampling_layers + 1)
|
||||
]
|
||||
)
|
||||
|
||||
if self.channel_independence:
|
||||
self.projection_layer = nn.Linear(
|
||||
configs.d_model, 1, bias=True)
|
||||
else:
|
||||
self.projection_layer = nn.Linear(
|
||||
configs.d_model, configs.c_out, bias=True)
|
||||
|
||||
self.out_res_layers = torch.nn.ModuleList([
|
||||
torch.nn.Linear(
|
||||
configs.seq_len // (configs.down_sampling_window ** i),
|
||||
configs.seq_len // (configs.down_sampling_window ** i),
|
||||
)
|
||||
for i in range(configs.down_sampling_layers + 1)
|
||||
])
|
||||
|
||||
self.regression_layers = torch.nn.ModuleList(
|
||||
[
|
||||
torch.nn.Linear(
|
||||
configs.seq_len // (configs.down_sampling_window ** i),
|
||||
configs.pred_len,
|
||||
)
|
||||
for i in range(configs.down_sampling_layers + 1)
|
||||
]
|
||||
)
|
||||
|
||||
if self.task_name == 'imputation' or self.task_name == 'anomaly_detection':
|
||||
if self.channel_independence:
|
||||
self.projection_layer = nn.Linear(
|
||||
configs.d_model, 1, bias=True)
|
||||
else:
|
||||
self.projection_layer = nn.Linear(
|
||||
configs.d_model, configs.c_out, bias=True)
|
||||
if self.task_name == 'classification':
|
||||
self.act = F.gelu
|
||||
self.dropout = nn.Dropout(configs.dropout)
|
||||
self.projection = nn.Linear(
|
||||
configs.d_model * configs.seq_len, configs.num_class)
|
||||
|
||||
def out_projection(self, dec_out, i, out_res):
|
||||
dec_out = self.projection_layer(dec_out)
|
||||
out_res = out_res.permute(0, 2, 1)
|
||||
out_res = self.out_res_layers[i](out_res)
|
||||
out_res = self.regression_layers[i](out_res).permute(0, 2, 1)
|
||||
dec_out = dec_out + out_res
|
||||
return dec_out
|
||||
|
||||
def pre_enc(self, x_list):
|
||||
if self.channel_independence:
|
||||
return (x_list, None)
|
||||
else:
|
||||
out1_list = []
|
||||
out2_list = []
|
||||
for x in x_list:
|
||||
x_1, x_2 = self.preprocess(x)
|
||||
out1_list.append(x_1)
|
||||
out2_list.append(x_2)
|
||||
return (out1_list, out2_list)
|
||||
|
||||
def __multi_scale_process_inputs(self, x_enc, x_mark_enc):
|
||||
if self.configs.down_sampling_method == 'max':
|
||||
down_pool = torch.nn.MaxPool1d(self.configs.down_sampling_window, return_indices=False)
|
||||
elif self.configs.down_sampling_method == 'avg':
|
||||
down_pool = torch.nn.AvgPool1d(self.configs.down_sampling_window)
|
||||
elif self.configs.down_sampling_method == 'conv':
|
||||
padding = 1 if torch.__version__ >= '1.5.0' else 2
|
||||
down_pool = nn.Conv1d(in_channels=self.configs.enc_in, out_channels=self.configs.enc_in,
|
||||
kernel_size=3, padding=padding,
|
||||
stride=self.configs.down_sampling_window,
|
||||
padding_mode='circular',
|
||||
bias=False)
|
||||
else:
|
||||
return x_enc, x_mark_enc
|
||||
# B,T,C -> B,C,T
|
||||
x_enc = x_enc.permute(0, 2, 1)
|
||||
|
||||
x_enc_ori = x_enc
|
||||
x_mark_enc_mark_ori = x_mark_enc
|
||||
|
||||
x_enc_sampling_list = []
|
||||
x_mark_sampling_list = []
|
||||
x_enc_sampling_list.append(x_enc.permute(0, 2, 1))
|
||||
x_mark_sampling_list.append(x_mark_enc)
|
||||
|
||||
for i in range(self.configs.down_sampling_layers):
|
||||
x_enc_sampling = down_pool(x_enc_ori)
|
||||
|
||||
x_enc_sampling_list.append(x_enc_sampling.permute(0, 2, 1))
|
||||
x_enc_ori = x_enc_sampling
|
||||
|
||||
if x_mark_enc is not None:
|
||||
x_mark_sampling_list.append(x_mark_enc_mark_ori[:, ::self.configs.down_sampling_window, :])
|
||||
x_mark_enc_mark_ori = x_mark_enc_mark_ori[:, ::self.configs.down_sampling_window, :]
|
||||
|
||||
x_enc = x_enc_sampling_list
|
||||
x_mark_enc = x_mark_sampling_list if x_mark_enc is not None else None
|
||||
|
||||
return x_enc, x_mark_enc
|
||||
|
||||
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
||||
|
||||
x_enc, x_mark_enc = self.__multi_scale_process_inputs(x_enc, x_mark_enc)
|
||||
|
||||
x_list = []
|
||||
x_mark_list = []
|
||||
if x_mark_enc is not None:
|
||||
for i, x, x_mark in zip(range(len(x_enc)), x_enc, x_mark_enc):
|
||||
B, T, N = x.size()
|
||||
x = self.normalize_layers[i](x, 'norm')
|
||||
if self.channel_independence:
|
||||
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
|
||||
x_list.append(x)
|
||||
x_mark = x_mark.repeat(N, 1, 1)
|
||||
x_mark_list.append(x_mark)
|
||||
else:
|
||||
x_list.append(x)
|
||||
x_mark_list.append(x_mark)
|
||||
else:
|
||||
for i, x in zip(range(len(x_enc)), x_enc, ):
|
||||
B, T, N = x.size()
|
||||
x = self.normalize_layers[i](x, 'norm')
|
||||
if self.channel_independence:
|
||||
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
|
||||
x_list.append(x)
|
||||
|
||||
# embedding
|
||||
enc_out_list = []
|
||||
x_list = self.pre_enc(x_list)
|
||||
if x_mark_enc is not None:
|
||||
for i, x, x_mark in zip(range(len(x_list[0])), x_list[0], x_mark_list):
|
||||
enc_out = self.enc_embedding(x, x_mark) # [B,T,C]
|
||||
enc_out_list.append(enc_out)
|
||||
else:
|
||||
for i, x in zip(range(len(x_list[0])), x_list[0]):
|
||||
enc_out = self.enc_embedding(x, None) # [B,T,C]
|
||||
enc_out_list.append(enc_out)
|
||||
|
||||
# Past Decomposable Mixing as encoder for past
|
||||
for i in range(self.layer):
|
||||
enc_out_list = self.pdm_blocks[i](enc_out_list)
|
||||
|
||||
# Future Multipredictor Mixing as decoder for future
|
||||
dec_out_list = self.future_multi_mixing(B, enc_out_list, x_list)
|
||||
|
||||
dec_out = torch.stack(dec_out_list, dim=-1).sum(-1)
|
||||
dec_out = self.normalize_layers[0](dec_out, 'denorm')
|
||||
return dec_out
|
||||
|
||||
def future_multi_mixing(self, B, enc_out_list, x_list):
|
||||
dec_out_list = []
|
||||
if self.channel_independence:
|
||||
x_list = x_list[0]
|
||||
for i, enc_out in zip(range(len(x_list)), enc_out_list):
|
||||
dec_out = self.predict_layers[i](enc_out.permute(0, 2, 1)).permute(
|
||||
0, 2, 1) # align temporal dimension
|
||||
dec_out = self.projection_layer(dec_out)
|
||||
dec_out = dec_out.reshape(B, self.configs.c_out, self.pred_len).permute(0, 2, 1).contiguous()
|
||||
dec_out_list.append(dec_out)
|
||||
|
||||
else:
|
||||
for i, enc_out, out_res in zip(range(len(x_list[0])), enc_out_list, x_list[1]):
|
||||
dec_out = self.predict_layers[i](enc_out.permute(0, 2, 1)).permute(
|
||||
0, 2, 1) # align temporal dimension
|
||||
dec_out = self.out_projection(dec_out, i, out_res)
|
||||
dec_out_list.append(dec_out)
|
||||
|
||||
return dec_out_list
|
||||
|
||||
def classification(self, x_enc, x_mark_enc):
|
||||
x_enc, _ = self.__multi_scale_process_inputs(x_enc, None)
|
||||
x_list = x_enc
|
||||
|
||||
# embedding
|
||||
enc_out_list = []
|
||||
for x in x_list:
|
||||
enc_out = self.enc_embedding(x, None) # [B,T,C]
|
||||
enc_out_list.append(enc_out)
|
||||
|
||||
# MultiScale-CrissCrossAttention as encoder for past
|
||||
for i in range(self.layer):
|
||||
enc_out_list = self.pdm_blocks[i](enc_out_list)
|
||||
|
||||
enc_out = enc_out_list[0]
|
||||
# Output
|
||||
# the output transformer encoder/decoder embeddings don't include non-linearity
|
||||
output = self.act(enc_out)
|
||||
output = self.dropout(output)
|
||||
# zero-out padding embeddings
|
||||
output = output * x_mark_enc.unsqueeze(-1)
|
||||
# (batch_size, seq_length * d_model)
|
||||
output = output.reshape(output.shape[0], -1)
|
||||
output = self.projection(output) # (batch_size, num_classes)
|
||||
return output
|
||||
|
||||
def anomaly_detection(self, x_enc):
|
||||
B, T, N = x_enc.size()
|
||||
x_enc, _ = self.__multi_scale_process_inputs(x_enc, None)
|
||||
|
||||
x_list = []
|
||||
|
||||
for i, x in zip(range(len(x_enc)), x_enc, ):
|
||||
B, T, N = x.size()
|
||||
x = self.normalize_layers[i](x, 'norm')
|
||||
if self.channel_independence:
|
||||
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
|
||||
x_list.append(x)
|
||||
|
||||
# embedding
|
||||
enc_out_list = []
|
||||
for x in x_list:
|
||||
enc_out = self.enc_embedding(x, None) # [B,T,C]
|
||||
enc_out_list.append(enc_out)
|
||||
|
||||
# MultiScale-CrissCrossAttention as encoder for past
|
||||
for i in range(self.layer):
|
||||
enc_out_list = self.pdm_blocks[i](enc_out_list)
|
||||
|
||||
dec_out = self.projection_layer(enc_out_list[0])
|
||||
dec_out = dec_out.reshape(B, self.configs.c_out, -1).permute(0, 2, 1).contiguous()
|
||||
|
||||
dec_out = self.normalize_layers[0](dec_out, 'denorm')
|
||||
return dec_out
|
||||
|
||||
def imputation(self, x_enc, x_mark_enc, mask):
|
||||
means = torch.sum(x_enc, dim=1) / torch.sum(mask == 1, dim=1)
|
||||
means = means.unsqueeze(1).detach()
|
||||
x_enc = x_enc - means
|
||||
x_enc = x_enc.masked_fill(mask == 0, 0)
|
||||
stdev = torch.sqrt(torch.sum(x_enc * x_enc, dim=1) /
|
||||
torch.sum(mask == 1, dim=1) + 1e-5)
|
||||
stdev = stdev.unsqueeze(1).detach()
|
||||
x_enc /= stdev
|
||||
|
||||
B, T, N = x_enc.size()
|
||||
x_enc, x_mark_enc = self.__multi_scale_process_inputs(x_enc, x_mark_enc)
|
||||
|
||||
x_list = []
|
||||
x_mark_list = []
|
||||
if x_mark_enc is not None:
|
||||
for i, x, x_mark in zip(range(len(x_enc)), x_enc, x_mark_enc):
|
||||
B, T, N = x.size()
|
||||
if self.channel_independence:
|
||||
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
|
||||
x_list.append(x)
|
||||
x_mark = x_mark.repeat(N, 1, 1)
|
||||
x_mark_list.append(x_mark)
|
||||
else:
|
||||
for i, x in zip(range(len(x_enc)), x_enc, ):
|
||||
B, T, N = x.size()
|
||||
if self.channel_independence:
|
||||
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
|
||||
x_list.append(x)
|
||||
|
||||
# embedding
|
||||
enc_out_list = []
|
||||
for x in x_list:
|
||||
enc_out = self.enc_embedding(x, None) # [B,T,C]
|
||||
enc_out_list.append(enc_out)
|
||||
|
||||
# MultiScale-CrissCrossAttention as encoder for past
|
||||
for i in range(self.layer):
|
||||
enc_out_list = self.pdm_blocks[i](enc_out_list)
|
||||
|
||||
dec_out = self.projection_layer(enc_out_list[0])
|
||||
dec_out = dec_out.reshape(B, self.configs.c_out, -1).permute(0, 2, 1).contiguous()
|
||||
|
||||
dec_out = dec_out * \
|
||||
(stdev[:, 0, :].unsqueeze(1).repeat(1, self.seq_len, 1))
|
||||
dec_out = dec_out + \
|
||||
(means[:, 0, :].unsqueeze(1).repeat(1, self.seq_len, 1))
|
||||
return dec_out
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
||||
return dec_out
|
||||
if self.task_name == 'imputation':
|
||||
dec_out = self.imputation(x_enc, x_mark_enc, mask)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'anomaly_detection':
|
||||
dec_out = self.anomaly_detection(x_enc)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'classification':
|
||||
dec_out = self.classification(x_enc, x_mark_enc)
|
||||
return dec_out # [B, N]
|
||||
else:
|
||||
raise ValueError('Other tasks implemented yet')
|
225
models/TimeXer.py
Normal file
225
models/TimeXer.py
Normal file
@ -0,0 +1,225 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from layers.SelfAttention_Family import FullAttention, AttentionLayer
|
||||
from layers.Embed import DataEmbedding_inverted, PositionalEmbedding
|
||||
import numpy as np
|
||||
|
||||
|
||||
class FlattenHead(nn.Module):
|
||||
def __init__(self, n_vars, nf, target_window, head_dropout=0):
|
||||
super().__init__()
|
||||
self.n_vars = n_vars
|
||||
self.flatten = nn.Flatten(start_dim=-2)
|
||||
self.linear = nn.Linear(nf, target_window)
|
||||
self.dropout = nn.Dropout(head_dropout)
|
||||
|
||||
def forward(self, x): # x: [bs x nvars x d_model x patch_num]
|
||||
x = self.flatten(x)
|
||||
x = self.linear(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class EnEmbedding(nn.Module):
|
||||
def __init__(self, n_vars, d_model, patch_len, dropout):
|
||||
super(EnEmbedding, self).__init__()
|
||||
# Patching
|
||||
self.patch_len = patch_len
|
||||
|
||||
self.value_embedding = nn.Linear(patch_len, d_model, bias=False)
|
||||
self.glb_token = nn.Parameter(torch.randn(1, n_vars, 1, d_model))
|
||||
self.position_embedding = PositionalEmbedding(d_model)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
# do patching
|
||||
n_vars = x.shape[1]
|
||||
glb = self.glb_token.repeat((x.shape[0], 1, 1, 1))
|
||||
|
||||
x = x.unfold(dimension=-1, size=self.patch_len, step=self.patch_len)
|
||||
x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
|
||||
# Input encoding
|
||||
x = self.value_embedding(x) + self.position_embedding(x)
|
||||
x = torch.reshape(x, (-1, n_vars, x.shape[-2], x.shape[-1]))
|
||||
x = torch.cat([x, glb], dim=2)
|
||||
x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
|
||||
return self.dropout(x), n_vars
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, layers, norm_layer=None, projection=None):
|
||||
super(Encoder, self).__init__()
|
||||
self.layers = nn.ModuleList(layers)
|
||||
self.norm = norm_layer
|
||||
self.projection = projection
|
||||
|
||||
def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
|
||||
for layer in self.layers:
|
||||
x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta)
|
||||
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
|
||||
if self.projection is not None:
|
||||
x = self.projection(x)
|
||||
return x
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
def __init__(self, self_attention, cross_attention, d_model, d_ff=None,
|
||||
dropout=0.1, activation="relu"):
|
||||
super(EncoderLayer, self).__init__()
|
||||
d_ff = d_ff or 4 * d_model
|
||||
self.self_attention = self_attention
|
||||
self.cross_attention = cross_attention
|
||||
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
|
||||
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.norm3 = nn.LayerNorm(d_model)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.activation = F.relu if activation == "relu" else F.gelu
|
||||
|
||||
def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
|
||||
B, L, D = cross.shape
|
||||
x = x + self.dropout(self.self_attention(
|
||||
x, x, x,
|
||||
attn_mask=x_mask,
|
||||
tau=tau, delta=None
|
||||
)[0])
|
||||
x = self.norm1(x)
|
||||
|
||||
x_glb_ori = x[:, -1, :].unsqueeze(1)
|
||||
x_glb = torch.reshape(x_glb_ori, (B, -1, D))
|
||||
x_glb_attn = self.dropout(self.cross_attention(
|
||||
x_glb, cross, cross,
|
||||
attn_mask=cross_mask,
|
||||
tau=tau, delta=delta
|
||||
)[0])
|
||||
x_glb_attn = torch.reshape(x_glb_attn,
|
||||
(x_glb_attn.shape[0] * x_glb_attn.shape[1], x_glb_attn.shape[2])).unsqueeze(1)
|
||||
x_glb = x_glb_ori + x_glb_attn
|
||||
x_glb = self.norm2(x_glb)
|
||||
|
||||
y = x = torch.cat([x[:, :-1, :], x_glb], dim=1)
|
||||
|
||||
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
|
||||
y = self.dropout(self.conv2(y).transpose(-1, 1))
|
||||
|
||||
return self.norm3(x + y)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
|
||||
def __init__(self, configs):
|
||||
super(Model, self).__init__()
|
||||
self.task_name = configs.task_name
|
||||
self.features = configs.features
|
||||
self.seq_len = configs.seq_len
|
||||
self.pred_len = configs.pred_len
|
||||
self.use_norm = configs.use_norm
|
||||
self.patch_len = configs.patch_len
|
||||
self.patch_num = int(configs.seq_len // configs.patch_len)
|
||||
self.n_vars = 1 if configs.features == 'MS' else configs.enc_in
|
||||
# Embedding
|
||||
self.en_embedding = EnEmbedding(self.n_vars, configs.d_model, self.patch_len, configs.dropout)
|
||||
|
||||
self.ex_embedding = DataEmbedding_inverted(configs.seq_len, configs.d_model, configs.embed, configs.freq,
|
||||
configs.dropout)
|
||||
|
||||
# Encoder-only architecture
|
||||
self.encoder = Encoder(
|
||||
[
|
||||
EncoderLayer(
|
||||
AttentionLayer(
|
||||
FullAttention(False, configs.factor, attention_dropout=configs.dropout,
|
||||
output_attention=False),
|
||||
configs.d_model, configs.n_heads),
|
||||
AttentionLayer(
|
||||
FullAttention(False, configs.factor, attention_dropout=configs.dropout,
|
||||
output_attention=False),
|
||||
configs.d_model, configs.n_heads),
|
||||
configs.d_model,
|
||||
configs.d_ff,
|
||||
dropout=configs.dropout,
|
||||
activation=configs.activation,
|
||||
)
|
||||
for l in range(configs.e_layers)
|
||||
],
|
||||
norm_layer=torch.nn.LayerNorm(configs.d_model)
|
||||
)
|
||||
self.head_nf = configs.d_model * (self.patch_num + 1)
|
||||
self.head = FlattenHead(configs.enc_in, self.head_nf, configs.pred_len,
|
||||
head_dropout=configs.dropout)
|
||||
|
||||
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
||||
if self.use_norm:
|
||||
# Normalization from Non-stationary Transformer
|
||||
means = x_enc.mean(1, keepdim=True).detach()
|
||||
x_enc = x_enc - means
|
||||
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
||||
x_enc /= stdev
|
||||
|
||||
_, _, N = x_enc.shape
|
||||
|
||||
en_embed, n_vars = self.en_embedding(x_enc[:, :, -1].unsqueeze(-1).permute(0, 2, 1))
|
||||
ex_embed = self.ex_embedding(x_enc[:, :, :-1], x_mark_enc)
|
||||
|
||||
enc_out = self.encoder(en_embed, ex_embed)
|
||||
enc_out = torch.reshape(
|
||||
enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1]))
|
||||
# z: [bs x nvars x d_model x patch_num]
|
||||
enc_out = enc_out.permute(0, 1, 3, 2)
|
||||
|
||||
dec_out = self.head(enc_out) # z: [bs x nvars x target_window]
|
||||
dec_out = dec_out.permute(0, 2, 1)
|
||||
|
||||
if self.use_norm:
|
||||
# De-Normalization from Non-stationary Transformer
|
||||
dec_out = dec_out * (stdev[:, 0, -1:].unsqueeze(1).repeat(1, self.pred_len, 1))
|
||||
dec_out = dec_out + (means[:, 0, -1:].unsqueeze(1).repeat(1, self.pred_len, 1))
|
||||
|
||||
return dec_out
|
||||
|
||||
|
||||
def forecast_multi(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
||||
if self.use_norm:
|
||||
# Normalization from Non-stationary Transformer
|
||||
means = x_enc.mean(1, keepdim=True).detach()
|
||||
x_enc = x_enc - means
|
||||
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
||||
x_enc /= stdev
|
||||
|
||||
_, _, N = x_enc.shape
|
||||
|
||||
en_embed, n_vars = self.en_embedding(x_enc.permute(0, 2, 1))
|
||||
ex_embed = self.ex_embedding(x_enc, x_mark_enc)
|
||||
|
||||
enc_out = self.encoder(en_embed, ex_embed)
|
||||
enc_out = torch.reshape(
|
||||
enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1]))
|
||||
# z: [bs x nvars x d_model x patch_num]
|
||||
enc_out = enc_out.permute(0, 1, 3, 2)
|
||||
|
||||
dec_out = self.head(enc_out) # z: [bs x nvars x target_window]
|
||||
dec_out = dec_out.permute(0, 2, 1)
|
||||
|
||||
if self.use_norm:
|
||||
# De-Normalization from Non-stationary Transformer
|
||||
dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
|
||||
dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
|
||||
|
||||
return dec_out
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
if self.features == 'M':
|
||||
dec_out = self.forecast_multi(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
||||
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
||||
else:
|
||||
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
||||
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
||||
else:
|
||||
return None
|
215
models/TimesNet.py
Normal file
215
models/TimesNet.py
Normal file
@ -0,0 +1,215 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.fft
|
||||
from layers.Embed import DataEmbedding
|
||||
from layers.Conv_Blocks import Inception_Block_V1
|
||||
|
||||
|
||||
def FFT_for_Period(x, k=2):
|
||||
# [B, T, C]
|
||||
xf = torch.fft.rfft(x, dim=1)
|
||||
# find period by amplitudes
|
||||
frequency_list = abs(xf).mean(0).mean(-1)
|
||||
frequency_list[0] = 0
|
||||
_, top_list = torch.topk(frequency_list, k)
|
||||
top_list = top_list.detach().cpu().numpy()
|
||||
period = x.shape[1] // top_list
|
||||
return period, abs(xf).mean(-1)[:, top_list]
|
||||
|
||||
|
||||
class TimesBlock(nn.Module):
|
||||
def __init__(self, configs):
|
||||
super(TimesBlock, self).__init__()
|
||||
self.seq_len = configs.seq_len
|
||||
self.pred_len = configs.pred_len
|
||||
self.k = configs.top_k
|
||||
# parameter-efficient design
|
||||
self.conv = nn.Sequential(
|
||||
Inception_Block_V1(configs.d_model, configs.d_ff,
|
||||
num_kernels=configs.num_kernels),
|
||||
nn.GELU(),
|
||||
Inception_Block_V1(configs.d_ff, configs.d_model,
|
||||
num_kernels=configs.num_kernels)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
B, T, N = x.size()
|
||||
period_list, period_weight = FFT_for_Period(x, self.k)
|
||||
|
||||
res = []
|
||||
for i in range(self.k):
|
||||
period = period_list[i]
|
||||
# padding
|
||||
if (self.seq_len + self.pred_len) % period != 0:
|
||||
length = (
|
||||
((self.seq_len + self.pred_len) // period) + 1) * period
|
||||
padding = torch.zeros([x.shape[0], (length - (self.seq_len + self.pred_len)), x.shape[2]]).to(x.device)
|
||||
out = torch.cat([x, padding], dim=1)
|
||||
else:
|
||||
length = (self.seq_len + self.pred_len)
|
||||
out = x
|
||||
# reshape
|
||||
out = out.reshape(B, length // period, period,
|
||||
N).permute(0, 3, 1, 2).contiguous()
|
||||
# 2D conv: from 1d Variation to 2d Variation
|
||||
out = self.conv(out)
|
||||
# reshape back
|
||||
out = out.permute(0, 2, 3, 1).reshape(B, -1, N)
|
||||
res.append(out[:, :(self.seq_len + self.pred_len), :])
|
||||
res = torch.stack(res, dim=-1)
|
||||
# adaptive aggregation
|
||||
period_weight = F.softmax(period_weight, dim=1)
|
||||
period_weight = period_weight.unsqueeze(
|
||||
1).unsqueeze(1).repeat(1, T, N, 1)
|
||||
res = torch.sum(res * period_weight, -1)
|
||||
# residual connection
|
||||
res = res + x
|
||||
return res
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
"""
|
||||
Paper link: https://openreview.net/pdf?id=ju_Uqw384Oq
|
||||
"""
|
||||
|
||||
def __init__(self, configs):
|
||||
super(Model, self).__init__()
|
||||
self.configs = configs
|
||||
self.task_name = configs.task_name
|
||||
self.seq_len = configs.seq_len
|
||||
self.label_len = configs.label_len
|
||||
self.pred_len = configs.pred_len
|
||||
self.model = nn.ModuleList([TimesBlock(configs)
|
||||
for _ in range(configs.e_layers)])
|
||||
self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,
|
||||
configs.dropout)
|
||||
self.layer = configs.e_layers
|
||||
self.layer_norm = nn.LayerNorm(configs.d_model)
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
self.predict_linear = nn.Linear(
|
||||
self.seq_len, self.pred_len + self.seq_len)
|
||||
self.projection = nn.Linear(
|
||||
configs.d_model, configs.c_out, bias=True)
|
||||
if self.task_name == 'imputation' or self.task_name == 'anomaly_detection':
|
||||
self.projection = nn.Linear(
|
||||
configs.d_model, configs.c_out, bias=True)
|
||||
if self.task_name == 'classification':
|
||||
self.act = F.gelu
|
||||
self.dropout = nn.Dropout(configs.dropout)
|
||||
self.projection = nn.Linear(
|
||||
configs.d_model * configs.seq_len, configs.num_class)
|
||||
|
||||
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
||||
# Normalization from Non-stationary Transformer
|
||||
means = x_enc.mean(1, keepdim=True).detach()
|
||||
x_enc = x_enc.sub(means)
|
||||
stdev = torch.sqrt(
|
||||
torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
||||
x_enc = x_enc.div(stdev)
|
||||
|
||||
# embedding
|
||||
enc_out = self.enc_embedding(x_enc, x_mark_enc) # [B,T,C]
|
||||
enc_out = self.predict_linear(enc_out.permute(0, 2, 1)).permute(
|
||||
0, 2, 1) # align temporal dimension
|
||||
# TimesNet
|
||||
for i in range(self.layer):
|
||||
enc_out = self.layer_norm(self.model[i](enc_out))
|
||||
# project back
|
||||
dec_out = self.projection(enc_out)
|
||||
|
||||
# De-Normalization from Non-stationary Transformer
|
||||
dec_out = dec_out.mul(
|
||||
(stdev[:, 0, :].unsqueeze(1).repeat(
|
||||
1, self.pred_len + self.seq_len, 1)))
|
||||
dec_out = dec_out.add(
|
||||
(means[:, 0, :].unsqueeze(1).repeat(
|
||||
1, self.pred_len + self.seq_len, 1)))
|
||||
return dec_out
|
||||
|
||||
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
|
||||
# Normalization from Non-stationary Transformer
|
||||
means = torch.sum(x_enc, dim=1) / torch.sum(mask == 1, dim=1)
|
||||
means = means.unsqueeze(1).detach()
|
||||
x_enc = x_enc.sub(means)
|
||||
x_enc = x_enc.masked_fill(mask == 0, 0)
|
||||
stdev = torch.sqrt(torch.sum(x_enc * x_enc, dim=1) /
|
||||
torch.sum(mask == 1, dim=1) + 1e-5)
|
||||
stdev = stdev.unsqueeze(1).detach()
|
||||
x_enc = x_enc.div(stdev)
|
||||
|
||||
# embedding
|
||||
enc_out = self.enc_embedding(x_enc, x_mark_enc) # [B,T,C]
|
||||
# TimesNet
|
||||
for i in range(self.layer):
|
||||
enc_out = self.layer_norm(self.model[i](enc_out))
|
||||
# project back
|
||||
dec_out = self.projection(enc_out)
|
||||
|
||||
# De-Normalization from Non-stationary Transformer
|
||||
dec_out = dec_out.mul(
|
||||
(stdev[:, 0, :].unsqueeze(1).repeat(
|
||||
1, self.pred_len + self.seq_len, 1)))
|
||||
dec_out = dec_out.add(
|
||||
(means[:, 0, :].unsqueeze(1).repeat(
|
||||
1, self.pred_len + self.seq_len, 1)))
|
||||
return dec_out
|
||||
|
||||
def anomaly_detection(self, x_enc):
|
||||
# Normalization from Non-stationary Transformer
|
||||
means = x_enc.mean(1, keepdim=True).detach()
|
||||
x_enc = x_enc.sub(means)
|
||||
stdev = torch.sqrt(
|
||||
torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
||||
x_enc = x_enc.div(stdev)
|
||||
|
||||
# embedding
|
||||
enc_out = self.enc_embedding(x_enc, None) # [B,T,C]
|
||||
# TimesNet
|
||||
for i in range(self.layer):
|
||||
enc_out = self.layer_norm(self.model[i](enc_out))
|
||||
# project back
|
||||
dec_out = self.projection(enc_out)
|
||||
|
||||
# De-Normalization from Non-stationary Transformer
|
||||
dec_out = dec_out.mul(
|
||||
(stdev[:, 0, :].unsqueeze(1).repeat(
|
||||
1, self.pred_len + self.seq_len, 1)))
|
||||
dec_out = dec_out.add(
|
||||
(means[:, 0, :].unsqueeze(1).repeat(
|
||||
1, self.pred_len + self.seq_len, 1)))
|
||||
return dec_out
|
||||
|
||||
def classification(self, x_enc, x_mark_enc):
|
||||
# embedding
|
||||
enc_out = self.enc_embedding(x_enc, None) # [B,T,C]
|
||||
# TimesNet
|
||||
for i in range(self.layer):
|
||||
enc_out = self.layer_norm(self.model[i](enc_out))
|
||||
|
||||
# Output
|
||||
# the output transformer encoder/decoder embeddings don't include non-linearity
|
||||
output = self.act(enc_out)
|
||||
output = self.dropout(output)
|
||||
# zero-out padding embeddings
|
||||
output = output * x_mark_enc.unsqueeze(-1)
|
||||
# (batch_size, seq_length * d_model)
|
||||
output = output.reshape(output.shape[0], -1)
|
||||
output = self.projection(output) # (batch_size, num_classes)
|
||||
return output
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
||||
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
||||
if self.task_name == 'imputation':
|
||||
dec_out = self.imputation(
|
||||
x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'anomaly_detection':
|
||||
dec_out = self.anomaly_detection(x_enc)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'classification':
|
||||
dec_out = self.classification(x_enc, x_mark_enc)
|
||||
return dec_out # [B, N]
|
||||
return None
|
124
models/Transformer.py
Normal file
124
models/Transformer.py
Normal file
@ -0,0 +1,124 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from layers.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer, ConvLayer
|
||||
from layers.SelfAttention_Family import FullAttention, AttentionLayer
|
||||
from layers.Embed import DataEmbedding
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
"""
|
||||
Vanilla Transformer
|
||||
with O(L^2) complexity
|
||||
Paper link: https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf
|
||||
"""
|
||||
|
||||
def __init__(self, configs):
|
||||
super(Model, self).__init__()
|
||||
self.task_name = configs.task_name
|
||||
self.pred_len = configs.pred_len
|
||||
# Embedding
|
||||
self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,
|
||||
configs.dropout)
|
||||
# Encoder
|
||||
self.encoder = Encoder(
|
||||
[
|
||||
EncoderLayer(
|
||||
AttentionLayer(
|
||||
FullAttention(False, configs.factor, attention_dropout=configs.dropout,
|
||||
output_attention=False), configs.d_model, configs.n_heads),
|
||||
configs.d_model,
|
||||
configs.d_ff,
|
||||
dropout=configs.dropout,
|
||||
activation=configs.activation
|
||||
) for l in range(configs.e_layers)
|
||||
],
|
||||
norm_layer=torch.nn.LayerNorm(configs.d_model)
|
||||
)
|
||||
# Decoder
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq,
|
||||
configs.dropout)
|
||||
self.decoder = Decoder(
|
||||
[
|
||||
DecoderLayer(
|
||||
AttentionLayer(
|
||||
FullAttention(True, configs.factor, attention_dropout=configs.dropout,
|
||||
output_attention=False),
|
||||
configs.d_model, configs.n_heads),
|
||||
AttentionLayer(
|
||||
FullAttention(False, configs.factor, attention_dropout=configs.dropout,
|
||||
output_attention=False),
|
||||
configs.d_model, configs.n_heads),
|
||||
configs.d_model,
|
||||
configs.d_ff,
|
||||
dropout=configs.dropout,
|
||||
activation=configs.activation,
|
||||
)
|
||||
for l in range(configs.d_layers)
|
||||
],
|
||||
norm_layer=torch.nn.LayerNorm(configs.d_model),
|
||||
projection=nn.Linear(configs.d_model, configs.c_out, bias=True)
|
||||
)
|
||||
if self.task_name == 'imputation':
|
||||
self.projection = nn.Linear(configs.d_model, configs.c_out, bias=True)
|
||||
if self.task_name == 'anomaly_detection':
|
||||
self.projection = nn.Linear(configs.d_model, configs.c_out, bias=True)
|
||||
if self.task_name == 'classification':
|
||||
self.act = F.gelu
|
||||
self.dropout = nn.Dropout(configs.dropout)
|
||||
self.projection = nn.Linear(configs.d_model * configs.seq_len, configs.num_class)
|
||||
|
||||
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
||||
# Embedding
|
||||
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
||||
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
||||
|
||||
dec_out = self.dec_embedding(x_dec, x_mark_dec)
|
||||
dec_out = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None)
|
||||
return dec_out
|
||||
|
||||
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
|
||||
# Embedding
|
||||
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
||||
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
||||
|
||||
dec_out = self.projection(enc_out)
|
||||
return dec_out
|
||||
|
||||
def anomaly_detection(self, x_enc):
|
||||
# Embedding
|
||||
enc_out = self.enc_embedding(x_enc, None)
|
||||
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
||||
|
||||
dec_out = self.projection(enc_out)
|
||||
return dec_out
|
||||
|
||||
def classification(self, x_enc, x_mark_enc):
|
||||
# Embedding
|
||||
enc_out = self.enc_embedding(x_enc, None)
|
||||
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
||||
|
||||
# Output
|
||||
output = self.act(enc_out) # the output transformer encoder/decoder embeddings don't include non-linearity
|
||||
output = self.dropout(output)
|
||||
output = output * x_mark_enc.unsqueeze(-1) # zero-out padding embeddings
|
||||
output = output.reshape(output.shape[0], -1) # (batch_size, seq_length * d_model)
|
||||
output = self.projection(output) # (batch_size, num_classes)
|
||||
return output
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
||||
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
||||
if self.task_name == 'imputation':
|
||||
dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'anomaly_detection':
|
||||
dec_out = self.anomaly_detection(x_enc)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'classification':
|
||||
dec_out = self.classification(x_enc, x_mark_enc)
|
||||
return dec_out # [B, N]
|
||||
return None
|
319
models/WPMixer.py
Normal file
319
models/WPMixer.py
Normal file
@ -0,0 +1,319 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Created on Sun Jan 5 16:10:01 2025
|
||||
@author: Murad
|
||||
SISLab, USF
|
||||
mmurad@usf.edu
|
||||
https://github.com/Secure-and-Intelligent-Systems-Lab/WPMixer
|
||||
"""
|
||||
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from layers.DWT_Decomposition import Decomposition
|
||||
|
||||
|
||||
class TokenMixer(nn.Module):
|
||||
def __init__(self, input_seq=[], batch_size=[], channel=[], pred_seq=[], dropout=[], factor=[], d_model=[]):
|
||||
super(TokenMixer, self).__init__()
|
||||
self.input_seq = input_seq
|
||||
self.batch_size = batch_size
|
||||
self.channel = channel
|
||||
self.pred_seq = pred_seq
|
||||
self.dropout = dropout
|
||||
self.factor = factor
|
||||
self.d_model = d_model
|
||||
|
||||
self.dropoutLayer = nn.Dropout(self.dropout)
|
||||
self.layers = nn.Sequential(nn.Linear(self.input_seq, self.pred_seq * self.factor),
|
||||
nn.GELU(),
|
||||
nn.Dropout(self.dropout),
|
||||
nn.Linear(self.pred_seq * self.factor, self.pred_seq)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.transpose(1, 2)
|
||||
x = self.layers(x)
|
||||
x = x.transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class Mixer(nn.Module):
|
||||
def __init__(self,
|
||||
input_seq=[],
|
||||
out_seq=[],
|
||||
batch_size=[],
|
||||
channel=[],
|
||||
d_model=[],
|
||||
dropout=[],
|
||||
tfactor=[],
|
||||
dfactor=[]):
|
||||
super(Mixer, self).__init__()
|
||||
self.input_seq = input_seq
|
||||
self.pred_seq = out_seq
|
||||
self.batch_size = batch_size
|
||||
self.channel = channel
|
||||
self.d_model = d_model
|
||||
self.dropout = dropout
|
||||
self.tfactor = tfactor # expansion factor for patch mixer
|
||||
self.dfactor = dfactor # expansion factor for embedding mixer
|
||||
|
||||
self.tMixer = TokenMixer(input_seq=self.input_seq, batch_size=self.batch_size, channel=self.channel,
|
||||
pred_seq=self.pred_seq, dropout=self.dropout, factor=self.tfactor,
|
||||
d_model=self.d_model)
|
||||
self.dropoutLayer = nn.Dropout(self.dropout)
|
||||
self.norm1 = nn.BatchNorm2d(self.channel)
|
||||
self.norm2 = nn.BatchNorm2d(self.channel)
|
||||
|
||||
self.embeddingMixer = nn.Sequential(nn.Linear(self.d_model, self.d_model * self.dfactor),
|
||||
nn.GELU(),
|
||||
nn.Dropout(self.dropout),
|
||||
nn.Linear(self.d_model * self.dfactor, self.d_model))
|
||||
|
||||
def forward(self, x):
|
||||
'''
|
||||
Parameters
|
||||
----------
|
||||
x : input: [Batch, Channel, Patch_number, d_model]
|
||||
|
||||
Returns
|
||||
-------
|
||||
x: output: [Batch, Channel, Patch_number, d_model]
|
||||
|
||||
'''
|
||||
x = self.norm1(x)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
x = self.dropoutLayer(self.tMixer(x))
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
x = self.norm2(x)
|
||||
x = x + self.dropoutLayer(self.embeddingMixer(x))
|
||||
return x
|
||||
|
||||
|
||||
class ResolutionBranch(nn.Module):
|
||||
def __init__(self,
|
||||
input_seq=[],
|
||||
pred_seq=[],
|
||||
batch_size=[],
|
||||
channel=[],
|
||||
d_model=[],
|
||||
dropout=[],
|
||||
embedding_dropout=[],
|
||||
tfactor=[],
|
||||
dfactor=[],
|
||||
patch_len=[],
|
||||
patch_stride=[]):
|
||||
super(ResolutionBranch, self).__init__()
|
||||
self.input_seq = input_seq
|
||||
self.pred_seq = pred_seq
|
||||
self.batch_size = batch_size
|
||||
self.channel = channel
|
||||
self.d_model = d_model
|
||||
self.dropout = dropout
|
||||
self.embedding_dropout = embedding_dropout
|
||||
self.tfactor = tfactor
|
||||
self.dfactor = dfactor
|
||||
self.patch_len = patch_len
|
||||
self.patch_stride = patch_stride
|
||||
self.patch_num = int((self.input_seq - self.patch_len) / self.patch_stride + 2)
|
||||
|
||||
self.patch_norm = nn.BatchNorm2d(self.channel)
|
||||
self.patch_embedding_layer = nn.Linear(self.patch_len, self.d_model) # shared among all channels
|
||||
self.mixer1 = Mixer(input_seq=self.patch_num,
|
||||
out_seq=self.patch_num,
|
||||
batch_size=self.batch_size,
|
||||
channel=self.channel,
|
||||
d_model=self.d_model,
|
||||
dropout=self.dropout,
|
||||
tfactor=self.tfactor,
|
||||
dfactor=self.dfactor)
|
||||
self.mixer2 = Mixer(input_seq=self.patch_num,
|
||||
out_seq=self.patch_num,
|
||||
batch_size=self.batch_size,
|
||||
channel=self.channel,
|
||||
d_model=self.d_model,
|
||||
dropout=self.dropout,
|
||||
tfactor=self.tfactor,
|
||||
dfactor=self.dfactor)
|
||||
self.norm = nn.BatchNorm2d(self.channel)
|
||||
self.dropoutLayer = nn.Dropout(self.embedding_dropout)
|
||||
self.head = nn.Sequential(nn.Flatten(start_dim=-2, end_dim=-1),
|
||||
nn.Linear(self.patch_num * self.d_model, self.pred_seq))
|
||||
|
||||
def forward(self, x):
|
||||
'''
|
||||
Parameters
|
||||
----------
|
||||
x : input coefficient series: [Batch, channel, length_of_coefficient_series]
|
||||
|
||||
Returns
|
||||
-------
|
||||
out : predicted coefficient series: [Batch, channel, length_of_pred_coeff_series]
|
||||
'''
|
||||
|
||||
x_patch = self.do_patching(x)
|
||||
x_patch = self.patch_norm(x_patch)
|
||||
x_emb = self.dropoutLayer(self.patch_embedding_layer(x_patch))
|
||||
|
||||
out = self.mixer1(x_emb)
|
||||
res = out
|
||||
out = res + self.mixer2(out)
|
||||
out = self.norm(out)
|
||||
|
||||
out = self.head(out)
|
||||
return out
|
||||
|
||||
def do_patching(self, x):
|
||||
x_end = x[:, :, -1:]
|
||||
x_padding = x_end.repeat(1, 1, self.patch_stride)
|
||||
x_new = torch.cat((x, x_padding), dim=-1)
|
||||
x_patch = x_new.unfold(dimension=-1, size=self.patch_len, step=self.patch_stride)
|
||||
return x_patch
|
||||
|
||||
|
||||
class WPMixerCore(nn.Module):
|
||||
def __init__(self,
|
||||
input_length=[],
|
||||
pred_length=[],
|
||||
wavelet_name=[],
|
||||
level=[],
|
||||
batch_size=[],
|
||||
channel=[],
|
||||
d_model=[],
|
||||
dropout=[],
|
||||
embedding_dropout=[],
|
||||
tfactor=[],
|
||||
dfactor=[],
|
||||
device=[],
|
||||
patch_len=[],
|
||||
patch_stride=[],
|
||||
no_decomposition=[],
|
||||
use_amp=[]):
|
||||
super(WPMixerCore, self).__init__()
|
||||
self.input_length = input_length
|
||||
self.pred_length = pred_length
|
||||
self.wavelet_name = wavelet_name
|
||||
self.level = level
|
||||
self.batch_size = batch_size
|
||||
self.channel = channel
|
||||
self.d_model = d_model
|
||||
self.dropout = dropout
|
||||
self.embedding_dropout = embedding_dropout
|
||||
self.device = device
|
||||
self.no_decomposition = no_decomposition
|
||||
self.tfactor = tfactor
|
||||
self.dfactor = dfactor
|
||||
self.use_amp = use_amp
|
||||
|
||||
self.Decomposition_model = Decomposition(input_length=self.input_length,
|
||||
pred_length=self.pred_length,
|
||||
wavelet_name=self.wavelet_name,
|
||||
level=self.level,
|
||||
batch_size=self.batch_size,
|
||||
channel=self.channel,
|
||||
d_model=self.d_model,
|
||||
tfactor=self.tfactor,
|
||||
dfactor=self.dfactor,
|
||||
device=self.device,
|
||||
no_decomposition=self.no_decomposition,
|
||||
use_amp=self.use_amp)
|
||||
|
||||
self.input_w_dim = self.Decomposition_model.input_w_dim # list of the length of the input coefficient series
|
||||
self.pred_w_dim = self.Decomposition_model.pred_w_dim # list of the length of the predicted coefficient series
|
||||
|
||||
self.patch_len = patch_len
|
||||
self.patch_stride = patch_stride
|
||||
|
||||
# (m+1) number of resolutionBranch
|
||||
self.resolutionBranch = nn.ModuleList([ResolutionBranch(input_seq=self.input_w_dim[i],
|
||||
pred_seq=self.pred_w_dim[i],
|
||||
batch_size=self.batch_size,
|
||||
channel=self.channel,
|
||||
d_model=self.d_model,
|
||||
dropout=self.dropout,
|
||||
embedding_dropout=self.embedding_dropout,
|
||||
tfactor=self.tfactor,
|
||||
dfactor=self.dfactor,
|
||||
patch_len=self.patch_len,
|
||||
patch_stride=self.patch_stride) for i in
|
||||
range(len(self.input_w_dim))])
|
||||
|
||||
def forward(self, xL):
|
||||
'''
|
||||
Parameters
|
||||
----------
|
||||
xL : Look back window: [Batch, look_back_length, channel]
|
||||
|
||||
Returns
|
||||
-------
|
||||
xT : Prediction time series: [Batch, prediction_length, output_channel]
|
||||
'''
|
||||
x = xL.transpose(1, 2) # [batch, channel, look_back_length]
|
||||
|
||||
# xA: approximation coefficient series,
|
||||
# xD: detail coefficient series
|
||||
# yA: predicted approximation coefficient series
|
||||
# yD: predicted detail coefficient series
|
||||
|
||||
xA, xD = self.Decomposition_model.transform(x)
|
||||
|
||||
yA = self.resolutionBranch[0](xA)
|
||||
yD = []
|
||||
for i in range(len(xD)):
|
||||
yD_i = self.resolutionBranch[i + 1](xD[i])
|
||||
yD.append(yD_i)
|
||||
|
||||
y = self.Decomposition_model.inv_transform(yA, yD)
|
||||
y = y.transpose(1, 2)
|
||||
xT = y[:, -self.pred_length:, :] # decomposition output is always even, but pred length can be odd
|
||||
|
||||
return xT
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args, tfactor=5, dfactor=5, wavelet='db2', level=1, stride=8, no_decomposition=False):
|
||||
super(Model, self).__init__()
|
||||
self.args = args
|
||||
self.task_name = args.task_name
|
||||
self.wpmixerCore = WPMixerCore(input_length=self.args.seq_len,
|
||||
pred_length=self.args.pred_len,
|
||||
wavelet_name=wavelet,
|
||||
level=level,
|
||||
batch_size=self.args.batch_size,
|
||||
channel=self.args.c_out,
|
||||
d_model=self.args.d_model,
|
||||
dropout=self.args.dropout,
|
||||
embedding_dropout=self.args.dropout,
|
||||
tfactor=tfactor,
|
||||
dfactor=dfactor,
|
||||
device=self.args.device,
|
||||
patch_len=self.args.patch_len,
|
||||
patch_stride=stride,
|
||||
no_decomposition=no_decomposition,
|
||||
use_amp=self.args.use_amp)
|
||||
|
||||
def forecast(self, x_enc, x_mark_enc, x_dec, batch_y_mark):
|
||||
# Normalization
|
||||
means = x_enc.mean(1, keepdim=True).detach()
|
||||
x_enc = x_enc - means
|
||||
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
||||
x_enc /= stdev
|
||||
|
||||
pred = self.wpmixerCore(x_enc)
|
||||
pred = pred[:, :, -self.args.c_out:]
|
||||
|
||||
# De-Normalization
|
||||
dec_out = pred * (stdev[:, 0].unsqueeze(1).repeat(1, self.args.pred_len, 1))
|
||||
dec_out = dec_out + (means[:, 0].unsqueeze(1).repeat(1, self.args.pred_len, 1))
|
||||
return dec_out
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'imputation':
|
||||
raise NotImplementedError("Task imputation for WPMixer is temporarily not supported")
|
||||
if self.task_name == 'anomaly_detection':
|
||||
raise NotImplementedError("Task anomaly_detection for WPMixer is temporarily not supported")
|
||||
if self.task_name == 'classification':
|
||||
raise NotImplementedError("Task classification for WPMixer is temporarily not supported")
|
||||
return None
|
0
models/__init__.py
Normal file
0
models/__init__.py
Normal file
132
models/iTransformer.py
Normal file
132
models/iTransformer.py
Normal file
@ -0,0 +1,132 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from layers.Transformer_EncDec import Encoder, EncoderLayer
|
||||
from layers.SelfAttention_Family import FullAttention, AttentionLayer
|
||||
from layers.Embed import DataEmbedding_inverted
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
"""
|
||||
Paper link: https://arxiv.org/abs/2310.06625
|
||||
"""
|
||||
|
||||
def __init__(self, configs):
|
||||
super(Model, self).__init__()
|
||||
self.task_name = configs.task_name
|
||||
self.seq_len = configs.seq_len
|
||||
self.pred_len = configs.pred_len
|
||||
# Embedding
|
||||
self.enc_embedding = DataEmbedding_inverted(configs.seq_len, configs.d_model, configs.embed, configs.freq,
|
||||
configs.dropout)
|
||||
# Encoder
|
||||
self.encoder = Encoder(
|
||||
[
|
||||
EncoderLayer(
|
||||
AttentionLayer(
|
||||
FullAttention(False, configs.factor, attention_dropout=configs.dropout,
|
||||
output_attention=False), configs.d_model, configs.n_heads),
|
||||
configs.d_model,
|
||||
configs.d_ff,
|
||||
dropout=configs.dropout,
|
||||
activation=configs.activation
|
||||
) for l in range(configs.e_layers)
|
||||
],
|
||||
norm_layer=torch.nn.LayerNorm(configs.d_model)
|
||||
)
|
||||
# Decoder
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
self.projection = nn.Linear(configs.d_model, configs.pred_len, bias=True)
|
||||
if self.task_name == 'imputation':
|
||||
self.projection = nn.Linear(configs.d_model, configs.seq_len, bias=True)
|
||||
if self.task_name == 'anomaly_detection':
|
||||
self.projection = nn.Linear(configs.d_model, configs.seq_len, bias=True)
|
||||
if self.task_name == 'classification':
|
||||
self.act = F.gelu
|
||||
self.dropout = nn.Dropout(configs.dropout)
|
||||
self.projection = nn.Linear(configs.d_model * configs.enc_in, configs.num_class)
|
||||
|
||||
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
||||
# Normalization from Non-stationary Transformer
|
||||
means = x_enc.mean(1, keepdim=True).detach()
|
||||
x_enc = x_enc - means
|
||||
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
||||
x_enc /= stdev
|
||||
|
||||
_, _, N = x_enc.shape
|
||||
|
||||
# Embedding
|
||||
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
||||
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
||||
|
||||
dec_out = self.projection(enc_out).permute(0, 2, 1)[:, :, :N]
|
||||
# De-Normalization from Non-stationary Transformer
|
||||
dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
|
||||
dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
|
||||
return dec_out
|
||||
|
||||
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
|
||||
# Normalization from Non-stationary Transformer
|
||||
means = x_enc.mean(1, keepdim=True).detach()
|
||||
x_enc = x_enc - means
|
||||
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
||||
x_enc /= stdev
|
||||
|
||||
_, L, N = x_enc.shape
|
||||
|
||||
# Embedding
|
||||
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
||||
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
||||
|
||||
dec_out = self.projection(enc_out).permute(0, 2, 1)[:, :, :N]
|
||||
# De-Normalization from Non-stationary Transformer
|
||||
dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, L, 1))
|
||||
dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, L, 1))
|
||||
return dec_out
|
||||
|
||||
def anomaly_detection(self, x_enc):
|
||||
# Normalization from Non-stationary Transformer
|
||||
means = x_enc.mean(1, keepdim=True).detach()
|
||||
x_enc = x_enc - means
|
||||
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
||||
x_enc /= stdev
|
||||
|
||||
_, L, N = x_enc.shape
|
||||
|
||||
# Embedding
|
||||
enc_out = self.enc_embedding(x_enc, None)
|
||||
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
||||
|
||||
dec_out = self.projection(enc_out).permute(0, 2, 1)[:, :, :N]
|
||||
# De-Normalization from Non-stationary Transformer
|
||||
dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, L, 1))
|
||||
dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, L, 1))
|
||||
return dec_out
|
||||
|
||||
def classification(self, x_enc, x_mark_enc):
|
||||
# Embedding
|
||||
enc_out = self.enc_embedding(x_enc, None)
|
||||
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
||||
|
||||
# Output
|
||||
output = self.act(enc_out) # the output transformer encoder/decoder embeddings don't include non-linearity
|
||||
output = self.dropout(output)
|
||||
output = output.reshape(output.shape[0], -1) # (batch_size, c_in * d_model)
|
||||
output = self.projection(output) # (batch_size, num_classes)
|
||||
return output
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
||||
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
||||
if self.task_name == 'imputation':
|
||||
dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'anomaly_detection':
|
||||
dec_out = self.anomaly_detection(x_enc)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'classification':
|
||||
dec_out = self.classification(x_enc, x_mark_enc)
|
||||
return dec_out # [B, N]
|
||||
return None
|
166
models/xPatch_SparseChannel.py
Normal file
166
models/xPatch_SparseChannel.py
Normal file
@ -0,0 +1,166 @@
|
||||
"""
|
||||
xPatch_SparseChannel model adapted for Time-Series-Library-main
|
||||
Supports both long-term forecasting and classification tasks
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from layers.DECOMP import DECOMP
|
||||
from layers.SeasonPatch import SeasonPatch
|
||||
from layers.RevIN import RevIN
|
||||
|
||||
class Model(nn.Module):
|
||||
"""
|
||||
xPatch SparseChannel Model
|
||||
"""
|
||||
|
||||
def __init__(self, configs):
|
||||
super(Model, self).__init__()
|
||||
|
||||
# Model configuration
|
||||
self.task_name = configs.task_name
|
||||
self.seq_len = configs.seq_len
|
||||
self.pred_len = configs.pred_len
|
||||
self.enc_in = configs.enc_in
|
||||
|
||||
# Model parameters
|
||||
self.patch_len = getattr(configs, 'patch_len', 16)
|
||||
self.stride = getattr(configs, 'stride', 8)
|
||||
|
||||
# Normalization
|
||||
self.revin = getattr(configs, 'revin', True)
|
||||
if self.revin:
|
||||
self.revin_layer = RevIN(self.enc_in, affine=True, subtract_last=False)
|
||||
|
||||
# Decomposition using original DECOMP with EMA/DEMA
|
||||
ma_type = getattr(configs, 'ma_type', 'ema')
|
||||
alpha = getattr(configs, 'alpha', torch.tensor(0.1))
|
||||
beta = getattr(configs, 'beta', torch.tensor(0.1))
|
||||
self.decomp = DECOMP(ma_type, alpha, beta)
|
||||
|
||||
# Season network (PatchTST + Graph Mixer)
|
||||
self.season_net = SeasonPatch(
|
||||
c_in=self.enc_in,
|
||||
seq_len=self.seq_len,
|
||||
pred_len=self.pred_len,
|
||||
patch_len=self.patch_len,
|
||||
stride=self.stride,
|
||||
k_graph=getattr(configs, 'k_graph', 8),
|
||||
d_model=getattr(configs, 'd_model', 128),
|
||||
n_layers=getattr(configs, 'e_layers', 3),
|
||||
n_heads=getattr(configs, 'n_heads', 16)
|
||||
)
|
||||
|
||||
# Trend network (MLP)
|
||||
self.fc5 = nn.Linear(self.seq_len, self.pred_len * 4)
|
||||
self.avgpool1 = nn.AvgPool1d(kernel_size=2)
|
||||
self.ln1 = nn.LayerNorm(self.pred_len * 2)
|
||||
self.fc6 = nn.Linear(self.pred_len * 2, self.pred_len)
|
||||
self.avgpool2 = nn.AvgPool1d(kernel_size=2)
|
||||
self.ln2 = nn.LayerNorm(self.pred_len // 2)
|
||||
self.fc7 = nn.Linear(self.pred_len // 2, self.pred_len)
|
||||
|
||||
# Task-specific heads
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
self.fc_final = nn.Linear(self.pred_len * 2, self.pred_len)
|
||||
elif self.task_name == 'classification':
|
||||
self.season_attention = nn.Sequential(
|
||||
nn.Linear(self.pred_len, 64),
|
||||
nn.Tanh(),
|
||||
nn.Linear(64, 1)
|
||||
)
|
||||
self.trend_attention = nn.Sequential(
|
||||
nn.Linear(self.pred_len, 64),
|
||||
nn.Tanh(),
|
||||
nn.Linear(64, 1)
|
||||
)
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(self.enc_in * 2, 128),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(getattr(configs, 'dropout', 0.1)),
|
||||
nn.Linear(128, configs.num_class)
|
||||
)
|
||||
|
||||
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
||||
"""Long-term forecasting"""
|
||||
# Normalization
|
||||
if self.revin:
|
||||
x_enc = self.revin_layer(x_enc, 'norm')
|
||||
|
||||
# Decomposition
|
||||
seasonal_init, trend_init = self.decomp(x_enc)
|
||||
|
||||
# Season stream
|
||||
y_season = self.season_net(seasonal_init) # [B, C, pred_len]
|
||||
|
||||
# Trend stream
|
||||
B, L, C = trend_init.shape
|
||||
trend = trend_init.permute(0, 2, 1).reshape(B * C, L) # [B*C, L]
|
||||
trend = self.fc5(trend)
|
||||
trend = self.avgpool1(trend)
|
||||
trend = self.ln1(trend)
|
||||
trend = self.fc6(trend)
|
||||
trend = self.avgpool2(trend)
|
||||
trend = self.ln2(trend)
|
||||
trend = self.fc7(trend) # [B*C, pred_len]
|
||||
y_trend = trend.view(B, C, -1) # [B, C, pred_len]
|
||||
|
||||
# Combine streams
|
||||
y = torch.cat([y_season, y_trend], dim=-1) # [B, C, 2*pred_len]
|
||||
y = self.fc_final(y) # [B, C, pred_len]
|
||||
y = y.permute(0, 2, 1) # [B, pred_len, C]
|
||||
|
||||
# Denormalization
|
||||
if self.revin:
|
||||
y = self.revin_layer(y, 'denorm')
|
||||
|
||||
return y
|
||||
|
||||
def classification(self, x_enc, x_mark_enc):
|
||||
"""Classification task"""
|
||||
# Normalization
|
||||
#if self.revin:
|
||||
# x_enc = self.revin_layer(x_enc, 'norm')
|
||||
|
||||
# Decomposition
|
||||
seasonal_init, trend_init = self.decomp(x_enc)
|
||||
|
||||
# Season stream
|
||||
y_season = self.season_net(seasonal_init) # [B, C, pred_len]
|
||||
|
||||
# print("shape:", trend_init.shape)
|
||||
# Trend stream
|
||||
B, L, C = trend_init.shape
|
||||
trend = trend_init.permute(0, 2, 1).reshape(B * C, L) # [B*C, L]
|
||||
trend = self.fc5(trend)
|
||||
trend = self.avgpool1(trend)
|
||||
trend = self.ln1(trend)
|
||||
trend = self.fc6(trend)
|
||||
trend = self.avgpool2(trend)
|
||||
trend = self.ln2(trend)
|
||||
trend = self.fc7(trend) # [B*C, pred_len]
|
||||
y_trend = trend.view(B, C, -1) # [B, C, pred_len]
|
||||
|
||||
# Attention-based pooling for classification
|
||||
season_attn_weights = torch.softmax(y_season, dim=-1)
|
||||
season_pooled = (y_season * season_attn_weights).sum(dim=-1) # [B, C]
|
||||
|
||||
trend_attn_weights = torch.softmax(y_trend, dim=-1) # 时间维
|
||||
trend_pooled = (y_trend * trend_attn_weights).sum(dim=-1) # [B, C]
|
||||
|
||||
# Combine features
|
||||
features = torch.cat([season_pooled, trend_pooled], dim=-1) # [B, 2*C]
|
||||
|
||||
# Classification
|
||||
logits = self.classifier(features) # [B, num_classes]
|
||||
return logits
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
||||
"""Forward pass dispatching to task-specific methods"""
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
||||
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
||||
elif self.task_name == 'classification':
|
||||
dec_out = self.classification(x_enc, x_mark_enc)
|
||||
return dec_out # [B, N]
|
||||
else:
|
||||
raise ValueError(f'Task {self.task_name} not supported by xPatch_SparseChannel')
|
Reference in New Issue
Block a user