feat(graph-mixer): implement L0 sparsity with Hard-Concrete gate for channel selection
This commit is contained in:
@ -27,7 +27,15 @@ class SeasonPatch(nn.Module):
|
||||
d_state: int = 64,
|
||||
d_conv: int = 4,
|
||||
expand: int = 2,
|
||||
headdim: int = 64):
|
||||
headdim: int = 64,
|
||||
# Mixergraph 可选超参数
|
||||
thr_graph: float = 0.5,
|
||||
symmetric_graph: bool = True,
|
||||
degree_rescale: str = "count-sqrt", # "none" | "count" | "count-sqrt" | "sum"
|
||||
gate_temperature: float = 2./3.,
|
||||
tau_attn: float = 1.0,
|
||||
l0_lambda: float = 1e-4):
|
||||
|
||||
super().__init__()
|
||||
|
||||
# Store patch parameters
|
||||
@ -46,7 +54,17 @@ class SeasonPatch(nn.Module):
|
||||
c_in=c_in, patch_num=patch_num, patch_len=patch_len,
|
||||
d_model=d_model, n_layers=n_layers, n_heads=n_heads
|
||||
)
|
||||
self.mixer = HierarchicalGraphMixer(c_in, dim=d_model, k=k_graph)
|
||||
# 集成新 HierarchicalGraphMixer(非归一化)
|
||||
self.mixer = HierarchicalGraphMixer(
|
||||
n_channel=c_in,
|
||||
dim=d_model,
|
||||
max_degree=k_graph,
|
||||
thr=thr_graph,
|
||||
temperature=gate_temperature,
|
||||
tau_attn=tau_attn,
|
||||
symmetric=symmetric_graph,
|
||||
degree_rescale=degree_rescale
|
||||
)
|
||||
# Prediction head(Transformer 路径用到,输入维度为 patch_num * d_model)
|
||||
self.head = nn.Sequential(
|
||||
nn.Linear(patch_num * d_model, patch_num * d_model),
|
||||
@ -97,3 +115,11 @@ class SeasonPatch(nn.Module):
|
||||
y_pred = self.head(z_last) # y_pred: [B, C, pred_len]
|
||||
|
||||
return y_pred # [B, C, pred_len]
|
||||
|
||||
def reg_loss(self):
|
||||
"""
|
||||
可选:把 L0 正则暴露出去,训练时加到总loss。
|
||||
"""
|
||||
if self.encoder_type == "Transformer" and hasattr(self, "mixer"):
|
||||
return self.mixer.l0_loss(self.l0_lambda)
|
||||
return torch.tensor(0.0, device=self.head[0].weight.device)
|
||||
|
Reference in New Issue
Block a user