diff --git a/layers/decomp.py b/layers/decomp.py new file mode 100644 index 0000000..99bb5c5 --- /dev/null +++ b/layers/decomp.py @@ -0,0 +1,21 @@ +import torch +from torch import nn + +from layers.ema import EMA +from layers.dema import DEMA + +class DECOMP(nn.Module): + """ + Series decomposition block + """ + def __init__(self, ma_type, alpha, beta): + super(DECOMP, self).__init__() + if ma_type == 'ema': + self.ma = EMA(alpha) + elif ma_type == 'dema': + self.ma = DEMA(alpha, beta) + + def forward(self, x): + moving_average = self.ma(x) + res = x - moving_average + return res, moving_average diff --git a/layers/dema.py b/layers/dema.py new file mode 100644 index 0000000..4e7004d --- /dev/null +++ b/layers/dema.py @@ -0,0 +1,27 @@ +import torch +from torch import nn + +class DEMA(nn.Module): + """ + Double Exponential Moving Average (DEMA) block to highlight the trend of time series + """ + def __init__(self, alpha, beta): + super(DEMA, self).__init__() + # self.alpha = nn.Parameter(alpha) # Learnable alpha + # self.beta = nn.Parameter(beta) # Learnable beta + self.alpha = alpha.to(device='cuda') + self.beta = beta.to(device='cuda') + + def forward(self, x): + # self.alpha.data.clamp_(0, 1) # Clamp learnable alpha to [0, 1] + # self.beta.data.clamp_(0, 1) # Clamp learnable beta to [0, 1] + s_prev = x[:, 0, :] + b = x[:, 1, :] - s_prev + res = [s_prev.unsqueeze(1)] + for t in range(1, x.shape[1]): + xt = x[:, t, :] + s = self.alpha * xt + (1 - self.alpha) * (s_prev + b) + b = self.beta * (s - s_prev) + (1 - self.beta) * b + s_prev = s + res.append(s.unsqueeze(1)) + return torch.cat(res, dim=1) diff --git a/layers/ema.py b/layers/ema.py new file mode 100644 index 0000000..3e5cc95 --- /dev/null +++ b/layers/ema.py @@ -0,0 +1,37 @@ +import torch +from torch import nn + +class EMA(nn.Module): + """ + Exponential Moving Average (EMA) block to highlight the trend of time series + """ + def __init__(self, alpha): + super(EMA, self).__init__() + # self.alpha = nn.Parameter(alpha) # Learnable alpha + self.alpha = alpha + + # Optimized implementation with O(1) time complexity + def forward(self, x): + # x: [Batch, Input, Channel] + # self.alpha.data.clamp_(0, 1) # Clamp learnable alpha to [0, 1] + _, t, _ = x.shape + powers = torch.flip(torch.arange(t, dtype=torch.double), dims=(0,)) + weights = torch.pow((1 - self.alpha), powers).to('cuda') + divisor = weights.clone() + weights[1:] = weights[1:] * self.alpha + weights = weights.reshape(1, t, 1) + divisor = divisor.reshape(1, t, 1) + x = torch.cumsum(x * weights, dim=1) + x = torch.div(x, divisor) + return x.to(torch.float32) + + # # Naive implementation with O(n) time complexity + # def forward(self, x): + # # self.alpha.data.clamp_(0, 1) # Clamp learnable alpha to [0, 1] + # s = x[:, 0, :] + # res = [s.unsqueeze(1)] + # for t in range(1, x.shape[1]): + # xt = x[:, t, :] + # s = self.alpha * xt + (1 - self.alpha) * s + # res.append(s.unsqueeze(1)) + # return torch.cat(res, dim=1) diff --git a/layers/revin.py b/layers/revin.py new file mode 100644 index 0000000..113d29b --- /dev/null +++ b/layers/revin.py @@ -0,0 +1,61 @@ +import torch +from torch import nn + +class RevIN(nn.Module): + def __init__(self, num_features: int, eps=1e-5, affine=True, subtract_last=False): + """ + :param num_features: the number of features or channels + :param eps: a value added for numerical stability + :param affine: if True, RevIN has learnable affine parameters + """ + super(RevIN, self).__init__() + self.num_features = num_features + self.eps = eps + self.affine = affine + self.subtract_last = subtract_last + if self.affine: + self._init_params() + + def forward(self, x, mode:str): + if mode == 'norm': + self._get_statistics(x) + x = self._normalize(x) + elif mode == 'denorm': + x = self._denormalize(x) + else: raise NotImplementedError + return x + + def _init_params(self): + # initialize RevIN params: (C,) + self.affine_weight = nn.Parameter(torch.ones(self.num_features)) + self.affine_bias = nn.Parameter(torch.zeros(self.num_features)) + + def _get_statistics(self, x): + dim2reduce = tuple(range(1, x.ndim-1)) + if self.subtract_last: + self.last = x[:,-1,:].unsqueeze(1) + else: + self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach() + self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach() + + def _normalize(self, x): + if self.subtract_last: + x = x - self.last + else: + x = x - self.mean + x = x / self.stdev + if self.affine: + x = x * self.affine_weight + x = x + self.affine_bias + return x + + def _denormalize(self, x): + if self.affine: + x = x - self.affine_bias + x = x / (self.affine_weight + self.eps*self.eps) + x = x * self.stdev + if self.subtract_last: + x = x + self.last + else: + x =x + self.mean + return x diff --git a/layers/telu.py b/layers/telu.py new file mode 100644 index 0000000..c61214b --- /dev/null +++ b/layers/telu.py @@ -0,0 +1,30 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class TeLU(nn.Module): + """ + 实现论文中提出的 TeLU 激活函数。 + 论文: TeLU Activation Function for Fast and Stable Deep Learning + 公式: TeLU(x) = x * tanh(e^x) + """ + def __init__(self): + """ + TeLU 激活函数没有可学习的参数,所以 __init__ 方法很简单。 + """ + super(TeLU, self).__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + 前向传播的计算逻辑。 + """ + # 直接应用公式 + return x * torch.tanh(torch.exp(x)) + + def __repr__(self): + """ + (可选但推荐) 定义一个好的字符串表示,方便打印模型结构。 + """ + return f"{self.__class__.__name__}()" + + diff --git a/models/OLinear/model.py b/models/OLinear/model.py index 1974e97..355f352 100644 --- a/models/OLinear/model.py +++ b/models/OLinear/model.py @@ -3,8 +3,8 @@ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np -from RevIN import RevIN -from Trans_EncDec import Encoder_ori, LinearEncoder +from ..RevIN import RevIN +from .Trans_EncDec import Encoder_ori, LinearEncoder diff --git a/models/RevIN/__init__.py b/models/RevIN/__init__.py index 4a10a05..87cb736 100644 --- a/models/RevIN/__init__.py +++ b/models/RevIN/__init__.py @@ -1 +1 @@ -from model import * +from .model import * diff --git a/models/TimesNet_Q/TimesNet_Q.py b/models/TimesNet_Q/TimesNet_Q.py new file mode 100644 index 0000000..6e9fa7d --- /dev/null +++ b/models/TimesNet_Q/TimesNet_Q.py @@ -0,0 +1,221 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.fft +import numpy as np +import os +from layers.Embed import DataEmbedding +from layers.Conv_Blocks import Inception_Block_V1 + + +def FFT_for_Period(x, k=2): + # [B, T, C] + xf = torch.fft.rfft(x, dim=1) + # find period by amplitudes + frequency_list = abs(xf).mean(0).mean(-1) + frequency_list[0] = 0 + _, top_list = torch.topk(frequency_list, k) + top_list = top_list.detach().cpu().numpy() + period = x.shape[1] // top_list + return period, abs(xf).mean(-1)[:, top_list] + + +class TimesBlock(nn.Module): + """Original TimesBlock without Q matrix transformation""" + def __init__(self, configs): + super(TimesBlock, self).__init__() + self.seq_len = configs.seq_len + self.pred_len = configs.pred_len + self.k = configs.top_k + # parameter-efficient design + self.conv = nn.Sequential( + Inception_Block_V1(configs.d_model, configs.d_ff, + num_kernels=configs.num_kernels), + nn.GELU(), + Inception_Block_V1(configs.d_ff, configs.d_model, + num_kernels=configs.num_kernels) + ) + + def forward(self, x): + B, T, N = x.size() + period_list, period_weight = FFT_for_Period(x, self.k) + + res = [] + for i in range(self.k): + period = period_list[i] + # padding + if (self.seq_len + self.pred_len) % period != 0: + length = ( + ((self.seq_len + self.pred_len) // period) + 1) * period + padding = torch.zeros([x.shape[0], (length - (self.seq_len + self.pred_len)), x.shape[2]]).to(x.device) + out = torch.cat([x, padding], dim=1) + else: + length = (self.seq_len + self.pred_len) + out = x + # reshape + out = out.reshape(B, length // period, period, + N).permute(0, 3, 1, 2).contiguous() + # 2D conv: from 1d Variation to 2d Variation + out = self.conv(out) + # reshape back + out = out.permute(0, 2, 3, 1).reshape(B, -1, N) + res.append(out[:, :(self.seq_len + self.pred_len), :]) + res = torch.stack(res, dim=-1) + # adaptive aggregation + period_weight = F.softmax(period_weight, dim=1) + period_weight = period_weight.unsqueeze( + 1).unsqueeze(1).repeat(1, T, N, 1) + res = torch.sum(res * period_weight, -1) + # residual connection + res = res + x + return res + + +class Model(nn.Module): + """ + TimesNet with Q matrix transformation + - Applies input Q matrix transformation before embedding + - Uses original TimesBlock logic + - Applies output Q matrix transformation before De-Normalization + Only implements long/short term forecasting + """ + + def __init__(self, configs): + super(Model, self).__init__() + self.configs = configs + self.task_name = configs.task_name + self.seq_len = configs.seq_len + self.label_len = configs.label_len + self.pred_len = configs.pred_len + + # Load Q matrices + self.load_Q_matrices(configs) + + # Model layers + self.model = nn.ModuleList([TimesBlock(configs) + for _ in range(configs.e_layers)]) + self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, + configs.dropout) + self.layer = configs.e_layers + self.layer_norm = nn.LayerNorm(configs.d_model) + + # Only implement forecast-related layers + self.predict_linear = nn.Linear( + self.seq_len, self.pred_len + self.seq_len) + self.projection = nn.Linear( + configs.d_model, configs.c_out, bias=True) + + def load_Q_matrices(self, configs): + """Load pre-computed Q matrices for input and output transformations""" + # Get dataset name from configs, default to ETTm1 if not specified + dataset_name = getattr(configs, 'dataset', 'ETTm1') + + # Input Q matrix (seq_len) + input_q_path = f'cov_mats/{dataset_name}/{dataset_name}_{configs.seq_len}_ratio1.0.npy' + # Output Q matrix (pred_len) + output_q_path = f'cov_mats/{dataset_name}/{dataset_name}_{configs.pred_len}_ratio1.0.npy' + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + if os.path.exists(input_q_path): + Q_input = np.load(input_q_path) + self.register_buffer('Q_input', torch.FloatTensor(Q_input).to(device)) + print(f"Loaded input Q matrix from {input_q_path}, shape: {Q_input.shape}") + else: + print(f"Warning: Input Q matrix not found at {input_q_path}, using identity matrix") + self.register_buffer('Q_input', torch.eye(configs.seq_len).to(device)) + + if os.path.exists(output_q_path): + Q_output = np.load(output_q_path) + self.register_buffer('Q_output', torch.FloatTensor(Q_output).to(device)) + print(f"Loaded output Q matrix from {output_q_path}, shape: {Q_output.shape}") + else: + print(f"Warning: Output Q matrix not found at {output_q_path}, using identity matrix") + self.register_buffer('Q_output', torch.eye(configs.pred_len).to(device)) + + def apply_input_Q_transformation(self, x): + """ + Apply input Q matrix transformation before embedding + Input: x with shape [B, T, N] where T = seq_len + Output: transformed x with shape [B, T, N] + """ + B, T, N = x.size() + + # Transpose to [B, N, T] for matrix multiplication + x_transposed = x.transpose(-1, -2) # [B, N, T] + + # Apply input Q transformation: einsum 'bnt,tv->bnv' + # x_transposed: [B, N, T], Q_input.T: [T, T] -> result: [B, N, T] + x_trans = torch.einsum('bnt,tv->bnv', x_transposed, self.Q_input.transpose(-1, -2)) + + # Transpose back to [B, T, N] + x_transformed = x_trans.transpose(-1, -2) # [B, T, N] + + return x_transformed + + def apply_output_Q_transformation(self, x): + """ + Apply output Q matrix transformation to prediction output + Input: x with shape [B, pred_len, N] + Output: transformed x with shape [B, pred_len, N] + """ + B, T, N = x.size() + assert T == self.pred_len, f"Expected pred_len {self.pred_len}, got {T}" + + # Transpose to [B, N, T] for matrix multiplication + x_transposed = x.transpose(-1, -2) # [B, N, pred_len] + + # Apply output Q transformation: einsum 'bnt,tv->bnv' + # x_transposed: [B, N, pred_len], Q_output: [pred_len, pred_len] -> result: [B, N, pred_len] + x_trans = torch.einsum('bnt,tv->bnv', x_transposed, self.Q_output) + + # Transpose back to [B, pred_len, N] + x_transformed = x_trans.transpose(-1, -2) # [B, pred_len, N] + + return x_transformed + + def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): + # Normalization from Non-stationary Transformer + means = x_enc.mean(1, keepdim=True).detach() + x_enc = x_enc.sub(means) + stdev = torch.sqrt( + torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) + x_enc = x_enc.div(stdev) + + # Apply input Q matrix transformation before embedding + x_enc_transformed = self.apply_input_Q_transformation(x_enc) + + # embedding with transformed input + enc_out = self.enc_embedding(x_enc_transformed, x_mark_enc) # [B,T,C] + enc_out = self.predict_linear(enc_out.permute(0, 2, 1)).permute( + 0, 2, 1) # align temporal dimension + + # TimesNet blocks (original logic, no Q transformation) + for i in range(self.layer): + enc_out = self.layer_norm(self.model[i](enc_out)) + + # project back + dec_out = self.projection(enc_out) + + # Extract prediction part and apply output Q transformation + pred_out = dec_out[:, -self.pred_len:, :] # [B, pred_len, N] + pred_out_transformed = self.apply_output_Q_transformation(pred_out) + + # De-Normalization from Non-stationary Transformer + pred_out_transformed = pred_out_transformed.mul( + (stdev[:, 0, :].unsqueeze(1).repeat( + 1, self.pred_len, 1))) + pred_out_transformed = pred_out_transformed.add( + (means[:, 0, :].unsqueeze(1).repeat( + 1, self.pred_len, 1))) + + return pred_out_transformed + + def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None): + # Only support long_term_forecast and short_term_forecast + if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': + dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) + return dec_out # [B, pred_len, D] + else: + raise NotImplementedError(f"Task {self.task_name} is not implemented in TimesNet_Q") + return None \ No newline at end of file diff --git a/models/TimesNet_Q/__init__.py b/models/TimesNet_Q/__init__.py new file mode 100644 index 0000000..55e1ee5 --- /dev/null +++ b/models/TimesNet_Q/__init__.py @@ -0,0 +1,3 @@ +from .TimesNet_Q import Model + +__all__ = ['Model'] \ No newline at end of file diff --git a/models/xPatch/network.py b/models/xPatch/network.py new file mode 100644 index 0000000..fbbc5d6 --- /dev/null +++ b/models/xPatch/network.py @@ -0,0 +1,132 @@ +import torch +from torch import nn + +class Network(nn.Module): + def __init__(self, seq_len, pred_len, patch_len, stride, padding_patch): + super(Network, self).__init__() + + # Parameters + self.pred_len = pred_len + + # Non-linear Stream + # Patching + self.patch_len = patch_len + self.stride = stride + self.padding_patch = padding_patch + self.dim = patch_len * patch_len + self.patch_num = (seq_len - patch_len)//stride + 1 + if padding_patch == 'end': # can be modified to general case + self.padding_patch_layer = nn.ReplicationPad1d((0, stride)) + self.patch_num += 1 + + # Patch Embedding + self.fc1 = nn.Linear(patch_len, self.dim) + self.gelu1 = nn.GELU() + self.bn1 = nn.BatchNorm1d(self.patch_num) + + # CNN Depthwise + self.conv1 = nn.Conv1d(self.patch_num, self.patch_num, + patch_len, patch_len, groups=self.patch_num) + self.gelu2 = nn.GELU() + self.bn2 = nn.BatchNorm1d(self.patch_num) + + # Residual Stream + self.fc2 = nn.Linear(self.dim, patch_len) + + # CNN Pointwise + self.conv2 = nn.Conv1d(self.patch_num, self.patch_num, 1, 1) + self.gelu3 = nn.GELU() + self.bn3 = nn.BatchNorm1d(self.patch_num) + + # Flatten Head + self.flatten1 = nn.Flatten(start_dim=-2) + self.fc3 = nn.Linear(self.patch_num * patch_len, pred_len * 2) + self.gelu4 = nn.GELU() + self.fc4 = nn.Linear(pred_len * 2, pred_len) + + # Linear Stream + # MLP + self.fc5 = nn.Linear(seq_len, pred_len * 4) + self.avgpool1 = nn.AvgPool1d(kernel_size=2) + self.ln1 = nn.LayerNorm(pred_len * 2) + + self.fc6 = nn.Linear(pred_len * 2, pred_len) + self.avgpool2 = nn.AvgPool1d(kernel_size=2) + self.ln2 = nn.LayerNorm(pred_len // 2) + + self.fc7 = nn.Linear(pred_len // 2, pred_len) + + # Streams Concatination + self.fc8 = nn.Linear(pred_len * 2, pred_len) + + def forward(self, s, t): + # x: [Batch, Input, Channel] + # s - seasonality + # t - trend + + s = s.permute(0,2,1) # to [Batch, Channel, Input] + t = t.permute(0,2,1) # to [Batch, Channel, Input] + + # Channel split for channel independence + B = s.shape[0] # Batch size + C = s.shape[1] # Channel size + I = s.shape[2] # Input size + s = torch.reshape(s, (B*C, I)) # [Batch and Channel, Input] + t = torch.reshape(t, (B*C, I)) # [Batch and Channel, Input] + + # Non-linear Stream + # Patching + if self.padding_patch == 'end': + s = self.padding_patch_layer(s) + s = s.unfold(dimension=-1, size=self.patch_len, step=self.stride) + # s: [Batch and Channel, Patch_num, Patch_len] + + # Patch Embedding + s = self.fc1(s) + s = self.gelu1(s) + s = self.bn1(s) + + res = s + + # CNN Depthwise + s = self.conv1(s) + s = self.gelu2(s) + s = self.bn2(s) + + # Residual Stream + res = self.fc2(res) + s = s + res + + # CNN Pointwise + s = self.conv2(s) + s = self.gelu3(s) + s = self.bn3(s) + + # Flatten Head + s = self.flatten1(s) + s = self.fc3(s) + s = self.gelu4(s) + s = self.fc4(s) + + # Linear Stream + # MLP + t = self.fc5(t) + t = self.avgpool1(t) + t = self.ln1(t) + + t = self.fc6(t) + t = self.avgpool2(t) + t = self.ln2(t) + + t = self.fc7(t) + + # Streams Concatination + x = torch.cat((s, t), dim=1) + x = self.fc8(x) + + # Channel concatination + x = torch.reshape(x, (B, C, self.pred_len)) # [Batch, Channel, Output] + + x = x.permute(0,2,1) # to [Batch, Output, Channel] + + return x diff --git a/models/xPatch/xPatch.py b/models/xPatch/xPatch.py new file mode 100644 index 0000000..01f0d10 --- /dev/null +++ b/models/xPatch/xPatch.py @@ -0,0 +1,58 @@ +import torch +import torch.nn as nn +import math + +from layers.decomp import DECOMP +from .network import Network +# from layers.network_mlp import NetworkMLP # For ablation study with MLP-only stream +# from layers.network_cnn import NetworkCNN # For ablation study with CNN-only stream +from layers.revin import RevIN + +class Model(nn.Module): + def __init__(self, configs): + super(Model, self).__init__() + + # Parameters + seq_len = configs.seq_len # lookback window L + pred_len = configs.pred_len # prediction length (96, 192, 336, 720) + c_in = configs.enc_in # input channels + + # Patching + patch_len = configs.patch_len + stride = configs.stride + padding_patch = configs.padding_patch + + # Normalization + self.revin = configs.revin + self.revin_layer = RevIN(c_in,affine=True,subtract_last=False) + + # Moving Average + self.ma_type = configs.ma_type + alpha = configs.alpha # smoothing factor for EMA (Exponential Moving Average) + beta = configs.beta # smoothing factor for DEMA (Double Exponential Moving Average) + + self.decomp = DECOMP(self.ma_type, alpha, beta) + self.net = Network(seq_len, pred_len, patch_len, stride, padding_patch) + # self.net_mlp = NetworkMLP(seq_len, pred_len) # For ablation study with MLP-only stream + # self.net_cnn = NetworkCNN(seq_len, pred_len, patch_len, stride, padding_patch) # For ablation study with CNN-only stream + + def forward(self, x): + # x: [Batch, Input, Channel] + + # Normalization + if self.revin: + x = self.revin_layer(x, 'norm') + + if self.ma_type == 'reg': # If no decomposition, directly pass the input to the network + x = self.net(x, x) + # x = self.net_mlp(x) # For ablation study with MLP-only stream + # x = self.net_cnn(x) # For ablation study with CNN-only stream + else: + seasonal_init, trend_init = self.decomp(x) + x = self.net(seasonal_init, trend_init) + + # Denormalization + if self.revin: + x = self.revin_layer(x, 'denorm') + + return x diff --git a/models/xPatch_Q/__init__.py b/models/xPatch_Q/__init__.py new file mode 100644 index 0000000..18d7ca9 --- /dev/null +++ b/models/xPatch_Q/__init__.py @@ -0,0 +1 @@ +from .xPatch_Q import Model \ No newline at end of file diff --git a/models/xPatch_Q/network.py b/models/xPatch_Q/network.py new file mode 100644 index 0000000..fbbc5d6 --- /dev/null +++ b/models/xPatch_Q/network.py @@ -0,0 +1,132 @@ +import torch +from torch import nn + +class Network(nn.Module): + def __init__(self, seq_len, pred_len, patch_len, stride, padding_patch): + super(Network, self).__init__() + + # Parameters + self.pred_len = pred_len + + # Non-linear Stream + # Patching + self.patch_len = patch_len + self.stride = stride + self.padding_patch = padding_patch + self.dim = patch_len * patch_len + self.patch_num = (seq_len - patch_len)//stride + 1 + if padding_patch == 'end': # can be modified to general case + self.padding_patch_layer = nn.ReplicationPad1d((0, stride)) + self.patch_num += 1 + + # Patch Embedding + self.fc1 = nn.Linear(patch_len, self.dim) + self.gelu1 = nn.GELU() + self.bn1 = nn.BatchNorm1d(self.patch_num) + + # CNN Depthwise + self.conv1 = nn.Conv1d(self.patch_num, self.patch_num, + patch_len, patch_len, groups=self.patch_num) + self.gelu2 = nn.GELU() + self.bn2 = nn.BatchNorm1d(self.patch_num) + + # Residual Stream + self.fc2 = nn.Linear(self.dim, patch_len) + + # CNN Pointwise + self.conv2 = nn.Conv1d(self.patch_num, self.patch_num, 1, 1) + self.gelu3 = nn.GELU() + self.bn3 = nn.BatchNorm1d(self.patch_num) + + # Flatten Head + self.flatten1 = nn.Flatten(start_dim=-2) + self.fc3 = nn.Linear(self.patch_num * patch_len, pred_len * 2) + self.gelu4 = nn.GELU() + self.fc4 = nn.Linear(pred_len * 2, pred_len) + + # Linear Stream + # MLP + self.fc5 = nn.Linear(seq_len, pred_len * 4) + self.avgpool1 = nn.AvgPool1d(kernel_size=2) + self.ln1 = nn.LayerNorm(pred_len * 2) + + self.fc6 = nn.Linear(pred_len * 2, pred_len) + self.avgpool2 = nn.AvgPool1d(kernel_size=2) + self.ln2 = nn.LayerNorm(pred_len // 2) + + self.fc7 = nn.Linear(pred_len // 2, pred_len) + + # Streams Concatination + self.fc8 = nn.Linear(pred_len * 2, pred_len) + + def forward(self, s, t): + # x: [Batch, Input, Channel] + # s - seasonality + # t - trend + + s = s.permute(0,2,1) # to [Batch, Channel, Input] + t = t.permute(0,2,1) # to [Batch, Channel, Input] + + # Channel split for channel independence + B = s.shape[0] # Batch size + C = s.shape[1] # Channel size + I = s.shape[2] # Input size + s = torch.reshape(s, (B*C, I)) # [Batch and Channel, Input] + t = torch.reshape(t, (B*C, I)) # [Batch and Channel, Input] + + # Non-linear Stream + # Patching + if self.padding_patch == 'end': + s = self.padding_patch_layer(s) + s = s.unfold(dimension=-1, size=self.patch_len, step=self.stride) + # s: [Batch and Channel, Patch_num, Patch_len] + + # Patch Embedding + s = self.fc1(s) + s = self.gelu1(s) + s = self.bn1(s) + + res = s + + # CNN Depthwise + s = self.conv1(s) + s = self.gelu2(s) + s = self.bn2(s) + + # Residual Stream + res = self.fc2(res) + s = s + res + + # CNN Pointwise + s = self.conv2(s) + s = self.gelu3(s) + s = self.bn3(s) + + # Flatten Head + s = self.flatten1(s) + s = self.fc3(s) + s = self.gelu4(s) + s = self.fc4(s) + + # Linear Stream + # MLP + t = self.fc5(t) + t = self.avgpool1(t) + t = self.ln1(t) + + t = self.fc6(t) + t = self.avgpool2(t) + t = self.ln2(t) + + t = self.fc7(t) + + # Streams Concatination + x = torch.cat((s, t), dim=1) + x = self.fc8(x) + + # Channel concatination + x = torch.reshape(x, (B, C, self.pred_len)) # [Batch, Channel, Output] + + x = x.permute(0,2,1) # to [Batch, Output, Channel] + + return x diff --git a/models/xPatch_Q/xPatch_Q.py b/models/xPatch_Q/xPatch_Q.py new file mode 100644 index 0000000..8598b3a --- /dev/null +++ b/models/xPatch_Q/xPatch_Q.py @@ -0,0 +1,146 @@ +import torch +import torch.nn as nn +import math +import numpy as np +import os + +from layers.decomp import DECOMP +from .network import Network +from layers.revin import RevIN + +class Model(nn.Module): + """ + xPatch with Q matrix transformation + - Applies RevIN normalization first + - Applies input Q matrix transformation after RevIN normalization (based on dataset and seq_len) + - Uses original xPatch logic (decomposition + dual stream network) + - Applies output Q matrix transformation before RevIN denormalization (based on dataset and pred_len) + """ + + def __init__(self, configs): + super(Model, self).__init__() + + # Parameters + seq_len = configs.seq_len # lookback window L + pred_len = configs.pred_len # prediction length (96, 192, 336, 720) + c_in = configs.enc_in # input channels + + # Patching + patch_len = configs.patch_len + stride = configs.stride + padding_patch = configs.padding_patch + + # Store for Q matrix transformations + self.seq_len = seq_len + self.pred_len = pred_len + + # Load Q matrices + self.load_Q_matrices(configs) + + # Normalization + self.revin = configs.revin + self.revin_layer = RevIN(c_in, affine=True, subtract_last=False) + + # Moving Average + self.ma_type = configs.ma_type + alpha = configs.alpha # smoothing factor for EMA (Exponential Moving Average) + beta = configs.beta # smoothing factor for DEMA (Double Exponential Moving Average) + + self.decomp = DECOMP(self.ma_type, alpha, beta) + self.net = Network(seq_len, pred_len, patch_len, stride, padding_patch) + + def load_Q_matrices(self, configs): + """Load pre-computed Q matrices for input and output transformations""" + # Get dataset name from configs, default to ETTm1 if not specified + dataset_name = getattr(configs, 'dataset', 'ETTm1') + + # Input Q matrix (seq_len) + input_q_path = f'cov_mats/{dataset_name}/{dataset_name}_{configs.seq_len}_ratio1.0.npy' + # Output Q matrix (pred_len) + output_q_path = f'cov_mats/{dataset_name}/{dataset_name}_{configs.pred_len}_ratio1.0.npy' + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + if os.path.exists(input_q_path): + Q_input = np.load(input_q_path) + self.register_buffer('Q_input', torch.FloatTensor(Q_input).to(device)) + print(f"Loaded input Q matrix from {input_q_path}, shape: {Q_input.shape}") + else: + print(f"Warning: Input Q matrix not found at {input_q_path}, using identity matrix") + self.register_buffer('Q_input', torch.eye(configs.seq_len).to(device)) + + if os.path.exists(output_q_path): + Q_output = np.load(output_q_path) + self.register_buffer('Q_output', torch.FloatTensor(Q_output).to(device)) + print(f"Loaded output Q matrix from {output_q_path}, shape: {Q_output.shape}") + else: + print(f"Warning: Output Q matrix not found at {output_q_path}, using identity matrix") + self.register_buffer('Q_output', torch.eye(configs.pred_len).to(device)) + + def apply_input_Q_transformation(self, x): + """ + Apply input Q matrix transformation after RevIN normalization + Input: x with shape [B, T, N] where T = seq_len + Output: transformed x with shape [B, T, N] + """ + B, T, N = x.size() + assert T == self.seq_len, f"Expected seq_len {self.seq_len}, got {T}" + + # Transpose to [B, N, T] for matrix multiplication + x_transposed = x.transpose(-1, -2) # [B, N, T] + + # Apply input Q transformation: einsum 'bnt,tv->bnv' + # x_transposed: [B, N, T], Q_input.T: [T, T] -> result: [B, N, T] + x_trans = torch.einsum('bnt,tv->bnv', x_transposed, self.Q_input.transpose(-1, -2)) + + # Transpose back to [B, T, N] + x_transformed = x_trans.transpose(-1, -2) # [B, T, N] + + return x_transformed + + def apply_output_Q_transformation(self, x): + """ + Apply output Q matrix transformation to prediction output + Input: x with shape [B, pred_len, N] + Output: transformed x with shape [B, pred_len, N] + """ + B, T, N = x.size() + assert T == self.pred_len, f"Expected pred_len {self.pred_len}, got {T}" + + # Transpose to [B, N, T] for matrix multiplication + x_transposed = x.transpose(-1, -2) # [B, N, pred_len] + + # Apply output Q transformation: einsum 'bnt,tv->bnv' + # x_transposed: [B, N, pred_len], Q_output: [pred_len, pred_len] -> result: [B, N, pred_len] + x_trans = torch.einsum('bnt,tv->bnv', x_transposed, self.Q_output) + + # Transpose back to [B, pred_len, N] + x_transformed = x_trans.transpose(-1, -2) # [B, pred_len, N] + + return x_transformed + + def forward(self, x): + # x: [Batch, Input, Channel] + + # RevIN Normalization + if self.revin: + x = self.revin_layer(x, 'norm') + + # Apply input Q matrix transformation after RevIN normalization + x_transformed = self.apply_input_Q_transformation(x) + + # xPatch processing with Q-transformed input + if self.ma_type == 'reg': # If no decomposition, directly pass the input to the network + output = self.net(x_transformed, x_transformed) + else: + seasonal_init, trend_init = self.decomp(x_transformed) + output = self.net(seasonal_init, trend_init) + + # Apply output Q matrix transformation to the prediction + output_transformed = self.apply_output_Q_transformation(output) + + # RevIN Denormalization + if self.revin: + output_transformed = self.revin_layer(output_transformed, 'denorm') + + return output_transformed