feat(graph-mixer): implement L0 sparsity with Hard-Concrete gate for channel selection
This commit is contained in:
@ -1,27 +1,86 @@
|
|||||||
|
import math
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import math
|
|
||||||
|
class HardConcreteGate(nn.Module):
|
||||||
|
"""
|
||||||
|
Hard-Concrete gate for L0-style sparsity (Louizos et al., 2017).
|
||||||
|
Produces z in [0,1] without row-wise normalization.
|
||||||
|
"""
|
||||||
|
def __init__(self, shape, temperature=2./3., gamma=-0.1, zeta=1.1, init_log_alpha=-2.0):
|
||||||
|
super().__init__()
|
||||||
|
self.log_alpha = nn.Parameter(torch.full(shape, init_log_alpha))
|
||||||
|
self.temperature = temperature
|
||||||
|
self.gamma = gamma
|
||||||
|
self.zeta = zeta
|
||||||
|
|
||||||
|
def sample(self, training=True):
|
||||||
|
if training:
|
||||||
|
u = torch.rand_like(self.log_alpha)
|
||||||
|
s = torch.sigmoid((self.log_alpha + torch.log(u) - torch.log(1 - u)) / self.temperature)
|
||||||
|
else:
|
||||||
|
# deterministic mean gate at eval
|
||||||
|
s = torch.sigmoid(self.log_alpha)
|
||||||
|
s_bar = s * (self.zeta - self.gamma) + self.gamma
|
||||||
|
z = torch.clamp(s_bar, 0., 1.)
|
||||||
|
return z
|
||||||
|
|
||||||
|
def expected_l0(self):
|
||||||
|
"""
|
||||||
|
E[1_{z>0}] closed-form for hard-concrete.
|
||||||
|
Useful for L0 penalty: lambda * expected_l0.sum()
|
||||||
|
"""
|
||||||
|
# s > t0 => z > 0, where t0 = -gamma / (zeta - gamma)
|
||||||
|
t0 = -self.gamma / (self.zeta - self.gamma)
|
||||||
|
# logit(t0)
|
||||||
|
logit_t0 = math.log(t0) - math.log(1 - t0)
|
||||||
|
# P(x > logit_t0) with x ~ Logistic(loc=log_alpha, scale=temperature)
|
||||||
|
p_open = torch.sigmoid((self.log_alpha - logit_t0) / self.temperature)
|
||||||
|
return p_open
|
||||||
|
|
||||||
class HierarchicalGraphMixer(nn.Module):
|
class HierarchicalGraphMixer(nn.Module):
|
||||||
"""
|
"""
|
||||||
分层图混合器,同时考虑宏观通道关系和微观 Patch 级别注意力。
|
使用 Hard-Concrete 边门控的分层图混合器:
|
||||||
输入 z : [B, C, N, D]
|
- Level 1: 非归一化、可阈值、可为空的通道图
|
||||||
输出 z_out : 同形状
|
- Level 2: 仅在被选中的边上做 Patch 级别交叉注意力
|
||||||
|
输入: z [B, C, N, D]
|
||||||
|
输出: z_out 同形状
|
||||||
"""
|
"""
|
||||||
def __init__(self, n_channel: int, dim: int, k: int = 5, tau_fw: float = 0.3, tau_bw: float = 3.0):
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_channel: int,
|
||||||
|
dim: int,
|
||||||
|
max_degree: int = None, # 可选:限制每行最多边数
|
||||||
|
thr: float = 0.5, # 保留边阈值,例如 0.5/0.7
|
||||||
|
temperature: float = 2./3.,
|
||||||
|
tau_attn: float = 1.0, # Patch attention 温度(可选)
|
||||||
|
symmetric: bool = True, # 是否对称化通道图
|
||||||
|
degree_rescale: str = "none", # "none" | "count" | "count-sqrt" | "sum"
|
||||||
|
init_log_alpha: float = -2.0
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.k = k
|
self.C = n_channel
|
||||||
self.tau_fw = tau_fw # 前向温度(小)
|
self.dim = dim
|
||||||
self.tau_bw = tau_bw # 反向温度(大)
|
self.max_degree = max_degree
|
||||||
|
self.thr = thr
|
||||||
# Level 1: Channel Graph (logits)
|
self.tau_attn = tau_attn
|
||||||
self.A = nn.Parameter(torch.zeros(n_channel, n_channel))
|
self.symmetric = symmetric
|
||||||
|
self.degree_rescale = degree_rescale
|
||||||
|
|
||||||
|
# Level 1: 非归一化门控
|
||||||
|
self.gate = HardConcreteGate(
|
||||||
|
shape=(n_channel, n_channel),
|
||||||
|
temperature=temperature,
|
||||||
|
init_log_alpha=init_log_alpha
|
||||||
|
)
|
||||||
|
|
||||||
|
# 可选 SE(你原来的 se 可以用来生成样本相关的通道优先级,但这里先保留接口)
|
||||||
self.se = nn.Sequential(
|
self.se = nn.Sequential(
|
||||||
nn.Linear(dim, dim // 4, bias=False), nn.SiLU(),
|
nn.Linear(dim, dim // 4, bias=False), nn.SiLU(),
|
||||||
nn.Linear(dim // 4, 1, bias=False), nn.Sigmoid()
|
nn.Linear(dim // 4, 1, bias=False), nn.Sigmoid()
|
||||||
)
|
)
|
||||||
|
|
||||||
# Level 2: Patch Cross-Attention
|
# Level 2: Patch Cross-Attention
|
||||||
self.q_proj = nn.Linear(dim, dim)
|
self.q_proj = nn.Linear(dim, dim)
|
||||||
self.k_proj = nn.Linear(dim, dim)
|
self.k_proj = nn.Linear(dim, dim)
|
||||||
@ -29,96 +88,108 @@ class HierarchicalGraphMixer(nn.Module):
|
|||||||
self.out_proj = nn.Linear(dim, dim)
|
self.out_proj = nn.Linear(dim, dim)
|
||||||
self.norm = nn.LayerNorm(dim)
|
self.norm = nn.LayerNorm(dim)
|
||||||
|
|
||||||
@torch.no_grad()
|
def _build_sparse_neighbors(self, z_gate):
|
||||||
def _mask_self_logits_(self, logits: torch.Tensor):
|
|
||||||
"""把对角线置为 -inf,确保不选到自己"""
|
|
||||||
C = logits.size(0)
|
|
||||||
eye = torch.eye(C, device=logits.device, dtype=torch.bool)
|
|
||||||
logits.masked_fill_(eye, float("-inf"))
|
|
||||||
|
|
||||||
def _gumbel_topk_select(self, logits: torch.Tensor):
|
|
||||||
"""
|
"""
|
||||||
返回:
|
基于 z_gate 构造每行的邻接列表(按阈值与可选top-k)。
|
||||||
- idx: [C, k_actual] 每行 top-k 的通道索引(不含自身)
|
返回:
|
||||||
- w_st: [C, k_actual] 选中边的权重(前向=用 tau_fw 的概率;反向梯度=来自 tau_bw 的概率)
|
- idx_list: 长度C的list,每项是LongTensor[idx_j]
|
||||||
|
- w_list: 长度C的list,每项是FloatTensor[w_j](非归一化)
|
||||||
"""
|
"""
|
||||||
C = logits.size(0)
|
C = z_gate.size(0)
|
||||||
k_actual = min(self.k, C - 1)
|
# 去对角
|
||||||
if k_actual <= 0:
|
z_gate = z_gate.clone()
|
||||||
idx = torch.empty((C, 0), dtype=torch.long, device=logits.device)
|
z_gate.fill_diagonal_(0.0)
|
||||||
w_st = torch.empty((C, 0), dtype=logits.dtype, device=logits.device)
|
|
||||||
return idx, w_st
|
|
||||||
|
|
||||||
# 共享一份 Gumbel 噪声,分别用不同温度构造前向/反向的分布
|
if self.symmetric:
|
||||||
g = -torch.empty_like(logits).exponential_().log()
|
z_gate = 0.5 * (z_gate + z_gate.t())
|
||||||
y_fw = (logits + g) / self.tau_fw
|
z_gate.fill_diagonal_(0.0)
|
||||||
y_bw = (logits + g) / self.tau_bw
|
|
||||||
|
|
||||||
# 排除自身
|
idx_list, w_list = [], []
|
||||||
y_fw = y_fw.clone()
|
for i in range(C):
|
||||||
y_bw = y_bw.clone()
|
row = z_gate[i] # [C]
|
||||||
self._mask_self_logits_(y_fw)
|
# 阈值筛选
|
||||||
self._mask_self_logits_(y_bw)
|
mask = row > self.thr
|
||||||
|
if mask.any():
|
||||||
|
vals = row[mask]
|
||||||
|
idxs = torch.nonzero(mask, as_tuple=False).squeeze(-1)
|
||||||
|
# 可选最多度数限制
|
||||||
|
if (self.max_degree is not None) and (idxs.numel() > self.max_degree):
|
||||||
|
topk = torch.topk(vals, k=self.max_degree, dim=0)
|
||||||
|
vals = topk.values
|
||||||
|
idxs = idxs[topk.indices]
|
||||||
|
else:
|
||||||
|
idxs = torch.empty((0,), dtype=torch.long, device=row.device)
|
||||||
|
vals = torch.empty((0,), dtype=row.dtype, device=row.device)
|
||||||
|
idx_list.append(idxs)
|
||||||
|
w_list.append(vals)
|
||||||
|
return idx_list, w_list
|
||||||
|
|
||||||
# 选择前向 top-k(严格选择)
|
def _degree_rescale(self, ctx, w_sel):
|
||||||
topk_val, idx = torch.topk(y_fw, k_actual, dim=-1) # [C, k]
|
"""
|
||||||
# 计算前向/反向的软概率,并仅收集被选中的 k 个
|
非归一化聚合的稳定性处理。可选对聚合值做degree归一化以稳定数值。
|
||||||
p_fw = F.softmax(y_fw, dim=-1) # [C, C]
|
ctx: [B, k, N, D]
|
||||||
p_bw = F.softmax(y_bw, dim=-1) # [C, C]
|
w_sel: [k]
|
||||||
w_fw = torch.gather(p_fw, -1, idx) # [C, k]
|
"""
|
||||||
w_bw = torch.gather(p_bw, -1, idx) # [C, k]
|
if self.degree_rescale == "none":
|
||||||
|
return (ctx * w_sel.view(1, -1, 1, 1)).sum(dim=1)
|
||||||
|
elif self.degree_rescale == "count":
|
||||||
|
k = max(1, w_sel.numel())
|
||||||
|
return (ctx * w_sel.view(1, -1, 1, 1)).sum(dim=1) / float(k)
|
||||||
|
elif self.degree_rescale == "count-sqrt":
|
||||||
|
k = max(1, w_sel.numel())
|
||||||
|
return (ctx * w_sel.view(1, -1, 1, 1)).sum(dim=1) / math.sqrt(k)
|
||||||
|
elif self.degree_rescale == "sum":
|
||||||
|
s = float(w_sel.sum().clamp(min=1e-6))
|
||||||
|
return (ctx * w_sel.view(1, -1, 1, 1)).sum(dim=1) / s
|
||||||
|
else:
|
||||||
|
return (ctx * w_sel.view(1, -1, 1, 1)).sum(dim=1)
|
||||||
|
|
||||||
# 在被选集合内进行归一化,稳定训练
|
def l0_loss(self, lam: float = 1e-4):
|
||||||
eps = 1e-9
|
"""
|
||||||
w_fw = w_fw / (w_fw.sum(-1, keepdim=True) + eps)
|
期望L0正则:鼓励稀疏邻接(可调强度)。
|
||||||
w_bw = w_bw / (w_bw.sum(-1, keepdim=True) + eps)
|
"""
|
||||||
|
return lam * self.gate.expected_l0().sum()
|
||||||
# Straight-Through:前向用 w_fw,反向梯度用 w_bw
|
|
||||||
w_st = w_fw.detach() + w_bw - w_bw.detach() # [C, k]
|
|
||||||
return idx, w_st
|
|
||||||
|
|
||||||
def forward(self, z):
|
def forward(self, z):
|
||||||
# z: [B, C, N, D]
|
# z: [B, C, N, D]
|
||||||
B, C, N, D = z.shape
|
B, C, N, D = z.shape
|
||||||
|
assert C == self.C and D == self.dim
|
||||||
|
|
||||||
# --- Level 1: 选每个通道的 top-k 相关通道(不含自身),并得到ST权重 ---
|
# Level 1: 采样非归一化门 z_gate ∈ [0,1]
|
||||||
idx, w_st = self._gumbel_topk_select(self.A) # idx:[C,k], w_st:[C,k]
|
z_gate = self.gate.sample(training=self.training) # [C, C]
|
||||||
|
|
||||||
# --- Level 2: 仅对被选中的通道做跨通道 Patch 交互 ---
|
# 构建稀疏邻居(阈值 + 可选 top-k)
|
||||||
|
idx_list, w_list = self._build_sparse_neighbors(z_gate)
|
||||||
|
|
||||||
|
# Level 2: 仅对被保留的边做跨通道 Patch 交互
|
||||||
out_z = torch.zeros_like(z)
|
out_z = torch.zeros_like(z)
|
||||||
|
|
||||||
for i in range(C):
|
for i in range(C):
|
||||||
target_z = z[:, i, :, :] # [B, N, D]
|
target_z = z[:, i, :, :] # [B, N, D]
|
||||||
|
idx = idx_list[i]
|
||||||
# 如果该通道没有可选邻居,直接残差
|
if idx.numel() == 0:
|
||||||
if idx.size(1) == 0:
|
# 空邻域:允许“没有相关通道”,仅残差/归一化
|
||||||
out_z[:, i, :, :] = self.norm(target_z)
|
out_z[:, i, :, :] = self.norm(target_z)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
sel_idx = idx[i] # [k]
|
w_sel = w_list[i] # [k], 非归一化权重,范围[0,1]
|
||||||
sel_w = w_st[i] # [k]
|
k_i = idx.numel()
|
||||||
k_i = sel_idx.numel()
|
|
||||||
|
|
||||||
# 源通道块: [B, k, N, D]
|
source_z = z[:, idx, :, :] # [B, k, N, D]
|
||||||
source_z = z[:, sel_idx, :, :]
|
|
||||||
|
|
||||||
# 线性投影
|
Q = self.q_proj(target_z) # [B, N, D]
|
||||||
Q = self.q_proj(target_z) # [B, N, D]
|
|
||||||
K = self.k_proj(source_z.reshape(B * k_i, N, D)).reshape(B, k_i, N, D)
|
K = self.k_proj(source_z.reshape(B * k_i, N, D)).reshape(B, k_i, N, D)
|
||||||
V = self.v_proj(source_z.reshape(B * k_i, N, D)).reshape(B, k_i, N, D)
|
V = self.v_proj(source_z.reshape(B * k_i, N, D)).reshape(B, k_i, N, D)
|
||||||
|
|
||||||
# 跨注意力(一次性对 k 个源通道)
|
# 跨通道 patch 注意力
|
||||||
# attn_scores: [B, k, N, N]
|
|
||||||
attn_scores = torch.einsum('bnd,bkmd->bknm', Q, K) / math.sqrt(D)
|
attn_scores = torch.einsum('bnd,bkmd->bknm', Q, K) / math.sqrt(D)
|
||||||
attn_probs = F.softmax(attn_scores, dim=-1) # [B, k, N, N]
|
if self.tau_attn != 1.0:
|
||||||
context = torch.einsum('bknm,bkmd->bknd', attn_probs, V) # [B, k, N, D]
|
attn_scores = attn_scores / self.tau_attn
|
||||||
|
attn_probs = F.softmax(attn_scores, dim=-1) # [B, k, N, N]
|
||||||
|
context = torch.einsum('bknm,bkmd->bknd', attn_probs, V) # [B, k, N, D]
|
||||||
|
|
||||||
# 用 ST 的通道权重聚合(前向=小温度的权重,反向梯度=大温度)
|
# 非归一化通道权重聚合 + 可选度归一化(仅数值稳定,不改变“非归一化”的语义)
|
||||||
w = sel_w.view(1, k_i, 1, 1) # [1, k, 1, 1]
|
aggregated_context = self._degree_rescale(context, w_sel) # [B, N, D]
|
||||||
aggregated_context = (context * w).sum(dim=1) # [B, N, D]
|
|
||||||
|
|
||||||
# 输出与残差
|
|
||||||
out_z[:, i, :, :] = self.norm(target_z + self.out_proj(aggregated_context))
|
out_z[:, i, :, :] = self.norm(target_z + self.out_proj(aggregated_context))
|
||||||
|
|
||||||
return out_z
|
return out_z
|
||||||
|
|
||||||
|
@ -27,7 +27,15 @@ class SeasonPatch(nn.Module):
|
|||||||
d_state: int = 64,
|
d_state: int = 64,
|
||||||
d_conv: int = 4,
|
d_conv: int = 4,
|
||||||
expand: int = 2,
|
expand: int = 2,
|
||||||
headdim: int = 64):
|
headdim: int = 64,
|
||||||
|
# Mixergraph 可选超参数
|
||||||
|
thr_graph: float = 0.5,
|
||||||
|
symmetric_graph: bool = True,
|
||||||
|
degree_rescale: str = "count-sqrt", # "none" | "count" | "count-sqrt" | "sum"
|
||||||
|
gate_temperature: float = 2./3.,
|
||||||
|
tau_attn: float = 1.0,
|
||||||
|
l0_lambda: float = 1e-4):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Store patch parameters
|
# Store patch parameters
|
||||||
@ -46,7 +54,17 @@ class SeasonPatch(nn.Module):
|
|||||||
c_in=c_in, patch_num=patch_num, patch_len=patch_len,
|
c_in=c_in, patch_num=patch_num, patch_len=patch_len,
|
||||||
d_model=d_model, n_layers=n_layers, n_heads=n_heads
|
d_model=d_model, n_layers=n_layers, n_heads=n_heads
|
||||||
)
|
)
|
||||||
self.mixer = HierarchicalGraphMixer(c_in, dim=d_model, k=k_graph)
|
# 集成新 HierarchicalGraphMixer(非归一化)
|
||||||
|
self.mixer = HierarchicalGraphMixer(
|
||||||
|
n_channel=c_in,
|
||||||
|
dim=d_model,
|
||||||
|
max_degree=k_graph,
|
||||||
|
thr=thr_graph,
|
||||||
|
temperature=gate_temperature,
|
||||||
|
tau_attn=tau_attn,
|
||||||
|
symmetric=symmetric_graph,
|
||||||
|
degree_rescale=degree_rescale
|
||||||
|
)
|
||||||
# Prediction head(Transformer 路径用到,输入维度为 patch_num * d_model)
|
# Prediction head(Transformer 路径用到,输入维度为 patch_num * d_model)
|
||||||
self.head = nn.Sequential(
|
self.head = nn.Sequential(
|
||||||
nn.Linear(patch_num * d_model, patch_num * d_model),
|
nn.Linear(patch_num * d_model, patch_num * d_model),
|
||||||
@ -97,3 +115,11 @@ class SeasonPatch(nn.Module):
|
|||||||
y_pred = self.head(z_last) # y_pred: [B, C, pred_len]
|
y_pred = self.head(z_last) # y_pred: [B, C, pred_len]
|
||||||
|
|
||||||
return y_pred # [B, C, pred_len]
|
return y_pred # [B, C, pred_len]
|
||||||
|
|
||||||
|
def reg_loss(self):
|
||||||
|
"""
|
||||||
|
可选:把 L0 正则暴露出去,训练时加到总loss。
|
||||||
|
"""
|
||||||
|
if self.encoder_type == "Transformer" and hasattr(self, "mixer"):
|
||||||
|
return self.mixer.l0_loss(self.l0_lambda)
|
||||||
|
return torch.tensor(0.0, device=self.head[0].weight.device)
|
||||||
|
@ -22,7 +22,7 @@ class Model(nn.Module):
|
|||||||
self.pred_len = configs.pred_len
|
self.pred_len = configs.pred_len
|
||||||
self.enc_in = configs.enc_in
|
self.enc_in = configs.enc_in
|
||||||
|
|
||||||
# Model parameters
|
# Patch parameters
|
||||||
self.patch_len = getattr(configs, 'patch_len', 16)
|
self.patch_len = getattr(configs, 'patch_len', 16)
|
||||||
self.stride = getattr(configs, 'stride', 8)
|
self.stride = getattr(configs, 'stride', 8)
|
||||||
|
|
||||||
@ -37,19 +37,33 @@ class Model(nn.Module):
|
|||||||
beta = getattr(configs, 'beta', torch.tensor(0.1))
|
beta = getattr(configs, 'beta', torch.tensor(0.1))
|
||||||
self.decomp = DECOMP(ma_type, alpha, beta)
|
self.decomp = DECOMP(ma_type, alpha, beta)
|
||||||
|
|
||||||
# Season network (PatchTST + Graph Mixer)
|
# Season network (PatchTST/Mamba2 + Graph Mixer)
|
||||||
|
# 透传新版 SeasonPatch 的参数(其中 GraphMixer 替换为非归一化 Hard-Concrete 门控)
|
||||||
self.season_net = SeasonPatch(
|
self.season_net = SeasonPatch(
|
||||||
c_in=self.enc_in,
|
c_in=self.enc_in,
|
||||||
seq_len=self.seq_len,
|
seq_len=self.seq_len,
|
||||||
pred_len=self.pred_len,
|
pred_len=self.pred_len,
|
||||||
patch_len=self.patch_len,
|
patch_len=self.patch_len,
|
||||||
stride=self.stride,
|
stride=self.stride,
|
||||||
k_graph=getattr(configs, 'k_graph', 8),
|
# 编码器类型:'Transformer' or 'Mamba2'
|
||||||
|
encoder_type=getattr(configs, 'season_encoder', 'Transformer'),
|
||||||
|
# Patch相关
|
||||||
d_model=getattr(configs, 'd_model', 128),
|
d_model=getattr(configs, 'd_model', 128),
|
||||||
n_layers=getattr(configs, 'e_layers', 3),
|
n_layers=getattr(configs, 'e_layers', 3),
|
||||||
n_heads=getattr(configs, 'n_heads', 16),
|
n_heads=getattr(configs, 'n_heads', 16),
|
||||||
# 读取选择的编码器类型('Transformer' 或 'Mamba2')
|
# GraphMixer相关(非归一化)
|
||||||
encoder_type = getattr(configs, 'season_encoder', 'Transformer')
|
k_graph=getattr(configs, 'k_graph', 8), # -> max_degree
|
||||||
|
thr_graph=getattr(configs, 'thr_graph', 0.5),
|
||||||
|
symmetric_graph=getattr(configs, 'symmetric_graph', True),
|
||||||
|
degree_rescale=getattr(configs, 'degree_rescale', 'count-sqrt'), # 'none' | 'count' | 'count-sqrt' | 'sum'
|
||||||
|
gate_temperature=getattr(configs, 'gate_temperature', 2.0/3.0),
|
||||||
|
tau_attn=getattr(configs, 'tau_attn', 1.0),
|
||||||
|
l0_lambda=getattr(configs, 'season_l0_lambda', 0.0),
|
||||||
|
# Mamba2相关
|
||||||
|
d_state=getattr(configs, 'd_state', 64),
|
||||||
|
d_conv=getattr(configs, 'd_conv', 4),
|
||||||
|
expand=getattr(configs, 'expand', 2),
|
||||||
|
headdim=getattr(configs, 'headdim', 64),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Trend network (MLP)
|
# Trend network (MLP)
|
||||||
@ -119,17 +133,12 @@ class Model(nn.Module):
|
|||||||
|
|
||||||
def classification(self, x_enc, x_mark_enc):
|
def classification(self, x_enc, x_mark_enc):
|
||||||
"""Classification task"""
|
"""Classification task"""
|
||||||
# Normalization
|
# Decomposition(分类任务通常可不做 RevIN,如需可自行打开)
|
||||||
#if self.revin:
|
|
||||||
# x_enc = self.revin_layer(x_enc, 'norm')
|
|
||||||
|
|
||||||
# Decomposition
|
|
||||||
seasonal_init, trend_init = self.decomp(x_enc)
|
seasonal_init, trend_init = self.decomp(x_enc)
|
||||||
|
|
||||||
# Season stream
|
# Season stream
|
||||||
y_season = self.season_net(seasonal_init) # [B, C, pred_len]
|
y_season = self.season_net(seasonal_init) # [B, C, pred_len]
|
||||||
|
|
||||||
# print("shape:", trend_init.shape)
|
|
||||||
# Trend stream
|
# Trend stream
|
||||||
B, L, C = trend_init.shape
|
B, L, C = trend_init.shape
|
||||||
trend = trend_init.permute(0, 2, 1).reshape(B * C, L) # [B*C, L]
|
trend = trend_init.permute(0, 2, 1).reshape(B * C, L) # [B*C, L]
|
||||||
@ -146,7 +155,7 @@ class Model(nn.Module):
|
|||||||
season_attn_weights = torch.softmax(y_season, dim=-1)
|
season_attn_weights = torch.softmax(y_season, dim=-1)
|
||||||
season_pooled = (y_season * season_attn_weights).sum(dim=-1) # [B, C]
|
season_pooled = (y_season * season_attn_weights).sum(dim=-1) # [B, C]
|
||||||
|
|
||||||
trend_attn_weights = torch.softmax(y_trend, dim=-1) # 时间维
|
trend_attn_weights = torch.softmax(y_trend, dim=-1)
|
||||||
trend_pooled = (y_trend * trend_attn_weights).sum(dim=-1) # [B, C]
|
trend_pooled = (y_trend * trend_attn_weights).sum(dim=-1) # [B, C]
|
||||||
|
|
||||||
# Combine features
|
# Combine features
|
||||||
@ -166,3 +175,12 @@ class Model(nn.Module):
|
|||||||
return dec_out # [B, N]
|
return dec_out # [B, N]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Task {self.task_name} not supported by xPatch_SparseChannel')
|
raise ValueError(f'Task {self.task_name} not supported by xPatch_SparseChannel')
|
||||||
|
|
||||||
|
def reg_loss(self):
|
||||||
|
"""
|
||||||
|
L0 正则项(仅在 Transformer 路径启用 GraphMixer 时非零)。
|
||||||
|
训练时:total_loss = main_loss + model.reg_loss()
|
||||||
|
"""
|
||||||
|
if hasattr(self, "season_net") and hasattr(self.season_net, "reg_loss"):
|
||||||
|
return self.season_net.reg_loss()
|
||||||
|
return torch.tensor(0.0, device=next(self.parameters()).device)
|
||||||
|
@ -2,6 +2,45 @@
|
|||||||
|
|
||||||
model_name=xPatch_SparseChannel
|
model_name=xPatch_SparseChannel
|
||||||
|
|
||||||
|
# ETTm1 dataset
|
||||||
|
for pred_len in 96 192 336 720
|
||||||
|
do
|
||||||
|
python -u run.py \
|
||||||
|
--task_name long_term_forecast \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/ETT-small/ \
|
||||||
|
--data_path ETTm1.csv \
|
||||||
|
--model_id ETTm1_$pred_len'_'$pred_len \
|
||||||
|
--model $model_name \
|
||||||
|
--data ETTm1 \
|
||||||
|
--features M \
|
||||||
|
--seq_len 96 \
|
||||||
|
--label_len 48 \
|
||||||
|
--pred_len $pred_len \
|
||||||
|
--e_layers 2 \
|
||||||
|
--d_layers 1 \
|
||||||
|
--enc_in 7 \
|
||||||
|
--c_out 7 \
|
||||||
|
--d_model 128 \
|
||||||
|
--lradj 'sigmoid' \
|
||||||
|
--d_ff 256 \
|
||||||
|
--n_heads 16 \
|
||||||
|
--patch_len 16 \
|
||||||
|
--stride 8 \
|
||||||
|
--k_graph 5 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--revin 1 \
|
||||||
|
--des 'Exp' \
|
||||||
|
--itr 1 \
|
||||||
|
--season_encoder 'Transformer' \
|
||||||
|
--thr_graph 0.6 \
|
||||||
|
--symmetric_graph 1 \
|
||||||
|
--degree_rescale 'none' \
|
||||||
|
--gate_temperature 0.6667 \
|
||||||
|
--tau_attn 1.0 \
|
||||||
|
--season_l0_lambda 0.0000
|
||||||
|
done
|
||||||
|
|
||||||
# Weather dataset
|
# Weather dataset
|
||||||
for pred_len in 96 192 336 720
|
for pred_len in 96 192 336 720
|
||||||
do
|
do
|
||||||
@ -32,7 +71,14 @@ python -u run.py \
|
|||||||
--dropout 0.1 \
|
--dropout 0.1 \
|
||||||
--revin 1 \
|
--revin 1 \
|
||||||
--des 'Exp' \
|
--des 'Exp' \
|
||||||
--itr 1
|
--itr 1 \
|
||||||
|
--season_encoder 'Transformer' \
|
||||||
|
--thr_graph 0.6 \
|
||||||
|
--symmetric_graph 1 \
|
||||||
|
--degree_rescale 'none' \
|
||||||
|
--gate_temperature 0.6667 \
|
||||||
|
--tau_attn 1.0 \
|
||||||
|
--season_l0_lambda 0.0000
|
||||||
done
|
done
|
||||||
|
|
||||||
# Exchange dataset
|
# Exchange dataset
|
||||||
@ -64,40 +110,16 @@ python -u run.py \
|
|||||||
--dropout 0.1 \
|
--dropout 0.1 \
|
||||||
--revin 1 \
|
--revin 1 \
|
||||||
--des 'Exp' \
|
--des 'Exp' \
|
||||||
--itr 1
|
--itr 1 \
|
||||||
|
--season_encoder 'Transformer' \
|
||||||
|
--thr_graph 0.6 \
|
||||||
|
--symmetric_graph 1 \
|
||||||
|
--degree_rescale 'none' \
|
||||||
|
--gate_temperature 0.6667 \
|
||||||
|
--tau_attn 1.0 \
|
||||||
|
--season_l0_lambda 0.0000
|
||||||
done
|
done
|
||||||
|
|
||||||
# ETTm1 dataset
|
|
||||||
for pred_len in 96 192 336 720
|
|
||||||
do
|
|
||||||
python -u run.py \
|
|
||||||
--task_name long_term_forecast \
|
|
||||||
--is_training 1 \
|
|
||||||
--root_path ./dataset/ETT-small/ \
|
|
||||||
--data_path ETTm1.csv \
|
|
||||||
--model_id ETTm1_$pred_len'_'$pred_len \
|
|
||||||
--model $model_name \
|
|
||||||
--data ETTm1 \
|
|
||||||
--features M \
|
|
||||||
--seq_len 96 \
|
|
||||||
--label_len 48 \
|
|
||||||
--pred_len $pred_len \
|
|
||||||
--e_layers 2 \
|
|
||||||
--d_layers 1 \
|
|
||||||
--enc_in 7 \
|
|
||||||
--c_out 7 \
|
|
||||||
--d_model 128 \
|
|
||||||
--lradj 'sigmoid' \
|
|
||||||
--d_ff 256 \
|
|
||||||
--n_heads 16 \
|
|
||||||
--patch_len 16 \
|
|
||||||
--stride 8 \
|
|
||||||
--k_graph 5 \
|
|
||||||
--dropout 0.1 \
|
|
||||||
--revin 1 \
|
|
||||||
--des 'Exp' \
|
|
||||||
--itr 1
|
|
||||||
done
|
|
||||||
|
|
||||||
# ETTm2 dataset
|
# ETTm2 dataset
|
||||||
for pred_len in 96 192 336 720
|
for pred_len in 96 192 336 720
|
||||||
@ -128,7 +150,14 @@ python -u run.py \
|
|||||||
--dropout 0.1 \
|
--dropout 0.1 \
|
||||||
--revin 1 \
|
--revin 1 \
|
||||||
--des 'Exp' \
|
--des 'Exp' \
|
||||||
--itr 1
|
--itr 1 \
|
||||||
|
--season_encoder 'Transformer' \
|
||||||
|
--thr_graph 0.6 \
|
||||||
|
--symmetric_graph 1 \
|
||||||
|
--degree_rescale 'none' \
|
||||||
|
--gate_temperature 0.6667 \
|
||||||
|
--tau_attn 1.0 \
|
||||||
|
--season_l0_lambda 0.0000
|
||||||
done
|
done
|
||||||
|
|
||||||
# ETTh1 dataset
|
# ETTh1 dataset
|
||||||
|
Reference in New Issue
Block a user