Files
tsmodel/models/DiffusionTimeSeries/diffusion_ts.py

324 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
# ----------------------- 工具:构造/缩放线性 betas -----------------------
def make_linear_betas(T: int, beta_start=1e-4, beta_end=2e-2, device='cpu'):
return torch.linspace(beta_start, beta_end, T, device=device)
def cumprod_from_betas(betas: torch.Tensor):
alphas = 1.0 - betas
return torch.cumprod(alphas, dim=0) # shape [T]
@torch.no_grad()
def scale_betas_to_target_cumprod(betas: torch.Tensor, target_cumprod: float, max_scale: float = 100.0):
"""
给定一段 betas[1..T],寻找缩放系数 s>0使得 ∏(1 - s*beta_i) = target_cumprod
用二分法在 (0, s_max) 上搜索。确保 0 < s*beta_i < 1。
"""
device = betas.device
eps = 1e-12
s_low = 0.0
s_high = min(max_scale, (1.0 - 1e-6) / (betas.max().item() + eps)) # 使 1 - s*beta > 0
def cumprod_with_scale(s: float):
a = (1.0 - betas * s).clamp(min=1e-6, max=1.0-1e-6)
return torch.cumprod(a, dim=0)[-1].item()
# 若不缩放已接近目标,直接返回
base = cumprod_with_scale(1.0)
if abs(base - target_cumprod) / max(target_cumprod, 1e-12) < 1e-6:
return betas
# 目标在 (0, s_high) 内单调可达,进行二分
for _ in range(60):
mid = 0.5 * (s_low + s_high)
val = cumprod_with_scale(mid)
if val > target_cumprod:
# 乘子太小(噪声弱),需要更大 s
s_low = mid
else:
s_high = mid
s_best = 0.5 * (s_low + s_high)
return (betas * s_best).clamp(min=1e-8, max=1-1e-6)
# ------------------------------ DiT Blocks --------------------------------
class DiTBlock(nn.Module):
def __init__(self, dim: int, heads: int, mlp_ratio=4.0):
super().__init__()
self.ln1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(dim, heads, batch_first=True)
self.ln2 = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio)),
nn.GELU(),
nn.Linear(int(dim * mlp_ratio), dim),
)
def forward(self, x):
# x: [B, C_tokens, D]
h = self.attn(self.ln1(x), self.ln1(x), self.ln1(x))[0]
x = x + h
x = x + self.mlp(self.ln2(x))
return x
class DiTChannelTokens(nn.Module):
"""
Token = 一个通道(变量)。
对于每个通道,输入是 [L] 的时间向量;我们用两条投影:
- W_x : 把 x_t 的时间向量投影成 token 向量
- W_n : 把 noise-level来自 schedule 的 ā 或 b̄ 的时间向量)投影成 token 偏置
注意:不再使用可学习的 t-embedding噪声条件完全由 noise map 决定。
"""
def __init__(self, L: int, C: int, dim: int = 256, depth: int = 8, heads: int = 8):
super().__init__()
self.L = L
self.C = C
self.dim = dim
# 通道嵌入(可选,用于区分变量)
self.channel_embed = nn.Parameter(torch.randn(C, dim) * 0.02)
# 将每个通道的时间序列映射到 token
self.proj_x = nn.Linear(L, dim, bias=False)
# 将每个通道的逐时间噪声强度(例如 [sqrt(ā), sqrt(1-ā)] 拼接后经一层线性)
self.proj_noise = nn.Linear(L, dim, bias=True)
self.blocks = nn.ModuleList([DiTBlock(dim, heads) for _ in range(depth)])
self.ln_f = nn.LayerNorm(dim)
# 反投影回时间长度 L预测 ε(每通道独立投影)
self.head = nn.Linear(dim, L, bias=False)
def forward(self, x_t: torch.Tensor, noise_feat: torch.Tensor):
"""
x_t : [B, L, C]
noise_feat: [B, L, C] (建议传入 sqrt(ā) 或 concat 后先合并到 L 维度,这里用一条投影即可)
返回 ε̂ : [B, L, C]
"""
B, L, C = x_t.shape
assert L == self.L and C == self.C
# 逐通道映射成 token
# 把 (B, L, C) 变 (B, C, L) 再线性
x_tc = x_t.permute(0, 2, 1) # [B, C, L]
n_tc = noise_feat.permute(0, 2, 1) # [B, C, L]
tok = self.proj_x(x_tc) + self.proj_noise(n_tc) # [B, C, D]
tok = tok + self.channel_embed.unsqueeze(0) # broadcast [1, C, D]
for blk in self.blocks:
tok = blk(tok) # [B, C, D]
tok = self.ln_f(tok)
out = self.head(tok) # [B, C, L]
eps_pred = out.permute(0, 2, 1) # [B, L, C]
return eps_pred
# ----------------------- RAD 两阶段扩散通道为token -----------------------
class RADChannelDiT(nn.Module):
def __init__(self,
past_len: int,
future_len: int,
channels: int,
T: int = 1000,
T1_ratio: float = 0.7,
model_dim: int = 256,
depth: int = 8,
heads: int = 8,
beta_start: float = 1e-4,
beta_end: float = 2e-2,
use_cosine_target: bool = True):
"""
- 训练两阶段Phase-1 + Phase-2t 从 [1..T] 均匀采样
- 推理:仅使用 Phase-1t: T1→1只更新未来区域
- Token=通道,每个 token 见到整个时间轴 + 噪声强度时间向量
"""
super().__init__()
self.P = past_len
self.H = future_len
self.C = channels
self.L = past_len + future_len
self.T = T
self.T1 = max(1, int(T * T1_ratio))
self.T2 = T - self.T1
assert self.T2 >= 1, "T1_ratio 不能太大,至少留下 1 步给 Phase-2"
device = torch.device('cpu')
# 目标 ā_T用于把两段线性 schedule 归一到同一最终噪声强度)
if use_cosine_target:
# 参考 cosine 计划,得到一条“全局目标 ā_T”
steps = T + 1
x = torch.linspace(0, T, steps, dtype=torch.float64)
s = 0.008
alphas_cum = torch.cos(((x / T) + s) / (1 + s) * math.pi / 2) ** 2
alphas_cum = alphas_cum / alphas_cum[0]
a_bar_target_T = float(alphas_cum[-1])
else:
# 直接用 DDPM 线性 beta 的结果作为目标
betas_full = make_linear_betas(T, beta_start, beta_end, device)
a_bar_target_T = cumprod_from_betas(betas_full)[-1].item()
# Phase-1 & Phase-2 原始线性 beta
betas1 = make_linear_betas(self.T1, beta_start, beta_end, device)
betas2 = make_linear_betas(self.T2, beta_start, beta_end, device)
# 首先不缩放,计算 ā1[T1], ā2[T2]
a_bar1 = cumprod_from_betas(betas1) # shape [T1]
a_bar2 = cumprod_from_betas(betas2) # shape [T2]
# 缩放 Phase-2 的 betas使 ā1[T1] * ā2'[T2] = 目标 ā_T
target_a2 = a_bar_target_T / (a_bar1[-1].item() + 1e-12)
betas2 = scale_betas_to_target_cumprod(betas2, target_a2)
# 重新计算
# a_bar1 = cumprod_from_betas(betas1).float() # [T1]
a_bar2 = cumprod_from_betas(betas2).float() # [T2]
self.register_buffer("betas1", betas1.float())
self.register_buffer("betas2", betas2.float())
self.register_buffer("alphas1", 1.0 - betas1.float())
self.register_buffer("alphas2", 1.0 - betas2.float())
self.register_buffer("a_bar1", a_bar1)
self.register_buffer("a_bar2", a_bar2)
self.register_buffer("a_bar_target_T", torch.tensor(a_bar_target_T, dtype=torch.float32))
# Backbone: token=通道
self.backbone = DiTChannelTokens(L=self.L, C=self.C, dim=model_dim, depth=depth, heads=heads)
# ------------------------ 内部:构造 mask & āt,i ------------------------
def _mask_future(self, B, device):
# mask: 未来区域=1历史=0形状 [B, L, C](与网络输入 [B,L,C] 对齐)
m = torch.zeros(B, self.L, self.C, device=device)
m[:, self.P:, :] = 1.0
return m
def _a_bar_map_at_t(self, t_scalar: int, B: int, device, mask_future: torch.Tensor):
"""
构造逐像素 āt,i形状 [B, L, C]
- 若 t<=T1未来区域用 ā1[t],历史区域=1
- 若 t> T1未来区域固定 ā1[T1],历史区域用 ā2[t-T1]
"""
if t_scalar <= self.T1:
a_future = self.a_bar1[t_scalar - 1] # 索引从 0 开始
a_past = torch.tensor(1.0, device=device)
else:
a_future = self.a_bar1[-1]
a_past = self.a_bar2[t_scalar - self.T1 - 1]
a_future_map = torch.full((B, self.L, self.C), float(a_future.item()), device=device)
a_past_map = torch.full((B, self.L, self.C), float(a_past.item()), device=device)
a_map = a_past_map * (1 - mask_future) + a_future_map * mask_future
return a_map # [B, L, C]
# ----------------------------- 前向训练 -----------------------------
def forward(self, x_hist: torch.Tensor, x_future: torch.Tensor) -> Tuple[torch.Tensor, dict]:
"""
x_hist : [B, P, C]
x_future : [B, H, C]
训练:采样 t∈[1..T],构造两阶段 āt,i边际加噪 xt并用逐通道 token 的 DiT 预测 ε
"""
B = x_hist.size(0)
device = x_hist.device
x0 = torch.cat([x_hist, x_future], dim=1) # [B, L, C]
# 采样训练步 t (1..T)
t = torch.randint(1, self.T + 1, (B,), device=device, dtype=torch.long)
# 构造 mask 和逐像素 āt,i
mask_fut = self._mask_future(B, device) # [B, L, C]
# 逐样本构造 āt,i不同样本 t 不同,只能用循环或向量化 trickB 通常不大for 循环即可)
a_bar_map = torch.stack([self._a_bar_map_at_t(int(tt.item()), 1, device, mask_fut[0:1])
for tt in t], dim=0).squeeze(1) # [B,L,C]
# 边际加噪
eps = torch.randn_like(x0) # [B,L,C]
x_t = a_bar_map.sqrt() * x0 + (1.0 - a_bar_map).sqrt() * eps
# Spatial Noise Embedding完全由 schedule 决定
# 传入每个像素的 √ā 和 √(1-ā)(或任选其一);这里用 √ā
noise_feat = a_bar_map.sqrt() # [B,L,C]
# 预测 ε
eps_pred = self.backbone(x_t, noise_feat) # [B,L,C]
loss = F.mse_loss(eps_pred, eps)
return loss, {'t_mean': t.float().mean().item()}
# ----------------------------- 采样推理 -----------------------------
@torch.no_grad()
def sample(self, x_hist: torch.Tensor, steps: Optional[int] = None) -> torch.Tensor:
"""
仅 Phase-1 推理t = T1..1,只更新未来区域,历史保持观测值
x_hist : [B,P,C]
return : [B,H,C]
"""
B = x_hist.size(0)
device = x_hist.device
mask_fut = self._mask_future(B, device) # [B,L,C]
# 初始化 x历史=观测,未来=高斯噪声
x = torch.zeros(B, self.L, self.C, device=device)
x[:, :self.P, :] = x_hist
x[:, self.P:, :] = torch.randn(B, self.H, self.C, device=device)
# 支持子采样:把 [T1..1] 均匀下采样到 steps
T1 = self.T1
steps = steps if steps is not None else T1
steps = max(1, min(steps, T1))
ts = torch.linspace(T1, 1, steps, device=device).long().tolist()
# 为 DDPM 更新需要 α_t, β_t仅对未来区域定义
alphas1 = self.alphas1 # [T1]
betas1 = self.betas1
a_bar1 = self.a_bar1
for idx, t_scalar in enumerate(ts):
# 当前 āt,iPhase-1历史=1未来=ā1[t]
a_bar_map = self._a_bar_map_at_t(int(t_scalar), B, device, mask_fut) # [B,L,C]
# 网络条件:用 √ā 作为噪声嵌入
noise_feat = a_bar_map.sqrt()
# 预测 ε
eps_pred = self.backbone(x, noise_feat) # [B,L,C]
# 对未来区域做 DDPM 一步(历史区保持原值)
# 标准 DDPM 公式(像素在未来区域共享同一 α_t、β_t
t_idx = t_scalar - 1
alpha_t = alphas1[t_idx] # 标量
beta_t = betas1[t_idx]
a_bar_t = a_bar1[t_idx]
if t_scalar > 1:
a_bar_prev = a_bar1[t_idx - 1]
else:
a_bar_prev = torch.tensor(1.0, device=device)
# x0 预测(仅用于推导均值,也可直接用μ公式)
x0_pred = (x - (1.0 - a_bar_t).sqrt() * eps_pred) / (a_bar_t.sqrt() + 1e-8)
# 均值μ_t = 1/sqrt(α_t) * (x_t - β_t / sqrt(1 - ā_t) * ε̂)
mean = (x - (beta_t / (1.0 - a_bar_t).sqrt()) * eps_pred) / (alpha_t.sqrt() + 1e-8)
# 采样噪声
if t_scalar > 1:
z = torch.randn_like(x)
else:
z = torch.zeros_like(x)
# 方差项DDPMσ_t = sqrt(β_t)
x_next = mean + z * beta_t.sqrt()
# 仅替换未来区域
x = x * (1 - mask_fut) + x_next * mask_fut
# 历史强制为观测
x[:, :self.P, :] = x_hist
return x[:, self.P:, :] # [B,H,C]