feat: 添加finetune

This commit is contained in:
gouhanke
2026-02-12 19:31:44 +08:00
parent efbe4b6ac9
commit 926a78eb66
2 changed files with 37 additions and 0 deletions

View File

@@ -211,6 +211,40 @@ def main(cfg: DictConfig):
log.error(f"❌ Agent 初始化失败: {e}") log.error(f"❌ Agent 初始化失败: {e}")
raise raise
# =========================================================================
# 3.1 从预训练 checkpoint 加载权重(微调)
# =========================================================================
pretrained_ckpt = cfg.train.get('pretrained_ckpt', None)
if pretrained_ckpt is not None:
ckpt_path = Path(pretrained_ckpt)
if ckpt_path.exists():
log.info(f"🔄 [Finetune] 从预训练 checkpoint 加载权重: {ckpt_path}")
try:
checkpoint = torch.load(ckpt_path, map_location=cfg.train.device)
# 只加载模型权重(不加载 optimizer、scheduler
missing_keys, unexpected_keys = agent.load_state_dict(
checkpoint['model_state_dict'],
strict=False # 允许部分加载(结构不完全匹配时)
)
log.info(f"✅ [Finetune] 模型权重加载成功")
if missing_keys:
log.warning(f"⚠️ [Finetune] 缺少的键 ({len(missing_keys)} 个): {missing_keys[:5]}...")
if unexpected_keys:
log.warning(f"⚠️ [Finetune] 多余的键 ({len(unexpected_keys)} 个): {unexpected_keys[:5]}...")
log.info(f"📊 [Finetune] 预训练信息: 步骤={checkpoint.get('step', 'N/A')}, 损失={checkpoint.get('loss', 'N/A')}")
log.info(f"📈 [Finetune] 使用新的训练配置lr={cfg.train.lr}, max_steps={cfg.train.max_steps}")
except Exception as e:
log.error(f"❌ [Finetune] 加载 checkpoint 失败: {e}")
log.warning("⚠️ 将从头开始训练")
else:
log.error(f"❌ [Finetune] Checkpoint 文件不存在: {ckpt_path}")
log.warning("⚠️ 将从头开始训练")
# ========================================================================= # =========================================================================
# 4. 设置优化器与学习率调度器 # 4. 设置优化器与学习率调度器
# ========================================================================= # =========================================================================

View File

@@ -32,6 +32,9 @@ train:
weight_decay: 1e-5 # 权重衰减L2 正则化) weight_decay: 1e-5 # 权重衰减L2 正则化)
grad_clip: 1.0 # 梯度裁剪阈值 grad_clip: 1.0 # 梯度裁剪阈值
# 微调配置
pretrained_ckpt: null # 预训练 checkpoint 路径(用于微调),例如: "checkpoints/vla_model_step_8000.pt"
# ==================== # ====================
# 实验配置 # 实验配置
# ==================== # ====================