feat: add TimesNet_Q and xPatch models with Q matrix transformations

This commit is contained in:
game-loader
2025-08-06 18:39:26 +08:00
parent 7fdf0f364d
commit 6bba6613c9
14 changed files with 872 additions and 3 deletions

View File

@ -0,0 +1,221 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.fft
import numpy as np
import os
from layers.Embed import DataEmbedding
from layers.Conv_Blocks import Inception_Block_V1
def FFT_for_Period(x, k=2):
# [B, T, C]
xf = torch.fft.rfft(x, dim=1)
# find period by amplitudes
frequency_list = abs(xf).mean(0).mean(-1)
frequency_list[0] = 0
_, top_list = torch.topk(frequency_list, k)
top_list = top_list.detach().cpu().numpy()
period = x.shape[1] // top_list
return period, abs(xf).mean(-1)[:, top_list]
class TimesBlock(nn.Module):
"""Original TimesBlock without Q matrix transformation"""
def __init__(self, configs):
super(TimesBlock, self).__init__()
self.seq_len = configs.seq_len
self.pred_len = configs.pred_len
self.k = configs.top_k
# parameter-efficient design
self.conv = nn.Sequential(
Inception_Block_V1(configs.d_model, configs.d_ff,
num_kernels=configs.num_kernels),
nn.GELU(),
Inception_Block_V1(configs.d_ff, configs.d_model,
num_kernels=configs.num_kernels)
)
def forward(self, x):
B, T, N = x.size()
period_list, period_weight = FFT_for_Period(x, self.k)
res = []
for i in range(self.k):
period = period_list[i]
# padding
if (self.seq_len + self.pred_len) % period != 0:
length = (
((self.seq_len + self.pred_len) // period) + 1) * period
padding = torch.zeros([x.shape[0], (length - (self.seq_len + self.pred_len)), x.shape[2]]).to(x.device)
out = torch.cat([x, padding], dim=1)
else:
length = (self.seq_len + self.pred_len)
out = x
# reshape
out = out.reshape(B, length // period, period,
N).permute(0, 3, 1, 2).contiguous()
# 2D conv: from 1d Variation to 2d Variation
out = self.conv(out)
# reshape back
out = out.permute(0, 2, 3, 1).reshape(B, -1, N)
res.append(out[:, :(self.seq_len + self.pred_len), :])
res = torch.stack(res, dim=-1)
# adaptive aggregation
period_weight = F.softmax(period_weight, dim=1)
period_weight = period_weight.unsqueeze(
1).unsqueeze(1).repeat(1, T, N, 1)
res = torch.sum(res * period_weight, -1)
# residual connection
res = res + x
return res
class Model(nn.Module):
"""
TimesNet with Q matrix transformation
- Applies input Q matrix transformation before embedding
- Uses original TimesBlock logic
- Applies output Q matrix transformation before De-Normalization
Only implements long/short term forecasting
"""
def __init__(self, configs):
super(Model, self).__init__()
self.configs = configs
self.task_name = configs.task_name
self.seq_len = configs.seq_len
self.label_len = configs.label_len
self.pred_len = configs.pred_len
# Load Q matrices
self.load_Q_matrices(configs)
# Model layers
self.model = nn.ModuleList([TimesBlock(configs)
for _ in range(configs.e_layers)])
self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,
configs.dropout)
self.layer = configs.e_layers
self.layer_norm = nn.LayerNorm(configs.d_model)
# Only implement forecast-related layers
self.predict_linear = nn.Linear(
self.seq_len, self.pred_len + self.seq_len)
self.projection = nn.Linear(
configs.d_model, configs.c_out, bias=True)
def load_Q_matrices(self, configs):
"""Load pre-computed Q matrices for input and output transformations"""
# Get dataset name from configs, default to ETTm1 if not specified
dataset_name = getattr(configs, 'dataset', 'ETTm1')
# Input Q matrix (seq_len)
input_q_path = f'cov_mats/{dataset_name}/{dataset_name}_{configs.seq_len}_ratio1.0.npy'
# Output Q matrix (pred_len)
output_q_path = f'cov_mats/{dataset_name}/{dataset_name}_{configs.pred_len}_ratio1.0.npy'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if os.path.exists(input_q_path):
Q_input = np.load(input_q_path)
self.register_buffer('Q_input', torch.FloatTensor(Q_input).to(device))
print(f"Loaded input Q matrix from {input_q_path}, shape: {Q_input.shape}")
else:
print(f"Warning: Input Q matrix not found at {input_q_path}, using identity matrix")
self.register_buffer('Q_input', torch.eye(configs.seq_len).to(device))
if os.path.exists(output_q_path):
Q_output = np.load(output_q_path)
self.register_buffer('Q_output', torch.FloatTensor(Q_output).to(device))
print(f"Loaded output Q matrix from {output_q_path}, shape: {Q_output.shape}")
else:
print(f"Warning: Output Q matrix not found at {output_q_path}, using identity matrix")
self.register_buffer('Q_output', torch.eye(configs.pred_len).to(device))
def apply_input_Q_transformation(self, x):
"""
Apply input Q matrix transformation before embedding
Input: x with shape [B, T, N] where T = seq_len
Output: transformed x with shape [B, T, N]
"""
B, T, N = x.size()
# Transpose to [B, N, T] for matrix multiplication
x_transposed = x.transpose(-1, -2) # [B, N, T]
# Apply input Q transformation: einsum 'bnt,tv->bnv'
# x_transposed: [B, N, T], Q_input.T: [T, T] -> result: [B, N, T]
x_trans = torch.einsum('bnt,tv->bnv', x_transposed, self.Q_input.transpose(-1, -2))
# Transpose back to [B, T, N]
x_transformed = x_trans.transpose(-1, -2) # [B, T, N]
return x_transformed
def apply_output_Q_transformation(self, x):
"""
Apply output Q matrix transformation to prediction output
Input: x with shape [B, pred_len, N]
Output: transformed x with shape [B, pred_len, N]
"""
B, T, N = x.size()
assert T == self.pred_len, f"Expected pred_len {self.pred_len}, got {T}"
# Transpose to [B, N, T] for matrix multiplication
x_transposed = x.transpose(-1, -2) # [B, N, pred_len]
# Apply output Q transformation: einsum 'bnt,tv->bnv'
# x_transposed: [B, N, pred_len], Q_output: [pred_len, pred_len] -> result: [B, N, pred_len]
x_trans = torch.einsum('bnt,tv->bnv', x_transposed, self.Q_output)
# Transpose back to [B, pred_len, N]
x_transformed = x_trans.transpose(-1, -2) # [B, pred_len, N]
return x_transformed
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
# Normalization from Non-stationary Transformer
means = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc.sub(means)
stdev = torch.sqrt(
torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
x_enc = x_enc.div(stdev)
# Apply input Q matrix transformation before embedding
x_enc_transformed = self.apply_input_Q_transformation(x_enc)
# embedding with transformed input
enc_out = self.enc_embedding(x_enc_transformed, x_mark_enc) # [B,T,C]
enc_out = self.predict_linear(enc_out.permute(0, 2, 1)).permute(
0, 2, 1) # align temporal dimension
# TimesNet blocks (original logic, no Q transformation)
for i in range(self.layer):
enc_out = self.layer_norm(self.model[i](enc_out))
# project back
dec_out = self.projection(enc_out)
# Extract prediction part and apply output Q transformation
pred_out = dec_out[:, -self.pred_len:, :] # [B, pred_len, N]
pred_out_transformed = self.apply_output_Q_transformation(pred_out)
# De-Normalization from Non-stationary Transformer
pred_out_transformed = pred_out_transformed.mul(
(stdev[:, 0, :].unsqueeze(1).repeat(
1, self.pred_len, 1)))
pred_out_transformed = pred_out_transformed.add(
(means[:, 0, :].unsqueeze(1).repeat(
1, self.pred_len, 1)))
return pred_out_transformed
def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None):
# Only support long_term_forecast and short_term_forecast
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
return dec_out # [B, pred_len, D]
else:
raise NotImplementedError(f"Task {self.task_name} is not implemented in TimesNet_Q")
return None