Files
roboimi/roboimi/demos/vla_scripts/train_vla.py

857 lines
35 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import sys
import os
import logging
import json
import pickle
import importlib
import hydra
import torch
import re
from tqdm import tqdm
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader, random_split
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from pathlib import Path
# 确保正确的导入路径
sys.path.append(os.getcwd())
from hydra.utils import instantiate
log = logging.getLogger(__name__)
# 注册列表长度解析器(用于配置中如 ${len:${data.camera_names}}
if not OmegaConf.has_resolver("len"):
OmegaConf.register_new_resolver("len", lambda x: len(x))
def recursive_to_device(data, device):
"""
递归地将嵌套字典/列表中的张量移动到指定设备。
Args:
data: 字典、列表或张量
device: 目标设备 (例如 'cuda', 'cpu')
Returns:
所有张量已移动到指定设备的数据结构
"""
if isinstance(data, torch.Tensor):
return data.to(device)
elif isinstance(data, dict):
return {k: recursive_to_device(v, device) for k, v in data.items()}
elif isinstance(data, list):
return [recursive_to_device(v, device) for v in data]
return data
def resolve_resume_checkpoint(resume_ckpt, checkpoint_dir):
"""
解析恢复训练用的 checkpoint 路径。
Args:
resume_ckpt: 配置中的 resume_ckpt支持路径或 "auto"
checkpoint_dir: 默认检查点目录
Returns:
Path 或 None
"""
if resume_ckpt is None:
return None
if str(resume_ckpt).lower() != "auto":
return Path(resume_ckpt)
pattern = re.compile(r"vla_model_step_(\d+)\.pt$")
candidates = []
for ckpt_path in checkpoint_dir.glob("vla_model_step_*.pt"):
match = pattern.search(ckpt_path.name)
if match:
candidates.append((int(match.group(1)), ckpt_path))
if not candidates:
return None
return max(candidates, key=lambda x: x[0])[1]
def get_lr_schedule_with_warmup(optimizer, warmup_steps, max_steps, scheduler_type='cosine', min_lr=0):
"""
创建带预热的学习率调度器。
Args:
optimizer: PyTorch 优化器
warmup_steps: 预热步数
max_steps: 总训练步数
scheduler_type: 预热后的调度器类型 ('cosine''constant')
min_lr: 最小学习率(用于余弦衰减)
Returns:
LambdaLR 调度器
"""
import math
# 在 LambdaLR 修改前捕获初始学习率
base_lr = optimizer.param_groups[0]['lr']
min_lr_ratio = min_lr / base_lr if base_lr > 0 else 0.0
def lr_lambda(step):
# 预热阶段:从 0 线性增加到 1
if step < warmup_steps:
return float(step) / float(max(1, warmup_steps))
# 预热后阶段
if scheduler_type == 'cosine':
# 从 1 到 min_lr_ratio 的余弦退火
progress = float(step - warmup_steps) / float(max(1, max_steps - warmup_steps))
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
return max(min_lr_ratio, cosine_decay)
else:
# 恒定学习率
return 1.0
return LambdaLR(optimizer, lr_lambda)
def build_training_optimizer(agent, lr, weight_decay):
"""为训练脚本构建优化器,优先复用 transformer head 自带的参数分组。"""
trainable_params = [param for param in agent.parameters() if param.requires_grad]
noise_pred_net = getattr(agent, 'noise_pred_net', None)
get_optim_groups = getattr(noise_pred_net, 'get_optim_groups', None)
use_head_groups = (
getattr(agent, 'head_type', None) == 'transformer'
and callable(get_optim_groups)
)
if not use_head_groups:
return AdamW(trainable_params, lr=lr, weight_decay=weight_decay)
head_groups = []
grouped_param_ids = set()
for group in get_optim_groups(weight_decay=weight_decay):
params = [param for param in group['params'] if param.requires_grad]
if not params:
continue
normalized_group = dict(group)
normalized_group['params'] = params
head_groups.append(normalized_group)
for param in params:
param_id = id(param)
if param_id in grouped_param_ids:
raise ValueError('Transformer optimizer groups contain duplicate parameters')
grouped_param_ids.add(param_id)
head_trainable_param_ids = {
id(param) for param in noise_pred_net.parameters() if param.requires_grad
}
missing_head_param_ids = head_trainable_param_ids - grouped_param_ids
if missing_head_param_ids:
raise ValueError('Transformer optimizer groups missed trainable head parameters')
remaining_params = [
param for param in trainable_params
if id(param) not in grouped_param_ids
]
optim_groups = head_groups
if remaining_params:
optim_groups = optim_groups + [{
'params': remaining_params,
'weight_decay': weight_decay,
}]
grouped_param_ids.update(id(param) for param in remaining_params)
all_trainable_param_ids = {id(param) for param in trainable_params}
if grouped_param_ids != all_trainable_param_ids:
raise ValueError('Optimizer parameter groups must include each trainable parameter exactly once')
return AdamW(optim_groups, lr=lr, weight_decay=weight_decay)
def _init_swanlab(cfg):
"""按需初始化 SwanLab并在缺少依赖或认证失败时快速失败。"""
if not bool(cfg.train.get('use_swanlab', False)):
return None
try:
swanlab = importlib.import_module("swanlab")
except ImportError as exc:
raise RuntimeError(
"SwanLab logging is enabled, but the 'swanlab' package could not be imported."
) from exc
def _to_plain_config(value):
if isinstance(value, dict):
return {key: _to_plain_config(val) for key, val in value.items()}
if isinstance(value, list):
return [_to_plain_config(item) for item in value]
if isinstance(value, tuple):
return tuple(_to_plain_config(item) for item in value)
items_method = getattr(value, 'items', None)
if callable(items_method):
try:
return {key: _to_plain_config(val) for key, val in items_method()}
except Exception:
pass
return value
swanlab_config = {
key: _to_plain_config(cfg[key])
for key in ('train', 'data', 'agent')
if key in cfg
}
init_kwargs = {
'project': cfg.train.get('swanlab_project', 'roboimi-vla'),
'config': swanlab_config,
}
run_name = cfg.train.get('swanlab_run_name', None)
if run_name:
init_kwargs['experiment_name'] = run_name
try:
swanlab.init(**init_kwargs)
except Exception as exc:
raise RuntimeError(
f"SwanLab logging is enabled, but SwanLab init/login failed: {exc}"
) from exc
return swanlab
def _log_to_swanlab(swanlab_module, payload, step=None):
if swanlab_module is None:
return
try:
swanlab_module.log(payload, step=step)
except Exception as exc:
log.warning(f"SwanLab log failed at step {step}: {exc}")
def _finish_swanlab(swanlab_module):
if swanlab_module is None:
return
try:
swanlab_module.finish()
except Exception as exc:
log.warning(f"SwanLab finish failed: {exc}")
def _run_training(cfg: DictConfig):
"""
VLA 训练脚本ResNet 骨干网络 + Diffusion 策略)
该脚本功能:
1. 从 HDF5 文件加载数据集
2. 实例化带 ResNet 视觉编码器的 VLAAgent
3. 训练基于扩散的动作预测模型
4. 定期保存检查点
"""
# 打印配置
print("=" * 80)
print("VLA 训练配置:")
print("=" * 80)
print(OmegaConf.to_yaml(cfg))
print("=" * 80)
log.info(f"🚀 开始 VLA 训练 (设备: {cfg.train.device})")
swanlab_module = _init_swanlab(cfg)
try:
# 创建检查点目录
checkpoint_dir = Path("checkpoints")
checkpoint_dir.mkdir(exist_ok=True)
default_best_model_path = checkpoint_dir / "vla_model_best.pt"
# =========================================================================
# 1. 实例化数据集与 DataLoader
# =========================================================================
log.info("📦 加载数据集...")
try:
dataset = instantiate(cfg.data)
log.info(f"✅ 数据集加载成功。总样本数: {len(dataset)}")
except Exception as e:
log.error(f"❌ 数据集加载失败: {e}")
raise
# 训练/验证集划分
val_split = float(cfg.train.get('val_split', 0.1))
seed = int(cfg.train.get('seed', 42))
val_size = int(len(dataset) * val_split)
train_size = len(dataset) - val_size
if val_size > 0:
train_dataset, val_dataset = random_split(
dataset,
[train_size, val_size],
generator=torch.Generator().manual_seed(seed)
)
log.info(f"✅ 数据集划分: 训练集={train_size}, 验证集={val_size} (验证比例={val_split})")
else:
train_dataset, val_dataset = dataset, None
log.info("✅ 数据集划分: 全部用于训练, 验证集=0 (验证比例=0)")
train_batch_size = int(cfg.train.batch_size)
train_drop_last = len(train_dataset) >= train_batch_size
if not train_drop_last:
log.warning(
"⚠️ 训练集样本数 (%s) 小于 batch_size (%s),将保留最后一个不完整批次以避免空训练加载器",
len(train_dataset),
train_batch_size,
)
train_loader = DataLoader(
train_dataset,
batch_size=train_batch_size,
shuffle=True,
num_workers=cfg.train.num_workers,
pin_memory=(cfg.train.device != "cpu"),
persistent_workers=False,
drop_last=train_drop_last
)
val_loader = None
if val_dataset is not None:
val_loader = DataLoader(
val_dataset,
batch_size=train_batch_size,
shuffle=False,
num_workers=cfg.train.num_workers,
pin_memory=(cfg.train.device != "cpu"),
persistent_workers=False,
drop_last=False
)
log.info(f"✅ 训练加载器每轮批次数: {len(train_loader)}")
if val_loader is not None:
log.info(f"✅ 验证加载器每轮批次数: {len(val_loader)}")
# =========================================================================
# 2. 加载数据集统计信息(将传递给 agent
# =========================================================================
log.info("💾 加载数据集统计信息...")
dataset_stats = None
try:
dataset_dir = cfg.data.get('dataset_dir', 'roboimi/demos/dataset/sim_transfer')
stats_path = Path(dataset_dir) / 'dataset_stats.pkl'
if stats_path.exists():
with open(stats_path, 'rb') as f:
stats = pickle.load(f)
# 扁平化stats字典嵌套结构→扁平结构以匹配NormalizationModule的期望格式
dataset_stats = {
'action_mean': stats['action_mean'].tolist(),
'action_std': stats['action_std'].tolist(),
'action_min': stats['action_min'].tolist(),
'action_max': stats['action_max'].tolist(),
'qpos_mean': stats['qpos_mean'].tolist(),
'qpos_std': stats['qpos_std'].tolist(),
'qpos_min': stats['qpos_min'].tolist(),
'qpos_max': stats['qpos_max'].tolist(),
}
log.info(f"✅ 数据集统计信息加载完成 (归一化: {cfg.agent.normalization_type})")
else:
log.warning(f"⚠️ 统计文件未找到: {stats_path}")
log.warning("⚠️ 推理时动作将无法反归一化!")
except Exception as e:
log.warning(f"⚠️ 统计信息加载失败: {e}")
log.warning("⚠️ 训练将继续,但推理可能无法正常工作")
# =========================================================================
# 3. 实例化 VLA Agent
# =========================================================================
log.info("🤖 初始化 VLA Agent...")
try:
# 将 dataset_stats 和 normalization_type 传递给 agent
agent = instantiate(cfg.agent, dataset_stats=dataset_stats)
agent.to(cfg.train.device)
agent.train()
log.info(f"✅ Agent 初始化完成并已移至 {cfg.train.device}")
# 统计参数量
total_params = sum(p.numel() for p in agent.parameters())
trainable_params = sum(p.numel() for p in agent.parameters() if p.requires_grad)
log.info(f"📊 总参数量: {total_params:,}")
log.info(f"📊 可训练参数量: {trainable_params:,}")
except Exception as e:
log.error(f"❌ Agent 初始化失败: {e}")
raise
# =========================================================================
# 3.1 从预训练 checkpoint 加载权重(微调)
# =========================================================================
pretrained_ckpt = cfg.train.get('pretrained_ckpt', None)
if pretrained_ckpt is not None:
ckpt_path = Path(pretrained_ckpt)
if ckpt_path.exists():
log.info(f"🔄 [Finetune] 从预训练 checkpoint 加载权重: {ckpt_path}")
try:
checkpoint = torch.load(ckpt_path, map_location=cfg.train.device)
# 只加载模型权重(不加载 optimizer、scheduler
missing_keys, unexpected_keys = agent.load_state_dict(
checkpoint['model_state_dict'],
strict=False # 允许部分加载(结构不完全匹配时)
)
log.info(f"✅ [Finetune] 模型权重加载成功")
if missing_keys:
log.warning(f"⚠️ [Finetune] 缺少的键 ({len(missing_keys)} 个): {missing_keys[:5]}...")
if unexpected_keys:
log.warning(f"⚠️ [Finetune] 多余的键 ({len(unexpected_keys)} 个): {unexpected_keys[:5]}...")
log.info(f"📊 [Finetune] 预训练信息: 步骤={checkpoint.get('step', 'N/A')}, 损失={checkpoint.get('loss', 'N/A')}")
log.info(f"📈 [Finetune] 使用新的训练配置lr={cfg.train.lr}, max_steps={cfg.train.max_steps}")
except Exception as e:
log.error(f"❌ [Finetune] 加载 checkpoint 失败: {e}")
log.warning("⚠️ 将从头开始训练")
else:
log.error(f"❌ [Finetune] Checkpoint 文件不存在: {ckpt_path}")
log.warning("⚠️ 将从头开始训练")
# =========================================================================
# 4. 设置优化器与学习率调度器
# =========================================================================
weight_decay = float(cfg.train.get('weight_decay', 1e-5))
grad_clip = float(cfg.train.get('grad_clip', 1.0))
optimizer = build_training_optimizer(agent, lr=cfg.train.lr, weight_decay=weight_decay)
log.info(f"🔧 优化器: AdamW (学习率={cfg.train.lr}, weight_decay={weight_decay})")
# 设置带预热的学習率调度器
warmup_steps = int(cfg.train.get('warmup_steps', 500))
scheduler_type = cfg.train.get('scheduler_type', 'cosine')
min_lr = float(cfg.train.get('min_lr', 1e-6))
scheduler = get_lr_schedule_with_warmup(
optimizer,
warmup_steps=warmup_steps,
max_steps=cfg.train.max_steps,
scheduler_type=scheduler_type,
min_lr=min_lr
)
log.info(f"📈 学习率调度器: {scheduler_type}{warmup_steps} 步预热 (最小学习率={min_lr})")
# =========================================================================
# 4.1 断点续训(恢复模型、优化器、调度器、步数)
# =========================================================================
def extract_checkpoint_metric_baseline(checkpoint):
checkpoint_loss = checkpoint.get('loss', None)
checkpoint_val_loss = checkpoint.get('val_loss', None)
checkpoint_rollout_reward = checkpoint.get('rollout_avg_reward', None)
baseline_loss = float('inf')
baseline_rollout_reward = float('-inf')
if checkpoint_rollout_reward is not None:
baseline_rollout_reward = float(checkpoint_rollout_reward)
if checkpoint_val_loss is not None:
baseline_loss = float(checkpoint_val_loss)
elif checkpoint_loss is not None:
baseline_loss = float(checkpoint_loss)
return baseline_loss, baseline_rollout_reward
start_step = 0
resume_loss = None
resume_best_loss = float('inf')
resume_best_rollout_reward = float('-inf')
best_model_path = None
resume_ckpt = cfg.train.get('resume_ckpt', None)
resume_path = resolve_resume_checkpoint(resume_ckpt, checkpoint_dir)
if resume_ckpt is not None:
if pretrained_ckpt is not None:
log.warning("⚠️ [Resume] 同时设置了 pretrained_ckpt 与 resume_ckpt将优先使用 resume_ckpt 进行断点续训")
if resume_path is None:
log.warning("⚠️ [Resume] 未找到可恢复的 checkpoint将从头开始训练")
elif not resume_path.exists():
log.error(f"❌ [Resume] Checkpoint 文件不存在: {resume_path}")
log.warning("⚠️ 将从头开始训练")
else:
log.info(f"🔄 [Resume] 从 checkpoint 恢复训练: {resume_path}")
try:
checkpoint = torch.load(resume_path, map_location=cfg.train.device)
agent.load_state_dict(checkpoint['model_state_dict'], strict=True)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
resume_step = int(checkpoint['step'])
start_step = resume_step + 1
loaded_loss = checkpoint.get('loss', None)
resume_loss = float(loaded_loss) if loaded_loss is not None else None
resume_best_loss, resume_best_rollout_reward = extract_checkpoint_metric_baseline(checkpoint)
if (
resume_best_rollout_reward != float('-inf')
or resume_best_loss != float('inf')
):
best_model_path = resume_path
if default_best_model_path.exists():
try:
best_checkpoint = torch.load(default_best_model_path, map_location=cfg.train.device)
_, best_checkpoint_rollout_reward = (
extract_checkpoint_metric_baseline(best_checkpoint)
)
if best_checkpoint_rollout_reward != float('-inf'):
resume_best_rollout_reward = best_checkpoint_rollout_reward
best_model_path = default_best_model_path
log.info(
"📈 [Resume] 从最佳 checkpoint 恢复最佳 rollout 基线: %s",
default_best_model_path,
)
except Exception as e:
log.warning(
f"⚠️ [Resume] 读取最佳 checkpoint 失败,将回退到恢复 checkpoint 的验证基线: {e}"
)
log.info(f"✅ [Resume] 恢复成功: 上次步骤={resume_step}, 本次从步骤 {start_step} 开始")
log.info(f"📈 [Resume] 当前学习率: {optimizer.param_groups[0]['lr']:.2e}")
except Exception as e:
log.error(f"❌ [Resume] 恢复失败: {e}")
log.warning("⚠️ 将从头开始训练")
start_step = 0
resume_loss = None
resume_best_loss = float('inf')
resume_best_rollout_reward = float('-inf')
# =========================================================================
# 5. 训练循环
# =========================================================================
log.info("🏋️ 开始训练循环...")
def build_agent_input(batch_data):
"""构建 agent 输入格式"""
images = {}
# SimpleRobotDataset 返回 observation.{cam_name} 格式
for cam_name in cfg.data.camera_names:
key = f"observation.{cam_name}"
if key in batch_data:
images[cam_name] = batch_data[key]
return {
'images': images,
'qpos': batch_data['observation.state'], # SimpleRobotDataset 使用 observation.state
'action': batch_data['action'],
'action_is_pad': batch_data.get('action_is_pad', None) # 传递padding mask
}
def save_checkpoint(checkpoint_path: Path, step: int, loss_value, val_loss=None, rollout_avg_reward=None):
agent_stats = agent.get_normalization_stats()
torch.save({
'step': step,
'model_state_dict': agent.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'loss': loss_value,
'val_loss': val_loss,
'rollout_avg_reward': rollout_avg_reward,
'dataset_stats': agent_stats, # 保存agent的统计信息
'current_lr': optimizer.param_groups[0]['lr'],
}, checkpoint_path)
return checkpoint_path
def run_validation():
"""运行验证"""
if val_loader is None:
return None
agent.eval()
# 设置确定性种子以获得可重现的损失
# 这确保验证损失在不同步骤之间可比较
torch.manual_seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed(42)
total_loss = 0.0
num_batches = 0
with torch.no_grad():
for val_batch in val_loader:
val_batch = recursive_to_device(val_batch, cfg.train.device)
val_input = build_agent_input(val_batch)
val_loss = agent.compute_loss(val_input)
total_loss += val_loss.item()
num_batches += 1
agent.train()
return total_loss / max(num_batches, 1)
def run_rollout_validation(checkpoint_path: Path):
from roboimi.demos.vla_scripts import eval_vla
rollout_cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=False))
rollout_cfg.eval.ckpt_path = str(checkpoint_path)
rollout_cfg.eval.num_episodes = int(cfg.train.get('rollout_num_episodes', 1))
rollout_cfg.eval.headless = True
rollout_cfg.eval.device = 'cpu'
rollout_cfg.eval.verbose_action = False
log.info(
"🎯 开始 checkpoint rollout 验证: %s (episodes=%s, headless=True)",
checkpoint_path,
rollout_cfg.eval.num_episodes,
)
return eval_vla._run_eval(rollout_cfg)
def run_checkpoint_rollout_validation(checkpoint_path: Path):
if not bool(cfg.train.get('rollout_validate_on_checkpoint', False)):
return None
return run_rollout_validation(checkpoint_path)
data_iter = iter(train_loader)
pbar = tqdm(range(start_step, cfg.train.max_steps), desc="训练中", ncols=100)
steps_per_epoch = len(train_loader)
rollout_val_freq_epochs = int(cfg.train.get('rollout_val_freq_epochs', 0) or 0)
rollout_validation_enabled = rollout_val_freq_epochs > 0
best_loss = resume_best_loss
best_rollout_reward = resume_best_rollout_reward
last_loss = resume_loss
if start_step >= cfg.train.max_steps:
log.warning(
f"⚠️ [Resume] start_step={start_step} 已达到/超过 max_steps={cfg.train.max_steps},跳过训练循环"
)
for step in pbar:
try:
batch = next(data_iter)
except StopIteration:
# 轮次结束时重启迭代器
data_iter = iter(train_loader)
batch = next(data_iter)
# =====================================================================
# 将批次移至设备
# =====================================================================
batch = recursive_to_device(batch, cfg.train.device)
# =====================================================================
# 准备 agent 输入
# =====================================================================
# 数据集返回: {action, qpos, image_<cam_name>, ...}
# Agent 期望: {images: dict, qpos: tensor, action: tensor}
# 准备 agent 输入
agent_input = build_agent_input(batch)
# =====================================================================
# 前向传播与损失计算
# =====================================================================
try:
loss = agent.compute_loss(agent_input)
except Exception as e:
log.error(f"❌ 步骤 {step} 前向传播失败: {e}")
raise
last_loss = loss.item()
# =====================================================================
# 反向传播与优化
# =====================================================================
optimizer.zero_grad()
loss.backward()
# 梯度裁剪以稳定训练
torch.nn.utils.clip_grad_norm_(agent.parameters(), max_norm=grad_clip)
optimizer.step()
scheduler.step()
# =====================================================================
# 日志记录
# =====================================================================
if step % cfg.train.log_freq == 0:
current_lr = optimizer.param_groups[0]['lr']
best_loss_to_log = best_loss if best_loss != float('inf') else loss.item()
pbar.set_postfix({
"loss": f"{loss.item():.4f}",
"lr": f"{current_lr:.2e}",
"best_loss": f"{best_loss_to_log:.4f}"
})
log.info(f"步骤 {step}/{cfg.train.max_steps} | 损失: {loss.item():.4f} | 学习率: {current_lr:.2e}")
_log_to_swanlab(
swanlab_module,
{
'train/loss': loss.item(),
'train/lr': current_lr,
'train/best_loss': best_loss_to_log,
'train/step': step,
},
step=step,
)
# =====================================================================
# 检查点保存与验证
# =====================================================================
checkpoint_path = None
val_loss = None
if step > 0 and step % cfg.train.save_freq == 0:
# 运行验证
val_loss = run_validation()
if val_loss is not None:
log.info(f"步骤 {step}/{cfg.train.max_steps} | 验证损失: {val_loss:.4f}")
_log_to_swanlab(
swanlab_module,
{'val/loss': val_loss},
step=step,
)
checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt"
save_checkpoint(
checkpoint_path,
step,
loss.item(),
val_loss=val_loss,
)
log.info(f"💾 检查点已保存: {checkpoint_path}")
# 在首次拿到 rollout 平均奖励之前,使用损失作为最佳模型回退指标
if best_rollout_reward == float('-inf'):
eval_loss = val_loss if val_loss is not None else loss.item()
if eval_loss < best_loss:
best_loss = eval_loss
best_model_path = default_best_model_path
save_checkpoint(
best_model_path,
step,
loss.item(),
val_loss=val_loss,
)
log.info(f"🌟 最佳模型已更新: {best_model_path} (验证损失: {best_loss:.4f})")
checkpoint_rollout_stats = run_checkpoint_rollout_validation(checkpoint_path)
checkpoint_rollout_avg_reward = (
checkpoint_rollout_stats.get('avg_reward')
if checkpoint_rollout_stats is not None else None
)
if checkpoint_rollout_avg_reward is not None:
log.info(
f"步骤 {step}/{cfg.train.max_steps} | checkpoint rollout 平均奖励: "
f"{checkpoint_rollout_avg_reward:.4f}"
)
_log_to_swanlab(
swanlab_module,
{'rollout/avg_reward': checkpoint_rollout_avg_reward},
step=step,
)
if checkpoint_rollout_avg_reward > best_rollout_reward:
best_rollout_reward = checkpoint_rollout_avg_reward
best_model_path = default_best_model_path
save_checkpoint(
best_model_path,
step,
loss.item(),
val_loss=val_loss,
rollout_avg_reward=checkpoint_rollout_avg_reward,
)
log.info(
f"🌟 最佳模型已更新: {best_model_path} "
f"(checkpoint rollout 平均奖励: {best_rollout_reward:.4f})"
)
completed_steps = step + 1
completed_epoch = (
completed_steps // steps_per_epoch
if steps_per_epoch > 0 else 0
)
should_run_epoch_rollout = (
rollout_validation_enabled
and steps_per_epoch > 0
and completed_steps % steps_per_epoch == 0
and completed_epoch > 0
and completed_epoch % rollout_val_freq_epochs == 0
)
if should_run_epoch_rollout:
if checkpoint_path is None:
checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt"
save_checkpoint(
checkpoint_path,
step,
loss.item(),
val_loss=val_loss,
)
log.info(f"💾 Epoch rollout 验证前检查点已保存: {checkpoint_path}")
rollout_stats = run_rollout_validation(checkpoint_path)
rollout_avg_reward = (
rollout_stats.get('avg_reward')
if rollout_stats is not None else None
)
if rollout_avg_reward is not None:
log.info(
f"步骤 {step}/{cfg.train.max_steps} | Epoch {completed_epoch} "
f"rollout 平均奖励: {rollout_avg_reward:.4f}"
)
_log_to_swanlab(
swanlab_module,
{
'rollout/avg_reward': rollout_avg_reward,
'rollout/epoch': completed_epoch,
},
step=step,
)
if rollout_avg_reward > best_rollout_reward:
best_rollout_reward = rollout_avg_reward
best_model_path = default_best_model_path
save_checkpoint(
best_model_path,
step,
loss.item(),
val_loss=val_loss,
rollout_avg_reward=rollout_avg_reward,
)
log.info(
f"🌟 最佳模型已更新: {best_model_path} "
f"(Epoch {completed_epoch} rollout 平均奖励: {best_rollout_reward:.4f})"
)
# =========================================================================
# 6. 保存最终模型
# =========================================================================
final_model_path = checkpoint_dir / "vla_model_final.pt"
save_checkpoint(
final_model_path,
cfg.train.max_steps,
last_loss,
)
log.info(f"💾 最终模型已保存: {final_model_path}")
_log_to_swanlab(
swanlab_module,
{
'final/checkpoint_path': str(final_model_path),
'final/best_checkpoint_path': (
str(best_model_path) if best_model_path is not None else ''
),
},
step=cfg.train.max_steps,
)
log.info("✅ 训练成功完成!")
if last_loss is not None:
log.info(f"📊 最终损失: {last_loss:.4f}")
else:
log.info("📊 最终损失: N/A未执行训练步")
if best_rollout_reward != float('-inf'):
log.info(f"📊 最佳 rollout 平均奖励: {best_rollout_reward:.4f}")
elif best_loss != float('inf'):
log.info(f"📊 最佳损失: {best_loss:.4f}")
else:
log.info("📊 最佳验证指标: N/A无有效 rollout/验证损失)")
finally:
_finish_swanlab(swanlab_module)
@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config")
def main(cfg: DictConfig):
_run_training(cfg)
if __name__ == "__main__":
main()