fix: support headless MuJoCo image capture during rollout

This commit is contained in:
Logic
2026-04-02 08:25:01 +08:00
parent 0514f86c36
commit 3a17744dcf
3 changed files with 79 additions and 16 deletions

View File

@@ -57,6 +57,7 @@ class DualDianaMed(MujocoEnv):
self.obs = None self.obs = None
self.rew = None self.rew = None
self._offscreen_renderer = None
def actuate_J(self, q_target, qdot_target, Arm): def actuate_J(self, q_target, qdot_target, Arm):
@@ -161,6 +162,8 @@ class DualDianaMed(MujocoEnv):
def _get_obs(self): def _get_obs(self):
if not self.is_render:
self._update_camera_images_sync()
obs = collections.OrderedDict() obs = collections.OrderedDict()
obs['qpos'] = self.get_obs_qpos obs['qpos'] = self.get_obs_qpos
obs['action'] = self.compute_qpos obs['action'] = self.compute_qpos
@@ -173,6 +176,8 @@ class DualDianaMed(MujocoEnv):
return obs return obs
def _get_image_obs(self): def _get_image_obs(self):
if not self.is_render:
self._update_camera_images_sync()
obs = collections.OrderedDict() obs = collections.OrderedDict()
obs['images'] = dict() obs['images'] = dict()
obs['images']['top'] = self.top obs['images']['top'] = self.top
@@ -211,27 +216,36 @@ class DualDianaMed(MujocoEnv):
raise AttributeError("please input right name") raise AttributeError("please input right name")
def _get_or_create_offscreen_renderer(self):
renderer = getattr(self, '_offscreen_renderer', None)
if renderer is None:
renderer = mj.Renderer(self.mj_model, height=480, width=640)
self._offscreen_renderer = renderer
return renderer
def _render_camera_set(self, img_renderer):
img_renderer.update_scene(self.mj_data, camera="rs_cam_right")
self.r_vis = img_renderer.render()[:, :, ::-1]
img_renderer.update_scene(self.mj_data, camera="rs_cam_left")
self.l_vis = img_renderer.render()[:, :, ::-1]
img_renderer.update_scene(self.mj_data, camera="top")
self.top = img_renderer.render()[:, :, ::-1]
img_renderer.update_scene(self.mj_data, camera="angle")
self.angle = img_renderer.render()[:, :, ::-1]
img_renderer.update_scene(self.mj_data, camera="front")
self.front = img_renderer.render()[:, :, ::-1]
def _update_camera_images_sync(self):
img_renderer = self._get_or_create_offscreen_renderer()
self._render_camera_set(img_renderer)
def camera_viewer(self): def camera_viewer(self):
img_renderer = mj.Renderer(self.mj_model,height=480,width=640) img_renderer = self._get_or_create_offscreen_renderer()
show_gui = self.is_render show_gui = self.is_render
if show_gui: if show_gui:
cv2.namedWindow('Cam view',cv2.WINDOW_NORMAL) cv2.namedWindow('Cam view',cv2.WINDOW_NORMAL)
while not self.exit_flag: while not self.exit_flag:
img_renderer.update_scene(self.mj_data,camera="rs_cam_right") self._render_camera_set(img_renderer)
self.r_vis = img_renderer.render()
self.r_vis = self.r_vis[:, :, ::-1]
img_renderer.update_scene(self.mj_data,camera="rs_cam_left")
self.l_vis = img_renderer.render()
self.l_vis = self.l_vis[:, :, ::-1]
img_renderer.update_scene(self.mj_data,camera="top")
self.top = img_renderer.render()
self.top = self.top[:, :, ::-1]
img_renderer.update_scene(self.mj_data,camera="angle")
self.angle = img_renderer.render()
self.angle = self.angle[:, :, ::-1]
img_renderer.update_scene(self.mj_data,camera="front")
self.front = img_renderer.render()
self.front = self.front[:, :, ::-1]
if show_gui: if show_gui:
if self.cam_view is not None: if self.cam_view is not None:
cv2.imshow('Cam view', self.cam_view) cv2.imshow('Cam view', self.cam_view)
@@ -239,6 +253,9 @@ class DualDianaMed(MujocoEnv):
def cam_start(self): def cam_start(self):
if not self.is_render:
self.cam_thread = None
return
self.cam_thread = threading.Thread(target=self.camera_viewer,daemon=True) self.cam_thread = threading.Thread(target=self.camera_viewer,daemon=True)
self.cam_thread.start() self.cam_thread.start()

View File

@@ -76,6 +76,9 @@ class DualDianaMed_Pos_Ctrl(DualDianaMed):
self.angle = None self.angle = None
self.r_vis = None self.r_vis = None
self.front = None self.front = None
if not self.is_render:
self._update_camera_images_sync()
return
self.cam_flage = True self.cam_flage = True
t=0 t=0
while self.cam_flage: while self.cam_flage:

View File

@@ -129,6 +129,49 @@ class EvalVLAHeadlessTest(unittest.TestCase):
cam_view="angle", cam_view="angle",
) )
def test_headless_sync_camera_capture_populates_images_without_gui_calls(self):
env = DualDianaMed.__new__(DualDianaMed)
env.mj_model = object()
env.mj_data = object()
env.exit_flag = False
env.is_render = False
env.cam = 'angle'
env.r_vis = None
env.l_vis = None
env.top = None
env.angle = None
env.front = None
env._offscreen_renderer = None
with mock.patch(
'roboimi.envs.double_base.mj.Renderer',
side_effect=lambda *args, **kwargs: _FakeRenderer(env),
) as renderer_cls, mock.patch('roboimi.envs.double_base.cv2.namedWindow') as named_window, mock.patch(
'roboimi.envs.double_base.cv2.imshow'
) as imshow, mock.patch('roboimi.envs.double_base.cv2.waitKey') as wait_key:
env._update_camera_images_sync()
renderer_cls.assert_called_once()
named_window.assert_not_called()
imshow.assert_not_called()
wait_key.assert_not_called()
self.assertIsNotNone(env.r_vis)
self.assertIsNotNone(env.l_vis)
self.assertIsNotNone(env.top)
self.assertIsNotNone(env.angle)
self.assertIsNotNone(env.front)
def test_cam_start_skips_background_thread_when_headless(self):
env = DualDianaMed.__new__(DualDianaMed)
env.is_render = False
env.cam_thread = None
with mock.patch('roboimi.envs.double_base.threading.Thread') as thread_cls:
env.cam_start()
thread_cls.assert_not_called()
self.assertIsNone(env.cam_thread)
def test_camera_viewer_headless_updates_images_without_gui_calls(self): def test_camera_viewer_headless_updates_images_without_gui_calls(self):
env = DualDianaMed.__new__(DualDianaMed) env = DualDianaMed.__new__(DualDianaMed)
env.mj_model = object() env.mj_model = object()