diff --git a/roboimi/demos/diana_air_insert_policy.py b/roboimi/demos/diana_air_insert_policy.py new file mode 100644 index 0000000..7834ac7 --- /dev/null +++ b/roboimi/demos/diana_air_insert_policy.py @@ -0,0 +1,68 @@ +import numpy as np +from pyquaternion import Quaternion + +from roboimi.demos.diana_policy import PolicyBase + + +class TestAirInsertPolicy(PolicyBase): + def generate_trajectory(self, task_state): + ring_xyz = np.asarray(task_state["ring_pos"], dtype=np.float64) + bar_xyz = np.asarray(task_state["bar_pos"], dtype=np.float64) + + init_mocap_pose_left = np.array( + [ + -0.17297014, + 1.00485877, + 1.32773627, + 7.06825181e-01, + 8.20281078e-06, + -7.07388269e-01, + -5.20399313e-06, + ], + dtype=np.float64, + ) + init_mocap_pose_right = np.array( + [ + 0.17297014, + 0.9951369, + 1.32773623, + 2.59463975e-06, + 7.07388269e-01, + 5.59551158e-06, + 7.06825181e-01, + ], + dtype=np.float64, + ) + + left_init_quat = Quaternion(init_mocap_pose_left[3:]) + right_init_quat = Quaternion(init_mocap_pose_right[3:]) + + left_pick_quat = (left_init_quat * Quaternion(axis=[0.0, 1.0, 0.0], degrees=90)).elements + right_pick_quat = (right_init_quat * Quaternion(axis=[0.0, 1.0, 0.0], degrees=90)).elements + right_insert_quat = (right_init_quat * Quaternion(axis=[1.0, 0.0, 0.0], degrees=90)).elements + + meet_xyz = np.array([0.0, 1.0, 1.30], dtype=np.float64) + left_hold_xyz = meet_xyz + np.array([-0.02, 0.0, 0.08], dtype=np.float64) + right_insert_start_xyz = meet_xyz + np.array([0.0, 0.0, 0.10], dtype=np.float64) + right_insert_end_xyz = meet_xyz + np.array([0.0, 0.0, -0.02], dtype=np.float64) + + self.left_trajectory = [ + {"t": 1, "xyz": init_mocap_pose_left[:3], "quat": init_mocap_pose_left[3:], "gripper": 100}, + {"t": 80, "xyz": ring_xyz + np.array([0.0, 0.0, 0.22]), "quat": left_pick_quat, "gripper": 100}, + {"t": 150, "xyz": ring_xyz + np.array([0.0, 0.0, 0.08]), "quat": left_pick_quat, "gripper": 100}, + {"t": 180, "xyz": ring_xyz + np.array([0.0, 0.0, 0.08]), "quat": left_pick_quat, "gripper": -100}, + {"t": 260, "xyz": ring_xyz + np.array([0.0, 0.0, 0.24]), "quat": left_pick_quat, "gripper": -100}, + {"t": 420, "xyz": left_hold_xyz, "quat": init_mocap_pose_left[3:], "gripper": -100}, + {"t": 700, "xyz": left_hold_xyz, "quat": init_mocap_pose_left[3:], "gripper": -100}, + ] + + self.right_trajectory = [ + {"t": 1, "xyz": init_mocap_pose_right[:3], "quat": init_mocap_pose_right[3:], "gripper": 100}, + {"t": 80, "xyz": bar_xyz + np.array([0.0, 0.0, 0.22]), "quat": right_pick_quat, "gripper": 100}, + {"t": 150, "xyz": bar_xyz + np.array([0.0, 0.0, 0.08]), "quat": right_pick_quat, "gripper": 100}, + {"t": 180, "xyz": bar_xyz + np.array([0.0, 0.0, 0.08]), "quat": right_pick_quat, "gripper": -100}, + {"t": 260, "xyz": bar_xyz + np.array([0.0, 0.0, 0.26]), "quat": right_pick_quat, "gripper": -100}, + {"t": 420, "xyz": right_insert_start_xyz, "quat": right_insert_quat, "gripper": -100}, + {"t": 580, "xyz": right_insert_end_xyz, "quat": right_insert_quat, "gripper": -100}, + {"t": 700, "xyz": right_insert_end_xyz, "quat": right_insert_quat, "gripper": -100}, + ] diff --git a/roboimi/demos/diana_record_sim_episodes.py b/roboimi/demos/diana_record_sim_episodes.py index d9d2e2e..19a9a86 100644 --- a/roboimi/demos/diana_record_sim_episodes.py +++ b/roboimi/demos/diana_record_sim_episodes.py @@ -2,9 +2,11 @@ import time import os import numpy as np from roboimi.envs.double_pos_ctrl_env import make_sim_env -from diana_policy import TestPickAndTransferPolicy +from roboimi.demos.diana_air_insert_policy import TestAirInsertPolicy +from roboimi.demos.diana_policy import TestPickAndTransferPolicy import cv2 -from roboimi.utils.act_ex_utils import sample_transfer_pose +from roboimi.utils.act_ex_utils import sample_air_insert_ring_bar_state, sample_transfer_pose +from roboimi.utils.constants import SIM_TASK_CONFIGS from roboimi.utils.streaming_episode_writer import StreamingEpisodeWriter import pathlib @@ -12,16 +14,32 @@ HOME_PATH = str(pathlib.Path(__file__).parent.resolve()) DATASET_DIR = HOME_PATH + '/dataset' -def main(): - task_name = 'sim_transfer' - dataset_dir = DATASET_DIR + '/sim_transfer' #SIM_TASK_CONFIGS[task_name]['dataset_dir'] - num_episodes = 100 #SIM_TASK_CONFIGS[task_name]['num_episodes'] +def sample_task_state(task_name): + if task_name == 'sim_transfer': + return sample_transfer_pose() + if task_name == 'sim_air_insert_ring_bar': + return sample_air_insert_ring_bar_state() + raise NotImplementedError(f'Unsupported scripted rollout task: {task_name}') + + +def make_policy(task_name, inject_noise=False): + if task_name == 'sim_transfer': + return TestPickAndTransferPolicy(inject_noise) + if task_name == 'sim_air_insert_ring_bar': + return TestAirInsertPolicy(inject_noise) + raise NotImplementedError(f'Unsupported scripted rollout task: {task_name}') + + +def main(task_name='sim_transfer'): + task_cfg = SIM_TASK_CONFIGS[task_name] + dataset_dir = task_cfg['dataset_dir'] + num_episodes = 100 inject_noise = False - episode_len = 700 #SIM_TASK_CONFIGS[task_name]['episode_len'] - camera_names = ['angle','r_vis', 'top', 'front'] #SIM_TASK_CONFIGS[task_name]['camera_names'] + episode_len = task_cfg['episode_len'] + camera_names = ['angle', 'r_vis', 'top', 'front'] image_size = (256, 256) - if task_name == 'sim_transfer': + if task_name in {'sim_transfer', 'sim_air_insert_ring_bar'}: print(task_name) else: raise NotImplementedError @@ -29,7 +47,7 @@ def main(): success = [] env = make_sim_env(task_name) - policy = TestPickAndTransferPolicy(inject_noise) + policy = make_policy(task_name, inject_noise=inject_noise) # 等待osmesa完全启动后再开始收集数据 print("等待osmesa线程启动...") @@ -41,8 +59,8 @@ def main(): max_reward = float('-inf') print(f'\n{episode_idx=}') print('Rollout out EE space scripted policy') - box_pos = sample_transfer_pose() - env.reset(box_pos) + task_state = sample_task_state(task_name) + env.reset(task_state) episode_writer = StreamingEpisodeWriter( dataset_path=os.path.join(dataset_dir, f'episode_{episode_idx}.hdf5'), max_timesteps=episode_len, @@ -50,7 +68,7 @@ def main(): image_size=image_size, ) for step in range(episode_len): - raw_action = policy.predict(box_pos,step) + raw_action = policy.predict(task_state, step) env.step(raw_action) env.render() sum_reward += env.rew diff --git a/tests/test_air_insert_env.py b/tests/test_air_insert_env.py index 8811ba9..236c5e6 100644 --- a/tests/test_air_insert_env.py +++ b/tests/test_air_insert_env.py @@ -300,5 +300,69 @@ class AirInsertRewardAndSuccessTest(unittest.TestCase): self.assertEqual(reward, 5) +class AirInsertPolicyAndSmokeTest(unittest.TestCase): + def test_air_insert_policy_emits_valid_16d_action(self): + policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy") + policy_cls = getattr(policy_module, "TestAirInsertPolicy", None) + self.assertIsNotNone( + policy_cls, + "Expected roboimi.demos.diana_air_insert_policy.TestAirInsertPolicy", + ) + + task_state = act_ex_utils.sample_air_insert_ring_bar_state() + policy = policy_cls(inject_noise=False) + action = policy.predict(task_state, 0) + + self.assertEqual(action.shape, (16,)) + np.testing.assert_array_equal(action[-2:], np.array([100, 100])) + + def test_scripted_rollout_entrypoint_selects_ring_bar_sampler_and_policy(self): + rollout_module = importlib.import_module("roboimi.demos.diana_record_sim_episodes") + sampler_fn = getattr(rollout_module, "sample_task_state", None) + policy_factory = getattr(rollout_module, "make_policy", None) + self.assertIsNotNone( + sampler_fn, + "Expected roboimi.demos.diana_record_sim_episodes.sample_task_state", + ) + self.assertIsNotNone( + policy_factory, + "Expected roboimi.demos.diana_record_sim_episodes.make_policy", + ) + + task_state = sampler_fn("sim_air_insert_ring_bar") + self.assertEqual( + list(task_state.keys()), + ["ring_pos", "ring_quat", "bar_pos", "bar_quat"], + ) + + policy = policy_factory("sim_air_insert_ring_bar", inject_noise=False) + self.assertEqual(policy.__class__.__name__, "TestAirInsertPolicy") + + def test_real_headless_smoke_instantiates_resets_and_steps_new_task_once(self): + policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy") + policy_cls = getattr(policy_module, "TestAirInsertPolicy", None) + self.assertIsNotNone(policy_cls) + + task_state = act_ex_utils.sample_air_insert_ring_bar_state() + env = make_sim_env("sim_air_insert_ring_bar", headless=True) + policy = policy_cls(inject_noise=False) + + try: + env.reset(task_state) + action = policy.predict(task_state, 0) + env.step(action) + self.assertIsNotNone(env.obs) + self.assertIn("qpos", env.obs) + self.assertIn("images", env.obs) + finally: + env.exit_flag = True + cam_thread = getattr(env, "cam_thread", None) + if cam_thread is not None: + cam_thread.join(timeout=1.0) + viewer = getattr(env, "viewer", None) + if viewer is not None: + viewer.close() + + if __name__ == "__main__": unittest.main()