fix: support headless rollout on remote training hosts
This commit is contained in:
@@ -246,6 +246,35 @@ class TrainVLATransformerOptimizerTest(unittest.TestCase):
|
||||
module.torch.backends.cudnn.enabled = original
|
||||
|
||||
|
||||
def test_resolve_run_output_dir_prefers_hydra_runtime_output_dir(self):
|
||||
module = self._load_train_vla_module()
|
||||
hydra_core_module = types.ModuleType('hydra.core')
|
||||
hydra_hydra_config_module = types.ModuleType('hydra.core.hydra_config')
|
||||
|
||||
class _Runtime:
|
||||
output_dir = '/tmp/hydra-output'
|
||||
|
||||
class _Cfg:
|
||||
runtime = _Runtime()
|
||||
|
||||
class HydraConfigStub:
|
||||
@staticmethod
|
||||
def initialized():
|
||||
return True
|
||||
@staticmethod
|
||||
def get():
|
||||
return _Cfg()
|
||||
|
||||
hydra_hydra_config_module.HydraConfig = HydraConfigStub
|
||||
with mock.patch.dict(sys.modules, {
|
||||
'hydra.core': hydra_core_module,
|
||||
'hydra.core.hydra_config': hydra_hydra_config_module,
|
||||
}):
|
||||
output_dir = module._resolve_run_output_dir()
|
||||
|
||||
self.assertEqual(Path(output_dir).resolve(), Path('/tmp/hydra-output').resolve())
|
||||
|
||||
|
||||
def test_train_script_uses_file_based_repo_root_on_sys_path(self):
|
||||
module = self._load_train_vla_module()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user