Files
TSlib/models/DC_PatchTST.py

529 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
from torch import nn
import torch.nn.functional as F
from layers.SelfAttention_Family import FullAttention, AttentionLayer
# 需要 Mamba2 作为外层编码器
from mamba_ssm.modules.mamba2 import Mamba2
# -------------------- Routing余弦路由和论文一致 --------------------
class RoutingModule(nn.Module):
def __init__(self, d_model):
super().__init__()
self.q_proj = nn.Linear(d_model, d_model, bias=False)
self.k_proj = nn.Linear(d_model, d_model, bias=False)
with torch.no_grad():
nn.init.eye_(self.q_proj.weight)
nn.init.eye_(self.k_proj.weight)
self.q_proj.weight._no_reinit = True
self.k_proj.weight._no_reinit = True
def forward(self, x, mask=None):
"""
x: (B, L, D)
mask: (B, L) bool, True=有效
返回:
boundary_prob: (B, L, 2)
boundary_mask: (B, L) bool
selected_probs: (B, L, 1)
"""
B, L, D = x.shape
q = F.normalize(self.q_proj(x[:, :-1]), dim=-1) # (B, L-1, D)
k = F.normalize(self.k_proj(x[:, 1:]), dim=-1) # (B, L-1, D)
cos_sim = (q * k).sum(dim=-1) # (B, L-1)
p = torch.clamp((1 - cos_sim) / 2, 0.0, 1.0) # (B, L-1)
p = F.pad(p, (1, 0), value=1.0) # 强制首位是边界
if mask is not None:
p = p * mask.float()
p[:, 0] = torch.where(mask[:, 0], torch.ones_like(p[:, 0]), p[:, 0])
boundary_prob = torch.stack([1 - p, p], dim=-1) # (B, L, 2)
selected_idx = boundary_prob.argmax(dim=-1)
boundary_mask = (selected_idx == 1)
if mask is not None:
boundary_mask = boundary_mask & mask
selected_probs = boundary_prob.gather(-1, selected_idx.unsqueeze(-1)) # (B, L, 1)
return boundary_prob, boundary_mask, selected_probs
# -------------------- 选择并右侧零pad不丢弃、不重复填充 --------------------
def select_and_right_pad(x, boundary_mask):
"""
内存优化版本减少临时tensor创建
x: (B, L, D), boundary_mask: (B, L) bool
返回:
x_pad: (B, T_max, D)
key_padding_mask: (B, T_max) bool, True=有效
lengths: (B,)
"""
B, L, D = x.shape
device = x.device
lengths = boundary_mask.sum(dim=1) # (B,)
T_max = int(lengths.max().item()) if lengths.max() > 0 else 1
x_pad = x.new_zeros(B, T_max, D)
key_padding_mask = torch.zeros(B, T_max, dtype=torch.bool, device=device)
# 预创建默认索引tensor避免重复创建
default_idx = torch.tensor([0], device=device)
for b in range(B):
mask_b = boundary_mask[b]
if mask_b.any():
idx = mask_b.nonzero(as_tuple=True)[0] # 更高效的nonzero
t = idx.numel()
x_pad[b, :t] = x[b, idx]
key_padding_mask[b, :t] = True
else:
# 使用预创建的tensor
x_pad[b, 0] = x[b, default_idx]
key_padding_mask[b, 0] = True
return x_pad, key_padding_mask, lengths
# -------------------- Mamba2 堆叠(外层编码器) --------------------
class Mamba2Encoder(nn.Module):
def __init__(self, d_model, depth=4, dropout=0.0):
super().__init__()
self.layers = nn.ModuleList([Mamba2(d_model=d_model) for _ in range(depth)])
self.norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
for layer in self.layers:
x = layer(x)
x = self.norm(x)
x = self.dropout(x)
return x
# -------------------- 两层Encoder + DC变长不丢信息以比率约束压缩 --------------------
class DCEmbedding2StageVarLen(nn.Module):
"""
- Stage 0: (B*nvars, L, 1) -> Linear(D0) -> Mamba2(D0) -> Routing -> 选择 -> 扩宽到 D1
- Stage 1: (B*nvars, L0_sel, D1) -> Mamba2(D1) -> Routing -> 选择
输出:
enc_out: (B*nvars, T_max, D1)
key_padding_mask: (B*nvars, T_max)
n_vars: int
aux: dict含两层ratio loss与边界信息
"""
def __init__(self, d_model_out, d_model_stage0, depth_enc0=4, depth_enc1=4, dropout=0.0,
target_ratio0=0.25, target_ratio1=0.5):
super().__init__()
assert d_model_out >= d_model_stage0, "要求 D0 <= D1"
self.d0 = d_model_stage0
self.d1 = d_model_out
# 标量 -> D0
self.input_proj = nn.Linear(1, self.d0)
# Stage 0
self.enc0 = Mamba2Encoder(self.d0, depth=depth_enc0, dropout=dropout)
self.router0 = RoutingModule(self.d0)
delta = self.d1 - self.d0
self.pad_vec = nn.Parameter(torch.zeros(delta)) if delta > 0 else None
self.target_ratio0 = target_ratio0
# Stage 1
self.enc1 = Mamba2Encoder(self.d1, depth=depth_enc1, dropout=dropout)
self.router1 = RoutingModule(self.d1)
self.target_ratio1 = target_ratio1
def _expand_width(self, x):
if self.pad_vec is None:
return x
B, L, _ = x.shape
return torch.cat([x, self.pad_vec.view(1, 1, -1).expand(B, L, -1)], dim=-1)
@staticmethod
def _ratio_loss(boundary_mask: torch.Tensor, boundary_prob: torch.Tensor, target_ratio: float) -> torch.Tensor:
eps = 1e-6
F_act = boundary_mask.float().mean(dim=1) # (B,)
G_prob = boundary_prob[..., 1].mean(dim=1) # (B,)
N = 1.0 / max(target_ratio, eps)
loss = N / (N - 1.0 + eps) * (((N - 1.0) * F_act * G_prob) + (1.0 - F_act) * (1.0 - G_prob))
return loss.mean()
def forward(self, x):
"""
x: (B, nvars, L)
内存优化版本及时删除中间tensor
"""
B, nvars, L = x.shape
x = x.reshape(B * nvars, L, 1)
x = self.input_proj(x) # (B*nvars, L, D0)
# Stage 0
h0 = self.enc0(x) # (B*nvars, L, D0)
p0, bm0, _ = self.router0(h0)
h0_sel, mask0, len0 = select_and_right_pad(h0, bm0) # (B*nvars, L0_max, D0)
# 及时删除不需要的tensor
del h0
# h0_sel = self._expand_width(h0_sel) # (B*nvars, L0_max, D1)
# Stage 1
#h1 = self.enc1(h0_sel) # (B*nvars, L0_max, D1)
#p1, bm1, _ = self.router1(h1)
#bm1 = bm1 & mask0
#h1_sel, mask1, len1 = select_and_right_pad(h1, bm1) # (B*nvars, L1_max, D1)
# 及时删除中间tensor
#del h1, h0_sel
# 计算ratio loss时使用detach避免保存计算图
ratio_loss0 = self._ratio_loss(bm0, p0, target_ratio=self.target_ratio0)
# ratio_loss1 = self._ratio_loss(bm1, p1, target_ratio=self.target_ratio1)
# 简化aux字典只保存必要信息
aux = {
"stage0": {"boundary_mask": bm0.detach(), "boundary_prob": p0.detach(), "lengths": len0.detach()},
# "stage1": {"boundary_mask": bm1.detach(), "boundary_prob": p1.detach(), "lengths": len1.detach()},
"ratio_loss0": ratio_loss0,
# "ratio_loss1": ratio_loss1,
}
return h0_sel, mask0, nvars, aux
# -------------------- Encoder/EncoderLayer带 key_padding_mask 透传) --------------------
class EncoderLayerWithMask(nn.Module):
"""
与原EncoderLayer结构一致但 forward 增加 key_padding_mask并传入 AttentionLayer。
FFN 用简单的 MLP与常规Transformer一致
"""
def __init__(self, attention: AttentionLayer, d_model, d_ff, dropout=0.1, activation="gelu"):
super().__init__()
self.attention = attention
self.dropout = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
if activation == "relu":
act = nn.ReLU()
elif activation == "gelu":
act = nn.GELU()
else:
raise ValueError(f"Unsupported activation: {activation}")
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
act,
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
)
def forward(self, x, attn_mask=None, tau=None, delta=None, key_padding_mask=None):
# Multi-head attention with key padding mask
attn_out, attn = self.attention(
x, x, x, attn_mask, tau=tau, delta=delta, key_padding_mask=key_padding_mask
)
x = x + self.dropout(attn_out)
x = self.norm1(x)
# FFN
y = self.ffn(x)
x = x + self.dropout(y)
x = self.norm2(x)
return x, attn
class EncoderWithMask(nn.Module):
"""
与原Encoder类似但 forward 支持 key_padding_mask并传递给每一层的注意力。
"""
def __init__(self, attn_layers, norm_layer=None):
super().__init__()
self.attn_layers = nn.ModuleList(attn_layers)
self.norm = norm_layer
def forward(self, x, attn_mask=None, key_padding_mask=None):
attns = []
for attn_layer in self.attn_layers:
x, attn = attn_layer(x, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
attns.append(attn)
if self.norm is not None:
x = self.norm(x)
return x, attns
# -------------------- 门控注意力聚合 + 任务头不依赖token数保留信息 --------------------
def masked_softmax(logits: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor:
"""
logits: (..., T)
mask: (..., T) bool, True=有效
"""
neg_inf = torch.finfo(logits.dtype).min
logits = logits.masked_fill(~mask, neg_inf)
return torch.softmax(logits, dim=dim)
class GatedAttnAggregator(nn.Module):
"""
门控注意力聚合器(可学习查询 + mask softmax + 值端门控)
输入: x: (B*, T, D), mask: (B*, T) bool
输出: slots: (B*, R, D) 其中 R 为聚合插槽数(可配置)
"""
def __init__(self, d_model: int, num_slots: int = 4, d_att: int = None, dropout: float = 0.1):
super().__init__()
self.d_model = d_model
self.R = num_slots
self.d_att = d_att or d_model
# 可学习查询R个
self.query = nn.Parameter(torch.randn(self.R, self.d_att) / (self.d_att ** 0.5))
# 线性投影
self.key_proj = nn.Linear(d_model, self.d_att)
self.val_proj = nn.Linear(d_model, d_model)
# 值端门控逐token标量门
self.gate = nn.Sequential(
nn.Linear(d_model, d_model // 2),
nn.GELU(),
nn.Linear(d_model // 2, 1),
nn.Sigmoid()
)
self.dropout = nn.Dropout(dropout)
def forward(self, x_bt_t_d: torch.Tensor, mask_bt: torch.Tensor) -> torch.Tensor:
"""
x_bt_t_d: (B*, T, D)
mask_bt: (B*, T) bool
return: slots (B*, R, D)
"""
BStar, T, D = x_bt_t_d.shape
K = self.key_proj(x_bt_t_d) # (B*, T, d_att)
V = self.val_proj(x_bt_t_d) # (B*, T, D)
g = self.gate(x_bt_t_d) # (B*, T, 1)
Vg = V * g # 门控后的值
Q = self.query.unsqueeze(0).expand(BStar, -1, -1) # (B*, R, d_att)
logits = torch.matmul(Q, K.transpose(1, 2)) / (self.d_att ** 0.5) # (B*, R, T)
attn_mask = mask_bt.unsqueeze(1) # (B*, 1, T)
attn = masked_softmax(logits, attn_mask, dim=-1)
attn = self.dropout(attn)
slots = torch.matmul(attn, Vg) # (B*, R, D)
return slots
class AttnPoolHeadForecast(nn.Module):
"""
预测任务头:门控注意力聚合到 R 个slots再映射到 target_windowpred_len
输出:(B, pred_len, nvars)
"""
def __init__(self, d_model: int, target_window: int, num_slots: int = 4, dropout: float = 0.1):
super().__init__()
self.agg = GatedAttnAggregator(d_model, num_slots=num_slots, dropout=dropout)
self.proj = nn.Sequential(
nn.LayerNorm(num_slots * d_model),
nn.Linear(num_slots * d_model, target_window),
)
self.dropout = nn.Dropout(dropout)
def forward(self, enc_out_bt_t_d: torch.Tensor, key_padding_mask_bt: torch.Tensor, n_vars: int, B: int):
slots = self.agg(enc_out_bt_t_d, key_padding_mask_bt) # (B*, R, D)
slots = slots.reshape(B, n_vars, -1) # (B, nvars, R*D)
out = self.proj(self.dropout(slots)) # (B, nvars, pred_len)
return out.permute(0, 2, 1) # (B, pred_len, nvars)
class AttnPoolHeadSeq(nn.Module):
"""
序列重建头:门控注意力聚合后映射到 seq_len
输出:(B, seq_len, nvars)
"""
def __init__(self, d_model: int, target_window: int, num_slots: int = 4, dropout: float = 0.1):
super().__init__()
self.agg = GatedAttnAggregator(d_model, num_slots=num_slots, dropout=dropout)
self.proj = nn.Sequential(
nn.LayerNorm(num_slots * d_model),
nn.Linear(num_slots * d_model, target_window),
)
self.dropout = nn.Dropout(dropout)
def forward(self, enc_out_bt_t_d: torch.Tensor, key_padding_mask_bt: torch.Tensor, n_vars: int, B: int):
slots = self.agg(enc_out_bt_t_d, key_padding_mask_bt) # (B*, R, D)
slots = slots.reshape(B, n_vars, -1) # (B, nvars, R*D)
out = self.proj(self.dropout(slots)) # (B, nvars, seq_len)
return out.permute(0, 2, 1) # (B, seq_len, nvars)
class AttnPoolHeadCls(nn.Module):
"""
分类头:每变量先门控注意力聚合到 R 个slots拼接所有变量后线性分类。
输出:(B, num_class)
"""
def __init__(self, d_model: int, n_vars: int, num_class: int, num_slots: int = 4, dropout: float = 0.1):
super().__init__()
self.agg = GatedAttnAggregator(d_model, num_slots=num_slots, dropout=dropout)
self.dropout = nn.Dropout(dropout)
self.proj = nn.Sequential(
nn.LayerNorm(n_vars * num_slots * d_model),
nn.Linear(n_vars * num_slots * d_model, num_class),
)
self.n_vars = n_vars
self.num_slots = num_slots
self.d_model = d_model
def forward(self, enc_out_bt_t_d: torch.Tensor, key_padding_mask_bt: torch.Tensor, n_vars: int, B: int):
slots = self.agg(enc_out_bt_t_d, key_padding_mask_bt) # (B*, R, D)
slots = slots.reshape(B, n_vars, self.num_slots * self.d_model) # (B, nvars, R*D)
flat = self.dropout(slots.reshape(B, -1)) # (B, nvars*R*D)
return self.proj(flat)
# -------------------- 主模型两层DC(比率控制) + 带mask的Encoder + 门控聚合头 --------------------
class Transpose(nn.Module):
def __init__(self, *dims, contiguous=False):
super().__init__()
self.dims, self.contiguous = dims, contiguous
def forward(self, x):
return x.transpose(*self.dims).contiguous() if self.contiguous else x.transpose(*self.dims)
class Model(nn.Module):
"""
PatchTST with DC and masked attention + gated heads:
- 用两层 Mamba2 编码器 + 动态分块 替代 PatchEmbedding
- DC 使用 ratio losstarget_ratio0/1控制压缩强度随层级加深序列变短d_model 变大D0->D1
- 注意力传入 key_padding_mask 屏蔽pad
- 头部使用门控注意力聚合不依赖token数信息保留更充分
"""
def __init__(
self, configs,
d_model_stage0=None, # D0默认= d_model // 2
depth_enc0=1, depth_enc1=1,
target_ratio0=0.25, # 约等于 1/N0
target_ratio1=0.5, # 约等于 1/N1
agg_slots=4, # 门控聚合的slot数
):
super().__init__()
self.task_name = configs.task_name
self.seq_len = configs.seq_len
self.pred_len = configs.pred_len
self.enc_in = configs.enc_in
# DC 嵌入
D1 = configs.d_model
D0 = d_model_stage0 if d_model_stage0 is not None else max(16, D1 // 2)
assert D1 >= D0, "要求 D0 <= D1"
self.dc_embedding = DCEmbedding2StageVarLen(
d_model_out=D1,
d_model_stage0=D0,
depth_enc0=depth_enc0,
depth_enc1=depth_enc1,
dropout=configs.dropout,
target_ratio0=target_ratio0,
target_ratio1=target_ratio1,
)
# 带mask的Encoder
attn_layers = [
EncoderLayerWithMask(
AttentionLayer(
FullAttention(False, configs.factor, attention_dropout=configs.dropout, output_attention=False),
D1, configs.n_heads
),
d_model=D1,
d_ff=configs.d_ff,
dropout=configs.dropout,
activation=configs.activation
) for _ in range(configs.e_layers)
]
self.encoder = EncoderWithMask(
attn_layers,
norm_layer=nn.Sequential(Transpose(1, 2), nn.BatchNorm1d(D1), Transpose(1, 2))
)
# 门控聚合头与token数无关
if self.task_name in ('long_term_forecast', 'short_term_forecast'):
self.head = AttnPoolHeadForecast(D1, self.pred_len, num_slots=agg_slots, dropout=configs.dropout)
elif self.task_name in ('imputation', 'anomaly_detection'):
self.head = AttnPoolHeadSeq(D1, self.seq_len, num_slots=agg_slots, dropout=configs.dropout)
elif self.task_name == 'classification':
self.head_cls = AttnPoolHeadCls(D1, n_vars=self.enc_in, num_class=configs.num_class, num_slots=agg_slots, dropout=configs.dropout)
# --------- 归一化/反归一化 ---------
def _pre_norm(self, x):
means = x.mean(1, keepdim=True).detach()
x = x - means
stdev = torch.sqrt(torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5)
x = x / stdev
return x, means, stdev
def _denorm(self, y, means, stdev, length):
return y * (stdev[:, 0, :].unsqueeze(1).repeat(1, length, 1)) + \
(means[:, 0, :].unsqueeze(1).repeat(1, length, 1))
# --------- DC + Transformer Encoder携带 key_padding_mask ----------
def _embed_and_encode(self, x_enc):
"""
x_enc: (B, L, C)
返回:
enc_out: (B*nvars, T_max, D1)
n_vars: int
key_padding_mask: (B*nvars, T_max)
aux: dict
"""
B, L, C = x_enc.shape
x_vars = x_enc.permute(0, 2, 1) # (B, nvars, L)
enc_out, key_padding_mask, n_vars, aux = self.dc_embedding(x_vars)
enc_out, _ = self.encoder(enc_out, attn_mask=None, key_padding_mask=key_padding_mask)
return enc_out, n_vars, key_padding_mask, B, aux
# --------- 各任务前向 ---------
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
x_enc, means, stdev = self._pre_norm(x_enc)
enc_out, n_vars, key_padding_mask, B, aux = self._embed_and_encode(x_enc)
dec_out = self.head(enc_out, key_padding_mask, n_vars, B) # (B, pred_len, nvars)
dec_out = self._denorm(dec_out, means, stdev, self.pred_len)
return dec_out, aux
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
means = torch.sum(x_enc, dim=1) / torch.sum(mask == 1, dim=1)
means = means.unsqueeze(1).detach()
x = x_enc - means
x = x.masked_fill(mask == 0, 0)
stdev = torch.sqrt(torch.sum(x * x, dim=1) / torch.sum(mask == 1, dim=1) + 1e-5)
stdev = stdev.unsqueeze(1).detach()
x = x / stdev
enc_out, n_vars, key_padding_mask, B, aux = self._embed_and_encode(x)
dec_out = self.head(enc_out, key_padding_mask, n_vars, B) # (B, seq_len, nvars)
dec_out = self._denorm(dec_out, means, stdev, self.seq_len)
return dec_out, aux
def anomaly_detection(self, x_enc):
x_enc, means, stdev = self._pre_norm(x_enc)
enc_out, n_vars, key_padding_mask, B, aux = self._embed_and_encode(x_enc)
dec_out = self.head(enc_out, key_padding_mask, n_vars, B) # (B, seq_len, nvars)
dec_out = self._denorm(dec_out, means, stdev, self.seq_len)
return dec_out, aux
def classification(self, x_enc, x_mark_enc):
x_enc, _, _ = self._pre_norm(x_enc)
enc_out, n_vars, key_padding_mask, B, aux = self._embed_and_encode(x_enc)
logits = self.head_cls(enc_out, key_padding_mask, n_vars, B) # (B, num_class)
return logits, aux
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
if self.task_name in ('long_term_forecast', 'short_term_forecast'):
dec_out, aux = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
return dec_out[:, -self.pred_len:, :], aux # [B, L, D], aux含ratio losses
if self.task_name == 'imputation':
dec_out, aux = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
return dec_out, aux
if self.task_name == 'anomaly_detection':
dec_out, aux = self.anomaly_detection(x_enc)
return dec_out, aux
if self.task_name == 'classification':
logits, aux = self.classification(x_enc, x_mark_enc)
return logits, aux
return None, None