feat(vla): align transformer training stack and rollout validation
This commit is contained in:
699
tests/test_train_vla_swanlab_logging.py
Normal file
699
tests/test_train_vla_swanlab_logging.py
Normal file
@@ -0,0 +1,699 @@
|
||||
import importlib
|
||||
import importlib.util
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import types
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
_TRAIN_VLA_PATH = _REPO_ROOT / 'roboimi/demos/vla_scripts/train_vla.py'
|
||||
_CONFIG_PATH = _REPO_ROOT / 'roboimi/vla/conf/config.yaml'
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
def __getattr__(self, name):
|
||||
try:
|
||||
return self[name]
|
||||
except KeyError as exc:
|
||||
raise AttributeError(name) from exc
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
self[name] = value
|
||||
|
||||
|
||||
def _to_attrdict(value):
|
||||
if isinstance(value, dict):
|
||||
return AttrDict({key: _to_attrdict(item) for key, item in value.items()})
|
||||
if isinstance(value, list):
|
||||
return [_to_attrdict(item) for item in value]
|
||||
return value
|
||||
|
||||
|
||||
class FakeDataset:
|
||||
def __len__(self):
|
||||
return 4
|
||||
|
||||
|
||||
class FakeLoader:
|
||||
def __init__(self, batch):
|
||||
self.batch = batch
|
||||
|
||||
def __len__(self):
|
||||
return 1
|
||||
|
||||
def __iter__(self):
|
||||
return iter((self.batch,))
|
||||
|
||||
|
||||
class FakeScheduler:
|
||||
def __init__(self):
|
||||
self.step_calls = 0
|
||||
|
||||
def step(self):
|
||||
self.step_calls += 1
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
return None
|
||||
|
||||
|
||||
class FakeOptimizer:
|
||||
def __init__(self, lr=1e-3):
|
||||
self.param_groups = [{'lr': lr}]
|
||||
self.loaded_state_dict = None
|
||||
|
||||
def zero_grad(self):
|
||||
return None
|
||||
|
||||
def step(self):
|
||||
return None
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.loaded_state_dict = state_dict
|
||||
return None
|
||||
|
||||
|
||||
class FakeProgressBar:
|
||||
def __init__(self, iterable):
|
||||
self._items = list(iterable)
|
||||
self.postfix_calls = []
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._items)
|
||||
|
||||
def set_postfix(self, values):
|
||||
self.postfix_calls.append(values)
|
||||
|
||||
|
||||
class FakeAgent(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.tensor(0.0))
|
||||
|
||||
def to(self, device):
|
||||
return self
|
||||
|
||||
def compute_loss(self, agent_input):
|
||||
del agent_input
|
||||
target = torch.tensor(0.25 if self.training else 0.1)
|
||||
return (self.weight - target).pow(2)
|
||||
|
||||
def get_normalization_stats(self):
|
||||
return {}
|
||||
|
||||
|
||||
class FakeSwanLab:
|
||||
def __init__(self, init_error=None, log_errors=None, finish_error=None):
|
||||
self.init_error = init_error
|
||||
self.log_errors = list(log_errors or [])
|
||||
self.finish_error = finish_error
|
||||
self.init_calls = []
|
||||
self.log_calls = []
|
||||
self.finish_calls = 0
|
||||
|
||||
def init(self, project, experiment_name=None, config=None):
|
||||
self.init_calls.append({
|
||||
'project': project,
|
||||
'experiment_name': experiment_name,
|
||||
'config': config,
|
||||
})
|
||||
if self.init_error is not None:
|
||||
raise self.init_error
|
||||
return object()
|
||||
|
||||
def log(self, payload, step=None):
|
||||
self.log_calls.append((dict(payload), step))
|
||||
if self.log_errors:
|
||||
raise self.log_errors.pop(0)
|
||||
|
||||
def finish(self):
|
||||
self.finish_calls += 1
|
||||
if self.finish_error is not None:
|
||||
raise self.finish_error
|
||||
|
||||
|
||||
class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
||||
def test_default_config_keeps_swanlab_opt_in(self):
|
||||
config_text = _CONFIG_PATH.read_text(encoding='utf-8')
|
||||
self.assertIn('use_swanlab: false', config_text)
|
||||
|
||||
def _load_train_vla_module(self):
|
||||
hydra_module = types.ModuleType('hydra')
|
||||
hydra_utils_module = types.ModuleType('hydra.utils')
|
||||
hydra_utils_module.instantiate = lambda *args, **kwargs: None
|
||||
|
||||
def hydra_main(**_kwargs):
|
||||
def decorator(func):
|
||||
return func
|
||||
return decorator
|
||||
|
||||
hydra_module.main = hydra_main
|
||||
hydra_module.utils = hydra_utils_module
|
||||
|
||||
class OmegaConfStub:
|
||||
_resolvers = {}
|
||||
|
||||
@classmethod
|
||||
def has_resolver(cls, name):
|
||||
return name in cls._resolvers
|
||||
|
||||
@classmethod
|
||||
def register_new_resolver(cls, name, resolver):
|
||||
cls._resolvers[name] = resolver
|
||||
|
||||
@staticmethod
|
||||
def to_yaml(_cfg):
|
||||
return 'stub-config'
|
||||
|
||||
@staticmethod
|
||||
def to_container(cfg, resolve=False):
|
||||
del resolve
|
||||
return dict(cfg)
|
||||
|
||||
@staticmethod
|
||||
def create(cfg):
|
||||
return _to_attrdict(cfg)
|
||||
|
||||
omegaconf_module = types.ModuleType('omegaconf')
|
||||
omegaconf_module.DictConfig = dict
|
||||
omegaconf_module.OmegaConf = OmegaConfStub
|
||||
|
||||
module_name = 'train_vla_swanlab_test_module'
|
||||
spec = importlib.util.spec_from_file_location(module_name, _TRAIN_VLA_PATH)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
with mock.patch.dict(
|
||||
sys.modules,
|
||||
{
|
||||
'hydra': hydra_module,
|
||||
'hydra.utils': hydra_utils_module,
|
||||
'omegaconf': omegaconf_module,
|
||||
},
|
||||
):
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
def _make_cfg(self, *, use_swanlab=True, swanlab_run_name='smoke-run'):
|
||||
return AttrDict(
|
||||
train=AttrDict(
|
||||
device='cpu',
|
||||
batch_size=2,
|
||||
num_workers=0,
|
||||
val_split=0.25,
|
||||
seed=0,
|
||||
lr=1e-3,
|
||||
max_steps=2,
|
||||
log_freq=1,
|
||||
save_freq=1,
|
||||
warmup_steps=1,
|
||||
scheduler_type='constant',
|
||||
min_lr=0.0,
|
||||
grad_clip=1.0,
|
||||
weight_decay=0.0,
|
||||
pretrained_ckpt=None,
|
||||
resume_ckpt=None,
|
||||
use_swanlab=use_swanlab,
|
||||
swanlab_project='roboimi-vla-tests',
|
||||
swanlab_run_name=swanlab_run_name,
|
||||
),
|
||||
data=AttrDict(
|
||||
camera_names=('front',),
|
||||
),
|
||||
agent=AttrDict(
|
||||
_target_='fake.agent',
|
||||
),
|
||||
eval=AttrDict(
|
||||
ckpt_path='unused.pt',
|
||||
num_episodes=1,
|
||||
max_timesteps=1,
|
||||
device='cpu',
|
||||
task_name='sim_transfer',
|
||||
camera_names=('front',),
|
||||
use_smoothing=False,
|
||||
smooth_alpha=0.3,
|
||||
verbose_action=False,
|
||||
headless=False,
|
||||
),
|
||||
)
|
||||
|
||||
def _get_run_training(self, module):
|
||||
run_training = getattr(module, '_run_training', None)
|
||||
self.assertIsNotNone(run_training, 'Expected train_vla.py to expose a _run_training(cfg) helper')
|
||||
return run_training
|
||||
|
||||
def _make_batch(self):
|
||||
return {
|
||||
'observation.front': torch.zeros(1, 3, 2, 2),
|
||||
'observation.state': torch.zeros(1, 4),
|
||||
'action': torch.zeros(1, 2),
|
||||
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
|
||||
}
|
||||
|
||||
def _loader_factory(self):
|
||||
train_batch = self._make_batch()
|
||||
val_batch = self._make_batch()
|
||||
|
||||
def factory(_dataset, *, shuffle, **_kwargs):
|
||||
return FakeLoader(train_batch if shuffle else val_batch)
|
||||
|
||||
return factory
|
||||
|
||||
def test_run_training_logs_metrics_and_checkpoint_paths_to_swanlab(self):
|
||||
module = self._load_train_vla_module()
|
||||
run_training = self._get_run_training(module)
|
||||
cfg = self._make_cfg()
|
||||
agent = FakeAgent()
|
||||
fake_swanlab = FakeSwanLab()
|
||||
real_import_module = importlib.import_module
|
||||
|
||||
def fake_instantiate(config_node, **_kwargs):
|
||||
if config_node is cfg.data:
|
||||
return FakeDataset()
|
||||
if config_node is cfg.agent:
|
||||
return agent
|
||||
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||
|
||||
def fake_import_module(name, package=None):
|
||||
if name == 'swanlab':
|
||||
return fake_swanlab
|
||||
return real_import_module(name, package)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
||||
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
|
||||
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
|
||||
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
|
||||
mock.patch.object(module.torch, 'save', return_value=None), \
|
||||
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
|
||||
run_training(cfg)
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
self.assertEqual(
|
||||
fake_swanlab.init_calls,
|
||||
[{
|
||||
'project': 'roboimi-vla-tests',
|
||||
'experiment_name': 'smoke-run',
|
||||
'config': {
|
||||
'train': {
|
||||
'device': 'cpu',
|
||||
'batch_size': 2,
|
||||
'num_workers': 0,
|
||||
'val_split': 0.25,
|
||||
'seed': 0,
|
||||
'lr': 1e-3,
|
||||
'max_steps': 2,
|
||||
'log_freq': 1,
|
||||
'save_freq': 1,
|
||||
'warmup_steps': 1,
|
||||
'scheduler_type': 'constant',
|
||||
'min_lr': 0.0,
|
||||
'grad_clip': 1.0,
|
||||
'weight_decay': 0.0,
|
||||
'pretrained_ckpt': None,
|
||||
'resume_ckpt': None,
|
||||
'use_swanlab': True,
|
||||
'swanlab_project': 'roboimi-vla-tests',
|
||||
'swanlab_run_name': 'smoke-run',
|
||||
},
|
||||
'data': {
|
||||
'camera_names': ('front',),
|
||||
},
|
||||
'agent': {
|
||||
'_target_': 'fake.agent',
|
||||
},
|
||||
},
|
||||
}],
|
||||
)
|
||||
|
||||
logged_keys = set().union(*(payload.keys() for payload, _step in fake_swanlab.log_calls))
|
||||
self.assertTrue(
|
||||
{
|
||||
'train/loss',
|
||||
'train/lr',
|
||||
'train/best_loss',
|
||||
'train/step',
|
||||
'val/loss',
|
||||
'final/checkpoint_path',
|
||||
'final/best_checkpoint_path',
|
||||
}.issubset(logged_keys)
|
||||
)
|
||||
|
||||
final_payload, final_step = fake_swanlab.log_calls[-1]
|
||||
self.assertEqual(final_step, cfg.train.max_steps)
|
||||
self.assertEqual(final_payload['final/checkpoint_path'], 'checkpoints/vla_model_final.pt')
|
||||
self.assertEqual(final_payload['final/best_checkpoint_path'], 'checkpoints/vla_model_best.pt')
|
||||
self.assertEqual(fake_swanlab.finish_calls, 1)
|
||||
|
||||
def test_run_training_skips_swanlab_when_disabled(self):
|
||||
module = self._load_train_vla_module()
|
||||
run_training = self._get_run_training(module)
|
||||
cfg = self._make_cfg(use_swanlab=False)
|
||||
agent = FakeAgent()
|
||||
|
||||
def fake_instantiate(config_node, **_kwargs):
|
||||
if config_node is cfg.data:
|
||||
return FakeDataset()
|
||||
if config_node is cfg.agent:
|
||||
return agent
|
||||
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
||||
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
|
||||
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
|
||||
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
|
||||
mock.patch.object(module.torch, 'save', return_value=None), \
|
||||
mock.patch.object(module.importlib, 'import_module', side_effect=AssertionError('swanlab import should not run')):
|
||||
run_training(cfg)
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
def test_run_training_finishes_swanlab_when_exception_happens_after_init(self):
|
||||
module = self._load_train_vla_module()
|
||||
run_training = self._get_run_training(module)
|
||||
cfg = self._make_cfg()
|
||||
fake_swanlab = FakeSwanLab()
|
||||
real_import_module = importlib.import_module
|
||||
|
||||
def fake_import_module(name, package=None):
|
||||
if name == 'swanlab':
|
||||
return fake_swanlab
|
||||
return real_import_module(name, package)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
with mock.patch.object(module, 'instantiate', side_effect=RuntimeError('dataset boom')), \
|
||||
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
|
||||
with self.assertRaisesRegex(RuntimeError, 'dataset boom'):
|
||||
run_training(cfg)
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
self.assertEqual(fake_swanlab.finish_calls, 1)
|
||||
|
||||
def test_run_training_warns_and_continues_when_swanlab_log_and_finish_fail(self):
|
||||
module = self._load_train_vla_module()
|
||||
run_training = self._get_run_training(module)
|
||||
cfg = self._make_cfg()
|
||||
agent = FakeAgent()
|
||||
fake_swanlab = FakeSwanLab(
|
||||
log_errors=[RuntimeError('log backend hiccup')],
|
||||
finish_error=RuntimeError('finish backend hiccup'),
|
||||
)
|
||||
real_import_module = importlib.import_module
|
||||
|
||||
def fake_instantiate(config_node, **_kwargs):
|
||||
if config_node is cfg.data:
|
||||
return FakeDataset()
|
||||
if config_node is cfg.agent:
|
||||
return agent
|
||||
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||
|
||||
def fake_import_module(name, package=None):
|
||||
if name == 'swanlab':
|
||||
return fake_swanlab
|
||||
return real_import_module(name, package)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
||||
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
|
||||
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
|
||||
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
|
||||
mock.patch.object(module.torch, 'save', return_value=None), \
|
||||
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module), \
|
||||
mock.patch.object(module.log, 'warning') as warning_mock:
|
||||
run_training(cfg)
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
warning_messages = [call.args[0] for call in warning_mock.call_args_list]
|
||||
self.assertTrue(any('SwanLab log failed' in message for message in warning_messages))
|
||||
self.assertTrue(any('SwanLab finish failed' in message for message in warning_messages))
|
||||
self.assertEqual(fake_swanlab.finish_calls, 1)
|
||||
|
||||
def test_run_training_resume_restores_best_rollout_baseline_from_best_checkpoint(self):
|
||||
module = self._load_train_vla_module()
|
||||
run_training = self._get_run_training(module)
|
||||
cfg = self._make_cfg()
|
||||
cfg.train.max_steps = 2
|
||||
cfg.train.save_freq = 1
|
||||
cfg.train.rollout_validate_on_checkpoint = True
|
||||
fake_swanlab = FakeSwanLab()
|
||||
fake_optimizer = FakeOptimizer(lr=cfg.train.lr)
|
||||
fake_scheduler = FakeScheduler()
|
||||
real_import_module = importlib.import_module
|
||||
saved_paths = []
|
||||
|
||||
def fake_instantiate(config_node, **_kwargs):
|
||||
if config_node is cfg.data:
|
||||
return FakeDataset()
|
||||
if config_node is cfg.agent:
|
||||
return FakeAgent()
|
||||
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||
|
||||
def fake_import_module(name, package=None):
|
||||
if name == 'swanlab':
|
||||
return fake_swanlab
|
||||
return real_import_module(name, package)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
checkpoint_dir = Path('checkpoints')
|
||||
checkpoint_dir.mkdir()
|
||||
resume_path = checkpoint_dir / 'vla_model_step_0.pt'
|
||||
resume_path.write_bytes(b'resume')
|
||||
best_path = checkpoint_dir / 'vla_model_best.pt'
|
||||
best_path.write_bytes(b'best')
|
||||
cfg.train.resume_ckpt = str(resume_path)
|
||||
|
||||
resume_checkpoint_state = {
|
||||
'step': 0,
|
||||
'model_state_dict': FakeAgent().state_dict(),
|
||||
'optimizer_state_dict': {},
|
||||
'scheduler_state_dict': {},
|
||||
'loss': 0.5,
|
||||
'val_loss': 0.25,
|
||||
}
|
||||
best_checkpoint_state = {
|
||||
'step': 0,
|
||||
'model_state_dict': FakeAgent().state_dict(),
|
||||
'optimizer_state_dict': {},
|
||||
'scheduler_state_dict': {},
|
||||
'loss': 0.5,
|
||||
'val_loss': 0.25,
|
||||
'rollout_avg_reward': 5.0,
|
||||
}
|
||||
|
||||
def fake_torch_load(path, map_location=None):
|
||||
del map_location
|
||||
path = Path(path)
|
||||
if path == resume_path:
|
||||
return resume_checkpoint_state
|
||||
if path == best_path:
|
||||
return best_checkpoint_state
|
||||
raise AssertionError(f'unexpected load path: {path}')
|
||||
|
||||
def fake_torch_save(payload, path):
|
||||
saved_paths.append(str(path))
|
||||
return None
|
||||
|
||||
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
||||
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
|
||||
mock.patch.object(module, 'build_training_optimizer', return_value=fake_optimizer), \
|
||||
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=fake_scheduler), \
|
||||
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
|
||||
mock.patch.object(module.torch, 'save', side_effect=fake_torch_save), \
|
||||
mock.patch.object(module.torch, 'load', side_effect=fake_torch_load), \
|
||||
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module), \
|
||||
mock.patch('roboimi.demos.vla_scripts.eval_vla._run_eval', return_value={'avg_reward': 3.0}):
|
||||
run_training(cfg)
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
final_payload, final_step = fake_swanlab.log_calls[-1]
|
||||
self.assertEqual(final_step, cfg.train.max_steps)
|
||||
self.assertEqual(final_payload['final/best_checkpoint_path'], 'checkpoints/vla_model_best.pt')
|
||||
self.assertNotIn('checkpoints/vla_model_best.pt', saved_paths)
|
||||
|
||||
def test_run_training_resume_ignores_best_checkpoint_without_rollout_metric(self):
|
||||
module = self._load_train_vla_module()
|
||||
run_training = self._get_run_training(module)
|
||||
cfg = self._make_cfg()
|
||||
cfg.train.max_steps = 1
|
||||
fake_swanlab = FakeSwanLab()
|
||||
fake_optimizer = FakeOptimizer(lr=cfg.train.lr)
|
||||
fake_scheduler = FakeScheduler()
|
||||
real_import_module = importlib.import_module
|
||||
|
||||
def fake_instantiate(config_node, **_kwargs):
|
||||
if config_node is cfg.data:
|
||||
return FakeDataset()
|
||||
if config_node is cfg.agent:
|
||||
return FakeAgent()
|
||||
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||
|
||||
def fake_import_module(name, package=None):
|
||||
if name == 'swanlab':
|
||||
return fake_swanlab
|
||||
return real_import_module(name, package)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
checkpoint_dir = Path('checkpoints')
|
||||
checkpoint_dir.mkdir()
|
||||
resume_path = checkpoint_dir / 'vla_model_step_0.pt'
|
||||
resume_path.write_bytes(b'resume')
|
||||
best_path = checkpoint_dir / 'vla_model_best.pt'
|
||||
best_path.write_bytes(b'stale')
|
||||
cfg.train.resume_ckpt = str(resume_path)
|
||||
|
||||
resume_checkpoint_state = {
|
||||
'step': 0,
|
||||
'model_state_dict': FakeAgent().state_dict(),
|
||||
'optimizer_state_dict': {},
|
||||
'scheduler_state_dict': {},
|
||||
'loss': 0.5,
|
||||
'val_loss': 0.25,
|
||||
}
|
||||
stale_best_checkpoint_state = {
|
||||
'step': 0,
|
||||
'model_state_dict': FakeAgent().state_dict(),
|
||||
'optimizer_state_dict': {},
|
||||
'scheduler_state_dict': {},
|
||||
'loss': 0.4,
|
||||
'val_loss': 0.2,
|
||||
}
|
||||
|
||||
def fake_torch_load(path, map_location=None):
|
||||
del map_location
|
||||
path = Path(path)
|
||||
if path == resume_path:
|
||||
return resume_checkpoint_state
|
||||
if path == best_path:
|
||||
return stale_best_checkpoint_state
|
||||
raise AssertionError(f'unexpected load path: {path}')
|
||||
|
||||
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
||||
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
|
||||
mock.patch.object(module, 'build_training_optimizer', return_value=fake_optimizer), \
|
||||
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=fake_scheduler), \
|
||||
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
|
||||
mock.patch.object(module.torch, 'save', return_value=None), \
|
||||
mock.patch.object(module.torch, 'load', side_effect=fake_torch_load), \
|
||||
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
|
||||
run_training(cfg)
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
final_payload, final_step = fake_swanlab.log_calls[-1]
|
||||
self.assertEqual(final_step, cfg.train.max_steps)
|
||||
self.assertEqual(final_payload['final/best_checkpoint_path'], 'checkpoints/vla_model_step_0.pt')
|
||||
|
||||
def test_run_training_ignores_stale_best_checkpoint_file_on_fresh_non_resume_run(self):
|
||||
module = self._load_train_vla_module()
|
||||
run_training = self._get_run_training(module)
|
||||
cfg = self._make_cfg()
|
||||
cfg.train.max_steps = 1
|
||||
fake_swanlab = FakeSwanLab()
|
||||
real_import_module = importlib.import_module
|
||||
|
||||
def fake_instantiate(config_node, **_kwargs):
|
||||
if config_node is cfg.data:
|
||||
return FakeDataset()
|
||||
if config_node is cfg.agent:
|
||||
return FakeAgent()
|
||||
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||
|
||||
def fake_import_module(name, package=None):
|
||||
if name == 'swanlab':
|
||||
return fake_swanlab
|
||||
return real_import_module(name, package)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
checkpoint_dir = Path('checkpoints')
|
||||
checkpoint_dir.mkdir()
|
||||
(checkpoint_dir / 'vla_model_best.pt').write_bytes(b'stale-best')
|
||||
|
||||
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
||||
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
|
||||
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
|
||||
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
|
||||
mock.patch.object(module.torch, 'save', return_value=None), \
|
||||
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
|
||||
run_training(cfg)
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
final_payload, final_step = fake_swanlab.log_calls[-1]
|
||||
self.assertEqual(final_step, cfg.train.max_steps)
|
||||
self.assertEqual(final_payload['final/best_checkpoint_path'], '')
|
||||
|
||||
def test_run_training_fails_fast_when_swanlab_import_is_unavailable(self):
|
||||
module = self._load_train_vla_module()
|
||||
run_training = self._get_run_training(module)
|
||||
cfg = self._make_cfg()
|
||||
real_import_module = importlib.import_module
|
||||
|
||||
def fake_import_module(name, package=None):
|
||||
if name == 'swanlab':
|
||||
raise ImportError('missing swanlab')
|
||||
return real_import_module(name, package)
|
||||
|
||||
with mock.patch.object(module, 'instantiate', side_effect=AssertionError('instantiate should not run')), \
|
||||
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
|
||||
with self.assertRaisesRegex(RuntimeError, 'SwanLab'):
|
||||
run_training(cfg)
|
||||
|
||||
def test_run_training_fails_fast_when_swanlab_init_fails(self):
|
||||
module = self._load_train_vla_module()
|
||||
run_training = self._get_run_training(module)
|
||||
cfg = self._make_cfg()
|
||||
fake_swanlab = FakeSwanLab(init_error=RuntimeError('not logged in'))
|
||||
real_import_module = importlib.import_module
|
||||
|
||||
def fake_import_module(name, package=None):
|
||||
if name == 'swanlab':
|
||||
return fake_swanlab
|
||||
return real_import_module(name, package)
|
||||
|
||||
with mock.patch.object(module, 'instantiate', side_effect=AssertionError('instantiate should not run')), \
|
||||
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
|
||||
with self.assertRaisesRegex(RuntimeError, 'not logged in'):
|
||||
run_training(cfg)
|
||||
|
||||
self.assertEqual(fake_swanlab.finish_calls, 0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user