fix bugs(admas timedeltas)
This commit is contained in:
@@ -93,7 +93,7 @@ class AdamLMSampler(BaseSampler):
|
||||
cfg_x = torch.cat([x, x], dim=0)
|
||||
cfg_t = t_cur.repeat(2)
|
||||
out = net(cfg_x, cfg_t, cfg_condition)
|
||||
out = self.guidance_fn(out, self.guidances[i])
|
||||
out = self.guidance_fn(out, self.guidance)
|
||||
pred_trajectory.append(out)
|
||||
out = torch.zeros_like(out)
|
||||
order = len(self.solver_coeffs[i])
|
||||
|
||||
@@ -98,7 +98,7 @@ class AdamLMSampler(BaseSampler):
|
||||
if i % self.state_refresh_rate == 0:
|
||||
state = None
|
||||
out, state = net(cfg_x, cfg_t, cfg_condition, state)
|
||||
out = self.guidance_fn(out, self.guidances[i])
|
||||
out = self.guidance_fn(out, self.guidance)
|
||||
pred_trajectory.append(out)
|
||||
out = torch.zeros_like(out)
|
||||
order = len(self.solver_coeffs[i])
|
||||
|
||||
Reference in New Issue
Block a user