import importlib import importlib.util import os import sys import tempfile import types import unittest from pathlib import Path from unittest import mock import torch from torch import nn _REPO_ROOT = Path(__file__).resolve().parents[1] _TRAIN_VLA_PATH = _REPO_ROOT / 'roboimi/demos/vla_scripts/train_vla.py' _CONFIG_PATH = _REPO_ROOT / 'roboimi/vla/conf/config.yaml' class AttrDict(dict): def __getattr__(self, name): try: return self[name] except KeyError as exc: raise AttributeError(name) from exc def __setattr__(self, name, value): self[name] = value def _to_attrdict(value): if isinstance(value, dict): return AttrDict({key: _to_attrdict(item) for key, item in value.items()}) if isinstance(value, list): return [_to_attrdict(item) for item in value] return value class FakeDataset: def __len__(self): return 4 class FakeLoader: def __init__(self, batch): self.batch = batch def __len__(self): return 1 def __iter__(self): return iter((self.batch,)) class FakeScheduler: def __init__(self): self.step_calls = 0 def step(self): self.step_calls += 1 def state_dict(self): return {} def load_state_dict(self, state_dict): return None class FakeOptimizer: def __init__(self, lr=1e-3): self.param_groups = [{'lr': lr}] self.loaded_state_dict = None def zero_grad(self): return None def step(self): return None def state_dict(self): return {} def load_state_dict(self, state_dict): self.loaded_state_dict = state_dict return None class FakeProgressBar: def __init__(self, iterable): self._items = list(iterable) self.postfix_calls = [] def __iter__(self): return iter(self._items) def set_postfix(self, values): self.postfix_calls.append(values) class FakeAgent(nn.Module): def __init__(self): super().__init__() self.weight = nn.Parameter(torch.tensor(0.0)) def to(self, device): return self def compute_loss(self, agent_input): del agent_input target = torch.tensor(0.25 if self.training else 0.1) return (self.weight - target).pow(2) def get_normalization_stats(self): return {} class FakeSwanLab: def __init__(self, init_error=None, log_errors=None, finish_error=None, image_errors=None): self.init_error = init_error self.log_errors = list(log_errors or []) self.finish_error = finish_error self.image_errors = list(image_errors or []) self.init_calls = [] self.log_calls = [] self.finish_calls = 0 self.image_calls = [] def init(self, project, experiment_name=None, config=None): self.init_calls.append({ 'project': project, 'experiment_name': experiment_name, 'config': config, }) if self.init_error is not None: raise self.init_error return object() def log(self, payload, step=None): self.log_calls.append((dict(payload), step)) if self.log_errors: raise self.log_errors.pop(0) def Image(self, path, caption=None): self.image_calls.append({ 'path': path, 'caption': caption, }) if self.image_errors: raise self.image_errors.pop(0) return { 'path': path, 'caption': caption, } def finish(self): self.finish_calls += 1 if self.finish_error is not None: raise self.finish_error class TrainVLASwanLabLoggingTest(unittest.TestCase): def test_default_config_keeps_swanlab_opt_in(self): config_text = _CONFIG_PATH.read_text(encoding='utf-8') self.assertIn('use_swanlab: false', config_text) def test_log_rollout_trajectory_images_to_swanlab_uploads_episode_artifacts(self): module = self._load_train_vla_module() fake_swanlab = FakeSwanLab() module._log_rollout_trajectory_images_to_swanlab( fake_swanlab, { 'episodes': [ { 'episode_index': 0, 'artifact_paths': {'trajectory_image': 'artifacts/episode_0_front.png'}, }, { 'episode_index': 3, 'artifact_paths': {'trajectory_image': 'artifacts/episode_3_front.png'}, }, { 'episode_index': 7, 'artifact_paths': {'trajectory_image': None}, }, { 'episode_index': 8, 'artifact_paths': {}, }, ], }, step=12, context_label='epoch 1 rollout', ) self.assertEqual( fake_swanlab.image_calls, [ { 'path': 'artifacts/episode_0_front.png', 'caption': 'epoch 1 rollout trajectory image - episode 0 (front)', }, { 'path': 'artifacts/episode_3_front.png', 'caption': 'epoch 1 rollout trajectory image - episode 3 (front)', }, ], ) self.assertIn( ( { 'rollout/trajectory_image_episode_0': { 'path': 'artifacts/episode_0_front.png', 'caption': 'epoch 1 rollout trajectory image - episode 0 (front)', }, 'rollout/trajectory_image_episode_3': { 'path': 'artifacts/episode_3_front.png', 'caption': 'epoch 1 rollout trajectory image - episode 3 (front)', }, }, 12, ), fake_swanlab.log_calls, ) def test_log_rollout_trajectory_images_to_swanlab_is_best_effort(self): module = self._load_train_vla_module() fake_swanlab = FakeSwanLab(image_errors=[RuntimeError('decode failed')]) with mock.patch.object(module.log, 'warning') as warning_mock: module._log_rollout_trajectory_images_to_swanlab( fake_swanlab, { 'episodes': [ { 'episode_index': 0, 'artifact_paths': {'trajectory_image': 'artifacts/bad_episode.png'}, }, { 'episode_index': 1, 'artifact_paths': {'trajectory_image': 'artifacts/good_episode.png'}, }, ], }, step=7, context_label='checkpoint rollout', ) self.assertEqual( fake_swanlab.image_calls, [ { 'path': 'artifacts/bad_episode.png', 'caption': 'checkpoint rollout trajectory image - episode 0 (front)', }, { 'path': 'artifacts/good_episode.png', 'caption': 'checkpoint rollout trajectory image - episode 1 (front)', }, ], ) self.assertIn( ( { 'rollout/trajectory_image_episode_1': { 'path': 'artifacts/good_episode.png', 'caption': 'checkpoint rollout trajectory image - episode 1 (front)', }, }, 7, ), fake_swanlab.log_calls, ) warning_messages = [call.args[0] for call in warning_mock.call_args_list] self.assertTrue( any('SwanLab rollout trajectory image upload prep failed' in message for message in warning_messages) ) def _load_train_vla_module(self): hydra_module = types.ModuleType('hydra') hydra_utils_module = types.ModuleType('hydra.utils') hydra_utils_module.instantiate = lambda *args, **kwargs: None def hydra_main(**_kwargs): def decorator(func): return func return decorator hydra_module.main = hydra_main hydra_module.utils = hydra_utils_module class OmegaConfStub: _resolvers = {} @classmethod def has_resolver(cls, name): return name in cls._resolvers @classmethod def register_new_resolver(cls, name, resolver): cls._resolvers[name] = resolver @staticmethod def to_yaml(_cfg): return 'stub-config' @staticmethod def to_container(cfg, resolve=False): del resolve return dict(cfg) @staticmethod def create(cfg): return _to_attrdict(cfg) omegaconf_module = types.ModuleType('omegaconf') omegaconf_module.DictConfig = dict omegaconf_module.OmegaConf = OmegaConfStub module_name = 'train_vla_swanlab_test_module' spec = importlib.util.spec_from_file_location(module_name, _TRAIN_VLA_PATH) module = importlib.util.module_from_spec(spec) with mock.patch.dict( sys.modules, { 'hydra': hydra_module, 'hydra.utils': hydra_utils_module, 'omegaconf': omegaconf_module, }, ): assert spec.loader is not None spec.loader.exec_module(module) return module def _make_cfg(self, *, use_swanlab=True, swanlab_run_name='smoke-run'): return AttrDict( train=AttrDict( device='cpu', batch_size=2, num_workers=0, val_split=0.25, seed=0, lr=1e-3, max_steps=2, log_freq=1, save_freq=1, warmup_steps=1, scheduler_type='constant', min_lr=0.0, grad_clip=1.0, weight_decay=0.0, pretrained_ckpt=None, resume_ckpt=None, use_swanlab=use_swanlab, swanlab_project='roboimi-vla-tests', swanlab_run_name=swanlab_run_name, ), data=AttrDict( camera_names=('front',), ), agent=AttrDict( _target_='fake.agent', ), eval=AttrDict( ckpt_path='unused.pt', num_episodes=1, max_timesteps=1, device='cpu', task_name='sim_transfer', camera_names=('front',), use_smoothing=False, smooth_alpha=0.3, verbose_action=False, headless=False, ), ) def _get_run_training(self, module): run_training = getattr(module, '_run_training', None) self.assertIsNotNone(run_training, 'Expected train_vla.py to expose a _run_training(cfg) helper') return run_training def _make_batch(self): return { 'observation.front': torch.zeros(1, 3, 2, 2), 'observation.state': torch.zeros(1, 4), 'action': torch.zeros(1, 2), 'action_is_pad': torch.zeros(1, 1, dtype=torch.bool), } def _loader_factory(self): train_batch = self._make_batch() val_batch = self._make_batch() def factory(_dataset, *, shuffle, **_kwargs): return FakeLoader(train_batch if shuffle else val_batch) return factory def test_run_training_logs_metrics_and_checkpoint_paths_to_swanlab(self): module = self._load_train_vla_module() run_training = self._get_run_training(module) cfg = self._make_cfg() agent = FakeAgent() fake_swanlab = FakeSwanLab() real_import_module = importlib.import_module def fake_instantiate(config_node, **_kwargs): if config_node is cfg.data: return FakeDataset() if config_node is cfg.agent: return agent raise AssertionError(f'unexpected instantiate config: {config_node!r}') def fake_import_module(name, package=None): if name == 'swanlab': return fake_swanlab return real_import_module(name, package) with tempfile.TemporaryDirectory() as tempdir: previous_cwd = os.getcwd() try: os.chdir(tempdir) with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \ mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \ mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \ mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \ mock.patch.object(module.torch, 'save', return_value=None), \ mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module): run_training(cfg) finally: os.chdir(previous_cwd) self.assertEqual( fake_swanlab.init_calls, [{ 'project': 'roboimi-vla-tests', 'experiment_name': 'smoke-run', 'config': { 'train': { 'device': 'cpu', 'batch_size': 2, 'num_workers': 0, 'val_split': 0.25, 'seed': 0, 'lr': 1e-3, 'max_steps': 2, 'log_freq': 1, 'save_freq': 1, 'warmup_steps': 1, 'scheduler_type': 'constant', 'min_lr': 0.0, 'grad_clip': 1.0, 'weight_decay': 0.0, 'pretrained_ckpt': None, 'resume_ckpt': None, 'use_swanlab': True, 'swanlab_project': 'roboimi-vla-tests', 'swanlab_run_name': 'smoke-run', }, 'data': { 'camera_names': ('front',), }, 'agent': { '_target_': 'fake.agent', }, }, }], ) logged_keys = set().union(*(payload.keys() for payload, _step in fake_swanlab.log_calls)) self.assertTrue( { 'train/loss', 'train/lr', 'train/best_loss', 'train/step', 'val/loss', 'final/checkpoint_path', 'final/best_checkpoint_path', }.issubset(logged_keys) ) final_payload, final_step = fake_swanlab.log_calls[-1] self.assertEqual(final_step, cfg.train.max_steps) self.assertTrue(final_payload['final/checkpoint_path'].endswith('checkpoints/vla_model_final.pt')) self.assertTrue(final_payload['final/best_checkpoint_path'].endswith('checkpoints/vla_model_best.pt')) self.assertEqual(fake_swanlab.finish_calls, 1) def test_run_training_skips_swanlab_when_disabled(self): module = self._load_train_vla_module() run_training = self._get_run_training(module) cfg = self._make_cfg(use_swanlab=False) agent = FakeAgent() def fake_instantiate(config_node, **_kwargs): if config_node is cfg.data: return FakeDataset() if config_node is cfg.agent: return agent raise AssertionError(f'unexpected instantiate config: {config_node!r}') with tempfile.TemporaryDirectory() as tempdir: previous_cwd = os.getcwd() try: os.chdir(tempdir) with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \ mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \ mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \ mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \ mock.patch.object(module.torch, 'save', return_value=None), \ mock.patch.object(module.importlib, 'import_module', side_effect=AssertionError('swanlab import should not run')): run_training(cfg) finally: os.chdir(previous_cwd) def test_run_training_finishes_swanlab_when_exception_happens_after_init(self): module = self._load_train_vla_module() run_training = self._get_run_training(module) cfg = self._make_cfg() fake_swanlab = FakeSwanLab() real_import_module = importlib.import_module def fake_import_module(name, package=None): if name == 'swanlab': return fake_swanlab return real_import_module(name, package) with tempfile.TemporaryDirectory() as tempdir: previous_cwd = os.getcwd() try: os.chdir(tempdir) with mock.patch.object(module, 'instantiate', side_effect=RuntimeError('dataset boom')), \ mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module): with self.assertRaisesRegex(RuntimeError, 'dataset boom'): run_training(cfg) finally: os.chdir(previous_cwd) self.assertEqual(fake_swanlab.finish_calls, 1) def test_run_training_warns_and_continues_when_swanlab_log_and_finish_fail(self): module = self._load_train_vla_module() run_training = self._get_run_training(module) cfg = self._make_cfg() agent = FakeAgent() fake_swanlab = FakeSwanLab( log_errors=[RuntimeError('log backend hiccup')], finish_error=RuntimeError('finish backend hiccup'), ) real_import_module = importlib.import_module def fake_instantiate(config_node, **_kwargs): if config_node is cfg.data: return FakeDataset() if config_node is cfg.agent: return agent raise AssertionError(f'unexpected instantiate config: {config_node!r}') def fake_import_module(name, package=None): if name == 'swanlab': return fake_swanlab return real_import_module(name, package) with tempfile.TemporaryDirectory() as tempdir: previous_cwd = os.getcwd() try: os.chdir(tempdir) with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \ mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \ mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \ mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \ mock.patch.object(module.torch, 'save', return_value=None), \ mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module), \ mock.patch.object(module.log, 'warning') as warning_mock: run_training(cfg) finally: os.chdir(previous_cwd) warning_messages = [call.args[0] for call in warning_mock.call_args_list] self.assertTrue(any('SwanLab log failed' in message for message in warning_messages)) self.assertTrue(any('SwanLab finish failed' in message for message in warning_messages)) self.assertEqual(fake_swanlab.finish_calls, 1) def test_run_training_resume_restores_best_rollout_baseline_from_best_checkpoint(self): module = self._load_train_vla_module() run_training = self._get_run_training(module) cfg = self._make_cfg() cfg.train.max_steps = 2 cfg.train.save_freq = 1 cfg.train.rollout_validate_on_checkpoint = True fake_swanlab = FakeSwanLab() fake_optimizer = FakeOptimizer(lr=cfg.train.lr) fake_scheduler = FakeScheduler() real_import_module = importlib.import_module saved_paths = [] def fake_instantiate(config_node, **_kwargs): if config_node is cfg.data: return FakeDataset() if config_node is cfg.agent: return FakeAgent() raise AssertionError(f'unexpected instantiate config: {config_node!r}') def fake_import_module(name, package=None): if name == 'swanlab': return fake_swanlab return real_import_module(name, package) with tempfile.TemporaryDirectory() as tempdir: previous_cwd = os.getcwd() try: os.chdir(tempdir) checkpoint_dir = Path('checkpoints') checkpoint_dir.mkdir() resume_path = checkpoint_dir / 'vla_model_step_0.pt' resume_path.write_bytes(b'resume') best_path = checkpoint_dir / 'vla_model_best.pt' best_path.write_bytes(b'best') cfg.train.resume_ckpt = str(resume_path) resume_checkpoint_state = { 'step': 0, 'model_state_dict': FakeAgent().state_dict(), 'optimizer_state_dict': {}, 'scheduler_state_dict': {}, 'loss': 0.5, 'val_loss': 0.25, } best_checkpoint_state = { 'step': 0, 'model_state_dict': FakeAgent().state_dict(), 'optimizer_state_dict': {}, 'scheduler_state_dict': {}, 'loss': 0.5, 'val_loss': 0.25, 'rollout_avg_reward': 5.0, } def fake_torch_load(path, map_location=None): del map_location path = Path(path).resolve() if path == resume_path.resolve(): return resume_checkpoint_state if path == best_path.resolve(): return best_checkpoint_state raise AssertionError(f'unexpected load path: {path}') def fake_torch_save(payload, path): saved_paths.append(str(path)) return None with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \ mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \ mock.patch.object(module, 'build_training_optimizer', return_value=fake_optimizer), \ mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=fake_scheduler), \ mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \ mock.patch.object(module.torch, 'save', side_effect=fake_torch_save), \ mock.patch.object(module.torch, 'load', side_effect=fake_torch_load), \ mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module), \ mock.patch('roboimi.demos.vla_scripts.eval_vla._run_eval', return_value={'avg_reward': 3.0}): run_training(cfg) finally: os.chdir(previous_cwd) final_payload, final_step = fake_swanlab.log_calls[-1] self.assertEqual(final_step, cfg.train.max_steps) self.assertTrue(final_payload['final/best_checkpoint_path'].endswith('checkpoints/vla_model_best.pt')) self.assertFalse(any(path.endswith('checkpoints/vla_model_best.pt') for path in saved_paths)) def test_run_training_resume_ignores_best_checkpoint_without_rollout_metric(self): module = self._load_train_vla_module() run_training = self._get_run_training(module) cfg = self._make_cfg() cfg.train.max_steps = 1 fake_swanlab = FakeSwanLab() fake_optimizer = FakeOptimizer(lr=cfg.train.lr) fake_scheduler = FakeScheduler() real_import_module = importlib.import_module def fake_instantiate(config_node, **_kwargs): if config_node is cfg.data: return FakeDataset() if config_node is cfg.agent: return FakeAgent() raise AssertionError(f'unexpected instantiate config: {config_node!r}') def fake_import_module(name, package=None): if name == 'swanlab': return fake_swanlab return real_import_module(name, package) with tempfile.TemporaryDirectory() as tempdir: previous_cwd = os.getcwd() try: os.chdir(tempdir) checkpoint_dir = Path('checkpoints') checkpoint_dir.mkdir() resume_path = checkpoint_dir / 'vla_model_step_0.pt' resume_path.write_bytes(b'resume') best_path = checkpoint_dir / 'vla_model_best.pt' best_path.write_bytes(b'stale') cfg.train.resume_ckpt = str(resume_path) resume_checkpoint_state = { 'step': 0, 'model_state_dict': FakeAgent().state_dict(), 'optimizer_state_dict': {}, 'scheduler_state_dict': {}, 'loss': 0.5, 'val_loss': 0.25, } stale_best_checkpoint_state = { 'step': 0, 'model_state_dict': FakeAgent().state_dict(), 'optimizer_state_dict': {}, 'scheduler_state_dict': {}, 'loss': 0.4, 'val_loss': 0.2, } def fake_torch_load(path, map_location=None): del map_location path = Path(path).resolve() if path == resume_path.resolve(): return resume_checkpoint_state if path == best_path.resolve(): return stale_best_checkpoint_state raise AssertionError(f'unexpected load path: {path}') with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \ mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \ mock.patch.object(module, 'build_training_optimizer', return_value=fake_optimizer), \ mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=fake_scheduler), \ mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \ mock.patch.object(module.torch, 'save', return_value=None), \ mock.patch.object(module.torch, 'load', side_effect=fake_torch_load), \ mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module): run_training(cfg) finally: os.chdir(previous_cwd) final_payload, final_step = fake_swanlab.log_calls[-1] self.assertEqual(final_step, cfg.train.max_steps) self.assertEqual(final_payload['final/best_checkpoint_path'], 'checkpoints/vla_model_step_0.pt') def test_run_training_ignores_stale_best_checkpoint_file_on_fresh_non_resume_run(self): module = self._load_train_vla_module() run_training = self._get_run_training(module) cfg = self._make_cfg() cfg.train.max_steps = 1 fake_swanlab = FakeSwanLab() real_import_module = importlib.import_module def fake_instantiate(config_node, **_kwargs): if config_node is cfg.data: return FakeDataset() if config_node is cfg.agent: return FakeAgent() raise AssertionError(f'unexpected instantiate config: {config_node!r}') def fake_import_module(name, package=None): if name == 'swanlab': return fake_swanlab return real_import_module(name, package) with tempfile.TemporaryDirectory() as tempdir: previous_cwd = os.getcwd() try: os.chdir(tempdir) checkpoint_dir = Path('checkpoints') checkpoint_dir.mkdir() (checkpoint_dir / 'vla_model_best.pt').write_bytes(b'stale-best') with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \ mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \ mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \ mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \ mock.patch.object(module.torch, 'save', return_value=None), \ mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module): run_training(cfg) finally: os.chdir(previous_cwd) final_payload, final_step = fake_swanlab.log_calls[-1] self.assertEqual(final_step, cfg.train.max_steps) self.assertEqual(final_payload['final/best_checkpoint_path'], '') def test_run_training_fails_fast_when_swanlab_import_is_unavailable(self): module = self._load_train_vla_module() run_training = self._get_run_training(module) cfg = self._make_cfg() real_import_module = importlib.import_module def fake_import_module(name, package=None): if name == 'swanlab': raise ImportError('missing swanlab') return real_import_module(name, package) with mock.patch.object(module, 'instantiate', side_effect=AssertionError('instantiate should not run')), \ mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module): with self.assertRaisesRegex(RuntimeError, 'SwanLab'): run_training(cfg) def test_run_training_fails_fast_when_swanlab_init_fails(self): module = self._load_train_vla_module() run_training = self._get_run_training(module) cfg = self._make_cfg() fake_swanlab = FakeSwanLab(init_error=RuntimeError('not logged in')) real_import_module = importlib.import_module def fake_import_module(name, package=None): if name == 'swanlab': return fake_swanlab return real_import_module(name, package) with mock.patch.object(module, 'instantiate', side_effect=AssertionError('instantiate should not run')), \ mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module): with self.assertRaisesRegex(RuntimeError, 'not logged in'): run_training(cfg) self.assertEqual(fake_swanlab.finish_calls, 0) if __name__ == '__main__': unittest.main()