diff --git a/roboimi/demos/vla_scripts/eval_vla.py b/roboimi/demos/vla_scripts/eval_vla.py index de7e7d7..2ada76c 100644 --- a/roboimi/demos/vla_scripts/eval_vla.py +++ b/roboimi/demos/vla_scripts/eval_vla.py @@ -37,6 +37,17 @@ if not OmegaConf.has_resolver("len"): OmegaConf.register_new_resolver("len", lambda x: len(x)) +def _configure_headless_mujoco_gl(eval_cfg: DictConfig) -> None: + if not bool(eval_cfg.get('headless', False)): + return + if os.environ.get('MUJOCO_GL'): + return + if os.environ.get('DISPLAY'): + return + os.environ['MUJOCO_GL'] = 'egl' + log.info('headless eval detected without DISPLAY; set MUJOCO_GL=egl') + + def load_checkpoint( ckpt_path: str, agent_cfg: DictConfig, @@ -501,6 +512,7 @@ def _run_eval(cfg: DictConfig): print("=" * 80) eval_cfg = cfg.eval + _configure_headless_mujoco_gl(eval_cfg) device = eval_cfg.device camera_names = list(eval_cfg.camera_names) artifact_paths = _resolve_artifact_paths(eval_cfg) diff --git a/roboimi/demos/vla_scripts/train_vla.py b/roboimi/demos/vla_scripts/train_vla.py index 059f4ea..b11331c 100644 --- a/roboimi/demos/vla_scripts/train_vla.py +++ b/roboimi/demos/vla_scripts/train_vla.py @@ -76,6 +76,18 @@ if not OmegaConf.has_resolver("len"): OmegaConf.register_new_resolver("len", lambda x: len(x)) +def _resolve_run_output_dir() -> Path: + try: + from hydra.core.hydra_config import HydraConfig + if HydraConfig.initialized(): + output_dir = HydraConfig.get().runtime.output_dir + if output_dir: + return Path(output_dir).resolve() + except Exception: + pass + return Path.cwd().resolve() + + _maybe_reexec_without_problematic_ld_preload() @@ -319,8 +331,9 @@ def _run_training(cfg: DictConfig): swanlab_module = _init_swanlab(cfg) try: # 创建检查点目录 - checkpoint_dir = Path("checkpoints") - checkpoint_dir.mkdir(exist_ok=True) + run_output_dir = _resolve_run_output_dir() + checkpoint_dir = run_output_dir / "checkpoints" + checkpoint_dir.mkdir(parents=True, exist_ok=True) default_best_model_path = checkpoint_dir / "vla_model_best.pt" # ========================================================================= diff --git a/tests/test_eval_vla_headless.py b/tests/test_eval_vla_headless.py index e6f4abb..f84a5f2 100644 --- a/tests/test_eval_vla_headless.py +++ b/tests/test_eval_vla_headless.py @@ -90,6 +90,18 @@ class _FakeRenderer: class EvalVLAHeadlessTest(unittest.TestCase): + def test_headless_eval_sets_mujoco_gl_to_egl_when_display_missing(self): + cfg = OmegaConf.create({"eval": {"headless": True}}) + with mock.patch.dict(eval_vla.os.environ, {}, clear=True): + eval_vla._configure_headless_mujoco_gl(cfg.eval) + self.assertEqual(eval_vla.os.environ.get("MUJOCO_GL"), "egl") + + def test_headless_eval_preserves_existing_mujoco_gl(self): + cfg = OmegaConf.create({"eval": {"headless": True}}) + with mock.patch.dict(eval_vla.os.environ, {"MUJOCO_GL": "osmesa"}, clear=True): + eval_vla._configure_headless_mujoco_gl(cfg.eval) + self.assertEqual(eval_vla.os.environ.get("MUJOCO_GL"), "osmesa") + def test_eval_config_exposes_headless_default(self): eval_cfg = OmegaConf.load(Path("roboimi/vla/conf/eval/eval.yaml")) diff --git a/tests/test_train_vla_transformer_optimizer.py b/tests/test_train_vla_transformer_optimizer.py index 7ac7657..56ee107 100644 --- a/tests/test_train_vla_transformer_optimizer.py +++ b/tests/test_train_vla_transformer_optimizer.py @@ -246,6 +246,35 @@ class TrainVLATransformerOptimizerTest(unittest.TestCase): module.torch.backends.cudnn.enabled = original + def test_resolve_run_output_dir_prefers_hydra_runtime_output_dir(self): + module = self._load_train_vla_module() + hydra_core_module = types.ModuleType('hydra.core') + hydra_hydra_config_module = types.ModuleType('hydra.core.hydra_config') + + class _Runtime: + output_dir = '/tmp/hydra-output' + + class _Cfg: + runtime = _Runtime() + + class HydraConfigStub: + @staticmethod + def initialized(): + return True + @staticmethod + def get(): + return _Cfg() + + hydra_hydra_config_module.HydraConfig = HydraConfigStub + with mock.patch.dict(sys.modules, { + 'hydra.core': hydra_core_module, + 'hydra.core.hydra_config': hydra_hydra_config_module, + }): + output_dir = module._resolve_run_output_dir() + + self.assertEqual(Path(output_dir).resolve(), Path('/tmp/hydra-output').resolve()) + + def test_train_script_uses_file_based_repo_root_on_sys_path(self): module = self._load_train_vla_module()