Files
roboimi/roboimi/demos/diana_record_sim_episodes.py

86 lines
2.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import time
import os
import numpy as np
from roboimi.envs.double_pos_ctrl_env import make_sim_env
from diana_policy import TestPickAndTransferPolicy
import cv2
from roboimi.utils.act_ex_utils import sample_transfer_pose
from roboimi.utils.streaming_episode_writer import StreamingEpisodeWriter
import pathlib
HOME_PATH = str(pathlib.Path(__file__).parent.resolve())
DATASET_DIR = HOME_PATH + '/dataset'
def main():
task_name = 'sim_transfer'
dataset_dir = DATASET_DIR + '/sim_transfer' #SIM_TASK_CONFIGS[task_name]['dataset_dir']
num_episodes = 100 #SIM_TASK_CONFIGS[task_name]['num_episodes']
inject_noise = False
episode_len = 700 #SIM_TASK_CONFIGS[task_name]['episode_len']
camera_names = ['angle','r_vis', 'top', 'front'] #SIM_TASK_CONFIGS[task_name]['camera_names']
image_size = (256, 256)
if task_name == 'sim_transfer':
print(task_name)
else:
raise NotImplementedError
success = []
env = make_sim_env(task_name)
policy = TestPickAndTransferPolicy(inject_noise)
# 等待osmesa完全启动后再开始收集数据
print("等待osmesa线程启动...")
time.sleep(60)
print("osmesa已就绪开始收集数据...")
for episode_idx in range(num_episodes):
sum_reward = 0.0
max_reward = float('-inf')
print(f'\n{episode_idx=}')
print('Rollout out EE space scripted policy')
box_pos = sample_transfer_pose()
env.reset(box_pos)
episode_writer = StreamingEpisodeWriter(
dataset_path=os.path.join(dataset_dir, f'episode_{episode_idx}.hdf5'),
max_timesteps=episode_len,
camera_names=camera_names,
image_size=image_size,
)
for step in range(episode_len):
raw_action = policy.predict(box_pos,step)
env.step(raw_action)
env.render()
sum_reward += env.rew
max_reward = max(max_reward, env.rew)
episode_writer.append(
qpos=env.obs['qpos'],
action=raw_action,
images=env.obs['images'],
)
if max_reward == env.max_reward:
success.append(1)
print(f"{episode_idx=} Successful, {sum_reward=}")
episode_writer.commit()
else:
success.append(0)
print(f"{episode_idx=} Failed")
print(max_reward)
episode_writer.discard()
# del policy
# env.viewer.close()
# del env
print(f'Success: {np.sum(success)} / {len(success)}')
env.exit_flag = True
cv2.destroyAllWindows()
cv2.waitKey(1)
env.cam_thread.join()
env.viewer.close()
if __name__ == '__main__':
main()