feat(model): add initial PatchTST model architecture and utilities

This commit is contained in:
game-loader
2025-08-28 13:23:06 +08:00
parent 4129832f98
commit 59b23d4637
6 changed files with 1142 additions and 0 deletions

379
layers/PatchTST_backbone.py Normal file
View File

@ -0,0 +1,379 @@
__all__ = ['PatchTST_backbone']
# Cell
from typing import Callable, Optional
import torch
from torch import nn
from torch import Tensor
import torch.nn.functional as F
import numpy as np
#from collections import OrderedDict
from layers.PatchTST_layers import *
from layers.revin import RevIN
# Cell
class PatchTST_backbone(nn.Module):
def __init__(self, c_in:int, context_window:int, target_window:int, patch_len:int, stride:int, max_seq_len:Optional[int]=1024,
n_layers:int=3, d_model=128, n_heads=16, d_k:Optional[int]=None, d_v:Optional[int]=None,
d_ff:int=256, norm:str='BatchNorm', attn_dropout:float=0., dropout:float=0., act:str="gelu", key_padding_mask:bool='auto',
padding_var:Optional[int]=None, attn_mask:Optional[Tensor]=None, res_attention:bool=True, pre_norm:bool=False, store_attn:bool=False,
pe:str='zeros', learn_pe:bool=True, fc_dropout:float=0., head_dropout = 0, padding_patch = None,
pretrain_head:bool=False, head_type = 'flatten', individual = False, revin = True, affine = True, subtract_last = False,
verbose:bool=False, **kwargs):
super().__init__()
# RevIn
self.revin = revin
if self.revin: self.revin_layer = RevIN(c_in, affine=affine, subtract_last=subtract_last)
# Patching
self.patch_len = patch_len
self.stride = stride
self.padding_patch = padding_patch
patch_num = int((context_window - patch_len)/stride + 1)
if padding_patch == 'end': # can be modified to general case
self.padding_patch_layer = nn.ReplicationPad1d((0, stride))
patch_num += 1
# Backbone
self.backbone = TSTiEncoder(c_in, patch_num=patch_num, patch_len=patch_len, max_seq_len=max_seq_len,
n_layers=n_layers, d_model=d_model, n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff,
attn_dropout=attn_dropout, dropout=dropout, act=act, key_padding_mask=key_padding_mask, padding_var=padding_var,
attn_mask=attn_mask, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn,
pe=pe, learn_pe=learn_pe, verbose=verbose, **kwargs)
# Head
self.head_nf = d_model * patch_num
self.n_vars = c_in
self.pretrain_head = pretrain_head
self.head_type = head_type
self.individual = individual
if self.pretrain_head:
self.head = self.create_pretrain_head(self.head_nf, c_in, fc_dropout) # custom head passed as a partial func with all its kwargs
elif head_type == 'flatten':
self.head = Flatten_Head(self.individual, self.n_vars, self.head_nf, target_window, head_dropout=head_dropout)
def forward(self, z): # z: [bs x nvars x seq_len]
# norm
if self.revin:
z = z.permute(0,2,1)
z = self.revin_layer(z, 'norm')
z = z.permute(0,2,1)
# do patching
if self.padding_patch == 'end':
z = self.padding_patch_layer(z)
z = z.unfold(dimension=-1, size=self.patch_len, step=self.stride) # z: [bs x nvars x patch_num x patch_len]
z = z.permute(0,1,3,2) # z: [bs x nvars x patch_len x patch_num]
# model
z = self.backbone(z) # z: [bs x nvars x d_model x patch_num]
z = self.head(z) # z: [bs x nvars x target_window]
# denorm
if self.revin:
z = z.permute(0,2,1)
z = self.revin_layer(z, 'denorm')
z = z.permute(0,2,1)
return z
def create_pretrain_head(self, head_nf, vars, dropout):
return nn.Sequential(nn.Dropout(dropout),
nn.Conv1d(head_nf, vars, 1)
)
class Flatten_Head(nn.Module):
def __init__(self, individual, n_vars, nf, target_window, head_dropout=0):
super().__init__()
self.individual = individual
self.n_vars = n_vars
if self.individual:
self.linears = nn.ModuleList()
self.dropouts = nn.ModuleList()
self.flattens = nn.ModuleList()
for i in range(self.n_vars):
self.flattens.append(nn.Flatten(start_dim=-2))
self.linears.append(nn.Linear(nf, target_window))
self.dropouts.append(nn.Dropout(head_dropout))
else:
self.flatten = nn.Flatten(start_dim=-2)
self.linear = nn.Linear(nf, target_window)
self.dropout = nn.Dropout(head_dropout)
def forward(self, x): # x: [bs x nvars x d_model x patch_num]
if self.individual:
x_out = []
for i in range(self.n_vars):
z = self.flattens[i](x[:,i,:,:]) # z: [bs x d_model * patch_num]
z = self.linears[i](z) # z: [bs x target_window]
z = self.dropouts[i](z)
x_out.append(z)
x = torch.stack(x_out, dim=1) # x: [bs x nvars x target_window]
else:
x = self.flatten(x)
x = self.linear(x)
x = self.dropout(x)
return x
class TSTiEncoder(nn.Module): #i means channel-independent
def __init__(self, c_in, patch_num, patch_len, max_seq_len=1024,
n_layers=3, d_model=128, n_heads=16, d_k=None, d_v=None,
d_ff=256, norm='BatchNorm', attn_dropout=0., dropout=0., act="gelu", store_attn=False,
key_padding_mask='auto', padding_var=None, attn_mask=None, res_attention=True, pre_norm=False,
pe='zeros', learn_pe=True, verbose=False, **kwargs):
super().__init__()
self.patch_num = patch_num
self.patch_len = patch_len
# Input encoding
q_len = patch_num
self.W_P = nn.Linear(patch_len, d_model) # Eq 1: projection of feature vectors onto a d-dim vector space
self.seq_len = q_len
# Positional encoding
self.W_pos = positional_encoding(pe, learn_pe, q_len, d_model)
# Residual dropout
self.dropout = nn.Dropout(dropout)
# Encoder
self.encoder = TSTEncoder(q_len, d_model, n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout, dropout=dropout,
pre_norm=pre_norm, activation=act, res_attention=res_attention, n_layers=n_layers, store_attn=store_attn)
def forward(self, x) -> Tensor: # x: [bs x nvars x patch_len x patch_num]
n_vars = x.shape[1]
# Input encoding
x = x.permute(0,1,3,2) # x: [bs x nvars x patch_num x patch_len]
x = self.W_P(x) # x: [bs x nvars x patch_num x d_model]
u = torch.reshape(x, (x.shape[0]*x.shape[1],x.shape[2],x.shape[3])) # u: [bs * nvars x patch_num x d_model]
u = self.dropout(u + self.W_pos) # u: [bs * nvars x patch_num x d_model]
# Encoder
z = self.encoder(u) # z: [bs * nvars x patch_num x d_model]
z = torch.reshape(z, (-1,n_vars,z.shape[-2],z.shape[-1])) # z: [bs x nvars x patch_num x d_model]
z = z.permute(0,1,3,2) # z: [bs x nvars x d_model x patch_num]
return z
# Cell
class TSTEncoder(nn.Module):
def __init__(self, q_len, d_model, n_heads, d_k=None, d_v=None, d_ff=None,
norm='BatchNorm', attn_dropout=0., dropout=0., activation='gelu',
res_attention=False, n_layers=1, pre_norm=False, store_attn=False):
super().__init__()
self.layers = nn.ModuleList([TSTEncoderLayer(q_len, d_model, n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, norm=norm,
attn_dropout=attn_dropout, dropout=dropout,
activation=activation, res_attention=res_attention,
pre_norm=pre_norm, store_attn=store_attn) for i in range(n_layers)])
self.res_attention = res_attention
def forward(self, src:Tensor, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None):
output = src
scores = None
if self.res_attention:
for mod in self.layers: output, scores = mod(output, prev=scores, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
return output
else:
for mod in self.layers: output = mod(output, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
return output
class TSTEncoderLayer(nn.Module):
def __init__(self, q_len, d_model, n_heads, d_k=None, d_v=None, d_ff=256, store_attn=False,
norm='BatchNorm', attn_dropout=0, dropout=0., bias=True, activation="gelu", res_attention=False, pre_norm=False):
super().__init__()
assert not d_model%n_heads, f"d_model ({d_model}) must be divisible by n_heads ({n_heads})"
d_k = d_model // n_heads if d_k is None else d_k
d_v = d_model // n_heads if d_v is None else d_v
# Multi-Head attention
self.res_attention = res_attention
self.self_attn = _MultiheadAttention(d_model, n_heads, d_k, d_v, attn_dropout=attn_dropout, proj_dropout=dropout, res_attention=res_attention)
# Add & Norm
self.dropout_attn = nn.Dropout(dropout)
if "batch" in norm.lower():
self.norm_attn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2))
else:
self.norm_attn = nn.LayerNorm(d_model)
# Position-wise Feed-Forward
self.ff = nn.Sequential(nn.Linear(d_model, d_ff, bias=bias),
get_activation_fn(activation),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model, bias=bias))
# Add & Norm
self.dropout_ffn = nn.Dropout(dropout)
if "batch" in norm.lower():
self.norm_ffn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2))
else:
self.norm_ffn = nn.LayerNorm(d_model)
self.pre_norm = pre_norm
self.store_attn = store_attn
def forward(self, src:Tensor, prev:Optional[Tensor]=None, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None) -> Tensor:
# Multi-Head attention sublayer
if self.pre_norm:
src = self.norm_attn(src)
## Multi-Head attention
if self.res_attention:
src2, attn, scores = self.self_attn(src, src, src, prev, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
else:
src2, attn = self.self_attn(src, src, src, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
if self.store_attn:
self.attn = attn
## Add & Norm
src = src + self.dropout_attn(src2) # Add: residual connection with residual dropout
if not self.pre_norm:
src = self.norm_attn(src)
# Feed-forward sublayer
if self.pre_norm:
src = self.norm_ffn(src)
## Position-wise Feed-Forward
src2 = self.ff(src)
## Add & Norm
src = src + self.dropout_ffn(src2) # Add: residual connection with residual dropout
if not self.pre_norm:
src = self.norm_ffn(src)
if self.res_attention:
return src, scores
else:
return src
class _MultiheadAttention(nn.Module):
def __init__(self, d_model, n_heads, d_k=None, d_v=None, res_attention=False, attn_dropout=0., proj_dropout=0., qkv_bias=True, lsa=False):
"""Multi Head Attention Layer
Input shape:
Q: [batch_size (bs) x max_q_len x d_model]
K, V: [batch_size (bs) x q_len x d_model]
mask: [q_len x q_len]
"""
super().__init__()
d_k = d_model // n_heads if d_k is None else d_k
d_v = d_model // n_heads if d_v is None else d_v
self.n_heads, self.d_k, self.d_v = n_heads, d_k, d_v
self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)
self.W_K = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)
self.W_V = nn.Linear(d_model, d_v * n_heads, bias=qkv_bias)
# Scaled Dot-Product Attention (multiple heads)
self.res_attention = res_attention
self.sdp_attn = _ScaledDotProductAttention(d_model, n_heads, attn_dropout=attn_dropout, res_attention=self.res_attention, lsa=lsa)
# Poject output
self.to_out = nn.Sequential(nn.Linear(n_heads * d_v, d_model), nn.Dropout(proj_dropout))
def forward(self, Q:Tensor, K:Optional[Tensor]=None, V:Optional[Tensor]=None, prev:Optional[Tensor]=None,
key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None):
bs = Q.size(0)
if K is None: K = Q
if V is None: V = Q
# Linear (+ split in multiple heads)
q_s = self.W_Q(Q).view(bs, -1, self.n_heads, self.d_k).transpose(1,2) # q_s : [bs x n_heads x max_q_len x d_k]
k_s = self.W_K(K).view(bs, -1, self.n_heads, self.d_k).permute(0,2,3,1) # k_s : [bs x n_heads x d_k x q_len] - transpose(1,2) + transpose(2,3)
v_s = self.W_V(V).view(bs, -1, self.n_heads, self.d_v).transpose(1,2) # v_s : [bs x n_heads x q_len x d_v]
# Apply Scaled Dot-Product Attention (multiple heads)
if self.res_attention:
output, attn_weights, attn_scores = self.sdp_attn(q_s, k_s, v_s, prev=prev, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
else:
output, attn_weights = self.sdp_attn(q_s, k_s, v_s, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
# output: [bs x n_heads x q_len x d_v], attn: [bs x n_heads x q_len x q_len], scores: [bs x n_heads x max_q_len x q_len]
# back to the original inputs dimensions
output = output.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.d_v) # output: [bs x q_len x n_heads * d_v]
output = self.to_out(output)
if self.res_attention: return output, attn_weights, attn_scores
else: return output, attn_weights
class _ScaledDotProductAttention(nn.Module):
r"""Scaled Dot-Product Attention module (Attention is all you need by Vaswani et al., 2017) with optional residual attention from previous layer
(Realformer: Transformer likes residual attention by He et al, 2020) and locality self sttention (Vision Transformer for Small-Size Datasets
by Lee et al, 2021)"""
def __init__(self, d_model, n_heads, attn_dropout=0., res_attention=False, lsa=False):
super().__init__()
self.attn_dropout = nn.Dropout(attn_dropout)
self.res_attention = res_attention
head_dim = d_model // n_heads
self.scale = nn.Parameter(torch.tensor(head_dim ** -0.5), requires_grad=lsa)
self.lsa = lsa
def forward(self, q:Tensor, k:Tensor, v:Tensor, prev:Optional[Tensor]=None, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None):
'''
Input shape:
q : [bs x n_heads x max_q_len x d_k]
k : [bs x n_heads x d_k x seq_len]
v : [bs x n_heads x seq_len x d_v]
prev : [bs x n_heads x q_len x seq_len]
key_padding_mask: [bs x seq_len]
attn_mask : [1 x seq_len x seq_len]
Output shape:
output: [bs x n_heads x q_len x d_v]
attn : [bs x n_heads x q_len x seq_len]
scores : [bs x n_heads x q_len x seq_len]
'''
# Scaled MatMul (q, k) - similarity scores for all pairs of positions in an input sequence
attn_scores = torch.matmul(q, k) * self.scale # attn_scores : [bs x n_heads x max_q_len x q_len]
# Add pre-softmax attention scores from the previous layer (optional)
if prev is not None: attn_scores = attn_scores + prev
# Attention mask (optional)
if attn_mask is not None: # attn_mask with shape [q_len x seq_len] - only used when q_len == seq_len
if attn_mask.dtype == torch.bool:
attn_scores.masked_fill_(attn_mask, -np.inf)
else:
attn_scores += attn_mask
# Key padding mask (optional)
if key_padding_mask is not None: # mask with shape [bs x q_len] (only when max_w_len == q_len)
attn_scores.masked_fill_(key_padding_mask.unsqueeze(1).unsqueeze(2), -np.inf)
# normalize the attention weights
attn_weights = F.softmax(attn_scores, dim=-1) # attn_weights : [bs x n_heads x max_q_len x q_len]
attn_weights = self.attn_dropout(attn_weights)
# compute the new values given the attention weights
output = torch.matmul(attn_weights, v) # output: [bs x n_heads x max_q_len x d_v]
if self.res_attention: return output, attn_weights, attn_scores
else: return output, attn_weights

121
layers/PatchTST_layers.py Normal file
View File

@ -0,0 +1,121 @@
__all__ = ['Transpose', 'get_activation_fn', 'moving_avg', 'series_decomp', 'PositionalEncoding', 'SinCosPosEncoding', 'Coord2dPosEncoding', 'Coord1dPosEncoding', 'positional_encoding']
import torch
from torch import nn
import math
class Transpose(nn.Module):
def __init__(self, *dims, contiguous=False):
super().__init__()
self.dims, self.contiguous = dims, contiguous
def forward(self, x):
if self.contiguous: return x.transpose(*self.dims).contiguous()
else: return x.transpose(*self.dims)
def get_activation_fn(activation):
if callable(activation): return activation()
elif activation.lower() == "relu": return nn.ReLU()
elif activation.lower() == "gelu": return nn.GELU()
raise ValueError(f'{activation} is not available. You can use "relu", "gelu", or a callable')
# decomposition
class moving_avg(nn.Module):
"""
Moving average block to highlight the trend of time series
"""
def __init__(self, kernel_size, stride):
super(moving_avg, self).__init__()
self.kernel_size = kernel_size
self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
def forward(self, x):
# padding on the both ends of time series
front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
x = torch.cat([front, x, end], dim=1)
x = self.avg(x.permute(0, 2, 1))
x = x.permute(0, 2, 1)
return x
class series_decomp(nn.Module):
"""
Series decomposition block
"""
def __init__(self, kernel_size):
super(series_decomp, self).__init__()
self.moving_avg = moving_avg(kernel_size, stride=1)
def forward(self, x):
moving_mean = self.moving_avg(x)
res = x - moving_mean
return res, moving_mean
# pos_encoding
def PositionalEncoding(q_len, d_model, normalize=True):
pe = torch.zeros(q_len, d_model)
position = torch.arange(0, q_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
if normalize:
pe = pe - pe.mean()
pe = pe / (pe.std() * 10)
return pe
SinCosPosEncoding = PositionalEncoding
def Coord2dPosEncoding(q_len, d_model, exponential=False, normalize=True, eps=1e-3, verbose=False):
x = .5 if exponential else 1
i = 0
for i in range(100):
cpe = 2 * (torch.linspace(0, 1, q_len).reshape(-1, 1) ** x) * (torch.linspace(0, 1, d_model).reshape(1, -1) ** x) - 1
# pv(f'{i:4.0f} {x:5.3f} {cpe.mean():+6.3f}', verbose)
if abs(cpe.mean()) <= eps: break
elif cpe.mean() > eps: x += .001
else: x -= .001
i += 1
if normalize:
cpe = cpe - cpe.mean()
cpe = cpe / (cpe.std() * 10)
return cpe
def Coord1dPosEncoding(q_len, exponential=False, normalize=True):
cpe = (2 * (torch.linspace(0, 1, q_len).reshape(-1, 1)**(.5 if exponential else 1)) - 1)
if normalize:
cpe = cpe - cpe.mean()
cpe = cpe / (cpe.std() * 10)
return cpe
def positional_encoding(pe, learn_pe, q_len, d_model):
# Positional encoding
if pe == None:
W_pos = torch.empty((q_len, d_model)) # pe = None and learn_pe = False can be used to measure impact of pe
nn.init.uniform_(W_pos, -0.02, 0.02)
learn_pe = False
elif pe == 'zero':
W_pos = torch.empty((q_len, 1))
nn.init.uniform_(W_pos, -0.02, 0.02)
elif pe == 'zeros':
W_pos = torch.empty((q_len, d_model))
nn.init.uniform_(W_pos, -0.02, 0.02)
elif pe == 'normal' or pe == 'gauss':
W_pos = torch.zeros((q_len, 1))
torch.nn.init.normal_(W_pos, mean=0.0, std=0.1)
elif pe == 'uniform':
W_pos = torch.zeros((q_len, 1))
nn.init.uniform_(W_pos, a=0.0, b=0.1)
elif pe == 'lin1d': W_pos = Coord1dPosEncoding(q_len, exponential=False, normalize=True)
elif pe == 'exp1d': W_pos = Coord1dPosEncoding(q_len, exponential=True, normalize=True)
elif pe == 'lin2d': W_pos = Coord2dPosEncoding(q_len, d_model, exponential=False, normalize=True)
elif pe == 'exp2d': W_pos = Coord2dPosEncoding(q_len, d_model, exponential=True, normalize=True)
elif pe == 'sincos': W_pos = PositionalEncoding(q_len, d_model, normalize=True)
else: raise ValueError(f"{pe} is not a valid pe (positional encoder. Available types: 'gauss'=='normal', \
'zeros', 'zero', uniform', 'lin1d', 'exp1d', 'lin2d', 'exp2d', 'sincos', None.)")
return nn.Parameter(W_pos, requires_grad=learn_pe)

View File

@ -0,0 +1,135 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvLayer(nn.Module):
def __init__(self, c_in):
super(ConvLayer, self).__init__()
self.downConv = nn.Conv1d(in_channels=c_in,
out_channels=c_in,
kernel_size=3,
padding=2,
padding_mode='circular')
self.norm = nn.BatchNorm1d(c_in)
self.activation = nn.ELU()
self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
def forward(self, x):
x = self.downConv(x.permute(0, 2, 1))
x = self.norm(x)
x = self.activation(x)
x = self.maxPool(x)
x = x.transpose(1, 2)
return x
class EncoderLayer(nn.Module):
def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
super(EncoderLayer, self).__init__()
d_ff = d_ff or 4 * d_model
self.attention = attention
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.activation = F.relu if activation == "relu" else F.gelu
def forward(self, x, attn_mask=None, tau=None, delta=None):
new_x, attn = self.attention(
x, x, x,
attn_mask=attn_mask,
tau=tau, delta=delta
)
x = x + self.dropout(new_x)
y = x = self.norm1(x)
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
y = self.dropout(self.conv2(y).transpose(-1, 1))
return self.norm2(x + y), attn
class Encoder(nn.Module):
def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
super(Encoder, self).__init__()
self.attn_layers = nn.ModuleList(attn_layers)
self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
self.norm = norm_layer
def forward(self, x, attn_mask=None, tau=None, delta=None):
# x [B, L, D]
attns = []
if self.conv_layers is not None:
for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)):
delta = delta if i == 0 else None
x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
x = conv_layer(x)
attns.append(attn)
x, attn = self.attn_layers[-1](x, tau=tau, delta=None)
attns.append(attn)
else:
for attn_layer in self.attn_layers:
x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
attns.append(attn)
if self.norm is not None:
x = self.norm(x)
return x, attns
class DecoderLayer(nn.Module):
def __init__(self, self_attention, cross_attention, d_model, d_ff=None,
dropout=0.1, activation="relu"):
super(DecoderLayer, self).__init__()
d_ff = d_ff or 4 * d_model
self.self_attention = self_attention
self.cross_attention = cross_attention
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.activation = F.relu if activation == "relu" else F.gelu
def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
x = x + self.dropout(self.self_attention(
x, x, x,
attn_mask=x_mask,
tau=tau, delta=None
)[0])
x = self.norm1(x)
x = x + self.dropout(self.cross_attention(
x, cross, cross,
attn_mask=cross_mask,
tau=tau, delta=delta
)[0])
y = x = self.norm2(x)
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
y = self.dropout(self.conv2(y).transpose(-1, 1))
return self.norm3(x + y)
class Decoder(nn.Module):
def __init__(self, layers, norm_layer=None, projection=None):
super(Decoder, self).__init__()
self.layers = nn.ModuleList(layers)
self.norm = norm_layer
self.projection = projection
def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
for layer in self.layers:
x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta)
if self.norm is not None:
x = self.norm(x)
if self.projection is not None:
x = self.projection(x)
return x

