feat: 添加保存模型的功能和推理脚本

This commit is contained in:
gouhanke
2026-02-03 18:03:47 +08:00
parent f5e2eca809
commit 3465782256
3 changed files with 112 additions and 1 deletions

View File

@@ -0,0 +1,100 @@
import sys
import os
import hydra
import torch
import matplotlib.pyplot as plt
import numpy as np
from omegaconf import DictConfig, OmegaConf
from hydra.utils import instantiate
from torch.utils.data import DataLoader
# 确保能导入 roboimi
sys.path.append(os.getcwd())
from roboimi.vla.agent import VLAAgent
def recursive_to_device(data, device):
if isinstance(data, torch.Tensor):
return data.to(device)
elif isinstance(data, dict):
return {k: recursive_to_device(v, device) for k, v in data.items()}
return data
@hydra.main(version_base=None, config_path="../../../roboimi/vla/conf", config_name="config")
def main(cfg: DictConfig):
print(">>> 🤖 Starting VLA Inference...")
device = cfg.train.device
# 1. 实例化 Agent (结构必须与训练时完全一致)
# 也可以在这里覆盖配置,例如 forcing freeze=True
agent: VLAAgent = instantiate(cfg.agent)
agent.to(device)
agent.eval() # 关键:切换到 Eval 模式
# 2. 加载权重
ckpt_path = "checkpoints/vla_model_final.pt"
if not os.path.exists(ckpt_path):
print(f"❌ Checkpoint not found at {ckpt_path}. Run training first!")
return
print(f"Loading weights from {ckpt_path}...")
# map_location='cpu' 防止在只有 CPU 的机器上加载 GPU 权重报错
state_dict = torch.load(ckpt_path, map_location=device)
agent.load_state_dict(state_dict)
print("✅ Weights loaded successfully.")
# 3. 准备测试数据 (从 Dataset 里取一个样本)
dataset = instantiate(cfg.data)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
sample = next(iter(dataloader))
# 准备输入 (模拟机器人实时运行)
# 注意:推理时不需要传 sample['actions']
primary_cam_key = cfg.data.obs_keys[0]
input_img = sample['obs'][primary_cam_key][:, -1, :, :, :] # (1, C, H, W)
agent_input = {
"obs": {
"image": input_img.to(device),
"text": sample["language"] # 即使不用文本,占位符也要留着
}
# ⚠️ 关键:这里不传 'actions',触发 Agent 进入 Inference 分支
}
# 4. 执行推理 (Reverse Diffusion)
print("running reverse diffusion (this may take a moment)...")
with torch.no_grad():
# 这会触发 DiffusionHead 的分支 B (loop over timesteps)
outputs = agent(agent_input)
# 5. 获取结果
# 输出 shape: (1, Chunk_Size, Action_Dim)
pred_actions = outputs['pred_actions'].cpu().numpy()[0]
gt_actions = sample['actions'][0].numpy() # 用来对比
print(f"✅ Generated Action Chunk Shape: {pred_actions.shape}")
# 6. 可视化对比 (保存图片)
plot_results(pred_actions, gt_actions)
def plot_results(pred, gt):
"""
简单的可视化:画出前几个维度的轨迹对比
"""
plt.figure(figsize=(10, 5))
# 比如只画前 3 个维度 (x, y, z)
dims_to_plot = 3
for i in range(dims_to_plot):
plt.subplot(1, dims_to_plot, i+1)
plt.plot(gt[:, i], 'g--', label='Ground Truth')
plt.plot(pred[:, i], 'b-', label='Diffusion Pred')
plt.title(f"Action Dim {i}")
if i == 0: plt.legend()
plt.ylim(-1, 1) # 假设动作是归一化的
plt.tight_layout()
plt.savefig("inference_result.png")
print("📊 Result plot saved to 'inference_result.png'")
if __name__ == "__main__":
main()

View File

@@ -95,6 +95,17 @@ def main(cfg: DictConfig):
log.info("✅ Training Loop with Real HDF5 Finished!") log.info("✅ Training Loop with Real HDF5 Finished!")
# --- 6. Save Checkpoint ---
save_dir = "checkpoints"
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, "vla_model_final.pt")
# 保存整个 Agent 的 state_dict
torch.save(agent.state_dict(), save_path)
log.info(f"💾 Model saved to {save_path}")
log.info("✅ Training Loop Finished!")
def recursive_to_device(data, device): def recursive_to_device(data, device):
if isinstance(data, torch.Tensor): if isinstance(data, torch.Tensor):
return data.to(device) return data.to(device)

View File

@@ -6,7 +6,7 @@ defaults:
train: train:
batch_size: 4 # 减小 batch size 方便调试 batch_size: 4 # 减小 batch size 方便调试
lr: 1e-4 lr: 1e-4
max_steps: 100 max_steps: 10
log_freq: 10 log_freq: 10
device: "cpu" device: "cpu"
num_workers: 0 # 调试设为0验证通过后改为 2 或 4 num_workers: 0 # 调试设为0验证通过后改为 2 或 4