feat: add TimesNet_Q and xPatch models with Q matrix transformations
This commit is contained in:
21
layers/decomp.py
Normal file
21
layers/decomp.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from layers.ema import EMA
|
||||||
|
from layers.dema import DEMA
|
||||||
|
|
||||||
|
class DECOMP(nn.Module):
|
||||||
|
"""
|
||||||
|
Series decomposition block
|
||||||
|
"""
|
||||||
|
def __init__(self, ma_type, alpha, beta):
|
||||||
|
super(DECOMP, self).__init__()
|
||||||
|
if ma_type == 'ema':
|
||||||
|
self.ma = EMA(alpha)
|
||||||
|
elif ma_type == 'dema':
|
||||||
|
self.ma = DEMA(alpha, beta)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
moving_average = self.ma(x)
|
||||||
|
res = x - moving_average
|
||||||
|
return res, moving_average
|
27
layers/dema.py
Normal file
27
layers/dema.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
class DEMA(nn.Module):
|
||||||
|
"""
|
||||||
|
Double Exponential Moving Average (DEMA) block to highlight the trend of time series
|
||||||
|
"""
|
||||||
|
def __init__(self, alpha, beta):
|
||||||
|
super(DEMA, self).__init__()
|
||||||
|
# self.alpha = nn.Parameter(alpha) # Learnable alpha
|
||||||
|
# self.beta = nn.Parameter(beta) # Learnable beta
|
||||||
|
self.alpha = alpha.to(device='cuda')
|
||||||
|
self.beta = beta.to(device='cuda')
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# self.alpha.data.clamp_(0, 1) # Clamp learnable alpha to [0, 1]
|
||||||
|
# self.beta.data.clamp_(0, 1) # Clamp learnable beta to [0, 1]
|
||||||
|
s_prev = x[:, 0, :]
|
||||||
|
b = x[:, 1, :] - s_prev
|
||||||
|
res = [s_prev.unsqueeze(1)]
|
||||||
|
for t in range(1, x.shape[1]):
|
||||||
|
xt = x[:, t, :]
|
||||||
|
s = self.alpha * xt + (1 - self.alpha) * (s_prev + b)
|
||||||
|
b = self.beta * (s - s_prev) + (1 - self.beta) * b
|
||||||
|
s_prev = s
|
||||||
|
res.append(s.unsqueeze(1))
|
||||||
|
return torch.cat(res, dim=1)
|
37
layers/ema.py
Normal file
37
layers/ema.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
class EMA(nn.Module):
|
||||||
|
"""
|
||||||
|
Exponential Moving Average (EMA) block to highlight the trend of time series
|
||||||
|
"""
|
||||||
|
def __init__(self, alpha):
|
||||||
|
super(EMA, self).__init__()
|
||||||
|
# self.alpha = nn.Parameter(alpha) # Learnable alpha
|
||||||
|
self.alpha = alpha
|
||||||
|
|
||||||
|
# Optimized implementation with O(1) time complexity
|
||||||
|
def forward(self, x):
|
||||||
|
# x: [Batch, Input, Channel]
|
||||||
|
# self.alpha.data.clamp_(0, 1) # Clamp learnable alpha to [0, 1]
|
||||||
|
_, t, _ = x.shape
|
||||||
|
powers = torch.flip(torch.arange(t, dtype=torch.double), dims=(0,))
|
||||||
|
weights = torch.pow((1 - self.alpha), powers).to('cuda')
|
||||||
|
divisor = weights.clone()
|
||||||
|
weights[1:] = weights[1:] * self.alpha
|
||||||
|
weights = weights.reshape(1, t, 1)
|
||||||
|
divisor = divisor.reshape(1, t, 1)
|
||||||
|
x = torch.cumsum(x * weights, dim=1)
|
||||||
|
x = torch.div(x, divisor)
|
||||||
|
return x.to(torch.float32)
|
||||||
|
|
||||||
|
# # Naive implementation with O(n) time complexity
|
||||||
|
# def forward(self, x):
|
||||||
|
# # self.alpha.data.clamp_(0, 1) # Clamp learnable alpha to [0, 1]
|
||||||
|
# s = x[:, 0, :]
|
||||||
|
# res = [s.unsqueeze(1)]
|
||||||
|
# for t in range(1, x.shape[1]):
|
||||||
|
# xt = x[:, t, :]
|
||||||
|
# s = self.alpha * xt + (1 - self.alpha) * s
|
||||||
|
# res.append(s.unsqueeze(1))
|
||||||
|
# return torch.cat(res, dim=1)
|
61
layers/revin.py
Normal file
61
layers/revin.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
class RevIN(nn.Module):
|
||||||
|
def __init__(self, num_features: int, eps=1e-5, affine=True, subtract_last=False):
|
||||||
|
"""
|
||||||
|
:param num_features: the number of features or channels
|
||||||
|
:param eps: a value added for numerical stability
|
||||||
|
:param affine: if True, RevIN has learnable affine parameters
|
||||||
|
"""
|
||||||
|
super(RevIN, self).__init__()
|
||||||
|
self.num_features = num_features
|
||||||
|
self.eps = eps
|
||||||
|
self.affine = affine
|
||||||
|
self.subtract_last = subtract_last
|
||||||
|
if self.affine:
|
||||||
|
self._init_params()
|
||||||
|
|
||||||
|
def forward(self, x, mode:str):
|
||||||
|
if mode == 'norm':
|
||||||
|
self._get_statistics(x)
|
||||||
|
x = self._normalize(x)
|
||||||
|
elif mode == 'denorm':
|
||||||
|
x = self._denormalize(x)
|
||||||
|
else: raise NotImplementedError
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _init_params(self):
|
||||||
|
# initialize RevIN params: (C,)
|
||||||
|
self.affine_weight = nn.Parameter(torch.ones(self.num_features))
|
||||||
|
self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
|
||||||
|
|
||||||
|
def _get_statistics(self, x):
|
||||||
|
dim2reduce = tuple(range(1, x.ndim-1))
|
||||||
|
if self.subtract_last:
|
||||||
|
self.last = x[:,-1,:].unsqueeze(1)
|
||||||
|
else:
|
||||||
|
self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
|
||||||
|
self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()
|
||||||
|
|
||||||
|
def _normalize(self, x):
|
||||||
|
if self.subtract_last:
|
||||||
|
x = x - self.last
|
||||||
|
else:
|
||||||
|
x = x - self.mean
|
||||||
|
x = x / self.stdev
|
||||||
|
if self.affine:
|
||||||
|
x = x * self.affine_weight
|
||||||
|
x = x + self.affine_bias
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _denormalize(self, x):
|
||||||
|
if self.affine:
|
||||||
|
x = x - self.affine_bias
|
||||||
|
x = x / (self.affine_weight + self.eps*self.eps)
|
||||||
|
x = x * self.stdev
|
||||||
|
if self.subtract_last:
|
||||||
|
x = x + self.last
|
||||||
|
else:
|
||||||
|
x =x + self.mean
|
||||||
|
return x
|
30
layers/telu.py
Normal file
30
layers/telu.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class TeLU(nn.Module):
|
||||||
|
"""
|
||||||
|
实现论文中提出的 TeLU 激活函数。
|
||||||
|
论文: TeLU Activation Function for Fast and Stable Deep Learning
|
||||||
|
公式: TeLU(x) = x * tanh(e^x)
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
TeLU 激活函数没有可学习的参数,所以 __init__ 方法很简单。
|
||||||
|
"""
|
||||||
|
super(TeLU, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
前向传播的计算逻辑。
|
||||||
|
"""
|
||||||
|
# 直接应用公式
|
||||||
|
return x * torch.tanh(torch.exp(x))
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
"""
|
||||||
|
(可选但推荐) 定义一个好的字符串表示,方便打印模型结构。
|
||||||
|
"""
|
||||||
|
return f"{self.__class__.__name__}()"
|
||||||
|
|
||||||
|
|
@ -3,8 +3,8 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from RevIN import RevIN
|
from ..RevIN import RevIN
|
||||||
from Trans_EncDec import Encoder_ori, LinearEncoder
|
from .Trans_EncDec import Encoder_ori, LinearEncoder
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -1 +1 @@
|
|||||||
from model import *
|
from .model import *
|
||||||
|
221
models/TimesNet_Q/TimesNet_Q.py
Normal file
221
models/TimesNet_Q/TimesNet_Q.py
Normal file
@ -0,0 +1,221 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.fft
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
from layers.Embed import DataEmbedding
|
||||||
|
from layers.Conv_Blocks import Inception_Block_V1
|
||||||
|
|
||||||
|
|
||||||
|
def FFT_for_Period(x, k=2):
|
||||||
|
# [B, T, C]
|
||||||
|
xf = torch.fft.rfft(x, dim=1)
|
||||||
|
# find period by amplitudes
|
||||||
|
frequency_list = abs(xf).mean(0).mean(-1)
|
||||||
|
frequency_list[0] = 0
|
||||||
|
_, top_list = torch.topk(frequency_list, k)
|
||||||
|
top_list = top_list.detach().cpu().numpy()
|
||||||
|
period = x.shape[1] // top_list
|
||||||
|
return period, abs(xf).mean(-1)[:, top_list]
|
||||||
|
|
||||||
|
|
||||||
|
class TimesBlock(nn.Module):
|
||||||
|
"""Original TimesBlock without Q matrix transformation"""
|
||||||
|
def __init__(self, configs):
|
||||||
|
super(TimesBlock, self).__init__()
|
||||||
|
self.seq_len = configs.seq_len
|
||||||
|
self.pred_len = configs.pred_len
|
||||||
|
self.k = configs.top_k
|
||||||
|
# parameter-efficient design
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
Inception_Block_V1(configs.d_model, configs.d_ff,
|
||||||
|
num_kernels=configs.num_kernels),
|
||||||
|
nn.GELU(),
|
||||||
|
Inception_Block_V1(configs.d_ff, configs.d_model,
|
||||||
|
num_kernels=configs.num_kernels)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, T, N = x.size()
|
||||||
|
period_list, period_weight = FFT_for_Period(x, self.k)
|
||||||
|
|
||||||
|
res = []
|
||||||
|
for i in range(self.k):
|
||||||
|
period = period_list[i]
|
||||||
|
# padding
|
||||||
|
if (self.seq_len + self.pred_len) % period != 0:
|
||||||
|
length = (
|
||||||
|
((self.seq_len + self.pred_len) // period) + 1) * period
|
||||||
|
padding = torch.zeros([x.shape[0], (length - (self.seq_len + self.pred_len)), x.shape[2]]).to(x.device)
|
||||||
|
out = torch.cat([x, padding], dim=1)
|
||||||
|
else:
|
||||||
|
length = (self.seq_len + self.pred_len)
|
||||||
|
out = x
|
||||||
|
# reshape
|
||||||
|
out = out.reshape(B, length // period, period,
|
||||||
|
N).permute(0, 3, 1, 2).contiguous()
|
||||||
|
# 2D conv: from 1d Variation to 2d Variation
|
||||||
|
out = self.conv(out)
|
||||||
|
# reshape back
|
||||||
|
out = out.permute(0, 2, 3, 1).reshape(B, -1, N)
|
||||||
|
res.append(out[:, :(self.seq_len + self.pred_len), :])
|
||||||
|
res = torch.stack(res, dim=-1)
|
||||||
|
# adaptive aggregation
|
||||||
|
period_weight = F.softmax(period_weight, dim=1)
|
||||||
|
period_weight = period_weight.unsqueeze(
|
||||||
|
1).unsqueeze(1).repeat(1, T, N, 1)
|
||||||
|
res = torch.sum(res * period_weight, -1)
|
||||||
|
# residual connection
|
||||||
|
res = res + x
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
"""
|
||||||
|
TimesNet with Q matrix transformation
|
||||||
|
- Applies input Q matrix transformation before embedding
|
||||||
|
- Uses original TimesBlock logic
|
||||||
|
- Applies output Q matrix transformation before De-Normalization
|
||||||
|
Only implements long/short term forecasting
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, configs):
|
||||||
|
super(Model, self).__init__()
|
||||||
|
self.configs = configs
|
||||||
|
self.task_name = configs.task_name
|
||||||
|
self.seq_len = configs.seq_len
|
||||||
|
self.label_len = configs.label_len
|
||||||
|
self.pred_len = configs.pred_len
|
||||||
|
|
||||||
|
# Load Q matrices
|
||||||
|
self.load_Q_matrices(configs)
|
||||||
|
|
||||||
|
# Model layers
|
||||||
|
self.model = nn.ModuleList([TimesBlock(configs)
|
||||||
|
for _ in range(configs.e_layers)])
|
||||||
|
self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,
|
||||||
|
configs.dropout)
|
||||||
|
self.layer = configs.e_layers
|
||||||
|
self.layer_norm = nn.LayerNorm(configs.d_model)
|
||||||
|
|
||||||
|
# Only implement forecast-related layers
|
||||||
|
self.predict_linear = nn.Linear(
|
||||||
|
self.seq_len, self.pred_len + self.seq_len)
|
||||||
|
self.projection = nn.Linear(
|
||||||
|
configs.d_model, configs.c_out, bias=True)
|
||||||
|
|
||||||
|
def load_Q_matrices(self, configs):
|
||||||
|
"""Load pre-computed Q matrices for input and output transformations"""
|
||||||
|
# Get dataset name from configs, default to ETTm1 if not specified
|
||||||
|
dataset_name = getattr(configs, 'dataset', 'ETTm1')
|
||||||
|
|
||||||
|
# Input Q matrix (seq_len)
|
||||||
|
input_q_path = f'cov_mats/{dataset_name}/{dataset_name}_{configs.seq_len}_ratio1.0.npy'
|
||||||
|
# Output Q matrix (pred_len)
|
||||||
|
output_q_path = f'cov_mats/{dataset_name}/{dataset_name}_{configs.pred_len}_ratio1.0.npy'
|
||||||
|
|
||||||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
if os.path.exists(input_q_path):
|
||||||
|
Q_input = np.load(input_q_path)
|
||||||
|
self.register_buffer('Q_input', torch.FloatTensor(Q_input).to(device))
|
||||||
|
print(f"Loaded input Q matrix from {input_q_path}, shape: {Q_input.shape}")
|
||||||
|
else:
|
||||||
|
print(f"Warning: Input Q matrix not found at {input_q_path}, using identity matrix")
|
||||||
|
self.register_buffer('Q_input', torch.eye(configs.seq_len).to(device))
|
||||||
|
|
||||||
|
if os.path.exists(output_q_path):
|
||||||
|
Q_output = np.load(output_q_path)
|
||||||
|
self.register_buffer('Q_output', torch.FloatTensor(Q_output).to(device))
|
||||||
|
print(f"Loaded output Q matrix from {output_q_path}, shape: {Q_output.shape}")
|
||||||
|
else:
|
||||||
|
print(f"Warning: Output Q matrix not found at {output_q_path}, using identity matrix")
|
||||||
|
self.register_buffer('Q_output', torch.eye(configs.pred_len).to(device))
|
||||||
|
|
||||||
|
def apply_input_Q_transformation(self, x):
|
||||||
|
"""
|
||||||
|
Apply input Q matrix transformation before embedding
|
||||||
|
Input: x with shape [B, T, N] where T = seq_len
|
||||||
|
Output: transformed x with shape [B, T, N]
|
||||||
|
"""
|
||||||
|
B, T, N = x.size()
|
||||||
|
|
||||||
|
# Transpose to [B, N, T] for matrix multiplication
|
||||||
|
x_transposed = x.transpose(-1, -2) # [B, N, T]
|
||||||
|
|
||||||
|
# Apply input Q transformation: einsum 'bnt,tv->bnv'
|
||||||
|
# x_transposed: [B, N, T], Q_input.T: [T, T] -> result: [B, N, T]
|
||||||
|
x_trans = torch.einsum('bnt,tv->bnv', x_transposed, self.Q_input.transpose(-1, -2))
|
||||||
|
|
||||||
|
# Transpose back to [B, T, N]
|
||||||
|
x_transformed = x_trans.transpose(-1, -2) # [B, T, N]
|
||||||
|
|
||||||
|
return x_transformed
|
||||||
|
|
||||||
|
def apply_output_Q_transformation(self, x):
|
||||||
|
"""
|
||||||
|
Apply output Q matrix transformation to prediction output
|
||||||
|
Input: x with shape [B, pred_len, N]
|
||||||
|
Output: transformed x with shape [B, pred_len, N]
|
||||||
|
"""
|
||||||
|
B, T, N = x.size()
|
||||||
|
assert T == self.pred_len, f"Expected pred_len {self.pred_len}, got {T}"
|
||||||
|
|
||||||
|
# Transpose to [B, N, T] for matrix multiplication
|
||||||
|
x_transposed = x.transpose(-1, -2) # [B, N, pred_len]
|
||||||
|
|
||||||
|
# Apply output Q transformation: einsum 'bnt,tv->bnv'
|
||||||
|
# x_transposed: [B, N, pred_len], Q_output: [pred_len, pred_len] -> result: [B, N, pred_len]
|
||||||
|
x_trans = torch.einsum('bnt,tv->bnv', x_transposed, self.Q_output)
|
||||||
|
|
||||||
|
# Transpose back to [B, pred_len, N]
|
||||||
|
x_transformed = x_trans.transpose(-1, -2) # [B, pred_len, N]
|
||||||
|
|
||||||
|
return x_transformed
|
||||||
|
|
||||||
|
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
||||||
|
# Normalization from Non-stationary Transformer
|
||||||
|
means = x_enc.mean(1, keepdim=True).detach()
|
||||||
|
x_enc = x_enc.sub(means)
|
||||||
|
stdev = torch.sqrt(
|
||||||
|
torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
||||||
|
x_enc = x_enc.div(stdev)
|
||||||
|
|
||||||
|
# Apply input Q matrix transformation before embedding
|
||||||
|
x_enc_transformed = self.apply_input_Q_transformation(x_enc)
|
||||||
|
|
||||||
|
# embedding with transformed input
|
||||||
|
enc_out = self.enc_embedding(x_enc_transformed, x_mark_enc) # [B,T,C]
|
||||||
|
enc_out = self.predict_linear(enc_out.permute(0, 2, 1)).permute(
|
||||||
|
0, 2, 1) # align temporal dimension
|
||||||
|
|
||||||
|
# TimesNet blocks (original logic, no Q transformation)
|
||||||
|
for i in range(self.layer):
|
||||||
|
enc_out = self.layer_norm(self.model[i](enc_out))
|
||||||
|
|
||||||
|
# project back
|
||||||
|
dec_out = self.projection(enc_out)
|
||||||
|
|
||||||
|
# Extract prediction part and apply output Q transformation
|
||||||
|
pred_out = dec_out[:, -self.pred_len:, :] # [B, pred_len, N]
|
||||||
|
pred_out_transformed = self.apply_output_Q_transformation(pred_out)
|
||||||
|
|
||||||
|
# De-Normalization from Non-stationary Transformer
|
||||||
|
pred_out_transformed = pred_out_transformed.mul(
|
||||||
|
(stdev[:, 0, :].unsqueeze(1).repeat(
|
||||||
|
1, self.pred_len, 1)))
|
||||||
|
pred_out_transformed = pred_out_transformed.add(
|
||||||
|
(means[:, 0, :].unsqueeze(1).repeat(
|
||||||
|
1, self.pred_len, 1)))
|
||||||
|
|
||||||
|
return pred_out_transformed
|
||||||
|
|
||||||
|
def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None):
|
||||||
|
# Only support long_term_forecast and short_term_forecast
|
||||||
|
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||||
|
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
||||||
|
return dec_out # [B, pred_len, D]
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Task {self.task_name} is not implemented in TimesNet_Q")
|
||||||
|
return None
|
3
models/TimesNet_Q/__init__.py
Normal file
3
models/TimesNet_Q/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from .TimesNet_Q import Model
|
||||||
|
|
||||||
|
__all__ = ['Model']
|
132
models/xPatch/network.py
Normal file
132
models/xPatch/network.py
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
class Network(nn.Module):
|
||||||
|
def __init__(self, seq_len, pred_len, patch_len, stride, padding_patch):
|
||||||
|
super(Network, self).__init__()
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
self.pred_len = pred_len
|
||||||
|
|
||||||
|
# Non-linear Stream
|
||||||
|
# Patching
|
||||||
|
self.patch_len = patch_len
|
||||||
|
self.stride = stride
|
||||||
|
self.padding_patch = padding_patch
|
||||||
|
self.dim = patch_len * patch_len
|
||||||
|
self.patch_num = (seq_len - patch_len)//stride + 1
|
||||||
|
if padding_patch == 'end': # can be modified to general case
|
||||||
|
self.padding_patch_layer = nn.ReplicationPad1d((0, stride))
|
||||||
|
self.patch_num += 1
|
||||||
|
|
||||||
|
# Patch Embedding
|
||||||
|
self.fc1 = nn.Linear(patch_len, self.dim)
|
||||||
|
self.gelu1 = nn.GELU()
|
||||||
|
self.bn1 = nn.BatchNorm1d(self.patch_num)
|
||||||
|
|
||||||
|
# CNN Depthwise
|
||||||
|
self.conv1 = nn.Conv1d(self.patch_num, self.patch_num,
|
||||||
|
patch_len, patch_len, groups=self.patch_num)
|
||||||
|
self.gelu2 = nn.GELU()
|
||||||
|
self.bn2 = nn.BatchNorm1d(self.patch_num)
|
||||||
|
|
||||||
|
# Residual Stream
|
||||||
|
self.fc2 = nn.Linear(self.dim, patch_len)
|
||||||
|
|
||||||
|
# CNN Pointwise
|
||||||
|
self.conv2 = nn.Conv1d(self.patch_num, self.patch_num, 1, 1)
|
||||||
|
self.gelu3 = nn.GELU()
|
||||||
|
self.bn3 = nn.BatchNorm1d(self.patch_num)
|
||||||
|
|
||||||
|
# Flatten Head
|
||||||
|
self.flatten1 = nn.Flatten(start_dim=-2)
|
||||||
|
self.fc3 = nn.Linear(self.patch_num * patch_len, pred_len * 2)
|
||||||
|
self.gelu4 = nn.GELU()
|
||||||
|
self.fc4 = nn.Linear(pred_len * 2, pred_len)
|
||||||
|
|
||||||
|
# Linear Stream
|
||||||
|
# MLP
|
||||||
|
self.fc5 = nn.Linear(seq_len, pred_len * 4)
|
||||||
|
self.avgpool1 = nn.AvgPool1d(kernel_size=2)
|
||||||
|
self.ln1 = nn.LayerNorm(pred_len * 2)
|
||||||
|
|
||||||
|
self.fc6 = nn.Linear(pred_len * 2, pred_len)
|
||||||
|
self.avgpool2 = nn.AvgPool1d(kernel_size=2)
|
||||||
|
self.ln2 = nn.LayerNorm(pred_len // 2)
|
||||||
|
|
||||||
|
self.fc7 = nn.Linear(pred_len // 2, pred_len)
|
||||||
|
|
||||||
|
# Streams Concatination
|
||||||
|
self.fc8 = nn.Linear(pred_len * 2, pred_len)
|
||||||
|
|
||||||
|
def forward(self, s, t):
|
||||||
|
# x: [Batch, Input, Channel]
|
||||||
|
# s - seasonality
|
||||||
|
# t - trend
|
||||||
|
|
||||||
|
s = s.permute(0,2,1) # to [Batch, Channel, Input]
|
||||||
|
t = t.permute(0,2,1) # to [Batch, Channel, Input]
|
||||||
|
|
||||||
|
# Channel split for channel independence
|
||||||
|
B = s.shape[0] # Batch size
|
||||||
|
C = s.shape[1] # Channel size
|
||||||
|
I = s.shape[2] # Input size
|
||||||
|
s = torch.reshape(s, (B*C, I)) # [Batch and Channel, Input]
|
||||||
|
t = torch.reshape(t, (B*C, I)) # [Batch and Channel, Input]
|
||||||
|
|
||||||
|
# Non-linear Stream
|
||||||
|
# Patching
|
||||||
|
if self.padding_patch == 'end':
|
||||||
|
s = self.padding_patch_layer(s)
|
||||||
|
s = s.unfold(dimension=-1, size=self.patch_len, step=self.stride)
|
||||||
|
# s: [Batch and Channel, Patch_num, Patch_len]
|
||||||
|
|
||||||
|
# Patch Embedding
|
||||||
|
s = self.fc1(s)
|
||||||
|
s = self.gelu1(s)
|
||||||
|
s = self.bn1(s)
|
||||||
|
|
||||||
|
res = s
|
||||||
|
|
||||||
|
# CNN Depthwise
|
||||||
|
s = self.conv1(s)
|
||||||
|
s = self.gelu2(s)
|
||||||
|
s = self.bn2(s)
|
||||||
|
|
||||||
|
# Residual Stream
|
||||||
|
res = self.fc2(res)
|
||||||
|
s = s + res
|
||||||
|
|
||||||
|
# CNN Pointwise
|
||||||
|
s = self.conv2(s)
|
||||||
|
s = self.gelu3(s)
|
||||||
|
s = self.bn3(s)
|
||||||
|
|
||||||
|
# Flatten Head
|
||||||
|
s = self.flatten1(s)
|
||||||
|
s = self.fc3(s)
|
||||||
|
s = self.gelu4(s)
|
||||||
|
s = self.fc4(s)
|
||||||
|
|
||||||
|
# Linear Stream
|
||||||
|
# MLP
|
||||||
|
t = self.fc5(t)
|
||||||
|
t = self.avgpool1(t)
|
||||||
|
t = self.ln1(t)
|
||||||
|
|
||||||
|
t = self.fc6(t)
|
||||||
|
t = self.avgpool2(t)
|
||||||
|
t = self.ln2(t)
|
||||||
|
|
||||||
|
t = self.fc7(t)
|
||||||
|
|
||||||
|
# Streams Concatination
|
||||||
|
x = torch.cat((s, t), dim=1)
|
||||||
|
x = self.fc8(x)
|
||||||
|
|
||||||
|
# Channel concatination
|
||||||
|
x = torch.reshape(x, (B, C, self.pred_len)) # [Batch, Channel, Output]
|
||||||
|
|
||||||
|
x = x.permute(0,2,1) # to [Batch, Output, Channel]
|
||||||
|
|
||||||
|
return x
|
58
models/xPatch/xPatch.py
Normal file
58
models/xPatch/xPatch.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import math
|
||||||
|
|
||||||
|
from layers.decomp import DECOMP
|
||||||
|
from .network import Network
|
||||||
|
# from layers.network_mlp import NetworkMLP # For ablation study with MLP-only stream
|
||||||
|
# from layers.network_cnn import NetworkCNN # For ablation study with CNN-only stream
|
||||||
|
from layers.revin import RevIN
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, configs):
|
||||||
|
super(Model, self).__init__()
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
seq_len = configs.seq_len # lookback window L
|
||||||
|
pred_len = configs.pred_len # prediction length (96, 192, 336, 720)
|
||||||
|
c_in = configs.enc_in # input channels
|
||||||
|
|
||||||
|
# Patching
|
||||||
|
patch_len = configs.patch_len
|
||||||
|
stride = configs.stride
|
||||||
|
padding_patch = configs.padding_patch
|
||||||
|
|
||||||
|
# Normalization
|
||||||
|
self.revin = configs.revin
|
||||||
|
self.revin_layer = RevIN(c_in,affine=True,subtract_last=False)
|
||||||
|
|
||||||
|
# Moving Average
|
||||||
|
self.ma_type = configs.ma_type
|
||||||
|
alpha = configs.alpha # smoothing factor for EMA (Exponential Moving Average)
|
||||||
|
beta = configs.beta # smoothing factor for DEMA (Double Exponential Moving Average)
|
||||||
|
|
||||||
|
self.decomp = DECOMP(self.ma_type, alpha, beta)
|
||||||
|
self.net = Network(seq_len, pred_len, patch_len, stride, padding_patch)
|
||||||
|
# self.net_mlp = NetworkMLP(seq_len, pred_len) # For ablation study with MLP-only stream
|
||||||
|
# self.net_cnn = NetworkCNN(seq_len, pred_len, patch_len, stride, padding_patch) # For ablation study with CNN-only stream
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x: [Batch, Input, Channel]
|
||||||
|
|
||||||
|
# Normalization
|
||||||
|
if self.revin:
|
||||||
|
x = self.revin_layer(x, 'norm')
|
||||||
|
|
||||||
|
if self.ma_type == 'reg': # If no decomposition, directly pass the input to the network
|
||||||
|
x = self.net(x, x)
|
||||||
|
# x = self.net_mlp(x) # For ablation study with MLP-only stream
|
||||||
|
# x = self.net_cnn(x) # For ablation study with CNN-only stream
|
||||||
|
else:
|
||||||
|
seasonal_init, trend_init = self.decomp(x)
|
||||||
|
x = self.net(seasonal_init, trend_init)
|
||||||
|
|
||||||
|
# Denormalization
|
||||||
|
if self.revin:
|
||||||
|
x = self.revin_layer(x, 'denorm')
|
||||||
|
|
||||||
|
return x
|
1
models/xPatch_Q/__init__.py
Normal file
1
models/xPatch_Q/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from .xPatch_Q import Model
|
132
models/xPatch_Q/network.py
Normal file
132
models/xPatch_Q/network.py
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
class Network(nn.Module):
|
||||||
|
def __init__(self, seq_len, pred_len, patch_len, stride, padding_patch):
|
||||||
|
super(Network, self).__init__()
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
self.pred_len = pred_len
|
||||||
|
|
||||||
|
# Non-linear Stream
|
||||||
|
# Patching
|
||||||
|
self.patch_len = patch_len
|
||||||
|
self.stride = stride
|
||||||
|
self.padding_patch = padding_patch
|
||||||
|
self.dim = patch_len * patch_len
|
||||||
|
self.patch_num = (seq_len - patch_len)//stride + 1
|
||||||
|
if padding_patch == 'end': # can be modified to general case
|
||||||
|
self.padding_patch_layer = nn.ReplicationPad1d((0, stride))
|
||||||
|
self.patch_num += 1
|
||||||
|
|
||||||
|
# Patch Embedding
|
||||||
|
self.fc1 = nn.Linear(patch_len, self.dim)
|
||||||
|
self.gelu1 = nn.GELU()
|
||||||
|
self.bn1 = nn.BatchNorm1d(self.patch_num)
|
||||||
|
|
||||||
|
# CNN Depthwise
|
||||||
|
self.conv1 = nn.Conv1d(self.patch_num, self.patch_num,
|
||||||
|
patch_len, patch_len, groups=self.patch_num)
|
||||||
|
self.gelu2 = nn.GELU()
|
||||||
|
self.bn2 = nn.BatchNorm1d(self.patch_num)
|
||||||
|
|
||||||
|
# Residual Stream
|
||||||
|
self.fc2 = nn.Linear(self.dim, patch_len)
|
||||||
|
|
||||||
|
# CNN Pointwise
|
||||||
|
self.conv2 = nn.Conv1d(self.patch_num, self.patch_num, 1, 1)
|
||||||
|
self.gelu3 = nn.GELU()
|
||||||
|
self.bn3 = nn.BatchNorm1d(self.patch_num)
|
||||||
|
|
||||||
|
# Flatten Head
|
||||||
|
self.flatten1 = nn.Flatten(start_dim=-2)
|
||||||
|
self.fc3 = nn.Linear(self.patch_num * patch_len, pred_len * 2)
|
||||||
|
self.gelu4 = nn.GELU()
|
||||||
|
self.fc4 = nn.Linear(pred_len * 2, pred_len)
|
||||||
|
|
||||||
|
# Linear Stream
|
||||||
|
# MLP
|
||||||
|
self.fc5 = nn.Linear(seq_len, pred_len * 4)
|
||||||
|
self.avgpool1 = nn.AvgPool1d(kernel_size=2)
|
||||||
|
self.ln1 = nn.LayerNorm(pred_len * 2)
|
||||||
|
|
||||||
|
self.fc6 = nn.Linear(pred_len * 2, pred_len)
|
||||||
|
self.avgpool2 = nn.AvgPool1d(kernel_size=2)
|
||||||
|
self.ln2 = nn.LayerNorm(pred_len // 2)
|
||||||
|
|
||||||
|
self.fc7 = nn.Linear(pred_len // 2, pred_len)
|
||||||
|
|
||||||
|
# Streams Concatination
|
||||||
|
self.fc8 = nn.Linear(pred_len * 2, pred_len)
|
||||||
|
|
||||||
|
def forward(self, s, t):
|
||||||
|
# x: [Batch, Input, Channel]
|
||||||
|
# s - seasonality
|
||||||
|
# t - trend
|
||||||
|
|
||||||
|
s = s.permute(0,2,1) # to [Batch, Channel, Input]
|
||||||
|
t = t.permute(0,2,1) # to [Batch, Channel, Input]
|
||||||
|
|
||||||
|
# Channel split for channel independence
|
||||||
|
B = s.shape[0] # Batch size
|
||||||
|
C = s.shape[1] # Channel size
|
||||||
|
I = s.shape[2] # Input size
|
||||||
|
s = torch.reshape(s, (B*C, I)) # [Batch and Channel, Input]
|
||||||
|
t = torch.reshape(t, (B*C, I)) # [Batch and Channel, Input]
|
||||||
|
|
||||||
|
# Non-linear Stream
|
||||||
|
# Patching
|
||||||
|
if self.padding_patch == 'end':
|
||||||
|
s = self.padding_patch_layer(s)
|
||||||
|
s = s.unfold(dimension=-1, size=self.patch_len, step=self.stride)
|
||||||
|
# s: [Batch and Channel, Patch_num, Patch_len]
|
||||||
|
|
||||||
|
# Patch Embedding
|
||||||
|
s = self.fc1(s)
|
||||||
|
s = self.gelu1(s)
|
||||||
|
s = self.bn1(s)
|
||||||
|
|
||||||
|
res = s
|
||||||
|
|
||||||
|
# CNN Depthwise
|
||||||
|
s = self.conv1(s)
|
||||||
|
s = self.gelu2(s)
|
||||||
|
s = self.bn2(s)
|
||||||
|
|
||||||
|
# Residual Stream
|
||||||
|
res = self.fc2(res)
|
||||||
|
s = s + res
|
||||||
|
|
||||||
|
# CNN Pointwise
|
||||||
|
s = self.conv2(s)
|
||||||
|
s = self.gelu3(s)
|
||||||
|
s = self.bn3(s)
|
||||||
|
|
||||||
|
# Flatten Head
|
||||||
|
s = self.flatten1(s)
|
||||||
|
s = self.fc3(s)
|
||||||
|
s = self.gelu4(s)
|
||||||
|
s = self.fc4(s)
|
||||||
|
|
||||||
|
# Linear Stream
|
||||||
|
# MLP
|
||||||
|
t = self.fc5(t)
|
||||||
|
t = self.avgpool1(t)
|
||||||
|
t = self.ln1(t)
|
||||||
|
|
||||||
|
t = self.fc6(t)
|
||||||
|
t = self.avgpool2(t)
|
||||||
|
t = self.ln2(t)
|
||||||
|
|
||||||
|
t = self.fc7(t)
|
||||||
|
|
||||||
|
# Streams Concatination
|
||||||
|
x = torch.cat((s, t), dim=1)
|
||||||
|
x = self.fc8(x)
|
||||||
|
|
||||||
|
# Channel concatination
|
||||||
|
x = torch.reshape(x, (B, C, self.pred_len)) # [Batch, Channel, Output]
|
||||||
|
|
||||||
|
x = x.permute(0,2,1) # to [Batch, Output, Channel]
|
||||||
|
|
||||||
|
return x
|
146
models/xPatch_Q/xPatch_Q.py
Normal file
146
models/xPatch_Q/xPatch_Q.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
|
||||||
|
from layers.decomp import DECOMP
|
||||||
|
from .network import Network
|
||||||
|
from layers.revin import RevIN
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
"""
|
||||||
|
xPatch with Q matrix transformation
|
||||||
|
- Applies RevIN normalization first
|
||||||
|
- Applies input Q matrix transformation after RevIN normalization (based on dataset and seq_len)
|
||||||
|
- Uses original xPatch logic (decomposition + dual stream network)
|
||||||
|
- Applies output Q matrix transformation before RevIN denormalization (based on dataset and pred_len)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, configs):
|
||||||
|
super(Model, self).__init__()
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
seq_len = configs.seq_len # lookback window L
|
||||||
|
pred_len = configs.pred_len # prediction length (96, 192, 336, 720)
|
||||||
|
c_in = configs.enc_in # input channels
|
||||||
|
|
||||||
|
# Patching
|
||||||
|
patch_len = configs.patch_len
|
||||||
|
stride = configs.stride
|
||||||
|
padding_patch = configs.padding_patch
|
||||||
|
|
||||||
|
# Store for Q matrix transformations
|
||||||
|
self.seq_len = seq_len
|
||||||
|
self.pred_len = pred_len
|
||||||
|
|
||||||
|
# Load Q matrices
|
||||||
|
self.load_Q_matrices(configs)
|
||||||
|
|
||||||
|
# Normalization
|
||||||
|
self.revin = configs.revin
|
||||||
|
self.revin_layer = RevIN(c_in, affine=True, subtract_last=False)
|
||||||
|
|
||||||
|
# Moving Average
|
||||||
|
self.ma_type = configs.ma_type
|
||||||
|
alpha = configs.alpha # smoothing factor for EMA (Exponential Moving Average)
|
||||||
|
beta = configs.beta # smoothing factor for DEMA (Double Exponential Moving Average)
|
||||||
|
|
||||||
|
self.decomp = DECOMP(self.ma_type, alpha, beta)
|
||||||
|
self.net = Network(seq_len, pred_len, patch_len, stride, padding_patch)
|
||||||
|
|
||||||
|
def load_Q_matrices(self, configs):
|
||||||
|
"""Load pre-computed Q matrices for input and output transformations"""
|
||||||
|
# Get dataset name from configs, default to ETTm1 if not specified
|
||||||
|
dataset_name = getattr(configs, 'dataset', 'ETTm1')
|
||||||
|
|
||||||
|
# Input Q matrix (seq_len)
|
||||||
|
input_q_path = f'cov_mats/{dataset_name}/{dataset_name}_{configs.seq_len}_ratio1.0.npy'
|
||||||
|
# Output Q matrix (pred_len)
|
||||||
|
output_q_path = f'cov_mats/{dataset_name}/{dataset_name}_{configs.pred_len}_ratio1.0.npy'
|
||||||
|
|
||||||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
if os.path.exists(input_q_path):
|
||||||
|
Q_input = np.load(input_q_path)
|
||||||
|
self.register_buffer('Q_input', torch.FloatTensor(Q_input).to(device))
|
||||||
|
print(f"Loaded input Q matrix from {input_q_path}, shape: {Q_input.shape}")
|
||||||
|
else:
|
||||||
|
print(f"Warning: Input Q matrix not found at {input_q_path}, using identity matrix")
|
||||||
|
self.register_buffer('Q_input', torch.eye(configs.seq_len).to(device))
|
||||||
|
|
||||||
|
if os.path.exists(output_q_path):
|
||||||
|
Q_output = np.load(output_q_path)
|
||||||
|
self.register_buffer('Q_output', torch.FloatTensor(Q_output).to(device))
|
||||||
|
print(f"Loaded output Q matrix from {output_q_path}, shape: {Q_output.shape}")
|
||||||
|
else:
|
||||||
|
print(f"Warning: Output Q matrix not found at {output_q_path}, using identity matrix")
|
||||||
|
self.register_buffer('Q_output', torch.eye(configs.pred_len).to(device))
|
||||||
|
|
||||||
|
def apply_input_Q_transformation(self, x):
|
||||||
|
"""
|
||||||
|
Apply input Q matrix transformation after RevIN normalization
|
||||||
|
Input: x with shape [B, T, N] where T = seq_len
|
||||||
|
Output: transformed x with shape [B, T, N]
|
||||||
|
"""
|
||||||
|
B, T, N = x.size()
|
||||||
|
assert T == self.seq_len, f"Expected seq_len {self.seq_len}, got {T}"
|
||||||
|
|
||||||
|
# Transpose to [B, N, T] for matrix multiplication
|
||||||
|
x_transposed = x.transpose(-1, -2) # [B, N, T]
|
||||||
|
|
||||||
|
# Apply input Q transformation: einsum 'bnt,tv->bnv'
|
||||||
|
# x_transposed: [B, N, T], Q_input.T: [T, T] -> result: [B, N, T]
|
||||||
|
x_trans = torch.einsum('bnt,tv->bnv', x_transposed, self.Q_input.transpose(-1, -2))
|
||||||
|
|
||||||
|
# Transpose back to [B, T, N]
|
||||||
|
x_transformed = x_trans.transpose(-1, -2) # [B, T, N]
|
||||||
|
|
||||||
|
return x_transformed
|
||||||
|
|
||||||
|
def apply_output_Q_transformation(self, x):
|
||||||
|
"""
|
||||||
|
Apply output Q matrix transformation to prediction output
|
||||||
|
Input: x with shape [B, pred_len, N]
|
||||||
|
Output: transformed x with shape [B, pred_len, N]
|
||||||
|
"""
|
||||||
|
B, T, N = x.size()
|
||||||
|
assert T == self.pred_len, f"Expected pred_len {self.pred_len}, got {T}"
|
||||||
|
|
||||||
|
# Transpose to [B, N, T] for matrix multiplication
|
||||||
|
x_transposed = x.transpose(-1, -2) # [B, N, pred_len]
|
||||||
|
|
||||||
|
# Apply output Q transformation: einsum 'bnt,tv->bnv'
|
||||||
|
# x_transposed: [B, N, pred_len], Q_output: [pred_len, pred_len] -> result: [B, N, pred_len]
|
||||||
|
x_trans = torch.einsum('bnt,tv->bnv', x_transposed, self.Q_output)
|
||||||
|
|
||||||
|
# Transpose back to [B, pred_len, N]
|
||||||
|
x_transformed = x_trans.transpose(-1, -2) # [B, pred_len, N]
|
||||||
|
|
||||||
|
return x_transformed
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x: [Batch, Input, Channel]
|
||||||
|
|
||||||
|
# RevIN Normalization
|
||||||
|
if self.revin:
|
||||||
|
x = self.revin_layer(x, 'norm')
|
||||||
|
|
||||||
|
# Apply input Q matrix transformation after RevIN normalization
|
||||||
|
x_transformed = self.apply_input_Q_transformation(x)
|
||||||
|
|
||||||
|
# xPatch processing with Q-transformed input
|
||||||
|
if self.ma_type == 'reg': # If no decomposition, directly pass the input to the network
|
||||||
|
output = self.net(x_transformed, x_transformed)
|
||||||
|
else:
|
||||||
|
seasonal_init, trend_init = self.decomp(x_transformed)
|
||||||
|
output = self.net(seasonal_init, trend_init)
|
||||||
|
|
||||||
|
# Apply output Q matrix transformation to the prediction
|
||||||
|
output_transformed = self.apply_output_Q_transformation(output)
|
||||||
|
|
||||||
|
# RevIN Denormalization
|
||||||
|
if self.revin:
|
||||||
|
output_transformed = self.revin_layer(output_transformed, 'denorm')
|
||||||
|
|
||||||
|
return output_transformed
|
Reference in New Issue
Block a user