5 Commits

23 changed files with 4853 additions and 183 deletions

View File

@@ -0,0 +1,471 @@
# feat-lewm-imf-fusion 实验操作指南
适用 worktree`/home/droid/project/roboimi/.worktrees/feat-lewm-imf-fusion`
## 0. 先记住当前常用 recipe
当前这条分支最常用的训练/验证配方,直接参考:
`experiment_suites/2026-04-21-lewm-fromscratch-old9-epoch50-roll5-val-20260421-153037/`
核心约定:
- agent`lewm_resnet_query_imf_attnres`
- from scratch`train.pretrained_ckpt=null``agent.lewm_pretrained_ckpt=null`
- 训练:`batch_size=32``lr=1e-4``max_steps=109350``save_freq=10000`
- 数值验证:`train.val_split=0.0` + `train.val_episode_indices=[100]`
- held-out numeric validation`train.action_mse_val_freq_epochs=1`
- rollout validation`train.rollout_val_freq_epochs=5``train.rollout_num_episodes=10`
- SwanLab`train.use_swanlab=true`project=`roboimi-vla`
---
## 1. 分支结构与关键文件
| 路径 | 作用 |
| --- | --- |
| `roboimi/demos/vla_scripts/train_vla.py` | 主训练入口负责数据集、checkpoint、数值验证、训练期 rollout 验证、SwanLab |
| `roboimi/demos/vla_scripts/eval_vla.py` | 单次 rollout / 离线验证入口;支持 headless、summary、trajectory image/video artifact |
| `roboimi/vla/conf/config.yaml` | 全局 Hydra 配置;训练默认值都在这里 |
| `roboimi/vla/conf/eval/eval.yaml` | eval 默认配置;`eval.ckpt_path``eval.num_episodes`、artifact 开关都在这里 |
| `roboimi/vla/conf/agent/lewm_resnet_query_imf_attnres.yaml` | 本分支最常用 agentLeWM query fusion + IMF AttnRes head |
| `roboimi/vla/conf/backbone/lewm_resnet_query_fusion.yaml` | LeWM 多视角 ResNet query fusion backbone 配置 |
| `roboimi/vla/agent_imf.py` | `IMFVLAAgent` 实现one-step IMF 推理、LeWM loss、LeWM 预训练组件加载 |
| `roboimi/vla/data/simpe_robot_dataset.py` | HDF5 懒加载数据集;也负责 `episode_indices` 过滤 |
| `roboimi/vla/scripts/calculate_stats.py` | 重算 `dataset_stats.pkl` |
| `experiment_suites/2026-04-21-lewm-fromscratch-old9-epoch50-roll5-val-20260421-153037/` | 当前最常用 suitemanifest、notes、launch log、local 启动脚本都在这里 |
补充:
- 本分支常用 run name 形如 `lewmimf-q08-ph08-ex08-emb384-l12-fromscratch-epoch50-step109350-5090g0-20260421-153037`
- `q08/ph16/ex08` 这类后缀分别对应 `agent.lewm_query_offsets``agent.pred_horizon``agent.num_action_steps`
---
## 2. 三台机器与环境
| 机器 | GPU | repo / worktree | Python | 常用数据集路径 |
| --- | --- | --- | --- | --- |
| 本地 `droid-z790eagleax` | 1× RTX 5090 32GB | `/home/droid/project/roboimi/.worktrees/feat-lewm-imf-fusion` | `/home/droid/.conda/envs/roboimi/bin/python` | `/home/droid/project/diana_sim/sim_transfer` |
| 5880 节点 `100.73.14.65` | 2× RTX 5880 Ada 48GB | `/home/droid/roboimi_suite_20260416_lewm_imf_fusion` | `/home/droid/miniforge3/envs/roboimi/bin/python` | `/home/droid/sim_dataset/sim_transfer` |
| L20 节点 `100.119.99.14` | 8× NVIDIA L20 46GB | `/data/roboimi_suite_20260416_lewm_imf_fusion` | `/home/droid/miniforge3/envs/roboimi/bin/python` | `/data/simtransfer/current` |
连接:
- 5880`ssh droid@100.73.14.65`
- L20`ssh droid@100.119.99.14`
经验规则:
- 本地 5090适合单条 smoke / 小规模主跑 / 本地调参
- 5880适合 2 条并行主跑
- L20适合大 grid数据和 run 建议都放 `/data`
---
## 3. 训练流怎么走
`train_vla.py` 的实际流程:
1. 读取 Hydra 配置并打印完整 cfg
2. 通过 `build_train_val_datasets()` 构建 train/val dataset
3.`DataLoader` 建 train/val loader
4.`dataset_dir/dataset_stats.pkl` 读取归一化统计
5. instantiate `IMFVLAAgent`
6. 可选加载:
- `train.pretrained_ckpt`
- `train.resume_ckpt`
- `agent.lewm_pretrained_ckpt`
7. 训练循环里按 `log_freq` 打 train loss / lr
8.`save_freq` 保存 `checkpoints/vla_model_step_*.pt`
9. 每个 epoch 结束时,按配置跑:
- held-out action MSE
- rollout validation
10. 最后写:
- `checkpoints/vla_model_best.pt`
- `checkpoints/vla_model_final.pt`
当前 best model 选择逻辑:
- **第一次拿到 rollout reward 之前**:先用 `val_loss`(或 train loss 回退)挑 best
- **第一次 rollout 之后**:优先用 `rollout_avg_reward` 挑 best
输出目录一般通过 `hydra.run.dir=...` 固定;否则 Hydra 自己生成。
---
## 4. 验证流怎么走
### 4.1 held-out 数值验证
当前常用做法不是随机切 `val_split`,而是:
- `train.val_split=0.0`
- `train.val_episode_indices=[100]`
- `train.action_mse_val_freq_epochs=1`
这样每个 epoch 结束都会在 `episode_100.hdf5` 上跑一次 `compute_action_mse_validation()`,日志 key 是:
- 控制台 / `train_vla.log``held-out action MSE`
- SwanLab`val/action_mse`
### 4.2 rollout 验证
当前训练内 rollout 验证由 `train_vla.py -> run_rollout_validation() -> eval_vla._run_eval()` 触发。
当前这条分支的常用训练内 rollout 约束是:
- `train.rollout_val_freq_epochs=5`
- `train.rollout_num_episodes=10`
- `train.rollout_validate_on_checkpoint=false`
- 强制 headless
- 强制 `verbose_action=false`
- 强制 `record_video=false`
- 强制 `save_trajectory_image=true`
- 强制 `trajectory_image_camera_name=front`
- 强制 `save_summary_json=true`
当前已经修正为**配置驱动的 rollout device / worker 路径**
- `train.rollout_device`:默认跟随 `train.device`
- `train.rollout_num_workers`:默认 `null`
- 当 rollout 设备是 CPU 时,自动退化为 `1`
- 当 rollout 设备是 CUDA 时,自动推断为 `min(train.rollout_num_episodes, 8)`
- `train.rollout_cuda_devices`:默认 `null`,等价于当前可见逻辑 GPU `[0]`
- `train.rollout_response_timeout_s`
- `train.rollout_server_startup_timeout_s`
所以现在:
- 训练在 `cuda` 上时,**训练期 rollout 默认会走 GPU**
- 如果 `rollout_num_workers > 1`,就会自动走并行 rollout
- 可以是 **单 GPU 多 worker 共用一个 inference server**
- 也可以是 **多 GPU 多 server 分摊 worker**
训练内 rollout artifact 默认落到:
`<hydra.run.dir>/rollout_artifacts/<checkpoint_stem>/`
常见文件:
- `rollout_summary.json`
- `rollout_front_ep01_trajectory.png` ... `rollout_front_ep10_trajectory.png`
日志重点看:
- `Epoch X rollout 平均奖励`
- `最佳模型已更新`
---
## 5. 数据集加载与 `val_episode_indices` 机制
### 5.1 数据集格式
`SimpleRobotDataset` 读取 `dataset_dir` 下的 `episode_*.hdf5`,每个 episode 文件里至少要有:
- `action`
- `observations/qpos`
- `observations/images/{cam_name}`
当前常用相机:
- `r_vis`
- `top`
- `front`
### 5.2 懒加载行为
`roboimi/vla/data/simpe_robot_dataset.py` 是按帧懒加载,不会一次性把整套 HDF5 全读进内存。
它会:
- 扫描目录下的 HDF5 文件
- 用文件名里的 episode 编号(如 `episode_100.hdf5` -> `100`)建立 `available_episode_indices`
- 在 worker 内做 HDF5 文件句柄 LRU 缓存
### 5.3 `val_episode_indices` 怎么切
`build_train_val_datasets()` 的逻辑是:
1. 先 instantiate 一次完整 dataset
2. 读取 `dataset.available_episode_indices`
3. 检查 `train.val_episode_indices` 是否都存在
4.`episode_indices=` 再各 instantiate 一次:
- train dataset = 全部 episode - held-out episode
- val dataset = 只包含 held-out episode
因此:
- `train.val_episode_indices=[100]` 的意思是“把 `episode_100.hdf5` 整个拿去做 held-out val”
- 如果 episode 不存在,会直接报错
- 如果你把所有 episode 都塞进 `val_episode_indices`,也会直接报错,因为训练集会变空
### 5.4 图像 resize 与 LeWM 附加字段
dataset 侧 resize 默认来自:
- `data.image_resize_shape`
- 如果 backbone 额外覆盖,则优先 `agent.vision_backbone.dataset_image_resize_shape`
返回 batch 除了常规:
- `observation.state`
- `observation.<cam>`
- `action`
还会在 LeWM 打开时返回:
- `lewm.observation.state`
- `lewm.observation.<cam>`
- `lewm.future.state`
- `lewm.future.<cam>`
### 5.5 统计文件
训练和推理都默认依赖 `dataset_stats.pkl`。数据集更新后重算:
```bash
/home/droid/.conda/envs/roboimi/bin/python roboimi/vla/scripts/calculate_stats.py \
--dataset_dir /home/droid/project/diana_sim/sim_transfer
```
远端只要把 `--dataset_dir` 换成对应主机路径即可。
---
## 6. SwanLab 行为
当前配置默认值里 `train.use_swanlab=false`,但本分支常用 recipe 基本都显式开:
- `train.use_swanlab=true`
- `train.swanlab_project=roboimi-vla`
- `train.swanlab_run_name=<run_name>`
`train_vla.py` 的 SwanLab 行为:
- 初始化时上传 `train` / `data` / `agent` 三段 config
- 训练中记录:
- `train/loss`
- `train/lr`
- `train/best_loss`
- `train/step`
- checkpoint 验证时记录:
- `val/loss`
- held-out 数值验证时记录:
- `val/action_mse`
- rollout 验证时记录:
- `rollout/avg_reward`
- `rollout/epoch`
- 训练结束时记录:
- `final/checkpoint_path`
- `final/best_checkpoint_path`
训练期 rollout 生成的前视图轨迹 PNG 会 best-effort 上传到 SwanLab失败只会 warning不会让训练中断。
---
## 7. 并行 rollout 说明
### 7.1 这套能力从哪里来
本分支的并行 rollout 方向不是 DataLoader 并行,而是 **`eval_vla.py` 的 multiprocess rollout path**。
参考来源:
`/home/droid/project/roboimi/.worktrees/multiprocess-rollout/roboimi/demos/vla_scripts/eval_vla.py`
那条路径的控制参数是:
- `eval.num_workers`
- `eval.cuda_devices`
语义是:
- `eval.num_workers`:环境 worker 数,按 episode 切分
- `eval.cuda_devices`:推理 server 绑定到哪些逻辑 GPU
### 7.2 两种常见模式
1. **单机单卡,多 worker 共用同一张 GPU**
- 典型:本地 5090 只有 1 卡,但想让 4 个 rollout worker 并行跑环境
- 形式:`eval.device=cuda eval.num_workers=4 'eval.cuda_devices=[0]'`
- 这时是 **1 个 CUDA inference server + 4 个 env worker**
2. **单机多卡,多 server 分摊 worker**
- 典型5880 有 2 卡L20 有多卡
- 形式:`eval.device=cuda eval.num_workers=8 'eval.cuda_devices=[0,1]'`
- worker 会按 round-robin 分到多个 server 上
### 7.3 操作上要注意什么
- 并行 rollout 依赖 **多进程 eval 路径**,不是 `train.num_workers`
- `train.num_workers` 是 DataLoader worker和 rollout 并行不是一回事
- `eval.num_workers > 1` 时必须 `eval.headless=true`
- worker 数会自动 cap 到 `eval.num_episodes`
- multiprocess rollout 当前已经支持 **per-episode trajectory image PNG**;多 worker 时每个 worker 会在自己的 artifact 子目录下写图summary 会带回对应路径
- 但多 worker 时仍然不要同时要求:
- `eval.record_video=true`
- `eval.save_trajectory=true`
- `eval.save_trajectory_npz=true`
- `eval.save_trajectory_image=true` 现在是可以开的,适合并行 reward + 定性检查一起做
### 7.4 并行 rollout 命令模板
**5090 单卡 4 worker**
```bash
/home/droid/.conda/envs/roboimi/bin/python roboimi/demos/vla_scripts/eval_vla.py \
agent=lewm_resnet_query_imf_attnres \
data.dataset_dir=/home/droid/project/diana_sim/sim_transfer \
train.device=cuda eval.device=cuda eval.headless=true eval.verbose_action=false \
eval.ckpt_path=/home/droid/project/roboimi/.worktrees/feat-lewm-imf-fusion/runs/<run_name>/checkpoints/vla_model_best.pt \
eval.num_episodes=10 eval.num_workers=4 'eval.cuda_devices=[0]' \
eval.save_summary_json=true eval.artifact_dir=/tmp/lewm_parallel_eval_5090
```
**5880 双卡 8 worker**
```bash
/home/droid/miniforge3/envs/roboimi/bin/python roboimi/demos/vla_scripts/eval_vla.py \
agent=lewm_resnet_query_imf_attnres \
data.dataset_dir=/home/droid/sim_dataset/sim_transfer \
train.device=cuda eval.device=cuda eval.headless=true eval.verbose_action=false \
eval.ckpt_path=/home/droid/roboimi_suite_20260416_lewm_imf_fusion/runs/<run_name>/checkpoints/vla_model_best.pt \
eval.num_episodes=10 eval.num_workers=8 'eval.cuda_devices=[0,1]' \
eval.save_summary_json=true eval.artifact_dir=/tmp/lewm_parallel_eval_5880
```
---
## 8. 当前常用命令 / 脚本
### 8.1 本地 5090直接用 suite 脚本
现成脚本:
`experiment_suites/2026-04-21-lewm-fromscratch-old9-epoch50-roll5-val-20260421-153037/launch_local_5090.sh`
运行:
```bash
bash experiment_suites/2026-04-21-lewm-fromscratch-old9-epoch50-roll5-val-20260421-153037/launch_local_5090.sh
```
### 8.2 本地 5090手动启动同 recipe
```bash
/home/droid/.conda/envs/roboimi/bin/python roboimi/demos/vla_scripts/train_vla.py \
agent=lewm_resnet_query_imf_attnres \
data.dataset_dir=/home/droid/project/diana_sim/sim_transfer \
'agent.lewm_query_offsets=[8]' \
agent.pred_horizon=8 \
agent.num_action_steps=8 \
train.device=cuda \
train.batch_size=32 \
train.lr=0.0001 \
train.max_steps=109350 \
train.num_workers=4 \
train.save_freq=10000 \
train.rollout_validate_on_checkpoint=false \
train.rollout_val_freq_epochs=5 \
train.rollout_num_episodes=10 \
train.val_split=0.0 \
'train.val_episode_indices=[100]' \
train.action_mse_val_freq_epochs=1 \
train.use_swanlab=true \
train.swanlab_project=roboimi-vla \
train.swanlab_run_name=lewmimf-q08-ph08-ex08-emb384-l12-fromscratch-epoch50-step109350-5090g0-20260421-153037 \
train.pretrained_ckpt=null \
agent.lewm_pretrained_ckpt=null \
hydra.run.dir=/home/droid/project/roboimi/.worktrees/feat-lewm-imf-fusion/runs/lewmimf-q08-ph08-ex08-emb384-l12-fromscratch-epoch50-step109350-5090g0-20260421-153037
```
### 8.3 5880常用命令模板
```bash
ssh droid@100.73.14.65
cd /home/droid/roboimi_suite_20260416_lewm_imf_fusion
/home/droid/miniforge3/envs/roboimi/bin/python roboimi/demos/vla_scripts/train_vla.py \
agent=lewm_resnet_query_imf_attnres \
data.dataset_dir=/home/droid/sim_dataset/sim_transfer \
'agent.lewm_query_offsets=[8]' \
agent.pred_horizon=16 \
agent.num_action_steps=8 \
train.device=cuda train.batch_size=32 train.lr=0.0001 train.max_steps=109350 \
train.num_workers=4 train.save_freq=10000 train.rollout_validate_on_checkpoint=false \
train.rollout_val_freq_epochs=5 train.rollout_num_episodes=10 train.val_split=0.0 \
'train.val_episode_indices=[100]' train.action_mse_val_freq_epochs=1 \
train.use_swanlab=true train.swanlab_project=roboimi-vla \
train.swanlab_run_name=lewmimf-q08-ph16-ex08-emb384-l12-fromscratch-epoch50-step109350-5880g0-20260421-153037 \
train.pretrained_ckpt=null agent.lewm_pretrained_ckpt=null \
hydra.run.dir=/home/droid/roboimi_suite_20260416_lewm_imf_fusion/runs/lewmimf-q08-ph16-ex08-emb384-l12-fromscratch-epoch50-step109350-5880g0-20260421-153037
```
### 8.4 L20常用命令模板
```bash
ssh droid@100.119.99.14
cd /data/roboimi_suite_20260416_lewm_imf_fusion
/home/droid/miniforge3/envs/roboimi/bin/python roboimi/demos/vla_scripts/train_vla.py \
agent=lewm_resnet_query_imf_attnres \
data.dataset_dir=/data/simtransfer/current \
'agent.lewm_query_offsets=[16]' \
agent.pred_horizon=16 \
agent.num_action_steps=16 \
train.device=cuda train.batch_size=32 train.lr=0.0001 train.max_steps=109350 \
train.num_workers=4 train.save_freq=10000 train.rollout_validate_on_checkpoint=false \
train.rollout_val_freq_epochs=5 train.rollout_num_episodes=10 train.val_split=0.0 \
'train.val_episode_indices=[100]' train.action_mse_val_freq_epochs=1 \
train.use_swanlab=true train.swanlab_project=roboimi-vla \
train.swanlab_run_name=lewmimf-q16-ph16-ex16-emb384-l12-fromscratch-epoch50-step109350-l20g0-20260421-153037 \
train.pretrained_ckpt=null agent.lewm_pretrained_ckpt=null \
hydra.run.dir=/data/roboimi_suite_20260416_lewm_imf_fusion/runs/lewmimf-q16-ph16-ex16-emb384-l12-fromscratch-epoch50-step109350-l20g0-20260421-153037
```
### 8.5 单次离线验证(当前分支已支持并行)
**单 GPU / 4 worker**
```bash
/home/droid/.conda/envs/roboimi/bin/python roboimi/demos/vla_scripts/eval_vla.py \
agent=lewm_resnet_query_imf_attnres \
data.dataset_dir=/home/droid/project/diana_sim/sim_transfer \
train.device=cuda eval.device=cuda \
eval.ckpt_path=/home/droid/project/roboimi/.worktrees/feat-lewm-imf-fusion/runs/<run_name>/checkpoints/vla_model_best.pt \
eval.num_episodes=10 eval.num_workers=4 'eval.cuda_devices=[0]' \
eval.headless=true eval.verbose_action=false \
eval.save_summary_json=true eval.save_trajectory_image=true \
eval.trajectory_image_camera_name=front \
eval.artifact_dir=/tmp/lewm_eval_front
```
**训练内启用并行 GPU rollout推荐显式写清楚**
```bash
/home/droid/.conda/envs/roboimi/bin/python roboimi/demos/vla_scripts/train_vla.py \
agent=lewm_resnet_query_imf_attnres \
data.dataset_dir=/home/droid/project/diana_sim/sim_transfer \
'agent.lewm_query_offsets=[8]' \
agent.pred_horizon=8 \
agent.num_action_steps=8 \
train.device=cuda \
train.batch_size=32 \
train.lr=0.0001 \
train.max_steps=109350 \
train.num_workers=4 \
train.save_freq=10000 \
train.rollout_val_freq_epochs=5 \
train.rollout_num_episodes=10 \
train.rollout_device=cuda \
train.rollout_num_workers=4 \
'train.rollout_cuda_devices=[0]' \
train.rollout_validate_on_checkpoint=false \
train.val_split=0.0 \
'train.val_episode_indices=[100]' \
train.action_mse_val_freq_epochs=1 \
train.use_swanlab=true \
train.swanlab_project=roboimi-vla \
train.swanlab_run_name=<run_name> \
hydra.run.dir=/home/droid/project/roboimi/.worktrees/feat-lewm-imf-fusion/runs/<run_name>
```
### 8.6 监控日志
```bash
tail -f runs/<run_name>/launch.stdout.log
tail -f runs/<run_name>/train_vla.log
```
远端就把 `runs/<run_name>` 换成 manifest 里的绝对路径。
---
## 9. 操作建议
- **优先以 suite 的 `manifest.json` / `notes.md` / `launch_logs/*.launch.log` 为准**,不要手写一套和历史 run 不一致的命令
- 要做当前常用验证,就显式加上:
- `train.val_split=0.0`
- `train.val_episode_indices=[100]`
- `train.action_mse_val_freq_epochs=1`
- `train.rollout_val_freq_epochs=5`
- `train.rollout_num_episodes=10`
- 本分支如果要对比不同 horizon / action-step尽量只改
- `agent.lewm_query_offsets`
- `agent.pred_horizon`
- `agent.num_action_steps`
- 想复现 2026-04-21 那轮 from-scratch 结果时,记得同时设:
- `train.pretrained_ckpt=null`
- `agent.lewm_pretrained_ckpt=null`

