feat: add TimesNet_Q and xPatch models with Q matrix transformations

This commit is contained in:
game-loader
2025-08-06 18:39:26 +08:00
parent 7fdf0f364d
commit 6bba6613c9
14 changed files with 872 additions and 3 deletions

146
models/xPatch_Q/xPatch_Q.py Normal file
View File

@ -0,0 +1,146 @@
import torch
import torch.nn as nn
import math
import numpy as np
import os
from layers.decomp import DECOMP
from .network import Network
from layers.revin import RevIN
class Model(nn.Module):
"""
xPatch with Q matrix transformation
- Applies RevIN normalization first
- Applies input Q matrix transformation after RevIN normalization (based on dataset and seq_len)
- Uses original xPatch logic (decomposition + dual stream network)
- Applies output Q matrix transformation before RevIN denormalization (based on dataset and pred_len)
"""
def __init__(self, configs):
super(Model, self).__init__()
# Parameters
seq_len = configs.seq_len # lookback window L
pred_len = configs.pred_len # prediction length (96, 192, 336, 720)
c_in = configs.enc_in # input channels
# Patching
patch_len = configs.patch_len
stride = configs.stride
padding_patch = configs.padding_patch
# Store for Q matrix transformations
self.seq_len = seq_len
self.pred_len = pred_len
# Load Q matrices
self.load_Q_matrices(configs)
# Normalization
self.revin = configs.revin
self.revin_layer = RevIN(c_in, affine=True, subtract_last=False)
# Moving Average
self.ma_type = configs.ma_type
alpha = configs.alpha # smoothing factor for EMA (Exponential Moving Average)
beta = configs.beta # smoothing factor for DEMA (Double Exponential Moving Average)
self.decomp = DECOMP(self.ma_type, alpha, beta)
self.net = Network(seq_len, pred_len, patch_len, stride, padding_patch)
def load_Q_matrices(self, configs):
"""Load pre-computed Q matrices for input and output transformations"""
# Get dataset name from configs, default to ETTm1 if not specified
dataset_name = getattr(configs, 'dataset', 'ETTm1')
# Input Q matrix (seq_len)
input_q_path = f'cov_mats/{dataset_name}/{dataset_name}_{configs.seq_len}_ratio1.0.npy'
# Output Q matrix (pred_len)
output_q_path = f'cov_mats/{dataset_name}/{dataset_name}_{configs.pred_len}_ratio1.0.npy'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if os.path.exists(input_q_path):
Q_input = np.load(input_q_path)
self.register_buffer('Q_input', torch.FloatTensor(Q_input).to(device))
print(f"Loaded input Q matrix from {input_q_path}, shape: {Q_input.shape}")
else:
print(f"Warning: Input Q matrix not found at {input_q_path}, using identity matrix")
self.register_buffer('Q_input', torch.eye(configs.seq_len).to(device))
if os.path.exists(output_q_path):
Q_output = np.load(output_q_path)
self.register_buffer('Q_output', torch.FloatTensor(Q_output).to(device))
print(f"Loaded output Q matrix from {output_q_path}, shape: {Q_output.shape}")
else:
print(f"Warning: Output Q matrix not found at {output_q_path}, using identity matrix")
self.register_buffer('Q_output', torch.eye(configs.pred_len).to(device))
def apply_input_Q_transformation(self, x):
"""
Apply input Q matrix transformation after RevIN normalization
Input: x with shape [B, T, N] where T = seq_len
Output: transformed x with shape [B, T, N]
"""
B, T, N = x.size()
assert T == self.seq_len, f"Expected seq_len {self.seq_len}, got {T}"
# Transpose to [B, N, T] for matrix multiplication
x_transposed = x.transpose(-1, -2) # [B, N, T]
# Apply input Q transformation: einsum 'bnt,tv->bnv'
# x_transposed: [B, N, T], Q_input.T: [T, T] -> result: [B, N, T]
x_trans = torch.einsum('bnt,tv->bnv', x_transposed, self.Q_input.transpose(-1, -2))
# Transpose back to [B, T, N]
x_transformed = x_trans.transpose(-1, -2) # [B, T, N]
return x_transformed
def apply_output_Q_transformation(self, x):
"""
Apply output Q matrix transformation to prediction output
Input: x with shape [B, pred_len, N]
Output: transformed x with shape [B, pred_len, N]
"""
B, T, N = x.size()
assert T == self.pred_len, f"Expected pred_len {self.pred_len}, got {T}"
# Transpose to [B, N, T] for matrix multiplication
x_transposed = x.transpose(-1, -2) # [B, N, pred_len]
# Apply output Q transformation: einsum 'bnt,tv->bnv'
# x_transposed: [B, N, pred_len], Q_output: [pred_len, pred_len] -> result: [B, N, pred_len]
x_trans = torch.einsum('bnt,tv->bnv', x_transposed, self.Q_output)
# Transpose back to [B, pred_len, N]
x_transformed = x_trans.transpose(-1, -2) # [B, pred_len, N]
return x_transformed
def forward(self, x):
# x: [Batch, Input, Channel]
# RevIN Normalization
if self.revin:
x = self.revin_layer(x, 'norm')
# Apply input Q matrix transformation after RevIN normalization
x_transformed = self.apply_input_Q_transformation(x)
# xPatch processing with Q-transformed input
if self.ma_type == 'reg': # If no decomposition, directly pass the input to the network
output = self.net(x_transformed, x_transformed)
else:
seasonal_init, trend_init = self.decomp(x_transformed)
output = self.net(seasonal_init, trend_init)
# Apply output Q matrix transformation to the prediction
output_transformed = self.apply_output_Q_transformation(output)
# RevIN Denormalization
if self.revin:
output_transformed = self.revin_layer(output_transformed, 'denorm')
return output_transformed