feat(graph-mixer): implement L0 sparsity with Hard-Concrete gate for channel selection
This commit is contained in:
@ -22,7 +22,7 @@ class Model(nn.Module):
|
||||
self.pred_len = configs.pred_len
|
||||
self.enc_in = configs.enc_in
|
||||
|
||||
# Model parameters
|
||||
# Patch parameters
|
||||
self.patch_len = getattr(configs, 'patch_len', 16)
|
||||
self.stride = getattr(configs, 'stride', 8)
|
||||
|
||||
@ -37,19 +37,33 @@ class Model(nn.Module):
|
||||
beta = getattr(configs, 'beta', torch.tensor(0.1))
|
||||
self.decomp = DECOMP(ma_type, alpha, beta)
|
||||
|
||||
# Season network (PatchTST + Graph Mixer)
|
||||
# Season network (PatchTST/Mamba2 + Graph Mixer)
|
||||
# 透传新版 SeasonPatch 的参数(其中 GraphMixer 替换为非归一化 Hard-Concrete 门控)
|
||||
self.season_net = SeasonPatch(
|
||||
c_in=self.enc_in,
|
||||
seq_len=self.seq_len,
|
||||
pred_len=self.pred_len,
|
||||
patch_len=self.patch_len,
|
||||
stride=self.stride,
|
||||
k_graph=getattr(configs, 'k_graph', 8),
|
||||
# 编码器类型:'Transformer' or 'Mamba2'
|
||||
encoder_type=getattr(configs, 'season_encoder', 'Transformer'),
|
||||
# Patch相关
|
||||
d_model=getattr(configs, 'd_model', 128),
|
||||
n_layers=getattr(configs, 'e_layers', 3),
|
||||
n_heads=getattr(configs, 'n_heads', 16),
|
||||
# 读取选择的编码器类型('Transformer' 或 'Mamba2')
|
||||
encoder_type = getattr(configs, 'season_encoder', 'Transformer')
|
||||
# GraphMixer相关(非归一化)
|
||||
k_graph=getattr(configs, 'k_graph', 8), # -> max_degree
|
||||
thr_graph=getattr(configs, 'thr_graph', 0.5),
|
||||
symmetric_graph=getattr(configs, 'symmetric_graph', True),
|
||||
degree_rescale=getattr(configs, 'degree_rescale', 'count-sqrt'), # 'none' | 'count' | 'count-sqrt' | 'sum'
|
||||
gate_temperature=getattr(configs, 'gate_temperature', 2.0/3.0),
|
||||
tau_attn=getattr(configs, 'tau_attn', 1.0),
|
||||
l0_lambda=getattr(configs, 'season_l0_lambda', 0.0),
|
||||
# Mamba2相关
|
||||
d_state=getattr(configs, 'd_state', 64),
|
||||
d_conv=getattr(configs, 'd_conv', 4),
|
||||
expand=getattr(configs, 'expand', 2),
|
||||
headdim=getattr(configs, 'headdim', 64),
|
||||
)
|
||||
|
||||
# Trend network (MLP)
|
||||
@ -119,17 +133,12 @@ class Model(nn.Module):
|
||||
|
||||
def classification(self, x_enc, x_mark_enc):
|
||||
"""Classification task"""
|
||||
# Normalization
|
||||
#if self.revin:
|
||||
# x_enc = self.revin_layer(x_enc, 'norm')
|
||||
|
||||
# Decomposition
|
||||
# Decomposition(分类任务通常可不做 RevIN,如需可自行打开)
|
||||
seasonal_init, trend_init = self.decomp(x_enc)
|
||||
|
||||
# Season stream
|
||||
y_season = self.season_net(seasonal_init) # [B, C, pred_len]
|
||||
|
||||
# print("shape:", trend_init.shape)
|
||||
# Trend stream
|
||||
B, L, C = trend_init.shape
|
||||
trend = trend_init.permute(0, 2, 1).reshape(B * C, L) # [B*C, L]
|
||||
@ -146,7 +155,7 @@ class Model(nn.Module):
|
||||
season_attn_weights = torch.softmax(y_season, dim=-1)
|
||||
season_pooled = (y_season * season_attn_weights).sum(dim=-1) # [B, C]
|
||||
|
||||
trend_attn_weights = torch.softmax(y_trend, dim=-1) # 时间维
|
||||
trend_attn_weights = torch.softmax(y_trend, dim=-1)
|
||||
trend_pooled = (y_trend * trend_attn_weights).sum(dim=-1) # [B, C]
|
||||
|
||||
# Combine features
|
||||
@ -166,3 +175,12 @@ class Model(nn.Module):
|
||||
return dec_out # [B, N]
|
||||
else:
|
||||
raise ValueError(f'Task {self.task_name} not supported by xPatch_SparseChannel')
|
||||
|
||||
def reg_loss(self):
|
||||
"""
|
||||
L0 正则项(仅在 Transformer 路径启用 GraphMixer 时非零)。
|
||||
训练时:total_loss = main_loss + model.reg_loss()
|
||||
"""
|
||||
if hasattr(self, "season_net") and hasattr(self.season_net, "reg_loss"):
|
||||
return self.season_net.reg_loss()
|
||||
return torch.tensor(0.0, device=next(self.parameters()).device)
|
||||
|
Reference in New Issue
Block a user