feat: add DiffusionTimeSeries and iTransformer models, introduce xPatch_SparseChannel
This commit is contained in:
323
models/DiffusionTimeSeries/diffusion_ts.py
Normal file
323
models/DiffusionTimeSeries/diffusion_ts.py
Normal file
@ -0,0 +1,323 @@
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# ----------------------- 工具:构造/缩放线性 betas -----------------------
|
||||
def make_linear_betas(T: int, beta_start=1e-4, beta_end=2e-2, device='cpu'):
|
||||
return torch.linspace(beta_start, beta_end, T, device=device)
|
||||
|
||||
def cumprod_from_betas(betas: torch.Tensor):
|
||||
alphas = 1.0 - betas
|
||||
return torch.cumprod(alphas, dim=0) # shape [T]
|
||||
|
||||
@torch.no_grad()
|
||||
def scale_betas_to_target_cumprod(betas: torch.Tensor, target_cumprod: float, max_scale: float = 100.0):
|
||||
"""
|
||||
给定一段 betas[1..T],寻找缩放系数 s>0,使得 ∏(1 - s*beta_i) = target_cumprod
|
||||
用二分法在 (0, s_max) 上搜索。确保 0 < s*beta_i < 1。
|
||||
"""
|
||||
device = betas.device
|
||||
eps = 1e-12
|
||||
s_low = 0.0
|
||||
s_high = min(max_scale, (1.0 - 1e-6) / (betas.max().item() + eps)) # 使 1 - s*beta > 0
|
||||
|
||||
def cumprod_with_scale(s: float):
|
||||
a = (1.0 - betas * s).clamp(min=1e-6, max=1.0-1e-6)
|
||||
return torch.cumprod(a, dim=0)[-1].item()
|
||||
|
||||
# 若不缩放已接近目标,直接返回
|
||||
base = cumprod_with_scale(1.0)
|
||||
if abs(base - target_cumprod) / max(target_cumprod, 1e-12) < 1e-6:
|
||||
return betas
|
||||
|
||||
# 目标在 (0, s_high) 内单调可达,进行二分
|
||||
for _ in range(60):
|
||||
mid = 0.5 * (s_low + s_high)
|
||||
val = cumprod_with_scale(mid)
|
||||
if val > target_cumprod:
|
||||
# 乘子太小(噪声弱),需要更大 s
|
||||
s_low = mid
|
||||
else:
|
||||
s_high = mid
|
||||
s_best = 0.5 * (s_low + s_high)
|
||||
return (betas * s_best).clamp(min=1e-8, max=1-1e-6)
|
||||
|
||||
|
||||
# ------------------------------ DiT Blocks --------------------------------
|
||||
class DiTBlock(nn.Module):
|
||||
def __init__(self, dim: int, heads: int, mlp_ratio=4.0):
|
||||
super().__init__()
|
||||
self.ln1 = nn.LayerNorm(dim)
|
||||
self.attn = nn.MultiheadAttention(dim, heads, batch_first=True)
|
||||
self.ln2 = nn.LayerNorm(dim)
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(dim, int(dim * mlp_ratio)),
|
||||
nn.GELU(),
|
||||
nn.Linear(int(dim * mlp_ratio), dim),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# x: [B, C_tokens, D]
|
||||
h = self.attn(self.ln1(x), self.ln1(x), self.ln1(x))[0]
|
||||
x = x + h
|
||||
x = x + self.mlp(self.ln2(x))
|
||||
return x
|
||||
|
||||
|
||||
class DiTChannelTokens(nn.Module):
|
||||
"""
|
||||
Token = 一个通道(变量)。
|
||||
对于每个通道,输入是 [L] 的时间向量;我们用两条投影:
|
||||
- W_x : 把 x_t 的时间向量投影成 token 向量
|
||||
- W_n : 把 noise-level(来自 schedule 的 ā 或 b̄ 的时间向量)投影成 token 偏置
|
||||
注意:不再使用可学习的 t-embedding;噪声条件完全由 noise map 决定。
|
||||
"""
|
||||
def __init__(self, L: int, C: int, dim: int = 256, depth: int = 8, heads: int = 8):
|
||||
super().__init__()
|
||||
self.L = L
|
||||
self.C = C
|
||||
self.dim = dim
|
||||
|
||||
# 通道嵌入(可选,用于区分变量)
|
||||
self.channel_embed = nn.Parameter(torch.randn(C, dim) * 0.02)
|
||||
|
||||
# 将每个通道的时间序列映射到 token
|
||||
self.proj_x = nn.Linear(L, dim, bias=False)
|
||||
# 将每个通道的逐时间噪声强度(例如 [sqrt(ā), sqrt(1-ā)] 拼接后经一层线性)
|
||||
self.proj_noise = nn.Linear(L, dim, bias=True)
|
||||
|
||||
self.blocks = nn.ModuleList([DiTBlock(dim, heads) for _ in range(depth)])
|
||||
self.ln_f = nn.LayerNorm(dim)
|
||||
|
||||
# 反投影回时间长度 L,预测 ε(每通道独立投影)
|
||||
self.head = nn.Linear(dim, L, bias=False)
|
||||
|
||||
def forward(self, x_t: torch.Tensor, noise_feat: torch.Tensor):
|
||||
"""
|
||||
x_t : [B, L, C]
|
||||
noise_feat: [B, L, C] (建议传入 sqrt(ā) 或 concat 后先合并到 L 维度,这里用一条投影即可)
|
||||
返回 ε̂ : [B, L, C]
|
||||
"""
|
||||
B, L, C = x_t.shape
|
||||
assert L == self.L and C == self.C
|
||||
|
||||
# 逐通道映射成 token
|
||||
# 把 (B, L, C) 变 (B, C, L) 再线性
|
||||
x_tc = x_t.permute(0, 2, 1) # [B, C, L]
|
||||
n_tc = noise_feat.permute(0, 2, 1) # [B, C, L]
|
||||
|
||||
tok = self.proj_x(x_tc) + self.proj_noise(n_tc) # [B, C, D]
|
||||
tok = tok + self.channel_embed.unsqueeze(0) # broadcast [1, C, D]
|
||||
|
||||
for blk in self.blocks:
|
||||
tok = blk(tok) # [B, C, D]
|
||||
|
||||
tok = self.ln_f(tok)
|
||||
out = self.head(tok) # [B, C, L]
|
||||
eps_pred = out.permute(0, 2, 1) # [B, L, C]
|
||||
return eps_pred
|
||||
|
||||
|
||||
# ----------------------- RAD 两阶段扩散(通道为token) -----------------------
|
||||
class RADChannelDiT(nn.Module):
|
||||
def __init__(self,
|
||||
past_len: int,
|
||||
future_len: int,
|
||||
channels: int,
|
||||
T: int = 1000,
|
||||
T1_ratio: float = 0.7,
|
||||
model_dim: int = 256,
|
||||
depth: int = 8,
|
||||
heads: int = 8,
|
||||
beta_start: float = 1e-4,
|
||||
beta_end: float = 2e-2,
|
||||
use_cosine_target: bool = True):
|
||||
"""
|
||||
- 训练:两阶段(Phase-1 + Phase-2),t 从 [1..T] 均匀采样
|
||||
- 推理:仅使用 Phase-1(t: T1→1),只更新未来区域
|
||||
- Token=通道,每个 token 见到整个时间轴 + 噪声强度时间向量
|
||||
"""
|
||||
super().__init__()
|
||||
self.P = past_len
|
||||
self.H = future_len
|
||||
self.C = channels
|
||||
self.L = past_len + future_len
|
||||
self.T = T
|
||||
self.T1 = max(1, int(T * T1_ratio))
|
||||
self.T2 = T - self.T1
|
||||
assert self.T2 >= 1, "T1_ratio 不能太大,至少留下 1 步给 Phase-2"
|
||||
|
||||
device = torch.device('cpu')
|
||||
|
||||
# 目标 ā_T(用于把两段线性 schedule 归一到同一最终噪声强度)
|
||||
if use_cosine_target:
|
||||
# 参考 cosine 计划,得到一条“全局目标 ā_T”
|
||||
steps = T + 1
|
||||
x = torch.linspace(0, T, steps, dtype=torch.float64)
|
||||
s = 0.008
|
||||
alphas_cum = torch.cos(((x / T) + s) / (1 + s) * math.pi / 2) ** 2
|
||||
alphas_cum = alphas_cum / alphas_cum[0]
|
||||
a_bar_target_T = float(alphas_cum[-1])
|
||||
else:
|
||||
# 直接用 DDPM 线性 beta 的结果作为目标
|
||||
betas_full = make_linear_betas(T, beta_start, beta_end, device)
|
||||
a_bar_target_T = cumprod_from_betas(betas_full)[-1].item()
|
||||
|
||||
# Phase-1 & Phase-2 原始线性 beta
|
||||
betas1 = make_linear_betas(self.T1, beta_start, beta_end, device)
|
||||
betas2 = make_linear_betas(self.T2, beta_start, beta_end, device)
|
||||
|
||||
# 首先不缩放,计算 ā1[T1], ā2[T2]
|
||||
a_bar1 = cumprod_from_betas(betas1) # shape [T1]
|
||||
a_bar2 = cumprod_from_betas(betas2) # shape [T2]
|
||||
|
||||
# 缩放 Phase-2 的 betas,使 ā1[T1] * ā2'[T2] = 目标 ā_T
|
||||
target_a2 = a_bar_target_T / (a_bar1[-1].item() + 1e-12)
|
||||
betas2 = scale_betas_to_target_cumprod(betas2, target_a2)
|
||||
|
||||
# 重新计算
|
||||
# a_bar1 = cumprod_from_betas(betas1).float() # [T1]
|
||||
a_bar2 = cumprod_from_betas(betas2).float() # [T2]
|
||||
|
||||
self.register_buffer("betas1", betas1.float())
|
||||
self.register_buffer("betas2", betas2.float())
|
||||
self.register_buffer("alphas1", 1.0 - betas1.float())
|
||||
self.register_buffer("alphas2", 1.0 - betas2.float())
|
||||
self.register_buffer("a_bar1", a_bar1)
|
||||
self.register_buffer("a_bar2", a_bar2)
|
||||
self.register_buffer("a_bar_target_T", torch.tensor(a_bar_target_T, dtype=torch.float32))
|
||||
|
||||
# Backbone: token=通道
|
||||
self.backbone = DiTChannelTokens(L=self.L, C=self.C, dim=model_dim, depth=depth, heads=heads)
|
||||
|
||||
# ------------------------ 内部:构造 mask & āt,i ------------------------
|
||||
def _mask_future(self, B, device):
|
||||
# mask: 未来区域=1,历史=0,形状 [B, L, C](与网络输入 [B,L,C] 对齐)
|
||||
m = torch.zeros(B, self.L, self.C, device=device)
|
||||
m[:, self.P:, :] = 1.0
|
||||
return m
|
||||
|
||||
def _a_bar_map_at_t(self, t_scalar: int, B: int, device, mask_future: torch.Tensor):
|
||||
"""
|
||||
构造逐像素 āt,i,形状 [B, L, C]
|
||||
- 若 t<=T1:未来区域用 ā1[t],历史区域=1
|
||||
- 若 t> T1:未来区域固定 ā1[T1],历史区域用 ā2[t-T1]
|
||||
"""
|
||||
if t_scalar <= self.T1:
|
||||
a_future = self.a_bar1[t_scalar - 1] # 索引从 0 开始
|
||||
a_past = torch.tensor(1.0, device=device)
|
||||
else:
|
||||
a_future = self.a_bar1[-1]
|
||||
a_past = self.a_bar2[t_scalar - self.T1 - 1]
|
||||
|
||||
a_future_map = torch.full((B, self.L, self.C), float(a_future.item()), device=device)
|
||||
a_past_map = torch.full((B, self.L, self.C), float(a_past.item()), device=device)
|
||||
a_map = a_past_map * (1 - mask_future) + a_future_map * mask_future
|
||||
return a_map # [B, L, C]
|
||||
|
||||
# ----------------------------- 前向训练 -----------------------------
|
||||
def forward(self, x_hist: torch.Tensor, x_future: torch.Tensor) -> Tuple[torch.Tensor, dict]:
|
||||
"""
|
||||
x_hist : [B, P, C]
|
||||
x_future : [B, H, C]
|
||||
训练:采样 t∈[1..T],构造两阶段 āt,i,边际加噪 xt,并用逐通道 token 的 DiT 预测 ε
|
||||
"""
|
||||
B = x_hist.size(0)
|
||||
device = x_hist.device
|
||||
x0 = torch.cat([x_hist, x_future], dim=1) # [B, L, C]
|
||||
|
||||
# 采样训练步 t (1..T)
|
||||
t = torch.randint(1, self.T + 1, (B,), device=device, dtype=torch.long)
|
||||
|
||||
# 构造 mask 和逐像素 āt,i
|
||||
mask_fut = self._mask_future(B, device) # [B, L, C]
|
||||
# 逐样本构造 āt,i(不同样本 t 不同,只能用循环或向量化 trick;B 通常不大,for 循环即可)
|
||||
a_bar_map = torch.stack([self._a_bar_map_at_t(int(tt.item()), 1, device, mask_fut[0:1])
|
||||
for tt in t], dim=0).squeeze(1) # [B,L,C]
|
||||
|
||||
# 边际加噪
|
||||
eps = torch.randn_like(x0) # [B,L,C]
|
||||
x_t = a_bar_map.sqrt() * x0 + (1.0 - a_bar_map).sqrt() * eps
|
||||
|
||||
# Spatial Noise Embedding:完全由 schedule 决定
|
||||
# 传入每个像素的 √ā 和 √(1-ā)(或任选其一);这里用 √ā
|
||||
noise_feat = a_bar_map.sqrt() # [B,L,C]
|
||||
|
||||
# 预测 ε
|
||||
eps_pred = self.backbone(x_t, noise_feat) # [B,L,C]
|
||||
|
||||
loss = F.mse_loss(eps_pred, eps)
|
||||
return loss, {'t_mean': t.float().mean().item()}
|
||||
|
||||
# ----------------------------- 采样推理 -----------------------------
|
||||
@torch.no_grad()
|
||||
def sample(self, x_hist: torch.Tensor, steps: Optional[int] = None) -> torch.Tensor:
|
||||
"""
|
||||
仅 Phase-1 推理:t = T1..1,只更新未来区域,历史保持观测值
|
||||
x_hist : [B,P,C]
|
||||
return : [B,H,C]
|
||||
"""
|
||||
B = x_hist.size(0)
|
||||
device = x_hist.device
|
||||
mask_fut = self._mask_future(B, device) # [B,L,C]
|
||||
|
||||
# 初始化 x:历史=观测,未来=高斯噪声
|
||||
x = torch.zeros(B, self.L, self.C, device=device)
|
||||
x[:, :self.P, :] = x_hist
|
||||
x[:, self.P:, :] = torch.randn(B, self.H, self.C, device=device)
|
||||
|
||||
# 支持子采样:把 [T1..1] 均匀下采样到 steps
|
||||
T1 = self.T1
|
||||
steps = steps if steps is not None else T1
|
||||
steps = max(1, min(steps, T1))
|
||||
ts = torch.linspace(T1, 1, steps, device=device).long().tolist()
|
||||
# 为 DDPM 更新需要 α_t, β_t(仅对未来区域定义)
|
||||
alphas1 = self.alphas1 # [T1]
|
||||
betas1 = self.betas1
|
||||
a_bar1 = self.a_bar1
|
||||
|
||||
for idx, t_scalar in enumerate(ts):
|
||||
# 当前 āt,i(Phase-1:历史=1,未来=ā1[t])
|
||||
a_bar_map = self._a_bar_map_at_t(int(t_scalar), B, device, mask_fut) # [B,L,C]
|
||||
# 网络条件:用 √ā 作为噪声嵌入
|
||||
noise_feat = a_bar_map.sqrt()
|
||||
# 预测 ε
|
||||
eps_pred = self.backbone(x, noise_feat) # [B,L,C]
|
||||
|
||||
# 对未来区域做 DDPM 一步(历史区保持原值)
|
||||
# 标准 DDPM 公式(像素在未来区域共享同一 α_t、β_t)
|
||||
t_idx = t_scalar - 1
|
||||
alpha_t = alphas1[t_idx] # 标量
|
||||
beta_t = betas1[t_idx]
|
||||
a_bar_t = a_bar1[t_idx]
|
||||
if t_scalar > 1:
|
||||
a_bar_prev = a_bar1[t_idx - 1]
|
||||
else:
|
||||
a_bar_prev = torch.tensor(1.0, device=device)
|
||||
|
||||
# x0 预测(仅用于推导均值,也可直接用μ公式)
|
||||
x0_pred = (x - (1.0 - a_bar_t).sqrt() * eps_pred) / (a_bar_t.sqrt() + 1e-8)
|
||||
|
||||
# 均值:μ_t = 1/sqrt(α_t) * (x_t - β_t / sqrt(1 - ā_t) * ε̂)
|
||||
mean = (x - (beta_t / (1.0 - a_bar_t).sqrt()) * eps_pred) / (alpha_t.sqrt() + 1e-8)
|
||||
|
||||
# 采样噪声
|
||||
if t_scalar > 1:
|
||||
z = torch.randn_like(x)
|
||||
else:
|
||||
z = torch.zeros_like(x)
|
||||
|
||||
# 方差项(DDPM):σ_t = sqrt(β_t)
|
||||
x_next = mean + z * beta_t.sqrt()
|
||||
|
||||
# 仅替换未来区域
|
||||
x = x * (1 - mask_fut) + x_next * mask_fut
|
||||
# 历史强制为观测
|
||||
x[:, :self.P, :] = x_hist
|
||||
|
||||
return x[:, self.P:, :] # [B,H,C]
|
||||
|
Reference in New Issue
Block a user