File diff suppressed because it is too large Load Diff

View File

@@ -118,6 +118,127 @@ def recursive_to_device(data, device):
return data
def build_agent_input(batch_data):
agent_input = {
'images': {
cam_name.replace('observation.', ''): value
for cam_name, value in batch_data.items()
if cam_name.startswith('observation.') and cam_name != 'observation.state'
},
'qpos': batch_data['observation.state'],
'action': batch_data['action'],
}
if 'action_is_pad' in batch_data:
agent_input['action_is_pad'] = batch_data['action_is_pad']
lewm_images = {
cam_name.replace('lewm.observation.', ''): value
for cam_name, value in batch_data.items()
if cam_name.startswith('lewm.observation.') and cam_name != 'lewm.observation.state'
}
if lewm_images:
agent_input['lewm_images'] = lewm_images
if 'lewm.observation.state' in batch_data:
agent_input['lewm_qpos'] = batch_data['lewm.observation.state']
lewm_future_images = {
cam_name.replace('lewm.future.', ''): value
for cam_name, value in batch_data.items()
if cam_name.startswith('lewm.future.') and cam_name != 'lewm.future.state'
}
if lewm_future_images:
agent_input['lewm_future_images'] = lewm_future_images
if 'lewm.future.state' in batch_data:
agent_input['lewm_future_qpos'] = batch_data['lewm.future.state']
return agent_input
def _instantiate_dataset(cfg, dataset_image_resize_shape, episode_indices=None):
kwargs = {'image_resize_shape': dataset_image_resize_shape}
if episode_indices is not None:
kwargs['episode_indices'] = episode_indices
return instantiate(cfg.data, **kwargs)
def build_train_val_datasets(cfg, dataset_image_resize_shape):
val_episode_indices = cfg.train.get('val_episode_indices', None)
if val_episode_indices:
dataset = _instantiate_dataset(cfg, dataset_image_resize_shape)
available_episode_indices = list(getattr(dataset, 'available_episode_indices', []))
if not available_episode_indices:
raise ValueError('显式 val_episode_indices 需要数据集暴露 available_episode_indices')
requested_val_episode_indices = sorted(int(idx) for idx in val_episode_indices)
available_set = set(available_episode_indices)
missing = sorted(set(requested_val_episode_indices) - available_set)
if missing:
raise ValueError(
f'val_episode_indices {missing} 不存在于数据集可用 episodes {available_episode_indices}'
)
train_episode_indices = [
idx for idx in available_episode_indices
if idx not in set(requested_val_episode_indices)
]
if not train_episode_indices:
raise ValueError('显式 val_episode_indices 不能覆盖全部 episodes训练集将为空')
train_dataset = _instantiate_dataset(
cfg,
dataset_image_resize_shape,
episode_indices=train_episode_indices,
)
val_dataset = _instantiate_dataset(
cfg,
dataset_image_resize_shape,
episode_indices=requested_val_episode_indices,
)
return dataset, train_dataset, val_dataset, requested_val_episode_indices
dataset = _instantiate_dataset(cfg, dataset_image_resize_shape)
val_split = float(cfg.train.get('val_split', 0.1))
seed = int(cfg.train.get('seed', 42))
val_size = int(len(dataset) * val_split)
train_size = len(dataset) - val_size
if val_size > 0:
train_dataset, val_dataset = random_split(
dataset,
[train_size, val_size],
generator=torch.Generator().manual_seed(seed)
)
else:
train_dataset, val_dataset = dataset, None
return dataset, train_dataset, val_dataset, None
def compute_action_mse_validation(agent, val_loader, device):
if val_loader is None:
return None
was_training = agent.training
agent.eval()
total_squared_error = 0.0
total_count = 0.0
with torch.no_grad():
for val_batch in val_loader:
val_batch = recursive_to_device(val_batch, device)
val_input = build_agent_input(val_batch)
pred_actions = agent.predict_action_chunk(val_input)
target_actions = val_input['action']
squared_error = (pred_actions - target_actions).pow(2)
action_is_pad = val_input.get('action_is_pad', None)
if action_is_pad is not None:
mask = (~action_is_pad).unsqueeze(-1).to(squared_error.dtype)
total_squared_error += (squared_error * mask).sum().item()
total_count += mask.sum().item() * squared_error.shape[-1]
else:
total_squared_error += squared_error.sum().item()
total_count += target_actions.numel()
if was_training:
agent.train()
return total_squared_error / max(total_count, 1.0)
def resolve_resume_checkpoint(resume_ckpt, checkpoint_dir):
"""
解析恢复训练用的 checkpoint 路径。
@@ -237,6 +358,32 @@ def build_training_optimizer(agent, lr, weight_decay):
return AdamW(optim_groups, lr=lr, weight_decay=weight_decay)
def load_state_dict_ignoring_shape_mismatches(module, incoming_state_dict):
"""Load only checkpoint tensors whose keys exist locally and whose shapes match."""
current_state_dict = module.state_dict()
compatible_state_dict = {}
mismatched_keys = []
missing_keys = []
for key, value in incoming_state_dict.items():
if key not in current_state_dict:
missing_keys.append(key)
continue
if current_state_dict[key].shape != value.shape:
mismatched_keys.append(key)
continue
compatible_state_dict[key] = value
merged_state_dict = dict(current_state_dict)
merged_state_dict.update(compatible_state_dict)
module.load_state_dict(merged_state_dict, strict=True)
return {
'loaded_keys': sorted(compatible_state_dict.keys()),
'missing_keys': sorted(missing_keys),
'mismatched_keys': sorted(mismatched_keys),
}
def _init_swanlab(cfg):
"""按需初始化 SwanLab并在缺少依赖或认证失败时快速失败。"""
if not bool(cfg.train.get('use_swanlab', False)):
@@ -384,29 +531,29 @@ def _run_training(cfg: DictConfig):
vision_backbone_cfg = cfg.agent.get('vision_backbone', None)
if vision_backbone_cfg is not None and 'dataset_image_resize_shape' in vision_backbone_cfg:
dataset_image_resize_shape = vision_backbone_cfg.get('dataset_image_resize_shape')
dataset = instantiate(
cfg.data,
image_resize_shape=dataset_image_resize_shape,
dataset, train_dataset, val_dataset, explicit_val_episode_indices = (
build_train_val_datasets(cfg, dataset_image_resize_shape)
)
log.info(f"✅ 数据集加载成功。总样本数: {len(dataset)}")
except Exception as e:
log.error(f"❌ 数据集加载失败: {e}")
raise
# 训练/验证集划分
val_split = float(cfg.train.get('val_split', 0.1))
seed = int(cfg.train.get('seed', 42))
val_size = int(len(dataset) * val_split)
train_size = len(dataset) - val_size
if val_size > 0:
train_dataset, val_dataset = random_split(
dataset,
[train_size, val_size],
generator=torch.Generator().manual_seed(seed)
if explicit_val_episode_indices is not None:
log.info(
"✅ 数据集划分: 训练集=%s, 验证集=%s (显式 held-out episodes=%s)",
len(train_dataset),
len(val_dataset),
explicit_val_episode_indices,
)
else:
val_split = float(cfg.train.get('val_split', 0.1))
val_size = len(val_dataset) if val_dataset is not None else 0
if val_size > 0:
log.info(
f"✅ 数据集划分: 训练集={len(train_dataset)}, 验证集={val_size} (验证比例={val_split})"
)
log.info(f"✅ 数据集划分: 训练集={train_size}, 验证集={val_size} (验证比例={val_split})")
else:
train_dataset, val_dataset = dataset, None
log.info("✅ 数据集划分: 全部用于训练, 验证集=0 (验证比例=0)")
train_batch_size = int(cfg.train.batch_size)
@@ -509,18 +656,23 @@ def _run_training(cfg: DictConfig):
try:
checkpoint = torch.load(ckpt_path, map_location=cfg.train.device)
# 只加载模型权重(不加载 optimizer、scheduler
missing_keys, unexpected_keys = agent.load_state_dict(
load_info = load_state_dict_ignoring_shape_mismatches(
agent,
checkpoint['model_state_dict'],
strict=False # 允许部分加载(结构不完全匹配时)
)
log.info(f"✅ [Finetune] 模型权重加载成功")
if missing_keys:
log.warning(f"⚠️ [Finetune] 缺少的键 ({len(missing_keys)} 个): {missing_keys[:5]}...")
if unexpected_keys:
log.warning(f"⚠️ [Finetune] 多余的键 ({len(unexpected_keys)} 个): {unexpected_keys[:5]}...")
if load_info['missing_keys']:
log.warning(
f"⚠️ [Finetune] checkpoint 中存在本地模型没有的键 ({len(load_info['missing_keys'])} 个): "
f"{load_info['missing_keys'][:5]}..."
)
if load_info['mismatched_keys']:
log.warning(
f"⚠️ [Finetune] 因形状不匹配而跳过的键 ({len(load_info['mismatched_keys'])} 个): "
f"{load_info['mismatched_keys'][:5]}..."
)
log.info(f"📊 [Finetune] 预训练信息: 步骤={checkpoint.get('step', 'N/A')}, 损失={checkpoint.get('loss', 'N/A')}")
log.info(f"📈 [Finetune] 使用新的训练配置lr={cfg.train.lr}, max_steps={cfg.train.max_steps}")
@@ -643,22 +795,6 @@ def _run_training(cfg: DictConfig):
# =========================================================================
log.info("🏋️ 开始训练循环...")
def build_agent_input(batch_data):
"""构建 agent 输入格式"""
images = {}
# SimpleRobotDataset 返回 observation.{cam_name} 格式
for cam_name in cfg.data.camera_names:
key = f"observation.{cam_name}"
if key in batch_data:
images[cam_name] = batch_data[key]
return {
'images': images,
'qpos': batch_data['observation.state'], # SimpleRobotDataset 使用 observation.state
'action': batch_data['action'],
'action_is_pad': batch_data.get('action_is_pad', None) # 传递padding mask
}
def save_checkpoint(checkpoint_path: Path, step: int, loss_value, val_loss=None, rollout_avg_reward=None):
agent_stats = agent.get_normalization_stats()
torch.save({
@@ -702,10 +838,28 @@ def _run_training(cfg: DictConfig):
from roboimi.demos.vla_scripts import eval_vla
rollout_cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=False))
rollout_num_episodes = int(cfg.train.get('rollout_num_episodes', 1))
rollout_device = str(cfg.train.get('rollout_device', cfg.train.device))
configured_rollout_workers = cfg.train.get('rollout_num_workers', None)
if configured_rollout_workers is None:
if rollout_device.startswith('cuda'):
rollout_num_workers = min(max(rollout_num_episodes, 1), 8)
else:
rollout_num_workers = 1
else:
rollout_num_workers = int(configured_rollout_workers)
rollout_cfg.eval.ckpt_path = str(checkpoint_path)
rollout_cfg.eval.num_episodes = int(cfg.train.get('rollout_num_episodes', 1))
rollout_cfg.eval.num_episodes = rollout_num_episodes
rollout_cfg.eval.num_workers = rollout_num_workers
rollout_cfg.eval.headless = True
rollout_cfg.eval.device = 'cpu'
rollout_cfg.eval.device = rollout_device
rollout_cfg.eval.cuda_devices = cfg.train.get('rollout_cuda_devices', None)
rollout_cfg.eval.response_timeout_s = float(
cfg.train.get('rollout_response_timeout_s', 300.0)
)
rollout_cfg.eval.server_startup_timeout_s = float(
cfg.train.get('rollout_server_startup_timeout_s', 300.0)
)
rollout_cfg.eval.verbose_action = False
rollout_cfg.eval.record_video = False
rollout_cfg.eval.save_trajectory_image = True
@@ -716,9 +870,11 @@ def _run_training(cfg: DictConfig):
)
log.info(
"🎯 开始 checkpoint rollout 验证: %s (episodes=%s, headless=True)",
"🎯 开始 checkpoint rollout 验证: %s (episodes=%s, device=%s, workers=%s, headless=True)",
checkpoint_path,
rollout_cfg.eval.num_episodes,
rollout_cfg.eval.device,
rollout_cfg.eval.num_workers,
)
return eval_vla._run_eval(rollout_cfg)
@@ -731,6 +887,7 @@ def _run_training(cfg: DictConfig):
pbar = tqdm(range(start_step, cfg.train.max_steps), desc="训练中", ncols=100)
steps_per_epoch = len(train_loader)
action_mse_val_freq_epochs = int(cfg.train.get('action_mse_val_freq_epochs', 0) or 0)
rollout_val_freq_epochs = int(cfg.train.get('rollout_val_freq_epochs', 0) or 0)
rollout_validation_enabled = rollout_val_freq_epochs > 0
best_loss = resume_best_loss
@@ -809,6 +966,15 @@ def _run_training(cfg: DictConfig):
},
step=step,
)
if hasattr(agent, 'get_last_loss_breakdown'):
loss_breakdown = agent.get_last_loss_breakdown()
extra_train_metrics = {
f"train/{key}": value
for key, value in loss_breakdown.items()
if value is not None and key != 'loss'
}
if extra_train_metrics:
_log_to_swanlab(swanlab_module, extra_train_metrics, step=step)
# =====================================================================
# 检查点保存与验证
@@ -891,6 +1057,33 @@ def _run_training(cfg: DictConfig):
and completed_epoch > 0
and completed_epoch % rollout_val_freq_epochs == 0
)
should_run_action_mse_validation = (
action_mse_val_freq_epochs > 0
and val_loader is not None
and steps_per_epoch > 0
and completed_steps % steps_per_epoch == 0
and completed_epoch > 0
and completed_epoch % action_mse_val_freq_epochs == 0
)
if should_run_action_mse_validation:
action_mse = compute_action_mse_validation(
agent,
val_loader,
cfg.train.device,
)
if action_mse is not None:
log.info(
f"步骤 {step}/{cfg.train.max_steps} | Epoch {completed_epoch} "
f"held-out action MSE: {action_mse:.6f}"
)
_log_to_swanlab(
swanlab_module,
{
'val/action_mse': action_mse,
'val/action_mse_epoch': completed_epoch,
},
step=step,
)
if should_run_epoch_rollout:
if checkpoint_path is None:
checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt"

View File

@@ -0,0 +1,267 @@
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import datetime as dt
import json
import pathlib
import re
import shlex
import subprocess
from collections import defaultdict
from typing import Any
STEP_PAT = re.compile(r"步骤\s+(\d+)/(\d+)")
BAR_PAT = re.compile(r"\|\s*(\d+)/(\d+)")
def normalize_chunks(text: str):
for part in re.split(r"[\r\n]+", text):
part = part.strip()
if part:
yield part
def parse_latest_line(text: str) -> tuple[str, int | None]:
latest_line = ""
latest_step = None
for line in normalize_chunks(text):
if "步骤" not in line and "训练中:" not in line:
continue
latest_line = line
match = STEP_PAT.search(line) or BAR_PAT.search(line)
if match:
latest_step = int(match.group(1))
return latest_line, latest_step
def now_iso() -> str:
return dt.datetime.now(
dt.timezone(dt.timedelta(hours=8)),
).isoformat(timespec="seconds")
def run_cmd(cmd: list[str], check: bool = True) -> subprocess.CompletedProcess[str]:
return subprocess.run(cmd, capture_output=True, text=True, check=check)
def probe_local(run: dict[str, Any]) -> dict[str, Any]:
pid = str(run["pid"])
ps = run_cmd(["ps", "-p", pid, "-o", "pid=,stat=,etime=,args="], check=False)
log_path = pathlib.Path(run["log_path"])
latest_line = ""
latest_step = None
if log_path.exists():
latest_line, latest_step = parse_latest_line(log_path.read_text(errors="replace"))
return {
"alive": bool(ps.stdout.strip()),
"ps": ps.stdout.strip(),
"log_exists": log_path.exists(),
"latest_line": latest_line,
"latest_step": latest_step,
}
def remote_probe(host: str, remote_user: str, runs: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
payload = [
{
"run_id": run["run_id"],
"pid": str(run["pid"]),
"log_path": run["log_path"],
}
for run in runs
]
remote_py = r"""
import json
import pathlib
import re
import subprocess
import sys
payload = json.loads(sys.argv[1])
step_pat = re.compile(r"步骤\s+(\d+)/(\d+)")
bar_pat = re.compile(r"\|\s*(\d+)/(\d+)")
def normalize_chunks(text):
for part in re.split(r"[\r\n]+", text):
part = part.strip()
if part:
yield part
def parse_latest_line(text):
latest_line = ""
latest_step = None
for line in normalize_chunks(text):
if "步骤" not in line and "训练中:" not in line:
continue
latest_line = line
match = step_pat.search(line) or bar_pat.search(line)
if match:
latest_step = int(match.group(1))
return latest_line, latest_step
out = {}
for item in payload:
try:
ps = subprocess.run(
["ps", "-p", item["pid"], "-o", "pid=,stat=,etime=,args="],
capture_output=True,
text=True,
check=False,
)
log_path = pathlib.Path(item["log_path"])
latest_line = ""
latest_step = None
if log_path.exists():
latest_line, latest_step = parse_latest_line(log_path.read_text(errors="replace"))
out[item["run_id"]] = {
"alive": bool(ps.stdout.strip()),
"ps": ps.stdout.strip(),
"log_exists": log_path.exists(),
"latest_line": latest_line,
"latest_step": latest_step,
}
except Exception as exc:
out[item["run_id"]] = {
"alive": False,
"ps": "",
"log_exists": False,
"latest_line": "",
"latest_step": None,
"error": str(exc),
}
print(json.dumps(out, ensure_ascii=False))
"""
remote_target = host if "@" in host else f"{remote_user}@{host}"
remote_cmd = (
f"python3 -c {shlex.quote(remote_py)} "
f"{shlex.quote(json.dumps(payload, ensure_ascii=False))}"
)
try:
res = run_cmd(
[
"ssh",
"-F",
"/dev/null",
"-o",
"BatchMode=yes",
"-o",
"StrictHostKeyChecking=accept-new",
remote_target,
remote_cmd,
]
)
return json.loads(res.stdout)
except subprocess.CalledProcessError as exc:
error = (exc.stderr or exc.stdout or str(exc)).strip()
return {
run["run_id"]: {
"alive": False,
"ps": "",
"log_exists": False,
"latest_line": "",
"latest_step": None,
"error": f"ssh_failed: {error}",
}
for run in runs
}
def append_notes(notes_path: pathlib.Path, snapshot_at: str, runs: list[dict[str, Any]]) -> None:
lines = [f"\n## Status snapshot {snapshot_at}"]
for run in runs:
lines.append(
(
f"- {run['run_id']}: host={run['host']} gpu={run['gpu']} "
f"alive={run.get('alive', False)} step={run.get('latest_step')} "
f"pid={run['pid']}"
)
)
if run.get("latest_line"):
lines.append(f" - latest_line: `{run['latest_line']}`")
if run.get("error"):
lines.append(f" - error: `{run['error']}`")
with notes_path.open("a", encoding="utf-8") as f:
f.write("\n".join(lines) + "\n")
def main() -> int:
parser = argparse.ArgumentParser()
parser.add_argument("suite_dir", type=pathlib.Path)
parser.add_argument("--remote-user", default="droid")
parser.add_argument("--append-notes", action="store_true")
args = parser.parse_args()
suite_dir = args.suite_dir.resolve()
status_path = suite_dir / "status.json"
notes_path = suite_dir / "notes.md"
monitor_dir = suite_dir / "monitor_logs"
monitor_dir.mkdir(parents=True, exist_ok=True)
status = json.loads(status_path.read_text(encoding="utf-8"))
runs: list[dict[str, Any]] = status["runs"]
snapshot_at = now_iso()
by_host: dict[str, list[dict[str, Any]]] = defaultdict(list)
for run in runs:
by_host[run["host"]].append(run)
results: dict[str, dict[str, Any]] = {}
for host, host_runs in by_host.items():
if host == "local":
for run in host_runs:
results[run["run_id"]] = probe_local(run)
else:
results.update(remote_probe(host, args.remote_user, host_runs))
alive_count = 0
for run in runs:
result = results[run["run_id"]]
run["alive"] = result["alive"]
run["ps"] = result["ps"]
run["log_exists"] = result["log_exists"]
run["latest_line"] = result["latest_line"]
run["latest_step"] = result["latest_step"]
run["last_verified_at"] = snapshot_at
if "error" in result:
run["error"] = result["error"]
else:
run.pop("error", None)
run["status"] = "running" if result["alive"] else "stopped"
alive_count += int(result["alive"])
status["last_verified_at"] = snapshot_at
status["alive_count"] = alive_count
status["total_runs"] = len(runs)
status_path.write_text(json.dumps(status, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
snapshot_payload = {
"suite_name": status.get("suite_name"),
"snapshot_at": snapshot_at,
"alive_count": alive_count,
"total_runs": len(runs),
"runs": {run["run_id"]: results[run["run_id"]] for run in runs},
}
timestamp_slug = snapshot_at.replace(":", "").replace("+", "_").replace("-", "")
snapshot_path = monitor_dir / f"status-{timestamp_slug}.json"
snapshot_path.write_text(
json.dumps(snapshot_payload, ensure_ascii=False, indent=2) + "\n",
encoding="utf-8",
)
if args.append_notes:
append_notes(notes_path, snapshot_at, runs)
print(json.dumps(snapshot_payload, ensure_ascii=False, indent=2))
print(f"\nstatus_json={status_path}")
print(f"snapshot_json={snapshot_path}")
if args.append_notes:
print(f"notes_md={notes_path}")
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -28,6 +28,7 @@ class VLAAgent(nn.Module):
num_action_steps=8, # 每次推理实际执行多少步动作
head_type='unet', # Policy head类型: 'unet' 或 'transformer'
cond_projector=None, # 可选:将视觉+状态条件投影到head期望维度
extra_condition_tokens: int = 0, # 可选额外条件token数量例如未来预测embedding
):
super().__init__()
# 保存参数
@@ -39,6 +40,9 @@ class VLAAgent(nn.Module):
self.num_action_steps = num_action_steps
self.inference_steps = inference_steps
self.head_type = head_type # 'unet' 或 'transformer'
self.extra_condition_tokens = int(extra_condition_tokens)
if self.extra_condition_tokens < 0:
raise ValueError(f"extra_condition_tokens must be >= 0, got {self.extra_condition_tokens}")
agent_camera_names = tuple(camera_names) if camera_names is not None else None
backbone_camera_names = getattr(vision_backbone, 'camera_names', None)
backbone_camera_names = tuple(backbone_camera_names) if backbone_camera_names is not None else None
@@ -71,11 +75,14 @@ class VLAAgent(nn.Module):
stats=dataset_stats,
normalization_type=normalization_type
)
self.dataset_stats = dataset_stats
self.vision_encoder = vision_backbone
self.state_encoder = state_encoder
if self.camera_names is not None:
self.vision_encoder.camera_names = self.camera_names
self.condition_tokens_per_step = int(getattr(self.vision_encoder, 'tokens_per_step', 1))
self.state_feature_dim = int(getattr(self.state_encoder, 'output_dim', obs_dim))
joint_vision_dim = getattr(self.vision_encoder, 'joint_output_dim', None)
if joint_vision_dim is not None:
per_token_vision_dim = int(joint_vision_dim)
@@ -87,8 +94,11 @@ class VLAAgent(nn.Module):
else:
per_token_vision_dim = int(single_cam_feat_dim) * int(num_cams)
self.condition_sequence_length = self.obs_horizon * self.condition_tokens_per_step
self.raw_per_step_cond_dim = per_token_vision_dim + obs_dim
self.history_condition_sequence_length = self.obs_horizon * self.condition_tokens_per_step
self.condition_sequence_length = (
self.history_condition_sequence_length + self.extra_condition_tokens
)
self.raw_per_step_cond_dim = per_token_vision_dim + self.state_feature_dim
if cond_projector is None:
self.cond_projector = None
self.per_step_cond_dim = self.raw_per_step_cond_dim
@@ -139,7 +149,6 @@ class VLAAgent(nn.Module):
global_cond_dim=self.global_cond_dim
)
self.state_encoder = state_encoder
self.action_encoder = action_encoder
# 初始化队列(用于在线推理)
@@ -220,7 +229,7 @@ class VLAAgent(nn.Module):
f"条件维度不匹配: got {cond.shape[-1]}, expected {self.per_step_cond_dim}"
)
cond = cond.reshape(batch_size, obs_steps * token_count, self.per_step_cond_dim)
expected_length = self.condition_sequence_length
expected_length = self.history_condition_sequence_length
if cond.shape[1] != expected_length:
raise RuntimeError(
f"条件序列长度不匹配: got {cond.shape[1]}, expected {expected_length}"

View File

@@ -1,9 +1,12 @@
from __future__ import annotations
from contextlib import nullcontext
from typing import Dict, Optional
from collections import deque
from pathlib import Path
from typing import Any, Dict, Mapping, Optional, Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
from roboimi.vla.agent import VLAAgent
@@ -15,14 +18,87 @@ except ImportError: # pragma: no cover
class IMFVLAAgent(VLAAgent):
def __init__(self, *args, inference_steps: int = 1, **kwargs):
def __init__(
self,
*args,
inference_steps: int = 1,
lewm_history_horizon: Optional[int] = None,
lewm_query_offsets: Optional[Sequence[int]] = None,
lewm_predictor: Optional[nn.Module] = None,
lewm_pred_projector: Optional[nn.Module] = None,
future_decoder: Optional[nn.Module] = None,
future_query_init_std: float = 0.02,
lewm_sigreg: Optional[nn.Module] = None,
lewm_sigreg_weight: float = 0.09,
lewm_loss_weight: float = 0.0,
lewm_pretrained_ckpt: Optional[str | Path | Mapping[str, Any]] = None,
**kwargs,
):
if inference_steps != 1:
raise ValueError(
'IMFVLAAgent only supports one-step inference; '
f'inference_steps must be 1, got {inference_steps}.'
)
lewm_query_offsets = tuple(int(offset) for offset in (lewm_query_offsets or ()))
inferred_extra_condition_tokens = len(lewm_query_offsets) if lewm_query_offsets else 0
default_extra_condition_tokens = (
0 if future_decoder is not None else inferred_extra_condition_tokens
)
kwargs.setdefault('extra_condition_tokens', default_extra_condition_tokens)
self.__dict__['lewm_history_horizon'] = int(lewm_history_horizon or kwargs.get('obs_horizon', 1))
self.__dict__['lewm_query_offsets'] = lewm_query_offsets
self.__dict__['lewm_predictor'] = lewm_predictor
self.__dict__['lewm_pred_projector'] = lewm_pred_projector or nn.Identity()
self.__dict__['future_decoder'] = future_decoder
self.__dict__['future_query_tokens'] = None
self.__dict__['future_query_init_std'] = float(future_query_init_std)
self.__dict__['lewm_sigreg'] = lewm_sigreg
self.__dict__['lewm_sigreg_weight'] = float(lewm_sigreg_weight)
self.__dict__['lewm_loss_weight'] = float(lewm_loss_weight)
self.__dict__['_last_loss_breakdown'] = {
'action_loss': 0.0,
'lewm_pred_loss': 0.0,
'lewm_sigreg_loss': 0.0,
'lewm_loss': 0.0,
'loss': 0.0,
}
super().__init__(*args, inference_steps=inference_steps, **kwargs)
self.inference_steps = 1
self.lewm_history_horizon = int(lewm_history_horizon or self.obs_horizon)
self.lewm_predictor = lewm_predictor
self.lewm_pred_projector = lewm_pred_projector or nn.Identity()
if future_decoder is not None and not isinstance(future_decoder, nn.Module):
self.future_decoder = future_decoder()
else:
self.future_decoder = future_decoder
self.future_query_tokens = None
self.future_query_init_std = float(future_query_init_std)
self.lewm_sigreg = lewm_sigreg
self.lewm_sigreg_weight = float(lewm_sigreg_weight)
if self.lewm_predictor is not None and self.future_decoder is not None:
raise ValueError('lewm_predictor and future_decoder are mutually exclusive')
if self.lewm_predictor is None and self.extra_condition_tokens > 0:
raise ValueError(
'extra_condition_tokens > 0 requires lewm_predictor to be provided'
)
if self.lewm_predictor is not None and self.extra_condition_tokens != inferred_extra_condition_tokens:
raise ValueError(
'extra_condition_tokens must equal len(lewm_query_offsets) when lewm_predictor is enabled'
)
if self.future_decoder is not None:
if inferred_extra_condition_tokens <= 0:
raise ValueError('future_decoder requires non-empty lewm_query_offsets')
if self.extra_condition_tokens != 0:
raise ValueError('future_decoder requires extra_condition_tokens=0')
self.future_query_tokens = nn.Parameter(
torch.randn(
1,
inferred_extra_condition_tokens,
self.per_step_cond_dim,
) * self.future_query_init_std
)
if lewm_pretrained_ckpt is not None:
self.load_lewm_pretrained_components(lewm_pretrained_ckpt)
@staticmethod
def _broadcast_batch_time(value: torch.Tensor, reference: torch.Tensor) -> torch.Tensor:
@@ -119,14 +195,251 @@ class IMFVLAAgent(VLAAgent):
delta = self._broadcast_batch_time(t - r, z_t)
return z_t - delta * u
def _normalize_qpos_for_lewm(self, qpos: torch.Tensor) -> torch.Tensor:
if not self.normalization.enabled:
return qpos
qpos_mean = getattr(self.normalization, 'qpos_mean', None)
qpos_std = getattr(self.normalization, 'qpos_std', None)
if qpos_mean is not None and qpos_std is not None:
return (qpos - qpos_mean) / qpos_std
if isinstance(self.dataset_stats, dict):
mean = self.dataset_stats.get('qpos_mean', None)
std = self.dataset_stats.get('qpos_std', None)
if mean is not None and std is not None:
mean = torch.as_tensor(mean, dtype=qpos.dtype, device=qpos.device)
std = torch.as_tensor(std, dtype=qpos.dtype, device=qpos.device)
return (qpos - mean) / std
return self.normalization.normalize_qpos(qpos)
def _project_lewm_future_tokens(self, predicted_tokens: torch.Tensor) -> torch.Tensor:
if predicted_tokens.ndim != 3:
raise ValueError(
f"expected predicted future tokens to be 3D, got rank {predicted_tokens.ndim}"
)
batch_size, token_count, token_dim = predicted_tokens.shape
flattened = predicted_tokens.reshape(batch_size * token_count, token_dim)
projected = self.lewm_pred_projector(flattened)
if projected.ndim != 2:
raise ValueError(
f"expected lewm_pred_projector to return rank-2 tensors, got rank {projected.ndim}"
)
return projected.reshape(batch_size, token_count, projected.shape[-1])
@staticmethod
def _load_checkpoint_payload(
checkpoint_or_path: str | Path | Mapping[str, Any],
) -> Mapping[str, torch.Tensor]:
if isinstance(checkpoint_or_path, (str, Path)):
payload = torch.load(Path(checkpoint_or_path), map_location='cpu', weights_only=False)
else:
payload = checkpoint_or_path
state_dict = payload.get('state_dict', payload)
if not isinstance(state_dict, Mapping):
raise TypeError('checkpoint payload must contain a mapping state_dict')
return state_dict
@staticmethod
def _extract_prefixed_state_dict(
state_dict: Mapping[str, torch.Tensor],
prefix: str,
) -> Dict[str, torch.Tensor]:
extracted = {
key[len(prefix):]: value
for key, value in state_dict.items()
if key.startswith(prefix)
}
if not extracted:
raise KeyError(f"checkpoint missing parameters with prefix {prefix!r}")
return extracted
@staticmethod
def _adapt_and_load_state_dict(
module: nn.Module,
incoming_state_dict: Mapping[str, torch.Tensor],
*,
query_key: str = 'query_tokens',
pos_key: str = 'pos_embedding',
) -> Dict[str, Sequence[str]]:
current_state_dict = module.state_dict()
adapted_state_dict = dict(current_state_dict)
loaded_keys = []
mismatched_keys = []
missing_keys = []
for key, current_tensor in current_state_dict.items():
if key not in incoming_state_dict:
continue
source_tensor = incoming_state_dict[key]
if source_tensor.shape == current_tensor.shape:
adapted_state_dict[key] = source_tensor
loaded_keys.append(key)
continue
if key in {query_key, pos_key} and source_tensor.ndim == current_tensor.ndim:
patched = current_tensor.clone()
overlap_slices = tuple(
slice(0, min(src_dim, cur_dim))
for src_dim, cur_dim in zip(source_tensor.shape, current_tensor.shape)
)
patched[overlap_slices] = source_tensor[overlap_slices]
if key == query_key:
copy_count = min(source_tensor.shape[1], current_tensor.shape[1])
if copy_count < current_tensor.shape[1] and copy_count > 0:
tail = source_tensor[:, copy_count - 1:copy_count, ...]
feature_dim = min(tail.shape[-1], patched.shape[-1])
patched[:, copy_count:, :feature_dim] = tail[:, :, :feature_dim]
else:
copy_count = min(source_tensor.shape[1], current_tensor.shape[1])
if copy_count < current_tensor.shape[1] and copy_count > 0:
tail = source_tensor[:, copy_count - 1:copy_count, ...]
feature_dim = min(tail.shape[-1], patched.shape[-1])
patched[:, copy_count:, :feature_dim] = tail[:, :, :feature_dim]
adapted_state_dict[key] = patched
loaded_keys.append(key)
continue
mismatched_keys.append(key)
for key in incoming_state_dict.keys():
if key not in current_state_dict:
missing_keys.append(key)
module.load_state_dict(adapted_state_dict, strict=True)
return {
'loaded_keys': tuple(sorted(loaded_keys)),
'mismatched_keys': tuple(sorted(set(mismatched_keys))),
'missing_keys': tuple(sorted(set(missing_keys))),
}
@staticmethod
def _load_state_dict_ignoring_shape_mismatches(
module: nn.Module,
incoming_state_dict: Mapping[str, torch.Tensor],
) -> Dict[str, Sequence[str]]:
current_state_dict = module.state_dict()
merged_state_dict = dict(current_state_dict)
loaded_keys = []
mismatched_keys = []
missing_keys = []
for key, value in incoming_state_dict.items():
if key not in current_state_dict:
missing_keys.append(key)
continue
if current_state_dict[key].shape != value.shape:
mismatched_keys.append(key)
continue
merged_state_dict[key] = value
loaded_keys.append(key)
module.load_state_dict(merged_state_dict, strict=True)
return {
'loaded_keys': tuple(sorted(loaded_keys)),
'mismatched_keys': tuple(sorted(mismatched_keys)),
'missing_keys': tuple(sorted(missing_keys)),
}
def load_lewm_pretrained_components(
self,
checkpoint_or_path: str | Path | Mapping[str, Any],
) -> None:
state_dict = self._load_checkpoint_payload(checkpoint_or_path)
if hasattr(self.vision_encoder, 'load_lewm_checkpoint'):
try:
self.vision_encoder.load_lewm_checkpoint({'state_dict': state_dict})
except RuntimeError:
vision_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.encoder.')
self._load_state_dict_ignoring_shape_mismatches(self.vision_encoder, vision_state_dict)
else:
vision_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.encoder.')
self._load_state_dict_ignoring_shape_mismatches(self.vision_encoder, vision_state_dict)
state_encoder_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.state_encoder.')
self._load_state_dict_ignoring_shape_mismatches(self.state_encoder, state_encoder_state_dict)
projector_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.projector.proj.')
mapped_projector_state_dict = {
f'linear.{key}': value
for key, value in projector_state_dict.items()
}
self._load_state_dict_ignoring_shape_mismatches(self.cond_projector, mapped_projector_state_dict)
if self.lewm_predictor is not None:
predictor_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.predictor.')
self._adapt_and_load_state_dict(self.lewm_predictor, predictor_state_dict)
if self.lewm_pred_projector is not None:
pred_projector_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.pred_proj.')
self._load_state_dict_ignoring_shape_mismatches(
self.lewm_pred_projector,
pred_projector_state_dict,
)
def _predict_future_tokens_with_decoder(self, history_cond: torch.Tensor) -> torch.Tensor:
if self.future_decoder is None or self.future_query_tokens is None:
raise RuntimeError('future_decoder path requested but not initialized')
batch_size = history_cond.shape[0]
query_tokens = self.future_query_tokens.expand(batch_size, -1, -1)
r = torch.zeros(batch_size, device=history_cond.device, dtype=history_cond.dtype)
t = torch.ones(batch_size, device=history_cond.device, dtype=history_cond.dtype)
return self.future_decoder(query_tokens, r, t, cond=history_cond)
def _build_full_condition(
self,
images,
proprioception,
*,
lewm_images=None,
lewm_proprioception=None,
):
normalized_proprioception = self.normalization.normalize_qpos(proprioception)
history_cond = self._build_cond(images, normalized_proprioception)
predicted_future_tokens = None
lewm_history_cond = None
if self.lewm_predictor is None and self.future_decoder is None:
return history_cond, predicted_future_tokens, lewm_history_cond
lewm_images = lewm_images if lewm_images is not None else images
lewm_proprioception = (
lewm_proprioception if lewm_proprioception is not None else proprioception
)
lewm_history_cond = self._build_cond(
lewm_images,
self._normalize_qpos_for_lewm(lewm_proprioception),
)
cond = history_cond
if self.lewm_predictor is not None:
predicted_future_tokens = self.lewm_predictor(lewm_history_cond)
predicted_future_tokens = self._project_lewm_future_tokens(predicted_future_tokens)
cond = torch.cat([history_cond, predicted_future_tokens], dim=1)
if cond.shape[1] != self.condition_sequence_length:
raise RuntimeError(
f"完整条件序列长度不匹配: got {cond.shape[1]}, expected {self.condition_sequence_length}"
)
if cond.shape[-1] != self.per_step_cond_dim:
raise RuntimeError(
f"完整条件维度不匹配: got {cond.shape[-1]}, expected {self.per_step_cond_dim}"
)
elif self.future_decoder is not None:
predicted_future_tokens = self._predict_future_tokens_with_decoder(lewm_history_cond)
return cond, predicted_future_tokens, lewm_history_cond
@staticmethod
def _masked_mse_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return F.mse_loss(pred, target)
def compute_loss(self, batch):
actions, states, images = batch['action'], batch['qpos'], batch['images']
action_is_pad = batch.get('action_is_pad', None)
batch_size = actions.shape[0]
states = self.normalization.normalize_qpos(states)
actions = self.normalization.normalize_action(actions)
cond = self._build_cond(images, states)
cond, predicted_future_tokens, lewm_history_cond = self._build_full_condition(
images,
states,
lewm_images=batch.get('lewm_images', None),
lewm_proprioception=batch.get('lewm_qpos', None),
)
x = actions
e = torch.randn_like(x)
@@ -146,16 +459,109 @@ class IMFVLAAgent(VLAAgent):
if action_is_pad is not None:
mask = (~action_is_pad).unsqueeze(-1).to(loss.dtype)
valid_count = mask.sum() * loss.shape[-1]
loss = (loss * mask).sum() / valid_count.clamp_min(1.0)
action_loss = (loss * mask).sum() / valid_count.clamp_min(1.0)
else:
loss = loss.mean()
return loss
action_loss = loss.mean()
lewm_pred_loss = torch.zeros((), device=action_loss.device, dtype=action_loss.dtype)
lewm_sigreg_loss = torch.zeros((), device=action_loss.device, dtype=action_loss.dtype)
if predicted_future_tokens is not None:
lewm_future_images = batch.get('lewm_future_images', None)
lewm_future_qpos = batch.get('lewm_future_qpos', None)
if lewm_future_images is not None and lewm_future_qpos is not None:
future_target = self._build_cond(
lewm_future_images,
self._normalize_qpos_for_lewm(lewm_future_qpos),
)
lewm_pred_loss = self._masked_mse_loss(predicted_future_tokens, future_target)
if self.lewm_sigreg is not None and lewm_history_cond is not None:
lewm_sigreg_loss = self.lewm_sigreg(lewm_history_cond.transpose(0, 1))
lewm_loss = lewm_pred_loss + self.lewm_sigreg_weight * lewm_sigreg_loss
total_loss = action_loss + self.lewm_loss_weight * lewm_loss
self._last_loss_breakdown = {
'action_loss': float(action_loss.detach().item()),
'lewm_pred_loss': float(lewm_pred_loss.detach().item()),
'lewm_sigreg_loss': float(lewm_sigreg_loss.detach().item()),
'lewm_loss': float(lewm_loss.detach().item()),
'loss': float(total_loss.detach().item()),
}
return total_loss
def get_last_loss_breakdown(self) -> Dict[str, float]:
return dict(self._last_loss_breakdown)
def reset(self):
super().reset()
if self.lewm_predictor is not None:
self._queues['lewm_qpos'] = deque(maxlen=self.lewm_history_horizon)
self._queues['lewm_images'] = deque(maxlen=self.lewm_history_horizon)
def _populate_queues(self, observation: Dict[str, torch.Tensor]) -> None:
super()._populate_queues(observation)
if self.lewm_predictor is None:
return
if 'qpos' in observation:
self._queues['lewm_qpos'].append(observation['qpos'].clone())
if 'images' in observation:
ordered_images = self._order_images(observation['images'])
self._queues['lewm_images'].append({k: v.clone() for k, v in ordered_images.items()})
def _prepare_observation_batch(self) -> Dict[str, torch.Tensor]:
batch = super()._prepare_observation_batch()
if self.lewm_predictor is None:
return batch
qpos_list = list(self._queues['lewm_qpos'])
images_list = list(self._queues['lewm_images'])
if len(qpos_list) == 0 or len(images_list) == 0:
raise ValueError("LeWM 观测队列为空,请先调用 _populate_queues 添加观测")
while len(qpos_list) < self.lewm_history_horizon:
qpos_list.append(qpos_list[-1])
while len(images_list) < self.lewm_history_horizon:
images_list.append(images_list[-1])
batch['lewm_qpos'] = torch.stack(qpos_list, dim=0).unsqueeze(0)
batch['lewm_images'] = {}
camera_names = self.camera_names if self.camera_names is not None else tuple(sorted(images_list[0].keys()))
for cam_name in camera_names:
batch['lewm_images'][cam_name] = torch.stack(
[img[cam_name] for img in images_list],
dim=0,
).unsqueeze(0)
return batch
@torch.no_grad()
def predict_action(self, images, proprioception):
def predict_action_chunk(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
return self.predict_action(
batch['images'],
batch['qpos'],
lewm_images=batch.get('lewm_images', None),
lewm_proprioception=batch.get('lewm_qpos', None),
)
@torch.no_grad()
def predict_action(
self,
images,
proprioception,
*,
lewm_images=None,
lewm_proprioception=None,
):
batch_size = proprioception.shape[0]
proprioception = self.normalization.normalize_qpos(proprioception)
cond = self._build_cond(images, proprioception)
if self.lewm_predictor is not None:
cond, _predicted_future_tokens, _lewm_history_cond = self._build_full_condition(
images,
proprioception,
lewm_images=lewm_images,
lewm_proprioception=lewm_proprioception,
)
else:
cond = self._build_cond(
images,
self.normalization.normalize_qpos(proprioception),
)
z_t = torch.randn((batch_size, self.pred_horizon, self.action_dim), device=cond.device, dtype=cond.dtype)
action = self._sample_one_step(z_t, cond=cond)
return self.normalization.denormalize_action(action)

View File

@@ -0,0 +1,74 @@
# @package agent
defaults:
- /backbone@vision_backbone: lewm_resnet_query_fusion
- /modules@state_encoder: lewm_state_encoder
- /modules@action_encoder: identity_action_encoder
- /modules@cond_projector: linear_condition_projector
- /head: imf_transformer1d
- /head@future_decoder: imf_transformer1d
- _self_
_target_: roboimi.vla.agent_imf.IMFVLAAgent
action_dim: 16
obs_dim: 16
normalization_type: "min_max"
pred_horizon: 8
obs_horizon: 2
num_action_steps: 8
camera_names: ${data.camera_names}
num_cams: 3
vision_backbone:
camera_names: ${agent.camera_names}
num_views: ${agent.num_cams}
cond_projector:
output_dim: 288
lewm_history_horizon: 3
lewm_query_offsets: [8]
extra_condition_tokens: 0
lewm_loss_weight: 1.0
lewm_sigreg_weight: 0.09
lewm_pretrained_ckpt: null
future_query_init_std: 0.02
lewm_sigreg:
_target_: roboimi.vla.models.backbones.lewm_resnet_query_fusion.SIGReg
knots: 17
num_proj: 1024
diffusion_steps: 100
inference_steps: 1
head_type: "transformer"
head:
input_dim: ${agent.action_dim}
output_dim: ${agent.action_dim}
horizon: ${agent.pred_horizon}
n_obs_steps: ${agent.obs_horizon}
cond_dim: ${agent.cond_projector.output_dim}
n_emb: 384
causal_attn: false
time_as_cond: true
obs_as_cond: true
n_cond_layers: 0
backbone_type: attnres_full
n_head: 1
n_kv_head: 1
future_decoder:
input_dim: ${agent.cond_projector.output_dim}
output_dim: ${agent.cond_projector.output_dim}
horizon: ${len:${agent.lewm_query_offsets}}
n_obs_steps: ${agent.lewm_history_horizon}
cond_dim: ${agent.cond_projector.output_dim}
n_emb: 384
causal_attn: false
time_as_cond: true
obs_as_cond: true
n_cond_layers: 0
backbone_type: attnres_full
n_head: 1
n_kv_head: 1

View File

@@ -0,0 +1,77 @@
# @package agent
defaults:
- /backbone@vision_backbone: lewm_resnet_query_fusion
- /modules@state_encoder: lewm_state_encoder
- /modules@action_encoder: identity_action_encoder
- /modules@cond_projector: linear_condition_projector
- /head: imf_transformer1d
- _self_
_target_: roboimi.vla.agent_imf.IMFVLAAgent
action_dim: 16
obs_dim: 16
normalization_type: "min_max"
pred_horizon: 8
obs_horizon: 2
num_action_steps: 8
camera_names: ${data.camera_names}
num_cams: 3
vision_backbone:
camera_names: ${agent.camera_names}
num_views: ${agent.num_cams}
cond_projector:
output_dim: 288
lewm_history_horizon: 3
lewm_query_offsets: [8]
extra_condition_tokens: ${len:${agent.lewm_query_offsets}}
lewm_loss_weight: 1.0
lewm_sigreg_weight: 0.09
lewm_pretrained_ckpt: null
lewm_sigreg:
_target_: roboimi.vla.models.backbones.lewm_resnet_query_fusion.SIGReg
knots: 17
num_proj: 1024
lewm_predictor:
_target_: roboimi.vla.models.backbones.lewm_resnet_query_fusion.QueryTokenPredictor
num_frames: ${agent.lewm_history_horizon}
query_offsets: ${agent.lewm_query_offsets}
input_dim: ${agent.cond_projector.output_dim}
hidden_dim: ${agent.cond_projector.output_dim}
output_dim: ${agent.cond_projector.output_dim}
depth: 6
heads: 16
mlp_dim: 2048
dim_head: 64
dropout: 0.1
emb_dropout: 0.0
lewm_pred_projector:
_target_: roboimi.vla.models.backbones.lewm_resnet_query_fusion.LeWMProjectorMLP
input_dim: ${agent.cond_projector.output_dim}
hidden_dim: 2048
output_dim: ${agent.cond_projector.output_dim}
diffusion_steps: 100
inference_steps: 1
head_type: "transformer"
head:
input_dim: ${agent.action_dim}
output_dim: ${agent.action_dim}
horizon: ${agent.pred_horizon}
n_obs_steps: ${agent.obs_horizon}
cond_dim: 288
n_emb: 384
causal_attn: false
time_as_cond: true
obs_as_cond: true
n_cond_layers: 0
backbone_type: attnres_full
n_head: 1
n_kv_head: 1

View File

@@ -0,0 +1,7 @@
_target_: roboimi.vla.models.backbones.lewm_resnet_query_fusion.LeWMMultiViewResNetBackbone
view_feature_dim: 96
num_views: ${agent.num_cams}
view_encoder_mode: separate
camera_names: ${agent.camera_names}
checkpoint_path: null

View File

@@ -18,6 +18,8 @@ train:
# 数据加载
num_workers: 12 # DataLoader 工作进程数(调试时设为 0
val_split: 0.0 # 验证集比例;默认使用全量数据训练
val_episode_indices: null # 显式按 episode 划出的验证集,例如 [100]
action_mse_val_freq_epochs: 0 # >0 时每隔多少个 epoch 在 held-out episode 上计算 action MSE
seed: 42 # 随机种子(用于数据划分)
# 日志和检查点
@@ -29,6 +31,11 @@ train:
rollout_val_freq_epochs: 50 # 每隔多少个 epoch 执行一次 rollout 验证
rollout_validate_on_checkpoint: false # 是否在保存 checkpoint 后立即运行 rollout 验证
rollout_num_episodes: 3 # rollout 验证的回合数
rollout_device: ${train.device} # rollout 使用的设备;默认跟随训练设备
rollout_num_workers: null # rollout 并行 worker 数null 时 CUDA 自动推断CPU 保持 1
rollout_cuda_devices: null # rollout CUDA 并行使用的逻辑 device 列表null 时默认 [0]
rollout_response_timeout_s: 300.0 # rollout worker 等待 inference server 响应的超时时间
rollout_server_startup_timeout_s: 300.0 # rollout 等待 inference server 就绪的超时时间
# 学习率调度器(带预热)
warmup_steps: 2000 # 预热步数Transformer建议更长

View File

@@ -11,6 +11,8 @@ dataset_dir: "roboimi/demos/dataset/sim_transfer"
# ====================
pred_horizon: ${agent.pred_horizon} # 预测步数
obs_horizon: ${agent.obs_horizon} # 观测步数
lewm_history_horizon: ${oc.select:agent.lewm_history_horizon,null}
lewm_query_offsets: ${oc.select:agent.lewm_query_offsets,null}
# ====================
# 相机配置

View File

@@ -2,6 +2,10 @@
# 评估配置
ckpt_path: "checkpoints/vla_model_best.pt" # 模型检查点路径
num_episodes: 3 # 评估回合数
num_workers: 1 # 并行 worker 数1 表示保持单进程评估
cuda_devices: null # CUDA 并行评估时使用的逻辑设备列表null 表示默认 [0]
response_timeout_s: 300.0 # worker 等待 inference server 响应的超时时间(秒)
server_startup_timeout_s: 300.0 # parent 等待 inference server 就绪的超时时间(秒)
max_timesteps: 700 # 每回合最大时间步
device: ${train.device} # 与训练保持一致
task_name: "sim_transfer" # 环境任务名称

View File

@@ -0,0 +1,5 @@
_target_: roboimi.vla.modules.encoders.LeWMStateEncoder
input_dim: ${agent.obs_dim}
hidden_dim: 256
output_dim: 64

View File

@@ -24,6 +24,9 @@ class SimpleRobotDataset(Dataset):
camera_names: List[str] = None,
image_resize_shape: Optional[Sequence[int]] = (224, 224),
max_open_files: int = 64,
lewm_history_horizon: Optional[int] = None,
lewm_query_offsets: Optional[Sequence[int]] = None,
episode_indices: Optional[Sequence[int]] = None,
):
"""
Args:
@@ -42,12 +45,22 @@ class SimpleRobotDataset(Dataset):
self.obs_horizon = obs_horizon
self.pred_horizon = pred_horizon
self.camera_names = camera_names or []
self.lewm_history_horizon = (
int(lewm_history_horizon) if lewm_history_horizon is not None else None
)
self.lewm_query_offsets = (
tuple(int(offset) for offset in lewm_query_offsets)
if lewm_query_offsets is not None else ()
)
self.image_resize_shape = (
tuple(int(v) for v in image_resize_shape)
if image_resize_shape is not None else None
)
self.max_open_files = max(1, int(max_open_files))
self._file_cache: "OrderedDict[str, h5py.File]" = OrderedDict()
self.requested_episode_indices = (
None if episode_indices is None else tuple(sorted(int(idx) for idx in episode_indices))
)
self.dataset_dir = Path(dataset_dir)
if not self.dataset_dir.exists():
@@ -60,20 +73,45 @@ class SimpleRobotDataset(Dataset):
if not self.hdf5_files:
raise FileNotFoundError(f"{dataset_dir} 中未找到 HDF5 文件")
if self.requested_episode_indices is not None:
requested = set(self.requested_episode_indices)
filtered = []
for hdf5_path in self.hdf5_files:
stem = hdf5_path.stem
if stem.startswith("episode_"):
try:
idx = int(stem.split("_")[-1])
except ValueError:
continue
if idx in requested:
filtered.append(hdf5_path)
self.hdf5_files = filtered
if not self.hdf5_files:
raise FileNotFoundError(
f"{dataset_dir} 中未找到 episode_indices={sorted(requested)} 对应的 HDF5 文件"
)
# 构建 episode 索引(只存储元数据,不加载数据)
self.episodes = {}
self.frame_meta = [] # 存储 (ep_idx, frame_idx, hdf5_path)
for ep_idx, hdf5_path in enumerate(self.hdf5_files):
with h5py.File(hdf5_path, 'r') as f:
T = f['action'].shape[0]
dataset_episode_idx = ep_idx
stem = hdf5_path.stem
if stem.startswith("episode_"):
try:
dataset_episode_idx = int(stem.split("_")[-1])
except ValueError:
pass
start_idx = len(self.frame_meta)
for t in range(T):
self.frame_meta.append({
"ep_idx": ep_idx,
"ep_idx": dataset_episode_idx,
"frame_idx": t,
"hdf5_path": hdf5_path,
})
self.episodes[ep_idx] = list(range(start_idx, len(self.frame_meta)))
self.episodes[dataset_episode_idx] = list(range(start_idx, len(self.frame_meta)))
print(f"懒加载模式: {len(self.hdf5_files)} 个 episodes, 共 {len(self.frame_meta)}")
@@ -220,6 +258,60 @@ class SimpleRobotDataset(Dataset):
for cam_name in self.camera_names:
result[f"observation.{cam_name}"] = torch.stack(observations[f"observation.{cam_name}"])
if self.lewm_history_horizon is not None and self.lewm_history_horizon > 0:
lewm_observations = {
"state": [],
}
for cam_name in self.camera_names:
lewm_observations[f"observation.{cam_name}"] = []
for delta in range(-self.lewm_history_horizon + 1, 1):
target_idx = idx + delta
if ep_start <= target_idx <= ep_end:
target_frame = self._load_frame(target_idx)
else:
boundary_idx = ep_start if target_idx < ep_start else ep_end
target_frame = self._load_frame(boundary_idx)
lewm_observations["state"].append(target_frame["observation.state"])
for cam_name in self.camera_names:
lewm_observations[f"observation.{cam_name}"].append(
target_frame[f"observation.{cam_name}"]
)
result["lewm.observation.state"] = torch.stack(lewm_observations["state"])
for cam_name in self.camera_names:
result[f"lewm.observation.{cam_name}"] = torch.stack(
lewm_observations[f"observation.{cam_name}"]
)
if self.lewm_query_offsets:
lewm_future = {
"state": [],
}
for cam_name in self.camera_names:
lewm_future[f"observation.{cam_name}"] = []
for offset in self.lewm_query_offsets:
target_idx = idx + offset
if ep_start <= target_idx <= ep_end:
target_frame = self._load_frame(target_idx)
else:
boundary_idx = ep_start if target_idx < ep_start else ep_end
target_frame = self._load_frame(boundary_idx)
lewm_future["state"].append(target_frame["observation.state"])
for cam_name in self.camera_names:
lewm_future[f"observation.{cam_name}"].append(
target_frame[f"observation.{cam_name}"]
)
result["lewm.future.state"] = torch.stack(lewm_future["state"])
for cam_name in self.camera_names:
result[f"lewm.future.{cam_name}"] = torch.stack(
lewm_future[f"observation.{cam_name}"]
)
return result
@property
@@ -227,6 +319,10 @@ class SimpleRobotDataset(Dataset):
"""获取所有相机键名 (LeRobotDataset 格式)"""
return [f"observation.{cam_name}" for cam_name in self.camera_names]
@property
def available_episode_indices(self) -> List[int]:
return sorted(self.episodes.keys())
@property
def camera_info(self) -> dict:
"""获取相机信息"""

