Compare commits
5 Commits
feat-imf-a
...
feat-lewm-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
61522d9ae5 | ||
|
|
4cd33258d2 | ||
|
|
d8066823e2 | ||
|
|
395f5a1645 | ||
|
|
74f4963613 |
471
docs/lewm-imf-experiment-guide.md
Normal file
471
docs/lewm-imf-experiment-guide.md
Normal file
@@ -0,0 +1,471 @@
|
||||
# feat-lewm-imf-fusion 实验操作指南
|
||||
|
||||
适用 worktree:`/home/droid/project/roboimi/.worktrees/feat-lewm-imf-fusion`
|
||||
|
||||
## 0. 先记住当前常用 recipe
|
||||
|
||||
当前这条分支最常用的训练/验证配方,直接参考:
|
||||
`experiment_suites/2026-04-21-lewm-fromscratch-old9-epoch50-roll5-val-20260421-153037/`
|
||||
|
||||
核心约定:
|
||||
- agent:`lewm_resnet_query_imf_attnres`
|
||||
- from scratch:`train.pretrained_ckpt=null`,`agent.lewm_pretrained_ckpt=null`
|
||||
- 训练:`batch_size=32`,`lr=1e-4`,`max_steps=109350`,`save_freq=10000`
|
||||
- 数值验证:`train.val_split=0.0` + `train.val_episode_indices=[100]`
|
||||
- held-out numeric validation:`train.action_mse_val_freq_epochs=1`
|
||||
- rollout validation:`train.rollout_val_freq_epochs=5`,`train.rollout_num_episodes=10`
|
||||
- SwanLab:`train.use_swanlab=true`,project=`roboimi-vla`
|
||||
|
||||
---
|
||||
|
||||
## 1. 分支结构与关键文件
|
||||
|
||||
| 路径 | 作用 |
|
||||
| --- | --- |
|
||||
| `roboimi/demos/vla_scripts/train_vla.py` | 主训练入口;负责数据集、checkpoint、数值验证、训练期 rollout 验证、SwanLab |
|
||||
| `roboimi/demos/vla_scripts/eval_vla.py` | 单次 rollout / 离线验证入口;支持 headless、summary、trajectory image/video artifact |
|
||||
| `roboimi/vla/conf/config.yaml` | 全局 Hydra 配置;训练默认值都在这里 |
|
||||
| `roboimi/vla/conf/eval/eval.yaml` | eval 默认配置;`eval.ckpt_path`、`eval.num_episodes`、artifact 开关都在这里 |
|
||||
| `roboimi/vla/conf/agent/lewm_resnet_query_imf_attnres.yaml` | 本分支最常用 agent;LeWM query fusion + IMF AttnRes head |
|
||||
| `roboimi/vla/conf/backbone/lewm_resnet_query_fusion.yaml` | LeWM 多视角 ResNet query fusion backbone 配置 |
|
||||
| `roboimi/vla/agent_imf.py` | `IMFVLAAgent` 实现;one-step IMF 推理、LeWM loss、LeWM 预训练组件加载 |
|
||||
| `roboimi/vla/data/simpe_robot_dataset.py` | HDF5 懒加载数据集;也负责 `episode_indices` 过滤 |
|
||||
| `roboimi/vla/scripts/calculate_stats.py` | 重算 `dataset_stats.pkl` |
|
||||
| `experiment_suites/2026-04-21-lewm-fromscratch-old9-epoch50-roll5-val-20260421-153037/` | 当前最常用 suite;manifest、notes、launch log、local 启动脚本都在这里 |
|
||||
|
||||
补充:
|
||||
- 本分支常用 run name 形如 `lewmimf-q08-ph08-ex08-emb384-l12-fromscratch-epoch50-step109350-5090g0-20260421-153037`
|
||||
- `q08/ph16/ex08` 这类后缀分别对应 `agent.lewm_query_offsets`、`agent.pred_horizon`、`agent.num_action_steps`
|
||||
|
||||
---
|
||||
|
||||
## 2. 三台机器与环境
|
||||
|
||||
| 机器 | GPU | repo / worktree | Python | 常用数据集路径 |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| 本地 `droid-z790eagleax` | 1× RTX 5090 32GB | `/home/droid/project/roboimi/.worktrees/feat-lewm-imf-fusion` | `/home/droid/.conda/envs/roboimi/bin/python` | `/home/droid/project/diana_sim/sim_transfer` |
|
||||
| 5880 节点 `100.73.14.65` | 2× RTX 5880 Ada 48GB | `/home/droid/roboimi_suite_20260416_lewm_imf_fusion` | `/home/droid/miniforge3/envs/roboimi/bin/python` | `/home/droid/sim_dataset/sim_transfer` |
|
||||
| L20 节点 `100.119.99.14` | 8× NVIDIA L20 46GB | `/data/roboimi_suite_20260416_lewm_imf_fusion` | `/home/droid/miniforge3/envs/roboimi/bin/python` | `/data/simtransfer/current` |
|
||||
|
||||
连接:
|
||||
- 5880:`ssh droid@100.73.14.65`
|
||||
- L20:`ssh droid@100.119.99.14`
|
||||
|
||||
经验规则:
|
||||
- 本地 5090:适合单条 smoke / 小规模主跑 / 本地调参
|
||||
- 5880:适合 2 条并行主跑
|
||||
- L20:适合大 grid;数据和 run 建议都放 `/data`
|
||||
|
||||
---
|
||||
|
||||
## 3. 训练流怎么走
|
||||
|
||||
`train_vla.py` 的实际流程:
|
||||
|
||||
1. 读取 Hydra 配置并打印完整 cfg
|
||||
2. 通过 `build_train_val_datasets()` 构建 train/val dataset
|
||||
3. 用 `DataLoader` 建 train/val loader
|
||||
4. 从 `dataset_dir/dataset_stats.pkl` 读取归一化统计
|
||||
5. instantiate `IMFVLAAgent`
|
||||
6. 可选加载:
|
||||
- `train.pretrained_ckpt`
|
||||
- `train.resume_ckpt`
|
||||
- `agent.lewm_pretrained_ckpt`
|
||||
7. 训练循环里按 `log_freq` 打 train loss / lr
|
||||
8. 按 `save_freq` 保存 `checkpoints/vla_model_step_*.pt`
|
||||
9. 每个 epoch 结束时,按配置跑:
|
||||
- held-out action MSE
|
||||
- rollout validation
|
||||
10. 最后写:
|
||||
- `checkpoints/vla_model_best.pt`
|
||||
- `checkpoints/vla_model_final.pt`
|
||||
|
||||
当前 best model 选择逻辑:
|
||||
- **第一次拿到 rollout reward 之前**:先用 `val_loss`(或 train loss 回退)挑 best
|
||||
- **第一次 rollout 之后**:优先用 `rollout_avg_reward` 挑 best
|
||||
|
||||
输出目录一般通过 `hydra.run.dir=...` 固定;否则 Hydra 自己生成。
|
||||
|
||||
---
|
||||
|
||||
## 4. 验证流怎么走
|
||||
|
||||
### 4.1 held-out 数值验证
|
||||
|
||||
当前常用做法不是随机切 `val_split`,而是:
|
||||
- `train.val_split=0.0`
|
||||
- `train.val_episode_indices=[100]`
|
||||
- `train.action_mse_val_freq_epochs=1`
|
||||
|
||||
这样每个 epoch 结束都会在 `episode_100.hdf5` 上跑一次 `compute_action_mse_validation()`,日志 key 是:
|
||||
- 控制台 / `train_vla.log`:`held-out action MSE`
|
||||
- SwanLab:`val/action_mse`
|
||||
|
||||
### 4.2 rollout 验证
|
||||
|
||||
当前训练内 rollout 验证由 `train_vla.py -> run_rollout_validation() -> eval_vla._run_eval()` 触发。
|
||||
|
||||
当前这条分支的常用训练内 rollout 约束是:
|
||||
- `train.rollout_val_freq_epochs=5`
|
||||
- `train.rollout_num_episodes=10`
|
||||
- `train.rollout_validate_on_checkpoint=false`
|
||||
- 强制 headless
|
||||
- 强制 `verbose_action=false`
|
||||
- 强制 `record_video=false`
|
||||
- 强制 `save_trajectory_image=true`
|
||||
- 强制 `trajectory_image_camera_name=front`
|
||||
- 强制 `save_summary_json=true`
|
||||
|
||||
当前已经修正为**配置驱动的 rollout device / worker 路径**:
|
||||
- `train.rollout_device`:默认跟随 `train.device`
|
||||
- `train.rollout_num_workers`:默认 `null`
|
||||
- 当 rollout 设备是 CPU 时,自动退化为 `1`
|
||||
- 当 rollout 设备是 CUDA 时,自动推断为 `min(train.rollout_num_episodes, 8)`
|
||||
- `train.rollout_cuda_devices`:默认 `null`,等价于当前可见逻辑 GPU `[0]`
|
||||
- `train.rollout_response_timeout_s`
|
||||
- `train.rollout_server_startup_timeout_s`
|
||||
|
||||
所以现在:
|
||||
- 训练在 `cuda` 上时,**训练期 rollout 默认会走 GPU**
|
||||
- 如果 `rollout_num_workers > 1`,就会自动走并行 rollout
|
||||
- 可以是 **单 GPU 多 worker 共用一个 inference server**
|
||||
- 也可以是 **多 GPU 多 server 分摊 worker**
|
||||
|
||||
训练内 rollout artifact 默认落到:
|
||||
`<hydra.run.dir>/rollout_artifacts/<checkpoint_stem>/`
|
||||
|
||||
常见文件:
|
||||
- `rollout_summary.json`
|
||||
- `rollout_front_ep01_trajectory.png` ... `rollout_front_ep10_trajectory.png`
|
||||
|
||||
日志重点看:
|
||||
- `Epoch X rollout 平均奖励`
|
||||
- `最佳模型已更新`
|
||||
|
||||
---
|
||||
|
||||
## 5. 数据集加载与 `val_episode_indices` 机制
|
||||
|
||||
### 5.1 数据集格式
|
||||
|
||||
`SimpleRobotDataset` 读取 `dataset_dir` 下的 `episode_*.hdf5`,每个 episode 文件里至少要有:
|
||||
- `action`
|
||||
- `observations/qpos`
|
||||
- `observations/images/{cam_name}`
|
||||
|
||||
当前常用相机:
|
||||
- `r_vis`
|
||||
- `top`
|
||||
- `front`
|
||||
|
||||
### 5.2 懒加载行为
|
||||
|
||||
`roboimi/vla/data/simpe_robot_dataset.py` 是按帧懒加载,不会一次性把整套 HDF5 全读进内存。
|
||||
|
||||
它会:
|
||||
- 扫描目录下的 HDF5 文件
|
||||
- 用文件名里的 episode 编号(如 `episode_100.hdf5` -> `100`)建立 `available_episode_indices`
|
||||
- 在 worker 内做 HDF5 文件句柄 LRU 缓存
|
||||
|
||||
### 5.3 `val_episode_indices` 怎么切
|
||||
|
||||
`build_train_val_datasets()` 的逻辑是:
|
||||
|
||||
1. 先 instantiate 一次完整 dataset
|
||||
2. 读取 `dataset.available_episode_indices`
|
||||
3. 检查 `train.val_episode_indices` 是否都存在
|
||||
4. 用 `episode_indices=` 再各 instantiate 一次:
|
||||
- train dataset = 全部 episode - held-out episode
|
||||
- val dataset = 只包含 held-out episode
|
||||
|
||||
因此:
|
||||
- `train.val_episode_indices=[100]` 的意思是“把 `episode_100.hdf5` 整个拿去做 held-out val”
|
||||
- 如果 episode 不存在,会直接报错
|
||||
- 如果你把所有 episode 都塞进 `val_episode_indices`,也会直接报错,因为训练集会变空
|
||||
|
||||
### 5.4 图像 resize 与 LeWM 附加字段
|
||||
|
||||
dataset 侧 resize 默认来自:
|
||||
- `data.image_resize_shape`
|
||||
- 如果 backbone 额外覆盖,则优先 `agent.vision_backbone.dataset_image_resize_shape`
|
||||
|
||||
返回 batch 除了常规:
|
||||
- `observation.state`
|
||||
- `observation.<cam>`
|
||||
- `action`
|
||||
|
||||
还会在 LeWM 打开时返回:
|
||||
- `lewm.observation.state`
|
||||
- `lewm.observation.<cam>`
|
||||
- `lewm.future.state`
|
||||
- `lewm.future.<cam>`
|
||||
|
||||
### 5.5 统计文件
|
||||
|
||||
训练和推理都默认依赖 `dataset_stats.pkl`。数据集更新后重算:
|
||||
|
||||
```bash
|
||||
/home/droid/.conda/envs/roboimi/bin/python roboimi/vla/scripts/calculate_stats.py \
|
||||
--dataset_dir /home/droid/project/diana_sim/sim_transfer
|
||||
```
|
||||
|
||||
远端只要把 `--dataset_dir` 换成对应主机路径即可。
|
||||
|
||||
---
|
||||
|
||||
## 6. SwanLab 行为
|
||||
|
||||
当前配置默认值里 `train.use_swanlab=false`,但本分支常用 recipe 基本都显式开:
|
||||
- `train.use_swanlab=true`
|
||||
- `train.swanlab_project=roboimi-vla`
|
||||
- `train.swanlab_run_name=<run_name>`
|
||||
|
||||
`train_vla.py` 的 SwanLab 行为:
|
||||
- 初始化时上传 `train` / `data` / `agent` 三段 config
|
||||
- 训练中记录:
|
||||
- `train/loss`
|
||||
- `train/lr`
|
||||
- `train/best_loss`
|
||||
- `train/step`
|
||||
- checkpoint 验证时记录:
|
||||
- `val/loss`
|
||||
- held-out 数值验证时记录:
|
||||
- `val/action_mse`
|
||||
- rollout 验证时记录:
|
||||
- `rollout/avg_reward`
|
||||
- `rollout/epoch`
|
||||
- 训练结束时记录:
|
||||
- `final/checkpoint_path`
|
||||
- `final/best_checkpoint_path`
|
||||
|
||||
训练期 rollout 生成的前视图轨迹 PNG 会 best-effort 上传到 SwanLab;失败只会 warning,不会让训练中断。
|
||||
|
||||
---
|
||||
|
||||
## 7. 并行 rollout 说明
|
||||
|
||||
### 7.1 这套能力从哪里来
|
||||
|
||||
本分支的并行 rollout 方向不是 DataLoader 并行,而是 **`eval_vla.py` 的 multiprocess rollout path**。
|
||||
参考来源:
|
||||
`/home/droid/project/roboimi/.worktrees/multiprocess-rollout/roboimi/demos/vla_scripts/eval_vla.py`
|
||||
|
||||
那条路径的控制参数是:
|
||||
- `eval.num_workers`
|
||||
- `eval.cuda_devices`
|
||||
|
||||
语义是:
|
||||
- `eval.num_workers`:环境 worker 数,按 episode 切分
|
||||
- `eval.cuda_devices`:推理 server 绑定到哪些逻辑 GPU
|
||||
|
||||
### 7.2 两种常见模式
|
||||
|
||||
1. **单机单卡,多 worker 共用同一张 GPU**
|
||||
- 典型:本地 5090 只有 1 卡,但想让 4 个 rollout worker 并行跑环境
|
||||
- 形式:`eval.device=cuda eval.num_workers=4 'eval.cuda_devices=[0]'`
|
||||
- 这时是 **1 个 CUDA inference server + 4 个 env worker**
|
||||
|
||||
2. **单机多卡,多 server 分摊 worker**
|
||||
- 典型:5880 有 2 卡,L20 有多卡
|
||||
- 形式:`eval.device=cuda eval.num_workers=8 'eval.cuda_devices=[0,1]'`
|
||||
- worker 会按 round-robin 分到多个 server 上
|
||||
|
||||
### 7.3 操作上要注意什么
|
||||
|
||||
- 并行 rollout 依赖 **多进程 eval 路径**,不是 `train.num_workers`
|
||||
- `train.num_workers` 是 DataLoader worker,和 rollout 并行不是一回事
|
||||
- `eval.num_workers > 1` 时必须 `eval.headless=true`
|
||||
- worker 数会自动 cap 到 `eval.num_episodes`
|
||||
- multiprocess rollout 当前已经支持 **per-episode trajectory image PNG**;多 worker 时每个 worker 会在自己的 artifact 子目录下写图,summary 会带回对应路径
|
||||
- 但多 worker 时仍然不要同时要求:
|
||||
- `eval.record_video=true`
|
||||
- `eval.save_trajectory=true`
|
||||
- `eval.save_trajectory_npz=true`
|
||||
- `eval.save_trajectory_image=true` 现在是可以开的,适合并行 reward + 定性检查一起做
|
||||
|
||||
### 7.4 并行 rollout 命令模板
|
||||
|
||||
**5090 单卡 4 worker:**
|
||||
|
||||
```bash
|
||||
/home/droid/.conda/envs/roboimi/bin/python roboimi/demos/vla_scripts/eval_vla.py \
|
||||
agent=lewm_resnet_query_imf_attnres \
|
||||
data.dataset_dir=/home/droid/project/diana_sim/sim_transfer \
|
||||
train.device=cuda eval.device=cuda eval.headless=true eval.verbose_action=false \
|
||||
eval.ckpt_path=/home/droid/project/roboimi/.worktrees/feat-lewm-imf-fusion/runs/<run_name>/checkpoints/vla_model_best.pt \
|
||||
eval.num_episodes=10 eval.num_workers=4 'eval.cuda_devices=[0]' \
|
||||
eval.save_summary_json=true eval.artifact_dir=/tmp/lewm_parallel_eval_5090
|
||||
```
|
||||
|
||||
**5880 双卡 8 worker:**
|
||||
|
||||
```bash
|
||||
/home/droid/miniforge3/envs/roboimi/bin/python roboimi/demos/vla_scripts/eval_vla.py \
|
||||
agent=lewm_resnet_query_imf_attnres \
|
||||
data.dataset_dir=/home/droid/sim_dataset/sim_transfer \
|
||||
train.device=cuda eval.device=cuda eval.headless=true eval.verbose_action=false \
|
||||
eval.ckpt_path=/home/droid/roboimi_suite_20260416_lewm_imf_fusion/runs/<run_name>/checkpoints/vla_model_best.pt \
|
||||
eval.num_episodes=10 eval.num_workers=8 'eval.cuda_devices=[0,1]' \
|
||||
eval.save_summary_json=true eval.artifact_dir=/tmp/lewm_parallel_eval_5880
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 8. 当前常用命令 / 脚本
|
||||
|
||||
### 8.1 本地 5090:直接用 suite 脚本
|
||||
|
||||
现成脚本:
|
||||
`experiment_suites/2026-04-21-lewm-fromscratch-old9-epoch50-roll5-val-20260421-153037/launch_local_5090.sh`
|
||||
|
||||
运行:
|
||||
|
||||
```bash
|
||||
bash experiment_suites/2026-04-21-lewm-fromscratch-old9-epoch50-roll5-val-20260421-153037/launch_local_5090.sh
|
||||
```
|
||||
|
||||
### 8.2 本地 5090:手动启动同 recipe
|
||||
|
||||
```bash
|
||||
/home/droid/.conda/envs/roboimi/bin/python roboimi/demos/vla_scripts/train_vla.py \
|
||||
agent=lewm_resnet_query_imf_attnres \
|
||||
data.dataset_dir=/home/droid/project/diana_sim/sim_transfer \
|
||||
'agent.lewm_query_offsets=[8]' \
|
||||
agent.pred_horizon=8 \
|
||||
agent.num_action_steps=8 \
|
||||
train.device=cuda \
|
||||
train.batch_size=32 \
|
||||
train.lr=0.0001 \
|
||||
train.max_steps=109350 \
|
||||
train.num_workers=4 \
|
||||
train.save_freq=10000 \
|
||||
train.rollout_validate_on_checkpoint=false \
|
||||
train.rollout_val_freq_epochs=5 \
|
||||
train.rollout_num_episodes=10 \
|
||||
train.val_split=0.0 \
|
||||
'train.val_episode_indices=[100]' \
|
||||
train.action_mse_val_freq_epochs=1 \
|
||||
train.use_swanlab=true \
|
||||
train.swanlab_project=roboimi-vla \
|
||||
train.swanlab_run_name=lewmimf-q08-ph08-ex08-emb384-l12-fromscratch-epoch50-step109350-5090g0-20260421-153037 \
|
||||
train.pretrained_ckpt=null \
|
||||
agent.lewm_pretrained_ckpt=null \
|
||||
hydra.run.dir=/home/droid/project/roboimi/.worktrees/feat-lewm-imf-fusion/runs/lewmimf-q08-ph08-ex08-emb384-l12-fromscratch-epoch50-step109350-5090g0-20260421-153037
|
||||
```
|
||||
|
||||
### 8.3 5880:常用命令模板
|
||||
|
||||
```bash
|
||||
ssh droid@100.73.14.65
|
||||
cd /home/droid/roboimi_suite_20260416_lewm_imf_fusion
|
||||
/home/droid/miniforge3/envs/roboimi/bin/python roboimi/demos/vla_scripts/train_vla.py \
|
||||
agent=lewm_resnet_query_imf_attnres \
|
||||
data.dataset_dir=/home/droid/sim_dataset/sim_transfer \
|
||||
'agent.lewm_query_offsets=[8]' \
|
||||
agent.pred_horizon=16 \
|
||||
agent.num_action_steps=8 \
|
||||
train.device=cuda train.batch_size=32 train.lr=0.0001 train.max_steps=109350 \
|
||||
train.num_workers=4 train.save_freq=10000 train.rollout_validate_on_checkpoint=false \
|
||||
train.rollout_val_freq_epochs=5 train.rollout_num_episodes=10 train.val_split=0.0 \
|
||||
'train.val_episode_indices=[100]' train.action_mse_val_freq_epochs=1 \
|
||||
train.use_swanlab=true train.swanlab_project=roboimi-vla \
|
||||
train.swanlab_run_name=lewmimf-q08-ph16-ex08-emb384-l12-fromscratch-epoch50-step109350-5880g0-20260421-153037 \
|
||||
train.pretrained_ckpt=null agent.lewm_pretrained_ckpt=null \
|
||||
hydra.run.dir=/home/droid/roboimi_suite_20260416_lewm_imf_fusion/runs/lewmimf-q08-ph16-ex08-emb384-l12-fromscratch-epoch50-step109350-5880g0-20260421-153037
|
||||
```
|
||||
|
||||
### 8.4 L20:常用命令模板
|
||||
|
||||
```bash
|
||||
ssh droid@100.119.99.14
|
||||
cd /data/roboimi_suite_20260416_lewm_imf_fusion
|
||||
/home/droid/miniforge3/envs/roboimi/bin/python roboimi/demos/vla_scripts/train_vla.py \
|
||||
agent=lewm_resnet_query_imf_attnres \
|
||||
data.dataset_dir=/data/simtransfer/current \
|
||||
'agent.lewm_query_offsets=[16]' \
|
||||
agent.pred_horizon=16 \
|
||||
agent.num_action_steps=16 \
|
||||
train.device=cuda train.batch_size=32 train.lr=0.0001 train.max_steps=109350 \
|
||||
train.num_workers=4 train.save_freq=10000 train.rollout_validate_on_checkpoint=false \
|
||||
train.rollout_val_freq_epochs=5 train.rollout_num_episodes=10 train.val_split=0.0 \
|
||||
'train.val_episode_indices=[100]' train.action_mse_val_freq_epochs=1 \
|
||||
train.use_swanlab=true train.swanlab_project=roboimi-vla \
|
||||
train.swanlab_run_name=lewmimf-q16-ph16-ex16-emb384-l12-fromscratch-epoch50-step109350-l20g0-20260421-153037 \
|
||||
train.pretrained_ckpt=null agent.lewm_pretrained_ckpt=null \
|
||||
hydra.run.dir=/data/roboimi_suite_20260416_lewm_imf_fusion/runs/lewmimf-q16-ph16-ex16-emb384-l12-fromscratch-epoch50-step109350-l20g0-20260421-153037
|
||||
```
|
||||
|
||||
### 8.5 单次离线验证(当前分支已支持并行)
|
||||
|
||||
**单 GPU / 4 worker:**
|
||||
|
||||
```bash
|
||||
/home/droid/.conda/envs/roboimi/bin/python roboimi/demos/vla_scripts/eval_vla.py \
|
||||
agent=lewm_resnet_query_imf_attnres \
|
||||
data.dataset_dir=/home/droid/project/diana_sim/sim_transfer \
|
||||
train.device=cuda eval.device=cuda \
|
||||
eval.ckpt_path=/home/droid/project/roboimi/.worktrees/feat-lewm-imf-fusion/runs/<run_name>/checkpoints/vla_model_best.pt \
|
||||
eval.num_episodes=10 eval.num_workers=4 'eval.cuda_devices=[0]' \
|
||||
eval.headless=true eval.verbose_action=false \
|
||||
eval.save_summary_json=true eval.save_trajectory_image=true \
|
||||
eval.trajectory_image_camera_name=front \
|
||||
eval.artifact_dir=/tmp/lewm_eval_front
|
||||
```
|
||||
|
||||
**训练内启用并行 GPU rollout(推荐显式写清楚)**:
|
||||
|
||||
```bash
|
||||
/home/droid/.conda/envs/roboimi/bin/python roboimi/demos/vla_scripts/train_vla.py \
|
||||
agent=lewm_resnet_query_imf_attnres \
|
||||
data.dataset_dir=/home/droid/project/diana_sim/sim_transfer \
|
||||
'agent.lewm_query_offsets=[8]' \
|
||||
agent.pred_horizon=8 \
|
||||
agent.num_action_steps=8 \
|
||||
train.device=cuda \
|
||||
train.batch_size=32 \
|
||||
train.lr=0.0001 \
|
||||
train.max_steps=109350 \
|
||||
train.num_workers=4 \
|
||||
train.save_freq=10000 \
|
||||
train.rollout_val_freq_epochs=5 \
|
||||
train.rollout_num_episodes=10 \
|
||||
train.rollout_device=cuda \
|
||||
train.rollout_num_workers=4 \
|
||||
'train.rollout_cuda_devices=[0]' \
|
||||
train.rollout_validate_on_checkpoint=false \
|
||||
train.val_split=0.0 \
|
||||
'train.val_episode_indices=[100]' \
|
||||
train.action_mse_val_freq_epochs=1 \
|
||||
train.use_swanlab=true \
|
||||
train.swanlab_project=roboimi-vla \
|
||||
train.swanlab_run_name=<run_name> \
|
||||
hydra.run.dir=/home/droid/project/roboimi/.worktrees/feat-lewm-imf-fusion/runs/<run_name>
|
||||
```
|
||||
|
||||
### 8.6 监控日志
|
||||
|
||||
```bash
|
||||
tail -f runs/<run_name>/launch.stdout.log
|
||||
tail -f runs/<run_name>/train_vla.log
|
||||
```
|
||||
|
||||
远端就把 `runs/<run_name>` 换成 manifest 里的绝对路径。
|
||||
|
||||
---
|
||||
|
||||
## 9. 操作建议
|
||||
|
||||
- **优先以 suite 的 `manifest.json` / `notes.md` / `launch_logs/*.launch.log` 为准**,不要手写一套和历史 run 不一致的命令
|
||||
- 要做当前常用验证,就显式加上:
|
||||
- `train.val_split=0.0`
|
||||
- `train.val_episode_indices=[100]`
|
||||
- `train.action_mse_val_freq_epochs=1`
|
||||
- `train.rollout_val_freq_epochs=5`
|
||||
- `train.rollout_num_episodes=10`
|
||||
- 本分支如果要对比不同 horizon / action-step,尽量只改:
|
||||
- `agent.lewm_query_offsets`
|
||||
- `agent.pred_horizon`
|
||||
- `agent.num_action_steps`
|
||||
- 想复现 2026-04-21 那轮 from-scratch 结果时,记得同时设:
|
||||
- `train.pretrained_ckpt=null`
|
||||
- `agent.lewm_pretrained_ckpt=null`
|
||||
File diff suppressed because it is too large
Load Diff
@@ -118,6 +118,127 @@ def recursive_to_device(data, device):
|
||||
return data
|
||||
|
||||
|
||||
def build_agent_input(batch_data):
|
||||
agent_input = {
|
||||
'images': {
|
||||
cam_name.replace('observation.', ''): value
|
||||
for cam_name, value in batch_data.items()
|
||||
if cam_name.startswith('observation.') and cam_name != 'observation.state'
|
||||
},
|
||||
'qpos': batch_data['observation.state'],
|
||||
'action': batch_data['action'],
|
||||
}
|
||||
|
||||
if 'action_is_pad' in batch_data:
|
||||
agent_input['action_is_pad'] = batch_data['action_is_pad']
|
||||
|
||||
lewm_images = {
|
||||
cam_name.replace('lewm.observation.', ''): value
|
||||
for cam_name, value in batch_data.items()
|
||||
if cam_name.startswith('lewm.observation.') and cam_name != 'lewm.observation.state'
|
||||
}
|
||||
if lewm_images:
|
||||
agent_input['lewm_images'] = lewm_images
|
||||
if 'lewm.observation.state' in batch_data:
|
||||
agent_input['lewm_qpos'] = batch_data['lewm.observation.state']
|
||||
|
||||
lewm_future_images = {
|
||||
cam_name.replace('lewm.future.', ''): value
|
||||
for cam_name, value in batch_data.items()
|
||||
if cam_name.startswith('lewm.future.') and cam_name != 'lewm.future.state'
|
||||
}
|
||||
if lewm_future_images:
|
||||
agent_input['lewm_future_images'] = lewm_future_images
|
||||
if 'lewm.future.state' in batch_data:
|
||||
agent_input['lewm_future_qpos'] = batch_data['lewm.future.state']
|
||||
|
||||
return agent_input
|
||||
|
||||
|
||||
def _instantiate_dataset(cfg, dataset_image_resize_shape, episode_indices=None):
|
||||
kwargs = {'image_resize_shape': dataset_image_resize_shape}
|
||||
if episode_indices is not None:
|
||||
kwargs['episode_indices'] = episode_indices
|
||||
return instantiate(cfg.data, **kwargs)
|
||||
|
||||
|
||||
def build_train_val_datasets(cfg, dataset_image_resize_shape):
|
||||
val_episode_indices = cfg.train.get('val_episode_indices', None)
|
||||
if val_episode_indices:
|
||||
dataset = _instantiate_dataset(cfg, dataset_image_resize_shape)
|
||||
available_episode_indices = list(getattr(dataset, 'available_episode_indices', []))
|
||||
if not available_episode_indices:
|
||||
raise ValueError('显式 val_episode_indices 需要数据集暴露 available_episode_indices')
|
||||
requested_val_episode_indices = sorted(int(idx) for idx in val_episode_indices)
|
||||
available_set = set(available_episode_indices)
|
||||
missing = sorted(set(requested_val_episode_indices) - available_set)
|
||||
if missing:
|
||||
raise ValueError(
|
||||
f'val_episode_indices {missing} 不存在于数据集可用 episodes {available_episode_indices}'
|
||||
)
|
||||
train_episode_indices = [
|
||||
idx for idx in available_episode_indices
|
||||
if idx not in set(requested_val_episode_indices)
|
||||
]
|
||||
if not train_episode_indices:
|
||||
raise ValueError('显式 val_episode_indices 不能覆盖全部 episodes,训练集将为空')
|
||||
|
||||
train_dataset = _instantiate_dataset(
|
||||
cfg,
|
||||
dataset_image_resize_shape,
|
||||
episode_indices=train_episode_indices,
|
||||
)
|
||||
val_dataset = _instantiate_dataset(
|
||||
cfg,
|
||||
dataset_image_resize_shape,
|
||||
episode_indices=requested_val_episode_indices,
|
||||
)
|
||||
return dataset, train_dataset, val_dataset, requested_val_episode_indices
|
||||
|
||||
dataset = _instantiate_dataset(cfg, dataset_image_resize_shape)
|
||||
val_split = float(cfg.train.get('val_split', 0.1))
|
||||
seed = int(cfg.train.get('seed', 42))
|
||||
val_size = int(len(dataset) * val_split)
|
||||
train_size = len(dataset) - val_size
|
||||
if val_size > 0:
|
||||
train_dataset, val_dataset = random_split(
|
||||
dataset,
|
||||
[train_size, val_size],
|
||||
generator=torch.Generator().manual_seed(seed)
|
||||
)
|
||||
else:
|
||||
train_dataset, val_dataset = dataset, None
|
||||
return dataset, train_dataset, val_dataset, None
|
||||
|
||||
|
||||
def compute_action_mse_validation(agent, val_loader, device):
|
||||
if val_loader is None:
|
||||
return None
|
||||
|
||||
was_training = agent.training
|
||||
agent.eval()
|
||||
total_squared_error = 0.0
|
||||
total_count = 0.0
|
||||
with torch.no_grad():
|
||||
for val_batch in val_loader:
|
||||
val_batch = recursive_to_device(val_batch, device)
|
||||
val_input = build_agent_input(val_batch)
|
||||
pred_actions = agent.predict_action_chunk(val_input)
|
||||
target_actions = val_input['action']
|
||||
squared_error = (pred_actions - target_actions).pow(2)
|
||||
action_is_pad = val_input.get('action_is_pad', None)
|
||||
if action_is_pad is not None:
|
||||
mask = (~action_is_pad).unsqueeze(-1).to(squared_error.dtype)
|
||||
total_squared_error += (squared_error * mask).sum().item()
|
||||
total_count += mask.sum().item() * squared_error.shape[-1]
|
||||
else:
|
||||
total_squared_error += squared_error.sum().item()
|
||||
total_count += target_actions.numel()
|
||||
if was_training:
|
||||
agent.train()
|
||||
return total_squared_error / max(total_count, 1.0)
|
||||
|
||||
|
||||
def resolve_resume_checkpoint(resume_ckpt, checkpoint_dir):
|
||||
"""
|
||||
解析恢复训练用的 checkpoint 路径。
|
||||
@@ -237,6 +358,32 @@ def build_training_optimizer(agent, lr, weight_decay):
|
||||
return AdamW(optim_groups, lr=lr, weight_decay=weight_decay)
|
||||
|
||||
|
||||
def load_state_dict_ignoring_shape_mismatches(module, incoming_state_dict):
|
||||
"""Load only checkpoint tensors whose keys exist locally and whose shapes match."""
|
||||
current_state_dict = module.state_dict()
|
||||
compatible_state_dict = {}
|
||||
mismatched_keys = []
|
||||
missing_keys = []
|
||||
|
||||
for key, value in incoming_state_dict.items():
|
||||
if key not in current_state_dict:
|
||||
missing_keys.append(key)
|
||||
continue
|
||||
if current_state_dict[key].shape != value.shape:
|
||||
mismatched_keys.append(key)
|
||||
continue
|
||||
compatible_state_dict[key] = value
|
||||
|
||||
merged_state_dict = dict(current_state_dict)
|
||||
merged_state_dict.update(compatible_state_dict)
|
||||
module.load_state_dict(merged_state_dict, strict=True)
|
||||
return {
|
||||
'loaded_keys': sorted(compatible_state_dict.keys()),
|
||||
'missing_keys': sorted(missing_keys),
|
||||
'mismatched_keys': sorted(mismatched_keys),
|
||||
}
|
||||
|
||||
|
||||
def _init_swanlab(cfg):
|
||||
"""按需初始化 SwanLab,并在缺少依赖或认证失败时快速失败。"""
|
||||
if not bool(cfg.train.get('use_swanlab', False)):
|
||||
@@ -384,30 +531,30 @@ def _run_training(cfg: DictConfig):
|
||||
vision_backbone_cfg = cfg.agent.get('vision_backbone', None)
|
||||
if vision_backbone_cfg is not None and 'dataset_image_resize_shape' in vision_backbone_cfg:
|
||||
dataset_image_resize_shape = vision_backbone_cfg.get('dataset_image_resize_shape')
|
||||
dataset = instantiate(
|
||||
cfg.data,
|
||||
image_resize_shape=dataset_image_resize_shape,
|
||||
dataset, train_dataset, val_dataset, explicit_val_episode_indices = (
|
||||
build_train_val_datasets(cfg, dataset_image_resize_shape)
|
||||
)
|
||||
log.info(f"✅ 数据集加载成功。总样本数: {len(dataset)}")
|
||||
except Exception as e:
|
||||
log.error(f"❌ 数据集加载失败: {e}")
|
||||
raise
|
||||
|
||||
# 训练/验证集划分
|
||||
val_split = float(cfg.train.get('val_split', 0.1))
|
||||
seed = int(cfg.train.get('seed', 42))
|
||||
val_size = int(len(dataset) * val_split)
|
||||
train_size = len(dataset) - val_size
|
||||
if val_size > 0:
|
||||
train_dataset, val_dataset = random_split(
|
||||
dataset,
|
||||
[train_size, val_size],
|
||||
generator=torch.Generator().manual_seed(seed)
|
||||
if explicit_val_episode_indices is not None:
|
||||
log.info(
|
||||
"✅ 数据集划分: 训练集=%s, 验证集=%s (显式 held-out episodes=%s)",
|
||||
len(train_dataset),
|
||||
len(val_dataset),
|
||||
explicit_val_episode_indices,
|
||||
)
|
||||
log.info(f"✅ 数据集划分: 训练集={train_size}, 验证集={val_size} (验证比例={val_split})")
|
||||
else:
|
||||
train_dataset, val_dataset = dataset, None
|
||||
log.info("✅ 数据集划分: 全部用于训练, 验证集=0 (验证比例=0)")
|
||||
val_split = float(cfg.train.get('val_split', 0.1))
|
||||
val_size = len(val_dataset) if val_dataset is not None else 0
|
||||
if val_size > 0:
|
||||
log.info(
|
||||
f"✅ 数据集划分: 训练集={len(train_dataset)}, 验证集={val_size} (验证比例={val_split})"
|
||||
)
|
||||
else:
|
||||
log.info("✅ 数据集划分: 全部用于训练, 验证集=0 (验证比例=0)")
|
||||
|
||||
train_batch_size = int(cfg.train.batch_size)
|
||||
train_drop_last = len(train_dataset) >= train_batch_size
|
||||
@@ -509,18 +656,23 @@ def _run_training(cfg: DictConfig):
|
||||
try:
|
||||
checkpoint = torch.load(ckpt_path, map_location=cfg.train.device)
|
||||
|
||||
# 只加载模型权重(不加载 optimizer、scheduler)
|
||||
missing_keys, unexpected_keys = agent.load_state_dict(
|
||||
load_info = load_state_dict_ignoring_shape_mismatches(
|
||||
agent,
|
||||
checkpoint['model_state_dict'],
|
||||
strict=False # 允许部分加载(结构不完全匹配时)
|
||||
)
|
||||
|
||||
log.info(f"✅ [Finetune] 模型权重加载成功")
|
||||
|
||||
if missing_keys:
|
||||
log.warning(f"⚠️ [Finetune] 缺少的键 ({len(missing_keys)} 个): {missing_keys[:5]}...")
|
||||
if unexpected_keys:
|
||||
log.warning(f"⚠️ [Finetune] 多余的键 ({len(unexpected_keys)} 个): {unexpected_keys[:5]}...")
|
||||
if load_info['missing_keys']:
|
||||
log.warning(
|
||||
f"⚠️ [Finetune] checkpoint 中存在本地模型没有的键 ({len(load_info['missing_keys'])} 个): "
|
||||
f"{load_info['missing_keys'][:5]}..."
|
||||
)
|
||||
if load_info['mismatched_keys']:
|
||||
log.warning(
|
||||
f"⚠️ [Finetune] 因形状不匹配而跳过的键 ({len(load_info['mismatched_keys'])} 个): "
|
||||
f"{load_info['mismatched_keys'][:5]}..."
|
||||
)
|
||||
|
||||
log.info(f"📊 [Finetune] 预训练信息: 步骤={checkpoint.get('step', 'N/A')}, 损失={checkpoint.get('loss', 'N/A')}")
|
||||
log.info(f"📈 [Finetune] 使用新的训练配置(lr={cfg.train.lr}, max_steps={cfg.train.max_steps})")
|
||||
@@ -643,22 +795,6 @@ def _run_training(cfg: DictConfig):
|
||||
# =========================================================================
|
||||
log.info("🏋️ 开始训练循环...")
|
||||
|
||||
def build_agent_input(batch_data):
|
||||
"""构建 agent 输入格式"""
|
||||
images = {}
|
||||
# SimpleRobotDataset 返回 observation.{cam_name} 格式
|
||||
for cam_name in cfg.data.camera_names:
|
||||
key = f"observation.{cam_name}"
|
||||
if key in batch_data:
|
||||
images[cam_name] = batch_data[key]
|
||||
|
||||
return {
|
||||
'images': images,
|
||||
'qpos': batch_data['observation.state'], # SimpleRobotDataset 使用 observation.state
|
||||
'action': batch_data['action'],
|
||||
'action_is_pad': batch_data.get('action_is_pad', None) # 传递padding mask
|
||||
}
|
||||
|
||||
def save_checkpoint(checkpoint_path: Path, step: int, loss_value, val_loss=None, rollout_avg_reward=None):
|
||||
agent_stats = agent.get_normalization_stats()
|
||||
torch.save({
|
||||
@@ -702,10 +838,28 @@ def _run_training(cfg: DictConfig):
|
||||
from roboimi.demos.vla_scripts import eval_vla
|
||||
|
||||
rollout_cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=False))
|
||||
rollout_num_episodes = int(cfg.train.get('rollout_num_episodes', 1))
|
||||
rollout_device = str(cfg.train.get('rollout_device', cfg.train.device))
|
||||
configured_rollout_workers = cfg.train.get('rollout_num_workers', None)
|
||||
if configured_rollout_workers is None:
|
||||
if rollout_device.startswith('cuda'):
|
||||
rollout_num_workers = min(max(rollout_num_episodes, 1), 8)
|
||||
else:
|
||||
rollout_num_workers = 1
|
||||
else:
|
||||
rollout_num_workers = int(configured_rollout_workers)
|
||||
rollout_cfg.eval.ckpt_path = str(checkpoint_path)
|
||||
rollout_cfg.eval.num_episodes = int(cfg.train.get('rollout_num_episodes', 1))
|
||||
rollout_cfg.eval.num_episodes = rollout_num_episodes
|
||||
rollout_cfg.eval.num_workers = rollout_num_workers
|
||||
rollout_cfg.eval.headless = True
|
||||
rollout_cfg.eval.device = 'cpu'
|
||||
rollout_cfg.eval.device = rollout_device
|
||||
rollout_cfg.eval.cuda_devices = cfg.train.get('rollout_cuda_devices', None)
|
||||
rollout_cfg.eval.response_timeout_s = float(
|
||||
cfg.train.get('rollout_response_timeout_s', 300.0)
|
||||
)
|
||||
rollout_cfg.eval.server_startup_timeout_s = float(
|
||||
cfg.train.get('rollout_server_startup_timeout_s', 300.0)
|
||||
)
|
||||
rollout_cfg.eval.verbose_action = False
|
||||
rollout_cfg.eval.record_video = False
|
||||
rollout_cfg.eval.save_trajectory_image = True
|
||||
@@ -716,9 +870,11 @@ def _run_training(cfg: DictConfig):
|
||||
)
|
||||
|
||||
log.info(
|
||||
"🎯 开始 checkpoint rollout 验证: %s (episodes=%s, headless=True)",
|
||||
"🎯 开始 checkpoint rollout 验证: %s (episodes=%s, device=%s, workers=%s, headless=True)",
|
||||
checkpoint_path,
|
||||
rollout_cfg.eval.num_episodes,
|
||||
rollout_cfg.eval.device,
|
||||
rollout_cfg.eval.num_workers,
|
||||
)
|
||||
return eval_vla._run_eval(rollout_cfg)
|
||||
|
||||
@@ -731,6 +887,7 @@ def _run_training(cfg: DictConfig):
|
||||
pbar = tqdm(range(start_step, cfg.train.max_steps), desc="训练中", ncols=100)
|
||||
|
||||
steps_per_epoch = len(train_loader)
|
||||
action_mse_val_freq_epochs = int(cfg.train.get('action_mse_val_freq_epochs', 0) or 0)
|
||||
rollout_val_freq_epochs = int(cfg.train.get('rollout_val_freq_epochs', 0) or 0)
|
||||
rollout_validation_enabled = rollout_val_freq_epochs > 0
|
||||
best_loss = resume_best_loss
|
||||
@@ -809,6 +966,15 @@ def _run_training(cfg: DictConfig):
|
||||
},
|
||||
step=step,
|
||||
)
|
||||
if hasattr(agent, 'get_last_loss_breakdown'):
|
||||
loss_breakdown = agent.get_last_loss_breakdown()
|
||||
extra_train_metrics = {
|
||||
f"train/{key}": value
|
||||
for key, value in loss_breakdown.items()
|
||||
if value is not None and key != 'loss'
|
||||
}
|
||||
if extra_train_metrics:
|
||||
_log_to_swanlab(swanlab_module, extra_train_metrics, step=step)
|
||||
|
||||
# =====================================================================
|
||||
# 检查点保存与验证
|
||||
@@ -891,6 +1057,33 @@ def _run_training(cfg: DictConfig):
|
||||
and completed_epoch > 0
|
||||
and completed_epoch % rollout_val_freq_epochs == 0
|
||||
)
|
||||
should_run_action_mse_validation = (
|
||||
action_mse_val_freq_epochs > 0
|
||||
and val_loader is not None
|
||||
and steps_per_epoch > 0
|
||||
and completed_steps % steps_per_epoch == 0
|
||||
and completed_epoch > 0
|
||||
and completed_epoch % action_mse_val_freq_epochs == 0
|
||||
)
|
||||
if should_run_action_mse_validation:
|
||||
action_mse = compute_action_mse_validation(
|
||||
agent,
|
||||
val_loader,
|
||||
cfg.train.device,
|
||||
)
|
||||
if action_mse is not None:
|
||||
log.info(
|
||||
f"步骤 {step}/{cfg.train.max_steps} | Epoch {completed_epoch} "
|
||||
f"held-out action MSE: {action_mse:.6f}"
|
||||
)
|
||||
_log_to_swanlab(
|
||||
swanlab_module,
|
||||
{
|
||||
'val/action_mse': action_mse,
|
||||
'val/action_mse_epoch': completed_epoch,
|
||||
},
|
||||
step=step,
|
||||
)
|
||||
if should_run_epoch_rollout:
|
||||
if checkpoint_path is None:
|
||||
checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt"
|
||||
|
||||
267
roboimi/scripts/refresh_experiment_suite_status.py
Executable file
267
roboimi/scripts/refresh_experiment_suite_status.py
Executable file
@@ -0,0 +1,267 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import datetime as dt
|
||||
import json
|
||||
import pathlib
|
||||
import re
|
||||
import shlex
|
||||
import subprocess
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
|
||||
STEP_PAT = re.compile(r"步骤\s+(\d+)/(\d+)")
|
||||
BAR_PAT = re.compile(r"\|\s*(\d+)/(\d+)")
|
||||
|
||||
|
||||
def normalize_chunks(text: str):
|
||||
for part in re.split(r"[\r\n]+", text):
|
||||
part = part.strip()
|
||||
if part:
|
||||
yield part
|
||||
|
||||
|
||||
def parse_latest_line(text: str) -> tuple[str, int | None]:
|
||||
latest_line = ""
|
||||
latest_step = None
|
||||
for line in normalize_chunks(text):
|
||||
if "步骤" not in line and "训练中:" not in line:
|
||||
continue
|
||||
latest_line = line
|
||||
match = STEP_PAT.search(line) or BAR_PAT.search(line)
|
||||
if match:
|
||||
latest_step = int(match.group(1))
|
||||
return latest_line, latest_step
|
||||
|
||||
|
||||
def now_iso() -> str:
|
||||
return dt.datetime.now(
|
||||
dt.timezone(dt.timedelta(hours=8)),
|
||||
).isoformat(timespec="seconds")
|
||||
|
||||
|
||||
def run_cmd(cmd: list[str], check: bool = True) -> subprocess.CompletedProcess[str]:
|
||||
return subprocess.run(cmd, capture_output=True, text=True, check=check)
|
||||
|
||||
|
||||
def probe_local(run: dict[str, Any]) -> dict[str, Any]:
|
||||
pid = str(run["pid"])
|
||||
ps = run_cmd(["ps", "-p", pid, "-o", "pid=,stat=,etime=,args="], check=False)
|
||||
log_path = pathlib.Path(run["log_path"])
|
||||
latest_line = ""
|
||||
latest_step = None
|
||||
if log_path.exists():
|
||||
latest_line, latest_step = parse_latest_line(log_path.read_text(errors="replace"))
|
||||
return {
|
||||
"alive": bool(ps.stdout.strip()),
|
||||
"ps": ps.stdout.strip(),
|
||||
"log_exists": log_path.exists(),
|
||||
"latest_line": latest_line,
|
||||
"latest_step": latest_step,
|
||||
}
|
||||
|
||||
|
||||
def remote_probe(host: str, remote_user: str, runs: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
|
||||
payload = [
|
||||
{
|
||||
"run_id": run["run_id"],
|
||||
"pid": str(run["pid"]),
|
||||
"log_path": run["log_path"],
|
||||
}
|
||||
for run in runs
|
||||
]
|
||||
remote_py = r"""
|
||||
import json
|
||||
import pathlib
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
payload = json.loads(sys.argv[1])
|
||||
step_pat = re.compile(r"步骤\s+(\d+)/(\d+)")
|
||||
bar_pat = re.compile(r"\|\s*(\d+)/(\d+)")
|
||||
|
||||
def normalize_chunks(text):
|
||||
for part in re.split(r"[\r\n]+", text):
|
||||
part = part.strip()
|
||||
if part:
|
||||
yield part
|
||||
|
||||
def parse_latest_line(text):
|
||||
latest_line = ""
|
||||
latest_step = None
|
||||
for line in normalize_chunks(text):
|
||||
if "步骤" not in line and "训练中:" not in line:
|
||||
continue
|
||||
latest_line = line
|
||||
match = step_pat.search(line) or bar_pat.search(line)
|
||||
if match:
|
||||
latest_step = int(match.group(1))
|
||||
return latest_line, latest_step
|
||||
|
||||
out = {}
|
||||
for item in payload:
|
||||
try:
|
||||
ps = subprocess.run(
|
||||
["ps", "-p", item["pid"], "-o", "pid=,stat=,etime=,args="],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
log_path = pathlib.Path(item["log_path"])
|
||||
latest_line = ""
|
||||
latest_step = None
|
||||
if log_path.exists():
|
||||
latest_line, latest_step = parse_latest_line(log_path.read_text(errors="replace"))
|
||||
out[item["run_id"]] = {
|
||||
"alive": bool(ps.stdout.strip()),
|
||||
"ps": ps.stdout.strip(),
|
||||
"log_exists": log_path.exists(),
|
||||
"latest_line": latest_line,
|
||||
"latest_step": latest_step,
|
||||
}
|
||||
except Exception as exc:
|
||||
out[item["run_id"]] = {
|
||||
"alive": False,
|
||||
"ps": "",
|
||||
"log_exists": False,
|
||||
"latest_line": "",
|
||||
"latest_step": None,
|
||||
"error": str(exc),
|
||||
}
|
||||
print(json.dumps(out, ensure_ascii=False))
|
||||
"""
|
||||
remote_target = host if "@" in host else f"{remote_user}@{host}"
|
||||
remote_cmd = (
|
||||
f"python3 -c {shlex.quote(remote_py)} "
|
||||
f"{shlex.quote(json.dumps(payload, ensure_ascii=False))}"
|
||||
)
|
||||
try:
|
||||
res = run_cmd(
|
||||
[
|
||||
"ssh",
|
||||
"-F",
|
||||
"/dev/null",
|
||||
"-o",
|
||||
"BatchMode=yes",
|
||||
"-o",
|
||||
"StrictHostKeyChecking=accept-new",
|
||||
remote_target,
|
||||
remote_cmd,
|
||||
]
|
||||
)
|
||||
return json.loads(res.stdout)
|
||||
except subprocess.CalledProcessError as exc:
|
||||
error = (exc.stderr or exc.stdout or str(exc)).strip()
|
||||
return {
|
||||
run["run_id"]: {
|
||||
"alive": False,
|
||||
"ps": "",
|
||||
"log_exists": False,
|
||||
"latest_line": "",
|
||||
"latest_step": None,
|
||||
"error": f"ssh_failed: {error}",
|
||||
}
|
||||
for run in runs
|
||||
}
|
||||
|
||||
|
||||
def append_notes(notes_path: pathlib.Path, snapshot_at: str, runs: list[dict[str, Any]]) -> None:
|
||||
lines = [f"\n## Status snapshot {snapshot_at}"]
|
||||
for run in runs:
|
||||
lines.append(
|
||||
(
|
||||
f"- {run['run_id']}: host={run['host']} gpu={run['gpu']} "
|
||||
f"alive={run.get('alive', False)} step={run.get('latest_step')} "
|
||||
f"pid={run['pid']}"
|
||||
)
|
||||
)
|
||||
if run.get("latest_line"):
|
||||
lines.append(f" - latest_line: `{run['latest_line']}`")
|
||||
if run.get("error"):
|
||||
lines.append(f" - error: `{run['error']}`")
|
||||
with notes_path.open("a", encoding="utf-8") as f:
|
||||
f.write("\n".join(lines) + "\n")
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("suite_dir", type=pathlib.Path)
|
||||
parser.add_argument("--remote-user", default="droid")
|
||||
parser.add_argument("--append-notes", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
suite_dir = args.suite_dir.resolve()
|
||||
status_path = suite_dir / "status.json"
|
||||
notes_path = suite_dir / "notes.md"
|
||||
monitor_dir = suite_dir / "monitor_logs"
|
||||
monitor_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
status = json.loads(status_path.read_text(encoding="utf-8"))
|
||||
runs: list[dict[str, Any]] = status["runs"]
|
||||
snapshot_at = now_iso()
|
||||
|
||||
by_host: dict[str, list[dict[str, Any]]] = defaultdict(list)
|
||||
for run in runs:
|
||||
by_host[run["host"]].append(run)
|
||||
|
||||
results: dict[str, dict[str, Any]] = {}
|
||||
for host, host_runs in by_host.items():
|
||||
if host == "local":
|
||||
for run in host_runs:
|
||||
results[run["run_id"]] = probe_local(run)
|
||||
else:
|
||||
results.update(remote_probe(host, args.remote_user, host_runs))
|
||||
|
||||
alive_count = 0
|
||||
for run in runs:
|
||||
result = results[run["run_id"]]
|
||||
run["alive"] = result["alive"]
|
||||
run["ps"] = result["ps"]
|
||||
run["log_exists"] = result["log_exists"]
|
||||
run["latest_line"] = result["latest_line"]
|
||||
run["latest_step"] = result["latest_step"]
|
||||
run["last_verified_at"] = snapshot_at
|
||||
if "error" in result:
|
||||
run["error"] = result["error"]
|
||||
else:
|
||||
run.pop("error", None)
|
||||
run["status"] = "running" if result["alive"] else "stopped"
|
||||
alive_count += int(result["alive"])
|
||||
|
||||
status["last_verified_at"] = snapshot_at
|
||||
status["alive_count"] = alive_count
|
||||
status["total_runs"] = len(runs)
|
||||
|
||||
status_path.write_text(json.dumps(status, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
|
||||
|
||||
snapshot_payload = {
|
||||
"suite_name": status.get("suite_name"),
|
||||
"snapshot_at": snapshot_at,
|
||||
"alive_count": alive_count,
|
||||
"total_runs": len(runs),
|
||||
"runs": {run["run_id"]: results[run["run_id"]] for run in runs},
|
||||
}
|
||||
timestamp_slug = snapshot_at.replace(":", "").replace("+", "_").replace("-", "")
|
||||
snapshot_path = monitor_dir / f"status-{timestamp_slug}.json"
|
||||
snapshot_path.write_text(
|
||||
json.dumps(snapshot_payload, ensure_ascii=False, indent=2) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
if args.append_notes:
|
||||
append_notes(notes_path, snapshot_at, runs)
|
||||
|
||||
print(json.dumps(snapshot_payload, ensure_ascii=False, indent=2))
|
||||
print(f"\nstatus_json={status_path}")
|
||||
print(f"snapshot_json={snapshot_path}")
|
||||
if args.append_notes:
|
||||
print(f"notes_md={notes_path}")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@@ -28,6 +28,7 @@ class VLAAgent(nn.Module):
|
||||
num_action_steps=8, # 每次推理实际执行多少步动作
|
||||
head_type='unet', # Policy head类型: 'unet' 或 'transformer'
|
||||
cond_projector=None, # 可选:将视觉+状态条件投影到head期望维度
|
||||
extra_condition_tokens: int = 0, # 可选:额外条件token数量(例如未来预测embedding)
|
||||
):
|
||||
super().__init__()
|
||||
# 保存参数
|
||||
@@ -39,6 +40,9 @@ class VLAAgent(nn.Module):
|
||||
self.num_action_steps = num_action_steps
|
||||
self.inference_steps = inference_steps
|
||||
self.head_type = head_type # 'unet' 或 'transformer'
|
||||
self.extra_condition_tokens = int(extra_condition_tokens)
|
||||
if self.extra_condition_tokens < 0:
|
||||
raise ValueError(f"extra_condition_tokens must be >= 0, got {self.extra_condition_tokens}")
|
||||
agent_camera_names = tuple(camera_names) if camera_names is not None else None
|
||||
backbone_camera_names = getattr(vision_backbone, 'camera_names', None)
|
||||
backbone_camera_names = tuple(backbone_camera_names) if backbone_camera_names is not None else None
|
||||
@@ -71,11 +75,14 @@ class VLAAgent(nn.Module):
|
||||
stats=dataset_stats,
|
||||
normalization_type=normalization_type
|
||||
)
|
||||
self.dataset_stats = dataset_stats
|
||||
|
||||
self.vision_encoder = vision_backbone
|
||||
self.state_encoder = state_encoder
|
||||
if self.camera_names is not None:
|
||||
self.vision_encoder.camera_names = self.camera_names
|
||||
self.condition_tokens_per_step = int(getattr(self.vision_encoder, 'tokens_per_step', 1))
|
||||
self.state_feature_dim = int(getattr(self.state_encoder, 'output_dim', obs_dim))
|
||||
joint_vision_dim = getattr(self.vision_encoder, 'joint_output_dim', None)
|
||||
if joint_vision_dim is not None:
|
||||
per_token_vision_dim = int(joint_vision_dim)
|
||||
@@ -87,8 +94,11 @@ class VLAAgent(nn.Module):
|
||||
else:
|
||||
per_token_vision_dim = int(single_cam_feat_dim) * int(num_cams)
|
||||
|
||||
self.condition_sequence_length = self.obs_horizon * self.condition_tokens_per_step
|
||||
self.raw_per_step_cond_dim = per_token_vision_dim + obs_dim
|
||||
self.history_condition_sequence_length = self.obs_horizon * self.condition_tokens_per_step
|
||||
self.condition_sequence_length = (
|
||||
self.history_condition_sequence_length + self.extra_condition_tokens
|
||||
)
|
||||
self.raw_per_step_cond_dim = per_token_vision_dim + self.state_feature_dim
|
||||
if cond_projector is None:
|
||||
self.cond_projector = None
|
||||
self.per_step_cond_dim = self.raw_per_step_cond_dim
|
||||
@@ -139,7 +149,6 @@ class VLAAgent(nn.Module):
|
||||
global_cond_dim=self.global_cond_dim
|
||||
)
|
||||
|
||||
self.state_encoder = state_encoder
|
||||
self.action_encoder = action_encoder
|
||||
|
||||
# 初始化队列(用于在线推理)
|
||||
@@ -220,7 +229,7 @@ class VLAAgent(nn.Module):
|
||||
f"条件维度不匹配: got {cond.shape[-1]}, expected {self.per_step_cond_dim}"
|
||||
)
|
||||
cond = cond.reshape(batch_size, obs_steps * token_count, self.per_step_cond_dim)
|
||||
expected_length = self.condition_sequence_length
|
||||
expected_length = self.history_condition_sequence_length
|
||||
if cond.shape[1] != expected_length:
|
||||
raise RuntimeError(
|
||||
f"条件序列长度不匹配: got {cond.shape[1]}, expected {expected_length}"
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import nullcontext
|
||||
from typing import Dict, Optional
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Mapping, Optional, Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from roboimi.vla.agent import VLAAgent
|
||||
@@ -15,14 +18,87 @@ except ImportError: # pragma: no cover
|
||||
|
||||
|
||||
class IMFVLAAgent(VLAAgent):
|
||||
def __init__(self, *args, inference_steps: int = 1, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
inference_steps: int = 1,
|
||||
lewm_history_horizon: Optional[int] = None,
|
||||
lewm_query_offsets: Optional[Sequence[int]] = None,
|
||||
lewm_predictor: Optional[nn.Module] = None,
|
||||
lewm_pred_projector: Optional[nn.Module] = None,
|
||||
future_decoder: Optional[nn.Module] = None,
|
||||
future_query_init_std: float = 0.02,
|
||||
lewm_sigreg: Optional[nn.Module] = None,
|
||||
lewm_sigreg_weight: float = 0.09,
|
||||
lewm_loss_weight: float = 0.0,
|
||||
lewm_pretrained_ckpt: Optional[str | Path | Mapping[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if inference_steps != 1:
|
||||
raise ValueError(
|
||||
'IMFVLAAgent only supports one-step inference; '
|
||||
f'inference_steps must be 1, got {inference_steps}.'
|
||||
)
|
||||
lewm_query_offsets = tuple(int(offset) for offset in (lewm_query_offsets or ()))
|
||||
inferred_extra_condition_tokens = len(lewm_query_offsets) if lewm_query_offsets else 0
|
||||
default_extra_condition_tokens = (
|
||||
0 if future_decoder is not None else inferred_extra_condition_tokens
|
||||
)
|
||||
kwargs.setdefault('extra_condition_tokens', default_extra_condition_tokens)
|
||||
self.__dict__['lewm_history_horizon'] = int(lewm_history_horizon or kwargs.get('obs_horizon', 1))
|
||||
self.__dict__['lewm_query_offsets'] = lewm_query_offsets
|
||||
self.__dict__['lewm_predictor'] = lewm_predictor
|
||||
self.__dict__['lewm_pred_projector'] = lewm_pred_projector or nn.Identity()
|
||||
self.__dict__['future_decoder'] = future_decoder
|
||||
self.__dict__['future_query_tokens'] = None
|
||||
self.__dict__['future_query_init_std'] = float(future_query_init_std)
|
||||
self.__dict__['lewm_sigreg'] = lewm_sigreg
|
||||
self.__dict__['lewm_sigreg_weight'] = float(lewm_sigreg_weight)
|
||||
self.__dict__['lewm_loss_weight'] = float(lewm_loss_weight)
|
||||
self.__dict__['_last_loss_breakdown'] = {
|
||||
'action_loss': 0.0,
|
||||
'lewm_pred_loss': 0.0,
|
||||
'lewm_sigreg_loss': 0.0,
|
||||
'lewm_loss': 0.0,
|
||||
'loss': 0.0,
|
||||
}
|
||||
super().__init__(*args, inference_steps=inference_steps, **kwargs)
|
||||
self.inference_steps = 1
|
||||
self.lewm_history_horizon = int(lewm_history_horizon or self.obs_horizon)
|
||||
self.lewm_predictor = lewm_predictor
|
||||
self.lewm_pred_projector = lewm_pred_projector or nn.Identity()
|
||||
if future_decoder is not None and not isinstance(future_decoder, nn.Module):
|
||||
self.future_decoder = future_decoder()
|
||||
else:
|
||||
self.future_decoder = future_decoder
|
||||
self.future_query_tokens = None
|
||||
self.future_query_init_std = float(future_query_init_std)
|
||||
self.lewm_sigreg = lewm_sigreg
|
||||
self.lewm_sigreg_weight = float(lewm_sigreg_weight)
|
||||
if self.lewm_predictor is not None and self.future_decoder is not None:
|
||||
raise ValueError('lewm_predictor and future_decoder are mutually exclusive')
|
||||
if self.lewm_predictor is None and self.extra_condition_tokens > 0:
|
||||
raise ValueError(
|
||||
'extra_condition_tokens > 0 requires lewm_predictor to be provided'
|
||||
)
|
||||
if self.lewm_predictor is not None and self.extra_condition_tokens != inferred_extra_condition_tokens:
|
||||
raise ValueError(
|
||||
'extra_condition_tokens must equal len(lewm_query_offsets) when lewm_predictor is enabled'
|
||||
)
|
||||
if self.future_decoder is not None:
|
||||
if inferred_extra_condition_tokens <= 0:
|
||||
raise ValueError('future_decoder requires non-empty lewm_query_offsets')
|
||||
if self.extra_condition_tokens != 0:
|
||||
raise ValueError('future_decoder requires extra_condition_tokens=0')
|
||||
self.future_query_tokens = nn.Parameter(
|
||||
torch.randn(
|
||||
1,
|
||||
inferred_extra_condition_tokens,
|
||||
self.per_step_cond_dim,
|
||||
) * self.future_query_init_std
|
||||
)
|
||||
if lewm_pretrained_ckpt is not None:
|
||||
self.load_lewm_pretrained_components(lewm_pretrained_ckpt)
|
||||
|
||||
@staticmethod
|
||||
def _broadcast_batch_time(value: torch.Tensor, reference: torch.Tensor) -> torch.Tensor:
|
||||
@@ -119,14 +195,251 @@ class IMFVLAAgent(VLAAgent):
|
||||
delta = self._broadcast_batch_time(t - r, z_t)
|
||||
return z_t - delta * u
|
||||
|
||||
def _normalize_qpos_for_lewm(self, qpos: torch.Tensor) -> torch.Tensor:
|
||||
if not self.normalization.enabled:
|
||||
return qpos
|
||||
|
||||
qpos_mean = getattr(self.normalization, 'qpos_mean', None)
|
||||
qpos_std = getattr(self.normalization, 'qpos_std', None)
|
||||
if qpos_mean is not None and qpos_std is not None:
|
||||
return (qpos - qpos_mean) / qpos_std
|
||||
if isinstance(self.dataset_stats, dict):
|
||||
mean = self.dataset_stats.get('qpos_mean', None)
|
||||
std = self.dataset_stats.get('qpos_std', None)
|
||||
if mean is not None and std is not None:
|
||||
mean = torch.as_tensor(mean, dtype=qpos.dtype, device=qpos.device)
|
||||
std = torch.as_tensor(std, dtype=qpos.dtype, device=qpos.device)
|
||||
return (qpos - mean) / std
|
||||
return self.normalization.normalize_qpos(qpos)
|
||||
|
||||
def _project_lewm_future_tokens(self, predicted_tokens: torch.Tensor) -> torch.Tensor:
|
||||
if predicted_tokens.ndim != 3:
|
||||
raise ValueError(
|
||||
f"expected predicted future tokens to be 3D, got rank {predicted_tokens.ndim}"
|
||||
)
|
||||
batch_size, token_count, token_dim = predicted_tokens.shape
|
||||
flattened = predicted_tokens.reshape(batch_size * token_count, token_dim)
|
||||
projected = self.lewm_pred_projector(flattened)
|
||||
if projected.ndim != 2:
|
||||
raise ValueError(
|
||||
f"expected lewm_pred_projector to return rank-2 tensors, got rank {projected.ndim}"
|
||||
)
|
||||
return projected.reshape(batch_size, token_count, projected.shape[-1])
|
||||
|
||||
@staticmethod
|
||||
def _load_checkpoint_payload(
|
||||
checkpoint_or_path: str | Path | Mapping[str, Any],
|
||||
) -> Mapping[str, torch.Tensor]:
|
||||
if isinstance(checkpoint_or_path, (str, Path)):
|
||||
payload = torch.load(Path(checkpoint_or_path), map_location='cpu', weights_only=False)
|
||||
else:
|
||||
payload = checkpoint_or_path
|
||||
state_dict = payload.get('state_dict', payload)
|
||||
if not isinstance(state_dict, Mapping):
|
||||
raise TypeError('checkpoint payload must contain a mapping state_dict')
|
||||
return state_dict
|
||||
|
||||
@staticmethod
|
||||
def _extract_prefixed_state_dict(
|
||||
state_dict: Mapping[str, torch.Tensor],
|
||||
prefix: str,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
extracted = {
|
||||
key[len(prefix):]: value
|
||||
for key, value in state_dict.items()
|
||||
if key.startswith(prefix)
|
||||
}
|
||||
if not extracted:
|
||||
raise KeyError(f"checkpoint missing parameters with prefix {prefix!r}")
|
||||
return extracted
|
||||
|
||||
@staticmethod
|
||||
def _adapt_and_load_state_dict(
|
||||
module: nn.Module,
|
||||
incoming_state_dict: Mapping[str, torch.Tensor],
|
||||
*,
|
||||
query_key: str = 'query_tokens',
|
||||
pos_key: str = 'pos_embedding',
|
||||
) -> Dict[str, Sequence[str]]:
|
||||
current_state_dict = module.state_dict()
|
||||
adapted_state_dict = dict(current_state_dict)
|
||||
loaded_keys = []
|
||||
mismatched_keys = []
|
||||
missing_keys = []
|
||||
for key, current_tensor in current_state_dict.items():
|
||||
if key not in incoming_state_dict:
|
||||
continue
|
||||
source_tensor = incoming_state_dict[key]
|
||||
if source_tensor.shape == current_tensor.shape:
|
||||
adapted_state_dict[key] = source_tensor
|
||||
loaded_keys.append(key)
|
||||
continue
|
||||
|
||||
if key in {query_key, pos_key} and source_tensor.ndim == current_tensor.ndim:
|
||||
patched = current_tensor.clone()
|
||||
overlap_slices = tuple(
|
||||
slice(0, min(src_dim, cur_dim))
|
||||
for src_dim, cur_dim in zip(source_tensor.shape, current_tensor.shape)
|
||||
)
|
||||
patched[overlap_slices] = source_tensor[overlap_slices]
|
||||
if key == query_key:
|
||||
copy_count = min(source_tensor.shape[1], current_tensor.shape[1])
|
||||
if copy_count < current_tensor.shape[1] and copy_count > 0:
|
||||
tail = source_tensor[:, copy_count - 1:copy_count, ...]
|
||||
feature_dim = min(tail.shape[-1], patched.shape[-1])
|
||||
patched[:, copy_count:, :feature_dim] = tail[:, :, :feature_dim]
|
||||
else:
|
||||
copy_count = min(source_tensor.shape[1], current_tensor.shape[1])
|
||||
if copy_count < current_tensor.shape[1] and copy_count > 0:
|
||||
tail = source_tensor[:, copy_count - 1:copy_count, ...]
|
||||
feature_dim = min(tail.shape[-1], patched.shape[-1])
|
||||
patched[:, copy_count:, :feature_dim] = tail[:, :, :feature_dim]
|
||||
adapted_state_dict[key] = patched
|
||||
loaded_keys.append(key)
|
||||
continue
|
||||
mismatched_keys.append(key)
|
||||
|
||||
for key in incoming_state_dict.keys():
|
||||
if key not in current_state_dict:
|
||||
missing_keys.append(key)
|
||||
module.load_state_dict(adapted_state_dict, strict=True)
|
||||
return {
|
||||
'loaded_keys': tuple(sorted(loaded_keys)),
|
||||
'mismatched_keys': tuple(sorted(set(mismatched_keys))),
|
||||
'missing_keys': tuple(sorted(set(missing_keys))),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _load_state_dict_ignoring_shape_mismatches(
|
||||
module: nn.Module,
|
||||
incoming_state_dict: Mapping[str, torch.Tensor],
|
||||
) -> Dict[str, Sequence[str]]:
|
||||
current_state_dict = module.state_dict()
|
||||
merged_state_dict = dict(current_state_dict)
|
||||
loaded_keys = []
|
||||
mismatched_keys = []
|
||||
missing_keys = []
|
||||
|
||||
for key, value in incoming_state_dict.items():
|
||||
if key not in current_state_dict:
|
||||
missing_keys.append(key)
|
||||
continue
|
||||
if current_state_dict[key].shape != value.shape:
|
||||
mismatched_keys.append(key)
|
||||
continue
|
||||
merged_state_dict[key] = value
|
||||
loaded_keys.append(key)
|
||||
|
||||
module.load_state_dict(merged_state_dict, strict=True)
|
||||
return {
|
||||
'loaded_keys': tuple(sorted(loaded_keys)),
|
||||
'mismatched_keys': tuple(sorted(mismatched_keys)),
|
||||
'missing_keys': tuple(sorted(missing_keys)),
|
||||
}
|
||||
|
||||
def load_lewm_pretrained_components(
|
||||
self,
|
||||
checkpoint_or_path: str | Path | Mapping[str, Any],
|
||||
) -> None:
|
||||
state_dict = self._load_checkpoint_payload(checkpoint_or_path)
|
||||
|
||||
if hasattr(self.vision_encoder, 'load_lewm_checkpoint'):
|
||||
try:
|
||||
self.vision_encoder.load_lewm_checkpoint({'state_dict': state_dict})
|
||||
except RuntimeError:
|
||||
vision_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.encoder.')
|
||||
self._load_state_dict_ignoring_shape_mismatches(self.vision_encoder, vision_state_dict)
|
||||
else:
|
||||
vision_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.encoder.')
|
||||
self._load_state_dict_ignoring_shape_mismatches(self.vision_encoder, vision_state_dict)
|
||||
|
||||
state_encoder_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.state_encoder.')
|
||||
self._load_state_dict_ignoring_shape_mismatches(self.state_encoder, state_encoder_state_dict)
|
||||
|
||||
projector_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.projector.proj.')
|
||||
mapped_projector_state_dict = {
|
||||
f'linear.{key}': value
|
||||
for key, value in projector_state_dict.items()
|
||||
}
|
||||
self._load_state_dict_ignoring_shape_mismatches(self.cond_projector, mapped_projector_state_dict)
|
||||
|
||||
if self.lewm_predictor is not None:
|
||||
predictor_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.predictor.')
|
||||
self._adapt_and_load_state_dict(self.lewm_predictor, predictor_state_dict)
|
||||
|
||||
if self.lewm_pred_projector is not None:
|
||||
pred_projector_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.pred_proj.')
|
||||
self._load_state_dict_ignoring_shape_mismatches(
|
||||
self.lewm_pred_projector,
|
||||
pred_projector_state_dict,
|
||||
)
|
||||
|
||||
def _predict_future_tokens_with_decoder(self, history_cond: torch.Tensor) -> torch.Tensor:
|
||||
if self.future_decoder is None or self.future_query_tokens is None:
|
||||
raise RuntimeError('future_decoder path requested but not initialized')
|
||||
batch_size = history_cond.shape[0]
|
||||
query_tokens = self.future_query_tokens.expand(batch_size, -1, -1)
|
||||
r = torch.zeros(batch_size, device=history_cond.device, dtype=history_cond.dtype)
|
||||
t = torch.ones(batch_size, device=history_cond.device, dtype=history_cond.dtype)
|
||||
return self.future_decoder(query_tokens, r, t, cond=history_cond)
|
||||
|
||||
def _build_full_condition(
|
||||
self,
|
||||
images,
|
||||
proprioception,
|
||||
*,
|
||||
lewm_images=None,
|
||||
lewm_proprioception=None,
|
||||
):
|
||||
normalized_proprioception = self.normalization.normalize_qpos(proprioception)
|
||||
history_cond = self._build_cond(images, normalized_proprioception)
|
||||
predicted_future_tokens = None
|
||||
lewm_history_cond = None
|
||||
|
||||
if self.lewm_predictor is None and self.future_decoder is None:
|
||||
return history_cond, predicted_future_tokens, lewm_history_cond
|
||||
|
||||
lewm_images = lewm_images if lewm_images is not None else images
|
||||
lewm_proprioception = (
|
||||
lewm_proprioception if lewm_proprioception is not None else proprioception
|
||||
)
|
||||
lewm_history_cond = self._build_cond(
|
||||
lewm_images,
|
||||
self._normalize_qpos_for_lewm(lewm_proprioception),
|
||||
)
|
||||
cond = history_cond
|
||||
if self.lewm_predictor is not None:
|
||||
predicted_future_tokens = self.lewm_predictor(lewm_history_cond)
|
||||
predicted_future_tokens = self._project_lewm_future_tokens(predicted_future_tokens)
|
||||
cond = torch.cat([history_cond, predicted_future_tokens], dim=1)
|
||||
if cond.shape[1] != self.condition_sequence_length:
|
||||
raise RuntimeError(
|
||||
f"完整条件序列长度不匹配: got {cond.shape[1]}, expected {self.condition_sequence_length}"
|
||||
)
|
||||
if cond.shape[-1] != self.per_step_cond_dim:
|
||||
raise RuntimeError(
|
||||
f"完整条件维度不匹配: got {cond.shape[-1]}, expected {self.per_step_cond_dim}"
|
||||
)
|
||||
elif self.future_decoder is not None:
|
||||
predicted_future_tokens = self._predict_future_tokens_with_decoder(lewm_history_cond)
|
||||
return cond, predicted_future_tokens, lewm_history_cond
|
||||
|
||||
@staticmethod
|
||||
def _masked_mse_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
return F.mse_loss(pred, target)
|
||||
|
||||
def compute_loss(self, batch):
|
||||
actions, states, images = batch['action'], batch['qpos'], batch['images']
|
||||
action_is_pad = batch.get('action_is_pad', None)
|
||||
batch_size = actions.shape[0]
|
||||
|
||||
states = self.normalization.normalize_qpos(states)
|
||||
actions = self.normalization.normalize_action(actions)
|
||||
cond = self._build_cond(images, states)
|
||||
cond, predicted_future_tokens, lewm_history_cond = self._build_full_condition(
|
||||
images,
|
||||
states,
|
||||
lewm_images=batch.get('lewm_images', None),
|
||||
lewm_proprioception=batch.get('lewm_qpos', None),
|
||||
)
|
||||
|
||||
x = actions
|
||||
e = torch.randn_like(x)
|
||||
@@ -146,16 +459,109 @@ class IMFVLAAgent(VLAAgent):
|
||||
if action_is_pad is not None:
|
||||
mask = (~action_is_pad).unsqueeze(-1).to(loss.dtype)
|
||||
valid_count = mask.sum() * loss.shape[-1]
|
||||
loss = (loss * mask).sum() / valid_count.clamp_min(1.0)
|
||||
action_loss = (loss * mask).sum() / valid_count.clamp_min(1.0)
|
||||
else:
|
||||
loss = loss.mean()
|
||||
return loss
|
||||
action_loss = loss.mean()
|
||||
|
||||
lewm_pred_loss = torch.zeros((), device=action_loss.device, dtype=action_loss.dtype)
|
||||
lewm_sigreg_loss = torch.zeros((), device=action_loss.device, dtype=action_loss.dtype)
|
||||
if predicted_future_tokens is not None:
|
||||
lewm_future_images = batch.get('lewm_future_images', None)
|
||||
lewm_future_qpos = batch.get('lewm_future_qpos', None)
|
||||
if lewm_future_images is not None and lewm_future_qpos is not None:
|
||||
future_target = self._build_cond(
|
||||
lewm_future_images,
|
||||
self._normalize_qpos_for_lewm(lewm_future_qpos),
|
||||
)
|
||||
lewm_pred_loss = self._masked_mse_loss(predicted_future_tokens, future_target)
|
||||
if self.lewm_sigreg is not None and lewm_history_cond is not None:
|
||||
lewm_sigreg_loss = self.lewm_sigreg(lewm_history_cond.transpose(0, 1))
|
||||
|
||||
lewm_loss = lewm_pred_loss + self.lewm_sigreg_weight * lewm_sigreg_loss
|
||||
total_loss = action_loss + self.lewm_loss_weight * lewm_loss
|
||||
self._last_loss_breakdown = {
|
||||
'action_loss': float(action_loss.detach().item()),
|
||||
'lewm_pred_loss': float(lewm_pred_loss.detach().item()),
|
||||
'lewm_sigreg_loss': float(lewm_sigreg_loss.detach().item()),
|
||||
'lewm_loss': float(lewm_loss.detach().item()),
|
||||
'loss': float(total_loss.detach().item()),
|
||||
}
|
||||
return total_loss
|
||||
|
||||
def get_last_loss_breakdown(self) -> Dict[str, float]:
|
||||
return dict(self._last_loss_breakdown)
|
||||
|
||||
def reset(self):
|
||||
super().reset()
|
||||
if self.lewm_predictor is not None:
|
||||
self._queues['lewm_qpos'] = deque(maxlen=self.lewm_history_horizon)
|
||||
self._queues['lewm_images'] = deque(maxlen=self.lewm_history_horizon)
|
||||
|
||||
def _populate_queues(self, observation: Dict[str, torch.Tensor]) -> None:
|
||||
super()._populate_queues(observation)
|
||||
if self.lewm_predictor is None:
|
||||
return
|
||||
if 'qpos' in observation:
|
||||
self._queues['lewm_qpos'].append(observation['qpos'].clone())
|
||||
if 'images' in observation:
|
||||
ordered_images = self._order_images(observation['images'])
|
||||
self._queues['lewm_images'].append({k: v.clone() for k, v in ordered_images.items()})
|
||||
|
||||
def _prepare_observation_batch(self) -> Dict[str, torch.Tensor]:
|
||||
batch = super()._prepare_observation_batch()
|
||||
if self.lewm_predictor is None:
|
||||
return batch
|
||||
|
||||
qpos_list = list(self._queues['lewm_qpos'])
|
||||
images_list = list(self._queues['lewm_images'])
|
||||
if len(qpos_list) == 0 or len(images_list) == 0:
|
||||
raise ValueError("LeWM 观测队列为空,请先调用 _populate_queues 添加观测")
|
||||
while len(qpos_list) < self.lewm_history_horizon:
|
||||
qpos_list.append(qpos_list[-1])
|
||||
while len(images_list) < self.lewm_history_horizon:
|
||||
images_list.append(images_list[-1])
|
||||
|
||||
batch['lewm_qpos'] = torch.stack(qpos_list, dim=0).unsqueeze(0)
|
||||
batch['lewm_images'] = {}
|
||||
camera_names = self.camera_names if self.camera_names is not None else tuple(sorted(images_list[0].keys()))
|
||||
for cam_name in camera_names:
|
||||
batch['lewm_images'][cam_name] = torch.stack(
|
||||
[img[cam_name] for img in images_list],
|
||||
dim=0,
|
||||
).unsqueeze(0)
|
||||
return batch
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action(self, images, proprioception):
|
||||
def predict_action_chunk(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
return self.predict_action(
|
||||
batch['images'],
|
||||
batch['qpos'],
|
||||
lewm_images=batch.get('lewm_images', None),
|
||||
lewm_proprioception=batch.get('lewm_qpos', None),
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action(
|
||||
self,
|
||||
images,
|
||||
proprioception,
|
||||
*,
|
||||
lewm_images=None,
|
||||
lewm_proprioception=None,
|
||||
):
|
||||
batch_size = proprioception.shape[0]
|
||||
proprioception = self.normalization.normalize_qpos(proprioception)
|
||||
cond = self._build_cond(images, proprioception)
|
||||
if self.lewm_predictor is not None:
|
||||
cond, _predicted_future_tokens, _lewm_history_cond = self._build_full_condition(
|
||||
images,
|
||||
proprioception,
|
||||
lewm_images=lewm_images,
|
||||
lewm_proprioception=lewm_proprioception,
|
||||
)
|
||||
else:
|
||||
cond = self._build_cond(
|
||||
images,
|
||||
self.normalization.normalize_qpos(proprioception),
|
||||
)
|
||||
z_t = torch.randn((batch_size, self.pred_horizon, self.action_dim), device=cond.device, dtype=cond.dtype)
|
||||
action = self._sample_one_step(z_t, cond=cond)
|
||||
return self.normalization.denormalize_action(action)
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
# @package agent
|
||||
defaults:
|
||||
- /backbone@vision_backbone: lewm_resnet_query_fusion
|
||||
- /modules@state_encoder: lewm_state_encoder
|
||||
- /modules@action_encoder: identity_action_encoder
|
||||
- /modules@cond_projector: linear_condition_projector
|
||||
- /head: imf_transformer1d
|
||||
- /head@future_decoder: imf_transformer1d
|
||||
- _self_
|
||||
|
||||
_target_: roboimi.vla.agent_imf.IMFVLAAgent
|
||||
|
||||
action_dim: 16
|
||||
obs_dim: 16
|
||||
normalization_type: "min_max"
|
||||
pred_horizon: 8
|
||||
obs_horizon: 2
|
||||
num_action_steps: 8
|
||||
camera_names: ${data.camera_names}
|
||||
num_cams: 3
|
||||
|
||||
vision_backbone:
|
||||
camera_names: ${agent.camera_names}
|
||||
num_views: ${agent.num_cams}
|
||||
|
||||
cond_projector:
|
||||
output_dim: 288
|
||||
|
||||
lewm_history_horizon: 3
|
||||
lewm_query_offsets: [8]
|
||||
extra_condition_tokens: 0
|
||||
lewm_loss_weight: 1.0
|
||||
lewm_sigreg_weight: 0.09
|
||||
lewm_pretrained_ckpt: null
|
||||
future_query_init_std: 0.02
|
||||
|
||||
lewm_sigreg:
|
||||
_target_: roboimi.vla.models.backbones.lewm_resnet_query_fusion.SIGReg
|
||||
knots: 17
|
||||
num_proj: 1024
|
||||
|
||||
diffusion_steps: 100
|
||||
inference_steps: 1
|
||||
head_type: "transformer"
|
||||
|
||||
head:
|
||||
input_dim: ${agent.action_dim}
|
||||
output_dim: ${agent.action_dim}
|
||||
horizon: ${agent.pred_horizon}
|
||||
n_obs_steps: ${agent.obs_horizon}
|
||||
cond_dim: ${agent.cond_projector.output_dim}
|
||||
n_emb: 384
|
||||
causal_attn: false
|
||||
time_as_cond: true
|
||||
obs_as_cond: true
|
||||
n_cond_layers: 0
|
||||
backbone_type: attnres_full
|
||||
n_head: 1
|
||||
n_kv_head: 1
|
||||
|
||||
future_decoder:
|
||||
input_dim: ${agent.cond_projector.output_dim}
|
||||
output_dim: ${agent.cond_projector.output_dim}
|
||||
horizon: ${len:${agent.lewm_query_offsets}}
|
||||
n_obs_steps: ${agent.lewm_history_horizon}
|
||||
cond_dim: ${agent.cond_projector.output_dim}
|
||||
n_emb: 384
|
||||
causal_attn: false
|
||||
time_as_cond: true
|
||||
obs_as_cond: true
|
||||
n_cond_layers: 0
|
||||
backbone_type: attnres_full
|
||||
n_head: 1
|
||||
n_kv_head: 1
|
||||
77
roboimi/vla/conf/agent/lewm_resnet_query_imf_attnres.yaml
Normal file
77
roboimi/vla/conf/agent/lewm_resnet_query_imf_attnres.yaml
Normal file
@@ -0,0 +1,77 @@
|
||||
# @package agent
|
||||
defaults:
|
||||
- /backbone@vision_backbone: lewm_resnet_query_fusion
|
||||
- /modules@state_encoder: lewm_state_encoder
|
||||
- /modules@action_encoder: identity_action_encoder
|
||||
- /modules@cond_projector: linear_condition_projector
|
||||
- /head: imf_transformer1d
|
||||
- _self_
|
||||
|
||||
_target_: roboimi.vla.agent_imf.IMFVLAAgent
|
||||
|
||||
action_dim: 16
|
||||
obs_dim: 16
|
||||
normalization_type: "min_max"
|
||||
pred_horizon: 8
|
||||
obs_horizon: 2
|
||||
num_action_steps: 8
|
||||
camera_names: ${data.camera_names}
|
||||
num_cams: 3
|
||||
|
||||
vision_backbone:
|
||||
camera_names: ${agent.camera_names}
|
||||
num_views: ${agent.num_cams}
|
||||
|
||||
cond_projector:
|
||||
output_dim: 288
|
||||
|
||||
lewm_history_horizon: 3
|
||||
lewm_query_offsets: [8]
|
||||
extra_condition_tokens: ${len:${agent.lewm_query_offsets}}
|
||||
lewm_loss_weight: 1.0
|
||||
lewm_sigreg_weight: 0.09
|
||||
lewm_pretrained_ckpt: null
|
||||
|
||||
lewm_sigreg:
|
||||
_target_: roboimi.vla.models.backbones.lewm_resnet_query_fusion.SIGReg
|
||||
knots: 17
|
||||
num_proj: 1024
|
||||
|
||||
lewm_predictor:
|
||||
_target_: roboimi.vla.models.backbones.lewm_resnet_query_fusion.QueryTokenPredictor
|
||||
num_frames: ${agent.lewm_history_horizon}
|
||||
query_offsets: ${agent.lewm_query_offsets}
|
||||
input_dim: ${agent.cond_projector.output_dim}
|
||||
hidden_dim: ${agent.cond_projector.output_dim}
|
||||
output_dim: ${agent.cond_projector.output_dim}
|
||||
depth: 6
|
||||
heads: 16
|
||||
mlp_dim: 2048
|
||||
dim_head: 64
|
||||
dropout: 0.1
|
||||
emb_dropout: 0.0
|
||||
|
||||
lewm_pred_projector:
|
||||
_target_: roboimi.vla.models.backbones.lewm_resnet_query_fusion.LeWMProjectorMLP
|
||||
input_dim: ${agent.cond_projector.output_dim}
|
||||
hidden_dim: 2048
|
||||
output_dim: ${agent.cond_projector.output_dim}
|
||||
|
||||
diffusion_steps: 100
|
||||
inference_steps: 1
|
||||
head_type: "transformer"
|
||||
|
||||
head:
|
||||
input_dim: ${agent.action_dim}
|
||||
output_dim: ${agent.action_dim}
|
||||
horizon: ${agent.pred_horizon}
|
||||
n_obs_steps: ${agent.obs_horizon}
|
||||
cond_dim: 288
|
||||
n_emb: 384
|
||||
causal_attn: false
|
||||
time_as_cond: true
|
||||
obs_as_cond: true
|
||||
n_cond_layers: 0
|
||||
backbone_type: attnres_full
|
||||
n_head: 1
|
||||
n_kv_head: 1
|
||||
7
roboimi/vla/conf/backbone/lewm_resnet_query_fusion.yaml
Normal file
7
roboimi/vla/conf/backbone/lewm_resnet_query_fusion.yaml
Normal file
@@ -0,0 +1,7 @@
|
||||
_target_: roboimi.vla.models.backbones.lewm_resnet_query_fusion.LeWMMultiViewResNetBackbone
|
||||
|
||||
view_feature_dim: 96
|
||||
num_views: ${agent.num_cams}
|
||||
view_encoder_mode: separate
|
||||
camera_names: ${agent.camera_names}
|
||||
checkpoint_path: null
|
||||
@@ -18,6 +18,8 @@ train:
|
||||
# 数据加载
|
||||
num_workers: 12 # DataLoader 工作进程数(调试时设为 0)
|
||||
val_split: 0.0 # 验证集比例;默认使用全量数据训练
|
||||
val_episode_indices: null # 显式按 episode 划出的验证集,例如 [100]
|
||||
action_mse_val_freq_epochs: 0 # >0 时每隔多少个 epoch 在 held-out episode 上计算 action MSE
|
||||
seed: 42 # 随机种子(用于数据划分)
|
||||
|
||||
# 日志和检查点
|
||||
@@ -29,6 +31,11 @@ train:
|
||||
rollout_val_freq_epochs: 50 # 每隔多少个 epoch 执行一次 rollout 验证
|
||||
rollout_validate_on_checkpoint: false # 是否在保存 checkpoint 后立即运行 rollout 验证
|
||||
rollout_num_episodes: 3 # rollout 验证的回合数
|
||||
rollout_device: ${train.device} # rollout 使用的设备;默认跟随训练设备
|
||||
rollout_num_workers: null # rollout 并行 worker 数;null 时 CUDA 自动推断,CPU 保持 1
|
||||
rollout_cuda_devices: null # rollout CUDA 并行使用的逻辑 device 列表;null 时默认 [0]
|
||||
rollout_response_timeout_s: 300.0 # rollout worker 等待 inference server 响应的超时时间
|
||||
rollout_server_startup_timeout_s: 300.0 # rollout 等待 inference server 就绪的超时时间
|
||||
|
||||
# 学习率调度器(带预热)
|
||||
warmup_steps: 2000 # 预热步数(Transformer建议更长)
|
||||
|
||||
@@ -11,6 +11,8 @@ dataset_dir: "roboimi/demos/dataset/sim_transfer"
|
||||
# ====================
|
||||
pred_horizon: ${agent.pred_horizon} # 预测步数
|
||||
obs_horizon: ${agent.obs_horizon} # 观测步数
|
||||
lewm_history_horizon: ${oc.select:agent.lewm_history_horizon,null}
|
||||
lewm_query_offsets: ${oc.select:agent.lewm_query_offsets,null}
|
||||
|
||||
# ====================
|
||||
# 相机配置
|
||||
|
||||
@@ -2,6 +2,10 @@
|
||||
# 评估配置
|
||||
ckpt_path: "checkpoints/vla_model_best.pt" # 模型检查点路径
|
||||
num_episodes: 3 # 评估回合数
|
||||
num_workers: 1 # 并行 worker 数;1 表示保持单进程评估
|
||||
cuda_devices: null # CUDA 并行评估时使用的逻辑设备列表;null 表示默认 [0]
|
||||
response_timeout_s: 300.0 # worker 等待 inference server 响应的超时时间(秒)
|
||||
server_startup_timeout_s: 300.0 # parent 等待 inference server 就绪的超时时间(秒)
|
||||
max_timesteps: 700 # 每回合最大时间步
|
||||
device: ${train.device} # 与训练保持一致
|
||||
task_name: "sim_transfer" # 环境任务名称
|
||||
|
||||
5
roboimi/vla/conf/modules/lewm_state_encoder.yaml
Normal file
5
roboimi/vla/conf/modules/lewm_state_encoder.yaml
Normal file
@@ -0,0 +1,5 @@
|
||||
_target_: roboimi.vla.modules.encoders.LeWMStateEncoder
|
||||
|
||||
input_dim: ${agent.obs_dim}
|
||||
hidden_dim: 256
|
||||
output_dim: 64
|
||||
@@ -24,6 +24,9 @@ class SimpleRobotDataset(Dataset):
|
||||
camera_names: List[str] = None,
|
||||
image_resize_shape: Optional[Sequence[int]] = (224, 224),
|
||||
max_open_files: int = 64,
|
||||
lewm_history_horizon: Optional[int] = None,
|
||||
lewm_query_offsets: Optional[Sequence[int]] = None,
|
||||
episode_indices: Optional[Sequence[int]] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -42,12 +45,22 @@ class SimpleRobotDataset(Dataset):
|
||||
self.obs_horizon = obs_horizon
|
||||
self.pred_horizon = pred_horizon
|
||||
self.camera_names = camera_names or []
|
||||
self.lewm_history_horizon = (
|
||||
int(lewm_history_horizon) if lewm_history_horizon is not None else None
|
||||
)
|
||||
self.lewm_query_offsets = (
|
||||
tuple(int(offset) for offset in lewm_query_offsets)
|
||||
if lewm_query_offsets is not None else ()
|
||||
)
|
||||
self.image_resize_shape = (
|
||||
tuple(int(v) for v in image_resize_shape)
|
||||
if image_resize_shape is not None else None
|
||||
)
|
||||
self.max_open_files = max(1, int(max_open_files))
|
||||
self._file_cache: "OrderedDict[str, h5py.File]" = OrderedDict()
|
||||
self.requested_episode_indices = (
|
||||
None if episode_indices is None else tuple(sorted(int(idx) for idx in episode_indices))
|
||||
)
|
||||
|
||||
self.dataset_dir = Path(dataset_dir)
|
||||
if not self.dataset_dir.exists():
|
||||
@@ -60,20 +73,45 @@ class SimpleRobotDataset(Dataset):
|
||||
if not self.hdf5_files:
|
||||
raise FileNotFoundError(f"在 {dataset_dir} 中未找到 HDF5 文件")
|
||||
|
||||
if self.requested_episode_indices is not None:
|
||||
requested = set(self.requested_episode_indices)
|
||||
filtered = []
|
||||
for hdf5_path in self.hdf5_files:
|
||||
stem = hdf5_path.stem
|
||||
if stem.startswith("episode_"):
|
||||
try:
|
||||
idx = int(stem.split("_")[-1])
|
||||
except ValueError:
|
||||
continue
|
||||
if idx in requested:
|
||||
filtered.append(hdf5_path)
|
||||
self.hdf5_files = filtered
|
||||
if not self.hdf5_files:
|
||||
raise FileNotFoundError(
|
||||
f"在 {dataset_dir} 中未找到 episode_indices={sorted(requested)} 对应的 HDF5 文件"
|
||||
)
|
||||
|
||||
# 构建 episode 索引(只存储元数据,不加载数据)
|
||||
self.episodes = {}
|
||||
self.frame_meta = [] # 存储 (ep_idx, frame_idx, hdf5_path)
|
||||
for ep_idx, hdf5_path in enumerate(self.hdf5_files):
|
||||
with h5py.File(hdf5_path, 'r') as f:
|
||||
T = f['action'].shape[0]
|
||||
dataset_episode_idx = ep_idx
|
||||
stem = hdf5_path.stem
|
||||
if stem.startswith("episode_"):
|
||||
try:
|
||||
dataset_episode_idx = int(stem.split("_")[-1])
|
||||
except ValueError:
|
||||
pass
|
||||
start_idx = len(self.frame_meta)
|
||||
for t in range(T):
|
||||
self.frame_meta.append({
|
||||
"ep_idx": ep_idx,
|
||||
"ep_idx": dataset_episode_idx,
|
||||
"frame_idx": t,
|
||||
"hdf5_path": hdf5_path,
|
||||
})
|
||||
self.episodes[ep_idx] = list(range(start_idx, len(self.frame_meta)))
|
||||
self.episodes[dataset_episode_idx] = list(range(start_idx, len(self.frame_meta)))
|
||||
|
||||
print(f"懒加载模式: {len(self.hdf5_files)} 个 episodes, 共 {len(self.frame_meta)} 帧")
|
||||
|
||||
@@ -220,6 +258,60 @@ class SimpleRobotDataset(Dataset):
|
||||
for cam_name in self.camera_names:
|
||||
result[f"observation.{cam_name}"] = torch.stack(observations[f"observation.{cam_name}"])
|
||||
|
||||
if self.lewm_history_horizon is not None and self.lewm_history_horizon > 0:
|
||||
lewm_observations = {
|
||||
"state": [],
|
||||
}
|
||||
for cam_name in self.camera_names:
|
||||
lewm_observations[f"observation.{cam_name}"] = []
|
||||
|
||||
for delta in range(-self.lewm_history_horizon + 1, 1):
|
||||
target_idx = idx + delta
|
||||
if ep_start <= target_idx <= ep_end:
|
||||
target_frame = self._load_frame(target_idx)
|
||||
else:
|
||||
boundary_idx = ep_start if target_idx < ep_start else ep_end
|
||||
target_frame = self._load_frame(boundary_idx)
|
||||
|
||||
lewm_observations["state"].append(target_frame["observation.state"])
|
||||
for cam_name in self.camera_names:
|
||||
lewm_observations[f"observation.{cam_name}"].append(
|
||||
target_frame[f"observation.{cam_name}"]
|
||||
)
|
||||
|
||||
result["lewm.observation.state"] = torch.stack(lewm_observations["state"])
|
||||
for cam_name in self.camera_names:
|
||||
result[f"lewm.observation.{cam_name}"] = torch.stack(
|
||||
lewm_observations[f"observation.{cam_name}"]
|
||||
)
|
||||
|
||||
if self.lewm_query_offsets:
|
||||
lewm_future = {
|
||||
"state": [],
|
||||
}
|
||||
for cam_name in self.camera_names:
|
||||
lewm_future[f"observation.{cam_name}"] = []
|
||||
|
||||
for offset in self.lewm_query_offsets:
|
||||
target_idx = idx + offset
|
||||
if ep_start <= target_idx <= ep_end:
|
||||
target_frame = self._load_frame(target_idx)
|
||||
else:
|
||||
boundary_idx = ep_start if target_idx < ep_start else ep_end
|
||||
target_frame = self._load_frame(boundary_idx)
|
||||
|
||||
lewm_future["state"].append(target_frame["observation.state"])
|
||||
for cam_name in self.camera_names:
|
||||
lewm_future[f"observation.{cam_name}"].append(
|
||||
target_frame[f"observation.{cam_name}"]
|
||||
)
|
||||
|
||||
result["lewm.future.state"] = torch.stack(lewm_future["state"])
|
||||
for cam_name in self.camera_names:
|
||||
result[f"lewm.future.{cam_name}"] = torch.stack(
|
||||
lewm_future[f"observation.{cam_name}"]
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@property
|
||||
@@ -227,6 +319,10 @@ class SimpleRobotDataset(Dataset):
|
||||
"""获取所有相机键名 (LeRobotDataset 格式)"""
|
||||
return [f"observation.{cam_name}" for cam_name in self.camera_names]
|
||||
|
||||
@property
|
||||
def available_episode_indices(self) -> List[int]:
|
||||
return sorted(self.episodes.keys())
|
||||
|
||||
@property
|
||||
def camera_info(self) -> dict:
|
||||
"""获取相机信息"""
|
||||
|
||||
@@ -1,5 +1,14 @@
|
||||
# Backbone models
|
||||
__all__ = ["LEWMViTBackbone", "ResNetBackbone", "ResNetDiffusionBackbone", "SigLIP2DiffusionBackbone"]
|
||||
__all__ = [
|
||||
"LEWMViTBackbone",
|
||||
"LeWMMultiViewResNetBackbone",
|
||||
"QueryTokenPredictor",
|
||||
"LeWMProjectorMLP",
|
||||
"SIGReg",
|
||||
"ResNetBackbone",
|
||||
"ResNetDiffusionBackbone",
|
||||
"SigLIP2DiffusionBackbone",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
@@ -9,6 +18,19 @@ def __getattr__(name):
|
||||
if name == "SigLIP2DiffusionBackbone":
|
||||
from .siglip2_diffusion_backbone import SigLIP2DiffusionBackbone
|
||||
return SigLIP2DiffusionBackbone
|
||||
if name in {"LeWMMultiViewResNetBackbone", "QueryTokenPredictor", "LeWMProjectorMLP", "SIGReg"}:
|
||||
from .lewm_resnet_query_fusion import (
|
||||
LeWMMultiViewResNetBackbone,
|
||||
QueryTokenPredictor,
|
||||
LeWMProjectorMLP,
|
||||
SIGReg,
|
||||
)
|
||||
return {
|
||||
"LeWMMultiViewResNetBackbone": LeWMMultiViewResNetBackbone,
|
||||
"QueryTokenPredictor": QueryTokenPredictor,
|
||||
"LeWMProjectorMLP": LeWMProjectorMLP,
|
||||
"SIGReg": SIGReg,
|
||||
}[name]
|
||||
if name in {"ResNetBackbone", "ResNetDiffusionBackbone"}:
|
||||
from .resnet_diffusion import ResNetDiffusionBackbone
|
||||
return ResNetDiffusionBackbone
|
||||
|
||||
409
roboimi/vla/models/backbones/lewm_resnet_query_fusion.py
Normal file
409
roboimi/vla/models/backbones/lewm_resnet_query_fusion.py
Normal file
@@ -0,0 +1,409 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Mapping, Optional, Sequence
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torchvision import models
|
||||
|
||||
from roboimi.vla.core.interfaces import VLABackbone
|
||||
|
||||
|
||||
class SpatialSoftmax2D(nn.Module):
|
||||
"""Convert a feature map into expected 2D keypoint coordinates per channel."""
|
||||
|
||||
def forward(self, feature_map):
|
||||
if feature_map.ndim != 4:
|
||||
raise ValueError(
|
||||
f"SpatialSoftmax2D expects a 4D tensor, got rank {feature_map.ndim}"
|
||||
)
|
||||
|
||||
batch, channels, height, width = feature_map.shape
|
||||
scores = feature_map.reshape(batch, channels, height * width)
|
||||
attention = F.softmax(scores, dim=-1)
|
||||
|
||||
ys = torch.linspace(-1.0, 1.0, height, device=feature_map.device, dtype=feature_map.dtype)
|
||||
xs = torch.linspace(-1.0, 1.0, width, device=feature_map.device, dtype=feature_map.dtype)
|
||||
grid_y, grid_x = torch.meshgrid(ys, xs, indexing="ij")
|
||||
grid_x = grid_x.reshape(1, 1, height * width)
|
||||
grid_y = grid_y.reshape(1, 1, height * width)
|
||||
|
||||
expected_x = (attention * grid_x).sum(dim=-1)
|
||||
expected_y = (attention * grid_y).sum(dim=-1)
|
||||
return torch.cat([expected_x, expected_y], dim=-1)
|
||||
|
||||
|
||||
class ResNet18SpatialEncoder(nn.Module):
|
||||
"""Encode one camera view into a fixed-dimensional spatial-softmax embedding."""
|
||||
|
||||
def __init__(self, view_feature_dim=96):
|
||||
super().__init__()
|
||||
if view_feature_dim % 2 != 0:
|
||||
raise ValueError("view_feature_dim must be even for spatial softmax features")
|
||||
|
||||
backbone = models.resnet18(weights=None)
|
||||
if all(
|
||||
hasattr(backbone, name)
|
||||
for name in ("conv1", "bn1", "relu", "maxpool", "layer1", "layer2", "layer3", "layer4")
|
||||
):
|
||||
self.backbone = nn.Sequential(
|
||||
backbone.conv1,
|
||||
backbone.bn1,
|
||||
backbone.relu,
|
||||
backbone.maxpool,
|
||||
backbone.layer1,
|
||||
backbone.layer2,
|
||||
backbone.layer3,
|
||||
backbone.layer4,
|
||||
)
|
||||
feature_channels = 512
|
||||
else:
|
||||
children = list(backbone.children())
|
||||
if len(children) < 1:
|
||||
raise ValueError("resnet18 backbone must expose child modules")
|
||||
truncated = children[:-2] if len(children) > 2 else children
|
||||
self.backbone = nn.Sequential(*truncated)
|
||||
with torch.no_grad():
|
||||
dummy = torch.zeros(1, 3, 16, 16)
|
||||
feature_channels = int(self.backbone(dummy).shape[1])
|
||||
|
||||
self.proj = nn.Conv2d(feature_channels, view_feature_dim // 2, kernel_size=1)
|
||||
self.spatial_softmax = SpatialSoftmax2D()
|
||||
self.output_dim = int(view_feature_dim)
|
||||
|
||||
def forward(self, pixels):
|
||||
if pixels.ndim not in (4, 5):
|
||||
raise ValueError(
|
||||
f"ResNet18SpatialEncoder expects a 4D or 5D tensor, got rank {pixels.ndim}"
|
||||
)
|
||||
|
||||
needs_unflatten = pixels.ndim == 5
|
||||
if needs_unflatten:
|
||||
batch, steps, channels, height, width = pixels.shape
|
||||
pixels = rearrange(pixels, "b t c h w -> (b t) c h w")
|
||||
|
||||
features = self.backbone(pixels.float())
|
||||
features = self.proj(features)
|
||||
embeddings = self.spatial_softmax(features)
|
||||
|
||||
if needs_unflatten:
|
||||
embeddings = rearrange(embeddings, "(b t) d -> b t d", b=batch, t=steps)
|
||||
return embeddings
|
||||
|
||||
|
||||
class LeWMMultiViewResNetBackbone(VLABackbone):
|
||||
"""RoboIMI-side LeWM multiview ResNet spatial-softmax encoder."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
view_feature_dim: int = 96,
|
||||
num_views: int = 3,
|
||||
view_encoder_mode: str = "shared",
|
||||
camera_names: Sequence[str] = ("r_vis", "top", "front"),
|
||||
checkpoint_path: str | Path | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if view_encoder_mode not in {"shared", "separate"}:
|
||||
raise ValueError(
|
||||
f"view_encoder_mode must be 'shared' or 'separate', got {view_encoder_mode}"
|
||||
)
|
||||
|
||||
self.view_feature_dim = int(view_feature_dim)
|
||||
self.num_views = int(num_views)
|
||||
self.view_encoder_mode = view_encoder_mode
|
||||
self.camera_names = tuple(camera_names)
|
||||
if len(self.camera_names) != self.num_views:
|
||||
raise ValueError(
|
||||
f"camera_names length({len(self.camera_names)}) must equal num_views({self.num_views})"
|
||||
)
|
||||
self.output_dim = self.view_feature_dim * self.num_views
|
||||
self.joint_output_dim = self.output_dim
|
||||
self.tokens_per_step = 1
|
||||
|
||||
if view_encoder_mode == "shared":
|
||||
self.single_view_encoder = ResNet18SpatialEncoder(
|
||||
view_feature_dim=view_feature_dim
|
||||
)
|
||||
self.view_encoders = None
|
||||
else:
|
||||
self.single_view_encoder = None
|
||||
self.view_encoders = nn.ModuleList(
|
||||
[
|
||||
ResNet18SpatialEncoder(view_feature_dim=view_feature_dim)
|
||||
for _ in range(num_views)
|
||||
]
|
||||
)
|
||||
|
||||
if checkpoint_path is not None:
|
||||
self.load_lewm_checkpoint(checkpoint_path)
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_state_dict(payload: Mapping[str, Any]) -> Mapping[str, torch.Tensor]:
|
||||
state_dict = payload.get("state_dict", payload)
|
||||
if not isinstance(state_dict, Mapping):
|
||||
raise TypeError("checkpoint payload must contain a mapping state_dict")
|
||||
return state_dict
|
||||
|
||||
@staticmethod
|
||||
def _extract_prefixed_state_dict(
|
||||
state_dict: Mapping[str, torch.Tensor],
|
||||
prefix: str,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
extracted = {
|
||||
key[len(prefix):]: value
|
||||
for key, value in state_dict.items()
|
||||
if key.startswith(prefix)
|
||||
}
|
||||
if not extracted:
|
||||
raise KeyError(f"checkpoint missing parameters with prefix {prefix!r}")
|
||||
return extracted
|
||||
|
||||
def load_lewm_checkpoint(self, checkpoint_or_path: str | Path | Mapping[str, Any]) -> None:
|
||||
if isinstance(checkpoint_or_path, (str, Path)):
|
||||
payload = torch.load(Path(checkpoint_or_path), map_location="cpu", weights_only=False)
|
||||
else:
|
||||
payload = checkpoint_or_path
|
||||
state_dict = self._unwrap_state_dict(payload)
|
||||
encoder_state_dict = self._extract_prefixed_state_dict(state_dict, "model.encoder.")
|
||||
self.load_state_dict(encoder_state_dict, strict=True)
|
||||
|
||||
def forward(self, images):
|
||||
missing = [camera_name for camera_name in self.camera_names if camera_name not in images]
|
||||
if missing:
|
||||
raise ValueError(
|
||||
f"image input missing required cameras. missing={missing}, expected={list(self.camera_names)}"
|
||||
)
|
||||
|
||||
first_image = images[self.camera_names[0]]
|
||||
batch_size, steps = first_image.shape[:2]
|
||||
view_embeddings = []
|
||||
if self.view_encoder_mode == "shared":
|
||||
for camera_name in self.camera_names:
|
||||
view_embeddings.append(self.single_view_encoder(images[camera_name]))
|
||||
else:
|
||||
for single_view_encoder, camera_name in zip(self.view_encoders, self.camera_names):
|
||||
view_embeddings.append(single_view_encoder(images[camera_name]))
|
||||
|
||||
embeddings = torch.cat(view_embeddings, dim=-1)
|
||||
return embeddings.reshape(batch_size, steps, self.output_dim)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout=0.0):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
self.heads = heads
|
||||
self.dropout = dropout
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
|
||||
self.to_out = (
|
||||
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
|
||||
if project_out
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, x, causal=True):
|
||||
x = self.norm(x)
|
||||
drop = self.dropout if self.training else 0.0
|
||||
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
||||
q, k, v = (rearrange(t, "b t (h d) -> b h t d", h=self.heads) for t in qkv)
|
||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=drop, is_causal=causal)
|
||||
out = rearrange(out, "b h t d -> b t (h d)")
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, dim, heads, dim_head, mlp_dim, dropout=0.0):
|
||||
super().__init__()
|
||||
self.attn = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
|
||||
self.mlp = FeedForward(dim, mlp_dim, dropout=dropout)
|
||||
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.attn(self.norm1(x))
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim,
|
||||
hidden_dim,
|
||||
output_dim,
|
||||
depth,
|
||||
heads,
|
||||
dim_head,
|
||||
mlp_dim,
|
||||
dropout=0.0,
|
||||
block_class=Block,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(hidden_dim)
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
self.input_proj = (
|
||||
nn.Linear(input_dim, hidden_dim)
|
||||
if input_dim != hidden_dim
|
||||
else nn.Identity()
|
||||
)
|
||||
self.cond_proj = (
|
||||
nn.Linear(input_dim, hidden_dim)
|
||||
if input_dim != hidden_dim
|
||||
else nn.Identity()
|
||||
)
|
||||
self.output_proj = (
|
||||
nn.Linear(hidden_dim, output_dim)
|
||||
if hidden_dim != output_dim
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(block_class(hidden_dim, heads, dim_head, mlp_dim, dropout))
|
||||
|
||||
def forward(self, x, c=None):
|
||||
x = self.input_proj(x)
|
||||
if c is not None:
|
||||
c = self.cond_proj(c)
|
||||
for block in self.layers:
|
||||
x = block(x)
|
||||
x = self.norm(x)
|
||||
return self.output_proj(x)
|
||||
|
||||
|
||||
class QueryTokenPredictor(nn.Module):
|
||||
"""History-only transformer predictor that decodes learned query tokens."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num_frames,
|
||||
query_offsets,
|
||||
depth,
|
||||
heads,
|
||||
mlp_dim,
|
||||
input_dim,
|
||||
hidden_dim,
|
||||
output_dim=None,
|
||||
dim_head=64,
|
||||
dropout=0.0,
|
||||
emb_dropout=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
if num_frames <= 0:
|
||||
raise ValueError(f"num_frames must be positive, got {num_frames}")
|
||||
|
||||
query_offsets = tuple(query_offsets)
|
||||
if not query_offsets:
|
||||
raise ValueError("query_offsets must contain at least one offset")
|
||||
if any(offset <= 0 for offset in query_offsets):
|
||||
raise ValueError(f"query_offsets must be positive, got {query_offsets}")
|
||||
|
||||
self.num_frames = int(num_frames)
|
||||
self.query_offsets = query_offsets
|
||||
self.num_query_tokens = len(query_offsets)
|
||||
self.pos_embedding = nn.Parameter(
|
||||
torch.randn(1, self.num_frames + self.num_query_tokens, input_dim)
|
||||
)
|
||||
self.query_tokens = nn.Parameter(
|
||||
torch.randn(1, self.num_query_tokens, input_dim)
|
||||
)
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
self.transformer = Transformer(
|
||||
input_dim,
|
||||
hidden_dim,
|
||||
output_dim or input_dim,
|
||||
depth,
|
||||
heads,
|
||||
dim_head,
|
||||
mlp_dim,
|
||||
dropout,
|
||||
block_class=Block,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if x.ndim != 3:
|
||||
raise ValueError(
|
||||
f"QueryTokenPredictor expects a 3D tensor, got rank {x.ndim}"
|
||||
)
|
||||
|
||||
T = x.size(1)
|
||||
if T > self.num_frames:
|
||||
raise ValueError(
|
||||
f"input sequence length {T} exceeds configured num_frames {self.num_frames}"
|
||||
)
|
||||
|
||||
query_tokens = self.query_tokens.expand(x.size(0), -1, -1)
|
||||
tokens = torch.cat([x, query_tokens], dim=1)
|
||||
tokens = tokens + self.pos_embedding[:, : tokens.size(1)]
|
||||
tokens = self.dropout(tokens)
|
||||
tokens = self.transformer(tokens)
|
||||
return tokens[:, -self.num_query_tokens :]
|
||||
|
||||
|
||||
class LeWMProjectorMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int = 288,
|
||||
hidden_dim: int = 2048,
|
||||
output_dim: int = 288,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.output_dim = int(output_dim)
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(int(input_dim), int(hidden_dim)),
|
||||
nn.BatchNorm1d(int(hidden_dim)),
|
||||
nn.GELU(),
|
||||
nn.Linear(int(hidden_dim), self.output_dim),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class SIGReg(nn.Module):
|
||||
"""Sketch Isotropic Gaussian Regularizer, matching the original LeWM design."""
|
||||
|
||||
def __init__(self, knots: int = 17, num_proj: int = 1024) -> None:
|
||||
super().__init__()
|
||||
self.num_proj = int(num_proj)
|
||||
t = torch.linspace(0, 3, int(knots), dtype=torch.float32)
|
||||
dt = 3 / (int(knots) - 1)
|
||||
weights = torch.full((int(knots),), 2 * dt, dtype=torch.float32)
|
||||
weights[[0, -1]] = dt
|
||||
window = torch.exp(-t.square() / 2.0)
|
||||
self.register_buffer("t", t)
|
||||
self.register_buffer("phi", window)
|
||||
self.register_buffer("weights", weights * window)
|
||||
|
||||
def forward(self, proj: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
proj: (T, B, D)
|
||||
"""
|
||||
A = torch.randn(proj.size(-1), self.num_proj, device=proj.device)
|
||||
A = A.div_(A.norm(p=2, dim=0))
|
||||
x_t = (proj @ A).unsqueeze(-1) * self.t
|
||||
err = (x_t.cos().mean(-3) - self.phi).square() + x_t.sin().mean(-3).square()
|
||||
statistic = (err @ self.weights) * proj.size(-2)
|
||||
return statistic.mean()
|
||||
@@ -15,4 +15,24 @@ class IdentityActionEncoder(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, action):
|
||||
return action
|
||||
return action
|
||||
|
||||
|
||||
class LeWMStateEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int = 16,
|
||||
hidden_dim: int = 256,
|
||||
output_dim: int = 64,
|
||||
):
|
||||
super().__init__()
|
||||
self.output_dim = int(output_dim)
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(int(input_dim), int(hidden_dim)),
|
||||
nn.LayerNorm(int(hidden_dim)),
|
||||
nn.GELU(),
|
||||
nn.Linear(int(hidden_dim), self.output_dim),
|
||||
)
|
||||
|
||||
def forward(self, state):
|
||||
return self.net(state)
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from roboimi.demos.vla_scripts import eval_vla
|
||||
from roboimi.vla.eval_utils import execute_policy_action
|
||||
|
||||
|
||||
@@ -14,6 +20,48 @@ class _FakeEnv:
|
||||
self.calls.append(("step_jnt", action))
|
||||
|
||||
|
||||
class _FakeQueue:
|
||||
def __init__(self, initial_items=None):
|
||||
self.items = list(initial_items or [])
|
||||
self.put_calls = []
|
||||
|
||||
def put(self, item):
|
||||
self.put_calls.append(item)
|
||||
self.items.append(item)
|
||||
|
||||
def get(self, timeout=None):
|
||||
del timeout
|
||||
if not self.items:
|
||||
raise AssertionError("queue unexpectedly empty")
|
||||
return self.items.pop(0)
|
||||
|
||||
|
||||
def _make_parallel_cfg(**eval_overrides):
|
||||
eval_cfg = {
|
||||
"ckpt_path": "checkpoints/vla_model_best.pt",
|
||||
"num_episodes": 5,
|
||||
"num_workers": 2,
|
||||
"max_timesteps": 1,
|
||||
"device": "cpu",
|
||||
"task_name": "sim_transfer",
|
||||
"camera_names": ["front"],
|
||||
"use_smoothing": False,
|
||||
"smooth_alpha": 0.3,
|
||||
"verbose_action": False,
|
||||
"headless": True,
|
||||
"artifact_dir": None,
|
||||
"save_artifacts": False,
|
||||
"save_summary_json": False,
|
||||
"save_timing": False,
|
||||
"save_trajectory": False,
|
||||
"save_trajectory_npz": False,
|
||||
"record_video": False,
|
||||
"save_trajectory_image": False,
|
||||
}
|
||||
eval_cfg.update(eval_overrides)
|
||||
return OmegaConf.create({"agent": {}, "eval": eval_cfg})
|
||||
|
||||
|
||||
class EvalVLAExecutionTest(unittest.TestCase):
|
||||
def test_execute_policy_action_uses_ee_step(self):
|
||||
env = _FakeEnv()
|
||||
@@ -23,6 +71,662 @@ class EvalVLAExecutionTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(env.calls, [("step", action)])
|
||||
|
||||
def test_split_episode_indices_balances_workers(self):
|
||||
self.assertEqual(
|
||||
eval_vla._split_episode_indices(num_episodes=10, num_workers=3),
|
||||
[[0, 1, 2, 3], [4, 5, 6], [7, 8, 9]],
|
||||
)
|
||||
|
||||
def test_normalize_num_workers_caps_worker_count_to_episode_count(self):
|
||||
self.assertEqual(eval_vla._normalize_num_workers(num_workers=5, num_episodes=2), 2)
|
||||
|
||||
def test_plan_episode_box_poses_uses_global_episode_order(self):
|
||||
planned_poses = [
|
||||
np.array([0.1, 0.2, 0.3], dtype=np.float32),
|
||||
np.array([1.1, 1.2, 1.3], dtype=np.float32),
|
||||
np.array([2.1, 2.2, 2.3], dtype=np.float32),
|
||||
]
|
||||
sampler = mock.Mock(side_effect=planned_poses)
|
||||
|
||||
result = eval_vla._plan_episode_box_poses(num_episodes=3, sampler=sampler)
|
||||
|
||||
self.assertEqual(sampler.call_count, 3)
|
||||
self.assertEqual(len(result), 3)
|
||||
for expected, actual in zip(planned_poses, result):
|
||||
np.testing.assert_array_equal(actual, expected)
|
||||
|
||||
def test_resolve_policy_camera_names_matches_vlaagent_fallback_sorting(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
"agent": {
|
||||
"_target_": "roboimi.vla.agent.VLAAgent",
|
||||
},
|
||||
"eval": {
|
||||
"camera_names": ["r_vis", "top", "front"],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
eval_vla._resolve_policy_camera_names(cfg),
|
||||
["front", "r_vis", "top"],
|
||||
)
|
||||
|
||||
def test_resolve_policy_camera_names_matches_gr00t_fallback_input_order(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
"agent": {
|
||||
"_target_": "roboimi.vla.agent_gr00t_dit.VLAAgentGr00tDiT",
|
||||
},
|
||||
"eval": {
|
||||
"camera_names": ["r_vis", "top", "front"],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
eval_vla._resolve_policy_camera_names(cfg),
|
||||
["r_vis", "top", "front"],
|
||||
)
|
||||
|
||||
def test_build_episode_plans_without_box_poses_keeps_serial_sampling_lazy(self):
|
||||
plans = eval_vla._build_episode_plans(num_episodes=3)
|
||||
|
||||
self.assertEqual(
|
||||
plans,
|
||||
[
|
||||
{"episode_index": 0},
|
||||
{"episode_index": 1},
|
||||
{"episode_index": 2},
|
||||
],
|
||||
)
|
||||
|
||||
def test_prepare_local_policy_batch_pads_latest_observation_to_obs_horizon(self):
|
||||
queues = eval_vla._new_local_policy_queues(obs_horizon=3)
|
||||
observation = {
|
||||
"qpos": torch.tensor([1.0, 2.0], dtype=torch.float32),
|
||||
"images": {
|
||||
"front": torch.tensor([[[1.0]]], dtype=torch.float32),
|
||||
},
|
||||
}
|
||||
|
||||
eval_vla._populate_local_policy_queues(queues, observation)
|
||||
batch = eval_vla._prepare_local_policy_batch(
|
||||
queues,
|
||||
obs_horizon=3,
|
||||
camera_names=["front"],
|
||||
)
|
||||
|
||||
self.assertEqual(tuple(batch["qpos"].shape), (1, 3, 2))
|
||||
self.assertEqual(tuple(batch["images"]["front"].shape), (1, 3, 1, 1, 1))
|
||||
np.testing.assert_array_equal(
|
||||
batch["qpos"][0].cpu().numpy(),
|
||||
np.array([[1.0, 2.0], [1.0, 2.0], [1.0, 2.0]], dtype=np.float32),
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
batch["images"]["front"][0].cpu().numpy(),
|
||||
np.array([[[[1.0]]], [[[1.0]]], [[[1.0]]]], dtype=np.float32),
|
||||
)
|
||||
|
||||
def test_enqueue_predicted_actions_uses_executable_slice(self):
|
||||
queues = eval_vla._new_local_policy_queues(obs_horizon=2)
|
||||
predicted_actions = torch.tensor(
|
||||
[[[10.0], [20.0], [30.0], [40.0]]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
eval_vla._enqueue_predicted_actions(
|
||||
queues,
|
||||
predicted_actions=predicted_actions,
|
||||
obs_horizon=2,
|
||||
num_action_steps=2,
|
||||
)
|
||||
|
||||
self.assertEqual(len(queues["action"]), 2)
|
||||
np.testing.assert_array_equal(queues["action"].popleft().numpy(), np.array([20.0], dtype=np.float32))
|
||||
np.testing.assert_array_equal(queues["action"].popleft().numpy(), np.array([30.0], dtype=np.float32))
|
||||
|
||||
def test_remote_policy_runner_only_requests_server_inference_when_local_action_queue_is_empty(self):
|
||||
request_queue = _FakeQueue()
|
||||
response_queue = _FakeQueue(
|
||||
[
|
||||
{
|
||||
"type": "predict_chunk_result",
|
||||
"actions": np.asarray([[[10.0], [20.0], [30.0]]], dtype=np.float32),
|
||||
}
|
||||
]
|
||||
)
|
||||
runner = eval_vla._RemotePolicyRunner(
|
||||
worker_index=3,
|
||||
server_index=1,
|
||||
request_queue=request_queue,
|
||||
response_queue=response_queue,
|
||||
camera_names=["front"],
|
||||
obs_horizon=2,
|
||||
num_action_steps=2,
|
||||
)
|
||||
first_observation = {
|
||||
"qpos": torch.tensor([1.0, 2.0], dtype=torch.float32),
|
||||
"images": {"front": torch.tensor([[[1.0]]], dtype=torch.float32)},
|
||||
}
|
||||
second_observation = {
|
||||
"qpos": torch.tensor([3.0, 4.0], dtype=torch.float32),
|
||||
"images": {"front": torch.tensor([[[2.0]]], dtype=torch.float32)},
|
||||
}
|
||||
|
||||
first_action, first_forward = runner.select_action(
|
||||
first_observation,
|
||||
episode_index=7,
|
||||
timestep=0,
|
||||
)
|
||||
second_action, second_forward = runner.select_action(
|
||||
second_observation,
|
||||
episode_index=7,
|
||||
timestep=1,
|
||||
)
|
||||
|
||||
self.assertTrue(first_forward)
|
||||
self.assertFalse(second_forward)
|
||||
self.assertEqual(len(request_queue.put_calls), 1)
|
||||
self.assertEqual(request_queue.put_calls[0]["type"], "predict_chunk")
|
||||
self.assertEqual(request_queue.put_calls[0]["worker_index"], 3)
|
||||
self.assertEqual(request_queue.put_calls[0]["server_index"], 1)
|
||||
np.testing.assert_array_equal(first_action.numpy(), np.array([20.0], dtype=np.float32))
|
||||
np.testing.assert_array_equal(second_action.numpy(), np.array([30.0], dtype=np.float32))
|
||||
|
||||
def test_merge_worker_summaries_sorts_episodes_and_recomputes_aggregates(self):
|
||||
worker_summaries = [
|
||||
{
|
||||
"avg_inference_fps": 999.0,
|
||||
"avg_control_fps": 999.0,
|
||||
"avg_obs_read_time_ms": 999.0,
|
||||
"avg_total_time_ms": 999.0,
|
||||
"timing_summary": {"count": 999, "model_forward_count": 999},
|
||||
"episodes": [
|
||||
{
|
||||
"episode_index": 2,
|
||||
"episode_reward": 9.0,
|
||||
"episode_max_reward": 4.0,
|
||||
"inference_fps": 30.0,
|
||||
"control_fps": 15.0,
|
||||
}
|
||||
],
|
||||
"_merge_state": {
|
||||
"obs_read_time_ms": [9.0],
|
||||
"preprocess_time_ms": [1.0],
|
||||
"inference_time_ms": [3.0],
|
||||
"env_step_time_ms": [4.0],
|
||||
"total_time_ms": [10.0],
|
||||
"model_forward_flags": [False],
|
||||
},
|
||||
},
|
||||
{
|
||||
"avg_inference_fps": 888.0,
|
||||
"avg_control_fps": 888.0,
|
||||
"avg_obs_read_time_ms": 888.0,
|
||||
"avg_total_time_ms": 888.0,
|
||||
"timing_summary": {"count": 888, "model_forward_count": 888},
|
||||
"episodes": [
|
||||
{
|
||||
"episode_index": 1,
|
||||
"episode_reward": 6.0,
|
||||
"episode_max_reward": 3.0,
|
||||
"inference_fps": 20.0,
|
||||
"control_fps": 10.0,
|
||||
},
|
||||
{
|
||||
"episode_index": 0,
|
||||
"episode_reward": 5.0,
|
||||
"episode_max_reward": 2.0,
|
||||
"inference_fps": 10.0,
|
||||
"control_fps": 5.0,
|
||||
},
|
||||
],
|
||||
"_merge_state": {
|
||||
"obs_read_time_ms": [1.0, 2.0, 12.0],
|
||||
"preprocess_time_ms": [2.0, 3.0, 4.0],
|
||||
"inference_time_ms": [4.0, 5.0, 6.0],
|
||||
"env_step_time_ms": [6.0, 7.0, 8.0],
|
||||
"total_time_ms": [8.0, 9.0, 20.0],
|
||||
"model_forward_flags": [True, False, True],
|
||||
},
|
||||
},
|
||||
]
|
||||
artifact_paths = {
|
||||
"output_dir": "/tmp/merged",
|
||||
"summary_json": "/tmp/merged/rollout_summary.json",
|
||||
"timing_json": "/tmp/merged/timing.json",
|
||||
"trajectory_npz": None,
|
||||
"video_mp4": None,
|
||||
"video_camera_name": None,
|
||||
}
|
||||
|
||||
merged = eval_vla._merge_worker_summaries(worker_summaries, artifact_paths)
|
||||
|
||||
self.assertEqual([episode["episode_index"] for episode in merged["episodes"]], [0, 1, 2])
|
||||
self.assertEqual(merged["episode_rewards"], [5.0, 6.0, 9.0])
|
||||
self.assertEqual(merged["episode_max_rewards"], [2.0, 3.0, 4.0])
|
||||
self.assertAlmostEqual(merged["avg_reward"], 20.0 / 3.0)
|
||||
self.assertAlmostEqual(merged["avg_max_reward"], 3.0)
|
||||
self.assertAlmostEqual(merged["avg_inference_fps"], 20.0)
|
||||
self.assertAlmostEqual(merged["avg_control_fps"], 10.0)
|
||||
self.assertAlmostEqual(merged["avg_obs_read_time_ms"], 6.0)
|
||||
self.assertAlmostEqual(merged["avg_total_time_ms"], 47.0 / 4.0)
|
||||
self.assertEqual(merged["timing_summary"]["count"], 4)
|
||||
self.assertEqual(merged["timing_summary"]["model_forward_count"], 2)
|
||||
self.assertEqual(merged["artifact_dir"], "/tmp/merged")
|
||||
self.assertEqual(merged["artifacts"], artifact_paths)
|
||||
|
||||
def test_build_cuda_server_payloads_uses_round_robin_worker_assignment(self):
|
||||
cfg = _make_parallel_cfg(num_episodes=4, num_workers=4, device="cuda", cuda_devices=[0, 1])
|
||||
artifact_paths = {"output_dir": None}
|
||||
|
||||
with mock.patch.object(
|
||||
eval_vla,
|
||||
"sample_transfer_pose",
|
||||
side_effect=[
|
||||
np.array([0.1, 0.2, 0.3], dtype=np.float32),
|
||||
np.array([0.4, 0.5, 0.6], dtype=np.float32),
|
||||
np.array([0.7, 0.8, 0.9], dtype=np.float32),
|
||||
np.array([1.0, 1.1, 1.2], dtype=np.float32),
|
||||
],
|
||||
):
|
||||
worker_payloads, _ = eval_vla._build_parallel_worker_payloads(cfg, artifact_paths)
|
||||
|
||||
server_payloads, assigned_workers = eval_vla._build_cuda_server_payloads(
|
||||
cfg,
|
||||
worker_payloads=worker_payloads,
|
||||
cuda_devices=[0, 1],
|
||||
)
|
||||
|
||||
self.assertEqual([payload["device_index"] for payload in server_payloads], [0, 1])
|
||||
self.assertEqual([payload["worker_index"] for payload in assigned_workers], [0, 1, 2, 3])
|
||||
self.assertEqual([payload["server_index"] for payload in assigned_workers], [0, 1, 0, 1])
|
||||
self.assertEqual(server_payloads[0]["worker_indices"], [0, 2])
|
||||
self.assertEqual(server_payloads[1]["worker_indices"], [1, 3])
|
||||
|
||||
def test_run_eval_parallel_dispatches_episode_splits_and_box_poses(self):
|
||||
cfg = _make_parallel_cfg(num_episodes=5, num_workers=2, artifact_dir="/tmp/parallel-root")
|
||||
planned_poses = [
|
||||
np.array([float(index), float(index) + 0.1, float(index) + 0.2], dtype=np.float32)
|
||||
for index in range(5)
|
||||
]
|
||||
observed_payloads = []
|
||||
|
||||
def fake_run_spawn_jobs(payloads, max_workers, worker_fn):
|
||||
del worker_fn
|
||||
self.assertEqual(max_workers, 2)
|
||||
observed_payloads.extend(payloads)
|
||||
return [
|
||||
{
|
||||
"episodes": [
|
||||
{
|
||||
"episode_index": 4,
|
||||
"episode_reward": 5.0,
|
||||
"episode_max_reward": 5.0,
|
||||
"inference_fps": 50.0,
|
||||
"control_fps": 25.0,
|
||||
},
|
||||
{
|
||||
"episode_index": 3,
|
||||
"episode_reward": 4.0,
|
||||
"episode_max_reward": 4.0,
|
||||
"inference_fps": 40.0,
|
||||
"control_fps": 20.0,
|
||||
},
|
||||
],
|
||||
"_merge_state": {
|
||||
"obs_read_time_ms": [4.0, 5.0],
|
||||
"preprocess_time_ms": [1.0, 1.0],
|
||||
"inference_time_ms": [2.0, 2.0],
|
||||
"env_step_time_ms": [3.0, 3.0],
|
||||
"total_time_ms": [4.0, 5.0],
|
||||
"model_forward_flags": [True, True],
|
||||
},
|
||||
},
|
||||
{
|
||||
"episodes": [
|
||||
{
|
||||
"episode_index": 2,
|
||||
"episode_reward": 3.0,
|
||||
"episode_max_reward": 3.0,
|
||||
"inference_fps": 30.0,
|
||||
"control_fps": 15.0,
|
||||
},
|
||||
{
|
||||
"episode_index": 1,
|
||||
"episode_reward": 2.0,
|
||||
"episode_max_reward": 2.0,
|
||||
"inference_fps": 20.0,
|
||||
"control_fps": 10.0,
|
||||
},
|
||||
{
|
||||
"episode_index": 0,
|
||||
"episode_reward": 1.0,
|
||||
"episode_max_reward": 1.0,
|
||||
"inference_fps": 10.0,
|
||||
"control_fps": 5.0,
|
||||
},
|
||||
],
|
||||
"_merge_state": {
|
||||
"obs_read_time_ms": [1.0, 2.0, 3.0],
|
||||
"preprocess_time_ms": [1.0, 1.0, 1.0],
|
||||
"inference_time_ms": [2.0, 2.0, 2.0],
|
||||
"env_step_time_ms": [3.0, 3.0, 3.0],
|
||||
"total_time_ms": [1.0, 2.0, 3.0],
|
||||
"model_forward_flags": [False, True, False],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
with mock.patch.object(
|
||||
eval_vla,
|
||||
"sample_transfer_pose",
|
||||
side_effect=planned_poses,
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
"_run_spawn_jobs",
|
||||
side_effect=fake_run_spawn_jobs,
|
||||
):
|
||||
summary = eval_vla._run_eval_parallel(cfg)
|
||||
|
||||
self.assertEqual(len(observed_payloads), 2)
|
||||
self.assertEqual(
|
||||
[[plan["episode_index"] for plan in payload["episode_plans"]] for payload in observed_payloads],
|
||||
[[0, 1, 2], [3, 4]],
|
||||
)
|
||||
for payload in observed_payloads:
|
||||
for plan in payload["episode_plans"]:
|
||||
np.testing.assert_array_equal(
|
||||
np.asarray(plan["box_pos"], dtype=np.float32),
|
||||
planned_poses[plan["episode_index"]],
|
||||
)
|
||||
self.assertEqual([episode["episode_index"] for episode in summary["episodes"]], [0, 1, 2, 3, 4])
|
||||
self.assertEqual(summary["episode_rewards"], [1.0, 2.0, 3.0, 4.0, 5.0])
|
||||
self.assertEqual(summary["num_episodes"], 5)
|
||||
|
||||
def test_run_eval_parallel_allows_trajectory_images_and_keeps_worker_artifact_paths(self):
|
||||
cfg = _make_parallel_cfg(
|
||||
num_episodes=2,
|
||||
num_workers=2,
|
||||
artifact_dir="/tmp/parallel-images",
|
||||
save_summary_json=True,
|
||||
save_trajectory_image=True,
|
||||
)
|
||||
observed_payloads = []
|
||||
|
||||
def fake_run_spawn_jobs(payloads, max_workers, worker_fn):
|
||||
del worker_fn
|
||||
self.assertEqual(max_workers, 2)
|
||||
observed_payloads.extend(payloads)
|
||||
return [
|
||||
{
|
||||
"episodes": [
|
||||
{
|
||||
"episode_index": 0,
|
||||
"episode_reward": 1.0,
|
||||
"episode_max_reward": 1.0,
|
||||
"inference_fps": 10.0,
|
||||
"control_fps": 5.0,
|
||||
"artifact_paths": {
|
||||
"trajectory_image": f"{payloads[0]['artifact_dir']}/rollout_front_ep01_trajectory.png",
|
||||
},
|
||||
},
|
||||
],
|
||||
"_merge_state": {
|
||||
"obs_read_time_ms": [1.0],
|
||||
"preprocess_time_ms": [1.0],
|
||||
"inference_time_ms": [1.0],
|
||||
"env_step_time_ms": [1.0],
|
||||
"total_time_ms": [1.0],
|
||||
"model_forward_flags": [True],
|
||||
},
|
||||
},
|
||||
{
|
||||
"episodes": [
|
||||
{
|
||||
"episode_index": 1,
|
||||
"episode_reward": 2.0,
|
||||
"episode_max_reward": 2.0,
|
||||
"inference_fps": 20.0,
|
||||
"control_fps": 10.0,
|
||||
"artifact_paths": {
|
||||
"trajectory_image": f"{payloads[1]['artifact_dir']}/rollout_front_ep02_trajectory.png",
|
||||
},
|
||||
},
|
||||
],
|
||||
"_merge_state": {
|
||||
"obs_read_time_ms": [2.0],
|
||||
"preprocess_time_ms": [2.0],
|
||||
"inference_time_ms": [2.0],
|
||||
"env_step_time_ms": [2.0],
|
||||
"total_time_ms": [2.0],
|
||||
"model_forward_flags": [False],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
with mock.patch.object(
|
||||
eval_vla,
|
||||
"sample_transfer_pose",
|
||||
side_effect=[
|
||||
np.array([0.1, 0.2, 0.3], dtype=np.float32),
|
||||
np.array([0.4, 0.5, 0.6], dtype=np.float32),
|
||||
],
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
"_run_spawn_jobs",
|
||||
side_effect=fake_run_spawn_jobs,
|
||||
):
|
||||
summary = eval_vla._run_eval_parallel(cfg)
|
||||
|
||||
self.assertEqual(len(observed_payloads), 2)
|
||||
self.assertTrue(observed_payloads[0]["artifact_dir"].endswith("workers/worker_00"))
|
||||
self.assertTrue(observed_payloads[1]["artifact_dir"].endswith("workers/worker_01"))
|
||||
self.assertTrue(
|
||||
summary["episodes"][0]["artifact_paths"]["trajectory_image"].endswith(
|
||||
"workers/worker_00/rollout_front_ep01_trajectory.png"
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
summary["episodes"][1]["artifact_paths"]["trajectory_image"].endswith(
|
||||
"workers/worker_01/rollout_front_ep02_trajectory.png"
|
||||
)
|
||||
)
|
||||
|
||||
def test_run_eval_parallel_surfaces_worker_failures(self):
|
||||
cfg = _make_parallel_cfg(num_episodes=2, num_workers=2)
|
||||
|
||||
with mock.patch.object(
|
||||
eval_vla,
|
||||
"sample_transfer_pose",
|
||||
side_effect=[
|
||||
np.array([0.1, 0.2, 0.3], dtype=np.float32),
|
||||
np.array([0.4, 0.5, 0.6], dtype=np.float32),
|
||||
],
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
"_run_spawn_jobs",
|
||||
side_effect=RuntimeError("boom"),
|
||||
):
|
||||
with self.assertRaisesRegex(RuntimeError, "Parallel rollout worker failed"):
|
||||
eval_vla._run_eval_parallel(cfg)
|
||||
|
||||
def test_run_eval_parallel_cuda_builds_server_payloads_and_merges_worker_results(self):
|
||||
cfg = _make_parallel_cfg(
|
||||
num_episodes=4,
|
||||
num_workers=4,
|
||||
device="cuda",
|
||||
cuda_devices=[0],
|
||||
artifact_dir="/tmp/cuda-root",
|
||||
)
|
||||
observed_server_payloads = []
|
||||
observed_worker_payloads = []
|
||||
|
||||
def fake_run_cuda_parallel_processes(server_payloads, worker_payloads):
|
||||
observed_server_payloads.extend(server_payloads)
|
||||
observed_worker_payloads.extend(worker_payloads)
|
||||
return [
|
||||
{
|
||||
"episodes": [
|
||||
{
|
||||
"episode_index": 2,
|
||||
"episode_reward": 3.0,
|
||||
"episode_max_reward": 3.0,
|
||||
"inference_fps": 30.0,
|
||||
"control_fps": 15.0,
|
||||
},
|
||||
{
|
||||
"episode_index": 0,
|
||||
"episode_reward": 1.0,
|
||||
"episode_max_reward": 1.0,
|
||||
"inference_fps": 10.0,
|
||||
"control_fps": 5.0,
|
||||
},
|
||||
],
|
||||
"_merge_state": {
|
||||
"obs_read_time_ms": [1.0, 2.0],
|
||||
"preprocess_time_ms": [1.0, 1.0],
|
||||
"inference_time_ms": [2.0, 2.0],
|
||||
"env_step_time_ms": [3.0, 3.0],
|
||||
"total_time_ms": [4.0, 4.0],
|
||||
"model_forward_flags": [True, False],
|
||||
},
|
||||
},
|
||||
{
|
||||
"episodes": [
|
||||
{
|
||||
"episode_index": 3,
|
||||
"episode_reward": 4.0,
|
||||
"episode_max_reward": 4.0,
|
||||
"inference_fps": 40.0,
|
||||
"control_fps": 20.0,
|
||||
},
|
||||
{
|
||||
"episode_index": 1,
|
||||
"episode_reward": 2.0,
|
||||
"episode_max_reward": 2.0,
|
||||
"inference_fps": 20.0,
|
||||
"control_fps": 10.0,
|
||||
},
|
||||
],
|
||||
"_merge_state": {
|
||||
"obs_read_time_ms": [3.0, 4.0],
|
||||
"preprocess_time_ms": [1.0, 1.0],
|
||||
"inference_time_ms": [2.0, 2.0],
|
||||
"env_step_time_ms": [3.0, 3.0],
|
||||
"total_time_ms": [4.0, 4.0],
|
||||
"model_forward_flags": [True, True],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
with mock.patch.object(
|
||||
eval_vla,
|
||||
"sample_transfer_pose",
|
||||
side_effect=[
|
||||
np.array([0.1, 0.2, 0.3], dtype=np.float32),
|
||||
np.array([0.4, 0.5, 0.6], dtype=np.float32),
|
||||
np.array([0.7, 0.8, 0.9], dtype=np.float32),
|
||||
np.array([1.0, 1.1, 1.2], dtype=np.float32),
|
||||
],
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
"_run_cuda_parallel_processes",
|
||||
side_effect=fake_run_cuda_parallel_processes,
|
||||
create=True,
|
||||
):
|
||||
summary = eval_vla._run_eval_parallel_cuda(cfg)
|
||||
|
||||
self.assertEqual(len(observed_server_payloads), 1)
|
||||
self.assertEqual(observed_server_payloads[0]["device_index"], 0)
|
||||
self.assertEqual(len(observed_worker_payloads), 4)
|
||||
self.assertTrue(all(payload["server_index"] == 0 for payload in observed_worker_payloads))
|
||||
self.assertEqual([episode["episode_index"] for episode in summary["episodes"]], [0, 1, 2, 3])
|
||||
self.assertEqual(summary["episode_rewards"], [1.0, 2.0, 3.0, 4.0])
|
||||
self.assertEqual(summary["num_episodes"], 4)
|
||||
|
||||
def test_run_eval_parallel_cuda_surfaces_server_failures(self):
|
||||
cfg = _make_parallel_cfg(num_episodes=2, num_workers=2, device="cuda", cuda_devices=[0])
|
||||
|
||||
with mock.patch.object(
|
||||
eval_vla,
|
||||
"sample_transfer_pose",
|
||||
side_effect=[
|
||||
np.array([0.1, 0.2, 0.3], dtype=np.float32),
|
||||
np.array([0.4, 0.5, 0.6], dtype=np.float32),
|
||||
],
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
"_run_cuda_parallel_processes",
|
||||
side_effect=RuntimeError("server boom"),
|
||||
create=True,
|
||||
):
|
||||
with self.assertRaisesRegex(RuntimeError, "Parallel CUDA rollout failed"):
|
||||
eval_vla._run_eval_parallel_cuda(cfg)
|
||||
|
||||
def test_run_spawn_jobs_supports_real_spawn_with_actual_eval_worker_entry(self):
|
||||
payloads = [
|
||||
{"_spawn_probe": True, "probe_value": 1, "worker_index": 0},
|
||||
{"_spawn_probe": True, "probe_value": 2, "worker_index": 1},
|
||||
]
|
||||
|
||||
results = eval_vla._run_spawn_jobs(
|
||||
payloads=payloads,
|
||||
max_workers=2,
|
||||
worker_fn=eval_vla._run_eval_worker_entry,
|
||||
)
|
||||
|
||||
self.assertEqual(sorted(result["probe_value"] for result in results), [1, 2])
|
||||
self.assertEqual(sorted(result["worker_index"] for result in results), [0, 1])
|
||||
|
||||
def test_cuda_server_and_env_worker_entrypoints_support_real_spawn_probe(self):
|
||||
ctx = eval_vla.multiprocessing.get_context("spawn")
|
||||
request_queue = ctx.Queue()
|
||||
response_queue = ctx.Queue()
|
||||
result_queue = ctx.Queue()
|
||||
|
||||
server = ctx.Process(
|
||||
target=eval_vla._inference_server_main,
|
||||
args=(
|
||||
{
|
||||
"_spawn_probe": True,
|
||||
"server_index": 0,
|
||||
"request_queue": request_queue,
|
||||
"response_queues": [response_queue],
|
||||
},
|
||||
),
|
||||
)
|
||||
worker = ctx.Process(
|
||||
target=eval_vla._env_worker_main,
|
||||
args=(
|
||||
{
|
||||
"_spawn_probe": True,
|
||||
"worker_index": 0,
|
||||
"server_index": 0,
|
||||
"request_queue": request_queue,
|
||||
"response_queue": response_queue,
|
||||
"result_queue": result_queue,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
server.start()
|
||||
worker.start()
|
||||
|
||||
result = result_queue.get(timeout=10.0)
|
||||
|
||||
worker.join(timeout=10.0)
|
||||
request_queue.put({"type": "shutdown_server"})
|
||||
server.join(timeout=10.0)
|
||||
|
||||
self.assertEqual(result["kind"], "worker_result")
|
||||
self.assertEqual(result["summary"]["probe_worker_index"], 0)
|
||||
self.assertEqual(result["summary"]["probe_server_index"], 0)
|
||||
self.assertEqual(result["summary"]["probe_actions"], [[[11.0], [22.0], [33.0]]])
|
||||
self.assertEqual(worker.exitcode, 0)
|
||||
self.assertEqual(server.exitcode, 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -126,6 +126,26 @@ class EvalVLAHeadlessTest(unittest.TestCase):
|
||||
self.assertIn("headless", eval_cfg)
|
||||
self.assertFalse(eval_cfg.headless)
|
||||
|
||||
def test_eval_config_exposes_num_workers_default(self):
|
||||
eval_cfg = OmegaConf.load(Path("roboimi/vla/conf/eval/eval.yaml"))
|
||||
|
||||
self.assertIn("num_workers", eval_cfg)
|
||||
self.assertEqual(eval_cfg.num_workers, 1)
|
||||
|
||||
def test_eval_config_exposes_cuda_devices_default(self):
|
||||
eval_cfg = OmegaConf.load(Path("roboimi/vla/conf/eval/eval.yaml"))
|
||||
|
||||
self.assertIn("cuda_devices", eval_cfg)
|
||||
self.assertIsNone(eval_cfg.cuda_devices)
|
||||
|
||||
def test_eval_config_exposes_parallel_timeout_defaults(self):
|
||||
eval_cfg = OmegaConf.load(Path("roboimi/vla/conf/eval/eval.yaml"))
|
||||
|
||||
self.assertIn("response_timeout_s", eval_cfg)
|
||||
self.assertIn("server_startup_timeout_s", eval_cfg)
|
||||
self.assertEqual(eval_cfg.response_timeout_s, 300.0)
|
||||
self.assertEqual(eval_cfg.server_startup_timeout_s, 300.0)
|
||||
|
||||
def test_make_sim_env_accepts_headless_and_disables_render(self):
|
||||
fake_env = object()
|
||||
|
||||
@@ -327,6 +347,172 @@ class EvalVLAHeadlessTest(unittest.TestCase):
|
||||
self.assertAlmostEqual(summary["avg_reward"], 3.75)
|
||||
self.assertEqual(summary["num_episodes"], 2)
|
||||
|
||||
def test_run_eval_uses_serial_path_when_num_workers_is_one(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
"eval": {
|
||||
"num_workers": 1,
|
||||
"num_episodes": 3,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
with mock.patch.object(
|
||||
eval_vla,
|
||||
"_run_eval_serial",
|
||||
return_value={"mode": "serial"},
|
||||
) as run_eval_serial, mock.patch.object(
|
||||
eval_vla,
|
||||
"_run_eval_parallel",
|
||||
) as run_eval_parallel:
|
||||
result = eval_vla._run_eval(cfg)
|
||||
|
||||
self.assertEqual(result, {"mode": "serial"})
|
||||
run_eval_serial.assert_called_once_with(cfg)
|
||||
run_eval_parallel.assert_not_called()
|
||||
|
||||
def test_run_eval_uses_serial_path_when_requested_workers_collapse_to_one(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
"eval": {
|
||||
"num_workers": 8,
|
||||
"num_episodes": 1,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
with mock.patch.object(
|
||||
eval_vla,
|
||||
"_run_eval_serial",
|
||||
return_value={"mode": "serial"},
|
||||
) as run_eval_serial, mock.patch.object(
|
||||
eval_vla,
|
||||
"_run_eval_parallel",
|
||||
) as run_eval_parallel:
|
||||
result = eval_vla._run_eval(cfg)
|
||||
|
||||
self.assertEqual(result, {"mode": "serial"})
|
||||
run_eval_serial.assert_called_once_with(cfg)
|
||||
run_eval_parallel.assert_not_called()
|
||||
|
||||
def test_run_eval_parallel_requires_headless_true(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
"agent": {},
|
||||
"eval": {
|
||||
"ckpt_path": "checkpoints/vla_model_best.pt",
|
||||
"num_episodes": 2,
|
||||
"num_workers": 2,
|
||||
"max_timesteps": 1,
|
||||
"device": "cpu",
|
||||
"task_name": "sim_transfer",
|
||||
"camera_names": ["front"],
|
||||
"use_smoothing": False,
|
||||
"smooth_alpha": 0.3,
|
||||
"verbose_action": False,
|
||||
"headless": False,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "headless=true"):
|
||||
eval_vla._run_eval_parallel(cfg)
|
||||
|
||||
def test_run_eval_parallel_dispatches_to_cpu_workers_when_device_is_cpu(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
"agent": {},
|
||||
"eval": {
|
||||
"ckpt_path": "checkpoints/vla_model_best.pt",
|
||||
"num_episodes": 2,
|
||||
"num_workers": 2,
|
||||
"max_timesteps": 1,
|
||||
"device": "cpu",
|
||||
"task_name": "sim_transfer",
|
||||
"camera_names": ["front"],
|
||||
"use_smoothing": False,
|
||||
"smooth_alpha": 0.3,
|
||||
"verbose_action": False,
|
||||
"headless": True,
|
||||
"cuda_devices": None,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
with mock.patch.object(
|
||||
eval_vla,
|
||||
"_run_eval_parallel_cpu",
|
||||
return_value={"mode": "cpu"},
|
||||
create=True,
|
||||
) as run_cpu_parallel, mock.patch.object(
|
||||
eval_vla,
|
||||
"_run_eval_parallel_cuda",
|
||||
create=True,
|
||||
) as run_cuda_parallel:
|
||||
result = eval_vla._run_eval_parallel(cfg)
|
||||
|
||||
self.assertEqual(result, {"mode": "cpu"})
|
||||
run_cpu_parallel.assert_called_once_with(cfg)
|
||||
run_cuda_parallel.assert_not_called()
|
||||
|
||||
def test_run_eval_parallel_dispatches_to_cuda_servers_when_device_is_cuda(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
"agent": {},
|
||||
"eval": {
|
||||
"ckpt_path": "checkpoints/vla_model_best.pt",
|
||||
"num_episodes": 2,
|
||||
"num_workers": 2,
|
||||
"max_timesteps": 1,
|
||||
"device": "cuda",
|
||||
"task_name": "sim_transfer",
|
||||
"camera_names": ["front"],
|
||||
"use_smoothing": False,
|
||||
"smooth_alpha": 0.3,
|
||||
"verbose_action": False,
|
||||
"headless": True,
|
||||
"cuda_devices": [0],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
with mock.patch.object(
|
||||
eval_vla,
|
||||
"_run_eval_parallel_cpu",
|
||||
create=True,
|
||||
) as run_cpu_parallel, mock.patch.object(
|
||||
eval_vla,
|
||||
"_run_eval_parallel_cuda",
|
||||
return_value={"mode": "cuda"},
|
||||
create=True,
|
||||
) as run_cuda_parallel:
|
||||
result = eval_vla._run_eval_parallel(cfg)
|
||||
|
||||
self.assertEqual(result, {"mode": "cuda"})
|
||||
run_cpu_parallel.assert_not_called()
|
||||
run_cuda_parallel.assert_called_once_with(cfg)
|
||||
|
||||
def test_resolve_cuda_devices_defaults_to_single_logical_gpu(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
"device": "cuda",
|
||||
"cuda_devices": None,
|
||||
}
|
||||
)
|
||||
|
||||
self.assertEqual(eval_vla._resolve_cuda_devices(cfg), [0])
|
||||
|
||||
def test_resolve_cuda_devices_rejects_empty_selection(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
"device": "cuda",
|
||||
"cuda_devices": [],
|
||||
}
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "cuda_devices"):
|
||||
eval_vla._resolve_cuda_devices(cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -376,6 +376,49 @@ class _ForbiddenScheduler:
|
||||
raise AssertionError('IMF inference should not use DDIM scheduler step')
|
||||
|
||||
|
||||
class _StubFutureTokenPredictor(nn.Module):
|
||||
def __init__(self, num_future_tokens=1):
|
||||
super().__init__()
|
||||
self.num_future_tokens = int(num_future_tokens)
|
||||
self.calls = []
|
||||
|
||||
def forward(self, history_tokens):
|
||||
self.calls.append(history_tokens.detach().clone())
|
||||
summary = history_tokens.mean(dim=1, keepdim=True)
|
||||
return summary.repeat(1, self.num_future_tokens, 1)
|
||||
|
||||
|
||||
class _RecordingDirectFutureDecoder(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scale = nn.Parameter(torch.tensor(0.5))
|
||||
self.calls = []
|
||||
|
||||
def forward(self, sample, r, t, cond=None):
|
||||
record = {
|
||||
'sample': sample.detach().clone(),
|
||||
'r': r.detach().clone(),
|
||||
't': t.detach().clone(),
|
||||
'cond': None if cond is None else cond.detach().clone(),
|
||||
}
|
||||
self.calls.append(record)
|
||||
cond_term = 0.0
|
||||
if cond is not None:
|
||||
cond_term = cond.mean(dim=1, keepdim=True)
|
||||
return self.scale * sample + cond_term
|
||||
|
||||
|
||||
class _RecordingSigReg(nn.Module):
|
||||
def __init__(self, value=0.5):
|
||||
super().__init__()
|
||||
self.value = float(value)
|
||||
self.calls = []
|
||||
|
||||
def forward(self, embeddings):
|
||||
self.calls.append(embeddings.detach().clone())
|
||||
return embeddings.new_tensor(self.value)
|
||||
|
||||
|
||||
def _make_images(batch_size, obs_horizon, per_camera_fill):
|
||||
return {
|
||||
name: torch.full((batch_size, obs_horizon, 1, 2, 2), fill_value=value, dtype=torch.float32)
|
||||
@@ -501,6 +544,311 @@ class IMFVLAAgentTest(unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(head.calls[0]['t'], torch.ones(2)))
|
||||
self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_cond))
|
||||
|
||||
def test_predict_action_appends_lewm_future_tokens_to_history_conditioning(self):
|
||||
agent_cls, agent_module = _load_imf_agent_class()
|
||||
head = _RecordingLinearIMFHead()
|
||||
future_predictor = _StubFutureTokenPredictor(num_future_tokens=1)
|
||||
agent = agent_cls(
|
||||
vision_backbone=_StubVisionBackbone(),
|
||||
state_encoder=nn.Identity(),
|
||||
action_encoder=nn.Identity(),
|
||||
head=head,
|
||||
action_dim=2,
|
||||
obs_dim=1,
|
||||
pred_horizon=3,
|
||||
obs_horizon=2,
|
||||
diffusion_steps=10,
|
||||
inference_steps=1,
|
||||
num_cams=len(_CAMERA_NAMES),
|
||||
camera_names=_CAMERA_NAMES,
|
||||
num_action_steps=2,
|
||||
head_type='transformer',
|
||||
extra_condition_tokens=1,
|
||||
lewm_history_horizon=3,
|
||||
lewm_query_offsets=[8],
|
||||
lewm_predictor=future_predictor,
|
||||
lewm_pred_projector=nn.Identity(),
|
||||
lewm_loss_weight=0.5,
|
||||
)
|
||||
agent.infer_scheduler = _ForbiddenScheduler()
|
||||
|
||||
images = _make_images(
|
||||
batch_size=1,
|
||||
obs_horizon=2,
|
||||
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
|
||||
)
|
||||
qpos = torch.tensor([[[1.0], [2.0]]], dtype=torch.float32)
|
||||
lewm_images = _make_images(
|
||||
batch_size=1,
|
||||
obs_horizon=3,
|
||||
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
|
||||
)
|
||||
lewm_qpos = torch.tensor([[[0.5], [1.5], [2.5]]], dtype=torch.float32)
|
||||
initial_noise = torch.tensor(
|
||||
[[[1.0, -1.0], [0.0, 2.0], [3.0, -2.0]]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
with mock.patch.object(agent_module.torch, 'randn', return_value=initial_noise):
|
||||
_ = agent.predict_action(
|
||||
images,
|
||||
qpos,
|
||||
lewm_images=lewm_images,
|
||||
lewm_proprioception=lewm_qpos,
|
||||
)
|
||||
|
||||
expected_history = torch.tensor(
|
||||
[[[10.0, 20.0, 30.0, 1.0], [10.0, 20.0, 30.0, 2.0]]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
expected_future = torch.tensor([[[10.0, 20.0, 30.0, 1.5]]], dtype=torch.float32)
|
||||
expected_cond = torch.cat([expected_history, expected_future], dim=1)
|
||||
|
||||
self.assertEqual(agent.condition_sequence_length, 3)
|
||||
self.assertEqual(agent.per_step_cond_dim, 4)
|
||||
self.assertEqual(len(head.calls), 1)
|
||||
self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_cond))
|
||||
self.assertEqual(len(future_predictor.calls), 1)
|
||||
|
||||
def test_compute_loss_tracks_action_and_lewm_loss_breakdown(self):
|
||||
agent_cls, agent_module = _load_imf_agent_class()
|
||||
head = _RecordingLinearIMFHead()
|
||||
future_predictor = _StubFutureTokenPredictor(num_future_tokens=1)
|
||||
sigreg = _RecordingSigReg(value=0.75)
|
||||
agent = agent_cls(
|
||||
vision_backbone=_StubVisionBackbone(),
|
||||
state_encoder=nn.Identity(),
|
||||
action_encoder=nn.Identity(),
|
||||
head=head,
|
||||
action_dim=2,
|
||||
obs_dim=1,
|
||||
pred_horizon=3,
|
||||
obs_horizon=2,
|
||||
diffusion_steps=10,
|
||||
inference_steps=1,
|
||||
num_cams=len(_CAMERA_NAMES),
|
||||
camera_names=_CAMERA_NAMES,
|
||||
num_action_steps=2,
|
||||
head_type='transformer',
|
||||
extra_condition_tokens=1,
|
||||
lewm_history_horizon=3,
|
||||
lewm_query_offsets=[8],
|
||||
lewm_predictor=future_predictor,
|
||||
lewm_pred_projector=nn.Identity(),
|
||||
lewm_sigreg=sigreg,
|
||||
lewm_sigreg_weight=0.09,
|
||||
lewm_loss_weight=0.25,
|
||||
)
|
||||
|
||||
images = _make_images(
|
||||
batch_size=1,
|
||||
obs_horizon=2,
|
||||
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
|
||||
)
|
||||
qpos = torch.tensor([[[0.25], [0.75]]], dtype=torch.float32)
|
||||
actions = torch.tensor(
|
||||
[[[1.0, -1.0], [0.5, 0.25], [-0.5, 1.5]]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
lewm_images = _make_images(
|
||||
batch_size=1,
|
||||
obs_horizon=3,
|
||||
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
|
||||
)
|
||||
lewm_qpos = torch.tensor([[[0.1], [0.2], [0.3]]], dtype=torch.float32)
|
||||
lewm_future_images = _make_images(
|
||||
batch_size=1,
|
||||
obs_horizon=1,
|
||||
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
|
||||
)
|
||||
lewm_future_qpos = torch.tensor([[[0.4]]], dtype=torch.float32)
|
||||
noise = torch.tensor(
|
||||
[[[0.2, -0.4], [0.1, 0.3], [0.5, -0.2]]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
t_sample = torch.tensor([0.8], dtype=torch.float32)
|
||||
r_sample = torch.tensor([0.25], dtype=torch.float32)
|
||||
|
||||
with mock.patch.object(agent_module.torch, 'randn_like', return_value=noise), \
|
||||
mock.patch.object(agent_module.torch, 'rand', side_effect=[t_sample, r_sample]):
|
||||
loss = agent.compute_loss(
|
||||
{
|
||||
'images': images,
|
||||
'qpos': qpos,
|
||||
'action': actions,
|
||||
'lewm_images': lewm_images,
|
||||
'lewm_qpos': lewm_qpos,
|
||||
'lewm_future_images': lewm_future_images,
|
||||
'lewm_future_qpos': lewm_future_qpos,
|
||||
}
|
||||
)
|
||||
|
||||
metrics = agent.get_last_loss_breakdown()
|
||||
self.assertAlmostEqual(loss.item(), metrics['loss'], places=6)
|
||||
self.assertIn('action_loss', metrics)
|
||||
self.assertIn('lewm_pred_loss', metrics)
|
||||
self.assertIn('lewm_sigreg_loss', metrics)
|
||||
self.assertIn('lewm_loss', metrics)
|
||||
self.assertAlmostEqual(metrics['lewm_sigreg_loss'], 0.75, places=6)
|
||||
self.assertAlmostEqual(
|
||||
metrics['lewm_loss'],
|
||||
metrics['lewm_pred_loss'] + 0.09 * metrics['lewm_sigreg_loss'],
|
||||
places=5,
|
||||
)
|
||||
self.assertAlmostEqual(
|
||||
metrics['loss'],
|
||||
metrics['action_loss'] + 0.25 * metrics['lewm_loss'],
|
||||
places=5,
|
||||
)
|
||||
self.assertEqual(len(sigreg.calls), 1)
|
||||
expected_lewm_history = torch.tensor(
|
||||
[[[1.0, 2.0, 3.0, 0.1], [1.0, 2.0, 3.0, 0.2], [1.0, 2.0, 3.0, 0.3]]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
torch.testing.assert_close(sigreg.calls[0], expected_lewm_history.transpose(0, 1))
|
||||
|
||||
def test_predict_action_with_dual_decoder_keeps_action_condition_history_only(self):
|
||||
agent_cls, agent_module = _load_imf_agent_class()
|
||||
head = _RecordingLinearIMFHead()
|
||||
future_decoder = _RecordingDirectFutureDecoder()
|
||||
agent = agent_cls(
|
||||
vision_backbone=_StubVisionBackbone(),
|
||||
state_encoder=nn.Identity(),
|
||||
action_encoder=nn.Identity(),
|
||||
head=head,
|
||||
future_decoder=future_decoder,
|
||||
action_dim=2,
|
||||
obs_dim=1,
|
||||
pred_horizon=3,
|
||||
obs_horizon=2,
|
||||
diffusion_steps=10,
|
||||
inference_steps=1,
|
||||
num_cams=len(_CAMERA_NAMES),
|
||||
camera_names=_CAMERA_NAMES,
|
||||
num_action_steps=2,
|
||||
head_type='transformer',
|
||||
lewm_history_horizon=3,
|
||||
lewm_query_offsets=[8],
|
||||
lewm_loss_weight=1.0,
|
||||
)
|
||||
agent.infer_scheduler = _ForbiddenScheduler()
|
||||
with torch.no_grad():
|
||||
agent.future_query_tokens.copy_(torch.tensor([[[0.1, 0.2, 0.3, 0.4]]], dtype=torch.float32))
|
||||
|
||||
images = _make_images(
|
||||
batch_size=1,
|
||||
obs_horizon=2,
|
||||
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
|
||||
)
|
||||
qpos = torch.tensor([[[1.0], [2.0]]], dtype=torch.float32)
|
||||
initial_noise = torch.tensor(
|
||||
[[[1.0, -1.0], [0.0, 2.0], [3.0, -2.0]]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
with mock.patch.object(agent_module.torch, 'randn', return_value=initial_noise):
|
||||
_ = agent.predict_action(images, qpos)
|
||||
|
||||
expected_history = torch.tensor(
|
||||
[[[10.0, 20.0, 30.0, 1.0], [10.0, 20.0, 30.0, 2.0]]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
self.assertEqual(len(head.calls), 1)
|
||||
self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_history))
|
||||
self.assertEqual(len(future_decoder.calls), 0)
|
||||
|
||||
def test_compute_loss_with_dual_decoder_tracks_lewm_loss_breakdown(self):
|
||||
agent_cls, agent_module = _load_imf_agent_class()
|
||||
head = _RecordingLinearIMFHead()
|
||||
future_decoder = _RecordingDirectFutureDecoder()
|
||||
sigreg = _RecordingSigReg(value=0.75)
|
||||
agent = agent_cls(
|
||||
vision_backbone=_StubVisionBackbone(),
|
||||
state_encoder=nn.Identity(),
|
||||
action_encoder=nn.Identity(),
|
||||
head=head,
|
||||
future_decoder=future_decoder,
|
||||
action_dim=2,
|
||||
obs_dim=1,
|
||||
pred_horizon=3,
|
||||
obs_horizon=2,
|
||||
diffusion_steps=10,
|
||||
inference_steps=1,
|
||||
num_cams=len(_CAMERA_NAMES),
|
||||
camera_names=_CAMERA_NAMES,
|
||||
num_action_steps=2,
|
||||
head_type='transformer',
|
||||
lewm_history_horizon=3,
|
||||
lewm_query_offsets=[8],
|
||||
lewm_sigreg=sigreg,
|
||||
lewm_sigreg_weight=0.09,
|
||||
lewm_loss_weight=1.0,
|
||||
)
|
||||
with torch.no_grad():
|
||||
agent.future_query_tokens.copy_(torch.tensor([[[0.2, 0.4, 0.6, 0.8]]], dtype=torch.float32))
|
||||
|
||||
images = _make_images(
|
||||
batch_size=1,
|
||||
obs_horizon=2,
|
||||
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
|
||||
)
|
||||
qpos = torch.tensor([[[0.25], [0.75]]], dtype=torch.float32)
|
||||
actions = torch.tensor(
|
||||
[[[1.0, -1.0], [0.5, 0.25], [-0.5, 1.5]]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
lewm_images = _make_images(
|
||||
batch_size=1,
|
||||
obs_horizon=3,
|
||||
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
|
||||
)
|
||||
lewm_qpos = torch.tensor([[[0.1], [0.2], [0.3]]], dtype=torch.float32)
|
||||
lewm_future_images = _make_images(
|
||||
batch_size=1,
|
||||
obs_horizon=1,
|
||||
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
|
||||
)
|
||||
lewm_future_qpos = torch.tensor([[[0.4]]], dtype=torch.float32)
|
||||
noise = torch.tensor(
|
||||
[[[0.2, -0.4], [0.1, 0.3], [0.5, -0.2]]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
t_sample = torch.tensor([0.8], dtype=torch.float32)
|
||||
r_sample = torch.tensor([0.25], dtype=torch.float32)
|
||||
|
||||
with mock.patch.object(agent_module.torch, 'randn_like', return_value=noise), \
|
||||
mock.patch.object(agent_module.torch, 'rand', side_effect=[t_sample, r_sample]):
|
||||
loss = agent.compute_loss(
|
||||
{
|
||||
'images': images,
|
||||
'qpos': qpos,
|
||||
'action': actions,
|
||||
'lewm_images': lewm_images,
|
||||
'lewm_qpos': lewm_qpos,
|
||||
'lewm_future_images': lewm_future_images,
|
||||
'lewm_future_qpos': lewm_future_qpos,
|
||||
}
|
||||
)
|
||||
|
||||
metrics = agent.get_last_loss_breakdown()
|
||||
self.assertAlmostEqual(loss.item(), metrics['loss'], places=6)
|
||||
self.assertEqual(len(head.calls), 2)
|
||||
self.assertEqual(head.calls[0]['cond'].shape, (1, 2, 4))
|
||||
self.assertEqual(len(future_decoder.calls), 1)
|
||||
self.assertEqual(future_decoder.calls[0]['cond'].shape, (1, 3, 4))
|
||||
self.assertAlmostEqual(
|
||||
metrics['loss'],
|
||||
metrics['action_loss'] + metrics['lewm_loss'],
|
||||
places=5,
|
||||
)
|
||||
self.assertAlmostEqual(
|
||||
metrics['lewm_loss'],
|
||||
metrics['lewm_pred_loss'] + 0.09 * metrics['lewm_sigreg_loss'],
|
||||
places=5,
|
||||
)
|
||||
self.assertGreater(metrics['lewm_pred_loss'], 0.0)
|
||||
self.assertAlmostEqual(metrics['lewm_sigreg_loss'], 0.75, places=6)
|
||||
|
||||
def test_select_action_only_regenerates_when_action_queue_is_empty(self):
|
||||
agent, _head, _agent_module = self._make_agent(pred_horizon=4, obs_horizon=2, num_action_steps=2)
|
||||
observation = {
|
||||
@@ -851,6 +1199,80 @@ class IMFVLAAgentTest(unittest.TestCase):
|
||||
self.assertEqual(agent.vision_encoder.output_dim, 96)
|
||||
self.assertEqual(agent.vision_encoder.eval_image_resize_shape, (256, 256))
|
||||
|
||||
def test_hydra_config_instantiates_lewm_resnet_query_imf_attnres_with_future_tokens(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
'agent=lewm_resnet_query_imf_attnres',
|
||||
'agent.head.n_layer=1',
|
||||
'agent.head.n_emb=16',
|
||||
'agent.lewm_query_offsets=[8]',
|
||||
]
|
||||
)
|
||||
|
||||
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
|
||||
self.assertEqual(
|
||||
cfg.agent.vision_backbone._target_,
|
||||
'roboimi.vla.models.backbones.lewm_resnet_query_fusion.LeWMMultiViewResNetBackbone',
|
||||
)
|
||||
self.assertEqual(
|
||||
cfg.agent.state_encoder._target_,
|
||||
'roboimi.vla.modules.encoders.LeWMStateEncoder',
|
||||
)
|
||||
self.assertEqual(cfg.agent.head.cond_dim, 288)
|
||||
self.assertEqual(cfg.agent.cond_projector.output_dim, 288)
|
||||
self.assertEqual(cfg.agent.extra_condition_tokens, 1)
|
||||
self.assertEqual(
|
||||
cfg.agent.lewm_sigreg._target_,
|
||||
'roboimi.vla.models.backbones.lewm_resnet_query_fusion.SIGReg',
|
||||
)
|
||||
self.assertAlmostEqual(cfg.agent.lewm_sigreg_weight, 0.09)
|
||||
|
||||
with _stub_optional_modules(include_imf_head=True):
|
||||
agent = instantiate(cfg.agent)
|
||||
|
||||
self.assertEqual(agent.per_step_cond_dim, 288)
|
||||
self.assertEqual(agent.condition_sequence_length, agent.obs_horizon + 1)
|
||||
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], 288)
|
||||
self.assertEqual(
|
||||
agent.noise_pred_net.constructor_kwargs['n_obs_steps'],
|
||||
agent.condition_sequence_length,
|
||||
)
|
||||
self.assertIsNotNone(agent.lewm_sigreg)
|
||||
|
||||
def test_hydra_config_instantiates_lewm_resnet_dual_decoder_imf_attnres(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
'agent=lewm_resnet_dual_decoder_imf_attnres',
|
||||
'agent.head.n_layer=1',
|
||||
'agent.head.n_emb=16',
|
||||
'agent.future_decoder.n_layer=1',
|
||||
'agent.future_decoder.n_emb=16',
|
||||
'agent.lewm_query_offsets=[8]',
|
||||
]
|
||||
)
|
||||
|
||||
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
|
||||
self.assertEqual(cfg.agent.extra_condition_tokens, 0)
|
||||
self.assertEqual(
|
||||
cfg.agent.future_decoder._target_,
|
||||
'roboimi.vla.models.heads.imf_transformer1d.IMFTransformer1D',
|
||||
)
|
||||
self.assertEqual(cfg.agent.head.cond_dim, 288)
|
||||
self.assertEqual(cfg.agent.future_decoder.cond_dim, 288)
|
||||
|
||||
with _stub_optional_modules(include_imf_head=True):
|
||||
agent = instantiate(cfg.agent)
|
||||
|
||||
self.assertEqual(agent.per_step_cond_dim, 288)
|
||||
self.assertEqual(agent.condition_sequence_length, agent.obs_horizon)
|
||||
self.assertEqual(agent.noise_pred_net.constructor_kwargs['n_obs_steps'], agent.obs_horizon)
|
||||
self.assertEqual(agent.future_decoder.constructor_kwargs['cond_dim'], 288)
|
||||
self.assertEqual(
|
||||
agent.future_decoder.constructor_kwargs['n_obs_steps'],
|
||||
agent.lewm_history_horizon,
|
||||
)
|
||||
self.assertEqual(agent.future_query_tokens.shape, (1, 1, 288))
|
||||
|
||||
|
||||
def test_hydra_config_instantiates_resnet_imf_attnres_multitoken_with_sequence_length_three_times_obs_horizon(self):
|
||||
cfg = _compose_cfg(
|
||||
|
||||
@@ -12,18 +12,21 @@ from roboimi.vla.data.simpe_robot_dataset import SimpleRobotDataset
|
||||
|
||||
|
||||
class SimpleRobotDatasetImageLoadingTest(unittest.TestCase):
|
||||
def _write_episode(self, dataset_dir: Path) -> None:
|
||||
episode_path = dataset_dir / "episode_0.hdf5"
|
||||
def _write_episode(self, dataset_dir: Path, episode_idx: int = 0, *, base_value: float = 0.0) -> None:
|
||||
episode_path = dataset_dir / f"episode_{episode_idx}.hdf5"
|
||||
with h5py.File(episode_path, "w") as root:
|
||||
root.create_dataset("action", data=np.arange(8, dtype=np.float32).reshape(4, 2))
|
||||
root.create_dataset(
|
||||
"action",
|
||||
data=(np.arange(8, dtype=np.float32).reshape(4, 2) + base_value),
|
||||
)
|
||||
root.create_dataset(
|
||||
"observations/qpos",
|
||||
data=np.arange(16, dtype=np.float32).reshape(4, 4),
|
||||
data=(np.arange(16, dtype=np.float32).reshape(4, 4) + base_value),
|
||||
)
|
||||
root.create_dataset("task", data=np.array([b"sim_transfer"]))
|
||||
root.create_dataset(
|
||||
"observations/images/front",
|
||||
data=np.arange(4 * 8 * 8 * 3, dtype=np.uint8).reshape(4, 8, 8, 3),
|
||||
data=((np.arange(4 * 8 * 8 * 3, dtype=np.uint8) + int(base_value)) % 255).reshape(4, 8, 8, 3),
|
||||
)
|
||||
|
||||
def test_getitem_only_resizes_observation_horizon_images(self):
|
||||
@@ -79,3 +82,46 @@ class SimpleRobotDatasetImageLoadingTest(unittest.TestCase):
|
||||
|
||||
fake_cv2.resize.assert_not_called()
|
||||
self.assertEqual(tuple(sample["observation.front"].shape), (2, 3, 8, 8))
|
||||
|
||||
def test_getitem_can_emit_lewm_history_and_future_observations(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
dataset_dir = Path(tmpdir)
|
||||
self._write_episode(dataset_dir)
|
||||
dataset = SimpleRobotDataset(
|
||||
dataset_dir,
|
||||
obs_horizon=2,
|
||||
pred_horizon=3,
|
||||
camera_names=["front"],
|
||||
image_resize_shape=None,
|
||||
lewm_history_horizon=3,
|
||||
lewm_query_offsets=[1, 2],
|
||||
)
|
||||
|
||||
sample = dataset[1]
|
||||
|
||||
self.assertEqual(tuple(sample["lewm.observation.state"].shape), (3, 4))
|
||||
self.assertEqual(tuple(sample["lewm.observation.front"].shape), (3, 3, 8, 8))
|
||||
self.assertEqual(tuple(sample["lewm.future.state"].shape), (2, 4))
|
||||
self.assertEqual(tuple(sample["lewm.future.front"].shape), (2, 3, 8, 8))
|
||||
|
||||
def test_dataset_can_limit_loading_to_specific_episode_indices(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
dataset_dir = Path(tmpdir)
|
||||
self._write_episode(dataset_dir, episode_idx=0, base_value=0.0)
|
||||
self._write_episode(dataset_dir, episode_idx=1, base_value=100.0)
|
||||
|
||||
dataset = SimpleRobotDataset(
|
||||
dataset_dir,
|
||||
obs_horizon=2,
|
||||
pred_horizon=3,
|
||||
camera_names=["front"],
|
||||
image_resize_shape=None,
|
||||
episode_indices=[1],
|
||||
)
|
||||
|
||||
sample = dataset[0]
|
||||
|
||||
self.assertEqual(len(dataset.hdf5_files), 1)
|
||||
self.assertEqual(dataset.available_episode_indices, [1])
|
||||
self.assertEqual(len(dataset), 4)
|
||||
self.assertTrue(np.allclose(sample["observation.state"][0].numpy(), np.array([100.0, 101.0, 102.0, 103.0])))
|
||||
|
||||
@@ -158,6 +158,106 @@ class TrainVLARolloutValidationTest(unittest.TestCase):
|
||||
self.assertGreater(float(cfg.train.lr), 5e-5)
|
||||
self.assertGreater(cfg.train.num_workers, 8)
|
||||
self.assertEqual(cfg.train.rollout_val_freq_epochs, 50)
|
||||
self.assertEqual(cfg.train.rollout_device, cfg.train.device)
|
||||
self.assertIsNone(cfg.train.rollout_num_workers)
|
||||
self.assertIsNone(cfg.train.rollout_cuda_devices)
|
||||
|
||||
def test_run_training_rollout_validation_propagates_gpu_parallel_settings(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
'train': {
|
||||
'device': 'cpu',
|
||||
'batch_size': 1,
|
||||
'num_workers': 0,
|
||||
'val_split': 0.0,
|
||||
'seed': 0,
|
||||
'lr': 1e-3,
|
||||
'max_steps': 2,
|
||||
'log_freq': 1,
|
||||
'save_freq': 1000,
|
||||
'warmup_steps': 1,
|
||||
'scheduler_type': 'constant',
|
||||
'min_lr': 0.0,
|
||||
'grad_clip': 1.0,
|
||||
'weight_decay': 0.0,
|
||||
'pretrained_ckpt': None,
|
||||
'resume_ckpt': None,
|
||||
'use_swanlab': False,
|
||||
'rollout_val_freq_epochs': 2,
|
||||
'rollout_num_episodes': 5,
|
||||
'rollout_device': 'cuda',
|
||||
'rollout_num_workers': 4,
|
||||
'rollout_cuda_devices': [0, 1],
|
||||
'rollout_response_timeout_s': 123.0,
|
||||
'rollout_server_startup_timeout_s': 456.0,
|
||||
},
|
||||
'data': {
|
||||
'camera_names': ['front'],
|
||||
},
|
||||
'agent': {
|
||||
'_target_': 'fake.agent',
|
||||
},
|
||||
'eval': {
|
||||
'ckpt_path': 'unused.pt',
|
||||
'num_episodes': 99,
|
||||
'max_timesteps': 1,
|
||||
'device': 'cpu',
|
||||
'task_name': 'sim_transfer',
|
||||
'camera_names': ['front'],
|
||||
'use_smoothing': False,
|
||||
'smooth_alpha': 0.3,
|
||||
'verbose_action': False,
|
||||
'headless': False,
|
||||
},
|
||||
}
|
||||
)
|
||||
rollout_mock = mock.Mock(return_value={'avg_reward': 1.0})
|
||||
|
||||
def fake_instantiate(config_node, **_kwargs):
|
||||
if config_node is cfg.data:
|
||||
return _FakeDataset()
|
||||
if config_node is cfg.agent:
|
||||
return _FakeAgent()
|
||||
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||
|
||||
def fake_dataloader(_dataset, *, shuffle, **_kwargs):
|
||||
del shuffle, _kwargs
|
||||
return _FakeLoader(
|
||||
{
|
||||
'observation.front': torch.zeros(1, 3, 2, 2),
|
||||
'observation.state': torch.zeros(1, 4),
|
||||
'action': torch.zeros(1, 2),
|
||||
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
|
||||
},
|
||||
length=1,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \
|
||||
mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \
|
||||
mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \
|
||||
mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \
|
||||
mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \
|
||||
mock.patch.object(train_vla.torch, 'save', return_value=None), \
|
||||
mock.patch.object(eval_vla, '_run_eval', rollout_mock, create=True):
|
||||
train_vla._run_training(cfg)
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
rollout_cfg = rollout_mock.call_args.args[0]
|
||||
self.assertEqual(rollout_cfg.eval.device, 'cuda')
|
||||
self.assertEqual(rollout_cfg.eval.num_workers, 4)
|
||||
self.assertEqual(list(rollout_cfg.eval.cuda_devices), [0, 1])
|
||||
self.assertEqual(float(rollout_cfg.eval.response_timeout_s), 123.0)
|
||||
self.assertEqual(float(rollout_cfg.eval.server_startup_timeout_s), 456.0)
|
||||
self.assertTrue(rollout_cfg.eval.headless)
|
||||
self.assertEqual(rollout_cfg.eval.num_episodes, 5)
|
||||
self.assertFalse(rollout_cfg.eval.record_video)
|
||||
self.assertTrue(rollout_cfg.eval.save_summary_json)
|
||||
self.assertTrue(rollout_cfg.eval.save_trajectory_image)
|
||||
|
||||
def test_training_passes_backbone_image_resize_override_to_dataset_instantiation(self):
|
||||
cfg = OmegaConf.create(
|
||||
|
||||
@@ -41,6 +41,19 @@ class FakeDataset:
|
||||
return 4
|
||||
|
||||
|
||||
class SplitAwareFakeDataset(FakeDataset):
|
||||
def __init__(self, episode_indices=None):
|
||||
self.episode_indices = None if episode_indices is None else list(episode_indices)
|
||||
if self.episode_indices is None:
|
||||
self.episodes = {0: [0], 1: [1], 2: [2]}
|
||||
else:
|
||||
self.episodes = {idx: [idx] for idx in self.episode_indices}
|
||||
|
||||
@property
|
||||
def available_episode_indices(self):
|
||||
return sorted(self.episodes.keys())
|
||||
|
||||
|
||||
class FakeLoader:
|
||||
def __init__(self, batch):
|
||||
self.batch = batch
|
||||
@@ -114,6 +127,26 @@ class FakeAgent(nn.Module):
|
||||
return {}
|
||||
|
||||
|
||||
class RecordingAgent(FakeAgent):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.seen_inputs = []
|
||||
|
||||
def compute_loss(self, agent_input):
|
||||
self.seen_inputs.append(agent_input)
|
||||
return super().compute_loss(agent_input)
|
||||
|
||||
def predict_action_chunk(self, agent_input):
|
||||
self.seen_inputs.append({'predict_action_chunk': agent_input})
|
||||
return torch.ones_like(agent_input['action'])
|
||||
|
||||
|
||||
class ShapeMixedFakeAgent(FakeAgent):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bias = nn.Parameter(torch.zeros(2))
|
||||
|
||||
|
||||
class FakeSwanLab:
|
||||
def __init__(self, init_error=None, log_errors=None, finish_error=None, image_errors=None):
|
||||
self.init_error = init_error
|
||||
@@ -339,6 +372,8 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
||||
batch_size=2,
|
||||
num_workers=0,
|
||||
val_split=0.25,
|
||||
val_episode_indices=None,
|
||||
action_mse_val_freq_epochs=0,
|
||||
seed=0,
|
||||
lr=1e-3,
|
||||
max_steps=2,
|
||||
@@ -388,6 +423,18 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
||||
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
|
||||
}
|
||||
|
||||
def _make_lewm_batch(self):
|
||||
batch = self._make_batch()
|
||||
batch.update(
|
||||
{
|
||||
'lewm.observation.front': torch.ones(1, 3, 2, 2),
|
||||
'lewm.observation.state': torch.ones(1, 4),
|
||||
'lewm.future.front': torch.full((1, 3, 2, 2), 2.0),
|
||||
'lewm.future.state': torch.full((1, 4), 2.0),
|
||||
}
|
||||
)
|
||||
return batch
|
||||
|
||||
def _loader_factory(self):
|
||||
train_batch = self._make_batch()
|
||||
val_batch = self._make_batch()
|
||||
@@ -397,6 +444,15 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
||||
|
||||
return factory
|
||||
|
||||
def _lewm_loader_factory(self):
|
||||
train_batch = self._make_lewm_batch()
|
||||
val_batch = self._make_lewm_batch()
|
||||
|
||||
def factory(_dataset, *, shuffle, **_kwargs):
|
||||
return FakeLoader(train_batch if shuffle else val_batch)
|
||||
|
||||
return factory
|
||||
|
||||
def test_run_training_logs_metrics_and_checkpoint_paths_to_swanlab(self):
|
||||
module = self._load_train_vla_module()
|
||||
run_training = self._get_run_training(module)
|
||||
@@ -442,6 +498,8 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
||||
'batch_size': 2,
|
||||
'num_workers': 0,
|
||||
'val_split': 0.25,
|
||||
'val_episode_indices': None,
|
||||
'action_mse_val_freq_epochs': 0,
|
||||
'seed': 0,
|
||||
'lr': 1e-3,
|
||||
'max_steps': 2,
|
||||
@@ -487,6 +545,95 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
||||
self.assertTrue(final_payload['final/best_checkpoint_path'].endswith('checkpoints/vla_model_best.pt'))
|
||||
self.assertEqual(fake_swanlab.finish_calls, 1)
|
||||
|
||||
def test_run_training_passes_lewm_history_and_future_batches_into_agent_input(self):
|
||||
module = self._load_train_vla_module()
|
||||
run_training = self._get_run_training(module)
|
||||
cfg = self._make_cfg(use_swanlab=False)
|
||||
cfg.train.max_steps = 1
|
||||
cfg.train.save_freq = 100
|
||||
agent = RecordingAgent()
|
||||
|
||||
def fake_instantiate(config_node, **_kwargs):
|
||||
if config_node is cfg.data:
|
||||
return FakeDataset()
|
||||
if config_node is cfg.agent:
|
||||
return agent
|
||||
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
||||
mock.patch.object(module, 'DataLoader', side_effect=self._lewm_loader_factory()), \
|
||||
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
|
||||
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
|
||||
mock.patch.object(module.torch, 'save', return_value=None):
|
||||
run_training(cfg)
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
self.assertGreaterEqual(len(agent.seen_inputs), 1)
|
||||
first_input = agent.seen_inputs[0]
|
||||
self.assertIn('lewm_images', first_input)
|
||||
self.assertIn('lewm_qpos', first_input)
|
||||
self.assertIn('lewm_future_images', first_input)
|
||||
self.assertIn('lewm_future_qpos', first_input)
|
||||
self.assertIn('front', first_input['lewm_images'])
|
||||
self.assertIn('front', first_input['lewm_future_images'])
|
||||
|
||||
def test_run_training_logs_epoch_action_mse_for_held_out_val_episode(self):
|
||||
module = self._load_train_vla_module()
|
||||
run_training = self._get_run_training(module)
|
||||
cfg = self._make_cfg()
|
||||
cfg.train.max_steps = 1
|
||||
cfg.train.save_freq = 100
|
||||
cfg.train.val_split = 0.0
|
||||
cfg.train.val_episode_indices = [2]
|
||||
cfg.train.action_mse_val_freq_epochs = 1
|
||||
agent = RecordingAgent()
|
||||
fake_swanlab = FakeSwanLab()
|
||||
real_import_module = importlib.import_module
|
||||
|
||||
def fake_instantiate(config_node, **kwargs):
|
||||
if config_node is cfg.data:
|
||||
return SplitAwareFakeDataset(kwargs.get('episode_indices'))
|
||||
if config_node is cfg.agent:
|
||||
return agent
|
||||
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||
|
||||
def fake_loader_factory(dataset, *, shuffle, **_kwargs):
|
||||
action_value = 0.0 if shuffle else 2.0
|
||||
batch = {
|
||||
'observation.front': torch.zeros(1, 3, 2, 2),
|
||||
'observation.state': torch.zeros(1, 4),
|
||||
'action': torch.full((1, 1, 2), action_value),
|
||||
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
|
||||
}
|
||||
return FakeLoader(batch)
|
||||
|
||||
def fake_import_module(name, package=None):
|
||||
if name == 'swanlab':
|
||||
return fake_swanlab
|
||||
return real_import_module(name, package)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
||||
mock.patch.object(module, 'DataLoader', side_effect=fake_loader_factory), \
|
||||
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
|
||||
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
|
||||
mock.patch.object(module.torch, 'save', return_value=None), \
|
||||
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
|
||||
run_training(cfg)
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
logged_keys = set().union(*(payload.keys() for payload, _ in fake_swanlab.log_calls))
|
||||
self.assertIn('val/action_mse', logged_keys)
|
||||
|
||||
def test_run_training_skips_swanlab_when_disabled(self):
|
||||
module = self._load_train_vla_module()
|
||||
run_training = self._get_run_training(module)
|
||||
@@ -668,6 +815,52 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
||||
self.assertTrue(final_payload['final/best_checkpoint_path'].endswith('checkpoints/vla_model_best.pt'))
|
||||
self.assertFalse(any(path.endswith('checkpoints/vla_model_best.pt') for path in saved_paths))
|
||||
|
||||
def test_run_training_pretrained_ckpt_loads_matching_keys_even_if_some_shapes_mismatch(self):
|
||||
module = self._load_train_vla_module()
|
||||
run_training = self._get_run_training(module)
|
||||
cfg = self._make_cfg(use_swanlab=False)
|
||||
cfg.train.max_steps = 0
|
||||
cfg.train.save_freq = 100
|
||||
cfg.train.pretrained_ckpt = 'pretrained.pt'
|
||||
agent = ShapeMixedFakeAgent()
|
||||
|
||||
def fake_instantiate(config_node, **_kwargs):
|
||||
if config_node is cfg.data:
|
||||
return FakeDataset()
|
||||
if config_node is cfg.agent:
|
||||
return agent
|
||||
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||
|
||||
def fake_torch_load(path, map_location=None):
|
||||
del map_location
|
||||
if Path(path).name != 'pretrained.pt':
|
||||
raise AssertionError(f'unexpected load path: {path}')
|
||||
return {
|
||||
'model_state_dict': {
|
||||
'weight': torch.tensor(3.0),
|
||||
'bias': torch.tensor([1.0, 2.0, 3.0]),
|
||||
},
|
||||
'step': 123,
|
||||
'loss': 0.5,
|
||||
}
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
Path('pretrained.pt').write_bytes(b'pretend')
|
||||
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
||||
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
|
||||
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
|
||||
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
|
||||
mock.patch.object(module.torch, 'save', return_value=None), \
|
||||
mock.patch.object(module.torch, 'load', side_effect=fake_torch_load):
|
||||
run_training(cfg)
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
self.assertAlmostEqual(agent.weight.item(), 3.0, places=6)
|
||||
|
||||
def test_run_training_resume_ignores_best_checkpoint_without_rollout_metric(self):
|
||||
module = self._load_train_vla_module()
|
||||
run_training = self._get_run_training(module)
|
||||
|
||||
Reference in New Issue
Block a user