feat: 添加finetune
This commit is contained in:
@@ -211,6 +211,40 @@ def main(cfg: DictConfig):
|
|||||||
log.error(f"❌ Agent 初始化失败: {e}")
|
log.error(f"❌ Agent 初始化失败: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# 3.1 从预训练 checkpoint 加载权重(微调)
|
||||||
|
# =========================================================================
|
||||||
|
pretrained_ckpt = cfg.train.get('pretrained_ckpt', None)
|
||||||
|
if pretrained_ckpt is not None:
|
||||||
|
ckpt_path = Path(pretrained_ckpt)
|
||||||
|
if ckpt_path.exists():
|
||||||
|
log.info(f"🔄 [Finetune] 从预训练 checkpoint 加载权重: {ckpt_path}")
|
||||||
|
try:
|
||||||
|
checkpoint = torch.load(ckpt_path, map_location=cfg.train.device)
|
||||||
|
|
||||||
|
# 只加载模型权重(不加载 optimizer、scheduler)
|
||||||
|
missing_keys, unexpected_keys = agent.load_state_dict(
|
||||||
|
checkpoint['model_state_dict'],
|
||||||
|
strict=False # 允许部分加载(结构不完全匹配时)
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info(f"✅ [Finetune] 模型权重加载成功")
|
||||||
|
|
||||||
|
if missing_keys:
|
||||||
|
log.warning(f"⚠️ [Finetune] 缺少的键 ({len(missing_keys)} 个): {missing_keys[:5]}...")
|
||||||
|
if unexpected_keys:
|
||||||
|
log.warning(f"⚠️ [Finetune] 多余的键 ({len(unexpected_keys)} 个): {unexpected_keys[:5]}...")
|
||||||
|
|
||||||
|
log.info(f"📊 [Finetune] 预训练信息: 步骤={checkpoint.get('step', 'N/A')}, 损失={checkpoint.get('loss', 'N/A')}")
|
||||||
|
log.info(f"📈 [Finetune] 使用新的训练配置(lr={cfg.train.lr}, max_steps={cfg.train.max_steps})")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"❌ [Finetune] 加载 checkpoint 失败: {e}")
|
||||||
|
log.warning("⚠️ 将从头开始训练")
|
||||||
|
else:
|
||||||
|
log.error(f"❌ [Finetune] Checkpoint 文件不存在: {ckpt_path}")
|
||||||
|
log.warning("⚠️ 将从头开始训练")
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# 4. 设置优化器与学习率调度器
|
# 4. 设置优化器与学习率调度器
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
|
|||||||
@@ -32,6 +32,9 @@ train:
|
|||||||
weight_decay: 1e-5 # 权重衰减(L2 正则化)
|
weight_decay: 1e-5 # 权重衰减(L2 正则化)
|
||||||
grad_clip: 1.0 # 梯度裁剪阈值
|
grad_clip: 1.0 # 梯度裁剪阈值
|
||||||
|
|
||||||
|
# 微调配置
|
||||||
|
pretrained_ckpt: null # 预训练 checkpoint 路径(用于微调),例如: "checkpoints/vla_model_step_8000.pt"
|
||||||
|
|
||||||
# ====================
|
# ====================
|
||||||
# 实验配置
|
# 实验配置
|
||||||
# ====================
|
# ====================
|
||||||
|
|||||||
Reference in New Issue
Block a user