feat: 添加resume机制

This commit is contained in:
JiajunLI
2026-03-06 11:19:30 +08:00
parent 7d39933a5b
commit 642d41dd8f

View File

@@ -5,6 +5,7 @@ import json
import pickle import pickle
import hydra import hydra
import torch import torch
import re
from tqdm import tqdm from tqdm import tqdm
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader, random_split from torch.utils.data import DataLoader, random_split
@@ -44,6 +45,35 @@ def recursive_to_device(data, device):
return data return data
def resolve_resume_checkpoint(resume_ckpt, checkpoint_dir):
"""
解析恢复训练用的 checkpoint 路径。
Args:
resume_ckpt: 配置中的 resume_ckpt支持路径或 "auto"
checkpoint_dir: 默认检查点目录
Returns:
Path 或 None
"""
if resume_ckpt is None:
return None
if str(resume_ckpt).lower() != "auto":
return Path(resume_ckpt)
pattern = re.compile(r"vla_model_step_(\d+)\.pt$")
candidates = []
for ckpt_path in checkpoint_dir.glob("vla_model_step_*.pt"):
match = pattern.search(ckpt_path.name)
if match:
candidates.append((int(match.group(1)), ckpt_path))
if not candidates:
return None
return max(candidates, key=lambda x: x[0])[1]
def get_lr_schedule_with_warmup(optimizer, warmup_steps, max_steps, scheduler_type='cosine', min_lr=0): def get_lr_schedule_with_warmup(optimizer, warmup_steps, max_steps, scheduler_type='cosine', min_lr=0):
""" """
创建带预热的学习率调度器。 创建带预热的学习率调度器。
@@ -270,6 +300,52 @@ def main(cfg: DictConfig):
) )
log.info(f"📈 学习率调度器: {scheduler_type}{warmup_steps} 步预热 (最小学习率={min_lr})") log.info(f"📈 学习率调度器: {scheduler_type}{warmup_steps} 步预热 (最小学习率={min_lr})")
# =========================================================================
# 4.1 断点续训(恢复模型、优化器、调度器、步数)
# =========================================================================
start_step = 0
resume_loss = None
resume_best_loss = float('inf')
resume_ckpt = cfg.train.get('resume_ckpt', None)
resume_path = resolve_resume_checkpoint(resume_ckpt, checkpoint_dir)
if resume_ckpt is not None:
if pretrained_ckpt is not None:
log.warning("⚠️ [Resume] 同时设置了 pretrained_ckpt 与 resume_ckpt将优先使用 resume_ckpt 进行断点续训")
if resume_path is None:
log.warning("⚠️ [Resume] 未找到可恢复的 checkpoint将从头开始训练")
elif not resume_path.exists():
log.error(f"❌ [Resume] Checkpoint 文件不存在: {resume_path}")
log.warning("⚠️ 将从头开始训练")
else:
log.info(f"🔄 [Resume] 从 checkpoint 恢复训练: {resume_path}")
try:
checkpoint = torch.load(resume_path, map_location=cfg.train.device)
agent.load_state_dict(checkpoint['model_state_dict'], strict=True)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
resume_step = int(checkpoint['step'])
start_step = resume_step + 1
loaded_loss = checkpoint.get('loss', None)
loaded_val_loss = checkpoint.get('val_loss', None)
resume_loss = float(loaded_loss) if loaded_loss is not None else None
if loaded_val_loss is not None:
resume_best_loss = float(loaded_val_loss)
elif loaded_loss is not None:
resume_best_loss = float(loaded_loss)
log.info(f"✅ [Resume] 恢复成功: 上次步骤={resume_step}, 本次从步骤 {start_step} 开始")
log.info(f"📈 [Resume] 当前学习率: {optimizer.param_groups[0]['lr']:.2e}")
except Exception as e:
log.error(f"❌ [Resume] 恢复失败: {e}")
log.warning("⚠️ 将从头开始训练")
start_step = 0
resume_loss = None
resume_best_loss = float('inf')
# ========================================================================= # =========================================================================
# 5. 训练循环 # 5. 训练循环
# ========================================================================= # =========================================================================
@@ -316,9 +392,15 @@ def main(cfg: DictConfig):
return total_loss / max(num_batches, 1) return total_loss / max(num_batches, 1)
data_iter = iter(train_loader) data_iter = iter(train_loader)
pbar = tqdm(range(cfg.train.max_steps), desc="训练中", ncols=100) pbar = tqdm(range(start_step, cfg.train.max_steps), desc="训练中", ncols=100)
best_loss = float('inf') best_loss = resume_best_loss
last_loss = resume_loss
if start_step >= cfg.train.max_steps:
log.warning(
f"⚠️ [Resume] start_step={start_step} 已达到/超过 max_steps={cfg.train.max_steps},跳过训练循环"
)
for step in pbar: for step in pbar:
try: try:
@@ -351,6 +433,8 @@ def main(cfg: DictConfig):
log.error(f"❌ 步骤 {step} 前向传播失败: {e}") log.error(f"❌ 步骤 {step} 前向传播失败: {e}")
raise raise
last_loss = loss.item()
# ===================================================================== # =====================================================================
# 反向传播与优化 # 反向传播与优化
# ===================================================================== # =====================================================================
@@ -427,15 +511,21 @@ def main(cfg: DictConfig):
'model_state_dict': agent.state_dict(), 'model_state_dict': agent.state_dict(),
'optimizer_state_dict': optimizer.state_dict(), 'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(), 'scheduler_state_dict': scheduler.state_dict(),
'loss': loss.item(), 'loss': last_loss,
'dataset_stats': agent_stats, # 保存agent的统计信息 'dataset_stats': agent_stats, # 保存agent的统计信息
'current_lr': optimizer.param_groups[0]['lr'], 'current_lr': optimizer.param_groups[0]['lr'],
}, final_model_path) }, final_model_path)
log.info(f"💾 最终模型已保存: {final_model_path}") log.info(f"💾 最终模型已保存: {final_model_path}")
log.info("✅ 训练成功完成!") log.info("✅ 训练成功完成!")
log.info(f"📊 最终损失: {loss.item():.4f}") if last_loss is not None:
log.info(f"📊 最损失: {best_loss:.4f}") log.info(f"📊 最损失: {last_loss:.4f}")
else:
log.info("📊 最终损失: N/A未执行训练步")
if best_loss != float('inf'):
log.info(f"📊 最佳损失: {best_loss:.4f}")
else:
log.info("📊 最佳损失: N/A无有效验证/训练损失)")
if __name__ == "__main__": if __name__ == "__main__":