feat: add DiffusionTimeSeries and iTransformer models, introduce xPatch_SparseChannel
This commit is contained in:
@ -273,15 +273,21 @@ class Dataset_Custom(Dataset):
|
||||
self.data_stamp = data_stamp
|
||||
|
||||
def __getitem__(self, index):
|
||||
# 1. 定义输入序列 seq_x 的起止位置
|
||||
s_begin = index
|
||||
s_end = s_begin + self.seq_len
|
||||
r_begin = s_end - self.label_len
|
||||
r_end = r_begin + self.label_len + self.pred_len
|
||||
|
||||
# 2. 定义目标序列 seq_y 的起止位置
|
||||
# seq_y 的开始 (r_begin) 就是 seq_x 的结束 (s_end)
|
||||
r_begin = s_end
|
||||
# seq_y 的结束 (r_end) 是其开始位置加上预测长度 (pred_len)
|
||||
r_end = r_begin + self.pred_len
|
||||
# 3. 根据起止位置切片数据
|
||||
seq_x = self.data_x[s_begin:s_end]
|
||||
seq_y = self.data_y[r_begin:r_end]
|
||||
seq_x_mark = self.data_stamp[s_begin:s_end]
|
||||
seq_y_mark = self.data_stamp[r_begin:r_end]
|
||||
seq_x = seq_x.astype('float32')
|
||||
seq_y = seq_y.astype('float32')
|
||||
|
||||
return seq_x, seq_y, seq_x_mark, seq_y_mark
|
||||
|
||||
|
0
models/DiffusionTimeSeries/__init__.py
Normal file
0
models/DiffusionTimeSeries/__init__.py
Normal file
323
models/DiffusionTimeSeries/diffusion_ts.py
Normal file
323
models/DiffusionTimeSeries/diffusion_ts.py
Normal file
@ -0,0 +1,323 @@
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# ----------------------- 工具:构造/缩放线性 betas -----------------------
|
||||
def make_linear_betas(T: int, beta_start=1e-4, beta_end=2e-2, device='cpu'):
|
||||
return torch.linspace(beta_start, beta_end, T, device=device)
|
||||
|
||||
def cumprod_from_betas(betas: torch.Tensor):
|
||||
alphas = 1.0 - betas
|
||||
return torch.cumprod(alphas, dim=0) # shape [T]
|
||||
|
||||
@torch.no_grad()
|
||||
def scale_betas_to_target_cumprod(betas: torch.Tensor, target_cumprod: float, max_scale: float = 100.0):
|
||||
"""
|
||||
给定一段 betas[1..T],寻找缩放系数 s>0,使得 ∏(1 - s*beta_i) = target_cumprod
|
||||
用二分法在 (0, s_max) 上搜索。确保 0 < s*beta_i < 1。
|
||||
"""
|
||||
device = betas.device
|
||||
eps = 1e-12
|
||||
s_low = 0.0
|
||||
s_high = min(max_scale, (1.0 - 1e-6) / (betas.max().item() + eps)) # 使 1 - s*beta > 0
|
||||
|
||||
def cumprod_with_scale(s: float):
|
||||
a = (1.0 - betas * s).clamp(min=1e-6, max=1.0-1e-6)
|
||||
return torch.cumprod(a, dim=0)[-1].item()
|
||||
|
||||
# 若不缩放已接近目标,直接返回
|
||||
base = cumprod_with_scale(1.0)
|
||||
if abs(base - target_cumprod) / max(target_cumprod, 1e-12) < 1e-6:
|
||||
return betas
|
||||
|
||||
# 目标在 (0, s_high) 内单调可达,进行二分
|
||||
for _ in range(60):
|
||||
mid = 0.5 * (s_low + s_high)
|
||||
val = cumprod_with_scale(mid)
|
||||
if val > target_cumprod:
|
||||
# 乘子太小(噪声弱),需要更大 s
|
||||
s_low = mid
|
||||
else:
|
||||
s_high = mid
|
||||
s_best = 0.5 * (s_low + s_high)
|
||||
return (betas * s_best).clamp(min=1e-8, max=1-1e-6)
|
||||
|
||||
|
||||
# ------------------------------ DiT Blocks --------------------------------
|
||||
class DiTBlock(nn.Module):
|
||||
def __init__(self, dim: int, heads: int, mlp_ratio=4.0):
|
||||
super().__init__()
|
||||
self.ln1 = nn.LayerNorm(dim)
|
||||
self.attn = nn.MultiheadAttention(dim, heads, batch_first=True)
|
||||
self.ln2 = nn.LayerNorm(dim)
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(dim, int(dim * mlp_ratio)),
|
||||
nn.GELU(),
|
||||
nn.Linear(int(dim * mlp_ratio), dim),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# x: [B, C_tokens, D]
|
||||
h = self.attn(self.ln1(x), self.ln1(x), self.ln1(x))[0]
|
||||
x = x + h
|
||||
x = x + self.mlp(self.ln2(x))
|
||||
return x
|
||||
|
||||
|
||||
class DiTChannelTokens(nn.Module):
|
||||
"""
|
||||
Token = 一个通道(变量)。
|
||||
对于每个通道,输入是 [L] 的时间向量;我们用两条投影:
|
||||
- W_x : 把 x_t 的时间向量投影成 token 向量
|
||||
- W_n : 把 noise-level(来自 schedule 的 ā 或 b̄ 的时间向量)投影成 token 偏置
|
||||
注意:不再使用可学习的 t-embedding;噪声条件完全由 noise map 决定。
|
||||
"""
|
||||
def __init__(self, L: int, C: int, dim: int = 256, depth: int = 8, heads: int = 8):
|
||||
super().__init__()
|
||||
self.L = L
|
||||
self.C = C
|
||||
self.dim = dim
|
||||
|
||||
# 通道嵌入(可选,用于区分变量)
|
||||
self.channel_embed = nn.Parameter(torch.randn(C, dim) * 0.02)
|
||||
|
||||
# 将每个通道的时间序列映射到 token
|
||||
self.proj_x = nn.Linear(L, dim, bias=False)
|
||||
# 将每个通道的逐时间噪声强度(例如 [sqrt(ā), sqrt(1-ā)] 拼接后经一层线性)
|
||||
self.proj_noise = nn.Linear(L, dim, bias=True)
|
||||
|
||||
self.blocks = nn.ModuleList([DiTBlock(dim, heads) for _ in range(depth)])
|
||||
self.ln_f = nn.LayerNorm(dim)
|
||||
|
||||
# 反投影回时间长度 L,预测 ε(每通道独立投影)
|
||||
self.head = nn.Linear(dim, L, bias=False)
|
||||
|
||||
def forward(self, x_t: torch.Tensor, noise_feat: torch.Tensor):
|
||||
"""
|
||||
x_t : [B, L, C]
|
||||
noise_feat: [B, L, C] (建议传入 sqrt(ā) 或 concat 后先合并到 L 维度,这里用一条投影即可)
|
||||
返回 ε̂ : [B, L, C]
|
||||
"""
|
||||
B, L, C = x_t.shape
|
||||
assert L == self.L and C == self.C
|
||||
|
||||
# 逐通道映射成 token
|
||||
# 把 (B, L, C) 变 (B, C, L) 再线性
|
||||
x_tc = x_t.permute(0, 2, 1) # [B, C, L]
|
||||
n_tc = noise_feat.permute(0, 2, 1) # [B, C, L]
|
||||
|
||||
tok = self.proj_x(x_tc) + self.proj_noise(n_tc) # [B, C, D]
|
||||
tok = tok + self.channel_embed.unsqueeze(0) # broadcast [1, C, D]
|
||||
|
||||
for blk in self.blocks:
|
||||
tok = blk(tok) # [B, C, D]
|
||||
|
||||
tok = self.ln_f(tok)
|
||||
out = self.head(tok) # [B, C, L]
|
||||
eps_pred = out.permute(0, 2, 1) # [B, L, C]
|
||||
return eps_pred
|
||||
|
||||
|
||||
# ----------------------- RAD 两阶段扩散(通道为token) -----------------------
|
||||
class RADChannelDiT(nn.Module):
|
||||
def __init__(self,
|
||||
past_len: int,
|
||||
future_len: int,
|
||||
channels: int,
|
||||
T: int = 1000,
|
||||
T1_ratio: float = 0.7,
|
||||
model_dim: int = 256,
|
||||
depth: int = 8,
|
||||
heads: int = 8,
|
||||
beta_start: float = 1e-4,
|
||||
beta_end: float = 2e-2,
|
||||
use_cosine_target: bool = True):
|
||||
"""
|
||||
- 训练:两阶段(Phase-1 + Phase-2),t 从 [1..T] 均匀采样
|
||||
- 推理:仅使用 Phase-1(t: T1→1),只更新未来区域
|
||||
- Token=通道,每个 token 见到整个时间轴 + 噪声强度时间向量
|
||||
"""
|
||||
super().__init__()
|
||||
self.P = past_len
|
||||
self.H = future_len
|
||||
self.C = channels
|
||||
self.L = past_len + future_len
|
||||
self.T = T
|
||||
self.T1 = max(1, int(T * T1_ratio))
|
||||
self.T2 = T - self.T1
|
||||
assert self.T2 >= 1, "T1_ratio 不能太大,至少留下 1 步给 Phase-2"
|
||||
|
||||
device = torch.device('cpu')
|
||||
|
||||
# 目标 ā_T(用于把两段线性 schedule 归一到同一最终噪声强度)
|
||||
if use_cosine_target:
|
||||
# 参考 cosine 计划,得到一条“全局目标 ā_T”
|
||||
steps = T + 1
|
||||
x = torch.linspace(0, T, steps, dtype=torch.float64)
|
||||
s = 0.008
|
||||
alphas_cum = torch.cos(((x / T) + s) / (1 + s) * math.pi / 2) ** 2
|
||||
alphas_cum = alphas_cum / alphas_cum[0]
|
||||
a_bar_target_T = float(alphas_cum[-1])
|
||||
else:
|
||||
# 直接用 DDPM 线性 beta 的结果作为目标
|
||||
betas_full = make_linear_betas(T, beta_start, beta_end, device)
|
||||
a_bar_target_T = cumprod_from_betas(betas_full)[-1].item()
|
||||
|
||||
# Phase-1 & Phase-2 原始线性 beta
|
||||
betas1 = make_linear_betas(self.T1, beta_start, beta_end, device)
|
||||
betas2 = make_linear_betas(self.T2, beta_start, beta_end, device)
|
||||
|
||||
# 首先不缩放,计算 ā1[T1], ā2[T2]
|
||||
a_bar1 = cumprod_from_betas(betas1) # shape [T1]
|
||||
a_bar2 = cumprod_from_betas(betas2) # shape [T2]
|
||||
|
||||
# 缩放 Phase-2 的 betas,使 ā1[T1] * ā2'[T2] = 目标 ā_T
|
||||
target_a2 = a_bar_target_T / (a_bar1[-1].item() + 1e-12)
|
||||
betas2 = scale_betas_to_target_cumprod(betas2, target_a2)
|
||||
|
||||
# 重新计算
|
||||
# a_bar1 = cumprod_from_betas(betas1).float() # [T1]
|
||||
a_bar2 = cumprod_from_betas(betas2).float() # [T2]
|
||||
|
||||
self.register_buffer("betas1", betas1.float())
|
||||
self.register_buffer("betas2", betas2.float())
|
||||
self.register_buffer("alphas1", 1.0 - betas1.float())
|
||||
self.register_buffer("alphas2", 1.0 - betas2.float())
|
||||
self.register_buffer("a_bar1", a_bar1)
|
||||
self.register_buffer("a_bar2", a_bar2)
|
||||
self.register_buffer("a_bar_target_T", torch.tensor(a_bar_target_T, dtype=torch.float32))
|
||||
|
||||
# Backbone: token=通道
|
||||
self.backbone = DiTChannelTokens(L=self.L, C=self.C, dim=model_dim, depth=depth, heads=heads)
|
||||
|
||||
# ------------------------ 内部:构造 mask & āt,i ------------------------
|
||||
def _mask_future(self, B, device):
|
||||
# mask: 未来区域=1,历史=0,形状 [B, L, C](与网络输入 [B,L,C] 对齐)
|
||||
m = torch.zeros(B, self.L, self.C, device=device)
|
||||
m[:, self.P:, :] = 1.0
|
||||
return m
|
||||
|
||||
def _a_bar_map_at_t(self, t_scalar: int, B: int, device, mask_future: torch.Tensor):
|
||||
"""
|
||||
构造逐像素 āt,i,形状 [B, L, C]
|
||||
- 若 t<=T1:未来区域用 ā1[t],历史区域=1
|
||||
- 若 t> T1:未来区域固定 ā1[T1],历史区域用 ā2[t-T1]
|
||||
"""
|
||||
if t_scalar <= self.T1:
|
||||
a_future = self.a_bar1[t_scalar - 1] # 索引从 0 开始
|
||||
a_past = torch.tensor(1.0, device=device)
|
||||
else:
|
||||
a_future = self.a_bar1[-1]
|
||||
a_past = self.a_bar2[t_scalar - self.T1 - 1]
|
||||
|
||||
a_future_map = torch.full((B, self.L, self.C), float(a_future.item()), device=device)
|
||||
a_past_map = torch.full((B, self.L, self.C), float(a_past.item()), device=device)
|
||||
a_map = a_past_map * (1 - mask_future) + a_future_map * mask_future
|
||||
return a_map # [B, L, C]
|
||||
|
||||
# ----------------------------- 前向训练 -----------------------------
|
||||
def forward(self, x_hist: torch.Tensor, x_future: torch.Tensor) -> Tuple[torch.Tensor, dict]:
|
||||
"""
|
||||
x_hist : [B, P, C]
|
||||
x_future : [B, H, C]
|
||||
训练:采样 t∈[1..T],构造两阶段 āt,i,边际加噪 xt,并用逐通道 token 的 DiT 预测 ε
|
||||
"""
|
||||
B = x_hist.size(0)
|
||||
device = x_hist.device
|
||||
x0 = torch.cat([x_hist, x_future], dim=1) # [B, L, C]
|
||||
|
||||
# 采样训练步 t (1..T)
|
||||
t = torch.randint(1, self.T + 1, (B,), device=device, dtype=torch.long)
|
||||
|
||||
# 构造 mask 和逐像素 āt,i
|
||||
mask_fut = self._mask_future(B, device) # [B, L, C]
|
||||
# 逐样本构造 āt,i(不同样本 t 不同,只能用循环或向量化 trick;B 通常不大,for 循环即可)
|
||||
a_bar_map = torch.stack([self._a_bar_map_at_t(int(tt.item()), 1, device, mask_fut[0:1])
|
||||
for tt in t], dim=0).squeeze(1) # [B,L,C]
|
||||
|
||||
# 边际加噪
|
||||
eps = torch.randn_like(x0) # [B,L,C]
|
||||
x_t = a_bar_map.sqrt() * x0 + (1.0 - a_bar_map).sqrt() * eps
|
||||
|
||||
# Spatial Noise Embedding:完全由 schedule 决定
|
||||
# 传入每个像素的 √ā 和 √(1-ā)(或任选其一);这里用 √ā
|
||||
noise_feat = a_bar_map.sqrt() # [B,L,C]
|
||||
|
||||
# 预测 ε
|
||||
eps_pred = self.backbone(x_t, noise_feat) # [B,L,C]
|
||||
|
||||
loss = F.mse_loss(eps_pred, eps)
|
||||
return loss, {'t_mean': t.float().mean().item()}
|
||||
|
||||
# ----------------------------- 采样推理 -----------------------------
|
||||
@torch.no_grad()
|
||||
def sample(self, x_hist: torch.Tensor, steps: Optional[int] = None) -> torch.Tensor:
|
||||
"""
|
||||
仅 Phase-1 推理:t = T1..1,只更新未来区域,历史保持观测值
|
||||
x_hist : [B,P,C]
|
||||
return : [B,H,C]
|
||||
"""
|
||||
B = x_hist.size(0)
|
||||
device = x_hist.device
|
||||
mask_fut = self._mask_future(B, device) # [B,L,C]
|
||||
|
||||
# 初始化 x:历史=观测,未来=高斯噪声
|
||||
x = torch.zeros(B, self.L, self.C, device=device)
|
||||
x[:, :self.P, :] = x_hist
|
||||
x[:, self.P:, :] = torch.randn(B, self.H, self.C, device=device)
|
||||
|
||||
# 支持子采样:把 [T1..1] 均匀下采样到 steps
|
||||
T1 = self.T1
|
||||
steps = steps if steps is not None else T1
|
||||
steps = max(1, min(steps, T1))
|
||||
ts = torch.linspace(T1, 1, steps, device=device).long().tolist()
|
||||
# 为 DDPM 更新需要 α_t, β_t(仅对未来区域定义)
|
||||
alphas1 = self.alphas1 # [T1]
|
||||
betas1 = self.betas1
|
||||
a_bar1 = self.a_bar1
|
||||
|
||||
for idx, t_scalar in enumerate(ts):
|
||||
# 当前 āt,i(Phase-1:历史=1,未来=ā1[t])
|
||||
a_bar_map = self._a_bar_map_at_t(int(t_scalar), B, device, mask_fut) # [B,L,C]
|
||||
# 网络条件:用 √ā 作为噪声嵌入
|
||||
noise_feat = a_bar_map.sqrt()
|
||||
# 预测 ε
|
||||
eps_pred = self.backbone(x, noise_feat) # [B,L,C]
|
||||
|
||||
# 对未来区域做 DDPM 一步(历史区保持原值)
|
||||
# 标准 DDPM 公式(像素在未来区域共享同一 α_t、β_t)
|
||||
t_idx = t_scalar - 1
|
||||
alpha_t = alphas1[t_idx] # 标量
|
||||
beta_t = betas1[t_idx]
|
||||
a_bar_t = a_bar1[t_idx]
|
||||
if t_scalar > 1:
|
||||
a_bar_prev = a_bar1[t_idx - 1]
|
||||
else:
|
||||
a_bar_prev = torch.tensor(1.0, device=device)
|
||||
|
||||
# x0 预测(仅用于推导均值,也可直接用μ公式)
|
||||
x0_pred = (x - (1.0 - a_bar_t).sqrt() * eps_pred) / (a_bar_t.sqrt() + 1e-8)
|
||||
|
||||
# 均值:μ_t = 1/sqrt(α_t) * (x_t - β_t / sqrt(1 - ā_t) * ε̂)
|
||||
mean = (x - (beta_t / (1.0 - a_bar_t).sqrt()) * eps_pred) / (alpha_t.sqrt() + 1e-8)
|
||||
|
||||
# 采样噪声
|
||||
if t_scalar > 1:
|
||||
z = torch.randn_like(x)
|
||||
else:
|
||||
z = torch.zeros_like(x)
|
||||
|
||||
# 方差项(DDPM):σ_t = sqrt(β_t)
|
||||
x_next = mean + z * beta_t.sqrt()
|
||||
|
||||
# 仅替换未来区域
|
||||
x = x * (1 - mask_fut) + x_next * mask_fut
|
||||
# 历史强制为观测
|
||||
x[:, :self.P, :] = x_hist
|
||||
|
||||
return x[:, self.P:, :] # [B,H,C]
|
||||
|
132
models/iTransformer/iTransformer.py
Normal file
132
models/iTransformer/iTransformer.py
Normal file
@ -0,0 +1,132 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from layers.Transformer_EncDec import Encoder, EncoderLayer
|
||||
from layers.SelfAttention_Family import FullAttention, AttentionLayer
|
||||
from layers.Embed import DataEmbedding_inverted
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
"""
|
||||
Paper link: https://arxiv.org/abs/2310.06625
|
||||
"""
|
||||
|
||||
def __init__(self, configs):
|
||||
super(Model, self).__init__()
|
||||
self.task_name = configs.task_name
|
||||
self.seq_len = configs.seq_len
|
||||
self.pred_len = configs.pred_len
|
||||
# Embedding
|
||||
self.enc_embedding = DataEmbedding_inverted(configs.seq_len, configs.d_model, configs.embed, configs.freq,
|
||||
configs.dropout)
|
||||
# Encoder
|
||||
self.encoder = Encoder(
|
||||
[
|
||||
EncoderLayer(
|
||||
AttentionLayer(
|
||||
FullAttention(False, configs.factor, attention_dropout=configs.dropout,
|
||||
output_attention=False), configs.d_model, configs.n_heads),
|
||||
configs.d_model,
|
||||
configs.d_ff,
|
||||
dropout=configs.dropout,
|
||||
activation=configs.activation
|
||||
) for l in range(configs.e_layers)
|
||||
],
|
||||
norm_layer=torch.nn.LayerNorm(configs.d_model)
|
||||
)
|
||||
# Decoder
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
self.projection = nn.Linear(configs.d_model, configs.pred_len, bias=True)
|
||||
if self.task_name == 'imputation':
|
||||
self.projection = nn.Linear(configs.d_model, configs.seq_len, bias=True)
|
||||
if self.task_name == 'anomaly_detection':
|
||||
self.projection = nn.Linear(configs.d_model, configs.seq_len, bias=True)
|
||||
if self.task_name == 'classification':
|
||||
self.act = F.gelu
|
||||
self.dropout = nn.Dropout(configs.dropout)
|
||||
self.projection = nn.Linear(configs.d_model * configs.enc_in, configs.num_class)
|
||||
|
||||
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
||||
# Normalization from Non-stationary Transformer
|
||||
means = x_enc.mean(1, keepdim=True).detach()
|
||||
x_enc = x_enc - means
|
||||
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
||||
x_enc /= stdev
|
||||
|
||||
_, _, N = x_enc.shape
|
||||
|
||||
# Embedding
|
||||
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
||||
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
||||
|
||||
dec_out = self.projection(enc_out).permute(0, 2, 1)[:, :, :N]
|
||||
# De-Normalization from Non-stationary Transformer
|
||||
dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
|
||||
dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
|
||||
return dec_out
|
||||
|
||||
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
|
||||
# Normalization from Non-stationary Transformer
|
||||
means = x_enc.mean(1, keepdim=True).detach()
|
||||
x_enc = x_enc - means
|
||||
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
||||
x_enc /= stdev
|
||||
|
||||
_, L, N = x_enc.shape
|
||||
|
||||
# Embedding
|
||||
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
||||
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
||||
|
||||
dec_out = self.projection(enc_out).permute(0, 2, 1)[:, :, :N]
|
||||
# De-Normalization from Non-stationary Transformer
|
||||
dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, L, 1))
|
||||
dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, L, 1))
|
||||
return dec_out
|
||||
|
||||
def anomaly_detection(self, x_enc):
|
||||
# Normalization from Non-stationary Transformer
|
||||
means = x_enc.mean(1, keepdim=True).detach()
|
||||
x_enc = x_enc - means
|
||||
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
||||
x_enc /= stdev
|
||||
|
||||
_, L, N = x_enc.shape
|
||||
|
||||
# Embedding
|
||||
enc_out = self.enc_embedding(x_enc, None)
|
||||
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
||||
|
||||
dec_out = self.projection(enc_out).permute(0, 2, 1)[:, :, :N]
|
||||
# De-Normalization from Non-stationary Transformer
|
||||
dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, L, 1))
|
||||
dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, L, 1))
|
||||
return dec_out
|
||||
|
||||
def classification(self, x_enc, x_mark_enc):
|
||||
# Embedding
|
||||
enc_out = self.enc_embedding(x_enc, None)
|
||||
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
||||
|
||||
# Output
|
||||
output = self.act(enc_out) # the output transformer encoder/decoder embeddings don't include non-linearity
|
||||
output = self.dropout(output)
|
||||
output = output.reshape(output.shape[0], -1) # (batch_size, c_in * d_model)
|
||||
output = self.projection(output) # (batch_size, num_classes)
|
||||
return output
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
||||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
||||
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
||||
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
||||
if self.task_name == 'imputation':
|
||||
dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'anomaly_detection':
|
||||
dec_out = self.anomaly_detection(x_enc)
|
||||
return dec_out # [B, L, D]
|
||||
if self.task_name == 'classification':
|
||||
dec_out = self.classification(x_enc, x_mark_enc)
|
||||
return dec_out # [B, N]
|
||||
return None
|
3
models/xPatch_SparseChannel/__init__.py
Normal file
3
models/xPatch_SparseChannel/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .xPatch import Model
|
||||
|
||||
__all__ = ['Model']
|
66
models/xPatch_SparseChannel/patchtst_ci.py
Normal file
66
models/xPatch_SparseChannel/patchtst_ci.py
Normal file
@ -0,0 +1,66 @@
|
||||
"""
|
||||
SeasonPatch = PatchTST (CI) + ChannelGraphMixer + 线性预测头
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from layers.PatchTST_layers import positional_encoding # 已存在
|
||||
from layers.mixer import HierarchicalGraphMixer # 刚才创建
|
||||
|
||||
# ------ PatchTST CI 编码器(与官方实现等价, 但去掉 head 便于插入 Mixer) ------
|
||||
class _Encoder(nn.Module):
|
||||
def __init__(self, c_in, seq_len, patch_len, stride,
|
||||
d_model=128, n_layers=3, n_heads=16, dropout=0.):
|
||||
super().__init__()
|
||||
self.patch_len = patch_len
|
||||
self.stride = stride
|
||||
self.patch_num = (seq_len - patch_len)//stride + 1
|
||||
|
||||
self.proj = nn.Linear(patch_len, d_model)
|
||||
self.pos = positional_encoding('zeros', True, self.patch_num, d_model)
|
||||
self.drop = nn.Dropout(dropout)
|
||||
|
||||
from layers.PatchTST_backbone import TSTEncoder # 与官方同名
|
||||
self.encoder = TSTEncoder(self.patch_num, d_model, n_heads,
|
||||
d_ff=d_model*2, dropout=dropout,
|
||||
n_layers=n_layers, norm='LayerNorm')
|
||||
|
||||
def forward(self, x): # [B,C,L]
|
||||
x = x.unfold(-1, self.patch_len, self.stride) # [B,C,patch_num,patch_len]
|
||||
B,C,N,P = x.shape
|
||||
z = self.proj(x) # [B,C,N,d]
|
||||
z = z.contiguous().view(B*C, N, -1) # [B*C,N,d]
|
||||
z = self.drop(z + self.pos)
|
||||
z = self.encoder(z) # [B*C,N,d]
|
||||
return z.view(B,C,N,-1) # [B,C,N,d]
|
||||
|
||||
# ------------------------- SeasonPatch -----------------------------
|
||||
class SeasonPatch(nn.Module):
|
||||
def __init__(self,
|
||||
c_in: int,
|
||||
seq_len: int,
|
||||
pred_len: int,
|
||||
patch_len: int,
|
||||
stride: int,
|
||||
k_graph: int = 5,
|
||||
d_model: int = 128,
|
||||
revin: bool = True):
|
||||
super().__init__()
|
||||
|
||||
self.encoder = _Encoder(c_in, seq_len, patch_len, stride,
|
||||
d_model=d_model)
|
||||
self.mixer = HierarchicalGraphMixer(c_in, dim=d_model, k=k_graph)
|
||||
|
||||
# Calculate actual number of patches
|
||||
patch_num = (seq_len - patch_len) // stride + 1
|
||||
self.head = nn.Linear(patch_num * d_model, pred_len)
|
||||
|
||||
def forward(self, x): # x [B,L,C]
|
||||
x = x.permute(0,2,1) # → [B,C,L]
|
||||
z = self.encoder(x) # [B,C,N,d]
|
||||
B,C,N,D = z.shape
|
||||
# 通道独立 -> 稀疏跨通道注入
|
||||
z_mix = self.mixer(z).view(B,C,N*D) # [B,C,N,d]
|
||||
y_pred = self.head(z_mix) # [B,C,T]
|
||||
|
||||
return y_pred
|
||||
|
56
models/xPatch_SparseChannel/xPatch.py
Normal file
56
models/xPatch_SparseChannel/xPatch.py
Normal file
@ -0,0 +1,56 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
from layers.decomp import DECOMP
|
||||
from .xPatch_SparseChannel import Network
|
||||
from layers.revin import RevIN
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, configs):
|
||||
super(Model, self).__init__()
|
||||
|
||||
# Parameters
|
||||
seq_len = configs.seq_len # lookback window L
|
||||
pred_len = configs.pred_len # prediction length (96, 192, 336, 720)
|
||||
c_in = configs.enc_in # input channels
|
||||
|
||||
# Patching
|
||||
patch_len = configs.patch_len
|
||||
stride = configs.stride
|
||||
padding_patch = configs.padding_patch
|
||||
|
||||
# Normalization
|
||||
self.revin = configs.revin
|
||||
self.revin_layer = RevIN(c_in,affine=True,subtract_last=False)
|
||||
|
||||
# Moving Average
|
||||
self.ma_type = configs.ma_type
|
||||
alpha = configs.alpha # smoothing factor for EMA (Exponential Moving Average)
|
||||
beta = configs.beta # smoothing factor for DEMA (Double Exponential Moving Average)
|
||||
|
||||
self.decomp = DECOMP(self.ma_type, alpha, beta)
|
||||
self.net = Network(seq_len, pred_len, patch_len, stride, padding_patch,c_in)
|
||||
# self.net_mlp = NetworkMLP(seq_len, pred_len) # For ablation study with MLP-only stream
|
||||
# self.net_cnn = NetworkCNN(seq_len, pred_len, patch_len, stride, padding_patch) # For ablation study with CNN-only stream
|
||||
|
||||
def forward(self, x):
|
||||
# x: [Batch, Input, Channel]
|
||||
|
||||
# Normalization
|
||||
if self.revin:
|
||||
x = self.revin_layer(x, 'norm')
|
||||
|
||||
if self.ma_type == 'reg': # If no decomposition, directly pass the input to the network
|
||||
x = self.net(x, x)
|
||||
# x = self.net_mlp(x) # For ablation study with MLP-only stream
|
||||
# x = self.net_cnn(x) # For ablation study with CNN-only stream
|
||||
else:
|
||||
seasonal_init, trend_init = self.decomp(x)
|
||||
x = self.net(seasonal_init, trend_init)
|
||||
|
||||
# Denormalization
|
||||
if self.revin:
|
||||
x = self.revin_layer(x, 'denorm')
|
||||
|
||||
return x
|
63
models/xPatch_SparseChannel/xPatch_SparseChannel.py
Normal file
63
models/xPatch_SparseChannel/xPatch_SparseChannel.py
Normal file
@ -0,0 +1,63 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from .patchtst_ci import SeasonPatch # <<< 新导入
|
||||
|
||||
class Network(nn.Module):
|
||||
"""
|
||||
trend : 原 MLP 线性流 (完全保留)
|
||||
season : SeasonPatch (PatchTST + Mixer)
|
||||
"""
|
||||
def __init__(self, seq_len, pred_len,
|
||||
patch_len, stride, padding_patch,
|
||||
c_in):
|
||||
super().__init__()
|
||||
|
||||
# -------- 季节性流 ---------------
|
||||
self.season_net = SeasonPatch(c_in=c_in,
|
||||
seq_len=seq_len,
|
||||
pred_len=pred_len,
|
||||
patch_len=patch_len,
|
||||
stride=stride,
|
||||
k_graph=5,
|
||||
d_model=128)
|
||||
|
||||
# --------- 线性趋势流 (原代码保持不变) ----------
|
||||
self.pred_len = pred_len
|
||||
self.fc5 = nn.Linear(seq_len, pred_len * 4)
|
||||
self.avgpool1 = nn.AvgPool1d(kernel_size=2)
|
||||
self.ln1 = nn.LayerNorm(pred_len * 2)
|
||||
|
||||
self.fc6 = nn.Linear(pred_len * 2, pred_len)
|
||||
self.avgpool2 = nn.AvgPool1d(kernel_size=2)
|
||||
self.ln2 = nn.LayerNorm(pred_len // 2)
|
||||
|
||||
self.fc7 = nn.Linear(pred_len // 2, pred_len)
|
||||
|
||||
# 流结果拼接
|
||||
self.fc_final = nn.Linear(pred_len * 2, pred_len)
|
||||
|
||||
# ---------------- forward --------------------
|
||||
def forward(self, s, t):
|
||||
# 输入形状: [B,L,C]
|
||||
B,L,C = s.shape
|
||||
|
||||
# ---------- Seasonality ------------
|
||||
y_season = self.season_net(s) # [B,C,T]
|
||||
|
||||
# ---------- Trend (原 MLP) ----------
|
||||
t = t.permute(0,2,1).reshape(B*C, L) # [B*C,L]
|
||||
t = self.fc5(t)
|
||||
t = self.avgpool1(t)
|
||||
t = self.ln1(t)
|
||||
t = self.fc6(t)
|
||||
t = self.avgpool2(t)
|
||||
t = self.ln2(t)
|
||||
t = self.fc7(t) # [B*C,T]
|
||||
y_trend = t.view(B, C, -1) # [B,C,T]
|
||||
|
||||
# --------- 拼接 & 输出 --------------
|
||||
y = torch.cat([y_season, y_trend], dim=-1) # [B,C,2T]
|
||||
y = self.fc_final(y) # [B,C,T]
|
||||
y = y.permute(0,2,1) # [B,T,C]
|
||||
return y
|
||||
|
177
train/train.py
177
train/train.py
@ -8,6 +8,8 @@ from torch.utils.data import DataLoader, TensorDataset
|
||||
import swanlab
|
||||
from typing import Dict, Any, Optional, Callable, Union, Tuple
|
||||
from dataflow import data_provider
|
||||
from layers.ps_loss import PSLoss
|
||||
from utils.tools import adjust_learning_rate, dotdict
|
||||
|
||||
class EarlyStopping:
|
||||
"""Early stopping to stop training when validation performance doesn't improve."""
|
||||
@ -138,7 +140,9 @@ def create_data_loaders_from_dataflow(args, use_x_mark: bool = True) -> Dict[str
|
||||
'test': test_loader
|
||||
}
|
||||
|
||||
def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool = True) -> Dict[str, DataLoader]:
|
||||
def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool = True,
|
||||
num_workers: int = 4, pin_memory: bool = True,
|
||||
persistent_workers: bool = True) -> Dict[str, DataLoader]:
|
||||
"""
|
||||
Create PyTorch DataLoaders from an NPZ file
|
||||
|
||||
@ -146,6 +150,9 @@ def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool =
|
||||
data_path (str): Path to the NPZ file containing the data
|
||||
batch_size (int): Batch size for the DataLoaders
|
||||
use_x_mark (bool): Whether to use time features (x_mark) from the data file
|
||||
num_workers (int): Number of worker processes for data loading
|
||||
pin_memory (bool): Whether to pin memory for faster GPU transfer
|
||||
persistent_workers (bool): Whether to keep workers alive between epochs
|
||||
|
||||
Returns:
|
||||
Dict[str, DataLoader]: Dictionary with train, val, and test DataLoaders
|
||||
@ -200,10 +207,34 @@ def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool =
|
||||
val_dataset = TensorDataset(val_x, val_y)
|
||||
test_dataset = TensorDataset(test_x, test_y)
|
||||
|
||||
# Create dataloaders
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
||||
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
|
||||
# Create dataloaders with performance optimizations
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=num_workers,
|
||||
pin_memory=pin_memory,
|
||||
persistent_workers=persistent_workers if num_workers > 0 else False,
|
||||
drop_last=True # Drop incomplete batches for training
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
pin_memory=pin_memory,
|
||||
persistent_workers=persistent_workers if num_workers > 0 else False,
|
||||
drop_last=False
|
||||
)
|
||||
test_loader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
pin_memory=pin_memory,
|
||||
persistent_workers=persistent_workers if num_workers > 0 else False,
|
||||
drop_last=False
|
||||
)
|
||||
|
||||
return {
|
||||
'train': train_loader,
|
||||
@ -223,7 +254,12 @@ def train_forecasting_model(
|
||||
log_interval: int = 10,
|
||||
use_x_mark: bool = True,
|
||||
dataset_mode: str = "npz",
|
||||
dataflow_args = None
|
||||
dataflow_args = None,
|
||||
use_ps_loss: bool = False,
|
||||
ps_lambda: float = 5.0,
|
||||
patch_len_threshold: int = 64,
|
||||
use_gdw: bool = True,
|
||||
lr_adjust_strategy: str = "type1"
|
||||
) -> Tuple[nn.Module, Dict[str, float]]:
|
||||
"""
|
||||
Train a time series forecasting model
|
||||
@ -241,6 +277,11 @@ def train_forecasting_model(
|
||||
use_x_mark (bool): Whether to use time features (x_mark) from the data file
|
||||
dataset_mode (str): Dataset construction mode - "npz" or "dataflow"
|
||||
dataflow_args: Arguments object for dataflow mode (required if dataset_mode="dataflow")
|
||||
use_ps_loss (bool): Whether to use Patch-wise Structural (PS) loss instead of MSE
|
||||
ps_lambda (float): Weight for PS loss component when combined with MSE
|
||||
patch_len_threshold (int): Maximum patch length for adaptive patching
|
||||
use_gdw (bool): Whether to use Gradient-based Dynamic Weighting
|
||||
lr_adjust_strategy (str): Learning rate adjustment strategy - 'type1', 'type2', 'type3', 'sigmoid', 'constant', '3', '4', '5', '6'
|
||||
|
||||
Returns:
|
||||
Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics
|
||||
@ -271,7 +312,10 @@ def train_forecasting_model(
|
||||
dataloaders = create_data_loaders(
|
||||
data_path=data_path,
|
||||
batch_size=config.get('batch_size', 32),
|
||||
use_x_mark=use_x_mark
|
||||
use_x_mark=use_x_mark,
|
||||
num_workers=config.get('num_workers', 4),
|
||||
pin_memory=config.get('pin_memory', True),
|
||||
persistent_workers=config.get('persistent_workers', True)
|
||||
)
|
||||
|
||||
# Construct the model
|
||||
@ -279,14 +323,24 @@ def train_forecasting_model(
|
||||
model = model.to(device)
|
||||
|
||||
# Define loss function and optimizer
|
||||
criterion = nn.MSELoss()
|
||||
if use_ps_loss:
|
||||
criterion = PSLoss(
|
||||
patch_len_threshold=patch_len_threshold,
|
||||
lambda_ps=ps_lambda,
|
||||
use_gdw=use_gdw
|
||||
)
|
||||
else:
|
||||
criterion = nn.MSELoss()
|
||||
optimizer = optim.Adam(
|
||||
model.parameters(),
|
||||
lr=config.get('learning_rate', 1e-3),
|
||||
)
|
||||
|
||||
# Add learning rate scheduler to halve LR after each epoch
|
||||
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
|
||||
# Create args object for learning rate adjustment
|
||||
lr_args = dotdict({
|
||||
'learning_rate': config.get('learning_rate', 1e-3),
|
||||
'lradj': lr_adjust_strategy
|
||||
})
|
||||
|
||||
# Initialize early stopping
|
||||
early_stopping = EarlyStopping(
|
||||
@ -334,7 +388,11 @@ def train_forecasting_model(
|
||||
# For simple models without time features
|
||||
outputs = model(inputs)
|
||||
|
||||
loss = criterion(outputs, targets)
|
||||
# Calculate loss
|
||||
if use_ps_loss:
|
||||
loss, loss_dict = criterion(outputs, targets, model)
|
||||
else:
|
||||
loss = criterion(outputs, targets)
|
||||
|
||||
# Backward pass and optimize
|
||||
loss.backward()
|
||||
@ -345,10 +403,26 @@ def train_forecasting_model(
|
||||
interval_loss += loss.item()
|
||||
|
||||
if (batch_idx + 1) % log_interval == 0:
|
||||
print(f"Batch {batch_idx+1}/{len(dataloaders['train'])}, Loss: {loss.item():.4f}")
|
||||
# 计算这一个 interval 的平均损失并记录
|
||||
avg_interval_loss = interval_loss / log_interval
|
||||
swanlab_run.log({"batch_train_loss": avg_interval_loss})
|
||||
if use_ps_loss and 'loss_dict' in locals():
|
||||
print(f"Batch {batch_idx+1}/{len(dataloaders['train'])}, "
|
||||
f"Total Loss: {loss.item():.4f}, "
|
||||
f"MSE: {loss_dict['mse_loss']:.4f}, "
|
||||
f"PS: {loss_dict['ps_loss']:.4f}")
|
||||
# Log detailed loss components
|
||||
swanlab_run.log({
|
||||
"batch_total_loss": loss.item(),
|
||||
"batch_mse_loss": loss_dict['mse_loss'],
|
||||
"batch_ps_loss": loss_dict['ps_loss'],
|
||||
"batch_corr_loss": loss_dict['corr_loss'],
|
||||
"batch_var_loss": loss_dict['var_loss'],
|
||||
"batch_mean_loss": loss_dict['mean_loss'],
|
||||
"alpha": loss_dict['alpha'],
|
||||
"beta": loss_dict['beta'],
|
||||
"gamma": loss_dict['gamma']
|
||||
})
|
||||
else:
|
||||
print(f"Batch {batch_idx+1}/{len(dataloaders['train'])}, Loss: {loss.item():.4f}")
|
||||
swanlab_run.log({"batch_train_loss": loss.item()})
|
||||
|
||||
# 重置 interval loss 以进行下一次计算
|
||||
interval_loss = 0.0
|
||||
@ -360,6 +434,7 @@ def train_forecasting_model(
|
||||
model.eval()
|
||||
val_loss = 0.0
|
||||
val_mse = 0.0
|
||||
val_mse_criterion = nn.MSELoss() # Always use MSE for validation metrics
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_data in dataloaders['val']:
|
||||
@ -381,18 +456,28 @@ def train_forecasting_model(
|
||||
# For simple models without time features
|
||||
outputs = model(inputs)
|
||||
|
||||
# Calculate loss
|
||||
loss = criterion(outputs, targets)
|
||||
val_loss += loss.item()
|
||||
# Calculate training loss (PS or MSE)
|
||||
if use_ps_loss:
|
||||
loss, _ = criterion(outputs, targets, model)
|
||||
val_loss += loss.item()
|
||||
else:
|
||||
loss = criterion(outputs, targets)
|
||||
val_loss += loss.item()
|
||||
|
||||
# Always calculate MSE for validation metrics
|
||||
mse_loss = val_mse_criterion(outputs, targets)
|
||||
val_mse += mse_loss.item()
|
||||
|
||||
|
||||
avg_val_loss = val_loss / len(dataloaders['val'])
|
||||
avg_val_mse = val_mse / len(dataloaders['val'])
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
|
||||
# Log metrics
|
||||
metrics_dict = {
|
||||
"train_loss": avg_train_loss,
|
||||
"val_loss": avg_val_loss,
|
||||
"val_mse": avg_val_mse,
|
||||
"learning_rate": current_lr,
|
||||
"epoch_time": epoch_time
|
||||
}
|
||||
@ -402,6 +487,7 @@ def train_forecasting_model(
|
||||
print(f"Epoch {epoch+1}/{max_epochs}, "
|
||||
f"Train Loss: {avg_train_loss:.4f}, "
|
||||
f"Val Loss: {avg_val_loss:.4f}, "
|
||||
f"Val MSE: {avg_val_mse:.4f}, "
|
||||
f"LR: {current_lr:.6f}, "
|
||||
f"Time: {epoch_time:.2f}s")
|
||||
|
||||
@ -416,16 +502,17 @@ def train_forecasting_model(
|
||||
print("Early stopping triggered")
|
||||
break
|
||||
|
||||
# Step the learning rate scheduler
|
||||
scheduler.step()
|
||||
# Adjust learning rate using utils.tools function
|
||||
adjust_learning_rate(optimizer, epoch, lr_args)
|
||||
|
||||
# Load the best model
|
||||
model.load_state_dict(torch.load(checkpoint_path))
|
||||
|
||||
# Test evaluation on the best model
|
||||
# Test evaluation on the best model - Always use MSE for final evaluation
|
||||
model.eval()
|
||||
test_loss = 0.0
|
||||
test_mse = 0.0
|
||||
mse_criterion = nn.MSELoss() # Always use MSE for test evaluation
|
||||
|
||||
print("Evaluating on test set...")
|
||||
with torch.no_grad():
|
||||
@ -448,16 +535,16 @@ def train_forecasting_model(
|
||||
# For simple models without time features
|
||||
outputs = model(inputs)
|
||||
|
||||
# Calculate loss
|
||||
loss = criterion(outputs, targets)
|
||||
test_loss += loss.item()
|
||||
# Always calculate MSE for test evaluation (for fair comparison)
|
||||
mse_loss = mse_criterion(outputs, targets)
|
||||
test_loss += mse_loss.item()
|
||||
|
||||
test_loss /= len(dataloaders['test'])
|
||||
|
||||
print(f"Test evaluation completed!")
|
||||
print(f"Test Loss (MSE): {test_loss:.6f}")
|
||||
|
||||
# Final validation for consistency
|
||||
# Final validation for consistency - Always use MSE for final metrics
|
||||
model.eval()
|
||||
final_val_loss = 0.0
|
||||
final_val_mse = 0.0
|
||||
@ -482,25 +569,31 @@ def train_forecasting_model(
|
||||
# For simple models without time features
|
||||
outputs = model(inputs)
|
||||
|
||||
# Calculate loss
|
||||
loss = criterion(outputs, targets)
|
||||
final_val_loss += loss.item()
|
||||
# Always calculate MSE for final validation (for fair comparison)
|
||||
mse_loss = mse_criterion(outputs, targets)
|
||||
final_val_loss += mse_loss.item()
|
||||
|
||||
|
||||
final_val_loss /= len(dataloaders['val'])
|
||||
|
||||
print(f"Final validation loss: {final_val_loss:.6f}")
|
||||
print(f"Final validation MSE: {final_val_loss:.6f}")
|
||||
print(f"Final test MSE: {test_loss:.6f}")
|
||||
|
||||
if use_ps_loss:
|
||||
print("Note: Model was trained with PS Loss but evaluated with MSE for fair comparison")
|
||||
|
||||
# Log final test results to swanlab
|
||||
final_metrics = {
|
||||
"final_test_loss": test_loss,
|
||||
"final_val_loss": final_val_loss
|
||||
"final_test_mse": test_loss,
|
||||
"final_val_mse": final_val_loss
|
||||
}
|
||||
swanlab_run.log(final_metrics)
|
||||
|
||||
# Update metrics with final values
|
||||
# Update metrics with final values (always MSE for comparison)
|
||||
metrics["final_val_loss"] = final_val_loss
|
||||
metrics["final_test_loss"] = test_loss
|
||||
metrics["final_val_mse"] = final_val_loss # Same as final_val_loss since we use MSE
|
||||
metrics["final_test_mse"] = test_loss # Same as final_test_loss since we use MSE
|
||||
|
||||
# Finish the swanlab run
|
||||
swanlab_run.finish()
|
||||
@ -519,7 +612,8 @@ def train_classification_model(
|
||||
log_interval: int = 10,
|
||||
use_x_mark: bool = True,
|
||||
dataset_mode: str = "npz",
|
||||
dataflow_args = None
|
||||
dataflow_args = None,
|
||||
lr_adjust_strategy: str = "type1"
|
||||
) -> Tuple[nn.Module, Dict[str, float]]:
|
||||
"""
|
||||
Train a time series classification model
|
||||
@ -537,6 +631,7 @@ def train_classification_model(
|
||||
use_x_mark (bool): Whether to use time features (x_mark) from the data file
|
||||
dataset_mode (str): Dataset construction mode - "npz" or "dataflow"
|
||||
dataflow_args: Arguments object for dataflow mode (required if dataset_mode="dataflow")
|
||||
lr_adjust_strategy (str): Learning rate adjustment strategy - 'type1', 'type2', 'type3', 'sigmoid', 'constant', '3', '4', '5', '6'
|
||||
|
||||
Returns:
|
||||
Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics
|
||||
@ -567,7 +662,10 @@ def train_classification_model(
|
||||
dataloaders = create_data_loaders(
|
||||
data_path=data_path,
|
||||
batch_size=config.get('batch_size', 32),
|
||||
use_x_mark=use_x_mark
|
||||
use_x_mark=use_x_mark,
|
||||
num_workers=config.get('num_workers', 4),
|
||||
pin_memory=config.get('pin_memory', True),
|
||||
persistent_workers=config.get('persistent_workers', True)
|
||||
)
|
||||
|
||||
# Construct the model
|
||||
@ -582,8 +680,11 @@ def train_classification_model(
|
||||
weight_decay=config.get('weight_decay', 1e-4)
|
||||
)
|
||||
|
||||
# Add learning rate scheduler to halve LR after each epoch
|
||||
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5)
|
||||
# Create args object for learning rate adjustment
|
||||
lr_args = dotdict({
|
||||
'learning_rate': config.get('learning_rate', 1e-3),
|
||||
'lradj': lr_adjust_strategy
|
||||
})
|
||||
|
||||
# Initialize early stopping
|
||||
early_stopping = EarlyStopping(
|
||||
@ -722,8 +823,8 @@ def train_classification_model(
|
||||
print("Early stopping triggered")
|
||||
break
|
||||
|
||||
# Step the learning rate scheduler
|
||||
scheduler.step()
|
||||
# Adjust learning rate using utils.tools function
|
||||
adjust_learning_rate(optimizer, epoch, lr_args)
|
||||
|
||||
# Load the best model
|
||||
model.load_state_dict(torch.load(checkpoint_path))
|
||||
|
408
train/train_diffusion.py
Normal file
408
train/train_diffusion.py
Normal file
@ -0,0 +1,408 @@
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
import swanlab
|
||||
from typing import Dict, Any, Optional, Callable, Union, Tuple
|
||||
from dataflow import data_provider
|
||||
|
||||
class EarlyStopping:
|
||||
"""Early stopping to stop training when validation performance doesn't improve."""
|
||||
def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'):
|
||||
self.patience = patience
|
||||
self.verbose = verbose
|
||||
self.counter = 0
|
||||
self.best_score = None
|
||||
self.early_stop = False
|
||||
self.val_loss_min = float('inf')
|
||||
self.delta = delta
|
||||
self.path = path
|
||||
|
||||
def __call__(self, val_loss, model):
|
||||
score = -val_loss
|
||||
|
||||
if self.best_score is None:
|
||||
self.best_score = score
|
||||
self.save_checkpoint(val_loss, model)
|
||||
elif score < self.best_score + self.delta:
|
||||
self.counter += 1
|
||||
if self.verbose:
|
||||
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
|
||||
if self.counter >= self.patience:
|
||||
self.early_stop = True
|
||||
else:
|
||||
self.best_score = score
|
||||
self.save_checkpoint(val_loss, model)
|
||||
self.counter = 0
|
||||
|
||||
def save_checkpoint(self, val_loss, model):
|
||||
"""Save model when validation loss decreases."""
|
||||
if self.verbose:
|
||||
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...')
|
||||
torch.save(model.state_dict(), self.path)
|
||||
self.val_loss_min = val_loss
|
||||
|
||||
class DatasetWrapperWithoutTimeFeatures(torch.utils.data.Dataset):
|
||||
"""Wrapper to remove time features from dataflow datasets when use_x_mark=False"""
|
||||
def __init__(self, original_dataset):
|
||||
self.original_dataset = original_dataset
|
||||
|
||||
def __getitem__(self, index):
|
||||
seq_x, seq_y, seq_x_mark, seq_y_mark = self.original_dataset[index]
|
||||
return seq_x, seq_y
|
||||
|
||||
def __len__(self):
|
||||
return len(self.original_dataset)
|
||||
|
||||
def inverse_transform(self, data):
|
||||
if hasattr(self.original_dataset, 'inverse_transform'):
|
||||
return self.original_dataset.inverse_transform(data)
|
||||
return data
|
||||
|
||||
def create_data_loaders_from_dataflow(args, use_x_mark: bool = True) -> Dict[str, DataLoader]:
|
||||
"""Create PyTorch DataLoaders using dataflow data_provider"""
|
||||
train_data, _ = data_provider(args, flag='train')
|
||||
val_data, _ = data_provider(args, flag='val')
|
||||
test_data, _ = data_provider(args, flag='test')
|
||||
|
||||
if not use_x_mark:
|
||||
train_data = DatasetWrapperWithoutTimeFeatures(train_data)
|
||||
val_data = DatasetWrapperWithoutTimeFeatures(val_data)
|
||||
test_data = DatasetWrapperWithoutTimeFeatures(test_data)
|
||||
|
||||
train_shuffle = True
|
||||
val_shuffle = False
|
||||
test_shuffle = False
|
||||
|
||||
train_drop_last = True
|
||||
val_drop_last = True
|
||||
test_drop_last = True
|
||||
|
||||
batch_size = args.batch_size
|
||||
num_workers = args.num_workers
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_data, batch_size=batch_size, shuffle=train_shuffle,
|
||||
num_workers=num_workers, drop_last=train_drop_last
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_data, batch_size=batch_size, shuffle=val_shuffle,
|
||||
num_workers=num_workers, drop_last=val_drop_last
|
||||
)
|
||||
test_loader = DataLoader(
|
||||
test_data, batch_size=batch_size, shuffle=test_shuffle,
|
||||
num_workers=num_workers, drop_last=test_drop_last
|
||||
)
|
||||
|
||||
return {'train': train_loader, 'val': val_loader, 'test': test_loader}
|
||||
|
||||
def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool = True) -> Dict[str, DataLoader]:
|
||||
"""
|
||||
Create PyTorch DataLoaders from an NPZ file
|
||||
|
||||
Args:
|
||||
data_path (str): Path to the NPZ file containing the data
|
||||
batch_size (int): Batch size for the DataLoaders
|
||||
use_x_mark (bool): Whether to use time features (x_mark) from the data file
|
||||
|
||||
Returns:
|
||||
Dict[str, DataLoader]: Dictionary with train, val, and test DataLoaders
|
||||
"""
|
||||
# Load data from NPZ file
|
||||
data = np.load(data_path, allow_pickle=True)
|
||||
train_x = data['train_x']
|
||||
train_y = data['train_y']
|
||||
val_x = data['val_x']
|
||||
val_y = data['val_y']
|
||||
test_x = data['test_x']
|
||||
test_y = data['test_y']
|
||||
|
||||
# Load time features if available and needed
|
||||
if use_x_mark:
|
||||
train_x_mark = data.get('train_x_mark', None)
|
||||
train_y_mark = data.get('train_y_mark', None)
|
||||
val_x_mark = data.get('val_x_mark', None)
|
||||
val_y_mark = data.get('val_y_mark', None)
|
||||
test_x_mark = data.get('test_x_mark', None)
|
||||
test_y_mark = data.get('test_y_mark', None)
|
||||
else:
|
||||
train_x_mark = None
|
||||
train_y_mark = None
|
||||
val_x_mark = None
|
||||
val_y_mark = None
|
||||
test_x_mark = None
|
||||
test_y_mark = None
|
||||
|
||||
# Convert to PyTorch tensors
|
||||
train_x = torch.FloatTensor(train_x)
|
||||
train_y = torch.FloatTensor(train_y)
|
||||
val_x = torch.FloatTensor(val_x)
|
||||
val_y = torch.FloatTensor(val_y)
|
||||
test_x = torch.FloatTensor(test_x)
|
||||
test_y = torch.FloatTensor(test_y)
|
||||
|
||||
# Create datasets based on whether time features are available
|
||||
if train_x_mark is not None:
|
||||
train_x_mark = torch.FloatTensor(train_x_mark)
|
||||
train_y_mark = torch.FloatTensor(train_y_mark)
|
||||
val_x_mark = torch.FloatTensor(val_x_mark)
|
||||
val_y_mark = torch.FloatTensor(val_y_mark)
|
||||
test_x_mark = torch.FloatTensor(test_x_mark)
|
||||
test_y_mark = torch.FloatTensor(test_y_mark)
|
||||
|
||||
train_dataset = TensorDataset(train_x, train_y, train_x_mark, train_y_mark)
|
||||
val_dataset = TensorDataset(val_x, val_y, val_x_mark, val_y_mark)
|
||||
test_dataset = TensorDataset(test_x, test_y, test_x_mark, test_y_mark)
|
||||
else:
|
||||
train_dataset = TensorDataset(train_x, train_y)
|
||||
val_dataset = TensorDataset(val_x, val_y)
|
||||
test_dataset = TensorDataset(test_x, test_y)
|
||||
|
||||
# Create dataloaders
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
||||
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
|
||||
|
||||
return {
|
||||
'train': train_loader,
|
||||
'val': val_loader,
|
||||
'test': test_loader
|
||||
}
|
||||
|
||||
def train_diffusion_model(
|
||||
model_constructor: Callable,
|
||||
data_path: str,
|
||||
project_name: str,
|
||||
config: Dict[str, Any],
|
||||
device: Optional[str] = None,
|
||||
early_stopping_patience: int = 10,
|
||||
max_epochs: int = 100,
|
||||
checkpoint_dir: str = "./checkpoints",
|
||||
log_interval: int = 10,
|
||||
) -> Tuple[nn.Module, Dict[str, float]]:
|
||||
"""
|
||||
Train a Diffusion time series forecasting model using NPZ data loading
|
||||
"""
|
||||
# Setup device
|
||||
if device is None:
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
# Initialize swanlab for experiment tracking
|
||||
swanlab_run = swanlab.init(
|
||||
project=project_name,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Create checkpoint directory if it doesn't exist
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
checkpoint_path = os.path.join(checkpoint_dir, f"{project_name}.pt")
|
||||
|
||||
# Create data loaders using NPZ files (following other models' pattern)
|
||||
dataloaders = create_data_loaders(
|
||||
data_path=data_path,
|
||||
batch_size=config.get('batch_size', 32),
|
||||
use_x_mark=False # DiffusionTimeSeries doesn't use time features
|
||||
)
|
||||
|
||||
# Construct the model
|
||||
model = model_constructor()
|
||||
model = model.to(device)
|
||||
|
||||
print(f"Model created with {model.get_num_params():,} parameters")
|
||||
|
||||
# Define optimizer for diffusion training
|
||||
optimizer = optim.Adam(
|
||||
model.parameters(),
|
||||
lr=config.get('learning_rate', 1e-4),
|
||||
weight_decay=config.get('weight_decay', 1e-4)
|
||||
)
|
||||
|
||||
# Learning rate scheduler
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer, mode='min', patience=5, factor=0.5
|
||||
)
|
||||
|
||||
# Initialize early stopping
|
||||
early_stopping = EarlyStopping(
|
||||
patience=early_stopping_patience,
|
||||
verbose=True,
|
||||
path=checkpoint_path
|
||||
)
|
||||
|
||||
# Training loop
|
||||
best_val_loss = float('inf')
|
||||
metrics = {}
|
||||
|
||||
for epoch in range(max_epochs):
|
||||
print(f"\nEpoch {epoch+1}/{max_epochs}")
|
||||
print("-" * 50)
|
||||
|
||||
# Training phase
|
||||
model.train()
|
||||
train_loss = 0.0
|
||||
train_samples = 0
|
||||
|
||||
interval_loss = 0.0
|
||||
start_time = time.time()
|
||||
|
||||
for batch_idx, (seq_x, seq_y) in enumerate(dataloaders['train']):
|
||||
seq_x, seq_y = seq_x.to(device), seq_y.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Diffusion training: model returns loss directly when y is provided
|
||||
loss = model(seq_x, seq_y)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
train_loss += loss.item()
|
||||
interval_loss += loss.item()
|
||||
train_samples += 1
|
||||
|
||||
# Log at intervals
|
||||
if (batch_idx + 1) % log_interval == 0:
|
||||
elapsed_time = time.time() - start_time
|
||||
avg_interval_loss = interval_loss / log_interval
|
||||
|
||||
print(f' Batch [{batch_idx+1}/{len(dataloaders["train"])}] '
|
||||
f'Loss: {avg_interval_loss:.6f} '
|
||||
f'Time: {elapsed_time:.2f}s')
|
||||
|
||||
# Log to swanlab
|
||||
swanlab.log({
|
||||
'batch_loss': avg_interval_loss,
|
||||
'batch': epoch * len(dataloaders['train']) + batch_idx,
|
||||
'learning_rate': optimizer.param_groups[0]['lr']
|
||||
})
|
||||
|
||||
interval_loss = 0.0
|
||||
start_time = time.time()
|
||||
|
||||
avg_train_loss = train_loss / train_samples
|
||||
|
||||
# Validation phase - Use faster sampling for validation
|
||||
model.eval()
|
||||
val_loss = 0.0
|
||||
val_samples = 0
|
||||
criterion = nn.MSELoss()
|
||||
|
||||
print(" Validating...")
|
||||
with torch.no_grad():
|
||||
# Temporarily reduce diffusion steps for faster validation
|
||||
original_timesteps = model.diffusion.num_timesteps
|
||||
model.diffusion.num_timesteps = 200# Much faster validation
|
||||
|
||||
for batch_idx, (seq_x, seq_y) in enumerate(dataloaders['val']):
|
||||
seq_x, seq_y = seq_x.to(device), seq_y.to(device)
|
||||
|
||||
# Generate predictions (inference mode with reduced steps)
|
||||
pred = model(seq_x)
|
||||
|
||||
# Compute MSE loss for validation
|
||||
loss = criterion(pred, seq_y)
|
||||
val_loss += loss.item()
|
||||
val_samples += 1
|
||||
|
||||
# Print validation progress for first epoch
|
||||
if epoch == 0 and (batch_idx + 1) % 50 == 0:
|
||||
print(f" Val Batch [{batch_idx+1}/{len(dataloaders['val'])}]")
|
||||
|
||||
# Early break for very first epoch to speed up
|
||||
if epoch == 0 and batch_idx >= 100: # Only validate on first 100 batches for first epoch
|
||||
break
|
||||
|
||||
# Restore original timesteps
|
||||
model.diffusion.num_timesteps = original_timesteps
|
||||
|
||||
avg_val_loss = val_loss / val_samples
|
||||
|
||||
# Learning rate scheduling
|
||||
scheduler.step(avg_val_loss)
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
|
||||
print(f" Train Loss: {avg_train_loss:.6f}")
|
||||
print(f" Val Loss: {avg_val_loss:.6f}")
|
||||
print(f" Learning Rate: {current_lr:.2e}")
|
||||
|
||||
# Log to swanlab
|
||||
swanlab.log({
|
||||
'epoch': epoch + 1,
|
||||
'train_loss': avg_train_loss,
|
||||
'val_loss': avg_val_loss,
|
||||
'learning_rate': current_lr
|
||||
})
|
||||
|
||||
# Early stopping check
|
||||
early_stopping(avg_val_loss, model)
|
||||
|
||||
if early_stopping.early_stop:
|
||||
print(f"Early stopping at epoch {epoch + 1}")
|
||||
break
|
||||
|
||||
# Load best model
|
||||
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
|
||||
|
||||
# Final evaluation on test set
|
||||
print("\nEvaluating on test set...")
|
||||
model.eval()
|
||||
test_loss = 0.0
|
||||
test_samples = 0
|
||||
all_preds = []
|
||||
all_targets = []
|
||||
|
||||
with torch.no_grad():
|
||||
# Use reduced timesteps for faster testing
|
||||
original_timesteps = model.diffusion.num_timesteps
|
||||
model.diffusion.num_timesteps = 200 # Faster but still good quality
|
||||
|
||||
for batch_idx, (seq_x, seq_y) in enumerate(dataloaders['test']):
|
||||
seq_x, seq_y = seq_x.to(device), seq_y.to(device)
|
||||
|
||||
pred = model(seq_x)
|
||||
|
||||
loss = criterion(pred, seq_y)
|
||||
test_loss += loss.item()
|
||||
test_samples += 1
|
||||
|
||||
all_preds.append(pred.cpu().numpy())
|
||||
all_targets.append(seq_y.cpu().numpy())
|
||||
|
||||
# Print progress every 50 batches
|
||||
if (batch_idx + 1) % 50 == 0:
|
||||
print(f" Test Batch [{batch_idx+1}/{len(dataloaders['test'])}]")
|
||||
|
||||
# Restore original timesteps
|
||||
model.diffusion.num_timesteps = original_timesteps
|
||||
|
||||
avg_test_loss = test_loss / test_samples
|
||||
|
||||
# Calculate additional metrics
|
||||
all_preds = np.concatenate(all_preds, axis=0)
|
||||
all_targets = np.concatenate(all_targets, axis=0)
|
||||
|
||||
mse = np.mean((all_preds - all_targets) ** 2)
|
||||
mae = np.mean(np.abs(all_preds - all_targets))
|
||||
rmse = np.sqrt(mse)
|
||||
|
||||
metrics = {
|
||||
'test_mse': mse,
|
||||
'test_mae': mae,
|
||||
'test_rmse': rmse,
|
||||
'test_loss': avg_test_loss
|
||||
}
|
||||
|
||||
print(f"Test Results:")
|
||||
print(f" MSE: {mse:.6f}")
|
||||
print(f" MAE: {mae:.6f}")
|
||||
print(f" RMSE: {rmse:.6f}")
|
||||
|
||||
# Log final results
|
||||
swanlab.log(metrics)
|
||||
|
||||
swanlab.finish()
|
||||
|
||||
return model, metrics
|
329
train_xpatch_sparse_multi_datasets.py
Normal file
329
train_xpatch_sparse_multi_datasets.py
Normal file
@ -0,0 +1,329 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Training script for xPatch_SparseChannel model on multiple datasets.
|
||||
Supports Weather, Traffic, Electricity, Exchange, and ILI datasets with sigmoid learning rate adjustment.
|
||||
"""
|
||||
|
||||
import os
|
||||
import math
|
||||
import argparse
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from train.train import train_forecasting_model
|
||||
from models.xPatch_SparseChannel.xPatch import Model as xPatchSparseChannel
|
||||
|
||||
# Dataset configurations
|
||||
DATASET_CONFIGS = {
|
||||
'Weather': {
|
||||
'csv_file': 'weather.csv',
|
||||
'enc_in': 21,
|
||||
'batch_size': 2048,
|
||||
'learning_rate': 0.0005,
|
||||
'target': 'OT',
|
||||
'data_path': 'weather/weather.csv',
|
||||
'seq_len': 96,
|
||||
'pred_lengths': [96, 192, 336, 720]
|
||||
},
|
||||
'Traffic': {
|
||||
'csv_file': 'traffic.csv',
|
||||
'enc_in': 862,
|
||||
'batch_size': 32,
|
||||
'learning_rate': 0.002,
|
||||
'target': 'OT',
|
||||
'data_path': 'traffic/traffic.csv',
|
||||
'seq_len': 96,
|
||||
'pred_lengths': [96, 192, 336, 720]
|
||||
},
|
||||
'Electricity': {
|
||||
'csv_file': 'electricity.csv',
|
||||
'enc_in': 321,
|
||||
'batch_size': 32,
|
||||
'learning_rate': 0.001,
|
||||
'target': 'OT',
|
||||
'data_path': 'electricity/electricity.csv',
|
||||
'seq_len': 96,
|
||||
'pred_lengths': [96, 192, 336, 720]
|
||||
},
|
||||
'Exchange': {
|
||||
'csv_file': 'exchange_rate.csv',
|
||||
'enc_in': 8,
|
||||
'batch_size': 128,
|
||||
'learning_rate': 0.00005,
|
||||
'target': 'OT',
|
||||
'data_path': 'exchange_rate/exchange_rate.csv',
|
||||
'seq_len': 96,
|
||||
'pred_lengths': [96, 192, 336, 720]
|
||||
},
|
||||
'ILI': {
|
||||
'csv_file': 'national_illness.csv',
|
||||
'enc_in': 7,
|
||||
'batch_size': 32,
|
||||
'learning_rate': 0.01,
|
||||
'target': 'ot',
|
||||
'data_path': 'illness/national_illness.csv',
|
||||
'seq_len': 36,
|
||||
'pred_lengths': [24, 36, 48, 60]
|
||||
}
|
||||
}
|
||||
|
||||
class Args:
|
||||
"""Configuration class for xPatch_SparseChannel model parameters."""
|
||||
def __init__(self, dataset_name, pred_len):
|
||||
dataset_config = DATASET_CONFIGS[dataset_name]
|
||||
|
||||
# Model architecture parameters
|
||||
self.task_name = 'long_term_forecast'
|
||||
self.seq_len = dataset_config['seq_len'] # Use dataset-specific seq_len
|
||||
self.label_len = self.seq_len // 2 # Half of seq_len as label length
|
||||
self.pred_len = pred_len
|
||||
self.enc_in = dataset_config['enc_in']
|
||||
self.c_out = dataset_config['enc_in']
|
||||
|
||||
# xPatch specific parameters from reference
|
||||
self.patch_len = 16 # patch length
|
||||
self.stride = 8 # stride
|
||||
self.padding_patch = 'end' # padding on the end
|
||||
|
||||
# Moving Average parameters
|
||||
self.ma_type = 'ema' # moving average type
|
||||
self.alpha = 0.3 # alpha parameter for EMA
|
||||
self.beta = 0.3 # beta parameter for DEMA
|
||||
|
||||
# RevIN normalization
|
||||
self.revin = 1 # RevIN; True 1 False 0
|
||||
|
||||
# Time features (not used by xPatch but required by data loader)
|
||||
self.embed = 'timeF' # Time feature embedding type
|
||||
self.freq = 'h' # Frequency for time features (hourly)
|
||||
|
||||
# Dataset specific parameters
|
||||
self.data = 'custom'
|
||||
self.root_path = './data/'
|
||||
self.data_path = dataset_config['data_path']
|
||||
self.features = 'M' # Multivariate prediction
|
||||
self.target = dataset_config['target'] # Target column
|
||||
self.train_only = False
|
||||
|
||||
# Required for dataflow - will be set by config
|
||||
self.batch_size = dataset_config['batch_size']
|
||||
self.num_workers = 8 # Will be overridden by config
|
||||
|
||||
print(f"xPatch_SparseChannel Model configuration for {dataset_name}:")
|
||||
print(f" - Input channels (C): {self.enc_in}")
|
||||
print(f" - Patch length: {self.patch_len}")
|
||||
print(f" - Stride: {self.stride}")
|
||||
print(f" - Sequence length: {self.seq_len}") # Now dataset-specific
|
||||
print(f" - Prediction length: {pred_len}")
|
||||
print(f" - Moving average type: {self.ma_type}")
|
||||
print(f" - Alpha: {self.alpha}")
|
||||
print(f" - Beta: {self.beta}")
|
||||
print(f" - RevIN: {self.revin}")
|
||||
print(f" - Target: {self.target}")
|
||||
print(f" - Batch size: {self.batch_size}")
|
||||
|
||||
def create_xpatch_sparse_model(args):
|
||||
"""Create xPatch_SparseChannel model with given configuration."""
|
||||
def model_constructor():
|
||||
return xPatchSparseChannel(args)
|
||||
return model_constructor
|
||||
|
||||
def train_single_dataset(dataset_name, pred_len, model_args, cmd_args, use_ps_loss=True):
|
||||
"""Train xPatch_SparseChannel on specified dataset with given prediction length."""
|
||||
|
||||
dataset_config = DATASET_CONFIGS[dataset_name]
|
||||
|
||||
# Update args for current prediction length
|
||||
model_args.pred_len = pred_len
|
||||
|
||||
# Update dataflow parameters from command line args
|
||||
model_args.num_workers = cmd_args.num_workers
|
||||
|
||||
# Create model constructor
|
||||
model_constructor = create_xpatch_sparse_model(model_args)
|
||||
|
||||
# Training configuration with dataset-specific parameters
|
||||
config = {
|
||||
'learning_rate': dataset_config['learning_rate'], # Dataset-specific learning rate
|
||||
'batch_size': dataset_config['batch_size'], # Dataset-specific batch size
|
||||
'weight_decay': 1e-4,
|
||||
'dataset': dataset_name,
|
||||
'pred_len': pred_len,
|
||||
'seq_len': model_args.seq_len,
|
||||
'patch_len': model_args.patch_len,
|
||||
'stride': model_args.stride,
|
||||
'ma_type': model_args.ma_type,
|
||||
'use_ps_loss': use_ps_loss,
|
||||
'num_workers': cmd_args.num_workers,
|
||||
'pin_memory': True,
|
||||
'persistent_workers': True
|
||||
}
|
||||
|
||||
# Project name for tracking
|
||||
loss_suffix = "_PSLoss" if use_ps_loss else "_MSE"
|
||||
project_name = f"xPatch_SparseChannel_{dataset_name}_pred{pred_len}{loss_suffix}_sigmoid"
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Training {dataset_name} with prediction length {pred_len}")
|
||||
print(f"Model: xPatch_SparseChannel")
|
||||
print(f"Loss function: {'PS_Loss' if use_ps_loss else 'MSE'}")
|
||||
print(f"Learning rate: {dataset_config['learning_rate']}")
|
||||
print(f"Batch size: {dataset_config['batch_size']}")
|
||||
print(f"Features: {dataset_config['enc_in']}")
|
||||
print(f"Data path: {model_args.root_path}{model_args.data_path}")
|
||||
print(f"LR adjustment: sigmoid")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# Train the model
|
||||
try:
|
||||
model, metrics = train_forecasting_model(
|
||||
model_constructor=model_constructor,
|
||||
data_path=f"{model_args.root_path}{model_args.data_path}",
|
||||
project_name=project_name,
|
||||
config=config,
|
||||
early_stopping_patience=5,
|
||||
max_epochs=100,
|
||||
checkpoint_dir="./checkpoints",
|
||||
log_interval=50,
|
||||
use_x_mark=False, # xPatch_SparseChannel doesn't use time features
|
||||
use_ps_loss=use_ps_loss,
|
||||
ps_lambda=cmd_args.ps_lambda,
|
||||
patch_len_threshold=64,
|
||||
use_gdw=True,
|
||||
dataset_mode="dataflow",
|
||||
dataflow_args=model_args,
|
||||
lr_adjust_strategy="sigmoid" # Use sigmoid learning rate adjustment
|
||||
)
|
||||
|
||||
print(f"Training completed for {project_name}")
|
||||
if use_ps_loss:
|
||||
print(f"Final validation MSE: {metrics.get('final_val_mse', 'N/A'):.6f}")
|
||||
else:
|
||||
print(f"Final validation MSE: {metrics.get('final_val_loss', 'N/A'):.6f}")
|
||||
|
||||
return model, metrics
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error training {project_name}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return None, None
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Train xPatch_SparseChannel on multiple datasets with sigmoid LR adjustment')
|
||||
parser.add_argument('--datasets', nargs='+', type=str,
|
||||
default=['Weather', 'Traffic', 'Electricity', 'Exchange', 'ILI'],
|
||||
choices=['Weather', 'Traffic', 'Electricity', 'Exchange', 'ILI'],
|
||||
help='List of datasets to train on')
|
||||
parser.add_argument('--use_ps_loss', action='store_true', default=True,
|
||||
help='Use PS_Loss instead of MSE')
|
||||
parser.add_argument('--ps_lambda', type=float, default=5.0,
|
||||
help='Weight for PS loss component')
|
||||
parser.add_argument('--device', type=str, default=None,
|
||||
help='Device to use for training (cuda/cpu)')
|
||||
parser.add_argument('--num_workers', type=int, default=8,
|
||||
help='Number of data loading workers')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("xPatch_SparseChannel Multi-Dataset Training Script with Sigmoid LR Adjustment")
|
||||
print("=" * 80)
|
||||
print(f"Datasets: {args.datasets}")
|
||||
print(f"Use PS_Loss: {args.use_ps_loss}")
|
||||
print(f"PS_Lambda: {args.ps_lambda}")
|
||||
print(f"Number of workers: {args.num_workers}")
|
||||
print(f"Learning rate adjustment: sigmoid")
|
||||
|
||||
# Display dataset configurations
|
||||
print("\nDataset Configurations:")
|
||||
for dataset in args.datasets:
|
||||
config = DATASET_CONFIGS[dataset]
|
||||
print(f" {dataset}:")
|
||||
print(f" - Features: {config['enc_in']}")
|
||||
print(f" - Batch size: {config['batch_size']}")
|
||||
print(f" - Learning rate: {config['learning_rate']}")
|
||||
print(f" - Sequence length: {config['seq_len']}")
|
||||
print(f" - Prediction lengths: {config['pred_lengths']}")
|
||||
print(f" - Data path: {config['data_path']}")
|
||||
|
||||
# Check if data files exist
|
||||
missing_datasets = []
|
||||
for dataset in args.datasets:
|
||||
config = DATASET_CONFIGS[dataset]
|
||||
data_path = f"./data/{config['data_path']}"
|
||||
if not os.path.exists(data_path):
|
||||
missing_datasets.append(f"{dataset}: '{data_path}'")
|
||||
|
||||
if missing_datasets:
|
||||
print(f"\nError: The following dataset files were not found:")
|
||||
for missing in missing_datasets:
|
||||
print(f" - {missing}")
|
||||
print("Please ensure all dataset files are available in the data/ directory.")
|
||||
return
|
||||
|
||||
# Set device
|
||||
if args.device is None:
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
else:
|
||||
device = args.device
|
||||
print(f"\nUsing device: {device}")
|
||||
|
||||
# Training results storage
|
||||
all_results = {}
|
||||
|
||||
# Train on each dataset
|
||||
for dataset in args.datasets:
|
||||
print(f"\n{'#'*80}")
|
||||
print(f"STARTING TRAINING ON {dataset.upper()} DATASET")
|
||||
print(f"{'#'*80}")
|
||||
|
||||
all_results[dataset] = {}
|
||||
config = DATASET_CONFIGS[dataset]
|
||||
|
||||
# Train on each prediction length for this dataset
|
||||
for pred_len in config['pred_lengths']: # Use dataset-specific prediction lengths
|
||||
# Create model configuration for current dataset
|
||||
model_args = Args(
|
||||
dataset_name=dataset,
|
||||
pred_len=pred_len
|
||||
)
|
||||
|
||||
# Train the model
|
||||
model, metrics = train_single_dataset(
|
||||
dataset_name=dataset,
|
||||
pred_len=pred_len,
|
||||
model_args=model_args,
|
||||
cmd_args=args,
|
||||
use_ps_loss=args.use_ps_loss
|
||||
)
|
||||
|
||||
# Store results
|
||||
all_results[dataset][pred_len] = {
|
||||
'model': model,
|
||||
'metrics': metrics,
|
||||
'data_path': f"./data/{config['data_path']}"
|
||||
}
|
||||
|
||||
# Print comprehensive summary
|
||||
print("\n" + "=" * 100)
|
||||
print("COMPREHENSIVE TRAINING SUMMARY")
|
||||
print("=" * 100)
|
||||
|
||||
for dataset in args.datasets:
|
||||
config = DATASET_CONFIGS[dataset]
|
||||
print(f"\n{dataset} (Features: {config['enc_in']}, Batch: {config['batch_size']}, LR: {config['learning_rate']}, Seq: {config['seq_len']}):")
|
||||
for pred_len in all_results[dataset]:
|
||||
result = all_results[dataset][pred_len]
|
||||
if result['metrics'] is not None:
|
||||
if args.use_ps_loss:
|
||||
mse = result['metrics'].get('final_val_mse', 'N/A')
|
||||
else:
|
||||
mse = result['metrics'].get('final_val_loss', 'N/A')
|
||||
print(f" Pred Length {pred_len}: MSE = {mse}")
|
||||
else:
|
||||
print(f" Pred Length {pred_len}: Training failed")
|
||||
|
||||
print(f"\nAll models saved in: ./checkpoints/")
|
||||
print("All datasets training completed!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user