feat(inference): 添加动作平滑器
This commit is contained in:
@@ -38,6 +38,11 @@ episode_len: # leave empty here by default
|
|||||||
camera_names: [] # leave empty here by default
|
camera_names: [] # leave empty here by default
|
||||||
xml_dir: # leave empty here by default
|
xml_dir: # leave empty here by default
|
||||||
|
|
||||||
|
# action smoothing settings (for GR00T)
|
||||||
|
use_action_smoothing: true
|
||||||
|
smooth_method: "ema" # Options: "ema", "moving_avg", "lowpass", "none"
|
||||||
|
smooth_alpha: 0.3 # Smoothing factor (0-1), smaller = smoother
|
||||||
|
|
||||||
# transformer settings
|
# transformer settings
|
||||||
batch_size: 15
|
batch_size: 15
|
||||||
state_dim: 16
|
state_dim: 16
|
||||||
|
|||||||
@@ -12,6 +12,71 @@ from roboimi.envs.double_pos_ctrl_env import make_sim_env
|
|||||||
from roboimi.utils.act_ex_utils import sample_transfer_pose
|
from roboimi.utils.act_ex_utils import sample_transfer_pose
|
||||||
|
|
||||||
|
|
||||||
|
class ActionSmoother:
|
||||||
|
"""
|
||||||
|
动作平滑器,支持多种平滑策略
|
||||||
|
"""
|
||||||
|
def __init__(self, action_dim, method='ema', alpha=0.3, window_size=5):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
action_dim: 动作维度
|
||||||
|
method: 平滑方法 ('ema', 'moving_avg', 'lowpass', 'none')
|
||||||
|
alpha: EMA 平滑系数 (0-1),越小越平滑
|
||||||
|
window_size: 滑动窗口大小
|
||||||
|
"""
|
||||||
|
self.action_dim = action_dim
|
||||||
|
self.method = method
|
||||||
|
self.alpha = alpha
|
||||||
|
self.window_size = window_size
|
||||||
|
self.history = []
|
||||||
|
self.prev_action = None
|
||||||
|
|
||||||
|
def smooth(self, action):
|
||||||
|
"""
|
||||||
|
对动作进行平滑处理
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: 当前动作 [action_dim]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
smoothed_action: 平滑后的动作
|
||||||
|
"""
|
||||||
|
if self.method == 'none':
|
||||||
|
return action
|
||||||
|
|
||||||
|
if self.method == 'ema':
|
||||||
|
# 指数移动平均
|
||||||
|
if self.prev_action is None:
|
||||||
|
smoothed = action
|
||||||
|
else:
|
||||||
|
smoothed = self.alpha * action + (1 - self.alpha) * self.prev_action
|
||||||
|
self.prev_action = smoothed
|
||||||
|
return smoothed
|
||||||
|
|
||||||
|
elif self.method == 'moving_avg':
|
||||||
|
# 滑动平均
|
||||||
|
self.history.append(action.copy())
|
||||||
|
if len(self.history) > self.window_size:
|
||||||
|
self.history.pop(0)
|
||||||
|
return np.mean(self.history, axis=0)
|
||||||
|
|
||||||
|
elif self.method == 'lowpass':
|
||||||
|
# 一阶低通滤波器
|
||||||
|
if self.prev_action is None:
|
||||||
|
smoothed = action
|
||||||
|
else:
|
||||||
|
smoothed = self.prev_action + self.alpha * (action - self.prev_action)
|
||||||
|
self.prev_action = smoothed
|
||||||
|
return smoothed
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown smoothing method: {self.method}")
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""重置平滑器状态"""
|
||||||
|
self.history = []
|
||||||
|
self.prev_action = None
|
||||||
|
|
||||||
|
|
||||||
#should be added into IOUtils
|
#should be added into IOUtils
|
||||||
def get_image(obs,camera_names):
|
def get_image(obs,camera_names):
|
||||||
@@ -57,6 +122,19 @@ def run_episode(config, policy, stats, save_episode,num_rollouts):
|
|||||||
pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std']
|
pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std']
|
||||||
post_process = lambda a: a * stats['action_std'] + stats['action_mean']
|
post_process = lambda a: a * stats['action_std'] + stats['action_mean']
|
||||||
box_pos = sample_transfer_pose()
|
box_pos = sample_transfer_pose()
|
||||||
|
|
||||||
|
# 初始化动作平滑器
|
||||||
|
action_dim = config['action_dim']
|
||||||
|
use_smoothing = config.get('use_action_smoothing', False)
|
||||||
|
smooth_method = config.get('smooth_method', 'ema')
|
||||||
|
smooth_alpha = config.get('smooth_alpha', 0.3)
|
||||||
|
|
||||||
|
if use_smoothing and config['policy_class'] == "GR00T":
|
||||||
|
smoother = ActionSmoother(action_dim, method=smooth_method, alpha=smooth_alpha)
|
||||||
|
print(f"Action smoothing enabled: method={smooth_method}, alpha={smooth_alpha}")
|
||||||
|
else:
|
||||||
|
smoother = None
|
||||||
|
|
||||||
for rollout_id in range(num_rollouts):
|
for rollout_id in range(num_rollouts):
|
||||||
print("\nrollout_id===",rollout_id,"\n")
|
print("\nrollout_id===",rollout_id,"\n")
|
||||||
image_list = []
|
image_list = []
|
||||||
@@ -64,6 +142,11 @@ def run_episode(config, policy, stats, save_episode,num_rollouts):
|
|||||||
query_frequency = config['policy_config'].get('num_queries', 1)
|
query_frequency = config['policy_config'].get('num_queries', 1)
|
||||||
print("query_freq =====",query_frequency)
|
print("query_freq =====",query_frequency)
|
||||||
env.reset(box_pos)
|
env.reset(box_pos)
|
||||||
|
|
||||||
|
# 重置平滑器
|
||||||
|
if smoother is not None:
|
||||||
|
smoother.reset()
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
for t in range(700):
|
for t in range(700):
|
||||||
image_list.append(env._get_image_obs()['images'] if 'images' in env._get_image_obs() else {print("img error")})
|
image_list.append(env._get_image_obs()['images'] if 'images' in env._get_image_obs() else {print("img error")})
|
||||||
@@ -83,6 +166,11 @@ def run_episode(config, policy, stats, save_episode,num_rollouts):
|
|||||||
|
|
||||||
|
|
||||||
action = post_process(raw_action)
|
action = post_process(raw_action)
|
||||||
|
|
||||||
|
# 应用动作平滑(仅对 GR00T)
|
||||||
|
if smoother is not None:
|
||||||
|
action = smoother.smooth(action)
|
||||||
|
|
||||||
print("action == ",action)
|
print("action == ",action)
|
||||||
env.step_jnt(action)
|
env.step_jnt(action)
|
||||||
rewards.append(env.rew)
|
rewards.append(env.rew)
|
||||||
|
|||||||
Reference in New Issue
Block a user