Files
diffusion_policy/tests/test_pusht_swanlab_config.py

33 lines
966 B
Python

import pathlib
from omegaconf import OmegaConf
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
def _load_cfg(name: str):
return OmegaConf.load(ROOT_DIR / name)
def test_image_pusht_dit_swanlab_config_uses_exp_name_and_no_resume_collision():
cfg = _load_cfg('image_pusht_diffusion_policy_dit.yaml')
assert cfg.logging.backend == 'swanlab'
assert cfg.logging.mode == 'online'
assert cfg.logging.name == cfg.exp_name
assert cfg.logging.resume is False
assert cfg.logging.id is None
assert cfg.logging.group == cfg.exp_name
def test_image_pusht_dit_imf_swanlab_config_uses_exp_name_and_no_resume_collision():
cfg = _load_cfg('image_pusht_diffusion_policy_dit_imf.yaml')
assert cfg.logging.backend == 'swanlab'
assert cfg.logging.mode == 'online'
assert cfg.logging.name == cfg.exp_name
assert cfg.logging.resume is False
assert cfg.logging.id is None
assert cfg.logging.group == cfg.exp_name