From 99d92c94e7e7ef27994507e4fb623a930ec28c53 Mon Sep 17 00:00:00 2001 From: wangshuai6 Date: Tue, 20 May 2025 12:30:40 +0800 Subject: [PATCH] fix bugs(admas timedeltas) --- src/diffusion/flow_matching/adam_sampling.py | 14 ++++++++++---- .../stateful_flow_matching/adam_sampling.py | 14 ++++++++++---- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/src/diffusion/flow_matching/adam_sampling.py b/src/diffusion/flow_matching/adam_sampling.py index f5824ae..be3c6c3 100644 --- a/src/diffusion/flow_matching/adam_sampling.py +++ b/src/diffusion/flow_matching/adam_sampling.py @@ -41,23 +41,25 @@ class AdamLMSampler(BaseSampler): self, order: int = 2, timeshift: float = 1.0, + guidance_interval_min: float = 0.0, + guidance_interval_max: float = 1.0, lms_transform_fn: Callable = nop, - w_scheduler: BaseScheduler = None, step_fn: Callable = ode_step_fn, *args, **kwargs ): super().__init__(*args, **kwargs) self.step_fn = step_fn - self.w_scheduler = w_scheduler assert self.scheduler is not None - assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ] + assert self.step_fn in [ode_step_fn, ] self.order = order self.lms_transform_fn = lms_transform_fn timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps) timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0) + self.guidance_interval_min = guidance_interval_min + self.guidance_interval_max = guidance_interval_max self.timesteps = shift_respace_fn(timesteps, timeshift) self.timedeltas = self.timesteps[1:] - self.timesteps[:-1] self._reparameterize_coeffs() @@ -93,7 +95,11 @@ class AdamLMSampler(BaseSampler): cfg_x = torch.cat([x, x], dim=0) cfg_t = t_cur.repeat(2) out = net(cfg_x, cfg_t, cfg_condition) - out = self.guidance_fn(out, self.guidance) + if t_cur[0] > self.guidance_interval_min and t_cur[0] < self.guidance_interval_max: + guidance = self.guidance + out = self.guidance_fn(out, guidance) + else: + out = self.guidance_fn(out, 1.0) pred_trajectory.append(out) out = torch.zeros_like(out) order = len(self.solver_coeffs[i]) diff --git a/src/diffusion/stateful_flow_matching/adam_sampling.py b/src/diffusion/stateful_flow_matching/adam_sampling.py index f3bb80c..6497adf 100644 --- a/src/diffusion/stateful_flow_matching/adam_sampling.py +++ b/src/diffusion/stateful_flow_matching/adam_sampling.py @@ -41,22 +41,24 @@ class AdamLMSampler(BaseSampler): self, order: int = 2, timeshift: float = 1.0, + guidance_interval_min: float = 0.0, + guidance_interval_max: float = 1.0, state_refresh_rate: int = 1, lms_transform_fn: Callable = nop, - w_scheduler: BaseScheduler = None, step_fn: Callable = ode_step_fn, *args, **kwargs ): super().__init__(*args, **kwargs) self.step_fn = step_fn - self.w_scheduler = w_scheduler self.state_refresh_rate = state_refresh_rate assert self.scheduler is not None - assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ] + assert self.step_fn in [ode_step_fn, ] self.order = order self.lms_transform_fn = lms_transform_fn + self.guidance_interval_min = guidance_interval_min + self.guidance_interval_max = guidance_interval_max timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps) timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0) @@ -98,7 +100,11 @@ class AdamLMSampler(BaseSampler): if i % self.state_refresh_rate == 0: state = None out, state = net(cfg_x, cfg_t, cfg_condition, state) - out = self.guidance_fn(out, self.guidance) + if t_cur[0] > self.guidance_interval_min and t_cur[0] < self.guidance_interval_max: + guidance = self.guidance + out = self.guidance_fn(out, guidance) + else: + out = self.guidance_fn(out, 1.0) pred_trajectory.append(out) out = torch.zeros_like(out) order = len(self.solver_coeffs[i])