feat: add DiffusionTimeSeries and iTransformer models, introduce xPatch_SparseChannel
This commit is contained in:
@ -273,15 +273,21 @@ class Dataset_Custom(Dataset):
|
|||||||
self.data_stamp = data_stamp
|
self.data_stamp = data_stamp
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
|
# 1. 定义输入序列 seq_x 的起止位置
|
||||||
s_begin = index
|
s_begin = index
|
||||||
s_end = s_begin + self.seq_len
|
s_end = s_begin + self.seq_len
|
||||||
r_begin = s_end - self.label_len
|
# 2. 定义目标序列 seq_y 的起止位置
|
||||||
r_end = r_begin + self.label_len + self.pred_len
|
# seq_y 的开始 (r_begin) 就是 seq_x 的结束 (s_end)
|
||||||
|
r_begin = s_end
|
||||||
|
# seq_y 的结束 (r_end) 是其开始位置加上预测长度 (pred_len)
|
||||||
|
r_end = r_begin + self.pred_len
|
||||||
|
# 3. 根据起止位置切片数据
|
||||||
seq_x = self.data_x[s_begin:s_end]
|
seq_x = self.data_x[s_begin:s_end]
|
||||||
seq_y = self.data_y[r_begin:r_end]
|
seq_y = self.data_y[r_begin:r_end]
|
||||||
seq_x_mark = self.data_stamp[s_begin:s_end]
|
seq_x_mark = self.data_stamp[s_begin:s_end]
|
||||||
seq_y_mark = self.data_stamp[r_begin:r_end]
|
seq_y_mark = self.data_stamp[r_begin:r_end]
|
||||||
|
seq_x = seq_x.astype('float32')
|
||||||
|
seq_y = seq_y.astype('float32')
|
||||||
|
|
||||||
return seq_x, seq_y, seq_x_mark, seq_y_mark
|
return seq_x, seq_y, seq_x_mark, seq_y_mark
|
||||||
|
|
||||||
|
0
models/DiffusionTimeSeries/__init__.py
Normal file
0
models/DiffusionTimeSeries/__init__.py
Normal file
323
models/DiffusionTimeSeries/diffusion_ts.py
Normal file
323
models/DiffusionTimeSeries/diffusion_ts.py
Normal file
@ -0,0 +1,323 @@
|
|||||||
|
import math
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------- 工具:构造/缩放线性 betas -----------------------
|
||||||
|
def make_linear_betas(T: int, beta_start=1e-4, beta_end=2e-2, device='cpu'):
|
||||||
|
return torch.linspace(beta_start, beta_end, T, device=device)
|
||||||
|
|
||||||
|
def cumprod_from_betas(betas: torch.Tensor):
|
||||||
|
alphas = 1.0 - betas
|
||||||
|
return torch.cumprod(alphas, dim=0) # shape [T]
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def scale_betas_to_target_cumprod(betas: torch.Tensor, target_cumprod: float, max_scale: float = 100.0):
|
||||||
|
"""
|
||||||
|
给定一段 betas[1..T],寻找缩放系数 s>0,使得 ∏(1 - s*beta_i) = target_cumprod
|
||||||
|
用二分法在 (0, s_max) 上搜索。确保 0 < s*beta_i < 1。
|
||||||
|
"""
|
||||||
|
device = betas.device
|
||||||
|
eps = 1e-12
|
||||||
|
s_low = 0.0
|
||||||
|
s_high = min(max_scale, (1.0 - 1e-6) / (betas.max().item() + eps)) # 使 1 - s*beta > 0
|
||||||
|
|
||||||
|
def cumprod_with_scale(s: float):
|
||||||
|
a = (1.0 - betas * s).clamp(min=1e-6, max=1.0-1e-6)
|
||||||
|
return torch.cumprod(a, dim=0)[-1].item()
|
||||||
|
|
||||||
|
# 若不缩放已接近目标,直接返回
|
||||||
|
base = cumprod_with_scale(1.0)
|
||||||
|
if abs(base - target_cumprod) / max(target_cumprod, 1e-12) < 1e-6:
|
||||||
|
return betas
|
||||||
|
|
||||||
|
# 目标在 (0, s_high) 内单调可达,进行二分
|
||||||
|
for _ in range(60):
|
||||||
|
mid = 0.5 * (s_low + s_high)
|
||||||
|
val = cumprod_with_scale(mid)
|
||||||
|
if val > target_cumprod:
|
||||||
|
# 乘子太小(噪声弱),需要更大 s
|
||||||
|
s_low = mid
|
||||||
|
else:
|
||||||
|
s_high = mid
|
||||||
|
s_best = 0.5 * (s_low + s_high)
|
||||||
|
return (betas * s_best).clamp(min=1e-8, max=1-1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------ DiT Blocks --------------------------------
|
||||||
|
class DiTBlock(nn.Module):
|
||||||
|
def __init__(self, dim: int, heads: int, mlp_ratio=4.0):
|
||||||
|
super().__init__()
|
||||||
|
self.ln1 = nn.LayerNorm(dim)
|
||||||
|
self.attn = nn.MultiheadAttention(dim, heads, batch_first=True)
|
||||||
|
self.ln2 = nn.LayerNorm(dim)
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.Linear(dim, int(dim * mlp_ratio)),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(int(dim * mlp_ratio), dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x: [B, C_tokens, D]
|
||||||
|
h = self.attn(self.ln1(x), self.ln1(x), self.ln1(x))[0]
|
||||||
|
x = x + h
|
||||||
|
x = x + self.mlp(self.ln2(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DiTChannelTokens(nn.Module):
|
||||||
|
"""
|
||||||
|
Token = 一个通道(变量)。
|
||||||
|
对于每个通道,输入是 [L] 的时间向量;我们用两条投影:
|
||||||
|
- W_x : 把 x_t 的时间向量投影成 token 向量
|
||||||
|
- W_n : 把 noise-level(来自 schedule 的 ā 或 b̄ 的时间向量)投影成 token 偏置
|
||||||
|
注意:不再使用可学习的 t-embedding;噪声条件完全由 noise map 决定。
|
||||||
|
"""
|
||||||
|
def __init__(self, L: int, C: int, dim: int = 256, depth: int = 8, heads: int = 8):
|
||||||
|
super().__init__()
|
||||||
|
self.L = L
|
||||||
|
self.C = C
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
# 通道嵌入(可选,用于区分变量)
|
||||||
|
self.channel_embed = nn.Parameter(torch.randn(C, dim) * 0.02)
|
||||||
|
|
||||||
|
# 将每个通道的时间序列映射到 token
|
||||||
|
self.proj_x = nn.Linear(L, dim, bias=False)
|
||||||
|
# 将每个通道的逐时间噪声强度(例如 [sqrt(ā), sqrt(1-ā)] 拼接后经一层线性)
|
||||||
|
self.proj_noise = nn.Linear(L, dim, bias=True)
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList([DiTBlock(dim, heads) for _ in range(depth)])
|
||||||
|
self.ln_f = nn.LayerNorm(dim)
|
||||||
|
|
||||||
|
# 反投影回时间长度 L,预测 ε(每通道独立投影)
|
||||||
|
self.head = nn.Linear(dim, L, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x_t: torch.Tensor, noise_feat: torch.Tensor):
|
||||||
|
"""
|
||||||
|
x_t : [B, L, C]
|
||||||
|
noise_feat: [B, L, C] (建议传入 sqrt(ā) 或 concat 后先合并到 L 维度,这里用一条投影即可)
|
||||||
|
返回 ε̂ : [B, L, C]
|
||||||
|
"""
|
||||||
|
B, L, C = x_t.shape
|
||||||
|
assert L == self.L and C == self.C
|
||||||
|
|
||||||
|
# 逐通道映射成 token
|
||||||
|
# 把 (B, L, C) 变 (B, C, L) 再线性
|
||||||
|
x_tc = x_t.permute(0, 2, 1) # [B, C, L]
|
||||||
|
n_tc = noise_feat.permute(0, 2, 1) # [B, C, L]
|
||||||
|
|
||||||
|
tok = self.proj_x(x_tc) + self.proj_noise(n_tc) # [B, C, D]
|
||||||
|
tok = tok + self.channel_embed.unsqueeze(0) # broadcast [1, C, D]
|
||||||
|
|
||||||
|
for blk in self.blocks:
|
||||||
|
tok = blk(tok) # [B, C, D]
|
||||||
|
|
||||||
|
tok = self.ln_f(tok)
|
||||||
|
out = self.head(tok) # [B, C, L]
|
||||||
|
eps_pred = out.permute(0, 2, 1) # [B, L, C]
|
||||||
|
return eps_pred
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------- RAD 两阶段扩散(通道为token) -----------------------
|
||||||
|
class RADChannelDiT(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
past_len: int,
|
||||||
|
future_len: int,
|
||||||
|
channels: int,
|
||||||
|
T: int = 1000,
|
||||||
|
T1_ratio: float = 0.7,
|
||||||
|
model_dim: int = 256,
|
||||||
|
depth: int = 8,
|
||||||
|
heads: int = 8,
|
||||||
|
beta_start: float = 1e-4,
|
||||||
|
beta_end: float = 2e-2,
|
||||||
|
use_cosine_target: bool = True):
|
||||||
|
"""
|
||||||
|
- 训练:两阶段(Phase-1 + Phase-2),t 从 [1..T] 均匀采样
|
||||||
|
- 推理:仅使用 Phase-1(t: T1→1),只更新未来区域
|
||||||
|
- Token=通道,每个 token 见到整个时间轴 + 噪声强度时间向量
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.P = past_len
|
||||||
|
self.H = future_len
|
||||||
|
self.C = channels
|
||||||
|
self.L = past_len + future_len
|
||||||
|
self.T = T
|
||||||
|
self.T1 = max(1, int(T * T1_ratio))
|
||||||
|
self.T2 = T - self.T1
|
||||||
|
assert self.T2 >= 1, "T1_ratio 不能太大,至少留下 1 步给 Phase-2"
|
||||||
|
|
||||||
|
device = torch.device('cpu')
|
||||||
|
|
||||||
|
# 目标 ā_T(用于把两段线性 schedule 归一到同一最终噪声强度)
|
||||||
|
if use_cosine_target:
|
||||||
|
# 参考 cosine 计划,得到一条“全局目标 ā_T”
|
||||||
|
steps = T + 1
|
||||||
|
x = torch.linspace(0, T, steps, dtype=torch.float64)
|
||||||
|
s = 0.008
|
||||||
|
alphas_cum = torch.cos(((x / T) + s) / (1 + s) * math.pi / 2) ** 2
|
||||||
|
alphas_cum = alphas_cum / alphas_cum[0]
|
||||||
|
a_bar_target_T = float(alphas_cum[-1])
|
||||||
|
else:
|
||||||
|
# 直接用 DDPM 线性 beta 的结果作为目标
|
||||||
|
betas_full = make_linear_betas(T, beta_start, beta_end, device)
|
||||||
|
a_bar_target_T = cumprod_from_betas(betas_full)[-1].item()
|
||||||
|
|
||||||
|
# Phase-1 & Phase-2 原始线性 beta
|
||||||
|
betas1 = make_linear_betas(self.T1, beta_start, beta_end, device)
|
||||||
|
betas2 = make_linear_betas(self.T2, beta_start, beta_end, device)
|
||||||
|
|
||||||
|
# 首先不缩放,计算 ā1[T1], ā2[T2]
|
||||||
|
a_bar1 = cumprod_from_betas(betas1) # shape [T1]
|
||||||
|
a_bar2 = cumprod_from_betas(betas2) # shape [T2]
|
||||||
|
|
||||||
|
# 缩放 Phase-2 的 betas,使 ā1[T1] * ā2'[T2] = 目标 ā_T
|
||||||
|
target_a2 = a_bar_target_T / (a_bar1[-1].item() + 1e-12)
|
||||||
|
betas2 = scale_betas_to_target_cumprod(betas2, target_a2)
|
||||||
|
|
||||||
|
# 重新计算
|
||||||
|
# a_bar1 = cumprod_from_betas(betas1).float() # [T1]
|
||||||
|
a_bar2 = cumprod_from_betas(betas2).float() # [T2]
|
||||||
|
|
||||||
|
self.register_buffer("betas1", betas1.float())
|
||||||
|
self.register_buffer("betas2", betas2.float())
|
||||||
|
self.register_buffer("alphas1", 1.0 - betas1.float())
|
||||||
|
self.register_buffer("alphas2", 1.0 - betas2.float())
|
||||||
|
self.register_buffer("a_bar1", a_bar1)
|
||||||
|
self.register_buffer("a_bar2", a_bar2)
|
||||||
|
self.register_buffer("a_bar_target_T", torch.tensor(a_bar_target_T, dtype=torch.float32))
|
||||||
|
|
||||||
|
# Backbone: token=通道
|
||||||
|
self.backbone = DiTChannelTokens(L=self.L, C=self.C, dim=model_dim, depth=depth, heads=heads)
|
||||||
|
|
||||||
|
# ------------------------ 内部:构造 mask & āt,i ------------------------
|
||||||
|
def _mask_future(self, B, device):
|
||||||
|
# mask: 未来区域=1,历史=0,形状 [B, L, C](与网络输入 [B,L,C] 对齐)
|
||||||
|
m = torch.zeros(B, self.L, self.C, device=device)
|
||||||
|
m[:, self.P:, :] = 1.0
|
||||||
|
return m
|
||||||
|
|
||||||
|
def _a_bar_map_at_t(self, t_scalar: int, B: int, device, mask_future: torch.Tensor):
|
||||||
|
"""
|
||||||
|
构造逐像素 āt,i,形状 [B, L, C]
|
||||||
|
- 若 t<=T1:未来区域用 ā1[t],历史区域=1
|
||||||
|
- 若 t> T1:未来区域固定 ā1[T1],历史区域用 ā2[t-T1]
|
||||||
|
"""
|
||||||
|
if t_scalar <= self.T1:
|
||||||
|
a_future = self.a_bar1[t_scalar - 1] # 索引从 0 开始
|
||||||
|
a_past = torch.tensor(1.0, device=device)
|
||||||
|
else:
|
||||||
|
a_future = self.a_bar1[-1]
|
||||||
|
a_past = self.a_bar2[t_scalar - self.T1 - 1]
|
||||||
|
|
||||||
|
a_future_map = torch.full((B, self.L, self.C), float(a_future.item()), device=device)
|
||||||
|
a_past_map = torch.full((B, self.L, self.C), float(a_past.item()), device=device)
|
||||||
|
a_map = a_past_map * (1 - mask_future) + a_future_map * mask_future
|
||||||
|
return a_map # [B, L, C]
|
||||||
|
|
||||||
|
# ----------------------------- 前向训练 -----------------------------
|
||||||
|
def forward(self, x_hist: torch.Tensor, x_future: torch.Tensor) -> Tuple[torch.Tensor, dict]:
|
||||||
|
"""
|
||||||
|
x_hist : [B, P, C]
|
||||||
|
x_future : [B, H, C]
|
||||||
|
训练:采样 t∈[1..T],构造两阶段 āt,i,边际加噪 xt,并用逐通道 token 的 DiT 预测 ε
|
||||||
|
"""
|
||||||
|
B = x_hist.size(0)
|
||||||
|
device = x_hist.device
|
||||||
|
x0 = torch.cat([x_hist, x_future], dim=1) # [B, L, C]
|
||||||
|
|
||||||
|
# 采样训练步 t (1..T)
|
||||||
|
t = torch.randint(1, self.T + 1, (B,), device=device, dtype=torch.long)
|
||||||
|
|
||||||
|
# 构造 mask 和逐像素 āt,i
|
||||||
|
mask_fut = self._mask_future(B, device) # [B, L, C]
|
||||||
|
# 逐样本构造 āt,i(不同样本 t 不同,只能用循环或向量化 trick;B 通常不大,for 循环即可)
|
||||||
|
a_bar_map = torch.stack([self._a_bar_map_at_t(int(tt.item()), 1, device, mask_fut[0:1])
|
||||||
|
for tt in t], dim=0).squeeze(1) # [B,L,C]
|
||||||
|
|
||||||
|
# 边际加噪
|
||||||
|
eps = torch.randn_like(x0) # [B,L,C]
|
||||||
|
x_t = a_bar_map.sqrt() * x0 + (1.0 - a_bar_map).sqrt() * eps
|
||||||
|
|
||||||
|
# Spatial Noise Embedding:完全由 schedule 决定
|
||||||
|
# 传入每个像素的 √ā 和 √(1-ā)(或任选其一);这里用 √ā
|
||||||
|
noise_feat = a_bar_map.sqrt() # [B,L,C]
|
||||||
|
|
||||||
|
# 预测 ε
|
||||||
|
eps_pred = self.backbone(x_t, noise_feat) # [B,L,C]
|
||||||
|
|
||||||
|
loss = F.mse_loss(eps_pred, eps)
|
||||||
|
return loss, {'t_mean': t.float().mean().item()}
|
||||||
|
|
||||||
|
# ----------------------------- 采样推理 -----------------------------
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample(self, x_hist: torch.Tensor, steps: Optional[int] = None) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
仅 Phase-1 推理:t = T1..1,只更新未来区域,历史保持观测值
|
||||||
|
x_hist : [B,P,C]
|
||||||
|
return : [B,H,C]
|
||||||
|
"""
|
||||||
|
B = x_hist.size(0)
|
||||||
|
device = x_hist.device
|
||||||
|
mask_fut = self._mask_future(B, device) # [B,L,C]
|
||||||
|
|
||||||
|
# 初始化 x:历史=观测,未来=高斯噪声
|
||||||
|
x = torch.zeros(B, self.L, self.C, device=device)
|
||||||
|
x[:, :self.P, :] = x_hist
|
||||||
|
x[:, self.P:, :] = torch.randn(B, self.H, self.C, device=device)
|
||||||
|
|
||||||
|
# 支持子采样:把 [T1..1] 均匀下采样到 steps
|
||||||
|
T1 = self.T1
|
||||||
|
steps = steps if steps is not None else T1
|
||||||
|
steps = max(1, min(steps, T1))
|
||||||
|
ts = torch.linspace(T1, 1, steps, device=device).long().tolist()
|
||||||
|
# 为 DDPM 更新需要 α_t, β_t(仅对未来区域定义)
|
||||||
|
alphas1 = self.alphas1 # [T1]
|
||||||
|
betas1 = self.betas1
|
||||||
|
a_bar1 = self.a_bar1
|
||||||
|
|
||||||
|
for idx, t_scalar in enumerate(ts):
|
||||||
|
# 当前 āt,i(Phase-1:历史=1,未来=ā1[t])
|
||||||
|
a_bar_map = self._a_bar_map_at_t(int(t_scalar), B, device, mask_fut) # [B,L,C]
|
||||||
|
# 网络条件:用 √ā 作为噪声嵌入
|
||||||
|
noise_feat = a_bar_map.sqrt()
|
||||||
|
# 预测 ε
|
||||||
|
eps_pred = self.backbone(x, noise_feat) # [B,L,C]
|
||||||
|
|
||||||
|
# 对未来区域做 DDPM 一步(历史区保持原值)
|
||||||
|
# 标准 DDPM 公式(像素在未来区域共享同一 α_t、β_t)
|
||||||
|
t_idx = t_scalar - 1
|
||||||
|
alpha_t = alphas1[t_idx] # 标量
|
||||||
|
beta_t = betas1[t_idx]
|
||||||
|
a_bar_t = a_bar1[t_idx]
|
||||||
|
if t_scalar > 1:
|
||||||
|
a_bar_prev = a_bar1[t_idx - 1]
|
||||||
|
else:
|
||||||
|
a_bar_prev = torch.tensor(1.0, device=device)
|
||||||
|
|
||||||
|
# x0 预测(仅用于推导均值,也可直接用μ公式)
|
||||||
|
x0_pred = (x - (1.0 - a_bar_t).sqrt() * eps_pred) / (a_bar_t.sqrt() + 1e-8)
|
||||||
|
|
||||||
|
# 均值:μ_t = 1/sqrt(α_t) * (x_t - β_t / sqrt(1 - ā_t) * ε̂)
|
||||||
|
mean = (x - (beta_t / (1.0 - a_bar_t).sqrt()) * eps_pred) / (alpha_t.sqrt() + 1e-8)
|
||||||
|
|
||||||
|
# 采样噪声
|
||||||
|
if t_scalar > 1:
|
||||||
|
z = torch.randn_like(x)
|
||||||
|
else:
|
||||||
|
z = torch.zeros_like(x)
|
||||||
|
|
||||||
|
# 方差项(DDPM):σ_t = sqrt(β_t)
|
||||||
|
x_next = mean + z * beta_t.sqrt()
|
||||||
|
|
||||||
|
# 仅替换未来区域
|
||||||
|
x = x * (1 - mask_fut) + x_next * mask_fut
|
||||||
|
# 历史强制为观测
|
||||||
|
x[:, :self.P, :] = x_hist
|
||||||
|
|
||||||
|
return x[:, self.P:, :] # [B,H,C]
|
||||||
|
|
132
models/iTransformer/iTransformer.py
Normal file
132
models/iTransformer/iTransformer.py
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from layers.Transformer_EncDec import Encoder, EncoderLayer
|
||||||
|
from layers.SelfAttention_Family import FullAttention, AttentionLayer
|
||||||
|
from layers.Embed import DataEmbedding_inverted
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
"""
|
||||||
|
Paper link: https://arxiv.org/abs/2310.06625
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, configs):
|
||||||
|
super(Model, self).__init__()
|
||||||
|
self.task_name = configs.task_name
|
||||||
|
self.seq_len = configs.seq_len
|
||||||
|
self.pred_len = configs.pred_len
|
||||||
|
# Embedding
|
||||||
|
self.enc_embedding = DataEmbedding_inverted(configs.seq_len, configs.d_model, configs.embed, configs.freq,
|
||||||
|
configs.dropout)
|
||||||
|
# Encoder
|
||||||
|
self.encoder = Encoder(
|
||||||
|
[
|
||||||
|
EncoderLayer(
|
||||||
|
AttentionLayer(
|
||||||
|
FullAttention(False, configs.factor, attention_dropout=configs.dropout,
|
||||||
|
output_attention=False), configs.d_model, configs.n_heads),
|
||||||
|
configs.d_model,
|
||||||
|
configs.d_ff,
|
||||||
|
dropout=configs.dropout,
|
||||||
|
activation=configs.activation
|
||||||
|
) for l in range(configs.e_layers)
|
||||||
|
],
|
||||||
|
norm_layer=torch.nn.LayerNorm(configs.d_model)
|
||||||
|
)
|
||||||
|
# Decoder
|
||||||
|
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||||
|
self.projection = nn.Linear(configs.d_model, configs.pred_len, bias=True)
|
||||||
|
if self.task_name == 'imputation':
|
||||||
|
self.projection = nn.Linear(configs.d_model, configs.seq_len, bias=True)
|
||||||
|
if self.task_name == 'anomaly_detection':
|
||||||
|
self.projection = nn.Linear(configs.d_model, configs.seq_len, bias=True)
|
||||||
|
if self.task_name == 'classification':
|
||||||
|
self.act = F.gelu
|
||||||
|
self.dropout = nn.Dropout(configs.dropout)
|
||||||
|
self.projection = nn.Linear(configs.d_model * configs.enc_in, configs.num_class)
|
||||||
|
|
||||||
|
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
||||||
|
# Normalization from Non-stationary Transformer
|
||||||
|
means = x_enc.mean(1, keepdim=True).detach()
|
||||||
|
x_enc = x_enc - means
|
||||||
|
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
||||||
|
x_enc /= stdev
|
||||||
|
|
||||||
|
_, _, N = x_enc.shape
|
||||||
|
|
||||||
|
# Embedding
|
||||||
|
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
||||||
|
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
||||||
|
|
||||||
|
dec_out = self.projection(enc_out).permute(0, 2, 1)[:, :, :N]
|
||||||
|
# De-Normalization from Non-stationary Transformer
|
||||||
|
dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
|
||||||
|
dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
|
||||||
|
return dec_out
|
||||||
|
|
||||||
|
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
|
||||||
|
# Normalization from Non-stationary Transformer
|
||||||
|
means = x_enc.mean(1, keepdim=True).detach()
|
||||||
|
x_enc = x_enc - means
|
||||||
|
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
||||||
|
x_enc /= stdev
|
||||||
|
|
||||||
|
_, L, N = x_enc.shape
|
||||||
|
|
||||||
|
# Embedding
|
||||||
|
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
||||||
|
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
||||||
|
|
||||||
|
dec_out = self.projection(enc_out).permute(0, 2, 1)[:, :, :N]
|
||||||
|
# De-Normalization from Non-stationary Transformer
|
||||||
|
dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, L, 1))
|
||||||
|
dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, L, 1))
|
||||||
|
return dec_out
|
||||||
|
|
||||||
|
def anomaly_detection(self, x_enc):
|
||||||
|
# Normalization from Non-stationary Transformer
|
||||||
|
means = x_enc.mean(1, keepdim=True).detach()
|
||||||
|
x_enc = x_enc - means
|
||||||
|
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
||||||
|
x_enc /= stdev
|
||||||
|
|
||||||
|
_, L, N = x_enc.shape
|
||||||
|
|
||||||
|
# Embedding
|
||||||
|
enc_out = self.enc_embedding(x_enc, None)
|
||||||
|
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
||||||
|
|
||||||
|
dec_out = self.projection(enc_out).permute(0, 2, 1)[:, :, :N]
|
||||||
|
# De-Normalization from Non-stationary Transformer
|
||||||
|
dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, L, 1))
|
||||||
|
dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, L, 1))
|
||||||
|
return dec_out
|
||||||
|
|
||||||
|
def classification(self, x_enc, x_mark_enc):
|
||||||
|
# Embedding
|
||||||
|
enc_out = self.enc_embedding(x_enc, None)
|
||||||
|
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
||||||
|
|
||||||
|
# Output
|
||||||
|
output = self.act(enc_out) # the output transformer encoder/decoder embeddings don't include non-linearity
|
||||||
|
output = self.dropout(output)
|
||||||
|
output = output.reshape(output.shape[0], -1) # (batch_size, c_in * d_model)
|
||||||
|
output = self.projection(output) # (batch_size, num_classes)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
||||||
|
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||||
|
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
||||||
|
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
||||||
|
if self.task_name == 'imputation':
|
||||||
|
dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
|
||||||
|
return dec_out # [B, L, D]
|
||||||
|
if self.task_name == 'anomaly_detection':
|
||||||
|
dec_out = self.anomaly_detection(x_enc)
|
||||||
|
return dec_out # [B, L, D]
|
||||||
|
if self.task_name == 'classification':
|
||||||
|
dec_out = self.classification(x_enc, x_mark_enc)
|
||||||
|
return dec_out # [B, N]
|
||||||
|
return None
|
3
models/xPatch_SparseChannel/__init__.py
Normal file
3
models/xPatch_SparseChannel/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from .xPatch import Model
|
||||||
|
|
||||||
|
__all__ = ['Model']
|
66
models/xPatch_SparseChannel/patchtst_ci.py
Normal file
66
models/xPatch_SparseChannel/patchtst_ci.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
"""
|
||||||
|
SeasonPatch = PatchTST (CI) + ChannelGraphMixer + 线性预测头
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from layers.PatchTST_layers import positional_encoding # 已存在
|
||||||
|
from layers.mixer import HierarchicalGraphMixer # 刚才创建
|
||||||
|
|
||||||
|
# ------ PatchTST CI 编码器(与官方实现等价, 但去掉 head 便于插入 Mixer) ------
|
||||||
|
class _Encoder(nn.Module):
|
||||||
|
def __init__(self, c_in, seq_len, patch_len, stride,
|
||||||
|
d_model=128, n_layers=3, n_heads=16, dropout=0.):
|
||||||
|
super().__init__()
|
||||||
|
self.patch_len = patch_len
|
||||||
|
self.stride = stride
|
||||||
|
self.patch_num = (seq_len - patch_len)//stride + 1
|
||||||
|
|
||||||
|
self.proj = nn.Linear(patch_len, d_model)
|
||||||
|
self.pos = positional_encoding('zeros', True, self.patch_num, d_model)
|
||||||
|
self.drop = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
from layers.PatchTST_backbone import TSTEncoder # 与官方同名
|
||||||
|
self.encoder = TSTEncoder(self.patch_num, d_model, n_heads,
|
||||||
|
d_ff=d_model*2, dropout=dropout,
|
||||||
|
n_layers=n_layers, norm='LayerNorm')
|
||||||
|
|
||||||
|
def forward(self, x): # [B,C,L]
|
||||||
|
x = x.unfold(-1, self.patch_len, self.stride) # [B,C,patch_num,patch_len]
|
||||||
|
B,C,N,P = x.shape
|
||||||
|
z = self.proj(x) # [B,C,N,d]
|
||||||
|
z = z.contiguous().view(B*C, N, -1) # [B*C,N,d]
|
||||||
|
z = self.drop(z + self.pos)
|
||||||
|
z = self.encoder(z) # [B*C,N,d]
|
||||||
|
return z.view(B,C,N,-1) # [B,C,N,d]
|
||||||
|
|
||||||
|
# ------------------------- SeasonPatch -----------------------------
|
||||||
|
class SeasonPatch(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
c_in: int,
|
||||||
|
seq_len: int,
|
||||||
|
pred_len: int,
|
||||||
|
patch_len: int,
|
||||||
|
stride: int,
|
||||||
|
k_graph: int = 5,
|
||||||
|
d_model: int = 128,
|
||||||
|
revin: bool = True):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.encoder = _Encoder(c_in, seq_len, patch_len, stride,
|
||||||
|
d_model=d_model)
|
||||||
|
self.mixer = HierarchicalGraphMixer(c_in, dim=d_model, k=k_graph)
|
||||||
|
|
||||||
|
# Calculate actual number of patches
|
||||||
|
patch_num = (seq_len - patch_len) // stride + 1
|
||||||
|
self.head = nn.Linear(patch_num * d_model, pred_len)
|
||||||
|
|
||||||
|
def forward(self, x): # x [B,L,C]
|
||||||
|
x = x.permute(0,2,1) # → [B,C,L]
|
||||||
|
z = self.encoder(x) # [B,C,N,d]
|
||||||
|
B,C,N,D = z.shape
|
||||||
|
# 通道独立 -> 稀疏跨通道注入
|
||||||
|
z_mix = self.mixer(z).view(B,C,N*D) # [B,C,N,d]
|
||||||
|
y_pred = self.head(z_mix) # [B,C,T]
|
||||||
|
|
||||||
|
return y_pred
|
||||||
|
|
56
models/xPatch_SparseChannel/xPatch.py
Normal file
56
models/xPatch_SparseChannel/xPatch.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import math
|
||||||
|
|
||||||
|
from layers.decomp import DECOMP
|
||||||
|
from .xPatch_SparseChannel import Network
|
||||||
|
from layers.revin import RevIN
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, configs):
|
||||||
|
super(Model, self).__init__()
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
seq_len = configs.seq_len # lookback window L
|
||||||
|
pred_len = configs.pred_len # prediction length (96, 192, 336, 720)
|
||||||
|
c_in = configs.enc_in # input channels
|
||||||
|
|
||||||
|
# Patching
|
||||||
|
patch_len = configs.patch_len
|
||||||
|
stride = configs.stride
|
||||||
|
padding_patch = configs.padding_patch
|
||||||
|
|
||||||
|
# Normalization
|
||||||
|
self.revin = configs.revin
|
||||||
|
self.revin_layer = RevIN(c_in,affine=True,subtract_last=False)
|
||||||
|
|
||||||
|
# Moving Average
|
||||||
|
self.ma_type = configs.ma_type
|
||||||
|
alpha = configs.alpha # smoothing factor for EMA (Exponential Moving Average)
|
||||||
|
beta = configs.beta # smoothing factor for DEMA (Double Exponential Moving Average)
|
||||||
|
|
||||||
|
self.decomp = DECOMP(self.ma_type, alpha, beta)
|
||||||
|
self.net = Network(seq_len, pred_len, patch_len, stride, padding_patch,c_in)
|
||||||
|
# self.net_mlp = NetworkMLP(seq_len, pred_len) # For ablation study with MLP-only stream
|
||||||
|
# self.net_cnn = NetworkCNN(seq_len, pred_len, patch_len, stride, padding_patch) # For ablation study with CNN-only stream
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x: [Batch, Input, Channel]
|
||||||
|
|
||||||
|
# Normalization
|
||||||
|
if self.revin:
|
||||||
|
x = self.revin_layer(x, 'norm')
|
||||||
|
|
||||||
|
if self.ma_type == 'reg': # If no decomposition, directly pass the input to the network
|
||||||
|
x = self.net(x, x)
|
||||||
|
# x = self.net_mlp(x) # For ablation study with MLP-only stream
|
||||||
|
# x = self.net_cnn(x) # For ablation study with CNN-only stream
|
||||||
|
else:
|
||||||
|
seasonal_init, trend_init = self.decomp(x)
|
||||||
|
x = self.net(seasonal_init, trend_init)
|
||||||
|
|
||||||
|
# Denormalization
|
||||||
|
if self.revin:
|
||||||
|
x = self.revin_layer(x, 'denorm')
|
||||||
|
|
||||||
|
return x
|
63
models/xPatch_SparseChannel/xPatch_SparseChannel.py
Normal file
63
models/xPatch_SparseChannel/xPatch_SparseChannel.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from .patchtst_ci import SeasonPatch # <<< 新导入
|
||||||
|
|
||||||
|
class Network(nn.Module):
|
||||||
|
"""
|
||||||
|
trend : 原 MLP 线性流 (完全保留)
|
||||||
|
season : SeasonPatch (PatchTST + Mixer)
|
||||||
|
"""
|
||||||
|
def __init__(self, seq_len, pred_len,
|
||||||
|
patch_len, stride, padding_patch,
|
||||||
|
c_in):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# -------- 季节性流 ---------------
|
||||||
|
self.season_net = SeasonPatch(c_in=c_in,
|
||||||
|
seq_len=seq_len,
|
||||||
|
pred_len=pred_len,
|
||||||
|
patch_len=patch_len,
|
||||||
|
stride=stride,
|
||||||
|
k_graph=5,
|
||||||
|
d_model=128)
|
||||||
|
|
||||||
|
# --------- 线性趋势流 (原代码保持不变) ----------
|
||||||
|
self.pred_len = pred_len
|
||||||
|
self.fc5 = nn.Linear(seq_len, pred_len * 4)
|
||||||
|
self.avgpool1 = nn.AvgPool1d(kernel_size=2)
|
||||||
|
self.ln1 = nn.LayerNorm(pred_len * 2)
|
||||||
|
|
||||||
|
self.fc6 = nn.Linear(pred_len * 2, pred_len)
|
||||||
|
self.avgpool2 = nn.AvgPool1d(kernel_size=2)
|
||||||
|
self.ln2 = nn.LayerNorm(pred_len // 2)
|
||||||
|
|
||||||
|
self.fc7 = nn.Linear(pred_len // 2, pred_len)
|
||||||
|
|
||||||
|
# 流结果拼接
|
||||||
|
self.fc_final = nn.Linear(pred_len * 2, pred_len)
|
||||||
|
|
||||||
|
# ---------------- forward --------------------
|
||||||
|
def forward(self, s, t):
|
||||||
|
# 输入形状: [B,L,C]
|
||||||
|
B,L,C = s.shape
|
||||||
|
|
||||||
|
# ---------- Seasonality ------------
|
||||||
|
y_season = self.season_net(s) # [B,C,T]
|
||||||
|
|
||||||
|
# ---------- Trend (原 MLP) ----------
|
||||||
|
t = t.permute(0,2,1).reshape(B*C, L) # [B*C,L]
|
||||||
|
t = self.fc5(t)
|
||||||
|
t = self.avgpool1(t)
|
||||||
|
t = self.ln1(t)
|
||||||
|
t = self.fc6(t)
|
||||||
|
t = self.avgpool2(t)
|
||||||
|
t = self.ln2(t)
|
||||||
|
t = self.fc7(t) # [B*C,T]
|
||||||
|
y_trend = t.view(B, C, -1) # [B,C,T]
|
||||||
|
|
||||||
|
# --------- 拼接 & 输出 --------------
|
||||||
|
y = torch.cat([y_season, y_trend], dim=-1) # [B,C,2T]
|
||||||
|
y = self.fc_final(y) # [B,C,T]
|
||||||
|
y = y.permute(0,2,1) # [B,T,C]
|
||||||
|
return y
|
||||||
|
|
177
train/train.py
177
train/train.py
@ -8,6 +8,8 @@ from torch.utils.data import DataLoader, TensorDataset
|
|||||||
import swanlab
|
import swanlab
|
||||||
from typing import Dict, Any, Optional, Callable, Union, Tuple
|
from typing import Dict, Any, Optional, Callable, Union, Tuple
|
||||||
from dataflow import data_provider
|
from dataflow import data_provider
|
||||||
|
from layers.ps_loss import PSLoss
|
||||||
|
from utils.tools import adjust_learning_rate, dotdict
|
||||||
|
|
||||||
class EarlyStopping:
|
class EarlyStopping:
|
||||||
"""Early stopping to stop training when validation performance doesn't improve."""
|
"""Early stopping to stop training when validation performance doesn't improve."""
|
||||||
@ -138,7 +140,9 @@ def create_data_loaders_from_dataflow(args, use_x_mark: bool = True) -> Dict[str
|
|||||||
'test': test_loader
|
'test': test_loader
|
||||||
}
|
}
|
||||||
|
|
||||||
def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool = True) -> Dict[str, DataLoader]:
|
def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool = True,
|
||||||
|
num_workers: int = 4, pin_memory: bool = True,
|
||||||
|
persistent_workers: bool = True) -> Dict[str, DataLoader]:
|
||||||
"""
|
"""
|
||||||
Create PyTorch DataLoaders from an NPZ file
|
Create PyTorch DataLoaders from an NPZ file
|
||||||
|
|
||||||
@ -146,6 +150,9 @@ def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool =
|
|||||||
data_path (str): Path to the NPZ file containing the data
|
data_path (str): Path to the NPZ file containing the data
|
||||||
batch_size (int): Batch size for the DataLoaders
|
batch_size (int): Batch size for the DataLoaders
|
||||||
use_x_mark (bool): Whether to use time features (x_mark) from the data file
|
use_x_mark (bool): Whether to use time features (x_mark) from the data file
|
||||||
|
num_workers (int): Number of worker processes for data loading
|
||||||
|
pin_memory (bool): Whether to pin memory for faster GPU transfer
|
||||||
|
persistent_workers (bool): Whether to keep workers alive between epochs
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, DataLoader]: Dictionary with train, val, and test DataLoaders
|
Dict[str, DataLoader]: Dictionary with train, val, and test DataLoaders
|
||||||
@ -200,10 +207,34 @@ def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool =
|
|||||||
val_dataset = TensorDataset(val_x, val_y)
|
val_dataset = TensorDataset(val_x, val_y)
|
||||||
test_dataset = TensorDataset(test_x, test_y)
|
test_dataset = TensorDataset(test_x, test_y)
|
||||||
|
|
||||||
# Create dataloaders
|
# Create dataloaders with performance optimizations
|
||||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
train_loader = DataLoader(
|
||||||
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
train_dataset,
|
||||||
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
|
batch_size=batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=num_workers,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
persistent_workers=persistent_workers if num_workers > 0 else False,
|
||||||
|
drop_last=True # Drop incomplete batches for training
|
||||||
|
)
|
||||||
|
val_loader = DataLoader(
|
||||||
|
val_dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=num_workers,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
persistent_workers=persistent_workers if num_workers > 0 else False,
|
||||||
|
drop_last=False
|
||||||
|
)
|
||||||
|
test_loader = DataLoader(
|
||||||
|
test_dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=num_workers,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
persistent_workers=persistent_workers if num_workers > 0 else False,
|
||||||
|
drop_last=False
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'train': train_loader,
|
'train': train_loader,
|
||||||
@ -223,7 +254,12 @@ def train_forecasting_model(
|
|||||||
log_interval: int = 10,
|
log_interval: int = 10,
|
||||||
use_x_mark: bool = True,
|
use_x_mark: bool = True,
|
||||||
dataset_mode: str = "npz",
|
dataset_mode: str = "npz",
|
||||||
dataflow_args = None
|
dataflow_args = None,
|
||||||
|
use_ps_loss: bool = False,
|
||||||
|
ps_lambda: float = 5.0,
|
||||||
|
patch_len_threshold: int = 64,
|
||||||
|
use_gdw: bool = True,
|
||||||
|
lr_adjust_strategy: str = "type1"
|
||||||
) -> Tuple[nn.Module, Dict[str, float]]:
|
) -> Tuple[nn.Module, Dict[str, float]]:
|
||||||
"""
|
"""
|
||||||
Train a time series forecasting model
|
Train a time series forecasting model
|
||||||
@ -241,6 +277,11 @@ def train_forecasting_model(
|
|||||||
use_x_mark (bool): Whether to use time features (x_mark) from the data file
|
use_x_mark (bool): Whether to use time features (x_mark) from the data file
|
||||||
dataset_mode (str): Dataset construction mode - "npz" or "dataflow"
|
dataset_mode (str): Dataset construction mode - "npz" or "dataflow"
|
||||||
dataflow_args: Arguments object for dataflow mode (required if dataset_mode="dataflow")
|
dataflow_args: Arguments object for dataflow mode (required if dataset_mode="dataflow")
|
||||||
|
use_ps_loss (bool): Whether to use Patch-wise Structural (PS) loss instead of MSE
|
||||||
|
ps_lambda (float): Weight for PS loss component when combined with MSE
|
||||||
|
patch_len_threshold (int): Maximum patch length for adaptive patching
|
||||||
|
use_gdw (bool): Whether to use Gradient-based Dynamic Weighting
|
||||||
|
lr_adjust_strategy (str): Learning rate adjustment strategy - 'type1', 'type2', 'type3', 'sigmoid', 'constant', '3', '4', '5', '6'
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics
|
Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics
|
||||||
@ -271,7 +312,10 @@ def train_forecasting_model(
|
|||||||
dataloaders = create_data_loaders(
|
dataloaders = create_data_loaders(
|
||||||
data_path=data_path,
|
data_path=data_path,
|
||||||
batch_size=config.get('batch_size', 32),
|
batch_size=config.get('batch_size', 32),
|
||||||
use_x_mark=use_x_mark
|
use_x_mark=use_x_mark,
|
||||||
|
num_workers=config.get('num_workers', 4),
|
||||||
|
pin_memory=config.get('pin_memory', True),
|
||||||
|
persistent_workers=config.get('persistent_workers', True)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Construct the model
|
# Construct the model
|
||||||
@ -279,14 +323,24 @@ def train_forecasting_model(
|
|||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
# Define loss function and optimizer
|
# Define loss function and optimizer
|
||||||
criterion = nn.MSELoss()
|
if use_ps_loss:
|
||||||
|
criterion = PSLoss(
|
||||||
|
patch_len_threshold=patch_len_threshold,
|
||||||
|
lambda_ps=ps_lambda,
|
||||||
|
use_gdw=use_gdw
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
criterion = nn.MSELoss()
|
||||||
optimizer = optim.Adam(
|
optimizer = optim.Adam(
|
||||||
model.parameters(),
|
model.parameters(),
|
||||||
lr=config.get('learning_rate', 1e-3),
|
lr=config.get('learning_rate', 1e-3),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add learning rate scheduler to halve LR after each epoch
|
# Create args object for learning rate adjustment
|
||||||
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
|
lr_args = dotdict({
|
||||||
|
'learning_rate': config.get('learning_rate', 1e-3),
|
||||||
|
'lradj': lr_adjust_strategy
|
||||||
|
})
|
||||||
|
|
||||||
# Initialize early stopping
|
# Initialize early stopping
|
||||||
early_stopping = EarlyStopping(
|
early_stopping = EarlyStopping(
|
||||||
@ -334,7 +388,11 @@ def train_forecasting_model(
|
|||||||
# For simple models without time features
|
# For simple models without time features
|
||||||
outputs = model(inputs)
|
outputs = model(inputs)
|
||||||
|
|
||||||
loss = criterion(outputs, targets)
|
# Calculate loss
|
||||||
|
if use_ps_loss:
|
||||||
|
loss, loss_dict = criterion(outputs, targets, model)
|
||||||
|
else:
|
||||||
|
loss = criterion(outputs, targets)
|
||||||
|
|
||||||
# Backward pass and optimize
|
# Backward pass and optimize
|
||||||
loss.backward()
|
loss.backward()
|
||||||
@ -345,10 +403,26 @@ def train_forecasting_model(
|
|||||||
interval_loss += loss.item()
|
interval_loss += loss.item()
|
||||||
|
|
||||||
if (batch_idx + 1) % log_interval == 0:
|
if (batch_idx + 1) % log_interval == 0:
|
||||||
print(f"Batch {batch_idx+1}/{len(dataloaders['train'])}, Loss: {loss.item():.4f}")
|
if use_ps_loss and 'loss_dict' in locals():
|
||||||
# 计算这一个 interval 的平均损失并记录
|
print(f"Batch {batch_idx+1}/{len(dataloaders['train'])}, "
|
||||||
avg_interval_loss = interval_loss / log_interval
|
f"Total Loss: {loss.item():.4f}, "
|
||||||
swanlab_run.log({"batch_train_loss": avg_interval_loss})
|
f"MSE: {loss_dict['mse_loss']:.4f}, "
|
||||||
|
f"PS: {loss_dict['ps_loss']:.4f}")
|
||||||
|
# Log detailed loss components
|
||||||
|
swanlab_run.log({
|
||||||
|
"batch_total_loss": loss.item(),
|
||||||
|
"batch_mse_loss": loss_dict['mse_loss'],
|
||||||
|
"batch_ps_loss": loss_dict['ps_loss'],
|
||||||
|
"batch_corr_loss": loss_dict['corr_loss'],
|
||||||
|
"batch_var_loss": loss_dict['var_loss'],
|
||||||
|
"batch_mean_loss": loss_dict['mean_loss'],
|
||||||
|
"alpha": loss_dict['alpha'],
|
||||||
|
"beta": loss_dict['beta'],
|
||||||
|
"gamma": loss_dict['gamma']
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
print(f"Batch {batch_idx+1}/{len(dataloaders['train'])}, Loss: {loss.item():.4f}")
|
||||||
|
swanlab_run.log({"batch_train_loss": loss.item()})
|
||||||
|
|
||||||
# 重置 interval loss 以进行下一次计算
|
# 重置 interval loss 以进行下一次计算
|
||||||
interval_loss = 0.0
|
interval_loss = 0.0
|
||||||
@ -360,6 +434,7 @@ def train_forecasting_model(
|
|||||||
model.eval()
|
model.eval()
|
||||||
val_loss = 0.0
|
val_loss = 0.0
|
||||||
val_mse = 0.0
|
val_mse = 0.0
|
||||||
|
val_mse_criterion = nn.MSELoss() # Always use MSE for validation metrics
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch_data in dataloaders['val']:
|
for batch_data in dataloaders['val']:
|
||||||
@ -381,18 +456,28 @@ def train_forecasting_model(
|
|||||||
# For simple models without time features
|
# For simple models without time features
|
||||||
outputs = model(inputs)
|
outputs = model(inputs)
|
||||||
|
|
||||||
# Calculate loss
|
# Calculate training loss (PS or MSE)
|
||||||
loss = criterion(outputs, targets)
|
if use_ps_loss:
|
||||||
val_loss += loss.item()
|
loss, _ = criterion(outputs, targets, model)
|
||||||
|
val_loss += loss.item()
|
||||||
|
else:
|
||||||
|
loss = criterion(outputs, targets)
|
||||||
|
val_loss += loss.item()
|
||||||
|
|
||||||
|
# Always calculate MSE for validation metrics
|
||||||
|
mse_loss = val_mse_criterion(outputs, targets)
|
||||||
|
val_mse += mse_loss.item()
|
||||||
|
|
||||||
|
|
||||||
avg_val_loss = val_loss / len(dataloaders['val'])
|
avg_val_loss = val_loss / len(dataloaders['val'])
|
||||||
|
avg_val_mse = val_mse / len(dataloaders['val'])
|
||||||
current_lr = optimizer.param_groups[0]['lr']
|
current_lr = optimizer.param_groups[0]['lr']
|
||||||
|
|
||||||
# Log metrics
|
# Log metrics
|
||||||
metrics_dict = {
|
metrics_dict = {
|
||||||
"train_loss": avg_train_loss,
|
"train_loss": avg_train_loss,
|
||||||
"val_loss": avg_val_loss,
|
"val_loss": avg_val_loss,
|
||||||
|
"val_mse": avg_val_mse,
|
||||||
"learning_rate": current_lr,
|
"learning_rate": current_lr,
|
||||||
"epoch_time": epoch_time
|
"epoch_time": epoch_time
|
||||||
}
|
}
|
||||||
@ -402,6 +487,7 @@ def train_forecasting_model(
|
|||||||
print(f"Epoch {epoch+1}/{max_epochs}, "
|
print(f"Epoch {epoch+1}/{max_epochs}, "
|
||||||
f"Train Loss: {avg_train_loss:.4f}, "
|
f"Train Loss: {avg_train_loss:.4f}, "
|
||||||
f"Val Loss: {avg_val_loss:.4f}, "
|
f"Val Loss: {avg_val_loss:.4f}, "
|
||||||
|
f"Val MSE: {avg_val_mse:.4f}, "
|
||||||
f"LR: {current_lr:.6f}, "
|
f"LR: {current_lr:.6f}, "
|
||||||
f"Time: {epoch_time:.2f}s")
|
f"Time: {epoch_time:.2f}s")
|
||||||
|
|
||||||
@ -416,16 +502,17 @@ def train_forecasting_model(
|
|||||||
print("Early stopping triggered")
|
print("Early stopping triggered")
|
||||||
break
|
break
|
||||||
|
|
||||||
# Step the learning rate scheduler
|
# Adjust learning rate using utils.tools function
|
||||||
scheduler.step()
|
adjust_learning_rate(optimizer, epoch, lr_args)
|
||||||
|
|
||||||
# Load the best model
|
# Load the best model
|
||||||
model.load_state_dict(torch.load(checkpoint_path))
|
model.load_state_dict(torch.load(checkpoint_path))
|
||||||
|
|
||||||
# Test evaluation on the best model
|
# Test evaluation on the best model - Always use MSE for final evaluation
|
||||||
model.eval()
|
model.eval()
|
||||||
test_loss = 0.0
|
test_loss = 0.0
|
||||||
test_mse = 0.0
|
test_mse = 0.0
|
||||||
|
mse_criterion = nn.MSELoss() # Always use MSE for test evaluation
|
||||||
|
|
||||||
print("Evaluating on test set...")
|
print("Evaluating on test set...")
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -448,16 +535,16 @@ def train_forecasting_model(
|
|||||||
# For simple models without time features
|
# For simple models without time features
|
||||||
outputs = model(inputs)
|
outputs = model(inputs)
|
||||||
|
|
||||||
# Calculate loss
|
# Always calculate MSE for test evaluation (for fair comparison)
|
||||||
loss = criterion(outputs, targets)
|
mse_loss = mse_criterion(outputs, targets)
|
||||||
test_loss += loss.item()
|
test_loss += mse_loss.item()
|
||||||
|
|
||||||
test_loss /= len(dataloaders['test'])
|
test_loss /= len(dataloaders['test'])
|
||||||
|
|
||||||
print(f"Test evaluation completed!")
|
print(f"Test evaluation completed!")
|
||||||
print(f"Test Loss (MSE): {test_loss:.6f}")
|
print(f"Test Loss (MSE): {test_loss:.6f}")
|
||||||
|
|
||||||
# Final validation for consistency
|
# Final validation for consistency - Always use MSE for final metrics
|
||||||
model.eval()
|
model.eval()
|
||||||
final_val_loss = 0.0
|
final_val_loss = 0.0
|
||||||
final_val_mse = 0.0
|
final_val_mse = 0.0
|
||||||
@ -482,25 +569,31 @@ def train_forecasting_model(
|
|||||||
# For simple models without time features
|
# For simple models without time features
|
||||||
outputs = model(inputs)
|
outputs = model(inputs)
|
||||||
|
|
||||||
# Calculate loss
|
# Always calculate MSE for final validation (for fair comparison)
|
||||||
loss = criterion(outputs, targets)
|
mse_loss = mse_criterion(outputs, targets)
|
||||||
final_val_loss += loss.item()
|
final_val_loss += mse_loss.item()
|
||||||
|
|
||||||
|
|
||||||
final_val_loss /= len(dataloaders['val'])
|
final_val_loss /= len(dataloaders['val'])
|
||||||
|
|
||||||
print(f"Final validation loss: {final_val_loss:.6f}")
|
print(f"Final validation MSE: {final_val_loss:.6f}")
|
||||||
|
print(f"Final test MSE: {test_loss:.6f}")
|
||||||
|
|
||||||
|
if use_ps_loss:
|
||||||
|
print("Note: Model was trained with PS Loss but evaluated with MSE for fair comparison")
|
||||||
|
|
||||||
# Log final test results to swanlab
|
# Log final test results to swanlab
|
||||||
final_metrics = {
|
final_metrics = {
|
||||||
"final_test_loss": test_loss,
|
"final_test_mse": test_loss,
|
||||||
"final_val_loss": final_val_loss
|
"final_val_mse": final_val_loss
|
||||||
}
|
}
|
||||||
swanlab_run.log(final_metrics)
|
swanlab_run.log(final_metrics)
|
||||||
|
|
||||||
# Update metrics with final values
|
# Update metrics with final values (always MSE for comparison)
|
||||||
metrics["final_val_loss"] = final_val_loss
|
metrics["final_val_loss"] = final_val_loss
|
||||||
metrics["final_test_loss"] = test_loss
|
metrics["final_test_loss"] = test_loss
|
||||||
|
metrics["final_val_mse"] = final_val_loss # Same as final_val_loss since we use MSE
|
||||||
|
metrics["final_test_mse"] = test_loss # Same as final_test_loss since we use MSE
|
||||||
|
|
||||||
# Finish the swanlab run
|
# Finish the swanlab run
|
||||||
swanlab_run.finish()
|
swanlab_run.finish()
|
||||||
@ -519,7 +612,8 @@ def train_classification_model(
|
|||||||
log_interval: int = 10,
|
log_interval: int = 10,
|
||||||
use_x_mark: bool = True,
|
use_x_mark: bool = True,
|
||||||
dataset_mode: str = "npz",
|
dataset_mode: str = "npz",
|
||||||
dataflow_args = None
|
dataflow_args = None,
|
||||||
|
lr_adjust_strategy: str = "type1"
|
||||||
) -> Tuple[nn.Module, Dict[str, float]]:
|
) -> Tuple[nn.Module, Dict[str, float]]:
|
||||||
"""
|
"""
|
||||||
Train a time series classification model
|
Train a time series classification model
|
||||||
@ -537,6 +631,7 @@ def train_classification_model(
|
|||||||
use_x_mark (bool): Whether to use time features (x_mark) from the data file
|
use_x_mark (bool): Whether to use time features (x_mark) from the data file
|
||||||
dataset_mode (str): Dataset construction mode - "npz" or "dataflow"
|
dataset_mode (str): Dataset construction mode - "npz" or "dataflow"
|
||||||
dataflow_args: Arguments object for dataflow mode (required if dataset_mode="dataflow")
|
dataflow_args: Arguments object for dataflow mode (required if dataset_mode="dataflow")
|
||||||
|
lr_adjust_strategy (str): Learning rate adjustment strategy - 'type1', 'type2', 'type3', 'sigmoid', 'constant', '3', '4', '5', '6'
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics
|
Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics
|
||||||
@ -567,7 +662,10 @@ def train_classification_model(
|
|||||||
dataloaders = create_data_loaders(
|
dataloaders = create_data_loaders(
|
||||||
data_path=data_path,
|
data_path=data_path,
|
||||||
batch_size=config.get('batch_size', 32),
|
batch_size=config.get('batch_size', 32),
|
||||||
use_x_mark=use_x_mark
|
use_x_mark=use_x_mark,
|
||||||
|
num_workers=config.get('num_workers', 4),
|
||||||
|
pin_memory=config.get('pin_memory', True),
|
||||||
|
persistent_workers=config.get('persistent_workers', True)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Construct the model
|
# Construct the model
|
||||||
@ -582,8 +680,11 @@ def train_classification_model(
|
|||||||
weight_decay=config.get('weight_decay', 1e-4)
|
weight_decay=config.get('weight_decay', 1e-4)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add learning rate scheduler to halve LR after each epoch
|
# Create args object for learning rate adjustment
|
||||||
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5)
|
lr_args = dotdict({
|
||||||
|
'learning_rate': config.get('learning_rate', 1e-3),
|
||||||
|
'lradj': lr_adjust_strategy
|
||||||
|
})
|
||||||
|
|
||||||
# Initialize early stopping
|
# Initialize early stopping
|
||||||
early_stopping = EarlyStopping(
|
early_stopping = EarlyStopping(
|
||||||
@ -722,8 +823,8 @@ def train_classification_model(
|
|||||||
print("Early stopping triggered")
|
print("Early stopping triggered")
|
||||||
break
|
break
|
||||||
|
|
||||||
# Step the learning rate scheduler
|
# Adjust learning rate using utils.tools function
|
||||||
scheduler.step()
|
adjust_learning_rate(optimizer, epoch, lr_args)
|
||||||
|
|
||||||
# Load the best model
|
# Load the best model
|
||||||
model.load_state_dict(torch.load(checkpoint_path))
|
model.load_state_dict(torch.load(checkpoint_path))
|
||||||
|
408
train/train_diffusion.py
Normal file
408
train/train_diffusion.py
Normal file
@ -0,0 +1,408 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
from torch.utils.data import DataLoader, TensorDataset
|
||||||
|
import swanlab
|
||||||
|
from typing import Dict, Any, Optional, Callable, Union, Tuple
|
||||||
|
from dataflow import data_provider
|
||||||
|
|
||||||
|
class EarlyStopping:
|
||||||
|
"""Early stopping to stop training when validation performance doesn't improve."""
|
||||||
|
def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'):
|
||||||
|
self.patience = patience
|
||||||
|
self.verbose = verbose
|
||||||
|
self.counter = 0
|
||||||
|
self.best_score = None
|
||||||
|
self.early_stop = False
|
||||||
|
self.val_loss_min = float('inf')
|
||||||
|
self.delta = delta
|
||||||
|
self.path = path
|
||||||
|
|
||||||
|
def __call__(self, val_loss, model):
|
||||||
|
score = -val_loss
|
||||||
|
|
||||||
|
if self.best_score is None:
|
||||||
|
self.best_score = score
|
||||||
|
self.save_checkpoint(val_loss, model)
|
||||||
|
elif score < self.best_score + self.delta:
|
||||||
|
self.counter += 1
|
||||||
|
if self.verbose:
|
||||||
|
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
|
||||||
|
if self.counter >= self.patience:
|
||||||
|
self.early_stop = True
|
||||||
|
else:
|
||||||
|
self.best_score = score
|
||||||
|
self.save_checkpoint(val_loss, model)
|
||||||
|
self.counter = 0
|
||||||
|
|
||||||
|
def save_checkpoint(self, val_loss, model):
|
||||||
|
"""Save model when validation loss decreases."""
|
||||||
|
if self.verbose:
|
||||||
|
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...')
|
||||||
|
torch.save(model.state_dict(), self.path)
|
||||||
|
self.val_loss_min = val_loss
|
||||||
|
|
||||||
|
class DatasetWrapperWithoutTimeFeatures(torch.utils.data.Dataset):
|
||||||
|
"""Wrapper to remove time features from dataflow datasets when use_x_mark=False"""
|
||||||
|
def __init__(self, original_dataset):
|
||||||
|
self.original_dataset = original_dataset
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
seq_x, seq_y, seq_x_mark, seq_y_mark = self.original_dataset[index]
|
||||||
|
return seq_x, seq_y
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.original_dataset)
|
||||||
|
|
||||||
|
def inverse_transform(self, data):
|
||||||
|
if hasattr(self.original_dataset, 'inverse_transform'):
|
||||||
|
return self.original_dataset.inverse_transform(data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def create_data_loaders_from_dataflow(args, use_x_mark: bool = True) -> Dict[str, DataLoader]:
|
||||||
|
"""Create PyTorch DataLoaders using dataflow data_provider"""
|
||||||
|
train_data, _ = data_provider(args, flag='train')
|
||||||
|
val_data, _ = data_provider(args, flag='val')
|
||||||
|
test_data, _ = data_provider(args, flag='test')
|
||||||
|
|
||||||
|
if not use_x_mark:
|
||||||
|
train_data = DatasetWrapperWithoutTimeFeatures(train_data)
|
||||||
|
val_data = DatasetWrapperWithoutTimeFeatures(val_data)
|
||||||
|
test_data = DatasetWrapperWithoutTimeFeatures(test_data)
|
||||||
|
|
||||||
|
train_shuffle = True
|
||||||
|
val_shuffle = False
|
||||||
|
test_shuffle = False
|
||||||
|
|
||||||
|
train_drop_last = True
|
||||||
|
val_drop_last = True
|
||||||
|
test_drop_last = True
|
||||||
|
|
||||||
|
batch_size = args.batch_size
|
||||||
|
num_workers = args.num_workers
|
||||||
|
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_data, batch_size=batch_size, shuffle=train_shuffle,
|
||||||
|
num_workers=num_workers, drop_last=train_drop_last
|
||||||
|
)
|
||||||
|
val_loader = DataLoader(
|
||||||
|
val_data, batch_size=batch_size, shuffle=val_shuffle,
|
||||||
|
num_workers=num_workers, drop_last=val_drop_last
|
||||||
|
)
|
||||||
|
test_loader = DataLoader(
|
||||||
|
test_data, batch_size=batch_size, shuffle=test_shuffle,
|
||||||
|
num_workers=num_workers, drop_last=test_drop_last
|
||||||
|
)
|
||||||
|
|
||||||
|
return {'train': train_loader, 'val': val_loader, 'test': test_loader}
|
||||||
|
|
||||||
|
def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool = True) -> Dict[str, DataLoader]:
|
||||||
|
"""
|
||||||
|
Create PyTorch DataLoaders from an NPZ file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_path (str): Path to the NPZ file containing the data
|
||||||
|
batch_size (int): Batch size for the DataLoaders
|
||||||
|
use_x_mark (bool): Whether to use time features (x_mark) from the data file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, DataLoader]: Dictionary with train, val, and test DataLoaders
|
||||||
|
"""
|
||||||
|
# Load data from NPZ file
|
||||||
|
data = np.load(data_path, allow_pickle=True)
|
||||||
|
train_x = data['train_x']
|
||||||
|
train_y = data['train_y']
|
||||||
|
val_x = data['val_x']
|
||||||
|
val_y = data['val_y']
|
||||||
|
test_x = data['test_x']
|
||||||
|
test_y = data['test_y']
|
||||||
|
|
||||||
|
# Load time features if available and needed
|
||||||
|
if use_x_mark:
|
||||||
|
train_x_mark = data.get('train_x_mark', None)
|
||||||
|
train_y_mark = data.get('train_y_mark', None)
|
||||||
|
val_x_mark = data.get('val_x_mark', None)
|
||||||
|
val_y_mark = data.get('val_y_mark', None)
|
||||||
|
test_x_mark = data.get('test_x_mark', None)
|
||||||
|
test_y_mark = data.get('test_y_mark', None)
|
||||||
|
else:
|
||||||
|
train_x_mark = None
|
||||||
|
train_y_mark = None
|
||||||
|
val_x_mark = None
|
||||||
|
val_y_mark = None
|
||||||
|
test_x_mark = None
|
||||||
|
test_y_mark = None
|
||||||
|
|
||||||
|
# Convert to PyTorch tensors
|
||||||
|
train_x = torch.FloatTensor(train_x)
|
||||||
|
train_y = torch.FloatTensor(train_y)
|
||||||
|
val_x = torch.FloatTensor(val_x)
|
||||||
|
val_y = torch.FloatTensor(val_y)
|
||||||
|
test_x = torch.FloatTensor(test_x)
|
||||||
|
test_y = torch.FloatTensor(test_y)
|
||||||
|
|
||||||
|
# Create datasets based on whether time features are available
|
||||||
|
if train_x_mark is not None:
|
||||||
|
train_x_mark = torch.FloatTensor(train_x_mark)
|
||||||
|
train_y_mark = torch.FloatTensor(train_y_mark)
|
||||||
|
val_x_mark = torch.FloatTensor(val_x_mark)
|
||||||
|
val_y_mark = torch.FloatTensor(val_y_mark)
|
||||||
|
test_x_mark = torch.FloatTensor(test_x_mark)
|
||||||
|
test_y_mark = torch.FloatTensor(test_y_mark)
|
||||||
|
|
||||||
|
train_dataset = TensorDataset(train_x, train_y, train_x_mark, train_y_mark)
|
||||||
|
val_dataset = TensorDataset(val_x, val_y, val_x_mark, val_y_mark)
|
||||||
|
test_dataset = TensorDataset(test_x, test_y, test_x_mark, test_y_mark)
|
||||||
|
else:
|
||||||
|
train_dataset = TensorDataset(train_x, train_y)
|
||||||
|
val_dataset = TensorDataset(val_x, val_y)
|
||||||
|
test_dataset = TensorDataset(test_x, test_y)
|
||||||
|
|
||||||
|
# Create dataloaders
|
||||||
|
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
||||||
|
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
||||||
|
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'train': train_loader,
|
||||||
|
'val': val_loader,
|
||||||
|
'test': test_loader
|
||||||
|
}
|
||||||
|
|
||||||
|
def train_diffusion_model(
|
||||||
|
model_constructor: Callable,
|
||||||
|
data_path: str,
|
||||||
|
project_name: str,
|
||||||
|
config: Dict[str, Any],
|
||||||
|
device: Optional[str] = None,
|
||||||
|
early_stopping_patience: int = 10,
|
||||||
|
max_epochs: int = 100,
|
||||||
|
checkpoint_dir: str = "./checkpoints",
|
||||||
|
log_interval: int = 10,
|
||||||
|
) -> Tuple[nn.Module, Dict[str, float]]:
|
||||||
|
"""
|
||||||
|
Train a Diffusion time series forecasting model using NPZ data loading
|
||||||
|
"""
|
||||||
|
# Setup device
|
||||||
|
if device is None:
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
|
||||||
|
# Initialize swanlab for experiment tracking
|
||||||
|
swanlab_run = swanlab.init(
|
||||||
|
project=project_name,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create checkpoint directory if it doesn't exist
|
||||||
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
|
checkpoint_path = os.path.join(checkpoint_dir, f"{project_name}.pt")
|
||||||
|
|
||||||
|
# Create data loaders using NPZ files (following other models' pattern)
|
||||||
|
dataloaders = create_data_loaders(
|
||||||
|
data_path=data_path,
|
||||||
|
batch_size=config.get('batch_size', 32),
|
||||||
|
use_x_mark=False # DiffusionTimeSeries doesn't use time features
|
||||||
|
)
|
||||||
|
|
||||||
|
# Construct the model
|
||||||
|
model = model_constructor()
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
print(f"Model created with {model.get_num_params():,} parameters")
|
||||||
|
|
||||||
|
# Define optimizer for diffusion training
|
||||||
|
optimizer = optim.Adam(
|
||||||
|
model.parameters(),
|
||||||
|
lr=config.get('learning_rate', 1e-4),
|
||||||
|
weight_decay=config.get('weight_decay', 1e-4)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Learning rate scheduler
|
||||||
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||||||
|
optimizer, mode='min', patience=5, factor=0.5
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize early stopping
|
||||||
|
early_stopping = EarlyStopping(
|
||||||
|
patience=early_stopping_patience,
|
||||||
|
verbose=True,
|
||||||
|
path=checkpoint_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
best_val_loss = float('inf')
|
||||||
|
metrics = {}
|
||||||
|
|
||||||
|
for epoch in range(max_epochs):
|
||||||
|
print(f"\nEpoch {epoch+1}/{max_epochs}")
|
||||||
|
print("-" * 50)
|
||||||
|
|
||||||
|
# Training phase
|
||||||
|
model.train()
|
||||||
|
train_loss = 0.0
|
||||||
|
train_samples = 0
|
||||||
|
|
||||||
|
interval_loss = 0.0
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
for batch_idx, (seq_x, seq_y) in enumerate(dataloaders['train']):
|
||||||
|
seq_x, seq_y = seq_x.to(device), seq_y.to(device)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# Diffusion training: model returns loss directly when y is provided
|
||||||
|
loss = model(seq_x, seq_y)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
train_loss += loss.item()
|
||||||
|
interval_loss += loss.item()
|
||||||
|
train_samples += 1
|
||||||
|
|
||||||
|
# Log at intervals
|
||||||
|
if (batch_idx + 1) % log_interval == 0:
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
avg_interval_loss = interval_loss / log_interval
|
||||||
|
|
||||||
|
print(f' Batch [{batch_idx+1}/{len(dataloaders["train"])}] '
|
||||||
|
f'Loss: {avg_interval_loss:.6f} '
|
||||||
|
f'Time: {elapsed_time:.2f}s')
|
||||||
|
|
||||||
|
# Log to swanlab
|
||||||
|
swanlab.log({
|
||||||
|
'batch_loss': avg_interval_loss,
|
||||||
|
'batch': epoch * len(dataloaders['train']) + batch_idx,
|
||||||
|
'learning_rate': optimizer.param_groups[0]['lr']
|
||||||
|
})
|
||||||
|
|
||||||
|
interval_loss = 0.0
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
avg_train_loss = train_loss / train_samples
|
||||||
|
|
||||||
|
# Validation phase - Use faster sampling for validation
|
||||||
|
model.eval()
|
||||||
|
val_loss = 0.0
|
||||||
|
val_samples = 0
|
||||||
|
criterion = nn.MSELoss()
|
||||||
|
|
||||||
|
print(" Validating...")
|
||||||
|
with torch.no_grad():
|
||||||
|
# Temporarily reduce diffusion steps for faster validation
|
||||||
|
original_timesteps = model.diffusion.num_timesteps
|
||||||
|
model.diffusion.num_timesteps = 200# Much faster validation
|
||||||
|
|
||||||
|
for batch_idx, (seq_x, seq_y) in enumerate(dataloaders['val']):
|
||||||
|
seq_x, seq_y = seq_x.to(device), seq_y.to(device)
|
||||||
|
|
||||||
|
# Generate predictions (inference mode with reduced steps)
|
||||||
|
pred = model(seq_x)
|
||||||
|
|
||||||
|
# Compute MSE loss for validation
|
||||||
|
loss = criterion(pred, seq_y)
|
||||||
|
val_loss += loss.item()
|
||||||
|
val_samples += 1
|
||||||
|
|
||||||
|
# Print validation progress for first epoch
|
||||||
|
if epoch == 0 and (batch_idx + 1) % 50 == 0:
|
||||||
|
print(f" Val Batch [{batch_idx+1}/{len(dataloaders['val'])}]")
|
||||||
|
|
||||||
|
# Early break for very first epoch to speed up
|
||||||
|
if epoch == 0 and batch_idx >= 100: # Only validate on first 100 batches for first epoch
|
||||||
|
break
|
||||||
|
|
||||||
|
# Restore original timesteps
|
||||||
|
model.diffusion.num_timesteps = original_timesteps
|
||||||
|
|
||||||
|
avg_val_loss = val_loss / val_samples
|
||||||
|
|
||||||
|
# Learning rate scheduling
|
||||||
|
scheduler.step(avg_val_loss)
|
||||||
|
current_lr = optimizer.param_groups[0]['lr']
|
||||||
|
|
||||||
|
print(f" Train Loss: {avg_train_loss:.6f}")
|
||||||
|
print(f" Val Loss: {avg_val_loss:.6f}")
|
||||||
|
print(f" Learning Rate: {current_lr:.2e}")
|
||||||
|
|
||||||
|
# Log to swanlab
|
||||||
|
swanlab.log({
|
||||||
|
'epoch': epoch + 1,
|
||||||
|
'train_loss': avg_train_loss,
|
||||||
|
'val_loss': avg_val_loss,
|
||||||
|
'learning_rate': current_lr
|
||||||
|
})
|
||||||
|
|
||||||
|
# Early stopping check
|
||||||
|
early_stopping(avg_val_loss, model)
|
||||||
|
|
||||||
|
if early_stopping.early_stop:
|
||||||
|
print(f"Early stopping at epoch {epoch + 1}")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Load best model
|
||||||
|
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
|
||||||
|
|
||||||
|
# Final evaluation on test set
|
||||||
|
print("\nEvaluating on test set...")
|
||||||
|
model.eval()
|
||||||
|
test_loss = 0.0
|
||||||
|
test_samples = 0
|
||||||
|
all_preds = []
|
||||||
|
all_targets = []
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# Use reduced timesteps for faster testing
|
||||||
|
original_timesteps = model.diffusion.num_timesteps
|
||||||
|
model.diffusion.num_timesteps = 200 # Faster but still good quality
|
||||||
|
|
||||||
|
for batch_idx, (seq_x, seq_y) in enumerate(dataloaders['test']):
|
||||||
|
seq_x, seq_y = seq_x.to(device), seq_y.to(device)
|
||||||
|
|
||||||
|
pred = model(seq_x)
|
||||||
|
|
||||||
|
loss = criterion(pred, seq_y)
|
||||||
|
test_loss += loss.item()
|
||||||
|
test_samples += 1
|
||||||
|
|
||||||
|
all_preds.append(pred.cpu().numpy())
|
||||||
|
all_targets.append(seq_y.cpu().numpy())
|
||||||
|
|
||||||
|
# Print progress every 50 batches
|
||||||
|
if (batch_idx + 1) % 50 == 0:
|
||||||
|
print(f" Test Batch [{batch_idx+1}/{len(dataloaders['test'])}]")
|
||||||
|
|
||||||
|
# Restore original timesteps
|
||||||
|
model.diffusion.num_timesteps = original_timesteps
|
||||||
|
|
||||||
|
avg_test_loss = test_loss / test_samples
|
||||||
|
|
||||||
|
# Calculate additional metrics
|
||||||
|
all_preds = np.concatenate(all_preds, axis=0)
|
||||||
|
all_targets = np.concatenate(all_targets, axis=0)
|
||||||
|
|
||||||
|
mse = np.mean((all_preds - all_targets) ** 2)
|
||||||
|
mae = np.mean(np.abs(all_preds - all_targets))
|
||||||
|
rmse = np.sqrt(mse)
|
||||||
|
|
||||||
|
metrics = {
|
||||||
|
'test_mse': mse,
|
||||||
|
'test_mae': mae,
|
||||||
|
'test_rmse': rmse,
|
||||||
|
'test_loss': avg_test_loss
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"Test Results:")
|
||||||
|
print(f" MSE: {mse:.6f}")
|
||||||
|
print(f" MAE: {mae:.6f}")
|
||||||
|
print(f" RMSE: {rmse:.6f}")
|
||||||
|
|
||||||
|
# Log final results
|
||||||
|
swanlab.log(metrics)
|
||||||
|
|
||||||
|
swanlab.finish()
|
||||||
|
|
||||||
|
return model, metrics
|
329
train_xpatch_sparse_multi_datasets.py
Normal file
329
train_xpatch_sparse_multi_datasets.py
Normal file
@ -0,0 +1,329 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Training script for xPatch_SparseChannel model on multiple datasets.
|
||||||
|
Supports Weather, Traffic, Electricity, Exchange, and ILI datasets with sigmoid learning rate adjustment.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import math
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from train.train import train_forecasting_model
|
||||||
|
from models.xPatch_SparseChannel.xPatch import Model as xPatchSparseChannel
|
||||||
|
|
||||||
|
# Dataset configurations
|
||||||
|
DATASET_CONFIGS = {
|
||||||
|
'Weather': {
|
||||||
|
'csv_file': 'weather.csv',
|
||||||
|
'enc_in': 21,
|
||||||
|
'batch_size': 2048,
|
||||||
|
'learning_rate': 0.0005,
|
||||||
|
'target': 'OT',
|
||||||
|
'data_path': 'weather/weather.csv',
|
||||||
|
'seq_len': 96,
|
||||||
|
'pred_lengths': [96, 192, 336, 720]
|
||||||
|
},
|
||||||
|
'Traffic': {
|
||||||
|
'csv_file': 'traffic.csv',
|
||||||
|
'enc_in': 862,
|
||||||
|
'batch_size': 32,
|
||||||
|
'learning_rate': 0.002,
|
||||||
|
'target': 'OT',
|
||||||
|
'data_path': 'traffic/traffic.csv',
|
||||||
|
'seq_len': 96,
|
||||||
|
'pred_lengths': [96, 192, 336, 720]
|
||||||
|
},
|
||||||
|
'Electricity': {
|
||||||
|
'csv_file': 'electricity.csv',
|
||||||
|
'enc_in': 321,
|
||||||
|
'batch_size': 32,
|
||||||
|
'learning_rate': 0.001,
|
||||||
|
'target': 'OT',
|
||||||
|
'data_path': 'electricity/electricity.csv',
|
||||||
|
'seq_len': 96,
|
||||||
|
'pred_lengths': [96, 192, 336, 720]
|
||||||
|
},
|
||||||
|
'Exchange': {
|
||||||
|
'csv_file': 'exchange_rate.csv',
|
||||||
|
'enc_in': 8,
|
||||||
|
'batch_size': 128,
|
||||||
|
'learning_rate': 0.00005,
|
||||||
|
'target': 'OT',
|
||||||
|
'data_path': 'exchange_rate/exchange_rate.csv',
|
||||||
|
'seq_len': 96,
|
||||||
|
'pred_lengths': [96, 192, 336, 720]
|
||||||
|
},
|
||||||
|
'ILI': {
|
||||||
|
'csv_file': 'national_illness.csv',
|
||||||
|
'enc_in': 7,
|
||||||
|
'batch_size': 32,
|
||||||
|
'learning_rate': 0.01,
|
||||||
|
'target': 'ot',
|
||||||
|
'data_path': 'illness/national_illness.csv',
|
||||||
|
'seq_len': 36,
|
||||||
|
'pred_lengths': [24, 36, 48, 60]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class Args:
|
||||||
|
"""Configuration class for xPatch_SparseChannel model parameters."""
|
||||||
|
def __init__(self, dataset_name, pred_len):
|
||||||
|
dataset_config = DATASET_CONFIGS[dataset_name]
|
||||||
|
|
||||||
|
# Model architecture parameters
|
||||||
|
self.task_name = 'long_term_forecast'
|
||||||
|
self.seq_len = dataset_config['seq_len'] # Use dataset-specific seq_len
|
||||||
|
self.label_len = self.seq_len // 2 # Half of seq_len as label length
|
||||||
|
self.pred_len = pred_len
|
||||||
|
self.enc_in = dataset_config['enc_in']
|
||||||
|
self.c_out = dataset_config['enc_in']
|
||||||
|
|
||||||
|
# xPatch specific parameters from reference
|
||||||
|
self.patch_len = 16 # patch length
|
||||||
|
self.stride = 8 # stride
|
||||||
|
self.padding_patch = 'end' # padding on the end
|
||||||
|
|
||||||
|
# Moving Average parameters
|
||||||
|
self.ma_type = 'ema' # moving average type
|
||||||
|
self.alpha = 0.3 # alpha parameter for EMA
|
||||||
|
self.beta = 0.3 # beta parameter for DEMA
|
||||||
|
|
||||||
|
# RevIN normalization
|
||||||
|
self.revin = 1 # RevIN; True 1 False 0
|
||||||
|
|
||||||
|
# Time features (not used by xPatch but required by data loader)
|
||||||
|
self.embed = 'timeF' # Time feature embedding type
|
||||||
|
self.freq = 'h' # Frequency for time features (hourly)
|
||||||
|
|
||||||
|
# Dataset specific parameters
|
||||||
|
self.data = 'custom'
|
||||||
|
self.root_path = './data/'
|
||||||
|
self.data_path = dataset_config['data_path']
|
||||||
|
self.features = 'M' # Multivariate prediction
|
||||||
|
self.target = dataset_config['target'] # Target column
|
||||||
|
self.train_only = False
|
||||||
|
|
||||||
|
# Required for dataflow - will be set by config
|
||||||
|
self.batch_size = dataset_config['batch_size']
|
||||||
|
self.num_workers = 8 # Will be overridden by config
|
||||||
|
|
||||||
|
print(f"xPatch_SparseChannel Model configuration for {dataset_name}:")
|
||||||
|
print(f" - Input channels (C): {self.enc_in}")
|
||||||
|
print(f" - Patch length: {self.patch_len}")
|
||||||
|
print(f" - Stride: {self.stride}")
|
||||||
|
print(f" - Sequence length: {self.seq_len}") # Now dataset-specific
|
||||||
|
print(f" - Prediction length: {pred_len}")
|
||||||
|
print(f" - Moving average type: {self.ma_type}")
|
||||||
|
print(f" - Alpha: {self.alpha}")
|
||||||
|
print(f" - Beta: {self.beta}")
|
||||||
|
print(f" - RevIN: {self.revin}")
|
||||||
|
print(f" - Target: {self.target}")
|
||||||
|
print(f" - Batch size: {self.batch_size}")
|
||||||
|
|
||||||
|
def create_xpatch_sparse_model(args):
|
||||||
|
"""Create xPatch_SparseChannel model with given configuration."""
|
||||||
|
def model_constructor():
|
||||||
|
return xPatchSparseChannel(args)
|
||||||
|
return model_constructor
|
||||||
|
|
||||||
|
def train_single_dataset(dataset_name, pred_len, model_args, cmd_args, use_ps_loss=True):
|
||||||
|
"""Train xPatch_SparseChannel on specified dataset with given prediction length."""
|
||||||
|
|
||||||
|
dataset_config = DATASET_CONFIGS[dataset_name]
|
||||||
|
|
||||||
|
# Update args for current prediction length
|
||||||
|
model_args.pred_len = pred_len
|
||||||
|
|
||||||
|
# Update dataflow parameters from command line args
|
||||||
|
model_args.num_workers = cmd_args.num_workers
|
||||||
|
|
||||||
|
# Create model constructor
|
||||||
|
model_constructor = create_xpatch_sparse_model(model_args)
|
||||||
|
|
||||||
|
# Training configuration with dataset-specific parameters
|
||||||
|
config = {
|
||||||
|
'learning_rate': dataset_config['learning_rate'], # Dataset-specific learning rate
|
||||||
|
'batch_size': dataset_config['batch_size'], # Dataset-specific batch size
|
||||||
|
'weight_decay': 1e-4,
|
||||||
|
'dataset': dataset_name,
|
||||||
|
'pred_len': pred_len,
|
||||||
|
'seq_len': model_args.seq_len,
|
||||||
|
'patch_len': model_args.patch_len,
|
||||||
|
'stride': model_args.stride,
|
||||||
|
'ma_type': model_args.ma_type,
|
||||||
|
'use_ps_loss': use_ps_loss,
|
||||||
|
'num_workers': cmd_args.num_workers,
|
||||||
|
'pin_memory': True,
|
||||||
|
'persistent_workers': True
|
||||||
|
}
|
||||||
|
|
||||||
|
# Project name for tracking
|
||||||
|
loss_suffix = "_PSLoss" if use_ps_loss else "_MSE"
|
||||||
|
project_name = f"xPatch_SparseChannel_{dataset_name}_pred{pred_len}{loss_suffix}_sigmoid"
|
||||||
|
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f"Training {dataset_name} with prediction length {pred_len}")
|
||||||
|
print(f"Model: xPatch_SparseChannel")
|
||||||
|
print(f"Loss function: {'PS_Loss' if use_ps_loss else 'MSE'}")
|
||||||
|
print(f"Learning rate: {dataset_config['learning_rate']}")
|
||||||
|
print(f"Batch size: {dataset_config['batch_size']}")
|
||||||
|
print(f"Features: {dataset_config['enc_in']}")
|
||||||
|
print(f"Data path: {model_args.root_path}{model_args.data_path}")
|
||||||
|
print(f"LR adjustment: sigmoid")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
|
||||||
|
# Train the model
|
||||||
|
try:
|
||||||
|
model, metrics = train_forecasting_model(
|
||||||
|
model_constructor=model_constructor,
|
||||||
|
data_path=f"{model_args.root_path}{model_args.data_path}",
|
||||||
|
project_name=project_name,
|
||||||
|
config=config,
|
||||||
|
early_stopping_patience=5,
|
||||||
|
max_epochs=100,
|
||||||
|
checkpoint_dir="./checkpoints",
|
||||||
|
log_interval=50,
|
||||||
|
use_x_mark=False, # xPatch_SparseChannel doesn't use time features
|
||||||
|
use_ps_loss=use_ps_loss,
|
||||||
|
ps_lambda=cmd_args.ps_lambda,
|
||||||
|
patch_len_threshold=64,
|
||||||
|
use_gdw=True,
|
||||||
|
dataset_mode="dataflow",
|
||||||
|
dataflow_args=model_args,
|
||||||
|
lr_adjust_strategy="sigmoid" # Use sigmoid learning rate adjustment
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Training completed for {project_name}")
|
||||||
|
if use_ps_loss:
|
||||||
|
print(f"Final validation MSE: {metrics.get('final_val_mse', 'N/A'):.6f}")
|
||||||
|
else:
|
||||||
|
print(f"Final validation MSE: {metrics.get('final_val_loss', 'N/A'):.6f}")
|
||||||
|
|
||||||
|
return model, metrics
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error training {project_name}: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='Train xPatch_SparseChannel on multiple datasets with sigmoid LR adjustment')
|
||||||
|
parser.add_argument('--datasets', nargs='+', type=str,
|
||||||
|
default=['Weather', 'Traffic', 'Electricity', 'Exchange', 'ILI'],
|
||||||
|
choices=['Weather', 'Traffic', 'Electricity', 'Exchange', 'ILI'],
|
||||||
|
help='List of datasets to train on')
|
||||||
|
parser.add_argument('--use_ps_loss', action='store_true', default=True,
|
||||||
|
help='Use PS_Loss instead of MSE')
|
||||||
|
parser.add_argument('--ps_lambda', type=float, default=5.0,
|
||||||
|
help='Weight for PS loss component')
|
||||||
|
parser.add_argument('--device', type=str, default=None,
|
||||||
|
help='Device to use for training (cuda/cpu)')
|
||||||
|
parser.add_argument('--num_workers', type=int, default=8,
|
||||||
|
help='Number of data loading workers')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print("xPatch_SparseChannel Multi-Dataset Training Script with Sigmoid LR Adjustment")
|
||||||
|
print("=" * 80)
|
||||||
|
print(f"Datasets: {args.datasets}")
|
||||||
|
print(f"Use PS_Loss: {args.use_ps_loss}")
|
||||||
|
print(f"PS_Lambda: {args.ps_lambda}")
|
||||||
|
print(f"Number of workers: {args.num_workers}")
|
||||||
|
print(f"Learning rate adjustment: sigmoid")
|
||||||
|
|
||||||
|
# Display dataset configurations
|
||||||
|
print("\nDataset Configurations:")
|
||||||
|
for dataset in args.datasets:
|
||||||
|
config = DATASET_CONFIGS[dataset]
|
||||||
|
print(f" {dataset}:")
|
||||||
|
print(f" - Features: {config['enc_in']}")
|
||||||
|
print(f" - Batch size: {config['batch_size']}")
|
||||||
|
print(f" - Learning rate: {config['learning_rate']}")
|
||||||
|
print(f" - Sequence length: {config['seq_len']}")
|
||||||
|
print(f" - Prediction lengths: {config['pred_lengths']}")
|
||||||
|
print(f" - Data path: {config['data_path']}")
|
||||||
|
|
||||||
|
# Check if data files exist
|
||||||
|
missing_datasets = []
|
||||||
|
for dataset in args.datasets:
|
||||||
|
config = DATASET_CONFIGS[dataset]
|
||||||
|
data_path = f"./data/{config['data_path']}"
|
||||||
|
if not os.path.exists(data_path):
|
||||||
|
missing_datasets.append(f"{dataset}: '{data_path}'")
|
||||||
|
|
||||||
|
if missing_datasets:
|
||||||
|
print(f"\nError: The following dataset files were not found:")
|
||||||
|
for missing in missing_datasets:
|
||||||
|
print(f" - {missing}")
|
||||||
|
print("Please ensure all dataset files are available in the data/ directory.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Set device
|
||||||
|
if args.device is None:
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
else:
|
||||||
|
device = args.device
|
||||||
|
print(f"\nUsing device: {device}")
|
||||||
|
|
||||||
|
# Training results storage
|
||||||
|
all_results = {}
|
||||||
|
|
||||||
|
# Train on each dataset
|
||||||
|
for dataset in args.datasets:
|
||||||
|
print(f"\n{'#'*80}")
|
||||||
|
print(f"STARTING TRAINING ON {dataset.upper()} DATASET")
|
||||||
|
print(f"{'#'*80}")
|
||||||
|
|
||||||
|
all_results[dataset] = {}
|
||||||
|
config = DATASET_CONFIGS[dataset]
|
||||||
|
|
||||||
|
# Train on each prediction length for this dataset
|
||||||
|
for pred_len in config['pred_lengths']: # Use dataset-specific prediction lengths
|
||||||
|
# Create model configuration for current dataset
|
||||||
|
model_args = Args(
|
||||||
|
dataset_name=dataset,
|
||||||
|
pred_len=pred_len
|
||||||
|
)
|
||||||
|
|
||||||
|
# Train the model
|
||||||
|
model, metrics = train_single_dataset(
|
||||||
|
dataset_name=dataset,
|
||||||
|
pred_len=pred_len,
|
||||||
|
model_args=model_args,
|
||||||
|
cmd_args=args,
|
||||||
|
use_ps_loss=args.use_ps_loss
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store results
|
||||||
|
all_results[dataset][pred_len] = {
|
||||||
|
'model': model,
|
||||||
|
'metrics': metrics,
|
||||||
|
'data_path': f"./data/{config['data_path']}"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Print comprehensive summary
|
||||||
|
print("\n" + "=" * 100)
|
||||||
|
print("COMPREHENSIVE TRAINING SUMMARY")
|
||||||
|
print("=" * 100)
|
||||||
|
|
||||||
|
for dataset in args.datasets:
|
||||||
|
config = DATASET_CONFIGS[dataset]
|
||||||
|
print(f"\n{dataset} (Features: {config['enc_in']}, Batch: {config['batch_size']}, LR: {config['learning_rate']}, Seq: {config['seq_len']}):")
|
||||||
|
for pred_len in all_results[dataset]:
|
||||||
|
result = all_results[dataset][pred_len]
|
||||||
|
if result['metrics'] is not None:
|
||||||
|
if args.use_ps_loss:
|
||||||
|
mse = result['metrics'].get('final_val_mse', 'N/A')
|
||||||
|
else:
|
||||||
|
mse = result['metrics'].get('final_val_loss', 'N/A')
|
||||||
|
print(f" Pred Length {pred_len}: MSE = {mse}")
|
||||||
|
else:
|
||||||
|
print(f" Pred Length {pred_len}: Training failed")
|
||||||
|
|
||||||
|
print(f"\nAll models saved in: ./checkpoints/")
|
||||||
|
print("All datasets training completed!")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Reference in New Issue
Block a user