fix: support headless rollout on remote training hosts

This commit is contained in:
Logic
2026-04-02 08:15:54 +08:00
parent dffd92f82d
commit 0514f86c36
4 changed files with 68 additions and 2 deletions

View File

@@ -37,6 +37,17 @@ if not OmegaConf.has_resolver("len"):
OmegaConf.register_new_resolver("len", lambda x: len(x)) OmegaConf.register_new_resolver("len", lambda x: len(x))
def _configure_headless_mujoco_gl(eval_cfg: DictConfig) -> None:
if not bool(eval_cfg.get('headless', False)):
return
if os.environ.get('MUJOCO_GL'):
return
if os.environ.get('DISPLAY'):
return
os.environ['MUJOCO_GL'] = 'egl'
log.info('headless eval detected without DISPLAY; set MUJOCO_GL=egl')
def load_checkpoint( def load_checkpoint(
ckpt_path: str, ckpt_path: str,
agent_cfg: DictConfig, agent_cfg: DictConfig,
@@ -501,6 +512,7 @@ def _run_eval(cfg: DictConfig):
print("=" * 80) print("=" * 80)
eval_cfg = cfg.eval eval_cfg = cfg.eval
_configure_headless_mujoco_gl(eval_cfg)
device = eval_cfg.device device = eval_cfg.device
camera_names = list(eval_cfg.camera_names) camera_names = list(eval_cfg.camera_names)
artifact_paths = _resolve_artifact_paths(eval_cfg) artifact_paths = _resolve_artifact_paths(eval_cfg)

View File

@@ -76,6 +76,18 @@ if not OmegaConf.has_resolver("len"):
OmegaConf.register_new_resolver("len", lambda x: len(x)) OmegaConf.register_new_resolver("len", lambda x: len(x))
def _resolve_run_output_dir() -> Path:
try:
from hydra.core.hydra_config import HydraConfig
if HydraConfig.initialized():
output_dir = HydraConfig.get().runtime.output_dir
if output_dir:
return Path(output_dir).resolve()
except Exception:
pass
return Path.cwd().resolve()
_maybe_reexec_without_problematic_ld_preload() _maybe_reexec_without_problematic_ld_preload()
@@ -319,8 +331,9 @@ def _run_training(cfg: DictConfig):
swanlab_module = _init_swanlab(cfg) swanlab_module = _init_swanlab(cfg)
try: try:
# 创建检查点目录 # 创建检查点目录
checkpoint_dir = Path("checkpoints") run_output_dir = _resolve_run_output_dir()
checkpoint_dir.mkdir(exist_ok=True) checkpoint_dir = run_output_dir / "checkpoints"
checkpoint_dir.mkdir(parents=True, exist_ok=True)
default_best_model_path = checkpoint_dir / "vla_model_best.pt" default_best_model_path = checkpoint_dir / "vla_model_best.pt"
# ========================================================================= # =========================================================================

View File

@@ -90,6 +90,18 @@ class _FakeRenderer:
class EvalVLAHeadlessTest(unittest.TestCase): class EvalVLAHeadlessTest(unittest.TestCase):
def test_headless_eval_sets_mujoco_gl_to_egl_when_display_missing(self):
cfg = OmegaConf.create({"eval": {"headless": True}})
with mock.patch.dict(eval_vla.os.environ, {}, clear=True):
eval_vla._configure_headless_mujoco_gl(cfg.eval)
self.assertEqual(eval_vla.os.environ.get("MUJOCO_GL"), "egl")
def test_headless_eval_preserves_existing_mujoco_gl(self):
cfg = OmegaConf.create({"eval": {"headless": True}})
with mock.patch.dict(eval_vla.os.environ, {"MUJOCO_GL": "osmesa"}, clear=True):
eval_vla._configure_headless_mujoco_gl(cfg.eval)
self.assertEqual(eval_vla.os.environ.get("MUJOCO_GL"), "osmesa")
def test_eval_config_exposes_headless_default(self): def test_eval_config_exposes_headless_default(self):
eval_cfg = OmegaConf.load(Path("roboimi/vla/conf/eval/eval.yaml")) eval_cfg = OmegaConf.load(Path("roboimi/vla/conf/eval/eval.yaml"))

View File

@@ -246,6 +246,35 @@ class TrainVLATransformerOptimizerTest(unittest.TestCase):
module.torch.backends.cudnn.enabled = original module.torch.backends.cudnn.enabled = original
def test_resolve_run_output_dir_prefers_hydra_runtime_output_dir(self):
module = self._load_train_vla_module()
hydra_core_module = types.ModuleType('hydra.core')
hydra_hydra_config_module = types.ModuleType('hydra.core.hydra_config')
class _Runtime:
output_dir = '/tmp/hydra-output'
class _Cfg:
runtime = _Runtime()
class HydraConfigStub:
@staticmethod
def initialized():
return True
@staticmethod
def get():
return _Cfg()
hydra_hydra_config_module.HydraConfig = HydraConfigStub
with mock.patch.dict(sys.modules, {
'hydra.core': hydra_core_module,
'hydra.core.hydra_config': hydra_hydra_config_module,
}):
output_dir = module._resolve_run_output_dir()
self.assertEqual(Path(output_dir).resolve(), Path('/tmp/hydra-output').resolve())
def test_train_script_uses_file_based_repo_root_on_sys_path(self): def test_train_script_uses_file_based_repo_root_on_sys_path(self):
module = self._load_train_vla_module() module = self._load_train_vla_module()