105 lines
3.7 KiB
Python
105 lines
3.7 KiB
Python
import time
|
|
import os,collections,sys
|
|
import numpy as np
|
|
import h5py
|
|
from roboimi.envs.double_pos_ctrl_env import make_sim_env
|
|
from diana_policy import TestPickAndTransferPolicy
|
|
import cv2
|
|
from roboimi.utils.act_ex_utils import sample_transfer_pose
|
|
|
|
import pathlib
|
|
HOME_PATH = str(pathlib.Path(__file__).parent.resolve())
|
|
DATASET_DIR = HOME_PATH + '/dataset'
|
|
|
|
|
|
def main():
|
|
task_name = 'sim_transfer'
|
|
dataset_dir = DATASET_DIR + '/sim_transfer' #SIM_TASK_CONFIGS[task_name]['dataset_dir']
|
|
num_episodes = 100 #SIM_TASK_CONFIGS[task_name]['num_episodes']
|
|
onscreen_render = None #config['onscreen_render']
|
|
inject_noise = False
|
|
render_cam_name = 'angle'
|
|
|
|
episode_len = 700 #SIM_TASK_CONFIGS[task_name]['episode_len']
|
|
camera_names = ['angle','r_vis', 'top', 'front'] #SIM_TASK_CONFIGS[task_name]['camera_names']
|
|
if task_name == 'sim_transfer':
|
|
policy = TestPickAndTransferPolicy(inject_noise)
|
|
print(task_name)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
success = []
|
|
|
|
env = make_sim_env(task_name)
|
|
policy = TestPickAndTransferPolicy(inject_noise)
|
|
for episode_idx in range(num_episodes):
|
|
obs = []
|
|
reward_ee = []
|
|
print(f'\n{episode_idx=}')
|
|
print('Rollout out EE space scripted policy')
|
|
box_pos = sample_transfer_pose()
|
|
env.reset(box_pos)
|
|
for step in range(episode_len):
|
|
|
|
|
|
action = policy.predict(box_pos,step)
|
|
env.step(action)
|
|
env.render()
|
|
reward_ee.append(env.rew)
|
|
obs.append(env.obs)
|
|
sum_reward = np.sum(reward_ee)
|
|
max_reward = np.max(reward_ee)
|
|
if max_reward == env.max_reward:
|
|
success.append(1)
|
|
print(f"{episode_idx=} Successful, {sum_reward=}")
|
|
t0 = time.time()
|
|
data_dict = {
|
|
'/observations/qpos': [],
|
|
'/action': [],
|
|
}
|
|
|
|
for cam_name in camera_names:
|
|
data_dict[f'/observations/images/{cam_name}'] = []
|
|
for i in range(episode_len):
|
|
print("type qpos==",obs[i]['qpos'])
|
|
data_dict['/observations/qpos'].append(obs[i]['qpos'])
|
|
data_dict['/action'].append(obs[i]['action'])
|
|
for cam_name in camera_names:
|
|
data_dict[f'/observations/images/{cam_name}'].append(obs[i]['images'][cam_name])
|
|
|
|
dataset_path = os.path.join(dataset_dir, f'episode_{episode_idx}')
|
|
|
|
with h5py.File(dataset_path + '.hdf5', 'w', rdcc_nbytes=1024 ** 2 * 2) as root:
|
|
max_timesteps = episode_len
|
|
root.attrs['sim'] = True
|
|
obs_ = root.create_group('observations')
|
|
image = obs_.create_group('images')
|
|
for cam_name in camera_names:
|
|
_ = image.create_dataset(cam_name, (max_timesteps, 480, 640, 3), dtype='uint8',
|
|
chunks=(1, 480, 640, 3), )
|
|
qpos = obs_.create_dataset('qpos', (max_timesteps, 16))
|
|
action = root.create_dataset('action', (max_timesteps, 16))
|
|
for name, array in data_dict.items():
|
|
root[name][...] = np.array(array)
|
|
else:
|
|
success.append(0)
|
|
print(f"{episode_idx=} Failed")
|
|
print(max_reward)
|
|
del obs
|
|
del reward_ee
|
|
del sum_reward
|
|
del max_reward
|
|
|
|
# del policy
|
|
# env.viewer.close()
|
|
# del env
|
|
print(f'Success: {np.sum(success)} / {len(success)}')
|
|
env.exit_flag = True
|
|
cv2.destroyAllWindows()
|
|
cv2.waitKey(1)
|
|
env.cam_thread.join()
|
|
env.viewer.close()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main() |