diff --git a/dataflow/data_loader.py b/dataflow/data_loader.py index 386ad09..215d986 100644 --- a/dataflow/data_loader.py +++ b/dataflow/data_loader.py @@ -273,15 +273,21 @@ class Dataset_Custom(Dataset): self.data_stamp = data_stamp def __getitem__(self, index): + # 1. 定义输入序列 seq_x 的起止位置 s_begin = index s_end = s_begin + self.seq_len - r_begin = s_end - self.label_len - r_end = r_begin + self.label_len + self.pred_len - + # 2. 定义目标序列 seq_y 的起止位置 + # seq_y 的开始 (r_begin) 就是 seq_x 的结束 (s_end) + r_begin = s_end + # seq_y 的结束 (r_end) 是其开始位置加上预测长度 (pred_len) + r_end = r_begin + self.pred_len + # 3. 根据起止位置切片数据 seq_x = self.data_x[s_begin:s_end] seq_y = self.data_y[r_begin:r_end] seq_x_mark = self.data_stamp[s_begin:s_end] seq_y_mark = self.data_stamp[r_begin:r_end] + seq_x = seq_x.astype('float32') + seq_y = seq_y.astype('float32') return seq_x, seq_y, seq_x_mark, seq_y_mark diff --git a/models/DiffusionTimeSeries/__init__.py b/models/DiffusionTimeSeries/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/DiffusionTimeSeries/diffusion_ts.py b/models/DiffusionTimeSeries/diffusion_ts.py new file mode 100644 index 0000000..fbaf47b --- /dev/null +++ b/models/DiffusionTimeSeries/diffusion_ts.py @@ -0,0 +1,323 @@ +import math +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# ----------------------- 工具:构造/缩放线性 betas ----------------------- +def make_linear_betas(T: int, beta_start=1e-4, beta_end=2e-2, device='cpu'): + return torch.linspace(beta_start, beta_end, T, device=device) + +def cumprod_from_betas(betas: torch.Tensor): + alphas = 1.0 - betas + return torch.cumprod(alphas, dim=0) # shape [T] + +@torch.no_grad() +def scale_betas_to_target_cumprod(betas: torch.Tensor, target_cumprod: float, max_scale: float = 100.0): + """ + 给定一段 betas[1..T],寻找缩放系数 s>0,使得 ∏(1 - s*beta_i) = target_cumprod + 用二分法在 (0, s_max) 上搜索。确保 0 < s*beta_i < 1。 + """ + device = betas.device + eps = 1e-12 + s_low = 0.0 + s_high = min(max_scale, (1.0 - 1e-6) / (betas.max().item() + eps)) # 使 1 - s*beta > 0 + + def cumprod_with_scale(s: float): + a = (1.0 - betas * s).clamp(min=1e-6, max=1.0-1e-6) + return torch.cumprod(a, dim=0)[-1].item() + + # 若不缩放已接近目标,直接返回 + base = cumprod_with_scale(1.0) + if abs(base - target_cumprod) / max(target_cumprod, 1e-12) < 1e-6: + return betas + + # 目标在 (0, s_high) 内单调可达,进行二分 + for _ in range(60): + mid = 0.5 * (s_low + s_high) + val = cumprod_with_scale(mid) + if val > target_cumprod: + # 乘子太小(噪声弱),需要更大 s + s_low = mid + else: + s_high = mid + s_best = 0.5 * (s_low + s_high) + return (betas * s_best).clamp(min=1e-8, max=1-1e-6) + + +# ------------------------------ DiT Blocks -------------------------------- +class DiTBlock(nn.Module): + def __init__(self, dim: int, heads: int, mlp_ratio=4.0): + super().__init__() + self.ln1 = nn.LayerNorm(dim) + self.attn = nn.MultiheadAttention(dim, heads, batch_first=True) + self.ln2 = nn.LayerNorm(dim) + self.mlp = nn.Sequential( + nn.Linear(dim, int(dim * mlp_ratio)), + nn.GELU(), + nn.Linear(int(dim * mlp_ratio), dim), + ) + + def forward(self, x): + # x: [B, C_tokens, D] + h = self.attn(self.ln1(x), self.ln1(x), self.ln1(x))[0] + x = x + h + x = x + self.mlp(self.ln2(x)) + return x + + +class DiTChannelTokens(nn.Module): + """ + Token = 一个通道(变量)。 + 对于每个通道,输入是 [L] 的时间向量;我们用两条投影: + - W_x : 把 x_t 的时间向量投影成 token 向量 + - W_n : 把 noise-level(来自 schedule 的 ā 或 b̄ 的时间向量)投影成 token 偏置 + 注意:不再使用可学习的 t-embedding;噪声条件完全由 noise map 决定。 + """ + def __init__(self, L: int, C: int, dim: int = 256, depth: int = 8, heads: int = 8): + super().__init__() + self.L = L + self.C = C + self.dim = dim + + # 通道嵌入(可选,用于区分变量) + self.channel_embed = nn.Parameter(torch.randn(C, dim) * 0.02) + + # 将每个通道的时间序列映射到 token + self.proj_x = nn.Linear(L, dim, bias=False) + # 将每个通道的逐时间噪声强度(例如 [sqrt(ā), sqrt(1-ā)] 拼接后经一层线性) + self.proj_noise = nn.Linear(L, dim, bias=True) + + self.blocks = nn.ModuleList([DiTBlock(dim, heads) for _ in range(depth)]) + self.ln_f = nn.LayerNorm(dim) + + # 反投影回时间长度 L,预测 ε(每通道独立投影) + self.head = nn.Linear(dim, L, bias=False) + + def forward(self, x_t: torch.Tensor, noise_feat: torch.Tensor): + """ + x_t : [B, L, C] + noise_feat: [B, L, C] (建议传入 sqrt(ā) 或 concat 后先合并到 L 维度,这里用一条投影即可) + 返回 ε̂ : [B, L, C] + """ + B, L, C = x_t.shape + assert L == self.L and C == self.C + + # 逐通道映射成 token + # 把 (B, L, C) 变 (B, C, L) 再线性 + x_tc = x_t.permute(0, 2, 1) # [B, C, L] + n_tc = noise_feat.permute(0, 2, 1) # [B, C, L] + + tok = self.proj_x(x_tc) + self.proj_noise(n_tc) # [B, C, D] + tok = tok + self.channel_embed.unsqueeze(0) # broadcast [1, C, D] + + for blk in self.blocks: + tok = blk(tok) # [B, C, D] + + tok = self.ln_f(tok) + out = self.head(tok) # [B, C, L] + eps_pred = out.permute(0, 2, 1) # [B, L, C] + return eps_pred + + +# ----------------------- RAD 两阶段扩散(通道为token) ----------------------- +class RADChannelDiT(nn.Module): + def __init__(self, + past_len: int, + future_len: int, + channels: int, + T: int = 1000, + T1_ratio: float = 0.7, + model_dim: int = 256, + depth: int = 8, + heads: int = 8, + beta_start: float = 1e-4, + beta_end: float = 2e-2, + use_cosine_target: bool = True): + """ + - 训练:两阶段(Phase-1 + Phase-2),t 从 [1..T] 均匀采样 + - 推理:仅使用 Phase-1(t: T1→1),只更新未来区域 + - Token=通道,每个 token 见到整个时间轴 + 噪声强度时间向量 + """ + super().__init__() + self.P = past_len + self.H = future_len + self.C = channels + self.L = past_len + future_len + self.T = T + self.T1 = max(1, int(T * T1_ratio)) + self.T2 = T - self.T1 + assert self.T2 >= 1, "T1_ratio 不能太大,至少留下 1 步给 Phase-2" + + device = torch.device('cpu') + + # 目标 ā_T(用于把两段线性 schedule 归一到同一最终噪声强度) + if use_cosine_target: + # 参考 cosine 计划,得到一条“全局目标 ā_T” + steps = T + 1 + x = torch.linspace(0, T, steps, dtype=torch.float64) + s = 0.008 + alphas_cum = torch.cos(((x / T) + s) / (1 + s) * math.pi / 2) ** 2 + alphas_cum = alphas_cum / alphas_cum[0] + a_bar_target_T = float(alphas_cum[-1]) + else: + # 直接用 DDPM 线性 beta 的结果作为目标 + betas_full = make_linear_betas(T, beta_start, beta_end, device) + a_bar_target_T = cumprod_from_betas(betas_full)[-1].item() + + # Phase-1 & Phase-2 原始线性 beta + betas1 = make_linear_betas(self.T1, beta_start, beta_end, device) + betas2 = make_linear_betas(self.T2, beta_start, beta_end, device) + + # 首先不缩放,计算 ā1[T1], ā2[T2] + a_bar1 = cumprod_from_betas(betas1) # shape [T1] + a_bar2 = cumprod_from_betas(betas2) # shape [T2] + + # 缩放 Phase-2 的 betas,使 ā1[T1] * ā2'[T2] = 目标 ā_T + target_a2 = a_bar_target_T / (a_bar1[-1].item() + 1e-12) + betas2 = scale_betas_to_target_cumprod(betas2, target_a2) + + # 重新计算 + # a_bar1 = cumprod_from_betas(betas1).float() # [T1] + a_bar2 = cumprod_from_betas(betas2).float() # [T2] + + self.register_buffer("betas1", betas1.float()) + self.register_buffer("betas2", betas2.float()) + self.register_buffer("alphas1", 1.0 - betas1.float()) + self.register_buffer("alphas2", 1.0 - betas2.float()) + self.register_buffer("a_bar1", a_bar1) + self.register_buffer("a_bar2", a_bar2) + self.register_buffer("a_bar_target_T", torch.tensor(a_bar_target_T, dtype=torch.float32)) + + # Backbone: token=通道 + self.backbone = DiTChannelTokens(L=self.L, C=self.C, dim=model_dim, depth=depth, heads=heads) + + # ------------------------ 内部:构造 mask & āt,i ------------------------ + def _mask_future(self, B, device): + # mask: 未来区域=1,历史=0,形状 [B, L, C](与网络输入 [B,L,C] 对齐) + m = torch.zeros(B, self.L, self.C, device=device) + m[:, self.P:, :] = 1.0 + return m + + def _a_bar_map_at_t(self, t_scalar: int, B: int, device, mask_future: torch.Tensor): + """ + 构造逐像素 āt,i,形状 [B, L, C] + - 若 t<=T1:未来区域用 ā1[t],历史区域=1 + - 若 t> T1:未来区域固定 ā1[T1],历史区域用 ā2[t-T1] + """ + if t_scalar <= self.T1: + a_future = self.a_bar1[t_scalar - 1] # 索引从 0 开始 + a_past = torch.tensor(1.0, device=device) + else: + a_future = self.a_bar1[-1] + a_past = self.a_bar2[t_scalar - self.T1 - 1] + + a_future_map = torch.full((B, self.L, self.C), float(a_future.item()), device=device) + a_past_map = torch.full((B, self.L, self.C), float(a_past.item()), device=device) + a_map = a_past_map * (1 - mask_future) + a_future_map * mask_future + return a_map # [B, L, C] + + # ----------------------------- 前向训练 ----------------------------- + def forward(self, x_hist: torch.Tensor, x_future: torch.Tensor) -> Tuple[torch.Tensor, dict]: + """ + x_hist : [B, P, C] + x_future : [B, H, C] + 训练:采样 t∈[1..T],构造两阶段 āt,i,边际加噪 xt,并用逐通道 token 的 DiT 预测 ε + """ + B = x_hist.size(0) + device = x_hist.device + x0 = torch.cat([x_hist, x_future], dim=1) # [B, L, C] + + # 采样训练步 t (1..T) + t = torch.randint(1, self.T + 1, (B,), device=device, dtype=torch.long) + + # 构造 mask 和逐像素 āt,i + mask_fut = self._mask_future(B, device) # [B, L, C] + # 逐样本构造 āt,i(不同样本 t 不同,只能用循环或向量化 trick;B 通常不大,for 循环即可) + a_bar_map = torch.stack([self._a_bar_map_at_t(int(tt.item()), 1, device, mask_fut[0:1]) + for tt in t], dim=0).squeeze(1) # [B,L,C] + + # 边际加噪 + eps = torch.randn_like(x0) # [B,L,C] + x_t = a_bar_map.sqrt() * x0 + (1.0 - a_bar_map).sqrt() * eps + + # Spatial Noise Embedding:完全由 schedule 决定 + # 传入每个像素的 √ā 和 √(1-ā)(或任选其一);这里用 √ā + noise_feat = a_bar_map.sqrt() # [B,L,C] + + # 预测 ε + eps_pred = self.backbone(x_t, noise_feat) # [B,L,C] + + loss = F.mse_loss(eps_pred, eps) + return loss, {'t_mean': t.float().mean().item()} + + # ----------------------------- 采样推理 ----------------------------- + @torch.no_grad() + def sample(self, x_hist: torch.Tensor, steps: Optional[int] = None) -> torch.Tensor: + """ + 仅 Phase-1 推理:t = T1..1,只更新未来区域,历史保持观测值 + x_hist : [B,P,C] + return : [B,H,C] + """ + B = x_hist.size(0) + device = x_hist.device + mask_fut = self._mask_future(B, device) # [B,L,C] + + # 初始化 x:历史=观测,未来=高斯噪声 + x = torch.zeros(B, self.L, self.C, device=device) + x[:, :self.P, :] = x_hist + x[:, self.P:, :] = torch.randn(B, self.H, self.C, device=device) + + # 支持子采样:把 [T1..1] 均匀下采样到 steps + T1 = self.T1 + steps = steps if steps is not None else T1 + steps = max(1, min(steps, T1)) + ts = torch.linspace(T1, 1, steps, device=device).long().tolist() + # 为 DDPM 更新需要 α_t, β_t(仅对未来区域定义) + alphas1 = self.alphas1 # [T1] + betas1 = self.betas1 + a_bar1 = self.a_bar1 + + for idx, t_scalar in enumerate(ts): + # 当前 āt,i(Phase-1:历史=1,未来=ā1[t]) + a_bar_map = self._a_bar_map_at_t(int(t_scalar), B, device, mask_fut) # [B,L,C] + # 网络条件:用 √ā 作为噪声嵌入 + noise_feat = a_bar_map.sqrt() + # 预测 ε + eps_pred = self.backbone(x, noise_feat) # [B,L,C] + + # 对未来区域做 DDPM 一步(历史区保持原值) + # 标准 DDPM 公式(像素在未来区域共享同一 α_t、β_t) + t_idx = t_scalar - 1 + alpha_t = alphas1[t_idx] # 标量 + beta_t = betas1[t_idx] + a_bar_t = a_bar1[t_idx] + if t_scalar > 1: + a_bar_prev = a_bar1[t_idx - 1] + else: + a_bar_prev = torch.tensor(1.0, device=device) + + # x0 预测(仅用于推导均值,也可直接用μ公式) + x0_pred = (x - (1.0 - a_bar_t).sqrt() * eps_pred) / (a_bar_t.sqrt() + 1e-8) + + # 均值:μ_t = 1/sqrt(α_t) * (x_t - β_t / sqrt(1 - ā_t) * ε̂) + mean = (x - (beta_t / (1.0 - a_bar_t).sqrt()) * eps_pred) / (alpha_t.sqrt() + 1e-8) + + # 采样噪声 + if t_scalar > 1: + z = torch.randn_like(x) + else: + z = torch.zeros_like(x) + + # 方差项(DDPM):σ_t = sqrt(β_t) + x_next = mean + z * beta_t.sqrt() + + # 仅替换未来区域 + x = x * (1 - mask_fut) + x_next * mask_fut + # 历史强制为观测 + x[:, :self.P, :] = x_hist + + return x[:, self.P:, :] # [B,H,C] + diff --git a/models/iTransformer/iTransformer.py b/models/iTransformer/iTransformer.py new file mode 100644 index 0000000..4833a69 --- /dev/null +++ b/models/iTransformer/iTransformer.py @@ -0,0 +1,132 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from layers.Transformer_EncDec import Encoder, EncoderLayer +from layers.SelfAttention_Family import FullAttention, AttentionLayer +from layers.Embed import DataEmbedding_inverted +import numpy as np + + +class Model(nn.Module): + """ + Paper link: https://arxiv.org/abs/2310.06625 + """ + + def __init__(self, configs): + super(Model, self).__init__() + self.task_name = configs.task_name + self.seq_len = configs.seq_len + self.pred_len = configs.pred_len + # Embedding + self.enc_embedding = DataEmbedding_inverted(configs.seq_len, configs.d_model, configs.embed, configs.freq, + configs.dropout) + # Encoder + self.encoder = Encoder( + [ + EncoderLayer( + AttentionLayer( + FullAttention(False, configs.factor, attention_dropout=configs.dropout, + output_attention=False), configs.d_model, configs.n_heads), + configs.d_model, + configs.d_ff, + dropout=configs.dropout, + activation=configs.activation + ) for l in range(configs.e_layers) + ], + norm_layer=torch.nn.LayerNorm(configs.d_model) + ) + # Decoder + if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': + self.projection = nn.Linear(configs.d_model, configs.pred_len, bias=True) + if self.task_name == 'imputation': + self.projection = nn.Linear(configs.d_model, configs.seq_len, bias=True) + if self.task_name == 'anomaly_detection': + self.projection = nn.Linear(configs.d_model, configs.seq_len, bias=True) + if self.task_name == 'classification': + self.act = F.gelu + self.dropout = nn.Dropout(configs.dropout) + self.projection = nn.Linear(configs.d_model * configs.enc_in, configs.num_class) + + def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): + # Normalization from Non-stationary Transformer + means = x_enc.mean(1, keepdim=True).detach() + x_enc = x_enc - means + stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) + x_enc /= stdev + + _, _, N = x_enc.shape + + # Embedding + enc_out = self.enc_embedding(x_enc, x_mark_enc) + enc_out, attns = self.encoder(enc_out, attn_mask=None) + + dec_out = self.projection(enc_out).permute(0, 2, 1)[:, :, :N] + # De-Normalization from Non-stationary Transformer + dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) + dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) + return dec_out + + def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask): + # Normalization from Non-stationary Transformer + means = x_enc.mean(1, keepdim=True).detach() + x_enc = x_enc - means + stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) + x_enc /= stdev + + _, L, N = x_enc.shape + + # Embedding + enc_out = self.enc_embedding(x_enc, x_mark_enc) + enc_out, attns = self.encoder(enc_out, attn_mask=None) + + dec_out = self.projection(enc_out).permute(0, 2, 1)[:, :, :N] + # De-Normalization from Non-stationary Transformer + dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, L, 1)) + dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, L, 1)) + return dec_out + + def anomaly_detection(self, x_enc): + # Normalization from Non-stationary Transformer + means = x_enc.mean(1, keepdim=True).detach() + x_enc = x_enc - means + stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) + x_enc /= stdev + + _, L, N = x_enc.shape + + # Embedding + enc_out = self.enc_embedding(x_enc, None) + enc_out, attns = self.encoder(enc_out, attn_mask=None) + + dec_out = self.projection(enc_out).permute(0, 2, 1)[:, :, :N] + # De-Normalization from Non-stationary Transformer + dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, L, 1)) + dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, L, 1)) + return dec_out + + def classification(self, x_enc, x_mark_enc): + # Embedding + enc_out = self.enc_embedding(x_enc, None) + enc_out, attns = self.encoder(enc_out, attn_mask=None) + + # Output + output = self.act(enc_out) # the output transformer encoder/decoder embeddings don't include non-linearity + output = self.dropout(output) + output = output.reshape(output.shape[0], -1) # (batch_size, c_in * d_model) + output = self.projection(output) # (batch_size, num_classes) + return output + + def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): + if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': + dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) + return dec_out[:, -self.pred_len:, :] # [B, L, D] + if self.task_name == 'imputation': + dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask) + return dec_out # [B, L, D] + if self.task_name == 'anomaly_detection': + dec_out = self.anomaly_detection(x_enc) + return dec_out # [B, L, D] + if self.task_name == 'classification': + dec_out = self.classification(x_enc, x_mark_enc) + return dec_out # [B, N] + return None diff --git a/models/xPatch_SparseChannel/__init__.py b/models/xPatch_SparseChannel/__init__.py new file mode 100644 index 0000000..977f249 --- /dev/null +++ b/models/xPatch_SparseChannel/__init__.py @@ -0,0 +1,3 @@ +from .xPatch import Model + +__all__ = ['Model'] \ No newline at end of file diff --git a/models/xPatch_SparseChannel/patchtst_ci.py b/models/xPatch_SparseChannel/patchtst_ci.py new file mode 100644 index 0000000..3a3f144 --- /dev/null +++ b/models/xPatch_SparseChannel/patchtst_ci.py @@ -0,0 +1,66 @@ +""" +SeasonPatch = PatchTST (CI) + ChannelGraphMixer + 线性预测头 +""" +import torch +import torch.nn as nn +from layers.PatchTST_layers import positional_encoding # 已存在 +from layers.mixer import HierarchicalGraphMixer # 刚才创建 + +# ------ PatchTST CI 编码器(与官方实现等价, 但去掉 head 便于插入 Mixer) ------ +class _Encoder(nn.Module): + def __init__(self, c_in, seq_len, patch_len, stride, + d_model=128, n_layers=3, n_heads=16, dropout=0.): + super().__init__() + self.patch_len = patch_len + self.stride = stride + self.patch_num = (seq_len - patch_len)//stride + 1 + + self.proj = nn.Linear(patch_len, d_model) + self.pos = positional_encoding('zeros', True, self.patch_num, d_model) + self.drop = nn.Dropout(dropout) + + from layers.PatchTST_backbone import TSTEncoder # 与官方同名 + self.encoder = TSTEncoder(self.patch_num, d_model, n_heads, + d_ff=d_model*2, dropout=dropout, + n_layers=n_layers, norm='LayerNorm') + + def forward(self, x): # [B,C,L] + x = x.unfold(-1, self.patch_len, self.stride) # [B,C,patch_num,patch_len] + B,C,N,P = x.shape + z = self.proj(x) # [B,C,N,d] + z = z.contiguous().view(B*C, N, -1) # [B*C,N,d] + z = self.drop(z + self.pos) + z = self.encoder(z) # [B*C,N,d] + return z.view(B,C,N,-1) # [B,C,N,d] + +# ------------------------- SeasonPatch ----------------------------- +class SeasonPatch(nn.Module): + def __init__(self, + c_in: int, + seq_len: int, + pred_len: int, + patch_len: int, + stride: int, + k_graph: int = 5, + d_model: int = 128, + revin: bool = True): + super().__init__() + + self.encoder = _Encoder(c_in, seq_len, patch_len, stride, + d_model=d_model) + self.mixer = HierarchicalGraphMixer(c_in, dim=d_model, k=k_graph) + + # Calculate actual number of patches + patch_num = (seq_len - patch_len) // stride + 1 + self.head = nn.Linear(patch_num * d_model, pred_len) + + def forward(self, x): # x [B,L,C] + x = x.permute(0,2,1) # → [B,C,L] + z = self.encoder(x) # [B,C,N,d] + B,C,N,D = z.shape + # 通道独立 -> 稀疏跨通道注入 + z_mix = self.mixer(z).view(B,C,N*D) # [B,C,N,d] + y_pred = self.head(z_mix) # [B,C,T] + + return y_pred + diff --git a/models/xPatch_SparseChannel/xPatch.py b/models/xPatch_SparseChannel/xPatch.py new file mode 100644 index 0000000..375f82a --- /dev/null +++ b/models/xPatch_SparseChannel/xPatch.py @@ -0,0 +1,56 @@ +import torch +import torch.nn as nn +import math + +from layers.decomp import DECOMP +from .xPatch_SparseChannel import Network +from layers.revin import RevIN + +class Model(nn.Module): + def __init__(self, configs): + super(Model, self).__init__() + + # Parameters + seq_len = configs.seq_len # lookback window L + pred_len = configs.pred_len # prediction length (96, 192, 336, 720) + c_in = configs.enc_in # input channels + + # Patching + patch_len = configs.patch_len + stride = configs.stride + padding_patch = configs.padding_patch + + # Normalization + self.revin = configs.revin + self.revin_layer = RevIN(c_in,affine=True,subtract_last=False) + + # Moving Average + self.ma_type = configs.ma_type + alpha = configs.alpha # smoothing factor for EMA (Exponential Moving Average) + beta = configs.beta # smoothing factor for DEMA (Double Exponential Moving Average) + + self.decomp = DECOMP(self.ma_type, alpha, beta) + self.net = Network(seq_len, pred_len, patch_len, stride, padding_patch,c_in) + # self.net_mlp = NetworkMLP(seq_len, pred_len) # For ablation study with MLP-only stream + # self.net_cnn = NetworkCNN(seq_len, pred_len, patch_len, stride, padding_patch) # For ablation study with CNN-only stream + + def forward(self, x): + # x: [Batch, Input, Channel] + + # Normalization + if self.revin: + x = self.revin_layer(x, 'norm') + + if self.ma_type == 'reg': # If no decomposition, directly pass the input to the network + x = self.net(x, x) + # x = self.net_mlp(x) # For ablation study with MLP-only stream + # x = self.net_cnn(x) # For ablation study with CNN-only stream + else: + seasonal_init, trend_init = self.decomp(x) + x = self.net(seasonal_init, trend_init) + + # Denormalization + if self.revin: + x = self.revin_layer(x, 'denorm') + + return x diff --git a/models/xPatch_SparseChannel/xPatch_SparseChannel.py b/models/xPatch_SparseChannel/xPatch_SparseChannel.py new file mode 100644 index 0000000..9ebdc4e --- /dev/null +++ b/models/xPatch_SparseChannel/xPatch_SparseChannel.py @@ -0,0 +1,63 @@ +import torch +from torch import nn +from .patchtst_ci import SeasonPatch # <<< 新导入 + +class Network(nn.Module): + """ + trend : 原 MLP 线性流 (完全保留) + season : SeasonPatch (PatchTST + Mixer) + """ + def __init__(self, seq_len, pred_len, + patch_len, stride, padding_patch, + c_in): + super().__init__() + + # -------- 季节性流 --------------- + self.season_net = SeasonPatch(c_in=c_in, + seq_len=seq_len, + pred_len=pred_len, + patch_len=patch_len, + stride=stride, + k_graph=5, + d_model=128) + + # --------- 线性趋势流 (原代码保持不变) ---------- + self.pred_len = pred_len + self.fc5 = nn.Linear(seq_len, pred_len * 4) + self.avgpool1 = nn.AvgPool1d(kernel_size=2) + self.ln1 = nn.LayerNorm(pred_len * 2) + + self.fc6 = nn.Linear(pred_len * 2, pred_len) + self.avgpool2 = nn.AvgPool1d(kernel_size=2) + self.ln2 = nn.LayerNorm(pred_len // 2) + + self.fc7 = nn.Linear(pred_len // 2, pred_len) + + # 流结果拼接 + self.fc_final = nn.Linear(pred_len * 2, pred_len) + + # ---------------- forward -------------------- + def forward(self, s, t): + # 输入形状: [B,L,C] + B,L,C = s.shape + + # ---------- Seasonality ------------ + y_season = self.season_net(s) # [B,C,T] + + # ---------- Trend (原 MLP) ---------- + t = t.permute(0,2,1).reshape(B*C, L) # [B*C,L] + t = self.fc5(t) + t = self.avgpool1(t) + t = self.ln1(t) + t = self.fc6(t) + t = self.avgpool2(t) + t = self.ln2(t) + t = self.fc7(t) # [B*C,T] + y_trend = t.view(B, C, -1) # [B,C,T] + + # --------- 拼接 & 输出 -------------- + y = torch.cat([y_season, y_trend], dim=-1) # [B,C,2T] + y = self.fc_final(y) # [B,C,T] + y = y.permute(0,2,1) # [B,T,C] + return y + diff --git a/train/train.py b/train/train.py index f1f05bd..d881efc 100644 --- a/train/train.py +++ b/train/train.py @@ -8,6 +8,8 @@ from torch.utils.data import DataLoader, TensorDataset import swanlab from typing import Dict, Any, Optional, Callable, Union, Tuple from dataflow import data_provider +from layers.ps_loss import PSLoss +from utils.tools import adjust_learning_rate, dotdict class EarlyStopping: """Early stopping to stop training when validation performance doesn't improve.""" @@ -138,7 +140,9 @@ def create_data_loaders_from_dataflow(args, use_x_mark: bool = True) -> Dict[str 'test': test_loader } -def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool = True) -> Dict[str, DataLoader]: +def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool = True, + num_workers: int = 4, pin_memory: bool = True, + persistent_workers: bool = True) -> Dict[str, DataLoader]: """ Create PyTorch DataLoaders from an NPZ file @@ -146,6 +150,9 @@ def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool = data_path (str): Path to the NPZ file containing the data batch_size (int): Batch size for the DataLoaders use_x_mark (bool): Whether to use time features (x_mark) from the data file + num_workers (int): Number of worker processes for data loading + pin_memory (bool): Whether to pin memory for faster GPU transfer + persistent_workers (bool): Whether to keep workers alive between epochs Returns: Dict[str, DataLoader]: Dictionary with train, val, and test DataLoaders @@ -200,10 +207,34 @@ def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool = val_dataset = TensorDataset(val_x, val_y) test_dataset = TensorDataset(test_x, test_y) - # Create dataloaders - train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) - val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) - test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) + # Create dataloaders with performance optimizations + train_loader = DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers if num_workers > 0 else False, + drop_last=True # Drop incomplete batches for training + ) + val_loader = DataLoader( + val_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers if num_workers > 0 else False, + drop_last=False + ) + test_loader = DataLoader( + test_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers if num_workers > 0 else False, + drop_last=False + ) return { 'train': train_loader, @@ -223,7 +254,12 @@ def train_forecasting_model( log_interval: int = 10, use_x_mark: bool = True, dataset_mode: str = "npz", - dataflow_args = None + dataflow_args = None, + use_ps_loss: bool = False, + ps_lambda: float = 5.0, + patch_len_threshold: int = 64, + use_gdw: bool = True, + lr_adjust_strategy: str = "type1" ) -> Tuple[nn.Module, Dict[str, float]]: """ Train a time series forecasting model @@ -241,6 +277,11 @@ def train_forecasting_model( use_x_mark (bool): Whether to use time features (x_mark) from the data file dataset_mode (str): Dataset construction mode - "npz" or "dataflow" dataflow_args: Arguments object for dataflow mode (required if dataset_mode="dataflow") + use_ps_loss (bool): Whether to use Patch-wise Structural (PS) loss instead of MSE + ps_lambda (float): Weight for PS loss component when combined with MSE + patch_len_threshold (int): Maximum patch length for adaptive patching + use_gdw (bool): Whether to use Gradient-based Dynamic Weighting + lr_adjust_strategy (str): Learning rate adjustment strategy - 'type1', 'type2', 'type3', 'sigmoid', 'constant', '3', '4', '5', '6' Returns: Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics @@ -271,7 +312,10 @@ def train_forecasting_model( dataloaders = create_data_loaders( data_path=data_path, batch_size=config.get('batch_size', 32), - use_x_mark=use_x_mark + use_x_mark=use_x_mark, + num_workers=config.get('num_workers', 4), + pin_memory=config.get('pin_memory', True), + persistent_workers=config.get('persistent_workers', True) ) # Construct the model @@ -279,14 +323,24 @@ def train_forecasting_model( model = model.to(device) # Define loss function and optimizer - criterion = nn.MSELoss() + if use_ps_loss: + criterion = PSLoss( + patch_len_threshold=patch_len_threshold, + lambda_ps=ps_lambda, + use_gdw=use_gdw + ) + else: + criterion = nn.MSELoss() optimizer = optim.Adam( model.parameters(), lr=config.get('learning_rate', 1e-3), ) - # Add learning rate scheduler to halve LR after each epoch - scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5) + # Create args object for learning rate adjustment + lr_args = dotdict({ + 'learning_rate': config.get('learning_rate', 1e-3), + 'lradj': lr_adjust_strategy + }) # Initialize early stopping early_stopping = EarlyStopping( @@ -334,7 +388,11 @@ def train_forecasting_model( # For simple models without time features outputs = model(inputs) - loss = criterion(outputs, targets) + # Calculate loss + if use_ps_loss: + loss, loss_dict = criterion(outputs, targets, model) + else: + loss = criterion(outputs, targets) # Backward pass and optimize loss.backward() @@ -345,10 +403,26 @@ def train_forecasting_model( interval_loss += loss.item() if (batch_idx + 1) % log_interval == 0: - print(f"Batch {batch_idx+1}/{len(dataloaders['train'])}, Loss: {loss.item():.4f}") - # 计算这一个 interval 的平均损失并记录 - avg_interval_loss = interval_loss / log_interval - swanlab_run.log({"batch_train_loss": avg_interval_loss}) + if use_ps_loss and 'loss_dict' in locals(): + print(f"Batch {batch_idx+1}/{len(dataloaders['train'])}, " + f"Total Loss: {loss.item():.4f}, " + f"MSE: {loss_dict['mse_loss']:.4f}, " + f"PS: {loss_dict['ps_loss']:.4f}") + # Log detailed loss components + swanlab_run.log({ + "batch_total_loss": loss.item(), + "batch_mse_loss": loss_dict['mse_loss'], + "batch_ps_loss": loss_dict['ps_loss'], + "batch_corr_loss": loss_dict['corr_loss'], + "batch_var_loss": loss_dict['var_loss'], + "batch_mean_loss": loss_dict['mean_loss'], + "alpha": loss_dict['alpha'], + "beta": loss_dict['beta'], + "gamma": loss_dict['gamma'] + }) + else: + print(f"Batch {batch_idx+1}/{len(dataloaders['train'])}, Loss: {loss.item():.4f}") + swanlab_run.log({"batch_train_loss": loss.item()}) # 重置 interval loss 以进行下一次计算 interval_loss = 0.0 @@ -360,6 +434,7 @@ def train_forecasting_model( model.eval() val_loss = 0.0 val_mse = 0.0 + val_mse_criterion = nn.MSELoss() # Always use MSE for validation metrics with torch.no_grad(): for batch_data in dataloaders['val']: @@ -381,18 +456,28 @@ def train_forecasting_model( # For simple models without time features outputs = model(inputs) - # Calculate loss - loss = criterion(outputs, targets) - val_loss += loss.item() + # Calculate training loss (PS or MSE) + if use_ps_loss: + loss, _ = criterion(outputs, targets, model) + val_loss += loss.item() + else: + loss = criterion(outputs, targets) + val_loss += loss.item() + + # Always calculate MSE for validation metrics + mse_loss = val_mse_criterion(outputs, targets) + val_mse += mse_loss.item() avg_val_loss = val_loss / len(dataloaders['val']) + avg_val_mse = val_mse / len(dataloaders['val']) current_lr = optimizer.param_groups[0]['lr'] # Log metrics metrics_dict = { "train_loss": avg_train_loss, "val_loss": avg_val_loss, + "val_mse": avg_val_mse, "learning_rate": current_lr, "epoch_time": epoch_time } @@ -402,6 +487,7 @@ def train_forecasting_model( print(f"Epoch {epoch+1}/{max_epochs}, " f"Train Loss: {avg_train_loss:.4f}, " f"Val Loss: {avg_val_loss:.4f}, " + f"Val MSE: {avg_val_mse:.4f}, " f"LR: {current_lr:.6f}, " f"Time: {epoch_time:.2f}s") @@ -416,16 +502,17 @@ def train_forecasting_model( print("Early stopping triggered") break - # Step the learning rate scheduler - scheduler.step() + # Adjust learning rate using utils.tools function + adjust_learning_rate(optimizer, epoch, lr_args) # Load the best model model.load_state_dict(torch.load(checkpoint_path)) - # Test evaluation on the best model + # Test evaluation on the best model - Always use MSE for final evaluation model.eval() test_loss = 0.0 test_mse = 0.0 + mse_criterion = nn.MSELoss() # Always use MSE for test evaluation print("Evaluating on test set...") with torch.no_grad(): @@ -448,16 +535,16 @@ def train_forecasting_model( # For simple models without time features outputs = model(inputs) - # Calculate loss - loss = criterion(outputs, targets) - test_loss += loss.item() + # Always calculate MSE for test evaluation (for fair comparison) + mse_loss = mse_criterion(outputs, targets) + test_loss += mse_loss.item() test_loss /= len(dataloaders['test']) print(f"Test evaluation completed!") print(f"Test Loss (MSE): {test_loss:.6f}") - # Final validation for consistency + # Final validation for consistency - Always use MSE for final metrics model.eval() final_val_loss = 0.0 final_val_mse = 0.0 @@ -482,25 +569,31 @@ def train_forecasting_model( # For simple models without time features outputs = model(inputs) - # Calculate loss - loss = criterion(outputs, targets) - final_val_loss += loss.item() + # Always calculate MSE for final validation (for fair comparison) + mse_loss = mse_criterion(outputs, targets) + final_val_loss += mse_loss.item() final_val_loss /= len(dataloaders['val']) - print(f"Final validation loss: {final_val_loss:.6f}") + print(f"Final validation MSE: {final_val_loss:.6f}") + print(f"Final test MSE: {test_loss:.6f}") + + if use_ps_loss: + print("Note: Model was trained with PS Loss but evaluated with MSE for fair comparison") # Log final test results to swanlab final_metrics = { - "final_test_loss": test_loss, - "final_val_loss": final_val_loss + "final_test_mse": test_loss, + "final_val_mse": final_val_loss } swanlab_run.log(final_metrics) - # Update metrics with final values + # Update metrics with final values (always MSE for comparison) metrics["final_val_loss"] = final_val_loss metrics["final_test_loss"] = test_loss + metrics["final_val_mse"] = final_val_loss # Same as final_val_loss since we use MSE + metrics["final_test_mse"] = test_loss # Same as final_test_loss since we use MSE # Finish the swanlab run swanlab_run.finish() @@ -519,7 +612,8 @@ def train_classification_model( log_interval: int = 10, use_x_mark: bool = True, dataset_mode: str = "npz", - dataflow_args = None + dataflow_args = None, + lr_adjust_strategy: str = "type1" ) -> Tuple[nn.Module, Dict[str, float]]: """ Train a time series classification model @@ -537,6 +631,7 @@ def train_classification_model( use_x_mark (bool): Whether to use time features (x_mark) from the data file dataset_mode (str): Dataset construction mode - "npz" or "dataflow" dataflow_args: Arguments object for dataflow mode (required if dataset_mode="dataflow") + lr_adjust_strategy (str): Learning rate adjustment strategy - 'type1', 'type2', 'type3', 'sigmoid', 'constant', '3', '4', '5', '6' Returns: Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics @@ -567,7 +662,10 @@ def train_classification_model( dataloaders = create_data_loaders( data_path=data_path, batch_size=config.get('batch_size', 32), - use_x_mark=use_x_mark + use_x_mark=use_x_mark, + num_workers=config.get('num_workers', 4), + pin_memory=config.get('pin_memory', True), + persistent_workers=config.get('persistent_workers', True) ) # Construct the model @@ -582,8 +680,11 @@ def train_classification_model( weight_decay=config.get('weight_decay', 1e-4) ) - # Add learning rate scheduler to halve LR after each epoch - scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5) + # Create args object for learning rate adjustment + lr_args = dotdict({ + 'learning_rate': config.get('learning_rate', 1e-3), + 'lradj': lr_adjust_strategy + }) # Initialize early stopping early_stopping = EarlyStopping( @@ -722,8 +823,8 @@ def train_classification_model( print("Early stopping triggered") break - # Step the learning rate scheduler - scheduler.step() + # Adjust learning rate using utils.tools function + adjust_learning_rate(optimizer, epoch, lr_args) # Load the best model model.load_state_dict(torch.load(checkpoint_path)) diff --git a/train/train_diffusion.py b/train/train_diffusion.py new file mode 100644 index 0000000..15277d1 --- /dev/null +++ b/train/train_diffusion.py @@ -0,0 +1,408 @@ +import os +import time +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader, TensorDataset +import swanlab +from typing import Dict, Any, Optional, Callable, Union, Tuple +from dataflow import data_provider + +class EarlyStopping: + """Early stopping to stop training when validation performance doesn't improve.""" + def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'): + self.patience = patience + self.verbose = verbose + self.counter = 0 + self.best_score = None + self.early_stop = False + self.val_loss_min = float('inf') + self.delta = delta + self.path = path + + def __call__(self, val_loss, model): + score = -val_loss + + if self.best_score is None: + self.best_score = score + self.save_checkpoint(val_loss, model) + elif score < self.best_score + self.delta: + self.counter += 1 + if self.verbose: + print(f'EarlyStopping counter: {self.counter} out of {self.patience}') + if self.counter >= self.patience: + self.early_stop = True + else: + self.best_score = score + self.save_checkpoint(val_loss, model) + self.counter = 0 + + def save_checkpoint(self, val_loss, model): + """Save model when validation loss decreases.""" + if self.verbose: + print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...') + torch.save(model.state_dict(), self.path) + self.val_loss_min = val_loss + +class DatasetWrapperWithoutTimeFeatures(torch.utils.data.Dataset): + """Wrapper to remove time features from dataflow datasets when use_x_mark=False""" + def __init__(self, original_dataset): + self.original_dataset = original_dataset + + def __getitem__(self, index): + seq_x, seq_y, seq_x_mark, seq_y_mark = self.original_dataset[index] + return seq_x, seq_y + + def __len__(self): + return len(self.original_dataset) + + def inverse_transform(self, data): + if hasattr(self.original_dataset, 'inverse_transform'): + return self.original_dataset.inverse_transform(data) + return data + +def create_data_loaders_from_dataflow(args, use_x_mark: bool = True) -> Dict[str, DataLoader]: + """Create PyTorch DataLoaders using dataflow data_provider""" + train_data, _ = data_provider(args, flag='train') + val_data, _ = data_provider(args, flag='val') + test_data, _ = data_provider(args, flag='test') + + if not use_x_mark: + train_data = DatasetWrapperWithoutTimeFeatures(train_data) + val_data = DatasetWrapperWithoutTimeFeatures(val_data) + test_data = DatasetWrapperWithoutTimeFeatures(test_data) + + train_shuffle = True + val_shuffle = False + test_shuffle = False + + train_drop_last = True + val_drop_last = True + test_drop_last = True + + batch_size = args.batch_size + num_workers = args.num_workers + + train_loader = DataLoader( + train_data, batch_size=batch_size, shuffle=train_shuffle, + num_workers=num_workers, drop_last=train_drop_last + ) + val_loader = DataLoader( + val_data, batch_size=batch_size, shuffle=val_shuffle, + num_workers=num_workers, drop_last=val_drop_last + ) + test_loader = DataLoader( + test_data, batch_size=batch_size, shuffle=test_shuffle, + num_workers=num_workers, drop_last=test_drop_last + ) + + return {'train': train_loader, 'val': val_loader, 'test': test_loader} + +def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool = True) -> Dict[str, DataLoader]: + """ + Create PyTorch DataLoaders from an NPZ file + + Args: + data_path (str): Path to the NPZ file containing the data + batch_size (int): Batch size for the DataLoaders + use_x_mark (bool): Whether to use time features (x_mark) from the data file + + Returns: + Dict[str, DataLoader]: Dictionary with train, val, and test DataLoaders + """ + # Load data from NPZ file + data = np.load(data_path, allow_pickle=True) + train_x = data['train_x'] + train_y = data['train_y'] + val_x = data['val_x'] + val_y = data['val_y'] + test_x = data['test_x'] + test_y = data['test_y'] + + # Load time features if available and needed + if use_x_mark: + train_x_mark = data.get('train_x_mark', None) + train_y_mark = data.get('train_y_mark', None) + val_x_mark = data.get('val_x_mark', None) + val_y_mark = data.get('val_y_mark', None) + test_x_mark = data.get('test_x_mark', None) + test_y_mark = data.get('test_y_mark', None) + else: + train_x_mark = None + train_y_mark = None + val_x_mark = None + val_y_mark = None + test_x_mark = None + test_y_mark = None + + # Convert to PyTorch tensors + train_x = torch.FloatTensor(train_x) + train_y = torch.FloatTensor(train_y) + val_x = torch.FloatTensor(val_x) + val_y = torch.FloatTensor(val_y) + test_x = torch.FloatTensor(test_x) + test_y = torch.FloatTensor(test_y) + + # Create datasets based on whether time features are available + if train_x_mark is not None: + train_x_mark = torch.FloatTensor(train_x_mark) + train_y_mark = torch.FloatTensor(train_y_mark) + val_x_mark = torch.FloatTensor(val_x_mark) + val_y_mark = torch.FloatTensor(val_y_mark) + test_x_mark = torch.FloatTensor(test_x_mark) + test_y_mark = torch.FloatTensor(test_y_mark) + + train_dataset = TensorDataset(train_x, train_y, train_x_mark, train_y_mark) + val_dataset = TensorDataset(val_x, val_y, val_x_mark, val_y_mark) + test_dataset = TensorDataset(test_x, test_y, test_x_mark, test_y_mark) + else: + train_dataset = TensorDataset(train_x, train_y) + val_dataset = TensorDataset(val_x, val_y) + test_dataset = TensorDataset(test_x, test_y) + + # Create dataloaders + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) + + return { + 'train': train_loader, + 'val': val_loader, + 'test': test_loader + } + +def train_diffusion_model( + model_constructor: Callable, + data_path: str, + project_name: str, + config: Dict[str, Any], + device: Optional[str] = None, + early_stopping_patience: int = 10, + max_epochs: int = 100, + checkpoint_dir: str = "./checkpoints", + log_interval: int = 10, +) -> Tuple[nn.Module, Dict[str, float]]: + """ + Train a Diffusion time series forecasting model using NPZ data loading + """ + # Setup device + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + # Initialize swanlab for experiment tracking + swanlab_run = swanlab.init( + project=project_name, + config=config, + ) + + # Create checkpoint directory if it doesn't exist + os.makedirs(checkpoint_dir, exist_ok=True) + checkpoint_path = os.path.join(checkpoint_dir, f"{project_name}.pt") + + # Create data loaders using NPZ files (following other models' pattern) + dataloaders = create_data_loaders( + data_path=data_path, + batch_size=config.get('batch_size', 32), + use_x_mark=False # DiffusionTimeSeries doesn't use time features + ) + + # Construct the model + model = model_constructor() + model = model.to(device) + + print(f"Model created with {model.get_num_params():,} parameters") + + # Define optimizer for diffusion training + optimizer = optim.Adam( + model.parameters(), + lr=config.get('learning_rate', 1e-4), + weight_decay=config.get('weight_decay', 1e-4) + ) + + # Learning rate scheduler + scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode='min', patience=5, factor=0.5 + ) + + # Initialize early stopping + early_stopping = EarlyStopping( + patience=early_stopping_patience, + verbose=True, + path=checkpoint_path + ) + + # Training loop + best_val_loss = float('inf') + metrics = {} + + for epoch in range(max_epochs): + print(f"\nEpoch {epoch+1}/{max_epochs}") + print("-" * 50) + + # Training phase + model.train() + train_loss = 0.0 + train_samples = 0 + + interval_loss = 0.0 + start_time = time.time() + + for batch_idx, (seq_x, seq_y) in enumerate(dataloaders['train']): + seq_x, seq_y = seq_x.to(device), seq_y.to(device) + + optimizer.zero_grad() + + # Diffusion training: model returns loss directly when y is provided + loss = model(seq_x, seq_y) + + loss.backward() + optimizer.step() + + train_loss += loss.item() + interval_loss += loss.item() + train_samples += 1 + + # Log at intervals + if (batch_idx + 1) % log_interval == 0: + elapsed_time = time.time() - start_time + avg_interval_loss = interval_loss / log_interval + + print(f' Batch [{batch_idx+1}/{len(dataloaders["train"])}] ' + f'Loss: {avg_interval_loss:.6f} ' + f'Time: {elapsed_time:.2f}s') + + # Log to swanlab + swanlab.log({ + 'batch_loss': avg_interval_loss, + 'batch': epoch * len(dataloaders['train']) + batch_idx, + 'learning_rate': optimizer.param_groups[0]['lr'] + }) + + interval_loss = 0.0 + start_time = time.time() + + avg_train_loss = train_loss / train_samples + + # Validation phase - Use faster sampling for validation + model.eval() + val_loss = 0.0 + val_samples = 0 + criterion = nn.MSELoss() + + print(" Validating...") + with torch.no_grad(): + # Temporarily reduce diffusion steps for faster validation + original_timesteps = model.diffusion.num_timesteps + model.diffusion.num_timesteps = 200# Much faster validation + + for batch_idx, (seq_x, seq_y) in enumerate(dataloaders['val']): + seq_x, seq_y = seq_x.to(device), seq_y.to(device) + + # Generate predictions (inference mode with reduced steps) + pred = model(seq_x) + + # Compute MSE loss for validation + loss = criterion(pred, seq_y) + val_loss += loss.item() + val_samples += 1 + + # Print validation progress for first epoch + if epoch == 0 and (batch_idx + 1) % 50 == 0: + print(f" Val Batch [{batch_idx+1}/{len(dataloaders['val'])}]") + + # Early break for very first epoch to speed up + if epoch == 0 and batch_idx >= 100: # Only validate on first 100 batches for first epoch + break + + # Restore original timesteps + model.diffusion.num_timesteps = original_timesteps + + avg_val_loss = val_loss / val_samples + + # Learning rate scheduling + scheduler.step(avg_val_loss) + current_lr = optimizer.param_groups[0]['lr'] + + print(f" Train Loss: {avg_train_loss:.6f}") + print(f" Val Loss: {avg_val_loss:.6f}") + print(f" Learning Rate: {current_lr:.2e}") + + # Log to swanlab + swanlab.log({ + 'epoch': epoch + 1, + 'train_loss': avg_train_loss, + 'val_loss': avg_val_loss, + 'learning_rate': current_lr + }) + + # Early stopping check + early_stopping(avg_val_loss, model) + + if early_stopping.early_stop: + print(f"Early stopping at epoch {epoch + 1}") + break + + # Load best model + model.load_state_dict(torch.load(checkpoint_path, map_location=device)) + + # Final evaluation on test set + print("\nEvaluating on test set...") + model.eval() + test_loss = 0.0 + test_samples = 0 + all_preds = [] + all_targets = [] + + with torch.no_grad(): + # Use reduced timesteps for faster testing + original_timesteps = model.diffusion.num_timesteps + model.diffusion.num_timesteps = 200 # Faster but still good quality + + for batch_idx, (seq_x, seq_y) in enumerate(dataloaders['test']): + seq_x, seq_y = seq_x.to(device), seq_y.to(device) + + pred = model(seq_x) + + loss = criterion(pred, seq_y) + test_loss += loss.item() + test_samples += 1 + + all_preds.append(pred.cpu().numpy()) + all_targets.append(seq_y.cpu().numpy()) + + # Print progress every 50 batches + if (batch_idx + 1) % 50 == 0: + print(f" Test Batch [{batch_idx+1}/{len(dataloaders['test'])}]") + + # Restore original timesteps + model.diffusion.num_timesteps = original_timesteps + + avg_test_loss = test_loss / test_samples + + # Calculate additional metrics + all_preds = np.concatenate(all_preds, axis=0) + all_targets = np.concatenate(all_targets, axis=0) + + mse = np.mean((all_preds - all_targets) ** 2) + mae = np.mean(np.abs(all_preds - all_targets)) + rmse = np.sqrt(mse) + + metrics = { + 'test_mse': mse, + 'test_mae': mae, + 'test_rmse': rmse, + 'test_loss': avg_test_loss + } + + print(f"Test Results:") + print(f" MSE: {mse:.6f}") + print(f" MAE: {mae:.6f}") + print(f" RMSE: {rmse:.6f}") + + # Log final results + swanlab.log(metrics) + + swanlab.finish() + + return model, metrics diff --git a/train_xpatch_sparse_multi_datasets.py b/train_xpatch_sparse_multi_datasets.py new file mode 100644 index 0000000..4dfb269 --- /dev/null +++ b/train_xpatch_sparse_multi_datasets.py @@ -0,0 +1,329 @@ +#!/usr/bin/env python3 +""" +Training script for xPatch_SparseChannel model on multiple datasets. +Supports Weather, Traffic, Electricity, Exchange, and ILI datasets with sigmoid learning rate adjustment. +""" + +import os +import math +import argparse +import torch +import torch.nn as nn +from train.train import train_forecasting_model +from models.xPatch_SparseChannel.xPatch import Model as xPatchSparseChannel + +# Dataset configurations +DATASET_CONFIGS = { + 'Weather': { + 'csv_file': 'weather.csv', + 'enc_in': 21, + 'batch_size': 2048, + 'learning_rate': 0.0005, + 'target': 'OT', + 'data_path': 'weather/weather.csv', + 'seq_len': 96, + 'pred_lengths': [96, 192, 336, 720] + }, + 'Traffic': { + 'csv_file': 'traffic.csv', + 'enc_in': 862, + 'batch_size': 32, + 'learning_rate': 0.002, + 'target': 'OT', + 'data_path': 'traffic/traffic.csv', + 'seq_len': 96, + 'pred_lengths': [96, 192, 336, 720] + }, + 'Electricity': { + 'csv_file': 'electricity.csv', + 'enc_in': 321, + 'batch_size': 32, + 'learning_rate': 0.001, + 'target': 'OT', + 'data_path': 'electricity/electricity.csv', + 'seq_len': 96, + 'pred_lengths': [96, 192, 336, 720] + }, + 'Exchange': { + 'csv_file': 'exchange_rate.csv', + 'enc_in': 8, + 'batch_size': 128, + 'learning_rate': 0.00005, + 'target': 'OT', + 'data_path': 'exchange_rate/exchange_rate.csv', + 'seq_len': 96, + 'pred_lengths': [96, 192, 336, 720] + }, + 'ILI': { + 'csv_file': 'national_illness.csv', + 'enc_in': 7, + 'batch_size': 32, + 'learning_rate': 0.01, + 'target': 'ot', + 'data_path': 'illness/national_illness.csv', + 'seq_len': 36, + 'pred_lengths': [24, 36, 48, 60] + } +} + +class Args: + """Configuration class for xPatch_SparseChannel model parameters.""" + def __init__(self, dataset_name, pred_len): + dataset_config = DATASET_CONFIGS[dataset_name] + + # Model architecture parameters + self.task_name = 'long_term_forecast' + self.seq_len = dataset_config['seq_len'] # Use dataset-specific seq_len + self.label_len = self.seq_len // 2 # Half of seq_len as label length + self.pred_len = pred_len + self.enc_in = dataset_config['enc_in'] + self.c_out = dataset_config['enc_in'] + + # xPatch specific parameters from reference + self.patch_len = 16 # patch length + self.stride = 8 # stride + self.padding_patch = 'end' # padding on the end + + # Moving Average parameters + self.ma_type = 'ema' # moving average type + self.alpha = 0.3 # alpha parameter for EMA + self.beta = 0.3 # beta parameter for DEMA + + # RevIN normalization + self.revin = 1 # RevIN; True 1 False 0 + + # Time features (not used by xPatch but required by data loader) + self.embed = 'timeF' # Time feature embedding type + self.freq = 'h' # Frequency for time features (hourly) + + # Dataset specific parameters + self.data = 'custom' + self.root_path = './data/' + self.data_path = dataset_config['data_path'] + self.features = 'M' # Multivariate prediction + self.target = dataset_config['target'] # Target column + self.train_only = False + + # Required for dataflow - will be set by config + self.batch_size = dataset_config['batch_size'] + self.num_workers = 8 # Will be overridden by config + + print(f"xPatch_SparseChannel Model configuration for {dataset_name}:") + print(f" - Input channels (C): {self.enc_in}") + print(f" - Patch length: {self.patch_len}") + print(f" - Stride: {self.stride}") + print(f" - Sequence length: {self.seq_len}") # Now dataset-specific + print(f" - Prediction length: {pred_len}") + print(f" - Moving average type: {self.ma_type}") + print(f" - Alpha: {self.alpha}") + print(f" - Beta: {self.beta}") + print(f" - RevIN: {self.revin}") + print(f" - Target: {self.target}") + print(f" - Batch size: {self.batch_size}") + +def create_xpatch_sparse_model(args): + """Create xPatch_SparseChannel model with given configuration.""" + def model_constructor(): + return xPatchSparseChannel(args) + return model_constructor + +def train_single_dataset(dataset_name, pred_len, model_args, cmd_args, use_ps_loss=True): + """Train xPatch_SparseChannel on specified dataset with given prediction length.""" + + dataset_config = DATASET_CONFIGS[dataset_name] + + # Update args for current prediction length + model_args.pred_len = pred_len + + # Update dataflow parameters from command line args + model_args.num_workers = cmd_args.num_workers + + # Create model constructor + model_constructor = create_xpatch_sparse_model(model_args) + + # Training configuration with dataset-specific parameters + config = { + 'learning_rate': dataset_config['learning_rate'], # Dataset-specific learning rate + 'batch_size': dataset_config['batch_size'], # Dataset-specific batch size + 'weight_decay': 1e-4, + 'dataset': dataset_name, + 'pred_len': pred_len, + 'seq_len': model_args.seq_len, + 'patch_len': model_args.patch_len, + 'stride': model_args.stride, + 'ma_type': model_args.ma_type, + 'use_ps_loss': use_ps_loss, + 'num_workers': cmd_args.num_workers, + 'pin_memory': True, + 'persistent_workers': True + } + + # Project name for tracking + loss_suffix = "_PSLoss" if use_ps_loss else "_MSE" + project_name = f"xPatch_SparseChannel_{dataset_name}_pred{pred_len}{loss_suffix}_sigmoid" + + print(f"\n{'='*60}") + print(f"Training {dataset_name} with prediction length {pred_len}") + print(f"Model: xPatch_SparseChannel") + print(f"Loss function: {'PS_Loss' if use_ps_loss else 'MSE'}") + print(f"Learning rate: {dataset_config['learning_rate']}") + print(f"Batch size: {dataset_config['batch_size']}") + print(f"Features: {dataset_config['enc_in']}") + print(f"Data path: {model_args.root_path}{model_args.data_path}") + print(f"LR adjustment: sigmoid") + print(f"{'='*60}") + + # Train the model + try: + model, metrics = train_forecasting_model( + model_constructor=model_constructor, + data_path=f"{model_args.root_path}{model_args.data_path}", + project_name=project_name, + config=config, + early_stopping_patience=5, + max_epochs=100, + checkpoint_dir="./checkpoints", + log_interval=50, + use_x_mark=False, # xPatch_SparseChannel doesn't use time features + use_ps_loss=use_ps_loss, + ps_lambda=cmd_args.ps_lambda, + patch_len_threshold=64, + use_gdw=True, + dataset_mode="dataflow", + dataflow_args=model_args, + lr_adjust_strategy="sigmoid" # Use sigmoid learning rate adjustment + ) + + print(f"Training completed for {project_name}") + if use_ps_loss: + print(f"Final validation MSE: {metrics.get('final_val_mse', 'N/A'):.6f}") + else: + print(f"Final validation MSE: {metrics.get('final_val_loss', 'N/A'):.6f}") + + return model, metrics + + except Exception as e: + print(f"Error training {project_name}: {e}") + import traceback + traceback.print_exc() + return None, None + +def main(): + parser = argparse.ArgumentParser(description='Train xPatch_SparseChannel on multiple datasets with sigmoid LR adjustment') + parser.add_argument('--datasets', nargs='+', type=str, + default=['Weather', 'Traffic', 'Electricity', 'Exchange', 'ILI'], + choices=['Weather', 'Traffic', 'Electricity', 'Exchange', 'ILI'], + help='List of datasets to train on') + parser.add_argument('--use_ps_loss', action='store_true', default=True, + help='Use PS_Loss instead of MSE') + parser.add_argument('--ps_lambda', type=float, default=5.0, + help='Weight for PS loss component') + parser.add_argument('--device', type=str, default=None, + help='Device to use for training (cuda/cpu)') + parser.add_argument('--num_workers', type=int, default=8, + help='Number of data loading workers') + + args = parser.parse_args() + + print("xPatch_SparseChannel Multi-Dataset Training Script with Sigmoid LR Adjustment") + print("=" * 80) + print(f"Datasets: {args.datasets}") + print(f"Use PS_Loss: {args.use_ps_loss}") + print(f"PS_Lambda: {args.ps_lambda}") + print(f"Number of workers: {args.num_workers}") + print(f"Learning rate adjustment: sigmoid") + + # Display dataset configurations + print("\nDataset Configurations:") + for dataset in args.datasets: + config = DATASET_CONFIGS[dataset] + print(f" {dataset}:") + print(f" - Features: {config['enc_in']}") + print(f" - Batch size: {config['batch_size']}") + print(f" - Learning rate: {config['learning_rate']}") + print(f" - Sequence length: {config['seq_len']}") + print(f" - Prediction lengths: {config['pred_lengths']}") + print(f" - Data path: {config['data_path']}") + + # Check if data files exist + missing_datasets = [] + for dataset in args.datasets: + config = DATASET_CONFIGS[dataset] + data_path = f"./data/{config['data_path']}" + if not os.path.exists(data_path): + missing_datasets.append(f"{dataset}: '{data_path}'") + + if missing_datasets: + print(f"\nError: The following dataset files were not found:") + for missing in missing_datasets: + print(f" - {missing}") + print("Please ensure all dataset files are available in the data/ directory.") + return + + # Set device + if args.device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + else: + device = args.device + print(f"\nUsing device: {device}") + + # Training results storage + all_results = {} + + # Train on each dataset + for dataset in args.datasets: + print(f"\n{'#'*80}") + print(f"STARTING TRAINING ON {dataset.upper()} DATASET") + print(f"{'#'*80}") + + all_results[dataset] = {} + config = DATASET_CONFIGS[dataset] + + # Train on each prediction length for this dataset + for pred_len in config['pred_lengths']: # Use dataset-specific prediction lengths + # Create model configuration for current dataset + model_args = Args( + dataset_name=dataset, + pred_len=pred_len + ) + + # Train the model + model, metrics = train_single_dataset( + dataset_name=dataset, + pred_len=pred_len, + model_args=model_args, + cmd_args=args, + use_ps_loss=args.use_ps_loss + ) + + # Store results + all_results[dataset][pred_len] = { + 'model': model, + 'metrics': metrics, + 'data_path': f"./data/{config['data_path']}" + } + + # Print comprehensive summary + print("\n" + "=" * 100) + print("COMPREHENSIVE TRAINING SUMMARY") + print("=" * 100) + + for dataset in args.datasets: + config = DATASET_CONFIGS[dataset] + print(f"\n{dataset} (Features: {config['enc_in']}, Batch: {config['batch_size']}, LR: {config['learning_rate']}, Seq: {config['seq_len']}):") + for pred_len in all_results[dataset]: + result = all_results[dataset][pred_len] + if result['metrics'] is not None: + if args.use_ps_loss: + mse = result['metrics'].get('final_val_mse', 'N/A') + else: + mse = result['metrics'].get('final_val_loss', 'N/A') + print(f" Pred Length {pred_len}: MSE = {mse}") + else: + print(f" Pred Length {pred_len}: Training failed") + + print(f"\nAll models saved in: ./checkpoints/") + print("All datasets training completed!") + +if __name__ == "__main__": + main()