feat: 添加resume机制
This commit is contained in:
@@ -5,6 +5,7 @@ import json
|
||||
import pickle
|
||||
import hydra
|
||||
import torch
|
||||
import re
|
||||
from tqdm import tqdm
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
@@ -44,6 +45,35 @@ def recursive_to_device(data, device):
|
||||
return data
|
||||
|
||||
|
||||
def resolve_resume_checkpoint(resume_ckpt, checkpoint_dir):
|
||||
"""
|
||||
解析恢复训练用的 checkpoint 路径。
|
||||
|
||||
Args:
|
||||
resume_ckpt: 配置中的 resume_ckpt,支持路径或 "auto"
|
||||
checkpoint_dir: 默认检查点目录
|
||||
|
||||
Returns:
|
||||
Path 或 None
|
||||
"""
|
||||
if resume_ckpt is None:
|
||||
return None
|
||||
|
||||
if str(resume_ckpt).lower() != "auto":
|
||||
return Path(resume_ckpt)
|
||||
|
||||
pattern = re.compile(r"vla_model_step_(\d+)\.pt$")
|
||||
candidates = []
|
||||
for ckpt_path in checkpoint_dir.glob("vla_model_step_*.pt"):
|
||||
match = pattern.search(ckpt_path.name)
|
||||
if match:
|
||||
candidates.append((int(match.group(1)), ckpt_path))
|
||||
|
||||
if not candidates:
|
||||
return None
|
||||
return max(candidates, key=lambda x: x[0])[1]
|
||||
|
||||
|
||||
def get_lr_schedule_with_warmup(optimizer, warmup_steps, max_steps, scheduler_type='cosine', min_lr=0):
|
||||
"""
|
||||
创建带预热的学习率调度器。
|
||||
@@ -270,6 +300,52 @@ def main(cfg: DictConfig):
|
||||
)
|
||||
log.info(f"📈 学习率调度器: {scheduler_type},{warmup_steps} 步预热 (最小学习率={min_lr})")
|
||||
|
||||
# =========================================================================
|
||||
# 4.1 断点续训(恢复模型、优化器、调度器、步数)
|
||||
# =========================================================================
|
||||
start_step = 0
|
||||
resume_loss = None
|
||||
resume_best_loss = float('inf')
|
||||
|
||||
resume_ckpt = cfg.train.get('resume_ckpt', None)
|
||||
resume_path = resolve_resume_checkpoint(resume_ckpt, checkpoint_dir)
|
||||
if resume_ckpt is not None:
|
||||
if pretrained_ckpt is not None:
|
||||
log.warning("⚠️ [Resume] 同时设置了 pretrained_ckpt 与 resume_ckpt,将优先使用 resume_ckpt 进行断点续训")
|
||||
if resume_path is None:
|
||||
log.warning("⚠️ [Resume] 未找到可恢复的 checkpoint,将从头开始训练")
|
||||
elif not resume_path.exists():
|
||||
log.error(f"❌ [Resume] Checkpoint 文件不存在: {resume_path}")
|
||||
log.warning("⚠️ 将从头开始训练")
|
||||
else:
|
||||
log.info(f"🔄 [Resume] 从 checkpoint 恢复训练: {resume_path}")
|
||||
try:
|
||||
checkpoint = torch.load(resume_path, map_location=cfg.train.device)
|
||||
|
||||
agent.load_state_dict(checkpoint['model_state_dict'], strict=True)
|
||||
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||
|
||||
resume_step = int(checkpoint['step'])
|
||||
start_step = resume_step + 1
|
||||
|
||||
loaded_loss = checkpoint.get('loss', None)
|
||||
loaded_val_loss = checkpoint.get('val_loss', None)
|
||||
resume_loss = float(loaded_loss) if loaded_loss is not None else None
|
||||
if loaded_val_loss is not None:
|
||||
resume_best_loss = float(loaded_val_loss)
|
||||
elif loaded_loss is not None:
|
||||
resume_best_loss = float(loaded_loss)
|
||||
|
||||
log.info(f"✅ [Resume] 恢复成功: 上次步骤={resume_step}, 本次从步骤 {start_step} 开始")
|
||||
log.info(f"📈 [Resume] 当前学习率: {optimizer.param_groups[0]['lr']:.2e}")
|
||||
except Exception as e:
|
||||
log.error(f"❌ [Resume] 恢复失败: {e}")
|
||||
log.warning("⚠️ 将从头开始训练")
|
||||
start_step = 0
|
||||
resume_loss = None
|
||||
resume_best_loss = float('inf')
|
||||
|
||||
# =========================================================================
|
||||
# 5. 训练循环
|
||||
# =========================================================================
|
||||
@@ -316,9 +392,15 @@ def main(cfg: DictConfig):
|
||||
return total_loss / max(num_batches, 1)
|
||||
|
||||
data_iter = iter(train_loader)
|
||||
pbar = tqdm(range(cfg.train.max_steps), desc="训练中", ncols=100)
|
||||
pbar = tqdm(range(start_step, cfg.train.max_steps), desc="训练中", ncols=100)
|
||||
|
||||
best_loss = float('inf')
|
||||
best_loss = resume_best_loss
|
||||
last_loss = resume_loss
|
||||
|
||||
if start_step >= cfg.train.max_steps:
|
||||
log.warning(
|
||||
f"⚠️ [Resume] start_step={start_step} 已达到/超过 max_steps={cfg.train.max_steps},跳过训练循环"
|
||||
)
|
||||
|
||||
for step in pbar:
|
||||
try:
|
||||
@@ -351,6 +433,8 @@ def main(cfg: DictConfig):
|
||||
log.error(f"❌ 步骤 {step} 前向传播失败: {e}")
|
||||
raise
|
||||
|
||||
last_loss = loss.item()
|
||||
|
||||
# =====================================================================
|
||||
# 反向传播与优化
|
||||
# =====================================================================
|
||||
@@ -427,15 +511,21 @@ def main(cfg: DictConfig):
|
||||
'model_state_dict': agent.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'scheduler_state_dict': scheduler.state_dict(),
|
||||
'loss': loss.item(),
|
||||
'loss': last_loss,
|
||||
'dataset_stats': agent_stats, # 保存agent的统计信息
|
||||
'current_lr': optimizer.param_groups[0]['lr'],
|
||||
}, final_model_path)
|
||||
log.info(f"💾 最终模型已保存: {final_model_path}")
|
||||
|
||||
log.info("✅ 训练成功完成!")
|
||||
log.info(f"📊 最终损失: {loss.item():.4f}")
|
||||
log.info(f"📊 最佳损失: {best_loss:.4f}")
|
||||
if last_loss is not None:
|
||||
log.info(f"📊 最终损失: {last_loss:.4f}")
|
||||
else:
|
||||
log.info("📊 最终损失: N/A(未执行训练步)")
|
||||
if best_loss != float('inf'):
|
||||
log.info(f"📊 最佳损失: {best_loss:.4f}")
|
||||
else:
|
||||
log.info("📊 最佳损失: N/A(无有效验证/训练损失)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user