111 lines
3.4 KiB
Python
111 lines
3.4 KiB
Python
import pathlib
|
|
import sys
|
|
|
|
import gym
|
|
from gym import spaces
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
|
|
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
|
|
if str(ROOT_DIR) not in sys.path:
|
|
sys.path.append(str(ROOT_DIR))
|
|
|
|
import diffusion_policy.env_runner.pusht_image_runner as runner_module
|
|
from diffusion_policy.env_runner.pusht_image_runner import summarize_rollout_metrics
|
|
|
|
|
|
class FakePushTImageEnv(gym.Env):
|
|
metadata = {'render.modes': ['rgb_array']}
|
|
|
|
def __init__(self, legacy=False, render_size=96):
|
|
del legacy, render_size
|
|
self.observation_space = spaces.Dict({
|
|
'image': spaces.Box(low=0, high=255, shape=(3, 4, 4), dtype=np.uint8),
|
|
})
|
|
self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,), dtype=np.float32)
|
|
self.seed_value = 0
|
|
self.step_count = 0
|
|
|
|
def seed(self, seed=None):
|
|
self.seed_value = 0 if seed is None else seed
|
|
|
|
def reset(self):
|
|
self.step_count = 0
|
|
return {'image': np.zeros((3, 4, 4), dtype=np.uint8)}
|
|
|
|
def step(self, action):
|
|
del action
|
|
self.step_count += 1
|
|
reward = 0.1 if self.seed_value < 10000 else 0.9
|
|
done = self.step_count >= 1
|
|
obs = {'image': np.full((3, 4, 4), self.step_count, dtype=np.uint8)}
|
|
return obs, reward, done, {}
|
|
|
|
def render(self, *args, **kwargs):
|
|
raise AssertionError('render should not be called for scalar-only PushT image rollouts')
|
|
|
|
|
|
class FakePolicy:
|
|
device = torch.device('cpu')
|
|
dtype = torch.float32
|
|
|
|
def reset(self):
|
|
return None
|
|
|
|
def predict_action(self, obs_dict):
|
|
n_envs = next(iter(obs_dict.values())).shape[0]
|
|
return {
|
|
'action': torch.zeros((n_envs, 2, 2), dtype=torch.float32),
|
|
}
|
|
|
|
|
|
def test_summarize_rollout_metrics_keeps_scalar_rewards_renames_means_and_omits_videos():
|
|
log_data = summarize_rollout_metrics(
|
|
env_seeds=[11, 12, 101],
|
|
env_prefixs=['train/', 'train/', 'test/'],
|
|
all_rewards=[
|
|
[0.2, 0.8],
|
|
[0.1, 0.4],
|
|
[0.5, 0.9],
|
|
],
|
|
all_video_paths=[
|
|
'/tmp/train-11.mp4',
|
|
'/tmp/train-12.mp4',
|
|
'/tmp/test-101.mp4',
|
|
],
|
|
)
|
|
|
|
assert log_data['train/sim_max_reward_11'] == 0.8
|
|
assert log_data['train/sim_max_reward_12'] == 0.4
|
|
assert log_data['test/sim_max_reward_101'] == 0.9
|
|
assert log_data['train_mean_score'] == pytest.approx(0.6)
|
|
assert log_data['test_mean_score'] == pytest.approx(0.9)
|
|
assert not any(key.startswith('train/sim_video_') for key in log_data)
|
|
assert not any(key.startswith('test/sim_video_') for key in log_data)
|
|
|
|
|
|
def test_runner_ignores_vis_flags_and_never_emits_sim_videos(tmp_path, monkeypatch):
|
|
monkeypatch.setattr(runner_module, 'PushTImageEnv', FakePushTImageEnv)
|
|
|
|
runner = runner_module.PushTImageRunner(
|
|
output_dir=tmp_path,
|
|
n_train=1,
|
|
n_train_vis=1,
|
|
n_test=1,
|
|
n_test_vis=1,
|
|
n_envs=2,
|
|
max_steps=2,
|
|
n_obs_steps=2,
|
|
n_action_steps=2,
|
|
tqdm_interval_sec=0.0,
|
|
)
|
|
|
|
log_data = runner.run(FakePolicy())
|
|
|
|
assert log_data['train/sim_max_reward_0'] == pytest.approx(0.1)
|
|
assert log_data['test/sim_max_reward_10000'] == pytest.approx(0.9)
|
|
assert log_data['train_mean_score'] == pytest.approx(0.1)
|
|
assert log_data['test_mean_score'] == pytest.approx(0.9)
|
|
assert not any('sim_video' in key for key in log_data)
|