Files
tsmodel/models/xPatch_Q/xPatch_Q.py

147 lines
5.8 KiB
Python

import torch
import torch.nn as nn
import math
import numpy as np
import os
from layers.decomp import DECOMP
from .network import Network
from layers.revin import RevIN
class Model(nn.Module):
"""
xPatch with Q matrix transformation
- Applies RevIN normalization first
- Applies input Q matrix transformation after RevIN normalization (based on dataset and seq_len)
- Uses original xPatch logic (decomposition + dual stream network)
- Applies output Q matrix transformation before RevIN denormalization (based on dataset and pred_len)
"""
def __init__(self, configs):
super(Model, self).__init__()
# Parameters
seq_len = configs.seq_len # lookback window L
pred_len = configs.pred_len # prediction length (96, 192, 336, 720)
c_in = configs.enc_in # input channels
# Patching
patch_len = configs.patch_len
stride = configs.stride
padding_patch = configs.padding_patch
# Store for Q matrix transformations
self.seq_len = seq_len
self.pred_len = pred_len
# Load Q matrices
self.load_Q_matrices(configs)
# Normalization
self.revin = configs.revin
self.revin_layer = RevIN(c_in, affine=True, subtract_last=False)
# Moving Average
self.ma_type = configs.ma_type
alpha = configs.alpha # smoothing factor for EMA (Exponential Moving Average)
beta = configs.beta # smoothing factor for DEMA (Double Exponential Moving Average)
self.decomp = DECOMP(self.ma_type, alpha, beta)
self.net = Network(seq_len, pred_len, patch_len, stride, padding_patch)
def load_Q_matrices(self, configs):
"""Load pre-computed Q matrices for input and output transformations"""
# Get dataset name from configs, default to ETTm1 if not specified
dataset_name = getattr(configs, 'dataset', 'ETTm1')
# Input Q matrix (seq_len)
input_q_path = f'cov_mats/{dataset_name}/{dataset_name}_{configs.seq_len}_ratio1.0.npy'
# Output Q matrix (pred_len)
output_q_path = f'cov_mats/{dataset_name}/{dataset_name}_{configs.pred_len}_ratio1.0.npy'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if os.path.exists(input_q_path):
Q_input = np.load(input_q_path)
self.register_buffer('Q_input', torch.FloatTensor(Q_input).to(device))
print(f"Loaded input Q matrix from {input_q_path}, shape: {Q_input.shape}")
else:
print(f"Warning: Input Q matrix not found at {input_q_path}, using identity matrix")
self.register_buffer('Q_input', torch.eye(configs.seq_len).to(device))
if os.path.exists(output_q_path):
Q_output = np.load(output_q_path)
self.register_buffer('Q_output', torch.FloatTensor(Q_output).to(device))
print(f"Loaded output Q matrix from {output_q_path}, shape: {Q_output.shape}")
else:
print(f"Warning: Output Q matrix not found at {output_q_path}, using identity matrix")
self.register_buffer('Q_output', torch.eye(configs.pred_len).to(device))
def apply_input_Q_transformation(self, x):
"""
Apply input Q matrix transformation after RevIN normalization
Input: x with shape [B, T, N] where T = seq_len
Output: transformed x with shape [B, T, N]
"""
B, T, N = x.size()
assert T == self.seq_len, f"Expected seq_len {self.seq_len}, got {T}"
# Transpose to [B, N, T] for matrix multiplication
x_transposed = x.transpose(-1, -2) # [B, N, T]
# Apply input Q transformation: einsum 'bnt,tv->bnv'
# x_transposed: [B, N, T], Q_input.T: [T, T] -> result: [B, N, T]
x_trans = torch.einsum('bnt,tv->bnv', x_transposed, self.Q_input.transpose(-1, -2))
# Transpose back to [B, T, N]
x_transformed = x_trans.transpose(-1, -2) # [B, T, N]
return x_transformed
def apply_output_Q_transformation(self, x):
"""
Apply output Q matrix transformation to prediction output
Input: x with shape [B, pred_len, N]
Output: transformed x with shape [B, pred_len, N]
"""
B, T, N = x.size()
assert T == self.pred_len, f"Expected pred_len {self.pred_len}, got {T}"
# Transpose to [B, N, T] for matrix multiplication
x_transposed = x.transpose(-1, -2) # [B, N, pred_len]
# Apply output Q transformation: einsum 'bnt,tv->bnv'
# x_transposed: [B, N, pred_len], Q_output: [pred_len, pred_len] -> result: [B, N, pred_len]
x_trans = torch.einsum('bnt,tv->bnv', x_transposed, self.Q_output)
# Transpose back to [B, pred_len, N]
x_transformed = x_trans.transpose(-1, -2) # [B, pred_len, N]
return x_transformed
def forward(self, x):
# x: [Batch, Input, Channel]
# RevIN Normalization
if self.revin:
x = self.revin_layer(x, 'norm')
# Apply input Q matrix transformation after RevIN normalization
x_transformed = self.apply_input_Q_transformation(x)
# xPatch processing with Q-transformed input
if self.ma_type == 'reg': # If no decomposition, directly pass the input to the network
output = self.net(x_transformed, x_transformed)
else:
seasonal_init, trend_init = self.decomp(x_transformed)
output = self.net(seasonal_init, trend_init)
# Apply output Q matrix transformation to the prediction
output_transformed = self.apply_output_Q_transformation(output)
# RevIN Denormalization
if self.revin:
output_transformed = self.revin_layer(output_transformed, 'denorm')
return output_transformed