Files
roboimi/tests/test_train_vla_transformer_optimizer.py

311 lines
10 KiB
Python

import importlib.util
import os
import sys
import tempfile
import types
import unittest
from pathlib import Path
from unittest import mock
import torch
from torch import nn
_REPO_ROOT = Path(__file__).resolve().parents[1]
_TRAIN_VLA_PATH = _REPO_ROOT / 'roboimi/demos/vla_scripts/train_vla.py'
class AttrDict(dict):
def __getattr__(self, name):
try:
return self[name]
except KeyError as exc:
raise AttributeError(name) from exc
def __setattr__(self, name, value):
self[name] = value
class FakeDataset:
def __len__(self):
return 4
class FakeLoader:
def __len__(self):
return 1
def __iter__(self):
return iter(())
class FakeScheduler:
def state_dict(self):
return {}
def load_state_dict(self, state_dict):
return None
class RecordingAdamW:
created = []
def __init__(self, params, lr, weight_decay):
self.lr = lr
self.weight_decay = weight_decay
self.param_groups = self._normalize_param_groups(params, lr, weight_decay)
RecordingAdamW.created.append(self)
@staticmethod
def _normalize_param_groups(params, lr, weight_decay):
if isinstance(params, (list, tuple)) and params and isinstance(params[0], dict):
groups = []
for group in params:
normalized = dict(group)
normalized['params'] = list(group['params'])
normalized.setdefault('lr', lr)
groups.append(normalized)
return groups
return [{
'params': list(params),
'lr': lr,
'weight_decay': weight_decay,
}]
def state_dict(self):
return {}
def load_state_dict(self, state_dict):
return None
class RecordingTransformerHead(nn.Module):
def __init__(self):
super().__init__()
self.proj = nn.Linear(4, 4)
self.norm = nn.LayerNorm(4)
self.optim_group_calls = []
def get_optim_groups(self, weight_decay):
self.optim_group_calls.append(weight_decay)
return [
{
'params': [self.proj.weight],
'weight_decay': weight_decay,
},
{
'params': [self.proj.bias, self.norm.weight, self.norm.bias],
'weight_decay': 0.0,
},
]
class FakeTransformerAgent(nn.Module):
def __init__(self):
super().__init__()
self.head_type = 'transformer'
self.noise_pred_net = RecordingTransformerHead()
self.backbone = nn.Linear(4, 3)
self.adapter = nn.Linear(3, 2, bias=False)
self.frozen = nn.Linear(2, 2)
for param in self.frozen.parameters():
param.requires_grad = False
def to(self, device):
return self
def get_normalization_stats(self):
return {}
class TrainVLATransformerOptimizerTest(unittest.TestCase):
def setUp(self):
RecordingAdamW.created = []
def _load_train_vla_module(self):
hydra_module = types.ModuleType('hydra')
hydra_utils_module = types.ModuleType('hydra.utils')
hydra_utils_module.instantiate = lambda *args, **kwargs: None
def hydra_main(**_kwargs):
def decorator(func):
return func
return decorator
hydra_module.main = hydra_main
hydra_module.utils = hydra_utils_module
class OmegaConfStub:
_resolvers = {}
@classmethod
def has_resolver(cls, name):
return name in cls._resolvers
@classmethod
def register_new_resolver(cls, name, resolver):
cls._resolvers[name] = resolver
@staticmethod
def to_yaml(_cfg):
return 'stub-config'
omegaconf_module = types.ModuleType('omegaconf')
omegaconf_module.DictConfig = dict
omegaconf_module.OmegaConf = OmegaConfStub
module_name = 'train_vla_optimizer_test_module'
spec = importlib.util.spec_from_file_location(module_name, _TRAIN_VLA_PATH)
module = importlib.util.module_from_spec(spec)
with mock.patch.dict(
sys.modules,
{
'hydra': hydra_module,
'hydra.utils': hydra_utils_module,
'omegaconf': omegaconf_module,
},
):
assert spec.loader is not None
spec.loader.exec_module(module)
return module
def _make_cfg(self):
return AttrDict(
train=AttrDict(
device='cpu',
batch_size=2,
num_workers=0,
val_split=0,
seed=0,
lr=1e-4,
max_steps=0,
log_freq=1,
save_freq=100,
warmup_steps=1,
scheduler_type='constant',
min_lr=0.0,
grad_clip=1.0,
weight_decay=0.123,
pretrained_ckpt=None,
resume_ckpt=None,
),
data=AttrDict(
camera_names=('front',),
),
agent=AttrDict(
_target_='fake.agent',
),
)
def _group_names(self, agent, optimizer):
names_by_param_id = {id(param): name for name, param in agent.named_parameters()}
return [
{names_by_param_id[id(param)] for param in group['params']}
for group in optimizer.param_groups
]
def test_transformer_training_prefers_head_optim_groups_and_keeps_remaining_trainable_params(self):
module = self._load_train_vla_module()
agent = FakeTransformerAgent()
cfg = self._make_cfg()
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return FakeDataset()
if config_node is cfg.agent:
return agent
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(module, 'DataLoader', side_effect=lambda *args, **kwargs: FakeLoader()), \
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
mock.patch.object(module, 'AdamW', RecordingAdamW), \
mock.patch.object(module.torch, 'save', return_value=None), \
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: iterable):
module.main(cfg)
finally:
os.chdir(previous_cwd)
self.assertEqual(agent.noise_pred_net.optim_group_calls, [cfg.train.weight_decay])
optimizer = RecordingAdamW.created[-1]
trainable_names = {
name for name, param in agent.named_parameters() if param.requires_grad
}
grouped_names = self._group_names(agent, optimizer)
optimizer_names = set().union(*grouped_names)
expected_head_names = {
'noise_pred_net.proj.weight',
'noise_pred_net.proj.bias',
'noise_pred_net.norm.weight',
'noise_pred_net.norm.bias',
}
expected_non_head_names = {
'backbone.weight',
'backbone.bias',
'adapter.weight',
}
self.assertEqual(grouped_names[0], {'noise_pred_net.proj.weight'})
self.assertEqual(grouped_names[1], expected_head_names - {'noise_pred_net.proj.weight'})
self.assertEqual(grouped_names[2], expected_non_head_names)
self.assertEqual(optimizer.param_groups[0]['weight_decay'], cfg.train.weight_decay)
self.assertEqual(optimizer.param_groups[1]['weight_decay'], 0.0)
self.assertEqual(optimizer.param_groups[2]['weight_decay'], cfg.train.weight_decay)
self.assertEqual(optimizer_names, trainable_names)
flattened_param_ids = [
id(param)
for group in optimizer.param_groups
for param in group['params']
]
self.assertEqual(len(flattened_param_ids), len(set(flattened_param_ids)))
self.assertNotIn('frozen.weight', optimizer_names)
self.assertNotIn('frozen.bias', optimizer_names)
def test_transformer_optimizer_ignores_frozen_head_params_returned_by_head_groups(self):
module = self._load_train_vla_module()
agent = FakeTransformerAgent()
agent.noise_pred_net.norm.bias.requires_grad = False
cfg = self._make_cfg()
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return FakeDataset()
if config_node is cfg.agent:
return agent
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(module, 'DataLoader', side_effect=lambda *args, **kwargs: FakeLoader()), \
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
mock.patch.object(module, 'AdamW', RecordingAdamW), \
mock.patch.object(module.torch, 'save', return_value=None), \
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: iterable):
module.main(cfg)
finally:
os.chdir(previous_cwd)
optimizer = RecordingAdamW.created[-1]
optimizer_names = set().union(*self._group_names(agent, optimizer))
trainable_names = {
name for name, param in agent.named_parameters() if param.requires_grad
}
self.assertEqual(agent.noise_pred_net.optim_group_calls, [cfg.train.weight_decay])
self.assertEqual(optimizer_names, trainable_names)
self.assertNotIn('noise_pred_net.norm.bias', optimizer_names)
if __name__ == '__main__':
unittest.main()