debug: 保存stats到ckpt

This commit is contained in:
gouhanke
2026-02-12 13:00:43 +08:00
parent ab971b3f96
commit 37a47ac2dd
2 changed files with 13 additions and 4 deletions

View File

@@ -346,6 +346,8 @@ def main(cfg: DictConfig):
log.info(f"步骤 {step}/{cfg.train.max_steps} | 验证损失: {val_loss:.4f}")
checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt"
# 使用agent的归一化统计信息包含normalization_type
agent_stats = agent.get_normalization_stats()
torch.save({
'step': step,
'model_state_dict': agent.state_dict(),
@@ -353,7 +355,7 @@ def main(cfg: DictConfig):
'scheduler_state_dict': scheduler.state_dict(),
'loss': loss.item(),
'val_loss': val_loss,
'dataset_stats': dataset_stats,
'dataset_stats': agent_stats, # 保存agent的统计信息
'current_lr': optimizer.param_groups[0]['lr'],
}, checkpoint_path)
log.info(f"💾 检查点已保存: {checkpoint_path}")
@@ -363,6 +365,7 @@ def main(cfg: DictConfig):
if eval_loss < best_loss:
best_loss = eval_loss
best_model_path = checkpoint_dir / "vla_model_best.pt"
agent_stats = agent.get_normalization_stats()
torch.save({
'step': step,
'model_state_dict': agent.state_dict(),
@@ -370,7 +373,7 @@ def main(cfg: DictConfig):
'scheduler_state_dict': scheduler.state_dict(),
'loss': loss.item(),
'val_loss': val_loss,
'dataset_stats': dataset_stats,
'dataset_stats': agent_stats, # 保存agent的统计信息
'current_lr': optimizer.param_groups[0]['lr'],
}, best_model_path)
log.info(f"🌟 最佳模型已更新: {best_model_path} (验证损失: {best_loss:.4f})")
@@ -379,13 +382,14 @@ def main(cfg: DictConfig):
# 6. 保存最终模型
# =========================================================================
final_model_path = checkpoint_dir / "vla_model_final.pt"
agent_stats = agent.get_normalization_stats()
torch.save({
'step': cfg.train.max_steps,
'model_state_dict': agent.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'loss': loss.item(),
'dataset_stats': dataset_stats,
'dataset_stats': agent_stats, # 保存agent的统计信息
'current_lr': optimizer.param_groups[0]['lr'],
}, final_model_path)
log.info(f"💾 最终模型已保存: {final_model_path}")

View File

@@ -240,7 +240,12 @@ class VLAAgent(nn.Module):
if device is not None and self.normalization.enabled:
# 确保归一化参数在同一设备上
# 根据归一化类型获取正确的属性
if self.normalization.normalization_type == 'gaussian':
norm_device = self.normalization.qpos_mean.device
else: # min_max
norm_device = self.normalization.qpos_min.device
if device != norm_device:
self.normalization.to(device)
# 同时确保其他模块也在正确设备