diff --git a/roboimi/demos/vla_scripts/eval_vla.py b/roboimi/demos/vla_scripts/eval_vla.py index a87e991..8fba2bd 100644 --- a/roboimi/demos/vla_scripts/eval_vla.py +++ b/roboimi/demos/vla_scripts/eval_vla.py @@ -14,6 +14,7 @@ import sys import os import json import logging +import time import torch import numpy as np import hydra @@ -95,6 +96,10 @@ class VLAEvaluator: self.cached_actions = None self.query_step = 0 + # Timing statistics + self.inference_times = [] # Model inference time only + self.total_times = [] # Total prediction time (including preprocessing) + def reset(self): """Reset evaluator state""" self.obs_buffer = { @@ -106,6 +111,10 @@ class VLAEvaluator: if self.smoother is not None: self.smoother.reset() + # Reset timing stats for each episode + self.inference_times = [] + self.total_times = [] + def _get_image_dict(self, obs: Dict) -> Dict[str, torch.Tensor]: images = {} for cam_name in self.camera_names: @@ -157,6 +166,8 @@ class VLAEvaluator: @torch.no_grad() def predict_action(self, obs: Dict) -> np.ndarray: + start_total = time.time() + images = self._get_image_dict(obs) qpos = self._get_qpos_dict(obs) @@ -164,11 +175,21 @@ class VLAEvaluator: images = {k: v.to(self.device) for k, v in images.items()} qpos = qpos.to(self.device) + # Measure pure model inference time + start_inference = time.time() predicted_actions = self.agent.predict_action( images=images, proprioception=qpos ) + # Synchronize CUDA if using GPU to get accurate timing + if self.device == 'cuda': + torch.cuda.synchronize() + end_inference = time.time() + + inference_time = end_inference - start_inference + self.inference_times.append(inference_time) + # Denormalize actions if self.stats is not None: if self.normalization_type == 'gaussian': @@ -185,8 +206,34 @@ class VLAEvaluator: if self.smoother is not None: raw_action = self.smoother.smooth(raw_action) + end_total = time.time() + total_time = end_total - start_total + self.total_times.append(total_time) + return raw_action + def get_timing_stats(self) -> Dict: + """Get timing statistics""" + if len(self.inference_times) == 0: + return { + 'inference_fps': 0.0, + 'control_fps': 0.0, + 'avg_inference_time_ms': 0.0, + 'avg_total_time_ms': 0.0 + } + + avg_inference_time = np.mean(self.inference_times) + avg_total_time = np.mean(self.total_times) + + return { + 'inference_fps': 1.0 / avg_inference_time if avg_inference_time > 0 else 0.0, + 'control_fps': 1.0 / avg_total_time if avg_total_time > 0 else 0.0, + 'avg_inference_time_ms': avg_inference_time * 1000, + 'avg_total_time_ms': avg_total_time * 1000, + 'num_inferences': len(self.inference_times), + 'num_steps': len(self.total_times) + } + class ActionSmoother: """Action smoothing for smoother execution""" @@ -313,6 +360,8 @@ def main(cfg: DictConfig): env = make_sim_env(eval_cfg.task_name) # Run episodes + all_stats = [] + for episode_idx in range(eval_cfg.num_episodes): print(f"\n{'='*60}") print(f"Episode {episode_idx + 1}/{eval_cfg.num_episodes}") @@ -333,11 +382,34 @@ def main(cfg: DictConfig): env.render() - print(f"\nEpisode {episode_idx + 1} completed ({eval_cfg.max_timesteps} timesteps)") + # Get timing statistics for this episode + stats = evaluator.get_timing_stats() + all_stats.append(stats) + print(f"\nEpisode {episode_idx + 1} completed ({eval_cfg.max_timesteps} timesteps)") + print(f" Model Inference FPS: {stats['inference_fps']:.2f} Hz") + print(f" Control Loop FPS: {stats['control_fps']:.2f} Hz") + print(f" Avg Inference Time: {stats['avg_inference_time_ms']:.2f} ms") + print(f" Avg Total Time: {stats['avg_total_time_ms']:.2f} ms") + print(f" Total Inferences: {stats['num_inferences']}") + + # Print overall statistics print(f"\n{'='*60}") print("Evaluation complete!") - print(f"{'='*60}\n") + print(f"{'='*60}") + + if all_stats: + avg_inference_fps = np.mean([s['inference_fps'] for s in all_stats]) + avg_control_fps = np.mean([s['control_fps'] for s in all_stats]) + avg_inference_time = np.mean([s['avg_inference_time_ms'] for s in all_stats]) + avg_total_time = np.mean([s['avg_total_time_ms'] for s in all_stats]) + + print(f"\nOverall Statistics ({eval_cfg.num_episodes} episodes):") + print(f" Average Model Inference FPS: {avg_inference_fps:.2f} Hz") + print(f" Average Control Loop FPS: {avg_control_fps:.2f} Hz") + print(f" Average Inference Time: {avg_inference_time:.2f} ms") + print(f" Average Total Time: {avg_total_time:.2f} ms") + print() if __name__ == '__main__':