From f4a5c77b7ce84d9199d0ad2a77865ac3be8a96cd Mon Sep 17 00:00:00 2001 From: gouhanke <12219217+gouhanke@user.noreply.gitee.com> Date: Fri, 6 Feb 2026 14:29:36 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E5=BD=92=E4=B8=80=E5=8C=96?= =?UTF-8?q?=E4=BB=8Eagent=E8=A7=A3=E8=80=A6=E5=88=B0=E8=AE=AD=E7=BB=83?= =?UTF-8?q?=E3=80=81=E6=8E=A8=E7=90=86=E8=84=9A=E6=9C=AC=E4=B8=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- roboimi/demos/vla_scripts/eval_vla.py | 92 +++++++++++++++----------- roboimi/demos/vla_scripts/train_vla.py | 40 +++++------ roboimi/vla/agent.py | 20 +----- 3 files changed, 72 insertions(+), 80 deletions(-) diff --git a/roboimi/demos/vla_scripts/eval_vla.py b/roboimi/demos/vla_scripts/eval_vla.py index 225fe4e..8264b28 100644 --- a/roboimi/demos/vla_scripts/eval_vla.py +++ b/roboimi/demos/vla_scripts/eval_vla.py @@ -48,7 +48,8 @@ class VLAEvaluator: pred_horizon: int = 16, use_smoothing: bool = False, smooth_method: str = 'ema', - smooth_alpha: float = 0.3 + smooth_alpha: float = 0.3, + dataset_stats: dict = None ): self.agent = agent.to(device) self.device = device @@ -57,6 +58,21 @@ class VLAEvaluator: self.obs_horizon = obs_horizon self.pred_horizon = pred_horizon + # Dataset statistics for normalization/denormalization + self.stats = dataset_stats + if self.stats is not None: + self.normalization_type = self.stats.get('normalization_type', 'gaussian') + self.qpos_mean = torch.tensor(self.stats['qpos_mean'], dtype=torch.float32) + self.qpos_std = torch.tensor(self.stats['qpos_std'], dtype=torch.float32) + self.qpos_min = torch.tensor(self.stats.get('qpos_min', []), dtype=torch.float32) + self.qpos_max = torch.tensor(self.stats.get('qpos_max', []), dtype=torch.float32) + self.action_mean = torch.tensor(self.stats['action_mean'], dtype=torch.float32) + self.action_std = torch.tensor(self.stats['action_std'], dtype=torch.float32) + self.action_min = torch.tensor(self.stats.get('action_min', []), dtype=torch.float32) + self.action_max = torch.tensor(self.stats.get('action_max', []), dtype=torch.float32) + else: + self.normalization_type = None + # Action smoothing self.use_smoothing = use_smoothing self.smooth_method = smooth_method @@ -124,7 +140,15 @@ class VLAEvaluator: if len(self.obs_buffer['qpos']) > self.obs_horizon: self.obs_buffer['qpos'] = self.obs_buffer['qpos'][-self.obs_horizon:] - qpos_tensor = torch.stack(self.obs_buffer['qpos'], dim=0).unsqueeze(0) + qpos_tensor = torch.stack(self.obs_buffer['qpos'], dim=0).unsqueeze(0) # (1, obs_horizon, obs_dim) + + # Normalize qpos + if self.stats is not None: + if self.normalization_type == 'gaussian': + qpos_tensor = (qpos_tensor - self.qpos_mean) / self.qpos_std + else: # min_max: normalize to [-1, 1] + qpos_tensor = 2 * (qpos_tensor - self.qpos_min) / (self.qpos_max - self.qpos_min) - 1 + return qpos_tensor @torch.no_grad() @@ -141,6 +165,13 @@ class VLAEvaluator: proprioception=qpos ) + # Denormalize actions + if self.stats is not None: + if self.normalization_type == 'gaussian': + predicted_actions = predicted_actions * self.action_std.to(self.device) + self.action_mean.to(self.device) + else: # min_max + predicted_actions = (predicted_actions + 1) / 2 * (self.action_max.to(self.device) - self.action_min.to(self.device)) + self.action_min.to(self.device) + self.cached_actions = predicted_actions.squeeze(0).cpu().numpy() self.query_step = 0 @@ -208,36 +239,29 @@ def load_checkpoint( agent = instantiate(agent_cfg) # Load model state - if 'model_state_dict' in checkpoint: - agent.load_state_dict(checkpoint['model_state_dict']) - log.info(f"✅ Model state loaded (step: {checkpoint.get('step', 'unknown')})") - elif 'state_dict' in checkpoint: - agent.load_state_dict(checkpoint['state_dict']) - log.info("✅ Model state loaded") - else: - agent.load_state_dict(checkpoint) - log.info("✅ Model state loaded") + agent.load_state_dict(checkpoint['model_state_dict']) + log.info(f"✅ Model state loaded (step: {checkpoint.get('step', 'unknown')})") # Load dataset statistics for denormalization - stats_path = ckpt_path.parent / 'dataset_stats.json' - if stats_path.exists(): - with open(stats_path, 'r') as f: - stats = json.load(f) - agent.action_mean = np.array(stats['action_mean']) - agent.action_std = np.array(stats['action_std']) - agent.qpos_mean = np.array(stats['qpos_mean']) - agent.qpos_std = np.array(stats['qpos_std']) - log.info("✅ Dataset statistics loaded for denormalization") + stats = checkpoint.get('dataset_stats', None) + + if stats is not None: + log.info(f"✅ Dataset statistics loaded (normalization: {stats.get('normalization_type', 'gaussian')})") else: - log.warning(f"⚠️ {stats_path} not found. Actions will not be denormalized!") - agent.action_mean = None - agent.action_std = None + # Fallback: try external JSON file (兼容旧 checkpoint) + stats_path = ckpt_path.parent / 'dataset_stats.json' + if stats_path.exists(): + with open(stats_path, 'r') as f: + stats = json.load(f) + log.info("✅ Dataset statistics loaded from external JSON (legacy)") + else: + log.warning("⚠️ No dataset statistics found. Actions will not be denormalized!") agent.eval() agent.to(device) log.info(f"✅ Model loaded successfully on {device}") - return agent + return agent, stats @hydra.main(version_base=None, config_path="../../vla/conf", config_name="config") @@ -262,7 +286,7 @@ def main(cfg: DictConfig): # Load model log.info(f"🚀 Loading model from {eval_cfg.ckpt_path}...") - agent = load_checkpoint( + agent, dataset_stats = load_checkpoint( ckpt_path=eval_cfg.ckpt_path, agent_cfg=cfg.agent, device=device @@ -277,7 +301,8 @@ def main(cfg: DictConfig): obs_horizon=eval_cfg.obs_horizon, use_smoothing=eval_cfg.use_smoothing, smooth_method=eval_cfg.smooth_method, - smooth_alpha=eval_cfg.smooth_alpha + smooth_alpha=eval_cfg.smooth_alpha, + dataset_stats=dataset_stats ) # Create environment @@ -293,9 +318,6 @@ def main(cfg: DictConfig): env.reset(box_pos) evaluator.reset() - success = False - success_timestep = 0 - with torch.inference_mode(): for t in tqdm(range(eval_cfg.max_timesteps), desc=f"Episode {episode_idx + 1}"): obs = env._get_image_obs() @@ -307,17 +329,7 @@ def main(cfg: DictConfig): env.render() - if env.rew == 1.0: - success = True - success_timestep = t - print(f"\n✅ Task completed at timestep {t}!") - break - - print(f"\nEpisode {episode_idx + 1} Summary:") - print(f" Success: {success}") - if success: - print(f" Success Timestep: {success_timestep}") - print(f" Length: {t + 1} timesteps") + print(f"\nEpisode {episode_idx + 1} completed ({eval_cfg.max_timesteps} timesteps)") print(f"\n{'='*60}") print("Evaluation complete!") diff --git a/roboimi/demos/vla_scripts/train_vla.py b/roboimi/demos/vla_scripts/train_vla.py index 169a1b8..348d8fd 100644 --- a/roboimi/demos/vla_scripts/train_vla.py +++ b/roboimi/demos/vla_scripts/train_vla.py @@ -106,43 +106,36 @@ def main(cfg: DictConfig): raise # ========================================================================= - # 2.5. Save Dataset Statistics as JSON + # 2.5. Load Dataset Statistics (will be saved into checkpoints) # ========================================================================= - log.info("💾 Saving dataset statistics...") + log.info("💾 Loading dataset statistics...") + dataset_stats = None try: - # Get dataset_dir from config dataset_dir = cfg.data.get('dataset_dir', 'roboimi/demos/dataset/sim_transfer') stats_path = Path(dataset_dir) / 'data_stats.pkl' if stats_path.exists(): - # Load pickle file with open(stats_path, 'rb') as f: stats = pickle.load(f) - # Extract action statistics - action_mean = stats['action']['mean'].tolist() if 'action' in stats else [] - action_std = stats['action']['std'].tolist() if 'action' in stats else [] - qpos_mean = stats['qpos']['mean'].tolist() if 'qpos' in stats else [] - qpos_std = stats['qpos']['std'].tolist() if 'qpos' in stats else [] - - # Save as JSON - json_stats = { - 'action_mean': action_mean, - 'action_std': action_std, - 'qpos_mean': qpos_mean, - 'qpos_std': qpos_std + dataset_stats = { + 'normalization_type': cfg.data.get('normalization_type', 'gaussian'), + 'action_mean': stats['action']['mean'].tolist(), + 'action_std': stats['action']['std'].tolist(), + 'action_min': stats['action']['min'].tolist(), + 'action_max': stats['action']['max'].tolist(), + 'qpos_mean': stats['qpos']['mean'].tolist(), + 'qpos_std': stats['qpos']['std'].tolist(), + 'qpos_min': stats['qpos']['min'].tolist(), + 'qpos_max': stats['qpos']['max'].tolist(), } - json_path = checkpoint_dir / 'dataset_stats.json' - with open(json_path, 'w') as f: - json.dump(json_stats, f, indent=2) - - log.info(f"✅ Dataset statistics saved to {json_path}") + log.info(f"✅ Dataset statistics loaded (normalization: {dataset_stats['normalization_type']})") else: log.warning(f"⚠️ Statistics file not found: {stats_path}") log.warning("⚠️ Actions will not be denormalized during inference!") except Exception as e: - log.warning(f"⚠️ Failed to save statistics as JSON: {e}") + log.warning(f"⚠️ Failed to load statistics: {e}") log.warning("⚠️ Training will continue, but inference may not work correctly") # ========================================================================= @@ -234,6 +227,7 @@ def main(cfg: DictConfig): 'model_state_dict': agent.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss.item(), + 'dataset_stats': dataset_stats, }, checkpoint_path) log.info(f"💾 Checkpoint saved: {checkpoint_path}") @@ -246,6 +240,7 @@ def main(cfg: DictConfig): 'model_state_dict': agent.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss.item(), + 'dataset_stats': dataset_stats, }, best_model_path) log.info(f"🌟 Best model updated: {best_model_path} (loss: {best_loss:.4f})") @@ -258,6 +253,7 @@ def main(cfg: DictConfig): 'model_state_dict': agent.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss.item(), + 'dataset_stats': dataset_stats, }, final_model_path) log.info(f"💾 Final model saved: {final_model_path}") diff --git a/roboimi/vla/agent.py b/roboimi/vla/agent.py index 2e6a2ee..f29901c 100644 --- a/roboimi/vla/agent.py +++ b/roboimi/vla/agent.py @@ -107,14 +107,6 @@ class VLAAgent(nn.Module): # 1. 提取当前观测特征 (只做一次) visual_features = self.vision_encoder(images).view(B, -1) proprioception = proprioception.view(B, -1) - if hasattr(self, 'qpos_mean') and hasattr(self, 'qpos_std') and self.qpos_mean is not None: - # Convert to tensor for normalization - qpos_mean = torch.from_numpy(self.qpos_mean).float().to(proprioception.device) - qpos_std = torch.from_numpy(self.qpos_std).float().to(proprioception.device) - qpos_mean = qpos_mean.repeat(2) - qpos_std = qpos_std.repeat(2) - # Normalize: (qpos - mean) / std - proprioception = (proprioception - qpos_mean.unsqueeze(0)) / qpos_std.unsqueeze(0) global_cond = torch.cat([visual_features, proprioception], dim=-1) # 2. 初始化纯高斯噪声动作 @@ -141,13 +133,5 @@ class VLAAgent(nn.Module): noise_pred, t, current_actions ).prev_sample - # 4. 反归一化动作 (Denormalize actions) - if hasattr(self, 'action_mean') and hasattr(self, 'action_std') and self.action_mean is not None: - # Convert to numpy for denormalization - action_mean = torch.from_numpy(self.action_mean).float().to(current_actions.device) - action_std = torch.from_numpy(self.action_std).float().to(current_actions.device) - # Denormalize: action * std + mean - current_actions = current_actions * action_std.unsqueeze(0).unsqueeze(0) + action_mean.unsqueeze(0).unsqueeze(0) - - # 5. 输出最终动作序列 - return current_actions # 返回去噪后的干净动作 \ No newline at end of file + # 4. 输出最终动作序列(归一化空间,由调用方负责反归一化) + return current_actions \ No newline at end of file