import importlib.util import os import sys import tempfile import types import unittest from pathlib import Path from unittest import mock import torch from torch import nn _REPO_ROOT = Path(__file__).resolve().parents[1] _TRAIN_VLA_PATH = _REPO_ROOT / 'roboimi/demos/vla_scripts/train_vla.py' class AttrDict(dict): def __getattr__(self, name): try: return self[name] except KeyError as exc: raise AttributeError(name) from exc def __setattr__(self, name, value): self[name] = value class FakeDataset: def __len__(self): return 4 class FakeLoader: def __len__(self): return 1 def __iter__(self): return iter(()) class FakeScheduler: def state_dict(self): return {} def load_state_dict(self, state_dict): return None class RecordingAdamW: created = [] def __init__(self, params, lr, weight_decay): self.lr = lr self.weight_decay = weight_decay self.param_groups = self._normalize_param_groups(params, lr, weight_decay) RecordingAdamW.created.append(self) @staticmethod def _normalize_param_groups(params, lr, weight_decay): if isinstance(params, (list, tuple)) and params and isinstance(params[0], dict): groups = [] for group in params: normalized = dict(group) normalized['params'] = list(group['params']) normalized.setdefault('lr', lr) groups.append(normalized) return groups return [{ 'params': list(params), 'lr': lr, 'weight_decay': weight_decay, }] def state_dict(self): return {} def load_state_dict(self, state_dict): return None class RecordingTransformerHead(nn.Module): def __init__(self): super().__init__() self.proj = nn.Linear(4, 4) self.norm = nn.LayerNorm(4) self.optim_group_calls = [] def get_optim_groups(self, weight_decay): self.optim_group_calls.append(weight_decay) return [ { 'params': [self.proj.weight], 'weight_decay': weight_decay, }, { 'params': [self.proj.bias, self.norm.weight, self.norm.bias], 'weight_decay': 0.0, }, ] class FakeIMFAgent(nn.Module): def __init__(self): super().__init__() self.head_type = 'imf_transformer' self.noise_pred_net = RecordingTransformerHead() self.backbone = nn.Linear(4, 3) self.adapter = nn.Linear(3, 2, bias=False) class FakeTransformerAgent(nn.Module): def __init__(self, *, head_type='transformer'): super().__init__() self.head_type = head_type self.noise_pred_net = RecordingTransformerHead() self.backbone = nn.Linear(4, 3) self.adapter = nn.Linear(3, 2, bias=False) self.frozen = nn.Linear(2, 2) for param in self.frozen.parameters(): param.requires_grad = False def to(self, device): return self def get_normalization_stats(self): return {} class TrainVLATransformerOptimizerTest(unittest.TestCase): def setUp(self): RecordingAdamW.created = [] def _load_train_vla_module(self): hydra_module = types.ModuleType('hydra') hydra_utils_module = types.ModuleType('hydra.utils') hydra_utils_module.instantiate = lambda *args, **kwargs: None def hydra_main(**_kwargs): def decorator(func): return func return decorator hydra_module.main = hydra_main hydra_module.utils = hydra_utils_module class OmegaConfStub: _resolvers = {} @classmethod def has_resolver(cls, name): return name in cls._resolvers @classmethod def register_new_resolver(cls, name, resolver): cls._resolvers[name] = resolver @staticmethod def to_yaml(_cfg): return 'stub-config' omegaconf_module = types.ModuleType('omegaconf') omegaconf_module.DictConfig = dict omegaconf_module.OmegaConf = OmegaConfStub module_name = 'train_vla_optimizer_test_module' spec = importlib.util.spec_from_file_location(module_name, _TRAIN_VLA_PATH) module = importlib.util.module_from_spec(spec) with mock.patch.dict( sys.modules, { 'hydra': hydra_module, 'hydra.utils': hydra_utils_module, 'omegaconf': omegaconf_module, }, ): assert spec.loader is not None spec.loader.exec_module(module) return module def _make_cfg(self): return AttrDict( train=AttrDict( device='cpu', batch_size=2, num_workers=0, val_split=0, seed=0, lr=1e-4, max_steps=0, log_freq=1, save_freq=100, warmup_steps=1, scheduler_type='constant', min_lr=0.0, grad_clip=1.0, weight_decay=0.123, pretrained_ckpt=None, resume_ckpt=None, ), data=AttrDict( camera_names=('front',), ), agent=AttrDict( _target_='fake.agent', ), ) def _group_names(self, agent, optimizer): names_by_param_id = {id(param): name for name, param in agent.named_parameters()} return [ {names_by_param_id[id(param)] for param in group['params']} for group in optimizer.param_groups ] def test_clean_ld_preload_value_removes_problematic_nxegl_entry(self): module = self._load_train_vla_module() cleaned, changed = module._clean_ld_preload_value( '/usr/lib/libfoo.so /usr/NX/lib/libnxegl.so /usr/lib/libbar.so' ) self.assertTrue(changed) self.assertEqual(cleaned, '/usr/lib/libfoo.so /usr/lib/libbar.so') def test_clean_ld_preload_value_leaves_safe_entries_unchanged(self): module = self._load_train_vla_module() cleaned, changed = module._clean_ld_preload_value('/usr/lib/libfoo.so /usr/lib/libbar.so') self.assertFalse(changed) self.assertEqual(cleaned, '/usr/lib/libfoo.so /usr/lib/libbar.so') def test_configure_cuda_runtime_can_disable_cudnn_for_training(self): module = self._load_train_vla_module() cfg = AttrDict(train=AttrDict(device='cuda', disable_cudnn=True)) original = module.torch.backends.cudnn.enabled try: module.torch.backends.cudnn.enabled = True module._configure_cuda_runtime(cfg) self.assertFalse(module.torch.backends.cudnn.enabled) finally: module.torch.backends.cudnn.enabled = original def test_train_script_uses_file_based_repo_root_on_sys_path(self): module = self._load_train_vla_module() fake_sys_path = ['/tmp/site-packages', '/another/path'] with mock.patch.object(module.sys, 'path', fake_sys_path): repo_root = module._ensure_repo_root_on_syspath() self.assertEqual(Path(repo_root).resolve(), _REPO_ROOT.resolve()) self.assertEqual(Path(fake_sys_path[0]).resolve(), _REPO_ROOT.resolve()) def test_non_transformer_head_with_get_optim_groups_still_uses_custom_groups(self): module = self._load_train_vla_module() agent = FakeIMFAgent() optimizer = module.build_training_optimizer(agent, lr=1e-4, weight_decay=0.123) self.assertEqual(agent.noise_pred_net.optim_group_calls, [0.123]) group_names = self._group_names(agent, optimizer) self.assertEqual(group_names[0], {'noise_pred_net.proj.weight'}) self.assertEqual(group_names[1], { 'noise_pred_net.proj.bias', 'noise_pred_net.norm.weight', 'noise_pred_net.norm.bias', }) self.assertEqual(group_names[2], {'backbone.weight', 'backbone.bias', 'adapter.weight'}) def test_transformer_training_prefers_head_optim_groups_and_keeps_remaining_trainable_params(self): module = self._load_train_vla_module() agent = FakeTransformerAgent() cfg = self._make_cfg() def fake_instantiate(config_node, **_kwargs): if config_node is cfg.data: return FakeDataset() if config_node is cfg.agent: return agent raise AssertionError(f'unexpected instantiate config: {config_node!r}') with tempfile.TemporaryDirectory() as tempdir: previous_cwd = os.getcwd() try: os.chdir(tempdir) with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \ mock.patch.object(module, 'DataLoader', side_effect=lambda *args, **kwargs: FakeLoader()), \ mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \ mock.patch.object(module, 'AdamW', RecordingAdamW), \ mock.patch.object(module.torch, 'save', return_value=None), \ mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: iterable): module.main(cfg) finally: os.chdir(previous_cwd) self.assertEqual(agent.noise_pred_net.optim_group_calls, [cfg.train.weight_decay]) optimizer = RecordingAdamW.created[-1] trainable_names = { name for name, param in agent.named_parameters() if param.requires_grad } grouped_names = self._group_names(agent, optimizer) optimizer_names = set().union(*grouped_names) expected_head_names = { 'noise_pred_net.proj.weight', 'noise_pred_net.proj.bias', 'noise_pred_net.norm.weight', 'noise_pred_net.norm.bias', } expected_non_head_names = { 'backbone.weight', 'backbone.bias', 'adapter.weight', } self.assertEqual(grouped_names[0], {'noise_pred_net.proj.weight'}) self.assertEqual(grouped_names[1], expected_head_names - {'noise_pred_net.proj.weight'}) self.assertEqual(grouped_names[2], expected_non_head_names) self.assertEqual(optimizer.param_groups[0]['weight_decay'], cfg.train.weight_decay) self.assertEqual(optimizer.param_groups[1]['weight_decay'], 0.0) self.assertEqual(optimizer.param_groups[2]['weight_decay'], cfg.train.weight_decay) self.assertEqual(optimizer_names, trainable_names) flattened_param_ids = [ id(param) for group in optimizer.param_groups for param in group['params'] ] self.assertEqual(len(flattened_param_ids), len(set(flattened_param_ids))) self.assertNotIn('frozen.weight', optimizer_names) self.assertNotIn('frozen.bias', optimizer_names) def test_any_head_with_get_optim_groups_uses_custom_groups_even_without_transformer_head_type(self): module = self._load_train_vla_module() agent = FakeTransformerAgent(head_type='imf') with mock.patch.object(module, 'AdamW', RecordingAdamW): optimizer = module.build_training_optimizer(agent, lr=1e-4, weight_decay=0.123) self.assertEqual(agent.noise_pred_net.optim_group_calls, [0.123]) grouped_names = self._group_names(agent, optimizer) self.assertEqual(grouped_names[0], {'noise_pred_net.proj.weight'}) self.assertEqual( grouped_names[1], {'noise_pred_net.proj.bias', 'noise_pred_net.norm.weight', 'noise_pred_net.norm.bias'}, ) self.assertEqual(grouped_names[2], {'backbone.weight', 'backbone.bias', 'adapter.weight'}) def test_transformer_optimizer_ignores_frozen_head_params_returned_by_head_groups(self): module = self._load_train_vla_module() agent = FakeTransformerAgent() agent.noise_pred_net.norm.bias.requires_grad = False cfg = self._make_cfg() def fake_instantiate(config_node, **_kwargs): if config_node is cfg.data: return FakeDataset() if config_node is cfg.agent: return agent raise AssertionError(f'unexpected instantiate config: {config_node!r}') with tempfile.TemporaryDirectory() as tempdir: previous_cwd = os.getcwd() try: os.chdir(tempdir) with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \ mock.patch.object(module, 'DataLoader', side_effect=lambda *args, **kwargs: FakeLoader()), \ mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \ mock.patch.object(module, 'AdamW', RecordingAdamW), \ mock.patch.object(module.torch, 'save', return_value=None), \ mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: iterable): module.main(cfg) finally: os.chdir(previous_cwd) optimizer = RecordingAdamW.created[-1] optimizer_names = set().union(*self._group_names(agent, optimizer)) trainable_names = { name for name, param in agent.named_parameters() if param.requires_grad } self.assertEqual(agent.noise_pred_net.optim_group_calls, [cfg.train.weight_decay]) self.assertEqual(optimizer_names, trainable_names) self.assertNotIn('noise_pred_net.norm.bias', optimizer_names) if __name__ == '__main__': unittest.main()