first timesnet try

This commit is contained in:
game-loader
2025-07-30 21:18:46 +08:00
parent dc8c9f1f09
commit 6ee5c769c4
17 changed files with 2918 additions and 0 deletions

View File

@ -0,0 +1,203 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class my_Layernorm(nn.Module):
"""
Special designed layernorm for the seasonal part
"""
def __init__(self, channels):
super(my_Layernorm, self).__init__()
self.layernorm = nn.LayerNorm(channels)
def forward(self, x):
x_hat = self.layernorm(x)
bias = torch.mean(x_hat, dim=1).unsqueeze(1).repeat(1, x.shape[1], 1)
return x_hat - bias
class moving_avg(nn.Module):
"""
Moving average block to highlight the trend of time series
"""
def __init__(self, kernel_size, stride):
super(moving_avg, self).__init__()
self.kernel_size = kernel_size
self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
def forward(self, x):
# padding on the both ends of time series
front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
x = torch.cat([front, x, end], dim=1)
x = self.avg(x.permute(0, 2, 1))
x = x.permute(0, 2, 1)
return x
class series_decomp(nn.Module):
"""
Series decomposition block
"""
def __init__(self, kernel_size):
super(series_decomp, self).__init__()
self.moving_avg = moving_avg(kernel_size, stride=1)
def forward(self, x):
moving_mean = self.moving_avg(x)
res = x - moving_mean
return res, moving_mean
class series_decomp_multi(nn.Module):
"""
Multiple Series decomposition block from FEDformer
"""
def __init__(self, kernel_size):
super(series_decomp_multi, self).__init__()
self.kernel_size = kernel_size
self.series_decomp = [series_decomp(kernel) for kernel in kernel_size]
def forward(self, x):
moving_mean = []
res = []
for func in self.series_decomp:
sea, moving_avg = func(x)
moving_mean.append(moving_avg)
res.append(sea)
sea = sum(res) / len(res)
moving_mean = sum(moving_mean) / len(moving_mean)
return sea, moving_mean
class EncoderLayer(nn.Module):
"""
Autoformer encoder layer with the progressive decomposition architecture
"""
def __init__(self, attention, d_model, d_ff=None, moving_avg=25, dropout=0.1, activation="relu"):
super(EncoderLayer, self).__init__()
d_ff = d_ff or 4 * d_model
self.attention = attention
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False)
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False)
self.decomp1 = series_decomp(moving_avg)
self.decomp2 = series_decomp(moving_avg)
self.dropout = nn.Dropout(dropout)
self.activation = F.relu if activation == "relu" else F.gelu
def forward(self, x, attn_mask=None):
new_x, attn = self.attention(
x, x, x,
attn_mask=attn_mask
)
x = x + self.dropout(new_x)
x, _ = self.decomp1(x)
y = x
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
y = self.dropout(self.conv2(y).transpose(-1, 1))
res, _ = self.decomp2(x + y)
return res, attn
class Encoder(nn.Module):
"""
Autoformer encoder
"""
def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
super(Encoder, self).__init__()
self.attn_layers = nn.ModuleList(attn_layers)
self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
self.norm = norm_layer
def forward(self, x, attn_mask=None):
attns = []
if self.conv_layers is not None:
for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers):
x, attn = attn_layer(x, attn_mask=attn_mask)
x = conv_layer(x)
attns.append(attn)
x, attn = self.attn_layers[-1](x)
attns.append(attn)
else:
for attn_layer in self.attn_layers:
x, attn = attn_layer(x, attn_mask=attn_mask)
attns.append(attn)
if self.norm is not None:
x = self.norm(x)
return x, attns
class DecoderLayer(nn.Module):
"""
Autoformer decoder layer with the progressive decomposition architecture
"""
def __init__(self, self_attention, cross_attention, d_model, c_out, d_ff=None,
moving_avg=25, dropout=0.1, activation="relu"):
super(DecoderLayer, self).__init__()
d_ff = d_ff or 4 * d_model
self.self_attention = self_attention
self.cross_attention = cross_attention
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False)
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False)
self.decomp1 = series_decomp(moving_avg)
self.decomp2 = series_decomp(moving_avg)
self.decomp3 = series_decomp(moving_avg)
self.dropout = nn.Dropout(dropout)
self.projection = nn.Conv1d(in_channels=d_model, out_channels=c_out, kernel_size=3, stride=1, padding=1,
padding_mode='circular', bias=False)
self.activation = F.relu if activation == "relu" else F.gelu
def forward(self, x, cross, x_mask=None, cross_mask=None):
x = x + self.dropout(self.self_attention(
x, x, x,
attn_mask=x_mask
)[0])
x, trend1 = self.decomp1(x)
x = x + self.dropout(self.cross_attention(
x, cross, cross,
attn_mask=cross_mask
)[0])
x, trend2 = self.decomp2(x)
y = x
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
y = self.dropout(self.conv2(y).transpose(-1, 1))
x, trend3 = self.decomp3(x + y)
residual_trend = trend1 + trend2 + trend3
residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose(1, 2)
return x, residual_trend
class Decoder(nn.Module):
"""
Autoformer encoder
"""
def __init__(self, layers, norm_layer=None, projection=None):
super(Decoder, self).__init__()
self.layers = nn.ModuleList(layers)
self.norm = norm_layer
self.projection = projection
def forward(self, x, cross, x_mask=None, cross_mask=None, trend=None):
for layer in self.layers:
x, residual_trend = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask)
trend = trend + residual_trend
if self.norm is not None:
x = self.norm(x)
if self.projection is not None:
x = self.projection(x)
return x, trend

