feat(policy): add scripted air insertion policy

This commit is contained in:
Logic
2026-04-23 17:44:53 +08:00
parent a837a982f7
commit 8145c9eb62
3 changed files with 163 additions and 13 deletions

View File

@@ -0,0 +1,68 @@
import numpy as np
from pyquaternion import Quaternion
from roboimi.demos.diana_policy import PolicyBase
class TestAirInsertPolicy(PolicyBase):
def generate_trajectory(self, task_state):
ring_xyz = np.asarray(task_state["ring_pos"], dtype=np.float64)
bar_xyz = np.asarray(task_state["bar_pos"], dtype=np.float64)
init_mocap_pose_left = np.array(
[
-0.17297014,
1.00485877,
1.32773627,
7.06825181e-01,
8.20281078e-06,
-7.07388269e-01,
-5.20399313e-06,
],
dtype=np.float64,
)
init_mocap_pose_right = np.array(
[
0.17297014,
0.9951369,
1.32773623,
2.59463975e-06,
7.07388269e-01,
5.59551158e-06,
7.06825181e-01,
],
dtype=np.float64,
)
left_init_quat = Quaternion(init_mocap_pose_left[3:])
right_init_quat = Quaternion(init_mocap_pose_right[3:])
left_pick_quat = (left_init_quat * Quaternion(axis=[0.0, 1.0, 0.0], degrees=90)).elements
right_pick_quat = (right_init_quat * Quaternion(axis=[0.0, 1.0, 0.0], degrees=90)).elements
right_insert_quat = (right_init_quat * Quaternion(axis=[1.0, 0.0, 0.0], degrees=90)).elements
meet_xyz = np.array([0.0, 1.0, 1.30], dtype=np.float64)
left_hold_xyz = meet_xyz + np.array([-0.02, 0.0, 0.08], dtype=np.float64)
right_insert_start_xyz = meet_xyz + np.array([0.0, 0.0, 0.10], dtype=np.float64)
right_insert_end_xyz = meet_xyz + np.array([0.0, 0.0, -0.02], dtype=np.float64)
self.left_trajectory = [
{"t": 1, "xyz": init_mocap_pose_left[:3], "quat": init_mocap_pose_left[3:], "gripper": 100},
{"t": 80, "xyz": ring_xyz + np.array([0.0, 0.0, 0.22]), "quat": left_pick_quat, "gripper": 100},
{"t": 150, "xyz": ring_xyz + np.array([0.0, 0.0, 0.08]), "quat": left_pick_quat, "gripper": 100},
{"t": 180, "xyz": ring_xyz + np.array([0.0, 0.0, 0.08]), "quat": left_pick_quat, "gripper": -100},
{"t": 260, "xyz": ring_xyz + np.array([0.0, 0.0, 0.24]), "quat": left_pick_quat, "gripper": -100},
{"t": 420, "xyz": left_hold_xyz, "quat": init_mocap_pose_left[3:], "gripper": -100},
{"t": 700, "xyz": left_hold_xyz, "quat": init_mocap_pose_left[3:], "gripper": -100},
]
self.right_trajectory = [
{"t": 1, "xyz": init_mocap_pose_right[:3], "quat": init_mocap_pose_right[3:], "gripper": 100},
{"t": 80, "xyz": bar_xyz + np.array([0.0, 0.0, 0.22]), "quat": right_pick_quat, "gripper": 100},
{"t": 150, "xyz": bar_xyz + np.array([0.0, 0.0, 0.08]), "quat": right_pick_quat, "gripper": 100},
{"t": 180, "xyz": bar_xyz + np.array([0.0, 0.0, 0.08]), "quat": right_pick_quat, "gripper": -100},
{"t": 260, "xyz": bar_xyz + np.array([0.0, 0.0, 0.26]), "quat": right_pick_quat, "gripper": -100},
{"t": 420, "xyz": right_insert_start_xyz, "quat": right_insert_quat, "gripper": -100},
{"t": 580, "xyz": right_insert_end_xyz, "quat": right_insert_quat, "gripper": -100},
{"t": 700, "xyz": right_insert_end_xyz, "quat": right_insert_quat, "gripper": -100},
]

View File

