feat(lewm): enable gpu parallel rollout validation
This commit is contained in:
471
docs/lewm-imf-experiment-guide.md
Normal file
471
docs/lewm-imf-experiment-guide.md
Normal file
@@ -0,0 +1,471 @@
|
||||
# feat-lewm-imf-fusion 实验操作指南
|
||||
|
||||
适用 worktree:`/home/droid/project/roboimi/.worktrees/feat-lewm-imf-fusion`
|
||||
|
||||
## 0. 先记住当前常用 recipe
|
||||
|
||||
当前这条分支最常用的训练/验证配方,直接参考:
|
||||
`experiment_suites/2026-04-21-lewm-fromscratch-old9-epoch50-roll5-val-20260421-153037/`
|
||||
|
||||
核心约定:
|
||||
- agent:`lewm_resnet_query_imf_attnres`
|
||||
- from scratch:`train.pretrained_ckpt=null`,`agent.lewm_pretrained_ckpt=null`
|
||||
- 训练:`batch_size=32`,`lr=1e-4`,`max_steps=109350`,`save_freq=10000`
|
||||
- 数值验证:`train.val_split=0.0` + `train.val_episode_indices=[100]`
|
||||
- held-out numeric validation:`train.action_mse_val_freq_epochs=1`
|
||||
- rollout validation:`train.rollout_val_freq_epochs=5`,`train.rollout_num_episodes=10`
|
||||
- SwanLab:`train.use_swanlab=true`,project=`roboimi-vla`
|
||||
|
||||
---
|
||||
|
||||
## 1. 分支结构与关键文件
|
||||
|
||||
| 路径 | 作用 |
|
||||
| --- | --- |
|
||||
| `roboimi/demos/vla_scripts/train_vla.py` | 主训练入口;负责数据集、checkpoint、数值验证、训练期 rollout 验证、SwanLab |
|
||||
| `roboimi/demos/vla_scripts/eval_vla.py` | 单次 rollout / 离线验证入口;支持 headless、summary、trajectory image/video artifact |
|
||||
| `roboimi/vla/conf/config.yaml` | 全局 Hydra 配置;训练默认值都在这里 |
|
||||
| `roboimi/vla/conf/eval/eval.yaml` | eval 默认配置;`eval.ckpt_path`、`eval.num_episodes`、artifact 开关都在这里 |
|
||||
| `roboimi/vla/conf/agent/lewm_resnet_query_imf_attnres.yaml` | 本分支最常用 agent;LeWM query fusion + IMF AttnRes head |
|
||||
| `roboimi/vla/conf/backbone/lewm_resnet_query_fusion.yaml` | LeWM 多视角 ResNet query fusion backbone 配置 |
|
||||
| `roboimi/vla/agent_imf.py` | `IMFVLAAgent` 实现;one-step IMF 推理、LeWM loss、LeWM 预训练组件加载 |
|
||||
| `roboimi/vla/data/simpe_robot_dataset.py` | HDF5 懒加载数据集;也负责 `episode_indices` 过滤 |
|
||||
| `roboimi/vla/scripts/calculate_stats.py` | 重算 `dataset_stats.pkl` |
|
||||
| `experiment_suites/2026-04-21-lewm-fromscratch-old9-epoch50-roll5-val-20260421-153037/` | 当前最常用 suite;manifest、notes、launch log、local 启动脚本都在这里 |
|
||||
|
||||
补充:
|
||||
- 本分支常用 run name 形如 `lewmimf-q08-ph08-ex08-emb384-l12-fromscratch-epoch50-step109350-5090g0-20260421-153037`
|
||||
- `q08/ph16/ex08` 这类后缀分别对应 `agent.lewm_query_offsets`、`agent.pred_horizon`、`agent.num_action_steps`
|
||||
|
||||
---
|
||||
|
||||
## 2. 三台机器与环境
|
||||
|
||||
| 机器 | GPU | repo / worktree | Python | 常用数据集路径 |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| 本地 `droid-z790eagleax` | 1× RTX 5090 32GB | `/home/droid/project/roboimi/.worktrees/feat-lewm-imf-fusion` | `/home/droid/.conda/envs/roboimi/bin/python` | `/home/droid/project/diana_sim/sim_transfer` |
|
||||
| 5880 节点 `100.73.14.65` | 2× RTX 5880 Ada 48GB | `/home/droid/roboimi_suite_20260416_lewm_imf_fusion` | `/home/droid/miniforge3/envs/roboimi/bin/python` | `/home/droid/sim_dataset/sim_transfer` |
|
||||
| L20 节点 `100.119.99.14` | 8× NVIDIA L20 46GB | `/data/roboimi_suite_20260416_lewm_imf_fusion` | `/home/droid/miniforge3/envs/roboimi/bin/python` | `/data/simtransfer/current` |
|
||||
|
||||
连接:
|
||||
- 5880:`ssh droid@100.73.14.65`
|
||||
- L20:`ssh droid@100.119.99.14`
|
||||
|
||||
经验规则:
|
||||
- 本地 5090:适合单条 smoke / 小规模主跑 / 本地调参
|
||||
- 5880:适合 2 条并行主跑
|
||||
- L20:适合大 grid;数据和 run 建议都放 `/data`
|
||||
|
||||
---
|
||||
|
||||
## 3. 训练流怎么走
|
||||
|
||||
`train_vla.py` 的实际流程:
|
||||
|
||||
1. 读取 Hydra 配置并打印完整 cfg
|
||||
2. 通过 `build_train_val_datasets()` 构建 train/val dataset
|
||||
3. 用 `DataLoader` 建 train/val loader
|
||||
4. 从 `dataset_dir/dataset_stats.pkl` 读取归一化统计
|
||||
5. instantiate `IMFVLAAgent`
|
||||
6. 可选加载:
|
||||
- `train.pretrained_ckpt`
|
||||
- `train.resume_ckpt`
|
||||
- `agent.lewm_pretrained_ckpt`
|
||||
7. 训练循环里按 `log_freq` 打 train loss / lr
|
||||
8. 按 `save_freq` 保存 `checkpoints/vla_model_step_*.pt`
|
||||
9. 每个 epoch 结束时,按配置跑:
|
||||
- held-out action MSE
|
||||
- rollout validation
|
||||
10. 最后写:
|
||||
- `checkpoints/vla_model_best.pt`
|
||||
- `checkpoints/vla_model_final.pt`
|
||||
|
||||
当前 best model 选择逻辑:
|
||||
- **第一次拿到 rollout reward 之前**:先用 `val_loss`(或 train loss 回退)挑 best
|
||||
- **第一次 rollout 之后**:优先用 `rollout_avg_reward` 挑 best
|
||||
|
||||
输出目录一般通过 `hydra.run.dir=...` 固定;否则 Hydra 自己生成。
|
||||
|
||||
---
|
||||
|
||||
## 4. 验证流怎么走
|
||||
|
||||
### 4.1 held-out 数值验证
|
||||
|
||||
当前常用做法不是随机切 `val_split`,而是:
|
||||
- `train.val_split=0.0`
|
||||
- `train.val_episode_indices=[100]`
|
||||
- `train.action_mse_val_freq_epochs=1`
|
||||
|
||||
这样每个 epoch 结束都会在 `episode_100.hdf5` 上跑一次 `compute_action_mse_validation()`,日志 key 是:
|
||||
- 控制台 / `train_vla.log`:`held-out action MSE`
|
||||
- SwanLab:`val/action_mse`
|
||||
|
||||
### 4.2 rollout 验证
|
||||
|
||||
当前训练内 rollout 验证由 `train_vla.py -> run_rollout_validation() -> eval_vla._run_eval()` 触发。
|
||||
|
||||
当前这条分支的常用训练内 rollout 约束是:
|
||||
- `train.rollout_val_freq_epochs=5`
|
||||
- `train.rollout_num_episodes=10`
|
||||
- `train.rollout_validate_on_checkpoint=false`
|
||||
- 强制 headless
|
||||
- 强制 `verbose_action=false`
|
||||
- 强制 `record_video=false`
|
||||
- 强制 `save_trajectory_image=true`
|
||||
- 强制 `trajectory_image_camera_name=front`
|
||||
- 强制 `save_summary_json=true`
|
||||
|
||||
当前已经修正为**配置驱动的 rollout device / worker 路径**:
|
||||
- `train.rollout_device`:默认跟随 `train.device`
|
||||
- `train.rollout_num_workers`:默认 `null`
|
||||
- 当 rollout 设备是 CPU 时,自动退化为 `1`
|
||||
- 当 rollout 设备是 CUDA 时,自动推断为 `min(train.rollout_num_episodes, 8)`
|
||||
- `train.rollout_cuda_devices`:默认 `null`,等价于当前可见逻辑 GPU `[0]`
|
||||
- `train.rollout_response_timeout_s`
|
||||
- `train.rollout_server_startup_timeout_s`
|
||||
|
||||
所以现在:
|
||||
- 训练在 `cuda` 上时,**训练期 rollout 默认会走 GPU**
|
||||
- 如果 `rollout_num_workers > 1`,就会自动走并行 rollout
|
||||
- 可以是 **单 GPU 多 worker 共用一个 inference server**
|
||||
- 也可以是 **多 GPU 多 server 分摊 worker**
|
||||
|
||||
训练内 rollout artifact 默认落到:
|
||||
`<hydra.run.dir>/rollout_artifacts/<checkpoint_stem>/`
|
||||
|
||||
常见文件:
|
||||
- `rollout_summary.json`
|
||||
- `rollout_front_ep01_trajectory.png` ... `rollout_front_ep10_trajectory.png`
|
||||
|
||||
日志重点看:
|
||||
- `Epoch X rollout 平均奖励`
|
||||
- `最佳模型已更新`
|
||||
|
||||
---
|
||||
|
||||
## 5. 数据集加载与 `val_episode_indices` 机制
|
||||
|
||||
### 5.1 数据集格式
|
||||
|
||||
`SimpleRobotDataset` 读取 `dataset_dir` 下的 `episode_*.hdf5`,每个 episode 文件里至少要有:
|
||||
- `action`
|
||||
- `observations/qpos`
|
||||
- `observations/images/{cam_name}`
|
||||
|
||||
当前常用相机:
|
||||
- `r_vis`
|
||||
- `top`
|
||||
- `front`
|
||||
|
||||
### 5.2 懒加载行为
|
||||
|
||||
`roboimi/vla/data/simpe_robot_dataset.py` 是按帧懒加载,不会一次性把整套 HDF5 全读进内存。
|
||||
|
||||
它会:
|
||||
- 扫描目录下的 HDF5 文件
|
||||
- 用文件名里的 episode 编号(如 `episode_100.hdf5` -> `100`)建立 `available_episode_indices`
|
||||
- 在 worker 内做 HDF5 文件句柄 LRU 缓存
|
||||
|
||||
### 5.3 `val_episode_indices` 怎么切
|
||||
|
||||
`build_train_val_datasets()` 的逻辑是:
|
||||
|
||||
1. 先 instantiate 一次完整 dataset
|
||||
2. 读取 `dataset.available_episode_indices`
|
||||
3. 检查 `train.val_episode_indices` 是否都存在
|
||||
4. 用 `episode_indices=` 再各 instantiate 一次:
|
||||
- train dataset = 全部 episode - held-out episode
|
||||
- val dataset = 只包含 held-out episode
|
||||
|
||||
因此:
|
||||
- `train.val_episode_indices=[100]` 的意思是“把 `episode_100.hdf5` 整个拿去做 held-out val”
|
||||
- 如果 episode 不存在,会直接报错
|
||||
- 如果你把所有 episode 都塞进 `val_episode_indices`,也会直接报错,因为训练集会变空
|
||||
|
||||
### 5.4 图像 resize 与 LeWM 附加字段
|
||||
|
||||
dataset 侧 resize 默认来自:
|
||||
- `data.image_resize_shape`
|
||||
- 如果 backbone 额外覆盖,则优先 `agent.vision_backbone.dataset_image_resize_shape`
|
||||
|
||||
返回 batch 除了常规:
|
||||
- `observation.state`
|
||||
- `observation.<cam>`
|
||||
- `action`
|
||||
|
||||
还会在 LeWM 打开时返回:
|
||||
- `lewm.observation.state`
|
||||
- `lewm.observation.<cam>`
|
||||
- `lewm.future.state`
|
||||
- `lewm.future.<cam>`
|
||||
|
||||
### 5.5 统计文件
|
||||
|
||||
训练和推理都默认依赖 `dataset_stats.pkl`。数据集更新后重算:
|
||||
|
||||
```bash
|
||||
/home/droid/.conda/envs/roboimi/bin/python roboimi/vla/scripts/calculate_stats.py \
|
||||
--dataset_dir /home/droid/project/diana_sim/sim_transfer
|
||||
```
|
||||
|
||||
远端只要把 `--dataset_dir` 换成对应主机路径即可。
|
||||
|
||||
---
|
||||
|
||||
## 6. SwanLab 行为
|
||||
|
||||
当前配置默认值里 `train.use_swanlab=false`,但本分支常用 recipe 基本都显式开:
|
||||
- `train.use_swanlab=true`
|
||||
- `train.swanlab_project=roboimi-vla`
|
||||
- `train.swanlab_run_name=<run_name>`
|
||||
|
||||
`train_vla.py` 的 SwanLab 行为:
|
||||
- 初始化时上传 `train` / `data` / `agent` 三段 config
|
||||
- 训练中记录:
|
||||
- `train/loss`
|
||||
- `train/lr`
|
||||
- `train/best_loss`
|
||||
- `train/step`
|
||||
- checkpoint 验证时记录:
|
||||
- `val/loss`
|
||||
- held-out 数值验证时记录:
|
||||
- `val/action_mse`
|
||||
- rollout 验证时记录:
|
||||
- `rollout/avg_reward`
|
||||
- `rollout/epoch`
|
||||
- 训练结束时记录:
|
||||
- `final/checkpoint_path`
|
||||
- `final/best_checkpoint_path`
|
||||
|
||||
训练期 rollout 生成的前视图轨迹 PNG 会 best-effort 上传到 SwanLab;失败只会 warning,不会让训练中断。
|
||||
|
||||
---
|
||||
|
||||
## 7. 并行 rollout 说明
|
||||
|
||||
### 7.1 这套能力从哪里来
|
||||
|
||||
本分支的并行 rollout 方向不是 DataLoader 并行,而是 **`eval_vla.py` 的 multiprocess rollout path**。
|
||||
参考来源:
|
||||
`/home/droid/project/roboimi/.worktrees/multiprocess-rollout/roboimi/demos/vla_scripts/eval_vla.py`
|
||||
|
||||
那条路径的控制参数是:
|
||||
- `eval.num_workers`
|
||||
- `eval.cuda_devices`
|
||||
|
||||
语义是:
|
||||
- `eval.num_workers`:环境 worker 数,按 episode 切分
|
||||
- `eval.cuda_devices`:推理 server 绑定到哪些逻辑 GPU
|
||||
|
||||
### 7.2 两种常见模式
|
||||
|
||||
1. **单机单卡,多 worker 共用同一张 GPU**
|
||||
- 典型:本地 5090 只有 1 卡,但想让 4 个 rollout worker 并行跑环境
|
||||
- 形式:`eval.device=cuda eval.num_workers=4 'eval.cuda_devices=[0]'`
|
||||
- 这时是 **1 个 CUDA inference server + 4 个 env worker**
|
||||
|
||||
2. **单机多卡,多 server 分摊 worker**
|
||||
- 典型:5880 有 2 卡,L20 有多卡
|
||||
- 形式:`eval.device=cuda eval.num_workers=8 'eval.cuda_devices=[0,1]'`
|
||||
- worker 会按 round-robin 分到多个 server 上
|
||||
|
||||
### 7.3 操作上要注意什么
|
||||
|
||||
- 并行 rollout 依赖 **多进程 eval 路径**,不是 `train.num_workers`
|
||||
- `train.num_workers` 是 DataLoader worker,和 rollout 并行不是一回事
|
||||
- `eval.num_workers > 1` 时必须 `eval.headless=true`
|
||||
- worker 数会自动 cap 到 `eval.num_episodes`
|
||||
- multiprocess rollout 当前已经支持 **per-episode trajectory image PNG**;多 worker 时每个 worker 会在自己的 artifact 子目录下写图,summary 会带回对应路径
|
||||
- 但多 worker 时仍然不要同时要求:
|
||||
- `eval.record_video=true`
|
||||
- `eval.save_trajectory=true`
|
||||
- `eval.save_trajectory_npz=true`
|
||||
- `eval.save_trajectory_image=true` 现在是可以开的,适合并行 reward + 定性检查一起做
|
||||
|
||||
### 7.4 并行 rollout 命令模板
|
||||
|
||||
**5090 单卡 4 worker:**
|
||||
|
||||
```bash
|
||||
/home/droid/.conda/envs/roboimi/bin/python roboimi/demos/vla_scripts/eval_vla.py \
|
||||
agent=lewm_resnet_query_imf_attnres \
|
||||
data.dataset_dir=/home/droid/project/diana_sim/sim_transfer \
|
||||
train.device=cuda eval.device=cuda eval.headless=true eval.verbose_action=false \
|
||||
eval.ckpt_path=/home/droid/project/roboimi/.worktrees/feat-lewm-imf-fusion/runs/<run_name>/checkpoints/vla_model_best.pt \
|
||||
eval.num_episodes=10 eval.num_workers=4 'eval.cuda_devices=[0]' \
|
||||
eval.save_summary_json=true eval.artifact_dir=/tmp/lewm_parallel_eval_5090
|
||||
```
|
||||
|
||||
**5880 双卡 8 worker:**
|
||||
|
||||
```bash
|
||||
/home/droid/miniforge3/envs/roboimi/bin/python roboimi/demos/vla_scripts/eval_vla.py \
|
||||
agent=lewm_resnet_query_imf_attnres \
|
||||
data.dataset_dir=/home/droid/sim_dataset/sim_transfer \
|
||||
train.device=cuda eval.device=cuda eval.headless=true eval.verbose_action=false \
|
||||
eval.ckpt_path=/home/droid/roboimi_suite_20260416_lewm_imf_fusion/runs/<run_name>/checkpoints/vla_model_best.pt \
|
||||
eval.num_episodes=10 eval.num_workers=8 'eval.cuda_devices=[0,1]' \
|
||||
eval.save_summary_json=true eval.artifact_dir=/tmp/lewm_parallel_eval_5880
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 8. 当前常用命令 / 脚本
|
||||
|
||||
### 8.1 本地 5090:直接用 suite 脚本
|
||||
|
||||
现成脚本:
|
||||
`experiment_suites/2026-04-21-lewm-fromscratch-old9-epoch50-roll5-val-20260421-153037/launch_local_5090.sh`
|
||||
|
||||
运行:
|
||||
|
||||
```bash
|
||||
bash experiment_suites/2026-04-21-lewm-fromscratch-old9-epoch50-roll5-val-20260421-153037/launch_local_5090.sh
|
||||
```
|
||||
|
||||
### 8.2 本地 5090:手动启动同 recipe
|
||||
|
||||
```bash
|
||||
/home/droid/.conda/envs/roboimi/bin/python roboimi/demos/vla_scripts/train_vla.py \
|
||||
agent=lewm_resnet_query_imf_attnres \
|
||||
data.dataset_dir=/home/droid/project/diana_sim/sim_transfer \
|
||||
'agent.lewm_query_offsets=[8]' \
|
||||
agent.pred_horizon=8 \
|
||||
agent.num_action_steps=8 \
|
||||
train.device=cuda \
|
||||
train.batch_size=32 \
|
||||
train.lr=0.0001 \
|
||||
train.max_steps=109350 \
|
||||
train.num_workers=4 \
|
||||
train.save_freq=10000 \
|
||||
train.rollout_validate_on_checkpoint=false \
|
||||
train.rollout_val_freq_epochs=5 \
|
||||
train.rollout_num_episodes=10 \
|
||||
train.val_split=0.0 \
|
||||
'train.val_episode_indices=[100]' \
|
||||
train.action_mse_val_freq_epochs=1 \
|
||||
train.use_swanlab=true \
|
||||
train.swanlab_project=roboimi-vla \
|
||||
train.swanlab_run_name=lewmimf-q08-ph08-ex08-emb384-l12-fromscratch-epoch50-step109350-5090g0-20260421-153037 \
|
||||
train.pretrained_ckpt=null \
|
||||
agent.lewm_pretrained_ckpt=null \
|
||||
hydra.run.dir=/home/droid/project/roboimi/.worktrees/feat-lewm-imf-fusion/runs/lewmimf-q08-ph08-ex08-emb384-l12-fromscratch-epoch50-step109350-5090g0-20260421-153037
|
||||
```
|
||||
|
||||
### 8.3 5880:常用命令模板
|
||||
|
||||
```bash
|
||||
ssh droid@100.73.14.65
|
||||
cd /home/droid/roboimi_suite_20260416_lewm_imf_fusion
|
||||
/home/droid/miniforge3/envs/roboimi/bin/python roboimi/demos/vla_scripts/train_vla.py \
|
||||
agent=lewm_resnet_query_imf_attnres \
|
||||
data.dataset_dir=/home/droid/sim_dataset/sim_transfer \
|
||||
'agent.lewm_query_offsets=[8]' \
|
||||
agent.pred_horizon=16 \
|
||||
agent.num_action_steps=8 \
|
||||
train.device=cuda train.batch_size=32 train.lr=0.0001 train.max_steps=109350 \
|
||||
train.num_workers=4 train.save_freq=10000 train.rollout_validate_on_checkpoint=false \
|
||||
train.rollout_val_freq_epochs=5 train.rollout_num_episodes=10 train.val_split=0.0 \
|
||||
'train.val_episode_indices=[100]' train.action_mse_val_freq_epochs=1 \
|
||||
train.use_swanlab=true train.swanlab_project=roboimi-vla \
|
||||
train.swanlab_run_name=lewmimf-q08-ph16-ex08-emb384-l12-fromscratch-epoch50-step109350-5880g0-20260421-153037 \
|
||||
train.pretrained_ckpt=null agent.lewm_pretrained_ckpt=null \
|
||||
hydra.run.dir=/home/droid/roboimi_suite_20260416_lewm_imf_fusion/runs/lewmimf-q08-ph16-ex08-emb384-l12-fromscratch-epoch50-step109350-5880g0-20260421-153037
|
||||
```
|
||||
|
||||
### 8.4 L20:常用命令模板
|
||||
|
||||
```bash
|
||||
ssh droid@100.119.99.14
|
||||
cd /data/roboimi_suite_20260416_lewm_imf_fusion
|
||||
/home/droid/miniforge3/envs/roboimi/bin/python roboimi/demos/vla_scripts/train_vla.py \
|
||||
agent=lewm_resnet_query_imf_attnres \
|
||||
data.dataset_dir=/data/simtransfer/current \
|
||||
'agent.lewm_query_offsets=[16]' \
|
||||
agent.pred_horizon=16 \
|
||||
agent.num_action_steps=16 \
|
||||
train.device=cuda train.batch_size=32 train.lr=0.0001 train.max_steps=109350 \
|
||||
train.num_workers=4 train.save_freq=10000 train.rollout_validate_on_checkpoint=false \
|
||||
train.rollout_val_freq_epochs=5 train.rollout_num_episodes=10 train.val_split=0.0 \
|
||||
'train.val_episode_indices=[100]' train.action_mse_val_freq_epochs=1 \
|
||||
train.use_swanlab=true train.swanlab_project=roboimi-vla \
|
||||
train.swanlab_run_name=lewmimf-q16-ph16-ex16-emb384-l12-fromscratch-epoch50-step109350-l20g0-20260421-153037 \
|
||||
train.pretrained_ckpt=null agent.lewm_pretrained_ckpt=null \
|
||||
hydra.run.dir=/data/roboimi_suite_20260416_lewm_imf_fusion/runs/lewmimf-q16-ph16-ex16-emb384-l12-fromscratch-epoch50-step109350-l20g0-20260421-153037
|
||||
```
|
||||
|
||||
### 8.5 单次离线验证(当前分支已支持并行)
|
||||
|
||||
**单 GPU / 4 worker:**
|
||||
|
||||
```bash
|
||||
/home/droid/.conda/envs/roboimi/bin/python roboimi/demos/vla_scripts/eval_vla.py \
|
||||
agent=lewm_resnet_query_imf_attnres \
|
||||
data.dataset_dir=/home/droid/project/diana_sim/sim_transfer \
|
||||
train.device=cuda eval.device=cuda \
|
||||
eval.ckpt_path=/home/droid/project/roboimi/.worktrees/feat-lewm-imf-fusion/runs/<run_name>/checkpoints/vla_model_best.pt \
|
||||
eval.num_episodes=10 eval.num_workers=4 'eval.cuda_devices=[0]' \
|
||||
eval.headless=true eval.verbose_action=false \
|
||||
eval.save_summary_json=true eval.save_trajectory_image=true \
|
||||
eval.trajectory_image_camera_name=front \
|
||||
eval.artifact_dir=/tmp/lewm_eval_front
|
||||
```
|
||||
|
||||
**训练内启用并行 GPU rollout(推荐显式写清楚)**:
|
||||
|
||||
```bash
|
||||
/home/droid/.conda/envs/roboimi/bin/python roboimi/demos/vla_scripts/train_vla.py \
|
||||
agent=lewm_resnet_query_imf_attnres \
|
||||
data.dataset_dir=/home/droid/project/diana_sim/sim_transfer \
|
||||
'agent.lewm_query_offsets=[8]' \
|
||||
agent.pred_horizon=8 \
|
||||
agent.num_action_steps=8 \
|
||||
train.device=cuda \
|
||||
train.batch_size=32 \
|
||||
train.lr=0.0001 \
|
||||
train.max_steps=109350 \
|
||||
train.num_workers=4 \
|
||||
train.save_freq=10000 \
|
||||
train.rollout_val_freq_epochs=5 \
|
||||
train.rollout_num_episodes=10 \
|
||||
train.rollout_device=cuda \
|
||||
train.rollout_num_workers=4 \
|
||||
'train.rollout_cuda_devices=[0]' \
|
||||
train.rollout_validate_on_checkpoint=false \
|
||||
train.val_split=0.0 \
|
||||
'train.val_episode_indices=[100]' \
|
||||
train.action_mse_val_freq_epochs=1 \
|
||||
train.use_swanlab=true \
|
||||
train.swanlab_project=roboimi-vla \
|
||||
train.swanlab_run_name=<run_name> \
|
||||
hydra.run.dir=/home/droid/project/roboimi/.worktrees/feat-lewm-imf-fusion/runs/<run_name>
|
||||
```
|
||||
|
||||
### 8.6 监控日志
|
||||
|
||||
```bash
|
||||
tail -f runs/<run_name>/launch.stdout.log
|
||||
tail -f runs/<run_name>/train_vla.log
|
||||
```
|
||||
|
||||
远端就把 `runs/<run_name>` 换成 manifest 里的绝对路径。
|
||||
|
||||
---
|
||||
|
||||
## 9. 操作建议
|
||||
|
||||
- **优先以 suite 的 `manifest.json` / `notes.md` / `launch_logs/*.launch.log` 为准**,不要手写一套和历史 run 不一致的命令
|
||||
- 要做当前常用验证,就显式加上:
|
||||
- `train.val_split=0.0`
|
||||
- `train.val_episode_indices=[100]`
|
||||
- `train.action_mse_val_freq_epochs=1`
|
||||
- `train.rollout_val_freq_epochs=5`
|
||||
- `train.rollout_num_episodes=10`
|
||||
- 本分支如果要对比不同 horizon / action-step,尽量只改:
|
||||
- `agent.lewm_query_offsets`
|
||||
- `agent.pred_horizon`
|
||||
- `agent.num_action_steps`
|
||||
- 想复现 2026-04-21 那轮 from-scratch 结果时,记得同时设:
|
||||
- `train.pretrained_ckpt=null`
|
||||
- `agent.lewm_pretrained_ckpt=null`
|
||||
File diff suppressed because it is too large
Load Diff
@@ -838,10 +838,28 @@ def _run_training(cfg: DictConfig):
|
||||
from roboimi.demos.vla_scripts import eval_vla
|
||||
|
||||
rollout_cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=False))
|
||||
rollout_num_episodes = int(cfg.train.get('rollout_num_episodes', 1))
|
||||
rollout_device = str(cfg.train.get('rollout_device', cfg.train.device))
|
||||
configured_rollout_workers = cfg.train.get('rollout_num_workers', None)
|
||||
if configured_rollout_workers is None:
|
||||
if rollout_device.startswith('cuda'):
|
||||
rollout_num_workers = min(max(rollout_num_episodes, 1), 8)
|
||||
else:
|
||||
rollout_num_workers = 1
|
||||
else:
|
||||
rollout_num_workers = int(configured_rollout_workers)
|
||||
rollout_cfg.eval.ckpt_path = str(checkpoint_path)
|
||||
rollout_cfg.eval.num_episodes = int(cfg.train.get('rollout_num_episodes', 1))
|
||||
rollout_cfg.eval.num_episodes = rollout_num_episodes
|
||||
rollout_cfg.eval.num_workers = rollout_num_workers
|
||||
rollout_cfg.eval.headless = True
|
||||
rollout_cfg.eval.device = 'cpu'
|
||||
rollout_cfg.eval.device = rollout_device
|
||||
rollout_cfg.eval.cuda_devices = cfg.train.get('rollout_cuda_devices', None)
|
||||
rollout_cfg.eval.response_timeout_s = float(
|
||||
cfg.train.get('rollout_response_timeout_s', 300.0)
|
||||
)
|
||||
rollout_cfg.eval.server_startup_timeout_s = float(
|
||||
cfg.train.get('rollout_server_startup_timeout_s', 300.0)
|
||||
)
|
||||
rollout_cfg.eval.verbose_action = False
|
||||
rollout_cfg.eval.record_video = False
|
||||
rollout_cfg.eval.save_trajectory_image = True
|
||||
@@ -852,9 +870,11 @@ def _run_training(cfg: DictConfig):
|
||||
)
|
||||
|
||||
log.info(
|
||||
"🎯 开始 checkpoint rollout 验证: %s (episodes=%s, headless=True)",
|
||||
"🎯 开始 checkpoint rollout 验证: %s (episodes=%s, device=%s, workers=%s, headless=True)",
|
||||
checkpoint_path,
|
||||
rollout_cfg.eval.num_episodes,
|
||||
rollout_cfg.eval.device,
|
||||
rollout_cfg.eval.num_workers,
|
||||
)
|
||||
return eval_vla._run_eval(rollout_cfg)
|
||||
|
||||
|
||||
@@ -31,6 +31,11 @@ train:
|
||||
rollout_val_freq_epochs: 50 # 每隔多少个 epoch 执行一次 rollout 验证
|
||||
rollout_validate_on_checkpoint: false # 是否在保存 checkpoint 后立即运行 rollout 验证
|
||||
rollout_num_episodes: 3 # rollout 验证的回合数
|
||||
rollout_device: ${train.device} # rollout 使用的设备;默认跟随训练设备
|
||||
rollout_num_workers: null # rollout 并行 worker 数;null 时 CUDA 自动推断,CPU 保持 1
|
||||
rollout_cuda_devices: null # rollout CUDA 并行使用的逻辑 device 列表;null 时默认 [0]
|
||||
rollout_response_timeout_s: 300.0 # rollout worker 等待 inference server 响应的超时时间
|
||||
rollout_server_startup_timeout_s: 300.0 # rollout 等待 inference server 就绪的超时时间
|
||||
|
||||
# 学习率调度器(带预热)
|
||||
warmup_steps: 2000 # 预热步数(Transformer建议更长)
|
||||
|
||||
@@ -2,6 +2,10 @@
|
||||
# 评估配置
|
||||
ckpt_path: "checkpoints/vla_model_best.pt" # 模型检查点路径
|
||||
num_episodes: 3 # 评估回合数
|
||||
num_workers: 1 # 并行 worker 数;1 表示保持单进程评估
|
||||
cuda_devices: null # CUDA 并行评估时使用的逻辑设备列表;null 表示默认 [0]
|
||||
response_timeout_s: 300.0 # worker 等待 inference server 响应的超时时间(秒)
|
||||
server_startup_timeout_s: 300.0 # parent 等待 inference server 就绪的超时时间(秒)
|
||||
max_timesteps: 700 # 每回合最大时间步
|
||||
device: ${train.device} # 与训练保持一致
|
||||
task_name: "sim_transfer" # 环境任务名称
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from roboimi.demos.vla_scripts import eval_vla
|
||||
from roboimi.vla.eval_utils import execute_policy_action
|
||||
|
||||
|
||||
@@ -14,6 +20,48 @@ class _FakeEnv:
|
||||
self.calls.append(("step_jnt", action))
|
||||
|
||||
|
||||
class _FakeQueue:
|
||||
def __init__(self, initial_items=None):
|
||||
self.items = list(initial_items or [])
|
||||
self.put_calls = []
|
||||
|
||||
def put(self, item):
|
||||
self.put_calls.append(item)
|
||||
self.items.append(item)
|
||||
|
||||
def get(self, timeout=None):
|
||||
del timeout
|
||||
if not self.items:
|
||||
raise AssertionError("queue unexpectedly empty")
|
||||
return self.items.pop(0)
|
||||
|
||||
|
||||
def _make_parallel_cfg(**eval_overrides):
|
||||
eval_cfg = {
|
||||
"ckpt_path": "checkpoints/vla_model_best.pt",
|
||||
"num_episodes": 5,
|
||||
"num_workers": 2,
|
||||
"max_timesteps": 1,
|
||||
"device": "cpu",
|
||||
"task_name": "sim_transfer",
|
||||
"camera_names": ["front"],
|
||||
"use_smoothing": False,
|
||||
"smooth_alpha": 0.3,
|
||||
"verbose_action": False,
|
||||
"headless": True,
|
||||
"artifact_dir": None,
|
||||
"save_artifacts": False,
|
||||
"save_summary_json": False,
|
||||
"save_timing": False,
|
||||
"save_trajectory": False,
|
||||
"save_trajectory_npz": False,
|
||||
"record_video": False,
|
||||
"save_trajectory_image": False,
|
||||
}
|
||||
eval_cfg.update(eval_overrides)
|
||||
return OmegaConf.create({"agent": {}, "eval": eval_cfg})
|
||||
|
||||
|
||||
class EvalVLAExecutionTest(unittest.TestCase):
|
||||
def test_execute_policy_action_uses_ee_step(self):
|
||||
env = _FakeEnv()
|
||||
@@ -23,6 +71,662 @@ class EvalVLAExecutionTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(env.calls, [("step", action)])
|
||||
|
||||
def test_split_episode_indices_balances_workers(self):
|
||||
self.assertEqual(
|
||||
eval_vla._split_episode_indices(num_episodes=10, num_workers=3),
|
||||
[[0, 1, 2, 3], [4, 5, 6], [7, 8, 9]],
|
||||
)
|
||||
|
||||
def test_normalize_num_workers_caps_worker_count_to_episode_count(self):
|
||||
self.assertEqual(eval_vla._normalize_num_workers(num_workers=5, num_episodes=2), 2)
|
||||
|
||||
def test_plan_episode_box_poses_uses_global_episode_order(self):
|
||||
planned_poses = [
|
||||
np.array([0.1, 0.2, 0.3], dtype=np.float32),
|
||||
np.array([1.1, 1.2, 1.3], dtype=np.float32),
|
||||
np.array([2.1, 2.2, 2.3], dtype=np.float32),
|
||||
]
|
||||
sampler = mock.Mock(side_effect=planned_poses)
|
||||
|
||||
result = eval_vla._plan_episode_box_poses(num_episodes=3, sampler=sampler)
|
||||
|
||||
self.assertEqual(sampler.call_count, 3)
|
||||
self.assertEqual(len(result), 3)
|
||||
for expected, actual in zip(planned_poses, result):
|
||||
np.testing.assert_array_equal(actual, expected)
|
||||
|
||||
def test_resolve_policy_camera_names_matches_vlaagent_fallback_sorting(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
"agent": {
|
||||
"_target_": "roboimi.vla.agent.VLAAgent",
|
||||
},
|
||||
"eval": {
|
||||
"camera_names": ["r_vis", "top", "front"],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
eval_vla._resolve_policy_camera_names(cfg),
|
||||
["front", "r_vis", "top"],
|
||||
)
|
||||
|
||||
def test_resolve_policy_camera_names_matches_gr00t_fallback_input_order(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
"agent": {
|
||||
"_target_": "roboimi.vla.agent_gr00t_dit.VLAAgentGr00tDiT",
|
||||
},
|
||||
"eval": {
|
||||
"camera_names": ["r_vis", "top", "front"],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
eval_vla._resolve_policy_camera_names(cfg),
|
||||
["r_vis", "top", "front"],
|
||||
)
|
||||
|
||||
def test_build_episode_plans_without_box_poses_keeps_serial_sampling_lazy(self):
|
||||
plans = eval_vla._build_episode_plans(num_episodes=3)
|
||||
|
||||
self.assertEqual(
|
||||
plans,
|
||||
[
|
||||
{"episode_index": 0},
|
||||
{"episode_index": 1},
|
||||
{"episode_index": 2},
|
||||
],
|
||||
)
|
||||
|
||||
def test_prepare_local_policy_batch_pads_latest_observation_to_obs_horizon(self):
|
||||
queues = eval_vla._new_local_policy_queues(obs_horizon=3)
|
||||
observation = {
|
||||
"qpos": torch.tensor([1.0, 2.0], dtype=torch.float32),
|
||||
"images": {
|
||||
"front": torch.tensor([[[1.0]]], dtype=torch.float32),
|
||||
},
|
||||
}
|
||||
|
||||
eval_vla._populate_local_policy_queues(queues, observation)
|
||||
batch = eval_vla._prepare_local_policy_batch(
|
||||
queues,
|
||||
obs_horizon=3,
|
||||
camera_names=["front"],
|
||||
)
|
||||
|
||||
self.assertEqual(tuple(batch["qpos"].shape), (1, 3, 2))
|
||||
self.assertEqual(tuple(batch["images"]["front"].shape), (1, 3, 1, 1, 1))
|
||||
np.testing.assert_array_equal(
|
||||
batch["qpos"][0].cpu().numpy(),
|
||||
np.array([[1.0, 2.0], [1.0, 2.0], [1.0, 2.0]], dtype=np.float32),
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
batch["images"]["front"][0].cpu().numpy(),
|
||||
np.array([[[[1.0]]], [[[1.0]]], [[[1.0]]]], dtype=np.float32),
|
||||
)
|
||||
|
||||
def test_enqueue_predicted_actions_uses_executable_slice(self):
|
||||
queues = eval_vla._new_local_policy_queues(obs_horizon=2)
|
||||
predicted_actions = torch.tensor(
|
||||
[[[10.0], [20.0], [30.0], [40.0]]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
eval_vla._enqueue_predicted_actions(
|
||||
queues,
|
||||
predicted_actions=predicted_actions,
|
||||
obs_horizon=2,
|
||||
num_action_steps=2,
|
||||
)
|
||||
|
||||
self.assertEqual(len(queues["action"]), 2)
|
||||
np.testing.assert_array_equal(queues["action"].popleft().numpy(), np.array([20.0], dtype=np.float32))
|
||||
np.testing.assert_array_equal(queues["action"].popleft().numpy(), np.array([30.0], dtype=np.float32))
|
||||
|
||||
def test_remote_policy_runner_only_requests_server_inference_when_local_action_queue_is_empty(self):
|
||||
request_queue = _FakeQueue()
|
||||
response_queue = _FakeQueue(
|
||||
[
|
||||
{
|
||||
"type": "predict_chunk_result",
|
||||
"actions": np.asarray([[[10.0], [20.0], [30.0]]], dtype=np.float32),
|
||||
}
|
||||
]
|
||||
)
|
||||
runner = eval_vla._RemotePolicyRunner(
|
||||
worker_index=3,
|
||||
server_index=1,
|
||||
request_queue=request_queue,
|
||||
response_queue=response_queue,
|
||||
camera_names=["front"],
|
||||
obs_horizon=2,
|
||||
num_action_steps=2,
|
||||
)
|
||||
first_observation = {
|
||||
"qpos": torch.tensor([1.0, 2.0], dtype=torch.float32),
|
||||
"images": {"front": torch.tensor([[[1.0]]], dtype=torch.float32)},
|
||||
}
|
||||
second_observation = {
|
||||
"qpos": torch.tensor([3.0, 4.0], dtype=torch.float32),
|
||||
"images": {"front": torch.tensor([[[2.0]]], dtype=torch.float32)},
|
||||
}
|
||||
|
||||
first_action, first_forward = runner.select_action(
|
||||
first_observation,
|
||||
episode_index=7,
|
||||
timestep=0,
|
||||
)
|
||||
second_action, second_forward = runner.select_action(
|
||||
second_observation,
|
||||
episode_index=7,
|
||||
timestep=1,
|
||||
)
|
||||
|
||||
self.assertTrue(first_forward)
|
||||
self.assertFalse(second_forward)
|
||||
self.assertEqual(len(request_queue.put_calls), 1)
|
||||
self.assertEqual(request_queue.put_calls[0]["type"], "predict_chunk")
|
||||
self.assertEqual(request_queue.put_calls[0]["worker_index"], 3)
|
||||
self.assertEqual(request_queue.put_calls[0]["server_index"], 1)
|
||||
np.testing.assert_array_equal(first_action.numpy(), np.array([20.0], dtype=np.float32))
|
||||
np.testing.assert_array_equal(second_action.numpy(), np.array([30.0], dtype=np.float32))
|
||||
|
||||
def test_merge_worker_summaries_sorts_episodes_and_recomputes_aggregates(self):
|
||||
worker_summaries = [
|
||||
{
|
||||
"avg_inference_fps": 999.0,
|
||||
"avg_control_fps": 999.0,
|
||||
"avg_obs_read_time_ms": 999.0,
|
||||
"avg_total_time_ms": 999.0,
|
||||
"timing_summary": {"count": 999, "model_forward_count": 999},
|
||||
"episodes": [
|
||||
{
|
||||
"episode_index": 2,
|
||||
"episode_reward": 9.0,
|
||||
"episode_max_reward": 4.0,
|
||||
"inference_fps": 30.0,
|
||||
"control_fps": 15.0,
|
||||
}
|
||||
],
|
||||
"_merge_state": {
|
||||
"obs_read_time_ms": [9.0],
|
||||
"preprocess_time_ms": [1.0],
|
||||
"inference_time_ms": [3.0],
|
||||
"env_step_time_ms": [4.0],
|
||||
"total_time_ms": [10.0],
|
||||
"model_forward_flags": [False],
|
||||
},
|
||||
},
|
||||
{
|
||||
"avg_inference_fps": 888.0,
|
||||
"avg_control_fps": 888.0,
|
||||
"avg_obs_read_time_ms": 888.0,
|
||||
"avg_total_time_ms": 888.0,
|
||||
"timing_summary": {"count": 888, "model_forward_count": 888},
|
||||
"episodes": [
|
||||
{
|
||||
"episode_index": 1,
|
||||
"episode_reward": 6.0,
|
||||
"episode_max_reward": 3.0,
|
||||
"inference_fps": 20.0,
|
||||
"control_fps": 10.0,
|
||||
},
|
||||
{
|
||||
"episode_index": 0,
|
||||
"episode_reward": 5.0,
|
||||
"episode_max_reward": 2.0,
|
||||
"inference_fps": 10.0,
|
||||
"control_fps": 5.0,
|
||||
},
|
||||
],
|
||||
"_merge_state": {
|
||||
"obs_read_time_ms": [1.0, 2.0, 12.0],
|
||||
"preprocess_time_ms": [2.0, 3.0, 4.0],
|
||||
"inference_time_ms": [4.0, 5.0, 6.0],
|
||||
"env_step_time_ms": [6.0, 7.0, 8.0],
|
||||
"total_time_ms": [8.0, 9.0, 20.0],
|
||||
"model_forward_flags": [True, False, True],
|
||||
},
|
||||
},
|
||||
]
|
||||
artifact_paths = {
|
||||
"output_dir": "/tmp/merged",
|
||||
"summary_json": "/tmp/merged/rollout_summary.json",
|
||||
"timing_json": "/tmp/merged/timing.json",
|
||||
"trajectory_npz": None,
|
||||
"video_mp4": None,
|
||||
"video_camera_name": None,
|
||||
}
|
||||
|
||||
merged = eval_vla._merge_worker_summaries(worker_summaries, artifact_paths)
|
||||
|
||||
self.assertEqual([episode["episode_index"] for episode in merged["episodes"]], [0, 1, 2])
|
||||
self.assertEqual(merged["episode_rewards"], [5.0, 6.0, 9.0])
|
||||
self.assertEqual(merged["episode_max_rewards"], [2.0, 3.0, 4.0])
|
||||
self.assertAlmostEqual(merged["avg_reward"], 20.0 / 3.0)
|
||||
self.assertAlmostEqual(merged["avg_max_reward"], 3.0)
|
||||
self.assertAlmostEqual(merged["avg_inference_fps"], 20.0)
|
||||
self.assertAlmostEqual(merged["avg_control_fps"], 10.0)
|
||||
self.assertAlmostEqual(merged["avg_obs_read_time_ms"], 6.0)
|
||||
self.assertAlmostEqual(merged["avg_total_time_ms"], 47.0 / 4.0)
|
||||
self.assertEqual(merged["timing_summary"]["count"], 4)
|
||||
self.assertEqual(merged["timing_summary"]["model_forward_count"], 2)
|
||||
self.assertEqual(merged["artifact_dir"], "/tmp/merged")
|
||||
self.assertEqual(merged["artifacts"], artifact_paths)
|
||||
|
||||
def test_build_cuda_server_payloads_uses_round_robin_worker_assignment(self):
|
||||
cfg = _make_parallel_cfg(num_episodes=4, num_workers=4, device="cuda", cuda_devices=[0, 1])
|
||||
artifact_paths = {"output_dir": None}
|
||||
|
||||
with mock.patch.object(
|
||||
eval_vla,
|
||||
"sample_transfer_pose",
|
||||
side_effect=[
|
||||
np.array([0.1, 0.2, 0.3], dtype=np.float32),
|
||||
np.array([0.4, 0.5, 0.6], dtype=np.float32),
|
||||
np.array([0.7, 0.8, 0.9], dtype=np.float32),
|
||||
np.array([1.0, 1.1, 1.2], dtype=np.float32),
|
||||
],
|
||||
):
|
||||
worker_payloads, _ = eval_vla._build_parallel_worker_payloads(cfg, artifact_paths)
|
||||
|
||||
server_payloads, assigned_workers = eval_vla._build_cuda_server_payloads(
|
||||
cfg,
|
||||
worker_payloads=worker_payloads,
|
||||
cuda_devices=[0, 1],
|
||||
)
|
||||
|
||||
self.assertEqual([payload["device_index"] for payload in server_payloads], [0, 1])
|
||||
self.assertEqual([payload["worker_index"] for payload in assigned_workers], [0, 1, 2, 3])
|
||||
self.assertEqual([payload["server_index"] for payload in assigned_workers], [0, 1, 0, 1])
|
||||
self.assertEqual(server_payloads[0]["worker_indices"], [0, 2])
|
||||
self.assertEqual(server_payloads[1]["worker_indices"], [1, 3])
|
||||
|
||||
def test_run_eval_parallel_dispatches_episode_splits_and_box_poses(self):
|
||||
cfg = _make_parallel_cfg(num_episodes=5, num_workers=2, artifact_dir="/tmp/parallel-root")
|
||||
planned_poses = [
|
||||
np.array([float(index), float(index) + 0.1, float(index) + 0.2], dtype=np.float32)
|
||||
for index in range(5)
|
||||
]
|
||||
observed_payloads = []
|
||||
|
||||
def fake_run_spawn_jobs(payloads, max_workers, worker_fn):
|
||||
del worker_fn
|
||||
self.assertEqual(max_workers, 2)
|
||||
observed_payloads.extend(payloads)
|
||||
return [
|
||||
{
|
||||
"episodes": [
|
||||
{
|
||||
"episode_index": 4,
|
||||
"episode_reward": 5.0,
|
||||
"episode_max_reward": 5.0,
|
||||
"inference_fps": 50.0,
|
||||
"control_fps": 25.0,
|
||||
},
|
||||
{
|
||||
"episode_index": 3,
|
||||
"episode_reward": 4.0,
|
||||
"episode_max_reward": 4.0,
|
||||
"inference_fps": 40.0,
|
||||
"control_fps": 20.0,
|
||||
},
|
||||
],
|
||||
"_merge_state": {
|
||||
"obs_read_time_ms": [4.0, 5.0],
|
||||
"preprocess_time_ms": [1.0, 1.0],
|
||||
"inference_time_ms": [2.0, 2.0],
|
||||
"env_step_time_ms": [3.0, 3.0],
|
||||
"total_time_ms": [4.0, 5.0],
|
||||
"model_forward_flags": [True, True],
|
||||
},
|
||||
},
|
||||
{
|
||||
"episodes": [
|
||||
{
|
||||
"episode_index": 2,
|
||||
"episode_reward": 3.0,
|
||||
"episode_max_reward": 3.0,
|
||||
"inference_fps": 30.0,
|
||||
"control_fps": 15.0,
|
||||
},
|
||||
{
|
||||
"episode_index": 1,
|
||||
"episode_reward": 2.0,
|
||||
"episode_max_reward": 2.0,
|
||||
"inference_fps": 20.0,
|
||||
"control_fps": 10.0,
|
||||
},
|
||||
{
|
||||
"episode_index": 0,
|
||||
"episode_reward": 1.0,
|
||||
"episode_max_reward": 1.0,
|
||||
"inference_fps": 10.0,
|
||||
"control_fps": 5.0,
|
||||
},
|
||||
],
|
||||
"_merge_state": {
|
||||
"obs_read_time_ms": [1.0, 2.0, 3.0],
|
||||
"preprocess_time_ms": [1.0, 1.0, 1.0],
|
||||
"inference_time_ms": [2.0, 2.0, 2.0],
|
||||
"env_step_time_ms": [3.0, 3.0, 3.0],
|
||||
"total_time_ms": [1.0, 2.0, 3.0],
|
||||
"model_forward_flags": [False, True, False],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
with mock.patch.object(
|
||||
eval_vla,
|
||||
"sample_transfer_pose",
|
||||
side_effect=planned_poses,
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
"_run_spawn_jobs",
|
||||
side_effect=fake_run_spawn_jobs,
|
||||
):
|
||||
summary = eval_vla._run_eval_parallel(cfg)
|
||||
|
||||
self.assertEqual(len(observed_payloads), 2)
|
||||
self.assertEqual(
|
||||
[[plan["episode_index"] for plan in payload["episode_plans"]] for payload in observed_payloads],
|
||||
[[0, 1, 2], [3, 4]],
|
||||
)
|
||||
for payload in observed_payloads:
|
||||
for plan in payload["episode_plans"]:
|
||||
np.testing.assert_array_equal(
|
||||
np.asarray(plan["box_pos"], dtype=np.float32),
|
||||
planned_poses[plan["episode_index"]],
|
||||
)
|
||||
self.assertEqual([episode["episode_index"] for episode in summary["episodes"]], [0, 1, 2, 3, 4])
|
||||
self.assertEqual(summary["episode_rewards"], [1.0, 2.0, 3.0, 4.0, 5.0])
|
||||
self.assertEqual(summary["num_episodes"], 5)
|
||||
|
||||
def test_run_eval_parallel_allows_trajectory_images_and_keeps_worker_artifact_paths(self):
|
||||
cfg = _make_parallel_cfg(
|
||||
num_episodes=2,
|
||||
num_workers=2,
|
||||
artifact_dir="/tmp/parallel-images",
|
||||
save_summary_json=True,
|
||||
save_trajectory_image=True,
|
||||
)
|
||||
observed_payloads = []
|
||||
|
||||
def fake_run_spawn_jobs(payloads, max_workers, worker_fn):
|
||||
del worker_fn
|
||||
self.assertEqual(max_workers, 2)
|
||||
observed_payloads.extend(payloads)
|
||||
return [
|
||||
{
|
||||
"episodes": [
|
||||
{
|
||||
"episode_index": 0,
|
||||
"episode_reward": 1.0,
|
||||
"episode_max_reward": 1.0,
|
||||
"inference_fps": 10.0,
|
||||
"control_fps": 5.0,
|
||||
"artifact_paths": {
|
||||
"trajectory_image": f"{payloads[0]['artifact_dir']}/rollout_front_ep01_trajectory.png",
|
||||
},
|
||||
},
|
||||
],
|
||||
"_merge_state": {
|
||||
"obs_read_time_ms": [1.0],
|
||||
"preprocess_time_ms": [1.0],
|
||||
"inference_time_ms": [1.0],
|
||||
"env_step_time_ms": [1.0],
|
||||
"total_time_ms": [1.0],
|
||||
"model_forward_flags": [True],
|
||||
},
|
||||
},
|
||||
{
|
||||
"episodes": [
|
||||
{
|
||||
"episode_index": 1,
|
||||
"episode_reward": 2.0,
|
||||
"episode_max_reward": 2.0,
|
||||
"inference_fps": 20.0,
|
||||
"control_fps": 10.0,
|
||||
"artifact_paths": {
|
||||
"trajectory_image": f"{payloads[1]['artifact_dir']}/rollout_front_ep02_trajectory.png",
|
||||
},
|
||||
},
|
||||
],
|
||||
"_merge_state": {
|
||||
"obs_read_time_ms": [2.0],
|
||||
"preprocess_time_ms": [2.0],
|
||||
"inference_time_ms": [2.0],
|
||||
"env_step_time_ms": [2.0],
|
||||
"total_time_ms": [2.0],
|
||||
"model_forward_flags": [False],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
with mock.patch.object(
|
||||
eval_vla,
|
||||
"sample_transfer_pose",
|
||||
side_effect=[
|
||||
np.array([0.1, 0.2, 0.3], dtype=np.float32),
|
||||
np.array([0.4, 0.5, 0.6], dtype=np.float32),
|
||||
],
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
"_run_spawn_jobs",
|
||||
side_effect=fake_run_spawn_jobs,
|
||||
):
|
||||
summary = eval_vla._run_eval_parallel(cfg)
|
||||
|
||||
self.assertEqual(len(observed_payloads), 2)
|
||||
self.assertTrue(observed_payloads[0]["artifact_dir"].endswith("workers/worker_00"))
|
||||
self.assertTrue(observed_payloads[1]["artifact_dir"].endswith("workers/worker_01"))
|
||||
self.assertTrue(
|
||||
summary["episodes"][0]["artifact_paths"]["trajectory_image"].endswith(
|
||||
"workers/worker_00/rollout_front_ep01_trajectory.png"
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
summary["episodes"][1]["artifact_paths"]["trajectory_image"].endswith(
|
||||
"workers/worker_01/rollout_front_ep02_trajectory.png"
|
||||
)
|
||||
)
|
||||
|
||||
def test_run_eval_parallel_surfaces_worker_failures(self):
|
||||
cfg = _make_parallel_cfg(num_episodes=2, num_workers=2)
|
||||
|
||||
with mock.patch.object(
|
||||
eval_vla,
|
||||
"sample_transfer_pose",
|
||||
side_effect=[
|
||||
np.array([0.1, 0.2, 0.3], dtype=np.float32),
|
||||
np.array([0.4, 0.5, 0.6], dtype=np.float32),
|
||||
],
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
"_run_spawn_jobs",
|
||||
side_effect=RuntimeError("boom"),
|
||||
):
|
||||
with self.assertRaisesRegex(RuntimeError, "Parallel rollout worker failed"):
|
||||
eval_vla._run_eval_parallel(cfg)
|
||||
|
||||
def test_run_eval_parallel_cuda_builds_server_payloads_and_merges_worker_results(self):
|
||||
cfg = _make_parallel_cfg(
|
||||
num_episodes=4,
|
||||
num_workers=4,
|
||||
device="cuda",
|
||||
cuda_devices=[0],
|
||||
artifact_dir="/tmp/cuda-root",
|
||||
)
|
||||
observed_server_payloads = []
|
||||
observed_worker_payloads = []
|
||||
|
||||
def fake_run_cuda_parallel_processes(server_payloads, worker_payloads):
|
||||
observed_server_payloads.extend(server_payloads)
|
||||
observed_worker_payloads.extend(worker_payloads)
|
||||
return [
|
||||
{
|
||||
"episodes": [
|
||||
{
|
||||
"episode_index": 2,
|
||||
"episode_reward": 3.0,
|
||||
"episode_max_reward": 3.0,
|
||||
"inference_fps": 30.0,
|
||||
"control_fps": 15.0,
|
||||
},
|
||||
{
|
||||
"episode_index": 0,
|
||||
"episode_reward": 1.0,
|
||||
"episode_max_reward": 1.0,
|
||||
"inference_fps": 10.0,
|
||||
"control_fps": 5.0,
|
||||
},
|
||||
],
|
||||
"_merge_state": {
|
||||
"obs_read_time_ms": [1.0, 2.0],
|
||||
"preprocess_time_ms": [1.0, 1.0],
|
||||
"inference_time_ms": [2.0, 2.0],
|
||||
"env_step_time_ms": [3.0, 3.0],
|
||||
"total_time_ms": [4.0, 4.0],
|
||||
"model_forward_flags": [True, False],
|
||||
},
|
||||
},
|
||||
{
|
||||
"episodes": [
|
||||
{
|
||||
"episode_index": 3,
|
||||
"episode_reward": 4.0,
|
||||
"episode_max_reward": 4.0,
|
||||
"inference_fps": 40.0,
|
||||
"control_fps": 20.0,
|
||||
},
|
||||
{
|
||||
"episode_index": 1,
|
||||
"episode_reward": 2.0,
|
||||
"episode_max_reward": 2.0,
|
||||
"inference_fps": 20.0,
|
||||
"control_fps": 10.0,
|
||||
},
|
||||
],
|
||||
"_merge_state": {
|
||||
"obs_read_time_ms": [3.0, 4.0],
|
||||
"preprocess_time_ms": [1.0, 1.0],
|
||||
"inference_time_ms": [2.0, 2.0],
|
||||
"env_step_time_ms": [3.0, 3.0],
|
||||
"total_time_ms": [4.0, 4.0],
|
||||
"model_forward_flags": [True, True],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
with mock.patch.object(
|
||||
eval_vla,
|
||||
"sample_transfer_pose",
|
||||
side_effect=[
|
||||
np.array([0.1, 0.2, 0.3], dtype=np.float32),
|
||||
np.array([0.4, 0.5, 0.6], dtype=np.float32),
|
||||
np.array([0.7, 0.8, 0.9], dtype=np.float32),
|
||||
np.array([1.0, 1.1, 1.2], dtype=np.float32),
|
||||
],
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
"_run_cuda_parallel_processes",
|
||||
side_effect=fake_run_cuda_parallel_processes,
|
||||
create=True,
|
||||
):
|
||||
summary = eval_vla._run_eval_parallel_cuda(cfg)
|
||||
|
||||
self.assertEqual(len(observed_server_payloads), 1)
|
||||
self.assertEqual(observed_server_payloads[0]["device_index"], 0)
|
||||
self.assertEqual(len(observed_worker_payloads), 4)
|
||||
self.assertTrue(all(payload["server_index"] == 0 for payload in observed_worker_payloads))
|
||||
self.assertEqual([episode["episode_index"] for episode in summary["episodes"]], [0, 1, 2, 3])
|
||||
self.assertEqual(summary["episode_rewards"], [1.0, 2.0, 3.0, 4.0])
|
||||
self.assertEqual(summary["num_episodes"], 4)
|
||||
|
||||
def test_run_eval_parallel_cuda_surfaces_server_failures(self):
|
||||
cfg = _make_parallel_cfg(num_episodes=2, num_workers=2, device="cuda", cuda_devices=[0])
|
||||
|
||||
with mock.patch.object(
|
||||
eval_vla,
|
||||
"sample_transfer_pose",
|
||||
side_effect=[
|
||||
np.array([0.1, 0.2, 0.3], dtype=np.float32),
|
||||
np.array([0.4, 0.5, 0.6], dtype=np.float32),
|
||||
],
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
"_run_cuda_parallel_processes",
|
||||
side_effect=RuntimeError("server boom"),
|
||||
create=True,
|
||||
):
|
||||
with self.assertRaisesRegex(RuntimeError, "Parallel CUDA rollout failed"):
|
||||
eval_vla._run_eval_parallel_cuda(cfg)
|
||||
|
||||
def test_run_spawn_jobs_supports_real_spawn_with_actual_eval_worker_entry(self):
|
||||
payloads = [
|
||||
{"_spawn_probe": True, "probe_value": 1, "worker_index": 0},
|
||||
{"_spawn_probe": True, "probe_value": 2, "worker_index": 1},
|
||||
]
|
||||
|
||||
results = eval_vla._run_spawn_jobs(
|
||||
payloads=payloads,
|
||||
max_workers=2,
|
||||
worker_fn=eval_vla._run_eval_worker_entry,
|
||||
)
|
||||
|
||||
self.assertEqual(sorted(result["probe_value"] for result in results), [1, 2])
|
||||
self.assertEqual(sorted(result["worker_index"] for result in results), [0, 1])
|
||||
|
||||
def test_cuda_server_and_env_worker_entrypoints_support_real_spawn_probe(self):
|
||||
ctx = eval_vla.multiprocessing.get_context("spawn")
|
||||
request_queue = ctx.Queue()
|
||||
response_queue = ctx.Queue()
|
||||
result_queue = ctx.Queue()
|
||||
|
||||
server = ctx.Process(
|
||||
target=eval_vla._inference_server_main,
|
||||
args=(
|
||||
{
|
||||
"_spawn_probe": True,
|
||||
"server_index": 0,
|
||||
"request_queue": request_queue,
|
||||
"response_queues": [response_queue],
|
||||
},
|
||||
),
|
||||
)
|
||||
worker = ctx.Process(
|
||||
target=eval_vla._env_worker_main,
|
||||
args=(
|
||||
{
|
||||
"_spawn_probe": True,
|
||||
"worker_index": 0,
|
||||
"server_index": 0,
|
||||
"request_queue": request_queue,
|
||||
"response_queue": response_queue,
|
||||
"result_queue": result_queue,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
server.start()
|
||||
worker.start()
|
||||
|
||||
result = result_queue.get(timeout=10.0)
|
||||
|
||||
worker.join(timeout=10.0)
|
||||
request_queue.put({"type": "shutdown_server"})
|
||||
server.join(timeout=10.0)
|
||||
|
||||
self.assertEqual(result["kind"], "worker_result")
|
||||
self.assertEqual(result["summary"]["probe_worker_index"], 0)
|
||||
self.assertEqual(result["summary"]["probe_server_index"], 0)
|
||||
self.assertEqual(result["summary"]["probe_actions"], [[[11.0], [22.0], [33.0]]])
|
||||
self.assertEqual(worker.exitcode, 0)
|
||||
self.assertEqual(server.exitcode, 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -126,6 +126,26 @@ class EvalVLAHeadlessTest(unittest.TestCase):
|
||||
self.assertIn("headless", eval_cfg)
|
||||
self.assertFalse(eval_cfg.headless)
|
||||
|
||||
def test_eval_config_exposes_num_workers_default(self):
|
||||
eval_cfg = OmegaConf.load(Path("roboimi/vla/conf/eval/eval.yaml"))
|
||||
|
||||
self.assertIn("num_workers", eval_cfg)
|
||||
self.assertEqual(eval_cfg.num_workers, 1)
|
||||
|
||||
def test_eval_config_exposes_cuda_devices_default(self):
|
||||
eval_cfg = OmegaConf.load(Path("roboimi/vla/conf/eval/eval.yaml"))
|
||||
|
||||
self.assertIn("cuda_devices", eval_cfg)
|
||||
self.assertIsNone(eval_cfg.cuda_devices)
|
||||
|
||||
def test_eval_config_exposes_parallel_timeout_defaults(self):
|
||||
eval_cfg = OmegaConf.load(Path("roboimi/vla/conf/eval/eval.yaml"))
|
||||
|
||||
self.assertIn("response_timeout_s", eval_cfg)
|
||||
self.assertIn("server_startup_timeout_s", eval_cfg)
|
||||
self.assertEqual(eval_cfg.response_timeout_s, 300.0)
|
||||
self.assertEqual(eval_cfg.server_startup_timeout_s, 300.0)
|
||||
|
||||
def test_make_sim_env_accepts_headless_and_disables_render(self):
|
||||
fake_env = object()
|
||||
|
||||
@@ -327,6 +347,172 @@ class EvalVLAHeadlessTest(unittest.TestCase):
|
||||
self.assertAlmostEqual(summary["avg_reward"], 3.75)
|
||||
self.assertEqual(summary["num_episodes"], 2)
|
||||
|
||||
def test_run_eval_uses_serial_path_when_num_workers_is_one(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
"eval": {
|
||||
"num_workers": 1,
|
||||
"num_episodes": 3,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
with mock.patch.object(
|
||||
eval_vla,
|
||||
"_run_eval_serial",
|
||||
return_value={"mode": "serial"},
|
||||
) as run_eval_serial, mock.patch.object(
|
||||
eval_vla,
|
||||
"_run_eval_parallel",
|
||||
) as run_eval_parallel:
|
||||
result = eval_vla._run_eval(cfg)
|
||||
|
||||
self.assertEqual(result, {"mode": "serial"})
|
||||
run_eval_serial.assert_called_once_with(cfg)
|
||||
run_eval_parallel.assert_not_called()
|
||||
|
||||
def test_run_eval_uses_serial_path_when_requested_workers_collapse_to_one(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
"eval": {
|
||||
"num_workers": 8,
|
||||
"num_episodes": 1,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
with mock.patch.object(
|
||||
eval_vla,
|
||||
"_run_eval_serial",
|
||||
return_value={"mode": "serial"},
|
||||
) as run_eval_serial, mock.patch.object(
|
||||
eval_vla,
|
||||
"_run_eval_parallel",
|
||||
) as run_eval_parallel:
|
||||
result = eval_vla._run_eval(cfg)
|
||||
|
||||
self.assertEqual(result, {"mode": "serial"})
|
||||
run_eval_serial.assert_called_once_with(cfg)
|
||||
run_eval_parallel.assert_not_called()
|
||||
|
||||
def test_run_eval_parallel_requires_headless_true(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
"agent": {},
|
||||
"eval": {
|
||||
"ckpt_path": "checkpoints/vla_model_best.pt",
|
||||
"num_episodes": 2,
|
||||
"num_workers": 2,
|
||||
"max_timesteps": 1,
|
||||
"device": "cpu",
|
||||
"task_name": "sim_transfer",
|
||||
"camera_names": ["front"],
|
||||
"use_smoothing": False,
|
||||
"smooth_alpha": 0.3,
|
||||
"verbose_action": False,
|
||||
"headless": False,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "headless=true"):
|
||||
eval_vla._run_eval_parallel(cfg)
|
||||
|
||||
def test_run_eval_parallel_dispatches_to_cpu_workers_when_device_is_cpu(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
"agent": {},
|
||||
"eval": {
|
||||
"ckpt_path": "checkpoints/vla_model_best.pt",
|
||||
"num_episodes": 2,
|
||||
"num_workers": 2,
|
||||
"max_timesteps": 1,
|
||||
"device": "cpu",
|
||||
"task_name": "sim_transfer",
|
||||
"camera_names": ["front"],
|
||||
"use_smoothing": False,
|
||||
"smooth_alpha": 0.3,
|
||||
"verbose_action": False,
|
||||
"headless": True,
|
||||
"cuda_devices": None,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
with mock.patch.object(
|
||||
eval_vla,
|
||||
"_run_eval_parallel_cpu",
|
||||
return_value={"mode": "cpu"},
|
||||
create=True,
|
||||
) as run_cpu_parallel, mock.patch.object(
|
||||
eval_vla,
|
||||
"_run_eval_parallel_cuda",
|
||||
create=True,
|
||||
) as run_cuda_parallel:
|
||||
result = eval_vla._run_eval_parallel(cfg)
|
||||
|
||||
self.assertEqual(result, {"mode": "cpu"})
|
||||
run_cpu_parallel.assert_called_once_with(cfg)
|
||||
run_cuda_parallel.assert_not_called()
|
||||
|
||||
def test_run_eval_parallel_dispatches_to_cuda_servers_when_device_is_cuda(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
"agent": {},
|
||||
"eval": {
|
||||
"ckpt_path": "checkpoints/vla_model_best.pt",
|
||||
"num_episodes": 2,
|
||||
"num_workers": 2,
|
||||
"max_timesteps": 1,
|
||||
"device": "cuda",
|
||||
"task_name": "sim_transfer",
|
||||
"camera_names": ["front"],
|
||||
"use_smoothing": False,
|
||||
"smooth_alpha": 0.3,
|
||||
"verbose_action": False,
|
||||
"headless": True,
|
||||
"cuda_devices": [0],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
with mock.patch.object(
|
||||
eval_vla,
|
||||
"_run_eval_parallel_cpu",
|
||||
create=True,
|
||||
) as run_cpu_parallel, mock.patch.object(
|
||||
eval_vla,
|
||||
"_run_eval_parallel_cuda",
|
||||
return_value={"mode": "cuda"},
|
||||
create=True,
|
||||
) as run_cuda_parallel:
|
||||
result = eval_vla._run_eval_parallel(cfg)
|
||||
|
||||
self.assertEqual(result, {"mode": "cuda"})
|
||||
run_cpu_parallel.assert_not_called()
|
||||
run_cuda_parallel.assert_called_once_with(cfg)
|
||||
|
||||
def test_resolve_cuda_devices_defaults_to_single_logical_gpu(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
"device": "cuda",
|
||||
"cuda_devices": None,
|
||||
}
|
||||
)
|
||||
|
||||
self.assertEqual(eval_vla._resolve_cuda_devices(cfg), [0])
|
||||
|
||||
def test_resolve_cuda_devices_rejects_empty_selection(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
"device": "cuda",
|
||||
"cuda_devices": [],
|
||||
}
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "cuda_devices"):
|
||||
eval_vla._resolve_cuda_devices(cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -158,6 +158,106 @@ class TrainVLARolloutValidationTest(unittest.TestCase):
|
||||
self.assertGreater(float(cfg.train.lr), 5e-5)
|
||||
self.assertGreater(cfg.train.num_workers, 8)
|
||||
self.assertEqual(cfg.train.rollout_val_freq_epochs, 50)
|
||||
self.assertEqual(cfg.train.rollout_device, cfg.train.device)
|
||||
self.assertIsNone(cfg.train.rollout_num_workers)
|
||||
self.assertIsNone(cfg.train.rollout_cuda_devices)
|
||||
|
||||
def test_run_training_rollout_validation_propagates_gpu_parallel_settings(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
'train': {
|
||||
'device': 'cpu',
|
||||
'batch_size': 1,
|
||||
'num_workers': 0,
|
||||
'val_split': 0.0,
|
||||
'seed': 0,
|
||||
'lr': 1e-3,
|
||||
'max_steps': 2,
|
||||
'log_freq': 1,
|
||||
'save_freq': 1000,
|
||||
'warmup_steps': 1,
|
||||
'scheduler_type': 'constant',
|
||||
'min_lr': 0.0,
|
||||
'grad_clip': 1.0,
|
||||
'weight_decay': 0.0,
|
||||
'pretrained_ckpt': None,
|
||||
'resume_ckpt': None,
|
||||
'use_swanlab': False,
|
||||
'rollout_val_freq_epochs': 2,
|
||||
'rollout_num_episodes': 5,
|
||||
'rollout_device': 'cuda',
|
||||
'rollout_num_workers': 4,
|
||||
'rollout_cuda_devices': [0, 1],
|
||||
'rollout_response_timeout_s': 123.0,
|
||||
'rollout_server_startup_timeout_s': 456.0,
|
||||
},
|
||||
'data': {
|
||||
'camera_names': ['front'],
|
||||
},
|
||||
'agent': {
|
||||
'_target_': 'fake.agent',
|
||||
},
|
||||
'eval': {
|
||||
'ckpt_path': 'unused.pt',
|
||||
'num_episodes': 99,
|
||||
'max_timesteps': 1,
|
||||
'device': 'cpu',
|
||||
'task_name': 'sim_transfer',
|
||||
'camera_names': ['front'],
|
||||
'use_smoothing': False,
|
||||
'smooth_alpha': 0.3,
|
||||
'verbose_action': False,
|
||||
'headless': False,
|
||||
},
|
||||
}
|
||||
)
|
||||
rollout_mock = mock.Mock(return_value={'avg_reward': 1.0})
|
||||
|
||||
def fake_instantiate(config_node, **_kwargs):
|
||||
if config_node is cfg.data:
|
||||
return _FakeDataset()
|
||||
if config_node is cfg.agent:
|
||||
return _FakeAgent()
|
||||
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||
|
||||
def fake_dataloader(_dataset, *, shuffle, **_kwargs):
|
||||
del shuffle, _kwargs
|
||||
return _FakeLoader(
|
||||
{
|
||||
'observation.front': torch.zeros(1, 3, 2, 2),
|
||||
'observation.state': torch.zeros(1, 4),
|
||||
'action': torch.zeros(1, 2),
|
||||
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
|
||||
},
|
||||
length=1,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \
|
||||
mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \
|
||||
mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \
|
||||
mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \
|
||||
mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \
|
||||
mock.patch.object(train_vla.torch, 'save', return_value=None), \
|
||||
mock.patch.object(eval_vla, '_run_eval', rollout_mock, create=True):
|
||||
train_vla._run_training(cfg)
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
rollout_cfg = rollout_mock.call_args.args[0]
|
||||
self.assertEqual(rollout_cfg.eval.device, 'cuda')
|
||||
self.assertEqual(rollout_cfg.eval.num_workers, 4)
|
||||
self.assertEqual(list(rollout_cfg.eval.cuda_devices), [0, 1])
|
||||
self.assertEqual(float(rollout_cfg.eval.response_timeout_s), 123.0)
|
||||
self.assertEqual(float(rollout_cfg.eval.server_startup_timeout_s), 456.0)
|
||||
self.assertTrue(rollout_cfg.eval.headless)
|
||||
self.assertEqual(rollout_cfg.eval.num_episodes, 5)
|
||||
self.assertFalse(rollout_cfg.eval.record_video)
|
||||
self.assertTrue(rollout_cfg.eval.save_summary_json)
|
||||
self.assertTrue(rollout_cfg.eval.save_trajectory_image)
|
||||
|
||||
def test_training_passes_backbone_image_resize_override_to_dataset_instantiation(self):
|
||||
cfg = OmegaConf.create(
|
||||
|
||||
Reference in New Issue
Block a user