import torch import torch.nn as nn import torch.nn.functional as F from layers.Embed import PositionalEmbedding from layers.SelfAttention_Family import FullAttention, AttentionLayer from layers.Transformer_EncDec import EncoderLayer class TSTEncoder(nn.Module): """ Transformer encoder for PatchTST, adapted for Time-Series-Library-main style """ def __init__(self, q_len, d_model, n_heads, d_k=None, d_v=None, d_ff=None, norm='BatchNorm', attn_dropout=0., dropout=0., activation='gelu', n_layers=1): super().__init__() d_k = d_model // n_heads if d_k is None else d_k d_v = d_model // n_heads if d_v is None else d_v d_ff = d_model * 4 if d_ff is None else d_ff self.layers = nn.ModuleList([ EncoderLayer( AttentionLayer( FullAttention(False, attention_dropout=attn_dropout), d_model, n_heads ), d_model, d_ff, dropout=dropout, activation=activation ) for i in range(n_layers) ]) def forward(self, src, attn_mask=None): output = src attns = [] for layer in self.layers: output, attn = layer(output, attn_mask) attns.append(attn) return output, attns class TSTiEncoder(nn.Module): """ Channel-independent TST Encoder adapted for Time-Series-Library-main """ def __init__(self, c_in, patch_num, patch_len, max_seq_len=1024, n_layers=3, d_model=128, n_heads=16, d_k=None, d_v=None, d_ff=256, norm='BatchNorm', attn_dropout=0., dropout=0., activation="gelu"): super().__init__() self.patch_num = patch_num self.patch_len = patch_len # Input encoding - projection of feature vectors onto a d-dim vector space self.W_P = nn.Linear(patch_len, d_model) # Positional encoding using Time-Series-Library-main's PositionalEmbedding self.pos_embedding = PositionalEmbedding(d_model, max_len=max_seq_len) # Residual dropout self.dropout = nn.Dropout(dropout) # Encoder self.encoder = TSTEncoder(patch_num, d_model, n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout, dropout=dropout, activation=activation, n_layers=n_layers) def forward(self, x): # x: [bs x nvars x patch_num x patch_len] bs, n_vars, patch_num, patch_len = x.shape # Input encoding: project patch_len to d_model x = self.W_P(x) # x: [bs x nvars x patch_num x d_model] # Reshape for attention: combine batch and channel dimensions u = torch.reshape(x, (bs * n_vars, patch_num, x.shape[-1])) # u: [bs * nvars x patch_num x d_model] # Add positional encoding pos = self.pos_embedding(u) # Get positional encoding [bs*nvars x patch_num x d_model] u = self.dropout(u + pos[:, :patch_num, :]) # Add positional encoding # Encoder z, attns = self.encoder(u) # z: [bs * nvars x patch_num x d_model] # Reshape back to separate batch and channel dimensions z = torch.reshape(z, (bs, n_vars, patch_num, z.shape[-1])) # z: [bs x nvars x patch_num x d_model] z = z.permute(0, 1, 3, 2) # z: [bs x nvars x d_model x patch_num] return z