debug: 保存stats到ckpt
This commit is contained in:
@@ -346,6 +346,8 @@ def main(cfg: DictConfig):
|
|||||||
log.info(f"步骤 {step}/{cfg.train.max_steps} | 验证损失: {val_loss:.4f}")
|
log.info(f"步骤 {step}/{cfg.train.max_steps} | 验证损失: {val_loss:.4f}")
|
||||||
|
|
||||||
checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt"
|
checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt"
|
||||||
|
# 使用agent的归一化统计信息(包含normalization_type)
|
||||||
|
agent_stats = agent.get_normalization_stats()
|
||||||
torch.save({
|
torch.save({
|
||||||
'step': step,
|
'step': step,
|
||||||
'model_state_dict': agent.state_dict(),
|
'model_state_dict': agent.state_dict(),
|
||||||
@@ -353,7 +355,7 @@ def main(cfg: DictConfig):
|
|||||||
'scheduler_state_dict': scheduler.state_dict(),
|
'scheduler_state_dict': scheduler.state_dict(),
|
||||||
'loss': loss.item(),
|
'loss': loss.item(),
|
||||||
'val_loss': val_loss,
|
'val_loss': val_loss,
|
||||||
'dataset_stats': dataset_stats,
|
'dataset_stats': agent_stats, # 保存agent的统计信息
|
||||||
'current_lr': optimizer.param_groups[0]['lr'],
|
'current_lr': optimizer.param_groups[0]['lr'],
|
||||||
}, checkpoint_path)
|
}, checkpoint_path)
|
||||||
log.info(f"💾 检查点已保存: {checkpoint_path}")
|
log.info(f"💾 检查点已保存: {checkpoint_path}")
|
||||||
@@ -363,6 +365,7 @@ def main(cfg: DictConfig):
|
|||||||
if eval_loss < best_loss:
|
if eval_loss < best_loss:
|
||||||
best_loss = eval_loss
|
best_loss = eval_loss
|
||||||
best_model_path = checkpoint_dir / "vla_model_best.pt"
|
best_model_path = checkpoint_dir / "vla_model_best.pt"
|
||||||
|
agent_stats = agent.get_normalization_stats()
|
||||||
torch.save({
|
torch.save({
|
||||||
'step': step,
|
'step': step,
|
||||||
'model_state_dict': agent.state_dict(),
|
'model_state_dict': agent.state_dict(),
|
||||||
@@ -370,7 +373,7 @@ def main(cfg: DictConfig):
|
|||||||
'scheduler_state_dict': scheduler.state_dict(),
|
'scheduler_state_dict': scheduler.state_dict(),
|
||||||
'loss': loss.item(),
|
'loss': loss.item(),
|
||||||
'val_loss': val_loss,
|
'val_loss': val_loss,
|
||||||
'dataset_stats': dataset_stats,
|
'dataset_stats': agent_stats, # 保存agent的统计信息
|
||||||
'current_lr': optimizer.param_groups[0]['lr'],
|
'current_lr': optimizer.param_groups[0]['lr'],
|
||||||
}, best_model_path)
|
}, best_model_path)
|
||||||
log.info(f"🌟 最佳模型已更新: {best_model_path} (验证损失: {best_loss:.4f})")
|
log.info(f"🌟 最佳模型已更新: {best_model_path} (验证损失: {best_loss:.4f})")
|
||||||
@@ -379,13 +382,14 @@ def main(cfg: DictConfig):
|
|||||||
# 6. 保存最终模型
|
# 6. 保存最终模型
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
final_model_path = checkpoint_dir / "vla_model_final.pt"
|
final_model_path = checkpoint_dir / "vla_model_final.pt"
|
||||||
|
agent_stats = agent.get_normalization_stats()
|
||||||
torch.save({
|
torch.save({
|
||||||
'step': cfg.train.max_steps,
|
'step': cfg.train.max_steps,
|
||||||
'model_state_dict': agent.state_dict(),
|
'model_state_dict': agent.state_dict(),
|
||||||
'optimizer_state_dict': optimizer.state_dict(),
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
'scheduler_state_dict': scheduler.state_dict(),
|
'scheduler_state_dict': scheduler.state_dict(),
|
||||||
'loss': loss.item(),
|
'loss': loss.item(),
|
||||||
'dataset_stats': dataset_stats,
|
'dataset_stats': agent_stats, # 保存agent的统计信息
|
||||||
'current_lr': optimizer.param_groups[0]['lr'],
|
'current_lr': optimizer.param_groups[0]['lr'],
|
||||||
}, final_model_path)
|
}, final_model_path)
|
||||||
log.info(f"💾 最终模型已保存: {final_model_path}")
|
log.info(f"💾 最终模型已保存: {final_model_path}")
|
||||||
|
|||||||
@@ -240,7 +240,12 @@ class VLAAgent(nn.Module):
|
|||||||
|
|
||||||
if device is not None and self.normalization.enabled:
|
if device is not None and self.normalization.enabled:
|
||||||
# 确保归一化参数在同一设备上
|
# 确保归一化参数在同一设备上
|
||||||
norm_device = self.normalization.qpos_mean.device
|
# 根据归一化类型获取正确的属性
|
||||||
|
if self.normalization.normalization_type == 'gaussian':
|
||||||
|
norm_device = self.normalization.qpos_mean.device
|
||||||
|
else: # min_max
|
||||||
|
norm_device = self.normalization.qpos_min.device
|
||||||
|
|
||||||
if device != norm_device:
|
if device != norm_device:
|
||||||
self.normalization.to(device)
|
self.normalization.to(device)
|
||||||
# 同时确保其他模块也在正确设备
|
# 同时确保其他模块也在正确设备
|
||||||
|
|||||||
Reference in New Issue
Block a user