feat(train): 添加验证集
This commit is contained in:
@@ -270,40 +270,39 @@ def main(cfg: DictConfig):
|
|||||||
log.info(f"Step {step}/{cfg.train.max_steps} | Loss: {loss.item():.4f}")
|
log.info(f"Step {step}/{cfg.train.max_steps} | Loss: {loss.item():.4f}")
|
||||||
|
|
||||||
# =====================================================================
|
# =====================================================================
|
||||||
# Validation
|
# Checkpoint saving & Validation
|
||||||
# =====================================================================
|
# =====================================================================
|
||||||
val_freq = int(cfg.train.get('val_freq', cfg.train.log_freq))
|
if step > 0 and step % cfg.train.save_freq == 0:
|
||||||
if val_loader is not None and val_freq > 0 and step % val_freq == 0:
|
# Run validation
|
||||||
val_loss = run_validation()
|
val_loss = run_validation()
|
||||||
if val_loss is not None:
|
if val_loss is not None:
|
||||||
log.info(f"Step {step}/{cfg.train.max_steps} | Val Loss: {val_loss:.4f}")
|
log.info(f"Step {step}/{cfg.train.max_steps} | Val Loss: {val_loss:.4f}")
|
||||||
|
|
||||||
# =====================================================================
|
|
||||||
# Checkpoint saving
|
|
||||||
# =====================================================================
|
|
||||||
if step > 0 and step % cfg.train.save_freq == 0:
|
|
||||||
checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt"
|
checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt"
|
||||||
torch.save({
|
torch.save({
|
||||||
'step': step,
|
'step': step,
|
||||||
'model_state_dict': agent.state_dict(),
|
'model_state_dict': agent.state_dict(),
|
||||||
'optimizer_state_dict': optimizer.state_dict(),
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
'loss': loss.item(),
|
'loss': loss.item(),
|
||||||
|
'val_loss': val_loss,
|
||||||
'dataset_stats': dataset_stats,
|
'dataset_stats': dataset_stats,
|
||||||
}, checkpoint_path)
|
}, checkpoint_path)
|
||||||
log.info(f"💾 Checkpoint saved: {checkpoint_path}")
|
log.info(f"💾 Checkpoint saved: {checkpoint_path}")
|
||||||
|
|
||||||
# Save best model
|
# Save best model based on validation loss
|
||||||
if loss.item() < best_loss:
|
eval_loss = val_loss if val_loss is not None else loss.item()
|
||||||
best_loss = loss.item()
|
if eval_loss < best_loss:
|
||||||
|
best_loss = eval_loss
|
||||||
best_model_path = checkpoint_dir / "vla_model_best.pt"
|
best_model_path = checkpoint_dir / "vla_model_best.pt"
|
||||||
torch.save({
|
torch.save({
|
||||||
'step': step,
|
'step': step,
|
||||||
'model_state_dict': agent.state_dict(),
|
'model_state_dict': agent.state_dict(),
|
||||||
'optimizer_state_dict': optimizer.state_dict(),
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
'loss': loss.item(),
|
'loss': loss.item(),
|
||||||
|
'val_loss': val_loss,
|
||||||
'dataset_stats': dataset_stats,
|
'dataset_stats': dataset_stats,
|
||||||
}, best_model_path)
|
}, best_model_path)
|
||||||
log.info(f"🌟 Best model updated: {best_model_path} (loss: {best_loss:.4f})")
|
log.info(f"🌟 Best model updated: {best_model_path} (val_loss: {best_loss:.4f})")
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# 5. Save Final Model
|
# 5. Save Final Model
|
||||||
|
|||||||
Reference in New Issue
Block a user