feat: add rollout trajectory image artifacts and swanlab logging
This commit is contained in:
@@ -234,7 +234,28 @@ class TrainVLARolloutValidationTest(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
agent = _FakeAgent()
|
||||
rollout_mock = mock.Mock(side_effect=[{'avg_reward': 2.0}, {'avg_reward': 1.0}])
|
||||
rollout_mock = mock.Mock(
|
||||
side_effect=[
|
||||
{
|
||||
'avg_reward': 2.0,
|
||||
'episodes': [
|
||||
{
|
||||
'episode_index': 0,
|
||||
'artifact_paths': {'trajectory_image': 'artifacts/epoch_49_front.png'},
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
'avg_reward': 1.0,
|
||||
'episodes': [
|
||||
{
|
||||
'episode_index': 0,
|
||||
'artifact_paths': {'trajectory_image': 'artifacts/epoch_99_front.png'},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
)
|
||||
swanlab_log_mock = mock.Mock()
|
||||
saved_checkpoints = []
|
||||
|
||||
@@ -281,17 +302,22 @@ class TrainVLARolloutValidationTest(unittest.TestCase):
|
||||
self.assertEqual(rollout_mock.call_count, 2)
|
||||
first_rollout_cfg = rollout_mock.call_args_list[0].args[0]
|
||||
second_rollout_cfg = rollout_mock.call_args_list[1].args[0]
|
||||
self.assertEqual(first_rollout_cfg.eval.ckpt_path, 'checkpoints/vla_model_step_49.pt')
|
||||
self.assertEqual(second_rollout_cfg.eval.ckpt_path, 'checkpoints/vla_model_step_99.pt')
|
||||
self.assertTrue(first_rollout_cfg.eval.ckpt_path.endswith('checkpoints/vla_model_step_49.pt'))
|
||||
self.assertTrue(second_rollout_cfg.eval.ckpt_path.endswith('checkpoints/vla_model_step_99.pt'))
|
||||
self.assertEqual(first_rollout_cfg.eval.num_episodes, 3)
|
||||
self.assertTrue(first_rollout_cfg.eval.headless)
|
||||
self.assertEqual(first_rollout_cfg.eval.device, 'cpu')
|
||||
self.assertFalse(first_rollout_cfg.eval.verbose_action)
|
||||
self.assertFalse(first_rollout_cfg.eval.record_video)
|
||||
self.assertTrue(first_rollout_cfg.eval.save_trajectory_image)
|
||||
self.assertEqual(first_rollout_cfg.eval.trajectory_image_camera_name, 'front')
|
||||
self.assertEqual(cfg.eval.ckpt_path, 'unused.pt')
|
||||
self.assertEqual(cfg.eval.num_episodes, 99)
|
||||
self.assertFalse(cfg.eval.headless)
|
||||
self.assertEqual(cfg.eval.device, 'cpu')
|
||||
self.assertFalse(cfg.eval.verbose_action)
|
||||
self.assertNotIn('save_trajectory_image', cfg.eval)
|
||||
self.assertNotIn('trajectory_image_camera_name', cfg.eval)
|
||||
|
||||
rollout_reward_logs = [
|
||||
call.args[1]['rollout/avg_reward']
|
||||
@@ -769,10 +795,8 @@ class TrainVLARolloutValidationTest(unittest.TestCase):
|
||||
'dataset_len': 1,
|
||||
},
|
||||
)
|
||||
self.assertEqual(
|
||||
[path for path, _payload in saved_checkpoints],
|
||||
['checkpoints/vla_model_final.pt'],
|
||||
)
|
||||
self.assertEqual(len(saved_checkpoints), 1)
|
||||
self.assertTrue(saved_checkpoints[0][0].endswith('checkpoints/vla_model_final.pt'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user