feat: add DiffusionTimeSeries and iTransformer models, introduce xPatch_SparseChannel

This commit is contained in:
game-loader
2025-08-26 20:53:35 +08:00
parent 44bd5c8f29
commit c3713f5c0b
11 changed files with 1528 additions and 41 deletions

View File

@ -273,15 +273,21 @@ class Dataset_Custom(Dataset):
self.data_stamp = data_stamp
def __getitem__(self, index):
# 1. 定义输入序列 seq_x 的起止位置
s_begin = index
s_end = s_begin + self.seq_len
r_begin = s_end - self.label_len
r_end = r_begin + self.label_len + self.pred_len
# 2. 定义目标序列 seq_y 的起止位置
# seq_y 的开始 (r_begin) 就是 seq_x 的结束 (s_end)
r_begin = s_end
# seq_y 的结束 (r_end) 是其开始位置加上预测长度 (pred_len)
r_end = r_begin + self.pred_len
# 3. 根据起止位置切片数据
seq_x = self.data_x[s_begin:s_end]
seq_y = self.data_y[r_begin:r_end]
seq_x_mark = self.data_stamp[s_begin:s_end]
seq_y_mark = self.data_stamp[r_begin:r_end]
seq_x = seq_x.astype('float32')
seq_y = seq_y.astype('float32')
return seq_x, seq_y, seq_x_mark, seq_y_mark

View File

View File

