206 lines
6.8 KiB
Python
206 lines
6.8 KiB
Python
import numpy as np
|
|
import torch
|
|
import collections
|
|
import tqdm
|
|
import dill
|
|
import math
|
|
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
|
|
from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
|
|
from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
|
|
|
|
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
|
|
from diffusion_policy.common.pytorch_util import dict_apply
|
|
from diffusion_policy.env_runner.base_image_runner import BaseImageRunner
|
|
|
|
|
|
def summarize_rollout_metrics(env_seeds, env_prefixs, all_rewards, all_video_paths=None):
|
|
del all_video_paths
|
|
|
|
max_rewards = collections.defaultdict(list)
|
|
log_data = dict()
|
|
for seed, prefix, rewards in zip(env_seeds, env_prefixs, all_rewards):
|
|
max_reward = np.max(rewards)
|
|
max_rewards[prefix].append(max_reward)
|
|
log_data[prefix + f'sim_max_reward_{seed}'] = max_reward
|
|
|
|
aggregate_key_map = {
|
|
'train/': 'train_mean_score',
|
|
'test/': 'test_mean_score',
|
|
}
|
|
for prefix, value in max_rewards.items():
|
|
log_data[aggregate_key_map.get(prefix, prefix + 'mean_score')] = np.mean(value)
|
|
|
|
return log_data
|
|
|
|
class PushTImageRunner(BaseImageRunner):
|
|
def __init__(self,
|
|
output_dir,
|
|
n_train=10,
|
|
n_train_vis=3,
|
|
train_start_seed=0,
|
|
n_test=22,
|
|
n_test_vis=6,
|
|
legacy_test=False,
|
|
test_start_seed=10000,
|
|
max_steps=200,
|
|
n_obs_steps=8,
|
|
n_action_steps=8,
|
|
fps=10,
|
|
crf=22,
|
|
render_size=96,
|
|
past_action=False,
|
|
tqdm_interval_sec=5.0,
|
|
n_envs=None
|
|
):
|
|
super().__init__(output_dir)
|
|
if n_envs is None:
|
|
n_envs = n_train + n_test
|
|
|
|
def env_fn():
|
|
return MultiStepWrapper(
|
|
PushTImageEnv(
|
|
legacy=legacy_test,
|
|
render_size=render_size
|
|
),
|
|
n_obs_steps=n_obs_steps,
|
|
n_action_steps=n_action_steps,
|
|
max_episode_steps=max_steps
|
|
)
|
|
|
|
env_fns = [env_fn] * n_envs
|
|
env_seeds = list()
|
|
env_prefixs = list()
|
|
env_init_fn_dills = list()
|
|
# train
|
|
for i in range(n_train):
|
|
seed = train_start_seed + i
|
|
|
|
def init_fn(env, seed=seed):
|
|
# set seed
|
|
assert isinstance(env, MultiStepWrapper)
|
|
env.seed(seed)
|
|
|
|
env_seeds.append(seed)
|
|
env_prefixs.append('train/')
|
|
env_init_fn_dills.append(dill.dumps(init_fn))
|
|
|
|
# test
|
|
for i in range(n_test):
|
|
seed = test_start_seed + i
|
|
|
|
def init_fn(env, seed=seed):
|
|
# set seed
|
|
assert isinstance(env, MultiStepWrapper)
|
|
env.seed(seed)
|
|
|
|
env_seeds.append(seed)
|
|
env_prefixs.append('test/')
|
|
env_init_fn_dills.append(dill.dumps(init_fn))
|
|
|
|
# This environment can run without multiprocessing, which avoids
|
|
# shared-memory and semaphore restrictions on some machines.
|
|
env = SyncVectorEnv(env_fns)
|
|
|
|
# test env
|
|
# env.reset(seed=env_seeds)
|
|
# x = env.step(env.action_space.sample())
|
|
# imgs = env.call('render')
|
|
# import pdb; pdb.set_trace()
|
|
|
|
self.env = env
|
|
self.env_fns = env_fns
|
|
self.env_seeds = env_seeds
|
|
self.env_prefixs = env_prefixs
|
|
self.env_init_fn_dills = env_init_fn_dills
|
|
self.fps = fps
|
|
self.crf = crf
|
|
self.n_obs_steps = n_obs_steps
|
|
self.n_action_steps = n_action_steps
|
|
self.past_action = past_action
|
|
self.max_steps = max_steps
|
|
self.tqdm_interval_sec = tqdm_interval_sec
|
|
|
|
def run(self, policy: BaseImagePolicy):
|
|
device = policy.device
|
|
dtype = policy.dtype
|
|
env = self.env
|
|
|
|
# plan for rollout
|
|
n_envs = len(self.env_fns)
|
|
n_inits = len(self.env_init_fn_dills)
|
|
n_chunks = math.ceil(n_inits / n_envs)
|
|
|
|
# allocate data
|
|
all_rewards = [None] * n_inits
|
|
|
|
for chunk_idx in range(n_chunks):
|
|
start = chunk_idx * n_envs
|
|
end = min(n_inits, start + n_envs)
|
|
this_global_slice = slice(start, end)
|
|
this_n_active_envs = end - start
|
|
this_local_slice = slice(0,this_n_active_envs)
|
|
|
|
this_init_fns = self.env_init_fn_dills[this_global_slice]
|
|
n_diff = n_envs - len(this_init_fns)
|
|
if n_diff > 0:
|
|
this_init_fns.extend([self.env_init_fn_dills[0]]*n_diff)
|
|
assert len(this_init_fns) == n_envs
|
|
|
|
# init envs
|
|
env.call_each('run_dill_function',
|
|
args_list=[(x,) for x in this_init_fns])
|
|
|
|
# start rollout
|
|
obs = env.reset()
|
|
past_action = None
|
|
policy.reset()
|
|
|
|
pbar = tqdm.tqdm(total=self.max_steps, desc=f"Eval PushtImageRunner {chunk_idx+1}/{n_chunks}",
|
|
leave=False, mininterval=self.tqdm_interval_sec)
|
|
done = False
|
|
while not done:
|
|
# create obs dict
|
|
np_obs_dict = dict(obs)
|
|
if self.past_action and (past_action is not None):
|
|
# TODO: not tested
|
|
np_obs_dict['past_action'] = past_action[
|
|
:,-(self.n_obs_steps-1):].astype(np.float32)
|
|
|
|
# device transfer
|
|
obs_dict = dict_apply(np_obs_dict,
|
|
lambda x: torch.from_numpy(x).to(
|
|
device=device))
|
|
|
|
# run policy
|
|
with torch.no_grad():
|
|
action_dict = policy.predict_action(obs_dict)
|
|
|
|
# device_transfer
|
|
np_action_dict = dict_apply(action_dict,
|
|
lambda x: x.detach().to('cpu').numpy())
|
|
|
|
action = np_action_dict['action']
|
|
|
|
# step env
|
|
obs, reward, done, info = env.step(action)
|
|
done = np.all(done)
|
|
past_action = action
|
|
|
|
# update pbar
|
|
pbar.update(action.shape[1])
|
|
pbar.close()
|
|
|
|
all_rewards[this_global_slice] = env.call('get_attr', 'reward')[this_local_slice]
|
|
# reset env state between evaluation calls
|
|
_ = env.reset()
|
|
|
|
# results reported in the paper are generated using the commented out
|
|
# line below, which would only report and average metrics from the
|
|
# first n_envs initial conditions and seeds. We keep the full n_inits
|
|
# behavior here.
|
|
return summarize_rollout_metrics(
|
|
env_seeds=self.env_seeds[:n_inits],
|
|
env_prefixs=self.env_prefixs[:n_inits],
|
|
all_rewards=all_rewards[:n_inits],
|
|
)
|