torch.compile

This commit is contained in:
wangshuai6
2025-07-03 20:23:59 +08:00
parent 2c16b8f423
commit ae7df43ecc

View File

@@ -34,6 +34,7 @@ class LightningModel(pl.LightningModule):
ema_tracker: Optional[EMACallable] = None, ema_tracker: Optional[EMACallable] = None,
optimizer: OptimizerCallable = None, optimizer: OptimizerCallable = None,
lr_scheduler: LRSchedulerCallable = None, lr_scheduler: LRSchedulerCallable = None,
compile: bool = False
): ):
super().__init__() super().__init__()
self.vae = vae self.vae = vae
@@ -48,6 +49,7 @@ class LightningModel(pl.LightningModule):
# self.model_loader = ModelLoader() # self.model_loader = ModelLoader()
self._strict_loading = False self._strict_loading = False
self._compile = compile
def configure_model(self) -> None: def configure_model(self) -> None:
self.trainer.strategy.barrier() self.trainer.strategy.barrier()
@@ -61,6 +63,11 @@ class LightningModel(pl.LightningModule):
no_grad(self.diffusion_sampler) no_grad(self.diffusion_sampler)
no_grad(self.ema_denoiser) no_grad(self.ema_denoiser)
# add compile to speed up
if self._compile:
self.denoiser = torch.compile(self.denoiser)
self.ema_denoiser = torch.compile(self.ema_denoiser)
def configure_callbacks(self) -> Union[Sequence[Callback], Callback]: def configure_callbacks(self) -> Union[Sequence[Callback], Callback]:
ema_tracker = self.ema_tracker(self.denoiser, self.ema_denoiser) ema_tracker = self.ema_tracker(self.denoiser, self.ema_denoiser)
return [ema_tracker] return [ema_tracker]