@ -0,0 +1,323 @@
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
# ----------------------- 工具:构造/缩放线性 betas -----------------------
def make_linear_betas(T: int, beta_start=1e-4, beta_end=2e-2, device='cpu'):
return torch.linspace(beta_start, beta_end, T, device=device)
def cumprod_from_betas(betas: torch.Tensor):
alphas = 1.0 - betas
return torch.cumprod(alphas, dim=0) # shape [T]
@torch.no_grad()
def scale_betas_to_target_cumprod(betas: torch.Tensor, target_cumprod: float, max_scale: float = 100.0):
"""
给定一段 betas[1..T],寻找缩放系数 s>0使得 ∏(1 - s*beta_i) = target_cumprod
用二分法在 (0, s_max) 上搜索。确保 0 < s*beta_i < 1。
"""
device = betas.device
eps = 1e-12
s_low = 0.0
s_high = min(max_scale, (1.0 - 1e-6) / (betas.max().item() + eps)) # 使 1 - s*beta > 0
def cumprod_with_scale(s: float):
a = (1.0 - betas * s).clamp(min=1e-6, max=1.0-1e-6)
return torch.cumprod(a, dim=0)[-1].item()
# 若不缩放已接近目标,直接返回
base = cumprod_with_scale(1.0)
if abs(base - target_cumprod) / max(target_cumprod, 1e-12) < 1e-6:
return betas
# 目标在 (0, s_high) 内单调可达,进行二分
for _ in range(60):
mid = 0.5 * (s_low + s_high)
val = cumprod_with_scale(mid)
if val > target_cumprod:
# 乘子太小(噪声弱),需要更大 s
s_low = mid
else:
s_high = mid
s_best = 0.5 * (s_low + s_high)
return (betas * s_best).clamp(min=1e-8, max=1-1e-6)
# ------------------------------ DiT Blocks --------------------------------
class DiTBlock(nn.Module):
def __init__(self, dim: int, heads: int, mlp_ratio=4.0):
super().__init__()
self.ln1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(dim, heads, batch_first=True)
self.ln2 = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio)),
nn.GELU(),
nn.Linear(int(dim * mlp_ratio), dim),
)
def forward(self, x):
# x: [B, C_tokens, D]
h = self.attn(self.ln1(x), self.ln1(x), self.ln1(x))[0]
x = x + h
x = x + self.mlp(self.ln2(x))
return x
class DiTChannelTokens(nn.Module):
"""
Token = 一个通道(变量)。
对于每个通道,输入是 [L] 的时间向量;我们用两条投影:
- W_x : 把 x_t 的时间向量投影成 token 向量
- W_n : 把 noise-level来自 schedule 的 ā 或 b̄ 的时间向量)投影成 token 偏置
注意:不再使用可学习的 t-embedding噪声条件完全由 noise map 决定。
"""
def __init__(self, L: int, C: int, dim: int = 256, depth: int = 8, heads: int = 8):
super().__init__()
self.L = L
self.C = C
self.dim = dim
# 通道嵌入(可选,用于区分变量)
self.channel_embed = nn.Parameter(torch.randn(C, dim) * 0.02)
# 将每个通道的时间序列映射到 token
self.proj_x = nn.Linear(L, dim, bias=False)
# 将每个通道的逐时间噪声强度(例如 [sqrt(ā), sqrt(1-ā)] 拼接后经一层线性)
self.proj_noise = nn.Linear(L, dim, bias=True)
self.blocks = nn.ModuleList([DiTBlock(dim, heads) for _ in range(depth)])
self.ln_f = nn.LayerNorm(dim)
# 反投影回时间长度 L预测 ε(每通道独立投影)
self.head = nn.Linear(dim, L, bias=False)
def forward(self, x_t: torch.Tensor, noise_feat: torch.Tensor):
"""
x_t : [B, L, C]
noise_feat: [B, L, C] (建议传入 sqrt(ā) 或 concat 后先合并到 L 维度,这里用一条投影即可)
返回 ε̂ : [B, L, C]
"""
B, L, C = x_t.shape
assert L == self.L and C == self.C
# 逐通道映射成 token
# 把 (B, L, C) 变 (B, C, L) 再线性
x_tc = x_t.permute(0, 2, 1) # [B, C, L]
n_tc = noise_feat.permute(0, 2, 1) # [B, C, L]
tok = self.proj_x(x_tc) + self.proj_noise(n_tc) # [B, C, D]
tok = tok + self.channel_embed.unsqueeze(0) # broadcast [1, C, D]
for blk in self.blocks:
tok = blk(tok) # [B, C, D]
tok = self.ln_f(tok)
out = self.head(tok) # [B, C, L]
eps_pred = out.permute(0, 2, 1) # [B, L, C]
return eps_pred
# ----------------------- RAD 两阶段扩散通道为token -----------------------
class RADChannelDiT(nn.Module):
def __init__(self,
past_len: int,
future_len: int,
channels: int,
T: int = 1000,
T1_ratio: float = 0.7,
model_dim: int = 256,
depth: int = 8,
heads: int = 8,
beta_start: float = 1e-4,
beta_end: float = 2e-2,
use_cosine_target: bool = True):
"""
- 训练两阶段Phase-1 + Phase-2t 从 [1..T] 均匀采样
- 推理:仅使用 Phase-1t: T1→1只更新未来区域
- Token=通道,每个 token 见到整个时间轴 + 噪声强度时间向量
"""
super().__init__()
self.P = past_len
self.H = future_len
self.C = channels
self.L = past_len + future_len
self.T = T
self.T1 = max(1, int(T * T1_ratio))
self.T2 = T - self.T1
assert self.T2 >= 1, "T1_ratio 不能太大,至少留下 1 步给 Phase-2"
device = torch.device('cpu')
# 目标 ā_T用于把两段线性 schedule 归一到同一最终噪声强度)
if use_cosine_target:
# 参考 cosine 计划,得到一条“全局目标 ā_T”
steps = T + 1
x = torch.linspace(0, T, steps, dtype=torch.float64)
s = 0.008
alphas_cum = torch.cos(((x / T) + s) / (1 + s) * math.pi / 2) ** 2
alphas_cum = alphas_cum / alphas_cum[0]
a_bar_target_T = float(alphas_cum[-1])
else:
# 直接用 DDPM 线性 beta 的结果作为目标
betas_full = make_linear_betas(T, beta_start, beta_end, device)
a_bar_target_T = cumprod_from_betas(betas_full)[-1].item()
# Phase-1 & Phase-2 原始线性 beta
betas1 = make_linear_betas(self.T1, beta_start, beta_end, device)
betas2 = make_linear_betas(self.T2, beta_start, beta_end, device)
# 首先不缩放,计算 ā1[T1], ā2[T2]
a_bar1 = cumprod_from_betas(betas1) # shape [T1]
a_bar2 = cumprod_from_betas(betas2) # shape [T2]
# 缩放 Phase-2 的 betas使 ā1[T1] * ā2'[T2] = 目标 ā_T
target_a2 = a_bar_target_T / (a_bar1[-1].item() + 1e-12)
betas2 = scale_betas_to_target_cumprod(betas2, target_a2)
# 重新计算
# a_bar1 = cumprod_from_betas(betas1).float() # [T1]
a_bar2 = cumprod_from_betas(betas2).float() # [T2]
self.register_buffer("betas1", betas1.float())
self.register_buffer("betas2", betas2.float())
self.register_buffer("alphas1", 1.0 - betas1.float())
self.register_buffer("alphas2", 1.0 - betas2.float())
self.register_buffer("a_bar1", a_bar1)
self.register_buffer("a_bar2", a_bar2)
self.register_buffer("a_bar_target_T", torch.tensor(a_bar_target_T, dtype=torch.float32))
# Backbone: token=通道
self.backbone = DiTChannelTokens(L=self.L, C=self.C, dim=model_dim, depth=depth, heads=heads)
# ------------------------ 内部:构造 mask & āt,i ------------------------
def _mask_future(self, B, device):
# mask: 未来区域=1历史=0形状 [B, L, C](与网络输入 [B,L,C] 对齐)
m = torch.zeros(B, self.L, self.C, device=device)
m[:, self.P:, :] = 1.0
return m
def _a_bar_map_at_t(self, t_scalar: int, B: int, device, mask_future: torch.Tensor):
"""
构造逐像素 āt,i形状 [B, L, C]
- 若 t<=T1未来区域用 ā1[t],历史区域=1
- 若 t> T1未来区域固定 ā1[T1],历史区域用 ā2[t-T1]
"""
if t_scalar <= self.T1:
a_future = self.a_bar1[t_scalar - 1] # 索引从 0 开始
a_past = torch.tensor(1.0, device=device)
else:
a_future = self.a_bar1[-1]
a_past = self.a_bar2[t_scalar - self.T1 - 1]
a_future_map = torch.full((B, self.L, self.C), float(a_future.item()), device=device)
a_past_map = torch.full((B, self.L, self.C), float(a_past.item()), device=device)
a_map = a_past_map * (1 - mask_future) + a_future_map * mask_future
return a_map # [B, L, C]
# ----------------------------- 前向训练 -----------------------------
def forward(self, x_hist: torch.Tensor, x_future: torch.Tensor) -> Tuple[torch.Tensor, dict]:
"""
x_hist : [B, P, C]
x_future : [B, H, C]
训练:采样 t∈[1..T],构造两阶段 āt,i边际加噪 xt并用逐通道 token 的 DiT 预测 ε
"""
B = x_hist.size(0)
device = x_hist.device
x0 = torch.cat([x_hist, x_future], dim=1) # [B, L, C]
# 采样训练步 t (1..T)
t = torch.randint(1, self.T + 1, (B,), device=device, dtype=torch.long)
# 构造 mask 和逐像素 āt,i
mask_fut = self._mask_future(B, device) # [B, L, C]
# 逐样本构造 āt,i不同样本 t 不同,只能用循环或向量化 trickB 通常不大for 循环即可)
a_bar_map = torch.stack([self._a_bar_map_at_t(int(tt.item()), 1, device, mask_fut[0:1])
for tt in t], dim=0).squeeze(1) # [B,L,C]
# 边际加噪
eps = torch.randn_like(x0) # [B,L,C]
x_t = a_bar_map.sqrt() * x0 + (1.0 - a_bar_map).sqrt() * eps
# Spatial Noise Embedding完全由 schedule 决定
# 传入每个像素的 √ā 和 √(1-ā)(或任选其一);这里用 √ā
noise_feat = a_bar_map.sqrt() # [B,L,C]
# 预测 ε
eps_pred = self.backbone(x_t, noise_feat) # [B,L,C]
loss = F.mse_loss(eps_pred, eps)
return loss, {'t_mean': t.float().mean().item()}
# ----------------------------- 采样推理 -----------------------------
@torch.no_grad()
def sample(self, x_hist: torch.Tensor, steps: Optional[int] = None) -> torch.Tensor:
"""
仅 Phase-1 推理t = T1..1,只更新未来区域,历史保持观测值
x_hist : [B,P,C]
return : [B,H,C]
"""
B = x_hist.size(0)
device = x_hist.device
mask_fut = self._mask_future(B, device) # [B,L,C]
# 初始化 x历史=观测,未来=高斯噪声
x = torch.zeros(B, self.L, self.C, device=device)
x[:, :self.P, :] = x_hist
x[:, self.P:, :] = torch.randn(B, self.H, self.C, device=device)
# 支持子采样:把 [T1..1] 均匀下采样到 steps
T1 = self.T1
steps = steps if steps is not None else T1
steps = max(1, min(steps, T1))
ts = torch.linspace(T1, 1, steps, device=device).long().tolist()
# 为 DDPM 更新需要 α_t, β_t仅对未来区域定义
alphas1 = self.alphas1 # [T1]
betas1 = self.betas1
a_bar1 = self.a_bar1
for idx, t_scalar in enumerate(ts):
# 当前 āt,iPhase-1历史=1未来=ā1[t]
a_bar_map = self._a_bar_map_at_t(int(t_scalar), B, device, mask_fut) # [B,L,C]
# 网络条件:用 √ā 作为噪声嵌入
noise_feat = a_bar_map.sqrt()
# 预测 ε
eps_pred = self.backbone(x, noise_feat) # [B,L,C]
# 对未来区域做 DDPM 一步(历史区保持原值)
# 标准 DDPM 公式(像素在未来区域共享同一 α_t、β_t
t_idx = t_scalar - 1
alpha_t = alphas1[t_idx] # 标量
beta_t = betas1[t_idx]
a_bar_t = a_bar1[t_idx]
if t_scalar > 1:
a_bar_prev = a_bar1[t_idx - 1]
else:
a_bar_prev = torch.tensor(1.0, device=device)
# x0 预测(仅用于推导均值,也可直接用μ公式)
x0_pred = (x - (1.0 - a_bar_t).sqrt() * eps_pred) / (a_bar_t.sqrt() + 1e-8)
# 均值μ_t = 1/sqrt(α_t) * (x_t - β_t / sqrt(1 - ā_t) * ε̂)
mean = (x - (beta_t / (1.0 - a_bar_t).sqrt()) * eps_pred) / (alpha_t.sqrt() + 1e-8)
# 采样噪声
if t_scalar > 1:
z = torch.randn_like(x)
else:
z = torch.zeros_like(x)
# 方差项DDPMσ_t = sqrt(β_t)
x_next = mean + z * beta_t.sqrt()
# 仅替换未来区域
x = x * (1 - mask_fut) + x_next * mask_fut
# 历史强制为观测
x[:, :self.P, :] = x_hist
return x[:, self.P:, :] # [B,H,C]

View File

