This commit is contained in:
wangshuai6
2025-04-11 19:14:51 +08:00
parent c1c8043ed1
commit 687d650175

View File

@@ -56,7 +56,7 @@ class EulerSampler(BaseSampler):
logger.warning("current sampler is ODE sampler, but w_scheduler is enabled")
# init recompute
self.num_recompute_timesteps = int(self.num_steps / self.state_refresh_rate)
self.recompute_timesteps = list(range(self.num_steps))
def sharing_dp(self, net, noise, condition, uncondition):
@@ -143,6 +143,7 @@ class EulerSampler(BaseSampler):
return x, pooled_state_list
def __call__(self, net, noise, condition, uncondition):
self.num_recompute_timesteps = int(self.num_steps / self.state_refresh_rate)
if len(self.recompute_timesteps) != self.num_recompute_timesteps:
self.recompute_timesteps = self.sharing_dp(net, noise, condition, uncondition)
denoised, _ = self._impl_sampling(net, noise, condition, uncondition)