Files
diffusion_policy/tests/test_pusht_image_runner_metrics.py

111 lines
3.4 KiB
Python

import pathlib
import sys
import gym
from gym import spaces
import numpy as np
import pytest
import torch
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
if str(ROOT_DIR) not in sys.path:
sys.path.append(str(ROOT_DIR))
import diffusion_policy.env_runner.pusht_image_runner as runner_module
from diffusion_policy.env_runner.pusht_image_runner import summarize_rollout_metrics
class FakePushTImageEnv(gym.Env):
metadata = {'render.modes': ['rgb_array']}
def __init__(self, legacy=False, render_size=96):
del legacy, render_size
self.observation_space = spaces.Dict({
'image': spaces.Box(low=0, high=255, shape=(3, 4, 4), dtype=np.uint8),
})
self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,), dtype=np.float32)
self.seed_value = 0
self.step_count = 0
def seed(self, seed=None):
self.seed_value = 0 if seed is None else seed
def reset(self):
self.step_count = 0
return {'image': np.zeros((3, 4, 4), dtype=np.uint8)}
def step(self, action):
del action
self.step_count += 1
reward = 0.1 if self.seed_value < 10000 else 0.9
done = self.step_count >= 1
obs = {'image': np.full((3, 4, 4), self.step_count, dtype=np.uint8)}
return obs, reward, done, {}
def render(self, *args, **kwargs):
raise AssertionError('render should not be called for scalar-only PushT image rollouts')
class FakePolicy:
device = torch.device('cpu')
dtype = torch.float32
def reset(self):
return None
def predict_action(self, obs_dict):
n_envs = next(iter(obs_dict.values())).shape[0]
return {
'action': torch.zeros((n_envs, 2, 2), dtype=torch.float32),
}
def test_summarize_rollout_metrics_keeps_scalar_rewards_renames_means_and_omits_videos():
log_data = summarize_rollout_metrics(
env_seeds=[11, 12, 101],
env_prefixs=['train/', 'train/', 'test/'],
all_rewards=[
[0.2, 0.8],
[0.1, 0.4],
[0.5, 0.9],
],
all_video_paths=[
'/tmp/train-11.mp4',
'/tmp/train-12.mp4',
'/tmp/test-101.mp4',
],
)
assert log_data['train/sim_max_reward_11'] == 0.8
assert log_data['train/sim_max_reward_12'] == 0.4
assert log_data['test/sim_max_reward_101'] == 0.9
assert log_data['train_mean_score'] == pytest.approx(0.6)
assert log_data['test_mean_score'] == pytest.approx(0.9)
assert not any(key.startswith('train/sim_video_') for key in log_data)
assert not any(key.startswith('test/sim_video_') for key in log_data)
def test_runner_ignores_vis_flags_and_never_emits_sim_videos(tmp_path, monkeypatch):
monkeypatch.setattr(runner_module, 'PushTImageEnv', FakePushTImageEnv)
runner = runner_module.PushTImageRunner(
output_dir=tmp_path,
n_train=1,
n_train_vis=1,
n_test=1,
n_test_vis=1,
n_envs=2,
max_steps=2,
n_obs_steps=2,
n_action_steps=2,
tqdm_interval_sec=0.0,
)
log_data = runner.run(FakePolicy())
assert log_data['train/sim_max_reward_0'] == pytest.approx(0.1)
assert log_data['test/sim_max_reward_10000'] == pytest.approx(0.9)
assert log_data['train_mean_score'] == pytest.approx(0.1)
assert log_data['test_mean_score'] == pytest.approx(0.9)
assert not any('sim_video' in key for key in log_data)