@ -0,0 +1,132 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from layers.Transformer_EncDec import Encoder, EncoderLayer
from layers.SelfAttention_Family import FullAttention, AttentionLayer
from layers.Embed import DataEmbedding_inverted
import numpy as np
class Model(nn.Module):
"""
Paper link: https://arxiv.org/abs/2310.06625
"""
def __init__(self, configs):
super(Model, self).__init__()
self.task_name = configs.task_name
self.seq_len = configs.seq_len
self.pred_len = configs.pred_len
# Embedding
self.enc_embedding = DataEmbedding_inverted(configs.seq_len, configs.d_model, configs.embed, configs.freq,
configs.dropout)
# Encoder
self.encoder = Encoder(
[
EncoderLayer(
AttentionLayer(
FullAttention(False, configs.factor, attention_dropout=configs.dropout,
output_attention=False), configs.d_model, configs.n_heads),
configs.d_model,
configs.d_ff,
dropout=configs.dropout,
activation=configs.activation
) for l in range(configs.e_layers)
],
norm_layer=torch.nn.LayerNorm(configs.d_model)
)
# Decoder
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
self.projection = nn.Linear(configs.d_model, configs.pred_len, bias=True)
if self.task_name == 'imputation':
self.projection = nn.Linear(configs.d_model, configs.seq_len, bias=True)
if self.task_name == 'anomaly_detection':
self.projection = nn.Linear(configs.d_model, configs.seq_len, bias=True)
if self.task_name == 'classification':
self.act = F.gelu
self.dropout = nn.Dropout(configs.dropout)
self.projection = nn.Linear(configs.d_model * configs.enc_in, configs.num_class)
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
# Normalization from Non-stationary Transformer
means = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - means
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
x_enc /= stdev
_, _, N = x_enc.shape
# Embedding
enc_out = self.enc_embedding(x_enc, x_mark_enc)
enc_out, attns = self.encoder(enc_out, attn_mask=None)
dec_out = self.projection(enc_out).permute(0, 2, 1)[:, :, :N]
# De-Normalization from Non-stationary Transformer
dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
return dec_out
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
# Normalization from Non-stationary Transformer
means = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - means
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
x_enc /= stdev
_, L, N = x_enc.shape
# Embedding
enc_out = self.enc_embedding(x_enc, x_mark_enc)
enc_out, attns = self.encoder(enc_out, attn_mask=None)
dec_out = self.projection(enc_out).permute(0, 2, 1)[:, :, :N]
# De-Normalization from Non-stationary Transformer
dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, L, 1))
dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, L, 1))
return dec_out
def anomaly_detection(self, x_enc):
# Normalization from Non-stationary Transformer
means = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - means
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
x_enc /= stdev
_, L, N = x_enc.shape
# Embedding
enc_out = self.enc_embedding(x_enc, None)
enc_out, attns = self.encoder(enc_out, attn_mask=None)
dec_out = self.projection(enc_out).permute(0, 2, 1)[:, :, :N]
# De-Normalization from Non-stationary Transformer
dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, L, 1))
dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, L, 1))
return dec_out
def classification(self, x_enc, x_mark_enc):
# Embedding
enc_out = self.enc_embedding(x_enc, None)
enc_out, attns = self.encoder(enc_out, attn_mask=None)
# Output
output = self.act(enc_out) # the output transformer encoder/decoder embeddings don't include non-linearity
output = self.dropout(output)
output = output.reshape(output.shape[0], -1) # (batch_size, c_in * d_model)
output = self.projection(output) # (batch_size, num_classes)
return output
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
return dec_out[:, -self.pred_len:, :] # [B, L, D]
if self.task_name == 'imputation':
dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
return dec_out # [B, L, D]
if self.task_name == 'anomaly_detection':
dec_out = self.anomaly_detection(x_enc)
return dec_out # [B, L, D]
if self.task_name == 'classification':
dec_out = self.classification(x_enc, x_mark_enc)
return dec_out # [B, N]
return None

View File

@ -0,0 +1,3 @@
from .xPatch import Model
__all__ = ['Model']

View File

@ -0,0 +1,66 @@
"""
SeasonPatch = PatchTST (CI) + ChannelGraphMixer + 线性预测头
"""
import torch
import torch.nn as nn
from layers.PatchTST_layers import positional_encoding # 已存在
from layers.mixer import HierarchicalGraphMixer # 刚才创建
# ------ PatchTST CI 编码器(与官方实现等价, 但去掉 head 便于插入 Mixer) ------
class _Encoder(nn.Module):
def __init__(self, c_in, seq_len, patch_len, stride,
d_model=128, n_layers=3, n_heads=16, dropout=0.):
super().__init__()
self.patch_len = patch_len
self.stride = stride
self.patch_num = (seq_len - patch_len)//stride + 1
self.proj = nn.Linear(patch_len, d_model)
self.pos = positional_encoding('zeros', True, self.patch_num, d_model)
self.drop = nn.Dropout(dropout)
from layers.PatchTST_backbone import TSTEncoder # 与官方同名
self.encoder = TSTEncoder(self.patch_num, d_model, n_heads,
d_ff=d_model*2, dropout=dropout,
n_layers=n_layers, norm='LayerNorm')
def forward(self, x): # [B,C,L]
x = x.unfold(-1, self.patch_len, self.stride) # [B,C,patch_num,patch_len]
B,C,N,P = x.shape
z = self.proj(x) # [B,C,N,d]
z = z.contiguous().view(B*C, N, -1) # [B*C,N,d]
z = self.drop(z + self.pos)
z = self.encoder(z) # [B*C,N,d]
return z.view(B,C,N,-1) # [B,C,N,d]
# ------------------------- SeasonPatch -----------------------------
class SeasonPatch(nn.Module):
def __init__(self,
c_in: int,
seq_len: int,
pred_len: int,
patch_len: int,
stride: int,
k_graph: int = 5,
d_model: int = 128,
revin: bool = True):
super().__init__()
self.encoder = _Encoder(c_in, seq_len, patch_len, stride,
d_model=d_model)
self.mixer = HierarchicalGraphMixer(c_in, dim=d_model, k=k_graph)
# Calculate actual number of patches
patch_num = (seq_len - patch_len) // stride + 1
self.head = nn.Linear(patch_num * d_model, pred_len)
def forward(self, x): # x [B,L,C]
x = x.permute(0,2,1) # → [B,C,L]
z = self.encoder(x) # [B,C,N,d]
B,C,N,D = z.shape
# 通道独立 -> 稀疏跨通道注入
z_mix = self.mixer(z).view(B,C,N*D) # [B,C,N,d]
y_pred = self.head(z_mix) # [B,C,T]
return y_pred

View File

@ -0,0 +1,56 @@
import torch
import torch.nn as nn
import math
from layers.decomp import DECOMP
from .xPatch_SparseChannel import Network
from layers.revin import RevIN
class Model(nn.Module):
def __init__(self, configs):
super(Model, self).__init__()
# Parameters
seq_len = configs.seq_len # lookback window L
pred_len = configs.pred_len # prediction length (96, 192, 336, 720)
c_in = configs.enc_in # input channels
# Patching
patch_len = configs.patch_len
stride = configs.stride
padding_patch = configs.padding_patch
# Normalization
self.revin = configs.revin
self.revin_layer = RevIN(c_in,affine=True,subtract_last=False)
# Moving Average
self.ma_type = configs.ma_type
alpha = configs.alpha # smoothing factor for EMA (Exponential Moving Average)
beta = configs.beta # smoothing factor for DEMA (Double Exponential Moving Average)
self.decomp = DECOMP(self.ma_type, alpha, beta)
self.net = Network(seq_len, pred_len, patch_len, stride, padding_patch,c_in)
# self.net_mlp = NetworkMLP(seq_len, pred_len) # For ablation study with MLP-only stream
# self.net_cnn = NetworkCNN(seq_len, pred_len, patch_len, stride, padding_patch) # For ablation study with CNN-only stream
def forward(self, x):
# x: [Batch, Input, Channel]
# Normalization
if self.revin:
x = self.revin_layer(x, 'norm')
if self.ma_type == 'reg': # If no decomposition, directly pass the input to the network
x = self.net(x, x)
# x = self.net_mlp(x) # For ablation study with MLP-only stream
# x = self.net_cnn(x) # For ablation study with CNN-only stream
else:
seasonal_init, trend_init = self.decomp(x)
x = self.net(seasonal_init, trend_init)
# Denormalization
if self.revin:
x = self.revin_layer(x, 'denorm')
return x

