396 lines
14 KiB
Python
396 lines
14 KiB
Python
import importlib.util
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
import types
|
|
import unittest
|
|
from pathlib import Path
|
|
from unittest import mock
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
|
|
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
|
_TRAIN_VLA_PATH = _REPO_ROOT / 'roboimi/demos/vla_scripts/train_vla.py'
|
|
|
|
|
|
class AttrDict(dict):
|
|
def __getattr__(self, name):
|
|
try:
|
|
return self[name]
|
|
except KeyError as exc:
|
|
raise AttributeError(name) from exc
|
|
|
|
def __setattr__(self, name, value):
|
|
self[name] = value
|
|
|
|
|
|
class FakeDataset:
|
|
def __len__(self):
|
|
return 4
|
|
|
|
|
|
class FakeLoader:
|
|
def __len__(self):
|
|
return 1
|
|
|
|
def __iter__(self):
|
|
return iter(())
|
|
|
|
|
|
class FakeScheduler:
|
|
def state_dict(self):
|
|
return {}
|
|
|
|
def load_state_dict(self, state_dict):
|
|
return None
|
|
|
|
|
|
class RecordingAdamW:
|
|
created = []
|
|
|
|
def __init__(self, params, lr, weight_decay):
|
|
self.lr = lr
|
|
self.weight_decay = weight_decay
|
|
self.param_groups = self._normalize_param_groups(params, lr, weight_decay)
|
|
RecordingAdamW.created.append(self)
|
|
|
|
@staticmethod
|
|
def _normalize_param_groups(params, lr, weight_decay):
|
|
if isinstance(params, (list, tuple)) and params and isinstance(params[0], dict):
|
|
groups = []
|
|
for group in params:
|
|
normalized = dict(group)
|
|
normalized['params'] = list(group['params'])
|
|
normalized.setdefault('lr', lr)
|
|
groups.append(normalized)
|
|
return groups
|
|
|
|
return [{
|
|
'params': list(params),
|
|
'lr': lr,
|
|
'weight_decay': weight_decay,
|
|
}]
|
|
|
|
def state_dict(self):
|
|
return {}
|
|
|
|
def load_state_dict(self, state_dict):
|
|
return None
|
|
|
|
|
|
class RecordingTransformerHead(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.proj = nn.Linear(4, 4)
|
|
self.norm = nn.LayerNorm(4)
|
|
self.optim_group_calls = []
|
|
|
|
def get_optim_groups(self, weight_decay):
|
|
self.optim_group_calls.append(weight_decay)
|
|
return [
|
|
{
|
|
'params': [self.proj.weight],
|
|
'weight_decay': weight_decay,
|
|
},
|
|
{
|
|
'params': [self.proj.bias, self.norm.weight, self.norm.bias],
|
|
'weight_decay': 0.0,
|
|
},
|
|
]
|
|
|
|
|
|
class FakeIMFAgent(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.head_type = 'imf_transformer'
|
|
self.noise_pred_net = RecordingTransformerHead()
|
|
self.backbone = nn.Linear(4, 3)
|
|
self.adapter = nn.Linear(3, 2, bias=False)
|
|
|
|
|
|
class FakeTransformerAgent(nn.Module):
|
|
def __init__(self, *, head_type='transformer'):
|
|
super().__init__()
|
|
self.head_type = head_type
|
|
self.noise_pred_net = RecordingTransformerHead()
|
|
self.backbone = nn.Linear(4, 3)
|
|
self.adapter = nn.Linear(3, 2, bias=False)
|
|
self.frozen = nn.Linear(2, 2)
|
|
for param in self.frozen.parameters():
|
|
param.requires_grad = False
|
|
|
|
def to(self, device):
|
|
return self
|
|
|
|
def get_normalization_stats(self):
|
|
return {}
|
|
|
|
|
|
class TrainVLATransformerOptimizerTest(unittest.TestCase):
|
|
def setUp(self):
|
|
RecordingAdamW.created = []
|
|
|
|
def _load_train_vla_module(self):
|
|
hydra_module = types.ModuleType('hydra')
|
|
hydra_utils_module = types.ModuleType('hydra.utils')
|
|
hydra_utils_module.instantiate = lambda *args, **kwargs: None
|
|
|
|
def hydra_main(**_kwargs):
|
|
def decorator(func):
|
|
return func
|
|
return decorator
|
|
|
|
hydra_module.main = hydra_main
|
|
hydra_module.utils = hydra_utils_module
|
|
|
|
class OmegaConfStub:
|
|
_resolvers = {}
|
|
|
|
@classmethod
|
|
def has_resolver(cls, name):
|
|
return name in cls._resolvers
|
|
|
|
@classmethod
|
|
def register_new_resolver(cls, name, resolver):
|
|
cls._resolvers[name] = resolver
|
|
|
|
@staticmethod
|
|
def to_yaml(_cfg):
|
|
return 'stub-config'
|
|
|
|
omegaconf_module = types.ModuleType('omegaconf')
|
|
omegaconf_module.DictConfig = dict
|
|
omegaconf_module.OmegaConf = OmegaConfStub
|
|
|
|
module_name = 'train_vla_optimizer_test_module'
|
|
spec = importlib.util.spec_from_file_location(module_name, _TRAIN_VLA_PATH)
|
|
module = importlib.util.module_from_spec(spec)
|
|
with mock.patch.dict(
|
|
sys.modules,
|
|
{
|
|
'hydra': hydra_module,
|
|
'hydra.utils': hydra_utils_module,
|
|
'omegaconf': omegaconf_module,
|
|
},
|
|
):
|
|
assert spec.loader is not None
|
|
spec.loader.exec_module(module)
|
|
return module
|
|
|
|
def _make_cfg(self):
|
|
return AttrDict(
|
|
train=AttrDict(
|
|
device='cpu',
|
|
batch_size=2,
|
|
num_workers=0,
|
|
val_split=0,
|
|
seed=0,
|
|
lr=1e-4,
|
|
max_steps=0,
|
|
log_freq=1,
|
|
save_freq=100,
|
|
warmup_steps=1,
|
|
scheduler_type='constant',
|
|
min_lr=0.0,
|
|
grad_clip=1.0,
|
|
weight_decay=0.123,
|
|
pretrained_ckpt=None,
|
|
resume_ckpt=None,
|
|
),
|
|
data=AttrDict(
|
|
camera_names=('front',),
|
|
),
|
|
agent=AttrDict(
|
|
_target_='fake.agent',
|
|
),
|
|
)
|
|
|
|
def _group_names(self, agent, optimizer):
|
|
names_by_param_id = {id(param): name for name, param in agent.named_parameters()}
|
|
return [
|
|
{names_by_param_id[id(param)] for param in group['params']}
|
|
for group in optimizer.param_groups
|
|
]
|
|
|
|
def test_clean_ld_preload_value_removes_problematic_nxegl_entry(self):
|
|
module = self._load_train_vla_module()
|
|
|
|
cleaned, changed = module._clean_ld_preload_value(
|
|
'/usr/lib/libfoo.so /usr/NX/lib/libnxegl.so /usr/lib/libbar.so'
|
|
)
|
|
|
|
self.assertTrue(changed)
|
|
self.assertEqual(cleaned, '/usr/lib/libfoo.so /usr/lib/libbar.so')
|
|
|
|
def test_clean_ld_preload_value_leaves_safe_entries_unchanged(self):
|
|
module = self._load_train_vla_module()
|
|
|
|
cleaned, changed = module._clean_ld_preload_value('/usr/lib/libfoo.so /usr/lib/libbar.so')
|
|
|
|
self.assertFalse(changed)
|
|
self.assertEqual(cleaned, '/usr/lib/libfoo.so /usr/lib/libbar.so')
|
|
|
|
|
|
def test_configure_cuda_runtime_can_disable_cudnn_for_training(self):
|
|
module = self._load_train_vla_module()
|
|
cfg = AttrDict(train=AttrDict(device='cuda', disable_cudnn=True))
|
|
|
|
original = module.torch.backends.cudnn.enabled
|
|
try:
|
|
module.torch.backends.cudnn.enabled = True
|
|
module._configure_cuda_runtime(cfg)
|
|
self.assertFalse(module.torch.backends.cudnn.enabled)
|
|
finally:
|
|
module.torch.backends.cudnn.enabled = original
|
|
|
|
|
|
def test_train_script_uses_file_based_repo_root_on_sys_path(self):
|
|
module = self._load_train_vla_module()
|
|
|
|
fake_sys_path = ['/tmp/site-packages', '/another/path']
|
|
with mock.patch.object(module.sys, 'path', fake_sys_path):
|
|
repo_root = module._ensure_repo_root_on_syspath()
|
|
|
|
self.assertEqual(Path(repo_root).resolve(), _REPO_ROOT.resolve())
|
|
self.assertEqual(Path(fake_sys_path[0]).resolve(), _REPO_ROOT.resolve())
|
|
|
|
|
|
def test_non_transformer_head_with_get_optim_groups_still_uses_custom_groups(self):
|
|
module = self._load_train_vla_module()
|
|
agent = FakeIMFAgent()
|
|
|
|
optimizer = module.build_training_optimizer(agent, lr=1e-4, weight_decay=0.123)
|
|
|
|
self.assertEqual(agent.noise_pred_net.optim_group_calls, [0.123])
|
|
group_names = self._group_names(agent, optimizer)
|
|
self.assertEqual(group_names[0], {'noise_pred_net.proj.weight'})
|
|
self.assertEqual(group_names[1], {
|
|
'noise_pred_net.proj.bias',
|
|
'noise_pred_net.norm.weight',
|
|
'noise_pred_net.norm.bias',
|
|
})
|
|
self.assertEqual(group_names[2], {'backbone.weight', 'backbone.bias', 'adapter.weight'})
|
|
|
|
|
|
def test_transformer_training_prefers_head_optim_groups_and_keeps_remaining_trainable_params(self):
|
|
module = self._load_train_vla_module()
|
|
agent = FakeTransformerAgent()
|
|
cfg = self._make_cfg()
|
|
|
|
def fake_instantiate(config_node, **_kwargs):
|
|
if config_node is cfg.data:
|
|
return FakeDataset()
|
|
if config_node is cfg.agent:
|
|
return agent
|
|
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
|
|
|
with tempfile.TemporaryDirectory() as tempdir:
|
|
previous_cwd = os.getcwd()
|
|
try:
|
|
os.chdir(tempdir)
|
|
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
|
mock.patch.object(module, 'DataLoader', side_effect=lambda *args, **kwargs: FakeLoader()), \
|
|
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
|
|
mock.patch.object(module, 'AdamW', RecordingAdamW), \
|
|
mock.patch.object(module.torch, 'save', return_value=None), \
|
|
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: iterable):
|
|
module.main(cfg)
|
|
finally:
|
|
os.chdir(previous_cwd)
|
|
|
|
self.assertEqual(agent.noise_pred_net.optim_group_calls, [cfg.train.weight_decay])
|
|
|
|
optimizer = RecordingAdamW.created[-1]
|
|
trainable_names = {
|
|
name for name, param in agent.named_parameters() if param.requires_grad
|
|
}
|
|
grouped_names = self._group_names(agent, optimizer)
|
|
optimizer_names = set().union(*grouped_names)
|
|
expected_head_names = {
|
|
'noise_pred_net.proj.weight',
|
|
'noise_pred_net.proj.bias',
|
|
'noise_pred_net.norm.weight',
|
|
'noise_pred_net.norm.bias',
|
|
}
|
|
expected_non_head_names = {
|
|
'backbone.weight',
|
|
'backbone.bias',
|
|
'adapter.weight',
|
|
}
|
|
|
|
self.assertEqual(grouped_names[0], {'noise_pred_net.proj.weight'})
|
|
self.assertEqual(grouped_names[1], expected_head_names - {'noise_pred_net.proj.weight'})
|
|
self.assertEqual(grouped_names[2], expected_non_head_names)
|
|
self.assertEqual(optimizer.param_groups[0]['weight_decay'], cfg.train.weight_decay)
|
|
self.assertEqual(optimizer.param_groups[1]['weight_decay'], 0.0)
|
|
self.assertEqual(optimizer.param_groups[2]['weight_decay'], cfg.train.weight_decay)
|
|
self.assertEqual(optimizer_names, trainable_names)
|
|
|
|
flattened_param_ids = [
|
|
id(param)
|
|
for group in optimizer.param_groups
|
|
for param in group['params']
|
|
]
|
|
self.assertEqual(len(flattened_param_ids), len(set(flattened_param_ids)))
|
|
self.assertNotIn('frozen.weight', optimizer_names)
|
|
self.assertNotIn('frozen.bias', optimizer_names)
|
|
|
|
def test_any_head_with_get_optim_groups_uses_custom_groups_even_without_transformer_head_type(self):
|
|
module = self._load_train_vla_module()
|
|
agent = FakeTransformerAgent(head_type='imf')
|
|
|
|
with mock.patch.object(module, 'AdamW', RecordingAdamW):
|
|
optimizer = module.build_training_optimizer(agent, lr=1e-4, weight_decay=0.123)
|
|
|
|
self.assertEqual(agent.noise_pred_net.optim_group_calls, [0.123])
|
|
grouped_names = self._group_names(agent, optimizer)
|
|
self.assertEqual(grouped_names[0], {'noise_pred_net.proj.weight'})
|
|
self.assertEqual(
|
|
grouped_names[1],
|
|
{'noise_pred_net.proj.bias', 'noise_pred_net.norm.weight', 'noise_pred_net.norm.bias'},
|
|
)
|
|
self.assertEqual(grouped_names[2], {'backbone.weight', 'backbone.bias', 'adapter.weight'})
|
|
|
|
def test_transformer_optimizer_ignores_frozen_head_params_returned_by_head_groups(self):
|
|
module = self._load_train_vla_module()
|
|
agent = FakeTransformerAgent()
|
|
agent.noise_pred_net.norm.bias.requires_grad = False
|
|
cfg = self._make_cfg()
|
|
|
|
def fake_instantiate(config_node, **_kwargs):
|
|
if config_node is cfg.data:
|
|
return FakeDataset()
|
|
if config_node is cfg.agent:
|
|
return agent
|
|
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
|
|
|
with tempfile.TemporaryDirectory() as tempdir:
|
|
previous_cwd = os.getcwd()
|
|
try:
|
|
os.chdir(tempdir)
|
|
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
|
mock.patch.object(module, 'DataLoader', side_effect=lambda *args, **kwargs: FakeLoader()), \
|
|
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
|
|
mock.patch.object(module, 'AdamW', RecordingAdamW), \
|
|
mock.patch.object(module.torch, 'save', return_value=None), \
|
|
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: iterable):
|
|
module.main(cfg)
|
|
finally:
|
|
os.chdir(previous_cwd)
|
|
|
|
optimizer = RecordingAdamW.created[-1]
|
|
optimizer_names = set().union(*self._group_names(agent, optimizer))
|
|
trainable_names = {
|
|
name for name, param in agent.named_parameters() if param.requires_grad
|
|
}
|
|
|
|
self.assertEqual(agent.noise_pred_net.optim_group_calls, [cfg.train.weight_decay])
|
|
self.assertEqual(optimizer_names, trainable_names)
|
|
self.assertNotIn('noise_pred_net.norm.bias', optimizer_names)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|