feat(graph-mixer): implement L0 sparsity with Hard-Concrete gate for channel selection

This commit is contained in:
gameloader
2025-09-11 16:50:58 +08:00
parent 5fc0da4239
commit 204d17086a
4 changed files with 268 additions and 124 deletions

View File

@ -22,7 +22,7 @@ class Model(nn.Module):
self.pred_len = configs.pred_len
self.enc_in = configs.enc_in
# Model parameters
# Patch parameters
self.patch_len = getattr(configs, 'patch_len', 16)
self.stride = getattr(configs, 'stride', 8)
@ -37,19 +37,33 @@ class Model(nn.Module):
beta = getattr(configs, 'beta', torch.tensor(0.1))
self.decomp = DECOMP(ma_type, alpha, beta)
# Season network (PatchTST + Graph Mixer)
# Season network (PatchTST/Mamba2 + Graph Mixer)
# 透传新版 SeasonPatch 的参数(其中 GraphMixer 替换为非归一化 Hard-Concrete 门控)
self.season_net = SeasonPatch(
c_in=self.enc_in,
seq_len=self.seq_len,
pred_len=self.pred_len,
patch_len=self.patch_len,
stride=self.stride,
k_graph=getattr(configs, 'k_graph', 8),
# 编码器类型:'Transformer' or 'Mamba2'
encoder_type=getattr(configs, 'season_encoder', 'Transformer'),
# Patch相关
d_model=getattr(configs, 'd_model', 128),
n_layers=getattr(configs, 'e_layers', 3),
n_heads=getattr(configs, 'n_heads', 16),
# 读取选择的编码器类型('Transformer' 或 'Mamba2'
encoder_type = getattr(configs, 'season_encoder', 'Transformer')
# GraphMixer相关非归一化
k_graph=getattr(configs, 'k_graph', 8), # -> max_degree
thr_graph=getattr(configs, 'thr_graph', 0.5),
symmetric_graph=getattr(configs, 'symmetric_graph', True),
degree_rescale=getattr(configs, 'degree_rescale', 'count-sqrt'), # 'none' | 'count' | 'count-sqrt' | 'sum'
gate_temperature=getattr(configs, 'gate_temperature', 2.0/3.0),
tau_attn=getattr(configs, 'tau_attn', 1.0),
l0_lambda=getattr(configs, 'season_l0_lambda', 0.0),
# Mamba2相关
d_state=getattr(configs, 'd_state', 64),
d_conv=getattr(configs, 'd_conv', 4),
expand=getattr(configs, 'expand', 2),
headdim=getattr(configs, 'headdim', 64),
)
# Trend network (MLP)
@ -119,17 +133,12 @@ class Model(nn.Module):
def classification(self, x_enc, x_mark_enc):
"""Classification task"""
# Normalization
#if self.revin:
# x_enc = self.revin_layer(x_enc, 'norm')
# Decomposition
# Decomposition分类任务通常可不做 RevIN如需可自行打开
seasonal_init, trend_init = self.decomp(x_enc)
# Season stream
y_season = self.season_net(seasonal_init) # [B, C, pred_len]
# print("shape:", trend_init.shape)
# Trend stream
B, L, C = trend_init.shape
trend = trend_init.permute(0, 2, 1).reshape(B * C, L) # [B*C, L]
@ -146,7 +155,7 @@ class Model(nn.Module):
season_attn_weights = torch.softmax(y_season, dim=-1)
season_pooled = (y_season * season_attn_weights).sum(dim=-1) # [B, C]
trend_attn_weights = torch.softmax(y_trend, dim=-1) # 时间维
trend_attn_weights = torch.softmax(y_trend, dim=-1)
trend_pooled = (y_trend * trend_attn_weights).sum(dim=-1) # [B, C]
# Combine features
@ -166,3 +175,12 @@ class Model(nn.Module):
return dec_out # [B, N]
else:
raise ValueError(f'Task {self.task_name} not supported by xPatch_SparseChannel')
def reg_loss(self):
"""
L0 正则项(仅在 Transformer 路径启用 GraphMixer 时非零)。
训练时total_loss = main_loss + model.reg_loss()
"""
if hasattr(self, "season_net") and hasattr(self.season_net, "reg_loss"):
return self.season_net.reg_loss()
return torch.tensor(0.0, device=next(self.parameters()).device)