submit code
This commit is contained in:
0
src/utils/__init__.py
Normal file
0
src/utils/__init__.py
Normal file
13
src/utils/copy.py
Normal file
13
src/utils/copy.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import torch
|
||||
|
||||
@torch.no_grad()
|
||||
def copy_params(src_model, dst_model):
|
||||
for src_param, dst_param in zip(src_model.parameters(), dst_model.parameters()):
|
||||
dst_param.data.copy_(src_param.data)
|
||||
|
||||
@torch.no_grad()
|
||||
def swap_tensors(tensor1, tensor2):
|
||||
tmp = torch.empty_like(tensor1)
|
||||
tmp.copy_(tensor1)
|
||||
tensor1.copy_(tensor2)
|
||||
tensor2.copy_(tmp)
|
||||
29
src/utils/model_loader.py
Normal file
29
src/utils/model_loader.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from lightning.fabric.utilities.types import _PATH
|
||||
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ModelLoader:
|
||||
def __init__(self,):
|
||||
super().__init__()
|
||||
|
||||
def load(self, denoiser, prefix=""):
|
||||
if denoiser.weight_path:
|
||||
weight = torch.load(denoiser.weight_path, map_location=torch.device('cpu'))
|
||||
|
||||
if denoiser.load_ema:
|
||||
prefix = "ema_denoiser." + prefix
|
||||
else:
|
||||
prefix = "denoiser." + prefix
|
||||
|
||||
for k, v in denoiser.state_dict().items():
|
||||
try:
|
||||
v.copy_(weight["state_dict"][prefix+k])
|
||||
except:
|
||||
logger.warning(f"Failed to copy {prefix+k} to denoiser weight")
|
||||
return denoiser
|
||||
16
src/utils/no_grad.py
Normal file
16
src/utils/no_grad.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import torch
|
||||
|
||||
@torch.no_grad()
|
||||
def no_grad(net):
|
||||
for param in net.parameters():
|
||||
param.requires_grad = False
|
||||
net.eval()
|
||||
return net
|
||||
|
||||
@torch.no_grad()
|
||||
def filter_nograd_tensors(params_list):
|
||||
filtered_params_list = []
|
||||
for param in params_list:
|
||||
if param.requires_grad:
|
||||
filtered_params_list.append(param)
|
||||
return filtered_params_list
|
||||
17
src/utils/patch_bugs.py
Normal file
17
src/utils/patch_bugs.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import torch
|
||||
import lightning.pytorch.loggers.wandb as wandb
|
||||
|
||||
setattr(wandb, '_WANDB_AVAILABLE', True)
|
||||
torch.set_float32_matmul_precision('medium')
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger("wandb")
|
||||
logger.setLevel(logging.WARNING)
|
||||
|
||||
import os
|
||||
os.environ["NCCL_DEBUG"] = "WARN"
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
|
||||
import warnings
|
||||
warnings.simplefilter(action='ignore', category=FutureWarning)
|
||||
warnings.simplefilter(action='ignore', category=UserWarning)
|
||||
Reference in New Issue
Block a user