View File

@@ -1,5 +1,14 @@
# Backbone models
__all__ = ["LEWMViTBackbone", "ResNetBackbone", "ResNetDiffusionBackbone", "SigLIP2DiffusionBackbone"]
__all__ = [
"LEWMViTBackbone",
"LeWMMultiViewResNetBackbone",
"QueryTokenPredictor",
"LeWMProjectorMLP",
"SIGReg",
"ResNetBackbone",
"ResNetDiffusionBackbone",
"SigLIP2DiffusionBackbone",
]
def __getattr__(name):
@@ -9,6 +18,19 @@ def __getattr__(name):
if name == "SigLIP2DiffusionBackbone":
from .siglip2_diffusion_backbone import SigLIP2DiffusionBackbone
return SigLIP2DiffusionBackbone
if name in {"LeWMMultiViewResNetBackbone", "QueryTokenPredictor", "LeWMProjectorMLP", "SIGReg"}:
from .lewm_resnet_query_fusion import (
LeWMMultiViewResNetBackbone,
QueryTokenPredictor,
LeWMProjectorMLP,
SIGReg,
)
return {
"LeWMMultiViewResNetBackbone": LeWMMultiViewResNetBackbone,
"QueryTokenPredictor": QueryTokenPredictor,
"LeWMProjectorMLP": LeWMProjectorMLP,
"SIGReg": SIGReg,
}[name]
if name in {"ResNetBackbone", "ResNetDiffusionBackbone"}:
from .resnet_diffusion import ResNetDiffusionBackbone
return ResNetDiffusionBackbone

