first commit
This commit is contained in:
91
layers/TSTEncoder.py
Normal file
91
layers/TSTEncoder.py
Normal file
@ -0,0 +1,91 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from layers.Embed import PositionalEmbedding
|
||||
from layers.SelfAttention_Family import FullAttention, AttentionLayer
|
||||
from layers.Transformer_EncDec import EncoderLayer
|
||||
|
||||
class TSTEncoder(nn.Module):
|
||||
"""
|
||||
Transformer encoder for PatchTST, adapted for Time-Series-Library-main style
|
||||
"""
|
||||
def __init__(self, q_len, d_model, n_heads, d_k=None, d_v=None, d_ff=None,
|
||||
norm='BatchNorm', attn_dropout=0., dropout=0., activation='gelu',
|
||||
n_layers=1):
|
||||
super().__init__()
|
||||
|
||||
d_k = d_model // n_heads if d_k is None else d_k
|
||||
d_v = d_model // n_heads if d_v is None else d_v
|
||||
d_ff = d_model * 4 if d_ff is None else d_ff
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
EncoderLayer(
|
||||
AttentionLayer(
|
||||
FullAttention(False, attention_dropout=attn_dropout),
|
||||
d_model, n_heads
|
||||
),
|
||||
d_model,
|
||||
d_ff,
|
||||
dropout=dropout,
|
||||
activation=activation
|
||||
) for i in range(n_layers)
|
||||
])
|
||||
|
||||
def forward(self, src, attn_mask=None):
|
||||
output = src
|
||||
attns = []
|
||||
for layer in self.layers:
|
||||
output, attn = layer(output, attn_mask)
|
||||
attns.append(attn)
|
||||
return output, attns
|
||||
|
||||
|
||||
class TSTiEncoder(nn.Module):
|
||||
"""
|
||||
Channel-independent TST Encoder adapted for Time-Series-Library-main
|
||||
"""
|
||||
def __init__(self, c_in, patch_num, patch_len, max_seq_len=1024,
|
||||
n_layers=3, d_model=128, n_heads=16, d_k=None, d_v=None,
|
||||
d_ff=256, norm='BatchNorm', attn_dropout=0., dropout=0.,
|
||||
activation="gelu"):
|
||||
super().__init__()
|
||||
|
||||
self.patch_num = patch_num
|
||||
self.patch_len = patch_len
|
||||
|
||||
# Input encoding - projection of feature vectors onto a d-dim vector space
|
||||
self.W_P = nn.Linear(patch_len, d_model)
|
||||
|
||||
# Positional encoding using Time-Series-Library-main's PositionalEmbedding
|
||||
self.pos_embedding = PositionalEmbedding(d_model, max_len=max_seq_len)
|
||||
|
||||
# Residual dropout
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
# Encoder
|
||||
self.encoder = TSTEncoder(patch_num, d_model, n_heads, d_k=d_k, d_v=d_v,
|
||||
d_ff=d_ff, norm=norm, attn_dropout=attn_dropout,
|
||||
dropout=dropout, activation=activation, n_layers=n_layers)
|
||||
|
||||
def forward(self, x):
|
||||
# x: [bs x nvars x patch_num x patch_len]
|
||||
bs, n_vars, patch_num, patch_len = x.shape
|
||||
|
||||
# Input encoding: project patch_len to d_model
|
||||
x = self.W_P(x) # x: [bs x nvars x patch_num x d_model]
|
||||
|
||||
# Reshape for attention: combine batch and channel dimensions
|
||||
u = torch.reshape(x, (bs * n_vars, patch_num, x.shape[-1])) # u: [bs * nvars x patch_num x d_model]
|
||||
|
||||
# Add positional encoding
|
||||
pos = self.pos_embedding(u) # Get positional encoding [bs*nvars x patch_num x d_model]
|
||||
u = self.dropout(u + pos[:, :patch_num, :]) # Add positional encoding
|
||||
|
||||
# Encoder
|
||||
z, attns = self.encoder(u) # z: [bs * nvars x patch_num x d_model]
|
||||
|
||||
# Reshape back to separate batch and channel dimensions
|
||||
z = torch.reshape(z, (bs, n_vars, patch_num, z.shape[-1])) # z: [bs x nvars x patch_num x d_model]
|
||||
z = z.permute(0, 1, 3, 2) # z: [bs x nvars x d_model x patch_num]
|
||||
|
||||
return z
|
Reference in New Issue
Block a user