diff --git a/roboimi/demos/vla_scripts/eval_vla.py b/roboimi/demos/vla_scripts/eval_vla.py new file mode 100644 index 0000000..848ded6 --- /dev/null +++ b/roboimi/demos/vla_scripts/eval_vla.py @@ -0,0 +1,100 @@ +import sys +import os +import hydra +import torch +import matplotlib.pyplot as plt +import numpy as np +from omegaconf import DictConfig, OmegaConf +from hydra.utils import instantiate +from torch.utils.data import DataLoader + +# 确保能导入 roboimi +sys.path.append(os.getcwd()) +from roboimi.vla.agent import VLAAgent + +def recursive_to_device(data, device): + if isinstance(data, torch.Tensor): + return data.to(device) + elif isinstance(data, dict): + return {k: recursive_to_device(v, device) for k, v in data.items()} + return data + +@hydra.main(version_base=None, config_path="../../../roboimi/vla/conf", config_name="config") +def main(cfg: DictConfig): + print(">>> 🤖 Starting VLA Inference...") + device = cfg.train.device + + # 1. 实例化 Agent (结构必须与训练时完全一致) + # 也可以在这里覆盖配置,例如 forcing freeze=True + agent: VLAAgent = instantiate(cfg.agent) + agent.to(device) + agent.eval() # 关键:切换到 Eval 模式 + + # 2. 加载权重 + ckpt_path = "checkpoints/vla_model_final.pt" + if not os.path.exists(ckpt_path): + print(f"❌ Checkpoint not found at {ckpt_path}. Run training first!") + return + + print(f"Loading weights from {ckpt_path}...") + # map_location='cpu' 防止在只有 CPU 的机器上加载 GPU 权重报错 + state_dict = torch.load(ckpt_path, map_location=device) + agent.load_state_dict(state_dict) + print("✅ Weights loaded successfully.") + + # 3. 准备测试数据 (从 Dataset 里取一个样本) + dataset = instantiate(cfg.data) + dataloader = DataLoader(dataset, batch_size=1, shuffle=True) + sample = next(iter(dataloader)) + + # 准备输入 (模拟机器人实时运行) + # 注意:推理时不需要传 sample['actions'] + primary_cam_key = cfg.data.obs_keys[0] + input_img = sample['obs'][primary_cam_key][:, -1, :, :, :] # (1, C, H, W) + + agent_input = { + "obs": { + "image": input_img.to(device), + "text": sample["language"] # 即使不用文本,占位符也要留着 + } + # ⚠️ 关键:这里不传 'actions',触发 Agent 进入 Inference 分支 + } + + # 4. 执行推理 (Reverse Diffusion) + print("running reverse diffusion (this may take a moment)...") + with torch.no_grad(): + # 这会触发 DiffusionHead 的分支 B (loop over timesteps) + outputs = agent(agent_input) + + # 5. 获取结果 + # 输出 shape: (1, Chunk_Size, Action_Dim) + pred_actions = outputs['pred_actions'].cpu().numpy()[0] + gt_actions = sample['actions'][0].numpy() # 用来对比 + + print(f"✅ Generated Action Chunk Shape: {pred_actions.shape}") + + # 6. 可视化对比 (保存图片) + plot_results(pred_actions, gt_actions) + +def plot_results(pred, gt): + """ + 简单的可视化:画出前几个维度的轨迹对比 + """ + plt.figure(figsize=(10, 5)) + + # 比如只画前 3 个维度 (x, y, z) + dims_to_plot = 3 + for i in range(dims_to_plot): + plt.subplot(1, dims_to_plot, i+1) + plt.plot(gt[:, i], 'g--', label='Ground Truth') + plt.plot(pred[:, i], 'b-', label='Diffusion Pred') + plt.title(f"Action Dim {i}") + if i == 0: plt.legend() + plt.ylim(-1, 1) # 假设动作是归一化的 + + plt.tight_layout() + plt.savefig("inference_result.png") + print("📊 Result plot saved to 'inference_result.png'") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/roboimi/demos/vla_scripts/train_vla.py b/roboimi/demos/vla_scripts/train_vla.py index 2e54b9a..8206c1d 100644 --- a/roboimi/demos/vla_scripts/train_vla.py +++ b/roboimi/demos/vla_scripts/train_vla.py @@ -95,6 +95,17 @@ def main(cfg: DictConfig): log.info("✅ Training Loop with Real HDF5 Finished!") +# --- 6. Save Checkpoint --- + save_dir = "checkpoints" + os.makedirs(save_dir, exist_ok=True) + save_path = os.path.join(save_dir, "vla_model_final.pt") + + # 保存整个 Agent 的 state_dict + torch.save(agent.state_dict(), save_path) + log.info(f"💾 Model saved to {save_path}") + + log.info("✅ Training Loop Finished!") + def recursive_to_device(data, device): if isinstance(data, torch.Tensor): return data.to(device) diff --git a/roboimi/vla/conf/config.yaml b/roboimi/vla/conf/config.yaml index 65ebea6..89661f2 100644 --- a/roboimi/vla/conf/config.yaml +++ b/roboimi/vla/conf/config.yaml @@ -6,7 +6,7 @@ defaults: train: batch_size: 4 # 减小 batch size 方便调试 lr: 1e-4 - max_steps: 100 + max_steps: 10 log_freq: 10 device: "cpu" num_workers: 0 # 调试设为0,验证通过后改为 2 或 4 \ No newline at end of file