feat(vla): align transformer training stack and rollout validation

This commit is contained in:
Logic
2026-03-31 15:39:20 +08:00
parent 424c265823
commit d84bc6876e
25 changed files with 4043 additions and 706 deletions

View File

@@ -0,0 +1,699 @@
import importlib
import importlib.util
import os
import sys
import tempfile
import types
import unittest
from pathlib import Path
from unittest import mock
import torch
from torch import nn
_REPO_ROOT = Path(__file__).resolve().parents[1]
_TRAIN_VLA_PATH = _REPO_ROOT / 'roboimi/demos/vla_scripts/train_vla.py'
_CONFIG_PATH = _REPO_ROOT / 'roboimi/vla/conf/config.yaml'
class AttrDict(dict):
def __getattr__(self, name):
try:
return self[name]
except KeyError as exc:
raise AttributeError(name) from exc
def __setattr__(self, name, value):
self[name] = value
def _to_attrdict(value):
if isinstance(value, dict):
return AttrDict({key: _to_attrdict(item) for key, item in value.items()})
if isinstance(value, list):
return [_to_attrdict(item) for item in value]
return value
class FakeDataset:
def __len__(self):
return 4
class FakeLoader:
def __init__(self, batch):
self.batch = batch
def __len__(self):
return 1
def __iter__(self):
return iter((self.batch,))
class FakeScheduler:
def __init__(self):
self.step_calls = 0
def step(self):
self.step_calls += 1
def state_dict(self):
return {}
def load_state_dict(self, state_dict):
return None
class FakeOptimizer:
def __init__(self, lr=1e-3):
self.param_groups = [{'lr': lr}]
self.loaded_state_dict = None
def zero_grad(self):
return None
def step(self):
return None
def state_dict(self):
return {}
def load_state_dict(self, state_dict):
self.loaded_state_dict = state_dict
return None
class FakeProgressBar:
def __init__(self, iterable):
self._items = list(iterable)
self.postfix_calls = []
def __iter__(self):
return iter(self._items)
def set_postfix(self, values):
self.postfix_calls.append(values)
class FakeAgent(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.tensor(0.0))
def to(self, device):
return self
def compute_loss(self, agent_input):
del agent_input
target = torch.tensor(0.25 if self.training else 0.1)
return (self.weight - target).pow(2)
def get_normalization_stats(self):
return {}
class FakeSwanLab:
def __init__(self, init_error=None, log_errors=None, finish_error=None):
self.init_error = init_error
self.log_errors = list(log_errors or [])
self.finish_error = finish_error
self.init_calls = []
self.log_calls = []
self.finish_calls = 0
def init(self, project, experiment_name=None, config=None):
self.init_calls.append({
'project': project,
'experiment_name': experiment_name,
'config': config,
})
if self.init_error is not None:
raise self.init_error
return object()
def log(self, payload, step=None):
self.log_calls.append((dict(payload), step))
if self.log_errors:
raise self.log_errors.pop(0)
def finish(self):
self.finish_calls += 1
if self.finish_error is not None:
raise self.finish_error
class TrainVLASwanLabLoggingTest(unittest.TestCase):
def test_default_config_keeps_swanlab_opt_in(self):
config_text = _CONFIG_PATH.read_text(encoding='utf-8')
self.assertIn('use_swanlab: false', config_text)
def _load_train_vla_module(self):
hydra_module = types.ModuleType('hydra')
hydra_utils_module = types.ModuleType('hydra.utils')
hydra_utils_module.instantiate = lambda *args, **kwargs: None
def hydra_main(**_kwargs):
def decorator(func):
return func
return decorator
hydra_module.main = hydra_main
hydra_module.utils = hydra_utils_module
class OmegaConfStub:
_resolvers = {}
@classmethod
def has_resolver(cls, name):
return name in cls._resolvers
@classmethod
def register_new_resolver(cls, name, resolver):
cls._resolvers[name] = resolver
@staticmethod
def to_yaml(_cfg):
return 'stub-config'
@staticmethod
def to_container(cfg, resolve=False):
del resolve
return dict(cfg)
@staticmethod
def create(cfg):
return _to_attrdict(cfg)
omegaconf_module = types.ModuleType('omegaconf')
omegaconf_module.DictConfig = dict
omegaconf_module.OmegaConf = OmegaConfStub
module_name = 'train_vla_swanlab_test_module'
spec = importlib.util.spec_from_file_location(module_name, _TRAIN_VLA_PATH)
module = importlib.util.module_from_spec(spec)
with mock.patch.dict(
sys.modules,
{
'hydra': hydra_module,
'hydra.utils': hydra_utils_module,
'omegaconf': omegaconf_module,
},
):
assert spec.loader is not None
spec.loader.exec_module(module)
return module
def _make_cfg(self, *, use_swanlab=True, swanlab_run_name='smoke-run'):
return AttrDict(
train=AttrDict(
device='cpu',
batch_size=2,
num_workers=0,
val_split=0.25,
seed=0,
lr=1e-3,
max_steps=2,
log_freq=1,
save_freq=1,
warmup_steps=1,
scheduler_type='constant',
min_lr=0.0,
grad_clip=1.0,
weight_decay=0.0,
pretrained_ckpt=None,
resume_ckpt=None,
use_swanlab=use_swanlab,
swanlab_project='roboimi-vla-tests',
swanlab_run_name=swanlab_run_name,
),
data=AttrDict(
camera_names=('front',),
),
agent=AttrDict(
_target_='fake.agent',
),
eval=AttrDict(
ckpt_path='unused.pt',
num_episodes=1,
max_timesteps=1,
device='cpu',
task_name='sim_transfer',
camera_names=('front',),
use_smoothing=False,
smooth_alpha=0.3,
verbose_action=False,
headless=False,
),
)
def _get_run_training(self, module):
run_training = getattr(module, '_run_training', None)
self.assertIsNotNone(run_training, 'Expected train_vla.py to expose a _run_training(cfg) helper')
return run_training
def _make_batch(self):
return {
'observation.front': torch.zeros(1, 3, 2, 2),
'observation.state': torch.zeros(1, 4),
'action': torch.zeros(1, 2),
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
}
def _loader_factory(self):
train_batch = self._make_batch()
val_batch = self._make_batch()
def factory(_dataset, *, shuffle, **_kwargs):
return FakeLoader(train_batch if shuffle else val_batch)
return factory
def test_run_training_logs_metrics_and_checkpoint_paths_to_swanlab(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
cfg = self._make_cfg()
agent = FakeAgent()
fake_swanlab = FakeSwanLab()
real_import_module = importlib.import_module
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return FakeDataset()
if config_node is cfg.agent:
return agent
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
def fake_import_module(name, package=None):
if name == 'swanlab':
return fake_swanlab
return real_import_module(name, package)
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
mock.patch.object(module.torch, 'save', return_value=None), \
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
run_training(cfg)
finally:
os.chdir(previous_cwd)
self.assertEqual(
fake_swanlab.init_calls,
[{
'project': 'roboimi-vla-tests',
'experiment_name': 'smoke-run',
'config': {
'train': {
'device': 'cpu',
'batch_size': 2,
'num_workers': 0,
'val_split': 0.25,
'seed': 0,
'lr': 1e-3,
'max_steps': 2,
'log_freq': 1,
'save_freq': 1,
'warmup_steps': 1,
'scheduler_type': 'constant',
'min_lr': 0.0,
'grad_clip': 1.0,
'weight_decay': 0.0,
'pretrained_ckpt': None,
'resume_ckpt': None,
'use_swanlab': True,
'swanlab_project': 'roboimi-vla-tests',
'swanlab_run_name': 'smoke-run',
},
'data': {
'camera_names': ('front',),
},
'agent': {
'_target_': 'fake.agent',
},
},
}],
)
logged_keys = set().union(*(payload.keys() for payload, _step in fake_swanlab.log_calls))
self.assertTrue(
{
'train/loss',
'train/lr',
'train/best_loss',
'train/step',
'val/loss',
'final/checkpoint_path',
'final/best_checkpoint_path',
}.issubset(logged_keys)
)
final_payload, final_step = fake_swanlab.log_calls[-1]
self.assertEqual(final_step, cfg.train.max_steps)
self.assertEqual(final_payload['final/checkpoint_path'], 'checkpoints/vla_model_final.pt')
self.assertEqual(final_payload['final/best_checkpoint_path'], 'checkpoints/vla_model_best.pt')
self.assertEqual(fake_swanlab.finish_calls, 1)
def test_run_training_skips_swanlab_when_disabled(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
cfg = self._make_cfg(use_swanlab=False)
agent = FakeAgent()
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return FakeDataset()
if config_node is cfg.agent:
return agent
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
mock.patch.object(module.torch, 'save', return_value=None), \
mock.patch.object(module.importlib, 'import_module', side_effect=AssertionError('swanlab import should not run')):
run_training(cfg)
finally:
os.chdir(previous_cwd)
def test_run_training_finishes_swanlab_when_exception_happens_after_init(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
cfg = self._make_cfg()
fake_swanlab = FakeSwanLab()
real_import_module = importlib.import_module
def fake_import_module(name, package=None):
if name == 'swanlab':
return fake_swanlab
return real_import_module(name, package)
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
with mock.patch.object(module, 'instantiate', side_effect=RuntimeError('dataset boom')), \
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
with self.assertRaisesRegex(RuntimeError, 'dataset boom'):
run_training(cfg)
finally:
os.chdir(previous_cwd)
self.assertEqual(fake_swanlab.finish_calls, 1)
def test_run_training_warns_and_continues_when_swanlab_log_and_finish_fail(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
cfg = self._make_cfg()
agent = FakeAgent()
fake_swanlab = FakeSwanLab(
log_errors=[RuntimeError('log backend hiccup')],
finish_error=RuntimeError('finish backend hiccup'),
)
real_import_module = importlib.import_module
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return FakeDataset()
if config_node is cfg.agent:
return agent
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
def fake_import_module(name, package=None):
if name == 'swanlab':
return fake_swanlab
return real_import_module(name, package)
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
mock.patch.object(module.torch, 'save', return_value=None), \
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module), \
mock.patch.object(module.log, 'warning') as warning_mock:
run_training(cfg)
finally:
os.chdir(previous_cwd)
warning_messages = [call.args[0] for call in warning_mock.call_args_list]
self.assertTrue(any('SwanLab log failed' in message for message in warning_messages))
self.assertTrue(any('SwanLab finish failed' in message for message in warning_messages))
self.assertEqual(fake_swanlab.finish_calls, 1)
def test_run_training_resume_restores_best_rollout_baseline_from_best_checkpoint(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
cfg = self._make_cfg()
cfg.train.max_steps = 2
cfg.train.save_freq = 1
cfg.train.rollout_validate_on_checkpoint = True
fake_swanlab = FakeSwanLab()
fake_optimizer = FakeOptimizer(lr=cfg.train.lr)
fake_scheduler = FakeScheduler()
real_import_module = importlib.import_module
saved_paths = []
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return FakeDataset()
if config_node is cfg.agent:
return FakeAgent()
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
def fake_import_module(name, package=None):
if name == 'swanlab':
return fake_swanlab
return real_import_module(name, package)
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
checkpoint_dir = Path('checkpoints')
checkpoint_dir.mkdir()
resume_path = checkpoint_dir / 'vla_model_step_0.pt'
resume_path.write_bytes(b'resume')
best_path = checkpoint_dir / 'vla_model_best.pt'
best_path.write_bytes(b'best')
cfg.train.resume_ckpt = str(resume_path)
resume_checkpoint_state = {
'step': 0,
'model_state_dict': FakeAgent().state_dict(),
'optimizer_state_dict': {},
'scheduler_state_dict': {},
'loss': 0.5,
'val_loss': 0.25,
}
best_checkpoint_state = {
'step': 0,
'model_state_dict': FakeAgent().state_dict(),
'optimizer_state_dict': {},
'scheduler_state_dict': {},
'loss': 0.5,
'val_loss': 0.25,
'rollout_avg_reward': 5.0,
}
def fake_torch_load(path, map_location=None):
del map_location
path = Path(path)
if path == resume_path:
return resume_checkpoint_state
if path == best_path:
return best_checkpoint_state
raise AssertionError(f'unexpected load path: {path}')
def fake_torch_save(payload, path):
saved_paths.append(str(path))
return None
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
mock.patch.object(module, 'build_training_optimizer', return_value=fake_optimizer), \
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=fake_scheduler), \
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
mock.patch.object(module.torch, 'save', side_effect=fake_torch_save), \
mock.patch.object(module.torch, 'load', side_effect=fake_torch_load), \
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module), \
mock.patch('roboimi.demos.vla_scripts.eval_vla._run_eval', return_value={'avg_reward': 3.0}):
run_training(cfg)
finally:
os.chdir(previous_cwd)
final_payload, final_step = fake_swanlab.log_calls[-1]
self.assertEqual(final_step, cfg.train.max_steps)
self.assertEqual(final_payload['final/best_checkpoint_path'], 'checkpoints/vla_model_best.pt')
self.assertNotIn('checkpoints/vla_model_best.pt', saved_paths)
def test_run_training_resume_ignores_best_checkpoint_without_rollout_metric(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
cfg = self._make_cfg()
cfg.train.max_steps = 1
fake_swanlab = FakeSwanLab()
fake_optimizer = FakeOptimizer(lr=cfg.train.lr)
fake_scheduler = FakeScheduler()
real_import_module = importlib.import_module
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return FakeDataset()
if config_node is cfg.agent:
return FakeAgent()
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
def fake_import_module(name, package=None):
if name == 'swanlab':
return fake_swanlab
return real_import_module(name, package)
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
checkpoint_dir = Path('checkpoints')
checkpoint_dir.mkdir()
resume_path = checkpoint_dir / 'vla_model_step_0.pt'
resume_path.write_bytes(b'resume')
best_path = checkpoint_dir / 'vla_model_best.pt'
best_path.write_bytes(b'stale')
cfg.train.resume_ckpt = str(resume_path)
resume_checkpoint_state = {
'step': 0,
'model_state_dict': FakeAgent().state_dict(),
'optimizer_state_dict': {},
'scheduler_state_dict': {},
'loss': 0.5,
'val_loss': 0.25,
}
stale_best_checkpoint_state = {
'step': 0,
'model_state_dict': FakeAgent().state_dict(),
'optimizer_state_dict': {},
'scheduler_state_dict': {},
'loss': 0.4,
'val_loss': 0.2,
}
def fake_torch_load(path, map_location=None):
del map_location
path = Path(path)
if path == resume_path:
return resume_checkpoint_state
if path == best_path:
return stale_best_checkpoint_state
raise AssertionError(f'unexpected load path: {path}')
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
mock.patch.object(module, 'build_training_optimizer', return_value=fake_optimizer), \
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=fake_scheduler), \
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
mock.patch.object(module.torch, 'save', return_value=None), \
mock.patch.object(module.torch, 'load', side_effect=fake_torch_load), \
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
run_training(cfg)
finally:
os.chdir(previous_cwd)
final_payload, final_step = fake_swanlab.log_calls[-1]
self.assertEqual(final_step, cfg.train.max_steps)
self.assertEqual(final_payload['final/best_checkpoint_path'], 'checkpoints/vla_model_step_0.pt')
def test_run_training_ignores_stale_best_checkpoint_file_on_fresh_non_resume_run(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
cfg = self._make_cfg()
cfg.train.max_steps = 1
fake_swanlab = FakeSwanLab()
real_import_module = importlib.import_module
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return FakeDataset()
if config_node is cfg.agent:
return FakeAgent()
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
def fake_import_module(name, package=None):
if name == 'swanlab':
return fake_swanlab
return real_import_module(name, package)
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
checkpoint_dir = Path('checkpoints')
checkpoint_dir.mkdir()
(checkpoint_dir / 'vla_model_best.pt').write_bytes(b'stale-best')
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
mock.patch.object(module.torch, 'save', return_value=None), \
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
run_training(cfg)
finally:
os.chdir(previous_cwd)
final_payload, final_step = fake_swanlab.log_calls[-1]
self.assertEqual(final_step, cfg.train.max_steps)
self.assertEqual(final_payload['final/best_checkpoint_path'], '')
def test_run_training_fails_fast_when_swanlab_import_is_unavailable(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
cfg = self._make_cfg()
real_import_module = importlib.import_module
def fake_import_module(name, package=None):
if name == 'swanlab':
raise ImportError('missing swanlab')
return real_import_module(name, package)
with mock.patch.object(module, 'instantiate', side_effect=AssertionError('instantiate should not run')), \
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
with self.assertRaisesRegex(RuntimeError, 'SwanLab'):
run_training(cfg)
def test_run_training_fails_fast_when_swanlab_init_fails(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
cfg = self._make_cfg()
fake_swanlab = FakeSwanLab(init_error=RuntimeError('not logged in'))
real_import_module = importlib.import_module
def fake_import_module(name, package=None):
if name == 'swanlab':
return fake_swanlab
return real_import_module(name, package)
with mock.patch.object(module, 'instantiate', side_effect=AssertionError('instantiate should not run')), \
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
with self.assertRaisesRegex(RuntimeError, 'not logged in'):
run_training(cfg)
self.assertEqual(fake_swanlab.finish_calls, 0)
if __name__ == '__main__':
unittest.main()