86 lines
2.7 KiB
Python
86 lines
2.7 KiB
Python
import time
|
||
import os
|
||
import numpy as np
|
||
from roboimi.envs.double_pos_ctrl_env import make_sim_env
|
||
from diana_policy import TestPickAndTransferPolicy
|
||
import cv2
|
||
from roboimi.utils.act_ex_utils import sample_transfer_pose
|
||
from roboimi.utils.streaming_episode_writer import StreamingEpisodeWriter
|
||
|
||
import pathlib
|
||
HOME_PATH = str(pathlib.Path(__file__).parent.resolve())
|
||
DATASET_DIR = HOME_PATH + '/dataset'
|
||
|
||
|
||
def main():
|
||
task_name = 'sim_transfer'
|
||
dataset_dir = DATASET_DIR + '/sim_transfer' #SIM_TASK_CONFIGS[task_name]['dataset_dir']
|
||
num_episodes = 100 #SIM_TASK_CONFIGS[task_name]['num_episodes']
|
||
inject_noise = False
|
||
|
||
episode_len = 700 #SIM_TASK_CONFIGS[task_name]['episode_len']
|
||
camera_names = ['angle','r_vis', 'top', 'front'] #SIM_TASK_CONFIGS[task_name]['camera_names']
|
||
image_size = (256, 256)
|
||
if task_name == 'sim_transfer':
|
||
print(task_name)
|
||
else:
|
||
raise NotImplementedError
|
||
|
||
success = []
|
||
|
||
env = make_sim_env(task_name)
|
||
policy = TestPickAndTransferPolicy(inject_noise)
|
||
|
||
# 等待osmesa完全启动后再开始收集数据
|
||
print("等待osmesa线程启动...")
|
||
time.sleep(60)
|
||
print("osmesa已就绪,开始收集数据...")
|
||
|
||
for episode_idx in range(num_episodes):
|
||
sum_reward = 0.0
|
||
max_reward = float('-inf')
|
||
print(f'\n{episode_idx=}')
|
||
print('Rollout out EE space scripted policy')
|
||
box_pos = sample_transfer_pose()
|
||
env.reset(box_pos)
|
||
episode_writer = StreamingEpisodeWriter(
|
||
dataset_path=os.path.join(dataset_dir, f'episode_{episode_idx}.hdf5'),
|
||
max_timesteps=episode_len,
|
||
camera_names=camera_names,
|
||
image_size=image_size,
|
||
)
|
||
for step in range(episode_len):
|
||
raw_action = policy.predict(box_pos,step)
|
||
env.step(raw_action)
|
||
env.render()
|
||
sum_reward += env.rew
|
||
max_reward = max(max_reward, env.rew)
|
||
episode_writer.append(
|
||
qpos=env.obs['qpos'],
|
||
action=raw_action,
|
||
images=env.obs['images'],
|
||
)
|
||
if max_reward == env.max_reward:
|
||
success.append(1)
|
||
print(f"{episode_idx=} Successful, {sum_reward=}")
|
||
episode_writer.commit()
|
||
else:
|
||
success.append(0)
|
||
print(f"{episode_idx=} Failed")
|
||
print(max_reward)
|
||
episode_writer.discard()
|
||
|
||
# del policy
|
||
# env.viewer.close()
|
||
# del env
|
||
print(f'Success: {np.sum(success)} / {len(success)}')
|
||
env.exit_flag = True
|
||
cv2.destroyAllWindows()
|
||
cv2.waitKey(1)
|
||
env.cam_thread.join()
|
||
env.viewer.close()
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|