feat: 添加finetune
This commit is contained in:
@@ -211,6 +211,40 @@ def main(cfg: DictConfig):
|
||||
log.error(f"❌ Agent 初始化失败: {e}")
|
||||
raise
|
||||
|
||||
# =========================================================================
|
||||
# 3.1 从预训练 checkpoint 加载权重(微调)
|
||||
# =========================================================================
|
||||
pretrained_ckpt = cfg.train.get('pretrained_ckpt', None)
|
||||
if pretrained_ckpt is not None:
|
||||
ckpt_path = Path(pretrained_ckpt)
|
||||
if ckpt_path.exists():
|
||||
log.info(f"🔄 [Finetune] 从预训练 checkpoint 加载权重: {ckpt_path}")
|
||||
try:
|
||||
checkpoint = torch.load(ckpt_path, map_location=cfg.train.device)
|
||||
|
||||
# 只加载模型权重(不加载 optimizer、scheduler)
|
||||
missing_keys, unexpected_keys = agent.load_state_dict(
|
||||
checkpoint['model_state_dict'],
|
||||
strict=False # 允许部分加载(结构不完全匹配时)
|
||||
)
|
||||
|
||||
log.info(f"✅ [Finetune] 模型权重加载成功")
|
||||
|
||||
if missing_keys:
|
||||
log.warning(f"⚠️ [Finetune] 缺少的键 ({len(missing_keys)} 个): {missing_keys[:5]}...")
|
||||
if unexpected_keys:
|
||||
log.warning(f"⚠️ [Finetune] 多余的键 ({len(unexpected_keys)} 个): {unexpected_keys[:5]}...")
|
||||
|
||||
log.info(f"📊 [Finetune] 预训练信息: 步骤={checkpoint.get('step', 'N/A')}, 损失={checkpoint.get('loss', 'N/A')}")
|
||||
log.info(f"📈 [Finetune] 使用新的训练配置(lr={cfg.train.lr}, max_steps={cfg.train.max_steps})")
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"❌ [Finetune] 加载 checkpoint 失败: {e}")
|
||||
log.warning("⚠️ 将从头开始训练")
|
||||
else:
|
||||
log.error(f"❌ [Finetune] Checkpoint 文件不存在: {ckpt_path}")
|
||||
log.warning("⚠️ 将从头开始训练")
|
||||
|
||||
# =========================================================================
|
||||
# 4. 设置优化器与学习率调度器
|
||||
# =========================================================================
|
||||
|
||||
@@ -32,6 +32,9 @@ train:
|
||||
weight_decay: 1e-5 # 权重衰减(L2 正则化)
|
||||
grad_clip: 1.0 # 梯度裁剪阈值
|
||||
|
||||
# 微调配置
|
||||
pretrained_ckpt: null # 预训练 checkpoint 路径(用于微调),例如: "checkpoints/vla_model_step_8000.pt"
|
||||
|
||||
# ====================
|
||||
# 实验配置
|
||||
# ====================
|
||||
|
||||
Reference in New Issue
Block a user