chore: 计算推理频率
This commit is contained in:
@@ -14,6 +14,7 @@ import sys
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import torch
|
||||
import numpy as np
|
||||
import hydra
|
||||
@@ -95,6 +96,10 @@ class VLAEvaluator:
|
||||
self.cached_actions = None
|
||||
self.query_step = 0
|
||||
|
||||
# Timing statistics
|
||||
self.inference_times = [] # Model inference time only
|
||||
self.total_times = [] # Total prediction time (including preprocessing)
|
||||
|
||||
def reset(self):
|
||||
"""Reset evaluator state"""
|
||||
self.obs_buffer = {
|
||||
@@ -106,6 +111,10 @@ class VLAEvaluator:
|
||||
if self.smoother is not None:
|
||||
self.smoother.reset()
|
||||
|
||||
# Reset timing stats for each episode
|
||||
self.inference_times = []
|
||||
self.total_times = []
|
||||
|
||||
def _get_image_dict(self, obs: Dict) -> Dict[str, torch.Tensor]:
|
||||
images = {}
|
||||
for cam_name in self.camera_names:
|
||||
@@ -157,6 +166,8 @@ class VLAEvaluator:
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action(self, obs: Dict) -> np.ndarray:
|
||||
start_total = time.time()
|
||||
|
||||
images = self._get_image_dict(obs)
|
||||
qpos = self._get_qpos_dict(obs)
|
||||
|
||||
@@ -164,11 +175,21 @@ class VLAEvaluator:
|
||||
images = {k: v.to(self.device) for k, v in images.items()}
|
||||
qpos = qpos.to(self.device)
|
||||
|
||||
# Measure pure model inference time
|
||||
start_inference = time.time()
|
||||
predicted_actions = self.agent.predict_action(
|
||||
images=images,
|
||||
proprioception=qpos
|
||||
)
|
||||
|
||||
# Synchronize CUDA if using GPU to get accurate timing
|
||||
if self.device == 'cuda':
|
||||
torch.cuda.synchronize()
|
||||
end_inference = time.time()
|
||||
|
||||
inference_time = end_inference - start_inference
|
||||
self.inference_times.append(inference_time)
|
||||
|
||||
# Denormalize actions
|
||||
if self.stats is not None:
|
||||
if self.normalization_type == 'gaussian':
|
||||
@@ -185,8 +206,34 @@ class VLAEvaluator:
|
||||
if self.smoother is not None:
|
||||
raw_action = self.smoother.smooth(raw_action)
|
||||
|
||||
end_total = time.time()
|
||||
total_time = end_total - start_total
|
||||
self.total_times.append(total_time)
|
||||
|
||||
return raw_action
|
||||
|
||||
def get_timing_stats(self) -> Dict:
|
||||
"""Get timing statistics"""
|
||||
if len(self.inference_times) == 0:
|
||||
return {
|
||||
'inference_fps': 0.0,
|
||||
'control_fps': 0.0,
|
||||
'avg_inference_time_ms': 0.0,
|
||||
'avg_total_time_ms': 0.0
|
||||
}
|
||||
|
||||
avg_inference_time = np.mean(self.inference_times)
|
||||
avg_total_time = np.mean(self.total_times)
|
||||
|
||||
return {
|
||||
'inference_fps': 1.0 / avg_inference_time if avg_inference_time > 0 else 0.0,
|
||||
'control_fps': 1.0 / avg_total_time if avg_total_time > 0 else 0.0,
|
||||
'avg_inference_time_ms': avg_inference_time * 1000,
|
||||
'avg_total_time_ms': avg_total_time * 1000,
|
||||
'num_inferences': len(self.inference_times),
|
||||
'num_steps': len(self.total_times)
|
||||
}
|
||||
|
||||
|
||||
class ActionSmoother:
|
||||
"""Action smoothing for smoother execution"""
|
||||
@@ -313,6 +360,8 @@ def main(cfg: DictConfig):
|
||||
env = make_sim_env(eval_cfg.task_name)
|
||||
|
||||
# Run episodes
|
||||
all_stats = []
|
||||
|
||||
for episode_idx in range(eval_cfg.num_episodes):
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Episode {episode_idx + 1}/{eval_cfg.num_episodes}")
|
||||
@@ -333,11 +382,34 @@ def main(cfg: DictConfig):
|
||||
|
||||
env.render()
|
||||
|
||||
print(f"\nEpisode {episode_idx + 1} completed ({eval_cfg.max_timesteps} timesteps)")
|
||||
# Get timing statistics for this episode
|
||||
stats = evaluator.get_timing_stats()
|
||||
all_stats.append(stats)
|
||||
|
||||
print(f"\nEpisode {episode_idx + 1} completed ({eval_cfg.max_timesteps} timesteps)")
|
||||
print(f" Model Inference FPS: {stats['inference_fps']:.2f} Hz")
|
||||
print(f" Control Loop FPS: {stats['control_fps']:.2f} Hz")
|
||||
print(f" Avg Inference Time: {stats['avg_inference_time_ms']:.2f} ms")
|
||||
print(f" Avg Total Time: {stats['avg_total_time_ms']:.2f} ms")
|
||||
print(f" Total Inferences: {stats['num_inferences']}")
|
||||
|
||||
# Print overall statistics
|
||||
print(f"\n{'='*60}")
|
||||
print("Evaluation complete!")
|
||||
print(f"{'='*60}\n")
|
||||
print(f"{'='*60}")
|
||||
|
||||
if all_stats:
|
||||
avg_inference_fps = np.mean([s['inference_fps'] for s in all_stats])
|
||||
avg_control_fps = np.mean([s['control_fps'] for s in all_stats])
|
||||
avg_inference_time = np.mean([s['avg_inference_time_ms'] for s in all_stats])
|
||||
avg_total_time = np.mean([s['avg_total_time_ms'] for s in all_stats])
|
||||
|
||||
print(f"\nOverall Statistics ({eval_cfg.num_episodes} episodes):")
|
||||
print(f" Average Model Inference FPS: {avg_inference_fps:.2f} Hz")
|
||||
print(f" Average Control Loop FPS: {avg_control_fps:.2f} Hz")
|
||||
print(f" Average Inference Time: {avg_inference_time:.2f} ms")
|
||||
print(f" Average Total Time: {avg_total_time:.2f} ms")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user