feat(callbacks): add Weights & Biases image logging to SaveImagesHook\n\n- Introduce optional wandb_max_save_num parameter in __init__\n- Add helper methods _get_wandb_logger and _log_wandb_samples\n- Log sampled images as wandb.Image artifacts during validation epoch end (global zero only)\n- Handle proper dtype conversion and clipping for wandb compatibility\n- Respect both general max_save_num and wandb-specific limit\n\nThis enables visualization of validation samples directly in the W&B dashboard without manual intervention.
This commit is contained in:
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
.venv/
|
||||
.idea/
|
||||
@@ -1,5 +1,6 @@
|
||||
import lightning.pytorch as pl
|
||||
from lightning.pytorch import Callback
|
||||
from lightning.pytorch.loggers import WandbLogger
|
||||
|
||||
|
||||
import os.path
|
||||
@@ -10,15 +11,20 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
||||
from lightning_utilities.core.rank_zero import rank_zero_info
|
||||
try:
|
||||
import wandb
|
||||
except ImportError: # pragma: no cover - optional dependency
|
||||
wandb = None
|
||||
|
||||
def process_fn(image, path):
|
||||
Image.fromarray(image).save(path)
|
||||
|
||||
class SaveImagesHook(Callback):
|
||||
def __init__(self, save_dir="val", max_save_num=100, compressed=True):
|
||||
def __init__(self, save_dir="val", max_save_num=100, compressed=True, wandb_max_save_num=None):
|
||||
self.save_dir = save_dir
|
||||
self.max_save_num = max_save_num
|
||||
self.compressed = compressed
|
||||
self.wandb_max_save_num = wandb_max_save_num
|
||||
|
||||
def save_start(self, target_dir):
|
||||
self.target_dir = target_dir
|
||||
@@ -68,6 +74,50 @@ class SaveImagesHook(Callback):
|
||||
self._have_saved_num = 0
|
||||
self.executor_pool = None
|
||||
|
||||
def _get_wandb_logger(self, trainer: "pl.Trainer"):
|
||||
if isinstance(trainer.logger, WandbLogger):
|
||||
return trainer.logger
|
||||
if getattr(trainer, "loggers", None):
|
||||
for logger in trainer.loggers:
|
||||
if isinstance(logger, WandbLogger):
|
||||
return logger
|
||||
return None
|
||||
|
||||
def _log_wandb_samples(self, trainer: "pl.Trainer"):
|
||||
if not trainer.is_global_zero:
|
||||
return
|
||||
if wandb is None:
|
||||
return
|
||||
if not self.samples:
|
||||
return
|
||||
if self.max_save_num == 0:
|
||||
return
|
||||
wandb_logger = self._get_wandb_logger(trainer)
|
||||
if wandb_logger is None:
|
||||
return
|
||||
|
||||
max_num = self.wandb_max_save_num if self.wandb_max_save_num is not None else self.max_save_num
|
||||
if max_num <= 0:
|
||||
return
|
||||
|
||||
images = []
|
||||
remaining = max_num
|
||||
for batch in self.samples:
|
||||
if remaining <= 0:
|
||||
break
|
||||
take = min(batch.shape[0], remaining)
|
||||
chunk = batch[:take]
|
||||
if chunk.dtype != numpy.uint8:
|
||||
chunk = numpy.clip(chunk, 0, 255).astype(numpy.uint8)
|
||||
images.extend([wandb.Image(img) for img in chunk])
|
||||
remaining -= take
|
||||
|
||||
if images:
|
||||
wandb_logger.experiment.log(
|
||||
{f"{self.save_dir}/samples": images},
|
||||
step=trainer.global_step,
|
||||
)
|
||||
|
||||
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
target_dir = os.path.join(trainer.default_root_dir, self.save_dir, f"iter_{trainer.global_step}")
|
||||
self.save_start(target_dir)
|
||||
@@ -84,6 +134,7 @@ class SaveImagesHook(Callback):
|
||||
return self.process_batch(trainer, pl_module, outputs, batch)
|
||||
|
||||
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
self._log_wandb_samples(trainer)
|
||||
self.save_end()
|
||||
|
||||
def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
|
||||
Reference in New Issue
Block a user