234
models/TimeMixer++/Embed.py Normal file
View File

@ -0,0 +1,234 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
import math
class PositionalEmbedding(nn.Module):
def __init__(self, d_model, max_len=5000):
super(PositionalEmbedding, self).__init__()
# Compute the positional encodings once in log space.
pe = torch.zeros(max_len, d_model).float()
pe.require_grad = False
position = torch.arange(0, max_len).float().unsqueeze(1)
div_term = (torch.arange(0, d_model, 2).float()
* -(math.log(10000.0) / d_model)).exp()
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
return self.pe[:, :x.size(1)]
class TokenEmbedding(nn.Module):
def __init__(self, c_in, d_model):
super(TokenEmbedding, self).__init__()
padding = 1 if torch.__version__ >= '1.5.0' else 2
self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
kernel_size=3, padding=padding, padding_mode='circular', bias=False)
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(
m.weight, mode='fan_in', nonlinearity='leaky_relu')
def forward(self, x):
x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
return x
class FixedEmbedding(nn.Module):
def __init__(self, c_in, d_model):
super(FixedEmbedding, self).__init__()
w = torch.zeros(c_in, d_model).float()
w.require_grad = False
position = torch.arange(0, c_in).float().unsqueeze(1)
div_term = (torch.arange(0, d_model, 2).float()
* -(math.log(10000.0) / d_model)).exp()
w[:, 0::2] = torch.sin(position * div_term)
w[:, 1::2] = torch.cos(position * div_term)
self.emb = nn.Embedding(c_in, d_model)
self.emb.weight = nn.Parameter(w, requires_grad=False)
def forward(self, x):
return self.emb(x).detach()
class TemporalEmbedding(nn.Module):
def __init__(self, d_model, embed_type='fixed', freq='h'):
super(TemporalEmbedding, self).__init__()
minute_size = 4
hour_size = 24
weekday_size = 7
day_size = 32
month_size = 13
Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding
if freq == 't':
self.minute_embed = Embed(minute_size, d_model)
self.hour_embed = Embed(hour_size, d_model)
self.weekday_embed = Embed(weekday_size, d_model)
self.day_embed = Embed(day_size, d_model)
self.month_embed = Embed(month_size, d_model)
def forward(self, x):
x = x.long()
minute_x = self.minute_embed(x[:, :, 4]) if hasattr(
self, 'minute_embed') else 0.
hour_x = self.hour_embed(x[:, :, 3])
weekday_x = self.weekday_embed(x[:, :, 2])
day_x = self.day_embed(x[:, :, 1])
month_x = self.month_embed(x[:, :, 0])
return hour_x + weekday_x + day_x + month_x + minute_x
class TimeFeatureEmbedding(nn.Module):
def __init__(self, d_model, embed_type='timeF', freq='h'):
super(TimeFeatureEmbedding, self).__init__()
freq_map = {'h': 4, 't': 5, 's': 6, 'ms': 7,
'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3}
d_inp = freq_map[freq]
self.embed = nn.Linear(d_inp, d_model, bias=False)
def forward(self, x):
return self.embed(x)
class DataEmbedding(nn.Module):
def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
super(DataEmbedding, self).__init__()
self.c_in = c_in
self.d_model = d_model
self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
self.position_embedding = PositionalEmbedding(d_model=d_model)
self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type,
freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(
d_model=d_model, embed_type=embed_type, freq=freq)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x, x_mark):
_, _, N = x.size()
if N == self.c_in:
if x_mark is None:
x = self.value_embedding(x) + self.position_embedding(x)
else:
x = self.value_embedding(
x) + self.temporal_embedding(x_mark) + self.position_embedding(x)
elif N == self.d_model:
if x_mark is None:
x = x + self.position_embedding(x)
else:
x = x + self.temporal_embedding(x_mark) + self.position_embedding(x)
return self.dropout(x)
class DataEmbedding_ms(nn.Module):
def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
super(DataEmbedding_ms, self).__init__()
self.value_embedding = TokenEmbedding(c_in=1, d_model=d_model)
self.position_embedding = PositionalEmbedding(d_model=d_model)
self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type,
freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(
d_model=d_model, embed_type=embed_type, freq=freq)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x, x_mark):
B, T, N = x.shape
x1 = self.value_embedding(x.reshape(0, 2, 1).reshape(B * N, T).unsqueeze(-1)).reshape(B, N, T, -1).permute(0, 2,
1, 3)
if x_mark is None:
x = x1
else:
x = x1 + self.temporal_embedding(x_mark)
return self.dropout(x)
class DataEmbedding_wo_pos(nn.Module):
def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
super(DataEmbedding_wo_pos, self).__init__()
self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
self.position_embedding = PositionalEmbedding(d_model=d_model)
self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type,
freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(
d_model=d_model, embed_type=embed_type, freq=freq)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x, x_mark):
if x is None and x_mark is not None:
return self.temporal_embedding(x_mark)
if x_mark is None:
x = self.value_embedding(x)
else:
x = self.value_embedding(x) + self.temporal_embedding(x_mark)
return self.dropout(x)
class PatchEmbedding_crossformer(nn.Module):
def __init__(self, d_model, patch_len, stride, padding, dropout):
super(PatchEmbedding_crossformer, self).__init__()
# Patching
self.patch_len = patch_len
self.stride = stride
self.padding_patch_layer = nn.ReplicationPad1d((0, padding))
# Backbone, Input encoding: projection of feature vectors onto a d-dim vector space
self.value_embedding = nn.Linear(patch_len, d_model, bias=False)
# Positional embedding
self.position_embedding = PositionalEmbedding(d_model)
# Residual dropout
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# do patching
n_vars = x.shape[1]
x = self.padding_patch_layer(x)
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
# Input encoding
x = self.value_embedding(x) + self.position_embedding(x)
return self.dropout(x), n_vars
class PatchEmbedding(nn.Module):
def __init__(self, d_model, patch_len, stride, dropout):
super(PatchEmbedding, self).__init__()
# Patching
self.patch_len = patch_len
self.stride = stride
self.padding_patch_layer = nn.ReplicationPad1d((0, stride))
# Backbone, Input encoding: projection of feature vectors onto a d-dim vector space
self.value_embedding = TokenEmbedding(patch_len, d_model)
# Positional embedding
self.position_embedding = PositionalEmbedding(d_model)
# Residual dropout
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# do patching
n_vars = x.shape[1]
x = self.padding_patch_layer(x)
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
# Input encoding
x = self.value_embedding(x) + self.position_embedding(x)
return self.dropout(x), n_vars

