Files
roboimi/tests/test_train_vla_transformer_optimizer.py
2026-04-01 23:35:31 +08:00

377 lines
13 KiB
Python

import importlib.util
import os
import sys
import tempfile
import types
import unittest
from pathlib import Path
from unittest import mock
import torch
from torch import nn
_REPO_ROOT = Path(__file__).resolve().parents[1]
_TRAIN_VLA_PATH = _REPO_ROOT / 'roboimi/demos/vla_scripts/train_vla.py'
class AttrDict(dict):
def __getattr__(self, name):
try:
return self[name]
except KeyError as exc:
raise AttributeError(name) from exc
def __setattr__(self, name, value):
self[name] = value
class FakeDataset:
def __len__(self):
return 4
class FakeLoader:
def __len__(self):
return 1
def __iter__(self):
return iter(())
class FakeScheduler:
def state_dict(self):
return {}
def load_state_dict(self, state_dict):
return None
class RecordingAdamW:
created = []
def __init__(self, params, lr, weight_decay):
self.lr = lr
self.weight_decay = weight_decay
self.param_groups = self._normalize_param_groups(params, lr, weight_decay)
RecordingAdamW.created.append(self)
@staticmethod
def _normalize_param_groups(params, lr, weight_decay):
if isinstance(params, (list, tuple)) and params and isinstance(params[0], dict):
groups = []
for group in params:
normalized = dict(group)
normalized['params'] = list(group['params'])
normalized.setdefault('lr', lr)
groups.append(normalized)
return groups
return [{
'params': list(params),
'lr': lr,
'weight_decay': weight_decay,
}]
def state_dict(self):
return {}
def load_state_dict(self, state_dict):
return None
class RecordingTransformerHead(nn.Module):
def __init__(self):
super().__init__()
self.proj = nn.Linear(4, 4)
self.norm = nn.LayerNorm(4)
self.optim_group_calls = []
def get_optim_groups(self, weight_decay):
self.optim_group_calls.append(weight_decay)
return [
{
'params': [self.proj.weight],
'weight_decay': weight_decay,
},
{
'params': [self.proj.bias, self.norm.weight, self.norm.bias],
'weight_decay': 0.0,
},
]
class FakeIMFAgent(nn.Module):
def __init__(self):
super().__init__()
self.head_type = 'imf_transformer'
self.noise_pred_net = RecordingTransformerHead()
self.backbone = nn.Linear(4, 3)
self.adapter = nn.Linear(3, 2, bias=False)
class FakeTransformerAgent(nn.Module):
def __init__(self, *, head_type='transformer'):
super().__init__()
self.head_type = head_type
self.noise_pred_net = RecordingTransformerHead()
self.backbone = nn.Linear(4, 3)
self.adapter = nn.Linear(3, 2, bias=False)
self.frozen = nn.Linear(2, 2)
for param in self.frozen.parameters():
param.requires_grad = False
def to(self, device):
return self
def get_normalization_stats(self):
return {}
class TrainVLATransformerOptimizerTest(unittest.TestCase):
def setUp(self):
RecordingAdamW.created = []
def _load_train_vla_module(self):
hydra_module = types.ModuleType('hydra')
hydra_utils_module = types.ModuleType('hydra.utils')
hydra_utils_module.instantiate = lambda *args, **kwargs: None
def hydra_main(**_kwargs):
def decorator(func):
return func
return decorator
hydra_module.main = hydra_main
hydra_module.utils = hydra_utils_module
class OmegaConfStub:
_resolvers = {}
@classmethod
def has_resolver(cls, name):
return name in cls._resolvers
@classmethod
def register_new_resolver(cls, name, resolver):
cls._resolvers[name] = resolver
@staticmethod
def to_yaml(_cfg):
return 'stub-config'
omegaconf_module = types.ModuleType('omegaconf')
omegaconf_module.DictConfig = dict
omegaconf_module.OmegaConf = OmegaConfStub
module_name = 'train_vla_optimizer_test_module'
spec = importlib.util.spec_from_file_location(module_name, _TRAIN_VLA_PATH)
module = importlib.util.module_from_spec(spec)
with mock.patch.dict(
sys.modules,
{
'hydra': hydra_module,
'hydra.utils': hydra_utils_module,
'omegaconf': omegaconf_module,
},
):
assert spec.loader is not None
spec.loader.exec_module(module)
return module
def _make_cfg(self):
return AttrDict(
train=AttrDict(
device='cpu',
batch_size=2,
num_workers=0,
val_split=0,
seed=0,
lr=1e-4,
max_steps=0,
log_freq=1,
save_freq=100,
warmup_steps=1,
scheduler_type='constant',
min_lr=0.0,
grad_clip=1.0,
weight_decay=0.123,
pretrained_ckpt=None,
resume_ckpt=None,
),
data=AttrDict(
camera_names=('front',),
),
agent=AttrDict(
_target_='fake.agent',
),
)
def _group_names(self, agent, optimizer):
names_by_param_id = {id(param): name for name, param in agent.named_parameters()}
return [
{names_by_param_id[id(param)] for param in group['params']}
for group in optimizer.param_groups
]
def test_configure_cuda_runtime_can_disable_cudnn_for_training(self):
module = self._load_train_vla_module()
cfg = AttrDict(train=AttrDict(device='cuda', disable_cudnn=True))
original = module.torch.backends.cudnn.enabled
try:
module.torch.backends.cudnn.enabled = True
module._configure_cuda_runtime(cfg)
self.assertFalse(module.torch.backends.cudnn.enabled)
finally:
module.torch.backends.cudnn.enabled = original
def test_train_script_uses_file_based_repo_root_on_sys_path(self):
module = self._load_train_vla_module()
fake_sys_path = ['/tmp/site-packages', '/another/path']
with mock.patch.object(module.sys, 'path', fake_sys_path):
repo_root = module._ensure_repo_root_on_syspath()
self.assertEqual(Path(repo_root).resolve(), _REPO_ROOT.resolve())
self.assertEqual(Path(fake_sys_path[0]).resolve(), _REPO_ROOT.resolve())
def test_non_transformer_head_with_get_optim_groups_still_uses_custom_groups(self):
module = self._load_train_vla_module()
agent = FakeIMFAgent()
optimizer = module.build_training_optimizer(agent, lr=1e-4, weight_decay=0.123)
self.assertEqual(agent.noise_pred_net.optim_group_calls, [0.123])
group_names = self._group_names(agent, optimizer)
self.assertEqual(group_names[0], {'noise_pred_net.proj.weight'})
self.assertEqual(group_names[1], {
'noise_pred_net.proj.bias',
'noise_pred_net.norm.weight',
'noise_pred_net.norm.bias',
})
self.assertEqual(group_names[2], {'backbone.weight', 'backbone.bias', 'adapter.weight'})
def test_transformer_training_prefers_head_optim_groups_and_keeps_remaining_trainable_params(self):
module = self._load_train_vla_module()
agent = FakeTransformerAgent()
cfg = self._make_cfg()
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return FakeDataset()
if config_node is cfg.agent:
return agent
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(module, 'DataLoader', side_effect=lambda *args, **kwargs: FakeLoader()), \
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
mock.patch.object(module, 'AdamW', RecordingAdamW), \
mock.patch.object(module.torch, 'save', return_value=None), \
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: iterable):
module.main(cfg)
finally:
os.chdir(previous_cwd)
self.assertEqual(agent.noise_pred_net.optim_group_calls, [cfg.train.weight_decay])
optimizer = RecordingAdamW.created[-1]
trainable_names = {
name for name, param in agent.named_parameters() if param.requires_grad
}
grouped_names = self._group_names(agent, optimizer)
optimizer_names = set().union(*grouped_names)
expected_head_names = {
'noise_pred_net.proj.weight',
'noise_pred_net.proj.bias',
'noise_pred_net.norm.weight',
'noise_pred_net.norm.bias',
}
expected_non_head_names = {
'backbone.weight',
'backbone.bias',
'adapter.weight',
}
self.assertEqual(grouped_names[0], {'noise_pred_net.proj.weight'})
self.assertEqual(grouped_names[1], expected_head_names - {'noise_pred_net.proj.weight'})
self.assertEqual(grouped_names[2], expected_non_head_names)
self.assertEqual(optimizer.param_groups[0]['weight_decay'], cfg.train.weight_decay)
self.assertEqual(optimizer.param_groups[1]['weight_decay'], 0.0)
self.assertEqual(optimizer.param_groups[2]['weight_decay'], cfg.train.weight_decay)
self.assertEqual(optimizer_names, trainable_names)
flattened_param_ids = [
id(param)
for group in optimizer.param_groups
for param in group['params']
]
self.assertEqual(len(flattened_param_ids), len(set(flattened_param_ids)))
self.assertNotIn('frozen.weight', optimizer_names)
self.assertNotIn('frozen.bias', optimizer_names)
def test_any_head_with_get_optim_groups_uses_custom_groups_even_without_transformer_head_type(self):
module = self._load_train_vla_module()
agent = FakeTransformerAgent(head_type='imf')
with mock.patch.object(module, 'AdamW', RecordingAdamW):
optimizer = module.build_training_optimizer(agent, lr=1e-4, weight_decay=0.123)
self.assertEqual(agent.noise_pred_net.optim_group_calls, [0.123])
grouped_names = self._group_names(agent, optimizer)
self.assertEqual(grouped_names[0], {'noise_pred_net.proj.weight'})
self.assertEqual(
grouped_names[1],
{'noise_pred_net.proj.bias', 'noise_pred_net.norm.weight', 'noise_pred_net.norm.bias'},
)
self.assertEqual(grouped_names[2], {'backbone.weight', 'backbone.bias', 'adapter.weight'})
def test_transformer_optimizer_ignores_frozen_head_params_returned_by_head_groups(self):
module = self._load_train_vla_module()
agent = FakeTransformerAgent()
agent.noise_pred_net.norm.bias.requires_grad = False
cfg = self._make_cfg()
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return FakeDataset()
if config_node is cfg.agent:
return agent
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(module, 'DataLoader', side_effect=lambda *args, **kwargs: FakeLoader()), \
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
mock.patch.object(module, 'AdamW', RecordingAdamW), \
mock.patch.object(module.torch, 'save', return_value=None), \
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: iterable):
module.main(cfg)
finally:
os.chdir(previous_cwd)
optimizer = RecordingAdamW.created[-1]
optimizer_names = set().union(*self._group_names(agent, optimizer))
trainable_names = {
name for name, param in agent.named_parameters() if param.requires_grad
}
self.assertEqual(agent.noise_pred_net.optim_group_calls, [cfg.train.weight_decay])
self.assertEqual(optimizer_names, trainable_names)
self.assertNotIn('noise_pred_net.norm.bias', optimizer_names)
if __name__ == '__main__':
unittest.main()