feat: add rollout trajectory image artifacts and swanlab logging
This commit is contained in:
@@ -115,13 +115,15 @@ class FakeAgent(nn.Module):
|
||||
|
||||
|
||||
class FakeSwanLab:
|
||||
def __init__(self, init_error=None, log_errors=None, finish_error=None):
|
||||
def __init__(self, init_error=None, log_errors=None, finish_error=None, image_errors=None):
|
||||
self.init_error = init_error
|
||||
self.log_errors = list(log_errors or [])
|
||||
self.finish_error = finish_error
|
||||
self.image_errors = list(image_errors or [])
|
||||
self.init_calls = []
|
||||
self.log_calls = []
|
||||
self.finish_calls = 0
|
||||
self.image_calls = []
|
||||
|
||||
def init(self, project, experiment_name=None, config=None):
|
||||
self.init_calls.append({
|
||||
@@ -138,6 +140,18 @@ class FakeSwanLab:
|
||||
if self.log_errors:
|
||||
raise self.log_errors.pop(0)
|
||||
|
||||
def Image(self, path, caption=None):
|
||||
self.image_calls.append({
|
||||
'path': path,
|
||||
'caption': caption,
|
||||
})
|
||||
if self.image_errors:
|
||||
raise self.image_errors.pop(0)
|
||||
return {
|
||||
'path': path,
|
||||
'caption': caption,
|
||||
}
|
||||
|
||||
def finish(self):
|
||||
self.finish_calls += 1
|
||||
if self.finish_error is not None:
|
||||
@@ -149,6 +163,119 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
||||
config_text = _CONFIG_PATH.read_text(encoding='utf-8')
|
||||
self.assertIn('use_swanlab: false', config_text)
|
||||
|
||||
def test_log_rollout_trajectory_images_to_swanlab_uploads_episode_artifacts(self):
|
||||
module = self._load_train_vla_module()
|
||||
fake_swanlab = FakeSwanLab()
|
||||
|
||||
module._log_rollout_trajectory_images_to_swanlab(
|
||||
fake_swanlab,
|
||||
{
|
||||
'episodes': [
|
||||
{
|
||||
'episode_index': 0,
|
||||
'artifact_paths': {'trajectory_image': 'artifacts/episode_0_front.png'},
|
||||
},
|
||||
{
|
||||
'episode_index': 3,
|
||||
'artifact_paths': {'trajectory_image': 'artifacts/episode_3_front.png'},
|
||||
},
|
||||
{
|
||||
'episode_index': 7,
|
||||
'artifact_paths': {'trajectory_image': None},
|
||||
},
|
||||
{
|
||||
'episode_index': 8,
|
||||
'artifact_paths': {},
|
||||
},
|
||||
],
|
||||
},
|
||||
step=12,
|
||||
context_label='epoch 1 rollout',
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
fake_swanlab.image_calls,
|
||||
[
|
||||
{
|
||||
'path': 'artifacts/episode_0_front.png',
|
||||
'caption': 'epoch 1 rollout trajectory image - episode 0 (front)',
|
||||
},
|
||||
{
|
||||
'path': 'artifacts/episode_3_front.png',
|
||||
'caption': 'epoch 1 rollout trajectory image - episode 3 (front)',
|
||||
},
|
||||
],
|
||||
)
|
||||
self.assertIn(
|
||||
(
|
||||
{
|
||||
'rollout/trajectory_image_episode_0': {
|
||||
'path': 'artifacts/episode_0_front.png',
|
||||
'caption': 'epoch 1 rollout trajectory image - episode 0 (front)',
|
||||
},
|
||||
'rollout/trajectory_image_episode_3': {
|
||||
'path': 'artifacts/episode_3_front.png',
|
||||
'caption': 'epoch 1 rollout trajectory image - episode 3 (front)',
|
||||
},
|
||||
},
|
||||
12,
|
||||
),
|
||||
fake_swanlab.log_calls,
|
||||
)
|
||||
|
||||
def test_log_rollout_trajectory_images_to_swanlab_is_best_effort(self):
|
||||
module = self._load_train_vla_module()
|
||||
fake_swanlab = FakeSwanLab(image_errors=[RuntimeError('decode failed')])
|
||||
|
||||
with mock.patch.object(module.log, 'warning') as warning_mock:
|
||||
module._log_rollout_trajectory_images_to_swanlab(
|
||||
fake_swanlab,
|
||||
{
|
||||
'episodes': [
|
||||
{
|
||||
'episode_index': 0,
|
||||
'artifact_paths': {'trajectory_image': 'artifacts/bad_episode.png'},
|
||||
},
|
||||
{
|
||||
'episode_index': 1,
|
||||
'artifact_paths': {'trajectory_image': 'artifacts/good_episode.png'},
|
||||
},
|
||||
],
|
||||
},
|
||||
step=7,
|
||||
context_label='checkpoint rollout',
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
fake_swanlab.image_calls,
|
||||
[
|
||||
{
|
||||
'path': 'artifacts/bad_episode.png',
|
||||
'caption': 'checkpoint rollout trajectory image - episode 0 (front)',
|
||||
},
|
||||
{
|
||||
'path': 'artifacts/good_episode.png',
|
||||
'caption': 'checkpoint rollout trajectory image - episode 1 (front)',
|
||||
},
|
||||
],
|
||||
)
|
||||
self.assertIn(
|
||||
(
|
||||
{
|
||||
'rollout/trajectory_image_episode_1': {
|
||||
'path': 'artifacts/good_episode.png',
|
||||
'caption': 'checkpoint rollout trajectory image - episode 1 (front)',
|
||||
},
|
||||
},
|
||||
7,
|
||||
),
|
||||
fake_swanlab.log_calls,
|
||||
)
|
||||
warning_messages = [call.args[0] for call in warning_mock.call_args_list]
|
||||
self.assertTrue(
|
||||
any('SwanLab rollout trajectory image upload prep failed' in message for message in warning_messages)
|
||||
)
|
||||
|
||||
def _load_train_vla_module(self):
|
||||
hydra_module = types.ModuleType('hydra')
|
||||
hydra_utils_module = types.ModuleType('hydra.utils')
|
||||
@@ -356,8 +483,8 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
||||
|
||||
final_payload, final_step = fake_swanlab.log_calls[-1]
|
||||
self.assertEqual(final_step, cfg.train.max_steps)
|
||||
self.assertEqual(final_payload['final/checkpoint_path'], 'checkpoints/vla_model_final.pt')
|
||||
self.assertEqual(final_payload['final/best_checkpoint_path'], 'checkpoints/vla_model_best.pt')
|
||||
self.assertTrue(final_payload['final/checkpoint_path'].endswith('checkpoints/vla_model_final.pt'))
|
||||
self.assertTrue(final_payload['final/best_checkpoint_path'].endswith('checkpoints/vla_model_best.pt'))
|
||||
self.assertEqual(fake_swanlab.finish_calls, 1)
|
||||
|
||||
def test_run_training_skips_swanlab_when_disabled(self):
|
||||
@@ -512,10 +639,10 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
||||
|
||||
def fake_torch_load(path, map_location=None):
|
||||
del map_location
|
||||
path = Path(path)
|
||||
if path == resume_path:
|
||||
path = Path(path).resolve()
|
||||
if path == resume_path.resolve():
|
||||
return resume_checkpoint_state
|
||||
if path == best_path:
|
||||
if path == best_path.resolve():
|
||||
return best_checkpoint_state
|
||||
raise AssertionError(f'unexpected load path: {path}')
|
||||
|
||||
@@ -538,8 +665,8 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
||||
|
||||
final_payload, final_step = fake_swanlab.log_calls[-1]
|
||||
self.assertEqual(final_step, cfg.train.max_steps)
|
||||
self.assertEqual(final_payload['final/best_checkpoint_path'], 'checkpoints/vla_model_best.pt')
|
||||
self.assertNotIn('checkpoints/vla_model_best.pt', saved_paths)
|
||||
self.assertTrue(final_payload['final/best_checkpoint_path'].endswith('checkpoints/vla_model_best.pt'))
|
||||
self.assertFalse(any(path.endswith('checkpoints/vla_model_best.pt') for path in saved_paths))
|
||||
|
||||
def test_run_training_resume_ignores_best_checkpoint_without_rollout_metric(self):
|
||||
module = self._load_train_vla_module()
|
||||
@@ -594,10 +721,10 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
||||
|
||||
def fake_torch_load(path, map_location=None):
|
||||
del map_location
|
||||
path = Path(path)
|
||||
if path == resume_path:
|
||||
path = Path(path).resolve()
|
||||
if path == resume_path.resolve():
|
||||
return resume_checkpoint_state
|
||||
if path == best_path:
|
||||
if path == best_path.resolve():
|
||||
return stale_best_checkpoint_state
|
||||
raise AssertionError(f'unexpected load path: {path}')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user