Files
TSlib/models/Koopa.py
2025-08-28 10:17:59 +00:00

338 lines
13 KiB
Python

import math
import torch
import torch.nn as nn
from data_provider.data_factory import data_provider
class FourierFilter(nn.Module):
"""
Fourier Filter: to time-variant and time-invariant term
"""
def __init__(self, mask_spectrum):
super(FourierFilter, self).__init__()
self.mask_spectrum = mask_spectrum
def forward(self, x):
xf = torch.fft.rfft(x, dim=1)
mask = torch.ones_like(xf)
mask[:, self.mask_spectrum, :] = 0
x_var = torch.fft.irfft(xf*mask, dim=1)
x_inv = x - x_var
return x_var, x_inv
class MLP(nn.Module):
'''
Multilayer perceptron to encode/decode high dimension representation of sequential data
'''
def __init__(self,
f_in,
f_out,
hidden_dim=128,
hidden_layers=2,
dropout=0.05,
activation='tanh'):
super(MLP, self).__init__()
self.f_in = f_in
self.f_out = f_out
self.hidden_dim = hidden_dim
self.hidden_layers = hidden_layers
self.dropout = dropout
if activation == 'relu':
self.activation = nn.ReLU()
elif activation == 'tanh':
self.activation = nn.Tanh()
else:
raise NotImplementedError
layers = [nn.Linear(self.f_in, self.hidden_dim),
self.activation, nn.Dropout(self.dropout)]
for i in range(self.hidden_layers-2):
layers += [nn.Linear(self.hidden_dim, self.hidden_dim),
self.activation, nn.Dropout(dropout)]
layers += [nn.Linear(hidden_dim, f_out)]
self.layers = nn.Sequential(*layers)
def forward(self, x):
# x: B x S x f_in
# y: B x S x f_out
y = self.layers(x)
return y
class KPLayer(nn.Module):
"""
A demonstration of finding one step transition of linear system by DMD iteratively
"""
def __init__(self):
super(KPLayer, self).__init__()
self.K = None # B E E
def one_step_forward(self, z, return_rec=False, return_K=False):
B, input_len, E = z.shape
assert input_len > 1, 'snapshots number should be larger than 1'
x, y = z[:, :-1], z[:, 1:]
# solve linear system
self.K = torch.linalg.lstsq(x, y).solution # B E E
if torch.isnan(self.K).any():
print('Encounter K with nan, replace K by identity matrix')
self.K = torch.eye(self.K.shape[1]).to(self.K.device).unsqueeze(0).repeat(B, 1, 1)
z_pred = torch.bmm(z[:, -1:], self.K)
if return_rec:
z_rec = torch.cat((z[:, :1], torch.bmm(x, self.K)), dim=1)
return z_rec, z_pred
return z_pred
def forward(self, z, pred_len=1):
assert pred_len >= 1, 'prediction length should not be less than 1'
z_rec, z_pred= self.one_step_forward(z, return_rec=True)
z_preds = [z_pred]
for i in range(1, pred_len):
z_pred = torch.bmm(z_pred, self.K)
z_preds.append(z_pred)
z_preds = torch.cat(z_preds, dim=1)
return z_rec, z_preds
class KPLayerApprox(nn.Module):
"""
Find koopman transition of linear system by DMD with multistep K approximation
"""
def __init__(self):
super(KPLayerApprox, self).__init__()
self.K = None # B E E
self.K_step = None # B E E
def forward(self, z, pred_len=1):
# z: B L E, koopman invariance space representation
# z_rec: B L E, reconstructed representation
# z_pred: B S E, forecasting representation
B, input_len, E = z.shape
assert input_len > 1, 'snapshots number should be larger than 1'
x, y = z[:, :-1], z[:, 1:]
# solve linear system
self.K = torch.linalg.lstsq(x, y).solution # B E E
if torch.isnan(self.K).any():
print('Encounter K with nan, replace K by identity matrix')
self.K = torch.eye(self.K.shape[1]).to(self.K.device).unsqueeze(0).repeat(B, 1, 1)
z_rec = torch.cat((z[:, :1], torch.bmm(x, self.K)), dim=1) # B L E
if pred_len <= input_len:
self.K_step = torch.linalg.matrix_power(self.K, pred_len)
if torch.isnan(self.K_step).any():
print('Encounter multistep K with nan, replace it by identity matrix')
self.K_step = torch.eye(self.K_step.shape[1]).to(self.K_step.device).unsqueeze(0).repeat(B, 1, 1)
z_pred = torch.bmm(z[:, -pred_len:, :], self.K_step)
else:
self.K_step = torch.linalg.matrix_power(self.K, input_len)
if torch.isnan(self.K_step).any():
print('Encounter multistep K with nan, replace it by identity matrix')
self.K_step = torch.eye(self.K_step.shape[1]).to(self.K_step.device).unsqueeze(0).repeat(B, 1, 1)
temp_z_pred, all_pred = z, []
for _ in range(math.ceil(pred_len / input_len)):
temp_z_pred = torch.bmm(temp_z_pred, self.K_step)
all_pred.append(temp_z_pred)
z_pred = torch.cat(all_pred, dim=1)[:, :pred_len, :]
return z_rec, z_pred
class TimeVarKP(nn.Module):
"""
Koopman Predictor with DMD (analysitical solution of Koopman operator)
Utilize local variations within individual sliding window to predict the future of time-variant term
"""
def __init__(self,
enc_in=8,
input_len=96,
pred_len=96,
seg_len=24,
dynamic_dim=128,
encoder=None,
decoder=None,
multistep=False,
):
super(TimeVarKP, self).__init__()
self.input_len = input_len
self.pred_len = pred_len
self.enc_in = enc_in
self.seg_len = seg_len
self.dynamic_dim = dynamic_dim
self.multistep = multistep
self.encoder, self.decoder = encoder, decoder
self.freq = math.ceil(self.input_len / self.seg_len) # segment number of input
self.step = math.ceil(self.pred_len / self.seg_len) # segment number of output
self.padding_len = self.seg_len * self.freq - self.input_len
# Approximate mulitstep K by KPLayerApprox when pred_len is large
self.dynamics = KPLayerApprox() if self.multistep else KPLayer()
def forward(self, x):
# x: B L C
B, L, C = x.shape
res = torch.cat((x[:, L-self.padding_len:, :], x) ,dim=1)
res = res.chunk(self.freq, dim=1) # F x B P C, P means seg_len
res = torch.stack(res, dim=1).reshape(B, self.freq, -1) # B F PC
res = self.encoder(res) # B F H
x_rec, x_pred = self.dynamics(res, self.step) # B F H, B S H
x_rec = self.decoder(x_rec) # B F PC
x_rec = x_rec.reshape(B, self.freq, self.seg_len, self.enc_in)
x_rec = x_rec.reshape(B, -1, self.enc_in)[:, :self.input_len, :] # B L C
x_pred = self.decoder(x_pred) # B S PC
x_pred = x_pred.reshape(B, self.step, self.seg_len, self.enc_in)
x_pred = x_pred.reshape(B, -1, self.enc_in)[:, :self.pred_len, :] # B S C
return x_rec, x_pred
class TimeInvKP(nn.Module):
"""
Koopman Predictor with learnable Koopman operator
Utilize lookback and forecast window snapshots to predict the future of time-invariant term
"""
def __init__(self,
input_len=96,
pred_len=96,
dynamic_dim=128,
encoder=None,
decoder=None):
super(TimeInvKP, self).__init__()
self.dynamic_dim = dynamic_dim
self.input_len = input_len
self.pred_len = pred_len
self.encoder = encoder
self.decoder = decoder
K_init = torch.randn(self.dynamic_dim, self.dynamic_dim)
U, _, V = torch.svd(K_init) # stable initialization
self.K = nn.Linear(self.dynamic_dim, self.dynamic_dim, bias=False)
self.K.weight.data = torch.mm(U, V.t())
def forward(self, x):
# x: B L C
res = x.transpose(1, 2) # B C L
res = self.encoder(res) # B C H
res = self.K(res) # B C H
res = self.decoder(res) # B C S
res = res.transpose(1, 2) # B S C
return res
class Model(nn.Module):
'''
Paper link: https://arxiv.org/pdf/2305.18803.pdf
'''
def __init__(self, configs, dynamic_dim=128, hidden_dim=64, hidden_layers=2, num_blocks=3, multistep=False):
"""
mask_spectrum: list, shared frequency spectrums
seg_len: int, segment length of time series
dynamic_dim: int, latent dimension of koopman embedding
hidden_dim: int, hidden dimension of en/decoder
hidden_layers: int, number of hidden layers of en/decoder
num_blocks: int, number of Koopa blocks
multistep: bool, whether to use approximation for multistep K
alpha: float, spectrum filter ratio
"""
super(Model, self).__init__()
self.task_name = configs.task_name
self.enc_in = configs.enc_in
self.input_len = configs.seq_len
self.pred_len = configs.pred_len
self.seg_len = self.pred_len
self.num_blocks = num_blocks
self.dynamic_dim = dynamic_dim
self.hidden_dim = hidden_dim
self.hidden_layers = hidden_layers
self.multistep = multistep
self.alpha = 0.2
self.mask_spectrum = self._get_mask_spectrum(configs)
self.disentanglement = FourierFilter(self.mask_spectrum)
# shared encoder/decoder to make koopman embedding consistent
self.time_inv_encoder = MLP(f_in=self.input_len, f_out=self.dynamic_dim, activation='relu',
hidden_dim=self.hidden_dim, hidden_layers=self.hidden_layers)
self.time_inv_decoder = MLP(f_in=self.dynamic_dim, f_out=self.pred_len, activation='relu',
hidden_dim=self.hidden_dim, hidden_layers=self.hidden_layers)
self.time_inv_kps = self.time_var_kps = nn.ModuleList([
TimeInvKP(input_len=self.input_len,
pred_len=self.pred_len,
dynamic_dim=self.dynamic_dim,
encoder=self.time_inv_encoder,
decoder=self.time_inv_decoder)
for _ in range(self.num_blocks)])
# shared encoder/decoder to make koopman embedding consistent
self.time_var_encoder = MLP(f_in=self.seg_len*self.enc_in, f_out=self.dynamic_dim, activation='tanh',
hidden_dim=self.hidden_dim, hidden_layers=self.hidden_layers)
self.time_var_decoder = MLP(f_in=self.dynamic_dim, f_out=self.seg_len*self.enc_in, activation='tanh',
hidden_dim=self.hidden_dim, hidden_layers=self.hidden_layers)
self.time_var_kps = nn.ModuleList([
TimeVarKP(enc_in=configs.enc_in,
input_len=self.input_len,
pred_len=self.pred_len,
seg_len=self.seg_len,
dynamic_dim=self.dynamic_dim,
encoder=self.time_var_encoder,
decoder=self.time_var_decoder,
multistep=self.multistep)
for _ in range(self.num_blocks)])
def _get_mask_spectrum(self, configs):
"""
get shared frequency spectrums
"""
train_data, train_loader = data_provider(configs, 'train')
amps = 0.0
for data in train_loader:
lookback_window = data[0]
amps += abs(torch.fft.rfft(lookback_window, dim=1)).mean(dim=0).mean(dim=1)
mask_spectrum = amps.topk(int(amps.shape[0]*self.alpha)).indices
return mask_spectrum # as the spectrums of time-invariant component
def forecast(self, x_enc):
# Series Stationarization adopted from NSformer
mean_enc = x_enc.mean(1, keepdim=True).detach() # B x 1 x E
x_enc = x_enc - mean_enc
std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
x_enc = x_enc / std_enc
# Koopman Forecasting
residual, forecast = x_enc, None
for i in range(self.num_blocks):
time_var_input, time_inv_input = self.disentanglement(residual)
time_inv_output = self.time_inv_kps[i](time_inv_input)
time_var_backcast, time_var_output = self.time_var_kps[i](time_var_input)
residual = residual - time_var_backcast
if forecast is None:
forecast = (time_inv_output + time_var_output)
else:
forecast += (time_inv_output + time_var_output)
# Series Stationarization adopted from NSformer
res = forecast * std_enc + mean_enc
return res
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
if self.task_name == 'long_term_forecast':
dec_out = self.forecast(x_enc)
return dec_out[:, -self.pred_len:, :] # [B, L, D]