From b9c3720eba257ffb9fc036cd2e0475f80ead30e3 Mon Sep 17 00:00:00 2001 From: gameloader Date: Mon, 19 Jan 2026 16:03:54 +0800 Subject: [PATCH] feat(callbacks): add Weights & Biases image logging to SaveImagesHook\n\n- Introduce optional wandb_max_save_num parameter in __init__\n- Add helper methods _get_wandb_logger and _log_wandb_samples\n- Log sampled images as wandb.Image artifacts during validation epoch end (global zero only)\n- Handle proper dtype conversion and clipping for wandb compatibility\n- Respect both general max_save_num and wandb-specific limit\n\nThis enables visualization of validation samples directly in the W&B dashboard without manual intervention. --- .gitignore | 4 +++ src/callbacks/save_images.py | 55 ++++++++++++++++++++++++++++++++++-- 2 files changed, 57 insertions(+), 2 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a1e299f --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +__pycache__/ +*.py[cod] +.venv/ +.idea/ diff --git a/src/callbacks/save_images.py b/src/callbacks/save_images.py index 85e9c0b..9c7520c 100644 --- a/src/callbacks/save_images.py +++ b/src/callbacks/save_images.py @@ -1,5 +1,6 @@ import lightning.pytorch as pl from lightning.pytorch import Callback +from lightning.pytorch.loggers import WandbLogger import os.path @@ -10,15 +11,20 @@ from concurrent.futures import ThreadPoolExecutor from lightning.pytorch.utilities.types import STEP_OUTPUT from lightning_utilities.core.rank_zero import rank_zero_info +try: + import wandb +except ImportError: # pragma: no cover - optional dependency + wandb = None def process_fn(image, path): Image.fromarray(image).save(path) class SaveImagesHook(Callback): - def __init__(self, save_dir="val", max_save_num=100, compressed=True): + def __init__(self, save_dir="val", max_save_num=100, compressed=True, wandb_max_save_num=None): self.save_dir = save_dir self.max_save_num = max_save_num self.compressed = compressed + self.wandb_max_save_num = wandb_max_save_num def save_start(self, target_dir): self.target_dir = target_dir @@ -68,6 +74,50 @@ class SaveImagesHook(Callback): self._have_saved_num = 0 self.executor_pool = None + def _get_wandb_logger(self, trainer: "pl.Trainer"): + if isinstance(trainer.logger, WandbLogger): + return trainer.logger + if getattr(trainer, "loggers", None): + for logger in trainer.loggers: + if isinstance(logger, WandbLogger): + return logger + return None + + def _log_wandb_samples(self, trainer: "pl.Trainer"): + if not trainer.is_global_zero: + return + if wandb is None: + return + if not self.samples: + return + if self.max_save_num == 0: + return + wandb_logger = self._get_wandb_logger(trainer) + if wandb_logger is None: + return + + max_num = self.wandb_max_save_num if self.wandb_max_save_num is not None else self.max_save_num + if max_num <= 0: + return + + images = [] + remaining = max_num + for batch in self.samples: + if remaining <= 0: + break + take = min(batch.shape[0], remaining) + chunk = batch[:take] + if chunk.dtype != numpy.uint8: + chunk = numpy.clip(chunk, 0, 255).astype(numpy.uint8) + images.extend([wandb.Image(img) for img in chunk]) + remaining -= take + + if images: + wandb_logger.experiment.log( + {f"{self.save_dir}/samples": images}, + step=trainer.global_step, + ) + def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: target_dir = os.path.join(trainer.default_root_dir, self.save_dir, f"iter_{trainer.global_step}") self.save_start(target_dir) @@ -84,6 +134,7 @@ class SaveImagesHook(Callback): return self.process_batch(trainer, pl_module, outputs, batch) def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self._log_wandb_samples(trainer) self.save_end() def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: @@ -102,4 +153,4 @@ class SaveImagesHook(Callback): return self.process_batch(trainer, pl_module, samples, batch) def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - self.save_end() \ No newline at end of file + self.save_end()