submit code
This commit is contained in:
123
src/lightning_model.py
Normal file
123
src/lightning_model.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from typing import Callable, Iterable, Any, Optional, Union, Sequence, Mapping, Dict
|
||||
import os.path
|
||||
import copy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import lightning.pytorch as pl
|
||||
from lightning.pytorch.utilities.types import OptimizerLRScheduler, STEP_OUTPUT
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
from torch.optim import Optimizer
|
||||
from lightning.pytorch.callbacks import Callback
|
||||
|
||||
|
||||
from src.models.vae import BaseVAE, fp2uint8
|
||||
from src.models.conditioner import BaseConditioner
|
||||
from src.utils.model_loader import ModelLoader
|
||||
from src.callbacks.simple_ema import SimpleEMA
|
||||
from src.diffusion.base.sampling import BaseSampler
|
||||
from src.diffusion.base.training import BaseTrainer
|
||||
from src.utils.no_grad import no_grad, filter_nograd_tensors
|
||||
from src.utils.copy import copy_params
|
||||
|
||||
EMACallable = Callable[[nn.Module, nn.Module], SimpleEMA]
|
||||
OptimizerCallable = Callable[[Iterable], Optimizer]
|
||||
LRSchedulerCallable = Callable[[Optimizer], LRScheduler]
|
||||
|
||||
|
||||
class LightningModel(pl.LightningModule):
|
||||
def __init__(self,
|
||||
vae: BaseVAE,
|
||||
conditioner: BaseConditioner,
|
||||
denoiser: nn.Module,
|
||||
diffusion_trainer: BaseTrainer,
|
||||
diffusion_sampler: BaseSampler,
|
||||
ema_tracker: Optional[EMACallable] = None,
|
||||
optimizer: OptimizerCallable = None,
|
||||
lr_scheduler: LRSchedulerCallable = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.vae = vae
|
||||
self.conditioner = conditioner
|
||||
self.denoiser = denoiser
|
||||
self.ema_denoiser = copy.deepcopy(self.denoiser)
|
||||
self.diffusion_sampler = diffusion_sampler
|
||||
self.diffusion_trainer = diffusion_trainer
|
||||
self.ema_tracker = ema_tracker
|
||||
self.optimizer = optimizer
|
||||
self.lr_scheduler = lr_scheduler
|
||||
# self.model_loader = ModelLoader()
|
||||
|
||||
self._strict_loading = False
|
||||
|
||||
def configure_model(self) -> None:
|
||||
self.trainer.strategy.barrier()
|
||||
# self.denoiser = self.model_loader.load(self.denoiser)
|
||||
copy_params(src_model=self.denoiser, dst_model=self.ema_denoiser)
|
||||
|
||||
# self.denoiser = torch.compile(self.denoiser)
|
||||
# disable grad for conditioner and vae
|
||||
no_grad(self.conditioner)
|
||||
no_grad(self.vae)
|
||||
no_grad(self.diffusion_sampler)
|
||||
no_grad(self.ema_denoiser)
|
||||
|
||||
def configure_callbacks(self) -> Union[Sequence[Callback], Callback]:
|
||||
ema_tracker = self.ema_tracker(self.denoiser, self.ema_denoiser)
|
||||
return [ema_tracker]
|
||||
|
||||
def configure_optimizers(self) -> OptimizerLRScheduler:
|
||||
params_denoiser = filter_nograd_tensors(self.denoiser.parameters())
|
||||
params_trainer = filter_nograd_tensors(self.diffusion_trainer.parameters())
|
||||
optimizer: torch.optim.Optimizer = self.optimizer([*params_trainer, *params_denoiser])
|
||||
if self.lr_scheduler is None:
|
||||
return dict(
|
||||
optimizer=optimizer
|
||||
)
|
||||
else:
|
||||
lr_scheduler = self.lr_scheduler(optimizer)
|
||||
return dict(
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler
|
||||
)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
raw_images, x, y = batch
|
||||
with torch.no_grad():
|
||||
x = self.vae.encode(x)
|
||||
condition, uncondition = self.conditioner(y)
|
||||
loss = self.diffusion_trainer(self.denoiser, self.ema_denoiser, raw_images, x, condition, uncondition)
|
||||
self.log_dict(loss, prog_bar=True, on_step=True, sync_dist=False)
|
||||
return loss["loss"]
|
||||
|
||||
def predict_step(self, batch, batch_idx):
|
||||
xT, y, metadata = batch
|
||||
with torch.no_grad():
|
||||
condition, uncondition = self.conditioner(y)
|
||||
# Sample images:
|
||||
samples = self.diffusion_sampler(self.denoiser, xT, condition, uncondition)
|
||||
samples = self.vae.decode(samples)
|
||||
# fp32 -1,1 -> uint8 0,255
|
||||
samples = fp2uint8(samples)
|
||||
return samples
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
samples = self.predict_step(batch, batch_idx)
|
||||
return samples
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
|
||||
if destination is None:
|
||||
destination = {}
|
||||
self._save_to_state_dict(destination, prefix, keep_vars)
|
||||
self.denoiser.state_dict(
|
||||
destination=destination,
|
||||
prefix=prefix+"denoiser.",
|
||||
keep_vars=keep_vars)
|
||||
self.ema_denoiser.state_dict(
|
||||
destination=destination,
|
||||
prefix=prefix+"ema_denoiser.",
|
||||
keep_vars=keep_vars)
|
||||
self.diffusion_trainer.state_dict(
|
||||
destination=destination,
|
||||
prefix=prefix+"diffusion_trainer.",
|
||||
keep_vars=keep_vars)
|
||||
return destination
|
||||
Reference in New Issue
Block a user