feat: 添加保存模型的功能和推理脚本
This commit is contained in:
100
roboimi/demos/vla_scripts/eval_vla.py
Normal file
100
roboimi/demos/vla_scripts/eval_vla.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import hydra
|
||||||
|
import torch
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
from omegaconf import DictConfig, OmegaConf
|
||||||
|
from hydra.utils import instantiate
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
# 确保能导入 roboimi
|
||||||
|
sys.path.append(os.getcwd())
|
||||||
|
from roboimi.vla.agent import VLAAgent
|
||||||
|
|
||||||
|
def recursive_to_device(data, device):
|
||||||
|
if isinstance(data, torch.Tensor):
|
||||||
|
return data.to(device)
|
||||||
|
elif isinstance(data, dict):
|
||||||
|
return {k: recursive_to_device(v, device) for k, v in data.items()}
|
||||||
|
return data
|
||||||
|
|
||||||
|
@hydra.main(version_base=None, config_path="../../../roboimi/vla/conf", config_name="config")
|
||||||
|
def main(cfg: DictConfig):
|
||||||
|
print(">>> 🤖 Starting VLA Inference...")
|
||||||
|
device = cfg.train.device
|
||||||
|
|
||||||
|
# 1. 实例化 Agent (结构必须与训练时完全一致)
|
||||||
|
# 也可以在这里覆盖配置,例如 forcing freeze=True
|
||||||
|
agent: VLAAgent = instantiate(cfg.agent)
|
||||||
|
agent.to(device)
|
||||||
|
agent.eval() # 关键:切换到 Eval 模式
|
||||||
|
|
||||||
|
# 2. 加载权重
|
||||||
|
ckpt_path = "checkpoints/vla_model_final.pt"
|
||||||
|
if not os.path.exists(ckpt_path):
|
||||||
|
print(f"❌ Checkpoint not found at {ckpt_path}. Run training first!")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"Loading weights from {ckpt_path}...")
|
||||||
|
# map_location='cpu' 防止在只有 CPU 的机器上加载 GPU 权重报错
|
||||||
|
state_dict = torch.load(ckpt_path, map_location=device)
|
||||||
|
agent.load_state_dict(state_dict)
|
||||||
|
print("✅ Weights loaded successfully.")
|
||||||
|
|
||||||
|
# 3. 准备测试数据 (从 Dataset 里取一个样本)
|
||||||
|
dataset = instantiate(cfg.data)
|
||||||
|
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
|
||||||
|
sample = next(iter(dataloader))
|
||||||
|
|
||||||
|
# 准备输入 (模拟机器人实时运行)
|
||||||
|
# 注意:推理时不需要传 sample['actions']
|
||||||
|
primary_cam_key = cfg.data.obs_keys[0]
|
||||||
|
input_img = sample['obs'][primary_cam_key][:, -1, :, :, :] # (1, C, H, W)
|
||||||
|
|
||||||
|
agent_input = {
|
||||||
|
"obs": {
|
||||||
|
"image": input_img.to(device),
|
||||||
|
"text": sample["language"] # 即使不用文本,占位符也要留着
|
||||||
|
}
|
||||||
|
# ⚠️ 关键:这里不传 'actions',触发 Agent 进入 Inference 分支
|
||||||
|
}
|
||||||
|
|
||||||
|
# 4. 执行推理 (Reverse Diffusion)
|
||||||
|
print("running reverse diffusion (this may take a moment)...")
|
||||||
|
with torch.no_grad():
|
||||||
|
# 这会触发 DiffusionHead 的分支 B (loop over timesteps)
|
||||||
|
outputs = agent(agent_input)
|
||||||
|
|
||||||
|
# 5. 获取结果
|
||||||
|
# 输出 shape: (1, Chunk_Size, Action_Dim)
|
||||||
|
pred_actions = outputs['pred_actions'].cpu().numpy()[0]
|
||||||
|
gt_actions = sample['actions'][0].numpy() # 用来对比
|
||||||
|
|
||||||
|
print(f"✅ Generated Action Chunk Shape: {pred_actions.shape}")
|
||||||
|
|
||||||
|
# 6. 可视化对比 (保存图片)
|
||||||
|
plot_results(pred_actions, gt_actions)
|
||||||
|
|
||||||
|
def plot_results(pred, gt):
|
||||||
|
"""
|
||||||
|
简单的可视化:画出前几个维度的轨迹对比
|
||||||
|
"""
|
||||||
|
plt.figure(figsize=(10, 5))
|
||||||
|
|
||||||
|
# 比如只画前 3 个维度 (x, y, z)
|
||||||
|
dims_to_plot = 3
|
||||||
|
for i in range(dims_to_plot):
|
||||||
|
plt.subplot(1, dims_to_plot, i+1)
|
||||||
|
plt.plot(gt[:, i], 'g--', label='Ground Truth')
|
||||||
|
plt.plot(pred[:, i], 'b-', label='Diffusion Pred')
|
||||||
|
plt.title(f"Action Dim {i}")
|
||||||
|
if i == 0: plt.legend()
|
||||||
|
plt.ylim(-1, 1) # 假设动作是归一化的
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig("inference_result.png")
|
||||||
|
print("📊 Result plot saved to 'inference_result.png'")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -95,6 +95,17 @@ def main(cfg: DictConfig):
|
|||||||
|
|
||||||
log.info("✅ Training Loop with Real HDF5 Finished!")
|
log.info("✅ Training Loop with Real HDF5 Finished!")
|
||||||
|
|
||||||
|
# --- 6. Save Checkpoint ---
|
||||||
|
save_dir = "checkpoints"
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
save_path = os.path.join(save_dir, "vla_model_final.pt")
|
||||||
|
|
||||||
|
# 保存整个 Agent 的 state_dict
|
||||||
|
torch.save(agent.state_dict(), save_path)
|
||||||
|
log.info(f"💾 Model saved to {save_path}")
|
||||||
|
|
||||||
|
log.info("✅ Training Loop Finished!")
|
||||||
|
|
||||||
def recursive_to_device(data, device):
|
def recursive_to_device(data, device):
|
||||||
if isinstance(data, torch.Tensor):
|
if isinstance(data, torch.Tensor):
|
||||||
return data.to(device)
|
return data.to(device)
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ defaults:
|
|||||||
train:
|
train:
|
||||||
batch_size: 4 # 减小 batch size 方便调试
|
batch_size: 4 # 减小 batch size 方便调试
|
||||||
lr: 1e-4
|
lr: 1e-4
|
||||||
max_steps: 100
|
max_steps: 10
|
||||||
log_freq: 10
|
log_freq: 10
|
||||||
device: "cpu"
|
device: "cpu"
|
||||||
num_workers: 0 # 调试设为0,验证通过后改为 2 或 4
|
num_workers: 0 # 调试设为0,验证通过后改为 2 或 4
|
||||||
Reference in New Issue
Block a user