From 37a47ac2dde88915c53661f883f799bf895eb35c Mon Sep 17 00:00:00 2001 From: gouhanke <12219217+gouhanke@user.noreply.gitee.com> Date: Thu, 12 Feb 2026 13:00:43 +0800 Subject: [PATCH] =?UTF-8?q?debug:=20=E4=BF=9D=E5=AD=98stats=E5=88=B0ckpt?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- roboimi/demos/vla_scripts/train_vla.py | 10 +++++++--- roboimi/vla/agent.py | 7 ++++++- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/roboimi/demos/vla_scripts/train_vla.py b/roboimi/demos/vla_scripts/train_vla.py index 358cb5e..d96ca29 100644 --- a/roboimi/demos/vla_scripts/train_vla.py +++ b/roboimi/demos/vla_scripts/train_vla.py @@ -346,6 +346,8 @@ def main(cfg: DictConfig): log.info(f"步骤 {step}/{cfg.train.max_steps} | 验证损失: {val_loss:.4f}") checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt" + # 使用agent的归一化统计信息(包含normalization_type) + agent_stats = agent.get_normalization_stats() torch.save({ 'step': step, 'model_state_dict': agent.state_dict(), @@ -353,7 +355,7 @@ def main(cfg: DictConfig): 'scheduler_state_dict': scheduler.state_dict(), 'loss': loss.item(), 'val_loss': val_loss, - 'dataset_stats': dataset_stats, + 'dataset_stats': agent_stats, # 保存agent的统计信息 'current_lr': optimizer.param_groups[0]['lr'], }, checkpoint_path) log.info(f"💾 检查点已保存: {checkpoint_path}") @@ -363,6 +365,7 @@ def main(cfg: DictConfig): if eval_loss < best_loss: best_loss = eval_loss best_model_path = checkpoint_dir / "vla_model_best.pt" + agent_stats = agent.get_normalization_stats() torch.save({ 'step': step, 'model_state_dict': agent.state_dict(), @@ -370,7 +373,7 @@ def main(cfg: DictConfig): 'scheduler_state_dict': scheduler.state_dict(), 'loss': loss.item(), 'val_loss': val_loss, - 'dataset_stats': dataset_stats, + 'dataset_stats': agent_stats, # 保存agent的统计信息 'current_lr': optimizer.param_groups[0]['lr'], }, best_model_path) log.info(f"🌟 最佳模型已更新: {best_model_path} (验证损失: {best_loss:.4f})") @@ -379,13 +382,14 @@ def main(cfg: DictConfig): # 6. 保存最终模型 # ========================================================================= final_model_path = checkpoint_dir / "vla_model_final.pt" + agent_stats = agent.get_normalization_stats() torch.save({ 'step': cfg.train.max_steps, 'model_state_dict': agent.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'loss': loss.item(), - 'dataset_stats': dataset_stats, + 'dataset_stats': agent_stats, # 保存agent的统计信息 'current_lr': optimizer.param_groups[0]['lr'], }, final_model_path) log.info(f"💾 最终模型已保存: {final_model_path}") diff --git a/roboimi/vla/agent.py b/roboimi/vla/agent.py index c1ac1cd..34fa47c 100644 --- a/roboimi/vla/agent.py +++ b/roboimi/vla/agent.py @@ -240,7 +240,12 @@ class VLAAgent(nn.Module): if device is not None and self.normalization.enabled: # 确保归一化参数在同一设备上 - norm_device = self.normalization.qpos_mean.device + # 根据归一化类型获取正确的属性 + if self.normalization.normalization_type == 'gaussian': + norm_device = self.normalization.qpos_mean.device + else: # min_max + norm_device = self.normalization.qpos_min.device + if device != norm_device: self.normalization.to(device) # 同时确保其他模块也在正确设备