feat: 添加resume机制

This commit is contained in:
JiajunLI
2026-03-06 11:19:30 +08:00
parent 7d39933a5b
commit 642d41dd8f

View File

@@ -5,6 +5,7 @@ import json
import pickle
import hydra
import torch
import re
from tqdm import tqdm
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader, random_split
@@ -44,6 +45,35 @@ def recursive_to_device(data, device):
return data
def resolve_resume_checkpoint(resume_ckpt, checkpoint_dir):
"""
解析恢复训练用的 checkpoint 路径。
Args:
resume_ckpt: 配置中的 resume_ckpt支持路径或 "auto"
checkpoint_dir: 默认检查点目录
Returns:
Path 或 None
"""
if resume_ckpt is None:
return None
if str(resume_ckpt).lower() != "auto":
return Path(resume_ckpt)
pattern = re.compile(r"vla_model_step_(\d+)\.pt$")
candidates = []
for ckpt_path in checkpoint_dir.glob("vla_model_step_*.pt"):
match = pattern.search(ckpt_path.name)
if match:
candidates.append((int(match.group(1)), ckpt_path))
if not candidates:
return None
return max(candidates, key=lambda x: x[0])[1]
def get_lr_schedule_with_warmup(optimizer, warmup_steps, max_steps, scheduler_type='cosine', min_lr=0):
"""
创建带预热的学习率调度器。
@@ -270,6 +300,52 @@ def main(cfg: DictConfig):
)
log.info(f"📈 学习率调度器: {scheduler_type}{warmup_steps} 步预热 (最小学习率={min_lr})")
# =========================================================================
# 4.1 断点续训(恢复模型、优化器、调度器、步数)
# =========================================================================
start_step = 0
resume_loss = None
resume_best_loss = float('inf')
resume_ckpt = cfg.train.get('resume_ckpt', None)
resume_path = resolve_resume_checkpoint(resume_ckpt, checkpoint_dir)
if resume_ckpt is not None:
if pretrained_ckpt is not None:
log.warning("⚠️ [Resume] 同时设置了 pretrained_ckpt 与 resume_ckpt将优先使用 resume_ckpt 进行断点续训")
if resume_path is None:
log.warning("⚠️ [Resume] 未找到可恢复的 checkpoint将从头开始训练")
elif not resume_path.exists():
log.error(f"❌ [Resume] Checkpoint 文件不存在: {resume_path}")
log.warning("⚠️ 将从头开始训练")
else:
log.info(f"🔄 [Resume] 从 checkpoint 恢复训练: {resume_path}")
try:
checkpoint = torch.load(resume_path, map_location=cfg.train.device)
agent.load_state_dict(checkpoint['model_state_dict'], strict=True)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
resume_step = int(checkpoint['step'])
start_step = resume_step + 1
loaded_loss = checkpoint.get('loss', None)
loaded_val_loss = checkpoint.get('val_loss', None)
resume_loss = float(loaded_loss) if loaded_loss is not None else None
if loaded_val_loss is not None:
resume_best_loss = float(loaded_val_loss)
elif loaded_loss is not None:
resume_best_loss = float(loaded_loss)
log.info(f"✅ [Resume] 恢复成功: 上次步骤={resume_step}, 本次从步骤 {start_step} 开始")
log.info(f"📈 [Resume] 当前学习率: {optimizer.param_groups[0]['lr']:.2e}")
except Exception as e:
log.error(f"❌ [Resume] 恢复失败: {e}")
log.warning("⚠️ 将从头开始训练")
start_step = 0
resume_loss = None
resume_best_loss = float('inf')
# =========================================================================
# 5. 训练循环
# =========================================================================
@@ -316,9 +392,15 @@ def main(cfg: DictConfig):
return total_loss / max(num_batches, 1)
data_iter = iter(train_loader)
pbar = tqdm(range(cfg.train.max_steps), desc="训练中", ncols=100)
pbar = tqdm(range(start_step, cfg.train.max_steps), desc="训练中", ncols=100)
best_loss = float('inf')
best_loss = resume_best_loss
last_loss = resume_loss
if start_step >= cfg.train.max_steps:
log.warning(
f"⚠️ [Resume] start_step={start_step} 已达到/超过 max_steps={cfg.train.max_steps},跳过训练循环"
)
for step in pbar:
try:
@@ -351,6 +433,8 @@ def main(cfg: DictConfig):
log.error(f"❌ 步骤 {step} 前向传播失败: {e}")
raise
last_loss = loss.item()
# =====================================================================
# 反向传播与优化
# =====================================================================
@@ -427,15 +511,21 @@ def main(cfg: DictConfig):
'model_state_dict': agent.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'loss': loss.item(),
'loss': last_loss,
'dataset_stats': agent_stats, # 保存agent的统计信息
'current_lr': optimizer.param_groups[0]['lr'],
}, final_model_path)
log.info(f"💾 最终模型已保存: {final_model_path}")
log.info("✅ 训练成功完成!")
log.info(f"📊 最终损失: {loss.item():.4f}")
log.info(f"📊 最损失: {best_loss:.4f}")
if last_loss is not None:
log.info(f"📊 最损失: {last_loss:.4f}")
else:
log.info("📊 最终损失: N/A未执行训练步")
if best_loss != float('inf'):
log.info(f"📊 最佳损失: {best_loss:.4f}")
else:
log.info("📊 最佳损失: N/A无有效验证/训练损失)")
if __name__ == "__main__":