fix bugs(admas timedeltas)
This commit is contained in:
@@ -59,7 +59,7 @@ class AdamLMSampler(BaseSampler):
|
|||||||
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
|
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
|
||||||
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
|
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
|
||||||
self.timesteps = shift_respace_fn(timesteps, timeshift)
|
self.timesteps = shift_respace_fn(timesteps, timeshift)
|
||||||
self.timedeltas = timesteps[1:] - self.timesteps[:-1]
|
self.timedeltas = self.timesteps[1:] - self.timesteps[:-1]
|
||||||
self._reparameterize_coeffs()
|
self._reparameterize_coeffs()
|
||||||
|
|
||||||
def _reparameterize_coeffs(self):
|
def _reparameterize_coeffs(self):
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ class AdamLMSampler(BaseSampler):
|
|||||||
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
|
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
|
||||||
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
|
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
|
||||||
self.timesteps = shift_respace_fn(timesteps, timeshift)
|
self.timesteps = shift_respace_fn(timesteps, timeshift)
|
||||||
self.timedeltas = timesteps[1:] - self.timesteps[:-1]
|
self.timedeltas = self.timesteps[1:] - self.timesteps[:-1]
|
||||||
self._reparameterize_coeffs()
|
self._reparameterize_coeffs()
|
||||||
|
|
||||||
def _reparameterize_coeffs(self):
|
def _reparameterize_coeffs(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user