780 lines
29 KiB
Python
780 lines
29 KiB
Python
import os
|
|
import tempfile
|
|
import unittest
|
|
from copy import deepcopy
|
|
from pathlib import Path
|
|
from unittest import mock
|
|
|
|
import numpy as np
|
|
import torch
|
|
from omegaconf import OmegaConf
|
|
from torch import nn
|
|
|
|
from roboimi.demos.vla_scripts import eval_vla, train_vla
|
|
|
|
|
|
class _FakeDataset:
|
|
def __len__(self):
|
|
return 4
|
|
|
|
|
|
class _FakeLoader:
|
|
def __init__(self, batch, length=1):
|
|
self._batches = [batch] * length
|
|
|
|
def __len__(self):
|
|
return len(self._batches)
|
|
|
|
def __iter__(self):
|
|
return iter(self._batches)
|
|
|
|
|
|
class _FakeOptimizer:
|
|
def __init__(self, lr=1e-3):
|
|
self.param_groups = [{'lr': lr}]
|
|
|
|
def zero_grad(self):
|
|
return None
|
|
|
|
def step(self):
|
|
return None
|
|
|
|
def state_dict(self):
|
|
return {}
|
|
|
|
def load_state_dict(self, state_dict):
|
|
del state_dict
|
|
return None
|
|
|
|
|
|
class _FakeScheduler:
|
|
def __init__(self):
|
|
self.step_calls = 0
|
|
|
|
def step(self):
|
|
self.step_calls += 1
|
|
|
|
def state_dict(self):
|
|
return {}
|
|
|
|
def load_state_dict(self, state_dict):
|
|
del state_dict
|
|
return None
|
|
|
|
|
|
class _FakeProgressBar:
|
|
def __init__(self, iterable):
|
|
self._items = list(iterable)
|
|
self.postfix_calls = []
|
|
|
|
def __iter__(self):
|
|
return iter(self._items)
|
|
|
|
def set_postfix(self, values):
|
|
self.postfix_calls.append(values)
|
|
|
|
|
|
class _FakeAgent(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.tensor(0.0))
|
|
|
|
def to(self, device):
|
|
del device
|
|
return self
|
|
|
|
def compute_loss(self, agent_input):
|
|
del agent_input
|
|
return (self.weight - torch.tensor(0.5)).pow(2)
|
|
|
|
def get_normalization_stats(self):
|
|
return {}
|
|
|
|
|
|
class _SequentialLossAgent(nn.Module):
|
|
def __init__(self, losses):
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.tensor(0.0))
|
|
self._losses = list(losses)
|
|
self._index = 0
|
|
|
|
def to(self, device):
|
|
del device
|
|
return self
|
|
|
|
def compute_loss(self, agent_input):
|
|
del agent_input
|
|
loss_value = self._losses[self._index]
|
|
self._index += 1
|
|
return (self.weight * 0) + torch.tensor(float(loss_value))
|
|
|
|
def get_normalization_stats(self):
|
|
return {}
|
|
|
|
|
|
class _FakeEvalAgent:
|
|
def __init__(self):
|
|
self.reset_calls = 0
|
|
|
|
def eval(self):
|
|
return self
|
|
|
|
def to(self, device):
|
|
del device
|
|
return self
|
|
|
|
def reset(self):
|
|
self.reset_calls += 1
|
|
|
|
def select_action(self, observation):
|
|
del observation
|
|
return torch.zeros(2)
|
|
|
|
|
|
class _FakeEvalEnv:
|
|
def reset(self, box_pos):
|
|
self.box_pos = box_pos
|
|
|
|
def _get_image_obs(self):
|
|
return {
|
|
'images': {
|
|
'front': np.zeros((8, 8, 3), dtype=np.uint8),
|
|
}
|
|
}
|
|
|
|
def _get_qpos_obs(self):
|
|
return {'qpos': np.zeros(4, dtype=np.float32)}
|
|
|
|
def render(self):
|
|
raise AssertionError('render should not be called in this helper delegation test')
|
|
|
|
|
|
class TrainVLARolloutValidationTest(unittest.TestCase):
|
|
def test_default_train_config_uses_full_dataset_and_epoch_rollout_validation(self):
|
|
cfg = OmegaConf.load(Path('roboimi/vla/conf/config.yaml'))
|
|
|
|
self.assertEqual(cfg.train.val_split, 0.0)
|
|
self.assertGreater(cfg.train.batch_size, 8)
|
|
self.assertGreater(float(cfg.train.lr), 5e-5)
|
|
self.assertGreater(cfg.train.num_workers, 8)
|
|
self.assertEqual(cfg.train.rollout_val_freq_epochs, 50)
|
|
|
|
def test_eval_main_delegates_to_plain_run_eval_helper(self):
|
|
cfg = OmegaConf.create(
|
|
{
|
|
'agent': {},
|
|
'eval': {
|
|
'ckpt_path': 'checkpoints/vla_model_step_1.pt',
|
|
'num_episodes': 1,
|
|
'max_timesteps': 1,
|
|
'device': 'cpu',
|
|
'task_name': 'sim_transfer',
|
|
'camera_names': ['front'],
|
|
'use_smoothing': False,
|
|
'smooth_alpha': 0.3,
|
|
'verbose_action': False,
|
|
'headless': True,
|
|
},
|
|
}
|
|
)
|
|
run_eval_mock = mock.Mock()
|
|
|
|
with mock.patch.object(eval_vla, '_run_eval', run_eval_mock, create=True), \
|
|
mock.patch.object(eval_vla, 'load_checkpoint', return_value=(_FakeEvalAgent(), None)), \
|
|
mock.patch.object(eval_vla, 'make_sim_env', return_value=_FakeEvalEnv()), \
|
|
mock.patch.object(eval_vla, 'sample_transfer_pose', return_value=np.zeros(3)), \
|
|
mock.patch.object(eval_vla, 'execute_policy_action'), \
|
|
mock.patch.object(eval_vla, 'tqdm', side_effect=lambda iterable, **kwargs: iterable):
|
|
eval_vla.main.__wrapped__(cfg)
|
|
|
|
run_eval_mock.assert_called_once_with(cfg)
|
|
|
|
def test_run_training_rollout_validation_runs_every_50_epochs_and_uses_avg_reward_metric(self):
|
|
cfg = OmegaConf.create(
|
|
{
|
|
'train': {
|
|
'device': 'cpu',
|
|
'batch_size': 1,
|
|
'num_workers': 0,
|
|
'val_split': 0.0,
|
|
'seed': 0,
|
|
'lr': 1e-3,
|
|
'max_steps': 100,
|
|
'log_freq': 1,
|
|
'save_freq': 1000,
|
|
'warmup_steps': 1,
|
|
'scheduler_type': 'constant',
|
|
'min_lr': 0.0,
|
|
'grad_clip': 1.0,
|
|
'weight_decay': 0.0,
|
|
'pretrained_ckpt': None,
|
|
'resume_ckpt': None,
|
|
'use_swanlab': False,
|
|
'rollout_val_freq_epochs': 50,
|
|
'rollout_num_episodes': 3,
|
|
},
|
|
'data': {
|
|
'camera_names': ['front'],
|
|
},
|
|
'agent': {
|
|
'_target_': 'fake.agent',
|
|
},
|
|
'eval': {
|
|
'ckpt_path': 'unused.pt',
|
|
'num_episodes': 99,
|
|
'max_timesteps': 1,
|
|
'device': 'cpu',
|
|
'task_name': 'sim_transfer',
|
|
'camera_names': ['front'],
|
|
'use_smoothing': False,
|
|
'smooth_alpha': 0.3,
|
|
'verbose_action': False,
|
|
'headless': False,
|
|
},
|
|
}
|
|
)
|
|
agent = _FakeAgent()
|
|
rollout_mock = mock.Mock(side_effect=[{'avg_reward': 2.0}, {'avg_reward': 1.0}])
|
|
swanlab_log_mock = mock.Mock()
|
|
saved_checkpoints = []
|
|
|
|
def fake_instantiate(config_node, **_kwargs):
|
|
if config_node is cfg.data:
|
|
return _FakeDataset()
|
|
if config_node is cfg.agent:
|
|
return agent
|
|
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
|
|
|
def fake_dataloader(_dataset, *, shuffle, **_kwargs):
|
|
del shuffle, _kwargs
|
|
return _FakeLoader(
|
|
{
|
|
'observation.front': torch.zeros(1, 3, 2, 2),
|
|
'observation.state': torch.zeros(1, 4),
|
|
'action': torch.zeros(1, 2),
|
|
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
|
|
},
|
|
length=1,
|
|
)
|
|
|
|
def fake_torch_save(payload, path):
|
|
saved_checkpoints.append((str(path), deepcopy(payload)))
|
|
return None
|
|
|
|
with tempfile.TemporaryDirectory() as tempdir:
|
|
previous_cwd = os.getcwd()
|
|
try:
|
|
os.chdir(tempdir)
|
|
with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \
|
|
mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \
|
|
mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \
|
|
mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \
|
|
mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \
|
|
mock.patch.object(train_vla, '_log_to_swanlab', swanlab_log_mock), \
|
|
mock.patch.object(train_vla.torch, 'save', side_effect=fake_torch_save), \
|
|
mock.patch.object(eval_vla, '_run_eval', rollout_mock, create=True), \
|
|
mock.patch.object(eval_vla.main, '__wrapped__', side_effect=AssertionError('training hook should call eval_vla._run_eval')):
|
|
train_vla._run_training(cfg)
|
|
finally:
|
|
os.chdir(previous_cwd)
|
|
|
|
self.assertEqual(rollout_mock.call_count, 2)
|
|
first_rollout_cfg = rollout_mock.call_args_list[0].args[0]
|
|
second_rollout_cfg = rollout_mock.call_args_list[1].args[0]
|
|
self.assertEqual(first_rollout_cfg.eval.ckpt_path, 'checkpoints/vla_model_step_49.pt')
|
|
self.assertEqual(second_rollout_cfg.eval.ckpt_path, 'checkpoints/vla_model_step_99.pt')
|
|
self.assertEqual(first_rollout_cfg.eval.num_episodes, 3)
|
|
self.assertTrue(first_rollout_cfg.eval.headless)
|
|
self.assertEqual(first_rollout_cfg.eval.device, 'cpu')
|
|
self.assertFalse(first_rollout_cfg.eval.verbose_action)
|
|
self.assertEqual(cfg.eval.ckpt_path, 'unused.pt')
|
|
self.assertEqual(cfg.eval.num_episodes, 99)
|
|
self.assertFalse(cfg.eval.headless)
|
|
self.assertEqual(cfg.eval.device, 'cpu')
|
|
self.assertFalse(cfg.eval.verbose_action)
|
|
|
|
rollout_reward_logs = [
|
|
call.args[1]['rollout/avg_reward']
|
|
for call in swanlab_log_mock.call_args_list
|
|
if len(call.args) >= 2 and 'rollout/avg_reward' in call.args[1]
|
|
]
|
|
self.assertEqual(rollout_reward_logs, [2.0, 1.0])
|
|
|
|
best_model_saves = [
|
|
payload for path, payload in saved_checkpoints
|
|
if path.endswith('checkpoints/vla_model_best.pt')
|
|
]
|
|
self.assertEqual(len(best_model_saves), 1)
|
|
self.assertEqual(best_model_saves[0]['rollout_avg_reward'], 2.0)
|
|
|
|
def test_run_training_keeps_loss_based_best_checkpoint_until_first_rollout_metric_exists(self):
|
|
cfg = OmegaConf.create(
|
|
{
|
|
'train': {
|
|
'device': 'cpu',
|
|
'batch_size': 1,
|
|
'num_workers': 0,
|
|
'val_split': 0.0,
|
|
'seed': 0,
|
|
'lr': 1e-3,
|
|
'max_steps': 5,
|
|
'log_freq': 1,
|
|
'save_freq': 2,
|
|
'warmup_steps': 1,
|
|
'scheduler_type': 'constant',
|
|
'min_lr': 0.0,
|
|
'grad_clip': 1.0,
|
|
'weight_decay': 0.0,
|
|
'pretrained_ckpt': None,
|
|
'resume_ckpt': None,
|
|
'use_swanlab': False,
|
|
'rollout_val_freq_epochs': 50,
|
|
'rollout_num_episodes': 3,
|
|
},
|
|
'data': {
|
|
'camera_names': ['front'],
|
|
},
|
|
'agent': {
|
|
'_target_': 'fake.agent',
|
|
},
|
|
'eval': {
|
|
'ckpt_path': 'unused.pt',
|
|
'num_episodes': 99,
|
|
'max_timesteps': 1,
|
|
'device': 'cpu',
|
|
'task_name': 'sim_transfer',
|
|
'camera_names': ['front'],
|
|
'use_smoothing': False,
|
|
'smooth_alpha': 0.3,
|
|
'verbose_action': False,
|
|
'headless': False,
|
|
},
|
|
}
|
|
)
|
|
saved_checkpoints = []
|
|
rollout_mock = mock.Mock()
|
|
|
|
def fake_instantiate(config_node, **_kwargs):
|
|
if config_node is cfg.data:
|
|
return _FakeDataset()
|
|
if config_node is cfg.agent:
|
|
return _FakeAgent()
|
|
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
|
|
|
def fake_dataloader(_dataset, *, shuffle, **_kwargs):
|
|
del shuffle, _kwargs
|
|
return _FakeLoader(
|
|
{
|
|
'observation.front': torch.zeros(1, 3, 2, 2),
|
|
'observation.state': torch.zeros(1, 4),
|
|
'action': torch.zeros(1, 2),
|
|
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
|
|
},
|
|
length=5,
|
|
)
|
|
|
|
def fake_torch_save(payload, path):
|
|
saved_checkpoints.append((str(path), deepcopy(payload)))
|
|
return None
|
|
|
|
with tempfile.TemporaryDirectory() as tempdir:
|
|
previous_cwd = os.getcwd()
|
|
try:
|
|
os.chdir(tempdir)
|
|
with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \
|
|
mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \
|
|
mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \
|
|
mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \
|
|
mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \
|
|
mock.patch.object(train_vla.torch, 'save', side_effect=fake_torch_save), \
|
|
mock.patch.object(eval_vla, '_run_eval', rollout_mock, create=True):
|
|
train_vla._run_training(cfg)
|
|
finally:
|
|
os.chdir(previous_cwd)
|
|
|
|
self.assertEqual(rollout_mock.call_count, 0)
|
|
best_model_saves = [
|
|
payload for path, payload in saved_checkpoints
|
|
if path.endswith('checkpoints/vla_model_best.pt')
|
|
]
|
|
self.assertEqual(len(best_model_saves), 1)
|
|
self.assertIsNone(best_model_saves[0]['rollout_avg_reward'])
|
|
|
|
def test_run_training_disables_drop_last_when_train_set_is_smaller_than_batch_size(self):
|
|
cfg = OmegaConf.create(
|
|
{
|
|
'train': {
|
|
'device': 'cpu',
|
|
'batch_size': 8,
|
|
'num_workers': 0,
|
|
'val_split': 0.0,
|
|
'seed': 0,
|
|
'lr': 1e-3,
|
|
'max_steps': 1,
|
|
'log_freq': 1,
|
|
'save_freq': 10,
|
|
'warmup_steps': 1,
|
|
'scheduler_type': 'constant',
|
|
'min_lr': 0.0,
|
|
'grad_clip': 1.0,
|
|
'weight_decay': 0.0,
|
|
'pretrained_ckpt': None,
|
|
'resume_ckpt': None,
|
|
'use_swanlab': False,
|
|
'rollout_val_freq_epochs': 50,
|
|
'rollout_num_episodes': 3,
|
|
},
|
|
'data': {
|
|
'camera_names': ['front'],
|
|
},
|
|
'agent': {
|
|
'_target_': 'fake.agent',
|
|
},
|
|
'eval': {
|
|
'ckpt_path': 'unused.pt',
|
|
'num_episodes': 99,
|
|
'max_timesteps': 1,
|
|
'device': 'cpu',
|
|
'task_name': 'sim_transfer',
|
|
'camera_names': ['front'],
|
|
'use_smoothing': False,
|
|
'smooth_alpha': 0.3,
|
|
'verbose_action': False,
|
|
'headless': False,
|
|
},
|
|
}
|
|
)
|
|
dataloader_calls = []
|
|
|
|
def fake_instantiate(config_node, **_kwargs):
|
|
if config_node is cfg.data:
|
|
return _FakeDataset()
|
|
if config_node is cfg.agent:
|
|
return _FakeAgent()
|
|
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
|
|
|
def fake_dataloader(dataset, *, shuffle, drop_last, **_kwargs):
|
|
dataloader_calls.append({
|
|
'shuffle': shuffle,
|
|
'drop_last': drop_last,
|
|
'dataset_len': len(dataset),
|
|
})
|
|
return _FakeLoader(
|
|
{
|
|
'observation.front': torch.zeros(1, 3, 2, 2),
|
|
'observation.state': torch.zeros(1, 4),
|
|
'action': torch.zeros(1, 2),
|
|
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
|
|
},
|
|
length=1,
|
|
)
|
|
|
|
with tempfile.TemporaryDirectory() as tempdir:
|
|
previous_cwd = os.getcwd()
|
|
try:
|
|
os.chdir(tempdir)
|
|
with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \
|
|
mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \
|
|
mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \
|
|
mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \
|
|
mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \
|
|
mock.patch.object(train_vla.torch, 'save', return_value=None):
|
|
train_vla._run_training(cfg)
|
|
finally:
|
|
os.chdir(previous_cwd)
|
|
|
|
train_loader_calls = [call for call in dataloader_calls if call['shuffle']]
|
|
self.assertEqual(len(train_loader_calls), 1)
|
|
self.assertFalse(train_loader_calls[0]['drop_last'])
|
|
|
|
def test_run_training_disables_persistent_workers_for_train_and_val_loaders(self):
|
|
cfg = OmegaConf.create(
|
|
{
|
|
'train': {
|
|
'device': 'cpu',
|
|
'batch_size': 2,
|
|
'num_workers': 2,
|
|
'val_split': 0.25,
|
|
'seed': 0,
|
|
'lr': 1e-3,
|
|
'max_steps': 1,
|
|
'log_freq': 1,
|
|
'save_freq': 10,
|
|
'warmup_steps': 1,
|
|
'scheduler_type': 'constant',
|
|
'min_lr': 0.0,
|
|
'grad_clip': 1.0,
|
|
'weight_decay': 0.0,
|
|
'pretrained_ckpt': None,
|
|
'resume_ckpt': None,
|
|
'use_swanlab': False,
|
|
'rollout_val_freq_epochs': 50,
|
|
'rollout_num_episodes': 3,
|
|
},
|
|
'data': {
|
|
'camera_names': ['front'],
|
|
},
|
|
'agent': {
|
|
'_target_': 'fake.agent',
|
|
},
|
|
'eval': {
|
|
'ckpt_path': 'unused.pt',
|
|
'num_episodes': 99,
|
|
'max_timesteps': 1,
|
|
'device': 'cpu',
|
|
'task_name': 'sim_transfer',
|
|
'camera_names': ['front'],
|
|
'use_smoothing': False,
|
|
'smooth_alpha': 0.3,
|
|
'verbose_action': False,
|
|
'headless': False,
|
|
},
|
|
}
|
|
)
|
|
dataloader_calls = []
|
|
|
|
def fake_instantiate(config_node, **_kwargs):
|
|
if config_node is cfg.data:
|
|
return _FakeDataset()
|
|
if config_node is cfg.agent:
|
|
return _FakeAgent()
|
|
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
|
|
|
def fake_dataloader(_dataset, *, shuffle, persistent_workers, num_workers, **_kwargs):
|
|
dataloader_calls.append({
|
|
'shuffle': shuffle,
|
|
'num_workers': num_workers,
|
|
'persistent_workers': persistent_workers,
|
|
})
|
|
return _FakeLoader(
|
|
{
|
|
'observation.front': torch.zeros(1, 3, 2, 2),
|
|
'observation.state': torch.zeros(1, 4),
|
|
'action': torch.zeros(1, 2),
|
|
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
|
|
},
|
|
length=1,
|
|
)
|
|
|
|
with tempfile.TemporaryDirectory() as tempdir:
|
|
previous_cwd = os.getcwd()
|
|
try:
|
|
os.chdir(tempdir)
|
|
with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \
|
|
mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \
|
|
mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \
|
|
mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \
|
|
mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \
|
|
mock.patch.object(train_vla.torch, 'save', return_value=None):
|
|
train_vla._run_training(cfg)
|
|
finally:
|
|
os.chdir(previous_cwd)
|
|
|
|
self.assertEqual(len(dataloader_calls), 2)
|
|
self.assertEqual([call['shuffle'] for call in dataloader_calls], [True, False])
|
|
self.assertTrue(all(call['num_workers'] == 2 for call in dataloader_calls))
|
|
self.assertTrue(all(call['persistent_workers'] is False for call in dataloader_calls))
|
|
|
|
def test_run_training_uses_loss_best_until_first_rollout_then_prefers_rollout_reward(self):
|
|
cfg = OmegaConf.create(
|
|
{
|
|
'train': {
|
|
'device': 'cpu',
|
|
'batch_size': 1,
|
|
'num_workers': 0,
|
|
'val_split': 0.0,
|
|
'seed': 0,
|
|
'lr': 1e-3,
|
|
'max_steps': 6,
|
|
'log_freq': 1,
|
|
'save_freq': 1,
|
|
'warmup_steps': 1,
|
|
'scheduler_type': 'constant',
|
|
'min_lr': 0.0,
|
|
'grad_clip': 1.0,
|
|
'weight_decay': 0.0,
|
|
'pretrained_ckpt': None,
|
|
'resume_ckpt': None,
|
|
'use_swanlab': False,
|
|
'rollout_val_freq_epochs': 2,
|
|
'rollout_num_episodes': 1,
|
|
},
|
|
'data': {
|
|
'camera_names': ['front'],
|
|
},
|
|
'agent': {
|
|
'_target_': 'fake.agent',
|
|
},
|
|
'eval': {
|
|
'ckpt_path': 'unused.pt',
|
|
'num_episodes': 99,
|
|
'max_timesteps': 1,
|
|
'device': 'cpu',
|
|
'task_name': 'sim_transfer',
|
|
'camera_names': ['front'],
|
|
'use_smoothing': False,
|
|
'smooth_alpha': 0.3,
|
|
'verbose_action': False,
|
|
'headless': False,
|
|
},
|
|
}
|
|
)
|
|
agent = _SequentialLossAgent([10, 9, 8, 7, 6, 5])
|
|
rollout_mock = mock.Mock(return_value={'avg_reward': 1.0})
|
|
saved_checkpoints = []
|
|
|
|
def fake_instantiate(config_node, **_kwargs):
|
|
if config_node is cfg.data:
|
|
return _FakeDataset()
|
|
if config_node is cfg.agent:
|
|
return agent
|
|
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
|
|
|
def fake_dataloader(_dataset, *, shuffle, **_kwargs):
|
|
del _kwargs
|
|
return _FakeLoader(
|
|
{
|
|
'observation.front': torch.zeros(1, 3, 2, 2),
|
|
'observation.state': torch.zeros(1, 4),
|
|
'action': torch.zeros(1, 2),
|
|
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
|
|
},
|
|
length=2 if shuffle else 1,
|
|
)
|
|
|
|
def fake_torch_save(payload, path):
|
|
saved_checkpoints.append((str(path), deepcopy(payload)))
|
|
return None
|
|
|
|
with tempfile.TemporaryDirectory() as tempdir:
|
|
previous_cwd = os.getcwd()
|
|
try:
|
|
os.chdir(tempdir)
|
|
with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \
|
|
mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \
|
|
mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \
|
|
mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \
|
|
mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \
|
|
mock.patch.object(train_vla.torch, 'save', side_effect=fake_torch_save), \
|
|
mock.patch.object(eval_vla, '_run_eval', rollout_mock, create=True):
|
|
train_vla._run_training(cfg)
|
|
finally:
|
|
os.chdir(previous_cwd)
|
|
|
|
best_model_saves = [
|
|
(payload['step'], payload['rollout_avg_reward'])
|
|
for path, payload in saved_checkpoints
|
|
if path.endswith('checkpoints/vla_model_best.pt')
|
|
]
|
|
self.assertEqual(
|
|
best_model_saves,
|
|
[
|
|
(1, None),
|
|
(2, None),
|
|
(3, None),
|
|
(3, 1.0),
|
|
],
|
|
)
|
|
self.assertEqual(rollout_mock.call_count, 1)
|
|
|
|
def test_run_training_keeps_tiny_train_dataset_batch_when_batch_size_is_larger(self):
|
|
cfg = OmegaConf.create(
|
|
{
|
|
'train': {
|
|
'device': 'cpu',
|
|
'batch_size': 8,
|
|
'num_workers': 0,
|
|
'val_split': 0.0,
|
|
'seed': 0,
|
|
'lr': 1e-3,
|
|
'max_steps': 1,
|
|
'log_freq': 1,
|
|
'save_freq': 1000,
|
|
'warmup_steps': 1,
|
|
'scheduler_type': 'constant',
|
|
'min_lr': 0.0,
|
|
'grad_clip': 1.0,
|
|
'weight_decay': 0.0,
|
|
'pretrained_ckpt': None,
|
|
'resume_ckpt': None,
|
|
'use_swanlab': False,
|
|
'rollout_val_freq_epochs': 0,
|
|
},
|
|
'data': {
|
|
'camera_names': ['front'],
|
|
},
|
|
'agent': {
|
|
'_target_': 'fake.agent',
|
|
},
|
|
}
|
|
)
|
|
agent = _FakeAgent()
|
|
dataloader_calls = []
|
|
saved_checkpoints = []
|
|
|
|
class _TinyDataset:
|
|
def __len__(self):
|
|
return 1
|
|
|
|
def fake_instantiate(config_node, **_kwargs):
|
|
if config_node is cfg.data:
|
|
return _TinyDataset()
|
|
if config_node is cfg.agent:
|
|
return agent
|
|
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
|
|
|
def fake_dataloader(dataset, *, drop_last, shuffle, **_kwargs):
|
|
del _kwargs
|
|
dataloader_calls.append(
|
|
{
|
|
'shuffle': shuffle,
|
|
'drop_last': drop_last,
|
|
'dataset_len': len(dataset),
|
|
}
|
|
)
|
|
loader_length = 0 if drop_last and len(dataset) < cfg.train.batch_size else 1
|
|
return _FakeLoader(
|
|
{
|
|
'observation.front': torch.zeros(1, 3, 2, 2),
|
|
'observation.state': torch.zeros(1, 4),
|
|
'action': torch.zeros(1, 2),
|
|
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
|
|
},
|
|
length=loader_length,
|
|
)
|
|
|
|
def fake_torch_save(payload, path):
|
|
saved_checkpoints.append((str(path), deepcopy(payload)))
|
|
return None
|
|
|
|
with tempfile.TemporaryDirectory() as tempdir:
|
|
previous_cwd = os.getcwd()
|
|
try:
|
|
os.chdir(tempdir)
|
|
with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \
|
|
mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \
|
|
mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \
|
|
mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \
|
|
mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \
|
|
mock.patch.object(train_vla.torch, 'save', side_effect=fake_torch_save):
|
|
train_vla._run_training(cfg)
|
|
finally:
|
|
os.chdir(previous_cwd)
|
|
|
|
self.assertEqual(
|
|
dataloader_calls[0],
|
|
{
|
|
'shuffle': True,
|
|
'drop_last': False,
|
|
'dataset_len': 1,
|
|
},
|
|
)
|
|
self.assertEqual(
|
|
[path for path, _payload in saved_checkpoints],
|
|
['checkpoints/vla_model_final.pt'],
|
|
)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|