View File

@ -0,0 +1,67 @@
import torch
import torch.nn as nn
class Normalize(nn.Module):
def __init__(self, num_features: int, eps=1e-5, affine=False, subtract_last=False, non_norm=False):
"""
:param num_features: the number of features or channels
:param eps: a value added for numerical stability
:param affine: if True, RevIN has learnable affine parameters
"""
super(Normalize, self).__init__()
self.num_features = num_features
self.eps = eps
self.affine = affine
self.subtract_last = subtract_last
self.non_norm = non_norm
if self.affine:
self._init_params()
def forward(self, x, mode: str):
if mode == 'norm':
self._get_statistics(x)
x = self._normalize(x)
elif mode == 'denorm':
x = self._denormalize(x)
else:
raise NotImplementedError
return x
def _init_params(self):
# initialize RevIN params: (C,)
self.affine_weight = nn.Parameter(torch.ones(self.num_features))
self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
def _get_statistics(self, x):
dim2reduce = tuple(range(1, x.ndim - 1))
if self.subtract_last:
self.last = x[:, -1, :].unsqueeze(1)
else:
self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()
def _normalize(self, x):
if self.non_norm:
return x
if self.subtract_last:
x = x - self.last
else:
x = x - self.mean
x = x / self.stdev
if self.affine:
x = x * self.affine_weight
x = x + self.affine_bias
return x
def _denormalize(self, x):
if self.non_norm:
return x
if self.affine:
x = x - self.affine_bias
x = x / (self.affine_weight + self.eps * self.eps)
x = x * self.stdev
if self.subtract_last:
x = x + self.last
else:
x = x + self.mean
return x

