Files
roboimi/tests/test_train_vla_swanlab_logging.py

700 lines
27 KiB
Python

import importlib
import importlib.util
import os
import sys
import tempfile
import types
import unittest
from pathlib import Path
from unittest import mock
import torch
from torch import nn
_REPO_ROOT = Path(__file__).resolve().parents[1]
_TRAIN_VLA_PATH = _REPO_ROOT / 'roboimi/demos/vla_scripts/train_vla.py'
_CONFIG_PATH = _REPO_ROOT / 'roboimi/vla/conf/config.yaml'
class AttrDict(dict):
def __getattr__(self, name):
try:
return self[name]
except KeyError as exc:
raise AttributeError(name) from exc
def __setattr__(self, name, value):
self[name] = value
def _to_attrdict(value):
if isinstance(value, dict):
return AttrDict({key: _to_attrdict(item) for key, item in value.items()})
if isinstance(value, list):
return [_to_attrdict(item) for item in value]
return value
class FakeDataset:
def __len__(self):
return 4
class FakeLoader:
def __init__(self, batch):
self.batch = batch
def __len__(self):
return 1
def __iter__(self):
return iter((self.batch,))
class FakeScheduler:
def __init__(self):
self.step_calls = 0
def step(self):
self.step_calls += 1
def state_dict(self):
return {}
def load_state_dict(self, state_dict):
return None
class FakeOptimizer:
def __init__(self, lr=1e-3):
self.param_groups = [{'lr': lr}]
self.loaded_state_dict = None
def zero_grad(self):
return None
def step(self):
return None
def state_dict(self):
return {}
def load_state_dict(self, state_dict):
self.loaded_state_dict = state_dict
return None
class FakeProgressBar:
def __init__(self, iterable):
self._items = list(iterable)
self.postfix_calls = []
def __iter__(self):
return iter(self._items)
def set_postfix(self, values):
self.postfix_calls.append(values)
class FakeAgent(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.tensor(0.0))
def to(self, device):
return self
def compute_loss(self, agent_input):
del agent_input
target = torch.tensor(0.25 if self.training else 0.1)
return (self.weight - target).pow(2)
def get_normalization_stats(self):
return {}
class FakeSwanLab:
def __init__(self, init_error=None, log_errors=None, finish_error=None):
self.init_error = init_error
self.log_errors = list(log_errors or [])
self.finish_error = finish_error
self.init_calls = []
self.log_calls = []
self.finish_calls = 0
def init(self, project, experiment_name=None, config=None):
self.init_calls.append({
'project': project,
'experiment_name': experiment_name,
'config': config,
})
if self.init_error is not None:
raise self.init_error
return object()
def log(self, payload, step=None):
self.log_calls.append((dict(payload), step))
if self.log_errors:
raise self.log_errors.pop(0)
def finish(self):
self.finish_calls += 1
if self.finish_error is not None:
raise self.finish_error
class TrainVLASwanLabLoggingTest(unittest.TestCase):
def test_default_config_keeps_swanlab_opt_in(self):
config_text = _CONFIG_PATH.read_text(encoding='utf-8')
self.assertIn('use_swanlab: false', config_text)
def _load_train_vla_module(self):
hydra_module = types.ModuleType('hydra')
hydra_utils_module = types.ModuleType('hydra.utils')
hydra_utils_module.instantiate = lambda *args, **kwargs: None
def hydra_main(**_kwargs):
def decorator(func):
return func
return decorator
hydra_module.main = hydra_main
hydra_module.utils = hydra_utils_module
class OmegaConfStub:
_resolvers = {}
@classmethod
def has_resolver(cls, name):
return name in cls._resolvers
@classmethod
def register_new_resolver(cls, name, resolver):
cls._resolvers[name] = resolver
@staticmethod
def to_yaml(_cfg):
return 'stub-config'
@staticmethod
def to_container(cfg, resolve=False):
del resolve
return dict(cfg)
@staticmethod
def create(cfg):
return _to_attrdict(cfg)
omegaconf_module = types.ModuleType('omegaconf')
omegaconf_module.DictConfig = dict
omegaconf_module.OmegaConf = OmegaConfStub
module_name = 'train_vla_swanlab_test_module'
spec = importlib.util.spec_from_file_location(module_name, _TRAIN_VLA_PATH)
module = importlib.util.module_from_spec(spec)
with mock.patch.dict(
sys.modules,
{
'hydra': hydra_module,
'hydra.utils': hydra_utils_module,
'omegaconf': omegaconf_module,
},
):
assert spec.loader is not None
spec.loader.exec_module(module)
return module
def _make_cfg(self, *, use_swanlab=True, swanlab_run_name='smoke-run'):
return AttrDict(
train=AttrDict(
device='cpu',
batch_size=2,
num_workers=0,
val_split=0.25,
seed=0,
lr=1e-3,
max_steps=2,
log_freq=1,
save_freq=1,
warmup_steps=1,
scheduler_type='constant',
min_lr=0.0,
grad_clip=1.0,
weight_decay=0.0,
pretrained_ckpt=None,
resume_ckpt=None,
use_swanlab=use_swanlab,
swanlab_project='roboimi-vla-tests',
swanlab_run_name=swanlab_run_name,
),
data=AttrDict(
camera_names=('front',),
),
agent=AttrDict(
_target_='fake.agent',
),
eval=AttrDict(
ckpt_path='unused.pt',
num_episodes=1,
max_timesteps=1,
device='cpu',
task_name='sim_transfer',
camera_names=('front',),
use_smoothing=False,
smooth_alpha=0.3,
verbose_action=False,
headless=False,
),
)
def _get_run_training(self, module):
run_training = getattr(module, '_run_training', None)
self.assertIsNotNone(run_training, 'Expected train_vla.py to expose a _run_training(cfg) helper')
return run_training
def _make_batch(self):
return {
'observation.front': torch.zeros(1, 3, 2, 2),
'observation.state': torch.zeros(1, 4),
'action': torch.zeros(1, 2),
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
}
def _loader_factory(self):
train_batch = self._make_batch()
val_batch = self._make_batch()
def factory(_dataset, *, shuffle, **_kwargs):
return FakeLoader(train_batch if shuffle else val_batch)
return factory
def test_run_training_logs_metrics_and_checkpoint_paths_to_swanlab(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
cfg = self._make_cfg()
agent = FakeAgent()
fake_swanlab = FakeSwanLab()
real_import_module = importlib.import_module
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return FakeDataset()
if config_node is cfg.agent:
return agent
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
def fake_import_module(name, package=None):
if name == 'swanlab':
return fake_swanlab
return real_import_module(name, package)
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
mock.patch.object(module.torch, 'save', return_value=None), \
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
run_training(cfg)
finally:
os.chdir(previous_cwd)
self.assertEqual(
fake_swanlab.init_calls,
[{
'project': 'roboimi-vla-tests',
'experiment_name': 'smoke-run',
'config': {
'train': {
'device': 'cpu',
'batch_size': 2,
'num_workers': 0,
'val_split': 0.25,
'seed': 0,
'lr': 1e-3,
'max_steps': 2,
'log_freq': 1,
'save_freq': 1,
'warmup_steps': 1,
'scheduler_type': 'constant',
'min_lr': 0.0,
'grad_clip': 1.0,
'weight_decay': 0.0,
'pretrained_ckpt': None,
'resume_ckpt': None,
'use_swanlab': True,
'swanlab_project': 'roboimi-vla-tests',
'swanlab_run_name': 'smoke-run',
},
'data': {
'camera_names': ('front',),
},
'agent': {
'_target_': 'fake.agent',
},
},
}],
)
logged_keys = set().union(*(payload.keys() for payload, _step in fake_swanlab.log_calls))
self.assertTrue(
{
'train/loss',
'train/lr',
'train/best_loss',
'train/step',
'val/loss',
'final/checkpoint_path',
'final/best_checkpoint_path',
}.issubset(logged_keys)
)
final_payload, final_step = fake_swanlab.log_calls[-1]
self.assertEqual(final_step, cfg.train.max_steps)
self.assertEqual(final_payload['final/checkpoint_path'], 'checkpoints/vla_model_final.pt')
self.assertEqual(final_payload['final/best_checkpoint_path'], 'checkpoints/vla_model_best.pt')
self.assertEqual(fake_swanlab.finish_calls, 1)
def test_run_training_skips_swanlab_when_disabled(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
cfg = self._make_cfg(use_swanlab=False)
agent = FakeAgent()
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return FakeDataset()
if config_node is cfg.agent:
return agent
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
mock.patch.object(module.torch, 'save', return_value=None), \
mock.patch.object(module.importlib, 'import_module', side_effect=AssertionError('swanlab import should not run')):
run_training(cfg)
finally:
os.chdir(previous_cwd)
def test_run_training_finishes_swanlab_when_exception_happens_after_init(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
cfg = self._make_cfg()
fake_swanlab = FakeSwanLab()
real_import_module = importlib.import_module
def fake_import_module(name, package=None):
if name == 'swanlab':
return fake_swanlab
return real_import_module(name, package)
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
with mock.patch.object(module, 'instantiate', side_effect=RuntimeError('dataset boom')), \
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
with self.assertRaisesRegex(RuntimeError, 'dataset boom'):
run_training(cfg)
finally:
os.chdir(previous_cwd)
self.assertEqual(fake_swanlab.finish_calls, 1)
def test_run_training_warns_and_continues_when_swanlab_log_and_finish_fail(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
cfg = self._make_cfg()
agent = FakeAgent()
fake_swanlab = FakeSwanLab(
log_errors=[RuntimeError('log backend hiccup')],
finish_error=RuntimeError('finish backend hiccup'),
)
real_import_module = importlib.import_module
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return FakeDataset()
if config_node is cfg.agent:
return agent
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
def fake_import_module(name, package=None):
if name == 'swanlab':
return fake_swanlab
return real_import_module(name, package)
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
mock.patch.object(module.torch, 'save', return_value=None), \
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module), \
mock.patch.object(module.log, 'warning') as warning_mock:
run_training(cfg)
finally:
os.chdir(previous_cwd)
warning_messages = [call.args[0] for call in warning_mock.call_args_list]
self.assertTrue(any('SwanLab log failed' in message for message in warning_messages))
self.assertTrue(any('SwanLab finish failed' in message for message in warning_messages))
self.assertEqual(fake_swanlab.finish_calls, 1)
def test_run_training_resume_restores_best_rollout_baseline_from_best_checkpoint(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
cfg = self._make_cfg()
cfg.train.max_steps = 2
cfg.train.save_freq = 1
cfg.train.rollout_validate_on_checkpoint = True
fake_swanlab = FakeSwanLab()
fake_optimizer = FakeOptimizer(lr=cfg.train.lr)
fake_scheduler = FakeScheduler()
real_import_module = importlib.import_module
saved_paths = []
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return FakeDataset()
if config_node is cfg.agent:
return FakeAgent()
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
def fake_import_module(name, package=None):
if name == 'swanlab':
return fake_swanlab
return real_import_module(name, package)
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
checkpoint_dir = Path('checkpoints')
checkpoint_dir.mkdir()
resume_path = checkpoint_dir / 'vla_model_step_0.pt'
resume_path.write_bytes(b'resume')
best_path = checkpoint_dir / 'vla_model_best.pt'
best_path.write_bytes(b'best')
cfg.train.resume_ckpt = str(resume_path)
resume_checkpoint_state = {
'step': 0,
'model_state_dict': FakeAgent().state_dict(),
'optimizer_state_dict': {},
'scheduler_state_dict': {},
'loss': 0.5,
'val_loss': 0.25,
}
best_checkpoint_state = {
'step': 0,
'model_state_dict': FakeAgent().state_dict(),
'optimizer_state_dict': {},
'scheduler_state_dict': {},
'loss': 0.5,
'val_loss': 0.25,
'rollout_avg_reward': 5.0,
}
def fake_torch_load(path, map_location=None):
del map_location
path = Path(path)
if path == resume_path:
return resume_checkpoint_state
if path == best_path:
return best_checkpoint_state
raise AssertionError(f'unexpected load path: {path}')
def fake_torch_save(payload, path):
saved_paths.append(str(path))
return None
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
mock.patch.object(module, 'build_training_optimizer', return_value=fake_optimizer), \
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=fake_scheduler), \
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
mock.patch.object(module.torch, 'save', side_effect=fake_torch_save), \
mock.patch.object(module.torch, 'load', side_effect=fake_torch_load), \
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module), \
mock.patch('roboimi.demos.vla_scripts.eval_vla._run_eval', return_value={'avg_reward': 3.0}):
run_training(cfg)
finally:
os.chdir(previous_cwd)
final_payload, final_step = fake_swanlab.log_calls[-1]
self.assertEqual(final_step, cfg.train.max_steps)
self.assertEqual(final_payload['final/best_checkpoint_path'], 'checkpoints/vla_model_best.pt')
self.assertNotIn('checkpoints/vla_model_best.pt', saved_paths)
def test_run_training_resume_ignores_best_checkpoint_without_rollout_metric(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
cfg = self._make_cfg()
cfg.train.max_steps = 1
fake_swanlab = FakeSwanLab()
fake_optimizer = FakeOptimizer(lr=cfg.train.lr)
fake_scheduler = FakeScheduler()
real_import_module = importlib.import_module
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return FakeDataset()
if config_node is cfg.agent:
return FakeAgent()
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
def fake_import_module(name, package=None):
if name == 'swanlab':
return fake_swanlab
return real_import_module(name, package)
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
checkpoint_dir = Path('checkpoints')
checkpoint_dir.mkdir()
resume_path = checkpoint_dir / 'vla_model_step_0.pt'
resume_path.write_bytes(b'resume')
best_path = checkpoint_dir / 'vla_model_best.pt'
best_path.write_bytes(b'stale')
cfg.train.resume_ckpt = str(resume_path)
resume_checkpoint_state = {
'step': 0,
'model_state_dict': FakeAgent().state_dict(),
'optimizer_state_dict': {},
'scheduler_state_dict': {},
'loss': 0.5,
'val_loss': 0.25,
}
stale_best_checkpoint_state = {
'step': 0,
'model_state_dict': FakeAgent().state_dict(),
'optimizer_state_dict': {},
'scheduler_state_dict': {},
'loss': 0.4,
'val_loss': 0.2,
}
def fake_torch_load(path, map_location=None):
del map_location
path = Path(path)
if path == resume_path:
return resume_checkpoint_state
if path == best_path:
return stale_best_checkpoint_state
raise AssertionError(f'unexpected load path: {path}')
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
mock.patch.object(module, 'build_training_optimizer', return_value=fake_optimizer), \
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=fake_scheduler), \
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
mock.patch.object(module.torch, 'save', return_value=None), \
mock.patch.object(module.torch, 'load', side_effect=fake_torch_load), \
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
run_training(cfg)
finally:
os.chdir(previous_cwd)
final_payload, final_step = fake_swanlab.log_calls[-1]
self.assertEqual(final_step, cfg.train.max_steps)
self.assertEqual(final_payload['final/best_checkpoint_path'], 'checkpoints/vla_model_step_0.pt')
def test_run_training_ignores_stale_best_checkpoint_file_on_fresh_non_resume_run(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
cfg = self._make_cfg()
cfg.train.max_steps = 1
fake_swanlab = FakeSwanLab()
real_import_module = importlib.import_module
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return FakeDataset()
if config_node is cfg.agent:
return FakeAgent()
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
def fake_import_module(name, package=None):
if name == 'swanlab':
return fake_swanlab
return real_import_module(name, package)
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
checkpoint_dir = Path('checkpoints')
checkpoint_dir.mkdir()
(checkpoint_dir / 'vla_model_best.pt').write_bytes(b'stale-best')
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
mock.patch.object(module.torch, 'save', return_value=None), \
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
run_training(cfg)
finally:
os.chdir(previous_cwd)
final_payload, final_step = fake_swanlab.log_calls[-1]
self.assertEqual(final_step, cfg.train.max_steps)
self.assertEqual(final_payload['final/best_checkpoint_path'], '')
def test_run_training_fails_fast_when_swanlab_import_is_unavailable(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
cfg = self._make_cfg()
real_import_module = importlib.import_module
def fake_import_module(name, package=None):
if name == 'swanlab':
raise ImportError('missing swanlab')
return real_import_module(name, package)
with mock.patch.object(module, 'instantiate', side_effect=AssertionError('instantiate should not run')), \
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
with self.assertRaisesRegex(RuntimeError, 'SwanLab'):
run_training(cfg)
def test_run_training_fails_fast_when_swanlab_init_fails(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
cfg = self._make_cfg()
fake_swanlab = FakeSwanLab(init_error=RuntimeError('not logged in'))
real_import_module = importlib.import_module
def fake_import_module(name, package=None):
if name == 'swanlab':
return fake_swanlab
return real_import_module(name, package)
with mock.patch.object(module, 'instantiate', side_effect=AssertionError('instantiate should not run')), \
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
with self.assertRaisesRegex(RuntimeError, 'not logged in'):
run_training(cfg)
self.assertEqual(fake_swanlab.finish_calls, 0)
if __name__ == '__main__':
unittest.main()