feat(inference): 添加动作平滑器
This commit is contained in:
@@ -38,6 +38,11 @@ episode_len: # leave empty here by default
|
||||
camera_names: [] # leave empty here by default
|
||||
xml_dir: # leave empty here by default
|
||||
|
||||
# action smoothing settings (for GR00T)
|
||||
use_action_smoothing: true
|
||||
smooth_method: "ema" # Options: "ema", "moving_avg", "lowpass", "none"
|
||||
smooth_alpha: 0.3 # Smoothing factor (0-1), smaller = smoother
|
||||
|
||||
# transformer settings
|
||||
batch_size: 15
|
||||
state_dim: 16
|
||||
|
||||
@@ -12,6 +12,71 @@ from roboimi.envs.double_pos_ctrl_env import make_sim_env
|
||||
from roboimi.utils.act_ex_utils import sample_transfer_pose
|
||||
|
||||
|
||||
class ActionSmoother:
|
||||
"""
|
||||
动作平滑器,支持多种平滑策略
|
||||
"""
|
||||
def __init__(self, action_dim, method='ema', alpha=0.3, window_size=5):
|
||||
"""
|
||||
Args:
|
||||
action_dim: 动作维度
|
||||
method: 平滑方法 ('ema', 'moving_avg', 'lowpass', 'none')
|
||||
alpha: EMA 平滑系数 (0-1),越小越平滑
|
||||
window_size: 滑动窗口大小
|
||||
"""
|
||||
self.action_dim = action_dim
|
||||
self.method = method
|
||||
self.alpha = alpha
|
||||
self.window_size = window_size
|
||||
self.history = []
|
||||
self.prev_action = None
|
||||
|
||||
def smooth(self, action):
|
||||
"""
|
||||
对动作进行平滑处理
|
||||
|
||||
Args:
|
||||
action: 当前动作 [action_dim]
|
||||
|
||||
Returns:
|
||||
smoothed_action: 平滑后的动作
|
||||
"""
|
||||
if self.method == 'none':
|
||||
return action
|
||||
|
||||
if self.method == 'ema':
|
||||
# 指数移动平均
|
||||
if self.prev_action is None:
|
||||
smoothed = action
|
||||
else:
|
||||
smoothed = self.alpha * action + (1 - self.alpha) * self.prev_action
|
||||
self.prev_action = smoothed
|
||||
return smoothed
|
||||
|
||||
elif self.method == 'moving_avg':
|
||||
# 滑动平均
|
||||
self.history.append(action.copy())
|
||||
if len(self.history) > self.window_size:
|
||||
self.history.pop(0)
|
||||
return np.mean(self.history, axis=0)
|
||||
|
||||
elif self.method == 'lowpass':
|
||||
# 一阶低通滤波器
|
||||
if self.prev_action is None:
|
||||
smoothed = action
|
||||
else:
|
||||
smoothed = self.prev_action + self.alpha * (action - self.prev_action)
|
||||
self.prev_action = smoothed
|
||||
return smoothed
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown smoothing method: {self.method}")
|
||||
|
||||
def reset(self):
|
||||
"""重置平滑器状态"""
|
||||
self.history = []
|
||||
self.prev_action = None
|
||||
|
||||
|
||||
#should be added into IOUtils
|
||||
def get_image(obs,camera_names):
|
||||
@@ -57,6 +122,19 @@ def run_episode(config, policy, stats, save_episode,num_rollouts):
|
||||
pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std']
|
||||
post_process = lambda a: a * stats['action_std'] + stats['action_mean']
|
||||
box_pos = sample_transfer_pose()
|
||||
|
||||
# 初始化动作平滑器
|
||||
action_dim = config['action_dim']
|
||||
use_smoothing = config.get('use_action_smoothing', False)
|
||||
smooth_method = config.get('smooth_method', 'ema')
|
||||
smooth_alpha = config.get('smooth_alpha', 0.3)
|
||||
|
||||
if use_smoothing and config['policy_class'] == "GR00T":
|
||||
smoother = ActionSmoother(action_dim, method=smooth_method, alpha=smooth_alpha)
|
||||
print(f"Action smoothing enabled: method={smooth_method}, alpha={smooth_alpha}")
|
||||
else:
|
||||
smoother = None
|
||||
|
||||
for rollout_id in range(num_rollouts):
|
||||
print("\nrollout_id===",rollout_id,"\n")
|
||||
image_list = []
|
||||
@@ -64,6 +142,11 @@ def run_episode(config, policy, stats, save_episode,num_rollouts):
|
||||
query_frequency = config['policy_config'].get('num_queries', 1)
|
||||
print("query_freq =====",query_frequency)
|
||||
env.reset(box_pos)
|
||||
|
||||
# 重置平滑器
|
||||
if smoother is not None:
|
||||
smoother.reset()
|
||||
|
||||
with torch.inference_mode():
|
||||
for t in range(700):
|
||||
image_list.append(env._get_image_obs()['images'] if 'images' in env._get_image_obs() else {print("img error")})
|
||||
@@ -83,6 +166,11 @@ def run_episode(config, policy, stats, save_episode,num_rollouts):
|
||||
|
||||
|
||||
action = post_process(raw_action)
|
||||
|
||||
# 应用动作平滑(仅对 GR00T)
|
||||
if smoother is not None:
|
||||
action = smoother.smooth(action)
|
||||
|
||||
print("action == ",action)
|
||||
env.step_jnt(action)
|
||||
rewards.append(env.rew)
|
||||
|
||||
Reference in New Issue
Block a user