From 642d41dd8f9e3424b1f2e2e6c963cab601ce8be5 Mon Sep 17 00:00:00 2001 From: JiajunLI Date: Fri, 6 Mar 2026 11:19:30 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0resume=E6=9C=BA?= =?UTF-8?q?=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- roboimi/demos/vla_scripts/train_vla.py | 100 +++++++++++++++++++++++-- 1 file changed, 95 insertions(+), 5 deletions(-) diff --git a/roboimi/demos/vla_scripts/train_vla.py b/roboimi/demos/vla_scripts/train_vla.py index c4656ca..058776e 100644 --- a/roboimi/demos/vla_scripts/train_vla.py +++ b/roboimi/demos/vla_scripts/train_vla.py @@ -5,6 +5,7 @@ import json import pickle import hydra import torch +import re from tqdm import tqdm from omegaconf import DictConfig, OmegaConf from torch.utils.data import DataLoader, random_split @@ -44,6 +45,35 @@ def recursive_to_device(data, device): return data +def resolve_resume_checkpoint(resume_ckpt, checkpoint_dir): + """ + 解析恢复训练用的 checkpoint 路径。 + + Args: + resume_ckpt: 配置中的 resume_ckpt,支持路径或 "auto" + checkpoint_dir: 默认检查点目录 + + Returns: + Path 或 None + """ + if resume_ckpt is None: + return None + + if str(resume_ckpt).lower() != "auto": + return Path(resume_ckpt) + + pattern = re.compile(r"vla_model_step_(\d+)\.pt$") + candidates = [] + for ckpt_path in checkpoint_dir.glob("vla_model_step_*.pt"): + match = pattern.search(ckpt_path.name) + if match: + candidates.append((int(match.group(1)), ckpt_path)) + + if not candidates: + return None + return max(candidates, key=lambda x: x[0])[1] + + def get_lr_schedule_with_warmup(optimizer, warmup_steps, max_steps, scheduler_type='cosine', min_lr=0): """ 创建带预热的学习率调度器。 @@ -270,6 +300,52 @@ def main(cfg: DictConfig): ) log.info(f"📈 学习率调度器: {scheduler_type},{warmup_steps} 步预热 (最小学习率={min_lr})") + # ========================================================================= + # 4.1 断点续训(恢复模型、优化器、调度器、步数) + # ========================================================================= + start_step = 0 + resume_loss = None + resume_best_loss = float('inf') + + resume_ckpt = cfg.train.get('resume_ckpt', None) + resume_path = resolve_resume_checkpoint(resume_ckpt, checkpoint_dir) + if resume_ckpt is not None: + if pretrained_ckpt is not None: + log.warning("⚠️ [Resume] 同时设置了 pretrained_ckpt 与 resume_ckpt,将优先使用 resume_ckpt 进行断点续训") + if resume_path is None: + log.warning("⚠️ [Resume] 未找到可恢复的 checkpoint,将从头开始训练") + elif not resume_path.exists(): + log.error(f"❌ [Resume] Checkpoint 文件不存在: {resume_path}") + log.warning("⚠️ 将从头开始训练") + else: + log.info(f"🔄 [Resume] 从 checkpoint 恢复训练: {resume_path}") + try: + checkpoint = torch.load(resume_path, map_location=cfg.train.device) + + agent.load_state_dict(checkpoint['model_state_dict'], strict=True) + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + scheduler.load_state_dict(checkpoint['scheduler_state_dict']) + + resume_step = int(checkpoint['step']) + start_step = resume_step + 1 + + loaded_loss = checkpoint.get('loss', None) + loaded_val_loss = checkpoint.get('val_loss', None) + resume_loss = float(loaded_loss) if loaded_loss is not None else None + if loaded_val_loss is not None: + resume_best_loss = float(loaded_val_loss) + elif loaded_loss is not None: + resume_best_loss = float(loaded_loss) + + log.info(f"✅ [Resume] 恢复成功: 上次步骤={resume_step}, 本次从步骤 {start_step} 开始") + log.info(f"📈 [Resume] 当前学习率: {optimizer.param_groups[0]['lr']:.2e}") + except Exception as e: + log.error(f"❌ [Resume] 恢复失败: {e}") + log.warning("⚠️ 将从头开始训练") + start_step = 0 + resume_loss = None + resume_best_loss = float('inf') + # ========================================================================= # 5. 训练循环 # ========================================================================= @@ -316,9 +392,15 @@ def main(cfg: DictConfig): return total_loss / max(num_batches, 1) data_iter = iter(train_loader) - pbar = tqdm(range(cfg.train.max_steps), desc="训练中", ncols=100) + pbar = tqdm(range(start_step, cfg.train.max_steps), desc="训练中", ncols=100) - best_loss = float('inf') + best_loss = resume_best_loss + last_loss = resume_loss + + if start_step >= cfg.train.max_steps: + log.warning( + f"⚠️ [Resume] start_step={start_step} 已达到/超过 max_steps={cfg.train.max_steps},跳过训练循环" + ) for step in pbar: try: @@ -351,6 +433,8 @@ def main(cfg: DictConfig): log.error(f"❌ 步骤 {step} 前向传播失败: {e}") raise + last_loss = loss.item() + # ===================================================================== # 反向传播与优化 # ===================================================================== @@ -427,15 +511,21 @@ def main(cfg: DictConfig): 'model_state_dict': agent.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), - 'loss': loss.item(), + 'loss': last_loss, 'dataset_stats': agent_stats, # 保存agent的统计信息 'current_lr': optimizer.param_groups[0]['lr'], }, final_model_path) log.info(f"💾 最终模型已保存: {final_model_path}") log.info("✅ 训练成功完成!") - log.info(f"📊 最终损失: {loss.item():.4f}") - log.info(f"📊 最佳损失: {best_loss:.4f}") + if last_loss is not None: + log.info(f"📊 最终损失: {last_loss:.4f}") + else: + log.info("📊 最终损失: N/A(未执行训练步)") + if best_loss != float('inf'): + log.info(f"📊 最佳损失: {best_loss:.4f}") + else: + log.info("📊 最佳损失: N/A(无有效验证/训练损失)") if __name__ == "__main__":