feat: add rollout trajectory image artifacts and swanlab logging

This commit is contained in:
Logic
2026-04-03 09:39:16 +08:00
parent 48f0eb8dd0
commit 0586a6e6c7
8 changed files with 626 additions and 21 deletions

View File

@@ -115,13 +115,15 @@ class FakeAgent(nn.Module):
class FakeSwanLab:
def __init__(self, init_error=None, log_errors=None, finish_error=None):
def __init__(self, init_error=None, log_errors=None, finish_error=None, image_errors=None):
self.init_error = init_error
self.log_errors = list(log_errors or [])
self.finish_error = finish_error
self.image_errors = list(image_errors or [])
self.init_calls = []
self.log_calls = []
self.finish_calls = 0
self.image_calls = []
def init(self, project, experiment_name=None, config=None):
self.init_calls.append({
@@ -138,6 +140,18 @@ class FakeSwanLab:
if self.log_errors:
raise self.log_errors.pop(0)
def Image(self, path, caption=None):
self.image_calls.append({
'path': path,
'caption': caption,
})
if self.image_errors:
raise self.image_errors.pop(0)
return {
'path': path,
'caption': caption,
}
def finish(self):
self.finish_calls += 1
if self.finish_error is not None:
@@ -149,6 +163,119 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
config_text = _CONFIG_PATH.read_text(encoding='utf-8')
self.assertIn('use_swanlab: false', config_text)
def test_log_rollout_trajectory_images_to_swanlab_uploads_episode_artifacts(self):
module = self._load_train_vla_module()
fake_swanlab = FakeSwanLab()
module._log_rollout_trajectory_images_to_swanlab(
fake_swanlab,
{
'episodes': [
{
'episode_index': 0,
'artifact_paths': {'trajectory_image': 'artifacts/episode_0_front.png'},
},
{
'episode_index': 3,
'artifact_paths': {'trajectory_image': 'artifacts/episode_3_front.png'},
},
{
'episode_index': 7,
'artifact_paths': {'trajectory_image': None},
},
{
'episode_index': 8,
'artifact_paths': {},
},
],
},
step=12,
context_label='epoch 1 rollout',
)
self.assertEqual(
fake_swanlab.image_calls,
[
{
'path': 'artifacts/episode_0_front.png',
'caption': 'epoch 1 rollout trajectory image - episode 0 (front)',
},
{
'path': 'artifacts/episode_3_front.png',
'caption': 'epoch 1 rollout trajectory image - episode 3 (front)',
},
],
)
self.assertIn(
(
{
'rollout/trajectory_image_episode_0': {
'path': 'artifacts/episode_0_front.png',
'caption': 'epoch 1 rollout trajectory image - episode 0 (front)',
},
'rollout/trajectory_image_episode_3': {
'path': 'artifacts/episode_3_front.png',
'caption': 'epoch 1 rollout trajectory image - episode 3 (front)',
},
},
12,
),
fake_swanlab.log_calls,
)
def test_log_rollout_trajectory_images_to_swanlab_is_best_effort(self):
module = self._load_train_vla_module()
fake_swanlab = FakeSwanLab(image_errors=[RuntimeError('decode failed')])
with mock.patch.object(module.log, 'warning') as warning_mock:
module._log_rollout_trajectory_images_to_swanlab(
fake_swanlab,
{
'episodes': [
{
'episode_index': 0,
'artifact_paths': {'trajectory_image': 'artifacts/bad_episode.png'},
},
{
'episode_index': 1,
'artifact_paths': {'trajectory_image': 'artifacts/good_episode.png'},
},
],
},
step=7,
context_label='checkpoint rollout',
)
self.assertEqual(
fake_swanlab.image_calls,
[
{
'path': 'artifacts/bad_episode.png',
'caption': 'checkpoint rollout trajectory image - episode 0 (front)',
},
{
'path': 'artifacts/good_episode.png',
'caption': 'checkpoint rollout trajectory image - episode 1 (front)',
},
],
)
self.assertIn(
(
{
'rollout/trajectory_image_episode_1': {
'path': 'artifacts/good_episode.png',
'caption': 'checkpoint rollout trajectory image - episode 1 (front)',
},
},
7,
),
fake_swanlab.log_calls,
)
warning_messages = [call.args[0] for call in warning_mock.call_args_list]
self.assertTrue(
any('SwanLab rollout trajectory image upload prep failed' in message for message in warning_messages)
)
def _load_train_vla_module(self):
hydra_module = types.ModuleType('hydra')
hydra_utils_module = types.ModuleType('hydra.utils')
@@ -356,8 +483,8 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
final_payload, final_step = fake_swanlab.log_calls[-1]
self.assertEqual(final_step, cfg.train.max_steps)
self.assertEqual(final_payload['final/checkpoint_path'], 'checkpoints/vla_model_final.pt')
self.assertEqual(final_payload['final/best_checkpoint_path'], 'checkpoints/vla_model_best.pt')
self.assertTrue(final_payload['final/checkpoint_path'].endswith('checkpoints/vla_model_final.pt'))
self.assertTrue(final_payload['final/best_checkpoint_path'].endswith('checkpoints/vla_model_best.pt'))
self.assertEqual(fake_swanlab.finish_calls, 1)
def test_run_training_skips_swanlab_when_disabled(self):
@@ -512,10 +639,10 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
def fake_torch_load(path, map_location=None):
del map_location
path = Path(path)
if path == resume_path:
path = Path(path).resolve()
if path == resume_path.resolve():
return resume_checkpoint_state
if path == best_path:
if path == best_path.resolve():
return best_checkpoint_state
raise AssertionError(f'unexpected load path: {path}')
@@ -538,8 +665,8 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
final_payload, final_step = fake_swanlab.log_calls[-1]
self.assertEqual(final_step, cfg.train.max_steps)
self.assertEqual(final_payload['final/best_checkpoint_path'], 'checkpoints/vla_model_best.pt')
self.assertNotIn('checkpoints/vla_model_best.pt', saved_paths)
self.assertTrue(final_payload['final/best_checkpoint_path'].endswith('checkpoints/vla_model_best.pt'))
self.assertFalse(any(path.endswith('checkpoints/vla_model_best.pt') for path in saved_paths))
def test_run_training_resume_ignores_best_checkpoint_without_rollout_metric(self):
module = self._load_train_vla_module()
@@ -594,10 +721,10 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
def fake_torch_load(path, map_location=None):
del map_location
path = Path(path)
if path == resume_path:
path = Path(path).resolve()
if path == resume_path.resolve():
return resume_checkpoint_state
if path == best_path:
if path == best_path.resolve():
return stale_best_checkpoint_state
raise AssertionError(f'unexpected load path: {path}')