app demo
This commit is contained in:
@@ -56,7 +56,7 @@ class EulerSampler(BaseSampler):
|
|||||||
logger.warning("current sampler is ODE sampler, but w_scheduler is enabled")
|
logger.warning("current sampler is ODE sampler, but w_scheduler is enabled")
|
||||||
|
|
||||||
# init recompute
|
# init recompute
|
||||||
self.num_recompute_timesteps = int(self.num_steps / self.state_refresh_rate)
|
|
||||||
self.recompute_timesteps = list(range(self.num_steps))
|
self.recompute_timesteps = list(range(self.num_steps))
|
||||||
|
|
||||||
def sharing_dp(self, net, noise, condition, uncondition):
|
def sharing_dp(self, net, noise, condition, uncondition):
|
||||||
@@ -143,6 +143,7 @@ class EulerSampler(BaseSampler):
|
|||||||
return x, pooled_state_list
|
return x, pooled_state_list
|
||||||
|
|
||||||
def __call__(self, net, noise, condition, uncondition):
|
def __call__(self, net, noise, condition, uncondition):
|
||||||
|
self.num_recompute_timesteps = int(self.num_steps / self.state_refresh_rate)
|
||||||
if len(self.recompute_timesteps) != self.num_recompute_timesteps:
|
if len(self.recompute_timesteps) != self.num_recompute_timesteps:
|
||||||
self.recompute_timesteps = self.sharing_dp(net, noise, condition, uncondition)
|
self.recompute_timesteps = self.sharing_dp(net, noise, condition, uncondition)
|
||||||
denoised, _ = self._impl_sampling(net, noise, condition, uncondition)
|
denoised, _ = self._impl_sampling(net, noise, condition, uncondition)
|
||||||
|
|||||||
Reference in New Issue
Block a user