@@ -2,9 +2,11 @@ import time
import os import os
import numpy as np import numpy as np
from roboimi.envs.double_pos_ctrl_env import make_sim_env from roboimi.envs.double_pos_ctrl_env import make_sim_env
from diana_policy import TestPickAndTransferPolicy from roboimi.demos.diana_air_insert_policy import TestAirInsertPolicy
from roboimi.demos.diana_policy import TestPickAndTransferPolicy
import cv2 import cv2
from roboimi.utils.act_ex_utils import sample_transfer_pose from roboimi.utils.act_ex_utils import sample_air_insert_ring_bar_state, sample_transfer_pose
from roboimi.utils.constants import SIM_TASK_CONFIGS
from roboimi.utils.streaming_episode_writer import StreamingEpisodeWriter from roboimi.utils.streaming_episode_writer import StreamingEpisodeWriter
import pathlib import pathlib
@@ -12,16 +14,32 @@ HOME_PATH = str(pathlib.Path(__file__).parent.resolve())
DATASET_DIR = HOME_PATH + '/dataset' DATASET_DIR = HOME_PATH + '/dataset'
def main(): def sample_task_state(task_name):
task_name = 'sim_transfer' if task_name == 'sim_transfer':
dataset_dir = DATASET_DIR + '/sim_transfer' #SIM_TASK_CONFIGS[task_name]['dataset_dir'] return sample_transfer_pose()
num_episodes = 100 #SIM_TASK_CONFIGS[task_name]['num_episodes'] if task_name == 'sim_air_insert_ring_bar':
return sample_air_insert_ring_bar_state()
raise NotImplementedError(f'Unsupported scripted rollout task: {task_name}')
def make_policy(task_name, inject_noise=False):
if task_name == 'sim_transfer':
return TestPickAndTransferPolicy(inject_noise)
if task_name == 'sim_air_insert_ring_bar':
return TestAirInsertPolicy(inject_noise)
raise NotImplementedError(f'Unsupported scripted rollout task: {task_name}')
def main(task_name='sim_transfer'):
task_cfg = SIM_TASK_CONFIGS[task_name]
dataset_dir = task_cfg['dataset_dir']
num_episodes = 100
inject_noise = False inject_noise = False
episode_len = 700 #SIM_TASK_CONFIGS[task_name]['episode_len'] episode_len = task_cfg['episode_len']
camera_names = ['angle','r_vis', 'top', 'front'] #SIM_TASK_CONFIGS[task_name]['camera_names'] camera_names = ['angle', 'r_vis', 'top', 'front']
image_size = (256, 256) image_size = (256, 256)
if task_name == 'sim_transfer': if task_name in {'sim_transfer', 'sim_air_insert_ring_bar'}:
print(task_name) print(task_name)
else: else:
raise NotImplementedError raise NotImplementedError
@@ -29,7 +47,7 @@ def main():
success = [] success = []
env = make_sim_env(task_name) env = make_sim_env(task_name)
policy = TestPickAndTransferPolicy(inject_noise) policy = make_policy(task_name, inject_noise=inject_noise)
# 等待osmesa完全启动后再开始收集数据 # 等待osmesa完全启动后再开始收集数据
print("等待osmesa线程启动...") print("等待osmesa线程启动...")
@@ -41,8 +59,8 @@ def main():
max_reward = float('-inf') max_reward = float('-inf')
print(f'\n{episode_idx=}') print(f'\n{episode_idx=}')
print('Rollout out EE space scripted policy') print('Rollout out EE space scripted policy')
box_pos = sample_transfer_pose() task_state = sample_task_state(task_name)
env.reset(box_pos) env.reset(task_state)
episode_writer = StreamingEpisodeWriter( episode_writer = StreamingEpisodeWriter(
dataset_path=os.path.join(dataset_dir, f'episode_{episode_idx}.hdf5'), dataset_path=os.path.join(dataset_dir, f'episode_{episode_idx}.hdf5'),
max_timesteps=episode_len, max_timesteps=episode_len,
@@ -50,7 +68,7 @@ def main():
image_size=image_size, image_size=image_size,
) )
for step in range(episode_len): for step in range(episode_len):
raw_action = policy.predict(box_pos,step) raw_action = policy.predict(task_state, step)
env.step(raw_action) env.step(raw_action)
env.render() env.render()
sum_reward += env.rew sum_reward += env.rew

View File

@@ -300,5 +300,69 @@ class AirInsertRewardAndSuccessTest(unittest.TestCase):
self.assertEqual(reward, 5) self.assertEqual(reward, 5)
class AirInsertPolicyAndSmokeTest(unittest.TestCase):
def test_air_insert_policy_emits_valid_16d_action(self):
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
self.assertIsNotNone(
policy_cls,
"Expected roboimi.demos.diana_air_insert_policy.TestAirInsertPolicy",
)
task_state = act_ex_utils.sample_air_insert_ring_bar_state()
policy = policy_cls(inject_noise=False)
action = policy.predict(task_state, 0)
self.assertEqual(action.shape, (16,))
np.testing.assert_array_equal(action[-2:], np.array([100, 100]))
def test_scripted_rollout_entrypoint_selects_ring_bar_sampler_and_policy(self):
rollout_module = importlib.import_module("roboimi.demos.diana_record_sim_episodes")
sampler_fn = getattr(rollout_module, "sample_task_state", None)
policy_factory = getattr(rollout_module, "make_policy", None)
self.assertIsNotNone(
sampler_fn,
"Expected roboimi.demos.diana_record_sim_episodes.sample_task_state",
)
self.assertIsNotNone(
policy_factory,
"Expected roboimi.demos.diana_record_sim_episodes.make_policy",
)
task_state = sampler_fn("sim_air_insert_ring_bar")
self.assertEqual(
list(task_state.keys()),
["ring_pos", "ring_quat", "bar_pos", "bar_quat"],
)
policy = policy_factory("sim_air_insert_ring_bar", inject_noise=False)
self.assertEqual(policy.__class__.__name__, "TestAirInsertPolicy")
def test_real_headless_smoke_instantiates_resets_and_steps_new_task_once(self):
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
self.assertIsNotNone(policy_cls)
task_state = act_ex_utils.sample_air_insert_ring_bar_state()
env = make_sim_env("sim_air_insert_ring_bar", headless=True)
policy = policy_cls(inject_noise=False)
try:
env.reset(task_state)
action = policy.predict(task_state, 0)
env.step(action)
self.assertIsNotNone(env.obs)
self.assertIn("qpos", env.obs)
self.assertIn("images", env.obs)
finally:
env.exit_flag = True
cam_thread = getattr(env, "cam_thread", None)
if cam_thread is not None:
cam_thread.join(timeout=1.0)
viewer = getattr(env, "viewer", None)
if viewer is not None:
viewer.close()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()