torch.compile
This commit is contained in:
@@ -34,6 +34,7 @@ class LightningModel(pl.LightningModule):
|
|||||||
ema_tracker: Optional[EMACallable] = None,
|
ema_tracker: Optional[EMACallable] = None,
|
||||||
optimizer: OptimizerCallable = None,
|
optimizer: OptimizerCallable = None,
|
||||||
lr_scheduler: LRSchedulerCallable = None,
|
lr_scheduler: LRSchedulerCallable = None,
|
||||||
|
compile: bool = False
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.vae = vae
|
self.vae = vae
|
||||||
@@ -48,6 +49,7 @@ class LightningModel(pl.LightningModule):
|
|||||||
# self.model_loader = ModelLoader()
|
# self.model_loader = ModelLoader()
|
||||||
|
|
||||||
self._strict_loading = False
|
self._strict_loading = False
|
||||||
|
self._compile = compile
|
||||||
|
|
||||||
def configure_model(self) -> None:
|
def configure_model(self) -> None:
|
||||||
self.trainer.strategy.barrier()
|
self.trainer.strategy.barrier()
|
||||||
@@ -61,6 +63,11 @@ class LightningModel(pl.LightningModule):
|
|||||||
no_grad(self.diffusion_sampler)
|
no_grad(self.diffusion_sampler)
|
||||||
no_grad(self.ema_denoiser)
|
no_grad(self.ema_denoiser)
|
||||||
|
|
||||||
|
# add compile to speed up
|
||||||
|
if self._compile:
|
||||||
|
self.denoiser = torch.compile(self.denoiser)
|
||||||
|
self.ema_denoiser = torch.compile(self.ema_denoiser)
|
||||||
|
|
||||||
def configure_callbacks(self) -> Union[Sequence[Callback], Callback]:
|
def configure_callbacks(self) -> Union[Sequence[Callback], Callback]:
|
||||||
ema_tracker = self.ema_tracker(self.denoiser, self.ema_denoiser)
|
ema_tracker = self.ema_tracker(self.denoiser, self.ema_denoiser)
|
||||||
return [ema_tracker]
|
return [ema_tracker]
|
||||||
|
|||||||
Reference in New Issue
Block a user