From 926a78eb660e52bb3aeae3654a419ee754b48664 Mon Sep 17 00:00:00 2001 From: gouhanke <12219217+gouhanke@user.noreply.gitee.com> Date: Thu, 12 Feb 2026 19:31:44 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0finetune?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- roboimi/demos/vla_scripts/train_vla.py | 34 ++++++++++++++++++++++++++ roboimi/vla/conf/config.yaml | 3 +++ 2 files changed, 37 insertions(+) diff --git a/roboimi/demos/vla_scripts/train_vla.py b/roboimi/demos/vla_scripts/train_vla.py index d96ca29..4f8f48a 100644 --- a/roboimi/demos/vla_scripts/train_vla.py +++ b/roboimi/demos/vla_scripts/train_vla.py @@ -211,6 +211,40 @@ def main(cfg: DictConfig): log.error(f"❌ Agent 初始化失败: {e}") raise + # ========================================================================= + # 3.1 从预训练 checkpoint 加载权重(微调) + # ========================================================================= + pretrained_ckpt = cfg.train.get('pretrained_ckpt', None) + if pretrained_ckpt is not None: + ckpt_path = Path(pretrained_ckpt) + if ckpt_path.exists(): + log.info(f"🔄 [Finetune] 从预训练 checkpoint 加载权重: {ckpt_path}") + try: + checkpoint = torch.load(ckpt_path, map_location=cfg.train.device) + + # 只加载模型权重(不加载 optimizer、scheduler) + missing_keys, unexpected_keys = agent.load_state_dict( + checkpoint['model_state_dict'], + strict=False # 允许部分加载(结构不完全匹配时) + ) + + log.info(f"✅ [Finetune] 模型权重加载成功") + + if missing_keys: + log.warning(f"⚠️ [Finetune] 缺少的键 ({len(missing_keys)} 个): {missing_keys[:5]}...") + if unexpected_keys: + log.warning(f"⚠️ [Finetune] 多余的键 ({len(unexpected_keys)} 个): {unexpected_keys[:5]}...") + + log.info(f"📊 [Finetune] 预训练信息: 步骤={checkpoint.get('step', 'N/A')}, 损失={checkpoint.get('loss', 'N/A')}") + log.info(f"📈 [Finetune] 使用新的训练配置(lr={cfg.train.lr}, max_steps={cfg.train.max_steps})") + + except Exception as e: + log.error(f"❌ [Finetune] 加载 checkpoint 失败: {e}") + log.warning("⚠️ 将从头开始训练") + else: + log.error(f"❌ [Finetune] Checkpoint 文件不存在: {ckpt_path}") + log.warning("⚠️ 将从头开始训练") + # ========================================================================= # 4. 设置优化器与学习率调度器 # ========================================================================= diff --git a/roboimi/vla/conf/config.yaml b/roboimi/vla/conf/config.yaml index b4cf8c0..8d14c93 100644 --- a/roboimi/vla/conf/config.yaml +++ b/roboimi/vla/conf/config.yaml @@ -32,6 +32,9 @@ train: weight_decay: 1e-5 # 权重衰减(L2 正则化) grad_clip: 1.0 # 梯度裁剪阈值 + # 微调配置 + pretrained_ckpt: null # 预训练 checkpoint 路径(用于微调),例如: "checkpoints/vla_model_step_8000.pt" + # ==================== # 实验配置 # ====================