Files
roboimi/roboimi/demos/vla_scripts/eval_vla.py
2026-02-06 16:08:56 +08:00

345 lines
11 KiB
Python

"""
VLA Policy Evaluation Script (Hydra-based)
This script evaluates a trained Vision-Language-Action (VLA) policy
in the MuJoCo simulation environment.
Usage:
python roboimi/demos/eval_vla.py
python roboimi/demos/eval_vla.py ckpt_path=checkpoints/vla_model_step_8000.pt num_episodes=5
python roboimi/demos/eval_vla.py use_smoothing=true smooth_alpha=0.5
"""
import sys
import os
import json
import logging
import torch
import numpy as np
import hydra
from pathlib import Path
from typing import Dict, List
from tqdm import tqdm
from omegaconf import DictConfig, OmegaConf
from hydra.utils import instantiate
from roboimi.envs.double_pos_ctrl_env import make_sim_env
from roboimi.utils.act_ex_utils import sample_transfer_pose
from einops import rearrange
# Ensure correct import path
sys.path.append(os.getcwd())
log = logging.getLogger(__name__)
# Register resolver for list length in configs (e.g., ${len:${data.camera_names}})
if not OmegaConf.has_resolver("len"):
OmegaConf.register_new_resolver("len", lambda x: len(x))
class VLAEvaluator:
"""
VLA Policy Evaluator for MuJoCo Simulation
"""
def __init__(
self,
agent: torch.nn.Module,
device: str = 'cuda',
camera_names: List[str] = ['r_vis', 'top', 'front'],
num_queries: int = 1,
obs_horizon: int = 2,
pred_horizon: int = 16,
use_smoothing: bool = False,
smooth_method: str = 'ema',
smooth_alpha: float = 0.3,
dataset_stats: dict = None
):
self.agent = agent.to(device)
self.device = device
self.camera_names = camera_names
self.num_queries = num_queries
self.obs_horizon = obs_horizon
self.pred_horizon = pred_horizon
# Dataset statistics for normalization/denormalization
self.stats = dataset_stats
if self.stats is not None:
self.normalization_type = self.stats.get('normalization_type', 'gaussian')
self.qpos_mean = torch.tensor(self.stats['qpos_mean'], dtype=torch.float32)
self.qpos_std = torch.tensor(self.stats['qpos_std'], dtype=torch.float32)
self.qpos_min = torch.tensor(self.stats.get('qpos_min', []), dtype=torch.float32)
self.qpos_max = torch.tensor(self.stats.get('qpos_max', []), dtype=torch.float32)
self.action_mean = torch.tensor(self.stats['action_mean'], dtype=torch.float32)
self.action_std = torch.tensor(self.stats['action_std'], dtype=torch.float32)
self.action_min = torch.tensor(self.stats.get('action_min', []), dtype=torch.float32)
self.action_max = torch.tensor(self.stats.get('action_max', []), dtype=torch.float32)
else:
self.normalization_type = None
# Action smoothing
self.use_smoothing = use_smoothing
self.smooth_method = smooth_method
self.smooth_alpha = smooth_alpha
self.smoother = ActionSmoother(
action_dim=16,
method=smooth_method,
alpha=smooth_alpha
) if use_smoothing else None
# Observation buffer for obs_horizon
self.obs_buffer = {
'images': {cam: [] for cam in camera_names},
'qpos': []
}
self.cached_actions = None
self.query_step = 0
def reset(self):
"""Reset evaluator state"""
self.obs_buffer = {
'images': {cam: [] for cam in self.camera_names},
'qpos': []
}
self.cached_actions = None
self.query_step = 0
if self.smoother is not None:
self.smoother.reset()
def _get_image_dict(self, obs: Dict) -> Dict[str, torch.Tensor]:
images = {}
for cam_name in self.camera_names:
img = obs['images'][cam_name]
img = rearrange(img, 'h w c -> c h w')
img = torch.from_numpy(img / 255.0).float()
images[cam_name] = img
image_dict = {}
for cam_name in self.camera_names:
cam_images = self.obs_buffer['images'][cam_name]
cam_images.append(images[cam_name])
while len(cam_images) < self.obs_horizon:
cam_images.insert(0, cam_images[0])
if len(cam_images) > self.obs_horizon:
cam_images = cam_images[-self.obs_horizon:]
img_tensor = torch.stack(cam_images, dim=0).unsqueeze(0)
image_dict[cam_name] = img_tensor
self.obs_buffer['images'][cam_name] = cam_images[-self.obs_horizon:]
return image_dict
def _get_qpos_dict(self, obs: Dict) -> torch.Tensor:
qpos = obs['qpos']
qpos = torch.from_numpy(qpos).float()
self.obs_buffer['qpos'].append(qpos)
while len(self.obs_buffer['qpos']) < self.obs_horizon:
self.obs_buffer['qpos'].insert(0, self.obs_buffer['qpos'][0])
if len(self.obs_buffer['qpos']) > self.obs_horizon:
self.obs_buffer['qpos'] = self.obs_buffer['qpos'][-self.obs_horizon:]
qpos_tensor = torch.stack(self.obs_buffer['qpos'], dim=0).unsqueeze(0) # (1, obs_horizon, obs_dim)
# Normalize qpos
if self.stats is not None:
if self.normalization_type == 'gaussian':
qpos_tensor = (qpos_tensor - self.qpos_mean) / self.qpos_std
else: # min_max: normalize to [-1, 1]
qpos_tensor = 2 * (qpos_tensor - self.qpos_min) / (self.qpos_max - self.qpos_min) - 1
return qpos_tensor
@torch.no_grad()
def predict_action(self, obs: Dict) -> np.ndarray:
images = self._get_image_dict(obs)
qpos = self._get_qpos_dict(obs)
if self.cached_actions is None or self.query_step % self.num_queries == 0:
images = {k: v.to(self.device) for k, v in images.items()}
qpos = qpos.to(self.device)
predicted_actions = self.agent.predict_action(
images=images,
proprioception=qpos
)
# Denormalize actions
if self.stats is not None:
if self.normalization_type == 'gaussian':
predicted_actions = predicted_actions * self.action_std.to(self.device) + self.action_mean.to(self.device)
else: # min_max
predicted_actions = (predicted_actions + 1) / 2 * (self.action_max.to(self.device) - self.action_min.to(self.device)) + self.action_min.to(self.device)
self.cached_actions = predicted_actions.squeeze(0).cpu().numpy()
self.query_step = 0
raw_action = self.cached_actions[self.query_step]
self.query_step += 1
if self.smoother is not None:
raw_action = self.smoother.smooth(raw_action)
return raw_action
class ActionSmoother:
"""Action smoothing for smoother execution"""
def __init__(self, action_dim: int, method: str = 'ema', alpha: float = 0.3):
self.action_dim = action_dim
self.method = method
self.alpha = alpha
self.prev_action = None
def smooth(self, action: np.ndarray) -> np.ndarray:
if self.method == 'ema':
if self.prev_action is None:
smoothed = action
else:
smoothed = self.alpha * action + (1 - self.alpha) * self.prev_action
self.prev_action = smoothed
return smoothed
else:
return action
def reset(self):
self.prev_action = None
def load_checkpoint(
ckpt_path: str,
agent_cfg: DictConfig,
device: str = 'cuda'
) -> torch.nn.Module:
"""
Load trained VLA model from checkpoint using Hydra agent config.
Args:
ckpt_path: Path to checkpoint file (.pt)
agent_cfg: Hydra agent config for instantiation
device: Device to load model on
Returns:
Loaded VLAAgent model
"""
from pathlib import Path as PathLib
ckpt_path = PathLib(ckpt_path).absolute()
if not ckpt_path.exists():
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
log.info(f"Loading checkpoint from {ckpt_path}")
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
log.info(f"Checkpoint keys: {checkpoint.keys()}")
# Instantiate agent from Hydra config
log.info("Instantiating agent from config...")
agent = instantiate(agent_cfg)
# Load model state
agent.load_state_dict(checkpoint['model_state_dict'])
log.info(f"✅ Model state loaded (step: {checkpoint.get('step', 'unknown')})")
# Load dataset statistics for denormalization
stats = checkpoint.get('dataset_stats', None)
if stats is not None:
log.info(f"✅ Dataset statistics loaded (normalization: {stats.get('normalization_type', 'gaussian')})")
else:
# Fallback: try external JSON file (兼容旧 checkpoint)
stats_path = ckpt_path.parent / 'dataset_stats.json'
if stats_path.exists():
with open(stats_path, 'r') as f:
stats = json.load(f)
log.info("✅ Dataset statistics loaded from external JSON (legacy)")
else:
log.warning("⚠️ No dataset statistics found. Actions will not be denormalized!")
agent.eval()
agent.to(device)
log.info(f"✅ Model loaded successfully on {device}")
return agent, stats
@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config")
def main(cfg: DictConfig):
"""
VLA Evaluation Script with Hydra Configuration.
All eval parameters come from vla/conf/eval.yaml, merged into cfg.
Override on command line: python eval_vla.py eval.ckpt_path=... eval.num_episodes=5
"""
# Print configuration
print("=" * 80)
print("VLA Evaluation Configuration:")
print("=" * 80)
print(OmegaConf.to_yaml(cfg))
print("=" * 80)
eval_cfg = cfg.eval
device = eval_cfg.device
camera_names = list(eval_cfg.camera_names)
# Load model
log.info(f"🚀 Loading model from {eval_cfg.ckpt_path}...")
agent, dataset_stats = load_checkpoint(
ckpt_path=eval_cfg.ckpt_path,
agent_cfg=cfg.agent,
device=device
)
# Create evaluator
evaluator = VLAEvaluator(
agent=agent,
device=device,
camera_names=camera_names,
num_queries=eval_cfg.num_queries,
obs_horizon=eval_cfg.obs_horizon,
use_smoothing=eval_cfg.use_smoothing,
smooth_method=eval_cfg.smooth_method,
smooth_alpha=eval_cfg.smooth_alpha,
dataset_stats=dataset_stats
)
# Create environment
env = make_sim_env(eval_cfg.task_name)
# Run episodes
for episode_idx in range(eval_cfg.num_episodes):
print(f"\n{'='*60}")
print(f"Episode {episode_idx + 1}/{eval_cfg.num_episodes}")
print(f"{'='*60}\n")
box_pos = sample_transfer_pose()
env.reset(box_pos)
evaluator.reset()
with torch.inference_mode():
for t in tqdm(range(eval_cfg.max_timesteps), desc=f"Episode {episode_idx + 1}"):
obs = env._get_image_obs()
qpos_obs = env._get_qpos_obs()
obs['qpos'] = qpos_obs['qpos']
action = evaluator.predict_action(obs)
env.step_jnt(action)
env.render()
print(f"\nEpisode {episode_idx + 1} completed ({eval_cfg.max_timesteps} timesteps)")
print(f"\n{'='*60}")
print("Evaluation complete!")
print(f"{'='*60}\n")
if __name__ == '__main__':
main()