feat: add rollout trajectory image artifacts and swanlab logging

This commit is contained in:
Logic
2026-04-03 09:39:16 +08:00
parent 48f0eb8dd0
commit 0586a6e6c7
8 changed files with 626 additions and 21 deletions

View File

@@ -234,7 +234,28 @@ class TrainVLARolloutValidationTest(unittest.TestCase):
}
)
agent = _FakeAgent()
rollout_mock = mock.Mock(side_effect=[{'avg_reward': 2.0}, {'avg_reward': 1.0}])
rollout_mock = mock.Mock(
side_effect=[
{
'avg_reward': 2.0,
'episodes': [
{
'episode_index': 0,
'artifact_paths': {'trajectory_image': 'artifacts/epoch_49_front.png'},
},
],
},
{
'avg_reward': 1.0,
'episodes': [
{
'episode_index': 0,
'artifact_paths': {'trajectory_image': 'artifacts/epoch_99_front.png'},
},
],
},
]
)
swanlab_log_mock = mock.Mock()
saved_checkpoints = []
@@ -281,17 +302,22 @@ class TrainVLARolloutValidationTest(unittest.TestCase):
self.assertEqual(rollout_mock.call_count, 2)
first_rollout_cfg = rollout_mock.call_args_list[0].args[0]
second_rollout_cfg = rollout_mock.call_args_list[1].args[0]
self.assertEqual(first_rollout_cfg.eval.ckpt_path, 'checkpoints/vla_model_step_49.pt')
self.assertEqual(second_rollout_cfg.eval.ckpt_path, 'checkpoints/vla_model_step_99.pt')
self.assertTrue(first_rollout_cfg.eval.ckpt_path.endswith('checkpoints/vla_model_step_49.pt'))
self.assertTrue(second_rollout_cfg.eval.ckpt_path.endswith('checkpoints/vla_model_step_99.pt'))
self.assertEqual(first_rollout_cfg.eval.num_episodes, 3)
self.assertTrue(first_rollout_cfg.eval.headless)
self.assertEqual(first_rollout_cfg.eval.device, 'cpu')
self.assertFalse(first_rollout_cfg.eval.verbose_action)
self.assertFalse(first_rollout_cfg.eval.record_video)
self.assertTrue(first_rollout_cfg.eval.save_trajectory_image)
self.assertEqual(first_rollout_cfg.eval.trajectory_image_camera_name, 'front')
self.assertEqual(cfg.eval.ckpt_path, 'unused.pt')
self.assertEqual(cfg.eval.num_episodes, 99)
self.assertFalse(cfg.eval.headless)
self.assertEqual(cfg.eval.device, 'cpu')
self.assertFalse(cfg.eval.verbose_action)
self.assertNotIn('save_trajectory_image', cfg.eval)
self.assertNotIn('trajectory_image_camera_name', cfg.eval)
rollout_reward_logs = [
call.args[1]['rollout/avg_reward']
@@ -769,10 +795,8 @@ class TrainVLARolloutValidationTest(unittest.TestCase):
'dataset_len': 1,
},
)
self.assertEqual(
[path for path, _payload in saved_checkpoints],
['checkpoints/vla_model_final.pt'],
)
self.assertEqual(len(saved_checkpoints), 1)
self.assertTrue(saved_checkpoints[0][0].endswith('checkpoints/vla_model_final.pt'))
if __name__ == '__main__':