feat: 添加resume机制
This commit is contained in:
@@ -5,6 +5,7 @@ import json
|
|||||||
import pickle
|
import pickle
|
||||||
import hydra
|
import hydra
|
||||||
import torch
|
import torch
|
||||||
|
import re
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from torch.utils.data import DataLoader, random_split
|
from torch.utils.data import DataLoader, random_split
|
||||||
@@ -44,6 +45,35 @@ def recursive_to_device(data, device):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_resume_checkpoint(resume_ckpt, checkpoint_dir):
|
||||||
|
"""
|
||||||
|
解析恢复训练用的 checkpoint 路径。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
resume_ckpt: 配置中的 resume_ckpt,支持路径或 "auto"
|
||||||
|
checkpoint_dir: 默认检查点目录
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path 或 None
|
||||||
|
"""
|
||||||
|
if resume_ckpt is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if str(resume_ckpt).lower() != "auto":
|
||||||
|
return Path(resume_ckpt)
|
||||||
|
|
||||||
|
pattern = re.compile(r"vla_model_step_(\d+)\.pt$")
|
||||||
|
candidates = []
|
||||||
|
for ckpt_path in checkpoint_dir.glob("vla_model_step_*.pt"):
|
||||||
|
match = pattern.search(ckpt_path.name)
|
||||||
|
if match:
|
||||||
|
candidates.append((int(match.group(1)), ckpt_path))
|
||||||
|
|
||||||
|
if not candidates:
|
||||||
|
return None
|
||||||
|
return max(candidates, key=lambda x: x[0])[1]
|
||||||
|
|
||||||
|
|
||||||
def get_lr_schedule_with_warmup(optimizer, warmup_steps, max_steps, scheduler_type='cosine', min_lr=0):
|
def get_lr_schedule_with_warmup(optimizer, warmup_steps, max_steps, scheduler_type='cosine', min_lr=0):
|
||||||
"""
|
"""
|
||||||
创建带预热的学习率调度器。
|
创建带预热的学习率调度器。
|
||||||
@@ -270,6 +300,52 @@ def main(cfg: DictConfig):
|
|||||||
)
|
)
|
||||||
log.info(f"📈 学习率调度器: {scheduler_type},{warmup_steps} 步预热 (最小学习率={min_lr})")
|
log.info(f"📈 学习率调度器: {scheduler_type},{warmup_steps} 步预热 (最小学习率={min_lr})")
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# 4.1 断点续训(恢复模型、优化器、调度器、步数)
|
||||||
|
# =========================================================================
|
||||||
|
start_step = 0
|
||||||
|
resume_loss = None
|
||||||
|
resume_best_loss = float('inf')
|
||||||
|
|
||||||
|
resume_ckpt = cfg.train.get('resume_ckpt', None)
|
||||||
|
resume_path = resolve_resume_checkpoint(resume_ckpt, checkpoint_dir)
|
||||||
|
if resume_ckpt is not None:
|
||||||
|
if pretrained_ckpt is not None:
|
||||||
|
log.warning("⚠️ [Resume] 同时设置了 pretrained_ckpt 与 resume_ckpt,将优先使用 resume_ckpt 进行断点续训")
|
||||||
|
if resume_path is None:
|
||||||
|
log.warning("⚠️ [Resume] 未找到可恢复的 checkpoint,将从头开始训练")
|
||||||
|
elif not resume_path.exists():
|
||||||
|
log.error(f"❌ [Resume] Checkpoint 文件不存在: {resume_path}")
|
||||||
|
log.warning("⚠️ 将从头开始训练")
|
||||||
|
else:
|
||||||
|
log.info(f"🔄 [Resume] 从 checkpoint 恢复训练: {resume_path}")
|
||||||
|
try:
|
||||||
|
checkpoint = torch.load(resume_path, map_location=cfg.train.device)
|
||||||
|
|
||||||
|
agent.load_state_dict(checkpoint['model_state_dict'], strict=True)
|
||||||
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||||
|
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||||
|
|
||||||
|
resume_step = int(checkpoint['step'])
|
||||||
|
start_step = resume_step + 1
|
||||||
|
|
||||||
|
loaded_loss = checkpoint.get('loss', None)
|
||||||
|
loaded_val_loss = checkpoint.get('val_loss', None)
|
||||||
|
resume_loss = float(loaded_loss) if loaded_loss is not None else None
|
||||||
|
if loaded_val_loss is not None:
|
||||||
|
resume_best_loss = float(loaded_val_loss)
|
||||||
|
elif loaded_loss is not None:
|
||||||
|
resume_best_loss = float(loaded_loss)
|
||||||
|
|
||||||
|
log.info(f"✅ [Resume] 恢复成功: 上次步骤={resume_step}, 本次从步骤 {start_step} 开始")
|
||||||
|
log.info(f"📈 [Resume] 当前学习率: {optimizer.param_groups[0]['lr']:.2e}")
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"❌ [Resume] 恢复失败: {e}")
|
||||||
|
log.warning("⚠️ 将从头开始训练")
|
||||||
|
start_step = 0
|
||||||
|
resume_loss = None
|
||||||
|
resume_best_loss = float('inf')
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# 5. 训练循环
|
# 5. 训练循环
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
@@ -316,9 +392,15 @@ def main(cfg: DictConfig):
|
|||||||
return total_loss / max(num_batches, 1)
|
return total_loss / max(num_batches, 1)
|
||||||
|
|
||||||
data_iter = iter(train_loader)
|
data_iter = iter(train_loader)
|
||||||
pbar = tqdm(range(cfg.train.max_steps), desc="训练中", ncols=100)
|
pbar = tqdm(range(start_step, cfg.train.max_steps), desc="训练中", ncols=100)
|
||||||
|
|
||||||
best_loss = float('inf')
|
best_loss = resume_best_loss
|
||||||
|
last_loss = resume_loss
|
||||||
|
|
||||||
|
if start_step >= cfg.train.max_steps:
|
||||||
|
log.warning(
|
||||||
|
f"⚠️ [Resume] start_step={start_step} 已达到/超过 max_steps={cfg.train.max_steps},跳过训练循环"
|
||||||
|
)
|
||||||
|
|
||||||
for step in pbar:
|
for step in pbar:
|
||||||
try:
|
try:
|
||||||
@@ -351,6 +433,8 @@ def main(cfg: DictConfig):
|
|||||||
log.error(f"❌ 步骤 {step} 前向传播失败: {e}")
|
log.error(f"❌ 步骤 {step} 前向传播失败: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
last_loss = loss.item()
|
||||||
|
|
||||||
# =====================================================================
|
# =====================================================================
|
||||||
# 反向传播与优化
|
# 反向传播与优化
|
||||||
# =====================================================================
|
# =====================================================================
|
||||||
@@ -427,15 +511,21 @@ def main(cfg: DictConfig):
|
|||||||
'model_state_dict': agent.state_dict(),
|
'model_state_dict': agent.state_dict(),
|
||||||
'optimizer_state_dict': optimizer.state_dict(),
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
'scheduler_state_dict': scheduler.state_dict(),
|
'scheduler_state_dict': scheduler.state_dict(),
|
||||||
'loss': loss.item(),
|
'loss': last_loss,
|
||||||
'dataset_stats': agent_stats, # 保存agent的统计信息
|
'dataset_stats': agent_stats, # 保存agent的统计信息
|
||||||
'current_lr': optimizer.param_groups[0]['lr'],
|
'current_lr': optimizer.param_groups[0]['lr'],
|
||||||
}, final_model_path)
|
}, final_model_path)
|
||||||
log.info(f"💾 最终模型已保存: {final_model_path}")
|
log.info(f"💾 最终模型已保存: {final_model_path}")
|
||||||
|
|
||||||
log.info("✅ 训练成功完成!")
|
log.info("✅ 训练成功完成!")
|
||||||
log.info(f"📊 最终损失: {loss.item():.4f}")
|
if last_loss is not None:
|
||||||
|
log.info(f"📊 最终损失: {last_loss:.4f}")
|
||||||
|
else:
|
||||||
|
log.info("📊 最终损失: N/A(未执行训练步)")
|
||||||
|
if best_loss != float('inf'):
|
||||||
log.info(f"📊 最佳损失: {best_loss:.4f}")
|
log.info(f"📊 最佳损失: {best_loss:.4f}")
|
||||||
|
else:
|
||||||
|
log.info("📊 最佳损失: N/A(无有效验证/训练损失)")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user