Compare commits
5 Commits
23088e5e33
...
2f9b99e0c4
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2f9b99e0c4 | ||
|
|
d5d5b53f71 | ||
|
|
d84bc6876e | ||
|
|
424c265823 | ||
|
|
cb79e00546 |
@@ -0,0 +1,42 @@
|
||||
# Streaming HDF5 EE Action Dataset Implementation Plan
|
||||
|
||||
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||
|
||||
**Goal:** 将 Diana 仿真采集改为流式写入 HDF5,图像保存为 256x256 的四路相机视角,并把 `/action` 改为 IK 前的原始末端位姿动作。
|
||||
|
||||
**Architecture:** 新增一个独立的流式 HDF5 episode writer,负责逐帧写入 qpos、原始 action 和 resize 后图像,并在 episode 成功时原子提交、失败时删除临时文件。采集脚本只负责 rollout 和把每一步观测/动作交给 writer,避免整集数据先堆在内存里。
|
||||
|
||||
**Tech Stack:** Python, h5py, numpy, cv2, unittest, MuJoCo demo scripts
|
||||
|
||||
---
|
||||
|
||||
### Task 1: 为流式 writer 建立测试边界
|
||||
|
||||
**Files:**
|
||||
- Create: `tests/test_streaming_episode_writer.py`
|
||||
- Create: `roboimi/utils/streaming_episode_writer.py`
|
||||
|
||||
- [ ] **Step 1: Write the failing test**
|
||||
- [ ] **Step 2: Run `python -m unittest tests.test_streaming_episode_writer -v` and confirm it fails because the writer module does not exist**
|
||||
- [ ] **Step 3: Implement the minimal streaming writer with temp-file commit/discard, per-frame append, and 256x256 image resize**
|
||||
- [ ] **Step 4: Re-run `python -m unittest tests.test_streaming_episode_writer -v` and confirm it passes**
|
||||
|
||||
### Task 2: 接入 Diana 采集脚本
|
||||
|
||||
**Files:**
|
||||
- Modify: `roboimi/demos/diana_record_sim_episodes.py`
|
||||
- Reuse: `roboimi/utils/streaming_episode_writer.py`
|
||||
|
||||
- [ ] **Step 1: Replace in-memory `data_dict` / `obs` accumulation with per-episode streaming writer lifecycle**
|
||||
- [ ] **Step 2: Keep four cameras (`angle`, `r_vis`, `top`, `front`) and resize to 256x256 before persistence**
|
||||
- [ ] **Step 3: Capture raw policy output before IK and write that to `/action`**
|
||||
- [ ] **Step 4: On success commit to `episode_{idx}.hdf5`; on failure remove temp file**
|
||||
|
||||
### Task 3: 验证改动
|
||||
|
||||
**Files:**
|
||||
- Verify only
|
||||
|
||||
- [ ] **Step 1: Run unit tests for the writer**
|
||||
- [ ] **Step 2: Run one end-to-end collection episode and stop after `episode_0.hdf5` becomes readable**
|
||||
- [ ] **Step 3: Verify HDF5 keys and shapes: `action=(700,16)`, image datasets are `(700,256,256,3)`, and `/action` matches raw EE action semantics**
|
||||
@@ -0,0 +1,26 @@
|
||||
# Raw Action Trajectory Viewer Implementation Plan
|
||||
|
||||
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||
|
||||
**Goal:** 在可交互 MuJoCo 仿真窗口中,把 rollout 导出的 raw EE action 轨迹用红色轨迹标出来并启动仿真供人工查看。
|
||||
|
||||
**Architecture:** 读取已有 trajectory artifact 中的 raw_action / step 数据,生成左右臂末端轨迹点,并在 viewer 渲染循环中持续注入红色 marker。实现尽量独立为一个可复用的小脚本,避免影响训练/评估主路径。
|
||||
|
||||
**Tech Stack:** Python, NumPy, MuJoCo viewer, unittest/mock.
|
||||
|
||||
---
|
||||
|
||||
### Task 1: 抽取 raw_action 轨迹并生成可视化点集
|
||||
- [ ] 写失败测试,验证从 trajectory.npz 提取左右臂轨迹点
|
||||
- [ ] 实现最小 helper
|
||||
- [ ] 运行测试确认通过
|
||||
|
||||
### Task 2: 在 viewer 中渲染红色轨迹并支持交互查看
|
||||
- [ ] 写失败测试,验证 marker 配置/调用
|
||||
- [ ] 实现 viewer 可视化脚本
|
||||
- [ ] 运行测试确认通过
|
||||
|
||||
### Task 3: 启动真实仿真窗口供人工查看
|
||||
- [ ] 用现有 trajectory artifact 启动 viewer
|
||||
- [ ] 确认窗口可交互、红线出现
|
||||
- [ ] 向用户汇报启动方式与脚本路径
|
||||
44
docs/superpowers/plans/2026-03-31-rollout-artifacts.md
Normal file
44
docs/superpowers/plans/2026-03-31-rollout-artifacts.md
Normal file
@@ -0,0 +1,44 @@
|
||||
# Rollout Artifacts Implementation Plan
|
||||
|
||||
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||
|
||||
**Goal:** Extend rollout evaluation so one selected checkpoint can be run once with video capture, timing breakdown, and saved EE trajectory artifacts.
|
||||
|
||||
**Architecture:** Keep the implementation centered in `eval_vla.py` so existing training-time rollout validation remains compatible. Add config-gated artifact capture helpers, serialize outputs under the eval run directory, and add lightweight tests for helper behavior and summary wiring; default eval behavior must remain unchanged when artifact capture is off.
|
||||
|
||||
**Tech Stack:** Python, Hydra/OmegaConf, NumPy, OpenCV, JSON, PyTorch unittest/mocking.
|
||||
|
||||
---
|
||||
|
||||
### Task 1: Add artifact capture configuration and helper wiring
|
||||
|
||||
**Files:**
|
||||
- Modify: `roboimi/demos/vla_scripts/eval_vla.py`
|
||||
- Modify: `roboimi/vla/conf/eval/eval.yaml`
|
||||
- Test: `tests/test_eval_vla_rollout_artifacts.py`
|
||||
|
||||
- [ ] **Step 1: Write failing tests for optional artifact config / summary wiring**
|
||||
- [ ] **Step 2: Implement config-backed artifact flags and output paths with defaults that write nothing**
|
||||
- [ ] **Step 3: Verify existing eval call sites still work with defaults**
|
||||
|
||||
### Task 2: Add timing breakdown, video recording, and trajectory export
|
||||
|
||||
**Files:**
|
||||
- Modify: `roboimi/demos/vla_scripts/eval_vla.py`
|
||||
- Test: `tests/test_eval_vla_rollout_artifacts.py`
|
||||
|
||||
- [ ] **Step 1: Write failing tests for timing aggregation, trajectory serialization, and summary schema**
|
||||
- [ ] **Step 2: Implement per-step timing capture for `obs_read_ms`, `preprocess_ms`, `inference_ms`, `env_step_ms`, `loop_total_ms`**
|
||||
- [ ] **Step 3: Implement MP4 recording from a chosen camera stream and canonical `trajectory.npz` export using `left_link7/right_link7` executed poses after `env.step`**
|
||||
- [ ] **Step 4: Run focused tests and fix issues**
|
||||
|
||||
### Task 3: Stop training safely and execute one real rollout
|
||||
|
||||
**Files:**
|
||||
- Use: `roboimi/demos/vla_scripts/eval_vla.py`
|
||||
- Output: `runs/.../eval_artifacts/...`
|
||||
|
||||
- [ ] **Step 1: Stop the active training process, wait for exit, and confirm the target checkpoint is readable**
|
||||
- [ ] **Step 2: Select the latest completed checkpoint if an explicit one is not provided; fall back to prior completed / best checkpoint if needed**
|
||||
- [ ] **Step 3: Run one headless rollout with artifact capture enabled**
|
||||
- [ ] **Step 4: Verify the MP4 / timing summary / trajectory files exist and summarize findings**
|
||||
@@ -0,0 +1,241 @@
|
||||
# VLA Training + Headless Rollout + SwanLab Design
|
||||
|
||||
**Date:** 2026-03-30
|
||||
**Branch:** feat-align-dp-transformer-ee
|
||||
|
||||
## Goal
|
||||
在当前仓库中补齐默认 `resnet_transformer` / `Transformer1D` 路线的训练依赖,使用数据集 `/home/droid/project/diana_sim/sim_transfer` 启动训练;同时支持训练过程中的 SwanLab 标量日志上传,并为后续 rollout 验证提供 headless 模式,避免弹出 MuJoCo / OpenCV 图形界面。
|
||||
|
||||
## Non-Goals
|
||||
- 不重写整套训练框架
|
||||
- 不引入新的 workspace / callback 框架
|
||||
- 不在本轮做复杂的视频/媒体日志上传
|
||||
- 不修改数据集格式本身
|
||||
|
||||
## Current State
|
||||
- 默认训练配置已切到 `agent=resnet_transformer`,head 为 `Transformer1D`
|
||||
- 当前环境缺少训练所需的若干 Python 依赖:`diffusers`、`torchvision`、`einops`、`swanlab`
|
||||
- 评估环境 `make_sim_env(task_name)` 当前写死 `is_render=True`
|
||||
- 相机线程 `camera_viewer()` 默认会 `cv2.namedWindow/imshow`,即使只想拿图像也会弹窗
|
||||
- 训练脚本当前支持 train/val loss、checkpoint,但没有 SwanLab 集成
|
||||
- 数据集目录 `/home/droid/project/diana_sim/sim_transfer` 下已有 100 个 episode,但还没有 `dataset_stats.pkl`
|
||||
|
||||
## User Requirements
|
||||
1. 在现有 mamba 环境里补齐训练依赖
|
||||
2. 在 `/home/droid/project/diana_sim/sim_transfer` 上开始训练
|
||||
3. 如果训练中需要 rollout 验证,希望支持 headless,不弹 GUI
|
||||
4. 训练指标上传到 SwanLab
|
||||
5. 默认 SwanLab project 名为 `roboimi-vla`
|
||||
|
||||
## Proposed Approach
|
||||
采用“最小必要改造”方案:
|
||||
|
||||
### 1. Dependency Layer
|
||||
在现有 `roboimi` 环境中补齐缺失训练依赖,并优先保持现有环境名与脚本入口不变。
|
||||
|
||||
#### Install Plan
|
||||
- 环境:继续使用现有 mamba 环境 `roboimi`
|
||||
- 安装方式:
|
||||
- 优先使用当前 env 的 `python -m pip install`
|
||||
- 安装包:
|
||||
- `diffusers`
|
||||
- `torchvision`
|
||||
- `einops`
|
||||
- `swanlab`
|
||||
- 版本策略:
|
||||
- 优先选择与当前 `torch==2.4.0` 可兼容的最新可安装版本
|
||||
- 若出现兼容性问题,再回退到与 `torch 2.4` 对齐的稳定版本
|
||||
- 复现策略:
|
||||
- 本轮会把**实际安装成功的 resolved versions** 补写回仓库的环境定义文件,避免后续环境漂移
|
||||
|
||||
训练前验证以下 import:
|
||||
- `torch`
|
||||
- `hydra`
|
||||
- `omegaconf`
|
||||
- `diffusers`
|
||||
- `torchvision`
|
||||
- `einops`
|
||||
- `swanlab`
|
||||
- `cv2`
|
||||
- `h5py`
|
||||
- `mujoco`
|
||||
|
||||
### 2. Dataset Preparation
|
||||
直接复用现有 `SimpleRobotDataset`,仅将 `data.dataset_dir` 指向:
|
||||
- `/home/droid/project/diana_sim/sim_transfer`
|
||||
|
||||
训练前使用现有统计脚本生成:
|
||||
- `/home/droid/project/diana_sim/sim_transfer/dataset_stats.pkl`
|
||||
|
||||
统计文件生成命令目标为:
|
||||
- 从仓库根目录执行
|
||||
- 直接针对 `/home/droid/project/diana_sim/sim_transfer` 输出 stats
|
||||
- 训练脚本不再依赖默认数据目录
|
||||
|
||||
### 3. SwanLab Logging
|
||||
在训练脚本中增加一个轻量 logging 集成层:
|
||||
- 通过配置决定是否启用 SwanLab,默认启用
|
||||
- 默认 project:`roboimi-vla`
|
||||
- API key 不写入仓库,不写入配置文件,只通过本地登录状态或环境变量使用
|
||||
- 当 `train.use_swanlab=true` 时:
|
||||
- 若 `swanlab` 不可 import,训练直接 fail fast
|
||||
- 若未登录或认证失败,训练直接 fail fast
|
||||
- 每个训练日志点上传:
|
||||
- `train/loss`
|
||||
- `train/lr`
|
||||
- `train/best_loss`
|
||||
- `train/step`
|
||||
- 每次验证时上传:
|
||||
- `val/loss`
|
||||
- 训练结束时记录最终 checkpoint 路径与 best checkpoint 路径
|
||||
|
||||
### 4. Headless Rollout Design
|
||||
目标是让 rollout 验证可以“拿到图像观测,但不弹任何窗口”。
|
||||
|
||||
最小改造策略:
|
||||
- 给 `make_sim_env(...)` 增加 `headless` / `is_render` 参数
|
||||
- 给相机线程显示逻辑增加开关:
|
||||
- headless 时继续更新 `r_vis/top/front/...` 图像缓存
|
||||
- 但不执行 `cv2.namedWindow` / `cv2.imshow` / `cv2.waitKey`
|
||||
- 评估脚本中:
|
||||
- headless 时不调用 `env.render()`
|
||||
- 仍然允许 `env._get_image_obs()` 和 policy inference 正常运行
|
||||
|
||||
#### Training-Time Rollout Scope
|
||||
- 本轮**会提供一个可选的 checkpoint-time rollout validation 路径**,默认关闭
|
||||
- 启用后,在训练保存 checkpoint 时可以调用同仓库的 rollout/eval 逻辑做少量 episode 验证
|
||||
- 此路径要求支持**唯一权威开关** `eval.headless=true`,即:
|
||||
- 不弹 MuJoCo viewer
|
||||
- 不执行 `cv2.namedWindow / cv2.imshow / cv2.waitKey`
|
||||
- 仍可读取图像并完成策略推理
|
||||
- 默认情况下不增加频繁 rollout,以避免拖慢训练;只提供能力与配置开关
|
||||
|
||||
如果验证发现相机线程强依赖 GUI,我们的降级策略是:
|
||||
- 训练主流程 + SwanLab 必须先跑通
|
||||
- rollout validation 保持为显式可选能力
|
||||
- 但本轮仍要保证至少存在可调用的 headless 验证执行路径,而不是仅停留在文档层面
|
||||
|
||||
### 5. Training Execution Strategy
|
||||
分两步执行:
|
||||
|
||||
#### Step A: Smoke Run
|
||||
使用较小步数启动一次 smoke training,确认:
|
||||
- 数据集可正常读取
|
||||
- 统计文件可加载
|
||||
- 模型可实例化
|
||||
- 单步前后向正常
|
||||
- checkpoint 正常写出
|
||||
- SwanLab 成功上传标量
|
||||
|
||||
#### Step B: Real Training Run
|
||||
在 smoke run 成功后,再启动正式训练。
|
||||
|
||||
## Execution Commands
|
||||
|
||||
### A. Stats Generation
|
||||
从仓库根目录执行,生成:
|
||||
- `/home/droid/project/diana_sim/sim_transfer/dataset_stats.pkl`
|
||||
|
||||
命令模板:
|
||||
```bash
|
||||
/home/droid/.conda/envs/roboimi/bin/python roboimi/vla/scripts/calculate_stats.py \
|
||||
--dataset_dir /home/droid/project/diana_sim/sim_transfer
|
||||
```
|
||||
|
||||
### B. Smoke Training Command
|
||||
从仓库根目录执行,核心覆盖项包括:
|
||||
- `data.dataset_dir=/home/droid/project/diana_sim/sim_transfer`
|
||||
- 较小 `train.max_steps`
|
||||
- 较高日志频率
|
||||
- 启用 SwanLab
|
||||
- 输出目录使用当前运行目录下的 `checkpoints/`
|
||||
|
||||
命令模板:
|
||||
```bash
|
||||
/home/droid/.conda/envs/roboimi/bin/python roboimi/demos/vla_scripts/train_vla.py \
|
||||
data.dataset_dir=/home/droid/project/diana_sim/sim_transfer \
|
||||
train.max_steps=20 \
|
||||
train.log_freq=1 \
|
||||
train.save_freq=10 \
|
||||
train.use_swanlab=true \
|
||||
train.swanlab_project=roboimi-vla \
|
||||
train.rollout_validate_on_checkpoint=false
|
||||
```
|
||||
|
||||
### C. Real Training Command
|
||||
从仓库根目录执行,核心覆盖项包括:
|
||||
- `data.dataset_dir=/home/droid/project/diana_sim/sim_transfer`
|
||||
- 正式 `train.max_steps`
|
||||
- 默认 project=`roboimi-vla`
|
||||
- 若启用 rollout validation,则传入 `eval.headless=true` 以及训练侧 rollout 开关
|
||||
|
||||
命令模板:
|
||||
```bash
|
||||
/home/droid/.conda/envs/roboimi/bin/python roboimi/demos/vla_scripts/train_vla.py \
|
||||
data.dataset_dir=/home/droid/project/diana_sim/sim_transfer \
|
||||
train.use_swanlab=true \
|
||||
train.swanlab_project=roboimi-vla \
|
||||
train.rollout_validate_on_checkpoint=true \
|
||||
eval.headless=true
|
||||
```
|
||||
|
||||
### D. Output Behavior
|
||||
- checkpoint 输出目录:当前工作目录下的 `checkpoints/`
|
||||
- 关键文件:
|
||||
- `checkpoints/vla_model_step_<N>.pt`
|
||||
- `checkpoints/vla_model_best.pt`
|
||||
- `checkpoints/vla_model_final.pt`
|
||||
|
||||
## File-Level Changes
|
||||
- `environment.yml`
|
||||
- 补写新增训练依赖,保证后续可复现
|
||||
- `roboimi/demos/vla_scripts/train_vla.py`
|
||||
- 增加 SwanLab 集成
|
||||
- 增加更明确的数据集目录覆盖支持
|
||||
- 增加可选 checkpoint-time rollout validation 入口
|
||||
- 保持当前 optimizer 对齐逻辑不变
|
||||
- `roboimi/vla/conf/config.yaml`
|
||||
- 增加/扩展训练日志、SwanLab、rollout 相关配置项
|
||||
- `roboimi/vla/conf/eval/eval.yaml`
|
||||
- 增加 `headless` 等评估控制项
|
||||
- `roboimi/envs/double_pos_ctrl_env.py`
|
||||
- `make_sim_env` 支持 headless / no-render
|
||||
- `roboimi/envs/double_base.py`
|
||||
- 相机采集与 GUI 显示解耦
|
||||
- `roboimi/vla/scripts/calculate_stats.py`
|
||||
- 改为直接支持通过命令行传入外部 `dataset_dir`
|
||||
- tests(新增)
|
||||
- 覆盖 SwanLab 可选初始化路径
|
||||
- 覆盖 headless 环境下“不弹窗但可取图”的关键逻辑
|
||||
|
||||
## Validation Plan
|
||||
1. 补齐依赖后验证 import 全通过
|
||||
2. 生成 `dataset_stats.pkl`
|
||||
3. 运行训练 smoke run
|
||||
4. 确认 SwanLab dashboard 在 project `roboimi-vla` 下有标量更新
|
||||
5. 若启用 rollout 验证:确认 headless 下不弹 GUI,且 rollout 路径能真正执行
|
||||
6. 再启动正式训练
|
||||
|
||||
## Config Contract
|
||||
本轮新增/固定的配置键以以下形式为准:
|
||||
- `train.use_swanlab: true|false`
|
||||
- `train.swanlab_project: roboimi-vla`
|
||||
- `train.rollout_validate_on_checkpoint: true|false`
|
||||
- `eval.headless: true|false`
|
||||
|
||||
## Risks and Mitigations
|
||||
- **Risk:** GUI/相机线程与离屏渲染耦合
|
||||
- **Mitigation:** 先解耦显示与图像更新;必要时把 rollout 验证降级为第二阶段
|
||||
- **Risk:** 现有 env 依赖不完整
|
||||
- **Mitigation:** 先做 import 验证,再做 smoke run
|
||||
- **Risk:** 数据集过大导致 smoke run 也很慢
|
||||
- **Mitigation:** smoke run 只跑极小步数
|
||||
- **Risk:** SwanLab API key 泄漏
|
||||
- **Mitigation:** 不写入代码/配置,只保存在本地登录态或环境变量
|
||||
|
||||
## Success Criteria
|
||||
- 训练脚本能在 `/home/droid/project/diana_sim/sim_transfer` 上启动
|
||||
- 能成功写出 checkpoint 到 `checkpoints/`
|
||||
- SwanLab 在 `roboimi-vla` 项目下能看到 train/val 标量
|
||||
- headless rollout 具备不弹 GUI 的执行路径
|
||||
- 若训练侧启用 rollout validation,则该路径可以在 headless 模式下被实际调用
|
||||
@@ -0,0 +1,16 @@
|
||||
# Rollout Artifacts Design
|
||||
|
||||
**Goal:** Add a one-off evaluation path that can record rollout video, export per-step timing breakdowns, and save executed end-effector trajectories for a selected checkpoint while preserving default eval behavior when artifact capture is disabled.
|
||||
|
||||
**Approach:** Extend `roboimi/demos/vla_scripts/eval_vla.py` with optional evaluation-time artifact capture that stays backward compatible when disabled. Reuse existing environment observation and camera streams, record one camera stream to MP4, collect per-step timing around observation read / preprocessing / model inference / env step / total loop, and save per-step raw predicted EE actions plus executed EE poses after stepping.
|
||||
|
||||
**Artifact contract:**
|
||||
- `video.mp4`: optional MP4 encoded from a selected camera stream (`r_vis`, `top`, `front`, etc.), written only when recording is enabled.
|
||||
- `trajectory.npz`: canonical trajectory export containing at minimum `step`, `reward`, `raw_action`, `executed_left_link7_pos`, `executed_left_link7_quat`, `executed_right_link7_pos`, `executed_right_link7_quat`, and optional duplicated tool-body poses if captured.
|
||||
- `timing.json`: JSON-serializable per-episode timing summary with millisecond units for `obs_read_ms`, `preprocess_ms`, `inference_ms`, `env_step_ms`, `loop_total_ms`, plus aggregate mean/std/min/max and counts. Raw per-step timing arrays should also be persisted in the NPZ for later analysis.
|
||||
|
||||
**Checkpoint selection:** Prefer an explicitly requested checkpoint path. If the caller asks for “latest” or omits a path in the execution helper, select the newest fully written checkpoint file by mtime/name and fail clearly if none exists.
|
||||
|
||||
**Stop-training / execution safety:** Before rollout, stop any active training process using the target run, wait for process exit, then verify the chosen checkpoint exists and is readable. If the most recent checkpoint is missing or mid-write, fall back to the previous completed checkpoint or `vla_model_best.pt` with the decision logged.
|
||||
|
||||
**Backward compatibility:** With all new eval flags left at default values, `_run_eval` return shape must remain compatible with existing callers, training-time rollout validation should continue to work without passing new options, and no artifact files should be written.
|
||||
@@ -229,6 +229,11 @@ dependencies:
|
||||
- python-xxhash=3.6.0
|
||||
- python_abi=3.10
|
||||
- pytorch=2.4.0
|
||||
- hydra-core=1.3.2
|
||||
- omegaconf=2.3.0
|
||||
- einops=0.8.2
|
||||
- diffusers=0.36.0
|
||||
- torchvision=0.19.0
|
||||
- pytz=2024.1
|
||||
- pyyaml=6.0.3
|
||||
- qhull=2020.2
|
||||
@@ -321,12 +326,10 @@ dependencies:
|
||||
- datasets==4.5.0
|
||||
- decorator==5.2.1
|
||||
- deepdiff==8.6.1
|
||||
- diffusers==0.30.0
|
||||
- dill==0.4.0
|
||||
- docstring_parser==0.17.0
|
||||
- draccus==0.10.0
|
||||
- eigenpy==3.10.3
|
||||
- einops==0.8.1
|
||||
- etils==1.7.0
|
||||
- evdev==1.9.2
|
||||
- exceptiongroup==1.3.1
|
||||
@@ -350,7 +353,6 @@ dependencies:
|
||||
- httpcore==1.0.9
|
||||
- httpx==0.28.1
|
||||
- huggingface_hub==1.3.2
|
||||
- hydra-core==1.3.2
|
||||
- imageio==2.35.1
|
||||
- imageio-ffmpeg==0.6.0
|
||||
- importlib_metadata==8.7.1
|
||||
@@ -380,22 +382,6 @@ dependencies:
|
||||
- networkx==3.4.2
|
||||
- numcodecs==0.13.1
|
||||
- numpy==2.2.6
|
||||
- nvidia-cublas-cu12==12.4.5.8
|
||||
- nvidia-cuda-cupti-cu12==12.4.127
|
||||
- nvidia-cuda-nvrtc-cu12==12.4.127
|
||||
- nvidia-cuda-runtime-cu12==12.4.127
|
||||
- nvidia-cudnn-cu12==9.1.0.70
|
||||
- nvidia-cufft-cu12==11.2.1.3
|
||||
- nvidia-cufile-cu12==1.11.1.6
|
||||
- nvidia-curand-cu12==10.3.5.147
|
||||
- nvidia-cusolver-cu12==11.6.1.9
|
||||
- nvidia-cusparse-cu12==12.3.1.170
|
||||
- nvidia-cusparselt-cu12==0.6.3
|
||||
- nvidia-nccl-cu12==2.21.5
|
||||
- nvidia-nvjitlink-cu12==12.4.127
|
||||
- nvidia-nvshmem-cu12==3.3.20
|
||||
- nvidia-nvtx-cu12==12.4.127
|
||||
- omegaconf==2.3.0
|
||||
- opencv-contrib-python==4.10.0.84
|
||||
- opencv-python==4.13.0.90
|
||||
- orderly-set==5.5.0
|
||||
@@ -431,7 +417,7 @@ dependencies:
|
||||
- regex==2026.1.15
|
||||
- requests==2.32.5
|
||||
- rerun-sdk==0.26.2
|
||||
- rich==14.2.0
|
||||
- rich==13.9.4
|
||||
- ruckig==0.9.2
|
||||
- safehttpx==0.1.7
|
||||
- safetensors==0.7.0
|
||||
@@ -443,18 +429,16 @@ dependencies:
|
||||
- stack-data==0.6.3
|
||||
- starlette==0.50.0
|
||||
- sympy==1.13.1
|
||||
- swanlab==0.7.13
|
||||
- termcolor==3.3.0
|
||||
- timm==1.0.24
|
||||
- toml==0.10.2
|
||||
- tomli==2.4.0
|
||||
- tomlkit==0.13.3
|
||||
- torch==2.5.0
|
||||
- torchcodec==0.5
|
||||
- torchmetrics==1.8.2
|
||||
- torchvision==0.20.0
|
||||
- tqdm==4.67.1
|
||||
- traitlets==5.14.3
|
||||
- triton==3.1.0
|
||||
- typer==0.21.1
|
||||
- typer-slim==0.21.1
|
||||
- typeshed_client==2.8.2
|
||||
|
||||
@@ -1,8 +1,46 @@
|
||||
import mujoco
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from roboimi.utils.KDL_utils import KDL_utils
|
||||
|
||||
|
||||
def resolve_robot_asset_path(asset_path):
|
||||
if asset_path is None:
|
||||
return None
|
||||
|
||||
raw_path = Path(asset_path).expanduser()
|
||||
if raw_path.is_absolute():
|
||||
return str(raw_path.resolve())
|
||||
|
||||
current_dir = Path(__file__).resolve().parent
|
||||
package_root = current_dir.parents[1]
|
||||
repo_root = current_dir.parents[2]
|
||||
|
||||
candidates = []
|
||||
if raw_path.parts and raw_path.parts[0] == 'roboimi':
|
||||
candidates.append(repo_root / raw_path)
|
||||
|
||||
candidates.extend([
|
||||
current_dir / raw_path,
|
||||
package_root / raw_path,
|
||||
repo_root / raw_path,
|
||||
])
|
||||
|
||||
normalized_candidates = []
|
||||
seen = set()
|
||||
for candidate in candidates:
|
||||
resolved = candidate.resolve()
|
||||
if resolved not in seen:
|
||||
normalized_candidates.append(resolved)
|
||||
seen.add(resolved)
|
||||
|
||||
for candidate in normalized_candidates:
|
||||
if candidate.exists():
|
||||
return str(candidate)
|
||||
|
||||
return str(normalized_candidates[0])
|
||||
|
||||
|
||||
class ArmBase(object):
|
||||
def __init__(self,
|
||||
name=None,
|
||||
@@ -11,8 +49,8 @@ class ArmBase(object):
|
||||
gripper=None
|
||||
):
|
||||
self.name = name
|
||||
self.urdf_path = urdf_path
|
||||
self.xml_path = xml_path
|
||||
self.urdf_path = resolve_robot_asset_path(urdf_path)
|
||||
self.xml_path = resolve_robot_asset_path(xml_path)
|
||||
self.gripper = gripper
|
||||
self.robot_model = mujoco.MjModel.from_xml_path(filename=self.xml_path, assets=None)
|
||||
self.robot_data = mujoco.MjData(self.robot_model)
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import time
|
||||
import os,collections,sys
|
||||
import os
|
||||
import numpy as np
|
||||
import h5py
|
||||
from roboimi.envs.double_pos_ctrl_env import make_sim_env
|
||||
from diana_policy import TestPickAndTransferPolicy
|
||||
import cv2
|
||||
from roboimi.utils.act_ex_utils import sample_transfer_pose
|
||||
from roboimi.utils.streaming_episode_writer import StreamingEpisodeWriter
|
||||
|
||||
import pathlib
|
||||
HOME_PATH = str(pathlib.Path(__file__).parent.resolve())
|
||||
@@ -16,14 +16,12 @@ def main():
|
||||
task_name = 'sim_transfer'
|
||||
dataset_dir = DATASET_DIR + '/sim_transfer' #SIM_TASK_CONFIGS[task_name]['dataset_dir']
|
||||
num_episodes = 100 #SIM_TASK_CONFIGS[task_name]['num_episodes']
|
||||
onscreen_render = None #config['onscreen_render']
|
||||
inject_noise = False
|
||||
render_cam_name = 'angle'
|
||||
|
||||
episode_len = 700 #SIM_TASK_CONFIGS[task_name]['episode_len']
|
||||
camera_names = ['angle','r_vis', 'top', 'front'] #SIM_TASK_CONFIGS[task_name]['camera_names']
|
||||
image_size = (256, 256)
|
||||
if task_name == 'sim_transfer':
|
||||
policy = TestPickAndTransferPolicy(inject_noise)
|
||||
print(task_name)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
@@ -39,62 +37,38 @@ def main():
|
||||
print("osmesa已就绪,开始收集数据...")
|
||||
|
||||
for episode_idx in range(num_episodes):
|
||||
obs = []
|
||||
reward_ee = []
|
||||
sum_reward = 0.0
|
||||
max_reward = float('-inf')
|
||||
print(f'\n{episode_idx=}')
|
||||
print('Rollout out EE space scripted policy')
|
||||
box_pos = sample_transfer_pose()
|
||||
env.reset(box_pos)
|
||||
episode_writer = StreamingEpisodeWriter(
|
||||
dataset_path=os.path.join(dataset_dir, f'episode_{episode_idx}.hdf5'),
|
||||
max_timesteps=episode_len,
|
||||
camera_names=camera_names,
|
||||
image_size=image_size,
|
||||
)
|
||||
for step in range(episode_len):
|
||||
|
||||
|
||||
action = policy.predict(box_pos,step)
|
||||
env.step(action)
|
||||
raw_action = policy.predict(box_pos,step)
|
||||
env.step(raw_action)
|
||||
env.render()
|
||||
reward_ee.append(env.rew)
|
||||
obs.append(env.obs)
|
||||
sum_reward = np.sum(reward_ee)
|
||||
max_reward = np.max(reward_ee)
|
||||
sum_reward += env.rew
|
||||
max_reward = max(max_reward, env.rew)
|
||||
episode_writer.append(
|
||||
qpos=env.obs['qpos'],
|
||||
action=raw_action,
|
||||
images=env.obs['images'],
|
||||
)
|
||||
if max_reward == env.max_reward:
|
||||
success.append(1)
|
||||
print(f"{episode_idx=} Successful, {sum_reward=}")
|
||||
t0 = time.time()
|
||||
data_dict = {
|
||||
'/observations/qpos': [],
|
||||
'/action': [],
|
||||
}
|
||||
|
||||
for cam_name in camera_names:
|
||||
data_dict[f'/observations/images/{cam_name}'] = []
|
||||
for i in range(episode_len):
|
||||
print("type qpos==",obs[i]['qpos'])
|
||||
data_dict['/observations/qpos'].append(obs[i]['qpos'])
|
||||
data_dict['/action'].append(obs[i]['action'])
|
||||
for cam_name in camera_names:
|
||||
data_dict[f'/observations/images/{cam_name}'].append(obs[i]['images'][cam_name])
|
||||
|
||||
dataset_path = os.path.join(dataset_dir, f'episode_{episode_idx}')
|
||||
|
||||
with h5py.File(dataset_path + '.hdf5', 'w', rdcc_nbytes=1024 ** 2 * 2) as root:
|
||||
max_timesteps = episode_len
|
||||
root.attrs['sim'] = True
|
||||
obs_ = root.create_group('observations')
|
||||
image = obs_.create_group('images')
|
||||
for cam_name in camera_names:
|
||||
_ = image.create_dataset(cam_name, (max_timesteps, 480, 640, 3), dtype='uint8',
|
||||
chunks=(1, 480, 640, 3), )
|
||||
qpos = obs_.create_dataset('qpos', (max_timesteps, 16))
|
||||
action = root.create_dataset('action', (max_timesteps, 16))
|
||||
for name, array in data_dict.items():
|
||||
root[name][...] = np.array(array)
|
||||
episode_writer.commit()
|
||||
else:
|
||||
success.append(0)
|
||||
print(f"{episode_idx=} Failed")
|
||||
print(max_reward)
|
||||
del obs
|
||||
del reward_ee
|
||||
del sum_reward
|
||||
del max_reward
|
||||
episode_writer.discard()
|
||||
|
||||
# del policy
|
||||
# env.viewer.close()
|
||||
@@ -108,4 +82,4 @@ def main():
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
main()
|
||||
|
||||
36
roboimi/demos/view_raw_action_trajectory.py
Normal file
36
roboimi/demos/view_raw_action_trajectory.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from roboimi.utils.raw_action_trajectory_viewer import launch_raw_action_trajectory_viewer
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Launch an interactive MuJoCo viewer with raw-action trajectory overlay.")
|
||||
parser.add_argument("trajectory_path", help="Path to raw_action.npy or trajectory.npz")
|
||||
parser.add_argument("--task-name", default="sim_transfer")
|
||||
parser.add_argument("--line-radius", type=float, default=0.004)
|
||||
parser.add_argument("--max-markers", type=int, default=1500)
|
||||
parser.add_argument(
|
||||
"--box-pos",
|
||||
type=float,
|
||||
nargs=3,
|
||||
default=None,
|
||||
help="Optional box xyz to use when resetting the environment",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
box_pos = np.asarray(args.box_pos, dtype=np.float32) if args.box_pos is not None else None
|
||||
launch_raw_action_trajectory_viewer(
|
||||
args.trajectory_path,
|
||||
task_name=args.task_name,
|
||||
line_radius=args.line_radius,
|
||||
max_markers=args.max_markers,
|
||||
box_pos=box_pos,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -19,7 +19,7 @@ import torch
|
||||
import numpy as np
|
||||
import hydra
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from typing import Any, Dict, Optional
|
||||
from tqdm import tqdm
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from hydra.utils import instantiate
|
||||
@@ -27,6 +27,7 @@ from einops import rearrange
|
||||
|
||||
from roboimi.envs.double_pos_ctrl_env import make_sim_env
|
||||
from roboimi.utils.act_ex_utils import sample_transfer_pose
|
||||
from roboimi.vla.eval_utils import execute_policy_action
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
@@ -121,6 +122,317 @@ def prepare_observation(obs: Dict, camera_names: list) -> Dict:
|
||||
return {'qpos': qpos, 'images': images}
|
||||
|
||||
|
||||
def _to_numpy_action(action: Any) -> np.ndarray:
|
||||
if isinstance(action, torch.Tensor):
|
||||
return action.detach().cpu().numpy().astype(np.float32, copy=True)
|
||||
return np.asarray(action, dtype=np.float32).copy()
|
||||
|
||||
|
||||
def _mean_or_zero(values: list[float]) -> float:
|
||||
return float(np.mean(values)) if values else 0.0
|
||||
|
||||
|
||||
def _stats_or_zero(values: list[float]) -> dict[str, float]:
|
||||
if not values:
|
||||
return {
|
||||
'mean': 0.0,
|
||||
'std': 0.0,
|
||||
'min': 0.0,
|
||||
'max': 0.0,
|
||||
}
|
||||
array = np.asarray(values, dtype=np.float64)
|
||||
return {
|
||||
'mean': float(array.mean()),
|
||||
'std': float(array.std()),
|
||||
'min': float(array.min()),
|
||||
'max': float(array.max()),
|
||||
}
|
||||
|
||||
|
||||
def _summarize_timing_breakdown(
|
||||
all_timings: dict[str, list[float]],
|
||||
model_forward_flags: list[bool],
|
||||
) -> dict[str, Any]:
|
||||
model_forward_flags = [bool(flag) for flag in model_forward_flags]
|
||||
return {
|
||||
'count': int(len(model_forward_flags)),
|
||||
'model_forward_count': int(sum(model_forward_flags)),
|
||||
'all_steps_ms': {
|
||||
stage: _stats_or_zero(values)
|
||||
for stage, values in all_timings.items()
|
||||
},
|
||||
'model_forward_steps_ms': {
|
||||
stage: _stats_or_zero(
|
||||
[value for value, should_keep in zip(values, model_forward_flags) if should_keep]
|
||||
)
|
||||
for stage, values in all_timings.items()
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _json_friendly(value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
return {str(key): _json_friendly(item) for key, item in value.items()}
|
||||
if isinstance(value, (list, tuple)):
|
||||
return [_json_friendly(item) for item in value]
|
||||
if isinstance(value, Path):
|
||||
return str(value)
|
||||
if isinstance(value, np.ndarray):
|
||||
return value.tolist()
|
||||
if isinstance(value, (np.integer, np.floating)):
|
||||
return value.item()
|
||||
return value
|
||||
|
||||
|
||||
def _resolve_artifact_paths(eval_cfg: DictConfig) -> dict[str, Optional[str]]:
|
||||
save_timing = bool(eval_cfg.get('save_timing', False))
|
||||
save_trajectory = bool(
|
||||
eval_cfg.get('save_trajectory', False) or eval_cfg.get('save_trajectory_npz', False)
|
||||
)
|
||||
wants_artifacts = any([
|
||||
bool(eval_cfg.get('save_artifacts', False)),
|
||||
save_timing,
|
||||
save_trajectory,
|
||||
bool(eval_cfg.get('record_video', False)),
|
||||
])
|
||||
output_dir: Optional[Path] = None
|
||||
if wants_artifacts:
|
||||
artifact_dir = eval_cfg.get('artifact_dir', None)
|
||||
if artifact_dir:
|
||||
output_dir = Path(str(artifact_dir)).expanduser().resolve()
|
||||
else:
|
||||
ckpt_stem = Path(str(eval_cfg.ckpt_path)).stem or 'rollout'
|
||||
timestamp = time.strftime('%Y%m%d-%H%M%S')
|
||||
output_dir = (Path.cwd() / 'rollout_artifacts' / f'{ckpt_stem}-{timestamp}').resolve()
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
video_camera_name = None
|
||||
if bool(eval_cfg.get('record_video', False)):
|
||||
configured_camera_name = eval_cfg.get('video_camera_name', None)
|
||||
if configured_camera_name is None:
|
||||
configured_camera_name = eval_cfg.get('video_camera', None)
|
||||
if configured_camera_name is not None:
|
||||
video_camera_name = str(configured_camera_name)
|
||||
elif eval_cfg.get('camera_names'):
|
||||
video_camera_name = str(eval_cfg.camera_names[0])
|
||||
else:
|
||||
raise ValueError('record_video=true requires eval.video_camera_name or a non-empty eval.camera_names')
|
||||
|
||||
return {
|
||||
'output_dir': str(output_dir) if output_dir is not None else None,
|
||||
'summary_json': (
|
||||
str(output_dir / 'rollout_summary.json')
|
||||
if output_dir is not None and bool(eval_cfg.get('save_summary_json', False))
|
||||
else None
|
||||
),
|
||||
'timing_json': (
|
||||
str(output_dir / 'timing.json')
|
||||
if output_dir is not None and save_timing
|
||||
else None
|
||||
),
|
||||
'trajectory_npz': (
|
||||
str(output_dir / 'trajectory.npz')
|
||||
if output_dir is not None and save_trajectory
|
||||
else None
|
||||
),
|
||||
'video_mp4': (
|
||||
str(output_dir / f'rollout_{video_camera_name}.mp4')
|
||||
if output_dir is not None and bool(eval_cfg.get('record_video', False))
|
||||
and video_camera_name is not None
|
||||
else None
|
||||
),
|
||||
'video_camera_name': video_camera_name,
|
||||
}
|
||||
|
||||
|
||||
def _get_video_frame(obs: Dict, camera_name: Optional[str]) -> Optional[np.ndarray]:
|
||||
if camera_name is None:
|
||||
return None
|
||||
frame = obs['images'][camera_name]
|
||||
frame = np.asarray(frame)
|
||||
if frame.ndim != 3 or frame.shape[2] != 3:
|
||||
raise ValueError(
|
||||
f'Video frame for camera {camera_name} must have shape (H, W, 3), got {frame.shape}'
|
||||
)
|
||||
if frame.dtype != np.uint8:
|
||||
frame = np.clip(frame, 0, 255).astype(np.uint8)
|
||||
return frame
|
||||
|
||||
|
||||
def _open_video_writer(output_path: str, frame_size: tuple[int, int], fps: int):
|
||||
import cv2
|
||||
|
||||
output_path = str(output_path)
|
||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||
writer = cv2.VideoWriter(output_path, fourcc, float(fps), frame_size)
|
||||
if not writer.isOpened():
|
||||
raise RuntimeError(f'无法打开视频输出: {output_path}')
|
||||
return writer
|
||||
|
||||
|
||||
class _RolloutVideoRecorder:
|
||||
def __init__(self, output_path: Optional[str], fps: int):
|
||||
self.output_path = output_path
|
||||
self.fps = int(fps)
|
||||
self.writer = None
|
||||
|
||||
def write(self, frame: Optional[np.ndarray]):
|
||||
if self.output_path is None or frame is None:
|
||||
return
|
||||
if self.writer is None:
|
||||
frame_size = (int(frame.shape[1]), int(frame.shape[0]))
|
||||
self.writer = _open_video_writer(self.output_path, frame_size, self.fps)
|
||||
self.writer.write(frame)
|
||||
|
||||
def close(self):
|
||||
if self.writer is not None:
|
||||
self.writer.release()
|
||||
self.writer = None
|
||||
|
||||
|
||||
def _read_body_pose(env, body_name: str):
|
||||
try:
|
||||
if callable(getattr(env, 'getBodyPos', None)) and callable(getattr(env, 'getBodyQuat', None)):
|
||||
pos = env.getBodyPos(body_name)
|
||||
quat = env.getBodyQuat(body_name)
|
||||
else:
|
||||
body = env.mj_data.body(body_name)
|
||||
pos = body.xpos
|
||||
quat = body.xquat
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return {
|
||||
'pos': np.asarray(pos, dtype=np.float32).copy(),
|
||||
'quat': np.asarray(quat, dtype=np.float32).copy(),
|
||||
}
|
||||
|
||||
|
||||
def _get_executed_ee_poses(env) -> dict[str, np.ndarray]:
|
||||
candidates = {
|
||||
'left_link7': ('left_link7', 'eef_left'),
|
||||
'right_link7': ('right_link7', 'eef_right'),
|
||||
'eef_left': ('eef_left', 'left_link7'),
|
||||
'eef_right': ('eef_right', 'right_link7'),
|
||||
}
|
||||
poses = {}
|
||||
for body_key, body_names in candidates.items():
|
||||
pose = None
|
||||
for body_name in body_names:
|
||||
pose = _read_body_pose(env, body_name)
|
||||
if pose is not None:
|
||||
break
|
||||
if pose is None:
|
||||
pose = {
|
||||
'pos': np.full(3, np.nan, dtype=np.float32),
|
||||
'quat': np.full(4, np.nan, dtype=np.float32),
|
||||
}
|
||||
poses[f'{body_key}_pos'] = pose['pos']
|
||||
poses[f'{body_key}_quat'] = pose['quat']
|
||||
return poses
|
||||
|
||||
|
||||
def _empty_rollout_trajectory() -> dict[str, list]:
|
||||
return {
|
||||
'episode_index': [],
|
||||
'step': [],
|
||||
'reward': [],
|
||||
'raw_action': [],
|
||||
'applied_action': [],
|
||||
'executed_left_link7_pos': [],
|
||||
'executed_left_link7_quat': [],
|
||||
'executed_right_link7_pos': [],
|
||||
'executed_right_link7_quat': [],
|
||||
'executed_eef_left_pos': [],
|
||||
'executed_eef_left_quat': [],
|
||||
'executed_eef_right_pos': [],
|
||||
'executed_eef_right_quat': [],
|
||||
'model_inference_triggered': [],
|
||||
'obs_read_time_ms': [],
|
||||
'preprocess_time_ms': [],
|
||||
'inference_time_ms': [],
|
||||
'env_step_time_ms': [],
|
||||
'total_time_ms': [],
|
||||
}
|
||||
|
||||
|
||||
def _append_rollout_step(
|
||||
storage: dict[str, list],
|
||||
episode_index: int,
|
||||
timestep: int,
|
||||
reward: Optional[float],
|
||||
raw_action: np.ndarray,
|
||||
executed_action: np.ndarray,
|
||||
executed_poses: dict[str, np.ndarray],
|
||||
timing_ms: dict[str, float],
|
||||
model_inference_triggered: bool,
|
||||
):
|
||||
storage['episode_index'].append(int(episode_index))
|
||||
storage['step'].append(int(timestep))
|
||||
storage['reward'].append(float(reward) if reward is not None else np.nan)
|
||||
storage['raw_action'].append(raw_action.astype(np.float32, copy=True))
|
||||
storage['applied_action'].append(executed_action.astype(np.float32, copy=True))
|
||||
storage['executed_left_link7_pos'].append(executed_poses['left_link7_pos'])
|
||||
storage['executed_left_link7_quat'].append(executed_poses['left_link7_quat'])
|
||||
storage['executed_right_link7_pos'].append(executed_poses['right_link7_pos'])
|
||||
storage['executed_right_link7_quat'].append(executed_poses['right_link7_quat'])
|
||||
storage['executed_eef_left_pos'].append(executed_poses['eef_left_pos'])
|
||||
storage['executed_eef_left_quat'].append(executed_poses['eef_left_quat'])
|
||||
storage['executed_eef_right_pos'].append(executed_poses['eef_right_pos'])
|
||||
storage['executed_eef_right_quat'].append(executed_poses['eef_right_quat'])
|
||||
storage['model_inference_triggered'].append(bool(model_inference_triggered))
|
||||
for key, value in timing_ms.items():
|
||||
storage[key].append(float(value))
|
||||
|
||||
|
||||
def _save_rollout_trajectory_npz(output_path: str, storage: dict[str, list]):
|
||||
step = np.asarray(storage['step'], dtype=np.int32)
|
||||
raw_action = np.asarray(storage['raw_action'], dtype=np.float32)
|
||||
applied_action = np.asarray(storage['applied_action'], dtype=np.float32)
|
||||
executed_left_link7_pos = np.asarray(storage['executed_left_link7_pos'], dtype=np.float32)
|
||||
executed_left_link7_quat = np.asarray(storage['executed_left_link7_quat'], dtype=np.float32)
|
||||
executed_right_link7_pos = np.asarray(storage['executed_right_link7_pos'], dtype=np.float32)
|
||||
executed_right_link7_quat = np.asarray(storage['executed_right_link7_quat'], dtype=np.float32)
|
||||
executed_eef_left_pos = np.asarray(storage['executed_eef_left_pos'], dtype=np.float32)
|
||||
executed_eef_left_quat = np.asarray(storage['executed_eef_left_quat'], dtype=np.float32)
|
||||
executed_eef_right_pos = np.asarray(storage['executed_eef_right_pos'], dtype=np.float32)
|
||||
executed_eef_right_quat = np.asarray(storage['executed_eef_right_quat'], dtype=np.float32)
|
||||
np.savez_compressed(
|
||||
output_path,
|
||||
episode_index=np.asarray(storage['episode_index'], dtype=np.int32),
|
||||
step=step,
|
||||
timestep=step,
|
||||
reward=np.asarray(storage['reward'], dtype=np.float32),
|
||||
raw_action=raw_action,
|
||||
raw_predicted_ee_action=raw_action,
|
||||
applied_action=applied_action,
|
||||
executed_ee_action=applied_action,
|
||||
executed_left_link7_pos=executed_left_link7_pos,
|
||||
executed_left_link7_quat=executed_left_link7_quat,
|
||||
executed_right_link7_pos=executed_right_link7_pos,
|
||||
executed_right_link7_quat=executed_right_link7_quat,
|
||||
executed_eef_left_pos=executed_eef_left_pos,
|
||||
executed_eef_left_quat=executed_eef_left_quat,
|
||||
executed_eef_right_pos=executed_eef_right_pos,
|
||||
executed_eef_right_quat=executed_eef_right_quat,
|
||||
left_ee_pos=executed_eef_left_pos,
|
||||
left_ee_quat=executed_eef_left_quat,
|
||||
right_ee_pos=executed_eef_right_pos,
|
||||
right_ee_quat=executed_eef_right_quat,
|
||||
model_inference_triggered=np.asarray(storage['model_inference_triggered'], dtype=bool),
|
||||
obs_read_time_ms=np.asarray(storage['obs_read_time_ms'], dtype=np.float32),
|
||||
preprocess_time_ms=np.asarray(storage['preprocess_time_ms'], dtype=np.float32),
|
||||
inference_time_ms=np.asarray(storage['inference_time_ms'], dtype=np.float32),
|
||||
env_step_time_ms=np.asarray(storage['env_step_time_ms'], dtype=np.float32),
|
||||
total_time_ms=np.asarray(storage['total_time_ms'], dtype=np.float32),
|
||||
)
|
||||
|
||||
|
||||
def _save_summary_json(output_path: str, summary: dict[str, Any]):
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(_json_friendly(summary), f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
class ActionSmoother:
|
||||
"""
|
||||
动作平滑器(指数移动平均)
|
||||
@@ -157,8 +469,23 @@ class ActionSmoother:
|
||||
self.prev_action = None
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config")
|
||||
def main(cfg: DictConfig):
|
||||
def _close_env(env):
|
||||
if env is None:
|
||||
return
|
||||
|
||||
if hasattr(env, 'exit_flag'):
|
||||
env.exit_flag = True
|
||||
|
||||
cam_thread = getattr(env, 'cam_thread', None)
|
||||
if cam_thread is not None and hasattr(cam_thread, 'join'):
|
||||
cam_thread.join(timeout=1.0)
|
||||
|
||||
viewer = getattr(env, 'viewer', None)
|
||||
if viewer is not None and hasattr(viewer, 'close'):
|
||||
viewer.close()
|
||||
|
||||
|
||||
def _run_eval(cfg: DictConfig):
|
||||
"""
|
||||
使用 agent 内置队列管理的简化版 VLA 评估
|
||||
|
||||
@@ -176,6 +503,18 @@ def main(cfg: DictConfig):
|
||||
eval_cfg = cfg.eval
|
||||
device = eval_cfg.device
|
||||
camera_names = list(eval_cfg.camera_names)
|
||||
artifact_paths = _resolve_artifact_paths(eval_cfg)
|
||||
video_recorder = _RolloutVideoRecorder(
|
||||
output_path=artifact_paths['video_mp4'],
|
||||
fps=int(eval_cfg.get('video_fps', 30)),
|
||||
)
|
||||
rollout_trajectory = _empty_rollout_trajectory()
|
||||
global_obs_read_times_ms = []
|
||||
global_preprocess_times_ms = []
|
||||
global_inference_times_ms = []
|
||||
global_env_step_times_ms = []
|
||||
global_total_times_ms = []
|
||||
global_model_forward_flags = []
|
||||
|
||||
# =========================================================================
|
||||
# 加载模型
|
||||
@@ -196,116 +535,261 @@ def main(cfg: DictConfig):
|
||||
# =========================================================================
|
||||
# 创建环境
|
||||
# =========================================================================
|
||||
env = make_sim_env(eval_cfg.task_name)
|
||||
env = make_sim_env(eval_cfg.task_name, headless=eval_cfg.headless)
|
||||
|
||||
# =========================================================================
|
||||
# 运行评估回合
|
||||
# =========================================================================
|
||||
all_stats = []
|
||||
episode_rewards = []
|
||||
episode_max_rewards = []
|
||||
try:
|
||||
for episode_idx in range(eval_cfg.num_episodes):
|
||||
print(f"\n{'='*60}")
|
||||
print(f"回合 {episode_idx + 1}/{eval_cfg.num_episodes}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
for episode_idx in range(eval_cfg.num_episodes):
|
||||
box_pos = sample_transfer_pose()
|
||||
env.reset(box_pos)
|
||||
|
||||
# 为新回合重置 agent 队列
|
||||
agent.reset()
|
||||
if smoother:
|
||||
smoother.reset()
|
||||
|
||||
# 计时统计
|
||||
obs_read_times_ms = []
|
||||
preprocess_times_ms = []
|
||||
inference_times_ms = []
|
||||
env_step_times_ms = []
|
||||
total_times_ms = []
|
||||
model_forward_flags = []
|
||||
episode_reward = 0.0
|
||||
episode_max_reward = float('-inf')
|
||||
|
||||
with torch.inference_mode():
|
||||
for t in tqdm(range(eval_cfg.max_timesteps), desc=f"回合 {episode_idx + 1}"):
|
||||
start_total = time.perf_counter()
|
||||
|
||||
# 从环境获取观测
|
||||
obs = env._get_image_obs()
|
||||
qpos_obs = env._get_qpos_obs()
|
||||
obs['qpos'] = qpos_obs['qpos']
|
||||
end_obs_read = time.perf_counter()
|
||||
|
||||
video_frame = _get_video_frame(obs, artifact_paths['video_camera_name'])
|
||||
video_recorder.write(video_frame)
|
||||
|
||||
# 准备给 agent 的观测
|
||||
observation = prepare_observation(obs, camera_names)
|
||||
end_preprocess = time.perf_counter()
|
||||
|
||||
# 选择动作(agent 内部处理队列管理)
|
||||
action_queue = getattr(agent, '_queues', {}).get('action', None)
|
||||
model_inference_triggered = len(action_queue) == 0 if action_queue is not None else True
|
||||
start_inference = time.perf_counter()
|
||||
action = agent.select_action(observation)
|
||||
|
||||
if str(device).startswith('cuda') and torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
end_inference = time.perf_counter()
|
||||
|
||||
# 转换为 numpy
|
||||
raw_action = _to_numpy_action(action)
|
||||
|
||||
# 调试:打印当前时间步的动作(由配置控制)
|
||||
if eval_cfg.get('verbose_action', False):
|
||||
print(f"\n[Step {t:3d}] 预测动作: {raw_action}")
|
||||
print(f" - 动作形状: {raw_action.shape}")
|
||||
print(f" - 动作范围: [{raw_action.min():.4f}, {raw_action.max():.4f}]")
|
||||
print(f" - 动作均值: {raw_action.mean():.4f}, 标准差: {raw_action.std():.4f}")
|
||||
|
||||
# 可选:平滑动作
|
||||
executed_action = raw_action.copy()
|
||||
if smoother:
|
||||
executed_action = smoother.smooth(executed_action)
|
||||
|
||||
# 执行动作
|
||||
start_env_step = time.perf_counter()
|
||||
execute_policy_action(env, executed_action)
|
||||
end_env_step = time.perf_counter()
|
||||
executed_poses = _get_executed_ee_poses(env)
|
||||
reward = getattr(env, 'rew', None)
|
||||
if reward is not None:
|
||||
reward = float(reward)
|
||||
episode_reward += reward
|
||||
episode_max_reward = max(episode_max_reward, reward)
|
||||
if not eval_cfg.headless:
|
||||
env.render()
|
||||
|
||||
end_total = time.perf_counter()
|
||||
|
||||
step_timing_ms = {
|
||||
'obs_read_time_ms': (end_obs_read - start_total) * 1000.0,
|
||||
'preprocess_time_ms': (end_preprocess - end_obs_read) * 1000.0,
|
||||
'inference_time_ms': (end_inference - start_inference) * 1000.0,
|
||||
'env_step_time_ms': (end_env_step - start_env_step) * 1000.0,
|
||||
'total_time_ms': (end_total - start_total) * 1000.0,
|
||||
}
|
||||
|
||||
# 记录计时
|
||||
obs_read_times_ms.append(step_timing_ms['obs_read_time_ms'])
|
||||
preprocess_times_ms.append(step_timing_ms['preprocess_time_ms'])
|
||||
inference_times_ms.append(step_timing_ms['inference_time_ms'])
|
||||
env_step_times_ms.append(step_timing_ms['env_step_time_ms'])
|
||||
total_times_ms.append(step_timing_ms['total_time_ms'])
|
||||
model_forward_flags.append(bool(model_inference_triggered))
|
||||
global_obs_read_times_ms.append(step_timing_ms['obs_read_time_ms'])
|
||||
global_preprocess_times_ms.append(step_timing_ms['preprocess_time_ms'])
|
||||
global_inference_times_ms.append(step_timing_ms['inference_time_ms'])
|
||||
global_env_step_times_ms.append(step_timing_ms['env_step_time_ms'])
|
||||
global_total_times_ms.append(step_timing_ms['total_time_ms'])
|
||||
global_model_forward_flags.append(bool(model_inference_triggered))
|
||||
|
||||
if artifact_paths['trajectory_npz'] is not None:
|
||||
_append_rollout_step(
|
||||
rollout_trajectory,
|
||||
episode_index=episode_idx,
|
||||
timestep=t,
|
||||
reward=reward,
|
||||
raw_action=raw_action,
|
||||
executed_action=executed_action,
|
||||
executed_poses=executed_poses,
|
||||
timing_ms=step_timing_ms,
|
||||
model_inference_triggered=model_inference_triggered,
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# 打印回合统计
|
||||
# =========================================================================
|
||||
avg_obs_read_time_ms = _mean_or_zero(obs_read_times_ms)
|
||||
avg_preprocess_time_ms = _mean_or_zero(preprocess_times_ms)
|
||||
avg_inference_time_ms = _mean_or_zero(inference_times_ms)
|
||||
avg_env_step_time_ms = _mean_or_zero(env_step_times_ms)
|
||||
avg_total_time_ms = _mean_or_zero(total_times_ms)
|
||||
timing_breakdown = _summarize_timing_breakdown(
|
||||
{
|
||||
'obs_read': obs_read_times_ms,
|
||||
'preprocess': preprocess_times_ms,
|
||||
'inference': inference_times_ms,
|
||||
'env_step': env_step_times_ms,
|
||||
'loop_total': total_times_ms,
|
||||
},
|
||||
model_forward_flags,
|
||||
)
|
||||
episode_artifact_paths = {
|
||||
'video': artifact_paths['video_mp4'],
|
||||
'trajectory': artifact_paths['trajectory_npz'],
|
||||
'timing': artifact_paths['timing_json'] or artifact_paths['summary_json'],
|
||||
}
|
||||
|
||||
stats = {
|
||||
'inference_fps': 1000.0 / avg_inference_time_ms if avg_inference_time_ms > 0 else 0.0,
|
||||
'control_fps': 1000.0 / avg_total_time_ms if avg_total_time_ms > 0 else 0.0,
|
||||
'avg_obs_read_time_ms': avg_obs_read_time_ms,
|
||||
'avg_preprocess_time_ms': avg_preprocess_time_ms,
|
||||
'avg_inference_time_ms': avg_inference_time_ms,
|
||||
'avg_env_step_time_ms': avg_env_step_time_ms,
|
||||
'avg_total_time_ms': avg_total_time_ms,
|
||||
'num_inferences': int(sum(model_forward_flags)),
|
||||
'num_model_forwards': int(sum(model_forward_flags)),
|
||||
'num_steps': len(total_times_ms),
|
||||
'episode_reward': float(episode_reward),
|
||||
'episode_max_reward': (
|
||||
float(episode_max_reward) if episode_max_reward != float('-inf') else None
|
||||
),
|
||||
'artifact_paths': episode_artifact_paths,
|
||||
'timing_breakdown_ms': timing_breakdown['all_steps_ms'],
|
||||
'timing_summary': timing_breakdown,
|
||||
}
|
||||
all_stats.append(stats)
|
||||
episode_rewards.append(float(episode_reward))
|
||||
if episode_max_reward != float('-inf'):
|
||||
episode_max_rewards.append(float(episode_max_reward))
|
||||
|
||||
print(f"\n回合 {episode_idx + 1} 完成 ({eval_cfg.max_timesteps} 时间步)")
|
||||
print(f" 模型推理 FPS: {stats['inference_fps']:.2f} Hz")
|
||||
print(f" 控制循环 FPS: {stats['control_fps']:.2f} Hz")
|
||||
print(f" 平均读观测时间: {stats['avg_obs_read_time_ms']:.2f} ms")
|
||||
print(f" 平均预处理时间: {stats['avg_preprocess_time_ms']:.2f} ms")
|
||||
print(f" 平均推理时间: {stats['avg_inference_time_ms']:.2f} ms")
|
||||
print(f" 平均环境步进时间: {stats['avg_env_step_time_ms']:.2f} ms")
|
||||
print(f" 平均总时间: {stats['avg_total_time_ms']:.2f} ms")
|
||||
print(f" 总推理次数: {stats['num_inferences']}")
|
||||
print(f" 回合累计奖励: {stats['episode_reward']:.2f}")
|
||||
|
||||
# =========================================================================
|
||||
# 总体统计
|
||||
# =========================================================================
|
||||
print(f"\n{'='*60}")
|
||||
print(f"回合 {episode_idx + 1}/{eval_cfg.num_episodes}")
|
||||
print(f"{'='*60}\n")
|
||||
print("评估完成!")
|
||||
print(f"{'='*60}")
|
||||
|
||||
box_pos = sample_transfer_pose()
|
||||
env.reset(box_pos)
|
||||
|
||||
# 为新回合重置 agent 队列
|
||||
agent.reset()
|
||||
if smoother:
|
||||
smoother.reset()
|
||||
|
||||
# 计时统计
|
||||
inference_times = []
|
||||
total_times = []
|
||||
|
||||
with torch.inference_mode():
|
||||
for t in tqdm(range(eval_cfg.max_timesteps), desc=f"回合 {episode_idx + 1}"):
|
||||
start_total = time.time()
|
||||
|
||||
# 从环境获取观测
|
||||
obs = env._get_image_obs()
|
||||
qpos_obs = env._get_qpos_obs()
|
||||
obs['qpos'] = qpos_obs['qpos']
|
||||
|
||||
# 准备给 agent 的观测
|
||||
observation = prepare_observation(obs, camera_names)
|
||||
|
||||
# 选择动作(agent 内部处理队列管理)
|
||||
start_inference = time.time()
|
||||
action = agent.select_action(observation)
|
||||
|
||||
if device == 'cuda':
|
||||
torch.cuda.synchronize()
|
||||
end_inference = time.time()
|
||||
|
||||
# 转换为 numpy
|
||||
action = action.cpu().numpy()
|
||||
|
||||
# 调试:打印当前时间步的动作(由配置控制)
|
||||
if eval_cfg.get('verbose_action', False):
|
||||
print(f"\n[Step {t:3d}] 预测动作: {action}")
|
||||
print(f" - 动作形状: {action.shape}")
|
||||
print(f" - 动作范围: [{action.min():.4f}, {action.max():.4f}]")
|
||||
print(f" - 动作均值: {action.mean():.4f}, 标准差: {action.std():.4f}")
|
||||
|
||||
# 可选:平滑动作
|
||||
if smoother:
|
||||
action = smoother.smooth(action)
|
||||
|
||||
# 执行动作
|
||||
env.step_jnt(action)
|
||||
env.render()
|
||||
|
||||
end_total = time.time()
|
||||
|
||||
# 记录计时
|
||||
inference_times.append(end_inference - start_inference)
|
||||
total_times.append(end_total - start_total)
|
||||
|
||||
# =========================================================================
|
||||
# 打印回合统计
|
||||
# =========================================================================
|
||||
avg_inference_time = np.mean(inference_times)
|
||||
avg_total_time = np.mean(total_times)
|
||||
|
||||
stats = {
|
||||
'inference_fps': 1.0 / avg_inference_time if avg_inference_time > 0 else 0.0,
|
||||
'control_fps': 1.0 / avg_total_time if avg_total_time > 0 else 0.0,
|
||||
'avg_inference_time_ms': avg_inference_time * 1000,
|
||||
'avg_total_time_ms': avg_total_time * 1000,
|
||||
'num_inferences': len([t for t in inference_times if t > 0.001]), # 统计实际推理次数
|
||||
'num_steps': len(total_times)
|
||||
summary = {
|
||||
'num_episodes': int(eval_cfg.num_episodes),
|
||||
'episode_rewards': episode_rewards,
|
||||
'episode_max_rewards': episode_max_rewards,
|
||||
'avg_reward': float(np.mean(episode_rewards)) if episode_rewards else 0.0,
|
||||
'avg_max_reward': float(np.mean(episode_max_rewards)) if episode_max_rewards else 0.0,
|
||||
'episodes': all_stats,
|
||||
'artifact_dir': artifact_paths['output_dir'],
|
||||
'artifacts': artifact_paths,
|
||||
}
|
||||
all_stats.append(stats)
|
||||
|
||||
print(f"\n回合 {episode_idx + 1} 完成 ({eval_cfg.max_timesteps} 时间步)")
|
||||
print(f" 模型推理 FPS: {stats['inference_fps']:.2f} Hz")
|
||||
print(f" 控制循环 FPS: {stats['control_fps']:.2f} Hz")
|
||||
print(f" 平均推理时间: {stats['avg_inference_time_ms']:.2f} ms")
|
||||
print(f" 平均总时间: {stats['avg_total_time_ms']:.2f} ms")
|
||||
print(f" 总推理次数: {stats['num_inferences']}")
|
||||
if all_stats:
|
||||
avg_inference_fps = np.mean([s['inference_fps'] for s in all_stats])
|
||||
avg_control_fps = np.mean([s['control_fps'] for s in all_stats])
|
||||
avg_obs_read_time = _mean_or_zero(global_obs_read_times_ms)
|
||||
avg_preprocess_time = _mean_or_zero(global_preprocess_times_ms)
|
||||
avg_inference_time = _mean_or_zero(global_inference_times_ms)
|
||||
avg_env_step_time = _mean_or_zero(global_env_step_times_ms)
|
||||
avg_total_time = _mean_or_zero(global_total_times_ms)
|
||||
summary.update({
|
||||
'avg_inference_fps': float(avg_inference_fps),
|
||||
'avg_control_fps': float(avg_control_fps),
|
||||
'avg_obs_read_time_ms': float(avg_obs_read_time),
|
||||
'avg_preprocess_time_ms': float(avg_preprocess_time),
|
||||
'avg_inference_time_ms': float(avg_inference_time),
|
||||
'avg_env_step_time_ms': float(avg_env_step_time),
|
||||
'avg_total_time_ms': float(avg_total_time),
|
||||
'timing_summary': _summarize_timing_breakdown(
|
||||
{
|
||||
'obs_read': global_obs_read_times_ms,
|
||||
'preprocess': global_preprocess_times_ms,
|
||||
'inference': global_inference_times_ms,
|
||||
'env_step': global_env_step_times_ms,
|
||||
'loop_total': global_total_times_ms,
|
||||
},
|
||||
global_model_forward_flags,
|
||||
),
|
||||
})
|
||||
|
||||
# =========================================================================
|
||||
# 总体统计
|
||||
# =========================================================================
|
||||
print(f"\n{'='*60}")
|
||||
print("评估完成!")
|
||||
print(f"{'='*60}")
|
||||
print(f"\n总体统计 ({eval_cfg.num_episodes} 个回合):")
|
||||
print(f" 平均模型推理 FPS: {avg_inference_fps:.2f} Hz")
|
||||
print(f" 平均控制循环 FPS: {avg_control_fps:.2f} Hz")
|
||||
print(f" 平均读观测时间: {avg_obs_read_time:.2f} ms")
|
||||
print(f" 平均预处理时间: {avg_preprocess_time:.2f} ms")
|
||||
print(f" 平均推理时间: {avg_inference_time:.2f} ms")
|
||||
print(f" 平均环境步进时间: {avg_env_step_time:.2f} ms")
|
||||
print(f" 平均总时间: {avg_total_time:.2f} ms")
|
||||
print(f" 平均累计奖励: {summary['avg_reward']:.2f}")
|
||||
|
||||
if all_stats:
|
||||
avg_inference_fps = np.mean([s['inference_fps'] for s in all_stats])
|
||||
avg_control_fps = np.mean([s['control_fps'] for s in all_stats])
|
||||
avg_inference_time = np.mean([s['avg_inference_time_ms'] for s in all_stats])
|
||||
avg_total_time = np.mean([s['avg_total_time_ms'] for s in all_stats])
|
||||
if artifact_paths['trajectory_npz'] is not None:
|
||||
_save_rollout_trajectory_npz(artifact_paths['trajectory_npz'], rollout_trajectory)
|
||||
if artifact_paths['summary_json'] is not None:
|
||||
_save_summary_json(artifact_paths['summary_json'], summary)
|
||||
if artifact_paths['timing_json'] is not None:
|
||||
_save_summary_json(artifact_paths['timing_json'], summary.get('timing_summary', {}))
|
||||
print()
|
||||
return _json_friendly(summary)
|
||||
finally:
|
||||
video_recorder.close()
|
||||
_close_env(env)
|
||||
|
||||
print(f"\n总体统计 ({eval_cfg.num_episodes} 个回合):")
|
||||
print(f" 平均模型推理 FPS: {avg_inference_fps:.2f} Hz")
|
||||
print(f" 平均控制循环 FPS: {avg_control_fps:.2f} Hz")
|
||||
print(f" 平均推理时间: {avg_inference_time:.2f} ms")
|
||||
print(f" 平均总时间: {avg_total_time:.2f} ms")
|
||||
print()
|
||||
|
||||
@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config")
|
||||
def main(cfg: DictConfig):
|
||||
return _run_eval(cfg)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -213,7 +213,9 @@ class DualDianaMed(MujocoEnv):
|
||||
|
||||
def camera_viewer(self):
|
||||
img_renderer = mj.Renderer(self.mj_model,height=480,width=640)
|
||||
cv2.namedWindow('Cam view',cv2.WINDOW_NORMAL)
|
||||
show_gui = self.is_render
|
||||
if show_gui:
|
||||
cv2.namedWindow('Cam view',cv2.WINDOW_NORMAL)
|
||||
while not self.exit_flag:
|
||||
img_renderer.update_scene(self.mj_data,camera="rs_cam_right")
|
||||
self.r_vis = img_renderer.render()
|
||||
@@ -230,9 +232,10 @@ class DualDianaMed(MujocoEnv):
|
||||
img_renderer.update_scene(self.mj_data,camera="front")
|
||||
self.front = img_renderer.render()
|
||||
self.front = self.front[:, :, ::-1]
|
||||
if self.cam_view is not None:
|
||||
cv2.imshow('Cam view', self.cam_view)
|
||||
cv2.waitKey(1)
|
||||
if show_gui:
|
||||
if self.cam_view is not None:
|
||||
cv2.imshow('Cam view', self.cam_view)
|
||||
cv2.waitKey(1)
|
||||
|
||||
|
||||
def cam_start(self):
|
||||
@@ -300,4 +303,4 @@ if __name__ == "__main__":
|
||||
# print("quat_right =",quat_right,"\n")
|
||||
if env.is_render:
|
||||
env.render()
|
||||
|
||||
|
||||
|
||||
@@ -133,12 +133,12 @@ class DualDianaMed_Pos_Ctrl(DualDianaMed):
|
||||
return reward
|
||||
|
||||
|
||||
def make_sim_env(task_name):
|
||||
def make_sim_env(task_name, headless=False):
|
||||
if 'sim_transfer' in task_name:
|
||||
from roboimi.assets.robots.diana_med import BiDianaMed
|
||||
env = DualDianaMed_Pos_Ctrl(
|
||||
robot=BiDianaMed(),
|
||||
is_render=True,
|
||||
is_render=not headless,
|
||||
control_freq=30,
|
||||
is_interpolate=True,
|
||||
cam_view='angle'
|
||||
@@ -167,4 +167,4 @@ if __name__ == "__main__":
|
||||
env.step(action)
|
||||
if env.is_render:
|
||||
env.render()
|
||||
|
||||
|
||||
|
||||
176
roboimi/utils/raw_action_trajectory_viewer.py
Normal file
176
roboimi/utils/raw_action_trajectory_viewer.py
Normal file
@@ -0,0 +1,176 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
import cv2
|
||||
import mujoco
|
||||
import numpy as np
|
||||
|
||||
from roboimi.assets.robots.diana_med import BiDianaMed
|
||||
from roboimi.envs.mujoco_base import MujocoEnv
|
||||
from roboimi.envs.double_pos_ctrl_env import make_sim_env
|
||||
from roboimi.utils.act_ex_utils import sample_transfer_pose
|
||||
|
||||
|
||||
def _load_raw_action_array(path: str | Path) -> np.ndarray:
|
||||
path = Path(path)
|
||||
if path.suffix == ".npy":
|
||||
raw_action = np.load(path)
|
||||
elif path.suffix == ".npz":
|
||||
archive = np.load(path)
|
||||
if "raw_action" in archive:
|
||||
raw_action = archive["raw_action"]
|
||||
elif "raw_predicted_ee_action" in archive:
|
||||
raw_action = archive["raw_predicted_ee_action"]
|
||||
else:
|
||||
raise KeyError(f"{path} does not contain raw_action")
|
||||
else:
|
||||
raise ValueError(f"unsupported trajectory file: {path}")
|
||||
raw_action = np.asarray(raw_action, dtype=np.float32)
|
||||
if raw_action.ndim != 2 or raw_action.shape[1] < 10:
|
||||
raise ValueError(f"raw_action must have shape (T, 16)-like, got {raw_action.shape}")
|
||||
return raw_action
|
||||
|
||||
|
||||
def disable_cv2_highgui(cv2_module=cv2):
|
||||
original = {
|
||||
"namedWindow": cv2_module.namedWindow,
|
||||
"imshow": cv2_module.imshow,
|
||||
"waitKey": cv2_module.waitKey,
|
||||
}
|
||||
|
||||
cv2_module.namedWindow = lambda *args, **kwargs: None
|
||||
cv2_module.imshow = lambda *args, **kwargs: None
|
||||
cv2_module.waitKey = lambda *args, **kwargs: 1
|
||||
|
||||
def restore():
|
||||
cv2_module.namedWindow = original["namedWindow"]
|
||||
cv2_module.imshow = original["imshow"]
|
||||
cv2_module.waitKey = original["waitKey"]
|
||||
|
||||
return restore
|
||||
|
||||
|
||||
def set_transfer_box_pose(mj_data, box_pos: np.ndarray) -> None:
|
||||
box_pos = np.asarray(box_pos, dtype=np.float64)
|
||||
if box_pos.shape != (3,):
|
||||
raise ValueError(f"box_pos must have shape (3,), got {box_pos.shape}")
|
||||
joint = mj_data.joint("red_box_joint")
|
||||
joint.qpos[0] = box_pos[0]
|
||||
joint.qpos[1] = box_pos[1]
|
||||
joint.qpos[2] = box_pos[2]
|
||||
joint.qpos[3] = 1.0
|
||||
joint.qpos[4] = 0.0
|
||||
joint.qpos[5] = 0.0
|
||||
joint.qpos[6] = 0.0
|
||||
|
||||
|
||||
def load_raw_action_positions(path: str | Path) -> dict[str, np.ndarray]:
|
||||
raw_action = _load_raw_action_array(path)
|
||||
return {
|
||||
"left": raw_action[:, :3].astype(np.float32, copy=True),
|
||||
"right": raw_action[:, 7:10].astype(np.float32, copy=True),
|
||||
}
|
||||
|
||||
|
||||
def _downsample_points(points: np.ndarray, stride: int) -> np.ndarray:
|
||||
sampled = points[::stride]
|
||||
if len(sampled) == 0:
|
||||
return points
|
||||
if not np.array_equal(sampled[-1], points[-1]):
|
||||
sampled = np.concatenate([sampled, points[-1:]], axis=0)
|
||||
return sampled
|
||||
|
||||
|
||||
def build_trajectory_capsule_markers(
|
||||
positions: dict[str, np.ndarray],
|
||||
*,
|
||||
max_markers: int,
|
||||
radius: float = 0.003,
|
||||
rgba: tuple[float, float, float, float] = (1.0, 0.0, 0.0, 1.0),
|
||||
) -> list[dict]:
|
||||
total_segments = sum(max(len(points) - 1, 0) for points in positions.values())
|
||||
if total_segments == 0:
|
||||
return []
|
||||
stride = max(1, math.ceil(total_segments / max_markers))
|
||||
markers = []
|
||||
for points in positions.values():
|
||||
sampled = _downsample_points(np.asarray(points, dtype=np.float64), stride)
|
||||
for idx in range(len(sampled) - 1):
|
||||
markers.append(
|
||||
{
|
||||
"from": sampled[idx],
|
||||
"to": sampled[idx + 1],
|
||||
"rgba": rgba,
|
||||
"radius": float(radius),
|
||||
}
|
||||
)
|
||||
return markers[:max_markers]
|
||||
|
||||
|
||||
def apply_capsule_markers_to_scene(user_scn, markers: Iterable[dict]) -> None:
|
||||
user_scn.ngeom = 0
|
||||
for marker in markers:
|
||||
if user_scn.ngeom >= user_scn.maxgeom:
|
||||
break
|
||||
geom = user_scn.geoms[user_scn.ngeom]
|
||||
mujoco.mjv_initGeom(
|
||||
geom,
|
||||
mujoco.mjtGeom.mjGEOM_CAPSULE,
|
||||
np.zeros(3, dtype=np.float64),
|
||||
np.zeros(3, dtype=np.float64),
|
||||
np.eye(3, dtype=np.float64).reshape(-1),
|
||||
np.asarray(marker["rgba"], dtype=np.float32),
|
||||
)
|
||||
mujoco.mjv_connector(
|
||||
geom,
|
||||
mujoco.mjtGeom.mjGEOM_CAPSULE,
|
||||
float(marker["radius"]),
|
||||
np.asarray(marker["from"], dtype=np.float64),
|
||||
np.asarray(marker["to"], dtype=np.float64),
|
||||
)
|
||||
user_scn.ngeom += 1
|
||||
|
||||
|
||||
def launch_raw_action_trajectory_viewer(
|
||||
trajectory_path: str | Path,
|
||||
*,
|
||||
task_name: str = "sim_transfer",
|
||||
line_radius: float = 0.004,
|
||||
max_markers: int = 1500,
|
||||
box_pos: np.ndarray | None = None,
|
||||
disable_camera_window: bool = True,
|
||||
):
|
||||
positions = load_raw_action_positions(trajectory_path)
|
||||
if task_name != "sim_transfer":
|
||||
raise NotImplementedError(f"unsupported task_name: {task_name}")
|
||||
if box_pos is None:
|
||||
box_pos = sample_transfer_pose()
|
||||
|
||||
robot = BiDianaMed()
|
||||
viewer_env = MujocoEnv(robot=robot, is_render=True, renderer="viewer", control_freq=30)
|
||||
viewer_env.reset()
|
||||
set_transfer_box_pose(viewer_env.mj_data, box_pos)
|
||||
mujoco.mj_forward(viewer_env.mj_model, viewer_env.mj_data)
|
||||
markers = build_trajectory_capsule_markers(
|
||||
positions,
|
||||
max_markers=max_markers,
|
||||
radius=line_radius,
|
||||
)
|
||||
|
||||
if viewer_env.viewer is None or getattr(viewer_env.viewer, "user_scn", None) is None:
|
||||
raise RuntimeError("viewer does not expose user_scn; cannot render trajectory overlay")
|
||||
|
||||
try:
|
||||
while viewer_env.viewer.is_running() and not viewer_env.exit_flag:
|
||||
with viewer_env.viewer.lock():
|
||||
apply_capsule_markers_to_scene(viewer_env.viewer.user_scn, markers)
|
||||
viewer_env.render()
|
||||
time.sleep(1 / 60.0)
|
||||
finally:
|
||||
viewer_env.exit_flag = True
|
||||
if getattr(viewer_env, "viewer", None) is not None:
|
||||
viewer_env.viewer.close()
|
||||
113
roboimi/utils/streaming_episode_writer.py
Normal file
113
roboimi/utils/streaming_episode_writer.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import h5py
|
||||
import numpy as np
|
||||
|
||||
|
||||
class StreamingEpisodeWriter:
|
||||
"""逐帧写入 episode 数据,成功后提交,失败时丢弃临时文件。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_path: str | os.PathLike[str],
|
||||
max_timesteps: int,
|
||||
camera_names: list[str],
|
||||
image_size: tuple[int, int] = (256, 256),
|
||||
) -> None:
|
||||
self.dataset_path = Path(dataset_path)
|
||||
self.tmp_path = Path(f"{self.dataset_path}.tmp")
|
||||
self.max_timesteps = int(max_timesteps)
|
||||
self.camera_names = list(camera_names)
|
||||
self.image_height = int(image_size[0])
|
||||
self.image_width = int(image_size[1])
|
||||
self.frame_index = 0
|
||||
self._committed = False
|
||||
self._closed = False
|
||||
|
||||
self.dataset_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if self.tmp_path.exists():
|
||||
self.tmp_path.unlink()
|
||||
|
||||
self._file = h5py.File(self.tmp_path, "w", rdcc_nbytes=1024**2 * 2)
|
||||
self._file.attrs["sim"] = True
|
||||
self._file.attrs["action_repr"] = "ee_pose_xyz_quat_gripper"
|
||||
self._file.attrs["image_height"] = self.image_height
|
||||
self._file.attrs["image_width"] = self.image_width
|
||||
self._file.attrs["camera_names"] = np.asarray(self.camera_names, dtype="S")
|
||||
|
||||
observations = self._file.create_group("observations")
|
||||
images = observations.create_group("images")
|
||||
for cam_name in self.camera_names:
|
||||
images.create_dataset(
|
||||
cam_name,
|
||||
(self.max_timesteps, self.image_height, self.image_width, 3),
|
||||
dtype="uint8",
|
||||
chunks=(1, self.image_height, self.image_width, 3),
|
||||
)
|
||||
observations.create_dataset(
|
||||
"qpos",
|
||||
(self.max_timesteps, 16),
|
||||
dtype="float32",
|
||||
chunks=(min(128, self.max_timesteps), 16),
|
||||
)
|
||||
self._file.create_dataset(
|
||||
"action",
|
||||
(self.max_timesteps, 16),
|
||||
dtype="float32",
|
||||
chunks=(min(128, self.max_timesteps), 16),
|
||||
)
|
||||
|
||||
def append(self, qpos: np.ndarray, action: np.ndarray, images: dict[str, np.ndarray]) -> None:
|
||||
if self._closed:
|
||||
raise RuntimeError("writer is already closed")
|
||||
if self.frame_index >= self.max_timesteps:
|
||||
raise IndexError("frame index exceeds max_timesteps")
|
||||
|
||||
qpos = np.asarray(qpos, dtype=np.float32)
|
||||
action = np.asarray(action, dtype=np.float32)
|
||||
if qpos.shape != (16,):
|
||||
raise ValueError(f"qpos shape must be (16,), got {qpos.shape}")
|
||||
if action.shape != (16,):
|
||||
raise ValueError(f"action shape must be (16,), got {action.shape}")
|
||||
|
||||
self._file["observations/qpos"][self.frame_index] = qpos
|
||||
self._file["action"][self.frame_index] = action
|
||||
|
||||
for cam_name in self.camera_names:
|
||||
if cam_name not in images:
|
||||
raise KeyError(f"missing image for camera '{cam_name}'")
|
||||
self._file[f"observations/images/{cam_name}"][self.frame_index] = self._resize_image(images[cam_name])
|
||||
|
||||
self.frame_index += 1
|
||||
|
||||
def commit(self) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._file.flush()
|
||||
self._file.close()
|
||||
self._closed = True
|
||||
os.replace(self.tmp_path, self.dataset_path)
|
||||
self._committed = True
|
||||
|
||||
def discard(self) -> None:
|
||||
if not self._closed:
|
||||
self._file.close()
|
||||
self._closed = True
|
||||
if self.tmp_path.exists():
|
||||
self.tmp_path.unlink()
|
||||
|
||||
def _resize_image(self, image: np.ndarray) -> np.ndarray:
|
||||
image = np.asarray(image, dtype=np.uint8)
|
||||
if image.ndim != 3 or image.shape[2] != 3:
|
||||
raise ValueError(f"image shape must be HxWx3, got {image.shape}")
|
||||
if image.shape[:2] == (self.image_height, self.image_width):
|
||||
return image
|
||||
|
||||
interpolation = cv2.INTER_AREA
|
||||
if image.shape[0] < self.image_height or image.shape[1] < self.image_width:
|
||||
interpolation = cv2.INTER_LINEAR
|
||||
return cv2.resize(image, (self.image_width, self.image_height), interpolation=interpolation)
|
||||
@@ -3,10 +3,8 @@ import torch.nn as nn
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
from typing import Dict, Optional, Any, Tuple
|
||||
from roboimi.vla.core.interfaces import VLABackbone, VLAProjector, VLAHead
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
from roboimi.vla.models.heads.conditional_unet1d import ConditionalUnet1D
|
||||
from roboimi.vla.models.normalization import NormalizationModule
|
||||
|
||||
class VLAAgent(nn.Module):
|
||||
@@ -24,6 +22,7 @@ class VLAAgent(nn.Module):
|
||||
diffusion_steps=100, # DDPM 加噪步数
|
||||
inference_steps=10, # DDIM 推理步数
|
||||
num_cams=3, # 视觉输入的摄像头数量
|
||||
camera_names: Optional[Tuple[str, ...]] = None, # 条件相机顺序
|
||||
dataset_stats=None, # 数据集统计信息,用于归一化
|
||||
normalization_type='min_max', # 归一化类型: 'gaussian' 或 'min_max'
|
||||
num_action_steps=8, # 每次推理实际执行多少步动作
|
||||
@@ -39,6 +38,31 @@ class VLAAgent(nn.Module):
|
||||
self.num_action_steps = num_action_steps
|
||||
self.inference_steps = inference_steps
|
||||
self.head_type = head_type # 'unet' 或 'transformer'
|
||||
agent_camera_names = tuple(camera_names) if camera_names is not None else None
|
||||
backbone_camera_names = getattr(vision_backbone, 'camera_names', None)
|
||||
backbone_camera_names = tuple(backbone_camera_names) if backbone_camera_names is not None else None
|
||||
backbone_num_cameras = getattr(vision_backbone, 'num_cameras', None)
|
||||
if backbone_num_cameras is not None and backbone_num_cameras != self.num_cams:
|
||||
raise ValueError(
|
||||
f"agent.num_cams({self.num_cams}) 与 "
|
||||
f"vision_backbone.num_cameras({backbone_num_cameras}) 不一致"
|
||||
)
|
||||
if (
|
||||
agent_camera_names is not None
|
||||
and backbone_camera_names is not None
|
||||
and agent_camera_names != backbone_camera_names
|
||||
):
|
||||
raise ValueError(
|
||||
f"agent.camera_names({list(agent_camera_names)}) 与 "
|
||||
f"vision_backbone.camera_names({list(backbone_camera_names)}) 不一致"
|
||||
)
|
||||
self.camera_names = (
|
||||
agent_camera_names if agent_camera_names is not None else backbone_camera_names
|
||||
)
|
||||
if self.camera_names is not None and len(self.camera_names) != self.num_cams:
|
||||
raise ValueError(
|
||||
f"camera_names 长度({len(self.camera_names)})与 num_cams({self.num_cams})不一致"
|
||||
)
|
||||
|
||||
|
||||
# 归一化模块 - 统一训练和推理的归一化逻辑
|
||||
@@ -48,6 +72,8 @@ class VLAAgent(nn.Module):
|
||||
)
|
||||
|
||||
self.vision_encoder = vision_backbone
|
||||
if self.camera_names is not None:
|
||||
self.vision_encoder.camera_names = self.camera_names
|
||||
single_cam_feat_dim = self.vision_encoder.output_dim
|
||||
# global_cond_dim: 展平后的总维度(用于UNet)
|
||||
total_vision_dim = single_cam_feat_dim * num_cams * obs_horizon
|
||||
@@ -117,6 +143,34 @@ class VLAAgent(nn.Module):
|
||||
return tuple(self._move_to_device(v, device) for v in data)
|
||||
return data
|
||||
|
||||
def _order_images(self, images: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
"""按显式配置的相机顺序返回图像字典。"""
|
||||
if self.camera_names is None:
|
||||
camera_names = tuple(sorted(images.keys()))
|
||||
if len(camera_names) != self.num_cams:
|
||||
raise ValueError(
|
||||
f"图像条件相机数量({len(camera_names)})与 num_cams({self.num_cams})不一致"
|
||||
)
|
||||
return {cam_name: images[cam_name] for cam_name in camera_names}
|
||||
|
||||
missing = [cam_name for cam_name in self.camera_names if cam_name not in images]
|
||||
if missing:
|
||||
raise ValueError(
|
||||
f"图像条件缺少必需相机。missing={missing}, expected={list(self.camera_names)}"
|
||||
)
|
||||
return {cam_name: images[cam_name] for cam_name in self.camera_names}
|
||||
|
||||
def _build_cond(self, images: Dict[str, torch.Tensor], states: torch.Tensor) -> torch.Tensor:
|
||||
"""构造每步条件,确保图像条件顺序稳定。"""
|
||||
ordered_images = self._order_images(images)
|
||||
visual_features = self.vision_encoder(ordered_images)
|
||||
state_features = self.state_encoder(states)
|
||||
cond = torch.cat([visual_features, state_features], dim=-1)
|
||||
if cond.shape[-1] != self.per_step_cond_dim:
|
||||
raise RuntimeError(
|
||||
f"条件维度不匹配: got {cond.shape[-1]}, expected {self.per_step_cond_dim}"
|
||||
)
|
||||
return cond
|
||||
|
||||
# ==========================
|
||||
# 训练阶段 (Training)
|
||||
@@ -136,10 +190,8 @@ class VLAAgent(nn.Module):
|
||||
states = self.normalization.normalize_qpos(states)
|
||||
actions = self.normalization.normalize_action(actions)
|
||||
|
||||
state_features = self.state_encoder(states)
|
||||
|
||||
# 1. 提取视觉特征
|
||||
visual_features = self.vision_encoder(images) # (B, obs_horizon, vision_dim)
|
||||
per_step_cond = self._build_cond(images, states)
|
||||
action_features = self.action_encoder(actions)
|
||||
|
||||
# 2. 采样噪声
|
||||
@@ -157,21 +209,16 @@ class VLAAgent(nn.Module):
|
||||
)
|
||||
|
||||
# 拼接全局条件并展平
|
||||
# visual_features: (B, obs_horizon, vision_dim)
|
||||
# state_features: (B, obs_horizon, obs_dim)
|
||||
# 拼接后展平为 (B, obs_horizon * (vision_dim + obs_dim))
|
||||
global_cond = torch.cat([visual_features, state_features], dim=-1)
|
||||
global_cond = global_cond.flatten(start_dim=1)
|
||||
# per_step_cond: (B, obs_horizon, vision_dim * num_cams + obs_dim)
|
||||
# 展平后用于 UNet,全序列形式用于 Transformer
|
||||
global_cond = per_step_cond.flatten(start_dim=1)
|
||||
|
||||
# 5. 网络预测噪声(根据head类型选择接口)
|
||||
if self.head_type == 'transformer':
|
||||
# Transformer需要序列格式的条件: (B, obs_horizon, cond_dim_per_step)
|
||||
# 将展平的global_cond reshape回序列格式
|
||||
cond = global_cond.reshape(B, self.obs_horizon, self.per_step_cond_dim)
|
||||
pred_noise = self.noise_pred_net(
|
||||
sample=noisy_actions,
|
||||
timestep=timesteps,
|
||||
cond=cond
|
||||
cond=per_step_cond
|
||||
)
|
||||
else: # 'unet'
|
||||
pred_noise = self.noise_pred_net(
|
||||
@@ -218,7 +265,8 @@ class VLAAgent(nn.Module):
|
||||
|
||||
# 添加图像
|
||||
if 'images' in observation:
|
||||
self._queues['images'].append({k: v.clone() for k, v in observation['images'].items()})
|
||||
ordered_images = self._order_images(observation['images'])
|
||||
self._queues['images'].append({k: v.clone() for k, v in ordered_images.items()})
|
||||
|
||||
def _prepare_observation_batch(self) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
@@ -246,7 +294,8 @@ class VLAAgent(nn.Module):
|
||||
images_list.append(images_list[-1])
|
||||
|
||||
batch_images = {}
|
||||
for cam_name in images_list[0].keys():
|
||||
camera_names = self.camera_names if self.camera_names is not None else tuple(sorted(images_list[0].keys()))
|
||||
for cam_name in camera_names:
|
||||
batch_images[cam_name] = torch.stack([img[cam_name] for img in images_list], dim=0).unsqueeze(0)
|
||||
|
||||
return {'qpos': batch_qpos, 'images': batch_images}
|
||||
@@ -346,22 +395,18 @@ class VLAAgent(nn.Module):
|
||||
proprioception = self.normalization.normalize_qpos(proprioception)
|
||||
|
||||
# 1. 提取当前观测特征(只提取一次)
|
||||
visual_features = self.vision_encoder(images)
|
||||
state_features = self.state_encoder(proprioception)
|
||||
per_step_cond = self._build_cond(images, proprioception)
|
||||
|
||||
# 拼接条件(只计算一次)
|
||||
# visual_features: (B, obs_horizon, vision_dim)
|
||||
# state_features: (B, obs_horizon, obs_dim)
|
||||
global_cond = torch.cat([visual_features, state_features], dim=-1)
|
||||
global_cond_flat = global_cond.flatten(start_dim=1)
|
||||
global_cond_flat = per_step_cond.flatten(start_dim=1)
|
||||
if self.head_type == 'transformer':
|
||||
cond = global_cond.reshape(B, self.obs_horizon, self.per_step_cond_dim)
|
||||
cond = per_step_cond
|
||||
else:
|
||||
cond = None
|
||||
|
||||
# 2. 初始化纯高斯噪声动作
|
||||
# 形状: (B, pred_horizon, action_dim)
|
||||
device = visual_features.device
|
||||
device = per_step_cond.device
|
||||
current_actions = torch.randn(
|
||||
(B, self.pred_horizon, self.action_dim), device=device
|
||||
)
|
||||
|
||||
@@ -29,8 +29,13 @@ num_action_steps: 8 # 每次推理实际执行多少步动作(应 <= p
|
||||
# ====================
|
||||
# 相机配置
|
||||
# ====================
|
||||
camera_names: ${data.camera_names} # 条件相机顺序固定为 r_vis, top, front
|
||||
num_cams: 3 # 摄像头数量 (r_vis, top, front)
|
||||
|
||||
vision_backbone:
|
||||
num_cameras: ${agent.num_cams}
|
||||
camera_names: ${agent.camera_names}
|
||||
|
||||
# ====================
|
||||
# 扩散过程配置
|
||||
# ====================
|
||||
@@ -52,3 +57,6 @@ head:
|
||||
# ResNet18 + SpatialSoftmax(32 keypoints) = 64维/相机
|
||||
# 计算方式:单相机特征(64) * 相机数(3) + obs_dim(16) = 208
|
||||
cond_dim: 208
|
||||
causal_attn: false
|
||||
time_as_cond: true
|
||||
obs_as_cond: true
|
||||
|
||||
@@ -9,19 +9,25 @@ defaults:
|
||||
# ====================
|
||||
train:
|
||||
# 基础训练参数
|
||||
batch_size: 8 # 批次大小
|
||||
lr: 5e-5 # 学习率(Transformer建议更小)
|
||||
batch_size: 16 # 批次大小
|
||||
lr: 1e-4 # 学习率
|
||||
max_steps: 100000 # 最大训练步数
|
||||
device: "cuda" # 设备: "cuda" 或 "cpu"
|
||||
|
||||
# 数据加载
|
||||
num_workers: 8 # DataLoader 工作进程数(调试时设为 0,生产环境用 8)
|
||||
val_split: 0.1 # 验证集比例
|
||||
num_workers: 12 # DataLoader 工作进程数(调试时设为 0)
|
||||
val_split: 0.0 # 验证集比例;默认使用全量数据训练
|
||||
seed: 42 # 随机种子(用于数据划分)
|
||||
|
||||
# 日志和检查点
|
||||
log_freq: 100 # 日志记录频率(步数)
|
||||
save_freq: 2000 # 保存检查点频率(步数)
|
||||
use_swanlab: false # 是否启用 SwanLab 标量日志
|
||||
swanlab_project: "roboimi-vla" # SwanLab project 名称
|
||||
swanlab_run_name: null # 可选的 SwanLab 运行名
|
||||
rollout_val_freq_epochs: 50 # 每隔多少个 epoch 执行一次 rollout 验证
|
||||
rollout_validate_on_checkpoint: false # 是否在保存 checkpoint 后立即运行 rollout 验证
|
||||
rollout_num_episodes: 3 # rollout 验证的回合数
|
||||
|
||||
# 学习率调度器(带预热)
|
||||
warmup_steps: 2000 # 预热步数(Transformer建议更长)
|
||||
|
||||
@@ -29,6 +29,19 @@ smooth_alpha: 0.3
|
||||
# ====================
|
||||
# 调试选项
|
||||
# ====================
|
||||
headless: false # 是否禁用 MuJoCo / OpenCV GUI 渲染
|
||||
verbose_action: true # 是否打印每个时间步的动作信息
|
||||
|
||||
|
||||
# ====================
|
||||
# Rollout artifact 导出
|
||||
# ====================
|
||||
artifact_dir: null # 可选输出目录;为空时在启用导出时自动创建目录
|
||||
save_artifacts: false # 总开关;实际仍需搭配下面的具体导出项
|
||||
save_timing: false # 是否保存 timing.json(包含各阶段耗时统计)
|
||||
save_trajectory: false # 是否保存 trajectory.npz(原始 EE action + 执行后 EE pose)
|
||||
save_summary_json: false # 是否保存 JSON-friendly rollout summary
|
||||
save_trajectory_npz: false # 是否保存每步轨迹/时序/EE pose 为 NPZ
|
||||
record_video: false # 是否从单个相机流录制 rollout mp4
|
||||
video_camera: null # video_camera_name 的别名
|
||||
video_camera_name: null # 录制视频使用的相机名;为空时默认取 camera_names[0]
|
||||
video_fps: 30 # 导出 mp4 的目标帧率
|
||||
|
||||
@@ -5,7 +5,7 @@ _partial_: true
|
||||
# ====================
|
||||
# Transformer 架构配置
|
||||
# ====================
|
||||
n_layer: 4 # Transformer层数(先用小模型提高收敛稳定性)
|
||||
n_layer: 4 # Transformer层数(保持当前小模型配置)
|
||||
n_head: 4 # 注意力头数
|
||||
n_emb: 128 # 嵌入维度
|
||||
p_drop_emb: 0.05 # Embedding dropout
|
||||
@@ -14,9 +14,10 @@ p_drop_attn: 0.05 # Attention dropout
|
||||
# ====================
|
||||
# 条件配置
|
||||
# ====================
|
||||
causal_attn: false # 是否使用因果注意力(自回归生成)
|
||||
obs_as_cond: true # 观测作为条件(由cond_dim > 0决定)
|
||||
n_cond_layers: 1 # 条件编码器层数(1层先做稳定融合)
|
||||
causal_attn: false # 对齐 external TransformerForDiffusion 的 full-attention / nocausal 变体
|
||||
time_as_cond: true # 与 external 实现一致:时间步作为条件 token
|
||||
obs_as_cond: true # API 对齐;实际是否启用由 cond_dim > 0 决定
|
||||
n_cond_layers: 1 # 条件编码器层数(保留当前配置)
|
||||
|
||||
# ====================
|
||||
# 注意事项
|
||||
|
||||
@@ -105,7 +105,7 @@ class SimpleRobotDataset(Dataset):
|
||||
self._file_cache[key] = f
|
||||
return f
|
||||
|
||||
def _load_frame(self, idx: int) -> Dict:
|
||||
def _load_frame(self, idx: int, *, load_images: bool = True) -> Dict:
|
||||
"""从 HDF5 文件懒加载单帧数据"""
|
||||
meta = self.frame_meta[idx]
|
||||
f = self._get_h5_file(meta["hdf5_path"])
|
||||
@@ -118,21 +118,22 @@ class SimpleRobotDataset(Dataset):
|
||||
}
|
||||
|
||||
# 加载图像数据: observations/images/{cam_name} -> observation.{cam_name}
|
||||
for cam_name in self.camera_names:
|
||||
h5_path = f'observations/images/{cam_name}'
|
||||
if h5_path in f:
|
||||
img = f[h5_path][meta["frame_idx"]]
|
||||
# Resize图像到224x224(减少内存和I/O负担)
|
||||
import cv2
|
||||
img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_LINEAR)
|
||||
# 转换为float并归一化到 [0, 1]
|
||||
img = torch.from_numpy(img).float() / 255.0
|
||||
frame[f"observation.{cam_name}"] = img.permute(2, 0, 1) # HWC -> CHW
|
||||
if load_images:
|
||||
for cam_name in self.camera_names:
|
||||
h5_path = f'observations/images/{cam_name}'
|
||||
if h5_path in f:
|
||||
img = f[h5_path][meta["frame_idx"]]
|
||||
# Resize图像到224x224(减少内存和I/O负担)
|
||||
import cv2
|
||||
img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_LINEAR)
|
||||
# 转换为float并归一化到 [0, 1]
|
||||
img = torch.from_numpy(img).float() / 255.0
|
||||
frame[f"observation.{cam_name}"] = img.permute(2, 0, 1) # HWC -> CHW
|
||||
|
||||
return frame
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
||||
frame = self._load_frame(idx)
|
||||
frame = self._load_frame(idx, load_images=False)
|
||||
ep_idx = frame["episode_index"]
|
||||
|
||||
# 获取当前 episode 的帧索引范围
|
||||
@@ -186,10 +187,10 @@ class SimpleRobotDataset(Dataset):
|
||||
target_idx = idx + delta
|
||||
|
||||
if target_idx <= ep_end:
|
||||
actions.append(self._load_frame(target_idx)["action"])
|
||||
actions.append(self._load_frame(target_idx, load_images=False)["action"])
|
||||
action_is_pad.append(False)
|
||||
else:
|
||||
actions.append(self._load_frame(ep_end)["action"])
|
||||
actions.append(self._load_frame(ep_end, load_images=False)["action"])
|
||||
action_is_pad.append(True)
|
||||
|
||||
# ============================================
|
||||
|
||||
3
roboimi/vla/eval_utils.py
Normal file
3
roboimi/vla/eval_utils.py
Normal file
@@ -0,0 +1,3 @@
|
||||
def execute_policy_action(env, action):
|
||||
"""Execute policy outputs using EE-action semantics."""
|
||||
env.step(action)
|
||||
@@ -178,12 +178,18 @@ class ResNetDiffusionBackbone(VLABackbone):
|
||||
spatial_softmax_num_keypoints: int = 32,
|
||||
use_separate_rgb_encoder_per_camera: bool = False, # 新增:是否为每个摄像头使用独立编码器
|
||||
num_cameras: int = 1, # 新增:摄像头数量(仅在独立编码器模式下使用)
|
||||
camera_names: Optional[Tuple[str, ...]] = None, # 显式相机顺序
|
||||
freeze_backbone: bool = True, # 新增:是否冻结ResNet backbone(推荐True)
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.use_separate_rgb_encoder_per_camera = use_separate_rgb_encoder_per_camera
|
||||
self.num_cameras = num_cameras
|
||||
self.camera_names = tuple(camera_names) if camera_names is not None else None
|
||||
if self.camera_names is not None and len(self.camera_names) != self.num_cameras:
|
||||
raise ValueError(
|
||||
f"camera_names 长度({len(self.camera_names)})与 num_cameras({self.num_cameras})不一致"
|
||||
)
|
||||
|
||||
if use_separate_rgb_encoder_per_camera:
|
||||
# 独立编码器模式:为每个摄像头创建独立的编码器
|
||||
@@ -217,6 +223,22 @@ class ResNetDiffusionBackbone(VLABackbone):
|
||||
)
|
||||
self.feature_dim = self.rgb_encoder.feature_dim
|
||||
|
||||
def _ordered_camera_names(self, images) -> Tuple[str, ...]:
|
||||
if self.camera_names is None:
|
||||
camera_names = tuple(sorted(images.keys()))
|
||||
if len(camera_names) != self.num_cameras:
|
||||
raise ValueError(
|
||||
f"图像输入相机数量({len(camera_names)})与 num_cameras({self.num_cameras})不一致"
|
||||
)
|
||||
return camera_names
|
||||
|
||||
missing = [cam_name for cam_name in self.camera_names if cam_name not in images]
|
||||
if missing:
|
||||
raise ValueError(
|
||||
f"图像输入缺少必需相机。missing={missing}, expected={list(self.camera_names)}"
|
||||
)
|
||||
return self.camera_names
|
||||
|
||||
def forward(self, images):
|
||||
"""
|
||||
Args:
|
||||
@@ -228,7 +250,7 @@ class ResNetDiffusionBackbone(VLABackbone):
|
||||
"""
|
||||
any_tensor = next(iter(images.values()))
|
||||
B, T = any_tensor.shape[:2]
|
||||
cam_names = sorted(images.keys())
|
||||
cam_names = self._ordered_camera_names(images)
|
||||
|
||||
if self.use_separate_rgb_encoder_per_camera:
|
||||
# 独立编码器模式:每个摄像头使用对应的编码器
|
||||
@@ -236,7 +258,7 @@ class ResNetDiffusionBackbone(VLABackbone):
|
||||
for cam_idx, cam_name in enumerate(cam_names):
|
||||
img = images[cam_name]
|
||||
encoder = self.rgb_encoder[cam_idx]
|
||||
features = encoder.forward_single_image(img.view(B * T, *img.shape[2:]))
|
||||
features = encoder.forward_single_image(img.reshape(B * T, *img.shape[2:]))
|
||||
features_all.append(features)
|
||||
return torch.cat(features_all, dim=1).view(B, T, -1)
|
||||
else:
|
||||
@@ -244,7 +266,7 @@ class ResNetDiffusionBackbone(VLABackbone):
|
||||
features_all = []
|
||||
for cam_name in cam_names:
|
||||
img = images[cam_name]
|
||||
features = self.rgb_encoder.forward_single_image(img.view(B * T, *img.shape[2:]))
|
||||
features = self.rgb_encoder.forward_single_image(img.reshape(B * T, *img.shape[2:]))
|
||||
features_all.append(features)
|
||||
return torch.cat(features_all, dim=1).view(B, T, -1)
|
||||
|
||||
@@ -369,4 +391,4 @@ if __name__ == "__main__":
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("🎉 All tests completed successfully!")
|
||||
print("=" * 60)
|
||||
print("=" * 60)
|
||||
|
||||
@@ -1,19 +1,35 @@
|
||||
"""
|
||||
Transformer-based Diffusion Policy Head
|
||||
"""Transformer-based diffusion head aligned with diffusion_policy's TransformerForDiffusion."""
|
||||
|
||||
使用Transformer架构(Encoder-Decoder)替代UNet进行噪声预测。
|
||||
支持通过Cross-Attention注入全局条件(观测特征)。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModuleAttrMixin(nn.Module):
|
||||
"""Minimal local copy of diffusion_policy's ModuleAttrMixin for state-dict parity."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._dummy_variable = nn.Parameter()
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(iter(self.parameters())).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(iter(self.parameters())).dtype
|
||||
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
"""正弦位置编码(用于时间步嵌入)"""
|
||||
def __init__(self, dim: int):
|
||||
def __init__(self, dim: int) -> None:
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
@@ -27,35 +43,13 @@ class SinusoidalPosEmb(nn.Module):
|
||||
return emb
|
||||
|
||||
|
||||
class Transformer1D(nn.Module):
|
||||
"""
|
||||
Transformer-based 1D Diffusion Model
|
||||
|
||||
使用Encoder-Decoder架构:
|
||||
- Encoder: 处理条件(观测 + 时间步)
|
||||
- Decoder: 通过Cross-Attention预测噪声
|
||||
|
||||
Args:
|
||||
input_dim: 输入动作维度
|
||||
output_dim: 输出动作维度
|
||||
horizon: 预测horizon长度
|
||||
n_obs_steps: 观测步数
|
||||
cond_dim: 条件维度
|
||||
n_layer: Transformer层数
|
||||
n_head: 注意力头数
|
||||
n_emb: 嵌入维度
|
||||
p_drop_emb: Embedding dropout
|
||||
p_drop_attn: Attention dropout
|
||||
causal_attn: 是否使用因果注意力(自回归)
|
||||
n_cond_layers: Encoder层数(0表示使用MLP)
|
||||
"""
|
||||
|
||||
class Transformer1D(ModuleAttrMixin):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
horizon: int,
|
||||
n_obs_steps: int = None,
|
||||
n_obs_steps: Optional[int] = None,
|
||||
cond_dim: int = 0,
|
||||
n_layer: int = 8,
|
||||
n_head: int = 8,
|
||||
@@ -63,57 +57,42 @@ class Transformer1D(nn.Module):
|
||||
p_drop_emb: float = 0.1,
|
||||
p_drop_attn: float = 0.1,
|
||||
causal_attn: bool = False,
|
||||
time_as_cond: bool = True,
|
||||
obs_as_cond: bool = False,
|
||||
n_cond_layers: int = 0
|
||||
):
|
||||
n_cond_layers: int = 0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# 计算序列长度
|
||||
if n_obs_steps is None:
|
||||
n_obs_steps = horizon
|
||||
|
||||
T = horizon
|
||||
T_cond = 1 # 时间步token数量
|
||||
|
||||
# 确定是否使用观测作为条件
|
||||
T_cond = 1
|
||||
if not time_as_cond:
|
||||
T += 1
|
||||
T_cond -= 1
|
||||
obs_as_cond = cond_dim > 0
|
||||
if obs_as_cond:
|
||||
assert time_as_cond
|
||||
T_cond += n_obs_steps
|
||||
|
||||
# 保存配置
|
||||
self.T = T
|
||||
self.T_cond = T_cond
|
||||
self.horizon = horizon
|
||||
self.obs_as_cond = obs_as_cond
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
|
||||
# ==================== 输入嵌入 ====================
|
||||
self.input_emb = nn.Linear(input_dim, n_emb)
|
||||
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
|
||||
self.drop = nn.Dropout(p_drop_emb)
|
||||
|
||||
# ==================== 条件编码 ====================
|
||||
# 时间步嵌入
|
||||
self.time_emb = SinusoidalPosEmb(n_emb)
|
||||
|
||||
# 观测条件嵌入(可选)
|
||||
self.cond_obs_emb = None
|
||||
if obs_as_cond:
|
||||
self.cond_obs_emb = nn.Linear(cond_dim, n_emb)
|
||||
|
||||
# 条件位置编码
|
||||
self.cond_pos_emb = None
|
||||
self.encoder = None
|
||||
self.decoder = None
|
||||
encoder_only = False
|
||||
|
||||
if T_cond > 0:
|
||||
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
|
||||
|
||||
# ==================== Encoder ====================
|
||||
self.encoder = None
|
||||
self.encoder_only = False
|
||||
|
||||
if T_cond > 0:
|
||||
if n_cond_layers > 0:
|
||||
# 使用Transformer Encoder
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=n_emb,
|
||||
nhead=n_head,
|
||||
@@ -121,61 +100,19 @@ class Transformer1D(nn.Module):
|
||||
dropout=p_drop_attn,
|
||||
activation='gelu',
|
||||
batch_first=True,
|
||||
norm_first=True # Pre-LN更稳定
|
||||
norm_first=True,
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(
|
||||
encoder_layer=encoder_layer,
|
||||
num_layers=n_cond_layers
|
||||
num_layers=n_cond_layers,
|
||||
)
|
||||
else:
|
||||
# 使用简单的MLP
|
||||
self.encoder = nn.Sequential(
|
||||
nn.Linear(n_emb, 4 * n_emb),
|
||||
nn.Mish(),
|
||||
nn.Linear(4 * n_emb, n_emb)
|
||||
nn.Linear(4 * n_emb, n_emb),
|
||||
)
|
||||
else:
|
||||
# Encoder-only模式(BERT风格)
|
||||
self.encoder_only = True
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=n_emb,
|
||||
nhead=n_head,
|
||||
dim_feedforward=4 * n_emb,
|
||||
dropout=p_drop_attn,
|
||||
activation='gelu',
|
||||
batch_first=True,
|
||||
norm_first=True
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(
|
||||
encoder_layer=encoder_layer,
|
||||
num_layers=n_layer
|
||||
)
|
||||
|
||||
# ==================== Attention Mask ====================
|
||||
self.mask = None
|
||||
self.memory_mask = None
|
||||
|
||||
if causal_attn:
|
||||
# 因果mask:确保只关注左侧
|
||||
sz = T
|
||||
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
||||
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||
self.register_buffer("mask", mask)
|
||||
|
||||
if obs_as_cond:
|
||||
# 交叉注意力mask
|
||||
S = T_cond
|
||||
t, s = torch.meshgrid(
|
||||
torch.arange(T),
|
||||
torch.arange(S),
|
||||
indexing='ij'
|
||||
)
|
||||
mask = t >= (s - 1)
|
||||
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||
self.register_buffer('memory_mask', mask)
|
||||
|
||||
# ==================== Decoder ====================
|
||||
if not self.encoder_only:
|
||||
decoder_layer = nn.TransformerDecoderLayer(
|
||||
d_model=n_emb,
|
||||
nhead=n_head,
|
||||
@@ -183,136 +120,199 @@ class Transformer1D(nn.Module):
|
||||
dropout=p_drop_attn,
|
||||
activation='gelu',
|
||||
batch_first=True,
|
||||
norm_first=True
|
||||
norm_first=True,
|
||||
)
|
||||
self.decoder = nn.TransformerDecoder(
|
||||
decoder_layer=decoder_layer,
|
||||
num_layers=n_layer
|
||||
num_layers=n_layer,
|
||||
)
|
||||
else:
|
||||
encoder_only = True
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=n_emb,
|
||||
nhead=n_head,
|
||||
dim_feedforward=4 * n_emb,
|
||||
dropout=p_drop_attn,
|
||||
activation='gelu',
|
||||
batch_first=True,
|
||||
norm_first=True,
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(
|
||||
encoder_layer=encoder_layer,
|
||||
num_layers=n_layer,
|
||||
)
|
||||
|
||||
# ==================== 输出头 ====================
|
||||
if causal_attn:
|
||||
sz = T
|
||||
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
||||
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||
self.register_buffer('mask', mask)
|
||||
|
||||
if time_as_cond and obs_as_cond:
|
||||
S = T_cond
|
||||
t, s = torch.meshgrid(torch.arange(T), torch.arange(S), indexing='ij')
|
||||
mask = t >= (s - 1)
|
||||
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||
self.register_buffer('memory_mask', mask)
|
||||
else:
|
||||
self.memory_mask = None
|
||||
else:
|
||||
self.mask = None
|
||||
self.memory_mask = None
|
||||
|
||||
self.ln_f = nn.LayerNorm(n_emb)
|
||||
self.head = nn.Linear(n_emb, output_dim)
|
||||
|
||||
# ==================== 初始化 ====================
|
||||
self.apply(self._init_weights)
|
||||
self.T = T
|
||||
self.T_cond = T_cond
|
||||
self.horizon = horizon
|
||||
self.time_as_cond = time_as_cond
|
||||
self.obs_as_cond = obs_as_cond
|
||||
self.encoder_only = encoder_only
|
||||
|
||||
# 打印参数量
|
||||
total_params = sum(p.numel() for p in self.parameters())
|
||||
print(f"Transformer1D parameters: {total_params:,}")
|
||||
self.apply(self._init_weights)
|
||||
logger.info('number of parameters: %e', sum(p.numel() for p in self.parameters()))
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""初始化权重"""
|
||||
ignore_types = (
|
||||
nn.Dropout,
|
||||
SinusoidalPosEmb,
|
||||
nn.TransformerEncoderLayer,
|
||||
nn.TransformerDecoderLayer,
|
||||
nn.TransformerEncoder,
|
||||
nn.TransformerDecoder,
|
||||
nn.ModuleList,
|
||||
nn.Mish,
|
||||
nn.Sequential,
|
||||
)
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.MultiheadAttention):
|
||||
# MultiheadAttention的权重初始化
|
||||
for name in ['in_proj_weight', 'q_proj_weight', 'k_proj_weight', 'v_proj_weight']:
|
||||
weight = getattr(module, name, None)
|
||||
for name in ('in_proj_weight', 'q_proj_weight', 'k_proj_weight', 'v_proj_weight'):
|
||||
weight = getattr(module, name)
|
||||
if weight is not None:
|
||||
torch.nn.init.normal_(weight, mean=0.0, std=0.02)
|
||||
|
||||
for name in ['in_proj_bias', 'bias_k', 'bias_v']:
|
||||
bias = getattr(module, name, None)
|
||||
for name in ('in_proj_bias', 'bias_k', 'bias_v'):
|
||||
bias = getattr(module, name)
|
||||
if bias is not None:
|
||||
torch.nn.init.zeros_(bias)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
torch.nn.init.ones_(module.weight)
|
||||
elif isinstance(module, Transformer1D):
|
||||
# 位置编码初始化
|
||||
torch.nn.init.normal_(self.pos_emb, mean=0.0, std=0.02)
|
||||
if self.cond_pos_emb is not None:
|
||||
torch.nn.init.normal_(self.cond_pos_emb, mean=0.0, std=0.02)
|
||||
torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
|
||||
if module.cond_obs_emb is not None:
|
||||
torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02)
|
||||
elif isinstance(module, ignore_types):
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError(f'Unaccounted module {module}')
|
||||
|
||||
def get_optim_groups(self, weight_decay: float = 1e-3):
|
||||
decay = set()
|
||||
no_decay = set()
|
||||
whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention)
|
||||
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
||||
|
||||
for module_name, module in self.named_modules():
|
||||
for param_name, _ in module.named_parameters():
|
||||
full_param_name = f'{module_name}.{param_name}' if module_name else param_name
|
||||
|
||||
if param_name.endswith('bias'):
|
||||
no_decay.add(full_param_name)
|
||||
elif param_name.startswith('bias'):
|
||||
no_decay.add(full_param_name)
|
||||
elif param_name.endswith('weight') and isinstance(module, whitelist_weight_modules):
|
||||
decay.add(full_param_name)
|
||||
elif param_name.endswith('weight') and isinstance(module, blacklist_weight_modules):
|
||||
no_decay.add(full_param_name)
|
||||
|
||||
no_decay.add('pos_emb')
|
||||
no_decay.add('_dummy_variable')
|
||||
if self.cond_pos_emb is not None:
|
||||
no_decay.add('cond_pos_emb')
|
||||
|
||||
param_dict = {name: param for name, param in self.named_parameters()}
|
||||
inter_params = decay & no_decay
|
||||
union_params = decay | no_decay
|
||||
assert len(inter_params) == 0, f'parameters {inter_params} made it into both decay/no_decay sets!'
|
||||
assert len(param_dict.keys() - union_params) == 0, (
|
||||
f'parameters {param_dict.keys() - union_params} were not separated into either decay/no_decay sets!'
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
'params': [param_dict[name] for name in sorted(decay)],
|
||||
'weight_decay': weight_decay,
|
||||
},
|
||||
{
|
||||
'params': [param_dict[name] for name in sorted(no_decay)],
|
||||
'weight_decay': 0.0,
|
||||
},
|
||||
]
|
||||
|
||||
def configure_optimizers(
|
||||
self,
|
||||
learning_rate: float = 1e-4,
|
||||
weight_decay: float = 1e-3,
|
||||
betas: Tuple[float, float] = (0.9, 0.95),
|
||||
):
|
||||
optim_groups = self.get_optim_groups(weight_decay=weight_decay)
|
||||
return torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
cond: Optional[torch.Tensor] = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
前向传播
|
||||
|
||||
Args:
|
||||
sample: (B, T, input_dim) 输入序列(加噪动作)
|
||||
timestep: (B,) 时间步
|
||||
cond: (B, T', cond_dim) 条件序列(观测特征)
|
||||
|
||||
Returns:
|
||||
(B, T, output_dim) 预测的噪声
|
||||
"""
|
||||
# ==================== 处理时间步 ====================
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
# 扩展到batch维度
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
time_emb = self.time_emb(timesteps).unsqueeze(1) # (B, 1, n_emb)
|
||||
time_emb = self.time_emb(timesteps).unsqueeze(1)
|
||||
|
||||
# ==================== 处理输入 ====================
|
||||
input_emb = self.input_emb(sample) # (B, T, n_emb)
|
||||
input_emb = self.input_emb(sample)
|
||||
|
||||
# ==================== Encoder-Decoder模式 ====================
|
||||
if not self.encoder_only:
|
||||
# --- Encoder: 处理条件 ---
|
||||
if self.encoder_only:
|
||||
token_embeddings = torch.cat([time_emb, input_emb], dim=1)
|
||||
t = token_embeddings.shape[1]
|
||||
position_embeddings = self.pos_emb[:, :t, :]
|
||||
x = self.drop(token_embeddings + position_embeddings)
|
||||
x = self.encoder(src=x, mask=self.mask)
|
||||
x = x[:, 1:, :]
|
||||
else:
|
||||
cond_embeddings = time_emb
|
||||
|
||||
if self.obs_as_cond and cond is not None:
|
||||
# 添加观测条件
|
||||
cond_obs_emb = self.cond_obs_emb(cond) # (B, T_cond-1, n_emb)
|
||||
if self.obs_as_cond:
|
||||
cond_obs_emb = self.cond_obs_emb(cond)
|
||||
cond_embeddings = torch.cat([cond_embeddings, cond_obs_emb], dim=1)
|
||||
|
||||
# 添加位置编码
|
||||
tc = cond_embeddings.shape[1]
|
||||
pos_emb = self.cond_pos_emb[:, :tc, :]
|
||||
x = self.drop(cond_embeddings + pos_emb)
|
||||
position_embeddings = self.cond_pos_emb[:, :tc, :]
|
||||
x = self.drop(cond_embeddings + position_embeddings)
|
||||
memory = self.encoder(x)
|
||||
|
||||
# 通过encoder
|
||||
memory = self.encoder(x) # (B, T_cond, n_emb)
|
||||
|
||||
# --- Decoder: 预测噪声 ---
|
||||
# 添加位置编码到输入
|
||||
token_embeddings = input_emb
|
||||
t = token_embeddings.shape[1]
|
||||
pos_emb = self.pos_emb[:, :t, :]
|
||||
x = self.drop(token_embeddings + pos_emb)
|
||||
|
||||
# Cross-Attention: Query来自输入,Key/Value来自memory
|
||||
position_embeddings = self.pos_emb[:, :t, :]
|
||||
x = self.drop(token_embeddings + position_embeddings)
|
||||
x = self.decoder(
|
||||
tgt=x,
|
||||
memory=memory,
|
||||
tgt_mask=self.mask,
|
||||
memory_mask=self.memory_mask
|
||||
memory_mask=self.memory_mask,
|
||||
)
|
||||
|
||||
# ==================== Encoder-Only模式 ====================
|
||||
else:
|
||||
# BERT风格:时间步作为特殊token
|
||||
token_embeddings = torch.cat([time_emb, input_emb], dim=1)
|
||||
t = token_embeddings.shape[1]
|
||||
pos_emb = self.pos_emb[:, :t, :]
|
||||
x = self.drop(token_embeddings + pos_emb)
|
||||
|
||||
x = self.encoder(src=x, mask=self.mask)
|
||||
x = x[:, 1:, :] # 移除时间步token
|
||||
|
||||
# ==================== 输出头 ====================
|
||||
x = self.ln_f(x)
|
||||
x = self.head(x) # (B, T, output_dim)
|
||||
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 便捷函数:创建Transformer1D模型
|
||||
# ============================================================================
|
||||
def create_transformer1d(
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
@@ -322,26 +322,9 @@ def create_transformer1d(
|
||||
n_layer: int = 8,
|
||||
n_head: int = 8,
|
||||
n_emb: int = 256,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Transformer1D:
|
||||
"""
|
||||
创建Transformer1D模型的便捷函数
|
||||
|
||||
Args:
|
||||
input_dim: 输入动作维度
|
||||
output_dim: 输出动作维度
|
||||
horizon: 预测horizon
|
||||
n_obs_steps: 观测步数
|
||||
cond_dim: 条件维度
|
||||
n_layer: Transformer层数
|
||||
n_head: 注意力头数
|
||||
n_emb: 嵌入维度
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
Transformer1D模型
|
||||
"""
|
||||
model = Transformer1D(
|
||||
return Transformer1D(
|
||||
input_dim=input_dim,
|
||||
output_dim=output_dim,
|
||||
horizon=horizon,
|
||||
@@ -350,47 +333,5 @@ def create_transformer1d(
|
||||
n_layer=n_layer,
|
||||
n_head=n_head,
|
||||
n_emb=n_emb,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 80)
|
||||
print("Testing Transformer1D")
|
||||
print("=" * 80)
|
||||
|
||||
# 配置
|
||||
B = 4
|
||||
T = 16
|
||||
action_dim = 16
|
||||
obs_horizon = 2
|
||||
cond_dim = 416 # vision + state特征维度
|
||||
|
||||
# 创建模型
|
||||
model = Transformer1D(
|
||||
input_dim=action_dim,
|
||||
output_dim=action_dim,
|
||||
horizon=T,
|
||||
n_obs_steps=obs_horizon,
|
||||
cond_dim=cond_dim,
|
||||
n_layer=4,
|
||||
n_head=8,
|
||||
n_emb=256,
|
||||
causal_attn=False
|
||||
)
|
||||
|
||||
# 测试前向传播
|
||||
sample = torch.randn(B, T, action_dim)
|
||||
timestep = torch.randint(0, 100, (B,))
|
||||
cond = torch.randn(B, obs_horizon, cond_dim)
|
||||
|
||||
output = model(sample, timestep, cond)
|
||||
|
||||
print(f"\n输入:")
|
||||
print(f" sample: {sample.shape}")
|
||||
print(f" timestep: {timestep.shape}")
|
||||
print(f" cond: {cond.shape}")
|
||||
print(f"\n输出:")
|
||||
print(f" output: {output.shape}")
|
||||
print(f"\n✅ 测试通过!")
|
||||
|
||||
@@ -1,8 +1,16 @@
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
import os
|
||||
import glob
|
||||
import pickle
|
||||
|
||||
|
||||
DEFAULT_DATASET_DIR = str(
|
||||
Path(__file__).resolve().parents[2] / "demos" / "dataset" / "sim_transfer"
|
||||
)
|
||||
|
||||
def get_data_stats(dataset_dir):
|
||||
"""
|
||||
@@ -23,6 +31,11 @@ def get_data_stats(dataset_dir):
|
||||
files = sorted(glob.glob(os.path.join(dataset_dir, 'episode_*.hdf5')))
|
||||
print(f"Found {len(files)} episodes in {dataset_dir}")
|
||||
|
||||
if not files:
|
||||
raise ValueError(
|
||||
f"No episode_*.hdf5 files found in dataset_dir: {dataset_dir}"
|
||||
)
|
||||
|
||||
all_actions = []
|
||||
all_qpos = []
|
||||
|
||||
@@ -70,18 +83,32 @@ def get_data_stats(dataset_dir):
|
||||
}
|
||||
return stats_flat
|
||||
|
||||
if __name__ == "__main__":
|
||||
DATASET_DIR = 'roboimi/demos/dataset/sim_transfer'
|
||||
OUTPUT_PATH = DATASET_DIR + "/dataset_stats.pkl"
|
||||
|
||||
stats_flat = get_data_stats(DATASET_DIR)
|
||||
def write_dataset_stats(dataset_dir):
|
||||
output_path = os.path.join(dataset_dir, "dataset_stats.pkl")
|
||||
stats_flat = get_data_stats(dataset_dir)
|
||||
|
||||
# 打印检查
|
||||
print("\n--- Stats Computed ---")
|
||||
print(f"Action Mean shape: {stats_flat['action_mean'].shape}")
|
||||
print(f"Action Std shape: {stats_flat['action_std'].shape}")
|
||||
|
||||
# 保存
|
||||
with open(OUTPUT_PATH, 'wb') as f:
|
||||
with open(output_path, 'wb') as f:
|
||||
pickle.dump(stats_flat, f)
|
||||
print(f"\nStats saved to {OUTPUT_PATH}")
|
||||
print(f"\nStats saved to {output_path}")
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def main(argv=None):
|
||||
parser = argparse.ArgumentParser(description="Calculate dataset statistics.")
|
||||
parser.add_argument(
|
||||
"--dataset_dir",
|
||||
default=DEFAULT_DATASET_DIR,
|
||||
help="Directory containing episode_*.hdf5 files.",
|
||||
)
|
||||
args = parser.parse_args(argv)
|
||||
write_dataset_stats(args.dataset_dir)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
88
tests/test_calculate_stats_cli.py
Normal file
88
tests/test_calculate_stats_cli.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import pickle
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
|
||||
from roboimi.vla.scripts import calculate_stats
|
||||
|
||||
|
||||
class CalculateStatsCliTest(unittest.TestCase):
|
||||
def test_default_dataset_dir_is_absolute_and_package_relative(self):
|
||||
expected = (
|
||||
Path(calculate_stats.__file__).resolve().parents[2]
|
||||
/ "demos"
|
||||
/ "dataset"
|
||||
/ "sim_transfer"
|
||||
)
|
||||
|
||||
self.assertEqual(Path(calculate_stats.DEFAULT_DATASET_DIR), expected)
|
||||
self.assertTrue(Path(calculate_stats.DEFAULT_DATASET_DIR).is_absolute())
|
||||
|
||||
def test_main_writes_dataset_stats_pkl_to_dataset_dir(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
dataset_dir = Path(tmpdir)
|
||||
episode_path = dataset_dir / "episode_0.hdf5"
|
||||
|
||||
with h5py.File(episode_path, "w") as root:
|
||||
root.create_dataset(
|
||||
"action",
|
||||
data=np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32),
|
||||
)
|
||||
observations = root.create_group("observations")
|
||||
observations.create_dataset(
|
||||
"qpos",
|
||||
data=np.array([[5.0, 6.0], [7.0, 8.0]], dtype=np.float32),
|
||||
)
|
||||
|
||||
calculate_stats.main(["--dataset_dir", str(dataset_dir)])
|
||||
|
||||
stats_path = dataset_dir / "dataset_stats.pkl"
|
||||
self.assertTrue(stats_path.exists())
|
||||
|
||||
with stats_path.open("rb") as f:
|
||||
stats = pickle.load(f)
|
||||
|
||||
self.assertEqual(
|
||||
set(stats),
|
||||
{
|
||||
"action_mean",
|
||||
"action_std",
|
||||
"action_min",
|
||||
"action_max",
|
||||
"qpos_mean",
|
||||
"qpos_std",
|
||||
"qpos_min",
|
||||
"qpos_max",
|
||||
},
|
||||
)
|
||||
np.testing.assert_allclose(stats["action_mean"], np.array([2.0, 3.0]))
|
||||
np.testing.assert_allclose(stats["qpos_mean"], np.array([6.0, 7.0]))
|
||||
|
||||
def test_main_raises_clear_error_for_empty_dataset_dir(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
dataset_dir = Path(tmpdir)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"No episode_\*\.hdf5 files found"
|
||||
) as ctx:
|
||||
calculate_stats.main(["--dataset_dir", str(dataset_dir)])
|
||||
|
||||
self.assertIn(str(dataset_dir), str(ctx.exception))
|
||||
|
||||
def test_main_raises_clear_error_for_missing_dataset_dir(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
dataset_dir = Path(tmpdir) / "missing"
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"No episode_\*\.hdf5 files found"
|
||||
) as ctx:
|
||||
calculate_stats.main(["--dataset_dir", str(dataset_dir)])
|
||||
|
||||
self.assertIn(str(dataset_dir), str(ctx.exception))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
28
tests/test_eval_vla_execution.py
Normal file
28
tests/test_eval_vla_execution.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import unittest
|
||||
|
||||
from roboimi.vla.eval_utils import execute_policy_action
|
||||
|
||||
|
||||
class _FakeEnv:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
def step(self, action):
|
||||
self.calls.append(("step", action))
|
||||
|
||||
def step_jnt(self, action):
|
||||
self.calls.append(("step_jnt", action))
|
||||
|
||||
|
||||
class EvalVLAExecutionTest(unittest.TestCase):
|
||||
def test_execute_policy_action_uses_ee_step(self):
|
||||
env = _FakeEnv()
|
||||
action = [1, 2, 3]
|
||||
|
||||
execute_policy_action(env, action)
|
||||
|
||||
self.assertEqual(env.calls, [("step", action)])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
259
tests/test_eval_vla_headless.py
Normal file
259
tests/test_eval_vla_headless.py
Normal file
@@ -0,0 +1,259 @@
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from roboimi.demos.vla_scripts import eval_vla
|
||||
from roboimi.envs.double_base import DualDianaMed
|
||||
from roboimi.envs.double_pos_ctrl_env import make_sim_env
|
||||
|
||||
|
||||
class _FakeAgent:
|
||||
def __init__(self):
|
||||
self.reset_calls = 0
|
||||
self.last_observation = None
|
||||
|
||||
def eval(self):
|
||||
return self
|
||||
|
||||
def to(self, _device):
|
||||
return self
|
||||
|
||||
def reset(self):
|
||||
self.reset_calls += 1
|
||||
|
||||
def select_action(self, observation):
|
||||
self.last_observation = observation
|
||||
return torch.zeros(16)
|
||||
|
||||
|
||||
class _FakeEnv:
|
||||
def __init__(self):
|
||||
self.image_obs_calls = 0
|
||||
self.render_calls = 0
|
||||
self.reset_calls = []
|
||||
|
||||
def reset(self, box_pos):
|
||||
self.reset_calls.append(np.array(box_pos))
|
||||
|
||||
def _get_image_obs(self):
|
||||
self.image_obs_calls += 1
|
||||
return {
|
||||
"images": {
|
||||
"front": np.zeros((8, 8, 3), dtype=np.uint8),
|
||||
}
|
||||
}
|
||||
|
||||
def _get_qpos_obs(self):
|
||||
return {"qpos": np.zeros(16, dtype=np.float32)}
|
||||
|
||||
def render(self):
|
||||
self.render_calls += 1
|
||||
raise AssertionError("env.render() should be skipped when eval.headless=true")
|
||||
|
||||
|
||||
class _RewardTrackingEnv(_FakeEnv):
|
||||
def __init__(self, reward_sequences):
|
||||
super().__init__()
|
||||
self.reward_sequences = reward_sequences
|
||||
self.episode_index = -1
|
||||
self.step_index = 0
|
||||
self.rew = 0.0
|
||||
|
||||
def reset(self, box_pos):
|
||||
super().reset(box_pos)
|
||||
self.episode_index += 1
|
||||
self.step_index = 0
|
||||
|
||||
|
||||
class _FakeRenderer:
|
||||
def __init__(self, env):
|
||||
self._env = env
|
||||
self._frames = [
|
||||
np.full((4, 4, 3), fill_value=index, dtype=np.uint8)
|
||||
for index in range(5)
|
||||
]
|
||||
self._index = 0
|
||||
|
||||
def update_scene(self, _mj_data, camera=None):
|
||||
self._camera = camera
|
||||
|
||||
def render(self):
|
||||
frame = self._frames[self._index]
|
||||
self._index += 1
|
||||
if self._index >= len(self._frames):
|
||||
self._env.exit_flag = True
|
||||
return frame
|
||||
|
||||
|
||||
class EvalVLAHeadlessTest(unittest.TestCase):
|
||||
def test_eval_config_exposes_headless_default(self):
|
||||
eval_cfg = OmegaConf.load(Path("roboimi/vla/conf/eval/eval.yaml"))
|
||||
|
||||
self.assertIn("headless", eval_cfg)
|
||||
self.assertFalse(eval_cfg.headless)
|
||||
|
||||
def test_make_sim_env_accepts_headless_and_disables_render(self):
|
||||
fake_env = object()
|
||||
|
||||
with mock.patch(
|
||||
"roboimi.assets.robots.diana_med.BiDianaMed",
|
||||
return_value="robot",
|
||||
), mock.patch(
|
||||
"roboimi.envs.double_pos_ctrl_env.DualDianaMed_Pos_Ctrl",
|
||||
return_value=fake_env,
|
||||
) as env_cls:
|
||||
env = make_sim_env("sim_transfer", headless=True)
|
||||
|
||||
self.assertIs(env, fake_env)
|
||||
env_cls.assert_called_once_with(
|
||||
robot="robot",
|
||||
is_render=False,
|
||||
control_freq=30,
|
||||
is_interpolate=True,
|
||||
cam_view="angle",
|
||||
)
|
||||
|
||||
def test_camera_viewer_headless_updates_images_without_gui_calls(self):
|
||||
env = DualDianaMed.__new__(DualDianaMed)
|
||||
env.mj_model = object()
|
||||
env.mj_data = object()
|
||||
env.exit_flag = False
|
||||
env.is_render = False
|
||||
env.cam = "angle"
|
||||
env.r_vis = None
|
||||
env.l_vis = None
|
||||
env.top = None
|
||||
env.angle = None
|
||||
env.front = None
|
||||
|
||||
with mock.patch(
|
||||
"roboimi.envs.double_base.mj.Renderer",
|
||||
side_effect=lambda *args, **kwargs: _FakeRenderer(env),
|
||||
), mock.patch("roboimi.envs.double_base.cv2.namedWindow") as named_window, mock.patch(
|
||||
"roboimi.envs.double_base.cv2.imshow"
|
||||
) as imshow, mock.patch("roboimi.envs.double_base.cv2.waitKey") as wait_key:
|
||||
env.camera_viewer()
|
||||
|
||||
named_window.assert_not_called()
|
||||
imshow.assert_not_called()
|
||||
wait_key.assert_not_called()
|
||||
self.assertIsNotNone(env.r_vis)
|
||||
self.assertIsNotNone(env.l_vis)
|
||||
self.assertIsNotNone(env.top)
|
||||
self.assertIsNotNone(env.angle)
|
||||
self.assertIsNotNone(env.front)
|
||||
|
||||
def test_eval_main_headless_skips_render_and_still_executes_policy(self):
|
||||
fake_env = _FakeEnv()
|
||||
fake_agent = _FakeAgent()
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
"agent": {},
|
||||
"eval": {
|
||||
"ckpt_path": "checkpoints/vla_model_best.pt",
|
||||
"num_episodes": 1,
|
||||
"max_timesteps": 1,
|
||||
"device": "cpu",
|
||||
"task_name": "sim_transfer",
|
||||
"camera_names": ["front"],
|
||||
"use_smoothing": False,
|
||||
"smooth_alpha": 0.3,
|
||||
"verbose_action": False,
|
||||
"headless": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
with mock.patch.object(
|
||||
eval_vla,
|
||||
"load_checkpoint",
|
||||
return_value=(fake_agent, None),
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
"make_sim_env",
|
||||
return_value=fake_env,
|
||||
) as make_env, mock.patch.object(
|
||||
eval_vla,
|
||||
"sample_transfer_pose",
|
||||
return_value=np.array([0.1, 0.2, 0.3]),
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
"execute_policy_action",
|
||||
) as execute_policy_action, mock.patch.object(
|
||||
eval_vla,
|
||||
"tqdm",
|
||||
side_effect=lambda iterable, **kwargs: iterable,
|
||||
):
|
||||
eval_vla.main.__wrapped__(cfg)
|
||||
|
||||
make_env.assert_called_once_with("sim_transfer", headless=True)
|
||||
execute_policy_action.assert_called_once()
|
||||
self.assertEqual(fake_env.image_obs_calls, 1)
|
||||
self.assertEqual(fake_env.render_calls, 0)
|
||||
self.assertIsNotNone(fake_agent.last_observation)
|
||||
self.assertIn("front", fake_agent.last_observation["images"])
|
||||
|
||||
def test_run_eval_returns_average_reward_summary(self):
|
||||
reward_sequences = [
|
||||
[1.0, 2.0],
|
||||
[0.5, 4.0],
|
||||
]
|
||||
fake_env = _RewardTrackingEnv(reward_sequences)
|
||||
fake_agent = _FakeAgent()
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
"agent": {},
|
||||
"eval": {
|
||||
"ckpt_path": "checkpoints/vla_model_best.pt",
|
||||
"num_episodes": 2,
|
||||
"max_timesteps": 2,
|
||||
"device": "cpu",
|
||||
"task_name": "sim_transfer",
|
||||
"camera_names": ["front"],
|
||||
"use_smoothing": False,
|
||||
"smooth_alpha": 0.3,
|
||||
"verbose_action": False,
|
||||
"headless": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
def fake_execute_policy_action(env, action):
|
||||
del action
|
||||
env.rew = env.reward_sequences[env.episode_index][env.step_index]
|
||||
env.step_index += 1
|
||||
|
||||
with mock.patch.object(
|
||||
eval_vla,
|
||||
"load_checkpoint",
|
||||
return_value=(fake_agent, None),
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
"make_sim_env",
|
||||
return_value=fake_env,
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
"sample_transfer_pose",
|
||||
return_value=np.array([0.1, 0.2, 0.3]),
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
"execute_policy_action",
|
||||
side_effect=fake_execute_policy_action,
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
"tqdm",
|
||||
side_effect=lambda iterable, **kwargs: iterable,
|
||||
):
|
||||
summary = eval_vla._run_eval(cfg)
|
||||
|
||||
self.assertEqual(summary["episode_rewards"], [3.0, 4.5])
|
||||
self.assertAlmostEqual(summary["avg_reward"], 3.75)
|
||||
self.assertEqual(summary["num_episodes"], 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
228
tests/test_eval_vla_rollout_artifacts.py
Normal file
228
tests/test_eval_vla_rollout_artifacts.py
Normal file
@@ -0,0 +1,228 @@
|
||||
import json
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from roboimi.demos.vla_scripts import eval_vla
|
||||
|
||||
|
||||
class _FakeAgent:
|
||||
def __init__(self, actions):
|
||||
self._actions = [torch.tensor(action, dtype=torch.float32) for action in actions]
|
||||
self.reset_calls = 0
|
||||
|
||||
def eval(self):
|
||||
return self
|
||||
|
||||
def to(self, _device):
|
||||
return self
|
||||
|
||||
def reset(self):
|
||||
self.reset_calls += 1
|
||||
|
||||
def select_action(self, observation):
|
||||
del observation
|
||||
return self._actions.pop(0)
|
||||
|
||||
|
||||
class _FakeEnv:
|
||||
def __init__(self):
|
||||
self.step_count = 0
|
||||
self.rew = 0.0
|
||||
self.render_calls = 0
|
||||
self.reset_calls = []
|
||||
|
||||
def reset(self, box_pos):
|
||||
self.reset_calls.append(np.array(box_pos, copy=True))
|
||||
self.step_count = 0
|
||||
self.rew = 0.0
|
||||
|
||||
def _get_image_obs(self):
|
||||
frame_value = self.step_count
|
||||
front = np.full((6, 8, 3), fill_value=frame_value, dtype=np.uint8)
|
||||
top = np.full((6, 8, 3), fill_value=frame_value + 20, dtype=np.uint8)
|
||||
return {"images": {"front": front, "top": top}}
|
||||
|
||||
def _get_qpos_obs(self):
|
||||
return {"qpos": np.arange(16, dtype=np.float32)}
|
||||
|
||||
def step(self, action):
|
||||
del action
|
||||
self.step_count += 1
|
||||
self.rew = float(self.step_count)
|
||||
|
||||
def render(self):
|
||||
self.render_calls += 1
|
||||
|
||||
def getBodyPos(self, name):
|
||||
base = float(self.step_count)
|
||||
if name == 'eef_left':
|
||||
return np.array([base, base + 0.1, base + 0.2], dtype=np.float32)
|
||||
if name == 'eef_right':
|
||||
return np.array([base + 1.0, base + 1.1, base + 1.2], dtype=np.float32)
|
||||
raise KeyError(name)
|
||||
|
||||
def getBodyQuat(self, name):
|
||||
base = float(self.step_count)
|
||||
if name == 'eef_left':
|
||||
return np.array([1.0, base, 0.0, 0.0], dtype=np.float32)
|
||||
if name == 'eef_right':
|
||||
return np.array([1.0, 0.0, base, 0.0], dtype=np.float32)
|
||||
raise KeyError(name)
|
||||
|
||||
|
||||
class _FakeVideoWriter:
|
||||
def __init__(self, output_path):
|
||||
self.output_path = Path(output_path)
|
||||
self.output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.output_path.write_bytes(b'')
|
||||
self.frames = []
|
||||
self.released = False
|
||||
|
||||
def isOpened(self):
|
||||
return True
|
||||
|
||||
def write(self, frame):
|
||||
self.frames.append(np.array(frame, copy=True))
|
||||
|
||||
def release(self):
|
||||
self.released = True
|
||||
self.output_path.write_bytes(b'fake-mp4')
|
||||
|
||||
|
||||
class EvalVLARolloutArtifactsTest(unittest.TestCase):
|
||||
def test_eval_config_exposes_rollout_artifact_defaults(self):
|
||||
eval_cfg = OmegaConf.load(Path('roboimi/vla/conf/eval/eval.yaml'))
|
||||
|
||||
self.assertIn('artifact_dir', eval_cfg)
|
||||
self.assertFalse(eval_cfg.save_summary_json)
|
||||
self.assertFalse(eval_cfg.save_trajectory_npz)
|
||||
self.assertFalse(eval_cfg.record_video)
|
||||
self.assertIsNone(eval_cfg.artifact_dir)
|
||||
self.assertIsNone(eval_cfg.video_camera_name)
|
||||
self.assertEqual(eval_cfg.video_fps, 30)
|
||||
|
||||
def test_run_eval_exports_npz_summary_and_video_artifacts(self):
|
||||
actions = [
|
||||
np.arange(16, dtype=np.float32),
|
||||
np.arange(16, dtype=np.float32) + 10.0,
|
||||
]
|
||||
fake_agent = _FakeAgent(actions)
|
||||
fake_env = _FakeEnv()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
'agent': {},
|
||||
'eval': {
|
||||
'ckpt_path': 'checkpoints/vla_model_best.pt',
|
||||
'num_episodes': 1,
|
||||
'max_timesteps': 2,
|
||||
'device': 'cpu',
|
||||
'task_name': 'sim_transfer',
|
||||
'camera_names': ['front', 'top'],
|
||||
'use_smoothing': True,
|
||||
'smooth_alpha': 0.5,
|
||||
'verbose_action': False,
|
||||
'headless': True,
|
||||
'artifact_dir': tmpdir,
|
||||
'save_summary_json': True,
|
||||
'save_trajectory_npz': True,
|
||||
'record_video': True,
|
||||
'video_camera_name': 'front',
|
||||
'video_fps': 12,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
writer_holder = {}
|
||||
|
||||
def fake_open_video_writer(output_path, frame_size, fps):
|
||||
self.assertEqual(frame_size, (8, 6))
|
||||
self.assertEqual(fps, 12)
|
||||
writer = _FakeVideoWriter(output_path)
|
||||
writer_holder['writer'] = writer
|
||||
return writer
|
||||
|
||||
with mock.patch.object(
|
||||
eval_vla,
|
||||
'load_checkpoint',
|
||||
return_value=(fake_agent, None),
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
'make_sim_env',
|
||||
return_value=fake_env,
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
'sample_transfer_pose',
|
||||
return_value=np.array([0.1, 0.2, 0.3], dtype=np.float32),
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
'tqdm',
|
||||
side_effect=lambda iterable, **kwargs: iterable,
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
'_open_video_writer',
|
||||
side_effect=fake_open_video_writer,
|
||||
):
|
||||
summary = eval_vla._run_eval(cfg)
|
||||
|
||||
artifacts = summary['artifacts']
|
||||
trajectory_path = Path(artifacts['trajectory_npz'])
|
||||
summary_path = Path(artifacts['summary_json'])
|
||||
video_path = Path(artifacts['video_mp4'])
|
||||
|
||||
self.assertEqual(Path(artifacts['output_dir']), Path(tmpdir))
|
||||
self.assertEqual(artifacts['video_camera_name'], 'front')
|
||||
self.assertTrue(trajectory_path.exists())
|
||||
self.assertTrue(summary_path.exists())
|
||||
self.assertTrue(video_path.exists())
|
||||
|
||||
rollout_npz = np.load(trajectory_path)
|
||||
np.testing.assert_array_equal(rollout_npz['episode_index'], np.array([0, 0]))
|
||||
np.testing.assert_array_equal(rollout_npz['timestep'], np.array([0, 1]))
|
||||
np.testing.assert_array_equal(rollout_npz['reward'], np.array([1.0, 2.0], dtype=np.float32))
|
||||
np.testing.assert_array_equal(rollout_npz['raw_predicted_ee_action'][0], actions[0])
|
||||
np.testing.assert_array_equal(rollout_npz['raw_predicted_ee_action'][1], actions[1])
|
||||
np.testing.assert_array_equal(rollout_npz['executed_ee_action'][0], actions[0])
|
||||
np.testing.assert_array_equal(
|
||||
rollout_npz['executed_ee_action'][1],
|
||||
(actions[0] + actions[1]) / 2.0,
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
rollout_npz['left_ee_pos'],
|
||||
np.array([[1.0, 1.1, 1.2], [2.0, 2.1, 2.2]], dtype=np.float32),
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
rollout_npz['right_ee_pos'],
|
||||
np.array([[2.0, 2.1, 2.2], [3.0, 3.1, 3.2]], dtype=np.float32),
|
||||
)
|
||||
self.assertEqual(rollout_npz['obs_read_time_ms'].shape, (2,))
|
||||
self.assertEqual(rollout_npz['preprocess_time_ms'].shape, (2,))
|
||||
self.assertEqual(rollout_npz['inference_time_ms'].shape, (2,))
|
||||
self.assertEqual(rollout_npz['env_step_time_ms'].shape, (2,))
|
||||
self.assertEqual(rollout_npz['total_time_ms'].shape, (2,))
|
||||
|
||||
writer = writer_holder['writer']
|
||||
self.assertTrue(writer.released)
|
||||
self.assertEqual(len(writer.frames), 2)
|
||||
np.testing.assert_array_equal(writer.frames[0], np.zeros((6, 8, 3), dtype=np.uint8))
|
||||
np.testing.assert_array_equal(writer.frames[1], np.full((6, 8, 3), 1, dtype=np.uint8))
|
||||
|
||||
with summary_path.open('r', encoding='utf-8') as fh:
|
||||
saved_summary = json.load(fh)
|
||||
self.assertEqual(saved_summary['artifacts']['trajectory_npz'], str(trajectory_path))
|
||||
self.assertEqual(saved_summary['artifacts']['video_mp4'], str(video_path))
|
||||
self.assertEqual(saved_summary['episode_rewards'], [3.0])
|
||||
self.assertAlmostEqual(summary['avg_reward'], 3.0)
|
||||
self.assertIn('avg_obs_read_time_ms', summary)
|
||||
self.assertIn('avg_env_step_time_ms', summary)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
119
tests/test_raw_action_trajectory_viewer.py
Normal file
119
tests/test_raw_action_trajectory_viewer.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
|
||||
from roboimi.utils import raw_action_trajectory_viewer as traj_view
|
||||
|
||||
|
||||
class RawActionTrajectoryViewerTest(unittest.TestCase):
|
||||
def test_set_transfer_box_pose_writes_joint_qpos(self):
|
||||
joint_qpos = np.zeros(7, dtype=np.float64)
|
||||
|
||||
class _FakeJoint:
|
||||
def __init__(self, qpos):
|
||||
self.qpos = qpos
|
||||
|
||||
class _FakeData:
|
||||
def joint(self, name):
|
||||
assert name == "red_box_joint"
|
||||
return _FakeJoint(joint_qpos)
|
||||
|
||||
traj_view.set_transfer_box_pose(_FakeData(), np.array([0.2, -0.1, 1.05], dtype=np.float64))
|
||||
|
||||
np.testing.assert_array_equal(
|
||||
joint_qpos,
|
||||
np.array([0.2, -0.1, 1.05, 1.0, 0.0, 0.0, 0.0], dtype=np.float64),
|
||||
)
|
||||
|
||||
def test_disable_cv2_highgui_temporarily_replaces_gui_calls(self):
|
||||
fake_cv2 = SimpleNamespace(
|
||||
namedWindow=lambda *args, **kwargs: "named",
|
||||
imshow=lambda *args, **kwargs: "imshow",
|
||||
waitKey=lambda *args, **kwargs: "wait",
|
||||
)
|
||||
|
||||
restore = traj_view.disable_cv2_highgui(fake_cv2)
|
||||
self.assertIsNone(fake_cv2.namedWindow("x"))
|
||||
self.assertIsNone(fake_cv2.imshow("x", None))
|
||||
self.assertEqual(fake_cv2.waitKey(1), 1)
|
||||
|
||||
restore()
|
||||
self.assertEqual(fake_cv2.namedWindow("x"), "named")
|
||||
self.assertEqual(fake_cv2.imshow("x", None), "imshow")
|
||||
self.assertEqual(fake_cv2.waitKey(1), "wait")
|
||||
|
||||
def test_load_raw_action_positions_from_npz(self):
|
||||
raw_action = np.array(
|
||||
[
|
||||
[1.0, 2.0, 3.0, 0, 0, 0, 1, 11.0, 12.0, 13.0, 0, 0, 0, 1, -1, -1],
|
||||
[4.0, 5.0, 6.0, 0, 0, 0, 1, 14.0, 15.0, 16.0, 0, 0, 0, 1, -1, -1],
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = Path(tmpdir) / "trajectory.npz"
|
||||
np.savez(path, raw_action=raw_action)
|
||||
|
||||
positions = traj_view.load_raw_action_positions(path)
|
||||
|
||||
np.testing.assert_array_equal(
|
||||
positions["left"],
|
||||
np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32),
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
positions["right"],
|
||||
np.array([[11.0, 12.0, 13.0], [14.0, 15.0, 16.0]], dtype=np.float32),
|
||||
)
|
||||
|
||||
def test_build_red_capsule_segments_downsamples_to_fit_scene_limit(self):
|
||||
left = np.stack([np.array([float(i), 0.0, 0.0], dtype=np.float32) for i in range(6)])
|
||||
right = np.stack([np.array([float(i), 1.0, 0.0], dtype=np.float32) for i in range(6)])
|
||||
|
||||
markers = traj_view.build_trajectory_capsule_markers(
|
||||
{"left": left, "right": right},
|
||||
max_markers=4,
|
||||
radius=0.01,
|
||||
)
|
||||
|
||||
self.assertLessEqual(len(markers), 4)
|
||||
self.assertTrue(all(marker["rgba"] == (1.0, 0.0, 0.0, 1.0) for marker in markers))
|
||||
self.assertTrue(all(marker["radius"] == 0.01 for marker in markers))
|
||||
|
||||
def test_apply_capsule_markers_populates_user_scene(self):
|
||||
fake_scene = SimpleNamespace(
|
||||
maxgeom=3,
|
||||
ngeom=99,
|
||||
geoms=[object(), object(), object()],
|
||||
)
|
||||
markers = [
|
||||
{
|
||||
"from": np.array([0.0, 0.0, 0.0], dtype=np.float64),
|
||||
"to": np.array([1.0, 0.0, 0.0], dtype=np.float64),
|
||||
"rgba": (1.0, 0.0, 0.0, 1.0),
|
||||
"radius": 0.01,
|
||||
},
|
||||
{
|
||||
"from": np.array([0.0, 1.0, 0.0], dtype=np.float64),
|
||||
"to": np.array([1.0, 1.0, 0.0], dtype=np.float64),
|
||||
"rgba": (1.0, 0.0, 0.0, 1.0),
|
||||
"radius": 0.01,
|
||||
},
|
||||
]
|
||||
|
||||
with mock.patch.object(traj_view.mujoco, "mjv_initGeom") as init_geom, mock.patch.object(
|
||||
traj_view.mujoco,
|
||||
"mjv_connector",
|
||||
) as connector:
|
||||
traj_view.apply_capsule_markers_to_scene(fake_scene, markers)
|
||||
|
||||
self.assertEqual(fake_scene.ngeom, 2)
|
||||
self.assertEqual(init_geom.call_count, 2)
|
||||
self.assertEqual(connector.call_count, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
387
tests/test_resnet_transformer_agent_wiring.py
Normal file
387
tests/test_resnet_transformer_agent_wiring.py
Normal file
@@ -0,0 +1,387 @@
|
||||
import contextlib
|
||||
import sys
|
||||
import types
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from hydra import compose, initialize_config_dir
|
||||
from hydra.errors import InstantiationException
|
||||
from hydra.core.global_hydra import GlobalHydra
|
||||
from hydra.utils import instantiate
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
_CONFIG_DIR = str((_REPO_ROOT / 'roboimi/vla/conf').resolve())
|
||||
_EXPECTED_CAMERA_NAMES = ['r_vis', 'top', 'front']
|
||||
_MISSING = object()
|
||||
|
||||
|
||||
class _FakeScheduler:
|
||||
def __init__(self, num_train_timesteps=100, **kwargs):
|
||||
self.config = types.SimpleNamespace(num_train_timesteps=num_train_timesteps)
|
||||
self.timesteps = []
|
||||
|
||||
def add_noise(self, sample, noise, timestep):
|
||||
return sample + noise
|
||||
|
||||
def set_timesteps(self, num_inference_steps):
|
||||
self.timesteps = list(range(num_inference_steps - 1, -1, -1))
|
||||
|
||||
def step(self, noise_pred, timestep, sample):
|
||||
return types.SimpleNamespace(prev_sample=sample)
|
||||
|
||||
|
||||
class _IdentityCrop:
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class _FakeResNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv2d(3, 8, kernel_size=3, padding=1)
|
||||
self.relu1 = torch.nn.ReLU()
|
||||
self.conv2 = torch.nn.Conv2d(8, 16, kernel_size=3, padding=1, stride=2)
|
||||
self.relu2 = torch.nn.ReLU()
|
||||
self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = torch.nn.Linear(16, 16)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu1(self.conv1(x))
|
||||
x = self.relu2(self.conv2(x))
|
||||
x = self.avgpool(x)
|
||||
x = torch.flatten(x, start_dim=1)
|
||||
return self.fc(x)
|
||||
|
||||
|
||||
class _FakeRearrange(torch.nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class _CondCapturingHead(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.last_cond = None
|
||||
|
||||
def forward(self, sample, timestep, cond):
|
||||
self.last_cond = cond.detach().clone()
|
||||
return torch.zeros_like(sample)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _stub_optional_modules():
|
||||
previous_modules = {}
|
||||
|
||||
def inject(name, module):
|
||||
if name not in previous_modules:
|
||||
previous_modules[name] = sys.modules.get(name, _MISSING)
|
||||
sys.modules[name] = module
|
||||
|
||||
diffusers_module = types.ModuleType('diffusers')
|
||||
schedulers_module = types.ModuleType('diffusers.schedulers')
|
||||
ddpm_module = types.ModuleType('diffusers.schedulers.scheduling_ddpm')
|
||||
ddim_module = types.ModuleType('diffusers.schedulers.scheduling_ddim')
|
||||
ddpm_module.DDPMScheduler = _FakeScheduler
|
||||
ddim_module.DDIMScheduler = _FakeScheduler
|
||||
diffusers_module.DDPMScheduler = _FakeScheduler
|
||||
diffusers_module.DDIMScheduler = _FakeScheduler
|
||||
diffusers_module.schedulers = schedulers_module
|
||||
schedulers_module.scheduling_ddpm = ddpm_module
|
||||
schedulers_module.scheduling_ddim = ddim_module
|
||||
|
||||
torchvision_module = types.ModuleType('torchvision')
|
||||
models_module = types.ModuleType('torchvision.models')
|
||||
transforms_module = types.ModuleType('torchvision.transforms')
|
||||
models_module.resnet18 = lambda weights=None: _FakeResNet()
|
||||
transforms_module.CenterCrop = _IdentityCrop
|
||||
transforms_module.RandomCrop = _IdentityCrop
|
||||
torchvision_module.models = models_module
|
||||
torchvision_module.transforms = transforms_module
|
||||
|
||||
einops_module = types.ModuleType('einops')
|
||||
einops_module.rearrange = lambda x, *args, **kwargs: x
|
||||
einops_layers_module = types.ModuleType('einops.layers')
|
||||
einops_layers_torch_module = types.ModuleType('einops.layers.torch')
|
||||
einops_layers_torch_module.Rearrange = _FakeRearrange
|
||||
einops_module.layers = einops_layers_module
|
||||
einops_layers_module.torch = einops_layers_torch_module
|
||||
|
||||
try:
|
||||
inject('diffusers', diffusers_module)
|
||||
inject('diffusers.schedulers', schedulers_module)
|
||||
inject('diffusers.schedulers.scheduling_ddpm', ddpm_module)
|
||||
inject('diffusers.schedulers.scheduling_ddim', ddim_module)
|
||||
inject('torchvision', torchvision_module)
|
||||
inject('torchvision.models', models_module)
|
||||
inject('torchvision.transforms', transforms_module)
|
||||
inject('einops', einops_module)
|
||||
inject('einops.layers', einops_layers_module)
|
||||
inject('einops.layers.torch', einops_layers_torch_module)
|
||||
yield
|
||||
finally:
|
||||
for name, previous in reversed(list(previous_modules.items())):
|
||||
if previous is _MISSING:
|
||||
sys.modules.pop(name, None)
|
||||
else:
|
||||
sys.modules[name] = previous
|
||||
|
||||
|
||||
def _compose_cfg(overrides=None):
|
||||
if not OmegaConf.has_resolver('len'):
|
||||
OmegaConf.register_new_resolver('len', lambda x: len(x))
|
||||
|
||||
GlobalHydra.instance().clear()
|
||||
with initialize_config_dir(version_base=None, config_dir=_CONFIG_DIR):
|
||||
return compose(config_name='config', overrides=list(overrides or []))
|
||||
|
||||
|
||||
def _make_images(batch_size, obs_horizon, image_shape, per_camera_fill=None):
|
||||
channels, height, width = image_shape
|
||||
per_camera_fill = per_camera_fill or {
|
||||
'front': 30.0,
|
||||
'top': 20.0,
|
||||
'r_vis': 10.0,
|
||||
}
|
||||
return {
|
||||
name: torch.full(
|
||||
(batch_size, obs_horizon, channels, height, width),
|
||||
fill_value=fill_value,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
for name, fill_value in per_camera_fill.items()
|
||||
}
|
||||
|
||||
|
||||
def _patch_backbone_for_order_tracking(backbone):
|
||||
feature_dim = backbone.output_dim
|
||||
|
||||
def encode_mean(image_batch):
|
||||
mean_feature = image_batch.mean(dim=(1, 2, 3)).unsqueeze(-1)
|
||||
return mean_feature.repeat(1, feature_dim)
|
||||
|
||||
if backbone.use_separate_rgb_encoder_per_camera:
|
||||
for encoder in backbone.rgb_encoder:
|
||||
encoder.forward_single_image = encode_mean
|
||||
else:
|
||||
backbone.rgb_encoder.forward_single_image = encode_mean
|
||||
|
||||
|
||||
def _extract_camera_markers(cond, feature_dim, num_cams):
|
||||
camera_block = cond[0, 0, : feature_dim * num_cams].view(num_cams, feature_dim)
|
||||
return camera_block[:, 0]
|
||||
|
||||
|
||||
class ResNetTransformerAgentWiringTest(unittest.TestCase):
|
||||
def test_hydra_wiring_uses_required_three_camera_transformer_conditioning_in_agent_order_and_ignores_extra_keys(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||
'agent.vision_backbone.input_shape=[3,16,16]',
|
||||
'agent.inference_steps=1',
|
||||
'agent.head.n_layer=1',
|
||||
'agent.head.n_cond_layers=0',
|
||||
'agent.head.n_emb=32',
|
||||
'agent.head.n_head=4',
|
||||
]
|
||||
)
|
||||
|
||||
self.assertEqual(list(cfg.data.camera_names), _EXPECTED_CAMERA_NAMES)
|
||||
self.assertEqual(list(cfg.eval.camera_names), _EXPECTED_CAMERA_NAMES)
|
||||
self.assertEqual(list(cfg.agent.camera_names), _EXPECTED_CAMERA_NAMES)
|
||||
self.assertEqual(list(cfg.agent.vision_backbone.camera_names), _EXPECTED_CAMERA_NAMES)
|
||||
self.assertEqual(cfg.agent.head_type, 'transformer')
|
||||
self.assertEqual(cfg.agent.num_cams, 3)
|
||||
self.assertTrue(cfg.agent.head.obs_as_cond)
|
||||
self.assertFalse(cfg.agent.head.causal_attn)
|
||||
|
||||
with _stub_optional_modules():
|
||||
agent = instantiate(cfg.agent)
|
||||
expected_cond_dim = agent.vision_encoder.output_dim * agent.num_cams + agent.obs_dim
|
||||
self.assertEqual(cfg.agent.head.cond_dim, expected_cond_dim)
|
||||
self.assertEqual(agent.per_step_cond_dim, expected_cond_dim)
|
||||
self.assertEqual(agent.noise_pred_net.cond_obs_emb.in_features, expected_cond_dim)
|
||||
|
||||
batch_size = 2
|
||||
image_shape = tuple(cfg.agent.vision_backbone.input_shape)
|
||||
images = _make_images(
|
||||
batch_size,
|
||||
cfg.agent.obs_horizon,
|
||||
image_shape,
|
||||
per_camera_fill={
|
||||
'front': 30.0,
|
||||
'top': 20.0,
|
||||
'r_vis': 10.0,
|
||||
'left_wrist': 99.0,
|
||||
},
|
||||
)
|
||||
proprioception = torch.randn(batch_size, cfg.agent.obs_horizon, cfg.agent.obs_dim)
|
||||
_patch_backbone_for_order_tracking(agent.vision_encoder)
|
||||
capturing_head = _CondCapturingHead()
|
||||
agent.noise_pred_net = capturing_head
|
||||
predicted_actions = agent.predict_action(images, proprioception)
|
||||
self.assertEqual(
|
||||
predicted_actions.shape,
|
||||
(batch_size, cfg.agent.pred_horizon, cfg.agent.action_dim),
|
||||
)
|
||||
self.assertIsNotNone(capturing_head.last_cond)
|
||||
self.assertEqual(capturing_head.last_cond.shape[-1], expected_cond_dim)
|
||||
camera_markers = _extract_camera_markers(
|
||||
capturing_head.last_cond,
|
||||
agent.vision_encoder.output_dim,
|
||||
agent.num_cams,
|
||||
)
|
||||
self.assertTrue(torch.allclose(camera_markers, torch.tensor([10.0, 20.0, 30.0])))
|
||||
|
||||
missing_images = dict(images)
|
||||
missing_images.pop('top')
|
||||
with self.assertRaisesRegex(ValueError, 'missing=.*top'):
|
||||
agent.predict_action(missing_images, proprioception)
|
||||
|
||||
def test_agent_rejects_conflicting_explicit_backbone_camera_names(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||
'agent.vision_backbone.input_shape=[3,16,16]',
|
||||
]
|
||||
)
|
||||
cfg.agent.vision_backbone.camera_names = ['front', 'top', 'r_vis']
|
||||
|
||||
with _stub_optional_modules():
|
||||
with self.assertRaisesRegex(InstantiationException, 'camera_names'):
|
||||
instantiate(cfg.agent)
|
||||
|
||||
def test_backbone_uses_sorted_fallback_order_when_camera_names_unset(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||
'agent.vision_backbone.input_shape=[3,16,16]',
|
||||
]
|
||||
)
|
||||
cfg.agent.vision_backbone.camera_names = None
|
||||
|
||||
with _stub_optional_modules():
|
||||
backbone = instantiate(cfg.agent.vision_backbone)
|
||||
_patch_backbone_for_order_tracking(backbone)
|
||||
images = _make_images(
|
||||
batch_size=1,
|
||||
obs_horizon=cfg.agent.obs_horizon,
|
||||
image_shape=tuple(cfg.agent.vision_backbone.input_shape),
|
||||
per_camera_fill={
|
||||
'top': 20.0,
|
||||
'front': 30.0,
|
||||
'r_vis': 10.0,
|
||||
},
|
||||
)
|
||||
ordered_features = backbone(images)
|
||||
camera_markers = _extract_camera_markers(
|
||||
ordered_features,
|
||||
backbone.output_dim,
|
||||
len(images),
|
||||
)
|
||||
self.assertTrue(torch.allclose(camera_markers, torch.tensor([30.0, 10.0, 20.0])))
|
||||
|
||||
def test_agent_queue_fallback_order_is_deterministic_when_camera_names_unset(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||
'agent.vision_backbone.input_shape=[3,16,16]',
|
||||
]
|
||||
)
|
||||
cfg.agent.camera_names = None
|
||||
cfg.agent.vision_backbone.camera_names = None
|
||||
|
||||
with _stub_optional_modules():
|
||||
agent = instantiate(cfg.agent)
|
||||
observation = {
|
||||
'qpos': torch.randn(cfg.agent.obs_dim),
|
||||
'images': {
|
||||
'top': torch.full(tuple(cfg.agent.vision_backbone.input_shape), 20.0),
|
||||
'front': torch.full(tuple(cfg.agent.vision_backbone.input_shape), 30.0),
|
||||
'r_vis': torch.full(tuple(cfg.agent.vision_backbone.input_shape), 10.0),
|
||||
},
|
||||
}
|
||||
agent._populate_queues(observation)
|
||||
batch = agent._prepare_observation_batch()
|
||||
self.assertEqual(list(batch['images'].keys()), ['front', 'r_vis', 'top'])
|
||||
|
||||
def test_backbone_rejects_camera_count_mismatch_when_camera_names_unset(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||
'agent.vision_backbone.input_shape=[3,16,16]',
|
||||
]
|
||||
)
|
||||
cfg.agent.vision_backbone.camera_names = None
|
||||
|
||||
with _stub_optional_modules():
|
||||
backbone = instantiate(cfg.agent.vision_backbone)
|
||||
images = _make_images(
|
||||
batch_size=1,
|
||||
obs_horizon=cfg.agent.obs_horizon,
|
||||
image_shape=tuple(cfg.agent.vision_backbone.input_shape),
|
||||
per_camera_fill={
|
||||
'front': 30.0,
|
||||
'r_vis': 10.0,
|
||||
},
|
||||
)
|
||||
with self.assertRaisesRegex(ValueError, 'num_cameras'):
|
||||
backbone(images)
|
||||
|
||||
def test_agent_rejects_camera_count_mismatch_when_camera_names_unset(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||
'agent.vision_backbone.input_shape=[3,16,16]',
|
||||
'agent.inference_steps=1',
|
||||
'agent.head.n_layer=1',
|
||||
'agent.head.n_cond_layers=0',
|
||||
'agent.head.n_emb=32',
|
||||
'agent.head.n_head=4',
|
||||
]
|
||||
)
|
||||
cfg.agent.camera_names = None
|
||||
cfg.agent.vision_backbone.camera_names = None
|
||||
|
||||
with _stub_optional_modules():
|
||||
agent = instantiate(cfg.agent)
|
||||
images = _make_images(
|
||||
batch_size=1,
|
||||
obs_horizon=cfg.agent.obs_horizon,
|
||||
image_shape=tuple(cfg.agent.vision_backbone.input_shape),
|
||||
per_camera_fill={
|
||||
'front': 30.0,
|
||||
'r_vis': 10.0,
|
||||
},
|
||||
)
|
||||
proprioception = torch.randn(1, cfg.agent.obs_horizon, cfg.agent.obs_dim)
|
||||
with self.assertRaisesRegex(ValueError, 'num_cams'):
|
||||
agent.predict_action(images, proprioception)
|
||||
|
||||
def test_agent_rejects_num_cams_mismatch_with_backbone_when_camera_names_unset(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||
'agent.vision_backbone.input_shape=[3,16,16]',
|
||||
]
|
||||
)
|
||||
cfg.agent.camera_names = None
|
||||
cfg.agent.vision_backbone.camera_names = None
|
||||
cfg.agent.num_cams = 2
|
||||
cfg.agent.vision_backbone.num_cameras = 3
|
||||
|
||||
with _stub_optional_modules():
|
||||
with self.assertRaisesRegex(InstantiationException, 'num_cams'):
|
||||
instantiate(cfg.agent)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
63
tests/test_robot_asset_paths.py
Normal file
63
tests/test_robot_asset_paths.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
from roboimi.assets.robots.diana_med import BiDianaMed
|
||||
|
||||
|
||||
class _FakeKDL:
|
||||
init_calls = []
|
||||
reset_calls = []
|
||||
|
||||
def __init__(self, urdf_path):
|
||||
self.__class__.init_calls.append(urdf_path)
|
||||
|
||||
def resetChain(self, base, end):
|
||||
self.__class__.reset_calls.append((base, end))
|
||||
|
||||
|
||||
class RobotAssetPathResolutionTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
_FakeKDL.init_calls = []
|
||||
_FakeKDL.reset_calls = []
|
||||
|
||||
def test_bidianamed_resolves_robot_asset_paths_independent_of_cwd(self):
|
||||
repo_root = Path(__file__).resolve().parents[1]
|
||||
expected_xml = repo_root / 'roboimi/assets/models/manipulators/DianaMed/bi_diana_transfer_ee.xml'
|
||||
expected_urdf = repo_root / 'roboimi/assets/models/manipulators/DianaMed/DualDianaMed.urdf'
|
||||
xml_calls = []
|
||||
|
||||
def fake_from_xml_path(*, filename, assets=None):
|
||||
xml_calls.append((filename, assets))
|
||||
return object()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
with mock.patch(
|
||||
'roboimi.assets.robots.arm_base.mujoco.MjModel.from_xml_path',
|
||||
side_effect=fake_from_xml_path,
|
||||
), mock.patch(
|
||||
'roboimi.assets.robots.arm_base.mujoco.MjData',
|
||||
return_value=object(),
|
||||
), mock.patch(
|
||||
'roboimi.assets.robots.arm_base.KDL_utils',
|
||||
_FakeKDL,
|
||||
):
|
||||
BiDianaMed()
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
self.assertEqual(len(xml_calls), 1)
|
||||
self.assertEqual(Path(xml_calls[0][0]), expected_xml)
|
||||
self.assertTrue(Path(xml_calls[0][0]).is_absolute())
|
||||
self.assertGreaterEqual(len(_FakeKDL.init_calls), 2)
|
||||
self.assertEqual({Path(path) for path in _FakeKDL.init_calls}, {expected_urdf})
|
||||
self.assertTrue(all(Path(path).is_absolute() for path in _FakeKDL.init_calls))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
58
tests/test_simple_robot_dataset_image_loading.py
Normal file
58
tests/test_simple_robot_dataset_image_loading.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import sys
|
||||
import tempfile
|
||||
import types
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
|
||||
from roboimi.vla.data.simpe_robot_dataset import SimpleRobotDataset
|
||||
|
||||
|
||||
class SimpleRobotDatasetImageLoadingTest(unittest.TestCase):
|
||||
def _write_episode(self, dataset_dir: Path) -> None:
|
||||
episode_path = dataset_dir / "episode_0.hdf5"
|
||||
with h5py.File(episode_path, "w") as root:
|
||||
root.create_dataset("action", data=np.arange(8, dtype=np.float32).reshape(4, 2))
|
||||
root.create_dataset(
|
||||
"observations/qpos",
|
||||
data=np.arange(16, dtype=np.float32).reshape(4, 4),
|
||||
)
|
||||
root.create_dataset("task", data=np.array([b"sim_transfer"]))
|
||||
root.create_dataset(
|
||||
"observations/images/front",
|
||||
data=np.arange(4 * 8 * 8 * 3, dtype=np.uint8).reshape(4, 8, 8, 3),
|
||||
)
|
||||
|
||||
def test_getitem_only_resizes_observation_horizon_images(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
dataset_dir = Path(tmpdir)
|
||||
self._write_episode(dataset_dir)
|
||||
dataset = SimpleRobotDataset(
|
||||
dataset_dir,
|
||||
obs_horizon=2,
|
||||
pred_horizon=3,
|
||||
camera_names=["front"],
|
||||
)
|
||||
|
||||
resize_calls = []
|
||||
|
||||
def fake_resize(image, size, interpolation=None):
|
||||
resize_calls.append(
|
||||
{
|
||||
"shape": tuple(image.shape),
|
||||
"size": size,
|
||||
"interpolation": interpolation,
|
||||
}
|
||||
)
|
||||
return image
|
||||
|
||||
fake_cv2 = types.SimpleNamespace(INTER_LINEAR=1, resize=fake_resize)
|
||||
|
||||
with mock.patch.dict(sys.modules, {"cv2": fake_cv2}):
|
||||
sample = dataset[1]
|
||||
|
||||
self.assertEqual(len(resize_calls), 2)
|
||||
self.assertEqual(tuple(sample["observation.front"].shape), (2, 3, 8, 8))
|
||||
79
tests/test_streaming_episode_writer.py
Normal file
79
tests/test_streaming_episode_writer.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
|
||||
from roboimi.utils.streaming_episode_writer import StreamingEpisodeWriter
|
||||
|
||||
|
||||
class StreamingEpisodeWriterTest(unittest.TestCase):
|
||||
def test_commit_persists_raw_action_and_resized_images(self):
|
||||
camera_names = ["angle", "r_vis", "top", "front"]
|
||||
raw_action_0 = np.arange(16, dtype=np.float32)
|
||||
raw_action_1 = np.arange(16, dtype=np.float32) + 100.0
|
||||
qpos_0 = np.arange(16, dtype=np.float32) + 200.0
|
||||
qpos_1 = np.arange(16, dtype=np.float32) + 300.0
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
episode_path = Path(tmpdir) / "episode_0.hdf5"
|
||||
writer = StreamingEpisodeWriter(
|
||||
dataset_path=episode_path,
|
||||
max_timesteps=2,
|
||||
camera_names=camera_names,
|
||||
image_size=(256, 256),
|
||||
)
|
||||
|
||||
writer.append(
|
||||
qpos=qpos_0,
|
||||
action=raw_action_0,
|
||||
images={
|
||||
cam: np.full((480, 640, 3), fill_value=idx + 1, dtype=np.uint8)
|
||||
for idx, cam in enumerate(camera_names)
|
||||
},
|
||||
)
|
||||
writer.append(
|
||||
qpos=qpos_1,
|
||||
action=raw_action_1,
|
||||
images={
|
||||
cam: np.full((480, 640, 3), fill_value=idx + 11, dtype=np.uint8)
|
||||
for idx, cam in enumerate(camera_names)
|
||||
},
|
||||
)
|
||||
writer.commit()
|
||||
|
||||
self.assertTrue(episode_path.exists())
|
||||
self.assertFalse(Path(str(episode_path) + ".tmp").exists())
|
||||
|
||||
with h5py.File(episode_path, "r") as root:
|
||||
self.assertEqual(root["action"].shape, (2, 16))
|
||||
self.assertEqual(root["observations/qpos"].shape, (2, 16))
|
||||
np.testing.assert_allclose(root["action"][0], raw_action_0)
|
||||
np.testing.assert_allclose(root["action"][1], raw_action_1)
|
||||
np.testing.assert_allclose(root["observations/qpos"][0], qpos_0)
|
||||
np.testing.assert_allclose(root["observations/qpos"][1], qpos_1)
|
||||
for idx, cam_name in enumerate(camera_names):
|
||||
dataset = root[f"observations/images/{cam_name}"]
|
||||
self.assertEqual(dataset.shape, (2, 256, 256, 3))
|
||||
self.assertEqual(dataset.dtype, np.uint8)
|
||||
self.assertTrue(np.all(dataset[0] == idx + 1))
|
||||
self.assertTrue(np.all(dataset[1] == idx + 11))
|
||||
|
||||
def test_discard_removes_temporary_file(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
episode_path = Path(tmpdir) / "episode_0.hdf5"
|
||||
writer = StreamingEpisodeWriter(
|
||||
dataset_path=episode_path,
|
||||
max_timesteps=1,
|
||||
camera_names=["angle", "r_vis", "top", "front"],
|
||||
image_size=(256, 256),
|
||||
)
|
||||
writer.discard()
|
||||
|
||||
self.assertFalse(episode_path.exists())
|
||||
self.assertFalse(Path(str(episode_path) + ".tmp").exists())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
779
tests/test_train_vla_rollout_validation.py
Normal file
779
tests/test_train_vla_rollout_validation.py
Normal file
@@ -0,0 +1,779 @@
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
from torch import nn
|
||||
|
||||
from roboimi.demos.vla_scripts import eval_vla, train_vla
|
||||
|
||||
|
||||
class _FakeDataset:
|
||||
def __len__(self):
|
||||
return 4
|
||||
|
||||
|
||||
class _FakeLoader:
|
||||
def __init__(self, batch, length=1):
|
||||
self._batches = [batch] * length
|
||||
|
||||
def __len__(self):
|
||||
return len(self._batches)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._batches)
|
||||
|
||||
|
||||
class _FakeOptimizer:
|
||||
def __init__(self, lr=1e-3):
|
||||
self.param_groups = [{'lr': lr}]
|
||||
|
||||
def zero_grad(self):
|
||||
return None
|
||||
|
||||
def step(self):
|
||||
return None
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
del state_dict
|
||||
return None
|
||||
|
||||
|
||||
class _FakeScheduler:
|
||||
def __init__(self):
|
||||
self.step_calls = 0
|
||||
|
||||
def step(self):
|
||||
self.step_calls += 1
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
del state_dict
|
||||
return None
|
||||
|
||||
|
||||
class _FakeProgressBar:
|
||||
def __init__(self, iterable):
|
||||
self._items = list(iterable)
|
||||
self.postfix_calls = []
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._items)
|
||||
|
||||
def set_postfix(self, values):
|
||||
self.postfix_calls.append(values)
|
||||
|
||||
|
||||
class _FakeAgent(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.tensor(0.0))
|
||||
|
||||
def to(self, device):
|
||||
del device
|
||||
return self
|
||||
|
||||
def compute_loss(self, agent_input):
|
||||
del agent_input
|
||||
return (self.weight - torch.tensor(0.5)).pow(2)
|
||||
|
||||
def get_normalization_stats(self):
|
||||
return {}
|
||||
|
||||
|
||||
class _SequentialLossAgent(nn.Module):
|
||||
def __init__(self, losses):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.tensor(0.0))
|
||||
self._losses = list(losses)
|
||||
self._index = 0
|
||||
|
||||
def to(self, device):
|
||||
del device
|
||||
return self
|
||||
|
||||
def compute_loss(self, agent_input):
|
||||
del agent_input
|
||||
loss_value = self._losses[self._index]
|
||||
self._index += 1
|
||||
return (self.weight * 0) + torch.tensor(float(loss_value))
|
||||
|
||||
def get_normalization_stats(self):
|
||||
return {}
|
||||
|
||||
|
||||
class _FakeEvalAgent:
|
||||
def __init__(self):
|
||||
self.reset_calls = 0
|
||||
|
||||
def eval(self):
|
||||
return self
|
||||
|
||||
def to(self, device):
|
||||
del device
|
||||
return self
|
||||
|
||||
def reset(self):
|
||||
self.reset_calls += 1
|
||||
|
||||
def select_action(self, observation):
|
||||
del observation
|
||||
return torch.zeros(2)
|
||||
|
||||
|
||||
class _FakeEvalEnv:
|
||||
def reset(self, box_pos):
|
||||
self.box_pos = box_pos
|
||||
|
||||
def _get_image_obs(self):
|
||||
return {
|
||||
'images': {
|
||||
'front': np.zeros((8, 8, 3), dtype=np.uint8),
|
||||
}
|
||||
}
|
||||
|
||||
def _get_qpos_obs(self):
|
||||
return {'qpos': np.zeros(4, dtype=np.float32)}
|
||||
|
||||
def render(self):
|
||||
raise AssertionError('render should not be called in this helper delegation test')
|
||||
|
||||
|
||||
class TrainVLARolloutValidationTest(unittest.TestCase):
|
||||
def test_default_train_config_uses_full_dataset_and_epoch_rollout_validation(self):
|
||||
cfg = OmegaConf.load(Path('roboimi/vla/conf/config.yaml'))
|
||||
|
||||
self.assertEqual(cfg.train.val_split, 0.0)
|
||||
self.assertGreater(cfg.train.batch_size, 8)
|
||||
self.assertGreater(float(cfg.train.lr), 5e-5)
|
||||
self.assertGreater(cfg.train.num_workers, 8)
|
||||
self.assertEqual(cfg.train.rollout_val_freq_epochs, 50)
|
||||
|
||||
def test_eval_main_delegates_to_plain_run_eval_helper(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
'agent': {},
|
||||
'eval': {
|
||||
'ckpt_path': 'checkpoints/vla_model_step_1.pt',
|
||||
'num_episodes': 1,
|
||||
'max_timesteps': 1,
|
||||
'device': 'cpu',
|
||||
'task_name': 'sim_transfer',
|
||||
'camera_names': ['front'],
|
||||
'use_smoothing': False,
|
||||
'smooth_alpha': 0.3,
|
||||
'verbose_action': False,
|
||||
'headless': True,
|
||||
},
|
||||
}
|
||||
)
|
||||
run_eval_mock = mock.Mock()
|
||||
|
||||
with mock.patch.object(eval_vla, '_run_eval', run_eval_mock, create=True), \
|
||||
mock.patch.object(eval_vla, 'load_checkpoint', return_value=(_FakeEvalAgent(), None)), \
|
||||
mock.patch.object(eval_vla, 'make_sim_env', return_value=_FakeEvalEnv()), \
|
||||
mock.patch.object(eval_vla, 'sample_transfer_pose', return_value=np.zeros(3)), \
|
||||
mock.patch.object(eval_vla, 'execute_policy_action'), \
|
||||
mock.patch.object(eval_vla, 'tqdm', side_effect=lambda iterable, **kwargs: iterable):
|
||||
eval_vla.main.__wrapped__(cfg)
|
||||
|
||||
run_eval_mock.assert_called_once_with(cfg)
|
||||
|
||||
def test_run_training_rollout_validation_runs_every_50_epochs_and_uses_avg_reward_metric(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
'train': {
|
||||
'device': 'cpu',
|
||||
'batch_size': 1,
|
||||
'num_workers': 0,
|
||||
'val_split': 0.0,
|
||||
'seed': 0,
|
||||
'lr': 1e-3,
|
||||
'max_steps': 100,
|
||||
'log_freq': 1,
|
||||
'save_freq': 1000,
|
||||
'warmup_steps': 1,
|
||||
'scheduler_type': 'constant',
|
||||
'min_lr': 0.0,
|
||||
'grad_clip': 1.0,
|
||||
'weight_decay': 0.0,
|
||||
'pretrained_ckpt': None,
|
||||
'resume_ckpt': None,
|
||||
'use_swanlab': False,
|
||||
'rollout_val_freq_epochs': 50,
|
||||
'rollout_num_episodes': 3,
|
||||
},
|
||||
'data': {
|
||||
'camera_names': ['front'],
|
||||
},
|
||||
'agent': {
|
||||
'_target_': 'fake.agent',
|
||||
},
|
||||
'eval': {
|
||||
'ckpt_path': 'unused.pt',
|
||||
'num_episodes': 99,
|
||||
'max_timesteps': 1,
|
||||
'device': 'cpu',
|
||||
'task_name': 'sim_transfer',
|
||||
'camera_names': ['front'],
|
||||
'use_smoothing': False,
|
||||
'smooth_alpha': 0.3,
|
||||
'verbose_action': False,
|
||||
'headless': False,
|
||||
},
|
||||
}
|
||||
)
|
||||
agent = _FakeAgent()
|
||||
rollout_mock = mock.Mock(side_effect=[{'avg_reward': 2.0}, {'avg_reward': 1.0}])
|
||||
swanlab_log_mock = mock.Mock()
|
||||
saved_checkpoints = []
|
||||
|
||||
def fake_instantiate(config_node, **_kwargs):
|
||||
if config_node is cfg.data:
|
||||
return _FakeDataset()
|
||||
if config_node is cfg.agent:
|
||||
return agent
|
||||
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||
|
||||
def fake_dataloader(_dataset, *, shuffle, **_kwargs):
|
||||
del shuffle, _kwargs
|
||||
return _FakeLoader(
|
||||
{
|
||||
'observation.front': torch.zeros(1, 3, 2, 2),
|
||||
'observation.state': torch.zeros(1, 4),
|
||||
'action': torch.zeros(1, 2),
|
||||
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
|
||||
},
|
||||
length=1,
|
||||
)
|
||||
|
||||
def fake_torch_save(payload, path):
|
||||
saved_checkpoints.append((str(path), deepcopy(payload)))
|
||||
return None
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \
|
||||
mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \
|
||||
mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \
|
||||
mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \
|
||||
mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \
|
||||
mock.patch.object(train_vla, '_log_to_swanlab', swanlab_log_mock), \
|
||||
mock.patch.object(train_vla.torch, 'save', side_effect=fake_torch_save), \
|
||||
mock.patch.object(eval_vla, '_run_eval', rollout_mock, create=True), \
|
||||
mock.patch.object(eval_vla.main, '__wrapped__', side_effect=AssertionError('training hook should call eval_vla._run_eval')):
|
||||
train_vla._run_training(cfg)
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
self.assertEqual(rollout_mock.call_count, 2)
|
||||
first_rollout_cfg = rollout_mock.call_args_list[0].args[0]
|
||||
second_rollout_cfg = rollout_mock.call_args_list[1].args[0]
|
||||
self.assertEqual(first_rollout_cfg.eval.ckpt_path, 'checkpoints/vla_model_step_49.pt')
|
||||
self.assertEqual(second_rollout_cfg.eval.ckpt_path, 'checkpoints/vla_model_step_99.pt')
|
||||
self.assertEqual(first_rollout_cfg.eval.num_episodes, 3)
|
||||
self.assertTrue(first_rollout_cfg.eval.headless)
|
||||
self.assertEqual(first_rollout_cfg.eval.device, 'cpu')
|
||||
self.assertFalse(first_rollout_cfg.eval.verbose_action)
|
||||
self.assertEqual(cfg.eval.ckpt_path, 'unused.pt')
|
||||
self.assertEqual(cfg.eval.num_episodes, 99)
|
||||
self.assertFalse(cfg.eval.headless)
|
||||
self.assertEqual(cfg.eval.device, 'cpu')
|
||||
self.assertFalse(cfg.eval.verbose_action)
|
||||
|
||||
rollout_reward_logs = [
|
||||
call.args[1]['rollout/avg_reward']
|
||||
for call in swanlab_log_mock.call_args_list
|
||||
if len(call.args) >= 2 and 'rollout/avg_reward' in call.args[1]
|
||||
]
|
||||
self.assertEqual(rollout_reward_logs, [2.0, 1.0])
|
||||
|
||||
best_model_saves = [
|
||||
payload for path, payload in saved_checkpoints
|
||||
if path.endswith('checkpoints/vla_model_best.pt')
|
||||
]
|
||||
self.assertEqual(len(best_model_saves), 1)
|
||||
self.assertEqual(best_model_saves[0]['rollout_avg_reward'], 2.0)
|
||||
|
||||
def test_run_training_keeps_loss_based_best_checkpoint_until_first_rollout_metric_exists(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
'train': {
|
||||
'device': 'cpu',
|
||||
'batch_size': 1,
|
||||
'num_workers': 0,
|
||||
'val_split': 0.0,
|
||||
'seed': 0,
|
||||
'lr': 1e-3,
|
||||
'max_steps': 5,
|
||||
'log_freq': 1,
|
||||
'save_freq': 2,
|
||||
'warmup_steps': 1,
|
||||
'scheduler_type': 'constant',
|
||||
'min_lr': 0.0,
|
||||
'grad_clip': 1.0,
|
||||
'weight_decay': 0.0,
|
||||
'pretrained_ckpt': None,
|
||||
'resume_ckpt': None,
|
||||
'use_swanlab': False,
|
||||
'rollout_val_freq_epochs': 50,
|
||||
'rollout_num_episodes': 3,
|
||||
},
|
||||
'data': {
|
||||
'camera_names': ['front'],
|
||||
},
|
||||
'agent': {
|
||||
'_target_': 'fake.agent',
|
||||
},
|
||||
'eval': {
|
||||
'ckpt_path': 'unused.pt',
|
||||
'num_episodes': 99,
|
||||
'max_timesteps': 1,
|
||||
'device': 'cpu',
|
||||
'task_name': 'sim_transfer',
|
||||
'camera_names': ['front'],
|
||||
'use_smoothing': False,
|
||||
'smooth_alpha': 0.3,
|
||||
'verbose_action': False,
|
||||
'headless': False,
|
||||
},
|
||||
}
|
||||
)
|
||||
saved_checkpoints = []
|
||||
rollout_mock = mock.Mock()
|
||||
|
||||
def fake_instantiate(config_node, **_kwargs):
|
||||
if config_node is cfg.data:
|
||||
return _FakeDataset()
|
||||
if config_node is cfg.agent:
|
||||
return _FakeAgent()
|
||||
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||
|
||||
def fake_dataloader(_dataset, *, shuffle, **_kwargs):
|
||||
del shuffle, _kwargs
|
||||
return _FakeLoader(
|
||||
{
|
||||
'observation.front': torch.zeros(1, 3, 2, 2),
|
||||
'observation.state': torch.zeros(1, 4),
|
||||
'action': torch.zeros(1, 2),
|
||||
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
|
||||
},
|
||||
length=5,
|
||||
)
|
||||
|
||||
def fake_torch_save(payload, path):
|
||||
saved_checkpoints.append((str(path), deepcopy(payload)))
|
||||
return None
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \
|
||||
mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \
|
||||
mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \
|
||||
mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \
|
||||
mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \
|
||||
mock.patch.object(train_vla.torch, 'save', side_effect=fake_torch_save), \
|
||||
mock.patch.object(eval_vla, '_run_eval', rollout_mock, create=True):
|
||||
train_vla._run_training(cfg)
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
self.assertEqual(rollout_mock.call_count, 0)
|
||||
best_model_saves = [
|
||||
payload for path, payload in saved_checkpoints
|
||||
if path.endswith('checkpoints/vla_model_best.pt')
|
||||
]
|
||||
self.assertEqual(len(best_model_saves), 1)
|
||||
self.assertIsNone(best_model_saves[0]['rollout_avg_reward'])
|
||||
|
||||
def test_run_training_disables_drop_last_when_train_set_is_smaller_than_batch_size(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
'train': {
|
||||
'device': 'cpu',
|
||||
'batch_size': 8,
|
||||
'num_workers': 0,
|
||||
'val_split': 0.0,
|
||||
'seed': 0,
|
||||
'lr': 1e-3,
|
||||
'max_steps': 1,
|
||||
'log_freq': 1,
|
||||
'save_freq': 10,
|
||||
'warmup_steps': 1,
|
||||
'scheduler_type': 'constant',
|
||||
'min_lr': 0.0,
|
||||
'grad_clip': 1.0,
|
||||
'weight_decay': 0.0,
|
||||
'pretrained_ckpt': None,
|
||||
'resume_ckpt': None,
|
||||
'use_swanlab': False,
|
||||
'rollout_val_freq_epochs': 50,
|
||||
'rollout_num_episodes': 3,
|
||||
},
|
||||
'data': {
|
||||
'camera_names': ['front'],
|
||||
},
|
||||
'agent': {
|
||||
'_target_': 'fake.agent',
|
||||
},
|
||||
'eval': {
|
||||
'ckpt_path': 'unused.pt',
|
||||
'num_episodes': 99,
|
||||
'max_timesteps': 1,
|
||||
'device': 'cpu',
|
||||
'task_name': 'sim_transfer',
|
||||
'camera_names': ['front'],
|
||||
'use_smoothing': False,
|
||||
'smooth_alpha': 0.3,
|
||||
'verbose_action': False,
|
||||
'headless': False,
|
||||
},
|
||||
}
|
||||
)
|
||||
dataloader_calls = []
|
||||
|
||||
def fake_instantiate(config_node, **_kwargs):
|
||||
if config_node is cfg.data:
|
||||
return _FakeDataset()
|
||||
if config_node is cfg.agent:
|
||||
return _FakeAgent()
|
||||
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||
|
||||
def fake_dataloader(dataset, *, shuffle, drop_last, **_kwargs):
|
||||
dataloader_calls.append({
|
||||
'shuffle': shuffle,
|
||||
'drop_last': drop_last,
|
||||
'dataset_len': len(dataset),
|
||||
})
|
||||
return _FakeLoader(
|
||||
{
|
||||
'observation.front': torch.zeros(1, 3, 2, 2),
|
||||
'observation.state': torch.zeros(1, 4),
|
||||
'action': torch.zeros(1, 2),
|
||||
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
|
||||
},
|
||||
length=1,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \
|
||||
mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \
|
||||
mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \
|
||||
mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \
|
||||
mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \
|
||||
mock.patch.object(train_vla.torch, 'save', return_value=None):
|
||||
train_vla._run_training(cfg)
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
train_loader_calls = [call for call in dataloader_calls if call['shuffle']]
|
||||
self.assertEqual(len(train_loader_calls), 1)
|
||||
self.assertFalse(train_loader_calls[0]['drop_last'])
|
||||
|
||||
def test_run_training_disables_persistent_workers_for_train_and_val_loaders(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
'train': {
|
||||
'device': 'cpu',
|
||||
'batch_size': 2,
|
||||
'num_workers': 2,
|
||||
'val_split': 0.25,
|
||||
'seed': 0,
|
||||
'lr': 1e-3,
|
||||
'max_steps': 1,
|
||||
'log_freq': 1,
|
||||
'save_freq': 10,
|
||||
'warmup_steps': 1,
|
||||
'scheduler_type': 'constant',
|
||||
'min_lr': 0.0,
|
||||
'grad_clip': 1.0,
|
||||
'weight_decay': 0.0,
|
||||
'pretrained_ckpt': None,
|
||||
'resume_ckpt': None,
|
||||
'use_swanlab': False,
|
||||
'rollout_val_freq_epochs': 50,
|
||||
'rollout_num_episodes': 3,
|
||||
},
|
||||
'data': {
|
||||
'camera_names': ['front'],
|
||||
},
|
||||
'agent': {
|
||||
'_target_': 'fake.agent',
|
||||
},
|
||||
'eval': {
|
||||
'ckpt_path': 'unused.pt',
|
||||
'num_episodes': 99,
|
||||
'max_timesteps': 1,
|
||||
'device': 'cpu',
|
||||
'task_name': 'sim_transfer',
|
||||
'camera_names': ['front'],
|
||||
'use_smoothing': False,
|
||||
'smooth_alpha': 0.3,
|
||||
'verbose_action': False,
|
||||
'headless': False,
|
||||
},
|
||||
}
|
||||
)
|
||||
dataloader_calls = []
|
||||
|
||||
def fake_instantiate(config_node, **_kwargs):
|
||||
if config_node is cfg.data:
|
||||
return _FakeDataset()
|
||||
if config_node is cfg.agent:
|
||||
return _FakeAgent()
|
||||
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||
|
||||
def fake_dataloader(_dataset, *, shuffle, persistent_workers, num_workers, **_kwargs):
|
||||
dataloader_calls.append({
|
||||
'shuffle': shuffle,
|
||||
'num_workers': num_workers,
|
||||
'persistent_workers': persistent_workers,
|
||||
})
|
||||
return _FakeLoader(
|
||||
{
|
||||
'observation.front': torch.zeros(1, 3, 2, 2),
|
||||
'observation.state': torch.zeros(1, 4),
|
||||
'action': torch.zeros(1, 2),
|
||||
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
|
||||
},
|
||||
length=1,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \
|
||||
mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \
|
||||
mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \
|
||||
mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \
|
||||
mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \
|
||||
mock.patch.object(train_vla.torch, 'save', return_value=None):
|
||||
train_vla._run_training(cfg)
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
self.assertEqual(len(dataloader_calls), 2)
|
||||
self.assertEqual([call['shuffle'] for call in dataloader_calls], [True, False])
|
||||
self.assertTrue(all(call['num_workers'] == 2 for call in dataloader_calls))
|
||||
self.assertTrue(all(call['persistent_workers'] is False for call in dataloader_calls))
|
||||
|
||||
def test_run_training_uses_loss_best_until_first_rollout_then_prefers_rollout_reward(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
'train': {
|
||||
'device': 'cpu',
|
||||
'batch_size': 1,
|
||||
'num_workers': 0,
|
||||
'val_split': 0.0,
|
||||
'seed': 0,
|
||||
'lr': 1e-3,
|
||||
'max_steps': 6,
|
||||
'log_freq': 1,
|
||||
'save_freq': 1,
|
||||
'warmup_steps': 1,
|
||||
'scheduler_type': 'constant',
|
||||
'min_lr': 0.0,
|
||||
'grad_clip': 1.0,
|
||||
'weight_decay': 0.0,
|
||||
'pretrained_ckpt': None,
|
||||
'resume_ckpt': None,
|
||||
'use_swanlab': False,
|
||||
'rollout_val_freq_epochs': 2,
|
||||
'rollout_num_episodes': 1,
|
||||
},
|
||||
'data': {
|
||||
'camera_names': ['front'],
|
||||
},
|
||||
'agent': {
|
||||
'_target_': 'fake.agent',
|
||||
},
|
||||
'eval': {
|
||||
'ckpt_path': 'unused.pt',
|
||||
'num_episodes': 99,
|
||||
'max_timesteps': 1,
|
||||
'device': 'cpu',
|
||||
'task_name': 'sim_transfer',
|
||||
'camera_names': ['front'],
|
||||
'use_smoothing': False,
|
||||
'smooth_alpha': 0.3,
|
||||
'verbose_action': False,
|
||||
'headless': False,
|
||||
},
|
||||
}
|
||||
)
|
||||
agent = _SequentialLossAgent([10, 9, 8, 7, 6, 5])
|
||||
rollout_mock = mock.Mock(return_value={'avg_reward': 1.0})
|
||||
saved_checkpoints = []
|
||||
|
||||
def fake_instantiate(config_node, **_kwargs):
|
||||
if config_node is cfg.data:
|
||||
return _FakeDataset()
|
||||
if config_node is cfg.agent:
|
||||
return agent
|
||||
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||
|
||||
def fake_dataloader(_dataset, *, shuffle, **_kwargs):
|
||||
del _kwargs
|
||||
return _FakeLoader(
|
||||
{
|
||||
'observation.front': torch.zeros(1, 3, 2, 2),
|
||||
'observation.state': torch.zeros(1, 4),
|
||||
'action': torch.zeros(1, 2),
|
||||
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
|
||||
},
|
||||
length=2 if shuffle else 1,
|
||||
)
|
||||
|
||||
def fake_torch_save(payload, path):
|
||||
saved_checkpoints.append((str(path), deepcopy(payload)))
|
||||
return None
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \
|
||||
mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \
|
||||
mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \
|
||||
mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \
|
||||
mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \
|
||||
mock.patch.object(train_vla.torch, 'save', side_effect=fake_torch_save), \
|
||||
mock.patch.object(eval_vla, '_run_eval', rollout_mock, create=True):
|
||||
train_vla._run_training(cfg)
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
best_model_saves = [
|
||||
(payload['step'], payload['rollout_avg_reward'])
|
||||
for path, payload in saved_checkpoints
|
||||
if path.endswith('checkpoints/vla_model_best.pt')
|
||||
]
|
||||
self.assertEqual(
|
||||
best_model_saves,
|
||||
[
|
||||
(1, None),
|
||||
(2, None),
|
||||
(3, None),
|
||||
(3, 1.0),
|
||||
],
|
||||
)
|
||||
self.assertEqual(rollout_mock.call_count, 1)
|
||||
|
||||
def test_run_training_keeps_tiny_train_dataset_batch_when_batch_size_is_larger(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
'train': {
|
||||
'device': 'cpu',
|
||||
'batch_size': 8,
|
||||
'num_workers': 0,
|
||||
'val_split': 0.0,
|
||||
'seed': 0,
|
||||
'lr': 1e-3,
|
||||
'max_steps': 1,
|
||||
'log_freq': 1,
|
||||
'save_freq': 1000,
|
||||
'warmup_steps': 1,
|
||||
'scheduler_type': 'constant',
|
||||
'min_lr': 0.0,
|
||||
'grad_clip': 1.0,
|
||||
'weight_decay': 0.0,
|
||||
'pretrained_ckpt': None,
|
||||
'resume_ckpt': None,
|
||||
'use_swanlab': False,
|
||||
'rollout_val_freq_epochs': 0,
|
||||
},
|
||||
'data': {
|
||||
'camera_names': ['front'],
|
||||
},
|
||||
'agent': {
|
||||
'_target_': 'fake.agent',
|
||||
},
|
||||
}
|
||||
)
|
||||
agent = _FakeAgent()
|
||||
dataloader_calls = []
|
||||
saved_checkpoints = []
|
||||
|
||||
class _TinyDataset:
|
||||
def __len__(self):
|
||||
return 1
|
||||
|
||||
def fake_instantiate(config_node, **_kwargs):
|
||||
if config_node is cfg.data:
|
||||
return _TinyDataset()
|
||||
if config_node is cfg.agent:
|
||||
return agent
|
||||
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||
|
||||
def fake_dataloader(dataset, *, drop_last, shuffle, **_kwargs):
|
||||
del _kwargs
|
||||
dataloader_calls.append(
|
||||
{
|
||||
'shuffle': shuffle,
|
||||
'drop_last': drop_last,
|
||||
'dataset_len': len(dataset),
|
||||
}
|
||||
)
|
||||
loader_length = 0 if drop_last and len(dataset) < cfg.train.batch_size else 1
|
||||
return _FakeLoader(
|
||||
{
|
||||
'observation.front': torch.zeros(1, 3, 2, 2),
|
||||
'observation.state': torch.zeros(1, 4),
|
||||
'action': torch.zeros(1, 2),
|
||||
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
|
||||
},
|
||||
length=loader_length,
|
||||
)
|
||||
|
||||
def fake_torch_save(payload, path):
|
||||
saved_checkpoints.append((str(path), deepcopy(payload)))
|
||||
return None
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \
|
||||
mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \
|
||||
mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \
|
||||
mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \
|
||||
mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \
|
||||
mock.patch.object(train_vla.torch, 'save', side_effect=fake_torch_save):
|
||||
train_vla._run_training(cfg)
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
self.assertEqual(
|
||||
dataloader_calls[0],
|
||||
{
|
||||
'shuffle': True,
|
||||
'drop_last': False,
|
||||
'dataset_len': 1,
|
||||
},
|
||||
)
|
||||
self.assertEqual(
|
||||
[path for path, _payload in saved_checkpoints],
|
||||
['checkpoints/vla_model_final.pt'],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
699
tests/test_train_vla_swanlab_logging.py
Normal file
699
tests/test_train_vla_swanlab_logging.py
Normal file
@@ -0,0 +1,699 @@
|
||||
import importlib
|
||||
import importlib.util
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import types
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
_TRAIN_VLA_PATH = _REPO_ROOT / 'roboimi/demos/vla_scripts/train_vla.py'
|
||||
_CONFIG_PATH = _REPO_ROOT / 'roboimi/vla/conf/config.yaml'
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
def __getattr__(self, name):
|
||||
try:
|
||||
return self[name]
|
||||
except KeyError as exc:
|
||||
raise AttributeError(name) from exc
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
self[name] = value
|
||||
|
||||
|
||||
def _to_attrdict(value):
|
||||
if isinstance(value, dict):
|
||||
return AttrDict({key: _to_attrdict(item) for key, item in value.items()})
|
||||
if isinstance(value, list):
|
||||
return [_to_attrdict(item) for item in value]
|
||||
return value
|
||||
|
||||
|
||||
class FakeDataset:
|
||||
def __len__(self):
|
||||
return 4
|
||||
|
||||
|
||||
class FakeLoader:
|
||||
def __init__(self, batch):
|
||||
self.batch = batch
|
||||
|
||||
def __len__(self):
|
||||
return 1
|
||||
|
||||
def __iter__(self):
|
||||
return iter((self.batch,))
|
||||
|
||||
|
||||
class FakeScheduler:
|
||||
def __init__(self):
|
||||
self.step_calls = 0
|
||||
|
||||
def step(self):
|
||||
self.step_calls += 1
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
return None
|
||||
|
||||
|
||||
class FakeOptimizer:
|
||||
def __init__(self, lr=1e-3):
|
||||
self.param_groups = [{'lr': lr}]
|
||||
self.loaded_state_dict = None
|
||||
|
||||
def zero_grad(self):
|
||||
return None
|
||||
|
||||
def step(self):
|
||||
return None
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.loaded_state_dict = state_dict
|
||||
return None
|
||||
|
||||
|
||||
class FakeProgressBar:
|
||||
def __init__(self, iterable):
|
||||
self._items = list(iterable)
|
||||
self.postfix_calls = []
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._items)
|
||||
|
||||
def set_postfix(self, values):
|
||||
self.postfix_calls.append(values)
|
||||
|
||||
|
||||
class FakeAgent(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.tensor(0.0))
|
||||
|
||||
def to(self, device):
|
||||
return self
|
||||
|
||||
def compute_loss(self, agent_input):
|
||||
del agent_input
|
||||
target = torch.tensor(0.25 if self.training else 0.1)
|
||||
return (self.weight - target).pow(2)
|
||||
|
||||
def get_normalization_stats(self):
|
||||
return {}
|
||||
|
||||
|
||||
class FakeSwanLab:
|
||||
def __init__(self, init_error=None, log_errors=None, finish_error=None):
|
||||
self.init_error = init_error
|
||||
self.log_errors = list(log_errors or [])
|
||||
self.finish_error = finish_error
|
||||
self.init_calls = []
|
||||
self.log_calls = []
|
||||
self.finish_calls = 0
|
||||
|
||||
def init(self, project, experiment_name=None, config=None):
|
||||
self.init_calls.append({
|
||||
'project': project,
|
||||
'experiment_name': experiment_name,
|
||||
'config': config,
|
||||
})
|
||||
if self.init_error is not None:
|
||||
raise self.init_error
|
||||
return object()
|
||||
|
||||
def log(self, payload, step=None):
|
||||
self.log_calls.append((dict(payload), step))
|
||||
if self.log_errors:
|
||||
raise self.log_errors.pop(0)
|
||||
|
||||
def finish(self):
|
||||
self.finish_calls += 1
|
||||
if self.finish_error is not None:
|
||||
raise self.finish_error
|
||||
|
||||
|
||||
class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
||||
def test_default_config_keeps_swanlab_opt_in(self):
|
||||
config_text = _CONFIG_PATH.read_text(encoding='utf-8')
|
||||
self.assertIn('use_swanlab: false', config_text)
|
||||
|
||||
def _load_train_vla_module(self):
|
||||
hydra_module = types.ModuleType('hydra')
|
||||
hydra_utils_module = types.ModuleType('hydra.utils')
|
||||
hydra_utils_module.instantiate = lambda *args, **kwargs: None
|
||||
|
||||
def hydra_main(**_kwargs):
|
||||
def decorator(func):
|
||||
return func
|
||||
return decorator
|
||||
|
||||
hydra_module.main = hydra_main
|
||||
hydra_module.utils = hydra_utils_module
|
||||
|
||||
class OmegaConfStub:
|
||||
_resolvers = {}
|
||||
|
||||
@classmethod
|
||||
def has_resolver(cls, name):
|
||||
return name in cls._resolvers
|
||||
|
||||
@classmethod
|
||||
def register_new_resolver(cls, name, resolver):
|
||||
cls._resolvers[name] = resolver
|
||||
|
||||
@staticmethod
|
||||
def to_yaml(_cfg):
|
||||
return 'stub-config'
|
||||
|
||||
@staticmethod
|
||||
def to_container(cfg, resolve=False):
|
||||
del resolve
|
||||
return dict(cfg)
|
||||
|
||||
@staticmethod
|
||||
def create(cfg):
|
||||
return _to_attrdict(cfg)
|
||||
|
||||
omegaconf_module = types.ModuleType('omegaconf')
|
||||
omegaconf_module.DictConfig = dict
|
||||
omegaconf_module.OmegaConf = OmegaConfStub
|
||||
|
||||
module_name = 'train_vla_swanlab_test_module'
|
||||
spec = importlib.util.spec_from_file_location(module_name, _TRAIN_VLA_PATH)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
with mock.patch.dict(
|
||||
sys.modules,
|
||||
{
|
||||
'hydra': hydra_module,
|
||||
'hydra.utils': hydra_utils_module,
|
||||
'omegaconf': omegaconf_module,
|
||||
},
|
||||
):
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
def _make_cfg(self, *, use_swanlab=True, swanlab_run_name='smoke-run'):
|
||||
return AttrDict(
|
||||
train=AttrDict(
|
||||
device='cpu',
|
||||
batch_size=2,
|
||||
num_workers=0,
|
||||
val_split=0.25,
|
||||
seed=0,
|
||||
lr=1e-3,
|
||||
max_steps=2,
|
||||
log_freq=1,
|
||||
save_freq=1,
|
||||
warmup_steps=1,
|
||||
scheduler_type='constant',
|
||||
min_lr=0.0,
|
||||
grad_clip=1.0,
|
||||
weight_decay=0.0,
|
||||
pretrained_ckpt=None,
|
||||
resume_ckpt=None,
|
||||
use_swanlab=use_swanlab,
|
||||
swanlab_project='roboimi-vla-tests',
|
||||
swanlab_run_name=swanlab_run_name,
|
||||
),
|
||||
data=AttrDict(
|
||||
camera_names=('front',),
|
||||
),
|
||||
agent=AttrDict(
|
||||
_target_='fake.agent',
|
||||
),
|
||||
eval=AttrDict(
|
||||
ckpt_path='unused.pt',
|
||||
num_episodes=1,
|
||||
max_timesteps=1,
|
||||
device='cpu',
|
||||
task_name='sim_transfer',
|
||||
camera_names=('front',),
|
||||
use_smoothing=False,
|
||||
smooth_alpha=0.3,
|
||||
verbose_action=False,
|
||||
headless=False,
|
||||
),
|
||||
)
|
||||
|
||||
def _get_run_training(self, module):
|
||||
run_training = getattr(module, '_run_training', None)
|
||||
self.assertIsNotNone(run_training, 'Expected train_vla.py to expose a _run_training(cfg) helper')
|
||||
return run_training
|
||||
|
||||
def _make_batch(self):
|
||||
return {
|
||||
'observation.front': torch.zeros(1, 3, 2, 2),
|
||||
'observation.state': torch.zeros(1, 4),
|
||||
'action': torch.zeros(1, 2),
|
||||
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
|
||||
}
|
||||
|
||||
def _loader_factory(self):
|
||||
train_batch = self._make_batch()
|
||||
val_batch = self._make_batch()
|
||||
|
||||
def factory(_dataset, *, shuffle, **_kwargs):
|
||||
return FakeLoader(train_batch if shuffle else val_batch)
|
||||
|
||||
return factory
|
||||
|
||||
def test_run_training_logs_metrics_and_checkpoint_paths_to_swanlab(self):
|
||||
module = self._load_train_vla_module()
|
||||
run_training = self._get_run_training(module)
|
||||
cfg = self._make_cfg()
|
||||
agent = FakeAgent()
|
||||
fake_swanlab = FakeSwanLab()
|
||||
real_import_module = importlib.import_module
|
||||
|
||||
def fake_instantiate(config_node, **_kwargs):
|
||||
if config_node is cfg.data:
|
||||
return FakeDataset()
|
||||
if config_node is cfg.agent:
|
||||
return agent
|
||||
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||
|
||||
def fake_import_module(name, package=None):
|
||||
if name == 'swanlab':
|
||||
return fake_swanlab
|
||||
return real_import_module(name, package)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
||||
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
|
||||
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
|
||||
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
|
||||
mock.patch.object(module.torch, 'save', return_value=None), \
|
||||
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
|
||||
run_training(cfg)
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
self.assertEqual(
|
||||
fake_swanlab.init_calls,
|
||||
[{
|
||||
'project': 'roboimi-vla-tests',
|
||||
'experiment_name': 'smoke-run',
|
||||
'config': {
|
||||
'train': {
|
||||
'device': 'cpu',
|
||||
'batch_size': 2,
|
||||
'num_workers': 0,
|
||||
'val_split': 0.25,
|
||||
'seed': 0,
|
||||
'lr': 1e-3,
|
||||
'max_steps': 2,
|
||||
'log_freq': 1,
|
||||
'save_freq': 1,
|
||||
'warmup_steps': 1,
|
||||
'scheduler_type': 'constant',
|
||||
'min_lr': 0.0,
|
||||
'grad_clip': 1.0,
|
||||
'weight_decay': 0.0,
|
||||
'pretrained_ckpt': None,
|
||||
'resume_ckpt': None,
|
||||
'use_swanlab': True,
|
||||
'swanlab_project': 'roboimi-vla-tests',
|
||||
'swanlab_run_name': 'smoke-run',
|
||||
},
|
||||
'data': {
|
||||
'camera_names': ('front',),
|
||||
},
|
||||
'agent': {
|
||||
'_target_': 'fake.agent',
|
||||
},
|
||||
},
|
||||
}],
|
||||
)
|
||||
|
||||
logged_keys = set().union(*(payload.keys() for payload, _step in fake_swanlab.log_calls))
|
||||
self.assertTrue(
|
||||
{
|
||||
'train/loss',
|
||||
'train/lr',
|
||||
'train/best_loss',
|
||||
'train/step',
|
||||
'val/loss',
|
||||
'final/checkpoint_path',
|
||||
'final/best_checkpoint_path',
|
||||
}.issubset(logged_keys)
|
||||
)
|
||||
|
||||
final_payload, final_step = fake_swanlab.log_calls[-1]
|
||||
self.assertEqual(final_step, cfg.train.max_steps)
|
||||
self.assertEqual(final_payload['final/checkpoint_path'], 'checkpoints/vla_model_final.pt')
|
||||
self.assertEqual(final_payload['final/best_checkpoint_path'], 'checkpoints/vla_model_best.pt')
|
||||
self.assertEqual(fake_swanlab.finish_calls, 1)
|
||||
|
||||
def test_run_training_skips_swanlab_when_disabled(self):
|
||||
module = self._load_train_vla_module()
|
||||
run_training = self._get_run_training(module)
|
||||
cfg = self._make_cfg(use_swanlab=False)
|
||||
agent = FakeAgent()
|
||||
|
||||
def fake_instantiate(config_node, **_kwargs):
|
||||
if config_node is cfg.data:
|
||||
return FakeDataset()
|
||||
if config_node is cfg.agent:
|
||||
return agent
|
||||
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
||||
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
|
||||
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
|
||||
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
|
||||
mock.patch.object(module.torch, 'save', return_value=None), \
|
||||
mock.patch.object(module.importlib, 'import_module', side_effect=AssertionError('swanlab import should not run')):
|
||||
run_training(cfg)
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
def test_run_training_finishes_swanlab_when_exception_happens_after_init(self):
|
||||
module = self._load_train_vla_module()
|
||||
run_training = self._get_run_training(module)
|
||||
cfg = self._make_cfg()
|
||||
fake_swanlab = FakeSwanLab()
|
||||
real_import_module = importlib.import_module
|
||||
|
||||
def fake_import_module(name, package=None):
|
||||
if name == 'swanlab':
|
||||
return fake_swanlab
|
||||
return real_import_module(name, package)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
with mock.patch.object(module, 'instantiate', side_effect=RuntimeError('dataset boom')), \
|
||||
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
|
||||
with self.assertRaisesRegex(RuntimeError, 'dataset boom'):
|
||||
run_training(cfg)
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
self.assertEqual(fake_swanlab.finish_calls, 1)
|
||||
|
||||
def test_run_training_warns_and_continues_when_swanlab_log_and_finish_fail(self):
|
||||
module = self._load_train_vla_module()
|
||||
run_training = self._get_run_training(module)
|
||||
cfg = self._make_cfg()
|
||||
agent = FakeAgent()
|
||||
fake_swanlab = FakeSwanLab(
|
||||
log_errors=[RuntimeError('log backend hiccup')],
|
||||
finish_error=RuntimeError('finish backend hiccup'),
|
||||
)
|
||||
real_import_module = importlib.import_module
|
||||
|
||||
def fake_instantiate(config_node, **_kwargs):
|
||||
if config_node is cfg.data:
|
||||
return FakeDataset()
|
||||
if config_node is cfg.agent:
|
||||
return agent
|
||||
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||
|
||||
def fake_import_module(name, package=None):
|
||||
if name == 'swanlab':
|
||||
return fake_swanlab
|
||||
return real_import_module(name, package)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
||||
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
|
||||
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
|
||||
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
|
||||
mock.patch.object(module.torch, 'save', return_value=None), \
|
||||
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module), \
|
||||
mock.patch.object(module.log, 'warning') as warning_mock:
|
||||
run_training(cfg)
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
warning_messages = [call.args[0] for call in warning_mock.call_args_list]
|
||||
self.assertTrue(any('SwanLab log failed' in message for message in warning_messages))
|
||||
self.assertTrue(any('SwanLab finish failed' in message for message in warning_messages))
|
||||
self.assertEqual(fake_swanlab.finish_calls, 1)
|
||||
|
||||
def test_run_training_resume_restores_best_rollout_baseline_from_best_checkpoint(self):
|
||||
module = self._load_train_vla_module()
|
||||
run_training = self._get_run_training(module)
|
||||
cfg = self._make_cfg()
|
||||
cfg.train.max_steps = 2
|
||||
cfg.train.save_freq = 1
|
||||
cfg.train.rollout_validate_on_checkpoint = True
|
||||
fake_swanlab = FakeSwanLab()
|
||||
fake_optimizer = FakeOptimizer(lr=cfg.train.lr)
|
||||
fake_scheduler = FakeScheduler()
|
||||
real_import_module = importlib.import_module
|
||||
saved_paths = []
|
||||
|
||||
def fake_instantiate(config_node, **_kwargs):
|
||||
if config_node is cfg.data:
|
||||
return FakeDataset()
|
||||
if config_node is cfg.agent:
|
||||
return FakeAgent()
|
||||
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||
|
||||
def fake_import_module(name, package=None):
|
||||
if name == 'swanlab':
|
||||
return fake_swanlab
|
||||
return real_import_module(name, package)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
checkpoint_dir = Path('checkpoints')
|
||||
checkpoint_dir.mkdir()
|
||||
resume_path = checkpoint_dir / 'vla_model_step_0.pt'
|
||||
resume_path.write_bytes(b'resume')
|
||||
best_path = checkpoint_dir / 'vla_model_best.pt'
|
||||
best_path.write_bytes(b'best')
|
||||
cfg.train.resume_ckpt = str(resume_path)
|
||||
|
||||
resume_checkpoint_state = {
|
||||
'step': 0,
|
||||
'model_state_dict': FakeAgent().state_dict(),
|
||||
'optimizer_state_dict': {},
|
||||
'scheduler_state_dict': {},
|
||||
'loss': 0.5,
|
||||
'val_loss': 0.25,
|
||||
}
|
||||
best_checkpoint_state = {
|
||||
'step': 0,
|
||||
'model_state_dict': FakeAgent().state_dict(),
|
||||
'optimizer_state_dict': {},
|
||||
'scheduler_state_dict': {},
|
||||
'loss': 0.5,
|
||||
'val_loss': 0.25,
|
||||
'rollout_avg_reward': 5.0,
|
||||
}
|
||||
|
||||
def fake_torch_load(path, map_location=None):
|
||||
del map_location
|
||||
path = Path(path)
|
||||
if path == resume_path:
|
||||
return resume_checkpoint_state
|
||||
if path == best_path:
|
||||
return best_checkpoint_state
|
||||
raise AssertionError(f'unexpected load path: {path}')
|
||||
|
||||
def fake_torch_save(payload, path):
|
||||
saved_paths.append(str(path))
|
||||
return None
|
||||
|
||||
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
||||
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
|
||||
mock.patch.object(module, 'build_training_optimizer', return_value=fake_optimizer), \
|
||||
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=fake_scheduler), \
|
||||
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
|
||||
mock.patch.object(module.torch, 'save', side_effect=fake_torch_save), \
|
||||
mock.patch.object(module.torch, 'load', side_effect=fake_torch_load), \
|
||||
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module), \
|
||||
mock.patch('roboimi.demos.vla_scripts.eval_vla._run_eval', return_value={'avg_reward': 3.0}):
|
||||
run_training(cfg)
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
final_payload, final_step = fake_swanlab.log_calls[-1]
|
||||
self.assertEqual(final_step, cfg.train.max_steps)
|
||||
self.assertEqual(final_payload['final/best_checkpoint_path'], 'checkpoints/vla_model_best.pt')
|
||||
self.assertNotIn('checkpoints/vla_model_best.pt', saved_paths)
|
||||
|
||||
def test_run_training_resume_ignores_best_checkpoint_without_rollout_metric(self):
|
||||
module = self._load_train_vla_module()
|
||||
run_training = self._get_run_training(module)
|
||||
cfg = self._make_cfg()
|
||||
cfg.train.max_steps = 1
|
||||
fake_swanlab = FakeSwanLab()
|
||||
fake_optimizer = FakeOptimizer(lr=cfg.train.lr)
|
||||
fake_scheduler = FakeScheduler()
|
||||
real_import_module = importlib.import_module
|
||||
|
||||
def fake_instantiate(config_node, **_kwargs):
|
||||
if config_node is cfg.data:
|
||||
return FakeDataset()
|
||||
if config_node is cfg.agent:
|
||||
return FakeAgent()
|
||||
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||
|
||||
def fake_import_module(name, package=None):
|
||||
if name == 'swanlab':
|
||||
return fake_swanlab
|
||||
return real_import_module(name, package)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
checkpoint_dir = Path('checkpoints')
|
||||
checkpoint_dir.mkdir()
|
||||
resume_path = checkpoint_dir / 'vla_model_step_0.pt'
|
||||
resume_path.write_bytes(b'resume')
|
||||
best_path = checkpoint_dir / 'vla_model_best.pt'
|
||||
best_path.write_bytes(b'stale')
|
||||
cfg.train.resume_ckpt = str(resume_path)
|
||||
|
||||
resume_checkpoint_state = {
|
||||
'step': 0,
|
||||
'model_state_dict': FakeAgent().state_dict(),
|
||||
'optimizer_state_dict': {},
|
||||
'scheduler_state_dict': {},
|
||||
'loss': 0.5,
|
||||
'val_loss': 0.25,
|
||||
}
|
||||
stale_best_checkpoint_state = {
|
||||
'step': 0,
|
||||
'model_state_dict': FakeAgent().state_dict(),
|
||||
'optimizer_state_dict': {},
|
||||
'scheduler_state_dict': {},
|
||||
'loss': 0.4,
|
||||
'val_loss': 0.2,
|
||||
}
|
||||
|
||||
def fake_torch_load(path, map_location=None):
|
||||
del map_location
|
||||
path = Path(path)
|
||||
if path == resume_path:
|
||||
return resume_checkpoint_state
|
||||
if path == best_path:
|
||||
return stale_best_checkpoint_state
|
||||
raise AssertionError(f'unexpected load path: {path}')
|
||||
|
||||
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
||||
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
|
||||
mock.patch.object(module, 'build_training_optimizer', return_value=fake_optimizer), \
|
||||
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=fake_scheduler), \
|
||||
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
|
||||
mock.patch.object(module.torch, 'save', return_value=None), \
|
||||
mock.patch.object(module.torch, 'load', side_effect=fake_torch_load), \
|
||||
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
|
||||
run_training(cfg)
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
final_payload, final_step = fake_swanlab.log_calls[-1]
|
||||
self.assertEqual(final_step, cfg.train.max_steps)
|
||||
self.assertEqual(final_payload['final/best_checkpoint_path'], 'checkpoints/vla_model_step_0.pt')
|
||||
|
||||
def test_run_training_ignores_stale_best_checkpoint_file_on_fresh_non_resume_run(self):
|
||||
module = self._load_train_vla_module()
|
||||
run_training = self._get_run_training(module)
|
||||
cfg = self._make_cfg()
|
||||
cfg.train.max_steps = 1
|
||||
fake_swanlab = FakeSwanLab()
|
||||
real_import_module = importlib.import_module
|
||||
|
||||
def fake_instantiate(config_node, **_kwargs):
|
||||
if config_node is cfg.data:
|
||||
return FakeDataset()
|
||||
if config_node is cfg.agent:
|
||||
return FakeAgent()
|
||||
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||
|
||||
def fake_import_module(name, package=None):
|
||||
if name == 'swanlab':
|
||||
return fake_swanlab
|
||||
return real_import_module(name, package)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
checkpoint_dir = Path('checkpoints')
|
||||
checkpoint_dir.mkdir()
|
||||
(checkpoint_dir / 'vla_model_best.pt').write_bytes(b'stale-best')
|
||||
|
||||
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
||||
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
|
||||
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
|
||||
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
|
||||
mock.patch.object(module.torch, 'save', return_value=None), \
|
||||
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
|
||||
run_training(cfg)
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
final_payload, final_step = fake_swanlab.log_calls[-1]
|
||||
self.assertEqual(final_step, cfg.train.max_steps)
|
||||
self.assertEqual(final_payload['final/best_checkpoint_path'], '')
|
||||
|
||||
def test_run_training_fails_fast_when_swanlab_import_is_unavailable(self):
|
||||
module = self._load_train_vla_module()
|
||||
run_training = self._get_run_training(module)
|
||||
cfg = self._make_cfg()
|
||||
real_import_module = importlib.import_module
|
||||
|
||||
def fake_import_module(name, package=None):
|
||||
if name == 'swanlab':
|
||||
raise ImportError('missing swanlab')
|
||||
return real_import_module(name, package)
|
||||
|
||||
with mock.patch.object(module, 'instantiate', side_effect=AssertionError('instantiate should not run')), \
|
||||
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
|
||||
with self.assertRaisesRegex(RuntimeError, 'SwanLab'):
|
||||
run_training(cfg)
|
||||
|
||||
def test_run_training_fails_fast_when_swanlab_init_fails(self):
|
||||
module = self._load_train_vla_module()
|
||||
run_training = self._get_run_training(module)
|
||||
cfg = self._make_cfg()
|
||||
fake_swanlab = FakeSwanLab(init_error=RuntimeError('not logged in'))
|
||||
real_import_module = importlib.import_module
|
||||
|
||||
def fake_import_module(name, package=None):
|
||||
if name == 'swanlab':
|
||||
return fake_swanlab
|
||||
return real_import_module(name, package)
|
||||
|
||||
with mock.patch.object(module, 'instantiate', side_effect=AssertionError('instantiate should not run')), \
|
||||
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
|
||||
with self.assertRaisesRegex(RuntimeError, 'not logged in'):
|
||||
run_training(cfg)
|
||||
|
||||
self.assertEqual(fake_swanlab.finish_calls, 0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
310
tests/test_train_vla_transformer_optimizer.py
Normal file
310
tests/test_train_vla_transformer_optimizer.py
Normal file
@@ -0,0 +1,310 @@
|
||||
import importlib.util
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import types
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
_TRAIN_VLA_PATH = _REPO_ROOT / 'roboimi/demos/vla_scripts/train_vla.py'
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
def __getattr__(self, name):
|
||||
try:
|
||||
return self[name]
|
||||
except KeyError as exc:
|
||||
raise AttributeError(name) from exc
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
self[name] = value
|
||||
|
||||
|
||||
class FakeDataset:
|
||||
def __len__(self):
|
||||
return 4
|
||||
|
||||
|
||||
class FakeLoader:
|
||||
def __len__(self):
|
||||
return 1
|
||||
|
||||
def __iter__(self):
|
||||
return iter(())
|
||||
|
||||
|
||||
class FakeScheduler:
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
return None
|
||||
|
||||
|
||||
class RecordingAdamW:
|
||||
created = []
|
||||
|
||||
def __init__(self, params, lr, weight_decay):
|
||||
self.lr = lr
|
||||
self.weight_decay = weight_decay
|
||||
self.param_groups = self._normalize_param_groups(params, lr, weight_decay)
|
||||
RecordingAdamW.created.append(self)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_param_groups(params, lr, weight_decay):
|
||||
if isinstance(params, (list, tuple)) and params and isinstance(params[0], dict):
|
||||
groups = []
|
||||
for group in params:
|
||||
normalized = dict(group)
|
||||
normalized['params'] = list(group['params'])
|
||||
normalized.setdefault('lr', lr)
|
||||
groups.append(normalized)
|
||||
return groups
|
||||
|
||||
return [{
|
||||
'params': list(params),
|
||||
'lr': lr,
|
||||
'weight_decay': weight_decay,
|
||||
}]
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
return None
|
||||
|
||||
|
||||
class RecordingTransformerHead(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(4, 4)
|
||||
self.norm = nn.LayerNorm(4)
|
||||
self.optim_group_calls = []
|
||||
|
||||
def get_optim_groups(self, weight_decay):
|
||||
self.optim_group_calls.append(weight_decay)
|
||||
return [
|
||||
{
|
||||
'params': [self.proj.weight],
|
||||
'weight_decay': weight_decay,
|
||||
},
|
||||
{
|
||||
'params': [self.proj.bias, self.norm.weight, self.norm.bias],
|
||||
'weight_decay': 0.0,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class FakeTransformerAgent(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.head_type = 'transformer'
|
||||
self.noise_pred_net = RecordingTransformerHead()
|
||||
self.backbone = nn.Linear(4, 3)
|
||||
self.adapter = nn.Linear(3, 2, bias=False)
|
||||
self.frozen = nn.Linear(2, 2)
|
||||
for param in self.frozen.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def to(self, device):
|
||||
return self
|
||||
|
||||
def get_normalization_stats(self):
|
||||
return {}
|
||||
|
||||
|
||||
class TrainVLATransformerOptimizerTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
RecordingAdamW.created = []
|
||||
|
||||
def _load_train_vla_module(self):
|
||||
hydra_module = types.ModuleType('hydra')
|
||||
hydra_utils_module = types.ModuleType('hydra.utils')
|
||||
hydra_utils_module.instantiate = lambda *args, **kwargs: None
|
||||
|
||||
def hydra_main(**_kwargs):
|
||||
def decorator(func):
|
||||
return func
|
||||
return decorator
|
||||
|
||||
hydra_module.main = hydra_main
|
||||
hydra_module.utils = hydra_utils_module
|
||||
|
||||
class OmegaConfStub:
|
||||
_resolvers = {}
|
||||
|
||||
@classmethod
|
||||
def has_resolver(cls, name):
|
||||
return name in cls._resolvers
|
||||
|
||||
@classmethod
|
||||
def register_new_resolver(cls, name, resolver):
|
||||
cls._resolvers[name] = resolver
|
||||
|
||||
@staticmethod
|
||||
def to_yaml(_cfg):
|
||||
return 'stub-config'
|
||||
|
||||
omegaconf_module = types.ModuleType('omegaconf')
|
||||
omegaconf_module.DictConfig = dict
|
||||
omegaconf_module.OmegaConf = OmegaConfStub
|
||||
|
||||
module_name = 'train_vla_optimizer_test_module'
|
||||
spec = importlib.util.spec_from_file_location(module_name, _TRAIN_VLA_PATH)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
with mock.patch.dict(
|
||||
sys.modules,
|
||||
{
|
||||
'hydra': hydra_module,
|
||||
'hydra.utils': hydra_utils_module,
|
||||
'omegaconf': omegaconf_module,
|
||||
},
|
||||
):
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
def _make_cfg(self):
|
||||
return AttrDict(
|
||||
train=AttrDict(
|
||||
device='cpu',
|
||||
batch_size=2,
|
||||
num_workers=0,
|
||||
val_split=0,
|
||||
seed=0,
|
||||
lr=1e-4,
|
||||
max_steps=0,
|
||||
log_freq=1,
|
||||
save_freq=100,
|
||||
warmup_steps=1,
|
||||
scheduler_type='constant',
|
||||
min_lr=0.0,
|
||||
grad_clip=1.0,
|
||||
weight_decay=0.123,
|
||||
pretrained_ckpt=None,
|
||||
resume_ckpt=None,
|
||||
),
|
||||
data=AttrDict(
|
||||
camera_names=('front',),
|
||||
),
|
||||
agent=AttrDict(
|
||||
_target_='fake.agent',
|
||||
),
|
||||
)
|
||||
|
||||
def _group_names(self, agent, optimizer):
|
||||
names_by_param_id = {id(param): name for name, param in agent.named_parameters()}
|
||||
return [
|
||||
{names_by_param_id[id(param)] for param in group['params']}
|
||||
for group in optimizer.param_groups
|
||||
]
|
||||
|
||||
def test_transformer_training_prefers_head_optim_groups_and_keeps_remaining_trainable_params(self):
|
||||
module = self._load_train_vla_module()
|
||||
agent = FakeTransformerAgent()
|
||||
cfg = self._make_cfg()
|
||||
|
||||
def fake_instantiate(config_node, **_kwargs):
|
||||
if config_node is cfg.data:
|
||||
return FakeDataset()
|
||||
if config_node is cfg.agent:
|
||||
return agent
|
||||
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
||||
mock.patch.object(module, 'DataLoader', side_effect=lambda *args, **kwargs: FakeLoader()), \
|
||||
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
|
||||
mock.patch.object(module, 'AdamW', RecordingAdamW), \
|
||||
mock.patch.object(module.torch, 'save', return_value=None), \
|
||||
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: iterable):
|
||||
module.main(cfg)
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
self.assertEqual(agent.noise_pred_net.optim_group_calls, [cfg.train.weight_decay])
|
||||
|
||||
optimizer = RecordingAdamW.created[-1]
|
||||
trainable_names = {
|
||||
name for name, param in agent.named_parameters() if param.requires_grad
|
||||
}
|
||||
grouped_names = self._group_names(agent, optimizer)
|
||||
optimizer_names = set().union(*grouped_names)
|
||||
expected_head_names = {
|
||||
'noise_pred_net.proj.weight',
|
||||
'noise_pred_net.proj.bias',
|
||||
'noise_pred_net.norm.weight',
|
||||
'noise_pred_net.norm.bias',
|
||||
}
|
||||
expected_non_head_names = {
|
||||
'backbone.weight',
|
||||
'backbone.bias',
|
||||
'adapter.weight',
|
||||
}
|
||||
|
||||
self.assertEqual(grouped_names[0], {'noise_pred_net.proj.weight'})
|
||||
self.assertEqual(grouped_names[1], expected_head_names - {'noise_pred_net.proj.weight'})
|
||||
self.assertEqual(grouped_names[2], expected_non_head_names)
|
||||
self.assertEqual(optimizer.param_groups[0]['weight_decay'], cfg.train.weight_decay)
|
||||
self.assertEqual(optimizer.param_groups[1]['weight_decay'], 0.0)
|
||||
self.assertEqual(optimizer.param_groups[2]['weight_decay'], cfg.train.weight_decay)
|
||||
self.assertEqual(optimizer_names, trainable_names)
|
||||
|
||||
flattened_param_ids = [
|
||||
id(param)
|
||||
for group in optimizer.param_groups
|
||||
for param in group['params']
|
||||
]
|
||||
self.assertEqual(len(flattened_param_ids), len(set(flattened_param_ids)))
|
||||
self.assertNotIn('frozen.weight', optimizer_names)
|
||||
self.assertNotIn('frozen.bias', optimizer_names)
|
||||
|
||||
def test_transformer_optimizer_ignores_frozen_head_params_returned_by_head_groups(self):
|
||||
module = self._load_train_vla_module()
|
||||
agent = FakeTransformerAgent()
|
||||
agent.noise_pred_net.norm.bias.requires_grad = False
|
||||
cfg = self._make_cfg()
|
||||
|
||||
def fake_instantiate(config_node, **_kwargs):
|
||||
if config_node is cfg.data:
|
||||
return FakeDataset()
|
||||
if config_node is cfg.agent:
|
||||
return agent
|
||||
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
||||
mock.patch.object(module, 'DataLoader', side_effect=lambda *args, **kwargs: FakeLoader()), \
|
||||
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
|
||||
mock.patch.object(module, 'AdamW', RecordingAdamW), \
|
||||
mock.patch.object(module.torch, 'save', return_value=None), \
|
||||
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: iterable):
|
||||
module.main(cfg)
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
optimizer = RecordingAdamW.created[-1]
|
||||
optimizer_names = set().union(*self._group_names(agent, optimizer))
|
||||
trainable_names = {
|
||||
name for name, param in agent.named_parameters() if param.requires_grad
|
||||
}
|
||||
|
||||
self.assertEqual(agent.noise_pred_net.optim_group_calls, [cfg.train.weight_decay])
|
||||
self.assertEqual(optimizer_names, trainable_names)
|
||||
self.assertNotIn('noise_pred_net.norm.bias', optimizer_names)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
262
tests/test_transformer1d_external_alignment.py
Normal file
262
tests/test_transformer1d_external_alignment.py
Normal file
@@ -0,0 +1,262 @@
|
||||
import contextlib
|
||||
import importlib.util
|
||||
import inspect
|
||||
import sys
|
||||
import types
|
||||
import unittest
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
_LOCAL_MODULE_PATH = _REPO_ROOT / 'roboimi/vla/models/heads/transformer1d.py'
|
||||
_EXTERNAL_CHECKOUT_ROOT = _REPO_ROOT.parent / 'diffusion_policy'
|
||||
_TRANSFORMER_WARNING_MESSAGE = (
|
||||
r'enable_nested_tensor is True, but self.use_nested_tensor is False '
|
||||
r'because encoder_layer\.norm_first was True'
|
||||
)
|
||||
_MISSING = object()
|
||||
|
||||
|
||||
def _load_module_from_path(name: str, path: Path, *, register: bool = False):
|
||||
spec = importlib.util.spec_from_file_location(name, path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
assert spec.loader is not None
|
||||
if register:
|
||||
sys.modules[name] = module
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def _resolve_external_module_paths(external_checkout_root: Path):
|
||||
diffusion_policy_root = external_checkout_root / 'diffusion_policy'
|
||||
paths = {
|
||||
'positional_embedding': diffusion_policy_root / 'model/diffusion/positional_embedding.py',
|
||||
'module_attr_mixin': diffusion_policy_root / 'model/common/module_attr_mixin.py',
|
||||
'transformer_for_diffusion': diffusion_policy_root / 'model/diffusion/transformer_for_diffusion.py',
|
||||
}
|
||||
if not all(path.exists() for path in paths.values()):
|
||||
return None
|
||||
return paths
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _temporary_registered_modules():
|
||||
previous_modules = {}
|
||||
|
||||
def remember(name: str) -> None:
|
||||
if name not in previous_modules:
|
||||
previous_modules[name] = sys.modules.get(name, _MISSING)
|
||||
|
||||
def ensure_package(name: str) -> None:
|
||||
if not name or name in sys.modules:
|
||||
return
|
||||
remember(name)
|
||||
package = types.ModuleType(name)
|
||||
package.__path__ = []
|
||||
sys.modules[name] = package
|
||||
|
||||
def load(name: str, path: Path):
|
||||
package_parts = name.split('.')[:-1]
|
||||
for idx in range(1, len(package_parts) + 1):
|
||||
ensure_package('.'.join(package_parts[:idx]))
|
||||
|
||||
remember(name)
|
||||
return _load_module_from_path(name, path, register=True)
|
||||
|
||||
try:
|
||||
yield load
|
||||
finally:
|
||||
for name, previous in reversed(list(previous_modules.items())):
|
||||
if previous is _MISSING:
|
||||
sys.modules.pop(name, None)
|
||||
else:
|
||||
sys.modules[name] = previous
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _suppress_nested_tensor_warning():
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
'ignore',
|
||||
message=_TRANSFORMER_WARNING_MESSAGE,
|
||||
category=UserWarning,
|
||||
module=r'torch\.nn\.modules\.transformer',
|
||||
)
|
||||
yield
|
||||
|
||||
|
||||
def _load_local_module():
|
||||
return _load_module_from_path('local_transformer1d_alignment', _LOCAL_MODULE_PATH)
|
||||
|
||||
|
||||
class Transformer1DExternalAlignmentTest(unittest.TestCase):
|
||||
def _load_transformer_classes_or_skip(self):
|
||||
external_paths = _resolve_external_module_paths(_EXTERNAL_CHECKOUT_ROOT)
|
||||
if external_paths is None:
|
||||
self.skipTest(f'external diffusion_policy checkout unavailable under {_EXTERNAL_CHECKOUT_ROOT}')
|
||||
|
||||
local_module = _load_local_module()
|
||||
with _temporary_registered_modules() as load_external:
|
||||
load_external(
|
||||
'diffusion_policy.model.diffusion.positional_embedding',
|
||||
external_paths['positional_embedding'],
|
||||
)
|
||||
load_external(
|
||||
'diffusion_policy.model.common.module_attr_mixin',
|
||||
external_paths['module_attr_mixin'],
|
||||
)
|
||||
external_module = load_external(
|
||||
'diffusion_policy.model.diffusion.transformer_for_diffusion',
|
||||
external_paths['transformer_for_diffusion'],
|
||||
)
|
||||
|
||||
return local_module.Transformer1D, local_module.create_transformer1d, external_module.TransformerForDiffusion
|
||||
|
||||
def _optim_group_names(self, model, groups):
|
||||
names_by_param = {id(param): name for name, param in model.named_parameters()}
|
||||
return [
|
||||
{names_by_param[id(param)] for param in group['params']}
|
||||
for group in groups
|
||||
]
|
||||
|
||||
def test_missing_external_checkout_resolution_returns_none(self):
|
||||
self.assertIsNone(_resolve_external_module_paths(_REPO_ROOT / '__missing_diffusion_policy_checkout__'))
|
||||
|
||||
def test_external_loader_restores_injected_sys_modules(self):
|
||||
external_paths = _resolve_external_module_paths(_EXTERNAL_CHECKOUT_ROOT)
|
||||
if external_paths is None:
|
||||
self.skipTest(f'external diffusion_policy checkout unavailable under {_EXTERNAL_CHECKOUT_ROOT}')
|
||||
|
||||
watched_names = [
|
||||
'diffusion_policy',
|
||||
'diffusion_policy.model',
|
||||
'diffusion_policy.model.common',
|
||||
'diffusion_policy.model.common.module_attr_mixin',
|
||||
'diffusion_policy.model.diffusion',
|
||||
'diffusion_policy.model.diffusion.positional_embedding',
|
||||
'diffusion_policy.model.diffusion.transformer_for_diffusion',
|
||||
]
|
||||
before = {name: sys.modules.get(name, _MISSING) for name in watched_names}
|
||||
|
||||
with _temporary_registered_modules() as load_external:
|
||||
load_external(
|
||||
'diffusion_policy.model.diffusion.positional_embedding',
|
||||
external_paths['positional_embedding'],
|
||||
)
|
||||
load_external(
|
||||
'diffusion_policy.model.common.module_attr_mixin',
|
||||
external_paths['module_attr_mixin'],
|
||||
)
|
||||
load_external(
|
||||
'diffusion_policy.model.diffusion.transformer_for_diffusion',
|
||||
external_paths['transformer_for_diffusion'],
|
||||
)
|
||||
|
||||
after = {name: sys.modules.get(name, _MISSING) for name in watched_names}
|
||||
self.assertEqual(after, before)
|
||||
|
||||
def test_transformer1d_preserves_local_direct_call_defaults(self):
|
||||
local_module = _load_local_module()
|
||||
ctor = inspect.signature(local_module.Transformer1D.__init__).parameters
|
||||
helper = inspect.signature(local_module.create_transformer1d).parameters
|
||||
|
||||
self.assertEqual(ctor['n_layer'].default, 8)
|
||||
self.assertEqual(ctor['n_head'].default, 8)
|
||||
self.assertEqual(ctor['n_emb'].default, 256)
|
||||
self.assertEqual(helper['n_layer'].default, 8)
|
||||
self.assertEqual(helper['n_head'].default, 8)
|
||||
self.assertEqual(helper['n_emb'].default, 256)
|
||||
|
||||
def test_time_as_cond_false_token_accounting_matches_external(self):
|
||||
Transformer1D, _, TransformerForDiffusion = self._load_transformer_classes_or_skip()
|
||||
self.assertIn('time_as_cond', inspect.signature(Transformer1D.__init__).parameters)
|
||||
|
||||
config = dict(
|
||||
input_dim=4,
|
||||
output_dim=4,
|
||||
horizon=6,
|
||||
n_obs_steps=3,
|
||||
cond_dim=0,
|
||||
n_layer=2,
|
||||
n_head=2,
|
||||
n_emb=8,
|
||||
p_drop_emb=0.0,
|
||||
p_drop_attn=0.0,
|
||||
causal_attn=False,
|
||||
time_as_cond=False,
|
||||
obs_as_cond=False,
|
||||
n_cond_layers=0,
|
||||
)
|
||||
|
||||
torch.manual_seed(5)
|
||||
with _suppress_nested_tensor_warning():
|
||||
external_model = TransformerForDiffusion(**config)
|
||||
local_model = Transformer1D(**config)
|
||||
external_model.eval()
|
||||
local_model.eval()
|
||||
|
||||
self.assertEqual(local_model.T, external_model.T)
|
||||
self.assertEqual(local_model.T_cond, external_model.T_cond)
|
||||
self.assertEqual(local_model.time_as_cond, external_model.time_as_cond)
|
||||
self.assertEqual(local_model.obs_as_cond, external_model.obs_as_cond)
|
||||
self.assertEqual(local_model.encoder_only, external_model.encoder_only)
|
||||
|
||||
def test_nocausal_state_dict_forward_and_optim_groups_match_external(self):
|
||||
Transformer1D, _, TransformerForDiffusion = self._load_transformer_classes_or_skip()
|
||||
config = dict(
|
||||
input_dim=4,
|
||||
output_dim=4,
|
||||
horizon=6,
|
||||
n_obs_steps=3,
|
||||
cond_dim=5,
|
||||
n_layer=2,
|
||||
n_head=2,
|
||||
n_emb=8,
|
||||
p_drop_emb=0.0,
|
||||
p_drop_attn=0.0,
|
||||
causal_attn=False,
|
||||
obs_as_cond=True,
|
||||
n_cond_layers=1,
|
||||
)
|
||||
|
||||
torch.manual_seed(7)
|
||||
with _suppress_nested_tensor_warning():
|
||||
external_model = TransformerForDiffusion(**config)
|
||||
local_model = Transformer1D(**config)
|
||||
external_model.eval()
|
||||
local_model.eval()
|
||||
|
||||
external_state_dict = external_model.state_dict()
|
||||
self.assertEqual(set(local_model.state_dict().keys()), set(external_state_dict.keys()))
|
||||
local_model.load_state_dict(external_state_dict, strict=True)
|
||||
|
||||
batch_size = 2
|
||||
sample = torch.randn(batch_size, config['horizon'], config['input_dim'])
|
||||
cond = torch.randn(batch_size, config['n_obs_steps'], config['cond_dim'])
|
||||
timestep = torch.tensor([11, 17], dtype=torch.long)
|
||||
|
||||
with torch.no_grad():
|
||||
external_out = external_model(sample=sample, timestep=timestep, cond=cond)
|
||||
local_out = local_model(sample=sample, timestep=timestep, cond=cond)
|
||||
|
||||
self.assertEqual(local_out.shape, (batch_size, config['horizon'], config['output_dim']))
|
||||
self.assertEqual(local_out.shape, external_out.shape)
|
||||
self.assertTrue(torch.allclose(local_out, external_out, atol=1e-6, rtol=1e-5))
|
||||
|
||||
weight_decay = 0.123
|
||||
external_groups = external_model.get_optim_groups(weight_decay=weight_decay)
|
||||
local_groups = local_model.get_optim_groups(weight_decay=weight_decay)
|
||||
|
||||
self.assertEqual(len(local_groups), len(external_groups))
|
||||
self.assertEqual([group['weight_decay'] for group in local_groups], [weight_decay, 0.0])
|
||||
self.assertEqual(
|
||||
self._optim_group_names(local_model, local_groups),
|
||||
self._optim_group_names(external_model, external_groups),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user