submit code

This commit is contained in:
wangshuai6
2025-04-09 11:01:16 +08:00
parent 4fbcf9bd87
commit 06499f1caa
145 changed files with 14400 additions and 0 deletions

View File

22
src/callbacks/grad.py Normal file
View File

@@ -0,0 +1,22 @@
import torch
import lightning.pytorch as pl
from lightning.pytorch.utilities import grad_norm
from torch.optim import Optimizer
class GradientMonitor(pl.Callback):
"""Logs the gradient norm"""
def __init__(self, norm_type: int = 2):
norm_type = float(norm_type)
if norm_type <= 0:
raise ValueError(f"`norm_type` must be a positive number or 'inf' (infinity norm). Got {norm_type}")
self.norm_type = norm_type
def on_before_optimizer_step(
self, trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
optimizer: Optimizer
) -> None:
norms = grad_norm(pl_module, norm_type=self.norm_type)
max_grad = torch.tensor([v for k, v in norms.items() if k != f"grad_{self.norm_type}_norm_total"]).max()
pl_module.log_dict({'train/grad/max': max_grad, 'train/grad/total': norms[f"grad_{self.norm_type}_norm_total"]})

View File

@@ -0,0 +1,25 @@
import os.path
from typing import Optional, Dict, Any
import lightning.pytorch as pl
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from soupsieve.util import lower
class CheckpointHook(ModelCheckpoint):
"""Save checkpoint with only the incremental part of the model"""
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
self.dirpath = trainer.default_root_dir
self.exception_ckpt_path = os.path.join(self.dirpath, "on_exception.pt")
pl_module.strict_loading = False
def on_save_checkpoint(
self, trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
checkpoint: Dict[str, Any]
) -> None:
del checkpoint["callbacks"]
# def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None:
# if not "debug" in self.exception_ckpt_path:
# trainer.save_checkpoint(self.exception_ckpt_path)

View File

@@ -0,0 +1,105 @@
import lightning.pytorch as pl
from lightning.pytorch import Callback
import os.path
import numpy
from PIL import Image
from typing import Sequence, Any, Dict
from concurrent.futures import ThreadPoolExecutor
from lightning.pytorch.utilities.types import STEP_OUTPUT
from lightning_utilities.core.rank_zero import rank_zero_info
def process_fn(image, path):
Image.fromarray(image).save(path)
class SaveImagesHook(Callback):
def __init__(self, save_dir="val", max_save_num=0, compressed=True):
self.save_dir = save_dir
self.max_save_num = max_save_num
self.compressed = compressed
def save_start(self, target_dir):
self.target_dir = target_dir
self.executor_pool = ThreadPoolExecutor(max_workers=8)
if not os.path.exists(self.target_dir):
os.makedirs(self.target_dir, exist_ok=True)
else:
if os.listdir(target_dir) and "debug" not in str(target_dir):
raise FileExistsError(f'{self.target_dir} already exists and not empty!')
self.samples = []
self._have_saved_num = 0
rank_zero_info(f"Save images to {self.target_dir}")
def save_image(self, images, filenames):
images = images.permute(0, 2, 3, 1).cpu().numpy()
for sample, filename in zip(images, filenames):
if isinstance(filename, Sequence):
filename = filename[0]
path = f'{self.target_dir}/{filename}'
if self._have_saved_num >= self.max_save_num:
break
self.executor_pool.submit(process_fn, sample, path)
self._have_saved_num += 1
def process_batch(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
samples: STEP_OUTPUT,
batch: Any,
) -> None:
b, c, h, w = samples.shape
xT, y, metadata = batch
all_samples = pl_module.all_gather(samples).view(-1, c, h, w)
self.save_image(samples, metadata)
if trainer.is_global_zero:
all_samples = all_samples.permute(0, 2, 3, 1).cpu().numpy()
self.samples.append(all_samples)
def save_end(self):
if self.compressed and len(self.samples) > 0:
samples = numpy.concatenate(self.samples)
numpy.savez(f'{self.target_dir}/output.npz', arr_0=samples)
self.executor_pool.shutdown(wait=True)
self.samples = []
self.target_dir = None
self._have_saved_num = 0
self.executor_pool = None
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
target_dir = os.path.join(trainer.default_root_dir, self.save_dir, f"iter_{trainer.global_step}")
self.save_start(target_dir)
def on_validation_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
return self.process_batch(trainer, pl_module, outputs, batch)
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.save_end()
def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
target_dir = os.path.join(trainer.default_root_dir, self.save_dir, "predict")
self.save_start(target_dir)
def on_predict_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
samples: Any,
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
return self.process_batch(trainer, pl_module, samples, batch)
def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.save_end()

View File

@@ -0,0 +1,79 @@
from typing import Any, Dict
import torch
import torch.nn as nn
import threading
import lightning.pytorch as pl
from lightning.pytorch import Callback
from lightning.pytorch.utilities.types import STEP_OUTPUT
from src.utils.copy import swap_tensors
class SimpleEMA(Callback):
def __init__(self, net:nn.Module, ema_net:nn.Module,
decay: float = 0.9999,
every_n_steps: int = 1,
eval_original_model:bool = False
):
super().__init__()
self.decay = decay
self.every_n_steps = every_n_steps
self.eval_original_model = eval_original_model
self._stream = torch.cuda.Stream()
self.net_params = list(net.parameters())
self.ema_params = list(ema_net.parameters())
def swap_model(self):
for ema_p, p, in zip(self.ema_params, self.net_params):
swap_tensors(ema_p, p)
def ema_step(self):
@torch.no_grad()
def ema_update(ema_model_tuple, current_model_tuple, decay):
torch._foreach_mul_(ema_model_tuple, decay)
torch._foreach_add_(
ema_model_tuple, current_model_tuple, alpha=(1.0 - decay),
)
if self._stream is not None:
self._stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._stream):
ema_update(self.ema_params, self.net_params, self.decay)
def on_train_batch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
) -> None:
if trainer.global_step % self.every_n_steps == 0:
self.ema_step()
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not self.eval_original_model:
self.swap_model()
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not self.eval_original_model:
self.swap_model()
def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not self.eval_original_model:
self.swap_model()
def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not self.eval_original_model:
self.swap_model()
def state_dict(self) -> Dict[str, Any]:
return {
"decay": self.decay,
"every_n_steps": self.every_n_steps,
"eval_original_model": self.eval_original_model,
}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.decay = state_dict["decay"]
self.every_n_steps = state_dict["every_n_steps"]
self.eval_original_model = state_dict["eval_original_model"]