debug: 保存stats到ckpt
This commit is contained in:
@@ -346,6 +346,8 @@ def main(cfg: DictConfig):
|
||||
log.info(f"步骤 {step}/{cfg.train.max_steps} | 验证损失: {val_loss:.4f}")
|
||||
|
||||
checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt"
|
||||
# 使用agent的归一化统计信息(包含normalization_type)
|
||||
agent_stats = agent.get_normalization_stats()
|
||||
torch.save({
|
||||
'step': step,
|
||||
'model_state_dict': agent.state_dict(),
|
||||
@@ -353,7 +355,7 @@ def main(cfg: DictConfig):
|
||||
'scheduler_state_dict': scheduler.state_dict(),
|
||||
'loss': loss.item(),
|
||||
'val_loss': val_loss,
|
||||
'dataset_stats': dataset_stats,
|
||||
'dataset_stats': agent_stats, # 保存agent的统计信息
|
||||
'current_lr': optimizer.param_groups[0]['lr'],
|
||||
}, checkpoint_path)
|
||||
log.info(f"💾 检查点已保存: {checkpoint_path}")
|
||||
@@ -363,6 +365,7 @@ def main(cfg: DictConfig):
|
||||
if eval_loss < best_loss:
|
||||
best_loss = eval_loss
|
||||
best_model_path = checkpoint_dir / "vla_model_best.pt"
|
||||
agent_stats = agent.get_normalization_stats()
|
||||
torch.save({
|
||||
'step': step,
|
||||
'model_state_dict': agent.state_dict(),
|
||||
@@ -370,7 +373,7 @@ def main(cfg: DictConfig):
|
||||
'scheduler_state_dict': scheduler.state_dict(),
|
||||
'loss': loss.item(),
|
||||
'val_loss': val_loss,
|
||||
'dataset_stats': dataset_stats,
|
||||
'dataset_stats': agent_stats, # 保存agent的统计信息
|
||||
'current_lr': optimizer.param_groups[0]['lr'],
|
||||
}, best_model_path)
|
||||
log.info(f"🌟 最佳模型已更新: {best_model_path} (验证损失: {best_loss:.4f})")
|
||||
@@ -379,13 +382,14 @@ def main(cfg: DictConfig):
|
||||
# 6. 保存最终模型
|
||||
# =========================================================================
|
||||
final_model_path = checkpoint_dir / "vla_model_final.pt"
|
||||
agent_stats = agent.get_normalization_stats()
|
||||
torch.save({
|
||||
'step': cfg.train.max_steps,
|
||||
'model_state_dict': agent.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'scheduler_state_dict': scheduler.state_dict(),
|
||||
'loss': loss.item(),
|
||||
'dataset_stats': dataset_stats,
|
||||
'dataset_stats': agent_stats, # 保存agent的统计信息
|
||||
'current_lr': optimizer.param_groups[0]['lr'],
|
||||
}, final_model_path)
|
||||
log.info(f"💾 最终模型已保存: {final_model_path}")
|
||||
|
||||
@@ -240,7 +240,12 @@ class VLAAgent(nn.Module):
|
||||
|
||||
if device is not None and self.normalization.enabled:
|
||||
# 确保归一化参数在同一设备上
|
||||
# 根据归一化类型获取正确的属性
|
||||
if self.normalization.normalization_type == 'gaussian':
|
||||
norm_device = self.normalization.qpos_mean.device
|
||||
else: # min_max
|
||||
norm_device = self.normalization.qpos_min.device
|
||||
|
||||
if device != norm_device:
|
||||
self.normalization.to(device)
|
||||
# 同时确保其他模块也在正确设备
|
||||
|
||||
Reference in New Issue
Block a user