View File

@ -0,0 +1,132 @@
import torch
from torch import nn
import math
class CrossChannelAttention(nn.Module):
"""
对每个通道 i
Query: 预测区间长度 pred_len 的可学习向量 (不依赖历史值)
Key/Value: 其它通道所有时间点的标量 -> 线性投影
输出: [B, pred_len, C]
复杂度: O(C * pred_len * (C-1) * L),当 C 很小(如 <= 64)通常可接受
"""
def __init__(self, seq_len, pred_len, c_in,
d_model=64, n_heads=4,
dropout=0.1, use_layernorm=True):
super().__init__()
assert d_model % n_heads == 0, "d_model 必须能整除 n_heads"
self.seq_len = seq_len
self.pred_len = pred_len
self.c_in = c_in
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
# 可学习的预测步 Query Embeddings: [pred_len, d_model]
self.query_embed = nn.Parameter(torch.randn(pred_len, d_model))
# 标量值 -> d_model 投影 (共享给 Key/Value)
self.key_proj = nn.Linear(1, d_model)
self.value_proj = nn.Linear(1, d_model)
# 输出压缩成标量
self.out_proj = nn.Linear(d_model, 1)
self.attn_dropout = nn.Dropout(dropout)
self.proj_dropout = nn.Dropout(dropout)
self.use_ln = use_layernorm
if use_layernorm:
self.ln_q = nn.LayerNorm(d_model)
self.ln_kv = nn.LayerNorm(d_model)
# 可选的时间 + 通道位置编码(简单可学习向量)
self.time_pos = nn.Parameter(torch.zeros(seq_len, d_model))
self.channel_pos = nn.Parameter(torch.zeros(c_in, d_model))
nn.init.normal_(self.time_pos, std=0.02)
nn.init.normal_(self.channel_pos, std=0.02)
nn.init.normal_(self.query_embed, std=0.02)
def split_heads(self, x):
# x: [B, T, d_model] -> [B, n_heads, T, head_dim]
B, T, D = x.shape
return x.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
def merge_heads(self, x):
# x: [B, n_heads, T, head_dim] -> [B, T, d_model]
B, H, T, Hd = x.shape
return x.transpose(1, 2).contiguous().view(B, T, H * Hd)
def forward(self, x):
"""
x: [B, L, C]
返回: cross_out [B, pred_len, C]
"""
B, L, C = x.shape
assert L == self.seq_len and C == self.c_in
# 准备 K,V: 对每个通道的时间序列做投影
# 先变成 [B, C, L, 1] -> Linear -> [B, C, L, d]
xc = x.permute(0, 2, 1).unsqueeze(-1) # [B, C, L, 1]
K = self.key_proj(xc) # [B, C, L, d_model]
V = self.value_proj(xc) # [B, C, L, d_model]
# 加位置编码(通道 + 时间)
# broadcast: time_pos [L, d_model] -> [1,1,L,d]; channel_pos [C,d_model]->[1,C,1,d]
K = K + self.time_pos.unsqueeze(0).unsqueeze(0) + self.channel_pos.unsqueeze(0).unsqueeze(2)
V = V + self.time_pos.unsqueeze(0).unsqueeze(0) + self.channel_pos.unsqueeze(0).unsqueeze(2)
if self.use_ln:
K = self.ln_kv(K)
V = self.ln_kv(V)
cross_outputs = []
# 预备 Query所有通道共享 query 基形,再可选加通道偏移)
base_q = self.query_embed # [pred_len, d_model]
for ci in range(C):
# 构造其它通道索引
if C == 1:
# 单通道退化: 直接输出零或复制自身
zero_out = x[:, -self.pred_len:, ci:ci+1]
cross_outputs.append(zero_out)
continue
other_idx = [j for j in range(C) if j != ci]
K_i = K[:, other_idx, :, :] # [B, C-1, L, d_model]
V_i = V[:, other_idx, :, :] # [B, C-1, L, d_model]
# 拉平成 token 维度
K_i = K_i.reshape(B, (C-1)*L, self.d_model) # [B, (C-1)*L, d]
V_i = V_i.reshape(B, (C-1)*L, self.d_model)
# Query: 复制到 batch并加通道偏移可选
Q_i = base_q.unsqueeze(0).expand(B, self.pred_len, self.d_model) # [B, pred_len, d_model]
Q_i = Q_i + self.channel_pos[ci].unsqueeze(0).unsqueeze(0)
if self.use_ln:
Q_i = self.ln_q(Q_i)
# 分头
Qh = self.split_heads(Q_i) # [B, H, pred_len, head_dim]
Kh = self.split_heads(K_i) # [B, H, (C-1)*L, head_dim]
Vh = self.split_heads(V_i) # [B, H, (C-1)*L, head_dim]
# Attention: [B, H, pred_len, (C-1)*L]
scores = torch.matmul(Qh, Kh.transpose(-2, -1)) / math.sqrt(self.head_dim)
attn = torch.softmax(scores, dim=-1)
attn = self.attn_dropout(attn)
# 上下文
ctx = torch.matmul(attn, Vh) # [B, H, pred_len, head_dim]
ctx = self.merge_heads(ctx) # [B, pred_len, d_model]
ctx = self.proj_dropout(ctx)
# 输出压缩到标量: [B, pred_len, 1]
out_ci = self.out_proj(ctx)
cross_outputs.append(out_ci) # list of [B, pred_len, 1]
cross_out = torch.cat(cross_outputs, dim=-1) # [B, pred_len, C]
return cross_out

