feat: add rollout trajectory image artifacts and swanlab logging
This commit is contained in:
@@ -102,8 +102,10 @@ class EvalVLARolloutArtifactsTest(unittest.TestCase):
|
||||
self.assertIn('artifact_dir', eval_cfg)
|
||||
self.assertFalse(eval_cfg.save_summary_json)
|
||||
self.assertFalse(eval_cfg.save_trajectory_npz)
|
||||
self.assertFalse(eval_cfg.save_trajectory_image)
|
||||
self.assertFalse(eval_cfg.record_video)
|
||||
self.assertIsNone(eval_cfg.artifact_dir)
|
||||
self.assertIsNone(eval_cfg.trajectory_image_camera_name)
|
||||
self.assertIsNone(eval_cfg.video_camera_name)
|
||||
self.assertEqual(eval_cfg.video_fps, 30)
|
||||
|
||||
@@ -133,6 +135,8 @@ class EvalVLARolloutArtifactsTest(unittest.TestCase):
|
||||
'artifact_dir': tmpdir,
|
||||
'save_summary_json': True,
|
||||
'save_trajectory_npz': True,
|
||||
'save_trajectory_image': True,
|
||||
'trajectory_image_camera_name': 'front',
|
||||
'record_video': True,
|
||||
'video_camera_name': 'front',
|
||||
'video_fps': 12,
|
||||
@@ -176,12 +180,14 @@ class EvalVLARolloutArtifactsTest(unittest.TestCase):
|
||||
trajectory_path = Path(artifacts['trajectory_npz'])
|
||||
summary_path = Path(artifacts['summary_json'])
|
||||
video_path = Path(artifacts['video_mp4'])
|
||||
trajectory_image_path = Path(summary['episodes'][0]['artifact_paths']['trajectory_image'])
|
||||
|
||||
self.assertEqual(Path(artifacts['output_dir']), Path(tmpdir))
|
||||
self.assertEqual(artifacts['video_camera_name'], 'front')
|
||||
self.assertTrue(trajectory_path.exists())
|
||||
self.assertTrue(summary_path.exists())
|
||||
self.assertTrue(video_path.exists())
|
||||
self.assertTrue(trajectory_image_path.exists())
|
||||
|
||||
rollout_npz = np.load(trajectory_path)
|
||||
np.testing.assert_array_equal(rollout_npz['episode_index'], np.array([0, 0]))
|
||||
@@ -218,11 +224,121 @@ class EvalVLARolloutArtifactsTest(unittest.TestCase):
|
||||
saved_summary = json.load(fh)
|
||||
self.assertEqual(saved_summary['artifacts']['trajectory_npz'], str(trajectory_path))
|
||||
self.assertEqual(saved_summary['artifacts']['video_mp4'], str(video_path))
|
||||
self.assertEqual(
|
||||
saved_summary['episodes'][0]['artifact_paths']['trajectory_image'],
|
||||
str(trajectory_image_path),
|
||||
)
|
||||
self.assertEqual(saved_summary['episode_rewards'], [3.0])
|
||||
self.assertAlmostEqual(summary['avg_reward'], 3.0)
|
||||
self.assertIn('avg_obs_read_time_ms', summary)
|
||||
self.assertIn('avg_env_step_time_ms', summary)
|
||||
|
||||
def test_run_eval_exports_front_trajectory_images_without_video_dependency(self):
|
||||
actions = [
|
||||
np.arange(16, dtype=np.float32),
|
||||
np.arange(16, dtype=np.float32) + 10.0,
|
||||
np.arange(16, dtype=np.float32) + 100.0,
|
||||
np.arange(16, dtype=np.float32) + 110.0,
|
||||
]
|
||||
fake_agent = _FakeAgent(actions)
|
||||
fake_env = _FakeEnv()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
'agent': {},
|
||||
'eval': {
|
||||
'ckpt_path': 'checkpoints/vla_model_best.pt',
|
||||
'num_episodes': 2,
|
||||
'max_timesteps': 2,
|
||||
'device': 'cpu',
|
||||
'task_name': 'sim_transfer',
|
||||
'camera_names': ['top', 'front'],
|
||||
'use_smoothing': True,
|
||||
'smooth_alpha': 0.5,
|
||||
'verbose_action': False,
|
||||
'headless': True,
|
||||
'artifact_dir': tmpdir,
|
||||
'save_trajectory_image': True,
|
||||
'record_video': False,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
trajectory_image_calls = []
|
||||
|
||||
def fake_save_rollout_trajectory_image(
|
||||
env,
|
||||
output_path,
|
||||
raw_actions,
|
||||
camera_name,
|
||||
*,
|
||||
line_radius=0.004,
|
||||
max_markers=1500,
|
||||
):
|
||||
del env, line_radius, max_markers
|
||||
trajectory_image_calls.append(
|
||||
{
|
||||
'output_path': output_path,
|
||||
'camera_name': camera_name,
|
||||
'raw_actions': [np.array(action, copy=True) for action in raw_actions],
|
||||
}
|
||||
)
|
||||
if output_path is None:
|
||||
return None
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
output_path.write_bytes(b'fake-png')
|
||||
return str(output_path)
|
||||
|
||||
with mock.patch.object(
|
||||
eval_vla,
|
||||
'load_checkpoint',
|
||||
return_value=(fake_agent, None),
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
'make_sim_env',
|
||||
return_value=fake_env,
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
'sample_transfer_pose',
|
||||
return_value=np.array([0.1, 0.2, 0.3], dtype=np.float32),
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
'tqdm',
|
||||
side_effect=lambda iterable, **kwargs: iterable,
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
'_save_rollout_trajectory_image',
|
||||
side_effect=fake_save_rollout_trajectory_image,
|
||||
) as save_trajectory_image_mock, mock.patch.object(
|
||||
eval_vla,
|
||||
'_open_video_writer',
|
||||
) as open_video_writer_mock:
|
||||
summary = eval_vla._run_eval(cfg)
|
||||
|
||||
self.assertEqual(save_trajectory_image_mock.call_count, 2)
|
||||
open_video_writer_mock.assert_not_called()
|
||||
self.assertIsNone(summary['artifacts']['video_mp4'])
|
||||
self.assertEqual(summary['artifacts']['trajectory_image_camera_name'], 'front')
|
||||
self.assertEqual(
|
||||
[call['camera_name'] for call in trajectory_image_calls],
|
||||
['front', 'front'],
|
||||
)
|
||||
|
||||
first_episode_path = Path(summary['episodes'][0]['artifact_paths']['trajectory_image'])
|
||||
second_episode_path = Path(summary['episodes'][1]['artifact_paths']['trajectory_image'])
|
||||
self.assertTrue(first_episode_path.exists())
|
||||
self.assertTrue(second_episode_path.exists())
|
||||
self.assertNotEqual(first_episode_path, second_episode_path)
|
||||
self.assertEqual(first_episode_path.parent, Path(tmpdir))
|
||||
self.assertEqual(second_episode_path.parent, Path(tmpdir))
|
||||
|
||||
np.testing.assert_array_equal(trajectory_image_calls[0]['raw_actions'][0], actions[0])
|
||||
np.testing.assert_array_equal(trajectory_image_calls[0]['raw_actions'][1], actions[1])
|
||||
np.testing.assert_array_equal(trajectory_image_calls[1]['raw_actions'][0], actions[2])
|
||||
np.testing.assert_array_equal(trajectory_image_calls[1]['raw_actions'][1], actions[3])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -234,7 +234,28 @@ class TrainVLARolloutValidationTest(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
agent = _FakeAgent()
|
||||
rollout_mock = mock.Mock(side_effect=[{'avg_reward': 2.0}, {'avg_reward': 1.0}])
|
||||
rollout_mock = mock.Mock(
|
||||
side_effect=[
|
||||
{
|
||||
'avg_reward': 2.0,
|
||||
'episodes': [
|
||||
{
|
||||
'episode_index': 0,
|
||||
'artifact_paths': {'trajectory_image': 'artifacts/epoch_49_front.png'},
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
'avg_reward': 1.0,
|
||||
'episodes': [
|
||||
{
|
||||
'episode_index': 0,
|
||||
'artifact_paths': {'trajectory_image': 'artifacts/epoch_99_front.png'},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
)
|
||||
swanlab_log_mock = mock.Mock()
|
||||
saved_checkpoints = []
|
||||
|
||||
@@ -281,17 +302,22 @@ class TrainVLARolloutValidationTest(unittest.TestCase):
|
||||
self.assertEqual(rollout_mock.call_count, 2)
|
||||
first_rollout_cfg = rollout_mock.call_args_list[0].args[0]
|
||||
second_rollout_cfg = rollout_mock.call_args_list[1].args[0]
|
||||
self.assertEqual(first_rollout_cfg.eval.ckpt_path, 'checkpoints/vla_model_step_49.pt')
|
||||
self.assertEqual(second_rollout_cfg.eval.ckpt_path, 'checkpoints/vla_model_step_99.pt')
|
||||
self.assertTrue(first_rollout_cfg.eval.ckpt_path.endswith('checkpoints/vla_model_step_49.pt'))
|
||||
self.assertTrue(second_rollout_cfg.eval.ckpt_path.endswith('checkpoints/vla_model_step_99.pt'))
|
||||
self.assertEqual(first_rollout_cfg.eval.num_episodes, 3)
|
||||
self.assertTrue(first_rollout_cfg.eval.headless)
|
||||
self.assertEqual(first_rollout_cfg.eval.device, 'cpu')
|
||||
self.assertFalse(first_rollout_cfg.eval.verbose_action)
|
||||
self.assertFalse(first_rollout_cfg.eval.record_video)
|
||||
self.assertTrue(first_rollout_cfg.eval.save_trajectory_image)
|
||||
self.assertEqual(first_rollout_cfg.eval.trajectory_image_camera_name, 'front')
|
||||
self.assertEqual(cfg.eval.ckpt_path, 'unused.pt')
|
||||
self.assertEqual(cfg.eval.num_episodes, 99)
|
||||
self.assertFalse(cfg.eval.headless)
|
||||
self.assertEqual(cfg.eval.device, 'cpu')
|
||||
self.assertFalse(cfg.eval.verbose_action)
|
||||
self.assertNotIn('save_trajectory_image', cfg.eval)
|
||||
self.assertNotIn('trajectory_image_camera_name', cfg.eval)
|
||||
|
||||
rollout_reward_logs = [
|
||||
call.args[1]['rollout/avg_reward']
|
||||
@@ -769,10 +795,8 @@ class TrainVLARolloutValidationTest(unittest.TestCase):
|
||||
'dataset_len': 1,
|
||||
},
|
||||
)
|
||||
self.assertEqual(
|
||||
[path for path, _payload in saved_checkpoints],
|
||||
['checkpoints/vla_model_final.pt'],
|
||||
)
|
||||
self.assertEqual(len(saved_checkpoints), 1)
|
||||
self.assertTrue(saved_checkpoints[0][0].endswith('checkpoints/vla_model_final.pt'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -115,13 +115,15 @@ class FakeAgent(nn.Module):
|
||||
|
||||
|
||||
class FakeSwanLab:
|
||||
def __init__(self, init_error=None, log_errors=None, finish_error=None):
|
||||
def __init__(self, init_error=None, log_errors=None, finish_error=None, image_errors=None):
|
||||
self.init_error = init_error
|
||||
self.log_errors = list(log_errors or [])
|
||||
self.finish_error = finish_error
|
||||
self.image_errors = list(image_errors or [])
|
||||
self.init_calls = []
|
||||
self.log_calls = []
|
||||
self.finish_calls = 0
|
||||
self.image_calls = []
|
||||
|
||||
def init(self, project, experiment_name=None, config=None):
|
||||
self.init_calls.append({
|
||||
@@ -138,6 +140,18 @@ class FakeSwanLab:
|
||||
if self.log_errors:
|
||||
raise self.log_errors.pop(0)
|
||||
|
||||
def Image(self, path, caption=None):
|
||||
self.image_calls.append({
|
||||
'path': path,
|
||||
'caption': caption,
|
||||
})
|
||||
if self.image_errors:
|
||||
raise self.image_errors.pop(0)
|
||||
return {
|
||||
'path': path,
|
||||
'caption': caption,
|
||||
}
|
||||
|
||||
def finish(self):
|
||||
self.finish_calls += 1
|
||||
if self.finish_error is not None:
|
||||
@@ -149,6 +163,119 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
||||
config_text = _CONFIG_PATH.read_text(encoding='utf-8')
|
||||
self.assertIn('use_swanlab: false', config_text)
|
||||
|
||||
def test_log_rollout_trajectory_images_to_swanlab_uploads_episode_artifacts(self):
|
||||
module = self._load_train_vla_module()
|
||||
fake_swanlab = FakeSwanLab()
|
||||
|
||||
module._log_rollout_trajectory_images_to_swanlab(
|
||||
fake_swanlab,
|
||||
{
|
||||
'episodes': [
|
||||
{
|
||||
'episode_index': 0,
|
||||
'artifact_paths': {'trajectory_image': 'artifacts/episode_0_front.png'},
|
||||
},
|
||||
{
|
||||
'episode_index': 3,
|
||||
'artifact_paths': {'trajectory_image': 'artifacts/episode_3_front.png'},
|
||||
},
|
||||
{
|
||||
'episode_index': 7,
|
||||
'artifact_paths': {'trajectory_image': None},
|
||||
},
|
||||
{
|
||||
'episode_index': 8,
|
||||
'artifact_paths': {},
|
||||
},
|
||||
],
|
||||
},
|
||||
step=12,
|
||||
context_label='epoch 1 rollout',
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
fake_swanlab.image_calls,
|
||||
[
|
||||
{
|
||||
'path': 'artifacts/episode_0_front.png',
|
||||
'caption': 'epoch 1 rollout trajectory image - episode 0 (front)',
|
||||
},
|
||||
{
|
||||
'path': 'artifacts/episode_3_front.png',
|
||||
'caption': 'epoch 1 rollout trajectory image - episode 3 (front)',
|
||||
},
|
||||
],
|
||||
)
|
||||
self.assertIn(
|
||||
(
|
||||
{
|
||||
'rollout/trajectory_image_episode_0': {
|
||||
'path': 'artifacts/episode_0_front.png',
|
||||
'caption': 'epoch 1 rollout trajectory image - episode 0 (front)',
|
||||
},
|
||||
'rollout/trajectory_image_episode_3': {
|
||||
'path': 'artifacts/episode_3_front.png',
|
||||
'caption': 'epoch 1 rollout trajectory image - episode 3 (front)',
|
||||
},
|
||||
},
|
||||
12,
|
||||
),
|
||||
fake_swanlab.log_calls,
|
||||
)
|
||||
|
||||
def test_log_rollout_trajectory_images_to_swanlab_is_best_effort(self):
|
||||
module = self._load_train_vla_module()
|
||||
fake_swanlab = FakeSwanLab(image_errors=[RuntimeError('decode failed')])
|
||||
|
||||
with mock.patch.object(module.log, 'warning') as warning_mock:
|
||||
module._log_rollout_trajectory_images_to_swanlab(
|
||||
fake_swanlab,
|
||||
{
|
||||
'episodes': [
|
||||
{
|
||||
'episode_index': 0,
|
||||
'artifact_paths': {'trajectory_image': 'artifacts/bad_episode.png'},
|
||||
},
|
||||
{
|
||||
'episode_index': 1,
|
||||
'artifact_paths': {'trajectory_image': 'artifacts/good_episode.png'},
|
||||
},
|
||||
],
|
||||
},
|
||||
step=7,
|
||||
context_label='checkpoint rollout',
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
fake_swanlab.image_calls,
|
||||
[
|
||||
{
|
||||
'path': 'artifacts/bad_episode.png',
|
||||
'caption': 'checkpoint rollout trajectory image - episode 0 (front)',
|
||||
},
|
||||
{
|
||||
'path': 'artifacts/good_episode.png',
|
||||
'caption': 'checkpoint rollout trajectory image - episode 1 (front)',
|
||||
},
|
||||
],
|
||||
)
|
||||
self.assertIn(
|
||||
(
|
||||
{
|
||||
'rollout/trajectory_image_episode_1': {
|
||||
'path': 'artifacts/good_episode.png',
|
||||
'caption': 'checkpoint rollout trajectory image - episode 1 (front)',
|
||||
},
|
||||
},
|
||||
7,
|
||||
),
|
||||
fake_swanlab.log_calls,
|
||||
)
|
||||
warning_messages = [call.args[0] for call in warning_mock.call_args_list]
|
||||
self.assertTrue(
|
||||
any('SwanLab rollout trajectory image upload prep failed' in message for message in warning_messages)
|
||||
)
|
||||
|
||||
def _load_train_vla_module(self):
|
||||
hydra_module = types.ModuleType('hydra')
|
||||
hydra_utils_module = types.ModuleType('hydra.utils')
|
||||
@@ -356,8 +483,8 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
||||
|
||||
final_payload, final_step = fake_swanlab.log_calls[-1]
|
||||
self.assertEqual(final_step, cfg.train.max_steps)
|
||||
self.assertEqual(final_payload['final/checkpoint_path'], 'checkpoints/vla_model_final.pt')
|
||||
self.assertEqual(final_payload['final/best_checkpoint_path'], 'checkpoints/vla_model_best.pt')
|
||||
self.assertTrue(final_payload['final/checkpoint_path'].endswith('checkpoints/vla_model_final.pt'))
|
||||
self.assertTrue(final_payload['final/best_checkpoint_path'].endswith('checkpoints/vla_model_best.pt'))
|
||||
self.assertEqual(fake_swanlab.finish_calls, 1)
|
||||
|
||||
def test_run_training_skips_swanlab_when_disabled(self):
|
||||
@@ -512,10 +639,10 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
||||
|
||||
def fake_torch_load(path, map_location=None):
|
||||
del map_location
|
||||
path = Path(path)
|
||||
if path == resume_path:
|
||||
path = Path(path).resolve()
|
||||
if path == resume_path.resolve():
|
||||
return resume_checkpoint_state
|
||||
if path == best_path:
|
||||
if path == best_path.resolve():
|
||||
return best_checkpoint_state
|
||||
raise AssertionError(f'unexpected load path: {path}')
|
||||
|
||||
@@ -538,8 +665,8 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
||||
|
||||
final_payload, final_step = fake_swanlab.log_calls[-1]
|
||||
self.assertEqual(final_step, cfg.train.max_steps)
|
||||
self.assertEqual(final_payload['final/best_checkpoint_path'], 'checkpoints/vla_model_best.pt')
|
||||
self.assertNotIn('checkpoints/vla_model_best.pt', saved_paths)
|
||||
self.assertTrue(final_payload['final/best_checkpoint_path'].endswith('checkpoints/vla_model_best.pt'))
|
||||
self.assertFalse(any(path.endswith('checkpoints/vla_model_best.pt') for path in saved_paths))
|
||||
|
||||
def test_run_training_resume_ignores_best_checkpoint_without_rollout_metric(self):
|
||||
module = self._load_train_vla_module()
|
||||
@@ -594,10 +721,10 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
||||
|
||||
def fake_torch_load(path, map_location=None):
|
||||
del map_location
|
||||
path = Path(path)
|
||||
if path == resume_path:
|
||||
path = Path(path).resolve()
|
||||
if path == resume_path.resolve():
|
||||
return resume_checkpoint_state
|
||||
if path == best_path:
|
||||
if path == best_path.resolve():
|
||||
return stale_best_checkpoint_state
|
||||
raise AssertionError(f'unexpected load path: {path}')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user