feat(callbacks): add Weights & Biases image logging to SaveImagesHook\n\n- Introduce optional wandb_max_save_num parameter in __init__\n- Add helper methods _get_wandb_logger and _log_wandb_samples\n- Log sampled images as wandb.Image artifacts during validation epoch end (global zero only)\n- Handle proper dtype conversion and clipping for wandb compatibility\n- Respect both general max_save_num and wandb-specific limit\n\nThis enables visualization of validation samples directly in the W&B dashboard without manual intervention.
This commit is contained in:
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
.venv/
|
||||||
|
.idea/
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
import lightning.pytorch as pl
|
import lightning.pytorch as pl
|
||||||
from lightning.pytorch import Callback
|
from lightning.pytorch import Callback
|
||||||
|
from lightning.pytorch.loggers import WandbLogger
|
||||||
|
|
||||||
|
|
||||||
import os.path
|
import os.path
|
||||||
@@ -10,15 +11,20 @@ from concurrent.futures import ThreadPoolExecutor
|
|||||||
|
|
||||||
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
||||||
from lightning_utilities.core.rank_zero import rank_zero_info
|
from lightning_utilities.core.rank_zero import rank_zero_info
|
||||||
|
try:
|
||||||
|
import wandb
|
||||||
|
except ImportError: # pragma: no cover - optional dependency
|
||||||
|
wandb = None
|
||||||
|
|
||||||
def process_fn(image, path):
|
def process_fn(image, path):
|
||||||
Image.fromarray(image).save(path)
|
Image.fromarray(image).save(path)
|
||||||
|
|
||||||
class SaveImagesHook(Callback):
|
class SaveImagesHook(Callback):
|
||||||
def __init__(self, save_dir="val", max_save_num=100, compressed=True):
|
def __init__(self, save_dir="val", max_save_num=100, compressed=True, wandb_max_save_num=None):
|
||||||
self.save_dir = save_dir
|
self.save_dir = save_dir
|
||||||
self.max_save_num = max_save_num
|
self.max_save_num = max_save_num
|
||||||
self.compressed = compressed
|
self.compressed = compressed
|
||||||
|
self.wandb_max_save_num = wandb_max_save_num
|
||||||
|
|
||||||
def save_start(self, target_dir):
|
def save_start(self, target_dir):
|
||||||
self.target_dir = target_dir
|
self.target_dir = target_dir
|
||||||
@@ -68,6 +74,50 @@ class SaveImagesHook(Callback):
|
|||||||
self._have_saved_num = 0
|
self._have_saved_num = 0
|
||||||
self.executor_pool = None
|
self.executor_pool = None
|
||||||
|
|
||||||
|
def _get_wandb_logger(self, trainer: "pl.Trainer"):
|
||||||
|
if isinstance(trainer.logger, WandbLogger):
|
||||||
|
return trainer.logger
|
||||||
|
if getattr(trainer, "loggers", None):
|
||||||
|
for logger in trainer.loggers:
|
||||||
|
if isinstance(logger, WandbLogger):
|
||||||
|
return logger
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _log_wandb_samples(self, trainer: "pl.Trainer"):
|
||||||
|
if not trainer.is_global_zero:
|
||||||
|
return
|
||||||
|
if wandb is None:
|
||||||
|
return
|
||||||
|
if not self.samples:
|
||||||
|
return
|
||||||
|
if self.max_save_num == 0:
|
||||||
|
return
|
||||||
|
wandb_logger = self._get_wandb_logger(trainer)
|
||||||
|
if wandb_logger is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
max_num = self.wandb_max_save_num if self.wandb_max_save_num is not None else self.max_save_num
|
||||||
|
if max_num <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
images = []
|
||||||
|
remaining = max_num
|
||||||
|
for batch in self.samples:
|
||||||
|
if remaining <= 0:
|
||||||
|
break
|
||||||
|
take = min(batch.shape[0], remaining)
|
||||||
|
chunk = batch[:take]
|
||||||
|
if chunk.dtype != numpy.uint8:
|
||||||
|
chunk = numpy.clip(chunk, 0, 255).astype(numpy.uint8)
|
||||||
|
images.extend([wandb.Image(img) for img in chunk])
|
||||||
|
remaining -= take
|
||||||
|
|
||||||
|
if images:
|
||||||
|
wandb_logger.experiment.log(
|
||||||
|
{f"{self.save_dir}/samples": images},
|
||||||
|
step=trainer.global_step,
|
||||||
|
)
|
||||||
|
|
||||||
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||||
target_dir = os.path.join(trainer.default_root_dir, self.save_dir, f"iter_{trainer.global_step}")
|
target_dir = os.path.join(trainer.default_root_dir, self.save_dir, f"iter_{trainer.global_step}")
|
||||||
self.save_start(target_dir)
|
self.save_start(target_dir)
|
||||||
@@ -84,6 +134,7 @@ class SaveImagesHook(Callback):
|
|||||||
return self.process_batch(trainer, pl_module, outputs, batch)
|
return self.process_batch(trainer, pl_module, outputs, batch)
|
||||||
|
|
||||||
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||||
|
self._log_wandb_samples(trainer)
|
||||||
self.save_end()
|
self.save_end()
|
||||||
|
|
||||||
def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||||
|
|||||||
Reference in New Issue
Block a user