View File

@@ -0,0 +1,409 @@
from __future__ import annotations
from pathlib import Path
from typing import Any, Dict, Mapping, Optional, Sequence
import torch
from einops import rearrange
from torch import nn
import torch.nn.functional as F
from torchvision import models
from roboimi.vla.core.interfaces import VLABackbone
class SpatialSoftmax2D(nn.Module):
"""Convert a feature map into expected 2D keypoint coordinates per channel."""
def forward(self, feature_map):
if feature_map.ndim != 4:
raise ValueError(
f"SpatialSoftmax2D expects a 4D tensor, got rank {feature_map.ndim}"
)
batch, channels, height, width = feature_map.shape
scores = feature_map.reshape(batch, channels, height * width)
attention = F.softmax(scores, dim=-1)
ys = torch.linspace(-1.0, 1.0, height, device=feature_map.device, dtype=feature_map.dtype)
xs = torch.linspace(-1.0, 1.0, width, device=feature_map.device, dtype=feature_map.dtype)
grid_y, grid_x = torch.meshgrid(ys, xs, indexing="ij")
grid_x = grid_x.reshape(1, 1, height * width)
grid_y = grid_y.reshape(1, 1, height * width)
expected_x = (attention * grid_x).sum(dim=-1)
expected_y = (attention * grid_y).sum(dim=-1)
return torch.cat([expected_x, expected_y], dim=-1)
class ResNet18SpatialEncoder(nn.Module):
"""Encode one camera view into a fixed-dimensional spatial-softmax embedding."""
def __init__(self, view_feature_dim=96):
super().__init__()
if view_feature_dim % 2 != 0:
raise ValueError("view_feature_dim must be even for spatial softmax features")
backbone = models.resnet18(weights=None)
if all(
hasattr(backbone, name)
for name in ("conv1", "bn1", "relu", "maxpool", "layer1", "layer2", "layer3", "layer4")
):
self.backbone = nn.Sequential(
backbone.conv1,
backbone.bn1,
backbone.relu,
backbone.maxpool,
backbone.layer1,
backbone.layer2,
backbone.layer3,
backbone.layer4,
)
feature_channels = 512
else:
children = list(backbone.children())
if len(children) < 1:
raise ValueError("resnet18 backbone must expose child modules")
truncated = children[:-2] if len(children) > 2 else children
self.backbone = nn.Sequential(*truncated)
with torch.no_grad():
dummy = torch.zeros(1, 3, 16, 16)
feature_channels = int(self.backbone(dummy).shape[1])
self.proj = nn.Conv2d(feature_channels, view_feature_dim // 2, kernel_size=1)
self.spatial_softmax = SpatialSoftmax2D()
self.output_dim = int(view_feature_dim)
def forward(self, pixels):
if pixels.ndim not in (4, 5):
raise ValueError(
f"ResNet18SpatialEncoder expects a 4D or 5D tensor, got rank {pixels.ndim}"
)
needs_unflatten = pixels.ndim == 5
if needs_unflatten:
batch, steps, channels, height, width = pixels.shape
pixels = rearrange(pixels, "b t c h w -> (b t) c h w")
features = self.backbone(pixels.float())
features = self.proj(features)
embeddings = self.spatial_softmax(features)
if needs_unflatten:
embeddings = rearrange(embeddings, "(b t) d -> b t d", b=batch, t=steps)
return embeddings
class LeWMMultiViewResNetBackbone(VLABackbone):
"""RoboIMI-side LeWM multiview ResNet spatial-softmax encoder."""
def __init__(
self,
view_feature_dim: int = 96,
num_views: int = 3,
view_encoder_mode: str = "shared",
camera_names: Sequence[str] = ("r_vis", "top", "front"),
checkpoint_path: str | Path | None = None,
) -> None:
super().__init__()
if view_encoder_mode not in {"shared", "separate"}:
raise ValueError(
f"view_encoder_mode must be 'shared' or 'separate', got {view_encoder_mode}"
)
self.view_feature_dim = int(view_feature_dim)
self.num_views = int(num_views)
self.view_encoder_mode = view_encoder_mode
self.camera_names = tuple(camera_names)
if len(self.camera_names) != self.num_views:
raise ValueError(
f"camera_names length({len(self.camera_names)}) must equal num_views({self.num_views})"
)
self.output_dim = self.view_feature_dim * self.num_views
self.joint_output_dim = self.output_dim
self.tokens_per_step = 1
if view_encoder_mode == "shared":
self.single_view_encoder = ResNet18SpatialEncoder(
view_feature_dim=view_feature_dim
)
self.view_encoders = None
else:
self.single_view_encoder = None
self.view_encoders = nn.ModuleList(
[
ResNet18SpatialEncoder(view_feature_dim=view_feature_dim)
for _ in range(num_views)
]
)
if checkpoint_path is not None:
self.load_lewm_checkpoint(checkpoint_path)
@staticmethod
def _unwrap_state_dict(payload: Mapping[str, Any]) -> Mapping[str, torch.Tensor]:
state_dict = payload.get("state_dict", payload)
if not isinstance(state_dict, Mapping):
raise TypeError("checkpoint payload must contain a mapping state_dict")
return state_dict
@staticmethod
def _extract_prefixed_state_dict(
state_dict: Mapping[str, torch.Tensor],
prefix: str,
) -> Dict[str, torch.Tensor]:
extracted = {
key[len(prefix):]: value
for key, value in state_dict.items()
if key.startswith(prefix)
}
if not extracted:
raise KeyError(f"checkpoint missing parameters with prefix {prefix!r}")
return extracted
def load_lewm_checkpoint(self, checkpoint_or_path: str | Path | Mapping[str, Any]) -> None:
if isinstance(checkpoint_or_path, (str, Path)):
payload = torch.load(Path(checkpoint_or_path), map_location="cpu", weights_only=False)
else:
payload = checkpoint_or_path
state_dict = self._unwrap_state_dict(payload)
encoder_state_dict = self._extract_prefixed_state_dict(state_dict, "model.encoder.")
self.load_state_dict(encoder_state_dict, strict=True)
def forward(self, images):
missing = [camera_name for camera_name in self.camera_names if camera_name not in images]
if missing:
raise ValueError(
f"image input missing required cameras. missing={missing}, expected={list(self.camera_names)}"
)
first_image = images[self.camera_names[0]]
batch_size, steps = first_image.shape[:2]
view_embeddings = []
if self.view_encoder_mode == "shared":
for camera_name in self.camera_names:
view_embeddings.append(self.single_view_encoder(images[camera_name]))
else:
for single_view_encoder, camera_name in zip(self.view_encoders, self.camera_names):
view_embeddings.append(single_view_encoder(images[camera_name]))
embeddings = torch.cat(view_embeddings, dim=-1)
return embeddings.reshape(batch_size, steps, self.output_dim)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.0):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.dropout = dropout
self.norm = nn.LayerNorm(dim)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = (
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
if project_out
else nn.Identity()
)
def forward(self, x, causal=True):
x = self.norm(x)
drop = self.dropout if self.training else 0.0
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = (rearrange(t, "b t (h d) -> b h t d", h=self.heads) for t in qkv)
out = F.scaled_dot_product_attention(q, k, v, dropout_p=drop, is_causal=causal)
out = rearrange(out, "b h t d -> b t (h d)")
return self.to_out(out)
class Block(nn.Module):
def __init__(self, dim, heads, dim_head, mlp_dim, dropout=0.0):
super().__init__()
self.attn = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
self.mlp = FeedForward(dim, mlp_dim, dropout=dropout)
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class Transformer(nn.Module):
def __init__(
self,
input_dim,
hidden_dim,
output_dim,
depth,
heads,
dim_head,
mlp_dim,
dropout=0.0,
block_class=Block,
):
super().__init__()
self.norm = nn.LayerNorm(hidden_dim)
self.layers = nn.ModuleList([])
self.input_proj = (
nn.Linear(input_dim, hidden_dim)
if input_dim != hidden_dim
else nn.Identity()
)
self.cond_proj = (
nn.Linear(input_dim, hidden_dim)
if input_dim != hidden_dim
else nn.Identity()
)
self.output_proj = (
nn.Linear(hidden_dim, output_dim)
if hidden_dim != output_dim
else nn.Identity()
)
for _ in range(depth):
self.layers.append(block_class(hidden_dim, heads, dim_head, mlp_dim, dropout))
def forward(self, x, c=None):
x = self.input_proj(x)
if c is not None:
c = self.cond_proj(c)
for block in self.layers:
x = block(x)
x = self.norm(x)
return self.output_proj(x)
class QueryTokenPredictor(nn.Module):
"""History-only transformer predictor that decodes learned query tokens."""
def __init__(
self,
*,
num_frames,
query_offsets,
depth,
heads,
mlp_dim,
input_dim,
hidden_dim,
output_dim=None,
dim_head=64,
dropout=0.0,
emb_dropout=0.0,
):
super().__init__()
if num_frames <= 0:
raise ValueError(f"num_frames must be positive, got {num_frames}")
query_offsets = tuple(query_offsets)
if not query_offsets:
raise ValueError("query_offsets must contain at least one offset")
if any(offset <= 0 for offset in query_offsets):
raise ValueError(f"query_offsets must be positive, got {query_offsets}")
self.num_frames = int(num_frames)
self.query_offsets = query_offsets
self.num_query_tokens = len(query_offsets)
self.pos_embedding = nn.Parameter(
torch.randn(1, self.num_frames + self.num_query_tokens, input_dim)
)
self.query_tokens = nn.Parameter(
torch.randn(1, self.num_query_tokens, input_dim)
)
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(
input_dim,
hidden_dim,
output_dim or input_dim,
depth,
heads,
dim_head,
mlp_dim,
dropout,
block_class=Block,
)
def forward(self, x):
if x.ndim != 3:
raise ValueError(
f"QueryTokenPredictor expects a 3D tensor, got rank {x.ndim}"
)
T = x.size(1)
if T > self.num_frames:
raise ValueError(
f"input sequence length {T} exceeds configured num_frames {self.num_frames}"
)
query_tokens = self.query_tokens.expand(x.size(0), -1, -1)
tokens = torch.cat([x, query_tokens], dim=1)
tokens = tokens + self.pos_embedding[:, : tokens.size(1)]
tokens = self.dropout(tokens)
tokens = self.transformer(tokens)
return tokens[:, -self.num_query_tokens :]
class LeWMProjectorMLP(nn.Module):
def __init__(
self,
input_dim: int = 288,
hidden_dim: int = 2048,
output_dim: int = 288,
) -> None:
super().__init__()
self.output_dim = int(output_dim)
self.net = nn.Sequential(
nn.Linear(int(input_dim), int(hidden_dim)),
nn.BatchNorm1d(int(hidden_dim)),
nn.GELU(),
nn.Linear(int(hidden_dim), self.output_dim),
)
def forward(self, x):
return self.net(x)
class SIGReg(nn.Module):
"""Sketch Isotropic Gaussian Regularizer, matching the original LeWM design."""
def __init__(self, knots: int = 17, num_proj: int = 1024) -> None:
super().__init__()
self.num_proj = int(num_proj)
t = torch.linspace(0, 3, int(knots), dtype=torch.float32)
dt = 3 / (int(knots) - 1)
weights = torch.full((int(knots),), 2 * dt, dtype=torch.float32)
weights[[0, -1]] = dt
window = torch.exp(-t.square() / 2.0)
self.register_buffer("t", t)
self.register_buffer("phi", window)
self.register_buffer("weights", weights * window)
def forward(self, proj: torch.Tensor) -> torch.Tensor:
"""
proj: (T, B, D)
"""
A = torch.randn(proj.size(-1), self.num_proj, device=proj.device)
A = A.div_(A.norm(p=2, dim=0))
x_t = (proj @ A).unsqueeze(-1) * self.t
err = (x_t.cos().mean(-3) - self.phi).square() + x_t.sin().mean(-3).square()
statistic = (err @ self.weights) * proj.size(-2)
return statistic.mean()

