700 lines
27 KiB
Python
700 lines
27 KiB
Python
import importlib
|
|
import importlib.util
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
import types
|
|
import unittest
|
|
from pathlib import Path
|
|
from unittest import mock
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
|
|
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
|
_TRAIN_VLA_PATH = _REPO_ROOT / 'roboimi/demos/vla_scripts/train_vla.py'
|
|
_CONFIG_PATH = _REPO_ROOT / 'roboimi/vla/conf/config.yaml'
|
|
|
|
|
|
class AttrDict(dict):
|
|
def __getattr__(self, name):
|
|
try:
|
|
return self[name]
|
|
except KeyError as exc:
|
|
raise AttributeError(name) from exc
|
|
|
|
def __setattr__(self, name, value):
|
|
self[name] = value
|
|
|
|
|
|
def _to_attrdict(value):
|
|
if isinstance(value, dict):
|
|
return AttrDict({key: _to_attrdict(item) for key, item in value.items()})
|
|
if isinstance(value, list):
|
|
return [_to_attrdict(item) for item in value]
|
|
return value
|
|
|
|
|
|
class FakeDataset:
|
|
def __len__(self):
|
|
return 4
|
|
|
|
|
|
class FakeLoader:
|
|
def __init__(self, batch):
|
|
self.batch = batch
|
|
|
|
def __len__(self):
|
|
return 1
|
|
|
|
def __iter__(self):
|
|
return iter((self.batch,))
|
|
|
|
|
|
class FakeScheduler:
|
|
def __init__(self):
|
|
self.step_calls = 0
|
|
|
|
def step(self):
|
|
self.step_calls += 1
|
|
|
|
def state_dict(self):
|
|
return {}
|
|
|
|
def load_state_dict(self, state_dict):
|
|
return None
|
|
|
|
|
|
class FakeOptimizer:
|
|
def __init__(self, lr=1e-3):
|
|
self.param_groups = [{'lr': lr}]
|
|
self.loaded_state_dict = None
|
|
|
|
def zero_grad(self):
|
|
return None
|
|
|
|
def step(self):
|
|
return None
|
|
|
|
def state_dict(self):
|
|
return {}
|
|
|
|
def load_state_dict(self, state_dict):
|
|
self.loaded_state_dict = state_dict
|
|
return None
|
|
|
|
|
|
class FakeProgressBar:
|
|
def __init__(self, iterable):
|
|
self._items = list(iterable)
|
|
self.postfix_calls = []
|
|
|
|
def __iter__(self):
|
|
return iter(self._items)
|
|
|
|
def set_postfix(self, values):
|
|
self.postfix_calls.append(values)
|
|
|
|
|
|
class FakeAgent(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.tensor(0.0))
|
|
|
|
def to(self, device):
|
|
return self
|
|
|
|
def compute_loss(self, agent_input):
|
|
del agent_input
|
|
target = torch.tensor(0.25 if self.training else 0.1)
|
|
return (self.weight - target).pow(2)
|
|
|
|
def get_normalization_stats(self):
|
|
return {}
|
|
|
|
|
|
class FakeSwanLab:
|
|
def __init__(self, init_error=None, log_errors=None, finish_error=None):
|
|
self.init_error = init_error
|
|
self.log_errors = list(log_errors or [])
|
|
self.finish_error = finish_error
|
|
self.init_calls = []
|
|
self.log_calls = []
|
|
self.finish_calls = 0
|
|
|
|
def init(self, project, experiment_name=None, config=None):
|
|
self.init_calls.append({
|
|
'project': project,
|
|
'experiment_name': experiment_name,
|
|
'config': config,
|
|
})
|
|
if self.init_error is not None:
|
|
raise self.init_error
|
|
return object()
|
|
|
|
def log(self, payload, step=None):
|
|
self.log_calls.append((dict(payload), step))
|
|
if self.log_errors:
|
|
raise self.log_errors.pop(0)
|
|
|
|
def finish(self):
|
|
self.finish_calls += 1
|
|
if self.finish_error is not None:
|
|
raise self.finish_error
|
|
|
|
|
|
class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
|
def test_default_config_keeps_swanlab_opt_in(self):
|
|
config_text = _CONFIG_PATH.read_text(encoding='utf-8')
|
|
self.assertIn('use_swanlab: false', config_text)
|
|
|
|
def _load_train_vla_module(self):
|
|
hydra_module = types.ModuleType('hydra')
|
|
hydra_utils_module = types.ModuleType('hydra.utils')
|
|
hydra_utils_module.instantiate = lambda *args, **kwargs: None
|
|
|
|
def hydra_main(**_kwargs):
|
|
def decorator(func):
|
|
return func
|
|
return decorator
|
|
|
|
hydra_module.main = hydra_main
|
|
hydra_module.utils = hydra_utils_module
|
|
|
|
class OmegaConfStub:
|
|
_resolvers = {}
|
|
|
|
@classmethod
|
|
def has_resolver(cls, name):
|
|
return name in cls._resolvers
|
|
|
|
@classmethod
|
|
def register_new_resolver(cls, name, resolver):
|
|
cls._resolvers[name] = resolver
|
|
|
|
@staticmethod
|
|
def to_yaml(_cfg):
|
|
return 'stub-config'
|
|
|
|
@staticmethod
|
|
def to_container(cfg, resolve=False):
|
|
del resolve
|
|
return dict(cfg)
|
|
|
|
@staticmethod
|
|
def create(cfg):
|
|
return _to_attrdict(cfg)
|
|
|
|
omegaconf_module = types.ModuleType('omegaconf')
|
|
omegaconf_module.DictConfig = dict
|
|
omegaconf_module.OmegaConf = OmegaConfStub
|
|
|
|
module_name = 'train_vla_swanlab_test_module'
|
|
spec = importlib.util.spec_from_file_location(module_name, _TRAIN_VLA_PATH)
|
|
module = importlib.util.module_from_spec(spec)
|
|
with mock.patch.dict(
|
|
sys.modules,
|
|
{
|
|
'hydra': hydra_module,
|
|
'hydra.utils': hydra_utils_module,
|
|
'omegaconf': omegaconf_module,
|
|
},
|
|
):
|
|
assert spec.loader is not None
|
|
spec.loader.exec_module(module)
|
|
return module
|
|
|
|
def _make_cfg(self, *, use_swanlab=True, swanlab_run_name='smoke-run'):
|
|
return AttrDict(
|
|
train=AttrDict(
|
|
device='cpu',
|
|
batch_size=2,
|
|
num_workers=0,
|
|
val_split=0.25,
|
|
seed=0,
|
|
lr=1e-3,
|
|
max_steps=2,
|
|
log_freq=1,
|
|
save_freq=1,
|
|
warmup_steps=1,
|
|
scheduler_type='constant',
|
|
min_lr=0.0,
|
|
grad_clip=1.0,
|
|
weight_decay=0.0,
|
|
pretrained_ckpt=None,
|
|
resume_ckpt=None,
|
|
use_swanlab=use_swanlab,
|
|
swanlab_project='roboimi-vla-tests',
|
|
swanlab_run_name=swanlab_run_name,
|
|
),
|
|
data=AttrDict(
|
|
camera_names=('front',),
|
|
),
|
|
agent=AttrDict(
|
|
_target_='fake.agent',
|
|
),
|
|
eval=AttrDict(
|
|
ckpt_path='unused.pt',
|
|
num_episodes=1,
|
|
max_timesteps=1,
|
|
device='cpu',
|
|
task_name='sim_transfer',
|
|
camera_names=('front',),
|
|
use_smoothing=False,
|
|
smooth_alpha=0.3,
|
|
verbose_action=False,
|
|
headless=False,
|
|
),
|
|
)
|
|
|
|
def _get_run_training(self, module):
|
|
run_training = getattr(module, '_run_training', None)
|
|
self.assertIsNotNone(run_training, 'Expected train_vla.py to expose a _run_training(cfg) helper')
|
|
return run_training
|
|
|
|
def _make_batch(self):
|
|
return {
|
|
'observation.front': torch.zeros(1, 3, 2, 2),
|
|
'observation.state': torch.zeros(1, 4),
|
|
'action': torch.zeros(1, 2),
|
|
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
|
|
}
|
|
|
|
def _loader_factory(self):
|
|
train_batch = self._make_batch()
|
|
val_batch = self._make_batch()
|
|
|
|
def factory(_dataset, *, shuffle, **_kwargs):
|
|
return FakeLoader(train_batch if shuffle else val_batch)
|
|
|
|
return factory
|
|
|
|
def test_run_training_logs_metrics_and_checkpoint_paths_to_swanlab(self):
|
|
module = self._load_train_vla_module()
|
|
run_training = self._get_run_training(module)
|
|
cfg = self._make_cfg()
|
|
agent = FakeAgent()
|
|
fake_swanlab = FakeSwanLab()
|
|
real_import_module = importlib.import_module
|
|
|
|
def fake_instantiate(config_node, **_kwargs):
|
|
if config_node is cfg.data:
|
|
return FakeDataset()
|
|
if config_node is cfg.agent:
|
|
return agent
|
|
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
|
|
|
def fake_import_module(name, package=None):
|
|
if name == 'swanlab':
|
|
return fake_swanlab
|
|
return real_import_module(name, package)
|
|
|
|
with tempfile.TemporaryDirectory() as tempdir:
|
|
previous_cwd = os.getcwd()
|
|
try:
|
|
os.chdir(tempdir)
|
|
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
|
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
|
|
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
|
|
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
|
|
mock.patch.object(module.torch, 'save', return_value=None), \
|
|
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
|
|
run_training(cfg)
|
|
finally:
|
|
os.chdir(previous_cwd)
|
|
|
|
self.assertEqual(
|
|
fake_swanlab.init_calls,
|
|
[{
|
|
'project': 'roboimi-vla-tests',
|
|
'experiment_name': 'smoke-run',
|
|
'config': {
|
|
'train': {
|
|
'device': 'cpu',
|
|
'batch_size': 2,
|
|
'num_workers': 0,
|
|
'val_split': 0.25,
|
|
'seed': 0,
|
|
'lr': 1e-3,
|
|
'max_steps': 2,
|
|
'log_freq': 1,
|
|
'save_freq': 1,
|
|
'warmup_steps': 1,
|
|
'scheduler_type': 'constant',
|
|
'min_lr': 0.0,
|
|
'grad_clip': 1.0,
|
|
'weight_decay': 0.0,
|
|
'pretrained_ckpt': None,
|
|
'resume_ckpt': None,
|
|
'use_swanlab': True,
|
|
'swanlab_project': 'roboimi-vla-tests',
|
|
'swanlab_run_name': 'smoke-run',
|
|
},
|
|
'data': {
|
|
'camera_names': ('front',),
|
|
},
|
|
'agent': {
|
|
'_target_': 'fake.agent',
|
|
},
|
|
},
|
|
}],
|
|
)
|
|
|
|
logged_keys = set().union(*(payload.keys() for payload, _step in fake_swanlab.log_calls))
|
|
self.assertTrue(
|
|
{
|
|
'train/loss',
|
|
'train/lr',
|
|
'train/best_loss',
|
|
'train/step',
|
|
'val/loss',
|
|
'final/checkpoint_path',
|
|
'final/best_checkpoint_path',
|
|
}.issubset(logged_keys)
|
|
)
|
|
|
|
final_payload, final_step = fake_swanlab.log_calls[-1]
|
|
self.assertEqual(final_step, cfg.train.max_steps)
|
|
self.assertEqual(final_payload['final/checkpoint_path'], 'checkpoints/vla_model_final.pt')
|
|
self.assertEqual(final_payload['final/best_checkpoint_path'], 'checkpoints/vla_model_best.pt')
|
|
self.assertEqual(fake_swanlab.finish_calls, 1)
|
|
|
|
def test_run_training_skips_swanlab_when_disabled(self):
|
|
module = self._load_train_vla_module()
|
|
run_training = self._get_run_training(module)
|
|
cfg = self._make_cfg(use_swanlab=False)
|
|
agent = FakeAgent()
|
|
|
|
def fake_instantiate(config_node, **_kwargs):
|
|
if config_node is cfg.data:
|
|
return FakeDataset()
|
|
if config_node is cfg.agent:
|
|
return agent
|
|
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
|
|
|
with tempfile.TemporaryDirectory() as tempdir:
|
|
previous_cwd = os.getcwd()
|
|
try:
|
|
os.chdir(tempdir)
|
|
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
|
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
|
|
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
|
|
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
|
|
mock.patch.object(module.torch, 'save', return_value=None), \
|
|
mock.patch.object(module.importlib, 'import_module', side_effect=AssertionError('swanlab import should not run')):
|
|
run_training(cfg)
|
|
finally:
|
|
os.chdir(previous_cwd)
|
|
|
|
def test_run_training_finishes_swanlab_when_exception_happens_after_init(self):
|
|
module = self._load_train_vla_module()
|
|
run_training = self._get_run_training(module)
|
|
cfg = self._make_cfg()
|
|
fake_swanlab = FakeSwanLab()
|
|
real_import_module = importlib.import_module
|
|
|
|
def fake_import_module(name, package=None):
|
|
if name == 'swanlab':
|
|
return fake_swanlab
|
|
return real_import_module(name, package)
|
|
|
|
with tempfile.TemporaryDirectory() as tempdir:
|
|
previous_cwd = os.getcwd()
|
|
try:
|
|
os.chdir(tempdir)
|
|
with mock.patch.object(module, 'instantiate', side_effect=RuntimeError('dataset boom')), \
|
|
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
|
|
with self.assertRaisesRegex(RuntimeError, 'dataset boom'):
|
|
run_training(cfg)
|
|
finally:
|
|
os.chdir(previous_cwd)
|
|
|
|
self.assertEqual(fake_swanlab.finish_calls, 1)
|
|
|
|
def test_run_training_warns_and_continues_when_swanlab_log_and_finish_fail(self):
|
|
module = self._load_train_vla_module()
|
|
run_training = self._get_run_training(module)
|
|
cfg = self._make_cfg()
|
|
agent = FakeAgent()
|
|
fake_swanlab = FakeSwanLab(
|
|
log_errors=[RuntimeError('log backend hiccup')],
|
|
finish_error=RuntimeError('finish backend hiccup'),
|
|
)
|
|
real_import_module = importlib.import_module
|
|
|
|
def fake_instantiate(config_node, **_kwargs):
|
|
if config_node is cfg.data:
|
|
return FakeDataset()
|
|
if config_node is cfg.agent:
|
|
return agent
|
|
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
|
|
|
def fake_import_module(name, package=None):
|
|
if name == 'swanlab':
|
|
return fake_swanlab
|
|
return real_import_module(name, package)
|
|
|
|
with tempfile.TemporaryDirectory() as tempdir:
|
|
previous_cwd = os.getcwd()
|
|
try:
|
|
os.chdir(tempdir)
|
|
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
|
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
|
|
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
|
|
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
|
|
mock.patch.object(module.torch, 'save', return_value=None), \
|
|
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module), \
|
|
mock.patch.object(module.log, 'warning') as warning_mock:
|
|
run_training(cfg)
|
|
finally:
|
|
os.chdir(previous_cwd)
|
|
|
|
warning_messages = [call.args[0] for call in warning_mock.call_args_list]
|
|
self.assertTrue(any('SwanLab log failed' in message for message in warning_messages))
|
|
self.assertTrue(any('SwanLab finish failed' in message for message in warning_messages))
|
|
self.assertEqual(fake_swanlab.finish_calls, 1)
|
|
|
|
def test_run_training_resume_restores_best_rollout_baseline_from_best_checkpoint(self):
|
|
module = self._load_train_vla_module()
|
|
run_training = self._get_run_training(module)
|
|
cfg = self._make_cfg()
|
|
cfg.train.max_steps = 2
|
|
cfg.train.save_freq = 1
|
|
cfg.train.rollout_validate_on_checkpoint = True
|
|
fake_swanlab = FakeSwanLab()
|
|
fake_optimizer = FakeOptimizer(lr=cfg.train.lr)
|
|
fake_scheduler = FakeScheduler()
|
|
real_import_module = importlib.import_module
|
|
saved_paths = []
|
|
|
|
def fake_instantiate(config_node, **_kwargs):
|
|
if config_node is cfg.data:
|
|
return FakeDataset()
|
|
if config_node is cfg.agent:
|
|
return FakeAgent()
|
|
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
|
|
|
def fake_import_module(name, package=None):
|
|
if name == 'swanlab':
|
|
return fake_swanlab
|
|
return real_import_module(name, package)
|
|
|
|
with tempfile.TemporaryDirectory() as tempdir:
|
|
previous_cwd = os.getcwd()
|
|
try:
|
|
os.chdir(tempdir)
|
|
checkpoint_dir = Path('checkpoints')
|
|
checkpoint_dir.mkdir()
|
|
resume_path = checkpoint_dir / 'vla_model_step_0.pt'
|
|
resume_path.write_bytes(b'resume')
|
|
best_path = checkpoint_dir / 'vla_model_best.pt'
|
|
best_path.write_bytes(b'best')
|
|
cfg.train.resume_ckpt = str(resume_path)
|
|
|
|
resume_checkpoint_state = {
|
|
'step': 0,
|
|
'model_state_dict': FakeAgent().state_dict(),
|
|
'optimizer_state_dict': {},
|
|
'scheduler_state_dict': {},
|
|
'loss': 0.5,
|
|
'val_loss': 0.25,
|
|
}
|
|
best_checkpoint_state = {
|
|
'step': 0,
|
|
'model_state_dict': FakeAgent().state_dict(),
|
|
'optimizer_state_dict': {},
|
|
'scheduler_state_dict': {},
|
|
'loss': 0.5,
|
|
'val_loss': 0.25,
|
|
'rollout_avg_reward': 5.0,
|
|
}
|
|
|
|
def fake_torch_load(path, map_location=None):
|
|
del map_location
|
|
path = Path(path)
|
|
if path == resume_path:
|
|
return resume_checkpoint_state
|
|
if path == best_path:
|
|
return best_checkpoint_state
|
|
raise AssertionError(f'unexpected load path: {path}')
|
|
|
|
def fake_torch_save(payload, path):
|
|
saved_paths.append(str(path))
|
|
return None
|
|
|
|
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
|
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
|
|
mock.patch.object(module, 'build_training_optimizer', return_value=fake_optimizer), \
|
|
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=fake_scheduler), \
|
|
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
|
|
mock.patch.object(module.torch, 'save', side_effect=fake_torch_save), \
|
|
mock.patch.object(module.torch, 'load', side_effect=fake_torch_load), \
|
|
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module), \
|
|
mock.patch('roboimi.demos.vla_scripts.eval_vla._run_eval', return_value={'avg_reward': 3.0}):
|
|
run_training(cfg)
|
|
finally:
|
|
os.chdir(previous_cwd)
|
|
|
|
final_payload, final_step = fake_swanlab.log_calls[-1]
|
|
self.assertEqual(final_step, cfg.train.max_steps)
|
|
self.assertEqual(final_payload['final/best_checkpoint_path'], 'checkpoints/vla_model_best.pt')
|
|
self.assertNotIn('checkpoints/vla_model_best.pt', saved_paths)
|
|
|
|
def test_run_training_resume_ignores_best_checkpoint_without_rollout_metric(self):
|
|
module = self._load_train_vla_module()
|
|
run_training = self._get_run_training(module)
|
|
cfg = self._make_cfg()
|
|
cfg.train.max_steps = 1
|
|
fake_swanlab = FakeSwanLab()
|
|
fake_optimizer = FakeOptimizer(lr=cfg.train.lr)
|
|
fake_scheduler = FakeScheduler()
|
|
real_import_module = importlib.import_module
|
|
|
|
def fake_instantiate(config_node, **_kwargs):
|
|
if config_node is cfg.data:
|
|
return FakeDataset()
|
|
if config_node is cfg.agent:
|
|
return FakeAgent()
|
|
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
|
|
|
def fake_import_module(name, package=None):
|
|
if name == 'swanlab':
|
|
return fake_swanlab
|
|
return real_import_module(name, package)
|
|
|
|
with tempfile.TemporaryDirectory() as tempdir:
|
|
previous_cwd = os.getcwd()
|
|
try:
|
|
os.chdir(tempdir)
|
|
checkpoint_dir = Path('checkpoints')
|
|
checkpoint_dir.mkdir()
|
|
resume_path = checkpoint_dir / 'vla_model_step_0.pt'
|
|
resume_path.write_bytes(b'resume')
|
|
best_path = checkpoint_dir / 'vla_model_best.pt'
|
|
best_path.write_bytes(b'stale')
|
|
cfg.train.resume_ckpt = str(resume_path)
|
|
|
|
resume_checkpoint_state = {
|
|
'step': 0,
|
|
'model_state_dict': FakeAgent().state_dict(),
|
|
'optimizer_state_dict': {},
|
|
'scheduler_state_dict': {},
|
|
'loss': 0.5,
|
|
'val_loss': 0.25,
|
|
}
|
|
stale_best_checkpoint_state = {
|
|
'step': 0,
|
|
'model_state_dict': FakeAgent().state_dict(),
|
|
'optimizer_state_dict': {},
|
|
'scheduler_state_dict': {},
|
|
'loss': 0.4,
|
|
'val_loss': 0.2,
|
|
}
|
|
|
|
def fake_torch_load(path, map_location=None):
|
|
del map_location
|
|
path = Path(path)
|
|
if path == resume_path:
|
|
return resume_checkpoint_state
|
|
if path == best_path:
|
|
return stale_best_checkpoint_state
|
|
raise AssertionError(f'unexpected load path: {path}')
|
|
|
|
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
|
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
|
|
mock.patch.object(module, 'build_training_optimizer', return_value=fake_optimizer), \
|
|
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=fake_scheduler), \
|
|
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
|
|
mock.patch.object(module.torch, 'save', return_value=None), \
|
|
mock.patch.object(module.torch, 'load', side_effect=fake_torch_load), \
|
|
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
|
|
run_training(cfg)
|
|
finally:
|
|
os.chdir(previous_cwd)
|
|
|
|
final_payload, final_step = fake_swanlab.log_calls[-1]
|
|
self.assertEqual(final_step, cfg.train.max_steps)
|
|
self.assertEqual(final_payload['final/best_checkpoint_path'], 'checkpoints/vla_model_step_0.pt')
|
|
|
|
def test_run_training_ignores_stale_best_checkpoint_file_on_fresh_non_resume_run(self):
|
|
module = self._load_train_vla_module()
|
|
run_training = self._get_run_training(module)
|
|
cfg = self._make_cfg()
|
|
cfg.train.max_steps = 1
|
|
fake_swanlab = FakeSwanLab()
|
|
real_import_module = importlib.import_module
|
|
|
|
def fake_instantiate(config_node, **_kwargs):
|
|
if config_node is cfg.data:
|
|
return FakeDataset()
|
|
if config_node is cfg.agent:
|
|
return FakeAgent()
|
|
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
|
|
|
def fake_import_module(name, package=None):
|
|
if name == 'swanlab':
|
|
return fake_swanlab
|
|
return real_import_module(name, package)
|
|
|
|
with tempfile.TemporaryDirectory() as tempdir:
|
|
previous_cwd = os.getcwd()
|
|
try:
|
|
os.chdir(tempdir)
|
|
checkpoint_dir = Path('checkpoints')
|
|
checkpoint_dir.mkdir()
|
|
(checkpoint_dir / 'vla_model_best.pt').write_bytes(b'stale-best')
|
|
|
|
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
|
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
|
|
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
|
|
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
|
|
mock.patch.object(module.torch, 'save', return_value=None), \
|
|
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
|
|
run_training(cfg)
|
|
finally:
|
|
os.chdir(previous_cwd)
|
|
|
|
final_payload, final_step = fake_swanlab.log_calls[-1]
|
|
self.assertEqual(final_step, cfg.train.max_steps)
|
|
self.assertEqual(final_payload['final/best_checkpoint_path'], '')
|
|
|
|
def test_run_training_fails_fast_when_swanlab_import_is_unavailable(self):
|
|
module = self._load_train_vla_module()
|
|
run_training = self._get_run_training(module)
|
|
cfg = self._make_cfg()
|
|
real_import_module = importlib.import_module
|
|
|
|
def fake_import_module(name, package=None):
|
|
if name == 'swanlab':
|
|
raise ImportError('missing swanlab')
|
|
return real_import_module(name, package)
|
|
|
|
with mock.patch.object(module, 'instantiate', side_effect=AssertionError('instantiate should not run')), \
|
|
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
|
|
with self.assertRaisesRegex(RuntimeError, 'SwanLab'):
|
|
run_training(cfg)
|
|
|
|
def test_run_training_fails_fast_when_swanlab_init_fails(self):
|
|
module = self._load_train_vla_module()
|
|
run_training = self._get_run_training(module)
|
|
cfg = self._make_cfg()
|
|
fake_swanlab = FakeSwanLab(init_error=RuntimeError('not logged in'))
|
|
real_import_module = importlib.import_module
|
|
|
|
def fake_import_module(name, package=None):
|
|
if name == 'swanlab':
|
|
return fake_swanlab
|
|
return real_import_module(name, package)
|
|
|
|
with mock.patch.object(module, 'instantiate', side_effect=AssertionError('instantiate should not run')), \
|
|
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
|
|
with self.assertRaisesRegex(RuntimeError, 'not logged in'):
|
|
run_training(cfg)
|
|
|
|
self.assertEqual(fake_swanlab.finish_calls, 0)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|