feat: add vision transfer backbones and IMF variants
This commit is contained in:
@@ -159,6 +159,92 @@ class TrainVLARolloutValidationTest(unittest.TestCase):
|
||||
self.assertGreater(cfg.train.num_workers, 8)
|
||||
self.assertEqual(cfg.train.rollout_val_freq_epochs, 50)
|
||||
|
||||
def test_training_passes_backbone_image_resize_override_to_dataset_instantiation(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
'agent': {
|
||||
'vision_backbone': {
|
||||
'dataset_image_resize_shape': None,
|
||||
},
|
||||
'normalization_type': 'min_max',
|
||||
},
|
||||
'data': {
|
||||
'dataset_dir': 'unused',
|
||||
'camera_names': ['front'],
|
||||
},
|
||||
'train': {
|
||||
'batch_size': 2,
|
||||
'lr': 1e-4,
|
||||
'max_steps': 0,
|
||||
'device': 'cpu',
|
||||
'disable_cudnn': False,
|
||||
'num_workers': 0,
|
||||
'val_split': 0.0,
|
||||
'seed': 42,
|
||||
'log_freq': 1,
|
||||
'save_freq': 10,
|
||||
'use_swanlab': False,
|
||||
'rollout_val_freq_epochs': 0,
|
||||
'rollout_validate_on_checkpoint': False,
|
||||
'rollout_num_episodes': 1,
|
||||
'warmup_steps': 1,
|
||||
'scheduler_type': 'constant',
|
||||
'min_lr': 1e-6,
|
||||
'weight_decay': 1e-5,
|
||||
'grad_clip': 1.0,
|
||||
'pretrained_ckpt': None,
|
||||
},
|
||||
'eval': {
|
||||
'ckpt_path': 'unused.pt',
|
||||
'num_episodes': 1,
|
||||
'headless': True,
|
||||
'device': 'cpu',
|
||||
'verbose_action': False,
|
||||
},
|
||||
'experiment': {},
|
||||
}
|
||||
)
|
||||
captured_dataset_kwargs = {}
|
||||
|
||||
def fake_instantiate(config_node, **kwargs):
|
||||
if config_node is cfg.data:
|
||||
captured_dataset_kwargs.update(kwargs)
|
||||
return _FakeDataset()
|
||||
if config_node is cfg.agent:
|
||||
return _FakeAgent()
|
||||
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||
|
||||
def fake_dataloader(_dataset, *, shuffle, **_kwargs):
|
||||
del shuffle, _kwargs
|
||||
return _FakeLoader(
|
||||
{
|
||||
'observation.front': torch.zeros(1, 3, 2, 2),
|
||||
'observation.state': torch.zeros(1, 4),
|
||||
'action': torch.zeros(1, 2),
|
||||
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
|
||||
},
|
||||
length=1,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \
|
||||
mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \
|
||||
mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \
|
||||
mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \
|
||||
mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \
|
||||
mock.patch.object(train_vla, '_init_swanlab', return_value=None), \
|
||||
mock.patch.object(train_vla, '_finish_swanlab', return_value=None), \
|
||||
mock.patch.object(train_vla.torch, 'save', return_value=None):
|
||||
train_vla._run_training(cfg)
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
self.assertIn('image_resize_shape', captured_dataset_kwargs)
|
||||
self.assertIsNone(captured_dataset_kwargs['image_resize_shape'])
|
||||
|
||||
def test_eval_main_delegates_to_plain_run_eval_helper(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user