View File

@@ -16,3 +16,23 @@ class IdentityActionEncoder(nn.Module):
def forward(self, action):
return action
class LeWMStateEncoder(nn.Module):
def __init__(
self,
input_dim: int = 16,
hidden_dim: int = 256,
output_dim: int = 64,
):
super().__init__()
self.output_dim = int(output_dim)
self.net = nn.Sequential(
nn.Linear(int(input_dim), int(hidden_dim)),
nn.LayerNorm(int(hidden_dim)),
nn.GELU(),
nn.Linear(int(hidden_dim), self.output_dim),
)
def forward(self, state):
return self.net(state)

View File

@@ -1,5 +1,11 @@
import unittest
from unittest import mock
import numpy as np
import torch
from omegaconf import OmegaConf
from roboimi.demos.vla_scripts import eval_vla
from roboimi.vla.eval_utils import execute_policy_action
@@ -14,6 +20,48 @@ class _FakeEnv:
self.calls.append(("step_jnt", action))
class _FakeQueue:
def __init__(self, initial_items=None):
self.items = list(initial_items or [])
self.put_calls = []
def put(self, item):
self.put_calls.append(item)
self.items.append(item)
def get(self, timeout=None):
del timeout
if not self.items:
raise AssertionError("queue unexpectedly empty")
return self.items.pop(0)
def _make_parallel_cfg(**eval_overrides):
eval_cfg = {
"ckpt_path": "checkpoints/vla_model_best.pt",
"num_episodes": 5,
"num_workers": 2,
"max_timesteps": 1,
"device": "cpu",
"task_name": "sim_transfer",
"camera_names": ["front"],
"use_smoothing": False,
"smooth_alpha": 0.3,
"verbose_action": False,
"headless": True,
"artifact_dir": None,
"save_artifacts": False,
"save_summary_json": False,
"save_timing": False,
"save_trajectory": False,
"save_trajectory_npz": False,
"record_video": False,
"save_trajectory_image": False,
}
eval_cfg.update(eval_overrides)
return OmegaConf.create({"agent": {}, "eval": eval_cfg})
class EvalVLAExecutionTest(unittest.TestCase):
def test_execute_policy_action_uses_ee_step(self):
env = _FakeEnv()
@@ -23,6 +71,662 @@ class EvalVLAExecutionTest(unittest.TestCase):
self.assertEqual(env.calls, [("step", action)])
def test_split_episode_indices_balances_workers(self):
self.assertEqual(
eval_vla._split_episode_indices(num_episodes=10, num_workers=3),
[[0, 1, 2, 3], [4, 5, 6], [7, 8, 9]],
)
def test_normalize_num_workers_caps_worker_count_to_episode_count(self):
self.assertEqual(eval_vla._normalize_num_workers(num_workers=5, num_episodes=2), 2)
def test_plan_episode_box_poses_uses_global_episode_order(self):
planned_poses = [
np.array([0.1, 0.2, 0.3], dtype=np.float32),
np.array([1.1, 1.2, 1.3], dtype=np.float32),
np.array([2.1, 2.2, 2.3], dtype=np.float32),
]
sampler = mock.Mock(side_effect=planned_poses)
result = eval_vla._plan_episode_box_poses(num_episodes=3, sampler=sampler)
self.assertEqual(sampler.call_count, 3)
self.assertEqual(len(result), 3)
for expected, actual in zip(planned_poses, result):
np.testing.assert_array_equal(actual, expected)
def test_resolve_policy_camera_names_matches_vlaagent_fallback_sorting(self):
cfg = OmegaConf.create(
{
"agent": {
"_target_": "roboimi.vla.agent.VLAAgent",
},
"eval": {
"camera_names": ["r_vis", "top", "front"],
},
}
)
self.assertEqual(
eval_vla._resolve_policy_camera_names(cfg),
["front", "r_vis", "top"],
)
def test_resolve_policy_camera_names_matches_gr00t_fallback_input_order(self):
cfg = OmegaConf.create(
{
"agent": {
"_target_": "roboimi.vla.agent_gr00t_dit.VLAAgentGr00tDiT",
},
"eval": {
"camera_names": ["r_vis", "top", "front"],
},
}
)
self.assertEqual(
eval_vla._resolve_policy_camera_names(cfg),
["r_vis", "top", "front"],
)
def test_build_episode_plans_without_box_poses_keeps_serial_sampling_lazy(self):
plans = eval_vla._build_episode_plans(num_episodes=3)
self.assertEqual(
plans,
[
{"episode_index": 0},
{"episode_index": 1},
{"episode_index": 2},
],
)
def test_prepare_local_policy_batch_pads_latest_observation_to_obs_horizon(self):
queues = eval_vla._new_local_policy_queues(obs_horizon=3)
observation = {
"qpos": torch.tensor([1.0, 2.0], dtype=torch.float32),
"images": {
"front": torch.tensor([[[1.0]]], dtype=torch.float32),
},
}
eval_vla._populate_local_policy_queues(queues, observation)
batch = eval_vla._prepare_local_policy_batch(
queues,
obs_horizon=3,
camera_names=["front"],
)
self.assertEqual(tuple(batch["qpos"].shape), (1, 3, 2))
self.assertEqual(tuple(batch["images"]["front"].shape), (1, 3, 1, 1, 1))
np.testing.assert_array_equal(
batch["qpos"][0].cpu().numpy(),
np.array([[1.0, 2.0], [1.0, 2.0], [1.0, 2.0]], dtype=np.float32),
)
np.testing.assert_array_equal(
batch["images"]["front"][0].cpu().numpy(),
np.array([[[[1.0]]], [[[1.0]]], [[[1.0]]]], dtype=np.float32),
)
def test_enqueue_predicted_actions_uses_executable_slice(self):
queues = eval_vla._new_local_policy_queues(obs_horizon=2)
predicted_actions = torch.tensor(
[[[10.0], [20.0], [30.0], [40.0]]],
dtype=torch.float32,
)
eval_vla._enqueue_predicted_actions(
queues,
predicted_actions=predicted_actions,
obs_horizon=2,
num_action_steps=2,
)
self.assertEqual(len(queues["action"]), 2)
np.testing.assert_array_equal(queues["action"].popleft().numpy(), np.array([20.0], dtype=np.float32))
np.testing.assert_array_equal(queues["action"].popleft().numpy(), np.array([30.0], dtype=np.float32))
def test_remote_policy_runner_only_requests_server_inference_when_local_action_queue_is_empty(self):
request_queue = _FakeQueue()
response_queue = _FakeQueue(
[
{
"type": "predict_chunk_result",
"actions": np.asarray([[[10.0], [20.0], [30.0]]], dtype=np.float32),
}
]
)
runner = eval_vla._RemotePolicyRunner(
worker_index=3,
server_index=1,
request_queue=request_queue,
response_queue=response_queue,
camera_names=["front"],
obs_horizon=2,
num_action_steps=2,
)
first_observation = {
"qpos": torch.tensor([1.0, 2.0], dtype=torch.float32),
"images": {"front": torch.tensor([[[1.0]]], dtype=torch.float32)},
}
second_observation = {
"qpos": torch.tensor([3.0, 4.0], dtype=torch.float32),
"images": {"front": torch.tensor([[[2.0]]], dtype=torch.float32)},
}
first_action, first_forward = runner.select_action(
first_observation,
episode_index=7,
timestep=0,
)
second_action, second_forward = runner.select_action(
second_observation,
episode_index=7,
timestep=1,
)
self.assertTrue(first_forward)
self.assertFalse(second_forward)
self.assertEqual(len(request_queue.put_calls), 1)
self.assertEqual(request_queue.put_calls[0]["type"], "predict_chunk")
self.assertEqual(request_queue.put_calls[0]["worker_index"], 3)
self.assertEqual(request_queue.put_calls[0]["server_index"], 1)
np.testing.assert_array_equal(first_action.numpy(), np.array([20.0], dtype=np.float32))
np.testing.assert_array_equal(second_action.numpy(), np.array([30.0], dtype=np.float32))
def test_merge_worker_summaries_sorts_episodes_and_recomputes_aggregates(self):
worker_summaries = [
{
"avg_inference_fps": 999.0,
"avg_control_fps": 999.0,
"avg_obs_read_time_ms": 999.0,
"avg_total_time_ms": 999.0,
"timing_summary": {"count": 999, "model_forward_count": 999},
"episodes": [
{
"episode_index": 2,
"episode_reward": 9.0,
"episode_max_reward": 4.0,
"inference_fps": 30.0,
"control_fps": 15.0,
}
],
"_merge_state": {
"obs_read_time_ms": [9.0],
"preprocess_time_ms": [1.0],
"inference_time_ms": [3.0],
"env_step_time_ms": [4.0],
"total_time_ms": [10.0],
"model_forward_flags": [False],
},
},
{
"avg_inference_fps": 888.0,
"avg_control_fps": 888.0,
"avg_obs_read_time_ms": 888.0,
"avg_total_time_ms": 888.0,
"timing_summary": {"count": 888, "model_forward_count": 888},
"episodes": [
{
"episode_index": 1,
"episode_reward": 6.0,
"episode_max_reward": 3.0,
"inference_fps": 20.0,
"control_fps": 10.0,
},
{
"episode_index": 0,
"episode_reward": 5.0,
"episode_max_reward": 2.0,
"inference_fps": 10.0,
"control_fps": 5.0,
},
],
"_merge_state": {
"obs_read_time_ms": [1.0, 2.0, 12.0],
"preprocess_time_ms": [2.0, 3.0, 4.0],
"inference_time_ms": [4.0, 5.0, 6.0],
"env_step_time_ms": [6.0, 7.0, 8.0],
"total_time_ms": [8.0, 9.0, 20.0],
"model_forward_flags": [True, False, True],
},
},
]
artifact_paths = {
"output_dir": "/tmp/merged",
"summary_json": "/tmp/merged/rollout_summary.json",
"timing_json": "/tmp/merged/timing.json",
"trajectory_npz": None,
"video_mp4": None,
"video_camera_name": None,
}
merged = eval_vla._merge_worker_summaries(worker_summaries, artifact_paths)
self.assertEqual([episode["episode_index"] for episode in merged["episodes"]], [0, 1, 2])
self.assertEqual(merged["episode_rewards"], [5.0, 6.0, 9.0])
self.assertEqual(merged["episode_max_rewards"], [2.0, 3.0, 4.0])
self.assertAlmostEqual(merged["avg_reward"], 20.0 / 3.0)
self.assertAlmostEqual(merged["avg_max_reward"], 3.0)
self.assertAlmostEqual(merged["avg_inference_fps"], 20.0)
self.assertAlmostEqual(merged["avg_control_fps"], 10.0)
self.assertAlmostEqual(merged["avg_obs_read_time_ms"], 6.0)
self.assertAlmostEqual(merged["avg_total_time_ms"], 47.0 / 4.0)
self.assertEqual(merged["timing_summary"]["count"], 4)
self.assertEqual(merged["timing_summary"]["model_forward_count"], 2)
self.assertEqual(merged["artifact_dir"], "/tmp/merged")
self.assertEqual(merged["artifacts"], artifact_paths)
def test_build_cuda_server_payloads_uses_round_robin_worker_assignment(self):
cfg = _make_parallel_cfg(num_episodes=4, num_workers=4, device="cuda", cuda_devices=[0, 1])
artifact_paths = {"output_dir": None}
with mock.patch.object(
eval_vla,
"sample_transfer_pose",
side_effect=[
np.array([0.1, 0.2, 0.3], dtype=np.float32),
np.array([0.4, 0.5, 0.6], dtype=np.float32),
np.array([0.7, 0.8, 0.9], dtype=np.float32),
np.array([1.0, 1.1, 1.2], dtype=np.float32),
],
):
worker_payloads, _ = eval_vla._build_parallel_worker_payloads(cfg, artifact_paths)
server_payloads, assigned_workers = eval_vla._build_cuda_server_payloads(
cfg,
worker_payloads=worker_payloads,
cuda_devices=[0, 1],
)
self.assertEqual([payload["device_index"] for payload in server_payloads], [0, 1])
self.assertEqual([payload["worker_index"] for payload in assigned_workers], [0, 1, 2, 3])
self.assertEqual([payload["server_index"] for payload in assigned_workers], [0, 1, 0, 1])
self.assertEqual(server_payloads[0]["worker_indices"], [0, 2])
self.assertEqual(server_payloads[1]["worker_indices"], [1, 3])
def test_run_eval_parallel_dispatches_episode_splits_and_box_poses(self):
cfg = _make_parallel_cfg(num_episodes=5, num_workers=2, artifact_dir="/tmp/parallel-root")
planned_poses = [
np.array([float(index), float(index) + 0.1, float(index) + 0.2], dtype=np.float32)
for index in range(5)
]
observed_payloads = []
def fake_run_spawn_jobs(payloads, max_workers, worker_fn):
del worker_fn
self.assertEqual(max_workers, 2)
observed_payloads.extend(payloads)
return [
{
"episodes": [
{
"episode_index": 4,
"episode_reward": 5.0,
"episode_max_reward": 5.0,
"inference_fps": 50.0,
"control_fps": 25.0,
},
{
"episode_index": 3,
"episode_reward": 4.0,
"episode_max_reward": 4.0,
"inference_fps": 40.0,
"control_fps": 20.0,
},
],
"_merge_state": {
"obs_read_time_ms": [4.0, 5.0],
"preprocess_time_ms": [1.0, 1.0],
"inference_time_ms": [2.0, 2.0],
"env_step_time_ms": [3.0, 3.0],
"total_time_ms": [4.0, 5.0],
"model_forward_flags": [True, True],
},
},
{
"episodes": [
{
"episode_index": 2,
"episode_reward": 3.0,
"episode_max_reward": 3.0,
"inference_fps": 30.0,
"control_fps": 15.0,
},
{
"episode_index": 1,
"episode_reward": 2.0,
"episode_max_reward": 2.0,
"inference_fps": 20.0,
"control_fps": 10.0,
},
{
"episode_index": 0,
"episode_reward": 1.0,
"episode_max_reward": 1.0,
"inference_fps": 10.0,
"control_fps": 5.0,
},
],
"_merge_state": {
"obs_read_time_ms": [1.0, 2.0, 3.0],
"preprocess_time_ms": [1.0, 1.0, 1.0],
"inference_time_ms": [2.0, 2.0, 2.0],
"env_step_time_ms": [3.0, 3.0, 3.0],
"total_time_ms": [1.0, 2.0, 3.0],
"model_forward_flags": [False, True, False],
},
},
]
with mock.patch.object(
eval_vla,
"sample_transfer_pose",
side_effect=planned_poses,
), mock.patch.object(
eval_vla,
"_run_spawn_jobs",
side_effect=fake_run_spawn_jobs,
):
summary = eval_vla._run_eval_parallel(cfg)
self.assertEqual(len(observed_payloads), 2)
self.assertEqual(
[[plan["episode_index"] for plan in payload["episode_plans"]] for payload in observed_payloads],
[[0, 1, 2], [3, 4]],
)
for payload in observed_payloads:
for plan in payload["episode_plans"]:
np.testing.assert_array_equal(
np.asarray(plan["box_pos"], dtype=np.float32),
planned_poses[plan["episode_index"]],
)
self.assertEqual([episode["episode_index"] for episode in summary["episodes"]], [0, 1, 2, 3, 4])
self.assertEqual(summary["episode_rewards"], [1.0, 2.0, 3.0, 4.0, 5.0])
self.assertEqual(summary["num_episodes"], 5)
def test_run_eval_parallel_allows_trajectory_images_and_keeps_worker_artifact_paths(self):
cfg = _make_parallel_cfg(
num_episodes=2,
num_workers=2,
artifact_dir="/tmp/parallel-images",
save_summary_json=True,
save_trajectory_image=True,
)
observed_payloads = []
def fake_run_spawn_jobs(payloads, max_workers, worker_fn):
del worker_fn
self.assertEqual(max_workers, 2)
observed_payloads.extend(payloads)
return [
{
"episodes": [
{
"episode_index": 0,
"episode_reward": 1.0,
"episode_max_reward": 1.0,
"inference_fps": 10.0,
"control_fps": 5.0,
"artifact_paths": {
"trajectory_image": f"{payloads[0]['artifact_dir']}/rollout_front_ep01_trajectory.png",
},
},
],
"_merge_state": {
"obs_read_time_ms": [1.0],
"preprocess_time_ms": [1.0],
"inference_time_ms": [1.0],
"env_step_time_ms": [1.0],
"total_time_ms": [1.0],
"model_forward_flags": [True],
},
},
{
"episodes": [
{
"episode_index": 1,
"episode_reward": 2.0,
"episode_max_reward": 2.0,
"inference_fps": 20.0,
"control_fps": 10.0,
"artifact_paths": {
"trajectory_image": f"{payloads[1]['artifact_dir']}/rollout_front_ep02_trajectory.png",
},
},
],
"_merge_state": {
"obs_read_time_ms": [2.0],
"preprocess_time_ms": [2.0],
"inference_time_ms": [2.0],
"env_step_time_ms": [2.0],
"total_time_ms": [2.0],
"model_forward_flags": [False],
},
},
]
with mock.patch.object(
eval_vla,
"sample_transfer_pose",
side_effect=[
np.array([0.1, 0.2, 0.3], dtype=np.float32),
np.array([0.4, 0.5, 0.6], dtype=np.float32),
],
), mock.patch.object(
eval_vla,
"_run_spawn_jobs",
side_effect=fake_run_spawn_jobs,
):
summary = eval_vla._run_eval_parallel(cfg)
self.assertEqual(len(observed_payloads), 2)
self.assertTrue(observed_payloads[0]["artifact_dir"].endswith("workers/worker_00"))
self.assertTrue(observed_payloads[1]["artifact_dir"].endswith("workers/worker_01"))
self.assertTrue(
summary["episodes"][0]["artifact_paths"]["trajectory_image"].endswith(
"workers/worker_00/rollout_front_ep01_trajectory.png"
)
)
self.assertTrue(
summary["episodes"][1]["artifact_paths"]["trajectory_image"].endswith(
"workers/worker_01/rollout_front_ep02_trajectory.png"
)
)
def test_run_eval_parallel_surfaces_worker_failures(self):
cfg = _make_parallel_cfg(num_episodes=2, num_workers=2)
with mock.patch.object(
eval_vla,
"sample_transfer_pose",
side_effect=[
np.array([0.1, 0.2, 0.3], dtype=np.float32),
np.array([0.4, 0.5, 0.6], dtype=np.float32),
],
), mock.patch.object(
eval_vla,
"_run_spawn_jobs",
side_effect=RuntimeError("boom"),
):
with self.assertRaisesRegex(RuntimeError, "Parallel rollout worker failed"):
eval_vla._run_eval_parallel(cfg)
def test_run_eval_parallel_cuda_builds_server_payloads_and_merges_worker_results(self):
cfg = _make_parallel_cfg(
num_episodes=4,
num_workers=4,
device="cuda",
cuda_devices=[0],
artifact_dir="/tmp/cuda-root",
)
observed_server_payloads = []
observed_worker_payloads = []
def fake_run_cuda_parallel_processes(server_payloads, worker_payloads):
observed_server_payloads.extend(server_payloads)
observed_worker_payloads.extend(worker_payloads)
return [
{
"episodes": [
{
"episode_index": 2,
"episode_reward": 3.0,
"episode_max_reward": 3.0,
"inference_fps": 30.0,
"control_fps": 15.0,
},
{
"episode_index": 0,
"episode_reward": 1.0,
"episode_max_reward": 1.0,
"inference_fps": 10.0,
"control_fps": 5.0,
},
],
"_merge_state": {
"obs_read_time_ms": [1.0, 2.0],
"preprocess_time_ms": [1.0, 1.0],
"inference_time_ms": [2.0, 2.0],
"env_step_time_ms": [3.0, 3.0],
"total_time_ms": [4.0, 4.0],
"model_forward_flags": [True, False],
},
},
{
"episodes": [
{
"episode_index": 3,
"episode_reward": 4.0,
"episode_max_reward": 4.0,
"inference_fps": 40.0,
"control_fps": 20.0,
},
{
"episode_index": 1,
"episode_reward": 2.0,
"episode_max_reward": 2.0,
"inference_fps": 20.0,
"control_fps": 10.0,
},
],
"_merge_state": {
"obs_read_time_ms": [3.0, 4.0],
"preprocess_time_ms": [1.0, 1.0],
"inference_time_ms": [2.0, 2.0],
"env_step_time_ms": [3.0, 3.0],
"total_time_ms": [4.0, 4.0],
"model_forward_flags": [True, True],
},
},
]
with mock.patch.object(
eval_vla,
"sample_transfer_pose",
side_effect=[
np.array([0.1, 0.2, 0.3], dtype=np.float32),
np.array([0.4, 0.5, 0.6], dtype=np.float32),
np.array([0.7, 0.8, 0.9], dtype=np.float32),
np.array([1.0, 1.1, 1.2], dtype=np.float32),
],
), mock.patch.object(
eval_vla,
"_run_cuda_parallel_processes",
side_effect=fake_run_cuda_parallel_processes,
create=True,
):
summary = eval_vla._run_eval_parallel_cuda(cfg)
self.assertEqual(len(observed_server_payloads), 1)
self.assertEqual(observed_server_payloads[0]["device_index"], 0)
self.assertEqual(len(observed_worker_payloads), 4)
self.assertTrue(all(payload["server_index"] == 0 for payload in observed_worker_payloads))
self.assertEqual([episode["episode_index"] for episode in summary["episodes"]], [0, 1, 2, 3])
self.assertEqual(summary["episode_rewards"], [1.0, 2.0, 3.0, 4.0])
self.assertEqual(summary["num_episodes"], 4)
def test_run_eval_parallel_cuda_surfaces_server_failures(self):
cfg = _make_parallel_cfg(num_episodes=2, num_workers=2, device="cuda", cuda_devices=[0])
with mock.patch.object(
eval_vla,
"sample_transfer_pose",
side_effect=[
np.array([0.1, 0.2, 0.3], dtype=np.float32),
np.array([0.4, 0.5, 0.6], dtype=np.float32),
],
), mock.patch.object(
eval_vla,
"_run_cuda_parallel_processes",
side_effect=RuntimeError("server boom"),
create=True,
):
with self.assertRaisesRegex(RuntimeError, "Parallel CUDA rollout failed"):
eval_vla._run_eval_parallel_cuda(cfg)
def test_run_spawn_jobs_supports_real_spawn_with_actual_eval_worker_entry(self):
payloads = [
{"_spawn_probe": True, "probe_value": 1, "worker_index": 0},
{"_spawn_probe": True, "probe_value": 2, "worker_index": 1},
]
results = eval_vla._run_spawn_jobs(
payloads=payloads,
max_workers=2,
worker_fn=eval_vla._run_eval_worker_entry,
)
self.assertEqual(sorted(result["probe_value"] for result in results), [1, 2])
self.assertEqual(sorted(result["worker_index"] for result in results), [0, 1])
def test_cuda_server_and_env_worker_entrypoints_support_real_spawn_probe(self):
ctx = eval_vla.multiprocessing.get_context("spawn")
request_queue = ctx.Queue()
response_queue = ctx.Queue()
result_queue = ctx.Queue()
server = ctx.Process(
target=eval_vla._inference_server_main,
args=(
{
"_spawn_probe": True,
"server_index": 0,
"request_queue": request_queue,
"response_queues": [response_queue],
},
),
)
worker = ctx.Process(
target=eval_vla._env_worker_main,
args=(
{
"_spawn_probe": True,
"worker_index": 0,
"server_index": 0,
"request_queue": request_queue,
"response_queue": response_queue,
"result_queue": result_queue,
},
),
)
server.start()
worker.start()
result = result_queue.get(timeout=10.0)
worker.join(timeout=10.0)
request_queue.put({"type": "shutdown_server"})
server.join(timeout=10.0)
self.assertEqual(result["kind"], "worker_result")
self.assertEqual(result["summary"]["probe_worker_index"], 0)
self.assertEqual(result["summary"]["probe_server_index"], 0)
self.assertEqual(result["summary"]["probe_actions"], [[[11.0], [22.0], [33.0]]])
self.assertEqual(worker.exitcode, 0)
self.assertEqual(server.exitcode, 0)
if __name__ == "__main__":
unittest.main()

