345 lines
11 KiB
Python
345 lines
11 KiB
Python
"""
|
|
VLA Policy Evaluation Script (Hydra-based)
|
|
|
|
This script evaluates a trained Vision-Language-Action (VLA) policy
|
|
in the MuJoCo simulation environment.
|
|
|
|
Usage:
|
|
python roboimi/demos/eval_vla.py
|
|
python roboimi/demos/eval_vla.py ckpt_path=checkpoints/vla_model_step_8000.pt num_episodes=5
|
|
python roboimi/demos/eval_vla.py use_smoothing=true smooth_alpha=0.5
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
import json
|
|
import logging
|
|
import torch
|
|
import numpy as np
|
|
import hydra
|
|
from pathlib import Path
|
|
from typing import Dict, List
|
|
from tqdm import tqdm
|
|
from omegaconf import DictConfig, OmegaConf
|
|
from hydra.utils import instantiate
|
|
|
|
from roboimi.envs.double_pos_ctrl_env import make_sim_env
|
|
from roboimi.utils.act_ex_utils import sample_transfer_pose
|
|
from einops import rearrange
|
|
|
|
# Ensure correct import path
|
|
sys.path.append(os.getcwd())
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
# Register resolver for list length in configs (e.g., ${len:${data.camera_names}})
|
|
if not OmegaConf.has_resolver("len"):
|
|
OmegaConf.register_new_resolver("len", lambda x: len(x))
|
|
|
|
|
|
class VLAEvaluator:
|
|
"""
|
|
VLA Policy Evaluator for MuJoCo Simulation
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
agent: torch.nn.Module,
|
|
device: str = 'cuda',
|
|
camera_names: List[str] = ['r_vis', 'top', 'front'],
|
|
num_queries: int = 1,
|
|
obs_horizon: int = 2,
|
|
pred_horizon: int = 16,
|
|
use_smoothing: bool = False,
|
|
smooth_method: str = 'ema',
|
|
smooth_alpha: float = 0.3,
|
|
dataset_stats: dict = None
|
|
):
|
|
self.agent = agent.to(device)
|
|
self.device = device
|
|
self.camera_names = camera_names
|
|
self.num_queries = num_queries
|
|
self.obs_horizon = obs_horizon
|
|
self.pred_horizon = pred_horizon
|
|
|
|
# Dataset statistics for normalization/denormalization
|
|
self.stats = dataset_stats
|
|
if self.stats is not None:
|
|
self.normalization_type = self.stats.get('normalization_type', 'gaussian')
|
|
self.qpos_mean = torch.tensor(self.stats['qpos_mean'], dtype=torch.float32)
|
|
self.qpos_std = torch.tensor(self.stats['qpos_std'], dtype=torch.float32)
|
|
self.qpos_min = torch.tensor(self.stats.get('qpos_min', []), dtype=torch.float32)
|
|
self.qpos_max = torch.tensor(self.stats.get('qpos_max', []), dtype=torch.float32)
|
|
self.action_mean = torch.tensor(self.stats['action_mean'], dtype=torch.float32)
|
|
self.action_std = torch.tensor(self.stats['action_std'], dtype=torch.float32)
|
|
self.action_min = torch.tensor(self.stats.get('action_min', []), dtype=torch.float32)
|
|
self.action_max = torch.tensor(self.stats.get('action_max', []), dtype=torch.float32)
|
|
else:
|
|
self.normalization_type = None
|
|
|
|
# Action smoothing
|
|
self.use_smoothing = use_smoothing
|
|
self.smooth_method = smooth_method
|
|
self.smooth_alpha = smooth_alpha
|
|
self.smoother = ActionSmoother(
|
|
action_dim=16,
|
|
method=smooth_method,
|
|
alpha=smooth_alpha
|
|
) if use_smoothing else None
|
|
|
|
# Observation buffer for obs_horizon
|
|
self.obs_buffer = {
|
|
'images': {cam: [] for cam in camera_names},
|
|
'qpos': []
|
|
}
|
|
self.cached_actions = None
|
|
self.query_step = 0
|
|
|
|
def reset(self):
|
|
"""Reset evaluator state"""
|
|
self.obs_buffer = {
|
|
'images': {cam: [] for cam in self.camera_names},
|
|
'qpos': []
|
|
}
|
|
self.cached_actions = None
|
|
self.query_step = 0
|
|
if self.smoother is not None:
|
|
self.smoother.reset()
|
|
|
|
def _get_image_dict(self, obs: Dict) -> Dict[str, torch.Tensor]:
|
|
images = {}
|
|
for cam_name in self.camera_names:
|
|
img = obs['images'][cam_name]
|
|
img = rearrange(img, 'h w c -> c h w')
|
|
img = torch.from_numpy(img / 255.0).float()
|
|
images[cam_name] = img
|
|
|
|
image_dict = {}
|
|
for cam_name in self.camera_names:
|
|
cam_images = self.obs_buffer['images'][cam_name]
|
|
cam_images.append(images[cam_name])
|
|
|
|
while len(cam_images) < self.obs_horizon:
|
|
cam_images.insert(0, cam_images[0])
|
|
|
|
if len(cam_images) > self.obs_horizon:
|
|
cam_images = cam_images[-self.obs_horizon:]
|
|
|
|
img_tensor = torch.stack(cam_images, dim=0).unsqueeze(0)
|
|
image_dict[cam_name] = img_tensor
|
|
|
|
self.obs_buffer['images'][cam_name] = cam_images[-self.obs_horizon:]
|
|
|
|
return image_dict
|
|
|
|
def _get_qpos_dict(self, obs: Dict) -> torch.Tensor:
|
|
qpos = obs['qpos']
|
|
qpos = torch.from_numpy(qpos).float()
|
|
|
|
self.obs_buffer['qpos'].append(qpos)
|
|
|
|
while len(self.obs_buffer['qpos']) < self.obs_horizon:
|
|
self.obs_buffer['qpos'].insert(0, self.obs_buffer['qpos'][0])
|
|
|
|
if len(self.obs_buffer['qpos']) > self.obs_horizon:
|
|
self.obs_buffer['qpos'] = self.obs_buffer['qpos'][-self.obs_horizon:]
|
|
|
|
qpos_tensor = torch.stack(self.obs_buffer['qpos'], dim=0).unsqueeze(0) # (1, obs_horizon, obs_dim)
|
|
|
|
# Normalize qpos
|
|
if self.stats is not None:
|
|
if self.normalization_type == 'gaussian':
|
|
qpos_tensor = (qpos_tensor - self.qpos_mean) / self.qpos_std
|
|
else: # min_max: normalize to [-1, 1]
|
|
qpos_tensor = 2 * (qpos_tensor - self.qpos_min) / (self.qpos_max - self.qpos_min) - 1
|
|
|
|
return qpos_tensor
|
|
|
|
@torch.no_grad()
|
|
def predict_action(self, obs: Dict) -> np.ndarray:
|
|
images = self._get_image_dict(obs)
|
|
qpos = self._get_qpos_dict(obs)
|
|
|
|
if self.cached_actions is None or self.query_step % self.num_queries == 0:
|
|
images = {k: v.to(self.device) for k, v in images.items()}
|
|
qpos = qpos.to(self.device)
|
|
|
|
predicted_actions = self.agent.predict_action(
|
|
images=images,
|
|
proprioception=qpos
|
|
)
|
|
|
|
# Denormalize actions
|
|
if self.stats is not None:
|
|
if self.normalization_type == 'gaussian':
|
|
predicted_actions = predicted_actions * self.action_std.to(self.device) + self.action_mean.to(self.device)
|
|
else: # min_max
|
|
predicted_actions = (predicted_actions + 1) / 2 * (self.action_max.to(self.device) - self.action_min.to(self.device)) + self.action_min.to(self.device)
|
|
|
|
self.cached_actions = predicted_actions.squeeze(0).cpu().numpy()
|
|
self.query_step = 0
|
|
|
|
raw_action = self.cached_actions[self.query_step]
|
|
self.query_step += 1
|
|
|
|
if self.smoother is not None:
|
|
raw_action = self.smoother.smooth(raw_action)
|
|
|
|
return raw_action
|
|
|
|
|
|
class ActionSmoother:
|
|
"""Action smoothing for smoother execution"""
|
|
|
|
def __init__(self, action_dim: int, method: str = 'ema', alpha: float = 0.3):
|
|
self.action_dim = action_dim
|
|
self.method = method
|
|
self.alpha = alpha
|
|
self.prev_action = None
|
|
|
|
def smooth(self, action: np.ndarray) -> np.ndarray:
|
|
if self.method == 'ema':
|
|
if self.prev_action is None:
|
|
smoothed = action
|
|
else:
|
|
smoothed = self.alpha * action + (1 - self.alpha) * self.prev_action
|
|
self.prev_action = smoothed
|
|
return smoothed
|
|
else:
|
|
return action
|
|
|
|
def reset(self):
|
|
self.prev_action = None
|
|
|
|
|
|
def load_checkpoint(
|
|
ckpt_path: str,
|
|
agent_cfg: DictConfig,
|
|
device: str = 'cuda'
|
|
) -> torch.nn.Module:
|
|
"""
|
|
Load trained VLA model from checkpoint using Hydra agent config.
|
|
|
|
Args:
|
|
ckpt_path: Path to checkpoint file (.pt)
|
|
agent_cfg: Hydra agent config for instantiation
|
|
device: Device to load model on
|
|
|
|
Returns:
|
|
Loaded VLAAgent model
|
|
"""
|
|
from pathlib import Path as PathLib
|
|
|
|
ckpt_path = PathLib(ckpt_path).absolute()
|
|
if not ckpt_path.exists():
|
|
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
|
|
|
|
log.info(f"Loading checkpoint from {ckpt_path}")
|
|
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
|
|
log.info(f"Checkpoint keys: {checkpoint.keys()}")
|
|
|
|
# Instantiate agent from Hydra config
|
|
log.info("Instantiating agent from config...")
|
|
agent = instantiate(agent_cfg)
|
|
|
|
# Load model state
|
|
agent.load_state_dict(checkpoint['model_state_dict'])
|
|
log.info(f"✅ Model state loaded (step: {checkpoint.get('step', 'unknown')})")
|
|
|
|
# Load dataset statistics for denormalization
|
|
stats = checkpoint.get('dataset_stats', None)
|
|
|
|
if stats is not None:
|
|
log.info(f"✅ Dataset statistics loaded (normalization: {stats.get('normalization_type', 'gaussian')})")
|
|
else:
|
|
# Fallback: try external JSON file (兼容旧 checkpoint)
|
|
stats_path = ckpt_path.parent / 'dataset_stats.json'
|
|
if stats_path.exists():
|
|
with open(stats_path, 'r') as f:
|
|
stats = json.load(f)
|
|
log.info("✅ Dataset statistics loaded from external JSON (legacy)")
|
|
else:
|
|
log.warning("⚠️ No dataset statistics found. Actions will not be denormalized!")
|
|
|
|
agent.eval()
|
|
agent.to(device)
|
|
|
|
log.info(f"✅ Model loaded successfully on {device}")
|
|
return agent, stats
|
|
|
|
|
|
@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config")
|
|
def main(cfg: DictConfig):
|
|
"""
|
|
VLA Evaluation Script with Hydra Configuration.
|
|
|
|
All eval parameters come from vla/conf/eval.yaml, merged into cfg.
|
|
Override on command line: python eval_vla.py eval.ckpt_path=... eval.num_episodes=5
|
|
"""
|
|
|
|
# Print configuration
|
|
print("=" * 80)
|
|
print("VLA Evaluation Configuration:")
|
|
print("=" * 80)
|
|
print(OmegaConf.to_yaml(cfg))
|
|
print("=" * 80)
|
|
|
|
eval_cfg = cfg.eval
|
|
device = eval_cfg.device
|
|
camera_names = list(eval_cfg.camera_names)
|
|
|
|
# Load model
|
|
log.info(f"🚀 Loading model from {eval_cfg.ckpt_path}...")
|
|
agent, dataset_stats = load_checkpoint(
|
|
ckpt_path=eval_cfg.ckpt_path,
|
|
agent_cfg=cfg.agent,
|
|
device=device
|
|
)
|
|
|
|
# Create evaluator
|
|
evaluator = VLAEvaluator(
|
|
agent=agent,
|
|
device=device,
|
|
camera_names=camera_names,
|
|
num_queries=eval_cfg.num_queries,
|
|
obs_horizon=eval_cfg.obs_horizon,
|
|
use_smoothing=eval_cfg.use_smoothing,
|
|
smooth_method=eval_cfg.smooth_method,
|
|
smooth_alpha=eval_cfg.smooth_alpha,
|
|
dataset_stats=dataset_stats
|
|
)
|
|
|
|
# Create environment
|
|
env = make_sim_env(eval_cfg.task_name)
|
|
|
|
# Run episodes
|
|
for episode_idx in range(eval_cfg.num_episodes):
|
|
print(f"\n{'='*60}")
|
|
print(f"Episode {episode_idx + 1}/{eval_cfg.num_episodes}")
|
|
print(f"{'='*60}\n")
|
|
|
|
box_pos = sample_transfer_pose()
|
|
env.reset(box_pos)
|
|
evaluator.reset()
|
|
|
|
with torch.inference_mode():
|
|
for t in tqdm(range(eval_cfg.max_timesteps), desc=f"Episode {episode_idx + 1}"):
|
|
obs = env._get_image_obs()
|
|
qpos_obs = env._get_qpos_obs()
|
|
obs['qpos'] = qpos_obs['qpos']
|
|
|
|
action = evaluator.predict_action(obs)
|
|
env.step_jnt(action)
|
|
|
|
env.render()
|
|
|
|
print(f"\nEpisode {episode_idx + 1} completed ({eval_cfg.max_timesteps} timesteps)")
|
|
|
|
print(f"\n{'='*60}")
|
|
print("Evaluation complete!")
|
|
print(f"{'='*60}\n")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|