import math from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F # ----------------------- 工具:构造/缩放线性 betas ----------------------- def make_linear_betas(T: int, beta_start=1e-4, beta_end=2e-2, device='cpu'): return torch.linspace(beta_start, beta_end, T, device=device) def cumprod_from_betas(betas: torch.Tensor): alphas = 1.0 - betas return torch.cumprod(alphas, dim=0) # shape [T] @torch.no_grad() def scale_betas_to_target_cumprod(betas: torch.Tensor, target_cumprod: float, max_scale: float = 100.0): """ 给定一段 betas[1..T],寻找缩放系数 s>0,使得 ∏(1 - s*beta_i) = target_cumprod 用二分法在 (0, s_max) 上搜索。确保 0 < s*beta_i < 1。 """ device = betas.device eps = 1e-12 s_low = 0.0 s_high = min(max_scale, (1.0 - 1e-6) / (betas.max().item() + eps)) # 使 1 - s*beta > 0 def cumprod_with_scale(s: float): a = (1.0 - betas * s).clamp(min=1e-6, max=1.0-1e-6) return torch.cumprod(a, dim=0)[-1].item() # 若不缩放已接近目标,直接返回 base = cumprod_with_scale(1.0) if abs(base - target_cumprod) / max(target_cumprod, 1e-12) < 1e-6: return betas # 目标在 (0, s_high) 内单调可达,进行二分 for _ in range(60): mid = 0.5 * (s_low + s_high) val = cumprod_with_scale(mid) if val > target_cumprod: # 乘子太小(噪声弱),需要更大 s s_low = mid else: s_high = mid s_best = 0.5 * (s_low + s_high) return (betas * s_best).clamp(min=1e-8, max=1-1e-6) # ------------------------------ DiT Blocks -------------------------------- class DiTBlock(nn.Module): def __init__(self, dim: int, heads: int, mlp_ratio=4.0): super().__init__() self.ln1 = nn.LayerNorm(dim) self.attn = nn.MultiheadAttention(dim, heads, batch_first=True) self.ln2 = nn.LayerNorm(dim) self.mlp = nn.Sequential( nn.Linear(dim, int(dim * mlp_ratio)), nn.GELU(), nn.Linear(int(dim * mlp_ratio), dim), ) def forward(self, x): # x: [B, C_tokens, D] h = self.attn(self.ln1(x), self.ln1(x), self.ln1(x))[0] x = x + h x = x + self.mlp(self.ln2(x)) return x class DiTChannelTokens(nn.Module): """ Token = 一个通道(变量)。 对于每个通道,输入是 [L] 的时间向量;我们用两条投影: - W_x : 把 x_t 的时间向量投影成 token 向量 - W_n : 把 noise-level(来自 schedule 的 ā 或 b̄ 的时间向量)投影成 token 偏置 注意:不再使用可学习的 t-embedding;噪声条件完全由 noise map 决定。 """ def __init__(self, L: int, C: int, dim: int = 256, depth: int = 8, heads: int = 8): super().__init__() self.L = L self.C = C self.dim = dim # 通道嵌入(可选,用于区分变量) self.channel_embed = nn.Parameter(torch.randn(C, dim) * 0.02) # 将每个通道的时间序列映射到 token self.proj_x = nn.Linear(L, dim, bias=False) # 将每个通道的逐时间噪声强度(例如 [sqrt(ā), sqrt(1-ā)] 拼接后经一层线性) self.proj_noise = nn.Linear(L, dim, bias=True) self.blocks = nn.ModuleList([DiTBlock(dim, heads) for _ in range(depth)]) self.ln_f = nn.LayerNorm(dim) # 反投影回时间长度 L,预测 ε(每通道独立投影) self.head = nn.Linear(dim, L, bias=False) def forward(self, x_t: torch.Tensor, noise_feat: torch.Tensor): """ x_t : [B, L, C] noise_feat: [B, L, C] (建议传入 sqrt(ā) 或 concat 后先合并到 L 维度,这里用一条投影即可) 返回 ε̂ : [B, L, C] """ B, L, C = x_t.shape assert L == self.L and C == self.C # 逐通道映射成 token # 把 (B, L, C) 变 (B, C, L) 再线性 x_tc = x_t.permute(0, 2, 1) # [B, C, L] n_tc = noise_feat.permute(0, 2, 1) # [B, C, L] tok = self.proj_x(x_tc) + self.proj_noise(n_tc) # [B, C, D] tok = tok + self.channel_embed.unsqueeze(0) # broadcast [1, C, D] for blk in self.blocks: tok = blk(tok) # [B, C, D] tok = self.ln_f(tok) out = self.head(tok) # [B, C, L] eps_pred = out.permute(0, 2, 1) # [B, L, C] return eps_pred # ----------------------- RAD 两阶段扩散(通道为token) ----------------------- class RADChannelDiT(nn.Module): def __init__(self, past_len: int, future_len: int, channels: int, T: int = 1000, T1_ratio: float = 0.7, model_dim: int = 256, depth: int = 8, heads: int = 8, beta_start: float = 1e-4, beta_end: float = 2e-2, use_cosine_target: bool = True): """ - 训练:两阶段(Phase-1 + Phase-2),t 从 [1..T] 均匀采样 - 推理:仅使用 Phase-1(t: T1→1),只更新未来区域 - Token=通道,每个 token 见到整个时间轴 + 噪声强度时间向量 """ super().__init__() self.P = past_len self.H = future_len self.C = channels self.L = past_len + future_len self.T = T self.T1 = max(1, int(T * T1_ratio)) self.T2 = T - self.T1 assert self.T2 >= 1, "T1_ratio 不能太大,至少留下 1 步给 Phase-2" device = torch.device('cpu') # 目标 ā_T(用于把两段线性 schedule 归一到同一最终噪声强度) if use_cosine_target: # 参考 cosine 计划,得到一条“全局目标 ā_T” steps = T + 1 x = torch.linspace(0, T, steps, dtype=torch.float64) s = 0.008 alphas_cum = torch.cos(((x / T) + s) / (1 + s) * math.pi / 2) ** 2 alphas_cum = alphas_cum / alphas_cum[0] a_bar_target_T = float(alphas_cum[-1]) else: # 直接用 DDPM 线性 beta 的结果作为目标 betas_full = make_linear_betas(T, beta_start, beta_end, device) a_bar_target_T = cumprod_from_betas(betas_full)[-1].item() # Phase-1 & Phase-2 原始线性 beta betas1 = make_linear_betas(self.T1, beta_start, beta_end, device) betas2 = make_linear_betas(self.T2, beta_start, beta_end, device) # 首先不缩放,计算 ā1[T1], ā2[T2] a_bar1 = cumprod_from_betas(betas1) # shape [T1] a_bar2 = cumprod_from_betas(betas2) # shape [T2] # 缩放 Phase-2 的 betas,使 ā1[T1] * ā2'[T2] = 目标 ā_T target_a2 = a_bar_target_T / (a_bar1[-1].item() + 1e-12) betas2 = scale_betas_to_target_cumprod(betas2, target_a2) # 重新计算 # a_bar1 = cumprod_from_betas(betas1).float() # [T1] a_bar2 = cumprod_from_betas(betas2).float() # [T2] self.register_buffer("betas1", betas1.float()) self.register_buffer("betas2", betas2.float()) self.register_buffer("alphas1", 1.0 - betas1.float()) self.register_buffer("alphas2", 1.0 - betas2.float()) self.register_buffer("a_bar1", a_bar1) self.register_buffer("a_bar2", a_bar2) self.register_buffer("a_bar_target_T", torch.tensor(a_bar_target_T, dtype=torch.float32)) # Backbone: token=通道 self.backbone = DiTChannelTokens(L=self.L, C=self.C, dim=model_dim, depth=depth, heads=heads) # ------------------------ 内部:构造 mask & āt,i ------------------------ def _mask_future(self, B, device): # mask: 未来区域=1,历史=0,形状 [B, L, C](与网络输入 [B,L,C] 对齐) m = torch.zeros(B, self.L, self.C, device=device) m[:, self.P:, :] = 1.0 return m def _a_bar_map_at_t(self, t_scalar: int, B: int, device, mask_future: torch.Tensor): """ 构造逐像素 āt,i,形状 [B, L, C] - 若 t<=T1:未来区域用 ā1[t],历史区域=1 - 若 t> T1:未来区域固定 ā1[T1],历史区域用 ā2[t-T1] """ if t_scalar <= self.T1: a_future = self.a_bar1[t_scalar - 1] # 索引从 0 开始 a_past = torch.tensor(1.0, device=device) else: a_future = self.a_bar1[-1] a_past = self.a_bar2[t_scalar - self.T1 - 1] a_future_map = torch.full((B, self.L, self.C), float(a_future.item()), device=device) a_past_map = torch.full((B, self.L, self.C), float(a_past.item()), device=device) a_map = a_past_map * (1 - mask_future) + a_future_map * mask_future return a_map # [B, L, C] # ----------------------------- 前向训练 ----------------------------- def forward(self, x_hist: torch.Tensor, x_future: torch.Tensor) -> Tuple[torch.Tensor, dict]: """ x_hist : [B, P, C] x_future : [B, H, C] 训练:采样 t∈[1..T],构造两阶段 āt,i,边际加噪 xt,并用逐通道 token 的 DiT 预测 ε """ B = x_hist.size(0) device = x_hist.device x0 = torch.cat([x_hist, x_future], dim=1) # [B, L, C] # 采样训练步 t (1..T) t = torch.randint(1, self.T + 1, (B,), device=device, dtype=torch.long) # 构造 mask 和逐像素 āt,i mask_fut = self._mask_future(B, device) # [B, L, C] # 逐样本构造 āt,i(不同样本 t 不同,只能用循环或向量化 trick;B 通常不大,for 循环即可) a_bar_map = torch.stack([self._a_bar_map_at_t(int(tt.item()), 1, device, mask_fut[0:1]) for tt in t], dim=0).squeeze(1) # [B,L,C] # 边际加噪 eps = torch.randn_like(x0) # [B,L,C] x_t = a_bar_map.sqrt() * x0 + (1.0 - a_bar_map).sqrt() * eps # Spatial Noise Embedding:完全由 schedule 决定 # 传入每个像素的 √ā 和 √(1-ā)(或任选其一);这里用 √ā noise_feat = a_bar_map.sqrt() # [B,L,C] # 预测 ε eps_pred = self.backbone(x_t, noise_feat) # [B,L,C] loss = F.mse_loss(eps_pred, eps) return loss, {'t_mean': t.float().mean().item()} # ----------------------------- 采样推理 ----------------------------- @torch.no_grad() def sample(self, x_hist: torch.Tensor, steps: Optional[int] = None) -> torch.Tensor: """ 仅 Phase-1 推理:t = T1..1,只更新未来区域,历史保持观测值 x_hist : [B,P,C] return : [B,H,C] """ B = x_hist.size(0) device = x_hist.device mask_fut = self._mask_future(B, device) # [B,L,C] # 初始化 x:历史=观测,未来=高斯噪声 x = torch.zeros(B, self.L, self.C, device=device) x[:, :self.P, :] = x_hist x[:, self.P:, :] = torch.randn(B, self.H, self.C, device=device) # 支持子采样:把 [T1..1] 均匀下采样到 steps T1 = self.T1 steps = steps if steps is not None else T1 steps = max(1, min(steps, T1)) ts = torch.linspace(T1, 1, steps, device=device).long().tolist() # 为 DDPM 更新需要 α_t, β_t(仅对未来区域定义) alphas1 = self.alphas1 # [T1] betas1 = self.betas1 a_bar1 = self.a_bar1 for idx, t_scalar in enumerate(ts): # 当前 āt,i(Phase-1:历史=1,未来=ā1[t]) a_bar_map = self._a_bar_map_at_t(int(t_scalar), B, device, mask_fut) # [B,L,C] # 网络条件:用 √ā 作为噪声嵌入 noise_feat = a_bar_map.sqrt() # 预测 ε eps_pred = self.backbone(x, noise_feat) # [B,L,C] # 对未来区域做 DDPM 一步(历史区保持原值) # 标准 DDPM 公式(像素在未来区域共享同一 α_t、β_t) t_idx = t_scalar - 1 alpha_t = alphas1[t_idx] # 标量 beta_t = betas1[t_idx] a_bar_t = a_bar1[t_idx] if t_scalar > 1: a_bar_prev = a_bar1[t_idx - 1] else: a_bar_prev = torch.tensor(1.0, device=device) # x0 预测(仅用于推导均值,也可直接用μ公式) x0_pred = (x - (1.0 - a_bar_t).sqrt() * eps_pred) / (a_bar_t.sqrt() + 1e-8) # 均值:μ_t = 1/sqrt(α_t) * (x_t - β_t / sqrt(1 - ā_t) * ε̂) mean = (x - (beta_t / (1.0 - a_bar_t).sqrt()) * eps_pred) / (alpha_t.sqrt() + 1e-8) # 采样噪声 if t_scalar > 1: z = torch.randn_like(x) else: z = torch.zeros_like(x) # 方差项(DDPM):σ_t = sqrt(β_t) x_next = mean + z * beta_t.sqrt() # 仅替换未来区域 x = x * (1 - mask_fut) + x_next * mask_fut # 历史强制为观测 x[:, :self.P, :] = x_hist return x[:, self.P:, :] # [B,H,C]