fix: sanitize problematic LD_PRELOAD for cuDNN
This commit is contained in:
@@ -24,6 +24,47 @@ def _ensure_repo_root_on_syspath():
|
|||||||
return repo_root
|
return repo_root
|
||||||
|
|
||||||
|
|
||||||
|
_PROBLEMATIC_LD_PRELOAD_SUBSTRINGS = ('/usr/NX/lib/libnxegl.so', 'libnxegl.so')
|
||||||
|
|
||||||
|
|
||||||
|
def _clean_ld_preload_value(value: str | None):
|
||||||
|
if not value:
|
||||||
|
return value, False
|
||||||
|
entries = [entry for entry in value.split() if entry]
|
||||||
|
filtered = [
|
||||||
|
entry for entry in entries
|
||||||
|
if not any(marker in entry for marker in _PROBLEMATIC_LD_PRELOAD_SUBSTRINGS)
|
||||||
|
]
|
||||||
|
changed = filtered != entries
|
||||||
|
cleaned = ' '.join(filtered) if filtered else None
|
||||||
|
return cleaned, changed
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_reexec_without_problematic_ld_preload():
|
||||||
|
if __name__ != '__main__':
|
||||||
|
return False
|
||||||
|
if os.environ.get('_ROBOIMI_LD_PRELOAD_SANITIZED') == '1':
|
||||||
|
return False
|
||||||
|
|
||||||
|
cleaned, changed = _clean_ld_preload_value(os.environ.get('LD_PRELOAD'))
|
||||||
|
if not changed:
|
||||||
|
return False
|
||||||
|
|
||||||
|
new_env = dict(os.environ)
|
||||||
|
new_env['_ROBOIMI_LD_PRELOAD_SANITIZED'] = '1'
|
||||||
|
if cleaned:
|
||||||
|
new_env['LD_PRELOAD'] = cleaned
|
||||||
|
else:
|
||||||
|
new_env.pop('LD_PRELOAD', None)
|
||||||
|
|
||||||
|
print(
|
||||||
|
'Detected problematic LD_PRELOAD entry for CUDA/cuDNN; re-executing train_vla.py without it.',
|
||||||
|
file=sys.stderr,
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
os.execvpe(sys.executable, [sys.executable, *sys.argv], new_env)
|
||||||
|
|
||||||
|
|
||||||
_REPO_ROOT = _ensure_repo_root_on_syspath()
|
_REPO_ROOT = _ensure_repo_root_on_syspath()
|
||||||
|
|
||||||
from hydra.utils import instantiate
|
from hydra.utils import instantiate
|
||||||
@@ -35,6 +76,9 @@ if not OmegaConf.has_resolver("len"):
|
|||||||
OmegaConf.register_new_resolver("len", lambda x: len(x))
|
OmegaConf.register_new_resolver("len", lambda x: len(x))
|
||||||
|
|
||||||
|
|
||||||
|
_maybe_reexec_without_problematic_ld_preload()
|
||||||
|
|
||||||
|
|
||||||
def _configure_cuda_runtime(cfg):
|
def _configure_cuda_runtime(cfg):
|
||||||
"""Apply process-level CUDA runtime switches required by this environment."""
|
"""Apply process-level CUDA runtime switches required by this environment."""
|
||||||
if str(cfg.train.device).startswith('cuda') and bool(cfg.train.get('disable_cudnn', False)):
|
if str(cfg.train.device).startswith('cuda') and bool(cfg.train.get('disable_cudnn', False)):
|
||||||
|
|||||||
@@ -214,6 +214,25 @@ class TrainVLATransformerOptimizerTest(unittest.TestCase):
|
|||||||
for group in optimizer.param_groups
|
for group in optimizer.param_groups
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def test_clean_ld_preload_value_removes_problematic_nxegl_entry(self):
|
||||||
|
module = self._load_train_vla_module()
|
||||||
|
|
||||||
|
cleaned, changed = module._clean_ld_preload_value(
|
||||||
|
'/usr/lib/libfoo.so /usr/NX/lib/libnxegl.so /usr/lib/libbar.so'
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(changed)
|
||||||
|
self.assertEqual(cleaned, '/usr/lib/libfoo.so /usr/lib/libbar.so')
|
||||||
|
|
||||||
|
def test_clean_ld_preload_value_leaves_safe_entries_unchanged(self):
|
||||||
|
module = self._load_train_vla_module()
|
||||||
|
|
||||||
|
cleaned, changed = module._clean_ld_preload_value('/usr/lib/libfoo.so /usr/lib/libbar.so')
|
||||||
|
|
||||||
|
self.assertFalse(changed)
|
||||||
|
self.assertEqual(cleaned, '/usr/lib/libfoo.so /usr/lib/libbar.so')
|
||||||
|
|
||||||
|
|
||||||
def test_configure_cuda_runtime_can_disable_cudnn_for_training(self):
|
def test_configure_cuda_runtime_can_disable_cudnn_for_training(self):
|
||||||
module = self._load_train_vla_module()
|
module = self._load_train_vla_module()
|
||||||
cfg = AttrDict(train=AttrDict(device='cuda', disable_cudnn=True))
|
cfg = AttrDict(train=AttrDict(device='cuda', disable_cudnn=True))
|
||||||
|
|||||||
Reference in New Issue
Block a user