From 59b23d4637956759f1200d1bd26cd34e4936721b Mon Sep 17 00:00:00 2001 From: game-loader Date: Thu, 28 Aug 2025 13:23:06 +0800 Subject: [PATCH] feat(model): add initial PatchTST model architecture and utilities --- layers/PatchTST_backbone.py | 379 +++++++++++++++++++++++++++++++++++ layers/PatchTST_layers.py | 121 +++++++++++ layers/Transformer_EncDec.py | 135 +++++++++++++ layers/cross_channel_attn.py | 132 ++++++++++++ layers/mixer.py | 136 +++++++++++++ layers/ps_loss.py | 239 ++++++++++++++++++++++ 6 files changed, 1142 insertions(+) create mode 100644 layers/PatchTST_backbone.py create mode 100644 layers/PatchTST_layers.py create mode 100644 layers/Transformer_EncDec.py create mode 100644 layers/cross_channel_attn.py create mode 100644 layers/mixer.py create mode 100644 layers/ps_loss.py diff --git a/layers/PatchTST_backbone.py b/layers/PatchTST_backbone.py new file mode 100644 index 0000000..14d292e --- /dev/null +++ b/layers/PatchTST_backbone.py @@ -0,0 +1,379 @@ +__all__ = ['PatchTST_backbone'] + +# Cell +from typing import Callable, Optional +import torch +from torch import nn +from torch import Tensor +import torch.nn.functional as F +import numpy as np + +#from collections import OrderedDict +from layers.PatchTST_layers import * +from layers.revin import RevIN + +# Cell +class PatchTST_backbone(nn.Module): + def __init__(self, c_in:int, context_window:int, target_window:int, patch_len:int, stride:int, max_seq_len:Optional[int]=1024, + n_layers:int=3, d_model=128, n_heads=16, d_k:Optional[int]=None, d_v:Optional[int]=None, + d_ff:int=256, norm:str='BatchNorm', attn_dropout:float=0., dropout:float=0., act:str="gelu", key_padding_mask:bool='auto', + padding_var:Optional[int]=None, attn_mask:Optional[Tensor]=None, res_attention:bool=True, pre_norm:bool=False, store_attn:bool=False, + pe:str='zeros', learn_pe:bool=True, fc_dropout:float=0., head_dropout = 0, padding_patch = None, + pretrain_head:bool=False, head_type = 'flatten', individual = False, revin = True, affine = True, subtract_last = False, + verbose:bool=False, **kwargs): + + super().__init__() + + # RevIn + self.revin = revin + if self.revin: self.revin_layer = RevIN(c_in, affine=affine, subtract_last=subtract_last) + + # Patching + self.patch_len = patch_len + self.stride = stride + self.padding_patch = padding_patch + patch_num = int((context_window - patch_len)/stride + 1) + if padding_patch == 'end': # can be modified to general case + self.padding_patch_layer = nn.ReplicationPad1d((0, stride)) + patch_num += 1 + + # Backbone + self.backbone = TSTiEncoder(c_in, patch_num=patch_num, patch_len=patch_len, max_seq_len=max_seq_len, + n_layers=n_layers, d_model=d_model, n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, + attn_dropout=attn_dropout, dropout=dropout, act=act, key_padding_mask=key_padding_mask, padding_var=padding_var, + attn_mask=attn_mask, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn, + pe=pe, learn_pe=learn_pe, verbose=verbose, **kwargs) + + # Head + self.head_nf = d_model * patch_num + self.n_vars = c_in + self.pretrain_head = pretrain_head + self.head_type = head_type + self.individual = individual + + if self.pretrain_head: + self.head = self.create_pretrain_head(self.head_nf, c_in, fc_dropout) # custom head passed as a partial func with all its kwargs + elif head_type == 'flatten': + self.head = Flatten_Head(self.individual, self.n_vars, self.head_nf, target_window, head_dropout=head_dropout) + + + def forward(self, z): # z: [bs x nvars x seq_len] + # norm + if self.revin: + z = z.permute(0,2,1) + z = self.revin_layer(z, 'norm') + z = z.permute(0,2,1) + + # do patching + if self.padding_patch == 'end': + z = self.padding_patch_layer(z) + z = z.unfold(dimension=-1, size=self.patch_len, step=self.stride) # z: [bs x nvars x patch_num x patch_len] + z = z.permute(0,1,3,2) # z: [bs x nvars x patch_len x patch_num] + + # model + z = self.backbone(z) # z: [bs x nvars x d_model x patch_num] + z = self.head(z) # z: [bs x nvars x target_window] + + # denorm + if self.revin: + z = z.permute(0,2,1) + z = self.revin_layer(z, 'denorm') + z = z.permute(0,2,1) + return z + + def create_pretrain_head(self, head_nf, vars, dropout): + return nn.Sequential(nn.Dropout(dropout), + nn.Conv1d(head_nf, vars, 1) + ) + + +class Flatten_Head(nn.Module): + def __init__(self, individual, n_vars, nf, target_window, head_dropout=0): + super().__init__() + + self.individual = individual + self.n_vars = n_vars + + if self.individual: + self.linears = nn.ModuleList() + self.dropouts = nn.ModuleList() + self.flattens = nn.ModuleList() + for i in range(self.n_vars): + self.flattens.append(nn.Flatten(start_dim=-2)) + self.linears.append(nn.Linear(nf, target_window)) + self.dropouts.append(nn.Dropout(head_dropout)) + else: + self.flatten = nn.Flatten(start_dim=-2) + self.linear = nn.Linear(nf, target_window) + self.dropout = nn.Dropout(head_dropout) + + def forward(self, x): # x: [bs x nvars x d_model x patch_num] + if self.individual: + x_out = [] + for i in range(self.n_vars): + z = self.flattens[i](x[:,i,:,:]) # z: [bs x d_model * patch_num] + z = self.linears[i](z) # z: [bs x target_window] + z = self.dropouts[i](z) + x_out.append(z) + x = torch.stack(x_out, dim=1) # x: [bs x nvars x target_window] + else: + x = self.flatten(x) + x = self.linear(x) + x = self.dropout(x) + return x + + + + +class TSTiEncoder(nn.Module): #i means channel-independent + def __init__(self, c_in, patch_num, patch_len, max_seq_len=1024, + n_layers=3, d_model=128, n_heads=16, d_k=None, d_v=None, + d_ff=256, norm='BatchNorm', attn_dropout=0., dropout=0., act="gelu", store_attn=False, + key_padding_mask='auto', padding_var=None, attn_mask=None, res_attention=True, pre_norm=False, + pe='zeros', learn_pe=True, verbose=False, **kwargs): + + + super().__init__() + + self.patch_num = patch_num + self.patch_len = patch_len + + # Input encoding + q_len = patch_num + self.W_P = nn.Linear(patch_len, d_model) # Eq 1: projection of feature vectors onto a d-dim vector space + self.seq_len = q_len + + # Positional encoding + self.W_pos = positional_encoding(pe, learn_pe, q_len, d_model) + + # Residual dropout + self.dropout = nn.Dropout(dropout) + + # Encoder + self.encoder = TSTEncoder(q_len, d_model, n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout, dropout=dropout, + pre_norm=pre_norm, activation=act, res_attention=res_attention, n_layers=n_layers, store_attn=store_attn) + + + def forward(self, x) -> Tensor: # x: [bs x nvars x patch_len x patch_num] + + n_vars = x.shape[1] + # Input encoding + x = x.permute(0,1,3,2) # x: [bs x nvars x patch_num x patch_len] + x = self.W_P(x) # x: [bs x nvars x patch_num x d_model] + + u = torch.reshape(x, (x.shape[0]*x.shape[1],x.shape[2],x.shape[3])) # u: [bs * nvars x patch_num x d_model] + u = self.dropout(u + self.W_pos) # u: [bs * nvars x patch_num x d_model] + + # Encoder + z = self.encoder(u) # z: [bs * nvars x patch_num x d_model] + z = torch.reshape(z, (-1,n_vars,z.shape[-2],z.shape[-1])) # z: [bs x nvars x patch_num x d_model] + z = z.permute(0,1,3,2) # z: [bs x nvars x d_model x patch_num] + + return z + + + +# Cell +class TSTEncoder(nn.Module): + def __init__(self, q_len, d_model, n_heads, d_k=None, d_v=None, d_ff=None, + norm='BatchNorm', attn_dropout=0., dropout=0., activation='gelu', + res_attention=False, n_layers=1, pre_norm=False, store_attn=False): + super().__init__() + + self.layers = nn.ModuleList([TSTEncoderLayer(q_len, d_model, n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, norm=norm, + attn_dropout=attn_dropout, dropout=dropout, + activation=activation, res_attention=res_attention, + pre_norm=pre_norm, store_attn=store_attn) for i in range(n_layers)]) + self.res_attention = res_attention + + def forward(self, src:Tensor, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None): + output = src + scores = None + if self.res_attention: + for mod in self.layers: output, scores = mod(output, prev=scores, key_padding_mask=key_padding_mask, attn_mask=attn_mask) + return output + else: + for mod in self.layers: output = mod(output, key_padding_mask=key_padding_mask, attn_mask=attn_mask) + return output + + + +class TSTEncoderLayer(nn.Module): + def __init__(self, q_len, d_model, n_heads, d_k=None, d_v=None, d_ff=256, store_attn=False, + norm='BatchNorm', attn_dropout=0, dropout=0., bias=True, activation="gelu", res_attention=False, pre_norm=False): + super().__init__() + assert not d_model%n_heads, f"d_model ({d_model}) must be divisible by n_heads ({n_heads})" + d_k = d_model // n_heads if d_k is None else d_k + d_v = d_model // n_heads if d_v is None else d_v + + # Multi-Head attention + self.res_attention = res_attention + self.self_attn = _MultiheadAttention(d_model, n_heads, d_k, d_v, attn_dropout=attn_dropout, proj_dropout=dropout, res_attention=res_attention) + + # Add & Norm + self.dropout_attn = nn.Dropout(dropout) + if "batch" in norm.lower(): + self.norm_attn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2)) + else: + self.norm_attn = nn.LayerNorm(d_model) + + # Position-wise Feed-Forward + self.ff = nn.Sequential(nn.Linear(d_model, d_ff, bias=bias), + get_activation_fn(activation), + nn.Dropout(dropout), + nn.Linear(d_ff, d_model, bias=bias)) + + # Add & Norm + self.dropout_ffn = nn.Dropout(dropout) + if "batch" in norm.lower(): + self.norm_ffn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2)) + else: + self.norm_ffn = nn.LayerNorm(d_model) + + self.pre_norm = pre_norm + self.store_attn = store_attn + + + def forward(self, src:Tensor, prev:Optional[Tensor]=None, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None) -> Tensor: + + # Multi-Head attention sublayer + if self.pre_norm: + src = self.norm_attn(src) + ## Multi-Head attention + if self.res_attention: + src2, attn, scores = self.self_attn(src, src, src, prev, key_padding_mask=key_padding_mask, attn_mask=attn_mask) + else: + src2, attn = self.self_attn(src, src, src, key_padding_mask=key_padding_mask, attn_mask=attn_mask) + if self.store_attn: + self.attn = attn + ## Add & Norm + src = src + self.dropout_attn(src2) # Add: residual connection with residual dropout + if not self.pre_norm: + src = self.norm_attn(src) + + # Feed-forward sublayer + if self.pre_norm: + src = self.norm_ffn(src) + ## Position-wise Feed-Forward + src2 = self.ff(src) + ## Add & Norm + src = src + self.dropout_ffn(src2) # Add: residual connection with residual dropout + if not self.pre_norm: + src = self.norm_ffn(src) + + if self.res_attention: + return src, scores + else: + return src + + + + +class _MultiheadAttention(nn.Module): + def __init__(self, d_model, n_heads, d_k=None, d_v=None, res_attention=False, attn_dropout=0., proj_dropout=0., qkv_bias=True, lsa=False): + """Multi Head Attention Layer + Input shape: + Q: [batch_size (bs) x max_q_len x d_model] + K, V: [batch_size (bs) x q_len x d_model] + mask: [q_len x q_len] + """ + super().__init__() + d_k = d_model // n_heads if d_k is None else d_k + d_v = d_model // n_heads if d_v is None else d_v + + self.n_heads, self.d_k, self.d_v = n_heads, d_k, d_v + + self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias) + self.W_K = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias) + self.W_V = nn.Linear(d_model, d_v * n_heads, bias=qkv_bias) + + # Scaled Dot-Product Attention (multiple heads) + self.res_attention = res_attention + self.sdp_attn = _ScaledDotProductAttention(d_model, n_heads, attn_dropout=attn_dropout, res_attention=self.res_attention, lsa=lsa) + + # Poject output + self.to_out = nn.Sequential(nn.Linear(n_heads * d_v, d_model), nn.Dropout(proj_dropout)) + + + def forward(self, Q:Tensor, K:Optional[Tensor]=None, V:Optional[Tensor]=None, prev:Optional[Tensor]=None, + key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None): + + bs = Q.size(0) + if K is None: K = Q + if V is None: V = Q + + # Linear (+ split in multiple heads) + q_s = self.W_Q(Q).view(bs, -1, self.n_heads, self.d_k).transpose(1,2) # q_s : [bs x n_heads x max_q_len x d_k] + k_s = self.W_K(K).view(bs, -1, self.n_heads, self.d_k).permute(0,2,3,1) # k_s : [bs x n_heads x d_k x q_len] - transpose(1,2) + transpose(2,3) + v_s = self.W_V(V).view(bs, -1, self.n_heads, self.d_v).transpose(1,2) # v_s : [bs x n_heads x q_len x d_v] + + # Apply Scaled Dot-Product Attention (multiple heads) + if self.res_attention: + output, attn_weights, attn_scores = self.sdp_attn(q_s, k_s, v_s, prev=prev, key_padding_mask=key_padding_mask, attn_mask=attn_mask) + else: + output, attn_weights = self.sdp_attn(q_s, k_s, v_s, key_padding_mask=key_padding_mask, attn_mask=attn_mask) + # output: [bs x n_heads x q_len x d_v], attn: [bs x n_heads x q_len x q_len], scores: [bs x n_heads x max_q_len x q_len] + + # back to the original inputs dimensions + output = output.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.d_v) # output: [bs x q_len x n_heads * d_v] + output = self.to_out(output) + + if self.res_attention: return output, attn_weights, attn_scores + else: return output, attn_weights + + +class _ScaledDotProductAttention(nn.Module): + r"""Scaled Dot-Product Attention module (Attention is all you need by Vaswani et al., 2017) with optional residual attention from previous layer + (Realformer: Transformer likes residual attention by He et al, 2020) and locality self sttention (Vision Transformer for Small-Size Datasets + by Lee et al, 2021)""" + + def __init__(self, d_model, n_heads, attn_dropout=0., res_attention=False, lsa=False): + super().__init__() + self.attn_dropout = nn.Dropout(attn_dropout) + self.res_attention = res_attention + head_dim = d_model // n_heads + self.scale = nn.Parameter(torch.tensor(head_dim ** -0.5), requires_grad=lsa) + self.lsa = lsa + + def forward(self, q:Tensor, k:Tensor, v:Tensor, prev:Optional[Tensor]=None, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None): + ''' + Input shape: + q : [bs x n_heads x max_q_len x d_k] + k : [bs x n_heads x d_k x seq_len] + v : [bs x n_heads x seq_len x d_v] + prev : [bs x n_heads x q_len x seq_len] + key_padding_mask: [bs x seq_len] + attn_mask : [1 x seq_len x seq_len] + Output shape: + output: [bs x n_heads x q_len x d_v] + attn : [bs x n_heads x q_len x seq_len] + scores : [bs x n_heads x q_len x seq_len] + ''' + + # Scaled MatMul (q, k) - similarity scores for all pairs of positions in an input sequence + attn_scores = torch.matmul(q, k) * self.scale # attn_scores : [bs x n_heads x max_q_len x q_len] + + # Add pre-softmax attention scores from the previous layer (optional) + if prev is not None: attn_scores = attn_scores + prev + + # Attention mask (optional) + if attn_mask is not None: # attn_mask with shape [q_len x seq_len] - only used when q_len == seq_len + if attn_mask.dtype == torch.bool: + attn_scores.masked_fill_(attn_mask, -np.inf) + else: + attn_scores += attn_mask + + # Key padding mask (optional) + if key_padding_mask is not None: # mask with shape [bs x q_len] (only when max_w_len == q_len) + attn_scores.masked_fill_(key_padding_mask.unsqueeze(1).unsqueeze(2), -np.inf) + + # normalize the attention weights + attn_weights = F.softmax(attn_scores, dim=-1) # attn_weights : [bs x n_heads x max_q_len x q_len] + attn_weights = self.attn_dropout(attn_weights) + + # compute the new values given the attention weights + output = torch.matmul(attn_weights, v) # output: [bs x n_heads x max_q_len x d_v] + + if self.res_attention: return output, attn_weights, attn_scores + else: return output, attn_weights + diff --git a/layers/PatchTST_layers.py b/layers/PatchTST_layers.py new file mode 100644 index 0000000..9190fff --- /dev/null +++ b/layers/PatchTST_layers.py @@ -0,0 +1,121 @@ +__all__ = ['Transpose', 'get_activation_fn', 'moving_avg', 'series_decomp', 'PositionalEncoding', 'SinCosPosEncoding', 'Coord2dPosEncoding', 'Coord1dPosEncoding', 'positional_encoding'] + +import torch +from torch import nn +import math + +class Transpose(nn.Module): + def __init__(self, *dims, contiguous=False): + super().__init__() + self.dims, self.contiguous = dims, contiguous + def forward(self, x): + if self.contiguous: return x.transpose(*self.dims).contiguous() + else: return x.transpose(*self.dims) + + +def get_activation_fn(activation): + if callable(activation): return activation() + elif activation.lower() == "relu": return nn.ReLU() + elif activation.lower() == "gelu": return nn.GELU() + raise ValueError(f'{activation} is not available. You can use "relu", "gelu", or a callable') + + +# decomposition + +class moving_avg(nn.Module): + """ + Moving average block to highlight the trend of time series + """ + def __init__(self, kernel_size, stride): + super(moving_avg, self).__init__() + self.kernel_size = kernel_size + self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0) + + def forward(self, x): + # padding on the both ends of time series + front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1) + end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1) + x = torch.cat([front, x, end], dim=1) + x = self.avg(x.permute(0, 2, 1)) + x = x.permute(0, 2, 1) + return x + + +class series_decomp(nn.Module): + """ + Series decomposition block + """ + def __init__(self, kernel_size): + super(series_decomp, self).__init__() + self.moving_avg = moving_avg(kernel_size, stride=1) + + def forward(self, x): + moving_mean = self.moving_avg(x) + res = x - moving_mean + return res, moving_mean + + + +# pos_encoding + +def PositionalEncoding(q_len, d_model, normalize=True): + pe = torch.zeros(q_len, d_model) + position = torch.arange(0, q_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + if normalize: + pe = pe - pe.mean() + pe = pe / (pe.std() * 10) + return pe + +SinCosPosEncoding = PositionalEncoding + +def Coord2dPosEncoding(q_len, d_model, exponential=False, normalize=True, eps=1e-3, verbose=False): + x = .5 if exponential else 1 + i = 0 + for i in range(100): + cpe = 2 * (torch.linspace(0, 1, q_len).reshape(-1, 1) ** x) * (torch.linspace(0, 1, d_model).reshape(1, -1) ** x) - 1 + # pv(f'{i:4.0f} {x:5.3f} {cpe.mean():+6.3f}', verbose) + if abs(cpe.mean()) <= eps: break + elif cpe.mean() > eps: x += .001 + else: x -= .001 + i += 1 + if normalize: + cpe = cpe - cpe.mean() + cpe = cpe / (cpe.std() * 10) + return cpe + +def Coord1dPosEncoding(q_len, exponential=False, normalize=True): + cpe = (2 * (torch.linspace(0, 1, q_len).reshape(-1, 1)**(.5 if exponential else 1)) - 1) + if normalize: + cpe = cpe - cpe.mean() + cpe = cpe / (cpe.std() * 10) + return cpe + +def positional_encoding(pe, learn_pe, q_len, d_model): + # Positional encoding + if pe == None: + W_pos = torch.empty((q_len, d_model)) # pe = None and learn_pe = False can be used to measure impact of pe + nn.init.uniform_(W_pos, -0.02, 0.02) + learn_pe = False + elif pe == 'zero': + W_pos = torch.empty((q_len, 1)) + nn.init.uniform_(W_pos, -0.02, 0.02) + elif pe == 'zeros': + W_pos = torch.empty((q_len, d_model)) + nn.init.uniform_(W_pos, -0.02, 0.02) + elif pe == 'normal' or pe == 'gauss': + W_pos = torch.zeros((q_len, 1)) + torch.nn.init.normal_(W_pos, mean=0.0, std=0.1) + elif pe == 'uniform': + W_pos = torch.zeros((q_len, 1)) + nn.init.uniform_(W_pos, a=0.0, b=0.1) + elif pe == 'lin1d': W_pos = Coord1dPosEncoding(q_len, exponential=False, normalize=True) + elif pe == 'exp1d': W_pos = Coord1dPosEncoding(q_len, exponential=True, normalize=True) + elif pe == 'lin2d': W_pos = Coord2dPosEncoding(q_len, d_model, exponential=False, normalize=True) + elif pe == 'exp2d': W_pos = Coord2dPosEncoding(q_len, d_model, exponential=True, normalize=True) + elif pe == 'sincos': W_pos = PositionalEncoding(q_len, d_model, normalize=True) + else: raise ValueError(f"{pe} is not a valid pe (positional encoder. Available types: 'gauss'=='normal', \ + 'zeros', 'zero', uniform', 'lin1d', 'exp1d', 'lin2d', 'exp2d', 'sincos', None.)") + return nn.Parameter(W_pos, requires_grad=learn_pe) diff --git a/layers/Transformer_EncDec.py b/layers/Transformer_EncDec.py new file mode 100644 index 0000000..dabf4c2 --- /dev/null +++ b/layers/Transformer_EncDec.py @@ -0,0 +1,135 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ConvLayer(nn.Module): + def __init__(self, c_in): + super(ConvLayer, self).__init__() + self.downConv = nn.Conv1d(in_channels=c_in, + out_channels=c_in, + kernel_size=3, + padding=2, + padding_mode='circular') + self.norm = nn.BatchNorm1d(c_in) + self.activation = nn.ELU() + self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) + + def forward(self, x): + x = self.downConv(x.permute(0, 2, 1)) + x = self.norm(x) + x = self.activation(x) + x = self.maxPool(x) + x = x.transpose(1, 2) + return x + + +class EncoderLayer(nn.Module): + def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"): + super(EncoderLayer, self).__init__() + d_ff = d_ff or 4 * d_model + self.attention = attention + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + self.activation = F.relu if activation == "relu" else F.gelu + + def forward(self, x, attn_mask=None, tau=None, delta=None): + new_x, attn = self.attention( + x, x, x, + attn_mask=attn_mask, + tau=tau, delta=delta + ) + x = x + self.dropout(new_x) + + y = x = self.norm1(x) + y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + + return self.norm2(x + y), attn + + +class Encoder(nn.Module): + def __init__(self, attn_layers, conv_layers=None, norm_layer=None): + super(Encoder, self).__init__() + self.attn_layers = nn.ModuleList(attn_layers) + self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None + self.norm = norm_layer + + def forward(self, x, attn_mask=None, tau=None, delta=None): + # x [B, L, D] + attns = [] + if self.conv_layers is not None: + for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)): + delta = delta if i == 0 else None + x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) + x = conv_layer(x) + attns.append(attn) + x, attn = self.attn_layers[-1](x, tau=tau, delta=None) + attns.append(attn) + else: + for attn_layer in self.attn_layers: + x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) + attns.append(attn) + + if self.norm is not None: + x = self.norm(x) + + return x, attns + + +class DecoderLayer(nn.Module): + def __init__(self, self_attention, cross_attention, d_model, d_ff=None, + dropout=0.1, activation="relu"): + super(DecoderLayer, self).__init__() + d_ff = d_ff or 4 * d_model + self.self_attention = self_attention + self.cross_attention = cross_attention + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + self.activation = F.relu if activation == "relu" else F.gelu + + def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None): + x = x + self.dropout(self.self_attention( + x, x, x, + attn_mask=x_mask, + tau=tau, delta=None + )[0]) + x = self.norm1(x) + + x = x + self.dropout(self.cross_attention( + x, cross, cross, + attn_mask=cross_mask, + tau=tau, delta=delta + )[0]) + + y = x = self.norm2(x) + y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + + return self.norm3(x + y) + + +class Decoder(nn.Module): + def __init__(self, layers, norm_layer=None, projection=None): + super(Decoder, self).__init__() + self.layers = nn.ModuleList(layers) + self.norm = norm_layer + self.projection = projection + + def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None): + for layer in self.layers: + x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta) + + if self.norm is not None: + x = self.norm(x) + + if self.projection is not None: + x = self.projection(x) + return x diff --git a/layers/cross_channel_attn.py b/layers/cross_channel_attn.py new file mode 100644 index 0000000..b2cc62e --- /dev/null +++ b/layers/cross_channel_attn.py @@ -0,0 +1,132 @@ +import torch +from torch import nn +import math + +class CrossChannelAttention(nn.Module): + """ + 对每个通道 i: + Query: 预测区间长度 pred_len 的可学习向量 (不依赖历史值) + Key/Value: 其它通道所有时间点的标量 -> 线性投影 + 输出: [B, pred_len, C] + 复杂度: O(C * pred_len * (C-1) * L),当 C 很小(如 <= 64)通常可接受 + """ + def __init__(self, seq_len, pred_len, c_in, + d_model=64, n_heads=4, + dropout=0.1, use_layernorm=True): + super().__init__() + assert d_model % n_heads == 0, "d_model 必须能整除 n_heads" + self.seq_len = seq_len + self.pred_len = pred_len + self.c_in = c_in + self.d_model = d_model + self.n_heads = n_heads + self.head_dim = d_model // n_heads + + # 可学习的预测步 Query Embeddings: [pred_len, d_model] + self.query_embed = nn.Parameter(torch.randn(pred_len, d_model)) + + # 标量值 -> d_model 投影 (共享给 Key/Value) + self.key_proj = nn.Linear(1, d_model) + self.value_proj = nn.Linear(1, d_model) + + # 输出压缩成标量 + self.out_proj = nn.Linear(d_model, 1) + + self.attn_dropout = nn.Dropout(dropout) + self.proj_dropout = nn.Dropout(dropout) + + self.use_ln = use_layernorm + if use_layernorm: + self.ln_q = nn.LayerNorm(d_model) + self.ln_kv = nn.LayerNorm(d_model) + + # 可选的时间 + 通道位置编码(简单可学习向量) + self.time_pos = nn.Parameter(torch.zeros(seq_len, d_model)) + self.channel_pos = nn.Parameter(torch.zeros(c_in, d_model)) + nn.init.normal_(self.time_pos, std=0.02) + nn.init.normal_(self.channel_pos, std=0.02) + nn.init.normal_(self.query_embed, std=0.02) + + def split_heads(self, x): + # x: [B, T, d_model] -> [B, n_heads, T, head_dim] + B, T, D = x.shape + return x.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) + + def merge_heads(self, x): + # x: [B, n_heads, T, head_dim] -> [B, T, d_model] + B, H, T, Hd = x.shape + return x.transpose(1, 2).contiguous().view(B, T, H * Hd) + + def forward(self, x): + """ + x: [B, L, C] + 返回: cross_out [B, pred_len, C] + """ + B, L, C = x.shape + assert L == self.seq_len and C == self.c_in + + # 准备 K,V: 对每个通道的时间序列做投影 + # 先变成 [B, C, L, 1] -> Linear -> [B, C, L, d] + xc = x.permute(0, 2, 1).unsqueeze(-1) # [B, C, L, 1] + K = self.key_proj(xc) # [B, C, L, d_model] + V = self.value_proj(xc) # [B, C, L, d_model] + + # 加位置编码(通道 + 时间) + # broadcast: time_pos [L, d_model] -> [1,1,L,d]; channel_pos [C,d_model]->[1,C,1,d] + K = K + self.time_pos.unsqueeze(0).unsqueeze(0) + self.channel_pos.unsqueeze(0).unsqueeze(2) + V = V + self.time_pos.unsqueeze(0).unsqueeze(0) + self.channel_pos.unsqueeze(0).unsqueeze(2) + + if self.use_ln: + K = self.ln_kv(K) + V = self.ln_kv(V) + + cross_outputs = [] + + # 预备 Query(所有通道共享 query 基形,再可选加通道偏移) + base_q = self.query_embed # [pred_len, d_model] + + for ci in range(C): + # 构造其它通道索引 + if C == 1: + # 单通道退化: 直接输出零或复制自身 + zero_out = x[:, -self.pred_len:, ci:ci+1] + cross_outputs.append(zero_out) + continue + other_idx = [j for j in range(C) if j != ci] + + K_i = K[:, other_idx, :, :] # [B, C-1, L, d_model] + V_i = V[:, other_idx, :, :] # [B, C-1, L, d_model] + + # 拉平成 token 维度 + K_i = K_i.reshape(B, (C-1)*L, self.d_model) # [B, (C-1)*L, d] + V_i = V_i.reshape(B, (C-1)*L, self.d_model) + + # Query: 复制到 batch,并加通道偏移(可选) + Q_i = base_q.unsqueeze(0).expand(B, self.pred_len, self.d_model) # [B, pred_len, d_model] + Q_i = Q_i + self.channel_pos[ci].unsqueeze(0).unsqueeze(0) + + if self.use_ln: + Q_i = self.ln_q(Q_i) + + # 分头 + Qh = self.split_heads(Q_i) # [B, H, pred_len, head_dim] + Kh = self.split_heads(K_i) # [B, H, (C-1)*L, head_dim] + Vh = self.split_heads(V_i) # [B, H, (C-1)*L, head_dim] + + # Attention: [B, H, pred_len, (C-1)*L] + scores = torch.matmul(Qh, Kh.transpose(-2, -1)) / math.sqrt(self.head_dim) + attn = torch.softmax(scores, dim=-1) + attn = self.attn_dropout(attn) + + # 上下文 + ctx = torch.matmul(attn, Vh) # [B, H, pred_len, head_dim] + ctx = self.merge_heads(ctx) # [B, pred_len, d_model] + ctx = self.proj_dropout(ctx) + + # 输出压缩到标量: [B, pred_len, 1] + out_ci = self.out_proj(ctx) + cross_outputs.append(out_ci) # list of [B, pred_len, 1] + + cross_out = torch.cat(cross_outputs, dim=-1) # [B, pred_len, C] + return cross_out + diff --git a/layers/mixer.py b/layers/mixer.py new file mode 100644 index 0000000..f814183 --- /dev/null +++ b/layers/mixer.py @@ -0,0 +1,136 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +class ChannelGraphMixer(nn.Module): + """ + 在 PatchTST 的通道独立输出上做一次可学习的稀疏跨通道交互. + 输入 z_list : 长度=M 的 list, 每个元素形状 [B, D] (单通道表示) + 输出 list, 形状同输入 + """ + def __init__(self, n_channel: int, dim: int, k: int = 5, tau: float = 0.2): + super().__init__() + self.k = k + self.tau = tau + self.dim = dim + self.A = nn.Parameter(torch.zeros(n_channel, n_channel)) # 可学习邻接 + self.mix = nn.Linear(dim, dim, bias=False) # 通道映射 + # 通道注意力过滤 CAF + self.se = nn.Sequential( + nn.Linear(dim, dim // 4, bias=False), + nn.ReLU(), + nn.Linear(dim // 4, 1, bias=False), + nn.Sigmoid() + ) + + # -------- util: 生成 row-wise top-k 稀疏邻接 ------------- + def _row_sparse(self, logits: torch.Tensor) -> torch.Tensor: + """ + logits: [M,M]. 返回一个稀疏矩阵, 每行只保留 top-k, 其余为 0 + """ + # Straight-through Gumbel–Softmax + g = -torch.empty_like(logits).exponential_().log() + y = (logits + g) / self.tau + probs = F.softmax(y, dim=-1) # 可导 + topk_val, _ = torch.topk(probs, self.k, dim=-1) + thr = topk_val[..., -1].unsqueeze(-1) # 每行阈值 + sparse = torch.where(probs >= thr, probs, torch.zeros_like(probs)) + return sparse.detach() + probs - probs.detach() # ST-estimator + + # -------------------------------------------------------- + def forward(self, z_list): + M = len(z_list) + B = z_list[0].shape[0] + Z = torch.stack(z_list, dim=1) # [B,M,D] + alpha = self.se(Z).squeeze(-1) # [B,M] 通道重要性 + + A_sparse = self._row_sparse(self.A) # [M,M] + + out = [] + for i in range(M): + # 汇聚来自其余通道的表示 + agg = 0 + for j in range(M): + if A_sparse[i, j] != 0: + agg = agg + alpha[:, j:j+1] * A_sparse[i, j] * self.mix(z_list[j]) + out.append(z_list[i] + agg) # 残差 + return out + + +class HierarchicalGraphMixer(nn.Module): + """ + 分层图混合器,同时考虑宏观通道关系和微观 Patch 级别注意力。 + 输入 z : 形状为 [B, C, N, D] 的张量 + 输出 z_out : 形状同输入 + """ + def __init__(self, n_channel: int, dim: int, k: int = 5, tau: float = 0.2): + super().__init__() + self.k = k + self.tau = tau + + # Level 1: Channel Graph + self.A = nn.Parameter(torch.zeros(n_channel, n_channel)) + self.se = nn.Sequential( + nn.Linear(dim, dim // 4, bias=False), nn.ReLU(), + nn.Linear(dim // 4, 1, bias=False), nn.Sigmoid() + ) + + # Level 2: Patch Cross-Attention + self.q_proj = nn.Linear(dim, dim) + self.k_proj = nn.Linear(dim, dim) + self.v_proj = nn.Linear(dim, dim) + self.out_proj = nn.Linear(dim, dim) + self.norm = nn.LayerNorm(dim) + + def _row_sparse(self, logits: torch.Tensor) -> torch.Tensor: + # (Gumbel-Softmax 函数,无需改变) + g = -torch.empty_like(logits).exponential_().log() + y = (logits + g) / self.tau + probs = F.softmax(y, dim=-1) + topk_val, _ = torch.topk(probs, self.k, dim=-1) + thr = topk_val[..., -1].unsqueeze(-1) + sparse = torch.where(probs >= thr, probs, torch.zeros_like(probs)) + return sparse.detach() + probs - probs.detach() + + def forward(self, z): + # z 的形状: [B, C, N, D] + B, C, N, D = z.shape + + # --- Level 1: 计算宏观权重 --- + A_sparse = self._row_sparse(self.A) # 通道连接稀疏图 A_sparse: [C, C] + + # --- Level 2: 跨通道 Patch 交互 --- + out_z = torch.zeros_like(z) + for i in range(C): # 遍历每个目标通道 i + target_z = z[:, i, :, :] # [B, N, D] + + # 准备聚合来自其他通道的 patch 级别上下文 + aggregated_context = torch.zeros_like(target_z) + + for j in range(C): # 遍历每个源通道 j + if A_sparse[i, j] != 0: + source_z = z[:, j, :, :] # [B, N, D] + + # --- 执行交叉注意力 --- + Q = self.q_proj(target_z) # Query 来自目标通道 i + K = self.k_proj(source_z) # Key 来自源通道 j + V = self.v_proj(source_z) # Value 来自源通道 j + + attn_scores = torch.bmm(Q, K.transpose(1, 2)) / math.sqrt(D) + attn_probs = F.softmax(attn_scores, dim=-1) # [B, N, N] + + context = torch.bmm(attn_probs, V) # [B, N, D], 从 j 聚合到 i 的上下文 + + # 宏观权重2 (通道连接强度): A_sparse[i, j] -> 标量 + + + # 加权上下文 + weighted_context = A_sparse[i, j] * context + aggregated_context = aggregated_context + weighted_context + + # 将聚合后的上下文通过输出层,并与原始目标表示相加(残差连接) + # LayerNorm 增加稳定性 + out_z[:, i, :, :] = self.norm(target_z + self.out_proj(aggregated_context)) + + return out_z diff --git a/layers/ps_loss.py b/layers/ps_loss.py new file mode 100644 index 0000000..8cdb1cc --- /dev/null +++ b/layers/ps_loss.py @@ -0,0 +1,239 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + + +class PSLoss(nn.Module): + """ + Patch-wise Structural (PS) Loss for time series forecasting. + + Implements the loss function described in the paper that combines: + 1. Fourier-based Adaptive Patching (FAP) + 2. Patch-wise Structural Loss (correlation, variance, mean) + 3. Gradient-based Dynamic Weighting (GDW) + """ + + def __init__(self, patch_len_threshold=64, lambda_ps=5.0, use_gdw=True): + super(PSLoss, self).__init__() + self.patch_len_threshold = patch_len_threshold + self.lambda_ps = lambda_ps + self.use_gdw = use_gdw + self.kl_loss = nn.KLDivLoss(reduction='none') + self.mse_loss = nn.MSELoss() + + def create_patches(self, x, patch_len, stride): + """ + Create patches from time series data. + + Args: + x: Input tensor of shape [B, L, C] + patch_len: Length of each patch + stride: Stride for patching + + Returns: + patches: Tensor of shape [B, C, N, P] where N is number of patches, P is patch length + """ + B, L, C = x.shape + + num_patches = (L - patch_len) // stride + 1 + patches = x.unfold(1, patch_len, stride) # [B, N, C, P] + patches = patches.permute(0, 2, 1, 3) # [B, C, N, P] + + return patches + + def fourier_based_adaptive_patching(self, true, pred): + """ + Fourier-based Adaptive Patching (FAP) to determine optimal patch length. + + Args: + true: Ground truth tensor [B, L, C] + pred: Prediction tensor [B, L, C] + + Returns: + true_patch: Patches from ground truth [B, C, N, P] + pred_patch: Patches from prediction [B, C, N, P] + """ + # Get dominant frequency from ground truth + true_fft = torch.fft.rfft(true, dim=1) + frequency_list = torch.abs(true_fft).mean(0).mean(-1) + frequency_list[:1] = 0.0 # Remove DC component + top_index = torch.argmax(frequency_list) + + # Calculate period and patch length + period = true.shape[1] // top_index if top_index > 0 else self.patch_len_threshold + patch_len = min(period // 2, self.patch_len_threshold) + patch_len = max(patch_len, 4) # Minimum patch length + stride = max(patch_len // 2, 1) + + # Create patches + true_patch = self.create_patches(true, patch_len, stride) + pred_patch = self.create_patches(pred, patch_len, stride) + + return true_patch, pred_patch + + def patch_wise_structural_loss(self, true_patch, pred_patch): + """ + Calculate patch-wise structural loss components. + + Args: + true_patch: Ground truth patches [B, C, N, P] + pred_patch: Prediction patches [B, C, N, P] + + Returns: + corr_loss: Correlation loss + var_loss: Variance loss + mean_loss: Mean loss + """ + # Calculate patch statistics + true_patch_mean = torch.mean(true_patch, dim=-1, keepdim=True) # [B, C, N, 1] + pred_patch_mean = torch.mean(pred_patch, dim=-1, keepdim=True) # [B, C, N, 1] + + true_patch_var = torch.var(true_patch, dim=-1, keepdim=True, unbiased=False) + pred_patch_var = torch.var(pred_patch, dim=-1, keepdim=True, unbiased=False) + true_patch_std = torch.sqrt(true_patch_var + 1e-8) + pred_patch_std = torch.sqrt(pred_patch_var + 1e-8) + + # Calculate covariance + true_centered = true_patch - true_patch_mean + pred_centered = pred_patch - pred_patch_mean + covariance = torch.mean(true_centered * pred_centered, dim=-1, keepdim=True) + + # 1. Correlation Loss (based on Pearson correlation coefficient) + correlation = covariance / (true_patch_std * pred_patch_std + 1e-8) + corr_loss = (1.0 - correlation).mean() + + # 2. Variance Loss (using KL divergence of softmax distributions) + true_patch_softmax = F.softmax(true_patch, dim=-1) + pred_patch_log_softmax = F.log_softmax(pred_patch, dim=-1) + var_loss = self.kl_loss(pred_patch_log_softmax, true_patch_softmax).sum(dim=-1).mean() + + # 3. Mean Loss (MAE between patch means) + mean_loss = torch.abs(true_patch_mean - pred_patch_mean).mean() + + return corr_loss, var_loss, mean_loss + + def gradient_based_dynamic_weighting(self, model, corr_loss, var_loss, mean_loss, true, pred): + """ + Gradient-based Dynamic Weighting (GDW) for balancing loss components. + + Args: + model: The neural network model + corr_loss, var_loss, mean_loss: Loss components + true: Ground truth tensor [B, L, C] + pred: Prediction tensor [B, L, C] + + Returns: + alpha, beta, gamma: Dynamic weights for the three loss components + """ + if not self.use_gdw or not self.training: + return 1.0, 1.0, 1.0 + + try: + # Get model parameters for gradient calculation + if hasattr(model, 'projector'): + params = list(model.projector.parameters()) + elif hasattr(model, 'projection'): + params = list(model.projection.parameters()) + else: + # Use output layer parameters + params = list(model.parameters())[-2:] # Last linear layer + + if not params: + return 1.0, 1.0, 1.0 + + # Calculate gradients for each loss component + corr_grad = torch.autograd.grad(corr_loss, params, retain_graph=True, allow_unused=True) + var_grad = torch.autograd.grad(var_loss, params, retain_graph=True, allow_unused=True) + mean_grad = torch.autograd.grad(mean_loss, params, retain_graph=True, allow_unused=True) + + # Filter out None gradients and calculate norms + corr_grad_norm = sum(g.norm().item() for g in corr_grad if g is not None) + var_grad_norm = sum(g.norm().item() for g in var_grad if g is not None) + mean_grad_norm = sum(g.norm().item() for g in mean_grad if g is not None) + + if corr_grad_norm == 0 or var_grad_norm == 0 or mean_grad_norm == 0: + return 1.0, 1.0, 1.0 + + # Calculate average gradient magnitude + grad_avg = (corr_grad_norm + var_grad_norm + mean_grad_norm) / 3.0 + + # Calculate dynamic weights + alpha = grad_avg / corr_grad_norm + beta = grad_avg / var_grad_norm + gamma = grad_avg / mean_grad_norm + + # Calculate scaling factors for gamma + true_flat = true.view(-1, true.shape[-1]) + pred_flat = pred.view(-1, pred.shape[-1]) + + true_mean = torch.mean(true_flat, dim=0, keepdim=True) + pred_mean = torch.mean(pred_flat, dim=0, keepdim=True) + true_std = torch.std(true_flat, dim=0, keepdim=True) + 1e-8 + pred_std = torch.std(pred_flat, dim=0, keepdim=True) + 1e-8 + + covariance = torch.mean((true_flat - true_mean) * (pred_flat - pred_mean), dim=0, keepdim=True) + correlation_sim = covariance / (true_std * pred_std) + correlation_sim = (1.0 + correlation_sim.mean()) * 0.5 + + variance_sim = (2 * true_std.mean() * pred_std.mean()) / (true_std.mean()**2 + pred_std.mean()**2) + + c = 0.5 * (1 + correlation_sim) + v = variance_sim + gamma = gamma * c * v + + return alpha.item(), beta.item(), gamma.item() + + except Exception: + # Fallback to equal weights if gradient calculation fails + return 1.0, 1.0, 1.0 + + def forward(self, pred, true, model=None): + """ + Forward pass of PS Loss. + + Args: + pred: Predictions [B, L, C] + true: Ground truth [B, L, C] + model: Neural network model (for gradient-based weighting) + + Returns: + total_loss: Combined MSE + PS loss + loss_dict: Dictionary with individual loss components + """ + # Standard MSE loss + mse_loss = self.mse_loss(pred, true) + + # Fourier-based adaptive patching + true_patch, pred_patch = self.fourier_based_adaptive_patching(true, pred) + + # Patch-wise structural loss + corr_loss, var_loss, mean_loss = self.patch_wise_structural_loss(true_patch, pred_patch) + + # Gradient-based dynamic weighting + if model is not None and self.use_gdw: + alpha, beta, gamma = self.gradient_based_dynamic_weighting( + model, corr_loss, var_loss, mean_loss, true, pred + ) + else: + alpha, beta, gamma = 1.0, 1.0, 1.0 + + # Combine PS loss components + ps_loss = alpha * corr_loss + beta * var_loss + gamma * mean_loss + + # Total loss: MSE + λ * PS_loss + total_loss = mse_loss + self.lambda_ps * ps_loss + + loss_dict = { + 'mse_loss': mse_loss.item(), + 'ps_loss': ps_loss.item(), + 'corr_loss': corr_loss.item(), + 'var_loss': var_loss.item(), + 'mean_loss': mean_loss.item(), + 'alpha': alpha, + 'beta': beta, + 'gamma': gamma, + 'total_loss': total_loss.item() + } + + return total_loss, loss_dict