View File

@ -0,0 +1,63 @@
import torch
from torch import nn
from .patchtst_ci import SeasonPatch # <<< 新导入
class Network(nn.Module):
"""
trend : 原 MLP 线性流 (完全保留)
season : SeasonPatch (PatchTST + Mixer)
"""
def __init__(self, seq_len, pred_len,
patch_len, stride, padding_patch,
c_in):
super().__init__()
# -------- 季节性流 ---------------
self.season_net = SeasonPatch(c_in=c_in,
seq_len=seq_len,
pred_len=pred_len,
patch_len=patch_len,
stride=stride,
k_graph=5,
d_model=128)
# --------- 线性趋势流 (原代码保持不变) ----------
self.pred_len = pred_len
self.fc5 = nn.Linear(seq_len, pred_len * 4)
self.avgpool1 = nn.AvgPool1d(kernel_size=2)
self.ln1 = nn.LayerNorm(pred_len * 2)
self.fc6 = nn.Linear(pred_len * 2, pred_len)
self.avgpool2 = nn.AvgPool1d(kernel_size=2)
self.ln2 = nn.LayerNorm(pred_len // 2)
self.fc7 = nn.Linear(pred_len // 2, pred_len)
# 流结果拼接
self.fc_final = nn.Linear(pred_len * 2, pred_len)
# ---------------- forward --------------------
def forward(self, s, t):
# 输入形状: [B,L,C]
B,L,C = s.shape
# ---------- Seasonality ------------
y_season = self.season_net(s) # [B,C,T]
# ---------- Trend (原 MLP) ----------
t = t.permute(0,2,1).reshape(B*C, L) # [B*C,L]
t = self.fc5(t)
t = self.avgpool1(t)
t = self.ln1(t)
t = self.fc6(t)
t = self.avgpool2(t)
t = self.ln2(t)
t = self.fc7(t) # [B*C,T]
y_trend = t.view(B, C, -1) # [B,C,T]
# --------- 拼接 & 输出 --------------
y = torch.cat([y_season, y_trend], dim=-1) # [B,C,2T]
y = self.fc_final(y) # [B,C,T]
y = y.permute(0,2,1) # [B,T,C]
return y

View File

@ -8,6 +8,8 @@ from torch.utils.data import DataLoader, TensorDataset
import swanlab
from typing import Dict, Any, Optional, Callable, Union, Tuple
from dataflow import data_provider
from layers.ps_loss import PSLoss
from utils.tools import adjust_learning_rate, dotdict
class EarlyStopping:
"""Early stopping to stop training when validation performance doesn't improve."""
@ -138,7 +140,9 @@ def create_data_loaders_from_dataflow(args, use_x_mark: bool = True) -> Dict[str
'test': test_loader
}
def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool = True) -> Dict[str, DataLoader]:
def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool = True,
num_workers: int = 4, pin_memory: bool = True,
persistent_workers: bool = True) -> Dict[str, DataLoader]:
"""
Create PyTorch DataLoaders from an NPZ file
@ -146,6 +150,9 @@ def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool =
data_path (str): Path to the NPZ file containing the data
batch_size (int): Batch size for the DataLoaders
use_x_mark (bool): Whether to use time features (x_mark) from the data file
num_workers (int): Number of worker processes for data loading
pin_memory (bool): Whether to pin memory for faster GPU transfer
persistent_workers (bool): Whether to keep workers alive between epochs
Returns:
Dict[str, DataLoader]: Dictionary with train, val, and test DataLoaders
@ -200,10 +207,34 @@ def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool =
val_dataset = TensorDataset(val_x, val_y)
test_dataset = TensorDataset(test_x, test_y)
# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# Create dataloaders with performance optimizations
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers if num_workers > 0 else False,
drop_last=True # Drop incomplete batches for training
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers if num_workers > 0 else False,
drop_last=False
)
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers if num_workers > 0 else False,
drop_last=False
)
return {
'train': train_loader,
@ -223,7 +254,12 @@ def train_forecasting_model(
log_interval: int = 10,
use_x_mark: bool = True,
dataset_mode: str = "npz",
dataflow_args = None
dataflow_args = None,
use_ps_loss: bool = False,
ps_lambda: float = 5.0,
patch_len_threshold: int = 64,
use_gdw: bool = True,
lr_adjust_strategy: str = "type1"
) -> Tuple[nn.Module, Dict[str, float]]:
"""
Train a time series forecasting model
@ -241,6 +277,11 @@ def train_forecasting_model(
use_x_mark (bool): Whether to use time features (x_mark) from the data file
dataset_mode (str): Dataset construction mode - "npz" or "dataflow"
dataflow_args: Arguments object for dataflow mode (required if dataset_mode="dataflow")
use_ps_loss (bool): Whether to use Patch-wise Structural (PS) loss instead of MSE
ps_lambda (float): Weight for PS loss component when combined with MSE
patch_len_threshold (int): Maximum patch length for adaptive patching
use_gdw (bool): Whether to use Gradient-based Dynamic Weighting
lr_adjust_strategy (str): Learning rate adjustment strategy - 'type1', 'type2', 'type3', 'sigmoid', 'constant', '3', '4', '5', '6'
Returns:
Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics
@ -271,7 +312,10 @@ def train_forecasting_model(
dataloaders = create_data_loaders(
data_path=data_path,
batch_size=config.get('batch_size', 32),
use_x_mark=use_x_mark
use_x_mark=use_x_mark,
num_workers=config.get('num_workers', 4),
pin_memory=config.get('pin_memory', True),
persistent_workers=config.get('persistent_workers', True)
)
# Construct the model
@ -279,14 +323,24 @@ def train_forecasting_model(
model = model.to(device)
# Define loss function and optimizer
criterion = nn.MSELoss()
if use_ps_loss:
criterion = PSLoss(
patch_len_threshold=patch_len_threshold,
lambda_ps=ps_lambda,
use_gdw=use_gdw
)
else:
criterion = nn.MSELoss()
optimizer = optim.Adam(
model.parameters(),
lr=config.get('learning_rate', 1e-3),
)
# Add learning rate scheduler to halve LR after each epoch
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
# Create args object for learning rate adjustment
lr_args = dotdict({
'learning_rate': config.get('learning_rate', 1e-3),
'lradj': lr_adjust_strategy
})
# Initialize early stopping
early_stopping = EarlyStopping(
@ -334,7 +388,11 @@ def train_forecasting_model(
# For simple models without time features
outputs = model(inputs)
loss = criterion(outputs, targets)
# Calculate loss
if use_ps_loss:
loss, loss_dict = criterion(outputs, targets, model)
else:
loss = criterion(outputs, targets)
# Backward pass and optimize
loss.backward()
@ -345,10 +403,26 @@ def train_forecasting_model(
interval_loss += loss.item()
if (batch_idx + 1) % log_interval == 0:
print(f"Batch {batch_idx+1}/{len(dataloaders['train'])}, Loss: {loss.item():.4f}")
# 计算这一个 interval 的平均损失并记录
avg_interval_loss = interval_loss / log_interval
swanlab_run.log({"batch_train_loss": avg_interval_loss})
if use_ps_loss and 'loss_dict' in locals():
print(f"Batch {batch_idx+1}/{len(dataloaders['train'])}, "
f"Total Loss: {loss.item():.4f}, "
f"MSE: {loss_dict['mse_loss']:.4f}, "
f"PS: {loss_dict['ps_loss']:.4f}")
# Log detailed loss components
swanlab_run.log({
"batch_total_loss": loss.item(),
"batch_mse_loss": loss_dict['mse_loss'],
"batch_ps_loss": loss_dict['ps_loss'],
"batch_corr_loss": loss_dict['corr_loss'],
"batch_var_loss": loss_dict['var_loss'],
"batch_mean_loss": loss_dict['mean_loss'],
"alpha": loss_dict['alpha'],
"beta": loss_dict['beta'],
"gamma": loss_dict['gamma']
})
else:
print(f"Batch {batch_idx+1}/{len(dataloaders['train'])}, Loss: {loss.item():.4f}")
swanlab_run.log({"batch_train_loss": loss.item()})
# 重置 interval loss 以进行下一次计算
interval_loss = 0.0
@ -360,6 +434,7 @@ def train_forecasting_model(
model.eval()
val_loss = 0.0
val_mse = 0.0
val_mse_criterion = nn.MSELoss() # Always use MSE for validation metrics
with torch.no_grad():
for batch_data in dataloaders['val']:
@ -381,18 +456,28 @@ def train_forecasting_model(
# For simple models without time features
outputs = model(inputs)
# Calculate loss
loss = criterion(outputs, targets)
val_loss += loss.item()
# Calculate training loss (PS or MSE)
if use_ps_loss:
loss, _ = criterion(outputs, targets, model)
val_loss += loss.item()
else:
loss = criterion(outputs, targets)
val_loss += loss.item()
# Always calculate MSE for validation metrics
mse_loss = val_mse_criterion(outputs, targets)
val_mse += mse_loss.item()
avg_val_loss = val_loss / len(dataloaders['val'])
avg_val_mse = val_mse / len(dataloaders['val'])
current_lr = optimizer.param_groups[0]['lr']
# Log metrics
metrics_dict = {
"train_loss": avg_train_loss,
"val_loss": avg_val_loss,
"val_mse": avg_val_mse,
"learning_rate": current_lr,
"epoch_time": epoch_time
}
@ -402,6 +487,7 @@ def train_forecasting_model(
print(f"Epoch {epoch+1}/{max_epochs}, "
f"Train Loss: {avg_train_loss:.4f}, "
f"Val Loss: {avg_val_loss:.4f}, "
f"Val MSE: {avg_val_mse:.4f}, "
f"LR: {current_lr:.6f}, "
f"Time: {epoch_time:.2f}s")
@ -416,16 +502,17 @@ def train_forecasting_model(
print("Early stopping triggered")
break
# Step the learning rate scheduler
scheduler.step()
# Adjust learning rate using utils.tools function
adjust_learning_rate(optimizer, epoch, lr_args)
# Load the best model
model.load_state_dict(torch.load(checkpoint_path))
# Test evaluation on the best model
# Test evaluation on the best model - Always use MSE for final evaluation
model.eval()
test_loss = 0.0
test_mse = 0.0
mse_criterion = nn.MSELoss() # Always use MSE for test evaluation
print("Evaluating on test set...")
with torch.no_grad():
@ -448,16 +535,16 @@ def train_forecasting_model(
# For simple models without time features
outputs = model(inputs)
# Calculate loss
loss = criterion(outputs, targets)
test_loss += loss.item()
# Always calculate MSE for test evaluation (for fair comparison)
mse_loss = mse_criterion(outputs, targets)
test_loss += mse_loss.item()
test_loss /= len(dataloaders['test'])
print(f"Test evaluation completed!")
print(f"Test Loss (MSE): {test_loss:.6f}")
# Final validation for consistency
# Final validation for consistency - Always use MSE for final metrics
model.eval()
final_val_loss = 0.0
final_val_mse = 0.0
@ -482,25 +569,31 @@ def train_forecasting_model(
# For simple models without time features
outputs = model(inputs)
# Calculate loss
loss = criterion(outputs, targets)
final_val_loss += loss.item()
# Always calculate MSE for final validation (for fair comparison)
mse_loss = mse_criterion(outputs, targets)
final_val_loss += mse_loss.item()
final_val_loss /= len(dataloaders['val'])
print(f"Final validation loss: {final_val_loss:.6f}")
print(f"Final validation MSE: {final_val_loss:.6f}")
print(f"Final test MSE: {test_loss:.6f}")
if use_ps_loss:
print("Note: Model was trained with PS Loss but evaluated with MSE for fair comparison")
# Log final test results to swanlab
final_metrics = {
"final_test_loss": test_loss,
"final_val_loss": final_val_loss
"final_test_mse": test_loss,
"final_val_mse": final_val_loss
}
swanlab_run.log(final_metrics)
# Update metrics with final values
# Update metrics with final values (always MSE for comparison)
metrics["final_val_loss"] = final_val_loss
metrics["final_test_loss"] = test_loss
metrics["final_val_mse"] = final_val_loss # Same as final_val_loss since we use MSE
metrics["final_test_mse"] = test_loss # Same as final_test_loss since we use MSE
# Finish the swanlab run
swanlab_run.finish()
@ -519,7 +612,8 @@ def train_classification_model(
log_interval: int = 10,
use_x_mark: bool = True,
dataset_mode: str = "npz",
dataflow_args = None
dataflow_args = None,
lr_adjust_strategy: str = "type1"
) -> Tuple[nn.Module, Dict[str, float]]:
"""
Train a time series classification model
@ -537,6 +631,7 @@ def train_classification_model(
use_x_mark (bool): Whether to use time features (x_mark) from the data file
dataset_mode (str): Dataset construction mode - "npz" or "dataflow"
dataflow_args: Arguments object for dataflow mode (required if dataset_mode="dataflow")
lr_adjust_strategy (str): Learning rate adjustment strategy - 'type1', 'type2', 'type3', 'sigmoid', 'constant', '3', '4', '5', '6'
Returns:
Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics
@ -567,7 +662,10 @@ def train_classification_model(
dataloaders = create_data_loaders(
data_path=data_path,
batch_size=config.get('batch_size', 32),
use_x_mark=use_x_mark
use_x_mark=use_x_mark,
num_workers=config.get('num_workers', 4),
pin_memory=config.get('pin_memory', True),
persistent_workers=config.get('persistent_workers', True)
)
# Construct the model
@ -582,8 +680,11 @@ def train_classification_model(
weight_decay=config.get('weight_decay', 1e-4)
)
# Add learning rate scheduler to halve LR after each epoch
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5)
# Create args object for learning rate adjustment
lr_args = dotdict({
'learning_rate': config.get('learning_rate', 1e-3),
'lradj': lr_adjust_strategy
})
# Initialize early stopping
early_stopping = EarlyStopping(
@ -722,8 +823,8 @@ def train_classification_model(
print("Early stopping triggered")
break
# Step the learning rate scheduler
scheduler.step()
# Adjust learning rate using utils.tools function
adjust_learning_rate(optimizer, epoch, lr_args)
# Load the best model
model.load_state_dict(torch.load(checkpoint_path))

408
train/train_diffusion.py Normal file
View File

@ -0,0 +1,408 @@
import os
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import swanlab
from typing import Dict, Any, Optional, Callable, Union, Tuple
from dataflow import data_provider
class EarlyStopping:
"""Early stopping to stop training when validation performance doesn't improve."""
def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'):
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = float('inf')
self.delta = delta
self.path = path
def __call__(self, val_loss, model):
score = -val_loss
if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model)
elif score < self.best_score + self.delta:
self.counter += 1
if self.verbose:
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss, model)
self.counter = 0
def save_checkpoint(self, val_loss, model):
"""Save model when validation loss decreases."""
if self.verbose:
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...')
torch.save(model.state_dict(), self.path)
self.val_loss_min = val_loss
class DatasetWrapperWithoutTimeFeatures(torch.utils.data.Dataset):
"""Wrapper to remove time features from dataflow datasets when use_x_mark=False"""
def __init__(self, original_dataset):
self.original_dataset = original_dataset
def __getitem__(self, index):
seq_x, seq_y, seq_x_mark, seq_y_mark = self.original_dataset[index]
return seq_x, seq_y
def __len__(self):
return len(self.original_dataset)
def inverse_transform(self, data):
if hasattr(self.original_dataset, 'inverse_transform'):
return self.original_dataset.inverse_transform(data)
return data
def create_data_loaders_from_dataflow(args, use_x_mark: bool = True) -> Dict[str, DataLoader]:
"""Create PyTorch DataLoaders using dataflow data_provider"""
train_data, _ = data_provider(args, flag='train')
val_data, _ = data_provider(args, flag='val')
test_data, _ = data_provider(args, flag='test')
if not use_x_mark:
train_data = DatasetWrapperWithoutTimeFeatures(train_data)
val_data = DatasetWrapperWithoutTimeFeatures(val_data)
test_data = DatasetWrapperWithoutTimeFeatures(test_data)
train_shuffle = True
val_shuffle = False
test_shuffle = False
train_drop_last = True
val_drop_last = True
test_drop_last = True
batch_size = args.batch_size
num_workers = args.num_workers
train_loader = DataLoader(
train_data, batch_size=batch_size, shuffle=train_shuffle,
num_workers=num_workers, drop_last=train_drop_last
)
val_loader = DataLoader(
val_data, batch_size=batch_size, shuffle=val_shuffle,
num_workers=num_workers, drop_last=val_drop_last
)
test_loader = DataLoader(
test_data, batch_size=batch_size, shuffle=test_shuffle,
num_workers=num_workers, drop_last=test_drop_last
)
return {'train': train_loader, 'val': val_loader, 'test': test_loader}
def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool = True) -> Dict[str, DataLoader]:
"""
Create PyTorch DataLoaders from an NPZ file
Args:
data_path (str): Path to the NPZ file containing the data
batch_size (int): Batch size for the DataLoaders
use_x_mark (bool): Whether to use time features (x_mark) from the data file
Returns:
Dict[str, DataLoader]: Dictionary with train, val, and test DataLoaders
"""
# Load data from NPZ file
data = np.load(data_path, allow_pickle=True)
train_x = data['train_x']
train_y = data['train_y']
val_x = data['val_x']
val_y = data['val_y']
test_x = data['test_x']
test_y = data['test_y']
# Load time features if available and needed
if use_x_mark:
train_x_mark = data.get('train_x_mark', None)
train_y_mark = data.get('train_y_mark', None)
val_x_mark = data.get('val_x_mark', None)
val_y_mark = data.get('val_y_mark', None)
test_x_mark = data.get('test_x_mark', None)
test_y_mark = data.get('test_y_mark', None)
else:
train_x_mark = None
train_y_mark = None
val_x_mark = None
val_y_mark = None
test_x_mark = None
test_y_mark = None
# Convert to PyTorch tensors
train_x = torch.FloatTensor(train_x)
train_y = torch.FloatTensor(train_y)
val_x = torch.FloatTensor(val_x)
val_y = torch.FloatTensor(val_y)
test_x = torch.FloatTensor(test_x)
test_y = torch.FloatTensor(test_y)
# Create datasets based on whether time features are available
if train_x_mark is not None:
train_x_mark = torch.FloatTensor(train_x_mark)
train_y_mark = torch.FloatTensor(train_y_mark)
val_x_mark = torch.FloatTensor(val_x_mark)
val_y_mark = torch.FloatTensor(val_y_mark)
test_x_mark = torch.FloatTensor(test_x_mark)
test_y_mark = torch.FloatTensor(test_y_mark)
train_dataset = TensorDataset(train_x, train_y, train_x_mark, train_y_mark)
val_dataset = TensorDataset(val_x, val_y, val_x_mark, val_y_mark)
test_dataset = TensorDataset(test_x, test_y, test_x_mark, test_y_mark)
else:
train_dataset = TensorDataset(train_x, train_y)
val_dataset = TensorDataset(val_x, val_y)
test_dataset = TensorDataset(test_x, test_y)
# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
return {
'train': train_loader,
'val': val_loader,
'test': test_loader
}
def train_diffusion_model(
model_constructor: Callable,
data_path: str,
project_name: str,
config: Dict[str, Any],
device: Optional[str] = None,
early_stopping_patience: int = 10,
max_epochs: int = 100,
checkpoint_dir: str = "./checkpoints",
log_interval: int = 10,
) -> Tuple[nn.Module, Dict[str, float]]:
"""
Train a Diffusion time series forecasting model using NPZ data loading
"""
# Setup device
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Initialize swanlab for experiment tracking
swanlab_run = swanlab.init(
project=project_name,
config=config,
)
# Create checkpoint directory if it doesn't exist
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(checkpoint_dir, f"{project_name}.pt")
# Create data loaders using NPZ files (following other models' pattern)
dataloaders = create_data_loaders(
data_path=data_path,
batch_size=config.get('batch_size', 32),
use_x_mark=False # DiffusionTimeSeries doesn't use time features
)
# Construct the model
model = model_constructor()
model = model.to(device)
print(f"Model created with {model.get_num_params():,} parameters")
# Define optimizer for diffusion training
optimizer = optim.Adam(
model.parameters(),
lr=config.get('learning_rate', 1e-4),
weight_decay=config.get('weight_decay', 1e-4)
)
# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', patience=5, factor=0.5
)
# Initialize early stopping
early_stopping = EarlyStopping(
patience=early_stopping_patience,
verbose=True,
path=checkpoint_path
)
# Training loop
best_val_loss = float('inf')
metrics = {}
for epoch in range(max_epochs):
print(f"\nEpoch {epoch+1}/{max_epochs}")
print("-" * 50)
# Training phase
model.train()
train_loss = 0.0
train_samples = 0
interval_loss = 0.0
start_time = time.time()
for batch_idx, (seq_x, seq_y) in enumerate(dataloaders['train']):
seq_x, seq_y = seq_x.to(device), seq_y.to(device)
optimizer.zero_grad()
# Diffusion training: model returns loss directly when y is provided
loss = model(seq_x, seq_y)
loss.backward()
optimizer.step()
train_loss += loss.item()
interval_loss += loss.item()
train_samples += 1
# Log at intervals
if (batch_idx + 1) % log_interval == 0:
elapsed_time = time.time() - start_time
avg_interval_loss = interval_loss / log_interval
print(f' Batch [{batch_idx+1}/{len(dataloaders["train"])}] '
f'Loss: {avg_interval_loss:.6f} '
f'Time: {elapsed_time:.2f}s')
# Log to swanlab
swanlab.log({
'batch_loss': avg_interval_loss,
'batch': epoch * len(dataloaders['train']) + batch_idx,
'learning_rate': optimizer.param_groups[0]['lr']
})
interval_loss = 0.0
start_time = time.time()
avg_train_loss = train_loss / train_samples
# Validation phase - Use faster sampling for validation
model.eval()
val_loss = 0.0
val_samples = 0
criterion = nn.MSELoss()
print(" Validating...")
with torch.no_grad():
# Temporarily reduce diffusion steps for faster validation
original_timesteps = model.diffusion.num_timesteps
model.diffusion.num_timesteps = 200# Much faster validation
for batch_idx, (seq_x, seq_y) in enumerate(dataloaders['val']):
seq_x, seq_y = seq_x.to(device), seq_y.to(device)
# Generate predictions (inference mode with reduced steps)
pred = model(seq_x)
# Compute MSE loss for validation
loss = criterion(pred, seq_y)
val_loss += loss.item()
val_samples += 1
# Print validation progress for first epoch
if epoch == 0 and (batch_idx + 1) % 50 == 0:
print(f" Val Batch [{batch_idx+1}/{len(dataloaders['val'])}]")
# Early break for very first epoch to speed up
if epoch == 0 and batch_idx >= 100: # Only validate on first 100 batches for first epoch
break
# Restore original timesteps
model.diffusion.num_timesteps = original_timesteps
avg_val_loss = val_loss / val_samples
# Learning rate scheduling
scheduler.step(avg_val_loss)
current_lr = optimizer.param_groups[0]['lr']
print(f" Train Loss: {avg_train_loss:.6f}")
print(f" Val Loss: {avg_val_loss:.6f}")
print(f" Learning Rate: {current_lr:.2e}")
# Log to swanlab
swanlab.log({
'epoch': epoch + 1,
'train_loss': avg_train_loss,
'val_loss': avg_val_loss,
'learning_rate': current_lr
})
# Early stopping check
early_stopping(avg_val_loss, model)
if early_stopping.early_stop:
print(f"Early stopping at epoch {epoch + 1}")
break
# Load best model
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
# Final evaluation on test set
print("\nEvaluating on test set...")
model.eval()
test_loss = 0.0
test_samples = 0
all_preds = []
all_targets = []
with torch.no_grad():
# Use reduced timesteps for faster testing
original_timesteps = model.diffusion.num_timesteps
model.diffusion.num_timesteps = 200 # Faster but still good quality
for batch_idx, (seq_x, seq_y) in enumerate(dataloaders['test']):
seq_x, seq_y = seq_x.to(device), seq_y.to(device)
pred = model(seq_x)
loss = criterion(pred, seq_y)
test_loss += loss.item()
test_samples += 1
all_preds.append(pred.cpu().numpy())
all_targets.append(seq_y.cpu().numpy())
# Print progress every 50 batches
if (batch_idx + 1) % 50 == 0:
print(f" Test Batch [{batch_idx+1}/{len(dataloaders['test'])}]")
# Restore original timesteps
model.diffusion.num_timesteps = original_timesteps
avg_test_loss = test_loss / test_samples
# Calculate additional metrics
all_preds = np.concatenate(all_preds, axis=0)
all_targets = np.concatenate(all_targets, axis=0)
mse = np.mean((all_preds - all_targets) ** 2)
mae = np.mean(np.abs(all_preds - all_targets))
rmse = np.sqrt(mse)
metrics = {
'test_mse': mse,
'test_mae': mae,
'test_rmse': rmse,
'test_loss': avg_test_loss
}
print(f"Test Results:")
print(f" MSE: {mse:.6f}")
print(f" MAE: {mae:.6f}")
print(f" RMSE: {rmse:.6f}")
# Log final results
swanlab.log(metrics)
swanlab.finish()
return model, metrics

View File

@ -0,0 +1,329 @@
#!/usr/bin/env python3
"""
Training script for xPatch_SparseChannel model on multiple datasets.
Supports Weather, Traffic, Electricity, Exchange, and ILI datasets with sigmoid learning rate adjustment.
"""
import os
import math
import argparse
import torch
import torch.nn as nn
from train.train import train_forecasting_model
from models.xPatch_SparseChannel.xPatch import Model as xPatchSparseChannel
# Dataset configurations
DATASET_CONFIGS = {
'Weather': {
'csv_file': 'weather.csv',
'enc_in': 21,
'batch_size': 2048,
'learning_rate': 0.0005,
'target': 'OT',
'data_path': 'weather/weather.csv',
'seq_len': 96,
'pred_lengths': [96, 192, 336, 720]
},
'Traffic': {
'csv_file': 'traffic.csv',
'enc_in': 862,
'batch_size': 32,
'learning_rate': 0.002,
'target': 'OT',
'data_path': 'traffic/traffic.csv',
'seq_len': 96,
'pred_lengths': [96, 192, 336, 720]
},
'Electricity': {
'csv_file': 'electricity.csv',
'enc_in': 321,
'batch_size': 32,
'learning_rate': 0.001,
'target': 'OT',
'data_path': 'electricity/electricity.csv',
'seq_len': 96,
'pred_lengths': [96, 192, 336, 720]
},
'Exchange': {
'csv_file': 'exchange_rate.csv',
'enc_in': 8,
'batch_size': 128,
'learning_rate': 0.00005,
'target': 'OT',
'data_path': 'exchange_rate/exchange_rate.csv',
'seq_len': 96,
'pred_lengths': [96, 192, 336, 720]
},
'ILI': {
'csv_file': 'national_illness.csv',
'enc_in': 7,
'batch_size': 32,
'learning_rate': 0.01,
'target': 'ot',
'data_path': 'illness/national_illness.csv',
'seq_len': 36,
'pred_lengths': [24, 36, 48, 60]
}
}
class Args:
"""Configuration class for xPatch_SparseChannel model parameters."""
def __init__(self, dataset_name, pred_len):
dataset_config = DATASET_CONFIGS[dataset_name]
# Model architecture parameters
self.task_name = 'long_term_forecast'
self.seq_len = dataset_config['seq_len'] # Use dataset-specific seq_len
self.label_len = self.seq_len // 2 # Half of seq_len as label length
self.pred_len = pred_len
self.enc_in = dataset_config['enc_in']
self.c_out = dataset_config['enc_in']
# xPatch specific parameters from reference
self.patch_len = 16 # patch length
self.stride = 8 # stride
self.padding_patch = 'end' # padding on the end
# Moving Average parameters
self.ma_type = 'ema' # moving average type
self.alpha = 0.3 # alpha parameter for EMA
self.beta = 0.3 # beta parameter for DEMA
# RevIN normalization
self.revin = 1 # RevIN; True 1 False 0
# Time features (not used by xPatch but required by data loader)
self.embed = 'timeF' # Time feature embedding type
self.freq = 'h' # Frequency for time features (hourly)
# Dataset specific parameters
self.data = 'custom'
self.root_path = './data/'
self.data_path = dataset_config['data_path']
self.features = 'M' # Multivariate prediction
self.target = dataset_config['target'] # Target column
self.train_only = False
# Required for dataflow - will be set by config
self.batch_size = dataset_config['batch_size']
self.num_workers = 8 # Will be overridden by config
print(f"xPatch_SparseChannel Model configuration for {dataset_name}:")
print(f" - Input channels (C): {self.enc_in}")
print(f" - Patch length: {self.patch_len}")
print(f" - Stride: {self.stride}")
print(f" - Sequence length: {self.seq_len}") # Now dataset-specific
print(f" - Prediction length: {pred_len}")
print(f" - Moving average type: {self.ma_type}")
print(f" - Alpha: {self.alpha}")
print(f" - Beta: {self.beta}")
print(f" - RevIN: {self.revin}")
print(f" - Target: {self.target}")
print(f" - Batch size: {self.batch_size}")
def create_xpatch_sparse_model(args):
"""Create xPatch_SparseChannel model with given configuration."""
def model_constructor():
return xPatchSparseChannel(args)
return model_constructor
def train_single_dataset(dataset_name, pred_len, model_args, cmd_args, use_ps_loss=True):
"""Train xPatch_SparseChannel on specified dataset with given prediction length."""
dataset_config = DATASET_CONFIGS[dataset_name]
# Update args for current prediction length
model_args.pred_len = pred_len
# Update dataflow parameters from command line args
model_args.num_workers = cmd_args.num_workers
# Create model constructor
model_constructor = create_xpatch_sparse_model(model_args)
# Training configuration with dataset-specific parameters
config = {
'learning_rate': dataset_config['learning_rate'], # Dataset-specific learning rate
'batch_size': dataset_config['batch_size'], # Dataset-specific batch size
'weight_decay': 1e-4,
'dataset': dataset_name,
'pred_len': pred_len,
'seq_len': model_args.seq_len,
'patch_len': model_args.patch_len,
'stride': model_args.stride,
'ma_type': model_args.ma_type,
'use_ps_loss': use_ps_loss,
'num_workers': cmd_args.num_workers,
'pin_memory': True,
'persistent_workers': True
}
# Project name for tracking
loss_suffix = "_PSLoss" if use_ps_loss else "_MSE"
project_name = f"xPatch_SparseChannel_{dataset_name}_pred{pred_len}{loss_suffix}_sigmoid"
print(f"\n{'='*60}")
print(f"Training {dataset_name} with prediction length {pred_len}")
print(f"Model: xPatch_SparseChannel")
print(f"Loss function: {'PS_Loss' if use_ps_loss else 'MSE'}")
print(f"Learning rate: {dataset_config['learning_rate']}")
print(f"Batch size: {dataset_config['batch_size']}")
print(f"Features: {dataset_config['enc_in']}")
print(f"Data path: {model_args.root_path}{model_args.data_path}")
print(f"LR adjustment: sigmoid")
print(f"{'='*60}")
# Train the model
try:
model, metrics = train_forecasting_model(
model_constructor=model_constructor,
data_path=f"{model_args.root_path}{model_args.data_path}",
project_name=project_name,
config=config,
early_stopping_patience=5,
max_epochs=100,
checkpoint_dir="./checkpoints",
log_interval=50,
use_x_mark=False, # xPatch_SparseChannel doesn't use time features
use_ps_loss=use_ps_loss,
ps_lambda=cmd_args.ps_lambda,
patch_len_threshold=64,
use_gdw=True,
dataset_mode="dataflow",
dataflow_args=model_args,
lr_adjust_strategy="sigmoid" # Use sigmoid learning rate adjustment
)
print(f"Training completed for {project_name}")
if use_ps_loss:
print(f"Final validation MSE: {metrics.get('final_val_mse', 'N/A'):.6f}")
else:
print(f"Final validation MSE: {metrics.get('final_val_loss', 'N/A'):.6f}")
return model, metrics
except Exception as e:
print(f"Error training {project_name}: {e}")
import traceback
traceback.print_exc()
return None, None
def main():
parser = argparse.ArgumentParser(description='Train xPatch_SparseChannel on multiple datasets with sigmoid LR adjustment')
parser.add_argument('--datasets', nargs='+', type=str,
default=['Weather', 'Traffic', 'Electricity', 'Exchange', 'ILI'],
choices=['Weather', 'Traffic', 'Electricity', 'Exchange', 'ILI'],
help='List of datasets to train on')
parser.add_argument('--use_ps_loss', action='store_true', default=True,
help='Use PS_Loss instead of MSE')
parser.add_argument('--ps_lambda', type=float, default=5.0,
help='Weight for PS loss component')
parser.add_argument('--device', type=str, default=None,
help='Device to use for training (cuda/cpu)')
parser.add_argument('--num_workers', type=int, default=8,
help='Number of data loading workers')
args = parser.parse_args()
print("xPatch_SparseChannel Multi-Dataset Training Script with Sigmoid LR Adjustment")
print("=" * 80)
print(f"Datasets: {args.datasets}")
print(f"Use PS_Loss: {args.use_ps_loss}")
print(f"PS_Lambda: {args.ps_lambda}")
print(f"Number of workers: {args.num_workers}")
print(f"Learning rate adjustment: sigmoid")
# Display dataset configurations
print("\nDataset Configurations:")
for dataset in args.datasets:
config = DATASET_CONFIGS[dataset]
print(f" {dataset}:")
print(f" - Features: {config['enc_in']}")
print(f" - Batch size: {config['batch_size']}")
print(f" - Learning rate: {config['learning_rate']}")
print(f" - Sequence length: {config['seq_len']}")
print(f" - Prediction lengths: {config['pred_lengths']}")
print(f" - Data path: {config['data_path']}")
# Check if data files exist
missing_datasets = []
for dataset in args.datasets:
config = DATASET_CONFIGS[dataset]
data_path = f"./data/{config['data_path']}"
if not os.path.exists(data_path):
missing_datasets.append(f"{dataset}: '{data_path}'")
if missing_datasets:
print(f"\nError: The following dataset files were not found:")
for missing in missing_datasets:
print(f" - {missing}")
print("Please ensure all dataset files are available in the data/ directory.")
return
# Set device
if args.device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
else:
device = args.device
print(f"\nUsing device: {device}")
# Training results storage
all_results = {}
# Train on each dataset
for dataset in args.datasets:
print(f"\n{'#'*80}")
print(f"STARTING TRAINING ON {dataset.upper()} DATASET")
print(f"{'#'*80}")
all_results[dataset] = {}
config = DATASET_CONFIGS[dataset]
# Train on each prediction length for this dataset
for pred_len in config['pred_lengths']: # Use dataset-specific prediction lengths
# Create model configuration for current dataset
model_args = Args(
dataset_name=dataset,
pred_len=pred_len
)
# Train the model
model, metrics = train_single_dataset(
dataset_name=dataset,
pred_len=pred_len,
model_args=model_args,
cmd_args=args,
use_ps_loss=args.use_ps_loss
)
# Store results
all_results[dataset][pred_len] = {
'model': model,
'metrics': metrics,
'data_path': f"./data/{config['data_path']}"
}
# Print comprehensive summary
print("\n" + "=" * 100)
print("COMPREHENSIVE TRAINING SUMMARY")
print("=" * 100)
for dataset in args.datasets:
config = DATASET_CONFIGS[dataset]
print(f"\n{dataset} (Features: {config['enc_in']}, Batch: {config['batch_size']}, LR: {config['learning_rate']}, Seq: {config['seq_len']}):")
for pred_len in all_results[dataset]:
result = all_results[dataset][pred_len]
if result['metrics'] is not None:
if args.use_ps_loss:
mse = result['metrics'].get('final_val_mse', 'N/A')
else:
mse = result['metrics'].get('final_val_loss', 'N/A')
print(f" Pred Length {pred_len}: MSE = {mse}")
else:
print(f" Pred Length {pred_len}: Training failed")
print(f"\nAll models saved in: ./checkpoints/")
print("All datasets training completed!")
if __name__ == "__main__":
main()