diff --git a/roboimi/demos/vla_scripts/train_vla.py b/roboimi/demos/vla_scripts/train_vla.py index f7c8e57..32115fb 100644 --- a/roboimi/demos/vla_scripts/train_vla.py +++ b/roboimi/demos/vla_scripts/train_vla.py @@ -270,40 +270,39 @@ def main(cfg: DictConfig): log.info(f"Step {step}/{cfg.train.max_steps} | Loss: {loss.item():.4f}") # ===================================================================== - # Validation + # Checkpoint saving & Validation # ===================================================================== - val_freq = int(cfg.train.get('val_freq', cfg.train.log_freq)) - if val_loader is not None and val_freq > 0 and step % val_freq == 0: + if step > 0 and step % cfg.train.save_freq == 0: + # Run validation val_loss = run_validation() if val_loss is not None: log.info(f"Step {step}/{cfg.train.max_steps} | Val Loss: {val_loss:.4f}") - # ===================================================================== - # Checkpoint saving - # ===================================================================== - if step > 0 and step % cfg.train.save_freq == 0: checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt" torch.save({ 'step': step, 'model_state_dict': agent.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss.item(), + 'val_loss': val_loss, 'dataset_stats': dataset_stats, }, checkpoint_path) log.info(f"💾 Checkpoint saved: {checkpoint_path}") - # Save best model - if loss.item() < best_loss: - best_loss = loss.item() + # Save best model based on validation loss + eval_loss = val_loss if val_loss is not None else loss.item() + if eval_loss < best_loss: + best_loss = eval_loss best_model_path = checkpoint_dir / "vla_model_best.pt" torch.save({ 'step': step, 'model_state_dict': agent.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss.item(), + 'val_loss': val_loss, 'dataset_stats': dataset_stats, }, best_model_path) - log.info(f"🌟 Best model updated: {best_model_path} (loss: {best_loss:.4f})") + log.info(f"🌟 Best model updated: {best_model_path} (val_loss: {best_loss:.4f})") # ========================================================================= # 5. Save Final Model