torch.compile
This commit is contained in:
@@ -34,6 +34,7 @@ class LightningModel(pl.LightningModule):
|
||||
ema_tracker: Optional[EMACallable] = None,
|
||||
optimizer: OptimizerCallable = None,
|
||||
lr_scheduler: LRSchedulerCallable = None,
|
||||
compile: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
self.vae = vae
|
||||
@@ -48,6 +49,7 @@ class LightningModel(pl.LightningModule):
|
||||
# self.model_loader = ModelLoader()
|
||||
|
||||
self._strict_loading = False
|
||||
self._compile = compile
|
||||
|
||||
def configure_model(self) -> None:
|
||||
self.trainer.strategy.barrier()
|
||||
@@ -61,6 +63,11 @@ class LightningModel(pl.LightningModule):
|
||||
no_grad(self.diffusion_sampler)
|
||||
no_grad(self.ema_denoiser)
|
||||
|
||||
# add compile to speed up
|
||||
if self._compile:
|
||||
self.denoiser = torch.compile(self.denoiser)
|
||||
self.ema_denoiser = torch.compile(self.ema_denoiser)
|
||||
|
||||
def configure_callbacks(self) -> Union[Sequence[Callback], Callback]:
|
||||
ema_tracker = self.ema_tracker(self.denoiser, self.ema_denoiser)
|
||||
return [ema_tracker]
|
||||
|
||||
Reference in New Issue
Block a user