857 lines
35 KiB
Python
857 lines
35 KiB
Python
import sys
|
||
import os
|
||
import logging
|
||
import json
|
||
import pickle
|
||
import importlib
|
||
import hydra
|
||
import torch
|
||
import re
|
||
from tqdm import tqdm
|
||
from omegaconf import DictConfig, OmegaConf
|
||
from torch.utils.data import DataLoader, random_split
|
||
from torch.optim import AdamW
|
||
from torch.optim.lr_scheduler import LambdaLR
|
||
from pathlib import Path
|
||
|
||
# 确保正确的导入路径
|
||
sys.path.append(os.getcwd())
|
||
|
||
from hydra.utils import instantiate
|
||
|
||
log = logging.getLogger(__name__)
|
||
|
||
# 注册列表长度解析器(用于配置中如 ${len:${data.camera_names}})
|
||
if not OmegaConf.has_resolver("len"):
|
||
OmegaConf.register_new_resolver("len", lambda x: len(x))
|
||
|
||
|
||
def recursive_to_device(data, device):
|
||
"""
|
||
递归地将嵌套字典/列表中的张量移动到指定设备。
|
||
|
||
Args:
|
||
data: 字典、列表或张量
|
||
device: 目标设备 (例如 'cuda', 'cpu')
|
||
|
||
Returns:
|
||
所有张量已移动到指定设备的数据结构
|
||
"""
|
||
if isinstance(data, torch.Tensor):
|
||
return data.to(device)
|
||
elif isinstance(data, dict):
|
||
return {k: recursive_to_device(v, device) for k, v in data.items()}
|
||
elif isinstance(data, list):
|
||
return [recursive_to_device(v, device) for v in data]
|
||
return data
|
||
|
||
|
||
def resolve_resume_checkpoint(resume_ckpt, checkpoint_dir):
|
||
"""
|
||
解析恢复训练用的 checkpoint 路径。
|
||
|
||
Args:
|
||
resume_ckpt: 配置中的 resume_ckpt,支持路径或 "auto"
|
||
checkpoint_dir: 默认检查点目录
|
||
|
||
Returns:
|
||
Path 或 None
|
||
"""
|
||
if resume_ckpt is None:
|
||
return None
|
||
|
||
if str(resume_ckpt).lower() != "auto":
|
||
return Path(resume_ckpt)
|
||
|
||
pattern = re.compile(r"vla_model_step_(\d+)\.pt$")
|
||
candidates = []
|
||
for ckpt_path in checkpoint_dir.glob("vla_model_step_*.pt"):
|
||
match = pattern.search(ckpt_path.name)
|
||
if match:
|
||
candidates.append((int(match.group(1)), ckpt_path))
|
||
|
||
if not candidates:
|
||
return None
|
||
return max(candidates, key=lambda x: x[0])[1]
|
||
|
||
|
||
def get_lr_schedule_with_warmup(optimizer, warmup_steps, max_steps, scheduler_type='cosine', min_lr=0):
|
||
"""
|
||
创建带预热的学习率调度器。
|
||
|
||
Args:
|
||
optimizer: PyTorch 优化器
|
||
warmup_steps: 预热步数
|
||
max_steps: 总训练步数
|
||
scheduler_type: 预热后的调度器类型 ('cosine' 或 'constant')
|
||
min_lr: 最小学习率(用于余弦衰减)
|
||
|
||
Returns:
|
||
LambdaLR 调度器
|
||
"""
|
||
import math
|
||
# 在 LambdaLR 修改前捕获初始学习率
|
||
base_lr = optimizer.param_groups[0]['lr']
|
||
min_lr_ratio = min_lr / base_lr if base_lr > 0 else 0.0
|
||
|
||
def lr_lambda(step):
|
||
# 预热阶段:从 0 线性增加到 1
|
||
if step < warmup_steps:
|
||
return float(step) / float(max(1, warmup_steps))
|
||
|
||
# 预热后阶段
|
||
if scheduler_type == 'cosine':
|
||
# 从 1 到 min_lr_ratio 的余弦退火
|
||
progress = float(step - warmup_steps) / float(max(1, max_steps - warmup_steps))
|
||
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
|
||
return max(min_lr_ratio, cosine_decay)
|
||
else:
|
||
# 恒定学习率
|
||
return 1.0
|
||
|
||
return LambdaLR(optimizer, lr_lambda)
|
||
|
||
|
||
def build_training_optimizer(agent, lr, weight_decay):
|
||
"""为训练脚本构建优化器,优先复用 transformer head 自带的参数分组。"""
|
||
trainable_params = [param for param in agent.parameters() if param.requires_grad]
|
||
noise_pred_net = getattr(agent, 'noise_pred_net', None)
|
||
get_optim_groups = getattr(noise_pred_net, 'get_optim_groups', None)
|
||
use_head_groups = (
|
||
getattr(agent, 'head_type', None) == 'transformer'
|
||
and callable(get_optim_groups)
|
||
)
|
||
|
||
if not use_head_groups:
|
||
return AdamW(trainable_params, lr=lr, weight_decay=weight_decay)
|
||
|
||
head_groups = []
|
||
grouped_param_ids = set()
|
||
for group in get_optim_groups(weight_decay=weight_decay):
|
||
params = [param for param in group['params'] if param.requires_grad]
|
||
if not params:
|
||
continue
|
||
normalized_group = dict(group)
|
||
normalized_group['params'] = params
|
||
head_groups.append(normalized_group)
|
||
|
||
for param in params:
|
||
param_id = id(param)
|
||
if param_id in grouped_param_ids:
|
||
raise ValueError('Transformer optimizer groups contain duplicate parameters')
|
||
grouped_param_ids.add(param_id)
|
||
|
||
head_trainable_param_ids = {
|
||
id(param) for param in noise_pred_net.parameters() if param.requires_grad
|
||
}
|
||
missing_head_param_ids = head_trainable_param_ids - grouped_param_ids
|
||
if missing_head_param_ids:
|
||
raise ValueError('Transformer optimizer groups missed trainable head parameters')
|
||
|
||
remaining_params = [
|
||
param for param in trainable_params
|
||
if id(param) not in grouped_param_ids
|
||
]
|
||
|
||
optim_groups = head_groups
|
||
if remaining_params:
|
||
optim_groups = optim_groups + [{
|
||
'params': remaining_params,
|
||
'weight_decay': weight_decay,
|
||
}]
|
||
grouped_param_ids.update(id(param) for param in remaining_params)
|
||
|
||
all_trainable_param_ids = {id(param) for param in trainable_params}
|
||
if grouped_param_ids != all_trainable_param_ids:
|
||
raise ValueError('Optimizer parameter groups must include each trainable parameter exactly once')
|
||
|
||
return AdamW(optim_groups, lr=lr, weight_decay=weight_decay)
|
||
|
||
|
||
def _init_swanlab(cfg):
|
||
"""按需初始化 SwanLab,并在缺少依赖或认证失败时快速失败。"""
|
||
if not bool(cfg.train.get('use_swanlab', False)):
|
||
return None
|
||
|
||
try:
|
||
swanlab = importlib.import_module("swanlab")
|
||
except ImportError as exc:
|
||
raise RuntimeError(
|
||
"SwanLab logging is enabled, but the 'swanlab' package could not be imported."
|
||
) from exc
|
||
|
||
def _to_plain_config(value):
|
||
if isinstance(value, dict):
|
||
return {key: _to_plain_config(val) for key, val in value.items()}
|
||
if isinstance(value, list):
|
||
return [_to_plain_config(item) for item in value]
|
||
if isinstance(value, tuple):
|
||
return tuple(_to_plain_config(item) for item in value)
|
||
|
||
items_method = getattr(value, 'items', None)
|
||
if callable(items_method):
|
||
try:
|
||
return {key: _to_plain_config(val) for key, val in items_method()}
|
||
except Exception:
|
||
pass
|
||
|
||
return value
|
||
|
||
swanlab_config = {
|
||
key: _to_plain_config(cfg[key])
|
||
for key in ('train', 'data', 'agent')
|
||
if key in cfg
|
||
}
|
||
|
||
init_kwargs = {
|
||
'project': cfg.train.get('swanlab_project', 'roboimi-vla'),
|
||
'config': swanlab_config,
|
||
}
|
||
run_name = cfg.train.get('swanlab_run_name', None)
|
||
if run_name:
|
||
init_kwargs['experiment_name'] = run_name
|
||
|
||
try:
|
||
swanlab.init(**init_kwargs)
|
||
except Exception as exc:
|
||
raise RuntimeError(
|
||
f"SwanLab logging is enabled, but SwanLab init/login failed: {exc}"
|
||
) from exc
|
||
|
||
return swanlab
|
||
|
||
|
||
def _log_to_swanlab(swanlab_module, payload, step=None):
|
||
if swanlab_module is None:
|
||
return
|
||
try:
|
||
swanlab_module.log(payload, step=step)
|
||
except Exception as exc:
|
||
log.warning(f"SwanLab log failed at step {step}: {exc}")
|
||
|
||
|
||
def _finish_swanlab(swanlab_module):
|
||
if swanlab_module is None:
|
||
return
|
||
try:
|
||
swanlab_module.finish()
|
||
except Exception as exc:
|
||
log.warning(f"SwanLab finish failed: {exc}")
|
||
|
||
|
||
def _run_training(cfg: DictConfig):
|
||
"""
|
||
VLA 训练脚本(ResNet 骨干网络 + Diffusion 策略)
|
||
|
||
该脚本功能:
|
||
1. 从 HDF5 文件加载数据集
|
||
2. 实例化带 ResNet 视觉编码器的 VLAAgent
|
||
3. 训练基于扩散的动作预测模型
|
||
4. 定期保存检查点
|
||
"""
|
||
|
||
# 打印配置
|
||
print("=" * 80)
|
||
print("VLA 训练配置:")
|
||
print("=" * 80)
|
||
print(OmegaConf.to_yaml(cfg))
|
||
print("=" * 80)
|
||
|
||
log.info(f"🚀 开始 VLA 训练 (设备: {cfg.train.device})")
|
||
swanlab_module = _init_swanlab(cfg)
|
||
try:
|
||
# 创建检查点目录
|
||
checkpoint_dir = Path("checkpoints")
|
||
checkpoint_dir.mkdir(exist_ok=True)
|
||
default_best_model_path = checkpoint_dir / "vla_model_best.pt"
|
||
|
||
# =========================================================================
|
||
# 1. 实例化数据集与 DataLoader
|
||
# =========================================================================
|
||
log.info("📦 加载数据集...")
|
||
try:
|
||
dataset = instantiate(cfg.data)
|
||
log.info(f"✅ 数据集加载成功。总样本数: {len(dataset)}")
|
||
except Exception as e:
|
||
log.error(f"❌ 数据集加载失败: {e}")
|
||
raise
|
||
|
||
# 训练/验证集划分
|
||
val_split = float(cfg.train.get('val_split', 0.1))
|
||
seed = int(cfg.train.get('seed', 42))
|
||
val_size = int(len(dataset) * val_split)
|
||
train_size = len(dataset) - val_size
|
||
if val_size > 0:
|
||
train_dataset, val_dataset = random_split(
|
||
dataset,
|
||
[train_size, val_size],
|
||
generator=torch.Generator().manual_seed(seed)
|
||
)
|
||
log.info(f"✅ 数据集划分: 训练集={train_size}, 验证集={val_size} (验证比例={val_split})")
|
||
else:
|
||
train_dataset, val_dataset = dataset, None
|
||
log.info("✅ 数据集划分: 全部用于训练, 验证集=0 (验证比例=0)")
|
||
|
||
train_batch_size = int(cfg.train.batch_size)
|
||
train_drop_last = len(train_dataset) >= train_batch_size
|
||
if not train_drop_last:
|
||
log.warning(
|
||
"⚠️ 训练集样本数 (%s) 小于 batch_size (%s),将保留最后一个不完整批次以避免空训练加载器",
|
||
len(train_dataset),
|
||
train_batch_size,
|
||
)
|
||
|
||
train_loader = DataLoader(
|
||
train_dataset,
|
||
batch_size=train_batch_size,
|
||
shuffle=True,
|
||
num_workers=cfg.train.num_workers,
|
||
pin_memory=(cfg.train.device != "cpu"),
|
||
persistent_workers=False,
|
||
drop_last=train_drop_last
|
||
)
|
||
|
||
val_loader = None
|
||
if val_dataset is not None:
|
||
val_loader = DataLoader(
|
||
val_dataset,
|
||
batch_size=train_batch_size,
|
||
shuffle=False,
|
||
num_workers=cfg.train.num_workers,
|
||
pin_memory=(cfg.train.device != "cpu"),
|
||
persistent_workers=False,
|
||
drop_last=False
|
||
)
|
||
|
||
log.info(f"✅ 训练加载器每轮批次数: {len(train_loader)}")
|
||
if val_loader is not None:
|
||
log.info(f"✅ 验证加载器每轮批次数: {len(val_loader)}")
|
||
|
||
# =========================================================================
|
||
# 2. 加载数据集统计信息(将传递给 agent)
|
||
# =========================================================================
|
||
log.info("💾 加载数据集统计信息...")
|
||
dataset_stats = None
|
||
try:
|
||
dataset_dir = cfg.data.get('dataset_dir', 'roboimi/demos/dataset/sim_transfer')
|
||
stats_path = Path(dataset_dir) / 'dataset_stats.pkl'
|
||
|
||
if stats_path.exists():
|
||
with open(stats_path, 'rb') as f:
|
||
stats = pickle.load(f)
|
||
|
||
# 扁平化stats字典(嵌套结构→扁平结构)以匹配NormalizationModule的期望格式
|
||
dataset_stats = {
|
||
'action_mean': stats['action_mean'].tolist(),
|
||
'action_std': stats['action_std'].tolist(),
|
||
'action_min': stats['action_min'].tolist(),
|
||
'action_max': stats['action_max'].tolist(),
|
||
'qpos_mean': stats['qpos_mean'].tolist(),
|
||
'qpos_std': stats['qpos_std'].tolist(),
|
||
'qpos_min': stats['qpos_min'].tolist(),
|
||
'qpos_max': stats['qpos_max'].tolist(),
|
||
}
|
||
log.info(f"✅ 数据集统计信息加载完成 (归一化: {cfg.agent.normalization_type})")
|
||
else:
|
||
log.warning(f"⚠️ 统计文件未找到: {stats_path}")
|
||
log.warning("⚠️ 推理时动作将无法反归一化!")
|
||
|
||
except Exception as e:
|
||
log.warning(f"⚠️ 统计信息加载失败: {e}")
|
||
log.warning("⚠️ 训练将继续,但推理可能无法正常工作")
|
||
|
||
# =========================================================================
|
||
# 3. 实例化 VLA Agent
|
||
# =========================================================================
|
||
log.info("🤖 初始化 VLA Agent...")
|
||
try:
|
||
# 将 dataset_stats 和 normalization_type 传递给 agent
|
||
agent = instantiate(cfg.agent, dataset_stats=dataset_stats)
|
||
agent.to(cfg.train.device)
|
||
agent.train()
|
||
log.info(f"✅ Agent 初始化完成并已移至 {cfg.train.device}")
|
||
|
||
# 统计参数量
|
||
total_params = sum(p.numel() for p in agent.parameters())
|
||
trainable_params = sum(p.numel() for p in agent.parameters() if p.requires_grad)
|
||
log.info(f"📊 总参数量: {total_params:,}")
|
||
log.info(f"📊 可训练参数量: {trainable_params:,}")
|
||
|
||
except Exception as e:
|
||
log.error(f"❌ Agent 初始化失败: {e}")
|
||
raise
|
||
|
||
# =========================================================================
|
||
# 3.1 从预训练 checkpoint 加载权重(微调)
|
||
# =========================================================================
|
||
pretrained_ckpt = cfg.train.get('pretrained_ckpt', None)
|
||
if pretrained_ckpt is not None:
|
||
ckpt_path = Path(pretrained_ckpt)
|
||
if ckpt_path.exists():
|
||
log.info(f"🔄 [Finetune] 从预训练 checkpoint 加载权重: {ckpt_path}")
|
||
try:
|
||
checkpoint = torch.load(ckpt_path, map_location=cfg.train.device)
|
||
|
||
# 只加载模型权重(不加载 optimizer、scheduler)
|
||
missing_keys, unexpected_keys = agent.load_state_dict(
|
||
checkpoint['model_state_dict'],
|
||
strict=False # 允许部分加载(结构不完全匹配时)
|
||
)
|
||
|
||
log.info(f"✅ [Finetune] 模型权重加载成功")
|
||
|
||
if missing_keys:
|
||
log.warning(f"⚠️ [Finetune] 缺少的键 ({len(missing_keys)} 个): {missing_keys[:5]}...")
|
||
if unexpected_keys:
|
||
log.warning(f"⚠️ [Finetune] 多余的键 ({len(unexpected_keys)} 个): {unexpected_keys[:5]}...")
|
||
|
||
log.info(f"📊 [Finetune] 预训练信息: 步骤={checkpoint.get('step', 'N/A')}, 损失={checkpoint.get('loss', 'N/A')}")
|
||
log.info(f"📈 [Finetune] 使用新的训练配置(lr={cfg.train.lr}, max_steps={cfg.train.max_steps})")
|
||
|
||
except Exception as e:
|
||
log.error(f"❌ [Finetune] 加载 checkpoint 失败: {e}")
|
||
log.warning("⚠️ 将从头开始训练")
|
||
else:
|
||
log.error(f"❌ [Finetune] Checkpoint 文件不存在: {ckpt_path}")
|
||
log.warning("⚠️ 将从头开始训练")
|
||
|
||
# =========================================================================
|
||
# 4. 设置优化器与学习率调度器
|
||
# =========================================================================
|
||
weight_decay = float(cfg.train.get('weight_decay', 1e-5))
|
||
grad_clip = float(cfg.train.get('grad_clip', 1.0))
|
||
|
||
optimizer = build_training_optimizer(agent, lr=cfg.train.lr, weight_decay=weight_decay)
|
||
log.info(f"🔧 优化器: AdamW (学习率={cfg.train.lr}, weight_decay={weight_decay})")
|
||
|
||
# 设置带预热的学習率调度器
|
||
warmup_steps = int(cfg.train.get('warmup_steps', 500))
|
||
scheduler_type = cfg.train.get('scheduler_type', 'cosine')
|
||
min_lr = float(cfg.train.get('min_lr', 1e-6))
|
||
|
||
scheduler = get_lr_schedule_with_warmup(
|
||
optimizer,
|
||
warmup_steps=warmup_steps,
|
||
max_steps=cfg.train.max_steps,
|
||
scheduler_type=scheduler_type,
|
||
min_lr=min_lr
|
||
)
|
||
log.info(f"📈 学习率调度器: {scheduler_type},{warmup_steps} 步预热 (最小学习率={min_lr})")
|
||
|
||
# =========================================================================
|
||
# 4.1 断点续训(恢复模型、优化器、调度器、步数)
|
||
# =========================================================================
|
||
def extract_checkpoint_metric_baseline(checkpoint):
|
||
checkpoint_loss = checkpoint.get('loss', None)
|
||
checkpoint_val_loss = checkpoint.get('val_loss', None)
|
||
checkpoint_rollout_reward = checkpoint.get('rollout_avg_reward', None)
|
||
|
||
baseline_loss = float('inf')
|
||
baseline_rollout_reward = float('-inf')
|
||
if checkpoint_rollout_reward is not None:
|
||
baseline_rollout_reward = float(checkpoint_rollout_reward)
|
||
if checkpoint_val_loss is not None:
|
||
baseline_loss = float(checkpoint_val_loss)
|
||
elif checkpoint_loss is not None:
|
||
baseline_loss = float(checkpoint_loss)
|
||
return baseline_loss, baseline_rollout_reward
|
||
|
||
start_step = 0
|
||
resume_loss = None
|
||
resume_best_loss = float('inf')
|
||
resume_best_rollout_reward = float('-inf')
|
||
best_model_path = None
|
||
|
||
resume_ckpt = cfg.train.get('resume_ckpt', None)
|
||
resume_path = resolve_resume_checkpoint(resume_ckpt, checkpoint_dir)
|
||
if resume_ckpt is not None:
|
||
if pretrained_ckpt is not None:
|
||
log.warning("⚠️ [Resume] 同时设置了 pretrained_ckpt 与 resume_ckpt,将优先使用 resume_ckpt 进行断点续训")
|
||
if resume_path is None:
|
||
log.warning("⚠️ [Resume] 未找到可恢复的 checkpoint,将从头开始训练")
|
||
elif not resume_path.exists():
|
||
log.error(f"❌ [Resume] Checkpoint 文件不存在: {resume_path}")
|
||
log.warning("⚠️ 将从头开始训练")
|
||
else:
|
||
log.info(f"🔄 [Resume] 从 checkpoint 恢复训练: {resume_path}")
|
||
try:
|
||
checkpoint = torch.load(resume_path, map_location=cfg.train.device)
|
||
|
||
agent.load_state_dict(checkpoint['model_state_dict'], strict=True)
|
||
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||
|
||
resume_step = int(checkpoint['step'])
|
||
start_step = resume_step + 1
|
||
|
||
loaded_loss = checkpoint.get('loss', None)
|
||
resume_loss = float(loaded_loss) if loaded_loss is not None else None
|
||
resume_best_loss, resume_best_rollout_reward = extract_checkpoint_metric_baseline(checkpoint)
|
||
if (
|
||
resume_best_rollout_reward != float('-inf')
|
||
or resume_best_loss != float('inf')
|
||
):
|
||
best_model_path = resume_path
|
||
|
||
if default_best_model_path.exists():
|
||
try:
|
||
best_checkpoint = torch.load(default_best_model_path, map_location=cfg.train.device)
|
||
_, best_checkpoint_rollout_reward = (
|
||
extract_checkpoint_metric_baseline(best_checkpoint)
|
||
)
|
||
if best_checkpoint_rollout_reward != float('-inf'):
|
||
resume_best_rollout_reward = best_checkpoint_rollout_reward
|
||
best_model_path = default_best_model_path
|
||
log.info(
|
||
"📈 [Resume] 从最佳 checkpoint 恢复最佳 rollout 基线: %s",
|
||
default_best_model_path,
|
||
)
|
||
except Exception as e:
|
||
log.warning(
|
||
f"⚠️ [Resume] 读取最佳 checkpoint 失败,将回退到恢复 checkpoint 的验证基线: {e}"
|
||
)
|
||
|
||
log.info(f"✅ [Resume] 恢复成功: 上次步骤={resume_step}, 本次从步骤 {start_step} 开始")
|
||
log.info(f"📈 [Resume] 当前学习率: {optimizer.param_groups[0]['lr']:.2e}")
|
||
except Exception as e:
|
||
log.error(f"❌ [Resume] 恢复失败: {e}")
|
||
log.warning("⚠️ 将从头开始训练")
|
||
start_step = 0
|
||
resume_loss = None
|
||
resume_best_loss = float('inf')
|
||
resume_best_rollout_reward = float('-inf')
|
||
|
||
# =========================================================================
|
||
# 5. 训练循环
|
||
# =========================================================================
|
||
log.info("🏋️ 开始训练循环...")
|
||
|
||
def build_agent_input(batch_data):
|
||
"""构建 agent 输入格式"""
|
||
images = {}
|
||
# SimpleRobotDataset 返回 observation.{cam_name} 格式
|
||
for cam_name in cfg.data.camera_names:
|
||
key = f"observation.{cam_name}"
|
||
if key in batch_data:
|
||
images[cam_name] = batch_data[key]
|
||
|
||
return {
|
||
'images': images,
|
||
'qpos': batch_data['observation.state'], # SimpleRobotDataset 使用 observation.state
|
||
'action': batch_data['action'],
|
||
'action_is_pad': batch_data.get('action_is_pad', None) # 传递padding mask
|
||
}
|
||
|
||
def save_checkpoint(checkpoint_path: Path, step: int, loss_value, val_loss=None, rollout_avg_reward=None):
|
||
agent_stats = agent.get_normalization_stats()
|
||
torch.save({
|
||
'step': step,
|
||
'model_state_dict': agent.state_dict(),
|
||
'optimizer_state_dict': optimizer.state_dict(),
|
||
'scheduler_state_dict': scheduler.state_dict(),
|
||
'loss': loss_value,
|
||
'val_loss': val_loss,
|
||
'rollout_avg_reward': rollout_avg_reward,
|
||
'dataset_stats': agent_stats, # 保存agent的统计信息
|
||
'current_lr': optimizer.param_groups[0]['lr'],
|
||
}, checkpoint_path)
|
||
return checkpoint_path
|
||
|
||
def run_validation():
|
||
"""运行验证"""
|
||
if val_loader is None:
|
||
return None
|
||
agent.eval()
|
||
|
||
# 设置确定性种子以获得可重现的损失
|
||
# 这确保验证损失在不同步骤之间可比较
|
||
torch.manual_seed(42)
|
||
if torch.cuda.is_available():
|
||
torch.cuda.manual_seed(42)
|
||
|
||
total_loss = 0.0
|
||
num_batches = 0
|
||
with torch.no_grad():
|
||
for val_batch in val_loader:
|
||
val_batch = recursive_to_device(val_batch, cfg.train.device)
|
||
val_input = build_agent_input(val_batch)
|
||
val_loss = agent.compute_loss(val_input)
|
||
total_loss += val_loss.item()
|
||
num_batches += 1
|
||
agent.train()
|
||
return total_loss / max(num_batches, 1)
|
||
|
||
def run_rollout_validation(checkpoint_path: Path):
|
||
from roboimi.demos.vla_scripts import eval_vla
|
||
|
||
rollout_cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=False))
|
||
rollout_cfg.eval.ckpt_path = str(checkpoint_path)
|
||
rollout_cfg.eval.num_episodes = int(cfg.train.get('rollout_num_episodes', 1))
|
||
rollout_cfg.eval.headless = True
|
||
rollout_cfg.eval.device = 'cpu'
|
||
rollout_cfg.eval.verbose_action = False
|
||
|
||
log.info(
|
||
"🎯 开始 checkpoint rollout 验证: %s (episodes=%s, headless=True)",
|
||
checkpoint_path,
|
||
rollout_cfg.eval.num_episodes,
|
||
)
|
||
return eval_vla._run_eval(rollout_cfg)
|
||
|
||
def run_checkpoint_rollout_validation(checkpoint_path: Path):
|
||
if not bool(cfg.train.get('rollout_validate_on_checkpoint', False)):
|
||
return None
|
||
return run_rollout_validation(checkpoint_path)
|
||
|
||
data_iter = iter(train_loader)
|
||
pbar = tqdm(range(start_step, cfg.train.max_steps), desc="训练中", ncols=100)
|
||
|
||
steps_per_epoch = len(train_loader)
|
||
rollout_val_freq_epochs = int(cfg.train.get('rollout_val_freq_epochs', 0) or 0)
|
||
rollout_validation_enabled = rollout_val_freq_epochs > 0
|
||
best_loss = resume_best_loss
|
||
best_rollout_reward = resume_best_rollout_reward
|
||
last_loss = resume_loss
|
||
|
||
if start_step >= cfg.train.max_steps:
|
||
log.warning(
|
||
f"⚠️ [Resume] start_step={start_step} 已达到/超过 max_steps={cfg.train.max_steps},跳过训练循环"
|
||
)
|
||
|
||
for step in pbar:
|
||
try:
|
||
batch = next(data_iter)
|
||
except StopIteration:
|
||
# 轮次结束时重启迭代器
|
||
data_iter = iter(train_loader)
|
||
batch = next(data_iter)
|
||
|
||
# =====================================================================
|
||
# 将批次移至设备
|
||
# =====================================================================
|
||
batch = recursive_to_device(batch, cfg.train.device)
|
||
|
||
# =====================================================================
|
||
# 准备 agent 输入
|
||
# =====================================================================
|
||
# 数据集返回: {action, qpos, image_<cam_name>, ...}
|
||
# Agent 期望: {images: dict, qpos: tensor, action: tensor}
|
||
|
||
# 准备 agent 输入
|
||
agent_input = build_agent_input(batch)
|
||
|
||
# =====================================================================
|
||
# 前向传播与损失计算
|
||
# =====================================================================
|
||
try:
|
||
loss = agent.compute_loss(agent_input)
|
||
except Exception as e:
|
||
log.error(f"❌ 步骤 {step} 前向传播失败: {e}")
|
||
raise
|
||
|
||
last_loss = loss.item()
|
||
|
||
# =====================================================================
|
||
# 反向传播与优化
|
||
# =====================================================================
|
||
optimizer.zero_grad()
|
||
loss.backward()
|
||
|
||
# 梯度裁剪以稳定训练
|
||
torch.nn.utils.clip_grad_norm_(agent.parameters(), max_norm=grad_clip)
|
||
|
||
optimizer.step()
|
||
scheduler.step()
|
||
|
||
# =====================================================================
|
||
# 日志记录
|
||
# =====================================================================
|
||
if step % cfg.train.log_freq == 0:
|
||
current_lr = optimizer.param_groups[0]['lr']
|
||
best_loss_to_log = best_loss if best_loss != float('inf') else loss.item()
|
||
pbar.set_postfix({
|
||
"loss": f"{loss.item():.4f}",
|
||
"lr": f"{current_lr:.2e}",
|
||
"best_loss": f"{best_loss_to_log:.4f}"
|
||
})
|
||
log.info(f"步骤 {step}/{cfg.train.max_steps} | 损失: {loss.item():.4f} | 学习率: {current_lr:.2e}")
|
||
_log_to_swanlab(
|
||
swanlab_module,
|
||
{
|
||
'train/loss': loss.item(),
|
||
'train/lr': current_lr,
|
||
'train/best_loss': best_loss_to_log,
|
||
'train/step': step,
|
||
},
|
||
step=step,
|
||
)
|
||
|
||
# =====================================================================
|
||
# 检查点保存与验证
|
||
# =====================================================================
|
||
checkpoint_path = None
|
||
val_loss = None
|
||
if step > 0 and step % cfg.train.save_freq == 0:
|
||
# 运行验证
|
||
val_loss = run_validation()
|
||
if val_loss is not None:
|
||
log.info(f"步骤 {step}/{cfg.train.max_steps} | 验证损失: {val_loss:.4f}")
|
||
_log_to_swanlab(
|
||
swanlab_module,
|
||
{'val/loss': val_loss},
|
||
step=step,
|
||
)
|
||
|
||
checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt"
|
||
save_checkpoint(
|
||
checkpoint_path,
|
||
step,
|
||
loss.item(),
|
||
val_loss=val_loss,
|
||
)
|
||
log.info(f"💾 检查点已保存: {checkpoint_path}")
|
||
|
||
# 在首次拿到 rollout 平均奖励之前,使用损失作为最佳模型回退指标
|
||
if best_rollout_reward == float('-inf'):
|
||
eval_loss = val_loss if val_loss is not None else loss.item()
|
||
if eval_loss < best_loss:
|
||
best_loss = eval_loss
|
||
best_model_path = default_best_model_path
|
||
save_checkpoint(
|
||
best_model_path,
|
||
step,
|
||
loss.item(),
|
||
val_loss=val_loss,
|
||
)
|
||
log.info(f"🌟 最佳模型已更新: {best_model_path} (验证损失: {best_loss:.4f})")
|
||
|
||
checkpoint_rollout_stats = run_checkpoint_rollout_validation(checkpoint_path)
|
||
checkpoint_rollout_avg_reward = (
|
||
checkpoint_rollout_stats.get('avg_reward')
|
||
if checkpoint_rollout_stats is not None else None
|
||
)
|
||
if checkpoint_rollout_avg_reward is not None:
|
||
log.info(
|
||
f"步骤 {step}/{cfg.train.max_steps} | checkpoint rollout 平均奖励: "
|
||
f"{checkpoint_rollout_avg_reward:.4f}"
|
||
)
|
||
_log_to_swanlab(
|
||
swanlab_module,
|
||
{'rollout/avg_reward': checkpoint_rollout_avg_reward},
|
||
step=step,
|
||
)
|
||
if checkpoint_rollout_avg_reward > best_rollout_reward:
|
||
best_rollout_reward = checkpoint_rollout_avg_reward
|
||
best_model_path = default_best_model_path
|
||
save_checkpoint(
|
||
best_model_path,
|
||
step,
|
||
loss.item(),
|
||
val_loss=val_loss,
|
||
rollout_avg_reward=checkpoint_rollout_avg_reward,
|
||
)
|
||
log.info(
|
||
f"🌟 最佳模型已更新: {best_model_path} "
|
||
f"(checkpoint rollout 平均奖励: {best_rollout_reward:.4f})"
|
||
)
|
||
|
||
completed_steps = step + 1
|
||
completed_epoch = (
|
||
completed_steps // steps_per_epoch
|
||
if steps_per_epoch > 0 else 0
|
||
)
|
||
should_run_epoch_rollout = (
|
||
rollout_validation_enabled
|
||
and steps_per_epoch > 0
|
||
and completed_steps % steps_per_epoch == 0
|
||
and completed_epoch > 0
|
||
and completed_epoch % rollout_val_freq_epochs == 0
|
||
)
|
||
if should_run_epoch_rollout:
|
||
if checkpoint_path is None:
|
||
checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt"
|
||
save_checkpoint(
|
||
checkpoint_path,
|
||
step,
|
||
loss.item(),
|
||
val_loss=val_loss,
|
||
)
|
||
log.info(f"💾 Epoch rollout 验证前检查点已保存: {checkpoint_path}")
|
||
|
||
rollout_stats = run_rollout_validation(checkpoint_path)
|
||
rollout_avg_reward = (
|
||
rollout_stats.get('avg_reward')
|
||
if rollout_stats is not None else None
|
||
)
|
||
if rollout_avg_reward is not None:
|
||
log.info(
|
||
f"步骤 {step}/{cfg.train.max_steps} | Epoch {completed_epoch} "
|
||
f"rollout 平均奖励: {rollout_avg_reward:.4f}"
|
||
)
|
||
_log_to_swanlab(
|
||
swanlab_module,
|
||
{
|
||
'rollout/avg_reward': rollout_avg_reward,
|
||
'rollout/epoch': completed_epoch,
|
||
},
|
||
step=step,
|
||
)
|
||
if rollout_avg_reward > best_rollout_reward:
|
||
best_rollout_reward = rollout_avg_reward
|
||
best_model_path = default_best_model_path
|
||
save_checkpoint(
|
||
best_model_path,
|
||
step,
|
||
loss.item(),
|
||
val_loss=val_loss,
|
||
rollout_avg_reward=rollout_avg_reward,
|
||
)
|
||
log.info(
|
||
f"🌟 最佳模型已更新: {best_model_path} "
|
||
f"(Epoch {completed_epoch} rollout 平均奖励: {best_rollout_reward:.4f})"
|
||
)
|
||
|
||
# =========================================================================
|
||
# 6. 保存最终模型
|
||
# =========================================================================
|
||
final_model_path = checkpoint_dir / "vla_model_final.pt"
|
||
save_checkpoint(
|
||
final_model_path,
|
||
cfg.train.max_steps,
|
||
last_loss,
|
||
)
|
||
log.info(f"💾 最终模型已保存: {final_model_path}")
|
||
_log_to_swanlab(
|
||
swanlab_module,
|
||
{
|
||
'final/checkpoint_path': str(final_model_path),
|
||
'final/best_checkpoint_path': (
|
||
str(best_model_path) if best_model_path is not None else ''
|
||
),
|
||
},
|
||
step=cfg.train.max_steps,
|
||
)
|
||
|
||
log.info("✅ 训练成功完成!")
|
||
if last_loss is not None:
|
||
log.info(f"📊 最终损失: {last_loss:.4f}")
|
||
else:
|
||
log.info("📊 最终损失: N/A(未执行训练步)")
|
||
if best_rollout_reward != float('-inf'):
|
||
log.info(f"📊 最佳 rollout 平均奖励: {best_rollout_reward:.4f}")
|
||
elif best_loss != float('inf'):
|
||
log.info(f"📊 最佳损失: {best_loss:.4f}")
|
||
else:
|
||
log.info("📊 最佳验证指标: N/A(无有效 rollout/验证损失)")
|
||
finally:
|
||
_finish_swanlab(swanlab_module)
|
||
|
||
|
||
@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config")
|
||
def main(cfg: DictConfig):
|
||
_run_training(cfg)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|