diff --git a/src/diffusion/flow_matching/adam_sampling.py b/src/diffusion/flow_matching/adam_sampling.py index 1d6d119..f5824ae 100644 --- a/src/diffusion/flow_matching/adam_sampling.py +++ b/src/diffusion/flow_matching/adam_sampling.py @@ -93,7 +93,7 @@ class AdamLMSampler(BaseSampler): cfg_x = torch.cat([x, x], dim=0) cfg_t = t_cur.repeat(2) out = net(cfg_x, cfg_t, cfg_condition) - out = self.guidance_fn(out, self.guidances[i]) + out = self.guidance_fn(out, self.guidance) pred_trajectory.append(out) out = torch.zeros_like(out) order = len(self.solver_coeffs[i]) diff --git a/src/diffusion/stateful_flow_matching/adam_sampling.py b/src/diffusion/stateful_flow_matching/adam_sampling.py index 013504b..f3bb80c 100644 --- a/src/diffusion/stateful_flow_matching/adam_sampling.py +++ b/src/diffusion/stateful_flow_matching/adam_sampling.py @@ -98,7 +98,7 @@ class AdamLMSampler(BaseSampler): if i % self.state_refresh_rate == 0: state = None out, state = net(cfg_x, cfg_t, cfg_condition, state) - out = self.guidance_fn(out, self.guidances[i]) + out = self.guidance_fn(out, self.guidance) pred_trajectory.append(out) out = torch.zeros_like(out) order = len(self.solver_coeffs[i])