import torch import torch.nn as nn import math import numpy as np import os from layers.decomp import DECOMP from .network import Network from layers.revin import RevIN class Model(nn.Module): """ xPatch with Q matrix transformation - Applies RevIN normalization first - Applies input Q matrix transformation after RevIN normalization (based on dataset and seq_len) - Uses original xPatch logic (decomposition + dual stream network) - Applies output Q matrix transformation before RevIN denormalization (based on dataset and pred_len) """ def __init__(self, configs): super(Model, self).__init__() # Parameters seq_len = configs.seq_len # lookback window L pred_len = configs.pred_len # prediction length (96, 192, 336, 720) c_in = configs.enc_in # input channels # Patching patch_len = configs.patch_len stride = configs.stride padding_patch = configs.padding_patch # Store for Q matrix transformations self.seq_len = seq_len self.pred_len = pred_len # Load Q matrices self.load_Q_matrices(configs) # Normalization self.revin = configs.revin self.revin_layer = RevIN(c_in, affine=True, subtract_last=False) # Moving Average self.ma_type = configs.ma_type alpha = configs.alpha # smoothing factor for EMA (Exponential Moving Average) beta = configs.beta # smoothing factor for DEMA (Double Exponential Moving Average) self.decomp = DECOMP(self.ma_type, alpha, beta) self.net = Network(seq_len, pred_len, patch_len, stride, padding_patch) def load_Q_matrices(self, configs): """Load pre-computed Q matrices for input and output transformations""" # Get dataset name from configs, default to ETTm1 if not specified dataset_name = getattr(configs, 'dataset', 'ETTm1') # Input Q matrix (seq_len) input_q_path = f'cov_mats/{dataset_name}/{dataset_name}_{configs.seq_len}_ratio1.0.npy' # Output Q matrix (pred_len) output_q_path = f'cov_mats/{dataset_name}/{dataset_name}_{configs.pred_len}_ratio1.0.npy' device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if os.path.exists(input_q_path): Q_input = np.load(input_q_path) self.register_buffer('Q_input', torch.FloatTensor(Q_input).to(device)) print(f"Loaded input Q matrix from {input_q_path}, shape: {Q_input.shape}") else: print(f"Warning: Input Q matrix not found at {input_q_path}, using identity matrix") self.register_buffer('Q_input', torch.eye(configs.seq_len).to(device)) if os.path.exists(output_q_path): Q_output = np.load(output_q_path) self.register_buffer('Q_output', torch.FloatTensor(Q_output).to(device)) print(f"Loaded output Q matrix from {output_q_path}, shape: {Q_output.shape}") else: print(f"Warning: Output Q matrix not found at {output_q_path}, using identity matrix") self.register_buffer('Q_output', torch.eye(configs.pred_len).to(device)) def apply_input_Q_transformation(self, x): """ Apply input Q matrix transformation after RevIN normalization Input: x with shape [B, T, N] where T = seq_len Output: transformed x with shape [B, T, N] """ B, T, N = x.size() assert T == self.seq_len, f"Expected seq_len {self.seq_len}, got {T}" # Transpose to [B, N, T] for matrix multiplication x_transposed = x.transpose(-1, -2) # [B, N, T] # Apply input Q transformation: einsum 'bnt,tv->bnv' # x_transposed: [B, N, T], Q_input.T: [T, T] -> result: [B, N, T] x_trans = torch.einsum('bnt,tv->bnv', x_transposed, self.Q_input.transpose(-1, -2)) # Transpose back to [B, T, N] x_transformed = x_trans.transpose(-1, -2) # [B, T, N] return x_transformed def apply_output_Q_transformation(self, x): """ Apply output Q matrix transformation to prediction output Input: x with shape [B, pred_len, N] Output: transformed x with shape [B, pred_len, N] """ B, T, N = x.size() assert T == self.pred_len, f"Expected pred_len {self.pred_len}, got {T}" # Transpose to [B, N, T] for matrix multiplication x_transposed = x.transpose(-1, -2) # [B, N, pred_len] # Apply output Q transformation: einsum 'bnt,tv->bnv' # x_transposed: [B, N, pred_len], Q_output: [pred_len, pred_len] -> result: [B, N, pred_len] x_trans = torch.einsum('bnt,tv->bnv', x_transposed, self.Q_output) # Transpose back to [B, pred_len, N] x_transformed = x_trans.transpose(-1, -2) # [B, pred_len, N] return x_transformed def forward(self, x): # x: [Batch, Input, Channel] # RevIN Normalization if self.revin: x = self.revin_layer(x, 'norm') # Apply input Q matrix transformation after RevIN normalization x_transformed = self.apply_input_Q_transformation(x) # xPatch processing with Q-transformed input if self.ma_type == 'reg': # If no decomposition, directly pass the input to the network output = self.net(x_transformed, x_transformed) else: seasonal_init, trend_init = self.decomp(x_transformed) output = self.net(seasonal_init, trend_init) # Apply output Q matrix transformation to the prediction output_transformed = self.apply_output_Q_transformation(output) # RevIN Denormalization if self.revin: output_transformed = self.revin_layer(output_transformed, 'denorm') return output_transformed