45 lines
1.4 KiB
Python
45 lines
1.4 KiB
Python
import pathlib
|
|
|
|
from omegaconf import OmegaConf
|
|
|
|
|
|
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
|
|
|
|
|
|
def _load_cfg(name: str):
|
|
return OmegaConf.load(ROOT_DIR / name)
|
|
|
|
|
|
def test_image_pusht_dit_swanlab_config_uses_exp_name_and_no_resume_collision():
|
|
cfg = _load_cfg('image_pusht_diffusion_policy_dit.yaml')
|
|
|
|
assert cfg.logging.backend == 'swanlab'
|
|
assert cfg.logging.mode == 'online'
|
|
assert cfg.logging.name == cfg.exp_name
|
|
assert cfg.logging.resume is False
|
|
assert cfg.logging.id is None
|
|
assert cfg.logging.group == cfg.exp_name
|
|
|
|
|
|
def test_image_pusht_dit_imf_swanlab_config_uses_exp_name_and_no_resume_collision():
|
|
cfg = _load_cfg('image_pusht_diffusion_policy_dit_imf.yaml')
|
|
|
|
assert cfg.logging.backend == 'swanlab'
|
|
assert cfg.logging.mode == 'online'
|
|
assert cfg.logging.name == cfg.exp_name
|
|
assert cfg.logging.resume is False
|
|
assert cfg.logging.id is None
|
|
assert cfg.logging.group == cfg.exp_name
|
|
|
|
|
|
def test_image_pusht_dit_imf_fullattn_config_uses_exp_name_and_disables_causal_attention():
|
|
cfg = _load_cfg('image_pusht_diffusion_policy_dit_imf_fullattn.yaml')
|
|
|
|
assert cfg.logging.backend == 'swanlab'
|
|
assert cfg.logging.mode == 'online'
|
|
assert cfg.logging.name == cfg.exp_name
|
|
assert cfg.logging.resume is False
|
|
assert cfg.logging.id is None
|
|
assert cfg.logging.group == cfg.exp_name
|
|
assert cfg.policy.causal_attn is False
|