feat(policy): add scripted air insertion policy
This commit is contained in:
68
roboimi/demos/diana_air_insert_policy.py
Normal file
68
roboimi/demos/diana_air_insert_policy.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
import numpy as np
|
||||||
|
from pyquaternion import Quaternion
|
||||||
|
|
||||||
|
from roboimi.demos.diana_policy import PolicyBase
|
||||||
|
|
||||||
|
|
||||||
|
class TestAirInsertPolicy(PolicyBase):
|
||||||
|
def generate_trajectory(self, task_state):
|
||||||
|
ring_xyz = np.asarray(task_state["ring_pos"], dtype=np.float64)
|
||||||
|
bar_xyz = np.asarray(task_state["bar_pos"], dtype=np.float64)
|
||||||
|
|
||||||
|
init_mocap_pose_left = np.array(
|
||||||
|
[
|
||||||
|
-0.17297014,
|
||||||
|
1.00485877,
|
||||||
|
1.32773627,
|
||||||
|
7.06825181e-01,
|
||||||
|
8.20281078e-06,
|
||||||
|
-7.07388269e-01,
|
||||||
|
-5.20399313e-06,
|
||||||
|
],
|
||||||
|
dtype=np.float64,
|
||||||
|
)
|
||||||
|
init_mocap_pose_right = np.array(
|
||||||
|
[
|
||||||
|
0.17297014,
|
||||||
|
0.9951369,
|
||||||
|
1.32773623,
|
||||||
|
2.59463975e-06,
|
||||||
|
7.07388269e-01,
|
||||||
|
5.59551158e-06,
|
||||||
|
7.06825181e-01,
|
||||||
|
],
|
||||||
|
dtype=np.float64,
|
||||||
|
)
|
||||||
|
|
||||||
|
left_init_quat = Quaternion(init_mocap_pose_left[3:])
|
||||||
|
right_init_quat = Quaternion(init_mocap_pose_right[3:])
|
||||||
|
|
||||||
|
left_pick_quat = (left_init_quat * Quaternion(axis=[0.0, 1.0, 0.0], degrees=90)).elements
|
||||||
|
right_pick_quat = (right_init_quat * Quaternion(axis=[0.0, 1.0, 0.0], degrees=90)).elements
|
||||||
|
right_insert_quat = (right_init_quat * Quaternion(axis=[1.0, 0.0, 0.0], degrees=90)).elements
|
||||||
|
|
||||||
|
meet_xyz = np.array([0.0, 1.0, 1.30], dtype=np.float64)
|
||||||
|
left_hold_xyz = meet_xyz + np.array([-0.02, 0.0, 0.08], dtype=np.float64)
|
||||||
|
right_insert_start_xyz = meet_xyz + np.array([0.0, 0.0, 0.10], dtype=np.float64)
|
||||||
|
right_insert_end_xyz = meet_xyz + np.array([0.0, 0.0, -0.02], dtype=np.float64)
|
||||||
|
|
||||||
|
self.left_trajectory = [
|
||||||
|
{"t": 1, "xyz": init_mocap_pose_left[:3], "quat": init_mocap_pose_left[3:], "gripper": 100},
|
||||||
|
{"t": 80, "xyz": ring_xyz + np.array([0.0, 0.0, 0.22]), "quat": left_pick_quat, "gripper": 100},
|
||||||
|
{"t": 150, "xyz": ring_xyz + np.array([0.0, 0.0, 0.08]), "quat": left_pick_quat, "gripper": 100},
|
||||||
|
{"t": 180, "xyz": ring_xyz + np.array([0.0, 0.0, 0.08]), "quat": left_pick_quat, "gripper": -100},
|
||||||
|
{"t": 260, "xyz": ring_xyz + np.array([0.0, 0.0, 0.24]), "quat": left_pick_quat, "gripper": -100},
|
||||||
|
{"t": 420, "xyz": left_hold_xyz, "quat": init_mocap_pose_left[3:], "gripper": -100},
|
||||||
|
{"t": 700, "xyz": left_hold_xyz, "quat": init_mocap_pose_left[3:], "gripper": -100},
|
||||||
|
]
|
||||||
|
|
||||||
|
self.right_trajectory = [
|
||||||
|
{"t": 1, "xyz": init_mocap_pose_right[:3], "quat": init_mocap_pose_right[3:], "gripper": 100},
|
||||||
|
{"t": 80, "xyz": bar_xyz + np.array([0.0, 0.0, 0.22]), "quat": right_pick_quat, "gripper": 100},
|
||||||
|
{"t": 150, "xyz": bar_xyz + np.array([0.0, 0.0, 0.08]), "quat": right_pick_quat, "gripper": 100},
|
||||||
|
{"t": 180, "xyz": bar_xyz + np.array([0.0, 0.0, 0.08]), "quat": right_pick_quat, "gripper": -100},
|
||||||
|
{"t": 260, "xyz": bar_xyz + np.array([0.0, 0.0, 0.26]), "quat": right_pick_quat, "gripper": -100},
|
||||||
|
{"t": 420, "xyz": right_insert_start_xyz, "quat": right_insert_quat, "gripper": -100},
|
||||||
|
{"t": 580, "xyz": right_insert_end_xyz, "quat": right_insert_quat, "gripper": -100},
|
||||||
|
{"t": 700, "xyz": right_insert_end_xyz, "quat": right_insert_quat, "gripper": -100},
|
||||||
|
]
|
||||||
@@ -2,9 +2,11 @@ import time
|
|||||||
import os
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from roboimi.envs.double_pos_ctrl_env import make_sim_env
|
from roboimi.envs.double_pos_ctrl_env import make_sim_env
|
||||||
from diana_policy import TestPickAndTransferPolicy
|
from roboimi.demos.diana_air_insert_policy import TestAirInsertPolicy
|
||||||
|
from roboimi.demos.diana_policy import TestPickAndTransferPolicy
|
||||||
import cv2
|
import cv2
|
||||||
from roboimi.utils.act_ex_utils import sample_transfer_pose
|
from roboimi.utils.act_ex_utils import sample_air_insert_ring_bar_state, sample_transfer_pose
|
||||||
|
from roboimi.utils.constants import SIM_TASK_CONFIGS
|
||||||
from roboimi.utils.streaming_episode_writer import StreamingEpisodeWriter
|
from roboimi.utils.streaming_episode_writer import StreamingEpisodeWriter
|
||||||
|
|
||||||
import pathlib
|
import pathlib
|
||||||
@@ -12,16 +14,32 @@ HOME_PATH = str(pathlib.Path(__file__).parent.resolve())
|
|||||||
DATASET_DIR = HOME_PATH + '/dataset'
|
DATASET_DIR = HOME_PATH + '/dataset'
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def sample_task_state(task_name):
|
||||||
task_name = 'sim_transfer'
|
if task_name == 'sim_transfer':
|
||||||
dataset_dir = DATASET_DIR + '/sim_transfer' #SIM_TASK_CONFIGS[task_name]['dataset_dir']
|
return sample_transfer_pose()
|
||||||
num_episodes = 100 #SIM_TASK_CONFIGS[task_name]['num_episodes']
|
if task_name == 'sim_air_insert_ring_bar':
|
||||||
|
return sample_air_insert_ring_bar_state()
|
||||||
|
raise NotImplementedError(f'Unsupported scripted rollout task: {task_name}')
|
||||||
|
|
||||||
|
|
||||||
|
def make_policy(task_name, inject_noise=False):
|
||||||
|
if task_name == 'sim_transfer':
|
||||||
|
return TestPickAndTransferPolicy(inject_noise)
|
||||||
|
if task_name == 'sim_air_insert_ring_bar':
|
||||||
|
return TestAirInsertPolicy(inject_noise)
|
||||||
|
raise NotImplementedError(f'Unsupported scripted rollout task: {task_name}')
|
||||||
|
|
||||||
|
|
||||||
|
def main(task_name='sim_transfer'):
|
||||||
|
task_cfg = SIM_TASK_CONFIGS[task_name]
|
||||||
|
dataset_dir = task_cfg['dataset_dir']
|
||||||
|
num_episodes = 100
|
||||||
inject_noise = False
|
inject_noise = False
|
||||||
|
|
||||||
episode_len = 700 #SIM_TASK_CONFIGS[task_name]['episode_len']
|
episode_len = task_cfg['episode_len']
|
||||||
camera_names = ['angle','r_vis', 'top', 'front'] #SIM_TASK_CONFIGS[task_name]['camera_names']
|
camera_names = ['angle', 'r_vis', 'top', 'front']
|
||||||
image_size = (256, 256)
|
image_size = (256, 256)
|
||||||
if task_name == 'sim_transfer':
|
if task_name in {'sim_transfer', 'sim_air_insert_ring_bar'}:
|
||||||
print(task_name)
|
print(task_name)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@@ -29,7 +47,7 @@ def main():
|
|||||||
success = []
|
success = []
|
||||||
|
|
||||||
env = make_sim_env(task_name)
|
env = make_sim_env(task_name)
|
||||||
policy = TestPickAndTransferPolicy(inject_noise)
|
policy = make_policy(task_name, inject_noise=inject_noise)
|
||||||
|
|
||||||
# 等待osmesa完全启动后再开始收集数据
|
# 等待osmesa完全启动后再开始收集数据
|
||||||
print("等待osmesa线程启动...")
|
print("等待osmesa线程启动...")
|
||||||
@@ -41,8 +59,8 @@ def main():
|
|||||||
max_reward = float('-inf')
|
max_reward = float('-inf')
|
||||||
print(f'\n{episode_idx=}')
|
print(f'\n{episode_idx=}')
|
||||||
print('Rollout out EE space scripted policy')
|
print('Rollout out EE space scripted policy')
|
||||||
box_pos = sample_transfer_pose()
|
task_state = sample_task_state(task_name)
|
||||||
env.reset(box_pos)
|
env.reset(task_state)
|
||||||
episode_writer = StreamingEpisodeWriter(
|
episode_writer = StreamingEpisodeWriter(
|
||||||
dataset_path=os.path.join(dataset_dir, f'episode_{episode_idx}.hdf5'),
|
dataset_path=os.path.join(dataset_dir, f'episode_{episode_idx}.hdf5'),
|
||||||
max_timesteps=episode_len,
|
max_timesteps=episode_len,
|
||||||
@@ -50,7 +68,7 @@ def main():
|
|||||||
image_size=image_size,
|
image_size=image_size,
|
||||||
)
|
)
|
||||||
for step in range(episode_len):
|
for step in range(episode_len):
|
||||||
raw_action = policy.predict(box_pos,step)
|
raw_action = policy.predict(task_state, step)
|
||||||
env.step(raw_action)
|
env.step(raw_action)
|
||||||
env.render()
|
env.render()
|
||||||
sum_reward += env.rew
|
sum_reward += env.rew
|
||||||
|
|||||||
@@ -300,5 +300,69 @@ class AirInsertRewardAndSuccessTest(unittest.TestCase):
|
|||||||
self.assertEqual(reward, 5)
|
self.assertEqual(reward, 5)
|
||||||
|
|
||||||
|
|
||||||
|
class AirInsertPolicyAndSmokeTest(unittest.TestCase):
|
||||||
|
def test_air_insert_policy_emits_valid_16d_action(self):
|
||||||
|
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
|
||||||
|
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
|
||||||
|
self.assertIsNotNone(
|
||||||
|
policy_cls,
|
||||||
|
"Expected roboimi.demos.diana_air_insert_policy.TestAirInsertPolicy",
|
||||||
|
)
|
||||||
|
|
||||||
|
task_state = act_ex_utils.sample_air_insert_ring_bar_state()
|
||||||
|
policy = policy_cls(inject_noise=False)
|
||||||
|
action = policy.predict(task_state, 0)
|
||||||
|
|
||||||
|
self.assertEqual(action.shape, (16,))
|
||||||
|
np.testing.assert_array_equal(action[-2:], np.array([100, 100]))
|
||||||
|
|
||||||
|
def test_scripted_rollout_entrypoint_selects_ring_bar_sampler_and_policy(self):
|
||||||
|
rollout_module = importlib.import_module("roboimi.demos.diana_record_sim_episodes")
|
||||||
|
sampler_fn = getattr(rollout_module, "sample_task_state", None)
|
||||||
|
policy_factory = getattr(rollout_module, "make_policy", None)
|
||||||
|
self.assertIsNotNone(
|
||||||
|
sampler_fn,
|
||||||
|
"Expected roboimi.demos.diana_record_sim_episodes.sample_task_state",
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(
|
||||||
|
policy_factory,
|
||||||
|
"Expected roboimi.demos.diana_record_sim_episodes.make_policy",
|
||||||
|
)
|
||||||
|
|
||||||
|
task_state = sampler_fn("sim_air_insert_ring_bar")
|
||||||
|
self.assertEqual(
|
||||||
|
list(task_state.keys()),
|
||||||
|
["ring_pos", "ring_quat", "bar_pos", "bar_quat"],
|
||||||
|
)
|
||||||
|
|
||||||
|
policy = policy_factory("sim_air_insert_ring_bar", inject_noise=False)
|
||||||
|
self.assertEqual(policy.__class__.__name__, "TestAirInsertPolicy")
|
||||||
|
|
||||||
|
def test_real_headless_smoke_instantiates_resets_and_steps_new_task_once(self):
|
||||||
|
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
|
||||||
|
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
|
||||||
|
self.assertIsNotNone(policy_cls)
|
||||||
|
|
||||||
|
task_state = act_ex_utils.sample_air_insert_ring_bar_state()
|
||||||
|
env = make_sim_env("sim_air_insert_ring_bar", headless=True)
|
||||||
|
policy = policy_cls(inject_noise=False)
|
||||||
|
|
||||||
|
try:
|
||||||
|
env.reset(task_state)
|
||||||
|
action = policy.predict(task_state, 0)
|
||||||
|
env.step(action)
|
||||||
|
self.assertIsNotNone(env.obs)
|
||||||
|
self.assertIn("qpos", env.obs)
|
||||||
|
self.assertIn("images", env.obs)
|
||||||
|
finally:
|
||||||
|
env.exit_flag = True
|
||||||
|
cam_thread = getattr(env, "cam_thread", None)
|
||||||
|
if cam_thread is not None:
|
||||||
|
cam_thread.join(timeout=1.0)
|
||||||
|
viewer = getattr(env, "viewer", None)
|
||||||
|
if viewer is not None:
|
||||||
|
viewer.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user