submit code
This commit is contained in:
0
src/callbacks/__init__.py
Normal file
0
src/callbacks/__init__.py
Normal file
22
src/callbacks/grad.py
Normal file
22
src/callbacks/grad.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import torch
|
||||
import lightning.pytorch as pl
|
||||
from lightning.pytorch.utilities import grad_norm
|
||||
from torch.optim import Optimizer
|
||||
|
||||
class GradientMonitor(pl.Callback):
|
||||
"""Logs the gradient norm"""
|
||||
|
||||
def __init__(self, norm_type: int = 2):
|
||||
norm_type = float(norm_type)
|
||||
if norm_type <= 0:
|
||||
raise ValueError(f"`norm_type` must be a positive number or 'inf' (infinity norm). Got {norm_type}")
|
||||
self.norm_type = norm_type
|
||||
|
||||
def on_before_optimizer_step(
|
||||
self, trainer: "pl.Trainer",
|
||||
pl_module: "pl.LightningModule",
|
||||
optimizer: Optimizer
|
||||
) -> None:
|
||||
norms = grad_norm(pl_module, norm_type=self.norm_type)
|
||||
max_grad = torch.tensor([v for k, v in norms.items() if k != f"grad_{self.norm_type}_norm_total"]).max()
|
||||
pl_module.log_dict({'train/grad/max': max_grad, 'train/grad/total': norms[f"grad_{self.norm_type}_norm_total"]})
|
||||
25
src/callbacks/model_checkpoint.py
Normal file
25
src/callbacks/model_checkpoint.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import os.path
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
import lightning.pytorch as pl
|
||||
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
|
||||
from soupsieve.util import lower
|
||||
|
||||
|
||||
class CheckpointHook(ModelCheckpoint):
|
||||
"""Save checkpoint with only the incremental part of the model"""
|
||||
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
|
||||
self.dirpath = trainer.default_root_dir
|
||||
self.exception_ckpt_path = os.path.join(self.dirpath, "on_exception.pt")
|
||||
pl_module.strict_loading = False
|
||||
|
||||
def on_save_checkpoint(
|
||||
self, trainer: "pl.Trainer",
|
||||
pl_module: "pl.LightningModule",
|
||||
checkpoint: Dict[str, Any]
|
||||
) -> None:
|
||||
del checkpoint["callbacks"]
|
||||
|
||||
# def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None:
|
||||
# if not "debug" in self.exception_ckpt_path:
|
||||
# trainer.save_checkpoint(self.exception_ckpt_path)
|
||||
105
src/callbacks/save_images.py
Normal file
105
src/callbacks/save_images.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import lightning.pytorch as pl
|
||||
from lightning.pytorch import Callback
|
||||
|
||||
|
||||
import os.path
|
||||
import numpy
|
||||
from PIL import Image
|
||||
from typing import Sequence, Any, Dict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
||||
from lightning_utilities.core.rank_zero import rank_zero_info
|
||||
|
||||
def process_fn(image, path):
|
||||
Image.fromarray(image).save(path)
|
||||
|
||||
class SaveImagesHook(Callback):
|
||||
def __init__(self, save_dir="val", max_save_num=0, compressed=True):
|
||||
self.save_dir = save_dir
|
||||
self.max_save_num = max_save_num
|
||||
self.compressed = compressed
|
||||
|
||||
def save_start(self, target_dir):
|
||||
self.target_dir = target_dir
|
||||
self.executor_pool = ThreadPoolExecutor(max_workers=8)
|
||||
if not os.path.exists(self.target_dir):
|
||||
os.makedirs(self.target_dir, exist_ok=True)
|
||||
else:
|
||||
if os.listdir(target_dir) and "debug" not in str(target_dir):
|
||||
raise FileExistsError(f'{self.target_dir} already exists and not empty!')
|
||||
self.samples = []
|
||||
self._have_saved_num = 0
|
||||
rank_zero_info(f"Save images to {self.target_dir}")
|
||||
|
||||
def save_image(self, images, filenames):
|
||||
images = images.permute(0, 2, 3, 1).cpu().numpy()
|
||||
for sample, filename in zip(images, filenames):
|
||||
if isinstance(filename, Sequence):
|
||||
filename = filename[0]
|
||||
path = f'{self.target_dir}/{filename}'
|
||||
if self._have_saved_num >= self.max_save_num:
|
||||
break
|
||||
self.executor_pool.submit(process_fn, sample, path)
|
||||
self._have_saved_num += 1
|
||||
|
||||
def process_batch(
|
||||
self,
|
||||
trainer: "pl.Trainer",
|
||||
pl_module: "pl.LightningModule",
|
||||
samples: STEP_OUTPUT,
|
||||
batch: Any,
|
||||
) -> None:
|
||||
b, c, h, w = samples.shape
|
||||
xT, y, metadata = batch
|
||||
all_samples = pl_module.all_gather(samples).view(-1, c, h, w)
|
||||
self.save_image(samples, metadata)
|
||||
if trainer.is_global_zero:
|
||||
all_samples = all_samples.permute(0, 2, 3, 1).cpu().numpy()
|
||||
self.samples.append(all_samples)
|
||||
|
||||
def save_end(self):
|
||||
if self.compressed and len(self.samples) > 0:
|
||||
samples = numpy.concatenate(self.samples)
|
||||
numpy.savez(f'{self.target_dir}/output.npz', arr_0=samples)
|
||||
self.executor_pool.shutdown(wait=True)
|
||||
self.samples = []
|
||||
self.target_dir = None
|
||||
self._have_saved_num = 0
|
||||
self.executor_pool = None
|
||||
|
||||
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
target_dir = os.path.join(trainer.default_root_dir, self.save_dir, f"iter_{trainer.global_step}")
|
||||
self.save_start(target_dir)
|
||||
|
||||
def on_validation_batch_end(
|
||||
self,
|
||||
trainer: "pl.Trainer",
|
||||
pl_module: "pl.LightningModule",
|
||||
outputs: STEP_OUTPUT,
|
||||
batch: Any,
|
||||
batch_idx: int,
|
||||
dataloader_idx: int = 0,
|
||||
) -> None:
|
||||
return self.process_batch(trainer, pl_module, outputs, batch)
|
||||
|
||||
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
self.save_end()
|
||||
|
||||
def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
target_dir = os.path.join(trainer.default_root_dir, self.save_dir, "predict")
|
||||
self.save_start(target_dir)
|
||||
|
||||
def on_predict_batch_end(
|
||||
self,
|
||||
trainer: "pl.Trainer",
|
||||
pl_module: "pl.LightningModule",
|
||||
samples: Any,
|
||||
batch: Any,
|
||||
batch_idx: int,
|
||||
dataloader_idx: int = 0,
|
||||
) -> None:
|
||||
return self.process_batch(trainer, pl_module, samples, batch)
|
||||
|
||||
def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
self.save_end()
|
||||
79
src/callbacks/simple_ema.py
Normal file
79
src/callbacks/simple_ema.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import threading
|
||||
import lightning.pytorch as pl
|
||||
from lightning.pytorch import Callback
|
||||
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
||||
|
||||
from src.utils.copy import swap_tensors
|
||||
|
||||
class SimpleEMA(Callback):
|
||||
def __init__(self, net:nn.Module, ema_net:nn.Module,
|
||||
decay: float = 0.9999,
|
||||
every_n_steps: int = 1,
|
||||
eval_original_model:bool = False
|
||||
):
|
||||
super().__init__()
|
||||
self.decay = decay
|
||||
self.every_n_steps = every_n_steps
|
||||
self.eval_original_model = eval_original_model
|
||||
self._stream = torch.cuda.Stream()
|
||||
|
||||
self.net_params = list(net.parameters())
|
||||
self.ema_params = list(ema_net.parameters())
|
||||
|
||||
def swap_model(self):
|
||||
for ema_p, p, in zip(self.ema_params, self.net_params):
|
||||
swap_tensors(ema_p, p)
|
||||
|
||||
def ema_step(self):
|
||||
@torch.no_grad()
|
||||
def ema_update(ema_model_tuple, current_model_tuple, decay):
|
||||
torch._foreach_mul_(ema_model_tuple, decay)
|
||||
torch._foreach_add_(
|
||||
ema_model_tuple, current_model_tuple, alpha=(1.0 - decay),
|
||||
)
|
||||
|
||||
if self._stream is not None:
|
||||
self._stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(self._stream):
|
||||
ema_update(self.ema_params, self.net_params, self.decay)
|
||||
|
||||
|
||||
def on_train_batch_end(
|
||||
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
|
||||
) -> None:
|
||||
if trainer.global_step % self.every_n_steps == 0:
|
||||
self.ema_step()
|
||||
|
||||
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
if not self.eval_original_model:
|
||||
self.swap_model()
|
||||
|
||||
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
if not self.eval_original_model:
|
||||
self.swap_model()
|
||||
|
||||
def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
if not self.eval_original_model:
|
||||
self.swap_model()
|
||||
|
||||
def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
if not self.eval_original_model:
|
||||
self.swap_model()
|
||||
|
||||
|
||||
def state_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"decay": self.decay,
|
||||
"every_n_steps": self.every_n_steps,
|
||||
"eval_original_model": self.eval_original_model,
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
||||
self.decay = state_dict["decay"]
|
||||
self.every_n_steps = state_dict["every_n_steps"]
|
||||
self.eval_original_model = state_dict["eval_original_model"]
|
||||
|
||||
Reference in New Issue
Block a user