import pathlib import sys import gym from gym import spaces import numpy as np import pytest import torch ROOT_DIR = pathlib.Path(__file__).resolve().parents[1] if str(ROOT_DIR) not in sys.path: sys.path.append(str(ROOT_DIR)) import diffusion_policy.env_runner.pusht_image_runner as runner_module from diffusion_policy.env_runner.pusht_image_runner import summarize_rollout_metrics class FakePushTImageEnv(gym.Env): metadata = {'render.modes': ['rgb_array']} def __init__(self, legacy=False, render_size=96): del legacy, render_size self.observation_space = spaces.Dict({ 'image': spaces.Box(low=0, high=255, shape=(3, 4, 4), dtype=np.uint8), }) self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,), dtype=np.float32) self.seed_value = 0 self.step_count = 0 def seed(self, seed=None): self.seed_value = 0 if seed is None else seed def reset(self): self.step_count = 0 return {'image': np.zeros((3, 4, 4), dtype=np.uint8)} def step(self, action): del action self.step_count += 1 reward = 0.1 if self.seed_value < 10000 else 0.9 done = self.step_count >= 1 obs = {'image': np.full((3, 4, 4), self.step_count, dtype=np.uint8)} return obs, reward, done, {} def render(self, *args, **kwargs): raise AssertionError('render should not be called for scalar-only PushT image rollouts') class FakePolicy: device = torch.device('cpu') dtype = torch.float32 def reset(self): return None def predict_action(self, obs_dict): n_envs = next(iter(obs_dict.values())).shape[0] return { 'action': torch.zeros((n_envs, 2, 2), dtype=torch.float32), } def test_summarize_rollout_metrics_keeps_scalar_rewards_renames_means_and_omits_videos(): log_data = summarize_rollout_metrics( env_seeds=[11, 12, 101], env_prefixs=['train/', 'train/', 'test/'], all_rewards=[ [0.2, 0.8], [0.1, 0.4], [0.5, 0.9], ], all_video_paths=[ '/tmp/train-11.mp4', '/tmp/train-12.mp4', '/tmp/test-101.mp4', ], ) assert log_data['train/sim_max_reward_11'] == 0.8 assert log_data['train/sim_max_reward_12'] == 0.4 assert log_data['test/sim_max_reward_101'] == 0.9 assert log_data['train_mean_score'] == pytest.approx(0.6) assert log_data['test_mean_score'] == pytest.approx(0.9) assert not any(key.startswith('train/sim_video_') for key in log_data) assert not any(key.startswith('test/sim_video_') for key in log_data) def test_runner_ignores_vis_flags_and_never_emits_sim_videos(tmp_path, monkeypatch): monkeypatch.setattr(runner_module, 'PushTImageEnv', FakePushTImageEnv) runner = runner_module.PushTImageRunner( output_dir=tmp_path, n_train=1, n_train_vis=1, n_test=1, n_test_vis=1, n_envs=2, max_steps=2, n_obs_steps=2, n_action_steps=2, tqdm_interval_sec=0.0, ) log_data = runner.run(FakePolicy()) assert log_data['train/sim_max_reward_0'] == pytest.approx(0.1) assert log_data['test/sim_max_reward_10000'] == pytest.approx(0.9) assert log_data['train_mean_score'] == pytest.approx(0.1) assert log_data['test_mean_score'] == pytest.approx(0.9) assert not any('sim_video' in key for key in log_data)