fix bugs(admas timedeltas)
This commit is contained in:
@@ -93,7 +93,7 @@ class AdamLMSampler(BaseSampler):
|
|||||||
cfg_x = torch.cat([x, x], dim=0)
|
cfg_x = torch.cat([x, x], dim=0)
|
||||||
cfg_t = t_cur.repeat(2)
|
cfg_t = t_cur.repeat(2)
|
||||||
out = net(cfg_x, cfg_t, cfg_condition)
|
out = net(cfg_x, cfg_t, cfg_condition)
|
||||||
out = self.guidance_fn(out, self.guidances[i])
|
out = self.guidance_fn(out, self.guidance)
|
||||||
pred_trajectory.append(out)
|
pred_trajectory.append(out)
|
||||||
out = torch.zeros_like(out)
|
out = torch.zeros_like(out)
|
||||||
order = len(self.solver_coeffs[i])
|
order = len(self.solver_coeffs[i])
|
||||||
|
|||||||
@@ -98,7 +98,7 @@ class AdamLMSampler(BaseSampler):
|
|||||||
if i % self.state_refresh_rate == 0:
|
if i % self.state_refresh_rate == 0:
|
||||||
state = None
|
state = None
|
||||||
out, state = net(cfg_x, cfg_t, cfg_condition, state)
|
out, state = net(cfg_x, cfg_t, cfg_condition, state)
|
||||||
out = self.guidance_fn(out, self.guidances[i])
|
out = self.guidance_fn(out, self.guidance)
|
||||||
pred_trajectory.append(out)
|
pred_trajectory.append(out)
|
||||||
out = torch.zeros_like(out)
|
out = torch.zeros_like(out)
|
||||||
order = len(self.solver_coeffs[i])
|
order = len(self.solver_coeffs[i])
|
||||||
|
|||||||
Reference in New Issue
Block a user