91 lines
3.5 KiB
Python
91 lines
3.5 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from layers.Embed import PositionalEmbedding
|
|
from layers.SelfAttention_Family import FullAttention, AttentionLayer
|
|
from layers.Transformer_EncDec import EncoderLayer
|
|
|
|
class TSTEncoder(nn.Module):
|
|
"""
|
|
Transformer encoder for PatchTST, adapted for Time-Series-Library-main style
|
|
"""
|
|
def __init__(self, q_len, d_model, n_heads, d_k=None, d_v=None, d_ff=None,
|
|
norm='BatchNorm', attn_dropout=0., dropout=0., activation='gelu',
|
|
n_layers=1):
|
|
super().__init__()
|
|
|
|
d_k = d_model // n_heads if d_k is None else d_k
|
|
d_v = d_model // n_heads if d_v is None else d_v
|
|
d_ff = d_model * 4 if d_ff is None else d_ff
|
|
|
|
self.layers = nn.ModuleList([
|
|
EncoderLayer(
|
|
AttentionLayer(
|
|
FullAttention(False, attention_dropout=attn_dropout),
|
|
d_model, n_heads
|
|
),
|
|
d_model,
|
|
d_ff,
|
|
dropout=dropout,
|
|
activation=activation
|
|
) for i in range(n_layers)
|
|
])
|
|
|
|
def forward(self, src, attn_mask=None):
|
|
output = src
|
|
attns = []
|
|
for layer in self.layers:
|
|
output, attn = layer(output, attn_mask)
|
|
attns.append(attn)
|
|
return output, attns
|
|
|
|
|
|
class TSTiEncoder(nn.Module):
|
|
"""
|
|
Channel-independent TST Encoder adapted for Time-Series-Library-main
|
|
"""
|
|
def __init__(self, c_in, patch_num, patch_len, max_seq_len=1024,
|
|
n_layers=3, d_model=128, n_heads=16, d_k=None, d_v=None,
|
|
d_ff=256, norm='BatchNorm', attn_dropout=0., dropout=0.,
|
|
activation="gelu"):
|
|
super().__init__()
|
|
|
|
self.patch_num = patch_num
|
|
self.patch_len = patch_len
|
|
|
|
# Input encoding - projection of feature vectors onto a d-dim vector space
|
|
self.W_P = nn.Linear(patch_len, d_model)
|
|
|
|
# Positional encoding using Time-Series-Library-main's PositionalEmbedding
|
|
self.pos_embedding = PositionalEmbedding(d_model, max_len=max_seq_len)
|
|
|
|
# Residual dropout
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
# Encoder
|
|
self.encoder = TSTEncoder(patch_num, d_model, n_heads, d_k=d_k, d_v=d_v,
|
|
d_ff=d_ff, norm=norm, attn_dropout=attn_dropout,
|
|
dropout=dropout, activation=activation, n_layers=n_layers)
|
|
|
|
def forward(self, x):
|
|
# x: [bs x nvars x patch_num x patch_len]
|
|
bs, n_vars, patch_num, patch_len = x.shape
|
|
|
|
# Input encoding: project patch_len to d_model
|
|
x = self.W_P(x) # x: [bs x nvars x patch_num x d_model]
|
|
|
|
# Reshape for attention: combine batch and channel dimensions
|
|
u = torch.reshape(x, (bs * n_vars, patch_num, x.shape[-1])) # u: [bs * nvars x patch_num x d_model]
|
|
|
|
# Add positional encoding
|
|
pos = self.pos_embedding(u) # Get positional encoding [bs*nvars x patch_num x d_model]
|
|
u = self.dropout(u + pos[:, :patch_num, :]) # Add positional encoding
|
|
|
|
# Encoder
|
|
z, attns = self.encoder(u) # z: [bs * nvars x patch_num x d_model]
|
|
|
|
# Reshape back to separate batch and channel dimensions
|
|
z = torch.reshape(z, (bs, n_vars, patch_num, z.shape[-1])) # z: [bs x nvars x patch_num x d_model]
|
|
z = z.permute(0, 1, 3, 2) # z: [bs x nvars x d_model x patch_num]
|
|
|
|
return z |