136
layers/mixer.py Normal file
View File

@ -0,0 +1,136 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class ChannelGraphMixer(nn.Module):
"""
在 PatchTST 的通道独立输出上做一次可学习的稀疏跨通道交互.
输入 z_list : 长度=M 的 list, 每个元素形状 [B, D] (单通道表示)
输出 list, 形状同输入
"""
def __init__(self, n_channel: int, dim: int, k: int = 5, tau: float = 0.2):
super().__init__()
self.k = k
self.tau = tau
self.dim = dim
self.A = nn.Parameter(torch.zeros(n_channel, n_channel)) # 可学习邻接
self.mix = nn.Linear(dim, dim, bias=False) # 通道映射
# 通道注意力过滤 CAF
self.se = nn.Sequential(
nn.Linear(dim, dim // 4, bias=False),
nn.ReLU(),
nn.Linear(dim // 4, 1, bias=False),
nn.Sigmoid()
)
# -------- util: 生成 row-wise top-k 稀疏邻接 -------------
def _row_sparse(self, logits: torch.Tensor) -> torch.Tensor:
"""
logits: [M,M]. 返回一个稀疏矩阵, 每行只保留 top-k, 其余为 0
"""
# Straight-through GumbelSoftmax
g = -torch.empty_like(logits).exponential_().log()
y = (logits + g) / self.tau
probs = F.softmax(y, dim=-1) # 可导
topk_val, _ = torch.topk(probs, self.k, dim=-1)
thr = topk_val[..., -1].unsqueeze(-1) # 每行阈值
sparse = torch.where(probs >= thr, probs, torch.zeros_like(probs))
return sparse.detach() + probs - probs.detach() # ST-estimator
# --------------------------------------------------------
def forward(self, z_list):
M = len(z_list)
B = z_list[0].shape[0]
Z = torch.stack(z_list, dim=1) # [B,M,D]
alpha = self.se(Z).squeeze(-1) # [B,M] 通道重要性
A_sparse = self._row_sparse(self.A) # [M,M]
out = []
for i in range(M):
# 汇聚来自其余通道的表示
agg = 0
for j in range(M):
if A_sparse[i, j] != 0:
agg = agg + alpha[:, j:j+1] * A_sparse[i, j] * self.mix(z_list[j])
out.append(z_list[i] + agg) # 残差
return out
class HierarchicalGraphMixer(nn.Module):
"""
分层图混合器,同时考虑宏观通道关系和微观 Patch 级别注意力。
输入 z : 形状为 [B, C, N, D] 的张量
输出 z_out : 形状同输入
"""
def __init__(self, n_channel: int, dim: int, k: int = 5, tau: float = 0.2):
super().__init__()
self.k = k
self.tau = tau
# Level 1: Channel Graph
self.A = nn.Parameter(torch.zeros(n_channel, n_channel))
self.se = nn.Sequential(
nn.Linear(dim, dim // 4, bias=False), nn.ReLU(),
nn.Linear(dim // 4, 1, bias=False), nn.Sigmoid()
)
# Level 2: Patch Cross-Attention
self.q_proj = nn.Linear(dim, dim)
self.k_proj = nn.Linear(dim, dim)
self.v_proj = nn.Linear(dim, dim)
self.out_proj = nn.Linear(dim, dim)
self.norm = nn.LayerNorm(dim)
def _row_sparse(self, logits: torch.Tensor) -> torch.Tensor:
# (Gumbel-Softmax 函数,无需改变)
g = -torch.empty_like(logits).exponential_().log()
y = (logits + g) / self.tau
probs = F.softmax(y, dim=-1)
topk_val, _ = torch.topk(probs, self.k, dim=-1)
thr = topk_val[..., -1].unsqueeze(-1)
sparse = torch.where(probs >= thr, probs, torch.zeros_like(probs))
return sparse.detach() + probs - probs.detach()
def forward(self, z):
# z 的形状: [B, C, N, D]
B, C, N, D = z.shape
# --- Level 1: 计算宏观权重 ---
A_sparse = self._row_sparse(self.A) # 通道连接稀疏图 A_sparse: [C, C]
# --- Level 2: 跨通道 Patch 交互 ---
out_z = torch.zeros_like(z)
for i in range(C): # 遍历每个目标通道 i
target_z = z[:, i, :, :] # [B, N, D]
# 准备聚合来自其他通道的 patch 级别上下文
aggregated_context = torch.zeros_like(target_z)
for j in range(C): # 遍历每个源通道 j
if A_sparse[i, j] != 0:
source_z = z[:, j, :, :] # [B, N, D]
# --- 执行交叉注意力 ---
Q = self.q_proj(target_z) # Query 来自目标通道 i
K = self.k_proj(source_z) # Key 来自源通道 j
V = self.v_proj(source_z) # Value 来自源通道 j
attn_scores = torch.bmm(Q, K.transpose(1, 2)) / math.sqrt(D)
attn_probs = F.softmax(attn_scores, dim=-1) # [B, N, N]
context = torch.bmm(attn_probs, V) # [B, N, D], 从 j 聚合到 i 的上下文
# 宏观权重2 (通道连接强度): A_sparse[i, j] -> 标量
# 加权上下文
weighted_context = A_sparse[i, j] * context
aggregated_context = aggregated_context + weighted_context
# 将聚合后的上下文通过输出层,并与原始目标表示相加(残差连接)
# LayerNorm 增加稳定性
out_z[:, i, :, :] = self.norm(target_z + self.out_proj(aggregated_context))
return out_z

239
layers/ps_loss.py Normal file
View File

@ -0,0 +1,239 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class PSLoss(nn.Module):
"""
Patch-wise Structural (PS) Loss for time series forecasting.
Implements the loss function described in the paper that combines:
1. Fourier-based Adaptive Patching (FAP)
2. Patch-wise Structural Loss (correlation, variance, mean)
3. Gradient-based Dynamic Weighting (GDW)
"""
def __init__(self, patch_len_threshold=64, lambda_ps=5.0, use_gdw=True):
super(PSLoss, self).__init__()
self.patch_len_threshold = patch_len_threshold
self.lambda_ps = lambda_ps
self.use_gdw = use_gdw
self.kl_loss = nn.KLDivLoss(reduction='none')
self.mse_loss = nn.MSELoss()
def create_patches(self, x, patch_len, stride):
"""
Create patches from time series data.
Args:
x: Input tensor of shape [B, L, C]
patch_len: Length of each patch
stride: Stride for patching
Returns:
patches: Tensor of shape [B, C, N, P] where N is number of patches, P is patch length
"""
B, L, C = x.shape
num_patches = (L - patch_len) // stride + 1
patches = x.unfold(1, patch_len, stride) # [B, N, C, P]
patches = patches.permute(0, 2, 1, 3) # [B, C, N, P]
return patches
def fourier_based_adaptive_patching(self, true, pred):
"""
Fourier-based Adaptive Patching (FAP) to determine optimal patch length.
Args:
true: Ground truth tensor [B, L, C]
pred: Prediction tensor [B, L, C]
Returns:
true_patch: Patches from ground truth [B, C, N, P]
pred_patch: Patches from prediction [B, C, N, P]
"""
# Get dominant frequency from ground truth
true_fft = torch.fft.rfft(true, dim=1)
frequency_list = torch.abs(true_fft).mean(0).mean(-1)
frequency_list[:1] = 0.0 # Remove DC component
top_index = torch.argmax(frequency_list)
# Calculate period and patch length
period = true.shape[1] // top_index if top_index > 0 else self.patch_len_threshold
patch_len = min(period // 2, self.patch_len_threshold)
patch_len = max(patch_len, 4) # Minimum patch length
stride = max(patch_len // 2, 1)
# Create patches
true_patch = self.create_patches(true, patch_len, stride)
pred_patch = self.create_patches(pred, patch_len, stride)
return true_patch, pred_patch
def patch_wise_structural_loss(self, true_patch, pred_patch):
"""
Calculate patch-wise structural loss components.
Args:
true_patch: Ground truth patches [B, C, N, P]
pred_patch: Prediction patches [B, C, N, P]
Returns:
corr_loss: Correlation loss
var_loss: Variance loss
mean_loss: Mean loss
"""
# Calculate patch statistics
true_patch_mean = torch.mean(true_patch, dim=-1, keepdim=True) # [B, C, N, 1]
pred_patch_mean = torch.mean(pred_patch, dim=-1, keepdim=True) # [B, C, N, 1]
true_patch_var = torch.var(true_patch, dim=-1, keepdim=True, unbiased=False)
pred_patch_var = torch.var(pred_patch, dim=-1, keepdim=True, unbiased=False)
true_patch_std = torch.sqrt(true_patch_var + 1e-8)
pred_patch_std = torch.sqrt(pred_patch_var + 1e-8)
# Calculate covariance
true_centered = true_patch - true_patch_mean
pred_centered = pred_patch - pred_patch_mean
covariance = torch.mean(true_centered * pred_centered, dim=-1, keepdim=True)
# 1. Correlation Loss (based on Pearson correlation coefficient)
correlation = covariance / (true_patch_std * pred_patch_std + 1e-8)
corr_loss = (1.0 - correlation).mean()
# 2. Variance Loss (using KL divergence of softmax distributions)
true_patch_softmax = F.softmax(true_patch, dim=-1)
pred_patch_log_softmax = F.log_softmax(pred_patch, dim=-1)
var_loss = self.kl_loss(pred_patch_log_softmax, true_patch_softmax).sum(dim=-1).mean()
# 3. Mean Loss (MAE between patch means)
mean_loss = torch.abs(true_patch_mean - pred_patch_mean).mean()
return corr_loss, var_loss, mean_loss
def gradient_based_dynamic_weighting(self, model, corr_loss, var_loss, mean_loss, true, pred):
"""
Gradient-based Dynamic Weighting (GDW) for balancing loss components.
Args:
model: The neural network model
corr_loss, var_loss, mean_loss: Loss components
true: Ground truth tensor [B, L, C]
pred: Prediction tensor [B, L, C]
Returns:
alpha, beta, gamma: Dynamic weights for the three loss components
"""
if not self.use_gdw or not self.training:
return 1.0, 1.0, 1.0
try:
# Get model parameters for gradient calculation
if hasattr(model, 'projector'):
params = list(model.projector.parameters())
elif hasattr(model, 'projection'):
params = list(model.projection.parameters())
else:
# Use output layer parameters
params = list(model.parameters())[-2:] # Last linear layer
if not params:
return 1.0, 1.0, 1.0
# Calculate gradients for each loss component
corr_grad = torch.autograd.grad(corr_loss, params, retain_graph=True, allow_unused=True)
var_grad = torch.autograd.grad(var_loss, params, retain_graph=True, allow_unused=True)
mean_grad = torch.autograd.grad(mean_loss, params, retain_graph=True, allow_unused=True)
# Filter out None gradients and calculate norms
corr_grad_norm = sum(g.norm().item() for g in corr_grad if g is not None)
var_grad_norm = sum(g.norm().item() for g in var_grad if g is not None)
mean_grad_norm = sum(g.norm().item() for g in mean_grad if g is not None)
if corr_grad_norm == 0 or var_grad_norm == 0 or mean_grad_norm == 0:
return 1.0, 1.0, 1.0
# Calculate average gradient magnitude
grad_avg = (corr_grad_norm + var_grad_norm + mean_grad_norm) / 3.0
# Calculate dynamic weights
alpha = grad_avg / corr_grad_norm
beta = grad_avg / var_grad_norm
gamma = grad_avg / mean_grad_norm
# Calculate scaling factors for gamma
true_flat = true.view(-1, true.shape[-1])
pred_flat = pred.view(-1, pred.shape[-1])
true_mean = torch.mean(true_flat, dim=0, keepdim=True)
pred_mean = torch.mean(pred_flat, dim=0, keepdim=True)
true_std = torch.std(true_flat, dim=0, keepdim=True) + 1e-8
pred_std = torch.std(pred_flat, dim=0, keepdim=True) + 1e-8
covariance = torch.mean((true_flat - true_mean) * (pred_flat - pred_mean), dim=0, keepdim=True)
correlation_sim = covariance / (true_std * pred_std)
correlation_sim = (1.0 + correlation_sim.mean()) * 0.5
variance_sim = (2 * true_std.mean() * pred_std.mean()) / (true_std.mean()**2 + pred_std.mean()**2)
c = 0.5 * (1 + correlation_sim)
v = variance_sim
gamma = gamma * c * v
return alpha.item(), beta.item(), gamma.item()
except Exception:
# Fallback to equal weights if gradient calculation fails
return 1.0, 1.0, 1.0
def forward(self, pred, true, model=None):
"""
Forward pass of PS Loss.
Args:
pred: Predictions [B, L, C]
true: Ground truth [B, L, C]
model: Neural network model (for gradient-based weighting)
Returns:
total_loss: Combined MSE + PS loss
loss_dict: Dictionary with individual loss components
"""
# Standard MSE loss
mse_loss = self.mse_loss(pred, true)
# Fourier-based adaptive patching
true_patch, pred_patch = self.fourier_based_adaptive_patching(true, pred)
# Patch-wise structural loss
corr_loss, var_loss, mean_loss = self.patch_wise_structural_loss(true_patch, pred_patch)
# Gradient-based dynamic weighting
if model is not None and self.use_gdw:
alpha, beta, gamma = self.gradient_based_dynamic_weighting(
model, corr_loss, var_loss, mean_loss, true, pred
)
else:
alpha, beta, gamma = 1.0, 1.0, 1.0
# Combine PS loss components
ps_loss = alpha * corr_loss + beta * var_loss + gamma * mean_loss
# Total loss: MSE + λ * PS_loss
total_loss = mse_loss + self.lambda_ps * ps_loss
loss_dict = {
'mse_loss': mse_loss.item(),
'ps_loss': ps_loss.item(),
'corr_loss': corr_loss.item(),
'var_loss': var_loss.item(),
'mean_loss': mean_loss.item(),
'alpha': alpha,
'beta': beta,
'gamma': gamma,
'total_loss': total_loss.item()
}
return total_loss, loss_dict