33 lines
966 B
Python
33 lines
966 B
Python
import pathlib
|
|
|
|
from omegaconf import OmegaConf
|
|
|
|
|
|
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
|
|
|
|
|
|
def _load_cfg(name: str):
|
|
return OmegaConf.load(ROOT_DIR / name)
|
|
|
|
|
|
def test_image_pusht_dit_swanlab_config_uses_exp_name_and_no_resume_collision():
|
|
cfg = _load_cfg('image_pusht_diffusion_policy_dit.yaml')
|
|
|
|
assert cfg.logging.backend == 'swanlab'
|
|
assert cfg.logging.mode == 'online'
|
|
assert cfg.logging.name == cfg.exp_name
|
|
assert cfg.logging.resume is False
|
|
assert cfg.logging.id is None
|
|
assert cfg.logging.group == cfg.exp_name
|
|
|
|
|
|
def test_image_pusht_dit_imf_swanlab_config_uses_exp_name_and_no_resume_collision():
|
|
cfg = _load_cfg('image_pusht_diffusion_policy_dit_imf.yaml')
|
|
|
|
assert cfg.logging.backend == 'swanlab'
|
|
assert cfg.logging.mode == 'online'
|
|
assert cfg.logging.name == cfg.exp_name
|
|
assert cfg.logging.resume is False
|
|
assert cfg.logging.id is None
|
|
assert cfg.logging.group == cfg.exp_name
|