View File

@ -0,0 +1,527 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from layers.Autoformer_EncDec import series_decomp
from layers.Embed import DataEmbedding_wo_pos
from layers.StandardNorm import Normalize
class DFT_series_decomp(nn.Module):
"""
Series decomposition block
"""
def __init__(self, top_k=5):
super(DFT_series_decomp, self).__init__()
self.top_k = top_k
def forward(self, x):
xf = torch.fft.rfft(x)
freq = abs(xf)
freq[0] = 0
top_k_freq, top_list = torch.topk(freq, self.top_k)
xf[freq <= top_k_freq.min()] = 0
x_season = torch.fft.irfft(xf)
x_trend = x - x_season
return x_season, x_trend
class MultiScaleSeasonMixing(nn.Module):
"""
Bottom-up mixing season pattern
"""
def __init__(self, configs):
super(MultiScaleSeasonMixing, self).__init__()
self.down_sampling_layers = torch.nn.ModuleList(
[
nn.Sequential(
torch.nn.Linear(
configs.seq_len // (configs.down_sampling_window ** i),
configs.seq_len // (configs.down_sampling_window ** (i + 1)),
),
nn.GELU(),
torch.nn.Linear(
configs.seq_len // (configs.down_sampling_window ** (i + 1)),
configs.seq_len // (configs.down_sampling_window ** (i + 1)),
),
)
for i in range(configs.down_sampling_layers)
]
)
def forward(self, season_list):
# mixing high->low
out_high = season_list[0]
out_low = season_list[1]
out_season_list = [out_high.permute(0, 2, 1)]
for i in range(len(season_list) - 1):
out_low_res = self.down_sampling_layers[i](out_high)
out_low = out_low + out_low_res
out_high = out_low
if i + 2 <= len(season_list) - 1:
out_low = season_list[i + 2]
out_season_list.append(out_high.permute(0, 2, 1))
return out_season_list
class MultiScaleTrendMixing(nn.Module):
"""
Top-down mixing trend pattern
"""
def __init__(self, configs):
super(MultiScaleTrendMixing, self).__init__()
self.up_sampling_layers = torch.nn.ModuleList(
[
nn.Sequential(
torch.nn.Linear(
configs.seq_len // (configs.down_sampling_window ** (i + 1)),
configs.seq_len // (configs.down_sampling_window ** i),
),
nn.GELU(),
torch.nn.Linear(
configs.seq_len // (configs.down_sampling_window ** i),
configs.seq_len // (configs.down_sampling_window ** i),
),
)
for i in reversed(range(configs.down_sampling_layers))
])
def forward(self, trend_list):
# mixing low->high
trend_list_reverse = trend_list.copy()
trend_list_reverse.reverse()
out_low = trend_list_reverse[0]
out_high = trend_list_reverse[1]
out_trend_list = [out_low.permute(0, 2, 1)]
for i in range(len(trend_list_reverse) - 1):
out_high_res = self.up_sampling_layers[i](out_low)
out_high = out_high + out_high_res
out_low = out_high
if i + 2 <= len(trend_list_reverse) - 1:
out_high = trend_list_reverse[i + 2]
out_trend_list.append(out_low.permute(0, 2, 1))
out_trend_list.reverse()
return out_trend_list
class PastDecomposableMixing(nn.Module):
def __init__(self, configs):
super(PastDecomposableMixing, self).__init__()
self.seq_len = configs.seq_len
self.pred_len = configs.pred_len
self.down_sampling_window = configs.down_sampling_window
self.layer_norm = nn.LayerNorm(configs.d_model)
self.dropout = nn.Dropout(configs.dropout)
self.channel_independence = configs.channel_independence
if configs.decomp_method == 'moving_avg':
self.decompsition = series_decomp(configs.moving_avg)
elif configs.decomp_method == "dft_decomp":
self.decompsition = DFT_series_decomp(configs.top_k)
else:
raise ValueError('decompsition is error')
if configs.channel_independence == 0:
self.cross_layer = nn.Sequential(
nn.Linear(in_features=configs.d_model, out_features=configs.d_ff),
nn.GELU(),
nn.Linear(in_features=configs.d_ff, out_features=configs.d_model),
)
# Mixing season
self.mixing_multi_scale_season = MultiScaleSeasonMixing(configs)
# Mxing trend
self.mixing_multi_scale_trend = MultiScaleTrendMixing(configs)
self.out_cross_layer = nn.Sequential(
nn.Linear(in_features=configs.d_model, out_features=configs.d_ff),
nn.GELU(),
nn.Linear(in_features=configs.d_ff, out_features=configs.d_model),
)
def forward(self, x_list):
length_list = []
for x in x_list:
_, T, _ = x.size()
length_list.append(T)
# Decompose to obtain the season and trend
season_list = []
trend_list = []
for x in x_list:
season, trend = self.decompsition(x)
if self.channel_independence == 0:
season = self.cross_layer(season)
trend = self.cross_layer(trend)
season_list.append(season.permute(0, 2, 1))
trend_list.append(trend.permute(0, 2, 1))
# bottom-up season mixing
out_season_list = self.mixing_multi_scale_season(season_list)
# top-down trend mixing
out_trend_list = self.mixing_multi_scale_trend(trend_list)
out_list = []
for ori, out_season, out_trend, length in zip(x_list, out_season_list, out_trend_list,
length_list):
out = out_season + out_trend
if self.channel_independence:
out = ori + self.out_cross_layer(out)
out_list.append(out[:, :length, :])
return out_list
class TimeMixer(nn.Module):
def __init__(self, configs):
super(TimeMixer, self).__init__()
self.configs = configs
self.task_name = configs.task_name
self.seq_len = configs.seq_len
self.label_len = configs.label_len
self.pred_len = configs.pred_len
self.down_sampling_window = configs.down_sampling_window
self.channel_independence = configs.channel_independence
self.pdm_blocks = nn.ModuleList([PastDecomposableMixing(configs)
for _ in range(configs.e_layers)])
self.preprocess = series_decomp(configs.moving_avg)
self.enc_in = configs.enc_in
self.use_future_temporal_feature = configs.use_future_temporal_feature
if self.channel_independence == 1:
self.enc_embedding = DataEmbedding_wo_pos(1, configs.d_model, configs.embed, configs.freq,
configs.dropout)
else:
self.enc_embedding = DataEmbedding_wo_pos(configs.enc_in, configs.d_model, configs.embed, configs.freq,
configs.dropout)
self.layer = configs.e_layers
self.normalize_layers = torch.nn.ModuleList(
[
Normalize(self.configs.enc_in, affine=True, non_norm=True if configs.use_norm == 0 else False)
for i in range(configs.down_sampling_layers + 1)
]
)
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
self.predict_layers = torch.nn.ModuleList(
[
torch.nn.Linear(
configs.seq_len // (configs.down_sampling_window ** i),
configs.pred_len,
)
for i in range(configs.down_sampling_layers + 1)
]
)
if self.channel_independence == 1:
self.projection_layer = nn.Linear(
configs.d_model, 1, bias=True)
else:
self.projection_layer = nn.Linear(
configs.d_model, configs.c_out, bias=True)
self.out_res_layers = torch.nn.ModuleList([
torch.nn.Linear(
configs.seq_len // (configs.down_sampling_window ** i),
configs.seq_len // (configs.down_sampling_window ** i),
)
for i in range(configs.down_sampling_layers + 1)
])
self.regression_layers = torch.nn.ModuleList(
[
torch.nn.Linear(
configs.seq_len // (configs.down_sampling_window ** i),
configs.pred_len,
)
for i in range(configs.down_sampling_layers + 1)
]
)
if self.task_name == 'imputation' or self.task_name == 'anomaly_detection':
if self.channel_independence == 1:
self.projection_layer = nn.Linear(
configs.d_model, 1, bias=True)
else:
self.projection_layer = nn.Linear(
configs.d_model, configs.c_out, bias=True)
if self.task_name == 'classification':
self.act = F.gelu
self.dropout = nn.Dropout(configs.dropout)
self.projection = nn.Linear(
configs.d_model * configs.seq_len, configs.num_class)
def out_projection(self, dec_out, i, out_res):
dec_out = self.projection_layer(dec_out)
out_res = out_res.permute(0, 2, 1)
out_res = self.out_res_layers[i](out_res)
out_res = self.regression_layers[i](out_res).permute(0, 2, 1)
dec_out = dec_out + out_res
return dec_out
def pre_enc(self, x_list):
if self.channel_independence == 1:
return (x_list, None)
else:
out1_list = []
out2_list = []
for x in x_list:
x_1, x_2 = self.preprocess(x)
out1_list.append(x_1)
out2_list.append(x_2)
return (out1_list, out2_list)
def __multi_scale_process_inputs(self, x_enc, x_mark_enc):
if self.configs.down_sampling_method == 'max':
down_pool = torch.nn.MaxPool1d(self.configs.down_sampling_window, return_indices=False)
elif self.configs.down_sampling_method == 'avg':
down_pool = torch.nn.AvgPool1d(self.configs.down_sampling_window)
elif self.configs.down_sampling_method == 'conv':
padding = 1 if torch.__version__ >= '1.5.0' else 2
down_pool = nn.Conv1d(in_channels=self.configs.enc_in, out_channels=self.configs.enc_in,
kernel_size=3, padding=padding,
stride=self.configs.down_sampling_window,
padding_mode='circular',
bias=False)
else:
return x_enc, x_mark_enc
# B,T,C -> B,C,T
x_enc = x_enc.permute(0, 2, 1)
x_enc_ori = x_enc
x_mark_enc_mark_ori = x_mark_enc
x_enc_sampling_list = []
x_mark_sampling_list = []
x_enc_sampling_list.append(x_enc.permute(0, 2, 1))
x_mark_sampling_list.append(x_mark_enc)
for i in range(self.configs.down_sampling_layers):
x_enc_sampling = down_pool(x_enc_ori)
x_enc_sampling_list.append(x_enc_sampling.permute(0, 2, 1))
x_enc_ori = x_enc_sampling
if x_mark_enc_mark_ori is not None:
x_mark_sampling_list.append(x_mark_enc_mark_ori[:, ::self.configs.down_sampling_window, :])
x_mark_enc_mark_ori = x_mark_enc_mark_ori[:, ::self.configs.down_sampling_window, :]
x_enc = x_enc_sampling_list
if x_mark_enc_mark_ori is not None:
x_mark_enc = x_mark_sampling_list
else:
x_mark_enc = x_mark_enc
return x_enc, x_mark_enc
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
if self.use_future_temporal_feature:
if self.channel_independence == 1:
B, T, N = x_enc.size()
x_mark_dec = x_mark_dec.repeat(N, 1, 1)
self.x_mark_dec = self.enc_embedding(None, x_mark_dec)
else:
self.x_mark_dec = self.enc_embedding(None, x_mark_dec)
x_enc, x_mark_enc = self.__multi_scale_process_inputs(x_enc, x_mark_enc)
x_list = []
x_mark_list = []
if x_mark_enc is not None:
for i, x, x_mark in zip(range(len(x_enc)), x_enc, x_mark_enc):
B, T, N = x.size()
x = self.normalize_layers[i](x, 'norm')
if self.channel_independence == 1:
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
x_mark = x_mark.repeat(N, 1, 1)
x_list.append(x)
x_mark_list.append(x_mark)
else:
for i, x in zip(range(len(x_enc)), x_enc, ):
B, T, N = x.size()
x = self.normalize_layers[i](x, 'norm')
if self.channel_independence == 1:
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
x_list.append(x)
# embedding
enc_out_list = []
x_list = self.pre_enc(x_list)
if x_mark_enc is not None:
for i, x, x_mark in zip(range(len(x_list[0])), x_list[0], x_mark_list):
enc_out = self.enc_embedding(x, x_mark) # [B,T,C]
enc_out_list.append(enc_out)
else:
for i, x in zip(range(len(x_list[0])), x_list[0]):
enc_out = self.enc_embedding(x, None) # [B,T,C]
enc_out_list.append(enc_out)
# Past Decomposable Mixing as encoder for past
for i in range(self.layer):
enc_out_list = self.pdm_blocks[i](enc_out_list)
# Future Multipredictor Mixing as decoder for future
dec_out_list = self.future_multi_mixing(B, enc_out_list, x_list)
dec_out = torch.stack(dec_out_list, dim=-1).sum(-1)
dec_out = self.normalize_layers[0](dec_out, 'denorm')
return dec_out
def future_multi_mixing(self, B, enc_out_list, x_list):
dec_out_list = []
if self.channel_independence == 1:
x_list = x_list[0]
for i, enc_out in zip(range(len(x_list)), enc_out_list):
dec_out = self.predict_layers[i](enc_out.permute(0, 2, 1)).permute(
0, 2, 1) # align temporal dimension
if self.use_future_temporal_feature:
dec_out = dec_out + self.x_mark_dec
dec_out = self.projection_layer(dec_out)
else:
dec_out = self.projection_layer(dec_out)
dec_out = dec_out.reshape(B, self.configs.c_out, self.pred_len).permute(0, 2, 1).contiguous()
dec_out_list.append(dec_out)
else:
for i, enc_out, out_res in zip(range(len(x_list[0])), enc_out_list, x_list[1]):
dec_out = self.predict_layers[i](enc_out.permute(0, 2, 1)).permute(
0, 2, 1) # align temporal dimension
dec_out = self.out_projection(dec_out, i, out_res)
dec_out_list.append(dec_out)
return dec_out_list
def classification(self, x_enc, x_mark_enc):
x_enc, _ = self.__multi_scale_process_inputs(x_enc, None)
x_list = x_enc
# embedding
enc_out_list = []
for x in x_list:
enc_out = self.enc_embedding(x, None) # [B,T,C]
enc_out_list.append(enc_out)
# MultiScale-CrissCrossAttention as encoder for past
for i in range(self.layer):
enc_out_list = self.pdm_blocks[i](enc_out_list)
enc_out = enc_out_list[0]
# Output
# the output transformer encoder/decoder embeddings don't include non-linearity
output = self.act(enc_out)
output = self.dropout(output)
# zero-out padding embeddings
output = output * x_mark_enc.unsqueeze(-1)
# (batch_size, seq_length * d_model)
output = output.reshape(output.shape[0], -1)
output = self.projection(output) # (batch_size, num_classes)
return output
def anomaly_detection(self, x_enc):
B, T, N = x_enc.size()
x_enc, _ = self.__multi_scale_process_inputs(x_enc, None)
x_list = []
for i, x in zip(range(len(x_enc)), x_enc, ):
B, T, N = x.size()
x = self.normalize_layers[i](x, 'norm')
if self.channel_independence == 1:
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
x_list.append(x)
# embedding
enc_out_list = []
for x in x_list:
enc_out = self.enc_embedding(x, None) # [B,T,C]
enc_out_list.append(enc_out)
# MultiScale-CrissCrossAttention as encoder for past
for i in range(self.layer):
enc_out_list = self.pdm_blocks[i](enc_out_list)
dec_out = self.projection_layer(enc_out_list[0])
dec_out = dec_out.reshape(B, self.configs.c_out, -1).permute(0, 2, 1).contiguous()
dec_out = self.normalize_layers[0](dec_out, 'denorm')
return dec_out
def imputation(self, x_enc, x_mark_enc, mask):
means = torch.sum(x_enc, dim=1) / torch.sum(mask == 1, dim=1)
means = means.unsqueeze(1).detach()
x_enc = x_enc - means
x_enc = x_enc.masked_fill(mask == 0, 0)
stdev = torch.sqrt(torch.sum(x_enc * x_enc, dim=1) /
torch.sum(mask == 1, dim=1) + 1e-5)
stdev = stdev.unsqueeze(1).detach()
x_enc /= stdev
B, T, N = x_enc.size()
x_enc, x_mark_enc = self.__multi_scale_process_inputs(x_enc, x_mark_enc)
x_list = []
x_mark_list = []
if x_mark_enc is not None:
for i, x, x_mark in zip(range(len(x_enc)), x_enc, x_mark_enc):
B, T, N = x.size()
if self.channel_independence == 1:
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
x_list.append(x)
x_mark = x_mark.repeat(N, 1, 1)
x_mark_list.append(x_mark)
else:
for i, x in zip(range(len(x_enc)), x_enc, ):
B, T, N = x.size()
if self.channel_independence == 1:
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
x_list.append(x)
# embedding
enc_out_list = []
for x in x_list:
enc_out = self.enc_embedding(x, None) # [B,T,C]
enc_out_list.append(enc_out)
# MultiScale-CrissCrossAttention as encoder for past
for i in range(self.layer):
enc_out_list = self.pdm_blocks[i](enc_out_list)
dec_out = self.projection_layer(enc_out_list[0])
dec_out = dec_out.reshape(B, self.configs.c_out, -1).permute(0, 2, 1).contiguous()
dec_out = dec_out * \
(stdev[:, 0, :].unsqueeze(1).repeat(1, self.seq_len, 1))
dec_out = dec_out + \
(means[:, 0, :].unsqueeze(1).repeat(1, self.seq_len, 1))
return dec_out
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
return dec_out
if self.task_name == 'imputation':
dec_out = self.imputation(x_enc, x_mark_enc, mask)
return dec_out # [B, L, D]
if self.task_name == 'anomaly_detection':
dec_out = self.anomaly_detection(x_enc)
return dec_out # [B, L, D]
if self.task_name == 'classification':
dec_out = self.classification(x_enc, x_mark_enc)
return dec_out # [B, N]
else:
raise ValueError('Other tasks implemented yet')

View File