feat(callbacks): add Weights & Biases image logging to SaveImagesHook\n\n- Introduce optional wandb_max_save_num parameter in __init__\n- Add helper methods _get_wandb_logger and _log_wandb_samples\n- Log sampled images as wandb.Image artifacts during validation epoch end (global zero only)\n- Handle proper dtype conversion and clipping for wandb compatibility\n- Respect both general max_save_num and wandb-specific limit\n\nThis enables visualization of validation samples directly in the W&B dashboard without manual intervention.

This commit is contained in:
gameloader
2026-01-19 16:03:54 +08:00
parent 6bf32b08fd
commit b9c3720eba
2 changed files with 57 additions and 2 deletions

4
.gitignore vendored Normal file
View File

@@ -0,0 +1,4 @@
__pycache__/
*.py[cod]
.venv/
.idea/

View File

@@ -1,5 +1,6 @@
import lightning.pytorch as pl import lightning.pytorch as pl
from lightning.pytorch import Callback from lightning.pytorch import Callback
from lightning.pytorch.loggers import WandbLogger
import os.path import os.path
@@ -10,15 +11,20 @@ from concurrent.futures import ThreadPoolExecutor
from lightning.pytorch.utilities.types import STEP_OUTPUT from lightning.pytorch.utilities.types import STEP_OUTPUT
from lightning_utilities.core.rank_zero import rank_zero_info from lightning_utilities.core.rank_zero import rank_zero_info
try:
import wandb
except ImportError: # pragma: no cover - optional dependency
wandb = None
def process_fn(image, path): def process_fn(image, path):
Image.fromarray(image).save(path) Image.fromarray(image).save(path)
class SaveImagesHook(Callback): class SaveImagesHook(Callback):
def __init__(self, save_dir="val", max_save_num=100, compressed=True): def __init__(self, save_dir="val", max_save_num=100, compressed=True, wandb_max_save_num=None):
self.save_dir = save_dir self.save_dir = save_dir
self.max_save_num = max_save_num self.max_save_num = max_save_num
self.compressed = compressed self.compressed = compressed
self.wandb_max_save_num = wandb_max_save_num
def save_start(self, target_dir): def save_start(self, target_dir):
self.target_dir = target_dir self.target_dir = target_dir
@@ -68,6 +74,50 @@ class SaveImagesHook(Callback):
self._have_saved_num = 0 self._have_saved_num = 0
self.executor_pool = None self.executor_pool = None
def _get_wandb_logger(self, trainer: "pl.Trainer"):
if isinstance(trainer.logger, WandbLogger):
return trainer.logger
if getattr(trainer, "loggers", None):
for logger in trainer.loggers:
if isinstance(logger, WandbLogger):
return logger
return None
def _log_wandb_samples(self, trainer: "pl.Trainer"):
if not trainer.is_global_zero:
return
if wandb is None:
return
if not self.samples:
return
if self.max_save_num == 0:
return
wandb_logger = self._get_wandb_logger(trainer)
if wandb_logger is None:
return
max_num = self.wandb_max_save_num if self.wandb_max_save_num is not None else self.max_save_num
if max_num <= 0:
return
images = []
remaining = max_num
for batch in self.samples:
if remaining <= 0:
break
take = min(batch.shape[0], remaining)
chunk = batch[:take]
if chunk.dtype != numpy.uint8:
chunk = numpy.clip(chunk, 0, 255).astype(numpy.uint8)
images.extend([wandb.Image(img) for img in chunk])
remaining -= take
if images:
wandb_logger.experiment.log(
{f"{self.save_dir}/samples": images},
step=trainer.global_step,
)
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
target_dir = os.path.join(trainer.default_root_dir, self.save_dir, f"iter_{trainer.global_step}") target_dir = os.path.join(trainer.default_root_dir, self.save_dir, f"iter_{trainer.global_step}")
self.save_start(target_dir) self.save_start(target_dir)
@@ -84,6 +134,7 @@ class SaveImagesHook(Callback):
return self.process_batch(trainer, pl_module, outputs, batch) return self.process_batch(trainer, pl_module, outputs, batch)
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._log_wandb_samples(trainer)
self.save_end() self.save_end()
def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: