From 4d4b3262c379a679399b2ff240d5c2a44ba5721b Mon Sep 17 00:00:00 2001 From: wangshuai6 Date: Tue, 20 May 2025 12:25:50 +0800 Subject: [PATCH] fix bugs(admas timedeltas) --- src/diffusion/flow_matching/adam_sampling.py | 2 +- src/diffusion/stateful_flow_matching/adam_sampling.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusion/flow_matching/adam_sampling.py b/src/diffusion/flow_matching/adam_sampling.py index 15d0c78..1d6d119 100644 --- a/src/diffusion/flow_matching/adam_sampling.py +++ b/src/diffusion/flow_matching/adam_sampling.py @@ -59,7 +59,7 @@ class AdamLMSampler(BaseSampler): timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps) timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0) self.timesteps = shift_respace_fn(timesteps, timeshift) - self.timedeltas = timesteps[1:] - self.timesteps[:-1] + self.timedeltas = self.timesteps[1:] - self.timesteps[:-1] self._reparameterize_coeffs() def _reparameterize_coeffs(self): diff --git a/src/diffusion/stateful_flow_matching/adam_sampling.py b/src/diffusion/stateful_flow_matching/adam_sampling.py index fb2e95b..013504b 100644 --- a/src/diffusion/stateful_flow_matching/adam_sampling.py +++ b/src/diffusion/stateful_flow_matching/adam_sampling.py @@ -61,7 +61,7 @@ class AdamLMSampler(BaseSampler): timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps) timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0) self.timesteps = shift_respace_fn(timesteps, timeshift) - self.timedeltas = timesteps[1:] - self.timesteps[:-1] + self.timedeltas = self.timesteps[1:] - self.timesteps[:-1] self._reparameterize_coeffs() def _reparameterize_coeffs(self):