from typing import Optional import pathlib import numpy as np import time import shutil import math from multiprocessing.managers import SharedMemoryManager from diffusion_policy.real_world.rtde_interpolation_controller import RTDEInterpolationController from diffusion_policy.real_world.multi_realsense import MultiRealsense, SingleRealsense from diffusion_policy.real_world.video_recorder import VideoRecorder from diffusion_policy.common.timestamp_accumulator import ( TimestampObsAccumulator, TimestampActionAccumulator, align_timestamps ) from diffusion_policy.real_world.multi_camera_visualizer import MultiCameraVisualizer from diffusion_policy.common.replay_buffer import ReplayBuffer from diffusion_policy.common.cv2_util import ( get_image_transform, optimal_row_cols) DEFAULT_OBS_KEY_MAP = { # robot 'ActualTCPPose': 'robot_eef_pose', 'ActualTCPSpeed': 'robot_eef_pose_vel', 'ActualQ': 'robot_joint', 'ActualQd': 'robot_joint_vel', # timestamps 'step_idx': 'step_idx', 'timestamp': 'timestamp' } class RealEnv: def __init__(self, # required params output_dir, robot_ip, # env params frequency=10, n_obs_steps=2, # obs obs_image_resolution=(640,480), max_obs_buffer_size=30, camera_serial_numbers=None, obs_key_map=DEFAULT_OBS_KEY_MAP, obs_float32=False, # action max_pos_speed=0.25, max_rot_speed=0.6, # robot tcp_offset=0.13, init_joints=False, # video capture params video_capture_fps=30, video_capture_resolution=(1280,720), # saving params record_raw_video=True, thread_per_video=2, video_crf=21, # vis params enable_multi_cam_vis=True, multi_cam_vis_resolution=(1280,720), # shared memory shm_manager=None ): assert frequency <= video_capture_fps output_dir = pathlib.Path(output_dir) assert output_dir.parent.is_dir() video_dir = output_dir.joinpath('videos') video_dir.mkdir(parents=True, exist_ok=True) zarr_path = str(output_dir.joinpath('replay_buffer.zarr').absolute()) replay_buffer = ReplayBuffer.create_from_path( zarr_path=zarr_path, mode='a') if shm_manager is None: shm_manager = SharedMemoryManager() shm_manager.start() if camera_serial_numbers is None: camera_serial_numbers = SingleRealsense.get_connected_devices_serial() color_tf = get_image_transform( input_res=video_capture_resolution, output_res=obs_image_resolution, # obs output rgb bgr_to_rgb=True) color_transform = color_tf if obs_float32: color_transform = lambda x: color_tf(x).astype(np.float32) / 255 def transform(data): data['color'] = color_transform(data['color']) return data rw, rh, col, row = optimal_row_cols( n_cameras=len(camera_serial_numbers), in_wh_ratio=obs_image_resolution[0]/obs_image_resolution[1], max_resolution=multi_cam_vis_resolution ) vis_color_transform = get_image_transform( input_res=video_capture_resolution, output_res=(rw,rh), bgr_to_rgb=False ) def vis_transform(data): data['color'] = vis_color_transform(data['color']) return data recording_transfrom = None recording_fps = video_capture_fps recording_pix_fmt = 'bgr24' if not record_raw_video: recording_transfrom = transform recording_fps = frequency recording_pix_fmt = 'rgb24' video_recorder = VideoRecorder.create_h264( fps=recording_fps, codec='h264', input_pix_fmt=recording_pix_fmt, crf=video_crf, thread_type='FRAME', thread_count=thread_per_video) realsense = MultiRealsense( serial_numbers=camera_serial_numbers, shm_manager=shm_manager, resolution=video_capture_resolution, capture_fps=video_capture_fps, put_fps=video_capture_fps, # send every frame immediately after arrival # ignores put_fps put_downsample=False, record_fps=recording_fps, enable_color=True, enable_depth=False, enable_infrared=False, get_max_k=max_obs_buffer_size, transform=transform, vis_transform=vis_transform, recording_transform=recording_transfrom, video_recorder=video_recorder, verbose=False ) multi_cam_vis = None if enable_multi_cam_vis: multi_cam_vis = MultiCameraVisualizer( realsense=realsense, row=row, col=col, rgb_to_bgr=False ) cube_diag = np.linalg.norm([1,1,1]) j_init = np.array([0,-90,-90,-90,90,0]) / 180 * np.pi if not init_joints: j_init = None robot = RTDEInterpolationController( shm_manager=shm_manager, robot_ip=robot_ip, frequency=125, # UR5 CB3 RTDE lookahead_time=0.1, gain=300, max_pos_speed=max_pos_speed*cube_diag, max_rot_speed=max_rot_speed*cube_diag, launch_timeout=3, tcp_offset_pose=[0,0,tcp_offset,0,0,0], payload_mass=None, payload_cog=None, joints_init=j_init, joints_init_speed=1.05, soft_real_time=False, verbose=False, receive_keys=None, get_max_k=max_obs_buffer_size ) self.realsense = realsense self.robot = robot self.multi_cam_vis = multi_cam_vis self.video_capture_fps = video_capture_fps self.frequency = frequency self.n_obs_steps = n_obs_steps self.max_obs_buffer_size = max_obs_buffer_size self.max_pos_speed = max_pos_speed self.max_rot_speed = max_rot_speed self.obs_key_map = obs_key_map # recording self.output_dir = output_dir self.video_dir = video_dir self.replay_buffer = replay_buffer # temp memory buffers self.last_realsense_data = None # recording buffers self.obs_accumulator = None self.action_accumulator = None self.stage_accumulator = None self.start_time = None # ======== start-stop API ============= @property def is_ready(self): return self.realsense.is_ready and self.robot.is_ready def start(self, wait=True): self.realsense.start(wait=False) self.robot.start(wait=False) if self.multi_cam_vis is not None: self.multi_cam_vis.start(wait=False) if wait: self.start_wait() def stop(self, wait=True): self.end_episode() if self.multi_cam_vis is not None: self.multi_cam_vis.stop(wait=False) self.robot.stop(wait=False) self.realsense.stop(wait=False) if wait: self.stop_wait() def start_wait(self): self.realsense.start_wait() self.robot.start_wait() if self.multi_cam_vis is not None: self.multi_cam_vis.start_wait() def stop_wait(self): self.robot.stop_wait() self.realsense.stop_wait() if self.multi_cam_vis is not None: self.multi_cam_vis.stop_wait() # ========= context manager =========== def __enter__(self): self.start() return self def __exit__(self, exc_type, exc_val, exc_tb): self.stop() # ========= async env API =========== def get_obs(self) -> dict: "observation dict" assert self.is_ready # get data # 30 Hz, camera_receive_timestamp k = math.ceil(self.n_obs_steps * (self.video_capture_fps / self.frequency)) self.last_realsense_data = self.realsense.get( k=k, out=self.last_realsense_data) # 125 hz, robot_receive_timestamp last_robot_data = self.robot.get_all_state() # both have more than n_obs_steps data # align camera obs timestamps dt = 1 / self.frequency last_timestamp = np.max([x['timestamp'][-1] for x in self.last_realsense_data.values()]) obs_align_timestamps = last_timestamp - (np.arange(self.n_obs_steps)[::-1] * dt) camera_obs = dict() for camera_idx, value in self.last_realsense_data.items(): this_timestamps = value['timestamp'] this_idxs = list() for t in obs_align_timestamps: is_before_idxs = np.nonzero(this_timestamps < t)[0] this_idx = 0 if len(is_before_idxs) > 0: this_idx = is_before_idxs[-1] this_idxs.append(this_idx) # remap key camera_obs[f'camera_{camera_idx}'] = value['color'][this_idxs] # align robot obs robot_timestamps = last_robot_data['robot_receive_timestamp'] this_timestamps = robot_timestamps this_idxs = list() for t in obs_align_timestamps: is_before_idxs = np.nonzero(this_timestamps < t)[0] this_idx = 0 if len(is_before_idxs) > 0: this_idx = is_before_idxs[-1] this_idxs.append(this_idx) robot_obs_raw = dict() for k, v in last_robot_data.items(): if k in self.obs_key_map: robot_obs_raw[self.obs_key_map[k]] = v robot_obs = dict() for k, v in robot_obs_raw.items(): robot_obs[k] = v[this_idxs] # accumulate obs if self.obs_accumulator is not None: self.obs_accumulator.put( robot_obs_raw, robot_timestamps ) # return obs obs_data = dict(camera_obs) obs_data.update(robot_obs) obs_data['timestamp'] = obs_align_timestamps return obs_data def exec_actions(self, actions: np.ndarray, timestamps: np.ndarray, stages: Optional[np.ndarray]=None): assert self.is_ready if not isinstance(actions, np.ndarray): actions = np.array(actions) if not isinstance(timestamps, np.ndarray): timestamps = np.array(timestamps) if stages is None: stages = np.zeros_like(timestamps, dtype=np.int64) elif not isinstance(stages, np.ndarray): stages = np.array(stages, dtype=np.int64) # convert action to pose receive_time = time.time() is_new = timestamps > receive_time new_actions = actions[is_new] new_timestamps = timestamps[is_new] new_stages = stages[is_new] # schedule waypoints for i in range(len(new_actions)): self.robot.schedule_waypoint( pose=new_actions[i], target_time=new_timestamps[i] ) # record actions if self.action_accumulator is not None: self.action_accumulator.put( new_actions, new_timestamps ) if self.stage_accumulator is not None: self.stage_accumulator.put( new_stages, new_timestamps ) def get_robot_state(self): return self.robot.get_state() # recording API def start_episode(self, start_time=None): "Start recording and return first obs" if start_time is None: start_time = time.time() self.start_time = start_time assert self.is_ready # prepare recording stuff episode_id = self.replay_buffer.n_episodes this_video_dir = self.video_dir.joinpath(str(episode_id)) this_video_dir.mkdir(parents=True, exist_ok=True) n_cameras = self.realsense.n_cameras video_paths = list() for i in range(n_cameras): video_paths.append( str(this_video_dir.joinpath(f'{i}.mp4').absolute())) # start recording on realsense self.realsense.restart_put(start_time=start_time) self.realsense.start_recording(video_path=video_paths, start_time=start_time) # create accumulators self.obs_accumulator = TimestampObsAccumulator( start_time=start_time, dt=1/self.frequency ) self.action_accumulator = TimestampActionAccumulator( start_time=start_time, dt=1/self.frequency ) self.stage_accumulator = TimestampActionAccumulator( start_time=start_time, dt=1/self.frequency ) print(f'Episode {episode_id} started!') def end_episode(self): "Stop recording" assert self.is_ready # stop video recorder self.realsense.stop_recording() if self.obs_accumulator is not None: # recording assert self.action_accumulator is not None assert self.stage_accumulator is not None # Since the only way to accumulate obs and action is by calling # get_obs and exec_actions, which will be in the same thread. # We don't need to worry new data come in here. obs_data = self.obs_accumulator.data obs_timestamps = self.obs_accumulator.timestamps actions = self.action_accumulator.actions action_timestamps = self.action_accumulator.timestamps stages = self.stage_accumulator.actions n_steps = min(len(obs_timestamps), len(action_timestamps)) if n_steps > 0: episode = dict() episode['timestamp'] = obs_timestamps[:n_steps] episode['action'] = actions[:n_steps] episode['stage'] = stages[:n_steps] for key, value in obs_data.items(): episode[key] = value[:n_steps] self.replay_buffer.add_episode(episode, compressors='disk') episode_id = self.replay_buffer.n_episodes - 1 print(f'Episode {episode_id} saved!') self.obs_accumulator = None self.action_accumulator = None self.stage_accumulator = None def drop_episode(self): self.end_episode() self.replay_buffer.drop_episode() episode_id = self.replay_buffer.n_episodes this_video_dir = self.video_dir.joinpath(str(episode_id)) if this_video_dir.exists(): shutil.rmtree(str(this_video_dir)) print(f'Episode {episode_id} dropped!')