View File

@@ -126,6 +126,26 @@ class EvalVLAHeadlessTest(unittest.TestCase):
self.assertIn("headless", eval_cfg)
self.assertFalse(eval_cfg.headless)
def test_eval_config_exposes_num_workers_default(self):
eval_cfg = OmegaConf.load(Path("roboimi/vla/conf/eval/eval.yaml"))
self.assertIn("num_workers", eval_cfg)
self.assertEqual(eval_cfg.num_workers, 1)
def test_eval_config_exposes_cuda_devices_default(self):
eval_cfg = OmegaConf.load(Path("roboimi/vla/conf/eval/eval.yaml"))
self.assertIn("cuda_devices", eval_cfg)
self.assertIsNone(eval_cfg.cuda_devices)
def test_eval_config_exposes_parallel_timeout_defaults(self):
eval_cfg = OmegaConf.load(Path("roboimi/vla/conf/eval/eval.yaml"))
self.assertIn("response_timeout_s", eval_cfg)
self.assertIn("server_startup_timeout_s", eval_cfg)
self.assertEqual(eval_cfg.response_timeout_s, 300.0)
self.assertEqual(eval_cfg.server_startup_timeout_s, 300.0)
def test_make_sim_env_accepts_headless_and_disables_render(self):
fake_env = object()
@@ -327,6 +347,172 @@ class EvalVLAHeadlessTest(unittest.TestCase):
self.assertAlmostEqual(summary["avg_reward"], 3.75)
self.assertEqual(summary["num_episodes"], 2)
def test_run_eval_uses_serial_path_when_num_workers_is_one(self):
cfg = OmegaConf.create(
{
"eval": {
"num_workers": 1,
"num_episodes": 3,
}
}
)
with mock.patch.object(
eval_vla,
"_run_eval_serial",
return_value={"mode": "serial"},
) as run_eval_serial, mock.patch.object(
eval_vla,
"_run_eval_parallel",
) as run_eval_parallel:
result = eval_vla._run_eval(cfg)
self.assertEqual(result, {"mode": "serial"})
run_eval_serial.assert_called_once_with(cfg)
run_eval_parallel.assert_not_called()
def test_run_eval_uses_serial_path_when_requested_workers_collapse_to_one(self):
cfg = OmegaConf.create(
{
"eval": {
"num_workers": 8,
"num_episodes": 1,
}
}
)
with mock.patch.object(
eval_vla,
"_run_eval_serial",
return_value={"mode": "serial"},
) as run_eval_serial, mock.patch.object(
eval_vla,
"_run_eval_parallel",
) as run_eval_parallel:
result = eval_vla._run_eval(cfg)
self.assertEqual(result, {"mode": "serial"})
run_eval_serial.assert_called_once_with(cfg)
run_eval_parallel.assert_not_called()
def test_run_eval_parallel_requires_headless_true(self):
cfg = OmegaConf.create(
{
"agent": {},
"eval": {
"ckpt_path": "checkpoints/vla_model_best.pt",
"num_episodes": 2,
"num_workers": 2,
"max_timesteps": 1,
"device": "cpu",
"task_name": "sim_transfer",
"camera_names": ["front"],
"use_smoothing": False,
"smooth_alpha": 0.3,
"verbose_action": False,
"headless": False,
},
}
)
with self.assertRaisesRegex(ValueError, "headless=true"):
eval_vla._run_eval_parallel(cfg)
def test_run_eval_parallel_dispatches_to_cpu_workers_when_device_is_cpu(self):
cfg = OmegaConf.create(
{
"agent": {},
"eval": {
"ckpt_path": "checkpoints/vla_model_best.pt",
"num_episodes": 2,
"num_workers": 2,
"max_timesteps": 1,
"device": "cpu",
"task_name": "sim_transfer",
"camera_names": ["front"],
"use_smoothing": False,
"smooth_alpha": 0.3,
"verbose_action": False,
"headless": True,
"cuda_devices": None,
},
}
)
with mock.patch.object(
eval_vla,
"_run_eval_parallel_cpu",
return_value={"mode": "cpu"},
create=True,
) as run_cpu_parallel, mock.patch.object(
eval_vla,
"_run_eval_parallel_cuda",
create=True,
) as run_cuda_parallel:
result = eval_vla._run_eval_parallel(cfg)
self.assertEqual(result, {"mode": "cpu"})
run_cpu_parallel.assert_called_once_with(cfg)
run_cuda_parallel.assert_not_called()
def test_run_eval_parallel_dispatches_to_cuda_servers_when_device_is_cuda(self):
cfg = OmegaConf.create(
{
"agent": {},
"eval": {
"ckpt_path": "checkpoints/vla_model_best.pt",
"num_episodes": 2,
"num_workers": 2,
"max_timesteps": 1,
"device": "cuda",
"task_name": "sim_transfer",
"camera_names": ["front"],
"use_smoothing": False,
"smooth_alpha": 0.3,
"verbose_action": False,
"headless": True,
"cuda_devices": [0],
},
}
)
with mock.patch.object(
eval_vla,
"_run_eval_parallel_cpu",
create=True,
) as run_cpu_parallel, mock.patch.object(
eval_vla,
"_run_eval_parallel_cuda",
return_value={"mode": "cuda"},
create=True,
) as run_cuda_parallel:
result = eval_vla._run_eval_parallel(cfg)
self.assertEqual(result, {"mode": "cuda"})
run_cpu_parallel.assert_not_called()
run_cuda_parallel.assert_called_once_with(cfg)
def test_resolve_cuda_devices_defaults_to_single_logical_gpu(self):
cfg = OmegaConf.create(
{
"device": "cuda",
"cuda_devices": None,
}
)
self.assertEqual(eval_vla._resolve_cuda_devices(cfg), [0])
def test_resolve_cuda_devices_rejects_empty_selection(self):
cfg = OmegaConf.create(
{
"device": "cuda",
"cuda_devices": [],
}
)
with self.assertRaisesRegex(ValueError, "cuda_devices"):
eval_vla._resolve_cuda_devices(cfg)
if __name__ == "__main__":
unittest.main()

