chore: 计算推理频率

This commit is contained in:
gouhanke
2026-02-09 15:39:22 +08:00
parent 8b700b6d99
commit ac870f6110

View File

@@ -14,6 +14,7 @@ import sys
import os
import json
import logging
import time
import torch
import numpy as np
import hydra
@@ -95,6 +96,10 @@ class VLAEvaluator:
self.cached_actions = None
self.query_step = 0
# Timing statistics
self.inference_times = [] # Model inference time only
self.total_times = [] # Total prediction time (including preprocessing)
def reset(self):
"""Reset evaluator state"""
self.obs_buffer = {
@@ -106,6 +111,10 @@ class VLAEvaluator:
if self.smoother is not None:
self.smoother.reset()
# Reset timing stats for each episode
self.inference_times = []
self.total_times = []
def _get_image_dict(self, obs: Dict) -> Dict[str, torch.Tensor]:
images = {}
for cam_name in self.camera_names:
@@ -157,6 +166,8 @@ class VLAEvaluator:
@torch.no_grad()
def predict_action(self, obs: Dict) -> np.ndarray:
start_total = time.time()
images = self._get_image_dict(obs)
qpos = self._get_qpos_dict(obs)
@@ -164,11 +175,21 @@ class VLAEvaluator:
images = {k: v.to(self.device) for k, v in images.items()}
qpos = qpos.to(self.device)
# Measure pure model inference time
start_inference = time.time()
predicted_actions = self.agent.predict_action(
images=images,
proprioception=qpos
)
# Synchronize CUDA if using GPU to get accurate timing
if self.device == 'cuda':
torch.cuda.synchronize()
end_inference = time.time()
inference_time = end_inference - start_inference
self.inference_times.append(inference_time)
# Denormalize actions
if self.stats is not None:
if self.normalization_type == 'gaussian':
@@ -185,8 +206,34 @@ class VLAEvaluator:
if self.smoother is not None:
raw_action = self.smoother.smooth(raw_action)
end_total = time.time()
total_time = end_total - start_total
self.total_times.append(total_time)
return raw_action
def get_timing_stats(self) -> Dict:
"""Get timing statistics"""
if len(self.inference_times) == 0:
return {
'inference_fps': 0.0,
'control_fps': 0.0,
'avg_inference_time_ms': 0.0,
'avg_total_time_ms': 0.0
}
avg_inference_time = np.mean(self.inference_times)
avg_total_time = np.mean(self.total_times)
return {
'inference_fps': 1.0 / avg_inference_time if avg_inference_time > 0 else 0.0,
'control_fps': 1.0 / avg_total_time if avg_total_time > 0 else 0.0,
'avg_inference_time_ms': avg_inference_time * 1000,
'avg_total_time_ms': avg_total_time * 1000,
'num_inferences': len(self.inference_times),
'num_steps': len(self.total_times)
}
class ActionSmoother:
"""Action smoothing for smoother execution"""
@@ -313,6 +360,8 @@ def main(cfg: DictConfig):
env = make_sim_env(eval_cfg.task_name)
# Run episodes
all_stats = []
for episode_idx in range(eval_cfg.num_episodes):
print(f"\n{'='*60}")
print(f"Episode {episode_idx + 1}/{eval_cfg.num_episodes}")
@@ -333,11 +382,34 @@ def main(cfg: DictConfig):
env.render()
print(f"\nEpisode {episode_idx + 1} completed ({eval_cfg.max_timesteps} timesteps)")
# Get timing statistics for this episode
stats = evaluator.get_timing_stats()
all_stats.append(stats)
print(f"\nEpisode {episode_idx + 1} completed ({eval_cfg.max_timesteps} timesteps)")
print(f" Model Inference FPS: {stats['inference_fps']:.2f} Hz")
print(f" Control Loop FPS: {stats['control_fps']:.2f} Hz")
print(f" Avg Inference Time: {stats['avg_inference_time_ms']:.2f} ms")
print(f" Avg Total Time: {stats['avg_total_time_ms']:.2f} ms")
print(f" Total Inferences: {stats['num_inferences']}")
# Print overall statistics
print(f"\n{'='*60}")
print("Evaluation complete!")
print(f"{'='*60}\n")
print(f"{'='*60}")
if all_stats:
avg_inference_fps = np.mean([s['inference_fps'] for s in all_stats])
avg_control_fps = np.mean([s['control_fps'] for s in all_stats])
avg_inference_time = np.mean([s['avg_inference_time_ms'] for s in all_stats])
avg_total_time = np.mean([s['avg_total_time_ms'] for s in all_stats])
print(f"\nOverall Statistics ({eval_cfg.num_episodes} episodes):")
print(f" Average Model Inference FPS: {avg_inference_fps:.2f} Hz")
print(f" Average Control Loop FPS: {avg_control_fps:.2f} Hz")
print(f" Average Inference Time: {avg_inference_time:.2f} ms")
print(f" Average Total Time: {avg_total_time:.2f} ms")
print()
if __name__ == '__main__':