feat(inference): 添加动作平滑器

This commit is contained in:
gouhanke
2026-02-03 10:32:09 +08:00
parent a977cc4f5e
commit c1ce560b32
2 changed files with 93 additions and 0 deletions

View File

@@ -38,6 +38,11 @@ episode_len: # leave empty here by default
camera_names: [] # leave empty here by default
xml_dir: # leave empty here by default
# action smoothing settings (for GR00T)
use_action_smoothing: true
smooth_method: "ema" # Options: "ema", "moving_avg", "lowpass", "none"
smooth_alpha: 0.3 # Smoothing factor (0-1), smaller = smoother
# transformer settings
batch_size: 15
state_dim: 16

View File

@@ -12,6 +12,71 @@ from roboimi.envs.double_pos_ctrl_env import make_sim_env
from roboimi.utils.act_ex_utils import sample_transfer_pose
class ActionSmoother:
"""
动作平滑器,支持多种平滑策略
"""
def __init__(self, action_dim, method='ema', alpha=0.3, window_size=5):
"""
Args:
action_dim: 动作维度
method: 平滑方法 ('ema', 'moving_avg', 'lowpass', 'none')
alpha: EMA 平滑系数 (0-1),越小越平滑
window_size: 滑动窗口大小
"""
self.action_dim = action_dim
self.method = method
self.alpha = alpha
self.window_size = window_size
self.history = []
self.prev_action = None
def smooth(self, action):
"""
对动作进行平滑处理
Args:
action: 当前动作 [action_dim]
Returns:
smoothed_action: 平滑后的动作
"""
if self.method == 'none':
return action
if self.method == 'ema':
# 指数移动平均
if self.prev_action is None:
smoothed = action
else:
smoothed = self.alpha * action + (1 - self.alpha) * self.prev_action
self.prev_action = smoothed
return smoothed
elif self.method == 'moving_avg':
# 滑动平均
self.history.append(action.copy())
if len(self.history) > self.window_size:
self.history.pop(0)
return np.mean(self.history, axis=0)
elif self.method == 'lowpass':
# 一阶低通滤波器
if self.prev_action is None:
smoothed = action
else:
smoothed = self.prev_action + self.alpha * (action - self.prev_action)
self.prev_action = smoothed
return smoothed
else:
raise ValueError(f"Unknown smoothing method: {self.method}")
def reset(self):
"""重置平滑器状态"""
self.history = []
self.prev_action = None
#should be added into IOUtils
def get_image(obs,camera_names):
@@ -57,6 +122,19 @@ def run_episode(config, policy, stats, save_episode,num_rollouts):
pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std']
post_process = lambda a: a * stats['action_std'] + stats['action_mean']
box_pos = sample_transfer_pose()
# 初始化动作平滑器
action_dim = config['action_dim']
use_smoothing = config.get('use_action_smoothing', False)
smooth_method = config.get('smooth_method', 'ema')
smooth_alpha = config.get('smooth_alpha', 0.3)
if use_smoothing and config['policy_class'] == "GR00T":
smoother = ActionSmoother(action_dim, method=smooth_method, alpha=smooth_alpha)
print(f"Action smoothing enabled: method={smooth_method}, alpha={smooth_alpha}")
else:
smoother = None
for rollout_id in range(num_rollouts):
print("\nrollout_id===",rollout_id,"\n")
image_list = []
@@ -64,6 +142,11 @@ def run_episode(config, policy, stats, save_episode,num_rollouts):
query_frequency = config['policy_config'].get('num_queries', 1)
print("query_freq =====",query_frequency)
env.reset(box_pos)
# 重置平滑器
if smoother is not None:
smoother.reset()
with torch.inference_mode():
for t in range(700):
image_list.append(env._get_image_obs()['images'] if 'images' in env._get_image_obs() else {print("img error")})
@@ -83,6 +166,11 @@ def run_episode(config, policy, stats, save_episode,num_rollouts):
action = post_process(raw_action)
# 应用动作平滑(仅对 GR00T
if smoother is not None:
action = smoother.smooth(action)
print("action == ",action)
env.step_jnt(action)
rewards.append(env.rew)