diff --git a/roboimi/demos/config.yaml b/roboimi/demos/config.yaml index 3b16eb1..efb6f1c 100644 --- a/roboimi/demos/config.yaml +++ b/roboimi/demos/config.yaml @@ -38,6 +38,11 @@ episode_len: # leave empty here by default camera_names: [] # leave empty here by default xml_dir: # leave empty here by default +# action smoothing settings (for GR00T) +use_action_smoothing: true +smooth_method: "ema" # Options: "ema", "moving_avg", "lowpass", "none" +smooth_alpha: 0.3 # Smoothing factor (0-1), smaller = smoother + # transformer settings batch_size: 15 state_dim: 16 diff --git a/roboimi/demos/diana_eval.py b/roboimi/demos/diana_eval.py index a5e71e5..e6994d4 100644 --- a/roboimi/demos/diana_eval.py +++ b/roboimi/demos/diana_eval.py @@ -12,6 +12,71 @@ from roboimi.envs.double_pos_ctrl_env import make_sim_env from roboimi.utils.act_ex_utils import sample_transfer_pose +class ActionSmoother: + """ + 动作平滑器,支持多种平滑策略 + """ + def __init__(self, action_dim, method='ema', alpha=0.3, window_size=5): + """ + Args: + action_dim: 动作维度 + method: 平滑方法 ('ema', 'moving_avg', 'lowpass', 'none') + alpha: EMA 平滑系数 (0-1),越小越平滑 + window_size: 滑动窗口大小 + """ + self.action_dim = action_dim + self.method = method + self.alpha = alpha + self.window_size = window_size + self.history = [] + self.prev_action = None + + def smooth(self, action): + """ + 对动作进行平滑处理 + + Args: + action: 当前动作 [action_dim] + + Returns: + smoothed_action: 平滑后的动作 + """ + if self.method == 'none': + return action + + if self.method == 'ema': + # 指数移动平均 + if self.prev_action is None: + smoothed = action + else: + smoothed = self.alpha * action + (1 - self.alpha) * self.prev_action + self.prev_action = smoothed + return smoothed + + elif self.method == 'moving_avg': + # 滑动平均 + self.history.append(action.copy()) + if len(self.history) > self.window_size: + self.history.pop(0) + return np.mean(self.history, axis=0) + + elif self.method == 'lowpass': + # 一阶低通滤波器 + if self.prev_action is None: + smoothed = action + else: + smoothed = self.prev_action + self.alpha * (action - self.prev_action) + self.prev_action = smoothed + return smoothed + + else: + raise ValueError(f"Unknown smoothing method: {self.method}") + + def reset(self): + """重置平滑器状态""" + self.history = [] + self.prev_action = None + #should be added into IOUtils def get_image(obs,camera_names): @@ -57,6 +122,19 @@ def run_episode(config, policy, stats, save_episode,num_rollouts): pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std'] post_process = lambda a: a * stats['action_std'] + stats['action_mean'] box_pos = sample_transfer_pose() + + # 初始化动作平滑器 + action_dim = config['action_dim'] + use_smoothing = config.get('use_action_smoothing', False) + smooth_method = config.get('smooth_method', 'ema') + smooth_alpha = config.get('smooth_alpha', 0.3) + + if use_smoothing and config['policy_class'] == "GR00T": + smoother = ActionSmoother(action_dim, method=smooth_method, alpha=smooth_alpha) + print(f"Action smoothing enabled: method={smooth_method}, alpha={smooth_alpha}") + else: + smoother = None + for rollout_id in range(num_rollouts): print("\nrollout_id===",rollout_id,"\n") image_list = [] @@ -64,6 +142,11 @@ def run_episode(config, policy, stats, save_episode,num_rollouts): query_frequency = config['policy_config'].get('num_queries', 1) print("query_freq =====",query_frequency) env.reset(box_pos) + + # 重置平滑器 + if smoother is not None: + smoother.reset() + with torch.inference_mode(): for t in range(700): image_list.append(env._get_image_obs()['images'] if 'images' in env._get_image_obs() else {print("img error")}) @@ -83,6 +166,11 @@ def run_episode(config, policy, stats, save_episode,num_rollouts): action = post_process(raw_action) + + # 应用动作平滑(仅对 GR00T) + if smoother is not None: + action = smoother.smooth(action) + print("action == ",action) env.step_jnt(action) rewards.append(env.rew)