debug: 保存stats到ckpt

This commit is contained in:
gouhanke
2026-02-12 13:00:43 +08:00
parent ab971b3f96
commit 37a47ac2dd
2 changed files with 13 additions and 4 deletions

View File

@@ -346,6 +346,8 @@ def main(cfg: DictConfig):
log.info(f"步骤 {step}/{cfg.train.max_steps} | 验证损失: {val_loss:.4f}") log.info(f"步骤 {step}/{cfg.train.max_steps} | 验证损失: {val_loss:.4f}")
checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt" checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt"
# 使用agent的归一化统计信息包含normalization_type
agent_stats = agent.get_normalization_stats()
torch.save({ torch.save({
'step': step, 'step': step,
'model_state_dict': agent.state_dict(), 'model_state_dict': agent.state_dict(),
@@ -353,7 +355,7 @@ def main(cfg: DictConfig):
'scheduler_state_dict': scheduler.state_dict(), 'scheduler_state_dict': scheduler.state_dict(),
'loss': loss.item(), 'loss': loss.item(),
'val_loss': val_loss, 'val_loss': val_loss,
'dataset_stats': dataset_stats, 'dataset_stats': agent_stats, # 保存agent的统计信息
'current_lr': optimizer.param_groups[0]['lr'], 'current_lr': optimizer.param_groups[0]['lr'],
}, checkpoint_path) }, checkpoint_path)
log.info(f"💾 检查点已保存: {checkpoint_path}") log.info(f"💾 检查点已保存: {checkpoint_path}")
@@ -363,6 +365,7 @@ def main(cfg: DictConfig):
if eval_loss < best_loss: if eval_loss < best_loss:
best_loss = eval_loss best_loss = eval_loss
best_model_path = checkpoint_dir / "vla_model_best.pt" best_model_path = checkpoint_dir / "vla_model_best.pt"
agent_stats = agent.get_normalization_stats()
torch.save({ torch.save({
'step': step, 'step': step,
'model_state_dict': agent.state_dict(), 'model_state_dict': agent.state_dict(),
@@ -370,7 +373,7 @@ def main(cfg: DictConfig):
'scheduler_state_dict': scheduler.state_dict(), 'scheduler_state_dict': scheduler.state_dict(),
'loss': loss.item(), 'loss': loss.item(),
'val_loss': val_loss, 'val_loss': val_loss,
'dataset_stats': dataset_stats, 'dataset_stats': agent_stats, # 保存agent的统计信息
'current_lr': optimizer.param_groups[0]['lr'], 'current_lr': optimizer.param_groups[0]['lr'],
}, best_model_path) }, best_model_path)
log.info(f"🌟 最佳模型已更新: {best_model_path} (验证损失: {best_loss:.4f})") log.info(f"🌟 最佳模型已更新: {best_model_path} (验证损失: {best_loss:.4f})")
@@ -379,13 +382,14 @@ def main(cfg: DictConfig):
# 6. 保存最终模型 # 6. 保存最终模型
# ========================================================================= # =========================================================================
final_model_path = checkpoint_dir / "vla_model_final.pt" final_model_path = checkpoint_dir / "vla_model_final.pt"
agent_stats = agent.get_normalization_stats()
torch.save({ torch.save({
'step': cfg.train.max_steps, 'step': cfg.train.max_steps,
'model_state_dict': agent.state_dict(), 'model_state_dict': agent.state_dict(),
'optimizer_state_dict': optimizer.state_dict(), 'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(), 'scheduler_state_dict': scheduler.state_dict(),
'loss': loss.item(), 'loss': loss.item(),
'dataset_stats': dataset_stats, 'dataset_stats': agent_stats, # 保存agent的统计信息
'current_lr': optimizer.param_groups[0]['lr'], 'current_lr': optimizer.param_groups[0]['lr'],
}, final_model_path) }, final_model_path)
log.info(f"💾 最终模型已保存: {final_model_path}") log.info(f"💾 最终模型已保存: {final_model_path}")

View File

@@ -240,7 +240,12 @@ class VLAAgent(nn.Module):
if device is not None and self.normalization.enabled: if device is not None and self.normalization.enabled:
# 确保归一化参数在同一设备上 # 确保归一化参数在同一设备上
norm_device = self.normalization.qpos_mean.device # 根据归一化类型获取正确的属性
if self.normalization.normalization_type == 'gaussian':
norm_device = self.normalization.qpos_mean.device
else: # min_max
norm_device = self.normalization.qpos_min.device
if device != norm_device: if device != norm_device:
self.normalization.to(device) self.normalization.to(device)
# 同时确保其他模块也在正确设备 # 同时确保其他模块也在正确设备