fix: support headless rollout on remote training hosts
This commit is contained in:
@@ -90,6 +90,18 @@ class _FakeRenderer:
|
||||
|
||||
|
||||
class EvalVLAHeadlessTest(unittest.TestCase):
|
||||
def test_headless_eval_sets_mujoco_gl_to_egl_when_display_missing(self):
|
||||
cfg = OmegaConf.create({"eval": {"headless": True}})
|
||||
with mock.patch.dict(eval_vla.os.environ, {}, clear=True):
|
||||
eval_vla._configure_headless_mujoco_gl(cfg.eval)
|
||||
self.assertEqual(eval_vla.os.environ.get("MUJOCO_GL"), "egl")
|
||||
|
||||
def test_headless_eval_preserves_existing_mujoco_gl(self):
|
||||
cfg = OmegaConf.create({"eval": {"headless": True}})
|
||||
with mock.patch.dict(eval_vla.os.environ, {"MUJOCO_GL": "osmesa"}, clear=True):
|
||||
eval_vla._configure_headless_mujoco_gl(cfg.eval)
|
||||
self.assertEqual(eval_vla.os.environ.get("MUJOCO_GL"), "osmesa")
|
||||
|
||||
def test_eval_config_exposes_headless_default(self):
|
||||
eval_cfg = OmegaConf.load(Path("roboimi/vla/conf/eval/eval.yaml"))
|
||||
|
||||
|
||||
@@ -246,6 +246,35 @@ class TrainVLATransformerOptimizerTest(unittest.TestCase):
|
||||
module.torch.backends.cudnn.enabled = original
|
||||
|
||||
|
||||
def test_resolve_run_output_dir_prefers_hydra_runtime_output_dir(self):
|
||||
module = self._load_train_vla_module()
|
||||
hydra_core_module = types.ModuleType('hydra.core')
|
||||
hydra_hydra_config_module = types.ModuleType('hydra.core.hydra_config')
|
||||
|
||||
class _Runtime:
|
||||
output_dir = '/tmp/hydra-output'
|
||||
|
||||
class _Cfg:
|
||||
runtime = _Runtime()
|
||||
|
||||
class HydraConfigStub:
|
||||
@staticmethod
|
||||
def initialized():
|
||||
return True
|
||||
@staticmethod
|
||||
def get():
|
||||
return _Cfg()
|
||||
|
||||
hydra_hydra_config_module.HydraConfig = HydraConfigStub
|
||||
with mock.patch.dict(sys.modules, {
|
||||
'hydra.core': hydra_core_module,
|
||||
'hydra.core.hydra_config': hydra_hydra_config_module,
|
||||
}):
|
||||
output_dir = module._resolve_run_output_dir()
|
||||
|
||||
self.assertEqual(Path(output_dir).resolve(), Path('/tmp/hydra-output').resolve())
|
||||
|
||||
|
||||
def test_train_script_uses_file_based_repo_root_on_sys_path(self):
|
||||
module = self._load_train_vla_module()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user