chore: 删除多余脚本
This commit is contained in:
@@ -1,532 +0,0 @@
|
||||
"""
|
||||
VLA Policy Evaluation Script
|
||||
|
||||
This script evaluates a trained Vision-Language-Action (VLA) policy
|
||||
in the MuJoCo simulation environment.
|
||||
|
||||
Usage:
|
||||
python roboimi/demos/eval_vla.py --ckpt_path checkpoints/vla_model_best.pt --num_episodes 3
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
from tqdm import tqdm
|
||||
|
||||
from roboimi.envs.double_pos_ctrl_env import make_sim_env
|
||||
from roboimi.utils.act_ex_utils import sample_transfer_pose
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class VLAEvaluator:
|
||||
"""
|
||||
VLA Policy Evaluator for MuJoCo Simulation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent: torch.nn.Module,
|
||||
device: str = 'cuda',
|
||||
camera_names: List[str] = ['r_vis', 'top', 'front'],
|
||||
num_queries: int = 1,
|
||||
obs_horizon: int = 2,
|
||||
pred_horizon: int = 16,
|
||||
use_smoothing: bool = False,
|
||||
smooth_method: str = 'ema',
|
||||
smooth_alpha: float = 0.3
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
agent: Trained VLAAgent
|
||||
device: Device for inference
|
||||
camera_names: List of camera names to use
|
||||
num_queries: How often to query the policy (in timesteps)
|
||||
obs_horizon: Number of observations to use as context
|
||||
pred_horizon: Number of future actions to predict
|
||||
use_smoothing: Whether to apply action smoothing
|
||||
smooth_method: Smoothing method ('ema', 'moving_avg', 'lowpass')
|
||||
smooth_alpha: Smoothing coefficient
|
||||
"""
|
||||
self.agent = agent.to(device)
|
||||
self.device = device
|
||||
self.camera_names = camera_names
|
||||
self.num_queries = num_queries
|
||||
self.obs_horizon = obs_horizon
|
||||
self.pred_horizon = pred_horizon
|
||||
|
||||
# Action smoothing
|
||||
self.use_smoothing = use_smoothing
|
||||
self.smooth_method = smooth_method
|
||||
self.smooth_alpha = smooth_alpha
|
||||
self.smoother = ActionSmoother(
|
||||
action_dim=16, # Assuming 16-dim actions
|
||||
method=smooth_method,
|
||||
alpha=smooth_alpha
|
||||
) if use_smoothing else None
|
||||
|
||||
# Observation buffer for obs_horizon
|
||||
self.obs_buffer = {
|
||||
'images': {cam: [] for cam in camera_names},
|
||||
'qpos': []
|
||||
}
|
||||
self.cached_actions = None
|
||||
self.query_step = 0
|
||||
|
||||
def reset(self):
|
||||
"""Reset evaluator state"""
|
||||
self.obs_buffer = {
|
||||
'images': {cam: [] for cam in self.camera_names},
|
||||
'qpos': []
|
||||
}
|
||||
self.cached_actions = None
|
||||
self.query_step = 0
|
||||
if self.smoother is not None:
|
||||
self.smoother.reset()
|
||||
|
||||
def _get_image_dict(self, obs: Dict) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Extract and preprocess images from observation
|
||||
|
||||
Args:
|
||||
obs: Environment observation dict
|
||||
|
||||
Returns:
|
||||
Dict mapping camera names to image tensors (B, obs_horizon, C, H, W)
|
||||
"""
|
||||
images = {}
|
||||
for cam_name in self.camera_names:
|
||||
# Extract image: (H, W, C) -> (C, H, W)
|
||||
img = obs['images'][cam_name]
|
||||
img = rearrange(img, 'h w c -> c h w')
|
||||
img = torch.from_numpy(img / 255.0).float()
|
||||
images[cam_name] = img # (C, H, W)
|
||||
|
||||
# Stack to create batch dimension
|
||||
image_dict = {}
|
||||
for cam_name in self.camera_names:
|
||||
# Collect obs_horizon frames
|
||||
cam_images = self.obs_buffer['images'][cam_name]
|
||||
cam_images.append(images[cam_name])
|
||||
|
||||
# Pad to obs_horizon if needed (duplicate first frame)
|
||||
while len(cam_images) < self.obs_horizon:
|
||||
cam_images.insert(0, cam_images[0])
|
||||
|
||||
# Keep only obs_horizon frames
|
||||
if len(cam_images) > self.obs_horizon:
|
||||
cam_images = cam_images[-self.obs_horizon:]
|
||||
|
||||
# Stack: (obs_horizon, C, H, W) -> (1, obs_horizon, C, H, W)
|
||||
img_tensor = torch.stack(cam_images, dim=0).unsqueeze(0)
|
||||
image_dict[cam_name] = img_tensor
|
||||
|
||||
# Update buffer (without padding)
|
||||
self.obs_buffer['images'][cam_name] = cam_images[-self.obs_horizon:]
|
||||
|
||||
return image_dict
|
||||
|
||||
def _get_qpos_dict(self, obs: Dict) -> torch.Tensor:
|
||||
"""
|
||||
Extract and preprocess qpos from observation
|
||||
|
||||
Args:
|
||||
obs: Environment observation dict
|
||||
|
||||
Returns:
|
||||
qpos tensor: (1, obs_horizon, obs_dim)
|
||||
"""
|
||||
qpos = obs['qpos']
|
||||
qpos = torch.from_numpy(qpos).float()
|
||||
|
||||
# Add to buffer
|
||||
self.obs_buffer['qpos'].append(qpos)
|
||||
|
||||
# Pad to obs_horizon if needed (duplicate first frame)
|
||||
while len(self.obs_buffer['qpos']) < self.obs_horizon:
|
||||
self.obs_buffer['qpos'].insert(0, self.obs_buffer['qpos'][0])
|
||||
|
||||
# Keep only obs_horizon frames
|
||||
if len(self.obs_buffer['qpos']) > self.obs_horizon:
|
||||
self.obs_buffer['qpos'] = self.obs_buffer['qpos'][-self.obs_horizon:]
|
||||
|
||||
# Stack: (obs_horizon, obs_dim) -> (1, obs_horizon, obs_dim)
|
||||
qpos_tensor = torch.stack(self.obs_buffer['qpos'], dim=0).unsqueeze(0)
|
||||
|
||||
return qpos_tensor
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action(self, obs: Dict) -> np.ndarray:
|
||||
"""
|
||||
Predict action using VLA policy
|
||||
|
||||
Args:
|
||||
obs: Current environment observation
|
||||
|
||||
Returns:
|
||||
action: numpy array of shape (action_dim,)
|
||||
"""
|
||||
# 1. Prepare observations
|
||||
images = self._get_image_dict(obs) # Dict[str, (1, obs_horizon, C, H, W)]
|
||||
qpos = self._get_qpos_dict(obs) # (1, obs_horizon, obs_dim)
|
||||
|
||||
# 2. Check if we need to query the policy
|
||||
if self.cached_actions is None or self.query_step % self.num_queries == 0:
|
||||
# Prepare input for VLA agent
|
||||
# VLAAgent.predict_action expects:
|
||||
# - images: Dict[str, Tensor] with shape (B, obs_horizon, C, H, W)
|
||||
# - proprioception: Tensor with shape (B, obs_horizon, obs_dim)
|
||||
|
||||
# Move to device
|
||||
images = {k: v.to(self.device) for k, v in images.items()}
|
||||
qpos = qpos.to(self.device)
|
||||
|
||||
# Predict actions using VLA agent
|
||||
# Returns: (B, pred_horizon, action_dim)
|
||||
predicted_actions = self.agent.predict_action(
|
||||
images=images,
|
||||
proprioception=qpos
|
||||
)
|
||||
|
||||
# Cache predicted actions (CPU numpy array)
|
||||
self.cached_actions = predicted_actions.squeeze(0).cpu().numpy() # (pred_horizon, action_dim)
|
||||
self.query_step = 0
|
||||
|
||||
# 3. Get action from cache
|
||||
raw_action = self.cached_actions[self.query_step]
|
||||
self.query_step += 1
|
||||
|
||||
# 4. Apply smoothing if enabled
|
||||
if self.smoother is not None:
|
||||
raw_action = self.smoother.smooth(raw_action)
|
||||
|
||||
return raw_action
|
||||
|
||||
|
||||
class ActionSmoother:
|
||||
"""Action smoothing for smoother execution"""
|
||||
|
||||
def __init__(self, action_dim: int, method: str = 'ema', alpha: float = 0.3):
|
||||
self.action_dim = action_dim
|
||||
self.method = method
|
||||
self.alpha = alpha
|
||||
self.prev_action = None
|
||||
|
||||
def smooth(self, action: np.ndarray) -> np.ndarray:
|
||||
if self.method == 'ema':
|
||||
if self.prev_action is None:
|
||||
smoothed = action
|
||||
else:
|
||||
smoothed = self.alpha * action + (1 - self.alpha) * self.prev_action
|
||||
self.prev_action = smoothed
|
||||
return smoothed
|
||||
else:
|
||||
return action
|
||||
|
||||
def reset(self):
|
||||
self.prev_action = None
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
ckpt_path: str,
|
||||
device: str = 'cuda'
|
||||
) -> torch.nn.Module:
|
||||
"""
|
||||
Load trained VLA model from checkpoint
|
||||
|
||||
Args:
|
||||
ckpt_path: Path to checkpoint file (.pt)
|
||||
device: Device to load model on
|
||||
|
||||
Returns:
|
||||
Loaded VLAAgent model
|
||||
"""
|
||||
from roboimi.vla.agent import VLAAgent
|
||||
from hydra import initialize_config_dir, compose
|
||||
from pathlib import Path as PathLib
|
||||
|
||||
ckpt_path = PathLib(ckpt_path).absolute()
|
||||
if not ckpt_path.exists():
|
||||
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
|
||||
|
||||
# Load checkpoint
|
||||
print(f"Loading checkpoint from {ckpt_path}")
|
||||
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
|
||||
|
||||
print(f"Checkpoint keys: {checkpoint.keys()}")
|
||||
|
||||
# Find VLA config directory
|
||||
import os
|
||||
|
||||
# Get script directory
|
||||
script_dir = PathLib(__file__).resolve().parent
|
||||
current_dir = PathLib(os.getcwd()).absolute()
|
||||
|
||||
# Try to find vla/conf directory
|
||||
config_dir = None
|
||||
|
||||
# Option 1: If running from roboimi directory
|
||||
if (current_dir / 'vla' / 'conf').exists():
|
||||
config_dir = current_dir / 'vla' / 'conf'
|
||||
# Option 2: If running from project root
|
||||
elif (current_dir / 'roboimi' / 'vla' / 'conf').exists():
|
||||
config_dir = current_dir / 'roboimi' / 'vla' / 'conf'
|
||||
# Option 3: Relative to script location
|
||||
elif (script_dir / '../vla' / 'conf').exists():
|
||||
config_dir = (script_dir / '../vla' / 'conf').resolve()
|
||||
# Option 4: Search upwards
|
||||
else:
|
||||
search_start = current_dir
|
||||
while search_start != search_start.parent:
|
||||
if (search_start / 'vla' / 'conf').exists():
|
||||
config_dir = search_start / 'vla' / 'conf'
|
||||
break
|
||||
search_start = search_start.parent
|
||||
|
||||
if config_dir is None:
|
||||
raise FileNotFoundError(
|
||||
f"Could not find VLA config directory.\n"
|
||||
f"Current directory: {current_dir}\n"
|
||||
f"Script location: {script_dir}\n"
|
||||
f"Please ensure you're running from the roboimi directory."
|
||||
)
|
||||
|
||||
config_abs_path = str(config_dir.absolute())
|
||||
print(f"Loading config from {config_abs_path}")
|
||||
|
||||
if not PathLib(config_abs_path).exists():
|
||||
raise FileNotFoundError(f"Config directory does not exist: {config_abs_path}")
|
||||
print(f"Loading config from {config_abs_path}")
|
||||
|
||||
# Initialize Hydra with absolute path
|
||||
with initialize_config_dir(config_dir=config_abs_path, version_base=None):
|
||||
cfg = compose(config_name="config")
|
||||
|
||||
# Instantiate agent from config
|
||||
print("Instantiating agent from config...")
|
||||
from hydra.utils import instantiate
|
||||
agent = instantiate(cfg.agent)
|
||||
|
||||
# Load model state
|
||||
if 'model_state_dict' in checkpoint:
|
||||
agent.load_state_dict(checkpoint['model_state_dict'])
|
||||
print(f"✅ Model state loaded (step: {checkpoint.get('step', 'unknown')})")
|
||||
elif 'state_dict' in checkpoint:
|
||||
agent.load_state_dict(checkpoint['state_dict'])
|
||||
print("✅ Model state loaded")
|
||||
else:
|
||||
# Assume checkpoint is the state_dict itself
|
||||
agent.load_state_dict(checkpoint)
|
||||
print("✅ Model state loaded")
|
||||
|
||||
# Load dataset statistics for denormalization
|
||||
import json
|
||||
stats_path = ckpt_path.parent / 'dataset_stats.json'
|
||||
if stats_path.exists():
|
||||
with open(stats_path, 'r') as f:
|
||||
stats = json.load(f)
|
||||
# Convert lists to numpy arrays
|
||||
agent.action_mean = np.array(stats['action_mean'])
|
||||
agent.action_std = np.array(stats['action_std'])
|
||||
agent.qpos_mean = np.array(stats['qpos_mean'])
|
||||
agent.qpos_std = np.array(stats['qpos_std'])
|
||||
print(f"✅ Dataset statistics loaded for denormalization")
|
||||
else:
|
||||
print(f"⚠️ Warning: {stats_path} not found. Actions will not be denormalized!")
|
||||
agent.action_mean = None
|
||||
agent.action_std = None
|
||||
|
||||
agent.eval()
|
||||
agent.to(device)
|
||||
|
||||
print(f"✅ Model loaded successfully on {device}")
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def evaluate_policy(
|
||||
agent: torch.nn.Module,
|
||||
num_episodes: int = 3,
|
||||
max_timesteps: int = 700,
|
||||
task_name: str = 'sim_transfer',
|
||||
device: str = 'cuda',
|
||||
camera_names: List[str] = ['r_vis', 'top', 'front'],
|
||||
num_queries: int = 1,
|
||||
obs_horizon: int = 2,
|
||||
save_video: bool = True
|
||||
):
|
||||
"""
|
||||
Evaluate VLA policy in simulation
|
||||
|
||||
Args:
|
||||
agent: Trained VLAAgent
|
||||
num_episodes: Number of episodes to run
|
||||
max_timesteps: Maximum timesteps per episode
|
||||
task_name: Task name for environment creation
|
||||
device: Device for inference
|
||||
camera_names: List of camera names
|
||||
num_queries: Policy query frequency
|
||||
obs_horizon: Observation horizon
|
||||
save_video: Whether to save video
|
||||
"""
|
||||
# Create evaluator
|
||||
evaluator = VLAEvaluator(
|
||||
agent=agent,
|
||||
device=device,
|
||||
camera_names=camera_names,
|
||||
num_queries=num_queries,
|
||||
obs_horizon=obs_horizon,
|
||||
use_smoothing=False,
|
||||
smooth_method='ema',
|
||||
smooth_alpha=0.3
|
||||
)
|
||||
|
||||
# Create environment
|
||||
env = make_sim_env(task_name)
|
||||
|
||||
# Run episodes
|
||||
for episode_idx in range(num_episodes):
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Episode {episode_idx + 1}/{num_episodes}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Reset environment and evaluator
|
||||
box_pos = sample_transfer_pose()
|
||||
env.reset(box_pos)
|
||||
evaluator.reset()
|
||||
|
||||
# Storage for visualization
|
||||
episode_images = []
|
||||
success = False
|
||||
success_timestep = 0
|
||||
|
||||
with torch.inference_mode():
|
||||
for t in tqdm(range(max_timesteps), desc=f"Episode {episode_idx + 1}"):
|
||||
# Get observation
|
||||
obs = env._get_image_obs()
|
||||
qpos_obs = env._get_qpos_obs()
|
||||
|
||||
# Merge observations
|
||||
obs['qpos'] = qpos_obs['qpos']
|
||||
|
||||
# Predict action
|
||||
action = evaluator.predict_action(obs)
|
||||
|
||||
# Execute action
|
||||
env.step_jnt(action)
|
||||
|
||||
# Save images for video
|
||||
if save_video:
|
||||
episode_images.append(obs['images'])
|
||||
|
||||
# Render
|
||||
env.render()
|
||||
|
||||
# Check if episode is done
|
||||
if env.rew == 1.0: # Success condition
|
||||
success = True
|
||||
success_timestep = t
|
||||
print(f"\n✅ Task completed at timestep {t}!")
|
||||
break
|
||||
|
||||
# Episode summary
|
||||
print(f"\nEpisode {episode_idx + 1} Summary:")
|
||||
print(f" Success: {success}")
|
||||
if success:
|
||||
print(f" Success Timestep: {success_timestep}")
|
||||
print(f" Length: {len(episode_images)} timesteps")
|
||||
|
||||
# Save video
|
||||
if save_video and episode_images:
|
||||
save_video_episode(
|
||||
episode_images,
|
||||
save_path=f"outputs/eval_vla_episode_{episode_idx}.mp4"
|
||||
)
|
||||
print(f" Video saved: outputs/eval_vla_episode_{episode_idx}.mp4")
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print("Evaluation complete!")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
|
||||
def save_video_episode(images: List[Dict], save_path: str, fps: int = 20):
|
||||
"""
|
||||
Save episode as video
|
||||
|
||||
Args:
|
||||
images: List of observation dicts containing images
|
||||
save_path: Path to save video
|
||||
fps: Frames per second
|
||||
"""
|
||||
try:
|
||||
import cv2
|
||||
from tqdm import tqdm
|
||||
|
||||
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Use first camera (e.g., 'r_vis') for visualization
|
||||
cam_name = list(images[0].keys())[0]
|
||||
|
||||
# Get image size
|
||||
H, W, C = images[0][cam_name].shape
|
||||
|
||||
# Create video writer
|
||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||
video_writer = cv2.VideoWriter(save_path, fourcc, fps, (W, H))
|
||||
|
||||
# Write frames
|
||||
for img_dict in tqdm(images, desc="Saving video"):
|
||||
frame = img_dict[cam_name]
|
||||
# Convert RGB to BGR for OpenCV
|
||||
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
||||
video_writer.write(frame_bgr)
|
||||
|
||||
video_writer.release()
|
||||
print(f"Video saved to {save_path}")
|
||||
|
||||
except ImportError:
|
||||
print("Warning: opencv-python not installed, skipping video save")
|
||||
print("Install with: pip install opencv-python")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Evaluate VLA Policy')
|
||||
parser.add_argument('--ckpt_path', type=str, required=True,
|
||||
help='Path to model checkpoint')
|
||||
parser.add_argument('--num_episodes', type=int, default=3,
|
||||
help='Number of evaluation episodes')
|
||||
parser.add_argument('--max_timesteps', type=int, default=700,
|
||||
help='Maximum timesteps per episode')
|
||||
parser.add_argument('--device', type=str, default='cuda',
|
||||
help='Device for inference')
|
||||
parser.add_argument('--camera_names', nargs='+', default=['r_vis', 'top', 'front'],
|
||||
help='Camera names to use')
|
||||
parser.add_argument('--num_queries', type=int, default=16,
|
||||
help='Policy query frequency (timesteps)')
|
||||
parser.add_argument('--obs_horizon', type=int, default=2,
|
||||
help='Observation horizon')
|
||||
parser.add_argument('--no_video', action='store_true',
|
||||
help='Do not save episode videos')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load model
|
||||
print(f"Loading model from {args.ckpt_path}...")
|
||||
agent = load_checkpoint(args.ckpt_path, device=args.device)
|
||||
|
||||
# Evaluate
|
||||
evaluate_policy(
|
||||
agent=agent,
|
||||
num_episodes=args.num_episodes,
|
||||
max_timesteps=args.max_timesteps,
|
||||
device=args.device,
|
||||
camera_names=args.camera_names,
|
||||
num_queries=args.num_queries,
|
||||
obs_horizon=args.obs_horizon,
|
||||
save_video=not args.no_video
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
328
roboimi/demos/vla_scripts/eval_vla.py
Normal file
328
roboimi/demos/vla_scripts/eval_vla.py
Normal file
@@ -0,0 +1,328 @@
|
||||
"""
|
||||
VLA Policy Evaluation Script (Hydra-based)
|
||||
|
||||
This script evaluates a trained Vision-Language-Action (VLA) policy
|
||||
in the MuJoCo simulation environment.
|
||||
|
||||
Usage:
|
||||
python roboimi/demos/eval_vla.py
|
||||
python roboimi/demos/eval_vla.py ckpt_path=checkpoints/vla_model_step_8000.pt num_episodes=5
|
||||
python roboimi/demos/eval_vla.py use_smoothing=true smooth_alpha=0.5
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import torch
|
||||
import numpy as np
|
||||
import hydra
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
from tqdm import tqdm
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from hydra.utils import instantiate
|
||||
|
||||
from roboimi.envs.double_pos_ctrl_env import make_sim_env
|
||||
from roboimi.utils.act_ex_utils import sample_transfer_pose
|
||||
from einops import rearrange
|
||||
|
||||
# Ensure correct import path
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VLAEvaluator:
|
||||
"""
|
||||
VLA Policy Evaluator for MuJoCo Simulation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent: torch.nn.Module,
|
||||
device: str = 'cuda',
|
||||
camera_names: List[str] = ['r_vis', 'top', 'front'],
|
||||
num_queries: int = 1,
|
||||
obs_horizon: int = 2,
|
||||
pred_horizon: int = 16,
|
||||
use_smoothing: bool = False,
|
||||
smooth_method: str = 'ema',
|
||||
smooth_alpha: float = 0.3
|
||||
):
|
||||
self.agent = agent.to(device)
|
||||
self.device = device
|
||||
self.camera_names = camera_names
|
||||
self.num_queries = num_queries
|
||||
self.obs_horizon = obs_horizon
|
||||
self.pred_horizon = pred_horizon
|
||||
|
||||
# Action smoothing
|
||||
self.use_smoothing = use_smoothing
|
||||
self.smooth_method = smooth_method
|
||||
self.smooth_alpha = smooth_alpha
|
||||
self.smoother = ActionSmoother(
|
||||
action_dim=16,
|
||||
method=smooth_method,
|
||||
alpha=smooth_alpha
|
||||
) if use_smoothing else None
|
||||
|
||||
# Observation buffer for obs_horizon
|
||||
self.obs_buffer = {
|
||||
'images': {cam: [] for cam in camera_names},
|
||||
'qpos': []
|
||||
}
|
||||
self.cached_actions = None
|
||||
self.query_step = 0
|
||||
|
||||
def reset(self):
|
||||
"""Reset evaluator state"""
|
||||
self.obs_buffer = {
|
||||
'images': {cam: [] for cam in self.camera_names},
|
||||
'qpos': []
|
||||
}
|
||||
self.cached_actions = None
|
||||
self.query_step = 0
|
||||
if self.smoother is not None:
|
||||
self.smoother.reset()
|
||||
|
||||
def _get_image_dict(self, obs: Dict) -> Dict[str, torch.Tensor]:
|
||||
images = {}
|
||||
for cam_name in self.camera_names:
|
||||
img = obs['images'][cam_name]
|
||||
img = rearrange(img, 'h w c -> c h w')
|
||||
img = torch.from_numpy(img / 255.0).float()
|
||||
images[cam_name] = img
|
||||
|
||||
image_dict = {}
|
||||
for cam_name in self.camera_names:
|
||||
cam_images = self.obs_buffer['images'][cam_name]
|
||||
cam_images.append(images[cam_name])
|
||||
|
||||
while len(cam_images) < self.obs_horizon:
|
||||
cam_images.insert(0, cam_images[0])
|
||||
|
||||
if len(cam_images) > self.obs_horizon:
|
||||
cam_images = cam_images[-self.obs_horizon:]
|
||||
|
||||
img_tensor = torch.stack(cam_images, dim=0).unsqueeze(0)
|
||||
image_dict[cam_name] = img_tensor
|
||||
|
||||
self.obs_buffer['images'][cam_name] = cam_images[-self.obs_horizon:]
|
||||
|
||||
return image_dict
|
||||
|
||||
def _get_qpos_dict(self, obs: Dict) -> torch.Tensor:
|
||||
qpos = obs['qpos']
|
||||
qpos = torch.from_numpy(qpos).float()
|
||||
|
||||
self.obs_buffer['qpos'].append(qpos)
|
||||
|
||||
while len(self.obs_buffer['qpos']) < self.obs_horizon:
|
||||
self.obs_buffer['qpos'].insert(0, self.obs_buffer['qpos'][0])
|
||||
|
||||
if len(self.obs_buffer['qpos']) > self.obs_horizon:
|
||||
self.obs_buffer['qpos'] = self.obs_buffer['qpos'][-self.obs_horizon:]
|
||||
|
||||
qpos_tensor = torch.stack(self.obs_buffer['qpos'], dim=0).unsqueeze(0)
|
||||
return qpos_tensor
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action(self, obs: Dict) -> np.ndarray:
|
||||
images = self._get_image_dict(obs)
|
||||
qpos = self._get_qpos_dict(obs)
|
||||
|
||||
if self.cached_actions is None or self.query_step % self.num_queries == 0:
|
||||
images = {k: v.to(self.device) for k, v in images.items()}
|
||||
qpos = qpos.to(self.device)
|
||||
|
||||
predicted_actions = self.agent.predict_action(
|
||||
images=images,
|
||||
proprioception=qpos
|
||||
)
|
||||
|
||||
self.cached_actions = predicted_actions.squeeze(0).cpu().numpy()
|
||||
self.query_step = 0
|
||||
|
||||
raw_action = self.cached_actions[self.query_step]
|
||||
self.query_step += 1
|
||||
|
||||
if self.smoother is not None:
|
||||
raw_action = self.smoother.smooth(raw_action)
|
||||
|
||||
return raw_action
|
||||
|
||||
|
||||
class ActionSmoother:
|
||||
"""Action smoothing for smoother execution"""
|
||||
|
||||
def __init__(self, action_dim: int, method: str = 'ema', alpha: float = 0.3):
|
||||
self.action_dim = action_dim
|
||||
self.method = method
|
||||
self.alpha = alpha
|
||||
self.prev_action = None
|
||||
|
||||
def smooth(self, action: np.ndarray) -> np.ndarray:
|
||||
if self.method == 'ema':
|
||||
if self.prev_action is None:
|
||||
smoothed = action
|
||||
else:
|
||||
smoothed = self.alpha * action + (1 - self.alpha) * self.prev_action
|
||||
self.prev_action = smoothed
|
||||
return smoothed
|
||||
else:
|
||||
return action
|
||||
|
||||
def reset(self):
|
||||
self.prev_action = None
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
ckpt_path: str,
|
||||
agent_cfg: DictConfig,
|
||||
device: str = 'cuda'
|
||||
) -> torch.nn.Module:
|
||||
"""
|
||||
Load trained VLA model from checkpoint using Hydra agent config.
|
||||
|
||||
Args:
|
||||
ckpt_path: Path to checkpoint file (.pt)
|
||||
agent_cfg: Hydra agent config for instantiation
|
||||
device: Device to load model on
|
||||
|
||||
Returns:
|
||||
Loaded VLAAgent model
|
||||
"""
|
||||
from pathlib import Path as PathLib
|
||||
|
||||
ckpt_path = PathLib(ckpt_path).absolute()
|
||||
if not ckpt_path.exists():
|
||||
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
|
||||
|
||||
log.info(f"Loading checkpoint from {ckpt_path}")
|
||||
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
|
||||
log.info(f"Checkpoint keys: {checkpoint.keys()}")
|
||||
|
||||
# Instantiate agent from Hydra config
|
||||
log.info("Instantiating agent from config...")
|
||||
agent = instantiate(agent_cfg)
|
||||
|
||||
# Load model state
|
||||
if 'model_state_dict' in checkpoint:
|
||||
agent.load_state_dict(checkpoint['model_state_dict'])
|
||||
log.info(f"✅ Model state loaded (step: {checkpoint.get('step', 'unknown')})")
|
||||
elif 'state_dict' in checkpoint:
|
||||
agent.load_state_dict(checkpoint['state_dict'])
|
||||
log.info("✅ Model state loaded")
|
||||
else:
|
||||
agent.load_state_dict(checkpoint)
|
||||
log.info("✅ Model state loaded")
|
||||
|
||||
# Load dataset statistics for denormalization
|
||||
stats_path = ckpt_path.parent / 'dataset_stats.json'
|
||||
if stats_path.exists():
|
||||
with open(stats_path, 'r') as f:
|
||||
stats = json.load(f)
|
||||
agent.action_mean = np.array(stats['action_mean'])
|
||||
agent.action_std = np.array(stats['action_std'])
|
||||
agent.qpos_mean = np.array(stats['qpos_mean'])
|
||||
agent.qpos_std = np.array(stats['qpos_std'])
|
||||
log.info("✅ Dataset statistics loaded for denormalization")
|
||||
else:
|
||||
log.warning(f"⚠️ {stats_path} not found. Actions will not be denormalized!")
|
||||
agent.action_mean = None
|
||||
agent.action_std = None
|
||||
|
||||
agent.eval()
|
||||
agent.to(device)
|
||||
|
||||
log.info(f"✅ Model loaded successfully on {device}")
|
||||
return agent
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config")
|
||||
def main(cfg: DictConfig):
|
||||
"""
|
||||
VLA Evaluation Script with Hydra Configuration.
|
||||
|
||||
All eval parameters come from vla/conf/eval.yaml, merged into cfg.
|
||||
Override on command line: python eval_vla.py eval.ckpt_path=... eval.num_episodes=5
|
||||
"""
|
||||
|
||||
# Print configuration
|
||||
print("=" * 80)
|
||||
print("VLA Evaluation Configuration:")
|
||||
print("=" * 80)
|
||||
print(OmegaConf.to_yaml(cfg))
|
||||
print("=" * 80)
|
||||
|
||||
eval_cfg = cfg.eval
|
||||
device = eval_cfg.device
|
||||
camera_names = list(eval_cfg.camera_names)
|
||||
|
||||
# Load model
|
||||
log.info(f"🚀 Loading model from {eval_cfg.ckpt_path}...")
|
||||
agent = load_checkpoint(
|
||||
ckpt_path=eval_cfg.ckpt_path,
|
||||
agent_cfg=cfg.agent,
|
||||
device=device
|
||||
)
|
||||
|
||||
# Create evaluator
|
||||
evaluator = VLAEvaluator(
|
||||
agent=agent,
|
||||
device=device,
|
||||
camera_names=camera_names,
|
||||
num_queries=eval_cfg.num_queries,
|
||||
obs_horizon=eval_cfg.obs_horizon,
|
||||
use_smoothing=eval_cfg.use_smoothing,
|
||||
smooth_method=eval_cfg.smooth_method,
|
||||
smooth_alpha=eval_cfg.smooth_alpha
|
||||
)
|
||||
|
||||
# Create environment
|
||||
env = make_sim_env(eval_cfg.task_name)
|
||||
|
||||
# Run episodes
|
||||
for episode_idx in range(eval_cfg.num_episodes):
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Episode {episode_idx + 1}/{eval_cfg.num_episodes}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
box_pos = sample_transfer_pose()
|
||||
env.reset(box_pos)
|
||||
evaluator.reset()
|
||||
|
||||
success = False
|
||||
success_timestep = 0
|
||||
|
||||
with torch.inference_mode():
|
||||
for t in tqdm(range(eval_cfg.max_timesteps), desc=f"Episode {episode_idx + 1}"):
|
||||
obs = env._get_image_obs()
|
||||
qpos_obs = env._get_qpos_obs()
|
||||
obs['qpos'] = qpos_obs['qpos']
|
||||
|
||||
action = evaluator.predict_action(obs)
|
||||
env.step_jnt(action)
|
||||
|
||||
env.render()
|
||||
|
||||
if env.rew == 1.0:
|
||||
success = True
|
||||
success_timestep = t
|
||||
print(f"\n✅ Task completed at timestep {t}!")
|
||||
break
|
||||
|
||||
print(f"\nEpisode {episode_idx + 1} Summary:")
|
||||
print(f" Success: {success}")
|
||||
if success:
|
||||
print(f" Success Timestep: {success_timestep}")
|
||||
print(f" Length: {t + 1} timesteps")
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print("Evaluation complete!")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,238 +0,0 @@
|
||||
# ResNet VLA Training Guide
|
||||
|
||||
This guide explains how to train the VLA agent with ResNet backbone and action_dim=16, obs_dim=16.
|
||||
|
||||
## Configuration Overview
|
||||
|
||||
### 1. Backbone Configuration
|
||||
**File**: `roboimi/vla/conf/backbone/resnet.yaml`
|
||||
- Model: microsoft/resnet-18
|
||||
- Output dim: 1024 (512 channels × 2 from SpatialSoftmax)
|
||||
- Frozen by default for faster training
|
||||
|
||||
### 2. Agent Configuration
|
||||
**File**: `roboimi/vla/conf/agent/resnet_diffusion.yaml`
|
||||
- Vision backbone: ResNet-18 with SpatialSoftmax
|
||||
- Action dimension: 16
|
||||
- Observation dimension: 16
|
||||
- Prediction horizon: 16 steps
|
||||
- Observation horizon: 2 steps
|
||||
- Diffusion steps: 100
|
||||
- Number of cameras: 2
|
||||
|
||||
### 3. Dataset Configuration
|
||||
**File**: `roboimi/vla/conf/data/resnet_dataset.yaml`
|
||||
- Dataset class: RobotDiffusionDataset
|
||||
- Prediction horizon: 16
|
||||
- Observation horizon: 2
|
||||
- Camera names: [r_vis, top]
|
||||
- Normalization: gaussian (mean/std)
|
||||
|
||||
### 4. Training Configuration
|
||||
**File**: `roboimi/vla/conf/config.yaml`
|
||||
- Batch size: 8
|
||||
- Learning rate: 1e-4
|
||||
- Max steps: 10000
|
||||
- Log frequency: 100 steps
|
||||
- Save frequency: 1000 steps
|
||||
- Device: cuda
|
||||
- Num workers: 4
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### 1. Prepare Dataset
|
||||
Your dataset should be organized as:
|
||||
```
|
||||
/path/to/your/dataset/
|
||||
├── episode_0.hdf5
|
||||
├── episode_1.hdf5
|
||||
├── ...
|
||||
└── data_stats.pkl
|
||||
```
|
||||
|
||||
Each HDF5 file should contain:
|
||||
```
|
||||
episode_N.hdf5
|
||||
├── action # (T, 16) float32
|
||||
└── observations/
|
||||
├── qpos # (T, 16) float32
|
||||
└── images/
|
||||
├── r_vis/ # (T, H, W, 3) uint8
|
||||
└── top/ # (T, H, W, 3) uint8
|
||||
```
|
||||
|
||||
### 2. Generate Dataset Statistics
|
||||
Create `data_stats.pkl` with:
|
||||
```python
|
||||
import pickle
|
||||
import numpy as np
|
||||
|
||||
stats = {
|
||||
'action': {
|
||||
'mean': np.zeros(16),
|
||||
'std': np.ones(16)
|
||||
},
|
||||
'qpos': {
|
||||
'mean': np.zeros(16),
|
||||
'std': np.ones(16)
|
||||
}
|
||||
}
|
||||
|
||||
with open('/path/to/your/dataset/data_stats.pkl', 'wb') as f:
|
||||
pickle.dump(stats, f)
|
||||
```
|
||||
|
||||
Or use the provided script:
|
||||
```bash
|
||||
python -m roboimi.vla.scripts.calculate_stats --dataset_dir /path/to/your/dataset
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### 1. Update Dataset Path
|
||||
Edit `roboimi/vla/conf/data/resnet_dataset.yaml`:
|
||||
```yaml
|
||||
dataset_dir: "/path/to/your/dataset" # CHANGE THIS
|
||||
camera_names:
|
||||
- r_vis # CHANGE TO YOUR CAMERA NAMES
|
||||
- top
|
||||
```
|
||||
|
||||
### 2. Run Training
|
||||
```bash
|
||||
# Basic training
|
||||
python roboimi/demos/vla_scripts/train_vla.py
|
||||
|
||||
# Override configurations
|
||||
python roboimi/demos/vla_scripts/train_vla.py train.batch_size=16
|
||||
python roboimi/demos/vla_scripts/train_vla.py train.device=cpu
|
||||
python roboimi/demos/vla_scripts/train_vla.py train.max_steps=20000
|
||||
python roboimi/demos/vla_scripts/train_vla.py data.dataset_dir=/custom/path
|
||||
|
||||
# Debug mode (CPU, small batch, few steps)
|
||||
python roboimi/demos/vla_scripts/train_vla.py \
|
||||
train.device=cpu \
|
||||
train.batch_size=2 \
|
||||
train.max_steps=10 \
|
||||
train.num_workers=0
|
||||
```
|
||||
|
||||
### 3. Monitor Training
|
||||
Checkpoints are saved to:
|
||||
- `checkpoints/vla_model_step_1000.pt` - Periodic checkpoints
|
||||
- `checkpoints/vla_model_best.pt` - Best model (lowest loss)
|
||||
- `checkpoints/vla_model_final.pt` - Final model
|
||||
|
||||
## Architecture Details
|
||||
|
||||
### Data Flow
|
||||
1. **Input**: Images from multiple cameras + proprioception (qpos)
|
||||
2. **Vision Encoder**: ResNet-18 → SpatialSoftmax → (B, T, 1024) per camera
|
||||
3. **Feature Concatenation**: All cameras + qpos → Global conditioning
|
||||
4. **Diffusion Policy**: 1D U-Net predicts noise on action sequences
|
||||
5. **Output**: Clean action sequence (B, 16, 16)
|
||||
|
||||
### Training Process
|
||||
1. Sample random timestep t from [0, 100]
|
||||
2. Add noise to ground truth actions
|
||||
3. Predict noise using vision + proprioception conditioning
|
||||
4. Compute MSE loss between predicted and actual noise
|
||||
5. Backpropagate and update weights
|
||||
|
||||
### Inference Process
|
||||
1. Extract visual features from current observation
|
||||
2. Start with random noise action sequence
|
||||
3. Iteratively denoise over 10 steps (DDPM scheduler)
|
||||
4. Return clean action sequence
|
||||
|
||||
## Common Issues
|
||||
|
||||
### Issue: Out of Memory
|
||||
**Solution**: Reduce batch size or use CPU
|
||||
```bash
|
||||
python train_vla.py train.batch_size=4 train.device=cpu
|
||||
```
|
||||
|
||||
### Issue: Dataset not found
|
||||
**Solution**: Check dataset_dir path in config
|
||||
```bash
|
||||
python train_vla.py data.dataset_dir=/absolute/path/to/dataset
|
||||
```
|
||||
|
||||
### Issue: Camera names mismatch
|
||||
**Solution**: Update camera_names in data config
|
||||
```yaml
|
||||
# roboimi/vla/conf/data/resnet_dataset.yaml
|
||||
camera_names:
|
||||
- your_camera_1
|
||||
- your_camera_2
|
||||
```
|
||||
|
||||
### Issue: data_stats.pkl missing
|
||||
**Solution**: Generate statistics file
|
||||
```bash
|
||||
python -m roboimi.vla.scripts.calculate_stats --dataset_dir /path/to/dataset
|
||||
```
|
||||
|
||||
## Model Files Created
|
||||
|
||||
```
|
||||
roboimi/vla/
|
||||
├── conf/
|
||||
│ ├── config.yaml (UPDATED)
|
||||
│ ├── backbone/
|
||||
│ │ └── resnet.yaml (NEW)
|
||||
│ ├── agent/
|
||||
│ │ └── resnet_diffusion.yaml (NEW)
|
||||
│ └── data/
|
||||
│ └── resnet_dataset.yaml (NEW)
|
||||
├── models/
|
||||
│ └── backbones/
|
||||
│ ├── __init__.py (UPDATED - added resnet export)
|
||||
│ └── resnet.py (EXISTING)
|
||||
└── demos/vla_scripts/
|
||||
└── train_vla.py (REWRITTEN)
|
||||
```
|
||||
|
||||
## Next Steps
|
||||
|
||||
1. **Prepare your dataset** in the required HDF5 format
|
||||
2. **Update dataset_dir** in `roboimi/vla/conf/data/resnet_dataset.yaml`
|
||||
3. **Run training** with `python roboimi/demos/vla_scripts/train_vla.py`
|
||||
4. **Monitor checkpoints** in `checkpoints/` directory
|
||||
5. **Evaluate** the trained model using the best checkpoint
|
||||
|
||||
## Advanced Configuration
|
||||
|
||||
### Use Different ResNet Variant
|
||||
Edit `roboimi/vla/conf/agent/resnet_diffusion.yaml`:
|
||||
```yaml
|
||||
vision_backbone:
|
||||
model_name: "microsoft/resnet-50" # or resnet-34, resnet-101
|
||||
```
|
||||
|
||||
### Adjust Diffusion Steps
|
||||
```yaml
|
||||
# More steps = better quality, slower training
|
||||
diffusion_steps: 200 # default: 100
|
||||
```
|
||||
|
||||
### Change Horizons
|
||||
```yaml
|
||||
pred_horizon: 32 # Predict more future steps
|
||||
obs_horizon: 4 # Use more history
|
||||
```
|
||||
|
||||
### Multi-GPU Training
|
||||
```bash
|
||||
# Use CUDA device 1
|
||||
python train_vla.py train.device=cuda:1
|
||||
|
||||
# For multi-GPU, use torch.distributed (requires code modification)
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- ResNet Paper: https://arxiv.org/abs/1512.03385
|
||||
- Diffusion Policy: https://diffusion-policy.cs.columbia.edu/
|
||||
- VLA Framework Documentation: See CLAUDE.md in project root
|
||||
@@ -1,239 +0,0 @@
|
||||
# VLA Evaluation Guide
|
||||
|
||||
This guide explains how to evaluate a trained Vision-Language-Action (VLA) policy in the MuJoCo simulation environment.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
1. **Trained Model**: Train your VLA model first using `train_vla.py`
|
||||
2. **Checkpoints**: Ensure you have saved model checkpoints in `checkpoints/` directory
|
||||
3. **Dependencies**: Install required dependencies:
|
||||
```bash
|
||||
pip install opencv-python tqdm
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Basic Evaluation
|
||||
|
||||
```bash
|
||||
# Evaluate with default settings (3 episodes)
|
||||
python roboimi/demos/eval_vla.py \
|
||||
--ckpt_path checkpoints/vla_model_best.pt
|
||||
|
||||
# Evaluate with custom settings
|
||||
python roboimi/demos/eval_vla.py \
|
||||
--ckpt_path checkpoints/vla_model_step_5000.pt \
|
||||
--num_episodes 5 \
|
||||
--max_timesteps 700 \
|
||||
--camera_names r_vis top angle \
|
||||
--num_queries 1 \
|
||||
--obs_horizon 2
|
||||
```
|
||||
|
||||
### Parameters
|
||||
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--ckpt_path` | Path to model checkpoint (.pt file) | Required |
|
||||
| `--num_episodes` | Number of evaluation episodes | 3 |
|
||||
| `--max_timesteps` | Maximum timesteps per episode | 700 |
|
||||
| `--device` | Device for inference (`cuda` or `cpu`) | `cuda` |
|
||||
| `--camera_names` | Camera names to use (space-separated) | `r_vis top` |
|
||||
| `--num_queries` | Policy query frequency (every N timesteps) | 1 |
|
||||
| `--obs_horizon` | Observation history length | 2 |
|
||||
| `--no_video` | Disable video saving | False |
|
||||
|
||||
## Usage Details
|
||||
|
||||
### Policy Query Frequency
|
||||
|
||||
The `--num_queries` parameter controls how often the policy is queried:
|
||||
|
||||
- `--num_queries 1`: Query every timestep (default, most accurate)
|
||||
- `--num_queries 4`: Query every 4 timesteps (faster, but uses cached actions)
|
||||
|
||||
When using cached actions (num_queries > 1), the policy predicts a chunk of actions (pred_horizon=16), and these actions are executed sequentially until the next query.
|
||||
|
||||
### Camera Selection
|
||||
|
||||
Available cameras depend on your environment:
|
||||
- `r_vis`: Right arm RealSense camera
|
||||
- `top`: Top-down view camera
|
||||
- `angle`: Angled view camera
|
||||
|
||||
Use `--camera_names` to specify which cameras to use:
|
||||
```bash
|
||||
--camera_names r_vis top # Use 2 cameras
|
||||
--camera_names r_vis top angle # Use all 3 cameras
|
||||
```
|
||||
|
||||
### Observation Horizon
|
||||
|
||||
The `--obs_horizon` parameter determines how many past observations to use as context:
|
||||
|
||||
```bash
|
||||
--obs_horizon 1 # Use only current observation
|
||||
--obs_horizon 2 # Use current + 1 past observation (default)
|
||||
--obs_horizon 4 # Use current + 3 past observations
|
||||
```
|
||||
|
||||
**Note**: Must match the value used during training.
|
||||
|
||||
## Output
|
||||
|
||||
### Console Output
|
||||
|
||||
During evaluation, you'll see:
|
||||
|
||||
```
|
||||
============================================================
|
||||
Episode 1/3
|
||||
============================================================
|
||||
|
||||
Episode 1: 100%|████████████████████| 700/700 [02:30<00:00, 4.64it/s]
|
||||
|
||||
✅ Task completed at timestep 453!
|
||||
|
||||
Episode 1 Summary:
|
||||
Total Reward: 1.0000
|
||||
Max Reward: 1.0000
|
||||
Length: 453 timesteps
|
||||
Video saved: outputs/eval_vla_episode_0.mp4
|
||||
```
|
||||
|
||||
### Video Output
|
||||
|
||||
Videos are saved to `outputs/eval_vla_episode_{N}.mp4` showing the robot's execution.
|
||||
|
||||
### Metrics
|
||||
|
||||
- **Total Reward**: Sum of rewards throughout the episode
|
||||
- **Max Reward**: Maximum reward achieved (1.0 = success)
|
||||
- **Length**: Number of timesteps executed
|
||||
|
||||
## Action Smoothing
|
||||
|
||||
The evaluator includes EMA (Exponential Moving Average) smoothing by default to reduce jitter:
|
||||
|
||||
```python
|
||||
# Default smoothing parameters
|
||||
smooth_method = 'ema'
|
||||
smooth_alpha = 0.3 # Lower = more smoothing
|
||||
```
|
||||
|
||||
To disable or modify smoothing, edit the `evaluate_policy()` call in `eval_vla.py`:
|
||||
|
||||
```python
|
||||
evaluator = VLAEvaluator(
|
||||
agent=agent,
|
||||
use_smoothing=False, # Disable smoothing
|
||||
# or
|
||||
smooth_method='moving_avg', # Use different method
|
||||
smooth_alpha=0.5 # Adjust smoothing strength
|
||||
)
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Issue: Checkpoint not found
|
||||
|
||||
```
|
||||
FileNotFoundError: Checkpoint not found: checkpoints/vla_model_best.pt
|
||||
```
|
||||
|
||||
**Solution**: Ensure you've trained the model and checkpoints exist:
|
||||
```bash
|
||||
ls -la checkpoints/
|
||||
# Should show: vla_model_best.pt, vla_model_final.pt, etc.
|
||||
```
|
||||
|
||||
### Issue: CUDA out of memory
|
||||
|
||||
**Solution**: Use CPU for inference:
|
||||
```bash
|
||||
python eval_vla.py --ckpt_path checkpoints/vla_model_best.pt --device cpu
|
||||
```
|
||||
|
||||
### Issue: Camera names don't match
|
||||
|
||||
**Solution**: Check your HDF5 files for available cameras:
|
||||
```python
|
||||
import h5py
|
||||
with h5py.File('roboimi/demos/dataset/sim_transfer/episode_0.hdf5', 'r') as f:
|
||||
print(list(f['observations/images'].keys()))
|
||||
# Output: ['angle', 'r_vis', 'top']
|
||||
```
|
||||
|
||||
Then use the correct camera names in your eval command.
|
||||
|
||||
### Issue: Mismatched obs_horizon
|
||||
|
||||
```
|
||||
RuntimeError: Tensor shape mismatch
|
||||
```
|
||||
|
||||
**Solution**: Ensure `--obs_horizon` matches the training config (`data.obs_horizon`).
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Custom Evaluation Script
|
||||
|
||||
You can also use the evaluator in your own scripts:
|
||||
|
||||
```python
|
||||
from roboimi.demos.eval_vla import VLAEvaluator, load_checkpoint
|
||||
from roboimi.envs.double_pos_ctrl_env import make_sim_env
|
||||
|
||||
# Load model
|
||||
agent = load_checkpoint('checkpoints/vla_model_best.pt')
|
||||
|
||||
# Create evaluator
|
||||
evaluator = VLAEvaluator(
|
||||
agent=agent,
|
||||
device='cuda',
|
||||
camera_names=['r_vis', 'top'],
|
||||
num_queries=1,
|
||||
obs_horizon=2
|
||||
)
|
||||
|
||||
# Create environment
|
||||
env = make_sim_env('sim_transfer')
|
||||
env.reset()
|
||||
evaluator.reset()
|
||||
|
||||
# Run episode
|
||||
obs = env._get_image_obs()
|
||||
obs['qpos'] = env._get_qpos_obs()['qpos']
|
||||
|
||||
# Predict and execute action
|
||||
action = evaluator.predict_action(obs)
|
||||
env.step_jnt(action)
|
||||
```
|
||||
|
||||
### Batch Evaluation
|
||||
|
||||
Evaluate multiple checkpoints:
|
||||
|
||||
```bash
|
||||
for ckpt in checkpoints/vla_model_step_*.pt; do
|
||||
echo "Evaluating $ckpt"
|
||||
python roboimi/demos/eval_vla.py \
|
||||
--ckpt_path "$ckpt" \
|
||||
--num_episodes 1 \
|
||||
--no_video
|
||||
done
|
||||
```
|
||||
|
||||
## Next Steps
|
||||
|
||||
1. **Train your model**: See [RESNET_TRAINING_GUIDE.md](roboimi/vla/RESNET_TRAINING_GUIDE.md)
|
||||
2. **Evaluate performance**: Use this evaluation script
|
||||
3. **Analyze results**: Compare different checkpoints
|
||||
4. **Deploy to real robot**: Adapt the evaluator for real robot control
|
||||
|
||||
## References
|
||||
|
||||
- Training Guide: [roboimi/vla/RESNET_TRAINING_GUIDE.md](roboimi/vla/RESNET_TRAINING_GUIDE.md)
|
||||
- Project Documentation: [CLAUDE.md](CLAUDE.md)
|
||||
- Original ACT Paper: https://arxiv.org/abs/2304.13705
|
||||
- Diffusion Policy: https://diffusion-policy.cs.columbia.edu/
|
||||
@@ -1,25 +0,0 @@
|
||||
# @package agent
|
||||
_target_: roboimi.vla.agent.VLAAgent
|
||||
|
||||
# --- Real Vision Backbone ---
|
||||
backbone:
|
||||
_target_: roboimi.vla.models.backbones.siglip.SigLIPBackbone
|
||||
# Google SigLIP (SOTA Vision Encoder)
|
||||
# 第一次运行会自动下载 (~1.5GB)
|
||||
model_name: "google/siglip-so400m-patch14-384"
|
||||
freeze: true # 初始阶段冻结视觉层,只训练 Head
|
||||
embed_dim: 1152 # SigLIP so400m-patch14-384 的 hidden_size
|
||||
|
||||
# --- Adapter ---
|
||||
projector:
|
||||
_target_: roboimi.vla.models.projectors.mlp.MLPProjector
|
||||
# 自动读取 SigLIP 的 1152 维
|
||||
input_dim: ${..backbone.embed_dim}
|
||||
output_dim: 384 # 压缩到 384 或 512 给 Policy 用
|
||||
|
||||
# --- Policy Head ---
|
||||
head:
|
||||
_target_: roboimi.vla.models.heads.debug.DebugHead
|
||||
input_dim: ${..projector.output_dim}
|
||||
action_dim: 16
|
||||
chunk_size: 16
|
||||
@@ -1,24 +0,0 @@
|
||||
_target_: roboimi.vla.agent.VLAAgent
|
||||
|
||||
# 1. Backbone Configuration
|
||||
backbone:
|
||||
_target_: roboimi.vla.models.backbones.debug.DebugBackbone
|
||||
embed_dim: 768 # Variable A
|
||||
seq_len: 10
|
||||
|
||||
# 2. Projector Configuration
|
||||
projector:
|
||||
_target_: roboimi.vla.models.projectors.mlp.MLPProjector
|
||||
# Dependency Injection via Interpolation:
|
||||
# Takes 'embed_dim' from the sibling 'backbone' config above.
|
||||
input_dim: ${..backbone.embed_dim}
|
||||
output_dim: 512 # Variable B (The bottleneck size)
|
||||
|
||||
# 3. Head Configuration
|
||||
head:
|
||||
_target_: roboimi.vla.models.heads.debug.DebugHead
|
||||
# Dependency Injection via Interpolation:
|
||||
# Takes 'output_dim' from the sibling 'projector' config above.
|
||||
input_dim: ${..projector.output_dim}
|
||||
action_dim: 7 # (x,y,z, r,p,y, gripper)
|
||||
chunk_size: 16
|
||||
@@ -1,30 +0,0 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
# 1. 将 backbone 配置挂载到 agent.vlm_backbone 节点
|
||||
- /backbone@vlm_backbone: siglip
|
||||
|
||||
# 2. 将 projector 配置挂载到 agent.img_projector 节点 (新增)
|
||||
- /projector@img_projector: mlp
|
||||
|
||||
# 3. 将 head 配置挂载到 agent.action_head 节点
|
||||
- /head@action_head: diffusion
|
||||
|
||||
# 4. 允许当前文件覆盖上述配置
|
||||
- _self_
|
||||
|
||||
_target_: roboimi.vla.agent.VLAAgent
|
||||
|
||||
# 核心超参数:单一真值源
|
||||
state_dim: 14
|
||||
embed_dim: 512
|
||||
|
||||
# --- 参数一致性绑定 (Interpolation) ---
|
||||
|
||||
# 强制 Projector 输出维度 = Agent 嵌入维度
|
||||
img_projector:
|
||||
input_dim: ${..vlm_backbone.output_dim} # 自动获取 backbone 的输出维度
|
||||
output_dim: ${..embed_dim} # 引用上方的 embed_dim
|
||||
|
||||
# 强制 Head 输入维度 = Agent 嵌入维度
|
||||
action_head:
|
||||
input_dim: ${..embed_dim} # 引用上方的 embed_dim
|
||||
@@ -8,15 +8,18 @@ vision_backbone:
|
||||
freeze: true
|
||||
|
||||
# Action and Observation Dimensions
|
||||
action_dim: 16 # Robot action dimension
|
||||
obs_dim: 16 # Proprioception dimension (qpos)
|
||||
action_dim: 16
|
||||
obs_dim: 16
|
||||
|
||||
# Prediction Horizons
|
||||
pred_horizon: 16 # How many future actions to predict
|
||||
obs_horizon: 2 # How many historical observations to use
|
||||
# Prediction and Observation Horizons
|
||||
pred_horizon: 16
|
||||
obs_horizon: 2
|
||||
|
||||
# Diffusion Parameters
|
||||
diffusion_steps: 100 # Number of diffusion timesteps for training
|
||||
|
||||
# Camera Configuration
|
||||
num_cams: 3 # Number of cameras (e.g., r_vis, top)
|
||||
# num_cams 应与 data.camera_names 列表长度一致
|
||||
# 可使用 Hydra OmegaConf resolver: ${oc.len:data.camera_names}
|
||||
# 但部分版本不支持,这里手动保持同步
|
||||
num_cams: 3 # len(data.camera_names) = 3
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
# @package agent
|
||||
_target_: roboimi.vla.agent.VLAAgent
|
||||
|
||||
# 1. Vision
|
||||
backbone:
|
||||
_target_: roboimi.vla.models.backbones.siglip.SigLIPBackbone
|
||||
model_name: "google/siglip-so400m-patch14-384"
|
||||
embed_dim: 1152
|
||||
freeze: true
|
||||
|
||||
# 2. Adapter
|
||||
projector:
|
||||
_target_: roboimi.vla.models.projectors.mlp.MLPProjector
|
||||
input_dim: ${..backbone.embed_dim}
|
||||
output_dim: 256 # 压缩给 Diffusion 用
|
||||
|
||||
# 3. Diffusion Policy Head
|
||||
head:
|
||||
_target_: roboimi.vla.models.heads.diffusion.DiffusionHead
|
||||
input_dim: ${..projector.output_dim}
|
||||
action_dim: 16
|
||||
chunk_size: 16
|
||||
n_timesteps: 50 # 训练用100,这里调试用50快一点
|
||||
hidden_dim: 256
|
||||
@@ -1,26 +0,0 @@
|
||||
# 调试用小模型
|
||||
# @package agent
|
||||
_target_: roboimi.vla.agent.VLAAgent
|
||||
|
||||
# --- 1. Backbone (VLM) ---
|
||||
backbone:
|
||||
_target_: roboimi.vla.models.backbones.debug.DebugBackbone
|
||||
embed_dim: 768 # 定义源头维度
|
||||
seq_len: 10
|
||||
|
||||
# --- 2. Projector (Adapter) ---
|
||||
projector:
|
||||
_target_: roboimi.vla.models.projectors.mlp.MLPProjector
|
||||
# 【关键】依赖注入:自动读取 backbone 的 embed_dim
|
||||
input_dim: ${..backbone.embed_dim}
|
||||
output_dim: 128 # 瓶颈层维度 (Tiny scale)
|
||||
|
||||
# --- 3. Head (Policy) ---
|
||||
head:
|
||||
_target_: roboimi.vla.models.heads.debug.DebugHead
|
||||
input_dim: ${..projector.output_dim}
|
||||
|
||||
# 【关键修改】改为 16 以匹配你的 Sim 数据
|
||||
action_dim: 16
|
||||
|
||||
chunk_size: 16
|
||||
@@ -1 +0,0 @@
|
||||
# CLIP Backbone 配置
|
||||
@@ -3,8 +3,3 @@ _target_: roboimi.vla.models.backbones.resnet.ResNetBackbone
|
||||
|
||||
model_name: "microsoft/resnet-18"
|
||||
freeze: true
|
||||
|
||||
# Output dimension calculation:
|
||||
# ResNet-18 final layer has 512 channels
|
||||
# After SpatialSoftmax: 512 * 2 = 1024 (x,y coordinates per channel)
|
||||
# output_dim: 1024
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
_target_: roboimi.vla.models.backbones.SigLIPBackbone
|
||||
model_name: "google/siglip-so400m-patch14-384"
|
||||
frozen: true
|
||||
output_dim: 1152 # SigLIP Large 的特征维度,需显式声明供 Projector 引用
|
||||
@@ -1,7 +1,8 @@
|
||||
defaults:
|
||||
- _self_
|
||||
- agent: resnet_diffusion
|
||||
- data: resnet_dataset
|
||||
- eval: eval
|
||||
- _self_
|
||||
|
||||
train:
|
||||
batch_size: 16 # Batch size for training
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
_target_: roboimi.vla.data.dataset.VLADataset
|
||||
dataset_dir: "/path/to/your/roboimi/demos/dataset/collected_data"
|
||||
pred_horizon: 16
|
||||
obs_horizon: 2
|
||||
|
||||
# 这里展示了 Hydra 的嵌套实例化:Transform 作为参数传入
|
||||
transform:
|
||||
_target_: roboimi.vla.data.image_transforms.VLAImageProcessor
|
||||
size: [224, 224]
|
||||
mean: [0.5, 0.5, 0.5] # SigLIP/CLIP 常用归一化
|
||||
std: [0.5, 0.5, 0.5]
|
||||
|
||||
# 如果需要 Tokenizer
|
||||
tokenizer: null
|
||||
# _target_: roboimi.vla.data.text_processing.SimpleTokenizer
|
||||
# max_length: 77
|
||||
@@ -4,9 +4,9 @@ _target_: roboimi.vla.data.dataset.RobotDiffusionDataset
|
||||
# Dataset Directory (CHANGE THIS TO YOUR DATA PATH)
|
||||
dataset_dir: "roboimi/demos/dataset/sim_transfer" # Path to your dataset directory
|
||||
|
||||
# Horizon Parameters
|
||||
pred_horizon: 16 # Prediction horizon (matches agent.pred_horizon)
|
||||
obs_horizon: 2 # Observation horizon (matches agent.obs_horizon)
|
||||
# Horizon Parameters — 使用 Hydra 插值,从 agent 配置中引用,保持一致性
|
||||
pred_horizon: ${agent.pred_horizon}
|
||||
obs_horizon: ${agent.obs_horizon}
|
||||
action_horizon: 8 # Action execution horizon (used during evaluation)
|
||||
|
||||
# Camera Names (CHANGE THIS TO MATCH YOUR CAMERAS)
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
_target_: roboimi.vla.data.dataset.RobotDiffusionDataset
|
||||
|
||||
dataset_dir: "/home/d51/workspace/work/robo-imi-act/roboimi/demos/dataset/sim_transfer"
|
||||
pred_horizon: 16
|
||||
obs_horizon: 1
|
||||
action_horizon: 8
|
||||
camera_names: ['r_vis', 'top', 'front'] # ['angle', 'r_vis', 'top']
|
||||
normalization_type: 'gaussian' # 'min_max' or 'gaussian'
|
||||
21
roboimi/vla/conf/eval/eval.yaml
Normal file
21
roboimi/vla/conf/eval/eval.yaml
Normal file
@@ -0,0 +1,21 @@
|
||||
# @package eval
|
||||
# Evaluation Configuration
|
||||
ckpt_path: "checkpoints/vla_model_best.pt" # Path to model checkpoint
|
||||
num_episodes: 3 # Number of evaluation episodes
|
||||
max_timesteps: 700 # Maximum timesteps per episode
|
||||
device: ${train.device} # 与训练保持一致
|
||||
task_name: "sim_transfer" # Task name for environment creation
|
||||
|
||||
# Policy execution — 从 agent 配置中引用,保持一致性
|
||||
num_queries: ${agent.pred_horizon} # 每次预测 pred_horizon 步后重新查询
|
||||
obs_horizon: ${agent.obs_horizon}
|
||||
|
||||
# Camera names — 从 data 配置中引用,保持一致性
|
||||
camera_names: ${data.camera_names}
|
||||
|
||||
# Action smoothing
|
||||
use_smoothing: false
|
||||
smooth_method: "ema"
|
||||
smooth_alpha: 0.3
|
||||
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
# ACT-VAE Head 配置
|
||||
@@ -1,7 +1,7 @@
|
||||
_target_: roboimi.vla.models.heads.DiffusionActionHead
|
||||
|
||||
# 显式声明必填参数
|
||||
input_dim: ??? # 【修复】必须存在,等待 agent/default.yaml 填充
|
||||
input_dim: ??? # 等待 agent/default.yaml 填充
|
||||
action_dim: 7
|
||||
obs_horizon: 2
|
||||
pred_horizon: 16
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
# Debug 训练超参数
|
||||
@@ -1 +0,0 @@
|
||||
# GPU 训练超参数
|
||||
@@ -1,75 +0,0 @@
|
||||
# 图像预处理
|
||||
import torch
|
||||
import numpy as np
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image
|
||||
from typing import Union, List
|
||||
|
||||
class VLAImageProcessor:
|
||||
"""
|
||||
VLA 图像预处理器,专为 SigLIP/CLIP 等 ViT 架构设计。
|
||||
功能:
|
||||
1. Numpy (HWC) -> Tensor (CHW)
|
||||
2. Resize (e.g., 384x384)
|
||||
3. Normalize (SigLIP: mean=0.5, std=0.5)
|
||||
4. Data Augmentation (训练时开启颜色抖动)
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
resolution: int = 384,
|
||||
mean: List[float] = [0.5, 0.5, 0.5],
|
||||
std: List[float] = [0.5, 0.5, 0.5],
|
||||
enable_augmentation: bool = True,
|
||||
aug_strength: float = 0.1 # 增强强度,0.1~0.2 比较安全
|
||||
):
|
||||
self.resolution = resolution
|
||||
self.enable_augmentation = enable_augmentation
|
||||
|
||||
# --- 1. 基础处理 (所有模式通用) ---
|
||||
# 注意:这里我们分步定义,因为增强通常在 PIL 阶段做比较快
|
||||
self.resize = T.Resize((resolution, resolution), interpolation=T.InterpolationMode.BICUBIC, antialias=True)
|
||||
self.to_tensor = T.ToTensor()
|
||||
self.normalize = T.Normalize(mean=mean, std=std)
|
||||
|
||||
# --- 2. 数据增强 (仅训练用) ---
|
||||
# 机器人学习通常不做 RandomCrop (会丢失绝对坐标信息),主要做颜色增强
|
||||
if enable_augmentation:
|
||||
self.aug = T.ColorJitter(
|
||||
brightness=aug_strength,
|
||||
contrast=aug_strength,
|
||||
saturation=aug_strength,
|
||||
hue=aug_strength / 2
|
||||
)
|
||||
else:
|
||||
self.aug = torch.nn.Identity()
|
||||
|
||||
def __call__(self, img: Union[np.ndarray, Image.Image, torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
img: (H, W, C) uint8 numpy array (from HDF5) OR PIL Image
|
||||
Returns:
|
||||
tensor: (C, H, W) float32, Normalized
|
||||
"""
|
||||
# 1. 统一转为 PIL Image (方便做 Resize 和 Jitter)
|
||||
if isinstance(img, np.ndarray):
|
||||
img = Image.fromarray(img)
|
||||
elif isinstance(img, torch.Tensor):
|
||||
# 假设 Tensor 是 CHW,转回 PIL 比较麻烦,通常 HDF5 出来都是 numpy
|
||||
pass
|
||||
|
||||
# 2. 数据增强 (如果开启)
|
||||
if self.enable_augmentation:
|
||||
img = self.aug(img)
|
||||
|
||||
# 3. 调整尺寸
|
||||
img = self.resize(img)
|
||||
|
||||
# 4. 转张量 & 归一化
|
||||
# ToTensor 会把 [0, 255] -> [0.0, 1.0]
|
||||
tensor = self.to_tensor(img)
|
||||
tensor = self.normalize(tensor)
|
||||
|
||||
return tensor
|
||||
|
||||
def __repr__(self):
|
||||
return f"VLAImageProcessor(res={self.resolution}, aug={self.enable_augmentation})"
|
||||
@@ -1 +0,0 @@
|
||||
# 文本 Tokenizer 包装
|
||||
@@ -1,10 +1,4 @@
|
||||
# Backbone models
|
||||
from .siglip import SigLIPBackbone
|
||||
from .resnet import ResNetBackbone
|
||||
# from .clip import CLIPBackbone
|
||||
# from .dinov2 import DinoV2Backbone
|
||||
|
||||
__all__ = ["SigLIPBackbone", "ResNetBackbone"]
|
||||
|
||||
# from .debug import DebugBackbone
|
||||
# __all__ = ["DebugBackbone"]
|
||||
__all__ = ["ResNetBackbone"]
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
# SigLIP Backbone 实现
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import AutoModel, AutoProcessor, SiglipVisionModel
|
||||
from typing import Dict, Optional
|
||||
from roboimi.vla.core.interfaces import VLABackbone
|
||||
|
||||
class SigLIPBackbone(VLABackbone):
|
||||
"""
|
||||
Wraps Google's SigLIP Vision Encoder.
|
||||
HuggingFace ID example: "google/siglip-so400m-patch14-384"
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "google/siglip-so400m-patch14-384",
|
||||
freeze: bool = True,
|
||||
embed_dim: Optional[int] = None
|
||||
):
|
||||
super().__init__()
|
||||
print(f"Loading SigLIP: {model_name} ...")
|
||||
|
||||
# 加载视觉部分 (Vision Model only)
|
||||
# 我们不需要 Text Tower,因为 SigLIP 是对齐好的,只用 Vision Tower 抽特征即可
|
||||
self.vision_model = SiglipVisionModel.from_pretrained(model_name)
|
||||
|
||||
# 优先使用配置传入的 embed_dim,否则自动获取
|
||||
if embed_dim is not None:
|
||||
self._embed_dim = embed_dim
|
||||
print(f"✓ Using configured embed_dim: {embed_dim}")
|
||||
else:
|
||||
# 自动获取维度 (SigLIP so400m 通常是 1152)
|
||||
self._embed_dim = self.vision_model.config.hidden_size
|
||||
print(f"✓ Auto-detected embed_dim: {self._embed_dim}")
|
||||
|
||||
if freeze:
|
||||
self._freeze_parameters()
|
||||
|
||||
def _freeze_parameters(self):
|
||||
print("❄️ Freezing Vision Backbone parameters")
|
||||
for param in self.vision_model.parameters():
|
||||
param.requires_grad = False
|
||||
self.vision_model.eval()
|
||||
|
||||
def forward(self, obs: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
obs['image']: (B, C, H, W) normalized tensor
|
||||
Returns:
|
||||
features: (B, Seq_Len, Embed_Dim)
|
||||
"""
|
||||
images = obs['image']
|
||||
|
||||
# SigLIP 期望输入是 (B, C, H, W)
|
||||
# HuggingFace 的 VisionModel 输出是一个 BaseModelOutputWithPooling
|
||||
# last_hidden_state shape: (B, Num_Patches, Embed_Dim)
|
||||
outputs = self.vision_model(pixel_values=images)
|
||||
|
||||
return outputs.last_hidden_state
|
||||
|
||||
@property
|
||||
def embed_dim(self) -> int:
|
||||
return self._embed_dim
|
||||
@@ -1,8 +1,4 @@
|
||||
# # Action Head models
|
||||
from .diffusion import ConditionalUnet1D
|
||||
# from .act import ACTHead
|
||||
|
||||
__all__ = ["ConditionalUnet1D"]
|
||||
|
||||
# from .debug import DebugHead
|
||||
# __all__ = ["DebugHead"]
|
||||
Reference in New Issue
Block a user