diff --git a/src/diffusion/flow_matching/adam_sampling.py b/src/diffusion/flow_matching/adam_sampling.py index 15d0c78..1d6d119 100644 --- a/src/diffusion/flow_matching/adam_sampling.py +++ b/src/diffusion/flow_matching/adam_sampling.py @@ -59,7 +59,7 @@ class AdamLMSampler(BaseSampler): timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps) timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0) self.timesteps = shift_respace_fn(timesteps, timeshift) - self.timedeltas = timesteps[1:] - self.timesteps[:-1] + self.timedeltas = self.timesteps[1:] - self.timesteps[:-1] self._reparameterize_coeffs() def _reparameterize_coeffs(self): diff --git a/src/diffusion/stateful_flow_matching/adam_sampling.py b/src/diffusion/stateful_flow_matching/adam_sampling.py index fb2e95b..013504b 100644 --- a/src/diffusion/stateful_flow_matching/adam_sampling.py +++ b/src/diffusion/stateful_flow_matching/adam_sampling.py @@ -61,7 +61,7 @@ class AdamLMSampler(BaseSampler): timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps) timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0) self.timesteps = shift_respace_fn(timesteps, timeshift) - self.timedeltas = timesteps[1:] - self.timesteps[:-1] + self.timedeltas = self.timesteps[1:] - self.timesteps[:-1] self._reparameterize_coeffs() def _reparameterize_coeffs(self):