first commit
This commit is contained in:
131
layers/Crossformer_EncDec.py
Normal file
131
layers/Crossformer_EncDec.py
Normal file
@ -0,0 +1,131 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange, repeat
|
||||
from layers.SelfAttention_Family import TwoStageAttentionLayer
|
||||
|
||||
|
||||
class SegMerging(nn.Module):
|
||||
def __init__(self, d_model, win_size, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.win_size = win_size
|
||||
self.linear_trans = nn.Linear(win_size * d_model, d_model)
|
||||
self.norm = norm_layer(win_size * d_model)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size, ts_d, seg_num, d_model = x.shape
|
||||
pad_num = seg_num % self.win_size
|
||||
if pad_num != 0:
|
||||
pad_num = self.win_size - pad_num
|
||||
x = torch.cat((x, x[:, :, -pad_num:, :]), dim=-2)
|
||||
|
||||
seg_to_merge = []
|
||||
for i in range(self.win_size):
|
||||
seg_to_merge.append(x[:, :, i::self.win_size, :])
|
||||
x = torch.cat(seg_to_merge, -1)
|
||||
|
||||
x = self.norm(x)
|
||||
x = self.linear_trans(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class scale_block(nn.Module):
|
||||
def __init__(self, configs, win_size, d_model, n_heads, d_ff, depth, dropout, \
|
||||
seg_num=10, factor=10):
|
||||
super(scale_block, self).__init__()
|
||||
|
||||
if win_size > 1:
|
||||
self.merge_layer = SegMerging(d_model, win_size, nn.LayerNorm)
|
||||
else:
|
||||
self.merge_layer = None
|
||||
|
||||
self.encode_layers = nn.ModuleList()
|
||||
|
||||
for i in range(depth):
|
||||
self.encode_layers.append(TwoStageAttentionLayer(configs, seg_num, factor, d_model, n_heads, \
|
||||
d_ff, dropout))
|
||||
|
||||
def forward(self, x, attn_mask=None, tau=None, delta=None):
|
||||
_, ts_dim, _, _ = x.shape
|
||||
|
||||
if self.merge_layer is not None:
|
||||
x = self.merge_layer(x)
|
||||
|
||||
for layer in self.encode_layers:
|
||||
x = layer(x)
|
||||
|
||||
return x, None
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, attn_layers):
|
||||
super(Encoder, self).__init__()
|
||||
self.encode_blocks = nn.ModuleList(attn_layers)
|
||||
|
||||
def forward(self, x):
|
||||
encode_x = []
|
||||
encode_x.append(x)
|
||||
|
||||
for block in self.encode_blocks:
|
||||
x, attns = block(x)
|
||||
encode_x.append(x)
|
||||
|
||||
return encode_x, None
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
def __init__(self, self_attention, cross_attention, seg_len, d_model, d_ff=None, dropout=0.1):
|
||||
super(DecoderLayer, self).__init__()
|
||||
self.self_attention = self_attention
|
||||
self.cross_attention = cross_attention
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.MLP1 = nn.Sequential(nn.Linear(d_model, d_model),
|
||||
nn.GELU(),
|
||||
nn.Linear(d_model, d_model))
|
||||
self.linear_pred = nn.Linear(d_model, seg_len)
|
||||
|
||||
def forward(self, x, cross):
|
||||
batch = x.shape[0]
|
||||
x = self.self_attention(x)
|
||||
x = rearrange(x, 'b ts_d out_seg_num d_model -> (b ts_d) out_seg_num d_model')
|
||||
|
||||
cross = rearrange(cross, 'b ts_d in_seg_num d_model -> (b ts_d) in_seg_num d_model')
|
||||
tmp, attn = self.cross_attention(x, cross, cross, None, None, None,)
|
||||
x = x + self.dropout(tmp)
|
||||
y = x = self.norm1(x)
|
||||
y = self.MLP1(y)
|
||||
dec_output = self.norm2(x + y)
|
||||
|
||||
dec_output = rearrange(dec_output, '(b ts_d) seg_dec_num d_model -> b ts_d seg_dec_num d_model', b=batch)
|
||||
layer_predict = self.linear_pred(dec_output)
|
||||
layer_predict = rearrange(layer_predict, 'b out_d seg_num seg_len -> b (out_d seg_num) seg_len')
|
||||
|
||||
return dec_output, layer_predict
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, layers):
|
||||
super(Decoder, self).__init__()
|
||||
self.decode_layers = nn.ModuleList(layers)
|
||||
|
||||
|
||||
def forward(self, x, cross):
|
||||
final_predict = None
|
||||
i = 0
|
||||
|
||||
ts_d = x.shape[1]
|
||||
for layer in self.decode_layers:
|
||||
cross_enc = cross[i]
|
||||
x, layer_predict = layer(x, cross_enc)
|
||||
if final_predict is None:
|
||||
final_predict = layer_predict
|
||||
else:
|
||||
final_predict = final_predict + layer_predict
|
||||
i += 1
|
||||
|
||||
final_predict = rearrange(final_predict, 'b (out_d seg_num) seg_len -> b (seg_num seg_len) out_d', out_d=ts_d)
|
||||
|
||||
return final_predict
|
Reference in New Issue
Block a user