fix bugs(admas timedeltas)

This commit is contained in:
wangshuai6
2025-05-20 12:27:34 +08:00
parent 699d89a772
commit 3693640ca3
2 changed files with 2 additions and 2 deletions

View File

@@ -93,7 +93,7 @@ class AdamLMSampler(BaseSampler):
cfg_x = torch.cat([x, x], dim=0) cfg_x = torch.cat([x, x], dim=0)
cfg_t = t_cur.repeat(2) cfg_t = t_cur.repeat(2)
out = net(cfg_x, cfg_t, cfg_condition) out = net(cfg_x, cfg_t, cfg_condition)
out = self.guidance_fn(out, self.guidances[i]) out = self.guidance_fn(out, self.guidance)
pred_trajectory.append(out) pred_trajectory.append(out)
out = torch.zeros_like(out) out = torch.zeros_like(out)
order = len(self.solver_coeffs[i]) order = len(self.solver_coeffs[i])

View File

@@ -98,7 +98,7 @@ class AdamLMSampler(BaseSampler):
if i % self.state_refresh_rate == 0: if i % self.state_refresh_rate == 0:
state = None state = None
out, state = net(cfg_x, cfg_t, cfg_condition, state) out, state = net(cfg_x, cfg_t, cfg_condition, state)
out = self.guidance_fn(out, self.guidances[i]) out = self.guidance_fn(out, self.guidance)
pred_trajectory.append(out) pred_trajectory.append(out)
out = torch.zeros_like(out) out = torch.zeros_like(out)
order = len(self.solver_coeffs[i]) order = len(self.solver_coeffs[i])