529 lines
21 KiB
Python
529 lines
21 KiB
Python
import torch
|
||
from torch import nn
|
||
import torch.nn.functional as F
|
||
from layers.SelfAttention_Family import FullAttention, AttentionLayer
|
||
|
||
# 需要 Mamba2 作为外层编码器
|
||
from mamba_ssm.modules.mamba2 import Mamba2
|
||
|
||
|
||
|
||
# -------------------- Routing(余弦路由,和论文一致) --------------------
|
||
class RoutingModule(nn.Module):
|
||
def __init__(self, d_model):
|
||
super().__init__()
|
||
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
||
self.k_proj = nn.Linear(d_model, d_model, bias=False)
|
||
with torch.no_grad():
|
||
nn.init.eye_(self.q_proj.weight)
|
||
nn.init.eye_(self.k_proj.weight)
|
||
self.q_proj.weight._no_reinit = True
|
||
self.k_proj.weight._no_reinit = True
|
||
|
||
def forward(self, x, mask=None):
|
||
"""
|
||
x: (B, L, D)
|
||
mask: (B, L) bool, True=有效
|
||
返回:
|
||
boundary_prob: (B, L, 2)
|
||
boundary_mask: (B, L) bool
|
||
selected_probs: (B, L, 1)
|
||
"""
|
||
B, L, D = x.shape
|
||
q = F.normalize(self.q_proj(x[:, :-1]), dim=-1) # (B, L-1, D)
|
||
k = F.normalize(self.k_proj(x[:, 1:]), dim=-1) # (B, L-1, D)
|
||
cos_sim = (q * k).sum(dim=-1) # (B, L-1)
|
||
p = torch.clamp((1 - cos_sim) / 2, 0.0, 1.0) # (B, L-1)
|
||
p = F.pad(p, (1, 0), value=1.0) # 强制首位是边界
|
||
|
||
if mask is not None:
|
||
p = p * mask.float()
|
||
p[:, 0] = torch.where(mask[:, 0], torch.ones_like(p[:, 0]), p[:, 0])
|
||
|
||
boundary_prob = torch.stack([1 - p, p], dim=-1) # (B, L, 2)
|
||
selected_idx = boundary_prob.argmax(dim=-1)
|
||
boundary_mask = (selected_idx == 1)
|
||
if mask is not None:
|
||
boundary_mask = boundary_mask & mask
|
||
selected_probs = boundary_prob.gather(-1, selected_idx.unsqueeze(-1)) # (B, L, 1)
|
||
return boundary_prob, boundary_mask, selected_probs
|
||
|
||
|
||
# -------------------- 选择并右侧零pad(不丢弃、不重复填充) --------------------
|
||
def select_and_right_pad(x, boundary_mask):
|
||
"""
|
||
内存优化版本:减少临时tensor创建
|
||
x: (B, L, D), boundary_mask: (B, L) bool
|
||
返回:
|
||
x_pad: (B, T_max, D)
|
||
key_padding_mask: (B, T_max) bool, True=有效
|
||
lengths: (B,)
|
||
"""
|
||
B, L, D = x.shape
|
||
device = x.device
|
||
lengths = boundary_mask.sum(dim=1) # (B,)
|
||
T_max = int(lengths.max().item()) if lengths.max() > 0 else 1
|
||
|
||
x_pad = x.new_zeros(B, T_max, D)
|
||
key_padding_mask = torch.zeros(B, T_max, dtype=torch.bool, device=device)
|
||
|
||
# 预创建默认索引tensor避免重复创建
|
||
default_idx = torch.tensor([0], device=device)
|
||
|
||
for b in range(B):
|
||
mask_b = boundary_mask[b]
|
||
if mask_b.any():
|
||
idx = mask_b.nonzero(as_tuple=True)[0] # 更高效的nonzero
|
||
t = idx.numel()
|
||
x_pad[b, :t] = x[b, idx]
|
||
key_padding_mask[b, :t] = True
|
||
else:
|
||
# 使用预创建的tensor
|
||
x_pad[b, 0] = x[b, default_idx]
|
||
key_padding_mask[b, 0] = True
|
||
|
||
return x_pad, key_padding_mask, lengths
|
||
|
||
|
||
# -------------------- Mamba2 堆叠(外层编码器) --------------------
|
||
class Mamba2Encoder(nn.Module):
|
||
def __init__(self, d_model, depth=4, dropout=0.0):
|
||
super().__init__()
|
||
self.layers = nn.ModuleList([Mamba2(d_model=d_model) for _ in range(depth)])
|
||
self.norm = nn.LayerNorm(d_model)
|
||
self.dropout = nn.Dropout(dropout)
|
||
|
||
def forward(self, x):
|
||
for layer in self.layers:
|
||
x = layer(x)
|
||
x = self.norm(x)
|
||
x = self.dropout(x)
|
||
return x
|
||
|
||
|
||
# -------------------- 两层Encoder + DC(变长,不丢信息;以比率约束压缩) --------------------
|
||
class DCEmbedding2StageVarLen(nn.Module):
|
||
"""
|
||
- Stage 0: (B*nvars, L, 1) -> Linear(D0) -> Mamba2(D0) -> Routing -> 选择 -> 扩宽到 D1
|
||
- Stage 1: (B*nvars, L0_sel, D1) -> Mamba2(D1) -> Routing -> 选择
|
||
输出:
|
||
enc_out: (B*nvars, T_max, D1)
|
||
key_padding_mask: (B*nvars, T_max)
|
||
n_vars: int
|
||
aux: dict(含两层ratio loss与边界信息)
|
||
"""
|
||
def __init__(self, d_model_out, d_model_stage0, depth_enc0=4, depth_enc1=4, dropout=0.0,
|
||
target_ratio0=0.25, target_ratio1=0.5):
|
||
super().__init__()
|
||
assert d_model_out >= d_model_stage0, "要求 D0 <= D1"
|
||
self.d0 = d_model_stage0
|
||
self.d1 = d_model_out
|
||
|
||
# 标量 -> D0
|
||
self.input_proj = nn.Linear(1, self.d0)
|
||
|
||
# Stage 0
|
||
self.enc0 = Mamba2Encoder(self.d0, depth=depth_enc0, dropout=dropout)
|
||
self.router0 = RoutingModule(self.d0)
|
||
delta = self.d1 - self.d0
|
||
self.pad_vec = nn.Parameter(torch.zeros(delta)) if delta > 0 else None
|
||
self.target_ratio0 = target_ratio0
|
||
|
||
# Stage 1
|
||
self.enc1 = Mamba2Encoder(self.d1, depth=depth_enc1, dropout=dropout)
|
||
self.router1 = RoutingModule(self.d1)
|
||
self.target_ratio1 = target_ratio1
|
||
|
||
def _expand_width(self, x):
|
||
if self.pad_vec is None:
|
||
return x
|
||
B, L, _ = x.shape
|
||
return torch.cat([x, self.pad_vec.view(1, 1, -1).expand(B, L, -1)], dim=-1)
|
||
|
||
@staticmethod
|
||
def _ratio_loss(boundary_mask: torch.Tensor, boundary_prob: torch.Tensor, target_ratio: float) -> torch.Tensor:
|
||
eps = 1e-6
|
||
F_act = boundary_mask.float().mean(dim=1) # (B,)
|
||
G_prob = boundary_prob[..., 1].mean(dim=1) # (B,)
|
||
N = 1.0 / max(target_ratio, eps)
|
||
loss = N / (N - 1.0 + eps) * (((N - 1.0) * F_act * G_prob) + (1.0 - F_act) * (1.0 - G_prob))
|
||
return loss.mean()
|
||
|
||
def forward(self, x):
|
||
"""
|
||
x: (B, nvars, L)
|
||
内存优化版本:及时删除中间tensor
|
||
"""
|
||
B, nvars, L = x.shape
|
||
x = x.reshape(B * nvars, L, 1)
|
||
x = self.input_proj(x) # (B*nvars, L, D0)
|
||
|
||
# Stage 0
|
||
h0 = self.enc0(x) # (B*nvars, L, D0)
|
||
p0, bm0, _ = self.router0(h0)
|
||
h0_sel, mask0, len0 = select_and_right_pad(h0, bm0) # (B*nvars, L0_max, D0)
|
||
|
||
# 及时删除不需要的tensor
|
||
del h0
|
||
|
||
# h0_sel = self._expand_width(h0_sel) # (B*nvars, L0_max, D1)
|
||
|
||
# Stage 1
|
||
#h1 = self.enc1(h0_sel) # (B*nvars, L0_max, D1)
|
||
#p1, bm1, _ = self.router1(h1)
|
||
#bm1 = bm1 & mask0
|
||
#h1_sel, mask1, len1 = select_and_right_pad(h1, bm1) # (B*nvars, L1_max, D1)
|
||
|
||
# 及时删除中间tensor
|
||
#del h1, h0_sel
|
||
|
||
# 计算ratio loss时使用detach避免保存计算图
|
||
ratio_loss0 = self._ratio_loss(bm0, p0, target_ratio=self.target_ratio0)
|
||
# ratio_loss1 = self._ratio_loss(bm1, p1, target_ratio=self.target_ratio1)
|
||
|
||
# 简化aux字典,只保存必要信息
|
||
aux = {
|
||
"stage0": {"boundary_mask": bm0.detach(), "boundary_prob": p0.detach(), "lengths": len0.detach()},
|
||
# "stage1": {"boundary_mask": bm1.detach(), "boundary_prob": p1.detach(), "lengths": len1.detach()},
|
||
"ratio_loss0": ratio_loss0,
|
||
# "ratio_loss1": ratio_loss1,
|
||
}
|
||
|
||
return h0_sel, mask0, nvars, aux
|
||
|
||
|
||
# -------------------- Encoder/EncoderLayer(带 key_padding_mask 透传) --------------------
|
||
class EncoderLayerWithMask(nn.Module):
|
||
"""
|
||
与原EncoderLayer结构一致,但 forward 增加 key_padding_mask,并传入 AttentionLayer。
|
||
FFN 用简单的 MLP(与常规Transformer一致)。
|
||
"""
|
||
def __init__(self, attention: AttentionLayer, d_model, d_ff, dropout=0.1, activation="gelu"):
|
||
super().__init__()
|
||
self.attention = attention
|
||
self.dropout = nn.Dropout(dropout)
|
||
self.norm1 = nn.LayerNorm(d_model)
|
||
self.norm2 = nn.LayerNorm(d_model)
|
||
|
||
if activation == "relu":
|
||
act = nn.ReLU()
|
||
elif activation == "gelu":
|
||
act = nn.GELU()
|
||
else:
|
||
raise ValueError(f"Unsupported activation: {activation}")
|
||
|
||
self.ffn = nn.Sequential(
|
||
nn.Linear(d_model, d_ff),
|
||
act,
|
||
nn.Dropout(dropout),
|
||
nn.Linear(d_ff, d_model),
|
||
)
|
||
|
||
def forward(self, x, attn_mask=None, tau=None, delta=None, key_padding_mask=None):
|
||
# Multi-head attention with key padding mask
|
||
attn_out, attn = self.attention(
|
||
x, x, x, attn_mask, tau=tau, delta=delta, key_padding_mask=key_padding_mask
|
||
)
|
||
x = x + self.dropout(attn_out)
|
||
x = self.norm1(x)
|
||
|
||
# FFN
|
||
y = self.ffn(x)
|
||
x = x + self.dropout(y)
|
||
x = self.norm2(x)
|
||
return x, attn
|
||
|
||
|
||
class EncoderWithMask(nn.Module):
|
||
"""
|
||
与原Encoder类似,但 forward 支持 key_padding_mask,并传递给每一层的注意力。
|
||
"""
|
||
def __init__(self, attn_layers, norm_layer=None):
|
||
super().__init__()
|
||
self.attn_layers = nn.ModuleList(attn_layers)
|
||
self.norm = norm_layer
|
||
|
||
def forward(self, x, attn_mask=None, key_padding_mask=None):
|
||
attns = []
|
||
for attn_layer in self.attn_layers:
|
||
x, attn = attn_layer(x, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
|
||
attns.append(attn)
|
||
if self.norm is not None:
|
||
x = self.norm(x)
|
||
return x, attns
|
||
|
||
|
||
# -------------------- 门控注意力聚合 + 任务头(不依赖token数;保留信息) --------------------
|
||
def masked_softmax(logits: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||
"""
|
||
logits: (..., T)
|
||
mask: (..., T) bool, True=有效
|
||
"""
|
||
neg_inf = torch.finfo(logits.dtype).min
|
||
logits = logits.masked_fill(~mask, neg_inf)
|
||
return torch.softmax(logits, dim=dim)
|
||
|
||
class GatedAttnAggregator(nn.Module):
|
||
"""
|
||
门控注意力聚合器(可学习查询 + mask softmax + 值端门控)
|
||
输入: x: (B*, T, D), mask: (B*, T) bool
|
||
输出: slots: (B*, R, D) 其中 R 为聚合插槽数(可配置)
|
||
"""
|
||
def __init__(self, d_model: int, num_slots: int = 4, d_att: int = None, dropout: float = 0.1):
|
||
super().__init__()
|
||
self.d_model = d_model
|
||
self.R = num_slots
|
||
self.d_att = d_att or d_model
|
||
|
||
# 可学习查询(R个)
|
||
self.query = nn.Parameter(torch.randn(self.R, self.d_att) / (self.d_att ** 0.5))
|
||
|
||
# 线性投影
|
||
self.key_proj = nn.Linear(d_model, self.d_att)
|
||
self.val_proj = nn.Linear(d_model, d_model)
|
||
|
||
# 值端门控(逐token标量门)
|
||
self.gate = nn.Sequential(
|
||
nn.Linear(d_model, d_model // 2),
|
||
nn.GELU(),
|
||
nn.Linear(d_model // 2, 1),
|
||
nn.Sigmoid()
|
||
)
|
||
|
||
self.dropout = nn.Dropout(dropout)
|
||
|
||
def forward(self, x_bt_t_d: torch.Tensor, mask_bt: torch.Tensor) -> torch.Tensor:
|
||
"""
|
||
x_bt_t_d: (B*, T, D)
|
||
mask_bt: (B*, T) bool
|
||
return: slots (B*, R, D)
|
||
"""
|
||
BStar, T, D = x_bt_t_d.shape
|
||
K = self.key_proj(x_bt_t_d) # (B*, T, d_att)
|
||
V = self.val_proj(x_bt_t_d) # (B*, T, D)
|
||
g = self.gate(x_bt_t_d) # (B*, T, 1)
|
||
Vg = V * g # 门控后的值
|
||
|
||
Q = self.query.unsqueeze(0).expand(BStar, -1, -1) # (B*, R, d_att)
|
||
|
||
logits = torch.matmul(Q, K.transpose(1, 2)) / (self.d_att ** 0.5) # (B*, R, T)
|
||
attn_mask = mask_bt.unsqueeze(1) # (B*, 1, T)
|
||
attn = masked_softmax(logits, attn_mask, dim=-1)
|
||
attn = self.dropout(attn)
|
||
|
||
slots = torch.matmul(attn, Vg) # (B*, R, D)
|
||
return slots
|
||
|
||
class AttnPoolHeadForecast(nn.Module):
|
||
"""
|
||
预测任务头:门控注意力聚合到 R 个slots,再映射到 target_window(pred_len)
|
||
输出:(B, pred_len, nvars)
|
||
"""
|
||
def __init__(self, d_model: int, target_window: int, num_slots: int = 4, dropout: float = 0.1):
|
||
super().__init__()
|
||
self.agg = GatedAttnAggregator(d_model, num_slots=num_slots, dropout=dropout)
|
||
self.proj = nn.Sequential(
|
||
nn.LayerNorm(num_slots * d_model),
|
||
nn.Linear(num_slots * d_model, target_window),
|
||
)
|
||
self.dropout = nn.Dropout(dropout)
|
||
|
||
def forward(self, enc_out_bt_t_d: torch.Tensor, key_padding_mask_bt: torch.Tensor, n_vars: int, B: int):
|
||
slots = self.agg(enc_out_bt_t_d, key_padding_mask_bt) # (B*, R, D)
|
||
slots = slots.reshape(B, n_vars, -1) # (B, nvars, R*D)
|
||
out = self.proj(self.dropout(slots)) # (B, nvars, pred_len)
|
||
return out.permute(0, 2, 1) # (B, pred_len, nvars)
|
||
|
||
class AttnPoolHeadSeq(nn.Module):
|
||
"""
|
||
序列重建头:门控注意力聚合后映射到 seq_len
|
||
输出:(B, seq_len, nvars)
|
||
"""
|
||
def __init__(self, d_model: int, target_window: int, num_slots: int = 4, dropout: float = 0.1):
|
||
super().__init__()
|
||
self.agg = GatedAttnAggregator(d_model, num_slots=num_slots, dropout=dropout)
|
||
self.proj = nn.Sequential(
|
||
nn.LayerNorm(num_slots * d_model),
|
||
nn.Linear(num_slots * d_model, target_window),
|
||
)
|
||
self.dropout = nn.Dropout(dropout)
|
||
|
||
def forward(self, enc_out_bt_t_d: torch.Tensor, key_padding_mask_bt: torch.Tensor, n_vars: int, B: int):
|
||
slots = self.agg(enc_out_bt_t_d, key_padding_mask_bt) # (B*, R, D)
|
||
slots = slots.reshape(B, n_vars, -1) # (B, nvars, R*D)
|
||
out = self.proj(self.dropout(slots)) # (B, nvars, seq_len)
|
||
return out.permute(0, 2, 1) # (B, seq_len, nvars)
|
||
|
||
class AttnPoolHeadCls(nn.Module):
|
||
"""
|
||
分类头:每变量先门控注意力聚合到 R 个slots,拼接所有变量后线性分类。
|
||
输出:(B, num_class)
|
||
"""
|
||
def __init__(self, d_model: int, n_vars: int, num_class: int, num_slots: int = 4, dropout: float = 0.1):
|
||
super().__init__()
|
||
self.agg = GatedAttnAggregator(d_model, num_slots=num_slots, dropout=dropout)
|
||
self.dropout = nn.Dropout(dropout)
|
||
self.proj = nn.Sequential(
|
||
nn.LayerNorm(n_vars * num_slots * d_model),
|
||
nn.Linear(n_vars * num_slots * d_model, num_class),
|
||
)
|
||
self.n_vars = n_vars
|
||
self.num_slots = num_slots
|
||
self.d_model = d_model
|
||
|
||
def forward(self, enc_out_bt_t_d: torch.Tensor, key_padding_mask_bt: torch.Tensor, n_vars: int, B: int):
|
||
slots = self.agg(enc_out_bt_t_d, key_padding_mask_bt) # (B*, R, D)
|
||
slots = slots.reshape(B, n_vars, self.num_slots * self.d_model) # (B, nvars, R*D)
|
||
flat = self.dropout(slots.reshape(B, -1)) # (B, nvars*R*D)
|
||
return self.proj(flat)
|
||
|
||
|
||
# -------------------- 主模型:两层DC(比率控制) + 带mask的Encoder + 门控聚合头 --------------------
|
||
class Transpose(nn.Module):
|
||
def __init__(self, *dims, contiguous=False):
|
||
super().__init__()
|
||
self.dims, self.contiguous = dims, contiguous
|
||
def forward(self, x):
|
||
return x.transpose(*self.dims).contiguous() if self.contiguous else x.transpose(*self.dims)
|
||
|
||
class Model(nn.Module):
|
||
"""
|
||
PatchTST with DC and masked attention + gated heads:
|
||
- 用两层 Mamba2 编码器 + 动态分块 替代 PatchEmbedding
|
||
- DC 使用 ratio loss(target_ratio0/1)控制压缩强度;随层级加深,序列变短,d_model 变大(D0->D1)
|
||
- 注意力传入 key_padding_mask 屏蔽pad
|
||
- 头部使用门控注意力聚合(不依赖token数,信息保留更充分)
|
||
"""
|
||
|
||
def __init__(
|
||
self, configs,
|
||
d_model_stage0=None, # D0,默认= d_model // 2
|
||
depth_enc0=1, depth_enc1=1,
|
||
target_ratio0=0.25, # 约等于 1/N0
|
||
target_ratio1=0.5, # 约等于 1/N1
|
||
agg_slots=4, # 门控聚合的slot数
|
||
):
|
||
super().__init__()
|
||
self.task_name = configs.task_name
|
||
self.seq_len = configs.seq_len
|
||
self.pred_len = configs.pred_len
|
||
self.enc_in = configs.enc_in
|
||
|
||
# DC 嵌入
|
||
D1 = configs.d_model
|
||
D0 = d_model_stage0 if d_model_stage0 is not None else max(16, D1 // 2)
|
||
assert D1 >= D0, "要求 D0 <= D1"
|
||
self.dc_embedding = DCEmbedding2StageVarLen(
|
||
d_model_out=D1,
|
||
d_model_stage0=D0,
|
||
depth_enc0=depth_enc0,
|
||
depth_enc1=depth_enc1,
|
||
dropout=configs.dropout,
|
||
target_ratio0=target_ratio0,
|
||
target_ratio1=target_ratio1,
|
||
)
|
||
|
||
# 带mask的Encoder
|
||
attn_layers = [
|
||
EncoderLayerWithMask(
|
||
AttentionLayer(
|
||
FullAttention(False, configs.factor, attention_dropout=configs.dropout, output_attention=False),
|
||
D1, configs.n_heads
|
||
),
|
||
d_model=D1,
|
||
d_ff=configs.d_ff,
|
||
dropout=configs.dropout,
|
||
activation=configs.activation
|
||
) for _ in range(configs.e_layers)
|
||
]
|
||
self.encoder = EncoderWithMask(
|
||
attn_layers,
|
||
norm_layer=nn.Sequential(Transpose(1, 2), nn.BatchNorm1d(D1), Transpose(1, 2))
|
||
)
|
||
|
||
# 门控聚合头(与token数无关)
|
||
if self.task_name in ('long_term_forecast', 'short_term_forecast'):
|
||
self.head = AttnPoolHeadForecast(D1, self.pred_len, num_slots=agg_slots, dropout=configs.dropout)
|
||
elif self.task_name in ('imputation', 'anomaly_detection'):
|
||
self.head = AttnPoolHeadSeq(D1, self.seq_len, num_slots=agg_slots, dropout=configs.dropout)
|
||
elif self.task_name == 'classification':
|
||
self.head_cls = AttnPoolHeadCls(D1, n_vars=self.enc_in, num_class=configs.num_class, num_slots=agg_slots, dropout=configs.dropout)
|
||
|
||
# --------- 归一化/反归一化 ---------
|
||
def _pre_norm(self, x):
|
||
means = x.mean(1, keepdim=True).detach()
|
||
x = x - means
|
||
stdev = torch.sqrt(torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
||
x = x / stdev
|
||
return x, means, stdev
|
||
|
||
def _denorm(self, y, means, stdev, length):
|
||
return y * (stdev[:, 0, :].unsqueeze(1).repeat(1, length, 1)) + \
|
||
(means[:, 0, :].unsqueeze(1).repeat(1, length, 1))
|
||
|
||
# --------- DC + Transformer Encoder(携带 key_padding_mask) ----------
|
||
def _embed_and_encode(self, x_enc):
|
||
"""
|
||
x_enc: (B, L, C)
|
||
返回:
|
||
enc_out: (B*nvars, T_max, D1)
|
||
n_vars: int
|
||
key_padding_mask: (B*nvars, T_max)
|
||
aux: dict
|
||
"""
|
||
B, L, C = x_enc.shape
|
||
x_vars = x_enc.permute(0, 2, 1) # (B, nvars, L)
|
||
enc_out, key_padding_mask, n_vars, aux = self.dc_embedding(x_vars)
|
||
enc_out, _ = self.encoder(enc_out, attn_mask=None, key_padding_mask=key_padding_mask)
|
||
return enc_out, n_vars, key_padding_mask, B, aux
|
||
|
||
# --------- 各任务前向 ---------
|
||
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
||
x_enc, means, stdev = self._pre_norm(x_enc)
|
||
enc_out, n_vars, key_padding_mask, B, aux = self._embed_and_encode(x_enc)
|
||
dec_out = self.head(enc_out, key_padding_mask, n_vars, B) # (B, pred_len, nvars)
|
||
dec_out = self._denorm(dec_out, means, stdev, self.pred_len)
|
||
return dec_out, aux
|
||
|
||
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
|
||
means = torch.sum(x_enc, dim=1) / torch.sum(mask == 1, dim=1)
|
||
means = means.unsqueeze(1).detach()
|
||
x = x_enc - means
|
||
x = x.masked_fill(mask == 0, 0)
|
||
stdev = torch.sqrt(torch.sum(x * x, dim=1) / torch.sum(mask == 1, dim=1) + 1e-5)
|
||
stdev = stdev.unsqueeze(1).detach()
|
||
x = x / stdev
|
||
|
||
enc_out, n_vars, key_padding_mask, B, aux = self._embed_and_encode(x)
|
||
dec_out = self.head(enc_out, key_padding_mask, n_vars, B) # (B, seq_len, nvars)
|
||
dec_out = self._denorm(dec_out, means, stdev, self.seq_len)
|
||
return dec_out, aux
|
||
|
||
def anomaly_detection(self, x_enc):
|
||
x_enc, means, stdev = self._pre_norm(x_enc)
|
||
enc_out, n_vars, key_padding_mask, B, aux = self._embed_and_encode(x_enc)
|
||
dec_out = self.head(enc_out, key_padding_mask, n_vars, B) # (B, seq_len, nvars)
|
||
dec_out = self._denorm(dec_out, means, stdev, self.seq_len)
|
||
return dec_out, aux
|
||
|
||
def classification(self, x_enc, x_mark_enc):
|
||
x_enc, _, _ = self._pre_norm(x_enc)
|
||
enc_out, n_vars, key_padding_mask, B, aux = self._embed_and_encode(x_enc)
|
||
logits = self.head_cls(enc_out, key_padding_mask, n_vars, B) # (B, num_class)
|
||
return logits, aux
|
||
|
||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
||
if self.task_name in ('long_term_forecast', 'short_term_forecast'):
|
||
dec_out, aux = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
||
return dec_out[:, -self.pred_len:, :], aux # [B, L, D], aux含ratio losses
|
||
if self.task_name == 'imputation':
|
||
dec_out, aux = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
|
||
return dec_out, aux
|
||
if self.task_name == 'anomaly_detection':
|
||
dec_out, aux = self.anomaly_detection(x_enc)
|
||
return dec_out, aux
|
||
if self.task_name == 'classification':
|
||
logits, aux = self.classification(x_enc, x_mark_enc)
|
||
return logits, aux
|
||
return None, None
|