feat(train): 添加验证集

This commit is contained in:
gouhanke
2026-02-06 18:00:09 +08:00
parent ea49e63eb7
commit 3d0c2ec5b1

View File

@@ -270,40 +270,39 @@ def main(cfg: DictConfig):
log.info(f"Step {step}/{cfg.train.max_steps} | Loss: {loss.item():.4f}")
# =====================================================================
# Validation
# Checkpoint saving & Validation
# =====================================================================
val_freq = int(cfg.train.get('val_freq', cfg.train.log_freq))
if val_loader is not None and val_freq > 0 and step % val_freq == 0:
if step > 0 and step % cfg.train.save_freq == 0:
# Run validation
val_loss = run_validation()
if val_loss is not None:
log.info(f"Step {step}/{cfg.train.max_steps} | Val Loss: {val_loss:.4f}")
# =====================================================================
# Checkpoint saving
# =====================================================================
if step > 0 and step % cfg.train.save_freq == 0:
checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt"
torch.save({
'step': step,
'model_state_dict': agent.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss.item(),
'val_loss': val_loss,
'dataset_stats': dataset_stats,
}, checkpoint_path)
log.info(f"💾 Checkpoint saved: {checkpoint_path}")
# Save best model
if loss.item() < best_loss:
best_loss = loss.item()
# Save best model based on validation loss
eval_loss = val_loss if val_loss is not None else loss.item()
if eval_loss < best_loss:
best_loss = eval_loss
best_model_path = checkpoint_dir / "vla_model_best.pt"
torch.save({
'step': step,
'model_state_dict': agent.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss.item(),
'val_loss': val_loss,
'dataset_stats': dataset_stats,
}, best_model_path)
log.info(f"🌟 Best model updated: {best_model_path} (loss: {best_loss:.4f})")
log.info(f"🌟 Best model updated: {best_model_path} (val_loss: {best_loss:.4f})")
# =========================================================================
# 5. Save Final Model