import pathlib from omegaconf import OmegaConf ROOT_DIR = pathlib.Path(__file__).resolve().parents[1] def _load_cfg(name: str): return OmegaConf.load(ROOT_DIR / name) def test_image_pusht_dit_swanlab_config_uses_exp_name_and_no_resume_collision(): cfg = _load_cfg('image_pusht_diffusion_policy_dit.yaml') assert cfg.logging.backend == 'swanlab' assert cfg.logging.mode == 'online' assert cfg.logging.name == cfg.exp_name assert cfg.logging.resume is False assert cfg.logging.id is None assert cfg.logging.group == cfg.exp_name def test_image_pusht_dit_imf_swanlab_config_uses_exp_name_and_no_resume_collision(): cfg = _load_cfg('image_pusht_diffusion_policy_dit_imf.yaml') assert cfg.logging.backend == 'swanlab' assert cfg.logging.mode == 'online' assert cfg.logging.name == cfg.exp_name assert cfg.logging.resume is False assert cfg.logging.id is None assert cfg.logging.group == cfg.exp_name def test_image_pusht_dit_imf_fullattn_config_uses_exp_name_and_disables_causal_attention(): cfg = _load_cfg('image_pusht_diffusion_policy_dit_imf_fullattn.yaml') assert cfg.logging.backend == 'swanlab' assert cfg.logging.mode == 'online' assert cfg.logging.name == cfg.exp_name assert cfg.logging.resume is False assert cfg.logging.id is None assert cfg.logging.group == cfg.exp_name assert cfg.policy.causal_attn is False