feat(model): add initial PatchTST model architecture and utilities
This commit is contained in:
379
layers/PatchTST_backbone.py
Normal file
379
layers/PatchTST_backbone.py
Normal file
@ -0,0 +1,379 @@
|
||||
__all__ = ['PatchTST_backbone']
|
||||
|
||||
# Cell
|
||||
from typing import Callable, Optional
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch import Tensor
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
#from collections import OrderedDict
|
||||
from layers.PatchTST_layers import *
|
||||
from layers.revin import RevIN
|
||||
|
||||
# Cell
|
||||
class PatchTST_backbone(nn.Module):
|
||||
def __init__(self, c_in:int, context_window:int, target_window:int, patch_len:int, stride:int, max_seq_len:Optional[int]=1024,
|
||||
n_layers:int=3, d_model=128, n_heads=16, d_k:Optional[int]=None, d_v:Optional[int]=None,
|
||||
d_ff:int=256, norm:str='BatchNorm', attn_dropout:float=0., dropout:float=0., act:str="gelu", key_padding_mask:bool='auto',
|
||||
padding_var:Optional[int]=None, attn_mask:Optional[Tensor]=None, res_attention:bool=True, pre_norm:bool=False, store_attn:bool=False,
|
||||
pe:str='zeros', learn_pe:bool=True, fc_dropout:float=0., head_dropout = 0, padding_patch = None,
|
||||
pretrain_head:bool=False, head_type = 'flatten', individual = False, revin = True, affine = True, subtract_last = False,
|
||||
verbose:bool=False, **kwargs):
|
||||
|
||||
super().__init__()
|
||||
|
||||
# RevIn
|
||||
self.revin = revin
|
||||
if self.revin: self.revin_layer = RevIN(c_in, affine=affine, subtract_last=subtract_last)
|
||||
|
||||
# Patching
|
||||
self.patch_len = patch_len
|
||||
self.stride = stride
|
||||
self.padding_patch = padding_patch
|
||||
patch_num = int((context_window - patch_len)/stride + 1)
|
||||
if padding_patch == 'end': # can be modified to general case
|
||||
self.padding_patch_layer = nn.ReplicationPad1d((0, stride))
|
||||
patch_num += 1
|
||||
|
||||
# Backbone
|
||||
self.backbone = TSTiEncoder(c_in, patch_num=patch_num, patch_len=patch_len, max_seq_len=max_seq_len,
|
||||
n_layers=n_layers, d_model=d_model, n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff,
|
||||
attn_dropout=attn_dropout, dropout=dropout, act=act, key_padding_mask=key_padding_mask, padding_var=padding_var,
|
||||
attn_mask=attn_mask, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn,
|
||||
pe=pe, learn_pe=learn_pe, verbose=verbose, **kwargs)
|
||||
|
||||
# Head
|
||||
self.head_nf = d_model * patch_num
|
||||
self.n_vars = c_in
|
||||
self.pretrain_head = pretrain_head
|
||||
self.head_type = head_type
|
||||
self.individual = individual
|
||||
|
||||
if self.pretrain_head:
|
||||
self.head = self.create_pretrain_head(self.head_nf, c_in, fc_dropout) # custom head passed as a partial func with all its kwargs
|
||||
elif head_type == 'flatten':
|
||||
self.head = Flatten_Head(self.individual, self.n_vars, self.head_nf, target_window, head_dropout=head_dropout)
|
||||
|
||||
|
||||
def forward(self, z): # z: [bs x nvars x seq_len]
|
||||
# norm
|
||||
if self.revin:
|
||||
z = z.permute(0,2,1)
|
||||
z = self.revin_layer(z, 'norm')
|
||||
z = z.permute(0,2,1)
|
||||
|
||||
# do patching
|
||||
if self.padding_patch == 'end':
|
||||
z = self.padding_patch_layer(z)
|
||||
z = z.unfold(dimension=-1, size=self.patch_len, step=self.stride) # z: [bs x nvars x patch_num x patch_len]
|
||||
z = z.permute(0,1,3,2) # z: [bs x nvars x patch_len x patch_num]
|
||||
|
||||
# model
|
||||
z = self.backbone(z) # z: [bs x nvars x d_model x patch_num]
|
||||
z = self.head(z) # z: [bs x nvars x target_window]
|
||||
|
||||
# denorm
|
||||
if self.revin:
|
||||
z = z.permute(0,2,1)
|
||||
z = self.revin_layer(z, 'denorm')
|
||||
z = z.permute(0,2,1)
|
||||
return z
|
||||
|
||||
def create_pretrain_head(self, head_nf, vars, dropout):
|
||||
return nn.Sequential(nn.Dropout(dropout),
|
||||
nn.Conv1d(head_nf, vars, 1)
|
||||
)
|
||||
|
||||
|
||||
class Flatten_Head(nn.Module):
|
||||
def __init__(self, individual, n_vars, nf, target_window, head_dropout=0):
|
||||
super().__init__()
|
||||
|
||||
self.individual = individual
|
||||
self.n_vars = n_vars
|
||||
|
||||
if self.individual:
|
||||
self.linears = nn.ModuleList()
|
||||
self.dropouts = nn.ModuleList()
|
||||
self.flattens = nn.ModuleList()
|
||||
for i in range(self.n_vars):
|
||||
self.flattens.append(nn.Flatten(start_dim=-2))
|
||||
self.linears.append(nn.Linear(nf, target_window))
|
||||
self.dropouts.append(nn.Dropout(head_dropout))
|
||||
else:
|
||||
self.flatten = nn.Flatten(start_dim=-2)
|
||||
self.linear = nn.Linear(nf, target_window)
|
||||
self.dropout = nn.Dropout(head_dropout)
|
||||
|
||||
def forward(self, x): # x: [bs x nvars x d_model x patch_num]
|
||||
if self.individual:
|
||||
x_out = []
|
||||
for i in range(self.n_vars):
|
||||
z = self.flattens[i](x[:,i,:,:]) # z: [bs x d_model * patch_num]
|
||||
z = self.linears[i](z) # z: [bs x target_window]
|
||||
z = self.dropouts[i](z)
|
||||
x_out.append(z)
|
||||
x = torch.stack(x_out, dim=1) # x: [bs x nvars x target_window]
|
||||
else:
|
||||
x = self.flatten(x)
|
||||
x = self.linear(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
|
||||
class TSTiEncoder(nn.Module): #i means channel-independent
|
||||
def __init__(self, c_in, patch_num, patch_len, max_seq_len=1024,
|
||||
n_layers=3, d_model=128, n_heads=16, d_k=None, d_v=None,
|
||||
d_ff=256, norm='BatchNorm', attn_dropout=0., dropout=0., act="gelu", store_attn=False,
|
||||
key_padding_mask='auto', padding_var=None, attn_mask=None, res_attention=True, pre_norm=False,
|
||||
pe='zeros', learn_pe=True, verbose=False, **kwargs):
|
||||
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.patch_num = patch_num
|
||||
self.patch_len = patch_len
|
||||
|
||||
# Input encoding
|
||||
q_len = patch_num
|
||||
self.W_P = nn.Linear(patch_len, d_model) # Eq 1: projection of feature vectors onto a d-dim vector space
|
||||
self.seq_len = q_len
|
||||
|
||||
# Positional encoding
|
||||
self.W_pos = positional_encoding(pe, learn_pe, q_len, d_model)
|
||||
|
||||
# Residual dropout
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
# Encoder
|
||||
self.encoder = TSTEncoder(q_len, d_model, n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout, dropout=dropout,
|
||||
pre_norm=pre_norm, activation=act, res_attention=res_attention, n_layers=n_layers, store_attn=store_attn)
|
||||
|
||||
|
||||
def forward(self, x) -> Tensor: # x: [bs x nvars x patch_len x patch_num]
|
||||
|
||||
n_vars = x.shape[1]
|
||||
# Input encoding
|
||||
x = x.permute(0,1,3,2) # x: [bs x nvars x patch_num x patch_len]
|
||||
x = self.W_P(x) # x: [bs x nvars x patch_num x d_model]
|
||||
|
||||
u = torch.reshape(x, (x.shape[0]*x.shape[1],x.shape[2],x.shape[3])) # u: [bs * nvars x patch_num x d_model]
|
||||
u = self.dropout(u + self.W_pos) # u: [bs * nvars x patch_num x d_model]
|
||||
|
||||
# Encoder
|
||||
z = self.encoder(u) # z: [bs * nvars x patch_num x d_model]
|
||||
z = torch.reshape(z, (-1,n_vars,z.shape[-2],z.shape[-1])) # z: [bs x nvars x patch_num x d_model]
|
||||
z = z.permute(0,1,3,2) # z: [bs x nvars x d_model x patch_num]
|
||||
|
||||
return z
|
||||
|
||||
|
||||
|
||||
# Cell
|
||||
class TSTEncoder(nn.Module):
|
||||
def __init__(self, q_len, d_model, n_heads, d_k=None, d_v=None, d_ff=None,
|
||||
norm='BatchNorm', attn_dropout=0., dropout=0., activation='gelu',
|
||||
res_attention=False, n_layers=1, pre_norm=False, store_attn=False):
|
||||
super().__init__()
|
||||
|
||||
self.layers = nn.ModuleList([TSTEncoderLayer(q_len, d_model, n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, norm=norm,
|
||||
attn_dropout=attn_dropout, dropout=dropout,
|
||||
activation=activation, res_attention=res_attention,
|
||||
pre_norm=pre_norm, store_attn=store_attn) for i in range(n_layers)])
|
||||
self.res_attention = res_attention
|
||||
|
||||
def forward(self, src:Tensor, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None):
|
||||
output = src
|
||||
scores = None
|
||||
if self.res_attention:
|
||||
for mod in self.layers: output, scores = mod(output, prev=scores, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
|
||||
return output
|
||||
else:
|
||||
for mod in self.layers: output = mod(output, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
|
||||
return output
|
||||
|
||||
|
||||
|
||||
class TSTEncoderLayer(nn.Module):
|
||||
def __init__(self, q_len, d_model, n_heads, d_k=None, d_v=None, d_ff=256, store_attn=False,
|
||||
norm='BatchNorm', attn_dropout=0, dropout=0., bias=True, activation="gelu", res_attention=False, pre_norm=False):
|
||||
super().__init__()
|
||||
assert not d_model%n_heads, f"d_model ({d_model}) must be divisible by n_heads ({n_heads})"
|
||||
d_k = d_model // n_heads if d_k is None else d_k
|
||||
d_v = d_model // n_heads if d_v is None else d_v
|
||||
|
||||
# Multi-Head attention
|
||||
self.res_attention = res_attention
|
||||
self.self_attn = _MultiheadAttention(d_model, n_heads, d_k, d_v, attn_dropout=attn_dropout, proj_dropout=dropout, res_attention=res_attention)
|
||||
|
||||
# Add & Norm
|
||||
self.dropout_attn = nn.Dropout(dropout)
|
||||
if "batch" in norm.lower():
|
||||
self.norm_attn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2))
|
||||
else:
|
||||
self.norm_attn = nn.LayerNorm(d_model)
|
||||
|
||||
# Position-wise Feed-Forward
|
||||
self.ff = nn.Sequential(nn.Linear(d_model, d_ff, bias=bias),
|
||||
get_activation_fn(activation),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(d_ff, d_model, bias=bias))
|
||||
|
||||
# Add & Norm
|
||||
self.dropout_ffn = nn.Dropout(dropout)
|
||||
if "batch" in norm.lower():
|
||||
self.norm_ffn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2))
|
||||
else:
|
||||
self.norm_ffn = nn.LayerNorm(d_model)
|
||||
|
||||
self.pre_norm = pre_norm
|
||||
self.store_attn = store_attn
|
||||
|
||||
|
||||
def forward(self, src:Tensor, prev:Optional[Tensor]=None, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None) -> Tensor:
|
||||
|
||||
# Multi-Head attention sublayer
|
||||
if self.pre_norm:
|
||||
src = self.norm_attn(src)
|
||||
## Multi-Head attention
|
||||
if self.res_attention:
|
||||
src2, attn, scores = self.self_attn(src, src, src, prev, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
|
||||
else:
|
||||
src2, attn = self.self_attn(src, src, src, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
|
||||
if self.store_attn:
|
||||
self.attn = attn
|
||||
## Add & Norm
|
||||
src = src + self.dropout_attn(src2) # Add: residual connection with residual dropout
|
||||
if not self.pre_norm:
|
||||
src = self.norm_attn(src)
|
||||
|
||||
# Feed-forward sublayer
|
||||
if self.pre_norm:
|
||||
src = self.norm_ffn(src)
|
||||
## Position-wise Feed-Forward
|
||||
src2 = self.ff(src)
|
||||
## Add & Norm
|
||||
src = src + self.dropout_ffn(src2) # Add: residual connection with residual dropout
|
||||
if not self.pre_norm:
|
||||
src = self.norm_ffn(src)
|
||||
|
||||
if self.res_attention:
|
||||
return src, scores
|
||||
else:
|
||||
return src
|
||||
|
||||
|
||||
|
||||
|
||||
class _MultiheadAttention(nn.Module):
|
||||
def __init__(self, d_model, n_heads, d_k=None, d_v=None, res_attention=False, attn_dropout=0., proj_dropout=0., qkv_bias=True, lsa=False):
|
||||
"""Multi Head Attention Layer
|
||||
Input shape:
|
||||
Q: [batch_size (bs) x max_q_len x d_model]
|
||||
K, V: [batch_size (bs) x q_len x d_model]
|
||||
mask: [q_len x q_len]
|
||||
"""
|
||||
super().__init__()
|
||||
d_k = d_model // n_heads if d_k is None else d_k
|
||||
d_v = d_model // n_heads if d_v is None else d_v
|
||||
|
||||
self.n_heads, self.d_k, self.d_v = n_heads, d_k, d_v
|
||||
|
||||
self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)
|
||||
self.W_K = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)
|
||||
self.W_V = nn.Linear(d_model, d_v * n_heads, bias=qkv_bias)
|
||||
|
||||
# Scaled Dot-Product Attention (multiple heads)
|
||||
self.res_attention = res_attention
|
||||
self.sdp_attn = _ScaledDotProductAttention(d_model, n_heads, attn_dropout=attn_dropout, res_attention=self.res_attention, lsa=lsa)
|
||||
|
||||
# Poject output
|
||||
self.to_out = nn.Sequential(nn.Linear(n_heads * d_v, d_model), nn.Dropout(proj_dropout))
|
||||
|
||||
|
||||
def forward(self, Q:Tensor, K:Optional[Tensor]=None, V:Optional[Tensor]=None, prev:Optional[Tensor]=None,
|
||||
key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None):
|
||||
|
||||
bs = Q.size(0)
|
||||
if K is None: K = Q
|
||||
if V is None: V = Q
|
||||
|
||||
# Linear (+ split in multiple heads)
|
||||
q_s = self.W_Q(Q).view(bs, -1, self.n_heads, self.d_k).transpose(1,2) # q_s : [bs x n_heads x max_q_len x d_k]
|
||||
k_s = self.W_K(K).view(bs, -1, self.n_heads, self.d_k).permute(0,2,3,1) # k_s : [bs x n_heads x d_k x q_len] - transpose(1,2) + transpose(2,3)
|
||||
v_s = self.W_V(V).view(bs, -1, self.n_heads, self.d_v).transpose(1,2) # v_s : [bs x n_heads x q_len x d_v]
|
||||
|
||||
# Apply Scaled Dot-Product Attention (multiple heads)
|
||||
if self.res_attention:
|
||||
output, attn_weights, attn_scores = self.sdp_attn(q_s, k_s, v_s, prev=prev, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
|
||||
else:
|
||||
output, attn_weights = self.sdp_attn(q_s, k_s, v_s, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
|
||||
# output: [bs x n_heads x q_len x d_v], attn: [bs x n_heads x q_len x q_len], scores: [bs x n_heads x max_q_len x q_len]
|
||||
|
||||
# back to the original inputs dimensions
|
||||
output = output.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.d_v) # output: [bs x q_len x n_heads * d_v]
|
||||
output = self.to_out(output)
|
||||
|
||||
if self.res_attention: return output, attn_weights, attn_scores
|
||||
else: return output, attn_weights
|
||||
|
||||
|
||||
class _ScaledDotProductAttention(nn.Module):
|
||||
r"""Scaled Dot-Product Attention module (Attention is all you need by Vaswani et al., 2017) with optional residual attention from previous layer
|
||||
(Realformer: Transformer likes residual attention by He et al, 2020) and locality self sttention (Vision Transformer for Small-Size Datasets
|
||||
by Lee et al, 2021)"""
|
||||
|
||||
def __init__(self, d_model, n_heads, attn_dropout=0., res_attention=False, lsa=False):
|
||||
super().__init__()
|
||||
self.attn_dropout = nn.Dropout(attn_dropout)
|
||||
self.res_attention = res_attention
|
||||
head_dim = d_model // n_heads
|
||||
self.scale = nn.Parameter(torch.tensor(head_dim ** -0.5), requires_grad=lsa)
|
||||
self.lsa = lsa
|
||||
|
||||
def forward(self, q:Tensor, k:Tensor, v:Tensor, prev:Optional[Tensor]=None, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None):
|
||||
'''
|
||||
Input shape:
|
||||
q : [bs x n_heads x max_q_len x d_k]
|
||||
k : [bs x n_heads x d_k x seq_len]
|
||||
v : [bs x n_heads x seq_len x d_v]
|
||||
prev : [bs x n_heads x q_len x seq_len]
|
||||
key_padding_mask: [bs x seq_len]
|
||||
attn_mask : [1 x seq_len x seq_len]
|
||||
Output shape:
|
||||
output: [bs x n_heads x q_len x d_v]
|
||||
attn : [bs x n_heads x q_len x seq_len]
|
||||
scores : [bs x n_heads x q_len x seq_len]
|
||||
'''
|
||||
|
||||
# Scaled MatMul (q, k) - similarity scores for all pairs of positions in an input sequence
|
||||
attn_scores = torch.matmul(q, k) * self.scale # attn_scores : [bs x n_heads x max_q_len x q_len]
|
||||
|
||||
# Add pre-softmax attention scores from the previous layer (optional)
|
||||
if prev is not None: attn_scores = attn_scores + prev
|
||||
|
||||
# Attention mask (optional)
|
||||
if attn_mask is not None: # attn_mask with shape [q_len x seq_len] - only used when q_len == seq_len
|
||||
if attn_mask.dtype == torch.bool:
|
||||
attn_scores.masked_fill_(attn_mask, -np.inf)
|
||||
else:
|
||||
attn_scores += attn_mask
|
||||
|
||||
# Key padding mask (optional)
|
||||
if key_padding_mask is not None: # mask with shape [bs x q_len] (only when max_w_len == q_len)
|
||||
attn_scores.masked_fill_(key_padding_mask.unsqueeze(1).unsqueeze(2), -np.inf)
|
||||
|
||||
# normalize the attention weights
|
||||
attn_weights = F.softmax(attn_scores, dim=-1) # attn_weights : [bs x n_heads x max_q_len x q_len]
|
||||
attn_weights = self.attn_dropout(attn_weights)
|
||||
|
||||
# compute the new values given the attention weights
|
||||
output = torch.matmul(attn_weights, v) # output: [bs x n_heads x max_q_len x d_v]
|
||||
|
||||
if self.res_attention: return output, attn_weights, attn_scores
|
||||
else: return output, attn_weights
|
||||
|
121
layers/PatchTST_layers.py
Normal file
121
layers/PatchTST_layers.py
Normal file
@ -0,0 +1,121 @@
|
||||
__all__ = ['Transpose', 'get_activation_fn', 'moving_avg', 'series_decomp', 'PositionalEncoding', 'SinCosPosEncoding', 'Coord2dPosEncoding', 'Coord1dPosEncoding', 'positional_encoding']
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import math
|
||||
|
||||
class Transpose(nn.Module):
|
||||
def __init__(self, *dims, contiguous=False):
|
||||
super().__init__()
|
||||
self.dims, self.contiguous = dims, contiguous
|
||||
def forward(self, x):
|
||||
if self.contiguous: return x.transpose(*self.dims).contiguous()
|
||||
else: return x.transpose(*self.dims)
|
||||
|
||||
|
||||
def get_activation_fn(activation):
|
||||
if callable(activation): return activation()
|
||||
elif activation.lower() == "relu": return nn.ReLU()
|
||||
elif activation.lower() == "gelu": return nn.GELU()
|
||||
raise ValueError(f'{activation} is not available. You can use "relu", "gelu", or a callable')
|
||||
|
||||
|
||||
# decomposition
|
||||
|
||||
class moving_avg(nn.Module):
|
||||
"""
|
||||
Moving average block to highlight the trend of time series
|
||||
"""
|
||||
def __init__(self, kernel_size, stride):
|
||||
super(moving_avg, self).__init__()
|
||||
self.kernel_size = kernel_size
|
||||
self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
# padding on the both ends of time series
|
||||
front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
|
||||
end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
|
||||
x = torch.cat([front, x, end], dim=1)
|
||||
x = self.avg(x.permute(0, 2, 1))
|
||||
x = x.permute(0, 2, 1)
|
||||
return x
|
||||
|
||||
|
||||
class series_decomp(nn.Module):
|
||||
"""
|
||||
Series decomposition block
|
||||
"""
|
||||
def __init__(self, kernel_size):
|
||||
super(series_decomp, self).__init__()
|
||||
self.moving_avg = moving_avg(kernel_size, stride=1)
|
||||
|
||||
def forward(self, x):
|
||||
moving_mean = self.moving_avg(x)
|
||||
res = x - moving_mean
|
||||
return res, moving_mean
|
||||
|
||||
|
||||
|
||||
# pos_encoding
|
||||
|
||||
def PositionalEncoding(q_len, d_model, normalize=True):
|
||||
pe = torch.zeros(q_len, d_model)
|
||||
position = torch.arange(0, q_len).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
if normalize:
|
||||
pe = pe - pe.mean()
|
||||
pe = pe / (pe.std() * 10)
|
||||
return pe
|
||||
|
||||
SinCosPosEncoding = PositionalEncoding
|
||||
|
||||
def Coord2dPosEncoding(q_len, d_model, exponential=False, normalize=True, eps=1e-3, verbose=False):
|
||||
x = .5 if exponential else 1
|
||||
i = 0
|
||||
for i in range(100):
|
||||
cpe = 2 * (torch.linspace(0, 1, q_len).reshape(-1, 1) ** x) * (torch.linspace(0, 1, d_model).reshape(1, -1) ** x) - 1
|
||||
# pv(f'{i:4.0f} {x:5.3f} {cpe.mean():+6.3f}', verbose)
|
||||
if abs(cpe.mean()) <= eps: break
|
||||
elif cpe.mean() > eps: x += .001
|
||||
else: x -= .001
|
||||
i += 1
|
||||
if normalize:
|
||||
cpe = cpe - cpe.mean()
|
||||
cpe = cpe / (cpe.std() * 10)
|
||||
return cpe
|
||||
|
||||
def Coord1dPosEncoding(q_len, exponential=False, normalize=True):
|
||||
cpe = (2 * (torch.linspace(0, 1, q_len).reshape(-1, 1)**(.5 if exponential else 1)) - 1)
|
||||
if normalize:
|
||||
cpe = cpe - cpe.mean()
|
||||
cpe = cpe / (cpe.std() * 10)
|
||||
return cpe
|
||||
|
||||
def positional_encoding(pe, learn_pe, q_len, d_model):
|
||||
# Positional encoding
|
||||
if pe == None:
|
||||
W_pos = torch.empty((q_len, d_model)) # pe = None and learn_pe = False can be used to measure impact of pe
|
||||
nn.init.uniform_(W_pos, -0.02, 0.02)
|
||||
learn_pe = False
|
||||
elif pe == 'zero':
|
||||
W_pos = torch.empty((q_len, 1))
|
||||
nn.init.uniform_(W_pos, -0.02, 0.02)
|
||||
elif pe == 'zeros':
|
||||
W_pos = torch.empty((q_len, d_model))
|
||||
nn.init.uniform_(W_pos, -0.02, 0.02)
|
||||
elif pe == 'normal' or pe == 'gauss':
|
||||
W_pos = torch.zeros((q_len, 1))
|
||||
torch.nn.init.normal_(W_pos, mean=0.0, std=0.1)
|
||||
elif pe == 'uniform':
|
||||
W_pos = torch.zeros((q_len, 1))
|
||||
nn.init.uniform_(W_pos, a=0.0, b=0.1)
|
||||
elif pe == 'lin1d': W_pos = Coord1dPosEncoding(q_len, exponential=False, normalize=True)
|
||||
elif pe == 'exp1d': W_pos = Coord1dPosEncoding(q_len, exponential=True, normalize=True)
|
||||
elif pe == 'lin2d': W_pos = Coord2dPosEncoding(q_len, d_model, exponential=False, normalize=True)
|
||||
elif pe == 'exp2d': W_pos = Coord2dPosEncoding(q_len, d_model, exponential=True, normalize=True)
|
||||
elif pe == 'sincos': W_pos = PositionalEncoding(q_len, d_model, normalize=True)
|
||||
else: raise ValueError(f"{pe} is not a valid pe (positional encoder. Available types: 'gauss'=='normal', \
|
||||
'zeros', 'zero', uniform', 'lin1d', 'exp1d', 'lin2d', 'exp2d', 'sincos', None.)")
|
||||
return nn.Parameter(W_pos, requires_grad=learn_pe)
|
135
layers/Transformer_EncDec.py
Normal file
135
layers/Transformer_EncDec.py
Normal file
@ -0,0 +1,135 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class ConvLayer(nn.Module):
|
||||
def __init__(self, c_in):
|
||||
super(ConvLayer, self).__init__()
|
||||
self.downConv = nn.Conv1d(in_channels=c_in,
|
||||
out_channels=c_in,
|
||||
kernel_size=3,
|
||||
padding=2,
|
||||
padding_mode='circular')
|
||||
self.norm = nn.BatchNorm1d(c_in)
|
||||
self.activation = nn.ELU()
|
||||
self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.downConv(x.permute(0, 2, 1))
|
||||
x = self.norm(x)
|
||||
x = self.activation(x)
|
||||
x = self.maxPool(x)
|
||||
x = x.transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
|
||||
super(EncoderLayer, self).__init__()
|
||||
d_ff = d_ff or 4 * d_model
|
||||
self.attention = attention
|
||||
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
|
||||
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.activation = F.relu if activation == "relu" else F.gelu
|
||||
|
||||
def forward(self, x, attn_mask=None, tau=None, delta=None):
|
||||
new_x, attn = self.attention(
|
||||
x, x, x,
|
||||
attn_mask=attn_mask,
|
||||
tau=tau, delta=delta
|
||||
)
|
||||
x = x + self.dropout(new_x)
|
||||
|
||||
y = x = self.norm1(x)
|
||||
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
|
||||
y = self.dropout(self.conv2(y).transpose(-1, 1))
|
||||
|
||||
return self.norm2(x + y), attn
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
|
||||
super(Encoder, self).__init__()
|
||||
self.attn_layers = nn.ModuleList(attn_layers)
|
||||
self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
|
||||
self.norm = norm_layer
|
||||
|
||||
def forward(self, x, attn_mask=None, tau=None, delta=None):
|
||||
# x [B, L, D]
|
||||
attns = []
|
||||
if self.conv_layers is not None:
|
||||
for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)):
|
||||
delta = delta if i == 0 else None
|
||||
x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
|
||||
x = conv_layer(x)
|
||||
attns.append(attn)
|
||||
x, attn = self.attn_layers[-1](x, tau=tau, delta=None)
|
||||
attns.append(attn)
|
||||
else:
|
||||
for attn_layer in self.attn_layers:
|
||||
x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
|
||||
attns.append(attn)
|
||||
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
|
||||
return x, attns
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
def __init__(self, self_attention, cross_attention, d_model, d_ff=None,
|
||||
dropout=0.1, activation="relu"):
|
||||
super(DecoderLayer, self).__init__()
|
||||
d_ff = d_ff or 4 * d_model
|
||||
self.self_attention = self_attention
|
||||
self.cross_attention = cross_attention
|
||||
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
|
||||
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.norm3 = nn.LayerNorm(d_model)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.activation = F.relu if activation == "relu" else F.gelu
|
||||
|
||||
def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
|
||||
x = x + self.dropout(self.self_attention(
|
||||
x, x, x,
|
||||
attn_mask=x_mask,
|
||||
tau=tau, delta=None
|
||||
)[0])
|
||||
x = self.norm1(x)
|
||||
|
||||
x = x + self.dropout(self.cross_attention(
|
||||
x, cross, cross,
|
||||
attn_mask=cross_mask,
|
||||
tau=tau, delta=delta
|
||||
)[0])
|
||||
|
||||
y = x = self.norm2(x)
|
||||
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
|
||||
y = self.dropout(self.conv2(y).transpose(-1, 1))
|
||||
|
||||
return self.norm3(x + y)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, layers, norm_layer=None, projection=None):
|
||||
super(Decoder, self).__init__()
|
||||
self.layers = nn.ModuleList(layers)
|
||||
self.norm = norm_layer
|
||||
self.projection = projection
|
||||
|
||||
def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
|
||||
for layer in self.layers:
|
||||
x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta)
|
||||
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
|
||||
if self.projection is not None:
|
||||
x = self.projection(x)
|
||||
return x
|
132
layers/cross_channel_attn.py
Normal file
132
layers/cross_channel_attn.py
Normal file
@ -0,0 +1,132 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import math
|
||||
|
||||
class CrossChannelAttention(nn.Module):
|
||||
"""
|
||||
对每个通道 i:
|
||||
Query: 预测区间长度 pred_len 的可学习向量 (不依赖历史值)
|
||||
Key/Value: 其它通道所有时间点的标量 -> 线性投影
|
||||
输出: [B, pred_len, C]
|
||||
复杂度: O(C * pred_len * (C-1) * L),当 C 很小(如 <= 64)通常可接受
|
||||
"""
|
||||
def __init__(self, seq_len, pred_len, c_in,
|
||||
d_model=64, n_heads=4,
|
||||
dropout=0.1, use_layernorm=True):
|
||||
super().__init__()
|
||||
assert d_model % n_heads == 0, "d_model 必须能整除 n_heads"
|
||||
self.seq_len = seq_len
|
||||
self.pred_len = pred_len
|
||||
self.c_in = c_in
|
||||
self.d_model = d_model
|
||||
self.n_heads = n_heads
|
||||
self.head_dim = d_model // n_heads
|
||||
|
||||
# 可学习的预测步 Query Embeddings: [pred_len, d_model]
|
||||
self.query_embed = nn.Parameter(torch.randn(pred_len, d_model))
|
||||
|
||||
# 标量值 -> d_model 投影 (共享给 Key/Value)
|
||||
self.key_proj = nn.Linear(1, d_model)
|
||||
self.value_proj = nn.Linear(1, d_model)
|
||||
|
||||
# 输出压缩成标量
|
||||
self.out_proj = nn.Linear(d_model, 1)
|
||||
|
||||
self.attn_dropout = nn.Dropout(dropout)
|
||||
self.proj_dropout = nn.Dropout(dropout)
|
||||
|
||||
self.use_ln = use_layernorm
|
||||
if use_layernorm:
|
||||
self.ln_q = nn.LayerNorm(d_model)
|
||||
self.ln_kv = nn.LayerNorm(d_model)
|
||||
|
||||
# 可选的时间 + 通道位置编码(简单可学习向量)
|
||||
self.time_pos = nn.Parameter(torch.zeros(seq_len, d_model))
|
||||
self.channel_pos = nn.Parameter(torch.zeros(c_in, d_model))
|
||||
nn.init.normal_(self.time_pos, std=0.02)
|
||||
nn.init.normal_(self.channel_pos, std=0.02)
|
||||
nn.init.normal_(self.query_embed, std=0.02)
|
||||
|
||||
def split_heads(self, x):
|
||||
# x: [B, T, d_model] -> [B, n_heads, T, head_dim]
|
||||
B, T, D = x.shape
|
||||
return x.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
def merge_heads(self, x):
|
||||
# x: [B, n_heads, T, head_dim] -> [B, T, d_model]
|
||||
B, H, T, Hd = x.shape
|
||||
return x.transpose(1, 2).contiguous().view(B, T, H * Hd)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: [B, L, C]
|
||||
返回: cross_out [B, pred_len, C]
|
||||
"""
|
||||
B, L, C = x.shape
|
||||
assert L == self.seq_len and C == self.c_in
|
||||
|
||||
# 准备 K,V: 对每个通道的时间序列做投影
|
||||
# 先变成 [B, C, L, 1] -> Linear -> [B, C, L, d]
|
||||
xc = x.permute(0, 2, 1).unsqueeze(-1) # [B, C, L, 1]
|
||||
K = self.key_proj(xc) # [B, C, L, d_model]
|
||||
V = self.value_proj(xc) # [B, C, L, d_model]
|
||||
|
||||
# 加位置编码(通道 + 时间)
|
||||
# broadcast: time_pos [L, d_model] -> [1,1,L,d]; channel_pos [C,d_model]->[1,C,1,d]
|
||||
K = K + self.time_pos.unsqueeze(0).unsqueeze(0) + self.channel_pos.unsqueeze(0).unsqueeze(2)
|
||||
V = V + self.time_pos.unsqueeze(0).unsqueeze(0) + self.channel_pos.unsqueeze(0).unsqueeze(2)
|
||||
|
||||
if self.use_ln:
|
||||
K = self.ln_kv(K)
|
||||
V = self.ln_kv(V)
|
||||
|
||||
cross_outputs = []
|
||||
|
||||
# 预备 Query(所有通道共享 query 基形,再可选加通道偏移)
|
||||
base_q = self.query_embed # [pred_len, d_model]
|
||||
|
||||
for ci in range(C):
|
||||
# 构造其它通道索引
|
||||
if C == 1:
|
||||
# 单通道退化: 直接输出零或复制自身
|
||||
zero_out = x[:, -self.pred_len:, ci:ci+1]
|
||||
cross_outputs.append(zero_out)
|
||||
continue
|
||||
other_idx = [j for j in range(C) if j != ci]
|
||||
|
||||
K_i = K[:, other_idx, :, :] # [B, C-1, L, d_model]
|
||||
V_i = V[:, other_idx, :, :] # [B, C-1, L, d_model]
|
||||
|
||||
# 拉平成 token 维度
|
||||
K_i = K_i.reshape(B, (C-1)*L, self.d_model) # [B, (C-1)*L, d]
|
||||
V_i = V_i.reshape(B, (C-1)*L, self.d_model)
|
||||
|
||||
# Query: 复制到 batch,并加通道偏移(可选)
|
||||
Q_i = base_q.unsqueeze(0).expand(B, self.pred_len, self.d_model) # [B, pred_len, d_model]
|
||||
Q_i = Q_i + self.channel_pos[ci].unsqueeze(0).unsqueeze(0)
|
||||
|
||||
if self.use_ln:
|
||||
Q_i = self.ln_q(Q_i)
|
||||
|
||||
# 分头
|
||||
Qh = self.split_heads(Q_i) # [B, H, pred_len, head_dim]
|
||||
Kh = self.split_heads(K_i) # [B, H, (C-1)*L, head_dim]
|
||||
Vh = self.split_heads(V_i) # [B, H, (C-1)*L, head_dim]
|
||||
|
||||
# Attention: [B, H, pred_len, (C-1)*L]
|
||||
scores = torch.matmul(Qh, Kh.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
attn = torch.softmax(scores, dim=-1)
|
||||
attn = self.attn_dropout(attn)
|
||||
|
||||
# 上下文
|
||||
ctx = torch.matmul(attn, Vh) # [B, H, pred_len, head_dim]
|
||||
ctx = self.merge_heads(ctx) # [B, pred_len, d_model]
|
||||
ctx = self.proj_dropout(ctx)
|
||||
|
||||
# 输出压缩到标量: [B, pred_len, 1]
|
||||
out_ci = self.out_proj(ctx)
|
||||
cross_outputs.append(out_ci) # list of [B, pred_len, 1]
|
||||
|
||||
cross_out = torch.cat(cross_outputs, dim=-1) # [B, pred_len, C]
|
||||
return cross_out
|
||||
|
136
layers/mixer.py
Normal file
136
layers/mixer.py
Normal file
@ -0,0 +1,136 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
class ChannelGraphMixer(nn.Module):
|
||||
"""
|
||||
在 PatchTST 的通道独立输出上做一次可学习的稀疏跨通道交互.
|
||||
输入 z_list : 长度=M 的 list, 每个元素形状 [B, D] (单通道表示)
|
||||
输出 list, 形状同输入
|
||||
"""
|
||||
def __init__(self, n_channel: int, dim: int, k: int = 5, tau: float = 0.2):
|
||||
super().__init__()
|
||||
self.k = k
|
||||
self.tau = tau
|
||||
self.dim = dim
|
||||
self.A = nn.Parameter(torch.zeros(n_channel, n_channel)) # 可学习邻接
|
||||
self.mix = nn.Linear(dim, dim, bias=False) # 通道映射
|
||||
# 通道注意力过滤 CAF
|
||||
self.se = nn.Sequential(
|
||||
nn.Linear(dim, dim // 4, bias=False),
|
||||
nn.ReLU(),
|
||||
nn.Linear(dim // 4, 1, bias=False),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
# -------- util: 生成 row-wise top-k 稀疏邻接 -------------
|
||||
def _row_sparse(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
logits: [M,M]. 返回一个稀疏矩阵, 每行只保留 top-k, 其余为 0
|
||||
"""
|
||||
# Straight-through Gumbel–Softmax
|
||||
g = -torch.empty_like(logits).exponential_().log()
|
||||
y = (logits + g) / self.tau
|
||||
probs = F.softmax(y, dim=-1) # 可导
|
||||
topk_val, _ = torch.topk(probs, self.k, dim=-1)
|
||||
thr = topk_val[..., -1].unsqueeze(-1) # 每行阈值
|
||||
sparse = torch.where(probs >= thr, probs, torch.zeros_like(probs))
|
||||
return sparse.detach() + probs - probs.detach() # ST-estimator
|
||||
|
||||
# --------------------------------------------------------
|
||||
def forward(self, z_list):
|
||||
M = len(z_list)
|
||||
B = z_list[0].shape[0]
|
||||
Z = torch.stack(z_list, dim=1) # [B,M,D]
|
||||
alpha = self.se(Z).squeeze(-1) # [B,M] 通道重要性
|
||||
|
||||
A_sparse = self._row_sparse(self.A) # [M,M]
|
||||
|
||||
out = []
|
||||
for i in range(M):
|
||||
# 汇聚来自其余通道的表示
|
||||
agg = 0
|
||||
for j in range(M):
|
||||
if A_sparse[i, j] != 0:
|
||||
agg = agg + alpha[:, j:j+1] * A_sparse[i, j] * self.mix(z_list[j])
|
||||
out.append(z_list[i] + agg) # 残差
|
||||
return out
|
||||
|
||||
|
||||
class HierarchicalGraphMixer(nn.Module):
|
||||
"""
|
||||
分层图混合器,同时考虑宏观通道关系和微观 Patch 级别注意力。
|
||||
输入 z : 形状为 [B, C, N, D] 的张量
|
||||
输出 z_out : 形状同输入
|
||||
"""
|
||||
def __init__(self, n_channel: int, dim: int, k: int = 5, tau: float = 0.2):
|
||||
super().__init__()
|
||||
self.k = k
|
||||
self.tau = tau
|
||||
|
||||
# Level 1: Channel Graph
|
||||
self.A = nn.Parameter(torch.zeros(n_channel, n_channel))
|
||||
self.se = nn.Sequential(
|
||||
nn.Linear(dim, dim // 4, bias=False), nn.ReLU(),
|
||||
nn.Linear(dim // 4, 1, bias=False), nn.Sigmoid()
|
||||
)
|
||||
|
||||
# Level 2: Patch Cross-Attention
|
||||
self.q_proj = nn.Linear(dim, dim)
|
||||
self.k_proj = nn.Linear(dim, dim)
|
||||
self.v_proj = nn.Linear(dim, dim)
|
||||
self.out_proj = nn.Linear(dim, dim)
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
def _row_sparse(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
# (Gumbel-Softmax 函数,无需改变)
|
||||
g = -torch.empty_like(logits).exponential_().log()
|
||||
y = (logits + g) / self.tau
|
||||
probs = F.softmax(y, dim=-1)
|
||||
topk_val, _ = torch.topk(probs, self.k, dim=-1)
|
||||
thr = topk_val[..., -1].unsqueeze(-1)
|
||||
sparse = torch.where(probs >= thr, probs, torch.zeros_like(probs))
|
||||
return sparse.detach() + probs - probs.detach()
|
||||
|
||||
def forward(self, z):
|
||||
# z 的形状: [B, C, N, D]
|
||||
B, C, N, D = z.shape
|
||||
|
||||
# --- Level 1: 计算宏观权重 ---
|
||||
A_sparse = self._row_sparse(self.A) # 通道连接稀疏图 A_sparse: [C, C]
|
||||
|
||||
# --- Level 2: 跨通道 Patch 交互 ---
|
||||
out_z = torch.zeros_like(z)
|
||||
for i in range(C): # 遍历每个目标通道 i
|
||||
target_z = z[:, i, :, :] # [B, N, D]
|
||||
|
||||
# 准备聚合来自其他通道的 patch 级别上下文
|
||||
aggregated_context = torch.zeros_like(target_z)
|
||||
|
||||
for j in range(C): # 遍历每个源通道 j
|
||||
if A_sparse[i, j] != 0:
|
||||
source_z = z[:, j, :, :] # [B, N, D]
|
||||
|
||||
# --- 执行交叉注意力 ---
|
||||
Q = self.q_proj(target_z) # Query 来自目标通道 i
|
||||
K = self.k_proj(source_z) # Key 来自源通道 j
|
||||
V = self.v_proj(source_z) # Value 来自源通道 j
|
||||
|
||||
attn_scores = torch.bmm(Q, K.transpose(1, 2)) / math.sqrt(D)
|
||||
attn_probs = F.softmax(attn_scores, dim=-1) # [B, N, N]
|
||||
|
||||
context = torch.bmm(attn_probs, V) # [B, N, D], 从 j 聚合到 i 的上下文
|
||||
|
||||
# 宏观权重2 (通道连接强度): A_sparse[i, j] -> 标量
|
||||
|
||||
|
||||
# 加权上下文
|
||||
weighted_context = A_sparse[i, j] * context
|
||||
aggregated_context = aggregated_context + weighted_context
|
||||
|
||||
# 将聚合后的上下文通过输出层,并与原始目标表示相加(残差连接)
|
||||
# LayerNorm 增加稳定性
|
||||
out_z[:, i, :, :] = self.norm(target_z + self.out_proj(aggregated_context))
|
||||
|
||||
return out_z
|
239
layers/ps_loss.py
Normal file
239
layers/ps_loss.py
Normal file
@ -0,0 +1,239 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
|
||||
class PSLoss(nn.Module):
|
||||
"""
|
||||
Patch-wise Structural (PS) Loss for time series forecasting.
|
||||
|
||||
Implements the loss function described in the paper that combines:
|
||||
1. Fourier-based Adaptive Patching (FAP)
|
||||
2. Patch-wise Structural Loss (correlation, variance, mean)
|
||||
3. Gradient-based Dynamic Weighting (GDW)
|
||||
"""
|
||||
|
||||
def __init__(self, patch_len_threshold=64, lambda_ps=5.0, use_gdw=True):
|
||||
super(PSLoss, self).__init__()
|
||||
self.patch_len_threshold = patch_len_threshold
|
||||
self.lambda_ps = lambda_ps
|
||||
self.use_gdw = use_gdw
|
||||
self.kl_loss = nn.KLDivLoss(reduction='none')
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
def create_patches(self, x, patch_len, stride):
|
||||
"""
|
||||
Create patches from time series data.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape [B, L, C]
|
||||
patch_len: Length of each patch
|
||||
stride: Stride for patching
|
||||
|
||||
Returns:
|
||||
patches: Tensor of shape [B, C, N, P] where N is number of patches, P is patch length
|
||||
"""
|
||||
B, L, C = x.shape
|
||||
|
||||
num_patches = (L - patch_len) // stride + 1
|
||||
patches = x.unfold(1, patch_len, stride) # [B, N, C, P]
|
||||
patches = patches.permute(0, 2, 1, 3) # [B, C, N, P]
|
||||
|
||||
return patches
|
||||
|
||||
def fourier_based_adaptive_patching(self, true, pred):
|
||||
"""
|
||||
Fourier-based Adaptive Patching (FAP) to determine optimal patch length.
|
||||
|
||||
Args:
|
||||
true: Ground truth tensor [B, L, C]
|
||||
pred: Prediction tensor [B, L, C]
|
||||
|
||||
Returns:
|
||||
true_patch: Patches from ground truth [B, C, N, P]
|
||||
pred_patch: Patches from prediction [B, C, N, P]
|
||||
"""
|
||||
# Get dominant frequency from ground truth
|
||||
true_fft = torch.fft.rfft(true, dim=1)
|
||||
frequency_list = torch.abs(true_fft).mean(0).mean(-1)
|
||||
frequency_list[:1] = 0.0 # Remove DC component
|
||||
top_index = torch.argmax(frequency_list)
|
||||
|
||||
# Calculate period and patch length
|
||||
period = true.shape[1] // top_index if top_index > 0 else self.patch_len_threshold
|
||||
patch_len = min(period // 2, self.patch_len_threshold)
|
||||
patch_len = max(patch_len, 4) # Minimum patch length
|
||||
stride = max(patch_len // 2, 1)
|
||||
|
||||
# Create patches
|
||||
true_patch = self.create_patches(true, patch_len, stride)
|
||||
pred_patch = self.create_patches(pred, patch_len, stride)
|
||||
|
||||
return true_patch, pred_patch
|
||||
|
||||
def patch_wise_structural_loss(self, true_patch, pred_patch):
|
||||
"""
|
||||
Calculate patch-wise structural loss components.
|
||||
|
||||
Args:
|
||||
true_patch: Ground truth patches [B, C, N, P]
|
||||
pred_patch: Prediction patches [B, C, N, P]
|
||||
|
||||
Returns:
|
||||
corr_loss: Correlation loss
|
||||
var_loss: Variance loss
|
||||
mean_loss: Mean loss
|
||||
"""
|
||||
# Calculate patch statistics
|
||||
true_patch_mean = torch.mean(true_patch, dim=-1, keepdim=True) # [B, C, N, 1]
|
||||
pred_patch_mean = torch.mean(pred_patch, dim=-1, keepdim=True) # [B, C, N, 1]
|
||||
|
||||
true_patch_var = torch.var(true_patch, dim=-1, keepdim=True, unbiased=False)
|
||||
pred_patch_var = torch.var(pred_patch, dim=-1, keepdim=True, unbiased=False)
|
||||
true_patch_std = torch.sqrt(true_patch_var + 1e-8)
|
||||
pred_patch_std = torch.sqrt(pred_patch_var + 1e-8)
|
||||
|
||||
# Calculate covariance
|
||||
true_centered = true_patch - true_patch_mean
|
||||
pred_centered = pred_patch - pred_patch_mean
|
||||
covariance = torch.mean(true_centered * pred_centered, dim=-1, keepdim=True)
|
||||
|
||||
# 1. Correlation Loss (based on Pearson correlation coefficient)
|
||||
correlation = covariance / (true_patch_std * pred_patch_std + 1e-8)
|
||||
corr_loss = (1.0 - correlation).mean()
|
||||
|
||||
# 2. Variance Loss (using KL divergence of softmax distributions)
|
||||
true_patch_softmax = F.softmax(true_patch, dim=-1)
|
||||
pred_patch_log_softmax = F.log_softmax(pred_patch, dim=-1)
|
||||
var_loss = self.kl_loss(pred_patch_log_softmax, true_patch_softmax).sum(dim=-1).mean()
|
||||
|
||||
# 3. Mean Loss (MAE between patch means)
|
||||
mean_loss = torch.abs(true_patch_mean - pred_patch_mean).mean()
|
||||
|
||||
return corr_loss, var_loss, mean_loss
|
||||
|
||||
def gradient_based_dynamic_weighting(self, model, corr_loss, var_loss, mean_loss, true, pred):
|
||||
"""
|
||||
Gradient-based Dynamic Weighting (GDW) for balancing loss components.
|
||||
|
||||
Args:
|
||||
model: The neural network model
|
||||
corr_loss, var_loss, mean_loss: Loss components
|
||||
true: Ground truth tensor [B, L, C]
|
||||
pred: Prediction tensor [B, L, C]
|
||||
|
||||
Returns:
|
||||
alpha, beta, gamma: Dynamic weights for the three loss components
|
||||
"""
|
||||
if not self.use_gdw or not self.training:
|
||||
return 1.0, 1.0, 1.0
|
||||
|
||||
try:
|
||||
# Get model parameters for gradient calculation
|
||||
if hasattr(model, 'projector'):
|
||||
params = list(model.projector.parameters())
|
||||
elif hasattr(model, 'projection'):
|
||||
params = list(model.projection.parameters())
|
||||
else:
|
||||
# Use output layer parameters
|
||||
params = list(model.parameters())[-2:] # Last linear layer
|
||||
|
||||
if not params:
|
||||
return 1.0, 1.0, 1.0
|
||||
|
||||
# Calculate gradients for each loss component
|
||||
corr_grad = torch.autograd.grad(corr_loss, params, retain_graph=True, allow_unused=True)
|
||||
var_grad = torch.autograd.grad(var_loss, params, retain_graph=True, allow_unused=True)
|
||||
mean_grad = torch.autograd.grad(mean_loss, params, retain_graph=True, allow_unused=True)
|
||||
|
||||
# Filter out None gradients and calculate norms
|
||||
corr_grad_norm = sum(g.norm().item() for g in corr_grad if g is not None)
|
||||
var_grad_norm = sum(g.norm().item() for g in var_grad if g is not None)
|
||||
mean_grad_norm = sum(g.norm().item() for g in mean_grad if g is not None)
|
||||
|
||||
if corr_grad_norm == 0 or var_grad_norm == 0 or mean_grad_norm == 0:
|
||||
return 1.0, 1.0, 1.0
|
||||
|
||||
# Calculate average gradient magnitude
|
||||
grad_avg = (corr_grad_norm + var_grad_norm + mean_grad_norm) / 3.0
|
||||
|
||||
# Calculate dynamic weights
|
||||
alpha = grad_avg / corr_grad_norm
|
||||
beta = grad_avg / var_grad_norm
|
||||
gamma = grad_avg / mean_grad_norm
|
||||
|
||||
# Calculate scaling factors for gamma
|
||||
true_flat = true.view(-1, true.shape[-1])
|
||||
pred_flat = pred.view(-1, pred.shape[-1])
|
||||
|
||||
true_mean = torch.mean(true_flat, dim=0, keepdim=True)
|
||||
pred_mean = torch.mean(pred_flat, dim=0, keepdim=True)
|
||||
true_std = torch.std(true_flat, dim=0, keepdim=True) + 1e-8
|
||||
pred_std = torch.std(pred_flat, dim=0, keepdim=True) + 1e-8
|
||||
|
||||
covariance = torch.mean((true_flat - true_mean) * (pred_flat - pred_mean), dim=0, keepdim=True)
|
||||
correlation_sim = covariance / (true_std * pred_std)
|
||||
correlation_sim = (1.0 + correlation_sim.mean()) * 0.5
|
||||
|
||||
variance_sim = (2 * true_std.mean() * pred_std.mean()) / (true_std.mean()**2 + pred_std.mean()**2)
|
||||
|
||||
c = 0.5 * (1 + correlation_sim)
|
||||
v = variance_sim
|
||||
gamma = gamma * c * v
|
||||
|
||||
return alpha.item(), beta.item(), gamma.item()
|
||||
|
||||
except Exception:
|
||||
# Fallback to equal weights if gradient calculation fails
|
||||
return 1.0, 1.0, 1.0
|
||||
|
||||
def forward(self, pred, true, model=None):
|
||||
"""
|
||||
Forward pass of PS Loss.
|
||||
|
||||
Args:
|
||||
pred: Predictions [B, L, C]
|
||||
true: Ground truth [B, L, C]
|
||||
model: Neural network model (for gradient-based weighting)
|
||||
|
||||
Returns:
|
||||
total_loss: Combined MSE + PS loss
|
||||
loss_dict: Dictionary with individual loss components
|
||||
"""
|
||||
# Standard MSE loss
|
||||
mse_loss = self.mse_loss(pred, true)
|
||||
|
||||
# Fourier-based adaptive patching
|
||||
true_patch, pred_patch = self.fourier_based_adaptive_patching(true, pred)
|
||||
|
||||
# Patch-wise structural loss
|
||||
corr_loss, var_loss, mean_loss = self.patch_wise_structural_loss(true_patch, pred_patch)
|
||||
|
||||
# Gradient-based dynamic weighting
|
||||
if model is not None and self.use_gdw:
|
||||
alpha, beta, gamma = self.gradient_based_dynamic_weighting(
|
||||
model, corr_loss, var_loss, mean_loss, true, pred
|
||||
)
|
||||
else:
|
||||
alpha, beta, gamma = 1.0, 1.0, 1.0
|
||||
|
||||
# Combine PS loss components
|
||||
ps_loss = alpha * corr_loss + beta * var_loss + gamma * mean_loss
|
||||
|
||||
# Total loss: MSE + λ * PS_loss
|
||||
total_loss = mse_loss + self.lambda_ps * ps_loss
|
||||
|
||||
loss_dict = {
|
||||
'mse_loss': mse_loss.item(),
|
||||
'ps_loss': ps_loss.item(),
|
||||
'corr_loss': corr_loss.item(),
|
||||
'var_loss': var_loss.item(),
|
||||
'mean_loss': mean_loss.item(),
|
||||
'alpha': alpha,
|
||||
'beta': beta,
|
||||
'gamma': gamma,
|
||||
'total_loss': total_loss.item()
|
||||
}
|
||||
|
||||
return total_loss, loss_dict
|
Reference in New Issue
Block a user