feat: add TimesNet_Q and xPatch models with Q matrix transformations
This commit is contained in:
1
models/xPatch_Q/__init__.py
Normal file
1
models/xPatch_Q/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .xPatch_Q import Model
|
132
models/xPatch_Q/network.py
Normal file
132
models/xPatch_Q/network.py
Normal file
@ -0,0 +1,132 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
class Network(nn.Module):
|
||||
def __init__(self, seq_len, pred_len, patch_len, stride, padding_patch):
|
||||
super(Network, self).__init__()
|
||||
|
||||
# Parameters
|
||||
self.pred_len = pred_len
|
||||
|
||||
# Non-linear Stream
|
||||
# Patching
|
||||
self.patch_len = patch_len
|
||||
self.stride = stride
|
||||
self.padding_patch = padding_patch
|
||||
self.dim = patch_len * patch_len
|
||||
self.patch_num = (seq_len - patch_len)//stride + 1
|
||||
if padding_patch == 'end': # can be modified to general case
|
||||
self.padding_patch_layer = nn.ReplicationPad1d((0, stride))
|
||||
self.patch_num += 1
|
||||
|
||||
# Patch Embedding
|
||||
self.fc1 = nn.Linear(patch_len, self.dim)
|
||||
self.gelu1 = nn.GELU()
|
||||
self.bn1 = nn.BatchNorm1d(self.patch_num)
|
||||
|
||||
# CNN Depthwise
|
||||
self.conv1 = nn.Conv1d(self.patch_num, self.patch_num,
|
||||
patch_len, patch_len, groups=self.patch_num)
|
||||
self.gelu2 = nn.GELU()
|
||||
self.bn2 = nn.BatchNorm1d(self.patch_num)
|
||||
|
||||
# Residual Stream
|
||||
self.fc2 = nn.Linear(self.dim, patch_len)
|
||||
|
||||
# CNN Pointwise
|
||||
self.conv2 = nn.Conv1d(self.patch_num, self.patch_num, 1, 1)
|
||||
self.gelu3 = nn.GELU()
|
||||
self.bn3 = nn.BatchNorm1d(self.patch_num)
|
||||
|
||||
# Flatten Head
|
||||
self.flatten1 = nn.Flatten(start_dim=-2)
|
||||
self.fc3 = nn.Linear(self.patch_num * patch_len, pred_len * 2)
|
||||
self.gelu4 = nn.GELU()
|
||||
self.fc4 = nn.Linear(pred_len * 2, pred_len)
|
||||
|
||||
# Linear Stream
|
||||
# MLP
|
||||
self.fc5 = nn.Linear(seq_len, pred_len * 4)
|
||||
self.avgpool1 = nn.AvgPool1d(kernel_size=2)
|
||||
self.ln1 = nn.LayerNorm(pred_len * 2)
|
||||
|
||||
self.fc6 = nn.Linear(pred_len * 2, pred_len)
|
||||
self.avgpool2 = nn.AvgPool1d(kernel_size=2)
|
||||
self.ln2 = nn.LayerNorm(pred_len // 2)
|
||||
|
||||
self.fc7 = nn.Linear(pred_len // 2, pred_len)
|
||||
|
||||
# Streams Concatination
|
||||
self.fc8 = nn.Linear(pred_len * 2, pred_len)
|
||||
|
||||
def forward(self, s, t):
|
||||
# x: [Batch, Input, Channel]
|
||||
# s - seasonality
|
||||
# t - trend
|
||||
|
||||
s = s.permute(0,2,1) # to [Batch, Channel, Input]
|
||||
t = t.permute(0,2,1) # to [Batch, Channel, Input]
|
||||
|
||||
# Channel split for channel independence
|
||||
B = s.shape[0] # Batch size
|
||||
C = s.shape[1] # Channel size
|
||||
I = s.shape[2] # Input size
|
||||
s = torch.reshape(s, (B*C, I)) # [Batch and Channel, Input]
|
||||
t = torch.reshape(t, (B*C, I)) # [Batch and Channel, Input]
|
||||
|
||||
# Non-linear Stream
|
||||
# Patching
|
||||
if self.padding_patch == 'end':
|
||||
s = self.padding_patch_layer(s)
|
||||
s = s.unfold(dimension=-1, size=self.patch_len, step=self.stride)
|
||||
# s: [Batch and Channel, Patch_num, Patch_len]
|
||||
|
||||
# Patch Embedding
|
||||
s = self.fc1(s)
|
||||
s = self.gelu1(s)
|
||||
s = self.bn1(s)
|
||||
|
||||
res = s
|
||||
|
||||
# CNN Depthwise
|
||||
s = self.conv1(s)
|
||||
s = self.gelu2(s)
|
||||
s = self.bn2(s)
|
||||
|
||||
# Residual Stream
|
||||
res = self.fc2(res)
|
||||
s = s + res
|
||||
|
||||
# CNN Pointwise
|
||||
s = self.conv2(s)
|
||||
s = self.gelu3(s)
|
||||
s = self.bn3(s)
|
||||
|
||||
# Flatten Head
|
||||
s = self.flatten1(s)
|
||||
s = self.fc3(s)
|
||||
s = self.gelu4(s)
|
||||
s = self.fc4(s)
|
||||
|
||||
# Linear Stream
|
||||
# MLP
|
||||
t = self.fc5(t)
|
||||
t = self.avgpool1(t)
|
||||
t = self.ln1(t)
|
||||
|
||||
t = self.fc6(t)
|
||||
t = self.avgpool2(t)
|
||||
t = self.ln2(t)
|
||||
|
||||
t = self.fc7(t)
|
||||
|
||||
# Streams Concatination
|
||||
x = torch.cat((s, t), dim=1)
|
||||
x = self.fc8(x)
|
||||
|
||||
# Channel concatination
|
||||
x = torch.reshape(x, (B, C, self.pred_len)) # [Batch, Channel, Output]
|
||||
|
||||
x = x.permute(0,2,1) # to [Batch, Output, Channel]
|
||||
|
||||
return x
|
146
models/xPatch_Q/xPatch_Q.py
Normal file
146
models/xPatch_Q/xPatch_Q.py
Normal file
@ -0,0 +1,146 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
from layers.decomp import DECOMP
|
||||
from .network import Network
|
||||
from layers.revin import RevIN
|
||||
|
||||
class Model(nn.Module):
|
||||
"""
|
||||
xPatch with Q matrix transformation
|
||||
- Applies RevIN normalization first
|
||||
- Applies input Q matrix transformation after RevIN normalization (based on dataset and seq_len)
|
||||
- Uses original xPatch logic (decomposition + dual stream network)
|
||||
- Applies output Q matrix transformation before RevIN denormalization (based on dataset and pred_len)
|
||||
"""
|
||||
|
||||
def __init__(self, configs):
|
||||
super(Model, self).__init__()
|
||||
|
||||
# Parameters
|
||||
seq_len = configs.seq_len # lookback window L
|
||||
pred_len = configs.pred_len # prediction length (96, 192, 336, 720)
|
||||
c_in = configs.enc_in # input channels
|
||||
|
||||
# Patching
|
||||
patch_len = configs.patch_len
|
||||
stride = configs.stride
|
||||
padding_patch = configs.padding_patch
|
||||
|
||||
# Store for Q matrix transformations
|
||||
self.seq_len = seq_len
|
||||
self.pred_len = pred_len
|
||||
|
||||
# Load Q matrices
|
||||
self.load_Q_matrices(configs)
|
||||
|
||||
# Normalization
|
||||
self.revin = configs.revin
|
||||
self.revin_layer = RevIN(c_in, affine=True, subtract_last=False)
|
||||
|
||||
# Moving Average
|
||||
self.ma_type = configs.ma_type
|
||||
alpha = configs.alpha # smoothing factor for EMA (Exponential Moving Average)
|
||||
beta = configs.beta # smoothing factor for DEMA (Double Exponential Moving Average)
|
||||
|
||||
self.decomp = DECOMP(self.ma_type, alpha, beta)
|
||||
self.net = Network(seq_len, pred_len, patch_len, stride, padding_patch)
|
||||
|
||||
def load_Q_matrices(self, configs):
|
||||
"""Load pre-computed Q matrices for input and output transformations"""
|
||||
# Get dataset name from configs, default to ETTm1 if not specified
|
||||
dataset_name = getattr(configs, 'dataset', 'ETTm1')
|
||||
|
||||
# Input Q matrix (seq_len)
|
||||
input_q_path = f'cov_mats/{dataset_name}/{dataset_name}_{configs.seq_len}_ratio1.0.npy'
|
||||
# Output Q matrix (pred_len)
|
||||
output_q_path = f'cov_mats/{dataset_name}/{dataset_name}_{configs.pred_len}_ratio1.0.npy'
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
if os.path.exists(input_q_path):
|
||||
Q_input = np.load(input_q_path)
|
||||
self.register_buffer('Q_input', torch.FloatTensor(Q_input).to(device))
|
||||
print(f"Loaded input Q matrix from {input_q_path}, shape: {Q_input.shape}")
|
||||
else:
|
||||
print(f"Warning: Input Q matrix not found at {input_q_path}, using identity matrix")
|
||||
self.register_buffer('Q_input', torch.eye(configs.seq_len).to(device))
|
||||
|
||||
if os.path.exists(output_q_path):
|
||||
Q_output = np.load(output_q_path)
|
||||
self.register_buffer('Q_output', torch.FloatTensor(Q_output).to(device))
|
||||
print(f"Loaded output Q matrix from {output_q_path}, shape: {Q_output.shape}")
|
||||
else:
|
||||
print(f"Warning: Output Q matrix not found at {output_q_path}, using identity matrix")
|
||||
self.register_buffer('Q_output', torch.eye(configs.pred_len).to(device))
|
||||
|
||||
def apply_input_Q_transformation(self, x):
|
||||
"""
|
||||
Apply input Q matrix transformation after RevIN normalization
|
||||
Input: x with shape [B, T, N] where T = seq_len
|
||||
Output: transformed x with shape [B, T, N]
|
||||
"""
|
||||
B, T, N = x.size()
|
||||
assert T == self.seq_len, f"Expected seq_len {self.seq_len}, got {T}"
|
||||
|
||||
# Transpose to [B, N, T] for matrix multiplication
|
||||
x_transposed = x.transpose(-1, -2) # [B, N, T]
|
||||
|
||||
# Apply input Q transformation: einsum 'bnt,tv->bnv'
|
||||
# x_transposed: [B, N, T], Q_input.T: [T, T] -> result: [B, N, T]
|
||||
x_trans = torch.einsum('bnt,tv->bnv', x_transposed, self.Q_input.transpose(-1, -2))
|
||||
|
||||
# Transpose back to [B, T, N]
|
||||
x_transformed = x_trans.transpose(-1, -2) # [B, T, N]
|
||||
|
||||
return x_transformed
|
||||
|
||||
def apply_output_Q_transformation(self, x):
|
||||
"""
|
||||
Apply output Q matrix transformation to prediction output
|
||||
Input: x with shape [B, pred_len, N]
|
||||
Output: transformed x with shape [B, pred_len, N]
|
||||
"""
|
||||
B, T, N = x.size()
|
||||
assert T == self.pred_len, f"Expected pred_len {self.pred_len}, got {T}"
|
||||
|
||||
# Transpose to [B, N, T] for matrix multiplication
|
||||
x_transposed = x.transpose(-1, -2) # [B, N, pred_len]
|
||||
|
||||
# Apply output Q transformation: einsum 'bnt,tv->bnv'
|
||||
# x_transposed: [B, N, pred_len], Q_output: [pred_len, pred_len] -> result: [B, N, pred_len]
|
||||
x_trans = torch.einsum('bnt,tv->bnv', x_transposed, self.Q_output)
|
||||
|
||||
# Transpose back to [B, pred_len, N]
|
||||
x_transformed = x_trans.transpose(-1, -2) # [B, pred_len, N]
|
||||
|
||||
return x_transformed
|
||||
|
||||
def forward(self, x):
|
||||
# x: [Batch, Input, Channel]
|
||||
|
||||
# RevIN Normalization
|
||||
if self.revin:
|
||||
x = self.revin_layer(x, 'norm')
|
||||
|
||||
# Apply input Q matrix transformation after RevIN normalization
|
||||
x_transformed = self.apply_input_Q_transformation(x)
|
||||
|
||||
# xPatch processing with Q-transformed input
|
||||
if self.ma_type == 'reg': # If no decomposition, directly pass the input to the network
|
||||
output = self.net(x_transformed, x_transformed)
|
||||
else:
|
||||
seasonal_init, trend_init = self.decomp(x_transformed)
|
||||
output = self.net(seasonal_init, trend_init)
|
||||
|
||||
# Apply output Q matrix transformation to the prediction
|
||||
output_transformed = self.apply_output_Q_transformation(output)
|
||||
|
||||
# RevIN Denormalization
|
||||
if self.revin:
|
||||
output_transformed = self.revin_layer(output_transformed, 'denorm')
|
||||
|
||||
return output_transformed
|
Reference in New Issue
Block a user