feat(train): 添加验证集
This commit is contained in:
@@ -270,40 +270,39 @@ def main(cfg: DictConfig):
|
||||
log.info(f"Step {step}/{cfg.train.max_steps} | Loss: {loss.item():.4f}")
|
||||
|
||||
# =====================================================================
|
||||
# Validation
|
||||
# Checkpoint saving & Validation
|
||||
# =====================================================================
|
||||
val_freq = int(cfg.train.get('val_freq', cfg.train.log_freq))
|
||||
if val_loader is not None and val_freq > 0 and step % val_freq == 0:
|
||||
if step > 0 and step % cfg.train.save_freq == 0:
|
||||
# Run validation
|
||||
val_loss = run_validation()
|
||||
if val_loss is not None:
|
||||
log.info(f"Step {step}/{cfg.train.max_steps} | Val Loss: {val_loss:.4f}")
|
||||
|
||||
# =====================================================================
|
||||
# Checkpoint saving
|
||||
# =====================================================================
|
||||
if step > 0 and step % cfg.train.save_freq == 0:
|
||||
checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt"
|
||||
torch.save({
|
||||
'step': step,
|
||||
'model_state_dict': agent.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'loss': loss.item(),
|
||||
'val_loss': val_loss,
|
||||
'dataset_stats': dataset_stats,
|
||||
}, checkpoint_path)
|
||||
log.info(f"💾 Checkpoint saved: {checkpoint_path}")
|
||||
|
||||
# Save best model
|
||||
if loss.item() < best_loss:
|
||||
best_loss = loss.item()
|
||||
# Save best model based on validation loss
|
||||
eval_loss = val_loss if val_loss is not None else loss.item()
|
||||
if eval_loss < best_loss:
|
||||
best_loss = eval_loss
|
||||
best_model_path = checkpoint_dir / "vla_model_best.pt"
|
||||
torch.save({
|
||||
'step': step,
|
||||
'model_state_dict': agent.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'loss': loss.item(),
|
||||
'val_loss': val_loss,
|
||||
'dataset_stats': dataset_stats,
|
||||
}, best_model_path)
|
||||
log.info(f"🌟 Best model updated: {best_model_path} (loss: {best_loss:.4f})")
|
||||
log.info(f"🌟 Best model updated: {best_model_path} (val_loss: {best_loss:.4f})")
|
||||
|
||||
# =========================================================================
|
||||
# 5. Save Final Model
|
||||
|
||||
Reference in New Issue
Block a user