From 3d0c2ec5b1af207c2f27492b75f019dce3e97dc2 Mon Sep 17 00:00:00 2001 From: gouhanke <12219217+gouhanke@user.noreply.gitee.com> Date: Fri, 6 Feb 2026 18:00:09 +0800 Subject: [PATCH] =?UTF-8?q?feat(train):=20=E6=B7=BB=E5=8A=A0=E9=AA=8C?= =?UTF-8?q?=E8=AF=81=E9=9B=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- roboimi/demos/vla_scripts/train_vla.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/roboimi/demos/vla_scripts/train_vla.py b/roboimi/demos/vla_scripts/train_vla.py index f7c8e57..32115fb 100644 --- a/roboimi/demos/vla_scripts/train_vla.py +++ b/roboimi/demos/vla_scripts/train_vla.py @@ -270,40 +270,39 @@ def main(cfg: DictConfig): log.info(f"Step {step}/{cfg.train.max_steps} | Loss: {loss.item():.4f}") # ===================================================================== - # Validation + # Checkpoint saving & Validation # ===================================================================== - val_freq = int(cfg.train.get('val_freq', cfg.train.log_freq)) - if val_loader is not None and val_freq > 0 and step % val_freq == 0: + if step > 0 and step % cfg.train.save_freq == 0: + # Run validation val_loss = run_validation() if val_loss is not None: log.info(f"Step {step}/{cfg.train.max_steps} | Val Loss: {val_loss:.4f}") - # ===================================================================== - # Checkpoint saving - # ===================================================================== - if step > 0 and step % cfg.train.save_freq == 0: checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt" torch.save({ 'step': step, 'model_state_dict': agent.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss.item(), + 'val_loss': val_loss, 'dataset_stats': dataset_stats, }, checkpoint_path) log.info(f"💾 Checkpoint saved: {checkpoint_path}") - # Save best model - if loss.item() < best_loss: - best_loss = loss.item() + # Save best model based on validation loss + eval_loss = val_loss if val_loss is not None else loss.item() + if eval_loss < best_loss: + best_loss = eval_loss best_model_path = checkpoint_dir / "vla_model_best.pt" torch.save({ 'step': step, 'model_state_dict': agent.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss.item(), + 'val_loss': val_loss, 'dataset_stats': dataset_stats, }, best_model_path) - log.info(f"🌟 Best model updated: {best_model_path} (loss: {best_loss:.4f})") + log.info(f"🌟 Best model updated: {best_model_path} (val_loss: {best_loss:.4f})") # ========================================================================= # 5. Save Final Model