diff --git a/src/lightning_model.py b/src/lightning_model.py index 4602e82..370d1a7 100644 --- a/src/lightning_model.py +++ b/src/lightning_model.py @@ -34,6 +34,7 @@ class LightningModel(pl.LightningModule): ema_tracker: Optional[EMACallable] = None, optimizer: OptimizerCallable = None, lr_scheduler: LRSchedulerCallable = None, + compile: bool = False ): super().__init__() self.vae = vae @@ -48,6 +49,7 @@ class LightningModel(pl.LightningModule): # self.model_loader = ModelLoader() self._strict_loading = False + self._compile = compile def configure_model(self) -> None: self.trainer.strategy.barrier() @@ -61,6 +63,11 @@ class LightningModel(pl.LightningModule): no_grad(self.diffusion_sampler) no_grad(self.ema_denoiser) + # add compile to speed up + if self._compile: + self.denoiser = torch.compile(self.denoiser) + self.ema_denoiser = torch.compile(self.ema_denoiser) + def configure_callbacks(self) -> Union[Sequence[Callback], Callback]: ema_tracker = self.ema_tracker(self.denoiser, self.ema_denoiser) return [ema_tracker]