View File

@@ -376,6 +376,49 @@ class _ForbiddenScheduler:
raise AssertionError('IMF inference should not use DDIM scheduler step')
class _StubFutureTokenPredictor(nn.Module):
def __init__(self, num_future_tokens=1):
super().__init__()
self.num_future_tokens = int(num_future_tokens)
self.calls = []
def forward(self, history_tokens):
self.calls.append(history_tokens.detach().clone())
summary = history_tokens.mean(dim=1, keepdim=True)
return summary.repeat(1, self.num_future_tokens, 1)
class _RecordingDirectFutureDecoder(nn.Module):
def __init__(self):
super().__init__()
self.scale = nn.Parameter(torch.tensor(0.5))
self.calls = []
def forward(self, sample, r, t, cond=None):
record = {
'sample': sample.detach().clone(),
'r': r.detach().clone(),
't': t.detach().clone(),
'cond': None if cond is None else cond.detach().clone(),
}
self.calls.append(record)
cond_term = 0.0
if cond is not None:
cond_term = cond.mean(dim=1, keepdim=True)
return self.scale * sample + cond_term
class _RecordingSigReg(nn.Module):
def __init__(self, value=0.5):
super().__init__()
self.value = float(value)
self.calls = []
def forward(self, embeddings):
self.calls.append(embeddings.detach().clone())
return embeddings.new_tensor(self.value)
def _make_images(batch_size, obs_horizon, per_camera_fill):
return {
name: torch.full((batch_size, obs_horizon, 1, 2, 2), fill_value=value, dtype=torch.float32)
@@ -501,6 +544,311 @@ class IMFVLAAgentTest(unittest.TestCase):
self.assertTrue(torch.allclose(head.calls[0]['t'], torch.ones(2)))
self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_cond))
def test_predict_action_appends_lewm_future_tokens_to_history_conditioning(self):
agent_cls, agent_module = _load_imf_agent_class()
head = _RecordingLinearIMFHead()
future_predictor = _StubFutureTokenPredictor(num_future_tokens=1)
agent = agent_cls(
vision_backbone=_StubVisionBackbone(),
state_encoder=nn.Identity(),
action_encoder=nn.Identity(),
head=head,
action_dim=2,
obs_dim=1,
pred_horizon=3,
obs_horizon=2,
diffusion_steps=10,
inference_steps=1,
num_cams=len(_CAMERA_NAMES),
camera_names=_CAMERA_NAMES,
num_action_steps=2,
head_type='transformer',
extra_condition_tokens=1,
lewm_history_horizon=3,
lewm_query_offsets=[8],
lewm_predictor=future_predictor,
lewm_pred_projector=nn.Identity(),
lewm_loss_weight=0.5,
)
agent.infer_scheduler = _ForbiddenScheduler()
images = _make_images(
batch_size=1,
obs_horizon=2,
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
)
qpos = torch.tensor([[[1.0], [2.0]]], dtype=torch.float32)
lewm_images = _make_images(
batch_size=1,
obs_horizon=3,
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
)
lewm_qpos = torch.tensor([[[0.5], [1.5], [2.5]]], dtype=torch.float32)
initial_noise = torch.tensor(
[[[1.0, -1.0], [0.0, 2.0], [3.0, -2.0]]],
dtype=torch.float32,
)
with mock.patch.object(agent_module.torch, 'randn', return_value=initial_noise):
_ = agent.predict_action(
images,
qpos,
lewm_images=lewm_images,
lewm_proprioception=lewm_qpos,
)
expected_history = torch.tensor(
[[[10.0, 20.0, 30.0, 1.0], [10.0, 20.0, 30.0, 2.0]]],
dtype=torch.float32,
)
expected_future = torch.tensor([[[10.0, 20.0, 30.0, 1.5]]], dtype=torch.float32)
expected_cond = torch.cat([expected_history, expected_future], dim=1)
self.assertEqual(agent.condition_sequence_length, 3)
self.assertEqual(agent.per_step_cond_dim, 4)
self.assertEqual(len(head.calls), 1)
self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_cond))
self.assertEqual(len(future_predictor.calls), 1)
def test_compute_loss_tracks_action_and_lewm_loss_breakdown(self):
agent_cls, agent_module = _load_imf_agent_class()
head = _RecordingLinearIMFHead()
future_predictor = _StubFutureTokenPredictor(num_future_tokens=1)
sigreg = _RecordingSigReg(value=0.75)
agent = agent_cls(
vision_backbone=_StubVisionBackbone(),
state_encoder=nn.Identity(),
action_encoder=nn.Identity(),
head=head,
action_dim=2,
obs_dim=1,
pred_horizon=3,
obs_horizon=2,
diffusion_steps=10,
inference_steps=1,
num_cams=len(_CAMERA_NAMES),
camera_names=_CAMERA_NAMES,
num_action_steps=2,
head_type='transformer',
extra_condition_tokens=1,
lewm_history_horizon=3,
lewm_query_offsets=[8],
lewm_predictor=future_predictor,
lewm_pred_projector=nn.Identity(),
lewm_sigreg=sigreg,
lewm_sigreg_weight=0.09,
lewm_loss_weight=0.25,
)
images = _make_images(
batch_size=1,
obs_horizon=2,
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
)
qpos = torch.tensor([[[0.25], [0.75]]], dtype=torch.float32)
actions = torch.tensor(
[[[1.0, -1.0], [0.5, 0.25], [-0.5, 1.5]]],
dtype=torch.float32,
)
lewm_images = _make_images(
batch_size=1,
obs_horizon=3,
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
)
lewm_qpos = torch.tensor([[[0.1], [0.2], [0.3]]], dtype=torch.float32)
lewm_future_images = _make_images(
batch_size=1,
obs_horizon=1,
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
)
lewm_future_qpos = torch.tensor([[[0.4]]], dtype=torch.float32)
noise = torch.tensor(
[[[0.2, -0.4], [0.1, 0.3], [0.5, -0.2]]],
dtype=torch.float32,
)
t_sample = torch.tensor([0.8], dtype=torch.float32)
r_sample = torch.tensor([0.25], dtype=torch.float32)
with mock.patch.object(agent_module.torch, 'randn_like', return_value=noise), \
mock.patch.object(agent_module.torch, 'rand', side_effect=[t_sample, r_sample]):
loss = agent.compute_loss(
{
'images': images,
'qpos': qpos,
'action': actions,
'lewm_images': lewm_images,
'lewm_qpos': lewm_qpos,
'lewm_future_images': lewm_future_images,
'lewm_future_qpos': lewm_future_qpos,
}
)
metrics = agent.get_last_loss_breakdown()
self.assertAlmostEqual(loss.item(), metrics['loss'], places=6)
self.assertIn('action_loss', metrics)
self.assertIn('lewm_pred_loss', metrics)
self.assertIn('lewm_sigreg_loss', metrics)
self.assertIn('lewm_loss', metrics)
self.assertAlmostEqual(metrics['lewm_sigreg_loss'], 0.75, places=6)
self.assertAlmostEqual(
metrics['lewm_loss'],
metrics['lewm_pred_loss'] + 0.09 * metrics['lewm_sigreg_loss'],
places=5,
)
self.assertAlmostEqual(
metrics['loss'],
metrics['action_loss'] + 0.25 * metrics['lewm_loss'],
places=5,
)
self.assertEqual(len(sigreg.calls), 1)
expected_lewm_history = torch.tensor(
[[[1.0, 2.0, 3.0, 0.1], [1.0, 2.0, 3.0, 0.2], [1.0, 2.0, 3.0, 0.3]]],
dtype=torch.float32,
)
torch.testing.assert_close(sigreg.calls[0], expected_lewm_history.transpose(0, 1))
def test_predict_action_with_dual_decoder_keeps_action_condition_history_only(self):
agent_cls, agent_module = _load_imf_agent_class()
head = _RecordingLinearIMFHead()
future_decoder = _RecordingDirectFutureDecoder()
agent = agent_cls(
vision_backbone=_StubVisionBackbone(),
state_encoder=nn.Identity(),
action_encoder=nn.Identity(),
head=head,
future_decoder=future_decoder,
action_dim=2,
obs_dim=1,
pred_horizon=3,
obs_horizon=2,
diffusion_steps=10,
inference_steps=1,
num_cams=len(_CAMERA_NAMES),
camera_names=_CAMERA_NAMES,
num_action_steps=2,
head_type='transformer',
lewm_history_horizon=3,
lewm_query_offsets=[8],
lewm_loss_weight=1.0,
)
agent.infer_scheduler = _ForbiddenScheduler()
with torch.no_grad():
agent.future_query_tokens.copy_(torch.tensor([[[0.1, 0.2, 0.3, 0.4]]], dtype=torch.float32))
images = _make_images(
batch_size=1,
obs_horizon=2,
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
)
qpos = torch.tensor([[[1.0], [2.0]]], dtype=torch.float32)
initial_noise = torch.tensor(
[[[1.0, -1.0], [0.0, 2.0], [3.0, -2.0]]],
dtype=torch.float32,
)
with mock.patch.object(agent_module.torch, 'randn', return_value=initial_noise):
_ = agent.predict_action(images, qpos)
expected_history = torch.tensor(
[[[10.0, 20.0, 30.0, 1.0], [10.0, 20.0, 30.0, 2.0]]],
dtype=torch.float32,
)
self.assertEqual(len(head.calls), 1)
self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_history))
self.assertEqual(len(future_decoder.calls), 0)
def test_compute_loss_with_dual_decoder_tracks_lewm_loss_breakdown(self):
agent_cls, agent_module = _load_imf_agent_class()
head = _RecordingLinearIMFHead()
future_decoder = _RecordingDirectFutureDecoder()
sigreg = _RecordingSigReg(value=0.75)
agent = agent_cls(
vision_backbone=_StubVisionBackbone(),
state_encoder=nn.Identity(),
action_encoder=nn.Identity(),
head=head,
future_decoder=future_decoder,
action_dim=2,
obs_dim=1,
pred_horizon=3,
obs_horizon=2,
diffusion_steps=10,
inference_steps=1,
num_cams=len(_CAMERA_NAMES),
camera_names=_CAMERA_NAMES,
num_action_steps=2,
head_type='transformer',
lewm_history_horizon=3,
lewm_query_offsets=[8],
lewm_sigreg=sigreg,
lewm_sigreg_weight=0.09,
lewm_loss_weight=1.0,
)
with torch.no_grad():
agent.future_query_tokens.copy_(torch.tensor([[[0.2, 0.4, 0.6, 0.8]]], dtype=torch.float32))
images = _make_images(
batch_size=1,
obs_horizon=2,
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
)
qpos = torch.tensor([[[0.25], [0.75]]], dtype=torch.float32)
actions = torch.tensor(
[[[1.0, -1.0], [0.5, 0.25], [-0.5, 1.5]]],
dtype=torch.float32,
)
lewm_images = _make_images(
batch_size=1,
obs_horizon=3,
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
)
lewm_qpos = torch.tensor([[[0.1], [0.2], [0.3]]], dtype=torch.float32)
lewm_future_images = _make_images(
batch_size=1,
obs_horizon=1,
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
)
lewm_future_qpos = torch.tensor([[[0.4]]], dtype=torch.float32)
noise = torch.tensor(
[[[0.2, -0.4], [0.1, 0.3], [0.5, -0.2]]],
dtype=torch.float32,
)
t_sample = torch.tensor([0.8], dtype=torch.float32)
r_sample = torch.tensor([0.25], dtype=torch.float32)
with mock.patch.object(agent_module.torch, 'randn_like', return_value=noise), \
mock.patch.object(agent_module.torch, 'rand', side_effect=[t_sample, r_sample]):
loss = agent.compute_loss(
{
'images': images,
'qpos': qpos,
'action': actions,
'lewm_images': lewm_images,
'lewm_qpos': lewm_qpos,
'lewm_future_images': lewm_future_images,
'lewm_future_qpos': lewm_future_qpos,
}
)
metrics = agent.get_last_loss_breakdown()
self.assertAlmostEqual(loss.item(), metrics['loss'], places=6)
self.assertEqual(len(head.calls), 2)
self.assertEqual(head.calls[0]['cond'].shape, (1, 2, 4))
self.assertEqual(len(future_decoder.calls), 1)
self.assertEqual(future_decoder.calls[0]['cond'].shape, (1, 3, 4))
self.assertAlmostEqual(
metrics['loss'],
metrics['action_loss'] + metrics['lewm_loss'],
places=5,
)
self.assertAlmostEqual(
metrics['lewm_loss'],
metrics['lewm_pred_loss'] + 0.09 * metrics['lewm_sigreg_loss'],
places=5,
)
self.assertGreater(metrics['lewm_pred_loss'], 0.0)
self.assertAlmostEqual(metrics['lewm_sigreg_loss'], 0.75, places=6)
def test_select_action_only_regenerates_when_action_queue_is_empty(self):
agent, _head, _agent_module = self._make_agent(pred_horizon=4, obs_horizon=2, num_action_steps=2)
observation = {
@@ -851,6 +1199,80 @@ class IMFVLAAgentTest(unittest.TestCase):
self.assertEqual(agent.vision_encoder.output_dim, 96)
self.assertEqual(agent.vision_encoder.eval_image_resize_shape, (256, 256))
def test_hydra_config_instantiates_lewm_resnet_query_imf_attnres_with_future_tokens(self):
cfg = _compose_cfg(
overrides=[
'agent=lewm_resnet_query_imf_attnres',
'agent.head.n_layer=1',
'agent.head.n_emb=16',
'agent.lewm_query_offsets=[8]',
]
)
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
self.assertEqual(
cfg.agent.vision_backbone._target_,
'roboimi.vla.models.backbones.lewm_resnet_query_fusion.LeWMMultiViewResNetBackbone',
)
self.assertEqual(
cfg.agent.state_encoder._target_,
'roboimi.vla.modules.encoders.LeWMStateEncoder',
)
self.assertEqual(cfg.agent.head.cond_dim, 288)
self.assertEqual(cfg.agent.cond_projector.output_dim, 288)
self.assertEqual(cfg.agent.extra_condition_tokens, 1)
self.assertEqual(
cfg.agent.lewm_sigreg._target_,
'roboimi.vla.models.backbones.lewm_resnet_query_fusion.SIGReg',
)
self.assertAlmostEqual(cfg.agent.lewm_sigreg_weight, 0.09)
with _stub_optional_modules(include_imf_head=True):
agent = instantiate(cfg.agent)
self.assertEqual(agent.per_step_cond_dim, 288)
self.assertEqual(agent.condition_sequence_length, agent.obs_horizon + 1)
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], 288)
self.assertEqual(
agent.noise_pred_net.constructor_kwargs['n_obs_steps'],
agent.condition_sequence_length,
)
self.assertIsNotNone(agent.lewm_sigreg)
def test_hydra_config_instantiates_lewm_resnet_dual_decoder_imf_attnres(self):
cfg = _compose_cfg(
overrides=[
'agent=lewm_resnet_dual_decoder_imf_attnres',
'agent.head.n_layer=1',
'agent.head.n_emb=16',
'agent.future_decoder.n_layer=1',
'agent.future_decoder.n_emb=16',
'agent.lewm_query_offsets=[8]',
]
)
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
self.assertEqual(cfg.agent.extra_condition_tokens, 0)
self.assertEqual(
cfg.agent.future_decoder._target_,
'roboimi.vla.models.heads.imf_transformer1d.IMFTransformer1D',
)
self.assertEqual(cfg.agent.head.cond_dim, 288)
self.assertEqual(cfg.agent.future_decoder.cond_dim, 288)
with _stub_optional_modules(include_imf_head=True):
agent = instantiate(cfg.agent)
self.assertEqual(agent.per_step_cond_dim, 288)
self.assertEqual(agent.condition_sequence_length, agent.obs_horizon)
self.assertEqual(agent.noise_pred_net.constructor_kwargs['n_obs_steps'], agent.obs_horizon)
self.assertEqual(agent.future_decoder.constructor_kwargs['cond_dim'], 288)
self.assertEqual(
agent.future_decoder.constructor_kwargs['n_obs_steps'],
agent.lewm_history_horizon,
)
self.assertEqual(agent.future_query_tokens.shape, (1, 1, 288))
def test_hydra_config_instantiates_resnet_imf_attnres_multitoken_with_sequence_length_three_times_obs_horizon(self):
cfg = _compose_cfg(

View File

@@ -12,18 +12,21 @@ from roboimi.vla.data.simpe_robot_dataset import SimpleRobotDataset
class SimpleRobotDatasetImageLoadingTest(unittest.TestCase):
def _write_episode(self, dataset_dir: Path) -> None:
episode_path = dataset_dir / "episode_0.hdf5"
def _write_episode(self, dataset_dir: Path, episode_idx: int = 0, *, base_value: float = 0.0) -> None:
episode_path = dataset_dir / f"episode_{episode_idx}.hdf5"
with h5py.File(episode_path, "w") as root:
root.create_dataset("action", data=np.arange(8, dtype=np.float32).reshape(4, 2))
root.create_dataset(
"action",
data=(np.arange(8, dtype=np.float32).reshape(4, 2) + base_value),
)
root.create_dataset(
"observations/qpos",
data=np.arange(16, dtype=np.float32).reshape(4, 4),
data=(np.arange(16, dtype=np.float32).reshape(4, 4) + base_value),
)
root.create_dataset("task", data=np.array([b"sim_transfer"]))
root.create_dataset(
"observations/images/front",
data=np.arange(4 * 8 * 8 * 3, dtype=np.uint8).reshape(4, 8, 8, 3),
data=((np.arange(4 * 8 * 8 * 3, dtype=np.uint8) + int(base_value)) % 255).reshape(4, 8, 8, 3),
)
def test_getitem_only_resizes_observation_horizon_images(self):
@@ -79,3 +82,46 @@ class SimpleRobotDatasetImageLoadingTest(unittest.TestCase):
fake_cv2.resize.assert_not_called()
self.assertEqual(tuple(sample["observation.front"].shape), (2, 3, 8, 8))
def test_getitem_can_emit_lewm_history_and_future_observations(self):
with tempfile.TemporaryDirectory() as tmpdir:
dataset_dir = Path(tmpdir)
self._write_episode(dataset_dir)
dataset = SimpleRobotDataset(
dataset_dir,
obs_horizon=2,
pred_horizon=3,
camera_names=["front"],
image_resize_shape=None,
lewm_history_horizon=3,
lewm_query_offsets=[1, 2],
)
sample = dataset[1]
self.assertEqual(tuple(sample["lewm.observation.state"].shape), (3, 4))
self.assertEqual(tuple(sample["lewm.observation.front"].shape), (3, 3, 8, 8))
self.assertEqual(tuple(sample["lewm.future.state"].shape), (2, 4))
self.assertEqual(tuple(sample["lewm.future.front"].shape), (2, 3, 8, 8))
def test_dataset_can_limit_loading_to_specific_episode_indices(self):
with tempfile.TemporaryDirectory() as tmpdir:
dataset_dir = Path(tmpdir)
self._write_episode(dataset_dir, episode_idx=0, base_value=0.0)
self._write_episode(dataset_dir, episode_idx=1, base_value=100.0)
dataset = SimpleRobotDataset(
dataset_dir,
obs_horizon=2,
pred_horizon=3,
camera_names=["front"],
image_resize_shape=None,
episode_indices=[1],
)
sample = dataset[0]
self.assertEqual(len(dataset.hdf5_files), 1)
self.assertEqual(dataset.available_episode_indices, [1])
self.assertEqual(len(dataset), 4)
self.assertTrue(np.allclose(sample["observation.state"][0].numpy(), np.array([100.0, 101.0, 102.0, 103.0])))

View File

@@ -158,6 +158,106 @@ class TrainVLARolloutValidationTest(unittest.TestCase):
self.assertGreater(float(cfg.train.lr), 5e-5)
self.assertGreater(cfg.train.num_workers, 8)
self.assertEqual(cfg.train.rollout_val_freq_epochs, 50)
self.assertEqual(cfg.train.rollout_device, cfg.train.device)
self.assertIsNone(cfg.train.rollout_num_workers)
self.assertIsNone(cfg.train.rollout_cuda_devices)
def test_run_training_rollout_validation_propagates_gpu_parallel_settings(self):
cfg = OmegaConf.create(
{
'train': {
'device': 'cpu',
'batch_size': 1,
'num_workers': 0,
'val_split': 0.0,
'seed': 0,
'lr': 1e-3,
'max_steps': 2,
'log_freq': 1,
'save_freq': 1000,
'warmup_steps': 1,
'scheduler_type': 'constant',
'min_lr': 0.0,
'grad_clip': 1.0,
'weight_decay': 0.0,
'pretrained_ckpt': None,
'resume_ckpt': None,
'use_swanlab': False,
'rollout_val_freq_epochs': 2,
'rollout_num_episodes': 5,
'rollout_device': 'cuda',
'rollout_num_workers': 4,
'rollout_cuda_devices': [0, 1],
'rollout_response_timeout_s': 123.0,
'rollout_server_startup_timeout_s': 456.0,
},
'data': {
'camera_names': ['front'],
},
'agent': {
'_target_': 'fake.agent',
},
'eval': {
'ckpt_path': 'unused.pt',
'num_episodes': 99,
'max_timesteps': 1,
'device': 'cpu',
'task_name': 'sim_transfer',
'camera_names': ['front'],
'use_smoothing': False,
'smooth_alpha': 0.3,
'verbose_action': False,
'headless': False,
},
}
)
rollout_mock = mock.Mock(return_value={'avg_reward': 1.0})
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return _FakeDataset()
if config_node is cfg.agent:
return _FakeAgent()
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
def fake_dataloader(_dataset, *, shuffle, **_kwargs):
del shuffle, _kwargs
return _FakeLoader(
{
'observation.front': torch.zeros(1, 3, 2, 2),
'observation.state': torch.zeros(1, 4),
'action': torch.zeros(1, 2),
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
},
length=1,
)
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \
mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \
mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \
mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \
mock.patch.object(train_vla.torch, 'save', return_value=None), \
mock.patch.object(eval_vla, '_run_eval', rollout_mock, create=True):
train_vla._run_training(cfg)
finally:
os.chdir(previous_cwd)
rollout_cfg = rollout_mock.call_args.args[0]
self.assertEqual(rollout_cfg.eval.device, 'cuda')
self.assertEqual(rollout_cfg.eval.num_workers, 4)
self.assertEqual(list(rollout_cfg.eval.cuda_devices), [0, 1])
self.assertEqual(float(rollout_cfg.eval.response_timeout_s), 123.0)
self.assertEqual(float(rollout_cfg.eval.server_startup_timeout_s), 456.0)
self.assertTrue(rollout_cfg.eval.headless)
self.assertEqual(rollout_cfg.eval.num_episodes, 5)
self.assertFalse(rollout_cfg.eval.record_video)
self.assertTrue(rollout_cfg.eval.save_summary_json)
self.assertTrue(rollout_cfg.eval.save_trajectory_image)
def test_training_passes_backbone_image_resize_override_to_dataset_instantiation(self):
cfg = OmegaConf.create(

View File

@@ -41,6 +41,19 @@ class FakeDataset:
return 4
class SplitAwareFakeDataset(FakeDataset):
def __init__(self, episode_indices=None):
self.episode_indices = None if episode_indices is None else list(episode_indices)
if self.episode_indices is None:
self.episodes = {0: [0], 1: [1], 2: [2]}
else:
self.episodes = {idx: [idx] for idx in self.episode_indices}
@property
def available_episode_indices(self):
return sorted(self.episodes.keys())
class FakeLoader:
def __init__(self, batch):
self.batch = batch
@@ -114,6 +127,26 @@ class FakeAgent(nn.Module):
return {}
class RecordingAgent(FakeAgent):
def __init__(self):
super().__init__()
self.seen_inputs = []
def compute_loss(self, agent_input):
self.seen_inputs.append(agent_input)
return super().compute_loss(agent_input)
def predict_action_chunk(self, agent_input):
self.seen_inputs.append({'predict_action_chunk': agent_input})
return torch.ones_like(agent_input['action'])
class ShapeMixedFakeAgent(FakeAgent):
def __init__(self):
super().__init__()
self.bias = nn.Parameter(torch.zeros(2))
class FakeSwanLab:
def __init__(self, init_error=None, log_errors=None, finish_error=None, image_errors=None):
self.init_error = init_error
@@ -339,6 +372,8 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
batch_size=2,
num_workers=0,
val_split=0.25,
val_episode_indices=None,
action_mse_val_freq_epochs=0,
seed=0,
lr=1e-3,
max_steps=2,
@@ -388,6 +423,18 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
}
def _make_lewm_batch(self):
batch = self._make_batch()
batch.update(
{
'lewm.observation.front': torch.ones(1, 3, 2, 2),
'lewm.observation.state': torch.ones(1, 4),
'lewm.future.front': torch.full((1, 3, 2, 2), 2.0),
'lewm.future.state': torch.full((1, 4), 2.0),
}
)
return batch
def _loader_factory(self):
train_batch = self._make_batch()
val_batch = self._make_batch()
@@ -397,6 +444,15 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
return factory
def _lewm_loader_factory(self):
train_batch = self._make_lewm_batch()
val_batch = self._make_lewm_batch()
def factory(_dataset, *, shuffle, **_kwargs):
return FakeLoader(train_batch if shuffle else val_batch)
return factory
def test_run_training_logs_metrics_and_checkpoint_paths_to_swanlab(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
@@ -442,6 +498,8 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
'batch_size': 2,
'num_workers': 0,
'val_split': 0.25,
'val_episode_indices': None,
'action_mse_val_freq_epochs': 0,
'seed': 0,
'lr': 1e-3,
'max_steps': 2,
@@ -487,6 +545,95 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
self.assertTrue(final_payload['final/best_checkpoint_path'].endswith('checkpoints/vla_model_best.pt'))
self.assertEqual(fake_swanlab.finish_calls, 1)
def test_run_training_passes_lewm_history_and_future_batches_into_agent_input(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
cfg = self._make_cfg(use_swanlab=False)
cfg.train.max_steps = 1
cfg.train.save_freq = 100
agent = RecordingAgent()
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return FakeDataset()
if config_node is cfg.agent:
return agent
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(module, 'DataLoader', side_effect=self._lewm_loader_factory()), \
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
mock.patch.object(module.torch, 'save', return_value=None):
run_training(cfg)
finally:
os.chdir(previous_cwd)
self.assertGreaterEqual(len(agent.seen_inputs), 1)
first_input = agent.seen_inputs[0]
self.assertIn('lewm_images', first_input)
self.assertIn('lewm_qpos', first_input)
self.assertIn('lewm_future_images', first_input)
self.assertIn('lewm_future_qpos', first_input)
self.assertIn('front', first_input['lewm_images'])
self.assertIn('front', first_input['lewm_future_images'])
def test_run_training_logs_epoch_action_mse_for_held_out_val_episode(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
cfg = self._make_cfg()
cfg.train.max_steps = 1
cfg.train.save_freq = 100
cfg.train.val_split = 0.0
cfg.train.val_episode_indices = [2]
cfg.train.action_mse_val_freq_epochs = 1
agent = RecordingAgent()
fake_swanlab = FakeSwanLab()
real_import_module = importlib.import_module
def fake_instantiate(config_node, **kwargs):
if config_node is cfg.data:
return SplitAwareFakeDataset(kwargs.get('episode_indices'))
if config_node is cfg.agent:
return agent
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
def fake_loader_factory(dataset, *, shuffle, **_kwargs):
action_value = 0.0 if shuffle else 2.0
batch = {
'observation.front': torch.zeros(1, 3, 2, 2),
'observation.state': torch.zeros(1, 4),
'action': torch.full((1, 1, 2), action_value),
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
}
return FakeLoader(batch)
def fake_import_module(name, package=None):
if name == 'swanlab':
return fake_swanlab
return real_import_module(name, package)
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(module, 'DataLoader', side_effect=fake_loader_factory), \
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
mock.patch.object(module.torch, 'save', return_value=None), \
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
run_training(cfg)
finally:
os.chdir(previous_cwd)
logged_keys = set().union(*(payload.keys() for payload, _ in fake_swanlab.log_calls))
self.assertIn('val/action_mse', logged_keys)
def test_run_training_skips_swanlab_when_disabled(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
@@ -668,6 +815,52 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
self.assertTrue(final_payload['final/best_checkpoint_path'].endswith('checkpoints/vla_model_best.pt'))
self.assertFalse(any(path.endswith('checkpoints/vla_model_best.pt') for path in saved_paths))
def test_run_training_pretrained_ckpt_loads_matching_keys_even_if_some_shapes_mismatch(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
cfg = self._make_cfg(use_swanlab=False)
cfg.train.max_steps = 0
cfg.train.save_freq = 100
cfg.train.pretrained_ckpt = 'pretrained.pt'
agent = ShapeMixedFakeAgent()
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return FakeDataset()
if config_node is cfg.agent:
return agent
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
def fake_torch_load(path, map_location=None):
del map_location
if Path(path).name != 'pretrained.pt':
raise AssertionError(f'unexpected load path: {path}')
return {
'model_state_dict': {
'weight': torch.tensor(3.0),
'bias': torch.tensor([1.0, 2.0, 3.0]),
},
'step': 123,
'loss': 0.5,
}
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
Path('pretrained.pt').write_bytes(b'pretend')
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
mock.patch.object(module.torch, 'save', return_value=None), \
mock.patch.object(module.torch, 'load', side_effect=fake_torch_load):
run_training(cfg)
finally:
os.chdir(previous_cwd)
self.assertAlmostEqual(agent.weight.item(), 3.0, places=6)
def test_run_training_resume_ignores_best_checkpoint_without_rollout_metric(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)