feat(vla): align transformer training stack and rollout validation
This commit is contained in:
310
tests/test_train_vla_transformer_optimizer.py
Normal file
310
tests/test_train_vla_transformer_optimizer.py
Normal file
@@ -0,0 +1,310 @@
|
||||
import importlib.util
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import types
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
_TRAIN_VLA_PATH = _REPO_ROOT / 'roboimi/demos/vla_scripts/train_vla.py'
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
def __getattr__(self, name):
|
||||
try:
|
||||
return self[name]
|
||||
except KeyError as exc:
|
||||
raise AttributeError(name) from exc
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
self[name] = value
|
||||
|
||||
|
||||
class FakeDataset:
|
||||
def __len__(self):
|
||||
return 4
|
||||
|
||||
|
||||
class FakeLoader:
|
||||
def __len__(self):
|
||||
return 1
|
||||
|
||||
def __iter__(self):
|
||||
return iter(())
|
||||
|
||||
|
||||
class FakeScheduler:
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
return None
|
||||
|
||||
|
||||
class RecordingAdamW:
|
||||
created = []
|
||||
|
||||
def __init__(self, params, lr, weight_decay):
|
||||
self.lr = lr
|
||||
self.weight_decay = weight_decay
|
||||
self.param_groups = self._normalize_param_groups(params, lr, weight_decay)
|
||||
RecordingAdamW.created.append(self)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_param_groups(params, lr, weight_decay):
|
||||
if isinstance(params, (list, tuple)) and params and isinstance(params[0], dict):
|
||||
groups = []
|
||||
for group in params:
|
||||
normalized = dict(group)
|
||||
normalized['params'] = list(group['params'])
|
||||
normalized.setdefault('lr', lr)
|
||||
groups.append(normalized)
|
||||
return groups
|
||||
|
||||
return [{
|
||||
'params': list(params),
|
||||
'lr': lr,
|
||||
'weight_decay': weight_decay,
|
||||
}]
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
return None
|
||||
|
||||
|
||||
class RecordingTransformerHead(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(4, 4)
|
||||
self.norm = nn.LayerNorm(4)
|
||||
self.optim_group_calls = []
|
||||
|
||||
def get_optim_groups(self, weight_decay):
|
||||
self.optim_group_calls.append(weight_decay)
|
||||
return [
|
||||
{
|
||||
'params': [self.proj.weight],
|
||||
'weight_decay': weight_decay,
|
||||
},
|
||||
{
|
||||
'params': [self.proj.bias, self.norm.weight, self.norm.bias],
|
||||
'weight_decay': 0.0,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class FakeTransformerAgent(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.head_type = 'transformer'
|
||||
self.noise_pred_net = RecordingTransformerHead()
|
||||
self.backbone = nn.Linear(4, 3)
|
||||
self.adapter = nn.Linear(3, 2, bias=False)
|
||||
self.frozen = nn.Linear(2, 2)
|
||||
for param in self.frozen.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def to(self, device):
|
||||
return self
|
||||
|
||||
def get_normalization_stats(self):
|
||||
return {}
|
||||
|
||||
|
||||
class TrainVLATransformerOptimizerTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
RecordingAdamW.created = []
|
||||
|
||||
def _load_train_vla_module(self):
|
||||
hydra_module = types.ModuleType('hydra')
|
||||
hydra_utils_module = types.ModuleType('hydra.utils')
|
||||
hydra_utils_module.instantiate = lambda *args, **kwargs: None
|
||||
|
||||
def hydra_main(**_kwargs):
|
||||
def decorator(func):
|
||||
return func
|
||||
return decorator
|
||||
|
||||
hydra_module.main = hydra_main
|
||||
hydra_module.utils = hydra_utils_module
|
||||
|
||||
class OmegaConfStub:
|
||||
_resolvers = {}
|
||||
|
||||
@classmethod
|
||||
def has_resolver(cls, name):
|
||||
return name in cls._resolvers
|
||||
|
||||
@classmethod
|
||||
def register_new_resolver(cls, name, resolver):
|
||||
cls._resolvers[name] = resolver
|
||||
|
||||
@staticmethod
|
||||
def to_yaml(_cfg):
|
||||
return 'stub-config'
|
||||
|
||||
omegaconf_module = types.ModuleType('omegaconf')
|
||||
omegaconf_module.DictConfig = dict
|
||||
omegaconf_module.OmegaConf = OmegaConfStub
|
||||
|
||||
module_name = 'train_vla_optimizer_test_module'
|
||||
spec = importlib.util.spec_from_file_location(module_name, _TRAIN_VLA_PATH)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
with mock.patch.dict(
|
||||
sys.modules,
|
||||
{
|
||||
'hydra': hydra_module,
|
||||
'hydra.utils': hydra_utils_module,
|
||||
'omegaconf': omegaconf_module,
|
||||
},
|
||||
):
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
def _make_cfg(self):
|
||||
return AttrDict(
|
||||
train=AttrDict(
|
||||
device='cpu',
|
||||
batch_size=2,
|
||||
num_workers=0,
|
||||
val_split=0,
|
||||
seed=0,
|
||||
lr=1e-4,
|
||||
max_steps=0,
|
||||
log_freq=1,
|
||||
save_freq=100,
|
||||
warmup_steps=1,
|
||||
scheduler_type='constant',
|
||||
min_lr=0.0,
|
||||
grad_clip=1.0,
|
||||
weight_decay=0.123,
|
||||
pretrained_ckpt=None,
|
||||
resume_ckpt=None,
|
||||
),
|
||||
data=AttrDict(
|
||||
camera_names=('front',),
|
||||
),
|
||||
agent=AttrDict(
|
||||
_target_='fake.agent',
|
||||
),
|
||||
)
|
||||
|
||||
def _group_names(self, agent, optimizer):
|
||||
names_by_param_id = {id(param): name for name, param in agent.named_parameters()}
|
||||
return [
|
||||
{names_by_param_id[id(param)] for param in group['params']}
|
||||
for group in optimizer.param_groups
|
||||
]
|
||||
|
||||
def test_transformer_training_prefers_head_optim_groups_and_keeps_remaining_trainable_params(self):
|
||||
module = self._load_train_vla_module()
|
||||
agent = FakeTransformerAgent()
|
||||
cfg = self._make_cfg()
|
||||
|
||||
def fake_instantiate(config_node, **_kwargs):
|
||||
if config_node is cfg.data:
|
||||
return FakeDataset()
|
||||
if config_node is cfg.agent:
|
||||
return agent
|
||||
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
||||
mock.patch.object(module, 'DataLoader', side_effect=lambda *args, **kwargs: FakeLoader()), \
|
||||
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
|
||||
mock.patch.object(module, 'AdamW', RecordingAdamW), \
|
||||
mock.patch.object(module.torch, 'save', return_value=None), \
|
||||
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: iterable):
|
||||
module.main(cfg)
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
self.assertEqual(agent.noise_pred_net.optim_group_calls, [cfg.train.weight_decay])
|
||||
|
||||
optimizer = RecordingAdamW.created[-1]
|
||||
trainable_names = {
|
||||
name for name, param in agent.named_parameters() if param.requires_grad
|
||||
}
|
||||
grouped_names = self._group_names(agent, optimizer)
|
||||
optimizer_names = set().union(*grouped_names)
|
||||
expected_head_names = {
|
||||
'noise_pred_net.proj.weight',
|
||||
'noise_pred_net.proj.bias',
|
||||
'noise_pred_net.norm.weight',
|
||||
'noise_pred_net.norm.bias',
|
||||
}
|
||||
expected_non_head_names = {
|
||||
'backbone.weight',
|
||||
'backbone.bias',
|
||||
'adapter.weight',
|
||||
}
|
||||
|
||||
self.assertEqual(grouped_names[0], {'noise_pred_net.proj.weight'})
|
||||
self.assertEqual(grouped_names[1], expected_head_names - {'noise_pred_net.proj.weight'})
|
||||
self.assertEqual(grouped_names[2], expected_non_head_names)
|
||||
self.assertEqual(optimizer.param_groups[0]['weight_decay'], cfg.train.weight_decay)
|
||||
self.assertEqual(optimizer.param_groups[1]['weight_decay'], 0.0)
|
||||
self.assertEqual(optimizer.param_groups[2]['weight_decay'], cfg.train.weight_decay)
|
||||
self.assertEqual(optimizer_names, trainable_names)
|
||||
|
||||
flattened_param_ids = [
|
||||
id(param)
|
||||
for group in optimizer.param_groups
|
||||
for param in group['params']
|
||||
]
|
||||
self.assertEqual(len(flattened_param_ids), len(set(flattened_param_ids)))
|
||||
self.assertNotIn('frozen.weight', optimizer_names)
|
||||
self.assertNotIn('frozen.bias', optimizer_names)
|
||||
|
||||
def test_transformer_optimizer_ignores_frozen_head_params_returned_by_head_groups(self):
|
||||
module = self._load_train_vla_module()
|
||||
agent = FakeTransformerAgent()
|
||||
agent.noise_pred_net.norm.bias.requires_grad = False
|
||||
cfg = self._make_cfg()
|
||||
|
||||
def fake_instantiate(config_node, **_kwargs):
|
||||
if config_node is cfg.data:
|
||||
return FakeDataset()
|
||||
if config_node is cfg.agent:
|
||||
return agent
|
||||
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
||||
mock.patch.object(module, 'DataLoader', side_effect=lambda *args, **kwargs: FakeLoader()), \
|
||||
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
|
||||
mock.patch.object(module, 'AdamW', RecordingAdamW), \
|
||||
mock.patch.object(module.torch, 'save', return_value=None), \
|
||||
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: iterable):
|
||||
module.main(cfg)
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
optimizer = RecordingAdamW.created[-1]
|
||||
optimizer_names = set().union(*self._group_names(agent, optimizer))
|
||||
trainable_names = {
|
||||
name for name, param in agent.named_parameters() if param.requires_grad
|
||||
}
|
||||
|
||||
self.assertEqual(agent.noise_pred_net.optim_group_calls, [cfg.train.weight_decay])
|
||||
self.assertEqual(optimizer_names, trainable_names)
|
||||
self.assertNotIn('noise_pred_net.norm.bias', optimizer_names)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user