import os import tempfile import unittest from copy import deepcopy from pathlib import Path from unittest import mock import numpy as np import torch from omegaconf import OmegaConf from torch import nn from roboimi.demos.vla_scripts import eval_vla, train_vla class _FakeDataset: def __len__(self): return 4 class _FakeLoader: def __init__(self, batch, length=1): self._batches = [batch] * length def __len__(self): return len(self._batches) def __iter__(self): return iter(self._batches) class _FakeOptimizer: def __init__(self, lr=1e-3): self.param_groups = [{'lr': lr}] def zero_grad(self): return None def step(self): return None def state_dict(self): return {} def load_state_dict(self, state_dict): del state_dict return None class _FakeScheduler: def __init__(self): self.step_calls = 0 def step(self): self.step_calls += 1 def state_dict(self): return {} def load_state_dict(self, state_dict): del state_dict return None class _FakeProgressBar: def __init__(self, iterable): self._items = list(iterable) self.postfix_calls = [] def __iter__(self): return iter(self._items) def set_postfix(self, values): self.postfix_calls.append(values) class _FakeAgent(nn.Module): def __init__(self): super().__init__() self.weight = nn.Parameter(torch.tensor(0.0)) def to(self, device): del device return self def compute_loss(self, agent_input): del agent_input return (self.weight - torch.tensor(0.5)).pow(2) def get_normalization_stats(self): return {} class _SequentialLossAgent(nn.Module): def __init__(self, losses): super().__init__() self.weight = nn.Parameter(torch.tensor(0.0)) self._losses = list(losses) self._index = 0 def to(self, device): del device return self def compute_loss(self, agent_input): del agent_input loss_value = self._losses[self._index] self._index += 1 return (self.weight * 0) + torch.tensor(float(loss_value)) def get_normalization_stats(self): return {} class _FakeEvalAgent: def __init__(self): self.reset_calls = 0 def eval(self): return self def to(self, device): del device return self def reset(self): self.reset_calls += 1 def select_action(self, observation): del observation return torch.zeros(2) class _FakeEvalEnv: def reset(self, box_pos): self.box_pos = box_pos def _get_image_obs(self): return { 'images': { 'front': np.zeros((8, 8, 3), dtype=np.uint8), } } def _get_qpos_obs(self): return {'qpos': np.zeros(4, dtype=np.float32)} def render(self): raise AssertionError('render should not be called in this helper delegation test') class TrainVLARolloutValidationTest(unittest.TestCase): def test_default_train_config_uses_full_dataset_and_epoch_rollout_validation(self): cfg = OmegaConf.load(Path('roboimi/vla/conf/config.yaml')) self.assertEqual(cfg.train.val_split, 0.0) self.assertGreater(cfg.train.batch_size, 8) self.assertGreater(float(cfg.train.lr), 5e-5) self.assertGreater(cfg.train.num_workers, 8) self.assertEqual(cfg.train.rollout_val_freq_epochs, 50) self.assertEqual(cfg.train.rollout_device, cfg.train.device) self.assertIsNone(cfg.train.rollout_num_workers) self.assertIsNone(cfg.train.rollout_cuda_devices) def test_run_training_rollout_validation_propagates_gpu_parallel_settings(self): cfg = OmegaConf.create( { 'train': { 'device': 'cpu', 'batch_size': 1, 'num_workers': 0, 'val_split': 0.0, 'seed': 0, 'lr': 1e-3, 'max_steps': 2, 'log_freq': 1, 'save_freq': 1000, 'warmup_steps': 1, 'scheduler_type': 'constant', 'min_lr': 0.0, 'grad_clip': 1.0, 'weight_decay': 0.0, 'pretrained_ckpt': None, 'resume_ckpt': None, 'use_swanlab': False, 'rollout_val_freq_epochs': 2, 'rollout_num_episodes': 5, 'rollout_device': 'cuda', 'rollout_num_workers': 4, 'rollout_cuda_devices': [0, 1], 'rollout_response_timeout_s': 123.0, 'rollout_server_startup_timeout_s': 456.0, }, 'data': { 'camera_names': ['front'], }, 'agent': { '_target_': 'fake.agent', }, 'eval': { 'ckpt_path': 'unused.pt', 'num_episodes': 99, 'max_timesteps': 1, 'device': 'cpu', 'task_name': 'sim_transfer', 'camera_names': ['front'], 'use_smoothing': False, 'smooth_alpha': 0.3, 'verbose_action': False, 'headless': False, }, } ) rollout_mock = mock.Mock(return_value={'avg_reward': 1.0}) def fake_instantiate(config_node, **_kwargs): if config_node is cfg.data: return _FakeDataset() if config_node is cfg.agent: return _FakeAgent() raise AssertionError(f'unexpected instantiate config: {config_node!r}') def fake_dataloader(_dataset, *, shuffle, **_kwargs): del shuffle, _kwargs return _FakeLoader( { 'observation.front': torch.zeros(1, 3, 2, 2), 'observation.state': torch.zeros(1, 4), 'action': torch.zeros(1, 2), 'action_is_pad': torch.zeros(1, 1, dtype=torch.bool), }, length=1, ) with tempfile.TemporaryDirectory() as tempdir: previous_cwd = os.getcwd() try: os.chdir(tempdir) with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \ mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \ mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \ mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \ mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \ mock.patch.object(train_vla.torch, 'save', return_value=None), \ mock.patch.object(eval_vla, '_run_eval', rollout_mock, create=True): train_vla._run_training(cfg) finally: os.chdir(previous_cwd) rollout_cfg = rollout_mock.call_args.args[0] self.assertEqual(rollout_cfg.eval.device, 'cuda') self.assertEqual(rollout_cfg.eval.num_workers, 4) self.assertEqual(list(rollout_cfg.eval.cuda_devices), [0, 1]) self.assertEqual(float(rollout_cfg.eval.response_timeout_s), 123.0) self.assertEqual(float(rollout_cfg.eval.server_startup_timeout_s), 456.0) self.assertTrue(rollout_cfg.eval.headless) self.assertEqual(rollout_cfg.eval.num_episodes, 5) self.assertFalse(rollout_cfg.eval.record_video) self.assertTrue(rollout_cfg.eval.save_summary_json) self.assertTrue(rollout_cfg.eval.save_trajectory_image) def test_training_passes_backbone_image_resize_override_to_dataset_instantiation(self): cfg = OmegaConf.create( { 'agent': { 'vision_backbone': { 'dataset_image_resize_shape': None, }, 'normalization_type': 'min_max', }, 'data': { 'dataset_dir': 'unused', 'camera_names': ['front'], }, 'train': { 'batch_size': 2, 'lr': 1e-4, 'max_steps': 0, 'device': 'cpu', 'disable_cudnn': False, 'num_workers': 0, 'val_split': 0.0, 'seed': 42, 'log_freq': 1, 'save_freq': 10, 'use_swanlab': False, 'rollout_val_freq_epochs': 0, 'rollout_validate_on_checkpoint': False, 'rollout_num_episodes': 1, 'warmup_steps': 1, 'scheduler_type': 'constant', 'min_lr': 1e-6, 'weight_decay': 1e-5, 'grad_clip': 1.0, 'pretrained_ckpt': None, }, 'eval': { 'ckpt_path': 'unused.pt', 'num_episodes': 1, 'headless': True, 'device': 'cpu', 'verbose_action': False, }, 'experiment': {}, } ) captured_dataset_kwargs = {} def fake_instantiate(config_node, **kwargs): if config_node is cfg.data: captured_dataset_kwargs.update(kwargs) return _FakeDataset() if config_node is cfg.agent: return _FakeAgent() raise AssertionError(f'unexpected instantiate config: {config_node!r}') def fake_dataloader(_dataset, *, shuffle, **_kwargs): del shuffle, _kwargs return _FakeLoader( { 'observation.front': torch.zeros(1, 3, 2, 2), 'observation.state': torch.zeros(1, 4), 'action': torch.zeros(1, 2), 'action_is_pad': torch.zeros(1, 1, dtype=torch.bool), }, length=1, ) with tempfile.TemporaryDirectory() as tempdir: previous_cwd = os.getcwd() try: os.chdir(tempdir) with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \ mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \ mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \ mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \ mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \ mock.patch.object(train_vla, '_init_swanlab', return_value=None), \ mock.patch.object(train_vla, '_finish_swanlab', return_value=None), \ mock.patch.object(train_vla.torch, 'save', return_value=None): train_vla._run_training(cfg) finally: os.chdir(previous_cwd) self.assertIn('image_resize_shape', captured_dataset_kwargs) self.assertIsNone(captured_dataset_kwargs['image_resize_shape']) def test_eval_main_delegates_to_plain_run_eval_helper(self): cfg = OmegaConf.create( { 'agent': {}, 'eval': { 'ckpt_path': 'checkpoints/vla_model_step_1.pt', 'num_episodes': 1, 'max_timesteps': 1, 'device': 'cpu', 'task_name': 'sim_transfer', 'camera_names': ['front'], 'use_smoothing': False, 'smooth_alpha': 0.3, 'verbose_action': False, 'headless': True, }, } ) run_eval_mock = mock.Mock() with mock.patch.object(eval_vla, '_run_eval', run_eval_mock, create=True), \ mock.patch.object(eval_vla, 'load_checkpoint', return_value=(_FakeEvalAgent(), None)), \ mock.patch.object(eval_vla, 'make_sim_env', return_value=_FakeEvalEnv()), \ mock.patch.object(eval_vla, 'sample_transfer_pose', return_value=np.zeros(3)), \ mock.patch.object(eval_vla, 'execute_policy_action'), \ mock.patch.object(eval_vla, 'tqdm', side_effect=lambda iterable, **kwargs: iterable): eval_vla.main.__wrapped__(cfg) run_eval_mock.assert_called_once_with(cfg) def test_run_training_rollout_validation_runs_every_50_epochs_and_uses_avg_reward_metric(self): cfg = OmegaConf.create( { 'train': { 'device': 'cpu', 'batch_size': 1, 'num_workers': 0, 'val_split': 0.0, 'seed': 0, 'lr': 1e-3, 'max_steps': 100, 'log_freq': 1, 'save_freq': 1000, 'warmup_steps': 1, 'scheduler_type': 'constant', 'min_lr': 0.0, 'grad_clip': 1.0, 'weight_decay': 0.0, 'pretrained_ckpt': None, 'resume_ckpt': None, 'use_swanlab': False, 'rollout_val_freq_epochs': 50, 'rollout_num_episodes': 3, }, 'data': { 'camera_names': ['front'], }, 'agent': { '_target_': 'fake.agent', }, 'eval': { 'ckpt_path': 'unused.pt', 'num_episodes': 99, 'max_timesteps': 1, 'device': 'cpu', 'task_name': 'sim_transfer', 'camera_names': ['front'], 'use_smoothing': False, 'smooth_alpha': 0.3, 'verbose_action': False, 'headless': False, }, } ) agent = _FakeAgent() rollout_mock = mock.Mock( side_effect=[ { 'avg_reward': 2.0, 'episodes': [ { 'episode_index': 0, 'artifact_paths': {'trajectory_image': 'artifacts/epoch_49_front.png'}, }, ], }, { 'avg_reward': 1.0, 'episodes': [ { 'episode_index': 0, 'artifact_paths': {'trajectory_image': 'artifacts/epoch_99_front.png'}, }, ], }, ] ) swanlab_log_mock = mock.Mock() saved_checkpoints = [] def fake_instantiate(config_node, **_kwargs): if config_node is cfg.data: return _FakeDataset() if config_node is cfg.agent: return agent raise AssertionError(f'unexpected instantiate config: {config_node!r}') def fake_dataloader(_dataset, *, shuffle, **_kwargs): del shuffle, _kwargs return _FakeLoader( { 'observation.front': torch.zeros(1, 3, 2, 2), 'observation.state': torch.zeros(1, 4), 'action': torch.zeros(1, 2), 'action_is_pad': torch.zeros(1, 1, dtype=torch.bool), }, length=1, ) def fake_torch_save(payload, path): saved_checkpoints.append((str(path), deepcopy(payload))) return None with tempfile.TemporaryDirectory() as tempdir: previous_cwd = os.getcwd() try: os.chdir(tempdir) with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \ mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \ mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \ mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \ mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \ mock.patch.object(train_vla, '_log_to_swanlab', swanlab_log_mock), \ mock.patch.object(train_vla.torch, 'save', side_effect=fake_torch_save), \ mock.patch.object(eval_vla, '_run_eval', rollout_mock, create=True), \ mock.patch.object(eval_vla.main, '__wrapped__', side_effect=AssertionError('training hook should call eval_vla._run_eval')): train_vla._run_training(cfg) finally: os.chdir(previous_cwd) self.assertEqual(rollout_mock.call_count, 2) first_rollout_cfg = rollout_mock.call_args_list[0].args[0] second_rollout_cfg = rollout_mock.call_args_list[1].args[0] self.assertTrue(first_rollout_cfg.eval.ckpt_path.endswith('checkpoints/vla_model_step_49.pt')) self.assertTrue(second_rollout_cfg.eval.ckpt_path.endswith('checkpoints/vla_model_step_99.pt')) self.assertEqual(first_rollout_cfg.eval.num_episodes, 3) self.assertTrue(first_rollout_cfg.eval.headless) self.assertEqual(first_rollout_cfg.eval.device, 'cpu') self.assertFalse(first_rollout_cfg.eval.verbose_action) self.assertFalse(first_rollout_cfg.eval.record_video) self.assertTrue(first_rollout_cfg.eval.save_trajectory_image) self.assertEqual(first_rollout_cfg.eval.trajectory_image_camera_name, 'front') self.assertEqual(cfg.eval.ckpt_path, 'unused.pt') self.assertEqual(cfg.eval.num_episodes, 99) self.assertFalse(cfg.eval.headless) self.assertEqual(cfg.eval.device, 'cpu') self.assertFalse(cfg.eval.verbose_action) self.assertNotIn('save_trajectory_image', cfg.eval) self.assertNotIn('trajectory_image_camera_name', cfg.eval) rollout_reward_logs = [ call.args[1]['rollout/avg_reward'] for call in swanlab_log_mock.call_args_list if len(call.args) >= 2 and 'rollout/avg_reward' in call.args[1] ] self.assertEqual(rollout_reward_logs, [2.0, 1.0]) best_model_saves = [ payload for path, payload in saved_checkpoints if path.endswith('checkpoints/vla_model_best.pt') ] self.assertEqual(len(best_model_saves), 1) self.assertEqual(best_model_saves[0]['rollout_avg_reward'], 2.0) def test_run_training_keeps_loss_based_best_checkpoint_until_first_rollout_metric_exists(self): cfg = OmegaConf.create( { 'train': { 'device': 'cpu', 'batch_size': 1, 'num_workers': 0, 'val_split': 0.0, 'seed': 0, 'lr': 1e-3, 'max_steps': 5, 'log_freq': 1, 'save_freq': 2, 'warmup_steps': 1, 'scheduler_type': 'constant', 'min_lr': 0.0, 'grad_clip': 1.0, 'weight_decay': 0.0, 'pretrained_ckpt': None, 'resume_ckpt': None, 'use_swanlab': False, 'rollout_val_freq_epochs': 50, 'rollout_num_episodes': 3, }, 'data': { 'camera_names': ['front'], }, 'agent': { '_target_': 'fake.agent', }, 'eval': { 'ckpt_path': 'unused.pt', 'num_episodes': 99, 'max_timesteps': 1, 'device': 'cpu', 'task_name': 'sim_transfer', 'camera_names': ['front'], 'use_smoothing': False, 'smooth_alpha': 0.3, 'verbose_action': False, 'headless': False, }, } ) saved_checkpoints = [] rollout_mock = mock.Mock() def fake_instantiate(config_node, **_kwargs): if config_node is cfg.data: return _FakeDataset() if config_node is cfg.agent: return _FakeAgent() raise AssertionError(f'unexpected instantiate config: {config_node!r}') def fake_dataloader(_dataset, *, shuffle, **_kwargs): del shuffle, _kwargs return _FakeLoader( { 'observation.front': torch.zeros(1, 3, 2, 2), 'observation.state': torch.zeros(1, 4), 'action': torch.zeros(1, 2), 'action_is_pad': torch.zeros(1, 1, dtype=torch.bool), }, length=5, ) def fake_torch_save(payload, path): saved_checkpoints.append((str(path), deepcopy(payload))) return None with tempfile.TemporaryDirectory() as tempdir: previous_cwd = os.getcwd() try: os.chdir(tempdir) with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \ mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \ mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \ mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \ mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \ mock.patch.object(train_vla.torch, 'save', side_effect=fake_torch_save), \ mock.patch.object(eval_vla, '_run_eval', rollout_mock, create=True): train_vla._run_training(cfg) finally: os.chdir(previous_cwd) self.assertEqual(rollout_mock.call_count, 0) best_model_saves = [ payload for path, payload in saved_checkpoints if path.endswith('checkpoints/vla_model_best.pt') ] self.assertEqual(len(best_model_saves), 1) self.assertIsNone(best_model_saves[0]['rollout_avg_reward']) def test_run_training_disables_drop_last_when_train_set_is_smaller_than_batch_size(self): cfg = OmegaConf.create( { 'train': { 'device': 'cpu', 'batch_size': 8, 'num_workers': 0, 'val_split': 0.0, 'seed': 0, 'lr': 1e-3, 'max_steps': 1, 'log_freq': 1, 'save_freq': 10, 'warmup_steps': 1, 'scheduler_type': 'constant', 'min_lr': 0.0, 'grad_clip': 1.0, 'weight_decay': 0.0, 'pretrained_ckpt': None, 'resume_ckpt': None, 'use_swanlab': False, 'rollout_val_freq_epochs': 50, 'rollout_num_episodes': 3, }, 'data': { 'camera_names': ['front'], }, 'agent': { '_target_': 'fake.agent', }, 'eval': { 'ckpt_path': 'unused.pt', 'num_episodes': 99, 'max_timesteps': 1, 'device': 'cpu', 'task_name': 'sim_transfer', 'camera_names': ['front'], 'use_smoothing': False, 'smooth_alpha': 0.3, 'verbose_action': False, 'headless': False, }, } ) dataloader_calls = [] def fake_instantiate(config_node, **_kwargs): if config_node is cfg.data: return _FakeDataset() if config_node is cfg.agent: return _FakeAgent() raise AssertionError(f'unexpected instantiate config: {config_node!r}') def fake_dataloader(dataset, *, shuffle, drop_last, **_kwargs): dataloader_calls.append({ 'shuffle': shuffle, 'drop_last': drop_last, 'dataset_len': len(dataset), }) return _FakeLoader( { 'observation.front': torch.zeros(1, 3, 2, 2), 'observation.state': torch.zeros(1, 4), 'action': torch.zeros(1, 2), 'action_is_pad': torch.zeros(1, 1, dtype=torch.bool), }, length=1, ) with tempfile.TemporaryDirectory() as tempdir: previous_cwd = os.getcwd() try: os.chdir(tempdir) with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \ mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \ mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \ mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \ mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \ mock.patch.object(train_vla.torch, 'save', return_value=None): train_vla._run_training(cfg) finally: os.chdir(previous_cwd) train_loader_calls = [call for call in dataloader_calls if call['shuffle']] self.assertEqual(len(train_loader_calls), 1) self.assertFalse(train_loader_calls[0]['drop_last']) def test_run_training_disables_persistent_workers_for_train_and_val_loaders(self): cfg = OmegaConf.create( { 'train': { 'device': 'cpu', 'batch_size': 2, 'num_workers': 2, 'val_split': 0.25, 'seed': 0, 'lr': 1e-3, 'max_steps': 1, 'log_freq': 1, 'save_freq': 10, 'warmup_steps': 1, 'scheduler_type': 'constant', 'min_lr': 0.0, 'grad_clip': 1.0, 'weight_decay': 0.0, 'pretrained_ckpt': None, 'resume_ckpt': None, 'use_swanlab': False, 'rollout_val_freq_epochs': 50, 'rollout_num_episodes': 3, }, 'data': { 'camera_names': ['front'], }, 'agent': { '_target_': 'fake.agent', }, 'eval': { 'ckpt_path': 'unused.pt', 'num_episodes': 99, 'max_timesteps': 1, 'device': 'cpu', 'task_name': 'sim_transfer', 'camera_names': ['front'], 'use_smoothing': False, 'smooth_alpha': 0.3, 'verbose_action': False, 'headless': False, }, } ) dataloader_calls = [] def fake_instantiate(config_node, **_kwargs): if config_node is cfg.data: return _FakeDataset() if config_node is cfg.agent: return _FakeAgent() raise AssertionError(f'unexpected instantiate config: {config_node!r}') def fake_dataloader(_dataset, *, shuffle, persistent_workers, num_workers, **_kwargs): dataloader_calls.append({ 'shuffle': shuffle, 'num_workers': num_workers, 'persistent_workers': persistent_workers, }) return _FakeLoader( { 'observation.front': torch.zeros(1, 3, 2, 2), 'observation.state': torch.zeros(1, 4), 'action': torch.zeros(1, 2), 'action_is_pad': torch.zeros(1, 1, dtype=torch.bool), }, length=1, ) with tempfile.TemporaryDirectory() as tempdir: previous_cwd = os.getcwd() try: os.chdir(tempdir) with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \ mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \ mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \ mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \ mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \ mock.patch.object(train_vla.torch, 'save', return_value=None): train_vla._run_training(cfg) finally: os.chdir(previous_cwd) self.assertEqual(len(dataloader_calls), 2) self.assertEqual([call['shuffle'] for call in dataloader_calls], [True, False]) self.assertTrue(all(call['num_workers'] == 2 for call in dataloader_calls)) self.assertTrue(all(call['persistent_workers'] is False for call in dataloader_calls)) def test_run_training_uses_loss_best_until_first_rollout_then_prefers_rollout_reward(self): cfg = OmegaConf.create( { 'train': { 'device': 'cpu', 'batch_size': 1, 'num_workers': 0, 'val_split': 0.0, 'seed': 0, 'lr': 1e-3, 'max_steps': 6, 'log_freq': 1, 'save_freq': 1, 'warmup_steps': 1, 'scheduler_type': 'constant', 'min_lr': 0.0, 'grad_clip': 1.0, 'weight_decay': 0.0, 'pretrained_ckpt': None, 'resume_ckpt': None, 'use_swanlab': False, 'rollout_val_freq_epochs': 2, 'rollout_num_episodes': 1, }, 'data': { 'camera_names': ['front'], }, 'agent': { '_target_': 'fake.agent', }, 'eval': { 'ckpt_path': 'unused.pt', 'num_episodes': 99, 'max_timesteps': 1, 'device': 'cpu', 'task_name': 'sim_transfer', 'camera_names': ['front'], 'use_smoothing': False, 'smooth_alpha': 0.3, 'verbose_action': False, 'headless': False, }, } ) agent = _SequentialLossAgent([10, 9, 8, 7, 6, 5]) rollout_mock = mock.Mock(return_value={'avg_reward': 1.0}) saved_checkpoints = [] def fake_instantiate(config_node, **_kwargs): if config_node is cfg.data: return _FakeDataset() if config_node is cfg.agent: return agent raise AssertionError(f'unexpected instantiate config: {config_node!r}') def fake_dataloader(_dataset, *, shuffle, **_kwargs): del _kwargs return _FakeLoader( { 'observation.front': torch.zeros(1, 3, 2, 2), 'observation.state': torch.zeros(1, 4), 'action': torch.zeros(1, 2), 'action_is_pad': torch.zeros(1, 1, dtype=torch.bool), }, length=2 if shuffle else 1, ) def fake_torch_save(payload, path): saved_checkpoints.append((str(path), deepcopy(payload))) return None with tempfile.TemporaryDirectory() as tempdir: previous_cwd = os.getcwd() try: os.chdir(tempdir) with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \ mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \ mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \ mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \ mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \ mock.patch.object(train_vla.torch, 'save', side_effect=fake_torch_save), \ mock.patch.object(eval_vla, '_run_eval', rollout_mock, create=True): train_vla._run_training(cfg) finally: os.chdir(previous_cwd) best_model_saves = [ (payload['step'], payload['rollout_avg_reward']) for path, payload in saved_checkpoints if path.endswith('checkpoints/vla_model_best.pt') ] self.assertEqual( best_model_saves, [ (1, None), (2, None), (3, None), (3, 1.0), ], ) self.assertEqual(rollout_mock.call_count, 1) def test_run_training_keeps_tiny_train_dataset_batch_when_batch_size_is_larger(self): cfg = OmegaConf.create( { 'train': { 'device': 'cpu', 'batch_size': 8, 'num_workers': 0, 'val_split': 0.0, 'seed': 0, 'lr': 1e-3, 'max_steps': 1, 'log_freq': 1, 'save_freq': 1000, 'warmup_steps': 1, 'scheduler_type': 'constant', 'min_lr': 0.0, 'grad_clip': 1.0, 'weight_decay': 0.0, 'pretrained_ckpt': None, 'resume_ckpt': None, 'use_swanlab': False, 'rollout_val_freq_epochs': 0, }, 'data': { 'camera_names': ['front'], }, 'agent': { '_target_': 'fake.agent', }, } ) agent = _FakeAgent() dataloader_calls = [] saved_checkpoints = [] class _TinyDataset: def __len__(self): return 1 def fake_instantiate(config_node, **_kwargs): if config_node is cfg.data: return _TinyDataset() if config_node is cfg.agent: return agent raise AssertionError(f'unexpected instantiate config: {config_node!r}') def fake_dataloader(dataset, *, drop_last, shuffle, **_kwargs): del _kwargs dataloader_calls.append( { 'shuffle': shuffle, 'drop_last': drop_last, 'dataset_len': len(dataset), } ) loader_length = 0 if drop_last and len(dataset) < cfg.train.batch_size else 1 return _FakeLoader( { 'observation.front': torch.zeros(1, 3, 2, 2), 'observation.state': torch.zeros(1, 4), 'action': torch.zeros(1, 2), 'action_is_pad': torch.zeros(1, 1, dtype=torch.bool), }, length=loader_length, ) def fake_torch_save(payload, path): saved_checkpoints.append((str(path), deepcopy(payload))) return None with tempfile.TemporaryDirectory() as tempdir: previous_cwd = os.getcwd() try: os.chdir(tempdir) with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \ mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \ mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \ mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \ mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \ mock.patch.object(train_vla.torch, 'save', side_effect=fake_torch_save): train_vla._run_training(cfg) finally: os.chdir(previous_cwd) self.assertEqual( dataloader_calls[0], { 'shuffle': True, 'drop_last': False, 'dataset_len': 1, }, ) self.assertEqual(len(saved_checkpoints), 1) self.assertTrue(saved_checkpoints[0][0].endswith('checkpoints/vla_model_final.pt')) if __name__ == '__main__': unittest.main()