first commit
This commit is contained in:
163
layers/AutoCorrelation.py
Normal file
163
layers/AutoCorrelation.py
Normal file
@ -0,0 +1,163 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import math
|
||||
from math import sqrt
|
||||
import os
|
||||
|
||||
|
||||
class AutoCorrelation(nn.Module):
|
||||
"""
|
||||
AutoCorrelation Mechanism with the following two phases:
|
||||
(1) period-based dependencies discovery
|
||||
(2) time delay aggregation
|
||||
This block can replace the self-attention family mechanism seamlessly.
|
||||
"""
|
||||
|
||||
def __init__(self, mask_flag=True, factor=1, scale=None, attention_dropout=0.1, output_attention=False):
|
||||
super(AutoCorrelation, self).__init__()
|
||||
self.factor = factor
|
||||
self.scale = scale
|
||||
self.mask_flag = mask_flag
|
||||
self.output_attention = output_attention
|
||||
self.dropout = nn.Dropout(attention_dropout)
|
||||
|
||||
def time_delay_agg_training(self, values, corr):
|
||||
"""
|
||||
SpeedUp version of Autocorrelation (a batch-normalization style design)
|
||||
This is for the training phase.
|
||||
"""
|
||||
head = values.shape[1]
|
||||
channel = values.shape[2]
|
||||
length = values.shape[3]
|
||||
# find top k
|
||||
top_k = int(self.factor * math.log(length))
|
||||
mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
|
||||
index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1]
|
||||
weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1)
|
||||
# update corr
|
||||
tmp_corr = torch.softmax(weights, dim=-1)
|
||||
# aggregation
|
||||
tmp_values = values
|
||||
delays_agg = torch.zeros_like(values).float()
|
||||
for i in range(top_k):
|
||||
pattern = torch.roll(tmp_values, -int(index[i]), -1)
|
||||
delays_agg = delays_agg + pattern * \
|
||||
(tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))
|
||||
return delays_agg
|
||||
|
||||
def time_delay_agg_inference(self, values, corr):
|
||||
"""
|
||||
SpeedUp version of Autocorrelation (a batch-normalization style design)
|
||||
This is for the inference phase.
|
||||
"""
|
||||
batch = values.shape[0]
|
||||
head = values.shape[1]
|
||||
channel = values.shape[2]
|
||||
length = values.shape[3]
|
||||
# index init
|
||||
init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1).to(values.device)
|
||||
# find top k
|
||||
top_k = int(self.factor * math.log(length))
|
||||
mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
|
||||
weights, delay = torch.topk(mean_value, top_k, dim=-1)
|
||||
# update corr
|
||||
tmp_corr = torch.softmax(weights, dim=-1)
|
||||
# aggregation
|
||||
tmp_values = values.repeat(1, 1, 1, 2)
|
||||
delays_agg = torch.zeros_like(values).float()
|
||||
for i in range(top_k):
|
||||
tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)
|
||||
pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
|
||||
delays_agg = delays_agg + pattern * \
|
||||
(tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))
|
||||
return delays_agg
|
||||
|
||||
def time_delay_agg_full(self, values, corr):
|
||||
"""
|
||||
Standard version of Autocorrelation
|
||||
"""
|
||||
batch = values.shape[0]
|
||||
head = values.shape[1]
|
||||
channel = values.shape[2]
|
||||
length = values.shape[3]
|
||||
# index init
|
||||
init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1).to(values.device)
|
||||
# find top k
|
||||
top_k = int(self.factor * math.log(length))
|
||||
weights, delay = torch.topk(corr, top_k, dim=-1)
|
||||
# update corr
|
||||
tmp_corr = torch.softmax(weights, dim=-1)
|
||||
# aggregation
|
||||
tmp_values = values.repeat(1, 1, 1, 2)
|
||||
delays_agg = torch.zeros_like(values).float()
|
||||
for i in range(top_k):
|
||||
tmp_delay = init_index + delay[..., i].unsqueeze(-1)
|
||||
pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
|
||||
delays_agg = delays_agg + pattern * (tmp_corr[..., i].unsqueeze(-1))
|
||||
return delays_agg
|
||||
|
||||
def forward(self, queries, keys, values, attn_mask):
|
||||
B, L, H, E = queries.shape
|
||||
_, S, _, D = values.shape
|
||||
if L > S:
|
||||
zeros = torch.zeros_like(queries[:, :(L - S), :]).float()
|
||||
values = torch.cat([values, zeros], dim=1)
|
||||
keys = torch.cat([keys, zeros], dim=1)
|
||||
else:
|
||||
values = values[:, :L, :, :]
|
||||
keys = keys[:, :L, :, :]
|
||||
|
||||
# period-based dependencies
|
||||
q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1)
|
||||
k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1)
|
||||
res = q_fft * torch.conj(k_fft)
|
||||
corr = torch.fft.irfft(res, dim=-1)
|
||||
|
||||
# time delay agg
|
||||
if self.training:
|
||||
V = self.time_delay_agg_training(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)
|
||||
else:
|
||||
V = self.time_delay_agg_inference(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)
|
||||
|
||||
if self.output_attention:
|
||||
return (V.contiguous(), corr.permute(0, 3, 1, 2))
|
||||
else:
|
||||
return (V.contiguous(), None)
|
||||
|
||||
|
||||
class AutoCorrelationLayer(nn.Module):
|
||||
def __init__(self, correlation, d_model, n_heads, d_keys=None,
|
||||
d_values=None):
|
||||
super(AutoCorrelationLayer, self).__init__()
|
||||
|
||||
d_keys = d_keys or (d_model // n_heads)
|
||||
d_values = d_values or (d_model // n_heads)
|
||||
|
||||
self.inner_correlation = correlation
|
||||
self.query_projection = nn.Linear(d_model, d_keys * n_heads)
|
||||
self.key_projection = nn.Linear(d_model, d_keys * n_heads)
|
||||
self.value_projection = nn.Linear(d_model, d_values * n_heads)
|
||||
self.out_projection = nn.Linear(d_values * n_heads, d_model)
|
||||
self.n_heads = n_heads
|
||||
|
||||
def forward(self, queries, keys, values, attn_mask):
|
||||
B, L, _ = queries.shape
|
||||
_, S, _ = keys.shape
|
||||
H = self.n_heads
|
||||
|
||||
queries = self.query_projection(queries).view(B, L, H, -1)
|
||||
keys = self.key_projection(keys).view(B, S, H, -1)
|
||||
values = self.value_projection(values).view(B, S, H, -1)
|
||||
|
||||
out, attn = self.inner_correlation(
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
attn_mask
|
||||
)
|
||||
out = out.view(B, L, -1)
|
||||
|
||||
return self.out_projection(out), attn
|
203
layers/Autoformer_EncDec.py
Normal file
203
layers/Autoformer_EncDec.py
Normal file
@ -0,0 +1,203 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class my_Layernorm(nn.Module):
|
||||
"""
|
||||
Special designed layernorm for the seasonal part
|
||||
"""
|
||||
|
||||
def __init__(self, channels):
|
||||
super(my_Layernorm, self).__init__()
|
||||
self.layernorm = nn.LayerNorm(channels)
|
||||
|
||||
def forward(self, x):
|
||||
x_hat = self.layernorm(x)
|
||||
bias = torch.mean(x_hat, dim=1).unsqueeze(1).repeat(1, x.shape[1], 1)
|
||||
return x_hat - bias
|
||||
|
||||
|
||||
class moving_avg(nn.Module):
|
||||
"""
|
||||
Moving average block to highlight the trend of time series
|
||||
"""
|
||||
|
||||
def __init__(self, kernel_size, stride):
|
||||
super(moving_avg, self).__init__()
|
||||
self.kernel_size = kernel_size
|
||||
self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
# padding on the both ends of time series
|
||||
front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
|
||||
end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
|
||||
x = torch.cat([front, x, end], dim=1)
|
||||
x = self.avg(x.permute(0, 2, 1))
|
||||
x = x.permute(0, 2, 1)
|
||||
return x
|
||||
|
||||
|
||||
class series_decomp(nn.Module):
|
||||
"""
|
||||
Series decomposition block
|
||||
"""
|
||||
|
||||
def __init__(self, kernel_size):
|
||||
super(series_decomp, self).__init__()
|
||||
self.moving_avg = moving_avg(kernel_size, stride=1)
|
||||
|
||||
def forward(self, x):
|
||||
moving_mean = self.moving_avg(x)
|
||||
res = x - moving_mean
|
||||
return res, moving_mean
|
||||
|
||||
|
||||
class series_decomp_multi(nn.Module):
|
||||
"""
|
||||
Multiple Series decomposition block from FEDformer
|
||||
"""
|
||||
|
||||
def __init__(self, kernel_size):
|
||||
super(series_decomp_multi, self).__init__()
|
||||
self.kernel_size = kernel_size
|
||||
self.series_decomp = [series_decomp(kernel) for kernel in kernel_size]
|
||||
|
||||
def forward(self, x):
|
||||
moving_mean = []
|
||||
res = []
|
||||
for func in self.series_decomp:
|
||||
sea, moving_avg = func(x)
|
||||
moving_mean.append(moving_avg)
|
||||
res.append(sea)
|
||||
|
||||
sea = sum(res) / len(res)
|
||||
moving_mean = sum(moving_mean) / len(moving_mean)
|
||||
return sea, moving_mean
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
"""
|
||||
Autoformer encoder layer with the progressive decomposition architecture
|
||||
"""
|
||||
|
||||
def __init__(self, attention, d_model, d_ff=None, moving_avg=25, dropout=0.1, activation="relu"):
|
||||
super(EncoderLayer, self).__init__()
|
||||
d_ff = d_ff or 4 * d_model
|
||||
self.attention = attention
|
||||
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False)
|
||||
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False)
|
||||
self.decomp1 = series_decomp(moving_avg)
|
||||
self.decomp2 = series_decomp(moving_avg)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.activation = F.relu if activation == "relu" else F.gelu
|
||||
|
||||
def forward(self, x, attn_mask=None):
|
||||
new_x, attn = self.attention(
|
||||
x, x, x,
|
||||
attn_mask=attn_mask
|
||||
)
|
||||
x = x + self.dropout(new_x)
|
||||
x, _ = self.decomp1(x)
|
||||
y = x
|
||||
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
|
||||
y = self.dropout(self.conv2(y).transpose(-1, 1))
|
||||
res, _ = self.decomp2(x + y)
|
||||
return res, attn
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
"""
|
||||
Autoformer encoder
|
||||
"""
|
||||
|
||||
def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
|
||||
super(Encoder, self).__init__()
|
||||
self.attn_layers = nn.ModuleList(attn_layers)
|
||||
self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
|
||||
self.norm = norm_layer
|
||||
|
||||
def forward(self, x, attn_mask=None):
|
||||
attns = []
|
||||
if self.conv_layers is not None:
|
||||
for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers):
|
||||
x, attn = attn_layer(x, attn_mask=attn_mask)
|
||||
x = conv_layer(x)
|
||||
attns.append(attn)
|
||||
x, attn = self.attn_layers[-1](x)
|
||||
attns.append(attn)
|
||||
else:
|
||||
for attn_layer in self.attn_layers:
|
||||
x, attn = attn_layer(x, attn_mask=attn_mask)
|
||||
attns.append(attn)
|
||||
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
|
||||
return x, attns
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
"""
|
||||
Autoformer decoder layer with the progressive decomposition architecture
|
||||
"""
|
||||
|
||||
def __init__(self, self_attention, cross_attention, d_model, c_out, d_ff=None,
|
||||
moving_avg=25, dropout=0.1, activation="relu"):
|
||||
super(DecoderLayer, self).__init__()
|
||||
d_ff = d_ff or 4 * d_model
|
||||
self.self_attention = self_attention
|
||||
self.cross_attention = cross_attention
|
||||
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False)
|
||||
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False)
|
||||
self.decomp1 = series_decomp(moving_avg)
|
||||
self.decomp2 = series_decomp(moving_avg)
|
||||
self.decomp3 = series_decomp(moving_avg)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.projection = nn.Conv1d(in_channels=d_model, out_channels=c_out, kernel_size=3, stride=1, padding=1,
|
||||
padding_mode='circular', bias=False)
|
||||
self.activation = F.relu if activation == "relu" else F.gelu
|
||||
|
||||
def forward(self, x, cross, x_mask=None, cross_mask=None):
|
||||
x = x + self.dropout(self.self_attention(
|
||||
x, x, x,
|
||||
attn_mask=x_mask
|
||||
)[0])
|
||||
x, trend1 = self.decomp1(x)
|
||||
x = x + self.dropout(self.cross_attention(
|
||||
x, cross, cross,
|
||||
attn_mask=cross_mask
|
||||
)[0])
|
||||
x, trend2 = self.decomp2(x)
|
||||
y = x
|
||||
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
|
||||
y = self.dropout(self.conv2(y).transpose(-1, 1))
|
||||
x, trend3 = self.decomp3(x + y)
|
||||
|
||||
residual_trend = trend1 + trend2 + trend3
|
||||
residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose(1, 2)
|
||||
return x, residual_trend
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""
|
||||
Autoformer encoder
|
||||
"""
|
||||
|
||||
def __init__(self, layers, norm_layer=None, projection=None):
|
||||
super(Decoder, self).__init__()
|
||||
self.layers = nn.ModuleList(layers)
|
||||
self.norm = norm_layer
|
||||
self.projection = projection
|
||||
|
||||
def forward(self, x, cross, x_mask=None, cross_mask=None, trend=None):
|
||||
for layer in self.layers:
|
||||
x, residual_trend = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask)
|
||||
trend = trend + residual_trend
|
||||
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
|
||||
if self.projection is not None:
|
||||
x = self.projection(x)
|
||||
return x, trend
|
60
layers/Conv_Blocks.py
Normal file
60
layers/Conv_Blocks.py
Normal file
@ -0,0 +1,60 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Inception_Block_V1(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True):
|
||||
super(Inception_Block_V1, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.num_kernels = num_kernels
|
||||
kernels = []
|
||||
for i in range(self.num_kernels):
|
||||
kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=2 * i + 1, padding=i))
|
||||
self.kernels = nn.ModuleList(kernels)
|
||||
if init_weight:
|
||||
self._initialize_weights()
|
||||
|
||||
def _initialize_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
res_list = []
|
||||
for i in range(self.num_kernels):
|
||||
res_list.append(self.kernels[i](x))
|
||||
res = torch.stack(res_list, dim=-1).mean(-1)
|
||||
return res
|
||||
|
||||
|
||||
class Inception_Block_V2(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True):
|
||||
super(Inception_Block_V2, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.num_kernels = num_kernels
|
||||
kernels = []
|
||||
for i in range(self.num_kernels // 2):
|
||||
kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=[1, 2 * i + 3], padding=[0, i + 1]))
|
||||
kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=[2 * i + 3, 1], padding=[i + 1, 0]))
|
||||
kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=1))
|
||||
self.kernels = nn.ModuleList(kernels)
|
||||
if init_weight:
|
||||
self._initialize_weights()
|
||||
|
||||
def _initialize_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
res_list = []
|
||||
for i in range(self.num_kernels // 2 * 2 + 1):
|
||||
res_list.append(self.kernels[i](x))
|
||||
res = torch.stack(res_list, dim=-1).mean(-1)
|
||||
return res
|
131
layers/Crossformer_EncDec.py
Normal file
131
layers/Crossformer_EncDec.py
Normal file
@ -0,0 +1,131 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange, repeat
|
||||
from layers.SelfAttention_Family import TwoStageAttentionLayer
|
||||
|
||||
|
||||
class SegMerging(nn.Module):
|
||||
def __init__(self, d_model, win_size, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.win_size = win_size
|
||||
self.linear_trans = nn.Linear(win_size * d_model, d_model)
|
||||
self.norm = norm_layer(win_size * d_model)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size, ts_d, seg_num, d_model = x.shape
|
||||
pad_num = seg_num % self.win_size
|
||||
if pad_num != 0:
|
||||
pad_num = self.win_size - pad_num
|
||||
x = torch.cat((x, x[:, :, -pad_num:, :]), dim=-2)
|
||||
|
||||
seg_to_merge = []
|
||||
for i in range(self.win_size):
|
||||
seg_to_merge.append(x[:, :, i::self.win_size, :])
|
||||
x = torch.cat(seg_to_merge, -1)
|
||||
|
||||
x = self.norm(x)
|
||||
x = self.linear_trans(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class scale_block(nn.Module):
|
||||
def __init__(self, configs, win_size, d_model, n_heads, d_ff, depth, dropout, \
|
||||
seg_num=10, factor=10):
|
||||
super(scale_block, self).__init__()
|
||||
|
||||
if win_size > 1:
|
||||
self.merge_layer = SegMerging(d_model, win_size, nn.LayerNorm)
|
||||
else:
|
||||
self.merge_layer = None
|
||||
|
||||
self.encode_layers = nn.ModuleList()
|
||||
|
||||
for i in range(depth):
|
||||
self.encode_layers.append(TwoStageAttentionLayer(configs, seg_num, factor, d_model, n_heads, \
|
||||
d_ff, dropout))
|
||||
|
||||
def forward(self, x, attn_mask=None, tau=None, delta=None):
|
||||
_, ts_dim, _, _ = x.shape
|
||||
|
||||
if self.merge_layer is not None:
|
||||
x = self.merge_layer(x)
|
||||
|
||||
for layer in self.encode_layers:
|
||||
x = layer(x)
|
||||
|
||||
return x, None
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, attn_layers):
|
||||
super(Encoder, self).__init__()
|
||||
self.encode_blocks = nn.ModuleList(attn_layers)
|
||||
|
||||
def forward(self, x):
|
||||
encode_x = []
|
||||
encode_x.append(x)
|
||||
|
||||
for block in self.encode_blocks:
|
||||
x, attns = block(x)
|
||||
encode_x.append(x)
|
||||
|
||||
return encode_x, None
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
def __init__(self, self_attention, cross_attention, seg_len, d_model, d_ff=None, dropout=0.1):
|
||||
super(DecoderLayer, self).__init__()
|
||||
self.self_attention = self_attention
|
||||
self.cross_attention = cross_attention
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.MLP1 = nn.Sequential(nn.Linear(d_model, d_model),
|
||||
nn.GELU(),
|
||||
nn.Linear(d_model, d_model))
|
||||
self.linear_pred = nn.Linear(d_model, seg_len)
|
||||
|
||||
def forward(self, x, cross):
|
||||
batch = x.shape[0]
|
||||
x = self.self_attention(x)
|
||||
x = rearrange(x, 'b ts_d out_seg_num d_model -> (b ts_d) out_seg_num d_model')
|
||||
|
||||
cross = rearrange(cross, 'b ts_d in_seg_num d_model -> (b ts_d) in_seg_num d_model')
|
||||
tmp, attn = self.cross_attention(x, cross, cross, None, None, None,)
|
||||
x = x + self.dropout(tmp)
|
||||
y = x = self.norm1(x)
|
||||
y = self.MLP1(y)
|
||||
dec_output = self.norm2(x + y)
|
||||
|
||||
dec_output = rearrange(dec_output, '(b ts_d) seg_dec_num d_model -> b ts_d seg_dec_num d_model', b=batch)
|
||||
layer_predict = self.linear_pred(dec_output)
|
||||
layer_predict = rearrange(layer_predict, 'b out_d seg_num seg_len -> b (out_d seg_num) seg_len')
|
||||
|
||||
return dec_output, layer_predict
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, layers):
|
||||
super(Decoder, self).__init__()
|
||||
self.decode_layers = nn.ModuleList(layers)
|
||||
|
||||
|
||||
def forward(self, x, cross):
|
||||
final_predict = None
|
||||
i = 0
|
||||
|
||||
ts_d = x.shape[1]
|
||||
for layer in self.decode_layers:
|
||||
cross_enc = cross[i]
|
||||
x, layer_predict = layer(x, cross_enc)
|
||||
if final_predict is None:
|
||||
final_predict = layer_predict
|
||||
else:
|
||||
final_predict = final_predict + layer_predict
|
||||
i += 1
|
||||
|
||||
final_predict = rearrange(final_predict, 'b (out_d seg_num) seg_len -> b (seg_num seg_len) out_d', out_d=ts_d)
|
||||
|
||||
return final_predict
|
22
layers/DECOMP.py
Normal file
22
layers/DECOMP.py
Normal file
@ -0,0 +1,22 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from layers.EMA import EMA
|
||||
from layers.DEMA import DEMA
|
||||
|
||||
class DECOMP(nn.Module):
|
||||
"""
|
||||
Series decomposition block using EMA/DEMA
|
||||
"""
|
||||
def __init__(self, ma_type, alpha, beta):
|
||||
super(DECOMP, self).__init__()
|
||||
if ma_type == 'ema':
|
||||
self.ma = EMA(alpha)
|
||||
elif ma_type == 'dema':
|
||||
self.ma = DEMA(alpha, beta)
|
||||
else:
|
||||
raise ValueError(f"Unsupported ma_type: {ma_type}. Use 'ema' or 'dema'")
|
||||
|
||||
def forward(self, x):
|
||||
moving_average = self.ma(x)
|
||||
res = x - moving_average
|
||||
return res, moving_average
|
23
layers/DEMA.py
Normal file
23
layers/DEMA.py
Normal file
@ -0,0 +1,23 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
class DEMA(nn.Module):
|
||||
"""
|
||||
Double Exponential Moving Average (DEMA) block to highlight the trend of time series
|
||||
"""
|
||||
def __init__(self, alpha, beta):
|
||||
super(DEMA, self).__init__()
|
||||
self.alpha = alpha.to(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
|
||||
self.beta = beta.to(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
|
||||
|
||||
def forward(self, x):
|
||||
s_prev = x[:, 0, :]
|
||||
b = x[:, 1, :] - s_prev
|
||||
res = [s_prev.unsqueeze(1)]
|
||||
for t in range(1, x.shape[1]):
|
||||
xt = x[:, t, :]
|
||||
s = self.alpha * xt + (1 - self.alpha) * (s_prev + b)
|
||||
b = self.beta * (s - s_prev) + (1 - self.beta) * b
|
||||
s_prev = s
|
||||
res.append(s.unsqueeze(1))
|
||||
return torch.cat(res, dim=1)
|
1268
layers/DWT_Decomposition.py
Normal file
1268
layers/DWT_Decomposition.py
Normal file
File diff suppressed because it is too large
Load Diff
23
layers/EMA.py
Normal file
23
layers/EMA.py
Normal file
@ -0,0 +1,23 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
class EMA(nn.Module):
|
||||
"""
|
||||
Exponential Moving Average (EMA) block to highlight the trend of time series
|
||||
"""
|
||||
def __init__(self, alpha):
|
||||
super(EMA, self).__init__()
|
||||
self.alpha = alpha
|
||||
|
||||
def forward(self, x):
|
||||
# x: [Batch, Input, Channel]
|
||||
_, t, _ = x.shape
|
||||
powers = torch.flip(torch.arange(t, dtype=torch.double), dims=(0,))
|
||||
weights = torch.pow((1 - self.alpha), powers).to(x.device)
|
||||
divisor = weights.clone()
|
||||
weights[1:] = weights[1:] * self.alpha
|
||||
weights = weights.reshape(1, t, 1)
|
||||
divisor = divisor.reshape(1, t, 1)
|
||||
x = torch.cumsum(x * weights, dim=1)
|
||||
x = torch.div(x, divisor)
|
||||
return x.to(torch.float32)
|
334
layers/ETSformer_EncDec.py
Normal file
334
layers/ETSformer_EncDec.py
Normal file
@ -0,0 +1,334 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.fft as fft
|
||||
from einops import rearrange, reduce, repeat
|
||||
import math, random
|
||||
from scipy.fftpack import next_fast_len
|
||||
|
||||
|
||||
class Transform:
|
||||
def __init__(self, sigma):
|
||||
self.sigma = sigma
|
||||
|
||||
@torch.no_grad()
|
||||
def transform(self, x):
|
||||
return self.jitter(self.shift(self.scale(x)))
|
||||
|
||||
def jitter(self, x):
|
||||
return x + (torch.randn(x.shape).to(x.device) * self.sigma)
|
||||
|
||||
def scale(self, x):
|
||||
return x * (torch.randn(x.size(-1)).to(x.device) * self.sigma + 1)
|
||||
|
||||
def shift(self, x):
|
||||
return x + (torch.randn(x.size(-1)).to(x.device) * self.sigma)
|
||||
|
||||
|
||||
def conv1d_fft(f, g, dim=-1):
|
||||
N = f.size(dim)
|
||||
M = g.size(dim)
|
||||
|
||||
fast_len = next_fast_len(N + M - 1)
|
||||
|
||||
F_f = fft.rfft(f, fast_len, dim=dim)
|
||||
F_g = fft.rfft(g, fast_len, dim=dim)
|
||||
|
||||
F_fg = F_f * F_g.conj()
|
||||
out = fft.irfft(F_fg, fast_len, dim=dim)
|
||||
out = out.roll((-1,), dims=(dim,))
|
||||
idx = torch.as_tensor(range(fast_len - N, fast_len)).to(out.device)
|
||||
out = out.index_select(dim, idx)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ExponentialSmoothing(nn.Module):
|
||||
|
||||
def __init__(self, dim, nhead, dropout=0.1, aux=False):
|
||||
super().__init__()
|
||||
self._smoothing_weight = nn.Parameter(torch.randn(nhead, 1))
|
||||
self.v0 = nn.Parameter(torch.randn(1, 1, nhead, dim))
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
if aux:
|
||||
self.aux_dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, values, aux_values=None):
|
||||
b, t, h, d = values.shape
|
||||
|
||||
init_weight, weight = self.get_exponential_weight(t)
|
||||
output = conv1d_fft(self.dropout(values), weight, dim=1)
|
||||
output = init_weight * self.v0 + output
|
||||
|
||||
if aux_values is not None:
|
||||
aux_weight = weight / (1 - self.weight) * self.weight
|
||||
aux_output = conv1d_fft(self.aux_dropout(aux_values), aux_weight)
|
||||
output = output + aux_output
|
||||
|
||||
return output
|
||||
|
||||
def get_exponential_weight(self, T):
|
||||
# Generate array [0, 1, ..., T-1]
|
||||
powers = torch.arange(T, dtype=torch.float, device=self.weight.device)
|
||||
|
||||
# (1 - \alpha) * \alpha^t, for all t = T-1, T-2, ..., 0]
|
||||
weight = (1 - self.weight) * (self.weight ** torch.flip(powers, dims=(0,)))
|
||||
|
||||
# \alpha^t for all t = 1, 2, ..., T
|
||||
init_weight = self.weight ** (powers + 1)
|
||||
|
||||
return rearrange(init_weight, 'h t -> 1 t h 1'), \
|
||||
rearrange(weight, 'h t -> 1 t h 1')
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return torch.sigmoid(self._smoothing_weight)
|
||||
|
||||
|
||||
class Feedforward(nn.Module):
|
||||
def __init__(self, d_model, dim_feedforward, dropout=0.1, activation='sigmoid'):
|
||||
# Implementation of Feedforward model
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
self.activation = getattr(F, activation)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear2(self.dropout1(self.activation(self.linear1(x))))
|
||||
return self.dropout2(x)
|
||||
|
||||
|
||||
class GrowthLayer(nn.Module):
|
||||
|
||||
def __init__(self, d_model, nhead, d_head=None, dropout=0.1):
|
||||
super().__init__()
|
||||
self.d_head = d_head or (d_model // nhead)
|
||||
self.d_model = d_model
|
||||
self.nhead = nhead
|
||||
|
||||
self.z0 = nn.Parameter(torch.randn(self.nhead, self.d_head))
|
||||
self.in_proj = nn.Linear(self.d_model, self.d_head * self.nhead)
|
||||
self.es = ExponentialSmoothing(self.d_head, self.nhead, dropout=dropout)
|
||||
self.out_proj = nn.Linear(self.d_head * self.nhead, self.d_model)
|
||||
|
||||
assert self.d_head * self.nhead == self.d_model, "d_model must be divisible by nhead"
|
||||
|
||||
def forward(self, inputs):
|
||||
"""
|
||||
:param inputs: shape: (batch, seq_len, dim)
|
||||
:return: shape: (batch, seq_len, dim)
|
||||
"""
|
||||
b, t, d = inputs.shape
|
||||
values = self.in_proj(inputs).view(b, t, self.nhead, -1)
|
||||
values = torch.cat([repeat(self.z0, 'h d -> b 1 h d', b=b), values], dim=1)
|
||||
values = values[:, 1:] - values[:, :-1]
|
||||
out = self.es(values)
|
||||
out = torch.cat([repeat(self.es.v0, '1 1 h d -> b 1 h d', b=b), out], dim=1)
|
||||
out = rearrange(out, 'b t h d -> b t (h d)')
|
||||
return self.out_proj(out)
|
||||
|
||||
|
||||
class FourierLayer(nn.Module):
|
||||
|
||||
def __init__(self, d_model, pred_len, k=None, low_freq=1):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.pred_len = pred_len
|
||||
self.k = k
|
||||
self.low_freq = low_freq
|
||||
|
||||
def forward(self, x):
|
||||
"""x: (b, t, d)"""
|
||||
b, t, d = x.shape
|
||||
x_freq = fft.rfft(x, dim=1)
|
||||
|
||||
if t % 2 == 0:
|
||||
x_freq = x_freq[:, self.low_freq:-1]
|
||||
f = fft.rfftfreq(t)[self.low_freq:-1]
|
||||
else:
|
||||
x_freq = x_freq[:, self.low_freq:]
|
||||
f = fft.rfftfreq(t)[self.low_freq:]
|
||||
|
||||
x_freq, index_tuple = self.topk_freq(x_freq)
|
||||
f = repeat(f, 'f -> b f d', b=x_freq.size(0), d=x_freq.size(2))
|
||||
f = rearrange(f[index_tuple], 'b f d -> b f () d').to(x_freq.device)
|
||||
|
||||
return self.extrapolate(x_freq, f, t)
|
||||
|
||||
def extrapolate(self, x_freq, f, t):
|
||||
x_freq = torch.cat([x_freq, x_freq.conj()], dim=1)
|
||||
f = torch.cat([f, -f], dim=1)
|
||||
t_val = rearrange(torch.arange(t + self.pred_len, dtype=torch.float),
|
||||
't -> () () t ()').to(x_freq.device)
|
||||
|
||||
amp = rearrange(x_freq.abs() / t, 'b f d -> b f () d')
|
||||
phase = rearrange(x_freq.angle(), 'b f d -> b f () d')
|
||||
|
||||
x_time = amp * torch.cos(2 * math.pi * f * t_val + phase)
|
||||
|
||||
return reduce(x_time, 'b f t d -> b t d', 'sum')
|
||||
|
||||
def topk_freq(self, x_freq):
|
||||
values, indices = torch.topk(x_freq.abs(), self.k, dim=1, largest=True, sorted=True)
|
||||
mesh_a, mesh_b = torch.meshgrid(torch.arange(x_freq.size(0)), torch.arange(x_freq.size(2)))
|
||||
index_tuple = (mesh_a.unsqueeze(1).to(indices.device), indices, mesh_b.unsqueeze(1).to(indices.device))
|
||||
x_freq = x_freq[index_tuple]
|
||||
|
||||
return x_freq, index_tuple
|
||||
|
||||
|
||||
class LevelLayer(nn.Module):
|
||||
|
||||
def __init__(self, d_model, c_out, dropout=0.1):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.c_out = c_out
|
||||
|
||||
self.es = ExponentialSmoothing(1, self.c_out, dropout=dropout, aux=True)
|
||||
self.growth_pred = nn.Linear(self.d_model, self.c_out)
|
||||
self.season_pred = nn.Linear(self.d_model, self.c_out)
|
||||
|
||||
def forward(self, level, growth, season):
|
||||
b, t, _ = level.shape
|
||||
growth = self.growth_pred(growth).view(b, t, self.c_out, 1)
|
||||
season = self.season_pred(season).view(b, t, self.c_out, 1)
|
||||
growth = growth.view(b, t, self.c_out, 1)
|
||||
season = season.view(b, t, self.c_out, 1)
|
||||
level = level.view(b, t, self.c_out, 1)
|
||||
out = self.es(level - season, aux_values=growth)
|
||||
out = rearrange(out, 'b t h d -> b t (h d)')
|
||||
return out
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, d_model, nhead, c_out, seq_len, pred_len, k, dim_feedforward=None, dropout=0.1,
|
||||
activation='sigmoid', layer_norm_eps=1e-5):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.nhead = nhead
|
||||
self.c_out = c_out
|
||||
self.seq_len = seq_len
|
||||
self.pred_len = pred_len
|
||||
dim_feedforward = dim_feedforward or 4 * d_model
|
||||
self.dim_feedforward = dim_feedforward
|
||||
|
||||
self.growth_layer = GrowthLayer(d_model, nhead, dropout=dropout)
|
||||
self.seasonal_layer = FourierLayer(d_model, pred_len, k=k)
|
||||
self.level_layer = LevelLayer(d_model, c_out, dropout=dropout)
|
||||
|
||||
# Implementation of Feedforward model
|
||||
self.ff = Feedforward(d_model, dim_feedforward, dropout=dropout, activation=activation)
|
||||
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
||||
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
||||
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, res, level, attn_mask=None):
|
||||
season = self._season_block(res)
|
||||
res = res - season[:, :-self.pred_len]
|
||||
growth = self._growth_block(res)
|
||||
res = self.norm1(res - growth[:, 1:])
|
||||
res = self.norm2(res + self.ff(res))
|
||||
|
||||
level = self.level_layer(level, growth[:, :-1], season[:, :-self.pred_len])
|
||||
return res, level, growth, season
|
||||
|
||||
def _growth_block(self, x):
|
||||
x = self.growth_layer(x)
|
||||
return self.dropout1(x)
|
||||
|
||||
def _season_block(self, x):
|
||||
x = self.seasonal_layer(x)
|
||||
return self.dropout2(x)
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
|
||||
def __init__(self, layers):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList(layers)
|
||||
|
||||
def forward(self, res, level, attn_mask=None):
|
||||
growths = []
|
||||
seasons = []
|
||||
for layer in self.layers:
|
||||
res, level, growth, season = layer(res, level, attn_mask=None)
|
||||
growths.append(growth)
|
||||
seasons.append(season)
|
||||
|
||||
return level, growths, seasons
|
||||
|
||||
|
||||
class DampingLayer(nn.Module):
|
||||
|
||||
def __init__(self, pred_len, nhead, dropout=0.1):
|
||||
super().__init__()
|
||||
self.pred_len = pred_len
|
||||
self.nhead = nhead
|
||||
self._damping_factor = nn.Parameter(torch.randn(1, nhead))
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
x = repeat(x, 'b 1 d -> b t d', t=self.pred_len)
|
||||
b, t, d = x.shape
|
||||
|
||||
powers = torch.arange(self.pred_len).to(self._damping_factor.device) + 1
|
||||
powers = powers.view(self.pred_len, 1)
|
||||
damping_factors = self.damping_factor ** powers
|
||||
damping_factors = damping_factors.cumsum(dim=0)
|
||||
x = x.view(b, t, self.nhead, -1)
|
||||
x = self.dropout(x) * damping_factors.unsqueeze(-1)
|
||||
return x.view(b, t, d)
|
||||
|
||||
@property
|
||||
def damping_factor(self):
|
||||
return torch.sigmoid(self._damping_factor)
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, d_model, nhead, c_out, pred_len, dropout=0.1):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.nhead = nhead
|
||||
self.c_out = c_out
|
||||
self.pred_len = pred_len
|
||||
|
||||
self.growth_damping = DampingLayer(pred_len, nhead, dropout=dropout)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, growth, season):
|
||||
growth_horizon = self.growth_damping(growth[:, -1:])
|
||||
growth_horizon = self.dropout1(growth_horizon)
|
||||
|
||||
seasonal_horizon = season[:, -self.pred_len:]
|
||||
return growth_horizon, seasonal_horizon
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
|
||||
def __init__(self, layers):
|
||||
super().__init__()
|
||||
self.d_model = layers[0].d_model
|
||||
self.c_out = layers[0].c_out
|
||||
self.pred_len = layers[0].pred_len
|
||||
self.nhead = layers[0].nhead
|
||||
|
||||
self.layers = nn.ModuleList(layers)
|
||||
self.pred = nn.Linear(self.d_model, self.c_out)
|
||||
|
||||
def forward(self, growths, seasons):
|
||||
growth_repr = []
|
||||
season_repr = []
|
||||
|
||||
for idx, layer in enumerate(self.layers):
|
||||
growth_horizon, season_horizon = layer(growths[idx], seasons[idx])
|
||||
growth_repr.append(growth_horizon)
|
||||
season_repr.append(season_horizon)
|
||||
growth_repr = sum(growth_repr)
|
||||
season_repr = sum(season_repr)
|
||||
return self.pred(growth_repr), self.pred(season_repr)
|
190
layers/Embed.py
Normal file
190
layers/Embed.py
Normal file
@ -0,0 +1,190 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import weight_norm
|
||||
import math
|
||||
|
||||
|
||||
class PositionalEmbedding(nn.Module):
|
||||
def __init__(self, d_model, max_len=5000):
|
||||
super(PositionalEmbedding, self).__init__()
|
||||
# Compute the positional encodings once in log space.
|
||||
pe = torch.zeros(max_len, d_model).float()
|
||||
pe.require_grad = False
|
||||
|
||||
position = torch.arange(0, max_len).float().unsqueeze(1)
|
||||
div_term = (torch.arange(0, d_model, 2).float()
|
||||
* -(math.log(10000.0) / d_model)).exp()
|
||||
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
|
||||
pe = pe.unsqueeze(0)
|
||||
self.register_buffer('pe', pe)
|
||||
|
||||
def forward(self, x):
|
||||
return self.pe[:, :x.size(1)]
|
||||
|
||||
|
||||
class TokenEmbedding(nn.Module):
|
||||
def __init__(self, c_in, d_model):
|
||||
super(TokenEmbedding, self).__init__()
|
||||
padding = 1 if torch.__version__ >= '1.5.0' else 2
|
||||
self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
|
||||
kernel_size=3, padding=padding, padding_mode='circular', bias=False)
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_in', nonlinearity='leaky_relu')
|
||||
|
||||
def forward(self, x):
|
||||
x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class FixedEmbedding(nn.Module):
|
||||
def __init__(self, c_in, d_model):
|
||||
super(FixedEmbedding, self).__init__()
|
||||
|
||||
w = torch.zeros(c_in, d_model).float()
|
||||
w.require_grad = False
|
||||
|
||||
position = torch.arange(0, c_in).float().unsqueeze(1)
|
||||
div_term = (torch.arange(0, d_model, 2).float()
|
||||
* -(math.log(10000.0) / d_model)).exp()
|
||||
|
||||
w[:, 0::2] = torch.sin(position * div_term)
|
||||
w[:, 1::2] = torch.cos(position * div_term)
|
||||
|
||||
self.emb = nn.Embedding(c_in, d_model)
|
||||
self.emb.weight = nn.Parameter(w, requires_grad=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.emb(x).detach()
|
||||
|
||||
|
||||
class TemporalEmbedding(nn.Module):
|
||||
def __init__(self, d_model, embed_type='fixed', freq='h'):
|
||||
super(TemporalEmbedding, self).__init__()
|
||||
|
||||
minute_size = 4
|
||||
hour_size = 24
|
||||
weekday_size = 7
|
||||
day_size = 32
|
||||
month_size = 13
|
||||
|
||||
Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding
|
||||
if freq == 't':
|
||||
self.minute_embed = Embed(minute_size, d_model)
|
||||
self.hour_embed = Embed(hour_size, d_model)
|
||||
self.weekday_embed = Embed(weekday_size, d_model)
|
||||
self.day_embed = Embed(day_size, d_model)
|
||||
self.month_embed = Embed(month_size, d_model)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.long()
|
||||
minute_x = self.minute_embed(x[:, :, 4]) if hasattr(
|
||||
self, 'minute_embed') else 0.
|
||||
hour_x = self.hour_embed(x[:, :, 3])
|
||||
weekday_x = self.weekday_embed(x[:, :, 2])
|
||||
day_x = self.day_embed(x[:, :, 1])
|
||||
month_x = self.month_embed(x[:, :, 0])
|
||||
|
||||
return hour_x + weekday_x + day_x + month_x + minute_x
|
||||
|
||||
|
||||
class TimeFeatureEmbedding(nn.Module):
|
||||
def __init__(self, d_model, embed_type='timeF', freq='h'):
|
||||
super(TimeFeatureEmbedding, self).__init__()
|
||||
|
||||
freq_map = {'h': 4, 't': 5, 's': 6,
|
||||
'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3}
|
||||
d_inp = freq_map[freq]
|
||||
self.embed = nn.Linear(d_inp, d_model, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.embed(x)
|
||||
|
||||
|
||||
class DataEmbedding(nn.Module):
|
||||
def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
|
||||
super(DataEmbedding, self).__init__()
|
||||
|
||||
self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
|
||||
self.position_embedding = PositionalEmbedding(d_model=d_model)
|
||||
self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type,
|
||||
freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(
|
||||
d_model=d_model, embed_type=embed_type, freq=freq)
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
def forward(self, x, x_mark):
|
||||
if x_mark is None:
|
||||
x = self.value_embedding(x) + self.position_embedding(x)
|
||||
else:
|
||||
x = self.value_embedding(
|
||||
x) + self.temporal_embedding(x_mark) + self.position_embedding(x)
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class DataEmbedding_inverted(nn.Module):
|
||||
def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
|
||||
super(DataEmbedding_inverted, self).__init__()
|
||||
self.value_embedding = nn.Linear(c_in, d_model)
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
def forward(self, x, x_mark):
|
||||
x = x.permute(0, 2, 1)
|
||||
# x: [Batch Variate Time]
|
||||
if x_mark is None:
|
||||
x = self.value_embedding(x)
|
||||
else:
|
||||
x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1))
|
||||
# x: [Batch Variate d_model]
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class DataEmbedding_wo_pos(nn.Module):
|
||||
def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
|
||||
super(DataEmbedding_wo_pos, self).__init__()
|
||||
|
||||
self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
|
||||
self.position_embedding = PositionalEmbedding(d_model=d_model)
|
||||
self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type,
|
||||
freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(
|
||||
d_model=d_model, embed_type=embed_type, freq=freq)
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
def forward(self, x, x_mark):
|
||||
if x_mark is None:
|
||||
x = self.value_embedding(x)
|
||||
else:
|
||||
x = self.value_embedding(x) + self.temporal_embedding(x_mark)
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class PatchEmbedding(nn.Module):
|
||||
def __init__(self, d_model, patch_len, stride, padding, dropout):
|
||||
super(PatchEmbedding, self).__init__()
|
||||
# Patching
|
||||
self.patch_len = patch_len
|
||||
self.stride = stride
|
||||
self.padding_patch_layer = nn.ReplicationPad1d((0, padding))
|
||||
|
||||
# Backbone, Input encoding: projection of feature vectors onto a d-dim vector space
|
||||
self.value_embedding = nn.Linear(patch_len, d_model, bias=False)
|
||||
|
||||
# Positional embedding
|
||||
self.position_embedding = PositionalEmbedding(d_model)
|
||||
|
||||
# Residual dropout
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
# do patching
|
||||
n_vars = x.shape[1]
|
||||
x = self.padding_patch_layer(x)
|
||||
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
|
||||
x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
|
||||
# Input encoding
|
||||
x = self.value_embedding(x) + self.position_embedding(x)
|
||||
return self.dropout(x), n_vars
|
162
layers/FourierCorrelation.py
Normal file
162
layers/FourierCorrelation.py
Normal file
@ -0,0 +1,162 @@
|
||||
# coding=utf-8
|
||||
# author=maziqing
|
||||
# email=maziqing.mzq@alibaba-inc.com
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def get_frequency_modes(seq_len, modes=64, mode_select_method='random'):
|
||||
"""
|
||||
get modes on frequency domain:
|
||||
'random' means sampling randomly;
|
||||
'else' means sampling the lowest modes;
|
||||
"""
|
||||
modes = min(modes, seq_len // 2)
|
||||
if mode_select_method == 'random':
|
||||
index = list(range(0, seq_len // 2))
|
||||
np.random.shuffle(index)
|
||||
index = index[:modes]
|
||||
else:
|
||||
index = list(range(0, modes))
|
||||
index.sort()
|
||||
return index
|
||||
|
||||
|
||||
# ########## fourier layer #############
|
||||
class FourierBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, n_heads, seq_len, modes=0, mode_select_method='random'):
|
||||
super(FourierBlock, self).__init__()
|
||||
print('fourier enhanced block used!')
|
||||
"""
|
||||
1D Fourier block. It performs representation learning on frequency domain,
|
||||
it does FFT, linear transform, and Inverse FFT.
|
||||
"""
|
||||
# get modes on frequency domain
|
||||
self.index = get_frequency_modes(seq_len, modes=modes, mode_select_method=mode_select_method)
|
||||
print('modes={}, index={}'.format(modes, self.index))
|
||||
|
||||
self.n_heads = n_heads
|
||||
self.scale = (1 / (in_channels * out_channels))
|
||||
self.weights1 = nn.Parameter(
|
||||
self.scale * torch.rand(self.n_heads, in_channels // self.n_heads, out_channels // self.n_heads,
|
||||
len(self.index), dtype=torch.float))
|
||||
self.weights2 = nn.Parameter(
|
||||
self.scale * torch.rand(self.n_heads, in_channels // self.n_heads, out_channels // self.n_heads,
|
||||
len(self.index), dtype=torch.float))
|
||||
|
||||
# Complex multiplication
|
||||
def compl_mul1d(self, order, x, weights):
|
||||
x_flag = True
|
||||
w_flag = True
|
||||
if not torch.is_complex(x):
|
||||
x_flag = False
|
||||
x = torch.complex(x, torch.zeros_like(x).to(x.device))
|
||||
if not torch.is_complex(weights):
|
||||
w_flag = False
|
||||
weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device))
|
||||
if x_flag or w_flag:
|
||||
return torch.complex(torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag),
|
||||
torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real))
|
||||
else:
|
||||
return torch.einsum(order, x.real, weights.real)
|
||||
|
||||
def forward(self, q, k, v, mask):
|
||||
# size = [B, L, H, E]
|
||||
B, L, H, E = q.shape
|
||||
x = q.permute(0, 2, 3, 1)
|
||||
# Compute Fourier coefficients
|
||||
x_ft = torch.fft.rfft(x, dim=-1)
|
||||
# Perform Fourier neural operations
|
||||
out_ft = torch.zeros(B, H, E, L // 2 + 1, device=x.device, dtype=torch.cfloat)
|
||||
for wi, i in enumerate(self.index):
|
||||
if i >= x_ft.shape[3] or wi >= out_ft.shape[3]:
|
||||
continue
|
||||
out_ft[:, :, :, wi] = self.compl_mul1d("bhi,hio->bho", x_ft[:, :, :, i],
|
||||
torch.complex(self.weights1, self.weights2)[:, :, :, wi])
|
||||
# Return to time domain
|
||||
x = torch.fft.irfft(out_ft, n=x.size(-1))
|
||||
return (x, None)
|
||||
|
||||
# ########## Fourier Cross Former ####################
|
||||
class FourierCrossAttention(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes=64, mode_select_method='random',
|
||||
activation='tanh', policy=0, num_heads=8):
|
||||
super(FourierCrossAttention, self).__init__()
|
||||
print(' fourier enhanced cross attention used!')
|
||||
"""
|
||||
1D Fourier Cross Attention layer. It does FFT, linear transform, attention mechanism and Inverse FFT.
|
||||
"""
|
||||
self.activation = activation
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
# get modes for queries and keys (& values) on frequency domain
|
||||
self.index_q = get_frequency_modes(seq_len_q, modes=modes, mode_select_method=mode_select_method)
|
||||
self.index_kv = get_frequency_modes(seq_len_kv, modes=modes, mode_select_method=mode_select_method)
|
||||
|
||||
print('modes_q={}, index_q={}'.format(len(self.index_q), self.index_q))
|
||||
print('modes_kv={}, index_kv={}'.format(len(self.index_kv), self.index_kv))
|
||||
|
||||
self.scale = (1 / (in_channels * out_channels))
|
||||
self.weights1 = nn.Parameter(
|
||||
self.scale * torch.rand(num_heads, in_channels // num_heads, out_channels // num_heads, len(self.index_q), dtype=torch.float))
|
||||
self.weights2 = nn.Parameter(
|
||||
self.scale * torch.rand(num_heads, in_channels // num_heads, out_channels // num_heads, len(self.index_q), dtype=torch.float))
|
||||
|
||||
# Complex multiplication
|
||||
def compl_mul1d(self, order, x, weights):
|
||||
x_flag = True
|
||||
w_flag = True
|
||||
if not torch.is_complex(x):
|
||||
x_flag = False
|
||||
x = torch.complex(x, torch.zeros_like(x).to(x.device))
|
||||
if not torch.is_complex(weights):
|
||||
w_flag = False
|
||||
weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device))
|
||||
if x_flag or w_flag:
|
||||
return torch.complex(torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag),
|
||||
torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real))
|
||||
else:
|
||||
return torch.einsum(order, x.real, weights.real)
|
||||
|
||||
def forward(self, q, k, v, mask):
|
||||
# size = [B, L, H, E]
|
||||
B, L, H, E = q.shape
|
||||
xq = q.permute(0, 2, 3, 1) # size = [B, H, E, L]
|
||||
xk = k.permute(0, 2, 3, 1)
|
||||
xv = v.permute(0, 2, 3, 1)
|
||||
|
||||
# Compute Fourier coefficients
|
||||
xq_ft_ = torch.zeros(B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat)
|
||||
xq_ft = torch.fft.rfft(xq, dim=-1)
|
||||
for i, j in enumerate(self.index_q):
|
||||
if j >= xq_ft.shape[3]:
|
||||
continue
|
||||
xq_ft_[:, :, :, i] = xq_ft[:, :, :, j]
|
||||
xk_ft_ = torch.zeros(B, H, E, len(self.index_kv), device=xq.device, dtype=torch.cfloat)
|
||||
xk_ft = torch.fft.rfft(xk, dim=-1)
|
||||
for i, j in enumerate(self.index_kv):
|
||||
if j >= xk_ft.shape[3]:
|
||||
continue
|
||||
xk_ft_[:, :, :, i] = xk_ft[:, :, :, j]
|
||||
|
||||
# perform attention mechanism on frequency domain
|
||||
xqk_ft = (self.compl_mul1d("bhex,bhey->bhxy", xq_ft_, xk_ft_))
|
||||
if self.activation == 'tanh':
|
||||
xqk_ft = torch.complex(xqk_ft.real.tanh(), xqk_ft.imag.tanh())
|
||||
elif self.activation == 'softmax':
|
||||
xqk_ft = torch.softmax(abs(xqk_ft), dim=-1)
|
||||
xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft))
|
||||
else:
|
||||
raise Exception('{} actiation function is not implemented'.format(self.activation))
|
||||
xqkv_ft = self.compl_mul1d("bhxy,bhey->bhex", xqk_ft, xk_ft_)
|
||||
xqkvw = self.compl_mul1d("bhex,heox->bhox", xqkv_ft, torch.complex(self.weights1, self.weights2))
|
||||
out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat)
|
||||
for i, j in enumerate(self.index_q):
|
||||
if i >= xqkvw.shape[3] or j >= out_ft.shape[3]:
|
||||
continue
|
||||
out_ft[:, :, :, j] = xqkvw[:, :, :, i]
|
||||
# Return to time domain
|
||||
out = torch.fft.irfft(out_ft / self.in_channels / self.out_channels, n=xq.size(-1))
|
||||
return (out, None)
|
83
layers/GraphMixer.py
Normal file
83
layers/GraphMixer.py
Normal file
@ -0,0 +1,83 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
class HierarchicalGraphMixer(nn.Module):
|
||||
"""
|
||||
分层图混合器,同时考虑宏观通道关系和微观 Patch 级别注意力。
|
||||
输入 z : 形状为 [B, C, N, D] 的张量
|
||||
输出 z_out : 形状同输入
|
||||
"""
|
||||
def __init__(self, n_channel: int, dim: int, k: int = 5, tau: float = 0.2):
|
||||
super().__init__()
|
||||
self.k = k
|
||||
self.tau = tau
|
||||
|
||||
# Level 1: Channel Graph
|
||||
self.A = nn.Parameter(torch.zeros(n_channel, n_channel))
|
||||
self.se = nn.Sequential(
|
||||
nn.Linear(dim, dim // 4, bias=False), nn.ReLU(),
|
||||
nn.Linear(dim // 4, 1, bias=False), nn.Sigmoid()
|
||||
)
|
||||
|
||||
# Level 2: Patch Cross-Attention
|
||||
self.q_proj = nn.Linear(dim, dim)
|
||||
self.k_proj = nn.Linear(dim, dim)
|
||||
self.v_proj = nn.Linear(dim, dim)
|
||||
self.out_proj = nn.Linear(dim, dim)
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
def _row_sparse(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
"""Gumbel-Softmax based sparse attention"""
|
||||
g = -torch.empty_like(logits).exponential_().log()
|
||||
y = (logits + g) / self.tau
|
||||
probs = F.softmax(y, dim=-1)
|
||||
|
||||
# Ensure k doesn't exceed the dimension size
|
||||
k_actual = min(self.k, probs.size(-1))
|
||||
if k_actual <= 0:
|
||||
return torch.zeros_like(probs)
|
||||
|
||||
topk_val, _ = torch.topk(probs, k_actual, dim=-1)
|
||||
thr = topk_val[..., -1].unsqueeze(-1)
|
||||
sparse = torch.where(probs >= thr, probs, torch.zeros_like(probs))
|
||||
return sparse.detach() + probs - probs.detach()
|
||||
|
||||
def forward(self, z):
|
||||
# z 的形状: [B, C, N, D]
|
||||
B, C, N, D = z.shape
|
||||
|
||||
# --- Level 1: 计算宏观权重 ---
|
||||
A_sparse = self._row_sparse(self.A) # 通道连接稀疏图 A_sparse: [C, C]
|
||||
|
||||
# --- Level 2: 跨通道 Patch 交互 ---
|
||||
out_z = torch.zeros_like(z)
|
||||
for i in range(C): # 遍历每个目标通道 i
|
||||
target_z = z[:, i, :, :] # [B, N, D]
|
||||
|
||||
# 准备聚合来自其他通道的 patch 级别上下文
|
||||
aggregated_context = torch.zeros_like(target_z)
|
||||
|
||||
for j in range(C): # 遍历每个源通道 j
|
||||
if A_sparse[i, j] != 0:
|
||||
source_z = z[:, j, :, :] # [B, N, D]
|
||||
|
||||
# --- 执行交叉注意力 ---
|
||||
Q = self.q_proj(target_z) # Query 来自目标通道 i
|
||||
K = self.k_proj(source_z) # Key 来自源通道 j
|
||||
V = self.v_proj(source_z) # Value 来自源通道 j
|
||||
|
||||
attn_scores = torch.bmm(Q, K.transpose(1, 2)) / math.sqrt(D)
|
||||
attn_probs = F.softmax(attn_scores, dim=-1) # [B, N, N]
|
||||
|
||||
context = torch.bmm(attn_probs, V) # [B, N, D], 从 j 聚合到 i 的上下文
|
||||
|
||||
# 加权上下文
|
||||
weighted_context = A_sparse[i, j] * context
|
||||
aggregated_context = aggregated_context + weighted_context
|
||||
|
||||
# 将聚合后的上下文通过输出层,并与原始目标表示相加(残差连接)
|
||||
out_z[:, i, :, :] = self.norm(target_z + self.out_proj(aggregated_context))
|
||||
|
||||
return out_z
|
587
layers/MultiWaveletCorrelation.py
Normal file
587
layers/MultiWaveletCorrelation.py
Normal file
@ -0,0 +1,587 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from typing import List, Tuple
|
||||
import math
|
||||
from functools import partial
|
||||
from torch import nn, einsum, diagonal
|
||||
from math import log2, ceil
|
||||
import pdb
|
||||
from sympy import Poly, legendre, Symbol, chebyshevt
|
||||
from scipy.special import eval_legendre
|
||||
|
||||
|
||||
def legendreDer(k, x):
|
||||
def _legendre(k, x):
|
||||
return (2 * k + 1) * eval_legendre(k, x)
|
||||
|
||||
out = 0
|
||||
for i in np.arange(k - 1, -1, -2):
|
||||
out += _legendre(i, x)
|
||||
return out
|
||||
|
||||
|
||||
def phi_(phi_c, x, lb=0, ub=1):
|
||||
mask = np.logical_or(x < lb, x > ub) * 1.0
|
||||
return np.polynomial.polynomial.Polynomial(phi_c)(x) * (1 - mask)
|
||||
|
||||
|
||||
def get_phi_psi(k, base):
|
||||
x = Symbol('x')
|
||||
phi_coeff = np.zeros((k, k))
|
||||
phi_2x_coeff = np.zeros((k, k))
|
||||
if base == 'legendre':
|
||||
for ki in range(k):
|
||||
coeff_ = Poly(legendre(ki, 2 * x - 1), x).all_coeffs()
|
||||
phi_coeff[ki, :ki + 1] = np.flip(np.sqrt(2 * ki + 1) * np.array(coeff_).astype(np.float64))
|
||||
coeff_ = Poly(legendre(ki, 4 * x - 1), x).all_coeffs()
|
||||
phi_2x_coeff[ki, :ki + 1] = np.flip(np.sqrt(2) * np.sqrt(2 * ki + 1) * np.array(coeff_).astype(np.float64))
|
||||
|
||||
psi1_coeff = np.zeros((k, k))
|
||||
psi2_coeff = np.zeros((k, k))
|
||||
for ki in range(k):
|
||||
psi1_coeff[ki, :] = phi_2x_coeff[ki, :]
|
||||
for i in range(k):
|
||||
a = phi_2x_coeff[ki, :ki + 1]
|
||||
b = phi_coeff[i, :i + 1]
|
||||
prod_ = np.convolve(a, b)
|
||||
prod_[np.abs(prod_) < 1e-8] = 0
|
||||
proj_ = (prod_ * 1 / (np.arange(len(prod_)) + 1) * np.power(0.5, 1 + np.arange(len(prod_)))).sum()
|
||||
psi1_coeff[ki, :] -= proj_ * phi_coeff[i, :]
|
||||
psi2_coeff[ki, :] -= proj_ * phi_coeff[i, :]
|
||||
for j in range(ki):
|
||||
a = phi_2x_coeff[ki, :ki + 1]
|
||||
b = psi1_coeff[j, :]
|
||||
prod_ = np.convolve(a, b)
|
||||
prod_[np.abs(prod_) < 1e-8] = 0
|
||||
proj_ = (prod_ * 1 / (np.arange(len(prod_)) + 1) * np.power(0.5, 1 + np.arange(len(prod_)))).sum()
|
||||
psi1_coeff[ki, :] -= proj_ * psi1_coeff[j, :]
|
||||
psi2_coeff[ki, :] -= proj_ * psi2_coeff[j, :]
|
||||
|
||||
a = psi1_coeff[ki, :]
|
||||
prod_ = np.convolve(a, a)
|
||||
prod_[np.abs(prod_) < 1e-8] = 0
|
||||
norm1 = (prod_ * 1 / (np.arange(len(prod_)) + 1) * np.power(0.5, 1 + np.arange(len(prod_)))).sum()
|
||||
|
||||
a = psi2_coeff[ki, :]
|
||||
prod_ = np.convolve(a, a)
|
||||
prod_[np.abs(prod_) < 1e-8] = 0
|
||||
norm2 = (prod_ * 1 / (np.arange(len(prod_)) + 1) * (1 - np.power(0.5, 1 + np.arange(len(prod_))))).sum()
|
||||
norm_ = np.sqrt(norm1 + norm2)
|
||||
psi1_coeff[ki, :] /= norm_
|
||||
psi2_coeff[ki, :] /= norm_
|
||||
psi1_coeff[np.abs(psi1_coeff) < 1e-8] = 0
|
||||
psi2_coeff[np.abs(psi2_coeff) < 1e-8] = 0
|
||||
|
||||
phi = [np.poly1d(np.flip(phi_coeff[i, :])) for i in range(k)]
|
||||
psi1 = [np.poly1d(np.flip(psi1_coeff[i, :])) for i in range(k)]
|
||||
psi2 = [np.poly1d(np.flip(psi2_coeff[i, :])) for i in range(k)]
|
||||
|
||||
elif base == 'chebyshev':
|
||||
for ki in range(k):
|
||||
if ki == 0:
|
||||
phi_coeff[ki, :ki + 1] = np.sqrt(2 / np.pi)
|
||||
phi_2x_coeff[ki, :ki + 1] = np.sqrt(2 / np.pi) * np.sqrt(2)
|
||||
else:
|
||||
coeff_ = Poly(chebyshevt(ki, 2 * x - 1), x).all_coeffs()
|
||||
phi_coeff[ki, :ki + 1] = np.flip(2 / np.sqrt(np.pi) * np.array(coeff_).astype(np.float64))
|
||||
coeff_ = Poly(chebyshevt(ki, 4 * x - 1), x).all_coeffs()
|
||||
phi_2x_coeff[ki, :ki + 1] = np.flip(
|
||||
np.sqrt(2) * 2 / np.sqrt(np.pi) * np.array(coeff_).astype(np.float64))
|
||||
|
||||
phi = [partial(phi_, phi_coeff[i, :]) for i in range(k)]
|
||||
|
||||
x = Symbol('x')
|
||||
kUse = 2 * k
|
||||
roots = Poly(chebyshevt(kUse, 2 * x - 1)).all_roots()
|
||||
x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
|
||||
# x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1)
|
||||
# not needed for our purpose here, we use even k always to avoid
|
||||
wm = np.pi / kUse / 2
|
||||
|
||||
psi1_coeff = np.zeros((k, k))
|
||||
psi2_coeff = np.zeros((k, k))
|
||||
|
||||
psi1 = [[] for _ in range(k)]
|
||||
psi2 = [[] for _ in range(k)]
|
||||
|
||||
for ki in range(k):
|
||||
psi1_coeff[ki, :] = phi_2x_coeff[ki, :]
|
||||
for i in range(k):
|
||||
proj_ = (wm * phi[i](x_m) * np.sqrt(2) * phi[ki](2 * x_m)).sum()
|
||||
psi1_coeff[ki, :] -= proj_ * phi_coeff[i, :]
|
||||
psi2_coeff[ki, :] -= proj_ * phi_coeff[i, :]
|
||||
|
||||
for j in range(ki):
|
||||
proj_ = (wm * psi1[j](x_m) * np.sqrt(2) * phi[ki](2 * x_m)).sum()
|
||||
psi1_coeff[ki, :] -= proj_ * psi1_coeff[j, :]
|
||||
psi2_coeff[ki, :] -= proj_ * psi2_coeff[j, :]
|
||||
|
||||
psi1[ki] = partial(phi_, psi1_coeff[ki, :], lb=0, ub=0.5)
|
||||
psi2[ki] = partial(phi_, psi2_coeff[ki, :], lb=0.5, ub=1)
|
||||
|
||||
norm1 = (wm * psi1[ki](x_m) * psi1[ki](x_m)).sum()
|
||||
norm2 = (wm * psi2[ki](x_m) * psi2[ki](x_m)).sum()
|
||||
|
||||
norm_ = np.sqrt(norm1 + norm2)
|
||||
psi1_coeff[ki, :] /= norm_
|
||||
psi2_coeff[ki, :] /= norm_
|
||||
psi1_coeff[np.abs(psi1_coeff) < 1e-8] = 0
|
||||
psi2_coeff[np.abs(psi2_coeff) < 1e-8] = 0
|
||||
|
||||
psi1[ki] = partial(phi_, psi1_coeff[ki, :], lb=0, ub=0.5 + 1e-16)
|
||||
psi2[ki] = partial(phi_, psi2_coeff[ki, :], lb=0.5 + 1e-16, ub=1)
|
||||
|
||||
return phi, psi1, psi2
|
||||
|
||||
|
||||
def get_filter(base, k):
|
||||
def psi(psi1, psi2, i, inp):
|
||||
mask = (inp <= 0.5) * 1.0
|
||||
return psi1[i](inp) * mask + psi2[i](inp) * (1 - mask)
|
||||
|
||||
if base not in ['legendre', 'chebyshev']:
|
||||
raise Exception('Base not supported')
|
||||
|
||||
x = Symbol('x')
|
||||
H0 = np.zeros((k, k))
|
||||
H1 = np.zeros((k, k))
|
||||
G0 = np.zeros((k, k))
|
||||
G1 = np.zeros((k, k))
|
||||
PHI0 = np.zeros((k, k))
|
||||
PHI1 = np.zeros((k, k))
|
||||
phi, psi1, psi2 = get_phi_psi(k, base)
|
||||
if base == 'legendre':
|
||||
roots = Poly(legendre(k, 2 * x - 1)).all_roots()
|
||||
x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
|
||||
wm = 1 / k / legendreDer(k, 2 * x_m - 1) / eval_legendre(k - 1, 2 * x_m - 1)
|
||||
|
||||
for ki in range(k):
|
||||
for kpi in range(k):
|
||||
H0[ki, kpi] = 1 / np.sqrt(2) * (wm * phi[ki](x_m / 2) * phi[kpi](x_m)).sum()
|
||||
G0[ki, kpi] = 1 / np.sqrt(2) * (wm * psi(psi1, psi2, ki, x_m / 2) * phi[kpi](x_m)).sum()
|
||||
H1[ki, kpi] = 1 / np.sqrt(2) * (wm * phi[ki]((x_m + 1) / 2) * phi[kpi](x_m)).sum()
|
||||
G1[ki, kpi] = 1 / np.sqrt(2) * (wm * psi(psi1, psi2, ki, (x_m + 1) / 2) * phi[kpi](x_m)).sum()
|
||||
|
||||
PHI0 = np.eye(k)
|
||||
PHI1 = np.eye(k)
|
||||
|
||||
elif base == 'chebyshev':
|
||||
x = Symbol('x')
|
||||
kUse = 2 * k
|
||||
roots = Poly(chebyshevt(kUse, 2 * x - 1)).all_roots()
|
||||
x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
|
||||
# x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1)
|
||||
# not needed for our purpose here, we use even k always to avoid
|
||||
wm = np.pi / kUse / 2
|
||||
|
||||
for ki in range(k):
|
||||
for kpi in range(k):
|
||||
H0[ki, kpi] = 1 / np.sqrt(2) * (wm * phi[ki](x_m / 2) * phi[kpi](x_m)).sum()
|
||||
G0[ki, kpi] = 1 / np.sqrt(2) * (wm * psi(psi1, psi2, ki, x_m / 2) * phi[kpi](x_m)).sum()
|
||||
H1[ki, kpi] = 1 / np.sqrt(2) * (wm * phi[ki]((x_m + 1) / 2) * phi[kpi](x_m)).sum()
|
||||
G1[ki, kpi] = 1 / np.sqrt(2) * (wm * psi(psi1, psi2, ki, (x_m + 1) / 2) * phi[kpi](x_m)).sum()
|
||||
|
||||
PHI0[ki, kpi] = (wm * phi[ki](2 * x_m) * phi[kpi](2 * x_m)).sum() * 2
|
||||
PHI1[ki, kpi] = (wm * phi[ki](2 * x_m - 1) * phi[kpi](2 * x_m - 1)).sum() * 2
|
||||
|
||||
PHI0[np.abs(PHI0) < 1e-8] = 0
|
||||
PHI1[np.abs(PHI1) < 1e-8] = 0
|
||||
|
||||
H0[np.abs(H0) < 1e-8] = 0
|
||||
H1[np.abs(H1) < 1e-8] = 0
|
||||
G0[np.abs(G0) < 1e-8] = 0
|
||||
G1[np.abs(G1) < 1e-8] = 0
|
||||
|
||||
return H0, H1, G0, G1, PHI0, PHI1
|
||||
|
||||
|
||||
class MultiWaveletTransform(nn.Module):
|
||||
"""
|
||||
1D multiwavelet block.
|
||||
"""
|
||||
|
||||
def __init__(self, ich=1, k=8, alpha=16, c=128,
|
||||
nCZ=1, L=0, base='legendre', attention_dropout=0.1):
|
||||
super(MultiWaveletTransform, self).__init__()
|
||||
print('base', base)
|
||||
self.k = k
|
||||
self.c = c
|
||||
self.L = L
|
||||
self.nCZ = nCZ
|
||||
self.Lk0 = nn.Linear(ich, c * k)
|
||||
self.Lk1 = nn.Linear(c * k, ich)
|
||||
self.ich = ich
|
||||
self.MWT_CZ = nn.ModuleList(MWT_CZ1d(k, alpha, L, c, base) for i in range(nCZ))
|
||||
|
||||
def forward(self, queries, keys, values, attn_mask):
|
||||
B, L, H, E = queries.shape
|
||||
_, S, _, D = values.shape
|
||||
if L > S:
|
||||
zeros = torch.zeros_like(queries[:, :(L - S), :]).float()
|
||||
values = torch.cat([values, zeros], dim=1)
|
||||
keys = torch.cat([keys, zeros], dim=1)
|
||||
else:
|
||||
values = values[:, :L, :, :]
|
||||
keys = keys[:, :L, :, :]
|
||||
values = values.view(B, L, -1)
|
||||
|
||||
V = self.Lk0(values).view(B, L, self.c, -1)
|
||||
for i in range(self.nCZ):
|
||||
V = self.MWT_CZ[i](V)
|
||||
if i < self.nCZ - 1:
|
||||
V = F.relu(V)
|
||||
|
||||
V = self.Lk1(V.view(B, L, -1))
|
||||
V = V.view(B, L, -1, D)
|
||||
return (V.contiguous(), None)
|
||||
|
||||
|
||||
class MultiWaveletCross(nn.Module):
|
||||
"""
|
||||
1D Multiwavelet Cross Attention layer.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes, c=64,
|
||||
k=8, ich=512,
|
||||
L=0,
|
||||
base='legendre',
|
||||
mode_select_method='random',
|
||||
initializer=None, activation='tanh',
|
||||
**kwargs):
|
||||
super(MultiWaveletCross, self).__init__()
|
||||
print('base', base)
|
||||
|
||||
self.c = c
|
||||
self.k = k
|
||||
self.L = L
|
||||
H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k)
|
||||
H0r = H0 @ PHI0
|
||||
G0r = G0 @ PHI0
|
||||
H1r = H1 @ PHI1
|
||||
G1r = G1 @ PHI1
|
||||
|
||||
H0r[np.abs(H0r) < 1e-8] = 0
|
||||
H1r[np.abs(H1r) < 1e-8] = 0
|
||||
G0r[np.abs(G0r) < 1e-8] = 0
|
||||
G1r[np.abs(G1r) < 1e-8] = 0
|
||||
self.max_item = 3
|
||||
|
||||
self.attn1 = FourierCrossAttentionW(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q,
|
||||
seq_len_kv=seq_len_kv, modes=modes, activation=activation,
|
||||
mode_select_method=mode_select_method)
|
||||
self.attn2 = FourierCrossAttentionW(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q,
|
||||
seq_len_kv=seq_len_kv, modes=modes, activation=activation,
|
||||
mode_select_method=mode_select_method)
|
||||
self.attn3 = FourierCrossAttentionW(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q,
|
||||
seq_len_kv=seq_len_kv, modes=modes, activation=activation,
|
||||
mode_select_method=mode_select_method)
|
||||
self.attn4 = FourierCrossAttentionW(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q,
|
||||
seq_len_kv=seq_len_kv, modes=modes, activation=activation,
|
||||
mode_select_method=mode_select_method)
|
||||
self.T0 = nn.Linear(k, k)
|
||||
self.register_buffer('ec_s', torch.Tensor(
|
||||
np.concatenate((H0.T, H1.T), axis=0)))
|
||||
self.register_buffer('ec_d', torch.Tensor(
|
||||
np.concatenate((G0.T, G1.T), axis=0)))
|
||||
|
||||
self.register_buffer('rc_e', torch.Tensor(
|
||||
np.concatenate((H0r, G0r), axis=0)))
|
||||
self.register_buffer('rc_o', torch.Tensor(
|
||||
np.concatenate((H1r, G1r), axis=0)))
|
||||
|
||||
self.Lk = nn.Linear(ich, c * k)
|
||||
self.Lq = nn.Linear(ich, c * k)
|
||||
self.Lv = nn.Linear(ich, c * k)
|
||||
self.out = nn.Linear(c * k, ich)
|
||||
self.modes1 = modes
|
||||
|
||||
def forward(self, q, k, v, mask=None):
|
||||
B, N, H, E = q.shape # (B, N, H, E) torch.Size([3, 768, 8, 2])
|
||||
_, S, _, _ = k.shape # (B, S, H, E) torch.Size([3, 96, 8, 2])
|
||||
|
||||
q = q.view(q.shape[0], q.shape[1], -1)
|
||||
k = k.view(k.shape[0], k.shape[1], -1)
|
||||
v = v.view(v.shape[0], v.shape[1], -1)
|
||||
q = self.Lq(q)
|
||||
q = q.view(q.shape[0], q.shape[1], self.c, self.k)
|
||||
k = self.Lk(k)
|
||||
k = k.view(k.shape[0], k.shape[1], self.c, self.k)
|
||||
v = self.Lv(v)
|
||||
v = v.view(v.shape[0], v.shape[1], self.c, self.k)
|
||||
|
||||
if N > S:
|
||||
zeros = torch.zeros_like(q[:, :(N - S), :]).float()
|
||||
v = torch.cat([v, zeros], dim=1)
|
||||
k = torch.cat([k, zeros], dim=1)
|
||||
else:
|
||||
v = v[:, :N, :, :]
|
||||
k = k[:, :N, :, :]
|
||||
|
||||
ns = math.floor(np.log2(N))
|
||||
nl = pow(2, math.ceil(np.log2(N)))
|
||||
extra_q = q[:, 0:nl - N, :, :]
|
||||
extra_k = k[:, 0:nl - N, :, :]
|
||||
extra_v = v[:, 0:nl - N, :, :]
|
||||
q = torch.cat([q, extra_q], 1)
|
||||
k = torch.cat([k, extra_k], 1)
|
||||
v = torch.cat([v, extra_v], 1)
|
||||
|
||||
Ud_q = torch.jit.annotate(List[Tuple[Tensor]], [])
|
||||
Ud_k = torch.jit.annotate(List[Tuple[Tensor]], [])
|
||||
Ud_v = torch.jit.annotate(List[Tuple[Tensor]], [])
|
||||
|
||||
Us_q = torch.jit.annotate(List[Tensor], [])
|
||||
Us_k = torch.jit.annotate(List[Tensor], [])
|
||||
Us_v = torch.jit.annotate(List[Tensor], [])
|
||||
|
||||
Ud = torch.jit.annotate(List[Tensor], [])
|
||||
Us = torch.jit.annotate(List[Tensor], [])
|
||||
|
||||
# decompose
|
||||
for i in range(ns - self.L):
|
||||
# print('q shape',q.shape)
|
||||
d, q = self.wavelet_transform(q)
|
||||
Ud_q += [tuple([d, q])]
|
||||
Us_q += [d]
|
||||
for i in range(ns - self.L):
|
||||
d, k = self.wavelet_transform(k)
|
||||
Ud_k += [tuple([d, k])]
|
||||
Us_k += [d]
|
||||
for i in range(ns - self.L):
|
||||
d, v = self.wavelet_transform(v)
|
||||
Ud_v += [tuple([d, v])]
|
||||
Us_v += [d]
|
||||
for i in range(ns - self.L):
|
||||
dk, sk = Ud_k[i], Us_k[i]
|
||||
dq, sq = Ud_q[i], Us_q[i]
|
||||
dv, sv = Ud_v[i], Us_v[i]
|
||||
Ud += [self.attn1(dq[0], dk[0], dv[0], mask)[0] + self.attn2(dq[1], dk[1], dv[1], mask)[0]]
|
||||
Us += [self.attn3(sq, sk, sv, mask)[0]]
|
||||
v = self.attn4(q, k, v, mask)[0]
|
||||
|
||||
# reconstruct
|
||||
for i in range(ns - 1 - self.L, -1, -1):
|
||||
v = v + Us[i]
|
||||
v = torch.cat((v, Ud[i]), -1)
|
||||
v = self.evenOdd(v)
|
||||
v = self.out(v[:, :N, :, :].contiguous().view(B, N, -1))
|
||||
return (v.contiguous(), None)
|
||||
|
||||
def wavelet_transform(self, x):
|
||||
xa = torch.cat([x[:, ::2, :, :],
|
||||
x[:, 1::2, :, :],
|
||||
], -1)
|
||||
d = torch.matmul(xa, self.ec_d)
|
||||
s = torch.matmul(xa, self.ec_s)
|
||||
return d, s
|
||||
|
||||
def evenOdd(self, x):
|
||||
B, N, c, ich = x.shape # (B, N, c, k)
|
||||
assert ich == 2 * self.k
|
||||
x_e = torch.matmul(x, self.rc_e)
|
||||
x_o = torch.matmul(x, self.rc_o)
|
||||
|
||||
x = torch.zeros(B, N * 2, c, self.k,
|
||||
device=x.device)
|
||||
x[..., ::2, :, :] = x_e
|
||||
x[..., 1::2, :, :] = x_o
|
||||
return x
|
||||
|
||||
|
||||
class FourierCrossAttentionW(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes=16, activation='tanh',
|
||||
mode_select_method='random'):
|
||||
super(FourierCrossAttentionW, self).__init__()
|
||||
print('corss fourier correlation used!')
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.modes1 = modes
|
||||
self.activation = activation
|
||||
|
||||
def compl_mul1d(self, order, x, weights):
|
||||
x_flag = True
|
||||
w_flag = True
|
||||
if not torch.is_complex(x):
|
||||
x_flag = False
|
||||
x = torch.complex(x, torch.zeros_like(x).to(x.device))
|
||||
if not torch.is_complex(weights):
|
||||
w_flag = False
|
||||
weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device))
|
||||
if x_flag or w_flag:
|
||||
return torch.complex(torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag),
|
||||
torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real))
|
||||
else:
|
||||
return torch.einsum(order, x.real, weights.real)
|
||||
|
||||
def forward(self, q, k, v, mask):
|
||||
B, L, E, H = q.shape
|
||||
|
||||
xq = q.permute(0, 3, 2, 1) # size = [B, H, E, L] torch.Size([3, 8, 64, 512])
|
||||
xk = k.permute(0, 3, 2, 1)
|
||||
xv = v.permute(0, 3, 2, 1)
|
||||
self.index_q = list(range(0, min(int(L // 2), self.modes1)))
|
||||
self.index_k_v = list(range(0, min(int(xv.shape[3] // 2), self.modes1)))
|
||||
|
||||
# Compute Fourier coefficients
|
||||
xq_ft_ = torch.zeros(B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat)
|
||||
xq_ft = torch.fft.rfft(xq, dim=-1)
|
||||
for i, j in enumerate(self.index_q):
|
||||
xq_ft_[:, :, :, i] = xq_ft[:, :, :, j]
|
||||
|
||||
xk_ft_ = torch.zeros(B, H, E, len(self.index_k_v), device=xq.device, dtype=torch.cfloat)
|
||||
xk_ft = torch.fft.rfft(xk, dim=-1)
|
||||
for i, j in enumerate(self.index_k_v):
|
||||
xk_ft_[:, :, :, i] = xk_ft[:, :, :, j]
|
||||
xqk_ft = (self.compl_mul1d("bhex,bhey->bhxy", xq_ft_, xk_ft_))
|
||||
if self.activation == 'tanh':
|
||||
xqk_ft = torch.complex(xqk_ft.real.tanh(), xqk_ft.imag.tanh())
|
||||
elif self.activation == 'softmax':
|
||||
xqk_ft = torch.softmax(abs(xqk_ft), dim=-1)
|
||||
xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft))
|
||||
else:
|
||||
raise Exception('{} actiation function is not implemented'.format(self.activation))
|
||||
xqkv_ft = self.compl_mul1d("bhxy,bhey->bhex", xqk_ft, xk_ft_)
|
||||
|
||||
xqkvw = xqkv_ft
|
||||
out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat)
|
||||
for i, j in enumerate(self.index_q):
|
||||
out_ft[:, :, :, j] = xqkvw[:, :, :, i]
|
||||
|
||||
out = torch.fft.irfft(out_ft / self.in_channels / self.out_channels, n=xq.size(-1)).permute(0, 3, 2, 1)
|
||||
# size = [B, L, H, E]
|
||||
return (out, None)
|
||||
|
||||
|
||||
class sparseKernelFT1d(nn.Module):
|
||||
def __init__(self,
|
||||
k, alpha, c=1,
|
||||
nl=1,
|
||||
initializer=None,
|
||||
**kwargs):
|
||||
super(sparseKernelFT1d, self).__init__()
|
||||
|
||||
self.modes1 = alpha
|
||||
self.scale = (1 / (c * k * c * k))
|
||||
self.weights1 = nn.Parameter(self.scale * torch.rand(c * k, c * k, self.modes1, dtype=torch.float))
|
||||
self.weights2 = nn.Parameter(self.scale * torch.rand(c * k, c * k, self.modes1, dtype=torch.float))
|
||||
self.weights1.requires_grad = True
|
||||
self.weights2.requires_grad = True
|
||||
self.k = k
|
||||
|
||||
def compl_mul1d(self, order, x, weights):
|
||||
x_flag = True
|
||||
w_flag = True
|
||||
if not torch.is_complex(x):
|
||||
x_flag = False
|
||||
x = torch.complex(x, torch.zeros_like(x).to(x.device))
|
||||
if not torch.is_complex(weights):
|
||||
w_flag = False
|
||||
weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device))
|
||||
if x_flag or w_flag:
|
||||
return torch.complex(torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag),
|
||||
torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real))
|
||||
else:
|
||||
return torch.einsum(order, x.real, weights.real)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, c, k = x.shape # (B, N, c, k)
|
||||
|
||||
x = x.view(B, N, -1)
|
||||
x = x.permute(0, 2, 1)
|
||||
x_fft = torch.fft.rfft(x)
|
||||
# Multiply relevant Fourier modes
|
||||
l = min(self.modes1, N // 2 + 1)
|
||||
out_ft = torch.zeros(B, c * k, N // 2 + 1, device=x.device, dtype=torch.cfloat)
|
||||
out_ft[:, :, :l] = self.compl_mul1d("bix,iox->box", x_fft[:, :, :l],
|
||||
torch.complex(self.weights1, self.weights2)[:, :, :l])
|
||||
x = torch.fft.irfft(out_ft, n=N)
|
||||
x = x.permute(0, 2, 1).view(B, N, c, k)
|
||||
return x
|
||||
|
||||
|
||||
# ##
|
||||
class MWT_CZ1d(nn.Module):
|
||||
def __init__(self,
|
||||
k=3, alpha=64,
|
||||
L=0, c=1,
|
||||
base='legendre',
|
||||
initializer=None,
|
||||
**kwargs):
|
||||
super(MWT_CZ1d, self).__init__()
|
||||
|
||||
self.k = k
|
||||
self.L = L
|
||||
H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k)
|
||||
H0r = H0 @ PHI0
|
||||
G0r = G0 @ PHI0
|
||||
H1r = H1 @ PHI1
|
||||
G1r = G1 @ PHI1
|
||||
|
||||
H0r[np.abs(H0r) < 1e-8] = 0
|
||||
H1r[np.abs(H1r) < 1e-8] = 0
|
||||
G0r[np.abs(G0r) < 1e-8] = 0
|
||||
G1r[np.abs(G1r) < 1e-8] = 0
|
||||
self.max_item = 3
|
||||
|
||||
self.A = sparseKernelFT1d(k, alpha, c)
|
||||
self.B = sparseKernelFT1d(k, alpha, c)
|
||||
self.C = sparseKernelFT1d(k, alpha, c)
|
||||
|
||||
self.T0 = nn.Linear(k, k)
|
||||
|
||||
self.register_buffer('ec_s', torch.Tensor(
|
||||
np.concatenate((H0.T, H1.T), axis=0)))
|
||||
self.register_buffer('ec_d', torch.Tensor(
|
||||
np.concatenate((G0.T, G1.T), axis=0)))
|
||||
|
||||
self.register_buffer('rc_e', torch.Tensor(
|
||||
np.concatenate((H0r, G0r), axis=0)))
|
||||
self.register_buffer('rc_o', torch.Tensor(
|
||||
np.concatenate((H1r, G1r), axis=0)))
|
||||
|
||||
def forward(self, x):
|
||||
B, N, c, k = x.shape # (B, N, k)
|
||||
ns = math.floor(np.log2(N))
|
||||
nl = pow(2, math.ceil(np.log2(N)))
|
||||
extra_x = x[:, 0:nl - N, :, :]
|
||||
x = torch.cat([x, extra_x], 1)
|
||||
Ud = torch.jit.annotate(List[Tensor], [])
|
||||
Us = torch.jit.annotate(List[Tensor], [])
|
||||
for i in range(ns - self.L):
|
||||
d, x = self.wavelet_transform(x)
|
||||
Ud += [self.A(d) + self.B(x)]
|
||||
Us += [self.C(d)]
|
||||
x = self.T0(x) # coarsest scale transform
|
||||
|
||||
# reconstruct
|
||||
for i in range(ns - 1 - self.L, -1, -1):
|
||||
x = x + Us[i]
|
||||
x = torch.cat((x, Ud[i]), -1)
|
||||
x = self.evenOdd(x)
|
||||
x = x[:, :N, :, :]
|
||||
|
||||
return x
|
||||
|
||||
def wavelet_transform(self, x):
|
||||
xa = torch.cat([x[:, ::2, :, :],
|
||||
x[:, 1::2, :, :],
|
||||
], -1)
|
||||
d = torch.matmul(xa, self.ec_d)
|
||||
s = torch.matmul(xa, self.ec_s)
|
||||
return d, s
|
||||
|
||||
def evenOdd(self, x):
|
||||
|
||||
B, N, c, ich = x.shape # (B, N, c, k)
|
||||
assert ich == 2 * self.k
|
||||
x_e = torch.matmul(x, self.rc_e)
|
||||
x_o = torch.matmul(x, self.rc_o)
|
||||
|
||||
x = torch.zeros(B, N * 2, c, self.k,
|
||||
device=x.device)
|
||||
x[..., ::2, :, :] = x_e
|
||||
x[..., 1::2, :, :] = x_o
|
||||
return x
|
218
layers/Pyraformer_EncDec.py
Normal file
218
layers/Pyraformer_EncDec.py
Normal file
@ -0,0 +1,218 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.modules.linear import Linear
|
||||
from layers.SelfAttention_Family import AttentionLayer, FullAttention
|
||||
from layers.Embed import DataEmbedding
|
||||
import math
|
||||
|
||||
|
||||
def get_mask(input_size, window_size, inner_size):
|
||||
"""Get the attention mask of PAM-Naive"""
|
||||
# Get the size of all layers
|
||||
all_size = []
|
||||
all_size.append(input_size)
|
||||
for i in range(len(window_size)):
|
||||
layer_size = math.floor(all_size[i] / window_size[i])
|
||||
all_size.append(layer_size)
|
||||
|
||||
seq_length = sum(all_size)
|
||||
mask = torch.zeros(seq_length, seq_length)
|
||||
|
||||
# get intra-scale mask
|
||||
inner_window = inner_size // 2
|
||||
for layer_idx in range(len(all_size)):
|
||||
start = sum(all_size[:layer_idx])
|
||||
for i in range(start, start + all_size[layer_idx]):
|
||||
left_side = max(i - inner_window, start)
|
||||
right_side = min(i + inner_window + 1, start + all_size[layer_idx])
|
||||
mask[i, left_side:right_side] = 1
|
||||
|
||||
# get inter-scale mask
|
||||
for layer_idx in range(1, len(all_size)):
|
||||
start = sum(all_size[:layer_idx])
|
||||
for i in range(start, start + all_size[layer_idx]):
|
||||
left_side = (start - all_size[layer_idx - 1]) + \
|
||||
(i - start) * window_size[layer_idx - 1]
|
||||
if i == (start + all_size[layer_idx] - 1):
|
||||
right_side = start
|
||||
else:
|
||||
right_side = (
|
||||
start - all_size[layer_idx - 1]) + (i - start + 1) * window_size[layer_idx - 1]
|
||||
mask[i, left_side:right_side] = 1
|
||||
mask[left_side:right_side, i] = 1
|
||||
|
||||
mask = (1 - mask).bool()
|
||||
|
||||
return mask, all_size
|
||||
|
||||
|
||||
def refer_points(all_sizes, window_size):
|
||||
"""Gather features from PAM's pyramid sequences"""
|
||||
input_size = all_sizes[0]
|
||||
indexes = torch.zeros(input_size, len(all_sizes))
|
||||
|
||||
for i in range(input_size):
|
||||
indexes[i][0] = i
|
||||
former_index = i
|
||||
for j in range(1, len(all_sizes)):
|
||||
start = sum(all_sizes[:j])
|
||||
inner_layer_idx = former_index - (start - all_sizes[j - 1])
|
||||
former_index = start + \
|
||||
min(inner_layer_idx // window_size[j - 1], all_sizes[j] - 1)
|
||||
indexes[i][j] = former_index
|
||||
|
||||
indexes = indexes.unsqueeze(0).unsqueeze(3)
|
||||
|
||||
return indexes.long()
|
||||
|
||||
|
||||
class RegularMask():
|
||||
def __init__(self, mask):
|
||||
self._mask = mask.unsqueeze(1)
|
||||
|
||||
@property
|
||||
def mask(self):
|
||||
return self._mask
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
""" Compose with two layers """
|
||||
|
||||
def __init__(self, d_model, d_inner, n_head, dropout=0.1, normalize_before=True):
|
||||
super(EncoderLayer, self).__init__()
|
||||
|
||||
self.slf_attn = AttentionLayer(
|
||||
FullAttention(mask_flag=True, factor=0,
|
||||
attention_dropout=dropout, output_attention=False),
|
||||
d_model, n_head)
|
||||
self.pos_ffn = PositionwiseFeedForward(
|
||||
d_model, d_inner, dropout=dropout, normalize_before=normalize_before)
|
||||
|
||||
def forward(self, enc_input, slf_attn_mask=None):
|
||||
attn_mask = RegularMask(slf_attn_mask)
|
||||
enc_output, _ = self.slf_attn(
|
||||
enc_input, enc_input, enc_input, attn_mask=attn_mask)
|
||||
enc_output = self.pos_ffn(enc_output)
|
||||
return enc_output
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
""" A encoder model with self attention mechanism. """
|
||||
|
||||
def __init__(self, configs, window_size, inner_size):
|
||||
super().__init__()
|
||||
|
||||
d_bottleneck = configs.d_model//4
|
||||
|
||||
self.mask, self.all_size = get_mask(
|
||||
configs.seq_len, window_size, inner_size)
|
||||
self.indexes = refer_points(self.all_size, window_size)
|
||||
self.layers = nn.ModuleList([
|
||||
EncoderLayer(configs.d_model, configs.d_ff, configs.n_heads, dropout=configs.dropout,
|
||||
normalize_before=False) for _ in range(configs.e_layers)
|
||||
]) # naive pyramid attention
|
||||
|
||||
self.enc_embedding = DataEmbedding(
|
||||
configs.enc_in, configs.d_model, configs.dropout)
|
||||
self.conv_layers = Bottleneck_Construct(
|
||||
configs.d_model, window_size, d_bottleneck)
|
||||
|
||||
def forward(self, x_enc, x_mark_enc):
|
||||
seq_enc = self.enc_embedding(x_enc, x_mark_enc)
|
||||
|
||||
mask = self.mask.repeat(len(seq_enc), 1, 1).to(x_enc.device)
|
||||
seq_enc = self.conv_layers(seq_enc)
|
||||
|
||||
for i in range(len(self.layers)):
|
||||
seq_enc = self.layers[i](seq_enc, mask)
|
||||
|
||||
indexes = self.indexes.repeat(seq_enc.size(
|
||||
0), 1, 1, seq_enc.size(2)).to(seq_enc.device)
|
||||
indexes = indexes.view(seq_enc.size(0), -1, seq_enc.size(2))
|
||||
all_enc = torch.gather(seq_enc, 1, indexes)
|
||||
seq_enc = all_enc.view(seq_enc.size(0), self.all_size[0], -1)
|
||||
|
||||
return seq_enc
|
||||
|
||||
|
||||
class ConvLayer(nn.Module):
|
||||
def __init__(self, c_in, window_size):
|
||||
super(ConvLayer, self).__init__()
|
||||
self.downConv = nn.Conv1d(in_channels=c_in,
|
||||
out_channels=c_in,
|
||||
kernel_size=window_size,
|
||||
stride=window_size)
|
||||
self.norm = nn.BatchNorm1d(c_in)
|
||||
self.activation = nn.ELU()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.downConv(x)
|
||||
x = self.norm(x)
|
||||
x = self.activation(x)
|
||||
return x
|
||||
|
||||
|
||||
class Bottleneck_Construct(nn.Module):
|
||||
"""Bottleneck convolution CSCM"""
|
||||
|
||||
def __init__(self, d_model, window_size, d_inner):
|
||||
super(Bottleneck_Construct, self).__init__()
|
||||
if not isinstance(window_size, list):
|
||||
self.conv_layers = nn.ModuleList([
|
||||
ConvLayer(d_inner, window_size),
|
||||
ConvLayer(d_inner, window_size),
|
||||
ConvLayer(d_inner, window_size)
|
||||
])
|
||||
else:
|
||||
self.conv_layers = []
|
||||
for i in range(len(window_size)):
|
||||
self.conv_layers.append(ConvLayer(d_inner, window_size[i]))
|
||||
self.conv_layers = nn.ModuleList(self.conv_layers)
|
||||
self.up = Linear(d_inner, d_model)
|
||||
self.down = Linear(d_model, d_inner)
|
||||
self.norm = nn.LayerNorm(d_model)
|
||||
|
||||
def forward(self, enc_input):
|
||||
temp_input = self.down(enc_input).permute(0, 2, 1)
|
||||
all_inputs = []
|
||||
for i in range(len(self.conv_layers)):
|
||||
temp_input = self.conv_layers[i](temp_input)
|
||||
all_inputs.append(temp_input)
|
||||
|
||||
all_inputs = torch.cat(all_inputs, dim=2).transpose(1, 2)
|
||||
all_inputs = self.up(all_inputs)
|
||||
all_inputs = torch.cat([enc_input, all_inputs], dim=1)
|
||||
|
||||
all_inputs = self.norm(all_inputs)
|
||||
return all_inputs
|
||||
|
||||
|
||||
class PositionwiseFeedForward(nn.Module):
|
||||
""" Two-layer position-wise feed-forward neural network. """
|
||||
|
||||
def __init__(self, d_in, d_hid, dropout=0.1, normalize_before=True):
|
||||
super().__init__()
|
||||
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
self.w_1 = nn.Linear(d_in, d_hid)
|
||||
self.w_2 = nn.Linear(d_hid, d_in)
|
||||
|
||||
self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.layer_norm(x)
|
||||
|
||||
x = F.gelu(self.w_1(x))
|
||||
x = self.dropout(x)
|
||||
x = self.w_2(x)
|
||||
x = self.dropout(x)
|
||||
x = x + residual
|
||||
|
||||
if not self.normalize_before:
|
||||
x = self.layer_norm(x)
|
||||
return x
|
59
layers/RevIN.py
Normal file
59
layers/RevIN.py
Normal file
@ -0,0 +1,59 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
class RevIN(nn.Module):
|
||||
"""
|
||||
Reversible Instance Normalization
|
||||
"""
|
||||
def __init__(self, num_features: int, eps=1e-5, affine=True, subtract_last=False):
|
||||
super(RevIN, self).__init__()
|
||||
self.num_features = num_features
|
||||
self.eps = eps
|
||||
self.affine = affine
|
||||
self.subtract_last = subtract_last
|
||||
if self.affine:
|
||||
self._init_params()
|
||||
|
||||
def forward(self, x, mode: str):
|
||||
if mode == 'norm':
|
||||
self._get_statistics(x)
|
||||
x = self._normalize(x)
|
||||
elif mode == 'denorm':
|
||||
x = self._denormalize(x)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return x
|
||||
|
||||
def _init_params(self):
|
||||
self.affine_weight = nn.Parameter(torch.ones(self.num_features))
|
||||
self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
|
||||
|
||||
def _get_statistics(self, x):
|
||||
dim2reduce = tuple(range(1, x.ndim-1))
|
||||
if self.subtract_last:
|
||||
self.last = x[:, -1, :].unsqueeze(1)
|
||||
else:
|
||||
self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
|
||||
self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()
|
||||
|
||||
def _normalize(self, x):
|
||||
if self.subtract_last:
|
||||
x = x - self.last
|
||||
else:
|
||||
x = x - self.mean
|
||||
x = x / self.stdev
|
||||
if self.affine:
|
||||
x = x * self.affine_weight
|
||||
x = x + self.affine_bias
|
||||
return x
|
||||
|
||||
def _denormalize(self, x):
|
||||
if self.affine:
|
||||
x = x - self.affine_bias
|
||||
x = x / (self.affine_weight + self.eps * self.eps)
|
||||
x = x * self.stdev
|
||||
if self.subtract_last:
|
||||
x = x + self.last
|
||||
else:
|
||||
x = x + self.mean
|
||||
return x
|
67
layers/SeasonPatch.py
Normal file
67
layers/SeasonPatch.py
Normal file
@ -0,0 +1,67 @@
|
||||
"""
|
||||
SeasonPatch = PatchTST (CI) + ChannelGraphMixer + Linear prediction head
|
||||
Adapted for Time-Series-Library-main style
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from layers.TSTEncoder import TSTiEncoder
|
||||
from layers.GraphMixer import HierarchicalGraphMixer
|
||||
|
||||
class SeasonPatch(nn.Module):
|
||||
def __init__(self,
|
||||
c_in: int,
|
||||
seq_len: int,
|
||||
pred_len: int,
|
||||
patch_len: int,
|
||||
stride: int,
|
||||
k_graph: int = 8,
|
||||
d_model: int = 128,
|
||||
n_layers: int = 3,
|
||||
n_heads: int = 16):
|
||||
super().__init__()
|
||||
|
||||
# Store patch parameters
|
||||
self.patch_len = patch_len
|
||||
self.stride = stride
|
||||
|
||||
# Calculate patch number
|
||||
patch_num = (seq_len - patch_len) // stride + 1
|
||||
|
||||
# PatchTST encoder (channel independent)
|
||||
self.encoder = TSTiEncoder(
|
||||
c_in=c_in,
|
||||
patch_num=patch_num,
|
||||
patch_len=patch_len,
|
||||
d_model=d_model,
|
||||
n_layers=n_layers,
|
||||
n_heads=n_heads
|
||||
)
|
||||
|
||||
# Cross-channel mixer
|
||||
self.mixer = HierarchicalGraphMixer(c_in, dim=d_model, k=k_graph)
|
||||
|
||||
# Prediction head
|
||||
self.head = nn.Linear(patch_num * d_model, pred_len)
|
||||
|
||||
def forward(self, x):
|
||||
# x: [B, L, C]
|
||||
x = x.permute(0, 2, 1) # → [B, C, L]
|
||||
|
||||
# Patch the input
|
||||
x_patch = x.unfold(-1, self.patch_len, self.stride) # [B, C, patch_num, patch_len]
|
||||
|
||||
# Encode patches
|
||||
z = self.encoder(x_patch) # [B, C, d_model, patch_num]
|
||||
|
||||
# z: [B, C, d_model, patch_num] → [B, C, patch_num, d_model]
|
||||
B, C, D, N = z.shape
|
||||
z = z.permute(0, 1, 3, 2) # [B, C, patch_num, d_model]
|
||||
|
||||
# Cross-channel mixing
|
||||
z_mix = self.mixer(z) # [B, C, patch_num, d_model]
|
||||
|
||||
# Flatten and predict
|
||||
z_mix = z_mix.view(B, C, N * D) # [B, C, patch_num * d_model]
|
||||
y_pred = self.head(z_mix) # [B, C, pred_len]
|
||||
|
||||
return y_pred
|
302
layers/SelfAttention_Family.py
Normal file
302
layers/SelfAttention_Family.py
Normal file
@ -0,0 +1,302 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from math import sqrt
|
||||
from utils.masking import TriangularCausalMask, ProbMask
|
||||
from reformer_pytorch import LSHSelfAttention
|
||||
from einops import rearrange, repeat
|
||||
|
||||
|
||||
class DSAttention(nn.Module):
|
||||
'''De-stationary Attention'''
|
||||
|
||||
def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
|
||||
super(DSAttention, self).__init__()
|
||||
self.scale = scale
|
||||
self.mask_flag = mask_flag
|
||||
self.output_attention = output_attention
|
||||
self.dropout = nn.Dropout(attention_dropout)
|
||||
|
||||
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
|
||||
B, L, H, E = queries.shape
|
||||
_, S, _, D = values.shape
|
||||
scale = self.scale or 1. / sqrt(E)
|
||||
|
||||
tau = 1.0 if tau is None else tau.unsqueeze(
|
||||
1).unsqueeze(1) # B x 1 x 1 x 1
|
||||
delta = 0.0 if delta is None else delta.unsqueeze(
|
||||
1).unsqueeze(1) # B x 1 x 1 x S
|
||||
|
||||
# De-stationary Attention, rescaling pre-softmax score with learned de-stationary factors
|
||||
scores = torch.einsum("blhe,bshe->bhls", queries, keys) * tau + delta
|
||||
|
||||
if self.mask_flag:
|
||||
if attn_mask is None:
|
||||
attn_mask = TriangularCausalMask(B, L, device=queries.device)
|
||||
|
||||
scores.masked_fill_(attn_mask.mask, -np.inf)
|
||||
|
||||
A = self.dropout(torch.softmax(scale * scores, dim=-1))
|
||||
V = torch.einsum("bhls,bshd->blhd", A, values)
|
||||
|
||||
if self.output_attention:
|
||||
return V.contiguous(), A
|
||||
else:
|
||||
return V.contiguous(), None
|
||||
|
||||
|
||||
class FullAttention(nn.Module):
|
||||
def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
|
||||
super(FullAttention, self).__init__()
|
||||
self.scale = scale
|
||||
self.mask_flag = mask_flag
|
||||
self.output_attention = output_attention
|
||||
self.dropout = nn.Dropout(attention_dropout)
|
||||
|
||||
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
|
||||
B, L, H, E = queries.shape
|
||||
_, S, _, D = values.shape
|
||||
scale = self.scale or 1. / sqrt(E)
|
||||
|
||||
scores = torch.einsum("blhe,bshe->bhls", queries, keys)
|
||||
|
||||
if self.mask_flag:
|
||||
if attn_mask is None:
|
||||
attn_mask = TriangularCausalMask(B, L, device=queries.device)
|
||||
|
||||
scores.masked_fill_(attn_mask.mask, -np.inf)
|
||||
|
||||
A = self.dropout(torch.softmax(scale * scores, dim=-1))
|
||||
V = torch.einsum("bhls,bshd->blhd", A, values)
|
||||
|
||||
if self.output_attention:
|
||||
return V.contiguous(), A
|
||||
else:
|
||||
return V.contiguous(), None
|
||||
|
||||
|
||||
class ProbAttention(nn.Module):
|
||||
def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
|
||||
super(ProbAttention, self).__init__()
|
||||
self.factor = factor
|
||||
self.scale = scale
|
||||
self.mask_flag = mask_flag
|
||||
self.output_attention = output_attention
|
||||
self.dropout = nn.Dropout(attention_dropout)
|
||||
|
||||
def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q)
|
||||
# Q [B, H, L, D]
|
||||
B, H, L_K, E = K.shape
|
||||
_, _, L_Q, _ = Q.shape
|
||||
|
||||
# calculate the sampled Q_K
|
||||
K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
|
||||
# real U = U_part(factor*ln(L_k))*L_q
|
||||
index_sample = torch.randint(L_K, (L_Q, sample_k))
|
||||
K_sample = K_expand[:, :, torch.arange(
|
||||
L_Q).unsqueeze(1), index_sample, :]
|
||||
Q_K_sample = torch.matmul(
|
||||
Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze()
|
||||
|
||||
# find the Top_k query with sparisty measurement
|
||||
M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
|
||||
M_top = M.topk(n_top, sorted=False)[1]
|
||||
|
||||
# use the reduced Q to calculate Q_K
|
||||
Q_reduce = Q[torch.arange(B)[:, None, None],
|
||||
torch.arange(H)[None, :, None],
|
||||
M_top, :] # factor*ln(L_q)
|
||||
Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k
|
||||
|
||||
return Q_K, M_top
|
||||
|
||||
def _get_initial_context(self, V, L_Q):
|
||||
B, H, L_V, D = V.shape
|
||||
if not self.mask_flag:
|
||||
# V_sum = V.sum(dim=-2)
|
||||
V_sum = V.mean(dim=-2)
|
||||
contex = V_sum.unsqueeze(-2).expand(B, H,
|
||||
L_Q, V_sum.shape[-1]).clone()
|
||||
else: # use mask
|
||||
# requires that L_Q == L_V, i.e. for self-attention only
|
||||
assert (L_Q == L_V)
|
||||
contex = V.cumsum(dim=-2)
|
||||
return contex
|
||||
|
||||
def _update_context(self, context_in, V, scores, index, L_Q, attn_mask):
|
||||
B, H, L_V, D = V.shape
|
||||
|
||||
if self.mask_flag:
|
||||
attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
|
||||
scores.masked_fill_(attn_mask.mask, -np.inf)
|
||||
|
||||
attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores)
|
||||
|
||||
context_in[torch.arange(B)[:, None, None],
|
||||
torch.arange(H)[None, :, None],
|
||||
index, :] = torch.matmul(attn, V).type_as(context_in)
|
||||
if self.output_attention:
|
||||
attns = (torch.ones([B, H, L_V, L_V]) /
|
||||
L_V).type_as(attn).to(attn.device)
|
||||
attns[torch.arange(B)[:, None, None], torch.arange(H)[
|
||||
None, :, None], index, :] = attn
|
||||
return context_in, attns
|
||||
else:
|
||||
return context_in, None
|
||||
|
||||
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
|
||||
B, L_Q, H, D = queries.shape
|
||||
_, L_K, _, _ = keys.shape
|
||||
|
||||
queries = queries.transpose(2, 1)
|
||||
keys = keys.transpose(2, 1)
|
||||
values = values.transpose(2, 1)
|
||||
|
||||
U_part = self.factor * \
|
||||
np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k)
|
||||
u = self.factor * \
|
||||
np.ceil(np.log(L_Q)).astype('int').item() # c*ln(L_q)
|
||||
|
||||
U_part = U_part if U_part < L_K else L_K
|
||||
u = u if u < L_Q else L_Q
|
||||
|
||||
scores_top, index = self._prob_QK(
|
||||
queries, keys, sample_k=U_part, n_top=u)
|
||||
|
||||
# add scale factor
|
||||
scale = self.scale or 1. / sqrt(D)
|
||||
if scale is not None:
|
||||
scores_top = scores_top * scale
|
||||
# get the context
|
||||
context = self._get_initial_context(values, L_Q)
|
||||
# update the context with selected top_k queries
|
||||
context, attn = self._update_context(
|
||||
context, values, scores_top, index, L_Q, attn_mask)
|
||||
|
||||
return context.contiguous(), attn
|
||||
|
||||
|
||||
class AttentionLayer(nn.Module):
|
||||
def __init__(self, attention, d_model, n_heads, d_keys=None,
|
||||
d_values=None):
|
||||
super(AttentionLayer, self).__init__()
|
||||
|
||||
d_keys = d_keys or (d_model // n_heads)
|
||||
d_values = d_values or (d_model // n_heads)
|
||||
|
||||
self.inner_attention = attention
|
||||
self.query_projection = nn.Linear(d_model, d_keys * n_heads)
|
||||
self.key_projection = nn.Linear(d_model, d_keys * n_heads)
|
||||
self.value_projection = nn.Linear(d_model, d_values * n_heads)
|
||||
self.out_projection = nn.Linear(d_values * n_heads, d_model)
|
||||
self.n_heads = n_heads
|
||||
|
||||
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
|
||||
B, L, _ = queries.shape
|
||||
_, S, _ = keys.shape
|
||||
H = self.n_heads
|
||||
|
||||
queries = self.query_projection(queries).view(B, L, H, -1)
|
||||
keys = self.key_projection(keys).view(B, S, H, -1)
|
||||
values = self.value_projection(values).view(B, S, H, -1)
|
||||
|
||||
out, attn = self.inner_attention(
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
attn_mask,
|
||||
tau=tau,
|
||||
delta=delta
|
||||
)
|
||||
out = out.view(B, L, -1)
|
||||
|
||||
return self.out_projection(out), attn
|
||||
|
||||
|
||||
class ReformerLayer(nn.Module):
|
||||
def __init__(self, attention, d_model, n_heads, d_keys=None,
|
||||
d_values=None, causal=False, bucket_size=4, n_hashes=4):
|
||||
super().__init__()
|
||||
self.bucket_size = bucket_size
|
||||
self.attn = LSHSelfAttention(
|
||||
dim=d_model,
|
||||
heads=n_heads,
|
||||
bucket_size=bucket_size,
|
||||
n_hashes=n_hashes,
|
||||
causal=causal
|
||||
)
|
||||
|
||||
def fit_length(self, queries):
|
||||
# inside reformer: assert N % (bucket_size * 2) == 0
|
||||
B, N, C = queries.shape
|
||||
if N % (self.bucket_size * 2) == 0:
|
||||
return queries
|
||||
else:
|
||||
# fill the time series
|
||||
fill_len = (self.bucket_size * 2) - (N % (self.bucket_size * 2))
|
||||
return torch.cat([queries, torch.zeros([B, fill_len, C]).to(queries.device)], dim=1)
|
||||
|
||||
def forward(self, queries, keys, values, attn_mask, tau, delta):
|
||||
# in Reformer: defalut queries=keys
|
||||
B, N, C = queries.shape
|
||||
queries = self.attn(self.fit_length(queries))[:, :N, :]
|
||||
return queries, None
|
||||
|
||||
|
||||
class TwoStageAttentionLayer(nn.Module):
|
||||
'''
|
||||
The Two Stage Attention (TSA) Layer
|
||||
input/output shape: [batch_size, Data_dim(D), Seg_num(L), d_model]
|
||||
'''
|
||||
|
||||
def __init__(self, configs,
|
||||
seg_num, factor, d_model, n_heads, d_ff=None, dropout=0.1):
|
||||
super(TwoStageAttentionLayer, self).__init__()
|
||||
d_ff = d_ff or 4 * d_model
|
||||
self.time_attention = AttentionLayer(FullAttention(False, configs.factor, attention_dropout=configs.dropout,
|
||||
output_attention=False), d_model, n_heads)
|
||||
self.dim_sender = AttentionLayer(FullAttention(False, configs.factor, attention_dropout=configs.dropout,
|
||||
output_attention=False), d_model, n_heads)
|
||||
self.dim_receiver = AttentionLayer(FullAttention(False, configs.factor, attention_dropout=configs.dropout,
|
||||
output_attention=False), d_model, n_heads)
|
||||
self.router = nn.Parameter(torch.randn(seg_num, factor, d_model))
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.norm3 = nn.LayerNorm(d_model)
|
||||
self.norm4 = nn.LayerNorm(d_model)
|
||||
|
||||
self.MLP1 = nn.Sequential(nn.Linear(d_model, d_ff),
|
||||
nn.GELU(),
|
||||
nn.Linear(d_ff, d_model))
|
||||
self.MLP2 = nn.Sequential(nn.Linear(d_model, d_ff),
|
||||
nn.GELU(),
|
||||
nn.Linear(d_ff, d_model))
|
||||
|
||||
def forward(self, x, attn_mask=None, tau=None, delta=None):
|
||||
# Cross Time Stage: Directly apply MSA to each dimension
|
||||
batch = x.shape[0]
|
||||
time_in = rearrange(x, 'b ts_d seg_num d_model -> (b ts_d) seg_num d_model')
|
||||
time_enc, attn = self.time_attention(
|
||||
time_in, time_in, time_in, attn_mask=None, tau=None, delta=None
|
||||
)
|
||||
dim_in = time_in + self.dropout(time_enc)
|
||||
dim_in = self.norm1(dim_in)
|
||||
dim_in = dim_in + self.dropout(self.MLP1(dim_in))
|
||||
dim_in = self.norm2(dim_in)
|
||||
|
||||
# Cross Dimension Stage: use a small set of learnable vectors to aggregate and distribute messages to build the D-to-D connection
|
||||
dim_send = rearrange(dim_in, '(b ts_d) seg_num d_model -> (b seg_num) ts_d d_model', b=batch)
|
||||
batch_router = repeat(self.router, 'seg_num factor d_model -> (repeat seg_num) factor d_model', repeat=batch)
|
||||
dim_buffer, attn = self.dim_sender(batch_router, dim_send, dim_send, attn_mask=None, tau=None, delta=None)
|
||||
dim_receive, attn = self.dim_receiver(dim_send, dim_buffer, dim_buffer, attn_mask=None, tau=None, delta=None)
|
||||
dim_enc = dim_send + self.dropout(dim_receive)
|
||||
dim_enc = self.norm3(dim_enc)
|
||||
dim_enc = dim_enc + self.dropout(self.MLP2(dim_enc))
|
||||
dim_enc = self.norm4(dim_enc)
|
||||
|
||||
final_out = rearrange(dim_enc, '(b seg_num) ts_d d_model -> b ts_d seg_num d_model', b=batch)
|
||||
|
||||
return final_out
|
68
layers/StandardNorm.py
Executable file
68
layers/StandardNorm.py
Executable file
@ -0,0 +1,68 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Normalize(nn.Module):
|
||||
def __init__(self, num_features: int, eps=1e-5, affine=False, subtract_last=False, non_norm=False):
|
||||
"""
|
||||
:param num_features: the number of features or channels
|
||||
:param eps: a value added for numerical stability
|
||||
:param affine: if True, RevIN has learnable affine parameters
|
||||
"""
|
||||
super(Normalize, self).__init__()
|
||||
self.num_features = num_features
|
||||
self.eps = eps
|
||||
self.affine = affine
|
||||
self.subtract_last = subtract_last
|
||||
self.non_norm = non_norm
|
||||
if self.affine:
|
||||
self._init_params()
|
||||
|
||||
def forward(self, x, mode: str):
|
||||
if mode == 'norm':
|
||||
self._get_statistics(x)
|
||||
x = self._normalize(x)
|
||||
elif mode == 'denorm':
|
||||
x = self._denormalize(x)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return x
|
||||
|
||||
def _init_params(self):
|
||||
# initialize RevIN params: (C,)
|
||||
self.affine_weight = nn.Parameter(torch.ones(self.num_features))
|
||||
self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
|
||||
|
||||
def _get_statistics(self, x):
|
||||
dim2reduce = tuple(range(1, x.ndim - 1))
|
||||
if self.subtract_last:
|
||||
self.last = x[:, -1, :].unsqueeze(1)
|
||||
else:
|
||||
self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
|
||||
self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()
|
||||
|
||||
def _normalize(self, x):
|
||||
if self.non_norm:
|
||||
return x
|
||||
if self.subtract_last:
|
||||
x = x - self.last
|
||||
else:
|
||||
x = x - self.mean
|
||||
x = x / self.stdev
|
||||
if self.affine:
|
||||
x = x * self.affine_weight
|
||||
x = x + self.affine_bias
|
||||
return x
|
||||
|
||||
def _denormalize(self, x):
|
||||
if self.non_norm:
|
||||
return x
|
||||
if self.affine:
|
||||
x = x - self.affine_bias
|
||||
x = x / (self.affine_weight + self.eps * self.eps)
|
||||
x = x * self.stdev
|
||||
if self.subtract_last:
|
||||
x = x + self.last
|
||||
else:
|
||||
x = x + self.mean
|
||||
return x
|
91
layers/TSTEncoder.py
Normal file
91
layers/TSTEncoder.py
Normal file
@ -0,0 +1,91 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from layers.Embed import PositionalEmbedding
|
||||
from layers.SelfAttention_Family import FullAttention, AttentionLayer
|
||||
from layers.Transformer_EncDec import EncoderLayer
|
||||
|
||||
class TSTEncoder(nn.Module):
|
||||
"""
|
||||
Transformer encoder for PatchTST, adapted for Time-Series-Library-main style
|
||||
"""
|
||||
def __init__(self, q_len, d_model, n_heads, d_k=None, d_v=None, d_ff=None,
|
||||
norm='BatchNorm', attn_dropout=0., dropout=0., activation='gelu',
|
||||
n_layers=1):
|
||||
super().__init__()
|
||||
|
||||
d_k = d_model // n_heads if d_k is None else d_k
|
||||
d_v = d_model // n_heads if d_v is None else d_v
|
||||
d_ff = d_model * 4 if d_ff is None else d_ff
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
EncoderLayer(
|
||||
AttentionLayer(
|
||||
FullAttention(False, attention_dropout=attn_dropout),
|
||||
d_model, n_heads
|
||||
),
|
||||
d_model,
|
||||
d_ff,
|
||||
dropout=dropout,
|
||||
activation=activation
|
||||
) for i in range(n_layers)
|
||||
])
|
||||
|
||||
def forward(self, src, attn_mask=None):
|
||||
output = src
|
||||
attns = []
|
||||
for layer in self.layers:
|
||||
output, attn = layer(output, attn_mask)
|
||||
attns.append(attn)
|
||||
return output, attns
|
||||
|
||||
|
||||
class TSTiEncoder(nn.Module):
|
||||
"""
|
||||
Channel-independent TST Encoder adapted for Time-Series-Library-main
|
||||
"""
|
||||
def __init__(self, c_in, patch_num, patch_len, max_seq_len=1024,
|
||||
n_layers=3, d_model=128, n_heads=16, d_k=None, d_v=None,
|
||||
d_ff=256, norm='BatchNorm', attn_dropout=0., dropout=0.,
|
||||
activation="gelu"):
|
||||
super().__init__()
|
||||
|
||||
self.patch_num = patch_num
|
||||
self.patch_len = patch_len
|
||||
|
||||
# Input encoding - projection of feature vectors onto a d-dim vector space
|
||||
self.W_P = nn.Linear(patch_len, d_model)
|
||||
|
||||
# Positional encoding using Time-Series-Library-main's PositionalEmbedding
|
||||
self.pos_embedding = PositionalEmbedding(d_model, max_len=max_seq_len)
|
||||
|
||||
# Residual dropout
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
# Encoder
|
||||
self.encoder = TSTEncoder(patch_num, d_model, n_heads, d_k=d_k, d_v=d_v,
|
||||
d_ff=d_ff, norm=norm, attn_dropout=attn_dropout,
|
||||
dropout=dropout, activation=activation, n_layers=n_layers)
|
||||
|
||||
def forward(self, x):
|
||||
# x: [bs x nvars x patch_num x patch_len]
|
||||
bs, n_vars, patch_num, patch_len = x.shape
|
||||
|
||||
# Input encoding: project patch_len to d_model
|
||||
x = self.W_P(x) # x: [bs x nvars x patch_num x d_model]
|
||||
|
||||
# Reshape for attention: combine batch and channel dimensions
|
||||
u = torch.reshape(x, (bs * n_vars, patch_num, x.shape[-1])) # u: [bs * nvars x patch_num x d_model]
|
||||
|
||||
# Add positional encoding
|
||||
pos = self.pos_embedding(u) # Get positional encoding [bs*nvars x patch_num x d_model]
|
||||
u = self.dropout(u + pos[:, :patch_num, :]) # Add positional encoding
|
||||
|
||||
# Encoder
|
||||
z, attns = self.encoder(u) # z: [bs * nvars x patch_num x d_model]
|
||||
|
||||
# Reshape back to separate batch and channel dimensions
|
||||
z = torch.reshape(z, (bs, n_vars, patch_num, z.shape[-1])) # z: [bs x nvars x patch_num x d_model]
|
||||
z = z.permute(0, 1, 3, 2) # z: [bs x nvars x d_model x patch_num]
|
||||
|
||||
return z
|
135
layers/Transformer_EncDec.py
Normal file
135
layers/Transformer_EncDec.py
Normal file
@ -0,0 +1,135 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class ConvLayer(nn.Module):
|
||||
def __init__(self, c_in):
|
||||
super(ConvLayer, self).__init__()
|
||||
self.downConv = nn.Conv1d(in_channels=c_in,
|
||||
out_channels=c_in,
|
||||
kernel_size=3,
|
||||
padding=2,
|
||||
padding_mode='circular')
|
||||
self.norm = nn.BatchNorm1d(c_in)
|
||||
self.activation = nn.ELU()
|
||||
self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.downConv(x.permute(0, 2, 1))
|
||||
x = self.norm(x)
|
||||
x = self.activation(x)
|
||||
x = self.maxPool(x)
|
||||
x = x.transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
|
||||
super(EncoderLayer, self).__init__()
|
||||
d_ff = d_ff or 4 * d_model
|
||||
self.attention = attention
|
||||
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
|
||||
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.activation = F.relu if activation == "relu" else F.gelu
|
||||
|
||||
def forward(self, x, attn_mask=None, tau=None, delta=None):
|
||||
new_x, attn = self.attention(
|
||||
x, x, x,
|
||||
attn_mask=attn_mask,
|
||||
tau=tau, delta=delta
|
||||
)
|
||||
x = x + self.dropout(new_x)
|
||||
|
||||
y = x = self.norm1(x)
|
||||
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
|
||||
y = self.dropout(self.conv2(y).transpose(-1, 1))
|
||||
|
||||
return self.norm2(x + y), attn
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
|
||||
super(Encoder, self).__init__()
|
||||
self.attn_layers = nn.ModuleList(attn_layers)
|
||||
self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
|
||||
self.norm = norm_layer
|
||||
|
||||
def forward(self, x, attn_mask=None, tau=None, delta=None):
|
||||
# x [B, L, D]
|
||||
attns = []
|
||||
if self.conv_layers is not None:
|
||||
for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)):
|
||||
delta = delta if i == 0 else None
|
||||
x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
|
||||
x = conv_layer(x)
|
||||
attns.append(attn)
|
||||
x, attn = self.attn_layers[-1](x, tau=tau, delta=None)
|
||||
attns.append(attn)
|
||||
else:
|
||||
for attn_layer in self.attn_layers:
|
||||
x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
|
||||
attns.append(attn)
|
||||
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
|
||||
return x, attns
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
def __init__(self, self_attention, cross_attention, d_model, d_ff=None,
|
||||
dropout=0.1, activation="relu"):
|
||||
super(DecoderLayer, self).__init__()
|
||||
d_ff = d_ff or 4 * d_model
|
||||
self.self_attention = self_attention
|
||||
self.cross_attention = cross_attention
|
||||
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
|
||||
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.norm3 = nn.LayerNorm(d_model)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.activation = F.relu if activation == "relu" else F.gelu
|
||||
|
||||
def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
|
||||
x = x + self.dropout(self.self_attention(
|
||||
x, x, x,
|
||||
attn_mask=x_mask,
|
||||
tau=tau, delta=None
|
||||
)[0])
|
||||
x = self.norm1(x)
|
||||
|
||||
x = x + self.dropout(self.cross_attention(
|
||||
x, cross, cross,
|
||||
attn_mask=cross_mask,
|
||||
tau=tau, delta=delta
|
||||
)[0])
|
||||
|
||||
y = x = self.norm2(x)
|
||||
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
|
||||
y = self.dropout(self.conv2(y).transpose(-1, 1))
|
||||
|
||||
return self.norm3(x + y)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, layers, norm_layer=None, projection=None):
|
||||
super(Decoder, self).__init__()
|
||||
self.layers = nn.ModuleList(layers)
|
||||
self.norm = norm_layer
|
||||
self.projection = projection
|
||||
|
||||
def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
|
||||
for layer in self.layers:
|
||||
x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta)
|
||||
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
|
||||
if self.projection is not None:
|
||||
x = self.projection(x)
|
||||
return x
|
0
layers/__init__.py
Normal file
0
layers/__init__.py
Normal file
Reference in New Issue
Block a user