submit code

This commit is contained in:
wangshuai6
2025-04-09 11:01:16 +08:00
parent 4fbcf9bd87
commit 06499f1caa
145 changed files with 14400 additions and 0 deletions

0
src/utils/__init__.py Normal file
View File

13
src/utils/copy.py Normal file
View File

@@ -0,0 +1,13 @@
import torch
@torch.no_grad()
def copy_params(src_model, dst_model):
for src_param, dst_param in zip(src_model.parameters(), dst_model.parameters()):
dst_param.data.copy_(src_param.data)
@torch.no_grad()
def swap_tensors(tensor1, tensor2):
tmp = torch.empty_like(tensor1)
tmp.copy_(tensor1)
tensor1.copy_(tensor2)
tensor2.copy_(tmp)

29
src/utils/model_loader.py Normal file
View File

@@ -0,0 +1,29 @@
from typing import Dict, Any, Optional
import torch
import torch.nn as nn
from lightning.fabric.utilities.types import _PATH
import logging
logger = logging.getLogger(__name__)
class ModelLoader:
def __init__(self,):
super().__init__()
def load(self, denoiser, prefix=""):
if denoiser.weight_path:
weight = torch.load(denoiser.weight_path, map_location=torch.device('cpu'))
if denoiser.load_ema:
prefix = "ema_denoiser." + prefix
else:
prefix = "denoiser." + prefix
for k, v in denoiser.state_dict().items():
try:
v.copy_(weight["state_dict"][prefix+k])
except:
logger.warning(f"Failed to copy {prefix+k} to denoiser weight")
return denoiser

16
src/utils/no_grad.py Normal file
View File

@@ -0,0 +1,16 @@
import torch
@torch.no_grad()
def no_grad(net):
for param in net.parameters():
param.requires_grad = False
net.eval()
return net
@torch.no_grad()
def filter_nograd_tensors(params_list):
filtered_params_list = []
for param in params_list:
if param.requires_grad:
filtered_params_list.append(param)
return filtered_params_list

17
src/utils/patch_bugs.py Normal file
View File

@@ -0,0 +1,17 @@
import torch
import lightning.pytorch.loggers.wandb as wandb
setattr(wandb, '_WANDB_AVAILABLE', True)
torch.set_float32_matmul_precision('medium')
import logging
logger = logging.getLogger("wandb")
logger.setLevel(logging.WARNING)
import os
os.environ["NCCL_DEBUG"] = "WARN"
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)