fix: support headless rollout on remote training hosts
This commit is contained in:
@@ -37,6 +37,17 @@ if not OmegaConf.has_resolver("len"):
|
||||
OmegaConf.register_new_resolver("len", lambda x: len(x))
|
||||
|
||||
|
||||
def _configure_headless_mujoco_gl(eval_cfg: DictConfig) -> None:
|
||||
if not bool(eval_cfg.get('headless', False)):
|
||||
return
|
||||
if os.environ.get('MUJOCO_GL'):
|
||||
return
|
||||
if os.environ.get('DISPLAY'):
|
||||
return
|
||||
os.environ['MUJOCO_GL'] = 'egl'
|
||||
log.info('headless eval detected without DISPLAY; set MUJOCO_GL=egl')
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
ckpt_path: str,
|
||||
agent_cfg: DictConfig,
|
||||
@@ -501,6 +512,7 @@ def _run_eval(cfg: DictConfig):
|
||||
print("=" * 80)
|
||||
|
||||
eval_cfg = cfg.eval
|
||||
_configure_headless_mujoco_gl(eval_cfg)
|
||||
device = eval_cfg.device
|
||||
camera_names = list(eval_cfg.camera_names)
|
||||
artifact_paths = _resolve_artifact_paths(eval_cfg)
|
||||
|
||||
@@ -76,6 +76,18 @@ if not OmegaConf.has_resolver("len"):
|
||||
OmegaConf.register_new_resolver("len", lambda x: len(x))
|
||||
|
||||
|
||||
def _resolve_run_output_dir() -> Path:
|
||||
try:
|
||||
from hydra.core.hydra_config import HydraConfig
|
||||
if HydraConfig.initialized():
|
||||
output_dir = HydraConfig.get().runtime.output_dir
|
||||
if output_dir:
|
||||
return Path(output_dir).resolve()
|
||||
except Exception:
|
||||
pass
|
||||
return Path.cwd().resolve()
|
||||
|
||||
|
||||
_maybe_reexec_without_problematic_ld_preload()
|
||||
|
||||
|
||||
@@ -319,8 +331,9 @@ def _run_training(cfg: DictConfig):
|
||||
swanlab_module = _init_swanlab(cfg)
|
||||
try:
|
||||
# 创建检查点目录
|
||||
checkpoint_dir = Path("checkpoints")
|
||||
checkpoint_dir.mkdir(exist_ok=True)
|
||||
run_output_dir = _resolve_run_output_dir()
|
||||
checkpoint_dir = run_output_dir / "checkpoints"
|
||||
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
default_best_model_path = checkpoint_dir / "vla_model_best.pt"
|
||||
|
||||
# =========================================================================
|
||||
|
||||
@@ -90,6 +90,18 @@ class _FakeRenderer:
|
||||
|
||||
|
||||
class EvalVLAHeadlessTest(unittest.TestCase):
|
||||
def test_headless_eval_sets_mujoco_gl_to_egl_when_display_missing(self):
|
||||
cfg = OmegaConf.create({"eval": {"headless": True}})
|
||||
with mock.patch.dict(eval_vla.os.environ, {}, clear=True):
|
||||
eval_vla._configure_headless_mujoco_gl(cfg.eval)
|
||||
self.assertEqual(eval_vla.os.environ.get("MUJOCO_GL"), "egl")
|
||||
|
||||
def test_headless_eval_preserves_existing_mujoco_gl(self):
|
||||
cfg = OmegaConf.create({"eval": {"headless": True}})
|
||||
with mock.patch.dict(eval_vla.os.environ, {"MUJOCO_GL": "osmesa"}, clear=True):
|
||||
eval_vla._configure_headless_mujoco_gl(cfg.eval)
|
||||
self.assertEqual(eval_vla.os.environ.get("MUJOCO_GL"), "osmesa")
|
||||
|
||||
def test_eval_config_exposes_headless_default(self):
|
||||
eval_cfg = OmegaConf.load(Path("roboimi/vla/conf/eval/eval.yaml"))
|
||||
|
||||
|
||||
@@ -246,6 +246,35 @@ class TrainVLATransformerOptimizerTest(unittest.TestCase):
|
||||
module.torch.backends.cudnn.enabled = original
|
||||
|
||||
|
||||
def test_resolve_run_output_dir_prefers_hydra_runtime_output_dir(self):
|
||||
module = self._load_train_vla_module()
|
||||
hydra_core_module = types.ModuleType('hydra.core')
|
||||
hydra_hydra_config_module = types.ModuleType('hydra.core.hydra_config')
|
||||
|
||||
class _Runtime:
|
||||
output_dir = '/tmp/hydra-output'
|
||||
|
||||
class _Cfg:
|
||||
runtime = _Runtime()
|
||||
|
||||
class HydraConfigStub:
|
||||
@staticmethod
|
||||
def initialized():
|
||||
return True
|
||||
@staticmethod
|
||||
def get():
|
||||
return _Cfg()
|
||||
|
||||
hydra_hydra_config_module.HydraConfig = HydraConfigStub
|
||||
with mock.patch.dict(sys.modules, {
|
||||
'hydra.core': hydra_core_module,
|
||||
'hydra.core.hydra_config': hydra_hydra_config_module,
|
||||
}):
|
||||
output_dir = module._resolve_run_output_dir()
|
||||
|
||||
self.assertEqual(Path(output_dir).resolve(), Path('/tmp/hydra-output').resolve())
|
||||
|
||||
|
||||
def test_train_script_uses_file_based_repo_root_on_sys_path(self):
|
||||
module = self._load_train_vla_module()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user