Compare commits
24 Commits
dev
...
feat-lewm-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
61522d9ae5 | ||
|
|
4cd33258d2 | ||
|
|
d8066823e2 | ||
|
|
395f5a1645 | ||
|
|
74f4963613 | ||
|
|
ff7c9c1f2a | ||
|
|
d51b3ecafa | ||
|
|
2033169840 | ||
|
|
a78006808a | ||
|
|
0586a6e6c7 | ||
|
|
48f0eb8dd0 | ||
|
|
3a17744dcf | ||
|
|
0514f86c36 | ||
|
|
dffd92f82d | ||
|
|
c2000b5533 | ||
|
|
8d6060224a | ||
|
|
8a8193fe7e | ||
|
|
1a92c5e8a6 | ||
|
|
b76bcd8b37 | ||
|
|
2f9b99e0c4 | ||
|
|
d5d5b53f71 | ||
|
|
d84bc6876e | ||
|
|
424c265823 | ||
|
|
cb79e00546 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -126,3 +126,6 @@ GEMINI.md
|
||||
.github/copilot-instructions.md
|
||||
|
||||
.hydra/
|
||||
|
||||
# Local git worktrees
|
||||
.worktrees/
|
||||
|
||||
471
docs/lewm-imf-experiment-guide.md
Normal file
471
docs/lewm-imf-experiment-guide.md
Normal file
@@ -0,0 +1,471 @@
|
||||
# feat-lewm-imf-fusion 实验操作指南
|
||||
|
||||
适用 worktree:`/home/droid/project/roboimi/.worktrees/feat-lewm-imf-fusion`
|
||||
|
||||
## 0. 先记住当前常用 recipe
|
||||
|
||||
当前这条分支最常用的训练/验证配方,直接参考:
|
||||
`experiment_suites/2026-04-21-lewm-fromscratch-old9-epoch50-roll5-val-20260421-153037/`
|
||||
|
||||
核心约定:
|
||||
- agent:`lewm_resnet_query_imf_attnres`
|
||||
- from scratch:`train.pretrained_ckpt=null`,`agent.lewm_pretrained_ckpt=null`
|
||||
- 训练:`batch_size=32`,`lr=1e-4`,`max_steps=109350`,`save_freq=10000`
|
||||
- 数值验证:`train.val_split=0.0` + `train.val_episode_indices=[100]`
|
||||
- held-out numeric validation:`train.action_mse_val_freq_epochs=1`
|
||||
- rollout validation:`train.rollout_val_freq_epochs=5`,`train.rollout_num_episodes=10`
|
||||
- SwanLab:`train.use_swanlab=true`,project=`roboimi-vla`
|
||||
|
||||
---
|
||||
|
||||
## 1. 分支结构与关键文件
|
||||
|
||||
| 路径 | 作用 |
|
||||
| --- | --- |
|
||||
| `roboimi/demos/vla_scripts/train_vla.py` | 主训练入口;负责数据集、checkpoint、数值验证、训练期 rollout 验证、SwanLab |
|
||||
| `roboimi/demos/vla_scripts/eval_vla.py` | 单次 rollout / 离线验证入口;支持 headless、summary、trajectory image/video artifact |
|
||||
| `roboimi/vla/conf/config.yaml` | 全局 Hydra 配置;训练默认值都在这里 |
|
||||
| `roboimi/vla/conf/eval/eval.yaml` | eval 默认配置;`eval.ckpt_path`、`eval.num_episodes`、artifact 开关都在这里 |
|
||||
| `roboimi/vla/conf/agent/lewm_resnet_query_imf_attnres.yaml` | 本分支最常用 agent;LeWM query fusion + IMF AttnRes head |
|
||||
| `roboimi/vla/conf/backbone/lewm_resnet_query_fusion.yaml` | LeWM 多视角 ResNet query fusion backbone 配置 |
|
||||
| `roboimi/vla/agent_imf.py` | `IMFVLAAgent` 实现;one-step IMF 推理、LeWM loss、LeWM 预训练组件加载 |
|
||||
| `roboimi/vla/data/simpe_robot_dataset.py` | HDF5 懒加载数据集;也负责 `episode_indices` 过滤 |
|
||||
| `roboimi/vla/scripts/calculate_stats.py` | 重算 `dataset_stats.pkl` |
|
||||
| `experiment_suites/2026-04-21-lewm-fromscratch-old9-epoch50-roll5-val-20260421-153037/` | 当前最常用 suite;manifest、notes、launch log、local 启动脚本都在这里 |
|
||||
|
||||
补充:
|
||||
- 本分支常用 run name 形如 `lewmimf-q08-ph08-ex08-emb384-l12-fromscratch-epoch50-step109350-5090g0-20260421-153037`
|
||||
- `q08/ph16/ex08` 这类后缀分别对应 `agent.lewm_query_offsets`、`agent.pred_horizon`、`agent.num_action_steps`
|
||||
|
||||
---
|
||||
|
||||
## 2. 三台机器与环境
|
||||
|
||||
| 机器 | GPU | repo / worktree | Python | 常用数据集路径 |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| 本地 `droid-z790eagleax` | 1× RTX 5090 32GB | `/home/droid/project/roboimi/.worktrees/feat-lewm-imf-fusion` | `/home/droid/.conda/envs/roboimi/bin/python` | `/home/droid/project/diana_sim/sim_transfer` |
|
||||
| 5880 节点 `100.73.14.65` | 2× RTX 5880 Ada 48GB | `/home/droid/roboimi_suite_20260416_lewm_imf_fusion` | `/home/droid/miniforge3/envs/roboimi/bin/python` | `/home/droid/sim_dataset/sim_transfer` |
|
||||
| L20 节点 `100.119.99.14` | 8× NVIDIA L20 46GB | `/data/roboimi_suite_20260416_lewm_imf_fusion` | `/home/droid/miniforge3/envs/roboimi/bin/python` | `/data/simtransfer/current` |
|
||||
|
||||
连接:
|
||||
- 5880:`ssh droid@100.73.14.65`
|
||||
- L20:`ssh droid@100.119.99.14`
|
||||
|
||||
经验规则:
|
||||
- 本地 5090:适合单条 smoke / 小规模主跑 / 本地调参
|
||||
- 5880:适合 2 条并行主跑
|
||||
- L20:适合大 grid;数据和 run 建议都放 `/data`
|
||||
|
||||
---
|
||||
|
||||
## 3. 训练流怎么走
|
||||
|
||||
`train_vla.py` 的实际流程:
|
||||
|
||||
1. 读取 Hydra 配置并打印完整 cfg
|
||||
2. 通过 `build_train_val_datasets()` 构建 train/val dataset
|
||||
3. 用 `DataLoader` 建 train/val loader
|
||||
4. 从 `dataset_dir/dataset_stats.pkl` 读取归一化统计
|
||||
5. instantiate `IMFVLAAgent`
|
||||
6. 可选加载:
|
||||
- `train.pretrained_ckpt`
|
||||
- `train.resume_ckpt`
|
||||
- `agent.lewm_pretrained_ckpt`
|
||||
7. 训练循环里按 `log_freq` 打 train loss / lr
|
||||
8. 按 `save_freq` 保存 `checkpoints/vla_model_step_*.pt`
|
||||
9. 每个 epoch 结束时,按配置跑:
|
||||
- held-out action MSE
|
||||
- rollout validation
|
||||
10. 最后写:
|
||||
- `checkpoints/vla_model_best.pt`
|
||||
- `checkpoints/vla_model_final.pt`
|
||||
|
||||
当前 best model 选择逻辑:
|
||||
- **第一次拿到 rollout reward 之前**:先用 `val_loss`(或 train loss 回退)挑 best
|
||||
- **第一次 rollout 之后**:优先用 `rollout_avg_reward` 挑 best
|
||||
|
||||
输出目录一般通过 `hydra.run.dir=...` 固定;否则 Hydra 自己生成。
|
||||
|
||||
---
|
||||
|
||||
## 4. 验证流怎么走
|
||||
|
||||
### 4.1 held-out 数值验证
|
||||
|
||||
当前常用做法不是随机切 `val_split`,而是:
|
||||
- `train.val_split=0.0`
|
||||
- `train.val_episode_indices=[100]`
|
||||
- `train.action_mse_val_freq_epochs=1`
|
||||
|
||||
这样每个 epoch 结束都会在 `episode_100.hdf5` 上跑一次 `compute_action_mse_validation()`,日志 key 是:
|
||||
- 控制台 / `train_vla.log`:`held-out action MSE`
|
||||
- SwanLab:`val/action_mse`
|
||||
|
||||
### 4.2 rollout 验证
|
||||
|
||||
当前训练内 rollout 验证由 `train_vla.py -> run_rollout_validation() -> eval_vla._run_eval()` 触发。
|
||||
|
||||
当前这条分支的常用训练内 rollout 约束是:
|
||||
- `train.rollout_val_freq_epochs=5`
|
||||
- `train.rollout_num_episodes=10`
|
||||
- `train.rollout_validate_on_checkpoint=false`
|
||||
- 强制 headless
|
||||
- 强制 `verbose_action=false`
|
||||
- 强制 `record_video=false`
|
||||
- 强制 `save_trajectory_image=true`
|
||||
- 强制 `trajectory_image_camera_name=front`
|
||||
- 强制 `save_summary_json=true`
|
||||
|
||||
当前已经修正为**配置驱动的 rollout device / worker 路径**:
|
||||
- `train.rollout_device`:默认跟随 `train.device`
|
||||
- `train.rollout_num_workers`:默认 `null`
|
||||
- 当 rollout 设备是 CPU 时,自动退化为 `1`
|
||||
- 当 rollout 设备是 CUDA 时,自动推断为 `min(train.rollout_num_episodes, 8)`
|
||||
- `train.rollout_cuda_devices`:默认 `null`,等价于当前可见逻辑 GPU `[0]`
|
||||
- `train.rollout_response_timeout_s`
|
||||
- `train.rollout_server_startup_timeout_s`
|
||||
|
||||
所以现在:
|
||||
- 训练在 `cuda` 上时,**训练期 rollout 默认会走 GPU**
|
||||
- 如果 `rollout_num_workers > 1`,就会自动走并行 rollout
|
||||
- 可以是 **单 GPU 多 worker 共用一个 inference server**
|
||||
- 也可以是 **多 GPU 多 server 分摊 worker**
|
||||
|
||||
训练内 rollout artifact 默认落到:
|
||||
`<hydra.run.dir>/rollout_artifacts/<checkpoint_stem>/`
|
||||
|
||||
常见文件:
|
||||
- `rollout_summary.json`
|
||||
- `rollout_front_ep01_trajectory.png` ... `rollout_front_ep10_trajectory.png`
|
||||
|
||||
日志重点看:
|
||||
- `Epoch X rollout 平均奖励`
|
||||
- `最佳模型已更新`
|
||||
|
||||
---
|
||||
|
||||
## 5. 数据集加载与 `val_episode_indices` 机制
|
||||
|
||||
### 5.1 数据集格式
|
||||
|
||||
`SimpleRobotDataset` 读取 `dataset_dir` 下的 `episode_*.hdf5`,每个 episode 文件里至少要有:
|
||||
- `action`
|
||||
- `observations/qpos`
|
||||
- `observations/images/{cam_name}`
|
||||
|
||||
当前常用相机:
|
||||
- `r_vis`
|
||||
- `top`
|
||||
- `front`
|
||||
|
||||
### 5.2 懒加载行为
|
||||
|
||||
`roboimi/vla/data/simpe_robot_dataset.py` 是按帧懒加载,不会一次性把整套 HDF5 全读进内存。
|
||||
|
||||
它会:
|
||||
- 扫描目录下的 HDF5 文件
|
||||
- 用文件名里的 episode 编号(如 `episode_100.hdf5` -> `100`)建立 `available_episode_indices`
|
||||
- 在 worker 内做 HDF5 文件句柄 LRU 缓存
|
||||
|
||||
### 5.3 `val_episode_indices` 怎么切
|
||||
|
||||
`build_train_val_datasets()` 的逻辑是:
|
||||
|
||||
1. 先 instantiate 一次完整 dataset
|
||||
2. 读取 `dataset.available_episode_indices`
|
||||
3. 检查 `train.val_episode_indices` 是否都存在
|
||||
4. 用 `episode_indices=` 再各 instantiate 一次:
|
||||
- train dataset = 全部 episode - held-out episode
|
||||
- val dataset = 只包含 held-out episode
|
||||
|
||||
因此:
|
||||
- `train.val_episode_indices=[100]` 的意思是“把 `episode_100.hdf5` 整个拿去做 held-out val”
|
||||
- 如果 episode 不存在,会直接报错
|
||||
- 如果你把所有 episode 都塞进 `val_episode_indices`,也会直接报错,因为训练集会变空
|
||||
|
||||
### 5.4 图像 resize 与 LeWM 附加字段
|
||||
|
||||
dataset 侧 resize 默认来自:
|
||||
- `data.image_resize_shape`
|
||||
- 如果 backbone 额外覆盖,则优先 `agent.vision_backbone.dataset_image_resize_shape`
|
||||
|
||||
返回 batch 除了常规:
|
||||
- `observation.state`
|
||||
- `observation.<cam>`
|
||||
- `action`
|
||||
|
||||
还会在 LeWM 打开时返回:
|
||||
- `lewm.observation.state`
|
||||
- `lewm.observation.<cam>`
|
||||
- `lewm.future.state`
|
||||
- `lewm.future.<cam>`
|
||||
|
||||
### 5.5 统计文件
|
||||
|
||||
训练和推理都默认依赖 `dataset_stats.pkl`。数据集更新后重算:
|
||||
|
||||
```bash
|
||||
/home/droid/.conda/envs/roboimi/bin/python roboimi/vla/scripts/calculate_stats.py \
|
||||
--dataset_dir /home/droid/project/diana_sim/sim_transfer
|
||||
```
|
||||
|
||||
远端只要把 `--dataset_dir` 换成对应主机路径即可。
|
||||
|
||||
---
|
||||
|
||||
## 6. SwanLab 行为
|
||||
|
||||
当前配置默认值里 `train.use_swanlab=false`,但本分支常用 recipe 基本都显式开:
|
||||
- `train.use_swanlab=true`
|
||||
- `train.swanlab_project=roboimi-vla`
|
||||
- `train.swanlab_run_name=<run_name>`
|
||||
|
||||
`train_vla.py` 的 SwanLab 行为:
|
||||
- 初始化时上传 `train` / `data` / `agent` 三段 config
|
||||
- 训练中记录:
|
||||
- `train/loss`
|
||||
- `train/lr`
|
||||
- `train/best_loss`
|
||||
- `train/step`
|
||||
- checkpoint 验证时记录:
|
||||
- `val/loss`
|
||||
- held-out 数值验证时记录:
|
||||
- `val/action_mse`
|
||||
- rollout 验证时记录:
|
||||
- `rollout/avg_reward`
|
||||
- `rollout/epoch`
|
||||
- 训练结束时记录:
|
||||
- `final/checkpoint_path`
|
||||
- `final/best_checkpoint_path`
|
||||
|
||||
训练期 rollout 生成的前视图轨迹 PNG 会 best-effort 上传到 SwanLab;失败只会 warning,不会让训练中断。
|
||||
|
||||
---
|
||||
|
||||
## 7. 并行 rollout 说明
|
||||
|
||||
### 7.1 这套能力从哪里来
|
||||
|
||||
本分支的并行 rollout 方向不是 DataLoader 并行,而是 **`eval_vla.py` 的 multiprocess rollout path**。
|
||||
参考来源:
|
||||
`/home/droid/project/roboimi/.worktrees/multiprocess-rollout/roboimi/demos/vla_scripts/eval_vla.py`
|
||||
|
||||
那条路径的控制参数是:
|
||||
- `eval.num_workers`
|
||||
- `eval.cuda_devices`
|
||||
|
||||
语义是:
|
||||
- `eval.num_workers`:环境 worker 数,按 episode 切分
|
||||
- `eval.cuda_devices`:推理 server 绑定到哪些逻辑 GPU
|
||||
|
||||
### 7.2 两种常见模式
|
||||
|
||||
1. **单机单卡,多 worker 共用同一张 GPU**
|
||||
- 典型:本地 5090 只有 1 卡,但想让 4 个 rollout worker 并行跑环境
|
||||
- 形式:`eval.device=cuda eval.num_workers=4 'eval.cuda_devices=[0]'`
|
||||
- 这时是 **1 个 CUDA inference server + 4 个 env worker**
|
||||
|
||||
2. **单机多卡,多 server 分摊 worker**
|
||||
- 典型:5880 有 2 卡,L20 有多卡
|
||||
- 形式:`eval.device=cuda eval.num_workers=8 'eval.cuda_devices=[0,1]'`
|
||||
- worker 会按 round-robin 分到多个 server 上
|
||||
|
||||
### 7.3 操作上要注意什么
|
||||
|
||||
- 并行 rollout 依赖 **多进程 eval 路径**,不是 `train.num_workers`
|
||||
- `train.num_workers` 是 DataLoader worker,和 rollout 并行不是一回事
|
||||
- `eval.num_workers > 1` 时必须 `eval.headless=true`
|
||||
- worker 数会自动 cap 到 `eval.num_episodes`
|
||||
- multiprocess rollout 当前已经支持 **per-episode trajectory image PNG**;多 worker 时每个 worker 会在自己的 artifact 子目录下写图,summary 会带回对应路径
|
||||
- 但多 worker 时仍然不要同时要求:
|
||||
- `eval.record_video=true`
|
||||
- `eval.save_trajectory=true`
|
||||
- `eval.save_trajectory_npz=true`
|
||||
- `eval.save_trajectory_image=true` 现在是可以开的,适合并行 reward + 定性检查一起做
|
||||
|
||||
### 7.4 并行 rollout 命令模板
|
||||
|
||||
**5090 单卡 4 worker:**
|
||||
|
||||
```bash
|
||||
/home/droid/.conda/envs/roboimi/bin/python roboimi/demos/vla_scripts/eval_vla.py \
|
||||
agent=lewm_resnet_query_imf_attnres \
|
||||
data.dataset_dir=/home/droid/project/diana_sim/sim_transfer \
|
||||
train.device=cuda eval.device=cuda eval.headless=true eval.verbose_action=false \
|
||||
eval.ckpt_path=/home/droid/project/roboimi/.worktrees/feat-lewm-imf-fusion/runs/<run_name>/checkpoints/vla_model_best.pt \
|
||||
eval.num_episodes=10 eval.num_workers=4 'eval.cuda_devices=[0]' \
|
||||
eval.save_summary_json=true eval.artifact_dir=/tmp/lewm_parallel_eval_5090
|
||||
```
|
||||
|
||||
**5880 双卡 8 worker:**
|
||||
|
||||
```bash
|
||||
/home/droid/miniforge3/envs/roboimi/bin/python roboimi/demos/vla_scripts/eval_vla.py \
|
||||
agent=lewm_resnet_query_imf_attnres \
|
||||
data.dataset_dir=/home/droid/sim_dataset/sim_transfer \
|
||||
train.device=cuda eval.device=cuda eval.headless=true eval.verbose_action=false \
|
||||
eval.ckpt_path=/home/droid/roboimi_suite_20260416_lewm_imf_fusion/runs/<run_name>/checkpoints/vla_model_best.pt \
|
||||
eval.num_episodes=10 eval.num_workers=8 'eval.cuda_devices=[0,1]' \
|
||||
eval.save_summary_json=true eval.artifact_dir=/tmp/lewm_parallel_eval_5880
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 8. 当前常用命令 / 脚本
|
||||
|
||||
### 8.1 本地 5090:直接用 suite 脚本
|
||||
|
||||
现成脚本:
|
||||
`experiment_suites/2026-04-21-lewm-fromscratch-old9-epoch50-roll5-val-20260421-153037/launch_local_5090.sh`
|
||||
|
||||
运行:
|
||||
|
||||
```bash
|
||||
bash experiment_suites/2026-04-21-lewm-fromscratch-old9-epoch50-roll5-val-20260421-153037/launch_local_5090.sh
|
||||
```
|
||||
|
||||
### 8.2 本地 5090:手动启动同 recipe
|
||||
|
||||
```bash
|
||||
/home/droid/.conda/envs/roboimi/bin/python roboimi/demos/vla_scripts/train_vla.py \
|
||||
agent=lewm_resnet_query_imf_attnres \
|
||||
data.dataset_dir=/home/droid/project/diana_sim/sim_transfer \
|
||||
'agent.lewm_query_offsets=[8]' \
|
||||
agent.pred_horizon=8 \
|
||||
agent.num_action_steps=8 \
|
||||
train.device=cuda \
|
||||
train.batch_size=32 \
|
||||
train.lr=0.0001 \
|
||||
train.max_steps=109350 \
|
||||
train.num_workers=4 \
|
||||
train.save_freq=10000 \
|
||||
train.rollout_validate_on_checkpoint=false \
|
||||
train.rollout_val_freq_epochs=5 \
|
||||
train.rollout_num_episodes=10 \
|
||||
train.val_split=0.0 \
|
||||
'train.val_episode_indices=[100]' \
|
||||
train.action_mse_val_freq_epochs=1 \
|
||||
train.use_swanlab=true \
|
||||
train.swanlab_project=roboimi-vla \
|
||||
train.swanlab_run_name=lewmimf-q08-ph08-ex08-emb384-l12-fromscratch-epoch50-step109350-5090g0-20260421-153037 \
|
||||
train.pretrained_ckpt=null \
|
||||
agent.lewm_pretrained_ckpt=null \
|
||||
hydra.run.dir=/home/droid/project/roboimi/.worktrees/feat-lewm-imf-fusion/runs/lewmimf-q08-ph08-ex08-emb384-l12-fromscratch-epoch50-step109350-5090g0-20260421-153037
|
||||
```
|
||||
|
||||
### 8.3 5880:常用命令模板
|
||||
|
||||
```bash
|
||||
ssh droid@100.73.14.65
|
||||
cd /home/droid/roboimi_suite_20260416_lewm_imf_fusion
|
||||
/home/droid/miniforge3/envs/roboimi/bin/python roboimi/demos/vla_scripts/train_vla.py \
|
||||
agent=lewm_resnet_query_imf_attnres \
|
||||
data.dataset_dir=/home/droid/sim_dataset/sim_transfer \
|
||||
'agent.lewm_query_offsets=[8]' \
|
||||
agent.pred_horizon=16 \
|
||||
agent.num_action_steps=8 \
|
||||
train.device=cuda train.batch_size=32 train.lr=0.0001 train.max_steps=109350 \
|
||||
train.num_workers=4 train.save_freq=10000 train.rollout_validate_on_checkpoint=false \
|
||||
train.rollout_val_freq_epochs=5 train.rollout_num_episodes=10 train.val_split=0.0 \
|
||||
'train.val_episode_indices=[100]' train.action_mse_val_freq_epochs=1 \
|
||||
train.use_swanlab=true train.swanlab_project=roboimi-vla \
|
||||
train.swanlab_run_name=lewmimf-q08-ph16-ex08-emb384-l12-fromscratch-epoch50-step109350-5880g0-20260421-153037 \
|
||||
train.pretrained_ckpt=null agent.lewm_pretrained_ckpt=null \
|
||||
hydra.run.dir=/home/droid/roboimi_suite_20260416_lewm_imf_fusion/runs/lewmimf-q08-ph16-ex08-emb384-l12-fromscratch-epoch50-step109350-5880g0-20260421-153037
|
||||
```
|
||||
|
||||
### 8.4 L20:常用命令模板
|
||||
|
||||
```bash
|
||||
ssh droid@100.119.99.14
|
||||
cd /data/roboimi_suite_20260416_lewm_imf_fusion
|
||||
/home/droid/miniforge3/envs/roboimi/bin/python roboimi/demos/vla_scripts/train_vla.py \
|
||||
agent=lewm_resnet_query_imf_attnres \
|
||||
data.dataset_dir=/data/simtransfer/current \
|
||||
'agent.lewm_query_offsets=[16]' \
|
||||
agent.pred_horizon=16 \
|
||||
agent.num_action_steps=16 \
|
||||
train.device=cuda train.batch_size=32 train.lr=0.0001 train.max_steps=109350 \
|
||||
train.num_workers=4 train.save_freq=10000 train.rollout_validate_on_checkpoint=false \
|
||||
train.rollout_val_freq_epochs=5 train.rollout_num_episodes=10 train.val_split=0.0 \
|
||||
'train.val_episode_indices=[100]' train.action_mse_val_freq_epochs=1 \
|
||||
train.use_swanlab=true train.swanlab_project=roboimi-vla \
|
||||
train.swanlab_run_name=lewmimf-q16-ph16-ex16-emb384-l12-fromscratch-epoch50-step109350-l20g0-20260421-153037 \
|
||||
train.pretrained_ckpt=null agent.lewm_pretrained_ckpt=null \
|
||||
hydra.run.dir=/data/roboimi_suite_20260416_lewm_imf_fusion/runs/lewmimf-q16-ph16-ex16-emb384-l12-fromscratch-epoch50-step109350-l20g0-20260421-153037
|
||||
```
|
||||
|
||||
### 8.5 单次离线验证(当前分支已支持并行)
|
||||
|
||||
**单 GPU / 4 worker:**
|
||||
|
||||
```bash
|
||||
/home/droid/.conda/envs/roboimi/bin/python roboimi/demos/vla_scripts/eval_vla.py \
|
||||
agent=lewm_resnet_query_imf_attnres \
|
||||
data.dataset_dir=/home/droid/project/diana_sim/sim_transfer \
|
||||
train.device=cuda eval.device=cuda \
|
||||
eval.ckpt_path=/home/droid/project/roboimi/.worktrees/feat-lewm-imf-fusion/runs/<run_name>/checkpoints/vla_model_best.pt \
|
||||
eval.num_episodes=10 eval.num_workers=4 'eval.cuda_devices=[0]' \
|
||||
eval.headless=true eval.verbose_action=false \
|
||||
eval.save_summary_json=true eval.save_trajectory_image=true \
|
||||
eval.trajectory_image_camera_name=front \
|
||||
eval.artifact_dir=/tmp/lewm_eval_front
|
||||
```
|
||||
|
||||
**训练内启用并行 GPU rollout(推荐显式写清楚)**:
|
||||
|
||||
```bash
|
||||
/home/droid/.conda/envs/roboimi/bin/python roboimi/demos/vla_scripts/train_vla.py \
|
||||
agent=lewm_resnet_query_imf_attnres \
|
||||
data.dataset_dir=/home/droid/project/diana_sim/sim_transfer \
|
||||
'agent.lewm_query_offsets=[8]' \
|
||||
agent.pred_horizon=8 \
|
||||
agent.num_action_steps=8 \
|
||||
train.device=cuda \
|
||||
train.batch_size=32 \
|
||||
train.lr=0.0001 \
|
||||
train.max_steps=109350 \
|
||||
train.num_workers=4 \
|
||||
train.save_freq=10000 \
|
||||
train.rollout_val_freq_epochs=5 \
|
||||
train.rollout_num_episodes=10 \
|
||||
train.rollout_device=cuda \
|
||||
train.rollout_num_workers=4 \
|
||||
'train.rollout_cuda_devices=[0]' \
|
||||
train.rollout_validate_on_checkpoint=false \
|
||||
train.val_split=0.0 \
|
||||
'train.val_episode_indices=[100]' \
|
||||
train.action_mse_val_freq_epochs=1 \
|
||||
train.use_swanlab=true \
|
||||
train.swanlab_project=roboimi-vla \
|
||||
train.swanlab_run_name=<run_name> \
|
||||
hydra.run.dir=/home/droid/project/roboimi/.worktrees/feat-lewm-imf-fusion/runs/<run_name>
|
||||
```
|
||||
|
||||
### 8.6 监控日志
|
||||
|
||||
```bash
|
||||
tail -f runs/<run_name>/launch.stdout.log
|
||||
tail -f runs/<run_name>/train_vla.log
|
||||
```
|
||||
|
||||
远端就把 `runs/<run_name>` 换成 manifest 里的绝对路径。
|
||||
|
||||
---
|
||||
|
||||
## 9. 操作建议
|
||||
|
||||
- **优先以 suite 的 `manifest.json` / `notes.md` / `launch_logs/*.launch.log` 为准**,不要手写一套和历史 run 不一致的命令
|
||||
- 要做当前常用验证,就显式加上:
|
||||
- `train.val_split=0.0`
|
||||
- `train.val_episode_indices=[100]`
|
||||
- `train.action_mse_val_freq_epochs=1`
|
||||
- `train.rollout_val_freq_epochs=5`
|
||||
- `train.rollout_num_episodes=10`
|
||||
- 本分支如果要对比不同 horizon / action-step,尽量只改:
|
||||
- `agent.lewm_query_offsets`
|
||||
- `agent.pred_horizon`
|
||||
- `agent.num_action_steps`
|
||||
- 想复现 2026-04-21 那轮 from-scratch 结果时,记得同时设:
|
||||
- `train.pretrained_ckpt=null`
|
||||
- `agent.lewm_pretrained_ckpt=null`
|
||||
@@ -0,0 +1,42 @@
|
||||
# Streaming HDF5 EE Action Dataset Implementation Plan
|
||||
|
||||
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||
|
||||
**Goal:** 将 Diana 仿真采集改为流式写入 HDF5,图像保存为 256x256 的四路相机视角,并把 `/action` 改为 IK 前的原始末端位姿动作。
|
||||
|
||||
**Architecture:** 新增一个独立的流式 HDF5 episode writer,负责逐帧写入 qpos、原始 action 和 resize 后图像,并在 episode 成功时原子提交、失败时删除临时文件。采集脚本只负责 rollout 和把每一步观测/动作交给 writer,避免整集数据先堆在内存里。
|
||||
|
||||
**Tech Stack:** Python, h5py, numpy, cv2, unittest, MuJoCo demo scripts
|
||||
|
||||
---
|
||||
|
||||
### Task 1: 为流式 writer 建立测试边界
|
||||
|
||||
**Files:**
|
||||
- Create: `tests/test_streaming_episode_writer.py`
|
||||
- Create: `roboimi/utils/streaming_episode_writer.py`
|
||||
|
||||
- [ ] **Step 1: Write the failing test**
|
||||
- [ ] **Step 2: Run `python -m unittest tests.test_streaming_episode_writer -v` and confirm it fails because the writer module does not exist**
|
||||
- [ ] **Step 3: Implement the minimal streaming writer with temp-file commit/discard, per-frame append, and 256x256 image resize**
|
||||
- [ ] **Step 4: Re-run `python -m unittest tests.test_streaming_episode_writer -v` and confirm it passes**
|
||||
|
||||
### Task 2: 接入 Diana 采集脚本
|
||||
|
||||
**Files:**
|
||||
- Modify: `roboimi/demos/diana_record_sim_episodes.py`
|
||||
- Reuse: `roboimi/utils/streaming_episode_writer.py`
|
||||
|
||||
- [ ] **Step 1: Replace in-memory `data_dict` / `obs` accumulation with per-episode streaming writer lifecycle**
|
||||
- [ ] **Step 2: Keep four cameras (`angle`, `r_vis`, `top`, `front`) and resize to 256x256 before persistence**
|
||||
- [ ] **Step 3: Capture raw policy output before IK and write that to `/action`**
|
||||
- [ ] **Step 4: On success commit to `episode_{idx}.hdf5`; on failure remove temp file**
|
||||
|
||||
### Task 3: 验证改动
|
||||
|
||||
**Files:**
|
||||
- Verify only
|
||||
|
||||
- [ ] **Step 1: Run unit tests for the writer**
|
||||
- [ ] **Step 2: Run one end-to-end collection episode and stop after `episode_0.hdf5` becomes readable**
|
||||
- [ ] **Step 3: Verify HDF5 keys and shapes: `action=(700,16)`, image datasets are `(700,256,256,3)`, and `/action` matches raw EE action semantics**
|
||||
@@ -0,0 +1,26 @@
|
||||
# Raw Action Trajectory Viewer Implementation Plan
|
||||
|
||||
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||
|
||||
**Goal:** 在可交互 MuJoCo 仿真窗口中,把 rollout 导出的 raw EE action 轨迹用红色轨迹标出来并启动仿真供人工查看。
|
||||
|
||||
**Architecture:** 读取已有 trajectory artifact 中的 raw_action / step 数据,生成左右臂末端轨迹点,并在 viewer 渲染循环中持续注入红色 marker。实现尽量独立为一个可复用的小脚本,避免影响训练/评估主路径。
|
||||
|
||||
**Tech Stack:** Python, NumPy, MuJoCo viewer, unittest/mock.
|
||||
|
||||
---
|
||||
|
||||
### Task 1: 抽取 raw_action 轨迹并生成可视化点集
|
||||
- [ ] 写失败测试,验证从 trajectory.npz 提取左右臂轨迹点
|
||||
- [ ] 实现最小 helper
|
||||
- [ ] 运行测试确认通过
|
||||
|
||||
### Task 2: 在 viewer 中渲染红色轨迹并支持交互查看
|
||||
- [ ] 写失败测试,验证 marker 配置/调用
|
||||
- [ ] 实现 viewer 可视化脚本
|
||||
- [ ] 运行测试确认通过
|
||||
|
||||
### Task 3: 启动真实仿真窗口供人工查看
|
||||
- [ ] 用现有 trajectory artifact 启动 viewer
|
||||
- [ ] 确认窗口可交互、红线出现
|
||||
- [ ] 向用户汇报启动方式与脚本路径
|
||||
44
docs/superpowers/plans/2026-03-31-rollout-artifacts.md
Normal file
44
docs/superpowers/plans/2026-03-31-rollout-artifacts.md
Normal file
@@ -0,0 +1,44 @@
|
||||
# Rollout Artifacts Implementation Plan
|
||||
|
||||
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||
|
||||
**Goal:** Extend rollout evaluation so one selected checkpoint can be run once with video capture, timing breakdown, and saved EE trajectory artifacts.
|
||||
|
||||
**Architecture:** Keep the implementation centered in `eval_vla.py` so existing training-time rollout validation remains compatible. Add config-gated artifact capture helpers, serialize outputs under the eval run directory, and add lightweight tests for helper behavior and summary wiring; default eval behavior must remain unchanged when artifact capture is off.
|
||||
|
||||
**Tech Stack:** Python, Hydra/OmegaConf, NumPy, OpenCV, JSON, PyTorch unittest/mocking.
|
||||
|
||||
---
|
||||
|
||||
### Task 1: Add artifact capture configuration and helper wiring
|
||||
|
||||
**Files:**
|
||||
- Modify: `roboimi/demos/vla_scripts/eval_vla.py`
|
||||
- Modify: `roboimi/vla/conf/eval/eval.yaml`
|
||||
- Test: `tests/test_eval_vla_rollout_artifacts.py`
|
||||
|
||||
- [ ] **Step 1: Write failing tests for optional artifact config / summary wiring**
|
||||
- [ ] **Step 2: Implement config-backed artifact flags and output paths with defaults that write nothing**
|
||||
- [ ] **Step 3: Verify existing eval call sites still work with defaults**
|
||||
|
||||
### Task 2: Add timing breakdown, video recording, and trajectory export
|
||||
|
||||
**Files:**
|
||||
- Modify: `roboimi/demos/vla_scripts/eval_vla.py`
|
||||
- Test: `tests/test_eval_vla_rollout_artifacts.py`
|
||||
|
||||
- [ ] **Step 1: Write failing tests for timing aggregation, trajectory serialization, and summary schema**
|
||||
- [ ] **Step 2: Implement per-step timing capture for `obs_read_ms`, `preprocess_ms`, `inference_ms`, `env_step_ms`, `loop_total_ms`**
|
||||
- [ ] **Step 3: Implement MP4 recording from a chosen camera stream and canonical `trajectory.npz` export using `left_link7/right_link7` executed poses after `env.step`**
|
||||
- [ ] **Step 4: Run focused tests and fix issues**
|
||||
|
||||
### Task 3: Stop training safely and execute one real rollout
|
||||
|
||||
**Files:**
|
||||
- Use: `roboimi/demos/vla_scripts/eval_vla.py`
|
||||
- Output: `runs/.../eval_artifacts/...`
|
||||
|
||||
- [ ] **Step 1: Stop the active training process, wait for exit, and confirm the target checkpoint is readable**
|
||||
- [ ] **Step 2: Select the latest completed checkpoint if an explicit one is not provided; fall back to prior completed / best checkpoint if needed**
|
||||
- [ ] **Step 3: Run one headless rollout with artifact capture enabled**
|
||||
- [ ] **Step 4: Verify the MP4 / timing summary / trajectory files exist and summarize findings**
|
||||
@@ -0,0 +1,268 @@
|
||||
# IMF-AttnRes Policy Migration Implementation Plan
|
||||
|
||||
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||
|
||||
**Goal:** 将 external `diffusion_policy@185ed659` 的 IMF-AttnRes 模型、训练目标和一步推理机制迁移到 RoboIMI,并在保持三相机视觉条件输入与现有训练/rollout 工作流的前提下启动同参数训练。
|
||||
|
||||
**Architecture:** 保留 RoboIMI 现有 ResNet 三相机观测编码、normalization、queue-based online rollout 和训练脚本;新增 AttnRes 组件与 IMF transformer head,并新增 IMF 专用 agent 以覆盖 DDPM loss / DDIM inference 语义。训练脚本只做最小接线修改,让新 head/agent 能用现有 optimizer、checkpoint、SwanLab 和 headless rollout。
|
||||
|
||||
**Tech Stack:** PyTorch, Hydra, diffusers schedulers (仅保留兼容初始化), MuJoCo rollout, unittest, SwanLab
|
||||
|
||||
---
|
||||
|
||||
## File Map
|
||||
|
||||
### New files
|
||||
- `roboimi/vla/models/heads/attnres_transformer_components.py` — 本地 IMF AttnRes 基础组件
|
||||
- `roboimi/vla/models/heads/imf_transformer1d.py` — IMF transformer head,暴露 `forward(sample, r, t, cond=None)`
|
||||
- `roboimi/vla/agent_imf.py` — IMF 专用 VLA agent,复用现有观测/队列/normalization 逻辑并覆盖 loss / inference
|
||||
- `roboimi/vla/conf/head/imf_transformer1d.yaml` — IMF head 配置
|
||||
- `roboimi/vla/conf/agent/resnet_imf_attnres.yaml` — IMF agent + backbone/head 组合配置
|
||||
- `tests/test_imf_transformer1d_external_alignment.py` — external `185ed659` 对齐测试
|
||||
- `tests/test_imf_vla_agent.py` — IMF agent 的 loss / inference / queue 语义测试
|
||||
|
||||
### Modified files
|
||||
- `roboimi/demos/vla_scripts/train_vla.py` — 优化器参数分组接线;确保新 agent 能无缝训练
|
||||
- `roboimi/vla/conf/config.yaml` — 保持默认配置不变,仅支持通过 override 启用 IMF agent
|
||||
- `tests/test_train_vla_transformer_optimizer.py` — 覆盖 IMF head 的 optimizer-group 行为
|
||||
- (如需要)`roboimi/vla/models/heads/__init__.py` 或相近导出文件 — 暴露新 head
|
||||
|
||||
---
|
||||
|
||||
### Task 1: 写 IMF transformer 对齐测试
|
||||
|
||||
**Files:**
|
||||
- Create: `tests/test_imf_transformer1d_external_alignment.py`
|
||||
- Reference: `/home/droid/project/diffusion_policy/diffusion_policy/model/diffusion/attnres_transformer_components.py`
|
||||
- Reference: `/home/droid/project/diffusion_policy/diffusion_policy/model/diffusion/imf_transformer_for_diffusion.py`
|
||||
|
||||
- [ ] **Step 1: 写失败测试,验证 local IMF head 与 external `185ed659` 的 state-dict key、前向 shape、forward 数值、optim groups 对齐**
|
||||
|
||||
```python
|
||||
with torch.no_grad():
|
||||
external_out = external_model(sample=sample, r=r, t=t, cond=cond)
|
||||
local_out = local_model(sample=sample, r=r, t=t, cond=cond)
|
||||
assert torch.allclose(local_out, external_out, atol=1e-6, rtol=1e-5)
|
||||
```
|
||||
|
||||
- [ ] **Step 2: 运行单测,确认当前失败**
|
||||
|
||||
Run: `python -m unittest tests.test_imf_transformer1d_external_alignment -v`
|
||||
Expected: FAIL,提示 `imf_transformer1d` / `attnres` 模块不存在
|
||||
|
||||
- [ ] **Step 3: 若测试需要复用现有 external-loader 逻辑,则从 `tests/test_transformer1d_external_alignment.py` 复制最小必要 helper,避免重复依赖 session context**
|
||||
|
||||
- [ ] **Step 4: 提交测试骨架**
|
||||
|
||||
```bash
|
||||
git add tests/test_imf_transformer1d_external_alignment.py
|
||||
git commit -m "test: add IMF transformer external alignment coverage"
|
||||
```
|
||||
|
||||
### Task 2: 实现 AttnRes 组件与 IMF transformer head
|
||||
|
||||
**Files:**
|
||||
- Create: `roboimi/vla/models/heads/attnres_transformer_components.py`
|
||||
- Create: `roboimi/vla/models/heads/imf_transformer1d.py`
|
||||
- Modify: `tests/test_imf_transformer1d_external_alignment.py`
|
||||
|
||||
- [ ] **Step 1: 按 external `185ed659` 迁移 AttnRes 基础组件,保持命名和参数语义一致**
|
||||
|
||||
必须包含:
|
||||
- `RMSNorm`
|
||||
- `RMSNormNoWeight`
|
||||
- `precompute_rope_freqs`
|
||||
- `apply_rope`
|
||||
- `GroupedQuerySelfAttention`
|
||||
- `SwiGLUFFN`
|
||||
- `AttnResOperator`
|
||||
- `AttnResSubLayer`
|
||||
- `AttnResTransformerBackbone`
|
||||
|
||||
- [ ] **Step 2: 在 `imf_transformer1d.py` 中实现本地 IMF head**
|
||||
|
||||
必须满足:
|
||||
- `forward(sample, r, t, cond=None)`
|
||||
- 默认支持 `backbone_type='attnres_full'`
|
||||
- token 序列为 `[r_token, t_token, cond_tokens..., sample_tokens...]`
|
||||
- 输出只切回 sample token 段
|
||||
- 保留 `get_optim_groups()` 供 AdamW 分组
|
||||
|
||||
- [ ] **Step 3: 运行对齐测试,修正 state-dict key / init / no-decay 参数分组不一致问题**
|
||||
|
||||
Run: `python -m unittest tests.test_imf_transformer1d_external_alignment -v`
|
||||
Expected: PASS
|
||||
|
||||
- [ ] **Step 4: 提交模型组件实现**
|
||||
|
||||
```bash
|
||||
git add roboimi/vla/models/heads/attnres_transformer_components.py \
|
||||
roboimi/vla/models/heads/imf_transformer1d.py \
|
||||
tests/test_imf_transformer1d_external_alignment.py
|
||||
git commit -m "feat: add IMF AttnRes transformer head"
|
||||
```
|
||||
|
||||
### Task 3: 写 IMF agent 行为测试
|
||||
|
||||
**Files:**
|
||||
- Create: `tests/test_imf_vla_agent.py`
|
||||
- Reference: `roboimi/vla/agent.py`
|
||||
- Reference: `tests/test_resnet_transformer_agent_wiring.py`
|
||||
|
||||
- [ ] **Step 1: 写失败测试,覆盖 IMF agent 的核心契约**
|
||||
|
||||
需要覆盖:
|
||||
1. `compute_loss()` 接受当前 batch 结构并返回标量 loss
|
||||
2. `predict_action()` 输出 `(B, pred_horizon, action_dim)`
|
||||
3. `select_action()` 仍按 queue/chunk 语义工作
|
||||
4. `predict_action()` 不走 DDIM 多步循环,而是只触发一步 IMF sample
|
||||
5. `action_is_pad` 存在时仅在有效 action 上计 loss
|
||||
|
||||
- [ ] **Step 2: 用 stub backbone / stub head 记录调用参数,验证 `r,t,cond` 的传递与 observation conditioning 维度正确**
|
||||
|
||||
```python
|
||||
self.assertEqual(recorded['cond'].shape, (B, obs_horizon, expected_cond_dim))
|
||||
self.assertTrue(torch.allclose(recorded['r'], torch.zeros(B)))
|
||||
self.assertTrue(torch.allclose(recorded['t'], torch.ones(B)))
|
||||
```
|
||||
|
||||
- [ ] **Step 3: 运行测试,确认当前失败**
|
||||
|
||||
Run: `python -m unittest tests.test_imf_vla_agent -v`
|
||||
Expected: FAIL,提示 `roboimi.vla.agent_imf` 不存在
|
||||
|
||||
- [ ] **Step 4: 提交测试骨架**
|
||||
|
||||
```bash
|
||||
git add tests/test_imf_vla_agent.py
|
||||
git commit -m "test: add IMF VLA agent behavior coverage"
|
||||
```
|
||||
|
||||
### Task 4: 实现 IMF agent 与 Hydra 接线
|
||||
|
||||
**Files:**
|
||||
- Create: `roboimi/vla/agent_imf.py`
|
||||
- Create: `roboimi/vla/conf/head/imf_transformer1d.yaml`
|
||||
- Create: `roboimi/vla/conf/agent/resnet_imf_attnres.yaml`
|
||||
- Modify: `roboimi/demos/vla_scripts/train_vla.py`
|
||||
- Modify: `tests/test_train_vla_transformer_optimizer.py`
|
||||
- Modify: `tests/test_imf_vla_agent.py`
|
||||
|
||||
- [ ] **Step 1: 以 `VLAAgent` 为基础实现 `IMFVLAAgent`**
|
||||
|
||||
实现策略:
|
||||
- 复用 `VLAAgent.__init__`、`_build_cond()`、`reset()`、`_populate_queues()`、`_prepare_observation_batch()`、`select_action()`、`get_normalization_stats()`
|
||||
- 覆盖:
|
||||
- `compute_loss()` -> IMF objective
|
||||
- `predict_action()` -> one-step sample
|
||||
- 提供内部 helper:
|
||||
- `_broadcast_batch_time`
|
||||
- `_apply_conditioning`(如需)
|
||||
- `_compute_u_and_du_dt`
|
||||
- `_compound_velocity`
|
||||
- `_sample_one_step`
|
||||
|
||||
- [ ] **Step 2: 在 JVP 路径中加入 CUDA math SDPA fallback,保持 external repo 的稳定性策略**
|
||||
|
||||
- [ ] **Step 3: 新增 Hydra 配置,让 `agent=resnet_imf_attnres` 可实例化**
|
||||
|
||||
关键默认值:
|
||||
- `_target_: roboimi.vla.agent_imf.IMFVLAAgent`
|
||||
- `head._target_: roboimi.vla.models.heads.imf_transformer1d.IMFTransformer1D`
|
||||
- `head.backbone_type: attnres_full`
|
||||
- `head.causal_attn: false`
|
||||
- `head.time_as_cond: true`
|
||||
- `head.n_cond_layers: 0`
|
||||
- `inference_steps: 1`
|
||||
- `camera_names: ${data.camera_names}`
|
||||
- `vision_backbone.camera_names: ${agent.camera_names}`
|
||||
|
||||
- [ ] **Step 4: 让训练脚本对任何带 `get_optim_groups()` 的 head 复用参数分组,而不是硬编码旧 transformer head_type**
|
||||
|
||||
推荐最小改法:
|
||||
```python
|
||||
use_head_groups = callable(getattr(noise_pred_net, 'get_optim_groups', None))
|
||||
```
|
||||
|
||||
- [ ] **Step 5: 运行测试并修复 wiring 问题**
|
||||
|
||||
Run:
|
||||
- `python -m unittest tests.test_imf_vla_agent -v`
|
||||
- `python -m unittest tests.test_train_vla_transformer_optimizer -v`
|
||||
|
||||
Expected: PASS
|
||||
|
||||
- [ ] **Step 6: 提交 agent / config / train-script 接线**
|
||||
|
||||
```bash
|
||||
git add roboimi/vla/agent_imf.py \
|
||||
roboimi/vla/conf/head/imf_transformer1d.yaml \
|
||||
roboimi/vla/conf/agent/resnet_imf_attnres.yaml \
|
||||
roboimi/demos/vla_scripts/train_vla.py \
|
||||
tests/test_imf_vla_agent.py \
|
||||
tests/test_train_vla_transformer_optimizer.py
|
||||
git commit -m "feat: add IMF VLA agent and training wiring"
|
||||
```
|
||||
|
||||
### Task 5: 集成验证与训练启动
|
||||
|
||||
**Files:**
|
||||
- Modify: none required unless验证暴露真实问题
|
||||
- Use run artifacts under: `runs/`
|
||||
|
||||
- [ ] **Step 1: 运行聚焦测试集**
|
||||
|
||||
Run:
|
||||
```bash
|
||||
python -m unittest \
|
||||
tests.test_imf_transformer1d_external_alignment \
|
||||
tests.test_imf_vla_agent \
|
||||
tests.test_resnet_transformer_agent_wiring \
|
||||
tests.test_train_vla_transformer_optimizer -v
|
||||
```
|
||||
Expected: PASS
|
||||
|
||||
- [ ] **Step 2: 运行一个最小 GPU 训练冒烟任务(不必长跑)**
|
||||
|
||||
Run:
|
||||
```bash
|
||||
/home/droid/.conda/envs/roboimi/bin/python roboimi/demos/vla_scripts/train_vla.py \
|
||||
agent=resnet_imf_attnres \
|
||||
data.dataset_dir=/home/droid/project/diana_sim/sim_transfer \
|
||||
data.camera_names=[r_vis,top,front] \
|
||||
train.device=cuda train.max_steps=2 train.batch_size=4 train.num_workers=2 \
|
||||
train.use_swanlab=false train.rollout_val_freq_epochs=0
|
||||
```
|
||||
Expected: 成功完成 2 steps,生成 checkpoint / log,无 shape 或 JVP 错误
|
||||
|
||||
- [ ] **Step 3: 用正式参数启动 IMF 训练**
|
||||
|
||||
Run:
|
||||
```bash
|
||||
/home/droid/.conda/envs/roboimi/bin/python roboimi/demos/vla_scripts/train_vla.py \
|
||||
agent=resnet_imf_attnres \
|
||||
data.dataset_dir=/home/droid/project/diana_sim/sim_transfer \
|
||||
data.camera_names=[r_vis,top,front] \
|
||||
train.device=cuda train.val_split=0.0 train.seed=42 \
|
||||
train.batch_size=80 train.lr=5e-4 train.num_workers=12 train.max_steps=150000 \
|
||||
train.log_freq=100 train.save_freq=10000 train.use_swanlab=true \
|
||||
train.swanlab_project=roboimi-vla \
|
||||
train.rollout_val_freq_epochs=5 train.rollout_validate_on_checkpoint=false \
|
||||
train.rollout_num_episodes=5 train.warmup_steps=2000 \
|
||||
train.scheduler_type=cosine train.min_lr=1e-6 train.weight_decay=1e-5 train.grad_clip=1.0 \
|
||||
agent.pred_horizon=16 agent.inference_steps=1 \
|
||||
agent.head.n_emb=384 agent.head.n_layer=18 agent.head.n_head=1 agent.head.n_kv_head=1 \
|
||||
agent.vision_backbone.pretrained_backbone_weights=null \
|
||||
agent.vision_backbone.freeze_backbone=false \
|
||||
agent.vision_backbone.use_separate_rgb_encoder_per_camera=true
|
||||
```
|
||||
Expected: 训练启动成功,SwanLab 记录完整 config,5 epoch 一次 headless rollout
|
||||
|
||||
- [ ] **Step 4: 记录 run 路径、训练 PID、SwanLab 运行名并向用户汇报**
|
||||
|
||||
- [ ] **Step 5: 提交最终收尾改动(如果 smoke fix 需要额外 patch)**
|
||||
|
||||
```bash
|
||||
git add <changed files>
|
||||
git commit -m "chore: verify IMF AttnRes training launch"
|
||||
```
|
||||
@@ -0,0 +1,79 @@
|
||||
# IMF Rollout Trajectory Images and Short-Horizon Training Implementation Plan
|
||||
|
||||
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||
|
||||
**Goal:** Add training-time rollout front trajectory image export plus SwanLab image logging, then start a new local IMF training run with `emb=384`, `layer=12`, `pred_horizon=8`, `num_action_steps=4`, `max_steps=50000`.
|
||||
|
||||
**Architecture:** Extend `eval_vla.py` so a rollout can emit one per-episode static front-view image with red EE trajectory overlay. Extend `train_vla.py` so rollout validation forces image export, forces video off, and uploads those per-episode images to SwanLab. Launch the requested new run through explicit command-line overrides rather than branch-default config changes.
|
||||
|
||||
**Tech Stack:** Python, PyTorch, Hydra/OmegaConf, MuJoCo, OpenCV, SwanLab.
|
||||
|
||||
---
|
||||
|
||||
### Task 1: Add and validate rollout image tests
|
||||
|
||||
**Files:**
|
||||
- Modify: `tests/test_eval_vla_rollout_artifacts.py`
|
||||
- Modify: `tests/test_train_vla_swanlab_logging.py`
|
||||
- Modify: `tests/test_train_vla_rollout_validation.py`
|
||||
|
||||
- [ ] Add/adjust eval tests so they assert per-episode trajectory image paths are produced without requiring video export.
|
||||
- [ ] Add/adjust training tests so they assert training-time rollout validation forces `record_video=false`.
|
||||
- [ ] Add/adjust training tests so they assert trajectory image paths flow from eval summary into SwanLab media logging.
|
||||
- [ ] Add/adjust training tests so they assert image media is logged, not only scalar reward metrics.
|
||||
|
||||
### Task 2: Implement per-episode front trajectory image export in eval
|
||||
|
||||
**Files:**
|
||||
- Modify: `roboimi/demos/vla_scripts/eval_vla.py`
|
||||
- Reuse/Read: `roboimi/utils/raw_action_trajectory_viewer.py`
|
||||
- Modify: `roboimi/vla/conf/eval/eval.yaml`
|
||||
|
||||
- [ ] Add config plumbing for `save_trajectory_image` and `trajectory_image_camera_name`.
|
||||
- [ ] Ensure the default training-time camera resolution path is pinned to `front`.
|
||||
- [ ] Implement distinct per-episode image naming so 5 rollout episodes create 5 distinct PNGs.
|
||||
- [ ] Reuse the existing red trajectory representation logic when composing the PNG.
|
||||
- [ ] Ensure headless eval works under EGL even on machines with `DISPLAY` set.
|
||||
|
||||
### Task 3: Implement SwanLab rollout image logging in training
|
||||
|
||||
**Files:**
|
||||
- Modify: `roboimi/demos/vla_scripts/train_vla.py`
|
||||
- Modify: `tests/test_train_vla_swanlab_logging.py`
|
||||
- Modify: `tests/test_train_vla_rollout_validation.py`
|
||||
|
||||
- [ ] Make `run_rollout_validation()` force `record_video=false`.
|
||||
- [ ] Make `run_rollout_validation()` force `save_trajectory_image=true` and `trajectory_image_camera_name=front`.
|
||||
- [ ] Ensure rollout validation still uses 5 episodes per validation event for the requested run.
|
||||
- [ ] Add a best-effort helper that converts per-episode image paths into SwanLab image media payloads.
|
||||
- [ ] Keep image-upload failures non-fatal and warning-only.
|
||||
|
||||
### Task 4: Verify action-chunk semantics for the new run
|
||||
|
||||
**Files:**
|
||||
- Verify: `roboimi/vla/agent.py`
|
||||
- Verify: `roboimi/vla/agent_imf.py`
|
||||
- Test: `tests/test_imf_vla_agent.py`
|
||||
|
||||
- [ ] Confirm the existing queue logic still means “predict 8, execute first 4”.
|
||||
- [ ] Do not change branch defaults unless strictly necessary; prefer launch-time overrides.
|
||||
|
||||
### Task 5: Verify and launch the requested local training run
|
||||
|
||||
**Files:**
|
||||
- Use: `roboimi/demos/vla_scripts/train_vla.py`
|
||||
- Use: `roboimi/demos/vla_scripts/eval_vla.py`
|
||||
|
||||
- [ ] Run the targeted verification suite.
|
||||
- [ ] Run one real headless smoke eval and confirm a front trajectory PNG is produced while `video_mp4` stays null.
|
||||
- [ ] Launch the new local training run with explicit overrides including:
|
||||
- `agent=resnet_imf_attnres`
|
||||
- `agent.head.n_emb=384`
|
||||
- `agent.head.n_layer=12`
|
||||
- `agent.pred_horizon=8`
|
||||
- `agent.num_action_steps=4`
|
||||
- `train.max_steps=50000`
|
||||
- `train.rollout_num_episodes=5`
|
||||
- `train.use_swanlab=true`
|
||||
- current local baseline dataset/camera/CUDA/batch/lr/num_workers/backbone settings
|
||||
- [ ] Verify PID, GPU allocation, log tail, and SwanLab run URL.
|
||||
@@ -0,0 +1,68 @@
|
||||
# IMF Horizon Grid and AttnRes Ablation Implementation Plan
|
||||
|
||||
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||
|
||||
**Goal:** Run a 6-run Phase-1 IMF horizon/action-step experiment grid across available GPUs, monitor progress and collect best rollout metrics, then use the best horizon setting for a Phase-2 visual-attnres ablation.
|
||||
|
||||
**Architecture:** Use the current IMF training code as-is for Phase-1 by sweeping explicit `(pred_horizon, num_action_steps)` overrides while keeping emb=384, layer=12, and max_steps=50k fixed. Maintain a local experiment suite directory with a manifest and machine-readable status snapshots so progress can be resumed and summarized across turns. After Phase-1 completes, compare the current head-only attnres setup against a variant that also adds attnres into the visual ResNet path.
|
||||
|
||||
**Tech Stack:** Python, Hydra/OmegaConf, PyTorch, SSH/Tailscale, JSON/CSV status files, SwanLab.
|
||||
|
||||
---
|
||||
|
||||
### Task 1: Prepare the experiment suite manifest and state tracking
|
||||
|
||||
**Files:**
|
||||
- Create: `experiment_suites/2026-04-04-imf-horizon-grid/manifest.json`
|
||||
- Create: `experiment_suites/2026-04-04-imf-horizon-grid/status.json`
|
||||
- Create: `experiment_suites/2026-04-04-imf-horizon-grid/notes.md`
|
||||
|
||||
- [ ] Define the 6 legal Phase-1 combinations: `(8,8)`, `(16,8)`, `(16,16)`, `(32,8)`, `(32,16)`, `(32,32)`.
|
||||
- [ ] Record for each run: name, host, GPU slot, command, log path, SwanLab run name, and completion criteria.
|
||||
- [ ] Define the comparison metric as the maximum rollout average reward seen during training (`max avg_reward`), preferably read from the best-checkpoint metadata and cross-checked against logs.
|
||||
- [ ] Keep `status.json` updated with per-run state: queued / running / finished / failed plus latest parsed progress.
|
||||
|
||||
### Task 2: Prepare the remote 8-GPU execution target
|
||||
|
||||
**Files:**
|
||||
- Remote working directory under `/home/droid/`
|
||||
- Reuse or create a synced code directory for this suite
|
||||
|
||||
- [ ] Verify the remote dataset path and environment path.
|
||||
- [ ] Verify GPU availability and reserve 6 GPUs for Phase-1 launches.
|
||||
- [ ] Sync the required code to a dedicated remote suite directory.
|
||||
- [ ] Record exact remote paths back into the local suite manifest.
|
||||
|
||||
### Task 3: Launch the 6 Phase-1 experiments in parallel
|
||||
|
||||
**Files:**
|
||||
- Reuse: `roboimi/demos/vla_scripts/train_vla.py`
|
||||
- Modify only local suite tracking files unless a launch bug is discovered
|
||||
|
||||
- [ ] Launch 6 runs concurrently with fixed settings: IMF, emb=384, layer=12, max_steps=50k.
|
||||
- [ ] Keep all other relevant training hyperparameters aligned to the current strong baseline unless a concrete blocker appears.
|
||||
- [ ] Assign one GPU per run on the 8xL20 host.
|
||||
- [ ] Capture PID, log path, and SwanLab URL for each run in `status.json`.
|
||||
|
||||
### Task 4: Monitor and summarize Phase-1 until all 6 finish
|
||||
|
||||
**Files:**
|
||||
- Update: `experiment_suites/2026-04-04-imf-horizon-grid/status.json`
|
||||
- Update: `experiment_suites/2026-04-04-imf-horizon-grid/notes.md`
|
||||
|
||||
- [ ] Periodically parse each run’s log/checkpoints to extract latest step, latest rollout reward, and best rollout reward so far.
|
||||
- [ ] Keep a resumable local summary so progress can be continued in later turns without rediscovery.
|
||||
- [ ] After all 6 runs finish, rank them by `max avg_reward` and write a compact Phase-1 summary.
|
||||
|
||||
### Task 5: Prepare the Phase-2 visual-attnres ablation
|
||||
|
||||
**Files:**
|
||||
- Likely modify: vision backbone implementation and config files (to be confirmed after code inspection)
|
||||
- Add/update targeted tests for the visual backbone path if code changes are needed
|
||||
|
||||
- [ ] Use the best Phase-1 `(pred_horizon, num_action_steps)` combination as the fixed rollout setting for Phase-2.
|
||||
- [ ] Compare:
|
||||
1. current setup: attnres only in the IMF head
|
||||
2. ablation setup: attnres in both IMF head and visual encoder path
|
||||
- [ ] Keep the rest of the training settings fixed.
|
||||
- [ ] Launch and monitor the Phase-2 pair after Phase-1 summary is complete.
|
||||
@@ -0,0 +1,92 @@
|
||||
# LEWM ViT Backbone Implementation Plan
|
||||
|
||||
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||
|
||||
**Goal:** Replace the current ResNet visual encoder in roboimi VLA training with a frozen LEWM ViT visual backbone (encoder + projector) that consumes the three camera views jointly and outputs one 192-d CLS embedding per timestep, then launch two 50k runs on the 5880 machine.
|
||||
|
||||
**Architecture:** Add a new joint-multiview LEWM backbone that fuses `front/top/r_vis` into one LEWM-style image, reproduces LEWM preprocessing, loads frozen weights from the trained checkpoint, and exposes a `joint_output_dim=192`. Add a minimal `VLAAgent` compatibility branch so conditions can be sized from joint visual dim instead of `output_dim * num_cams`, while leaving the rest of the diffusion pipeline unchanged.
|
||||
|
||||
**Tech Stack:** PyTorch, transformers `ViTModel`, Hydra configs, existing roboimi VLA training/eval scripts, remote SSH/rsync to 100.73.14.65.
|
||||
|
||||
---
|
||||
|
||||
### Task 1: Add failing tests for LEWM joint-vision backbone contract
|
||||
|
||||
**Files:**
|
||||
- Create: `tests/test_lewm_vit_backbone.py`
|
||||
- Modify: `tests/test_imf_vla_agent.py`
|
||||
|
||||
- [ ] **Step 1: Write the failing backbone shape/load test**
|
||||
- [ ] **Step 2: Run `pytest tests/test_lewm_vit_backbone.py -q` and verify it fails**
|
||||
- [ ] **Step 3: Extend `tests/test_imf_vla_agent.py` with a failing joint-output backbone case**
|
||||
- [ ] **Step 4: Run `pytest tests/test_imf_vla_agent.py -q` and verify it fails**
|
||||
|
||||
### Task 2: Implement LEWM joint-multiview frozen backbone
|
||||
|
||||
**Files:**
|
||||
- Create: `roboimi/vla/models/backbones/lewm_vit_backbone.py`
|
||||
- Modify: `roboimi/vla/models/backbones/__init__.py` only if exports are needed
|
||||
|
||||
- [ ] **Step 1: Create `LEWMViTBackbone` with public attrs `camera_names`, `num_cameras`, `joint_output_dim=192`**
|
||||
- [ ] **Step 2: Reproduce LEWM preprocessing and joint multiview fusion**
|
||||
- [ ] **Step 3: Load checkpoint weights from `model.encoder.*` and `model.projector.*`**
|
||||
- [ ] **Step 4: Freeze encoder/projector and keep them in eval mode via `train()` override**
|
||||
- [ ] **Step 5: Run `pytest tests/test_lewm_vit_backbone.py -q` and verify green**
|
||||
|
||||
### Task 3: Add minimal agent support for joint visual dim
|
||||
|
||||
**Files:**
|
||||
- Modify: `roboimi/vla/agent.py`
|
||||
- Test: `tests/test_imf_vla_agent.py`
|
||||
|
||||
- [ ] **Step 1: Add a `joint_output_dim` branch in `VLAAgent.__init__` for `per_step_cond_dim` / `global_cond_dim`**
|
||||
- [ ] **Step 2: Keep `_build_cond()` semantics unchanged except for matching the new dim contract**
|
||||
- [ ] **Step 3: Run `pytest tests/test_imf_vla_agent.py -q` and verify green**
|
||||
|
||||
### Task 4: Add Hydra configs for LEWM backbone training
|
||||
|
||||
**Files:**
|
||||
- Create: `roboimi/vla/conf/backbone/lewm_vit_diffusion.yaml`
|
||||
- Create: `roboimi/vla/conf/agent/lewm_imf_attnres.yaml`
|
||||
|
||||
- [ ] **Step 1: Add backbone config pointing to the new LEWM backbone**
|
||||
- [ ] **Step 2: Add `agent=lewm_imf_attnres` config with 3 cameras and `head.cond_dim=208`**
|
||||
- [ ] **Step 3: Verify Hydra instantiation with a one-shot compose smoke**
|
||||
|
||||
### Task 5: Verify focused local tests
|
||||
|
||||
**Files:**
|
||||
- Reuse the above
|
||||
|
||||
- [ ] **Step 1: Run `pytest tests/test_lewm_vit_backbone.py tests/test_imf_vla_agent.py tests/test_eval_vla_headless_import.py -q`**
|
||||
- [ ] **Step 2: If needed, run one tiny local import/forward smoke**
|
||||
|
||||
### Task 6: Sync to 5880 and remote smoke with real checkpoint
|
||||
|
||||
**Files:**
|
||||
- Remote target: `/home/droid/roboimi_suite_20260404`
|
||||
|
||||
- [ ] **Step 1: Rsync modified source/config files to `100.73.14.65:/home/droid/roboimi_suite_20260404`**
|
||||
- [ ] **Step 2: Run a 2-step smoke on GPU0 with `agent.head.n_emb=384`, `train.rollout_num_episodes=10`, real LEWM checkpoint**
|
||||
- [ ] **Step 3: Run a 2-step smoke on GPU1 with `agent.head.n_emb=256`, same checkpoint**
|
||||
|
||||
### Task 7: Launch two real 50k runs on the 5880 machine
|
||||
|
||||
**Files:**
|
||||
- Remote logs under `/home/droid/roboimi_suite_20260404/experiment_suite_launch_logs/`
|
||||
|
||||
- [ ] **Step 1: Launch embed384/layer12 on GPU0**
|
||||
- [ ] **Step 2: Launch embed256/layer12 on GPU1**
|
||||
- [ ] **Step 3: Ensure both use `data.camera_names=[r_vis,top,front]`, `pred_horizon=16`, `num_action_steps=8`, `train.rollout_num_episodes=10`, `max_steps=50000`**
|
||||
- [ ] **Step 4: Record run names, pids, log paths, SwanLab URLs**
|
||||
|
||||
### Task 8: Update experiment tracking docs and commit
|
||||
|
||||
**Files:**
|
||||
- Create: `experiment_suites/2026-04-05-lewm-vit-transfer/manifest.json`
|
||||
- Create: `experiment_suites/2026-04-05-lewm-vit-transfer/status.json`
|
||||
- Create: `experiment_suites/2026-04-05-lewm-vit-transfer/notes.md`
|
||||
|
||||
- [ ] **Step 1: Record checkpoint path, frozen LEWM design, rollout=10, and both run configs**
|
||||
- [ ] **Step 2: Record running status after launch**
|
||||
- [ ] **Step 3: Commit implementation + docs with a focused message**
|
||||
@@ -0,0 +1,64 @@
|
||||
# Phase-2 Full-AttnRes Vision Implementation Plan
|
||||
|
||||
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||
|
||||
**Goal:** Replace all ResNet residual units in the vision backbone with AttnRes-based image blocks while preserving the current IMF agent interfaces and launch a Phase-2 experiment anchored on the best Phase-1 horizon setting.
|
||||
|
||||
**Architecture:** Keep the current multi-camera encoder shell and per-camera output contract, but introduce a new ResNet-like 2D AttnRes backbone that preserves stage-wise downsampling and final SpatialSoftmax conditioning. Wire it into the existing `ResNetDiffusionBackbone` via an opt-in mode and keep the agent/head/data interfaces unchanged.
|
||||
|
||||
**Tech Stack:** PyTorch, Hydra/OmegaConf, existing IMF AttnRes transformer components, pytest.
|
||||
|
||||
---
|
||||
|
||||
### Task 1: Add failing tests for the new full-AttnRes visual backbone
|
||||
|
||||
**Files:**
|
||||
- Create: `tests/test_attnres_resnet2d_backbone.py`
|
||||
- Update: `tests/test_imf_vla_agent.py`
|
||||
|
||||
- [ ] **Step 1: Write a failing backbone shape test**
|
||||
- [ ] **Step 2: Run it to confirm the new backbone/config does not exist yet**
|
||||
- [ ] **Step 3: Add a failing IMF agent wiring test for unchanged cond_dim=208**
|
||||
- [ ] **Step 4: Run the targeted tests and capture the failure**
|
||||
|
||||
### Task 2: Implement a ResNet-like 2D AttnRes backbone
|
||||
|
||||
**Files:**
|
||||
- Create: `roboimi/vla/models/backbones/attnres_resnet2d.py`
|
||||
- Modify: `roboimi/vla/models/backbones/resnet_diffusion.py`
|
||||
|
||||
- [ ] **Step 1: Add minimal 2D tokenization helpers and positional encoding / bias handling**
|
||||
- [ ] **Step 2: Implement `AttnResImageBlock2D` for feature maps**
|
||||
- [ ] **Step 3: Implement `AttnResResNetLikeBackbone2D` with stage-wise downsampling**
|
||||
- [ ] **Step 4: Wire `_SingleRgbEncoder` to choose between original ResNet trunk and the new full-AttnRes trunk**
|
||||
- [ ] **Step 5: Run the new backbone tests**
|
||||
|
||||
### Task 3: Expose config switches and agent wiring
|
||||
|
||||
**Files:**
|
||||
- Modify: `roboimi/vla/conf/backbone/resnet_diffusion.yaml`
|
||||
- Modify: `roboimi/vla/conf/agent/resnet_imf_attnres.yaml`
|
||||
|
||||
- [ ] **Step 1: Add a backbone mode/config flag for the full-AttnRes vision trunk**
|
||||
- [ ] **Step 2: Add defaults for attnres image depth/heads/etc. if needed**
|
||||
- [ ] **Step 3: Add a Phase-2 launch override path that enables the new visual trunk**
|
||||
- [ ] **Step 4: Run agent wiring tests again**
|
||||
|
||||
### Task 4: Smoke-verify training path
|
||||
|
||||
**Files:**
|
||||
- Reuse existing training scripts and configs
|
||||
|
||||
- [ ] **Step 1: Run a short CPU or tiny-step smoke instantiation / `compute_loss` test**
|
||||
- [ ] **Step 2: If needed, run a very short training smoke launch**
|
||||
- [ ] **Step 3: Verify no cond-dim or rollout-loading regressions**
|
||||
|
||||
### Task 5: Launch the Phase-2 experiment
|
||||
|
||||
**Files:**
|
||||
- Update experiment tracking under `experiment_suites/`
|
||||
|
||||
- [ ] **Step 1: Use Phase-1 best setting (`pred_horizon=16`, `num_action_steps=8`)**
|
||||
- [ ] **Step 2: Launch baseline reference or reuse existing result**
|
||||
- [ ] **Step 3: Launch full-AttnRes vision experiment**
|
||||
- [ ] **Step 4: Track rollout metrics and compare max avg_reward**
|
||||
81
docs/superpowers/plans/2026-04-06-resnet-multitoken-imf.md
Normal file
81
docs/superpowers/plans/2026-04-06-resnet-multitoken-imf.md
Normal file
@@ -0,0 +1,81 @@
|
||||
# ResNet Multitoken IMF Implementation Plan
|
||||
|
||||
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||
|
||||
**Goal:** Implement a standard-ResNet-18 multiview IMF variant that emits three condition tokens per obs step and launch four L20 experiments for `n_emb in {256,384}` and `n_layer in {12,16}`.
|
||||
|
||||
**Architecture:** The ResNet backbone will optionally return one token per camera instead of concatenating all cameras into one token. `VLAAgent` will pair each camera token with the current state, project each pair into a condition token, flatten the per-step camera tokens into one cond sequence, and feed that sequence into the existing IMF/AttnRes head.
|
||||
|
||||
**Tech Stack:** PyTorch, torchvision ResNet-18, Hydra, pytest, SwanLab, SSH/Tailscale.
|
||||
|
||||
---
|
||||
|
||||
### Task 1: Add failing tests for multi-token conditioning
|
||||
|
||||
**Files:**
|
||||
- Modify: `tests/test_imf_vla_agent.py`
|
||||
- Modify: `tests/test_resnet_transformer_agent_wiring.py`
|
||||
|
||||
- [ ] **Step 1: Add a direct agent test**
|
||||
- Stub a vision backbone returning `(B,T,3,D)` and assert `_build_cond()` yields `(B, T*3, D_cond)`.
|
||||
- Assert state is paired with each camera token, not concatenated across cameras first.
|
||||
- [ ] **Step 2: Add Hydra wiring test**
|
||||
- Instantiate a new `agent=resnet_imf_attnres_multitoken` config with small dims.
|
||||
- Assert `condition_tokens_per_step == 3`, `condition_sequence_length == obs_horizon * 3`, and head `n_obs_steps` receives that sequence length.
|
||||
- [ ] **Step 3: Run focused tests and verify RED**
|
||||
- `python -m pytest tests/test_imf_vla_agent.py tests/test_resnet_transformer_agent_wiring.py -q`
|
||||
|
||||
### Task 2: Implement multi-token ResNet conditioning path
|
||||
|
||||
**Files:**
|
||||
- Modify: `roboimi/vla/models/backbones/resnet_diffusion.py`
|
||||
- Modify: `roboimi/vla/agent.py`
|
||||
- Create: `roboimi/vla/conf/agent/resnet_imf_attnres_multitoken.yaml`
|
||||
|
||||
- [ ] **Step 1: Extend ResNet backbone**
|
||||
- Add an opt-in flag to return `(B,T,num_cams,D)` camera tokens instead of one concatenated `(B,T,num_cams*D)` token.
|
||||
- Keep standard ResNet-18 vision mode; do not switch to AttnRes vision.
|
||||
- [ ] **Step 2: Extend VLAAgent condition building**
|
||||
- Support visual features with rank 4 `(B,T,K,D)`.
|
||||
- Broadcast state to `(B,T,K,D_state)`, concatenate per camera, apply projector per token, then flatten to `(B,T*K,D_cond)`.
|
||||
- Track `condition_tokens_per_step` and `condition_sequence_length`.
|
||||
- [ ] **Step 3: Update transformer-head instantiation**
|
||||
- Pass `n_obs_steps=condition_sequence_length` when building transformer heads.
|
||||
- [ ] **Step 4: Add Hydra config**
|
||||
- New agent config uses:
|
||||
- separate ResNet-18 per camera
|
||||
- standard residual vision trunk (`vision_backbone_mode=resnet`)
|
||||
- condition projector output dim tied to `${agent.head.n_emb}`
|
||||
- rollout episodes `10`, `pred_horizon=16`, `num_action_steps=8`
|
||||
|
||||
### Task 3: Verify locally
|
||||
|
||||
**Files:**
|
||||
- Modify only if verification reveals issues
|
||||
|
||||
- [ ] **Step 1: Run focused tests and make them pass**
|
||||
- `python -m pytest tests/test_imf_vla_agent.py tests/test_resnet_transformer_agent_wiring.py -q`
|
||||
- [ ] **Step 2: Run regression subset**
|
||||
- `python -m pytest tests/test_eval_vla_headless.py tests/test_train_vla_rollout_validation.py tests/test_simple_robot_dataset_image_loading.py -q`
|
||||
- [ ] **Step 3: Run local smoke instantiation**
|
||||
- instantiate the new Hydra config and verify cond shape / sequence length
|
||||
|
||||
### Task 4: Launch 4 L20 experiments
|
||||
|
||||
**Files:**
|
||||
- Remote repo copy under `/home/droid/roboimi_suite_20260404`
|
||||
|
||||
- [ ] **Step 1: Sync code to `100.119.99.14`**
|
||||
- [ ] **Step 2: Smoke the new config on remote**
|
||||
- [ ] **Step 3: Launch runs**
|
||||
- `(n_emb=256, n_layer=12)`
|
||||
- `(n_emb=256, n_layer=16)`
|
||||
- `(n_emb=384, n_layer=12)`
|
||||
- `(n_emb=384, n_layer=16)`
|
||||
- [ ] **Step 4: Keep fixed across runs**
|
||||
- rollout episodes `10`
|
||||
- `pred_horizon=16`
|
||||
- `num_action_steps=8`
|
||||
- standard ResNet-18 vision trunk
|
||||
- three separate camera weights
|
||||
- [ ] **Step 5: Record PIDs, GPUs, log paths, SwanLab URLs**
|
||||
78
docs/superpowers/plans/2026-04-06-siglip2-multiview-vla.md
Normal file
78
docs/superpowers/plans/2026-04-06-siglip2-multiview-vla.md
Normal file
@@ -0,0 +1,78 @@
|
||||
# SigLIP2 Multiview VLA Implementation Plan
|
||||
|
||||
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||
|
||||
**Goal:** Integrate a frozen shared SigLIP2 multiview encoder into the IMF/AttnRes policy, preserve raw-256 image handling, and launch two 50k-step experiments on the 5880 host with per-view projection dims 96 and 192.
|
||||
|
||||
**Architecture:** A new backbone will independently encode each camera view with SigLIP2 and project each 768-d pooled feature to a configurable per-view dimension. `VLAAgent` will concatenate visual features with robot state, then optionally project the combined per-step condition to the head's required 384-d interface before diffusion training/inference.
|
||||
|
||||
**Tech Stack:** PyTorch, transformers SigLIP2, Hydra, pytest, SSH/Tailscale, SwanLab.
|
||||
|
||||
---
|
||||
|
||||
### Task 1: Add failing tests for SigLIP2 backbone and projected conditioning
|
||||
|
||||
**Files:**
|
||||
- Create: `tests/test_siglip2_diffusion_backbone.py`
|
||||
- Modify: `tests/test_imf_vla_agent.py`
|
||||
|
||||
- [ ] **Step 1: Write failing backbone tests**
|
||||
- Instantiate the new backbone with a stub SigLIP2 vision model.
|
||||
- Assert raw dataset resize is `None`, eval resize is `(256, 256)`, output shape is `(B, T, 3 * per_view_output_dim)`.
|
||||
- Assert three views are encoded independently and projected.
|
||||
- [ ] **Step 2: Run focused tests and verify RED**
|
||||
- Run `pytest tests/test_siglip2_diffusion_backbone.py tests/test_imf_vla_agent.py -q`
|
||||
- Expect failure because the backbone/config/projector do not exist yet.
|
||||
- [ ] **Step 3: Extend agent wiring tests**
|
||||
- Add a Hydra/instantiate test for a new SigLIP2 IMF config.
|
||||
- Assert raw condition dim `3 * per_view_output_dim + obs_dim`, projected cond dim `384`, and head `cond_dim == 384`.
|
||||
|
||||
### Task 2: Implement SigLIP2 backbone and optional condition projector
|
||||
|
||||
**Files:**
|
||||
- Create: `roboimi/vla/models/backbones/siglip2_diffusion_backbone.py`
|
||||
- Create: `roboimi/vla/conf/backbone/siglip2_diffusion.yaml`
|
||||
- Create: `roboimi/vla/conf/agent/siglip2_imf_attnres.yaml`
|
||||
- Create: `roboimi/vla/conf/modules/linear_condition_projector.yaml`
|
||||
- Modify: `roboimi/vla/models/backbones/__init__.py`
|
||||
- Modify: `roboimi/vla/agent.py`
|
||||
|
||||
- [ ] **Step 1: Implement backbone**
|
||||
- Load `SiglipVisionModel.from_pretrained("google/siglip2-base-patch16-256")`.
|
||||
- Normalize `[0,1]` pixels with mean/std `0.5` and encode each view independently.
|
||||
- Project each 768-d pooled feature to configurable per-view dim and concatenate across cameras.
|
||||
- [ ] **Step 2: Implement optional condition projector**
|
||||
- Allow `VLAAgent` to accept `cond_projector`.
|
||||
- Track `raw_per_step_cond_dim` and projected `per_step_cond_dim` / `global_cond_dim`.
|
||||
- Apply the projector in `_build_cond()` after visual+state concatenation.
|
||||
- [ ] **Step 3: Add Hydra configs**
|
||||
- New agent config should default to `n_emb=384`, `n_layer=12`, `pred_horizon=16`, `num_action_steps=8`, `head.cond_dim=384`.
|
||||
- Backbone config should set `dataset_image_resize_shape: null` and `eval_image_resize_shape: [256, 256]`.
|
||||
|
||||
### Task 3: Verify locally and prepare remote execution
|
||||
|
||||
**Files:**
|
||||
- Modify as needed only if tests/smoke reveal issues
|
||||
|
||||
- [ ] **Step 1: Run focused tests and make them pass**
|
||||
- `pytest tests/test_siglip2_diffusion_backbone.py tests/test_imf_vla_agent.py tests/test_eval_vla_headless.py tests/test_train_vla_rollout_validation.py tests/test_simple_robot_dataset_image_loading.py -q`
|
||||
- [ ] **Step 2: Run a local smoke instantiation**
|
||||
- Instantiate the new Hydra config with stubbed optional modules or offline-safe monkeypatching.
|
||||
- [ ] **Step 3: Review diffs for unintended LEWM/raw256 regressions**
|
||||
|
||||
### Task 4: Sync to 5880 and launch experiments
|
||||
|
||||
**Files:**
|
||||
- Remote repo copy under `/home/droid/roboimi_suite_20260404`
|
||||
|
||||
- [ ] **Step 1: Stop superseded remote jobs**
|
||||
- [ ] **Step 2: Sync updated code to remote**
|
||||
- Prefer `rsync` or `git push/pull` without overwriting unrelated files.
|
||||
- [ ] **Step 3: Remote smoke test**
|
||||
- Confirm SigLIP2 model download/import works in `/home/droid/miniforge3/envs/roboimi/bin/python`.
|
||||
- Confirm headless rollout path still uses `256x256` eval resize.
|
||||
- [ ] **Step 4: Launch experiment A**
|
||||
- `per_view_output_dim=96`, `embed=384`, `layer=12`, `pred=16`, `exec=8`, `steps=50000`.
|
||||
- [ ] **Step 5: Launch experiment B**
|
||||
- `per_view_output_dim=192`, same other hyperparameters.
|
||||
- [ ] **Step 6: Record PIDs, GPUs, log paths, and SwanLab run URLs.**
|
||||
@@ -0,0 +1,241 @@
|
||||
# VLA Training + Headless Rollout + SwanLab Design
|
||||
|
||||
**Date:** 2026-03-30
|
||||
**Branch:** feat-align-dp-transformer-ee
|
||||
|
||||
## Goal
|
||||
在当前仓库中补齐默认 `resnet_transformer` / `Transformer1D` 路线的训练依赖,使用数据集 `/home/droid/project/diana_sim/sim_transfer` 启动训练;同时支持训练过程中的 SwanLab 标量日志上传,并为后续 rollout 验证提供 headless 模式,避免弹出 MuJoCo / OpenCV 图形界面。
|
||||
|
||||
## Non-Goals
|
||||
- 不重写整套训练框架
|
||||
- 不引入新的 workspace / callback 框架
|
||||
- 不在本轮做复杂的视频/媒体日志上传
|
||||
- 不修改数据集格式本身
|
||||
|
||||
## Current State
|
||||
- 默认训练配置已切到 `agent=resnet_transformer`,head 为 `Transformer1D`
|
||||
- 当前环境缺少训练所需的若干 Python 依赖:`diffusers`、`torchvision`、`einops`、`swanlab`
|
||||
- 评估环境 `make_sim_env(task_name)` 当前写死 `is_render=True`
|
||||
- 相机线程 `camera_viewer()` 默认会 `cv2.namedWindow/imshow`,即使只想拿图像也会弹窗
|
||||
- 训练脚本当前支持 train/val loss、checkpoint,但没有 SwanLab 集成
|
||||
- 数据集目录 `/home/droid/project/diana_sim/sim_transfer` 下已有 100 个 episode,但还没有 `dataset_stats.pkl`
|
||||
|
||||
## User Requirements
|
||||
1. 在现有 mamba 环境里补齐训练依赖
|
||||
2. 在 `/home/droid/project/diana_sim/sim_transfer` 上开始训练
|
||||
3. 如果训练中需要 rollout 验证,希望支持 headless,不弹 GUI
|
||||
4. 训练指标上传到 SwanLab
|
||||
5. 默认 SwanLab project 名为 `roboimi-vla`
|
||||
|
||||
## Proposed Approach
|
||||
采用“最小必要改造”方案:
|
||||
|
||||
### 1. Dependency Layer
|
||||
在现有 `roboimi` 环境中补齐缺失训练依赖,并优先保持现有环境名与脚本入口不变。
|
||||
|
||||
#### Install Plan
|
||||
- 环境:继续使用现有 mamba 环境 `roboimi`
|
||||
- 安装方式:
|
||||
- 优先使用当前 env 的 `python -m pip install`
|
||||
- 安装包:
|
||||
- `diffusers`
|
||||
- `torchvision`
|
||||
- `einops`
|
||||
- `swanlab`
|
||||
- 版本策略:
|
||||
- 优先选择与当前 `torch==2.4.0` 可兼容的最新可安装版本
|
||||
- 若出现兼容性问题,再回退到与 `torch 2.4` 对齐的稳定版本
|
||||
- 复现策略:
|
||||
- 本轮会把**实际安装成功的 resolved versions** 补写回仓库的环境定义文件,避免后续环境漂移
|
||||
|
||||
训练前验证以下 import:
|
||||
- `torch`
|
||||
- `hydra`
|
||||
- `omegaconf`
|
||||
- `diffusers`
|
||||
- `torchvision`
|
||||
- `einops`
|
||||
- `swanlab`
|
||||
- `cv2`
|
||||
- `h5py`
|
||||
- `mujoco`
|
||||
|
||||
### 2. Dataset Preparation
|
||||
直接复用现有 `SimpleRobotDataset`,仅将 `data.dataset_dir` 指向:
|
||||
- `/home/droid/project/diana_sim/sim_transfer`
|
||||
|
||||
训练前使用现有统计脚本生成:
|
||||
- `/home/droid/project/diana_sim/sim_transfer/dataset_stats.pkl`
|
||||
|
||||
统计文件生成命令目标为:
|
||||
- 从仓库根目录执行
|
||||
- 直接针对 `/home/droid/project/diana_sim/sim_transfer` 输出 stats
|
||||
- 训练脚本不再依赖默认数据目录
|
||||
|
||||
### 3. SwanLab Logging
|
||||
在训练脚本中增加一个轻量 logging 集成层:
|
||||
- 通过配置决定是否启用 SwanLab,默认启用
|
||||
- 默认 project:`roboimi-vla`
|
||||
- API key 不写入仓库,不写入配置文件,只通过本地登录状态或环境变量使用
|
||||
- 当 `train.use_swanlab=true` 时:
|
||||
- 若 `swanlab` 不可 import,训练直接 fail fast
|
||||
- 若未登录或认证失败,训练直接 fail fast
|
||||
- 每个训练日志点上传:
|
||||
- `train/loss`
|
||||
- `train/lr`
|
||||
- `train/best_loss`
|
||||
- `train/step`
|
||||
- 每次验证时上传:
|
||||
- `val/loss`
|
||||
- 训练结束时记录最终 checkpoint 路径与 best checkpoint 路径
|
||||
|
||||
### 4. Headless Rollout Design
|
||||
目标是让 rollout 验证可以“拿到图像观测,但不弹任何窗口”。
|
||||
|
||||
最小改造策略:
|
||||
- 给 `make_sim_env(...)` 增加 `headless` / `is_render` 参数
|
||||
- 给相机线程显示逻辑增加开关:
|
||||
- headless 时继续更新 `r_vis/top/front/...` 图像缓存
|
||||
- 但不执行 `cv2.namedWindow` / `cv2.imshow` / `cv2.waitKey`
|
||||
- 评估脚本中:
|
||||
- headless 时不调用 `env.render()`
|
||||
- 仍然允许 `env._get_image_obs()` 和 policy inference 正常运行
|
||||
|
||||
#### Training-Time Rollout Scope
|
||||
- 本轮**会提供一个可选的 checkpoint-time rollout validation 路径**,默认关闭
|
||||
- 启用后,在训练保存 checkpoint 时可以调用同仓库的 rollout/eval 逻辑做少量 episode 验证
|
||||
- 此路径要求支持**唯一权威开关** `eval.headless=true`,即:
|
||||
- 不弹 MuJoCo viewer
|
||||
- 不执行 `cv2.namedWindow / cv2.imshow / cv2.waitKey`
|
||||
- 仍可读取图像并完成策略推理
|
||||
- 默认情况下不增加频繁 rollout,以避免拖慢训练;只提供能力与配置开关
|
||||
|
||||
如果验证发现相机线程强依赖 GUI,我们的降级策略是:
|
||||
- 训练主流程 + SwanLab 必须先跑通
|
||||
- rollout validation 保持为显式可选能力
|
||||
- 但本轮仍要保证至少存在可调用的 headless 验证执行路径,而不是仅停留在文档层面
|
||||
|
||||
### 5. Training Execution Strategy
|
||||
分两步执行:
|
||||
|
||||
#### Step A: Smoke Run
|
||||
使用较小步数启动一次 smoke training,确认:
|
||||
- 数据集可正常读取
|
||||
- 统计文件可加载
|
||||
- 模型可实例化
|
||||
- 单步前后向正常
|
||||
- checkpoint 正常写出
|
||||
- SwanLab 成功上传标量
|
||||
|
||||
#### Step B: Real Training Run
|
||||
在 smoke run 成功后,再启动正式训练。
|
||||
|
||||
## Execution Commands
|
||||
|
||||
### A. Stats Generation
|
||||
从仓库根目录执行,生成:
|
||||
- `/home/droid/project/diana_sim/sim_transfer/dataset_stats.pkl`
|
||||
|
||||
命令模板:
|
||||
```bash
|
||||
/home/droid/.conda/envs/roboimi/bin/python roboimi/vla/scripts/calculate_stats.py \
|
||||
--dataset_dir /home/droid/project/diana_sim/sim_transfer
|
||||
```
|
||||
|
||||
### B. Smoke Training Command
|
||||
从仓库根目录执行,核心覆盖项包括:
|
||||
- `data.dataset_dir=/home/droid/project/diana_sim/sim_transfer`
|
||||
- 较小 `train.max_steps`
|
||||
- 较高日志频率
|
||||
- 启用 SwanLab
|
||||
- 输出目录使用当前运行目录下的 `checkpoints/`
|
||||
|
||||
命令模板:
|
||||
```bash
|
||||
/home/droid/.conda/envs/roboimi/bin/python roboimi/demos/vla_scripts/train_vla.py \
|
||||
data.dataset_dir=/home/droid/project/diana_sim/sim_transfer \
|
||||
train.max_steps=20 \
|
||||
train.log_freq=1 \
|
||||
train.save_freq=10 \
|
||||
train.use_swanlab=true \
|
||||
train.swanlab_project=roboimi-vla \
|
||||
train.rollout_validate_on_checkpoint=false
|
||||
```
|
||||
|
||||
### C. Real Training Command
|
||||
从仓库根目录执行,核心覆盖项包括:
|
||||
- `data.dataset_dir=/home/droid/project/diana_sim/sim_transfer`
|
||||
- 正式 `train.max_steps`
|
||||
- 默认 project=`roboimi-vla`
|
||||
- 若启用 rollout validation,则传入 `eval.headless=true` 以及训练侧 rollout 开关
|
||||
|
||||
命令模板:
|
||||
```bash
|
||||
/home/droid/.conda/envs/roboimi/bin/python roboimi/demos/vla_scripts/train_vla.py \
|
||||
data.dataset_dir=/home/droid/project/diana_sim/sim_transfer \
|
||||
train.use_swanlab=true \
|
||||
train.swanlab_project=roboimi-vla \
|
||||
train.rollout_validate_on_checkpoint=true \
|
||||
eval.headless=true
|
||||
```
|
||||
|
||||
### D. Output Behavior
|
||||
- checkpoint 输出目录:当前工作目录下的 `checkpoints/`
|
||||
- 关键文件:
|
||||
- `checkpoints/vla_model_step_<N>.pt`
|
||||
- `checkpoints/vla_model_best.pt`
|
||||
- `checkpoints/vla_model_final.pt`
|
||||
|
||||
## File-Level Changes
|
||||
- `environment.yml`
|
||||
- 补写新增训练依赖,保证后续可复现
|
||||
- `roboimi/demos/vla_scripts/train_vla.py`
|
||||
- 增加 SwanLab 集成
|
||||
- 增加更明确的数据集目录覆盖支持
|
||||
- 增加可选 checkpoint-time rollout validation 入口
|
||||
- 保持当前 optimizer 对齐逻辑不变
|
||||
- `roboimi/vla/conf/config.yaml`
|
||||
- 增加/扩展训练日志、SwanLab、rollout 相关配置项
|
||||
- `roboimi/vla/conf/eval/eval.yaml`
|
||||
- 增加 `headless` 等评估控制项
|
||||
- `roboimi/envs/double_pos_ctrl_env.py`
|
||||
- `make_sim_env` 支持 headless / no-render
|
||||
- `roboimi/envs/double_base.py`
|
||||
- 相机采集与 GUI 显示解耦
|
||||
- `roboimi/vla/scripts/calculate_stats.py`
|
||||
- 改为直接支持通过命令行传入外部 `dataset_dir`
|
||||
- tests(新增)
|
||||
- 覆盖 SwanLab 可选初始化路径
|
||||
- 覆盖 headless 环境下“不弹窗但可取图”的关键逻辑
|
||||
|
||||
## Validation Plan
|
||||
1. 补齐依赖后验证 import 全通过
|
||||
2. 生成 `dataset_stats.pkl`
|
||||
3. 运行训练 smoke run
|
||||
4. 确认 SwanLab dashboard 在 project `roboimi-vla` 下有标量更新
|
||||
5. 若启用 rollout 验证:确认 headless 下不弹 GUI,且 rollout 路径能真正执行
|
||||
6. 再启动正式训练
|
||||
|
||||
## Config Contract
|
||||
本轮新增/固定的配置键以以下形式为准:
|
||||
- `train.use_swanlab: true|false`
|
||||
- `train.swanlab_project: roboimi-vla`
|
||||
- `train.rollout_validate_on_checkpoint: true|false`
|
||||
- `eval.headless: true|false`
|
||||
|
||||
## Risks and Mitigations
|
||||
- **Risk:** GUI/相机线程与离屏渲染耦合
|
||||
- **Mitigation:** 先解耦显示与图像更新;必要时把 rollout 验证降级为第二阶段
|
||||
- **Risk:** 现有 env 依赖不完整
|
||||
- **Mitigation:** 先做 import 验证,再做 smoke run
|
||||
- **Risk:** 数据集过大导致 smoke run 也很慢
|
||||
- **Mitigation:** smoke run 只跑极小步数
|
||||
- **Risk:** SwanLab API key 泄漏
|
||||
- **Mitigation:** 不写入代码/配置,只保存在本地登录态或环境变量
|
||||
|
||||
## Success Criteria
|
||||
- 训练脚本能在 `/home/droid/project/diana_sim/sim_transfer` 上启动
|
||||
- 能成功写出 checkpoint 到 `checkpoints/`
|
||||
- SwanLab 在 `roboimi-vla` 项目下能看到 train/val 标量
|
||||
- headless rollout 具备不弹 GUI 的执行路径
|
||||
- 若训练侧启用 rollout validation,则该路径可以在 headless 模式下被实际调用
|
||||
@@ -0,0 +1,16 @@
|
||||
# Rollout Artifacts Design
|
||||
|
||||
**Goal:** Add a one-off evaluation path that can record rollout video, export per-step timing breakdowns, and save executed end-effector trajectories for a selected checkpoint while preserving default eval behavior when artifact capture is disabled.
|
||||
|
||||
**Approach:** Extend `roboimi/demos/vla_scripts/eval_vla.py` with optional evaluation-time artifact capture that stays backward compatible when disabled. Reuse existing environment observation and camera streams, record one camera stream to MP4, collect per-step timing around observation read / preprocessing / model inference / env step / total loop, and save per-step raw predicted EE actions plus executed EE poses after stepping.
|
||||
|
||||
**Artifact contract:**
|
||||
- `video.mp4`: optional MP4 encoded from a selected camera stream (`r_vis`, `top`, `front`, etc.), written only when recording is enabled.
|
||||
- `trajectory.npz`: canonical trajectory export containing at minimum `step`, `reward`, `raw_action`, `executed_left_link7_pos`, `executed_left_link7_quat`, `executed_right_link7_pos`, `executed_right_link7_quat`, and optional duplicated tool-body poses if captured.
|
||||
- `timing.json`: JSON-serializable per-episode timing summary with millisecond units for `obs_read_ms`, `preprocess_ms`, `inference_ms`, `env_step_ms`, `loop_total_ms`, plus aggregate mean/std/min/max and counts. Raw per-step timing arrays should also be persisted in the NPZ for later analysis.
|
||||
|
||||
**Checkpoint selection:** Prefer an explicitly requested checkpoint path. If the caller asks for “latest” or omits a path in the execution helper, select the newest fully written checkpoint file by mtime/name and fail clearly if none exists.
|
||||
|
||||
**Stop-training / execution safety:** Before rollout, stop any active training process using the target run, wait for process exit, then verify the chosen checkpoint exists and is readable. If the most recent checkpoint is missing or mid-write, fall back to the previous completed checkpoint or `vla_model_best.pt` with the decision logged.
|
||||
|
||||
**Backward compatibility:** With all new eval flags left at default values, `_run_eval` return shape must remain compatible with existing callers, training-time rollout validation should continue to work without passing new options, and no artifact files should be written.
|
||||
272
docs/superpowers/specs/2026-04-01-imf-attnres-policy-design.md
Normal file
272
docs/superpowers/specs/2026-04-01-imf-attnres-policy-design.md
Normal file
@@ -0,0 +1,272 @@
|
||||
# IMF-AttnRes Policy Migration Design
|
||||
|
||||
**Date:** 2026-04-01
|
||||
**Status:** Approved in chat, written spec pending review
|
||||
|
||||
## Goal
|
||||
|
||||
将 `/home/droid/project/diffusion_policy` 中提交 `185ed659` 的 IMF-AttnRes diffusion policy 迁移到当前 `roboimi` 仓库,作为当前 DiT / Transformer diffusion policy 的替代训练选项;同时迁移其训练目标与一步推理机制,并保持 RoboIMI 现有的仿真环境、三相机视觉输入、数据集格式、训练脚本和 rollout 验证工作流可继续使用。
|
||||
|
||||
## Non-Goals
|
||||
|
||||
- 不迁移 external repo 中与当前任务无关的 obs encoder、dataset、env wrapper、PushT 专用逻辑。
|
||||
- 不强行复刻 external repo 中全部目录结构;仅迁移当前 RoboIMI 训练所必需的模型、loss、inference 语义。
|
||||
- 不在本次工作中同时保留旧 DiT 为默认训练目标;旧配置继续可用,但新模型单独提供 config 入口。
|
||||
|
||||
## User-Confirmed Requirements
|
||||
|
||||
1. 迁移对象是 `185ed659` 中的 **IMF-AttnRes 模型相关代码**。
|
||||
2. 不只是迁移骨架,还要迁移:
|
||||
- **训练目标**
|
||||
- **一步推理机制**
|
||||
3. 视觉输入与当前 RoboIMI diffusion policy 一致:
|
||||
- 使用三个相机图像作为条件输入
|
||||
- 图像观测必须作为条件,而不是拼进输出预测目标
|
||||
4. 当前任务里,IMF policy 用来替代现有 DiT/Transformer diffusion policy 训练。
|
||||
5. 训练参数沿用最近一次训练的大体设置(后续由训练命令显式覆盖),但推理方式改为 IMF 的 one-step 机制。
|
||||
6. 用户接受 IMF 中“全注意力 / 非因果注意力”的实现约束。
|
||||
|
||||
## External Source of Truth
|
||||
|
||||
迁移语义以 external repo 的以下文件为准:
|
||||
|
||||
- `diffusion_policy/model/diffusion/attnres_transformer_components.py`
|
||||
- `diffusion_policy/model/diffusion/imf_transformer_for_diffusion.py`
|
||||
- `diffusion_policy/policy/imf_transformer_hybrid_image_policy.py`
|
||||
- 参考配置:`image_pusht_diffusion_policy_dit_imf_attnres_full.yaml`
|
||||
|
||||
其中最关键的差异是:该策略并非 DDPM/DDIM 多步去噪,而是 IMF 训练目标 + one-step 推理。
|
||||
|
||||
## Current RoboIMI Baseline
|
||||
|
||||
当前 RoboIMI 中与该任务直接相关的基线如下:
|
||||
|
||||
- 视觉编码:`ResNetDiffusionBackbone`
|
||||
- 三相机:`r_vis`, `top`, `front`
|
||||
- 每个时间步将相机特征与 `qpos` 拼接为 per-step condition
|
||||
- 策略主体:`VLAAgent`
|
||||
- `compute_loss()` 使用 DDPM 噪声预测损失
|
||||
- `predict_action()` 使用 DDIM 多步采样
|
||||
- 在线控制通过动作队列机制在 `select_action()` 中按 chunk 触发预测
|
||||
- 训练脚本:`roboimi/demos/vla_scripts/train_vla.py`
|
||||
- 支持 GPU 训练、SwanLab 日志、headless rollout 验证
|
||||
|
||||
因此,本次迁移的核心不是换视觉 backbone,而是替换 **head + loss + inference semantics**。
|
||||
|
||||
## Recommended Integration Approach
|
||||
|
||||
采用 **最小侵入式集成**:
|
||||
|
||||
1. **保留当前 RoboIMI 的视觉编码、数据读取、rollout/eval、训练脚本主框架**。
|
||||
2. **新增 IMF 专用 head 模块**,在 RoboIMI 内本地实现:
|
||||
- AttnRes 组件
|
||||
- IMF transformer 主体
|
||||
3. **新增 IMF 专用 agent**,复用当前 `VLAAgent` 的:
|
||||
- 归一化逻辑
|
||||
- 相机顺序管理
|
||||
- 观测缓存 / 动作 chunk 缓存
|
||||
- rollout 接口
|
||||
但覆盖:
|
||||
- `compute_loss()`
|
||||
- `predict_action()`
|
||||
4. **新增独立 Hydra config**,让 IMF policy 作为新的 agent 选项,不破坏已有 resnet_transformer / gr00t_dit 配置。
|
||||
|
||||
这样做的原因:
|
||||
|
||||
- 迁移 IMF 语义时不必把当前 DDPM agent 搅乱;
|
||||
- rollout / eval / checkpoint 逻辑仍然可复用;
|
||||
- 便于和现有 Transformer / DiT 直接做 A/B 对比训练。
|
||||
|
||||
## Architecture
|
||||
|
||||
### 1. Observation / Conditioning Path
|
||||
|
||||
沿用当前 RoboIMI 的视觉路径:
|
||||
|
||||
- 输入观测:`images={r_vis, top, front}` + `qpos`
|
||||
- `ResNetDiffusionBackbone` 对每个相机编码,得到 per-camera feature
|
||||
- `state_encoder` 编码 `qpos`
|
||||
- 将三相机特征与 state feature 按时间步拼接,形成 `per_step_cond`
|
||||
|
||||
这里不迁移 external repo 的 obs_encoder 实现;我们只对齐 **“图像作为条件 token 输入 transformer”** 这一语义。
|
||||
|
||||
### 2. Condition Tokenization
|
||||
|
||||
对齐 external IMF transformer 的 token 使用方式:
|
||||
|
||||
- action trajectory token:由 `(B, pred_horizon, action_dim)` 通过线性层映射到 `n_emb`
|
||||
- time token:两个标量 `r` 与 `t`,分别通过 sinusoidal embedding + linear projection 得到 token
|
||||
- observation token:`per_step_cond` 通过线性层映射到 `n_emb`
|
||||
- 最终 token 序列为:
|
||||
- `[r_token, t_token, obs_cond_tokens..., action_tokens...]`
|
||||
|
||||
在当前任务中,obs token 数量等于 `obs_horizon`,且图像观测始终作为条件输入。
|
||||
|
||||
### 3. IMF-AttnRes Backbone
|
||||
|
||||
在 RoboIMI 内新增 AttnRes backbone 实现,保持 external commit 的关键语义:
|
||||
|
||||
- `RMSNorm` / `RMSNormNoWeight`
|
||||
- RoPE
|
||||
- Grouped Query Self-Attention
|
||||
- SwiGLU FFN
|
||||
- AttnRes operator / residual source aggregation
|
||||
- `AttnResTransformerBackbone`
|
||||
|
||||
并保持:
|
||||
|
||||
- **full attention**(不使用因果注意力)
|
||||
- `backbone_type='attnres_full'`
|
||||
- 输出仅切回 action token 部分,再经过最终 norm + head 得到 velocity-like 输出
|
||||
|
||||
### 4. Training Objective
|
||||
|
||||
训练目标从当前 DDPM epsilon prediction 改为 external IMF 目标:
|
||||
|
||||
给定真实轨迹 `x` 与随机噪声 `e`:
|
||||
|
||||
1. 采样 `t ~ U(0,1)`、`r ~ U(0,1)`,并排序为 `t >= r`
|
||||
2. 构造插值状态:
|
||||
- `z_t = (1 - t) x + t e`
|
||||
3. 用模型计算:
|
||||
- `v = f(z_t, t, t, cond)`
|
||||
4. 对 `g(z, r, t) = f(z, r, t, cond)` 做 JVP,得到:
|
||||
- `u, du_dt`
|
||||
5. 构造 compound velocity:
|
||||
- `V = u + (t - r) * du_dt`
|
||||
6. 目标为:
|
||||
- `target = e - x`
|
||||
7. 用 action 维度上的 MSE 作为最终损失
|
||||
|
||||
RoboIMI 现有 batch 中的 `action_is_pad` 仍要保留支持;如果存在 padding,只在有效 action 上计算损失。
|
||||
|
||||
### 5. One-Step Inference
|
||||
|
||||
推理改为 external IMF 的一步采样语义:
|
||||
|
||||
1. 从标准高斯初始化 action trajectory `z_t`
|
||||
2. 计算 `u = f(z_t, r=0, t=1, cond)`
|
||||
3. 一步更新:
|
||||
- `x_hat = z_t - (t-r) * u = z_t - u`
|
||||
4. 反归一化得到动作序列
|
||||
|
||||
这意味着:
|
||||
|
||||
- `num_inference_steps` 对 IMF policy 固定为 `1`
|
||||
- 不再调用 DDIM scheduler 的多步 `step()`
|
||||
- 在线控制中仍沿用当前 chunk 机制:
|
||||
- 动作队列为空时触发一次 `predict_action_chunk()`
|
||||
- 取预测序列中 `[obs_horizon-1 : obs_horizon-1+num_action_steps]` 这一段入队
|
||||
|
||||
也就是说,**触发模型前向的规则不变,改变的是每次触发后的动作序列生成方式**。
|
||||
|
||||
## API / Code Structure
|
||||
|
||||
计划中的主要代码边界如下:
|
||||
|
||||
- `roboimi/vla/models/heads/attnres_transformer_components.py`
|
||||
- IMF AttnRes 基础组件
|
||||
- `roboimi/vla/models/heads/imf_transformer1d.py`
|
||||
- RoboIMI 版本 IMF transformer head
|
||||
- 对外暴露 `forward(sample, r, t, cond=None)`
|
||||
- 暴露 `get_optim_groups()` 供 AdamW 分组使用
|
||||
- `roboimi/vla/agent_imf.py`
|
||||
- 复用 `VLAAgent` 的观测处理 / normalization / queue 基础设施
|
||||
- 覆盖 IMF 的训练损失与 one-step 预测逻辑
|
||||
- Hydra config
|
||||
- `roboimi/vla/conf/head/imf_transformer1d.yaml`
|
||||
- `roboimi/vla/conf/agent/resnet_imf_attnres.yaml`
|
||||
|
||||
训练脚本主流程尽量不改;只要求它能 instantiate 新 agent 并继续使用当前 rollout / checkpoint / swanlab 逻辑。
|
||||
|
||||
## Compatibility Decisions
|
||||
|
||||
## Initial Config Defaults To Preserve
|
||||
|
||||
为避免迁移时语义漂移,首版 IMF 配置默认值明确固定为:
|
||||
|
||||
- `backbone_type: attnres_full`
|
||||
- `n_head: 1`
|
||||
- `n_kv_head: 1`
|
||||
- `n_cond_layers: 0`
|
||||
- `time_as_cond: true`
|
||||
- `causal_attn: false`
|
||||
- `num_inference_steps: 1`
|
||||
|
||||
这些默认值与 external `185ed659` 的 IMF-AttnRes 使用方式保持一致;后续调参可以覆盖,但首版迁移必须先以该语义跑通。
|
||||
|
||||
### Reuse From RoboIMI
|
||||
|
||||
保留:
|
||||
|
||||
- 三相机数据读取方式
|
||||
- ResNet visual backbone
|
||||
- qpos / action normalization
|
||||
- 训练循环、优化器、scheduler、SwanLab、headless rollout
|
||||
- `select_action()` 的在线 chunk 执行方式
|
||||
|
||||
### Replace With External IMF Semantics
|
||||
|
||||
替换:
|
||||
|
||||
- transformer head 实现
|
||||
- diffusion training objective
|
||||
- inference sampling semantics
|
||||
|
||||
### Intentionally Not Mirrored 1:1
|
||||
|
||||
不强行与 external repo 一致的部分:
|
||||
|
||||
- external repo 的整体 policy 基类继承体系
|
||||
- external repo 的 obs encoder 模块树
|
||||
- external repo 的 normalizer / mask generator 框架
|
||||
|
||||
原因是当前 RoboIMI 已有稳定的数据接口和 rollout 流程,直接嫁接进去更稳。
|
||||
|
||||
## Testing / Verification Strategy
|
||||
|
||||
迁移完成后至少验证以下内容:
|
||||
|
||||
1. **单元 / 冒烟验证**
|
||||
- IMF head 前向 shape 正确
|
||||
- IMF agent `compute_loss()` 在真实 batch 上可前向、反向
|
||||
- IMF agent `predict_action()` 能输出 `(B, pred_horizon, action_dim)`
|
||||
2. **训练链路验证**
|
||||
- 使用 GPU 跑一个短训练任务,确认:
|
||||
- dataloader 正常
|
||||
- optimizer / lr scheduler 正常
|
||||
- SwanLab 正常记录配置和训练指标
|
||||
3. **rollout 验证**
|
||||
- 训练中周期性 headless rollout 能跑通
|
||||
- 环境仍按 EE-style `step()` 接收动作
|
||||
4. **最终交付**
|
||||
- 用用户指定的同类超参数启动正式训练
|
||||
|
||||
## Risks and Mitigations
|
||||
|
||||
### Risk 1: JVP 在 CUDA 注意力内核上不稳定
|
||||
|
||||
缓解:沿用 external repo 的策略,在 JVP 路径上切换到 math SDP kernel,必要时 fallback 到 `torch.autograd.functional.jvp`。同时,JVP 的切线构造与 `u, du_dt` 计算流程必须严格对齐 external source,不在本次迁移中自行改写其数学语义。
|
||||
|
||||
### Risk 2: Optimizer 参数分组遗漏新模块
|
||||
|
||||
缓解:IMF head 提供 `get_optim_groups()`,并在训练脚本中按“只要 head 提供该接口就使用”的策略统一处理,而不是绑定旧 `head_type`。
|
||||
|
||||
### Risk 3: 现有 rollout 逻辑假定 DDIM 多步采样
|
||||
|
||||
缓解:保持 `select_action()` / `predict_action_chunk()` 接口不变,只替换 `predict_action()` 内部实现,确保 eval 代码无需理解 IMF 细节。
|
||||
|
||||
### Risk 4: 训练命令参数与新 config 不一致
|
||||
|
||||
缓解:新增独立 agent config,并保留此前训练参数作为显式 CLI override 模板。
|
||||
|
||||
## Success Criteria
|
||||
|
||||
以下条件全部满足,视为本次迁移成功:
|
||||
|
||||
1. RoboIMI 中新增 IMF-AttnRes policy,可通过 Hydra config 单独启用。
|
||||
2. 训练时使用 external IMF 的 loss,而不是当前 DDPM epsilon loss。
|
||||
3. 推理时使用 one-step IMF 采样,而不是 DDIM 多步采样。
|
||||
4. 三相机图像始终作为条件输入参与模型前向。
|
||||
5. 在线 rollout 能在 headless 仿真环境中跑通。
|
||||
6. 能按最近一次实验参数模板成功启动训练。
|
||||
@@ -0,0 +1,75 @@
|
||||
# IMF Rollout Trajectory Images + Short-Horizon Training Design
|
||||
|
||||
## Background
|
||||
The current RoboIMI IMF training flow can perform rollout validation and log scalar reward metrics to SwanLab, but it does not yet emit the qualitative rollout artifacts now required for analysis. The user wants training-time rollout validation to save front-view trajectory images with the model-generated trajectory drawn in red, upload those images to SwanLab, and then start a new local short-horizon IMF training run.
|
||||
|
||||
## Goals
|
||||
1. During training-time rollout validation, save one **front-camera** trajectory image per rollout episode.
|
||||
2. The image must show the rollout EE trajectory in red.
|
||||
3. Reuse the existing repository trajectory visualization logic as much as practical, especially the existing red capsule-marker trajectory representation.
|
||||
4. Save 5 rollout images locally for each validation event and upload the same 5 images to SwanLab.
|
||||
5. Do **not** record rollout videos for this training-time validation flow.
|
||||
6. Start a new local IMF-AttnRes training run with:
|
||||
- `agent.head.n_emb=384`
|
||||
- `agent.head.n_layer=12`
|
||||
- `agent.pred_horizon=8`
|
||||
- `agent.num_action_steps=4`
|
||||
- `train.max_steps=50000`
|
||||
- `train.rollout_num_episodes=5`
|
||||
- `train.use_swanlab=true`
|
||||
|
||||
## Non-Goals
|
||||
- No IMF architecture or loss-function change.
|
||||
- No dataset schema change.
|
||||
- No rollout video generation for the new training flow.
|
||||
- No interactive viewer requirement.
|
||||
|
||||
## Existing Relevant Code
|
||||
- `roboimi/demos/vla_scripts/eval_vla.py`
|
||||
- already supports rollout summaries, optional trajectory export, and optional video export.
|
||||
- `roboimi/utils/raw_action_trajectory_viewer.py`
|
||||
- already contains the red trajectory capsule-marker construction logic.
|
||||
- `roboimi/demos/vla_scripts/train_vla.py`
|
||||
- already performs periodic rollout validation and scalar SwanLab logging.
|
||||
- `roboimi/vla/agent.py`
|
||||
- already implements “predict pred_horizon, execute first num_action_steps” queue semantics.
|
||||
|
||||
## Design Decisions
|
||||
|
||||
### 1. Artifact contract
|
||||
Each rollout episode will emit one distinct PNG file under the eval artifact directory. The file naming/path contract must be per-episode, not shared, so a 5-episode validation event yields 5 stable image paths without overwriting.
|
||||
|
||||
### 2. Trajectory definition
|
||||
The red trajectory corresponds to the **actually executed model action sequence** over the rollout loop: the raw EE actions returned and consumed step-by-step by the policy loop. For the requested short-horizon run, this means the visualization reflects repeated execution of the first 4 actions from each predicted 8-action chunk, not every discarded future prediction from replanning.
|
||||
|
||||
### 3. Camera choice
|
||||
The training-time image export path is explicitly pinned to the repo’s concrete `front` camera key. It must not silently use `camera_names[0]` if that is not `front`.
|
||||
|
||||
### 4. Rendering path
|
||||
`eval_vla.py` will add a lightweight headless image-export path that:
|
||||
- renders the `front` camera frame,
|
||||
- overlays the trajectory using the existing red trajectory representation,
|
||||
- saves a static PNG per episode.
|
||||
|
||||
The implementation may reuse the existing marker-construction logic directly and add a minimal helper for final image composition/export.
|
||||
|
||||
### 5. Training-time behavior
|
||||
`train_vla.py` rollout validation must explicitly:
|
||||
- request/save trajectory images,
|
||||
- keep `record_video=false`,
|
||||
- return the 5 per-episode image paths in the rollout summary payload,
|
||||
- upload those 5 images to SwanLab,
|
||||
- keep image-upload failures non-fatal.
|
||||
|
||||
## Expected User-Visible Outcome
|
||||
For each scheduled validation event in the new training run:
|
||||
- 5 rollout episodes execute,
|
||||
- 5 front-view PNG trajectory images are saved locally,
|
||||
- the same 5 images are uploaded to SwanLab,
|
||||
- scalar reward metrics continue to be logged,
|
||||
- no rollout videos are generated.
|
||||
|
||||
## Risks and Mitigations
|
||||
- **Headless rendering conflicts from desktop env vars**: force headless eval onto EGL when `headless=true`.
|
||||
- **Image overwrite risk**: use explicit per-episode artifact paths.
|
||||
- **SwanLab media API mismatch**: isolate media logging in a small best-effort helper.
|
||||
138
docs/superpowers/specs/2026-04-05-lewm-vit-backbone-design.md
Normal file
138
docs/superpowers/specs/2026-04-05-lewm-vit-backbone-design.md
Normal file
@@ -0,0 +1,138 @@
|
||||
# LEWM ViT Backbone Replacement Design
|
||||
|
||||
## Goal
|
||||
将当前 roboimi VLA policy 中的 ResNet 视觉编码器替换为来自 LEWM checkpoint 的冻结 ViT 视觉编码器(encoder + projector),仅使用最终 CLS token 的 192 维 embedding 作为视觉特征。
|
||||
|
||||
## User constraints
|
||||
- 使用 `/home/droid/下载/lewm_sim_transfer_checkpoint_usage.md` 中确认的训练好 checkpoint
|
||||
- 只使用视觉编码部分:`encoder + projector`
|
||||
- 权重冻结
|
||||
- 维持“视觉特征 + state 拼接,再送入 diffusion transformer”这一总体处理方式
|
||||
- 输入使用三视角:`[r_vis, top, front]`
|
||||
- 在 5880 机器上启动两个训练:`embed=384/layer=12` 和 `embed=256/layer=12`
|
||||
- `pred_horizon=16`
|
||||
- `num_action_steps=8`
|
||||
- 每个训练 `50k` steps
|
||||
- rollout 验证每次用 `10` 个 episodes,不是之前的 `5`
|
||||
|
||||
## Trusted existing facts
|
||||
1. LEWM checkpoint 路径:
|
||||
- `/home/droid/le-wm/lewm-sim-transfer/pa1w85md8jop6bvol8oxp/checkpoints/epoch=99-step=47800.ckpt`
|
||||
2. 需要加载的 state_dict 前缀:
|
||||
- `model.encoder.*`
|
||||
- `model.projector.*`
|
||||
3. LEWM ViT 配置:
|
||||
- encoder scale: `tiny`
|
||||
- hidden size: `192`
|
||||
- layers: `12`
|
||||
- attention heads: `3`
|
||||
- patch size: `14`
|
||||
- projector: `MLP(192 -> 2048 -> 192)` with `BatchNorm1d + GELU`
|
||||
4. LEWM 训练时三视角先拼成单图,再送入单个 ViT encoder;输出整体视觉 embedding 是 **192 维**。
|
||||
|
||||
## Key design decision
|
||||
### Chosen design: fuse 3 cameras into one LEWM-style image, output one 192-d visual vector per timestep
|
||||
不是把 LEWM ViT 当成“每相机一个 192-d encoder”,而是按 LEWM 原训练方式:
|
||||
- 输入三视角图像字典 `{r_vis, top, front}`
|
||||
- 按固定顺序拼成一张 fused image
|
||||
- 走单个 frozen ViT + projector
|
||||
- 得到一个 **192 维总视觉特征**
|
||||
|
||||
### Why this is the right replacement
|
||||
当前 ResNet backbone 对外给到 policy head 的**总视觉特征维度**是:
|
||||
- 每相机 `64`
|
||||
- 三相机总计 `192`
|
||||
|
||||
而 LEWM checkpoint 输出的 CLS/projector embedding 也是:
|
||||
- 总计 `192`
|
||||
|
||||
因此,最自然的“直接平替当前 ResNet 视觉编码器”的方式是:
|
||||
- 用 LEWM backbone 直接产出一个 192-d 总视觉向量
|
||||
- 后续和 state `16-d` 拼接后,依旧得到 `208-d` 条件向量
|
||||
- 不改 diffusion head 的总体接口和语义
|
||||
|
||||
## Interface compatibility plan
|
||||
现有 `VLAAgent` 假设 backbone 暴露:
|
||||
- `camera_names`
|
||||
- `num_cameras`
|
||||
- `output_dim`(语义上是“每相机特征维度”)
|
||||
- `forward(images_dict) -> (B, T, total_visual_dim)`
|
||||
|
||||
为了最小改动兼容现有 agent:
|
||||
- 新 LEWM backbone 的 `forward()` 返回 `(B, T, 192)`
|
||||
- `camera_names = ('r_vis', 'top', 'front')`
|
||||
- `num_cameras = 3`
|
||||
- `output_dim = 64`
|
||||
|
||||
这样 `VLAAgent` 内部仍会计算:
|
||||
- `per_step_cond_dim = output_dim * num_cams + obs_dim = 64*3 + 16 = 208`
|
||||
与实际 `forward()` 输出的 `192 + 16 = 208` 保持一致。
|
||||
|
||||
> 也就是说:`output_dim` 在这个 backbone 里保留为“与旧 ResNet 总特征等价的单相机占位维度”,而不是“真实 projector 输出维度”。这是一个兼容性 shim,用来避免改 agent 主逻辑。
|
||||
|
||||
## Image preprocessing design
|
||||
当前 roboimi dataset 已经把每个相机图像读成:
|
||||
- `(C, 224, 224)`
|
||||
- 值域 `[0, 1]`
|
||||
|
||||
新 LEWM backbone 将:
|
||||
1. 按顺序取 `r_vis`, `top`, `front`
|
||||
2. 在宽度方向拼接,得到 fused image:
|
||||
- `(C, 224, 672)`
|
||||
3. 使用 LEWM 一致的 ImageNet normalize:
|
||||
- mean `[0.485, 0.456, 0.406]`
|
||||
- std `[0.229, 0.224, 0.225]`
|
||||
4. 调用 `ViTModel(..., interpolate_pos_encoding=True)`
|
||||
5. 取 `last_hidden_state[:, 0]`
|
||||
6. 送入 frozen projector,得到 `(B*T, 192)`
|
||||
|
||||
## Files to create / modify
|
||||
### New files
|
||||
- `roboimi/vla/models/backbones/lewm_vit_backbone.py`
|
||||
- `roboimi/vla/conf/backbone/lewm_vit_diffusion.yaml`
|
||||
- `roboimi/vla/conf/agent/lewm_imf_attnres.yaml`
|
||||
- `tests/test_lewm_vit_backbone.py`
|
||||
|
||||
### Modified files
|
||||
- `roboimi/vla/models/backbones/__init__`(如果需要导出)
|
||||
- `tests/test_imf_vla_agent.py`(增加新 backbone 集成用例)
|
||||
- `roboimi/demos/vla_scripts/train_vla.py`(如需仅调整 rollout 默认/日志;如果命令覆盖足够,则尽量不改主逻辑)
|
||||
- 训练/实验 suite 文档(新增本次 LEWM ViT 训练记录)
|
||||
|
||||
## Testing plan
|
||||
1. **Unit test: load + forward**
|
||||
- 用 synthetic checkpoint 验证新 backbone 能正确加载 `model.encoder.*` 与 `model.projector.*`
|
||||
- 输入 3 相机 `(B,T,C,224,224)`
|
||||
- 输出 `(B,T,192)`
|
||||
2. **Agent integration test**
|
||||
- backbone.output_dim=64, num_cameras=3
|
||||
- agent `_build_cond()` 输出最后维度为 `208`
|
||||
3. **Remote smoke test on 5880**
|
||||
- 使用真实 checkpoint
|
||||
- `max_steps=2`
|
||||
- 两个实验各自 smoke 一次
|
||||
4. **Full run**
|
||||
- GPU0: `embed=384, layer=12`
|
||||
- GPU1: `embed=256, layer=12`
|
||||
- `rollout_num_episodes=10`
|
||||
|
||||
## Training launch contract
|
||||
- host: `100.73.14.65`
|
||||
- code dir: `/home/droid/roboimi_suite_20260404`
|
||||
- python: `/home/droid/miniforge3/envs/roboimi/bin/python`
|
||||
- dataset: `/home/droid/sim_dataset/sim_transfer`
|
||||
- cameras: `[r_vis, top, front]`
|
||||
- agent: new `lewm_imf_attnres`
|
||||
- max_steps: `50000`
|
||||
- rollout every `5` epochs
|
||||
- rollout episodes: `10`
|
||||
|
||||
## Risks
|
||||
1. LEWM 训练时的 fused image 预处理如果方向实现错了(224x672 vs 672x224),会导致分布偏移。
|
||||
2. 当前 roboimi env 需确保安装 `transformers`;从 `environment.yml` 看本地已有该依赖,但远端训练环境要 smoke 确认。
|
||||
3. 因为这是 frozen ViT + projector,若 projector BN 仍保持 train 模式,统计量会漂移,所以必须整体 `eval()` 并冻结。
|
||||
|
||||
## Recommended first implementation path
|
||||
- 先实现一个独立 `LEWMViTBackbone` 类,不改现有 `ResNetDiffusionBackbone` 主逻辑。
|
||||
- 再通过新的 hydra backbone/agent 配置接入。
|
||||
- 优先做到“最少侵入 + smoke 可跑 + 远端可训”。
|
||||
@@ -0,0 +1,81 @@
|
||||
# Phase-2 Full-AttnRes Vision Design
|
||||
|
||||
## Goal
|
||||
在当前 roboimi IMF policy 中,把视觉 backbone 里原先由 ResNet BasicBlock/Bottleneck 提供的残差单元全部替换为 AttnRes 风格单元,同时尽量保持现有 agent / cond / rollout / 训练脚本接口不变。
|
||||
|
||||
## User requirement interpretation
|
||||
这里按最严格解释执行:
|
||||
- 不是“在 ResNet 后面再加一个 AttnRes 模块”
|
||||
- 也不是“只在某几个 stage 加 AttnRes 混合”
|
||||
- 而是:视觉主干网络中原本依赖 ResNet residual block 的地方,统一改成 AttnRes residual operator 驱动的 block
|
||||
- 最终仍然输出与现有 `ResNetDiffusionBackbone` 相同的每相机特征接口,以便复用 `SpatialSoftmax -> Linear -> ReLU`、多相机拼接、state concat、IMF head 条件输入
|
||||
|
||||
## Recommended design
|
||||
### Option A (recommended)
|
||||
保留 ResNet 的宏观 stage/stem 结构与通道/步幅规划,但把每个 stage 内的 BasicBlock/Bottleneck 替换为新的 `AttnResImageBlock2D`:
|
||||
- 输入仍是 `(B, C, H, W)` feature map
|
||||
- block 内先把空间维 flatten 成 token 序列 `(B, H*W, C)`
|
||||
- 用二维位置编码 / 可学习位置偏置 + AttnRes self-attention + AttnRes FFN 完成 block 变换
|
||||
- 再 reshape 回 `(B, C, H, W)`
|
||||
- stage 间下采样仍由显式 stride/downsample path 完成
|
||||
|
||||
优点:
|
||||
- 最接近“ResNet 中所有残差都由 AttnRes 代替”的要求
|
||||
- 保留现有视觉输出接口和 cond_dim,不用改 agent/head/data pipeline
|
||||
- 仍可沿用现有多相机编码器框架
|
||||
|
||||
缺点:
|
||||
- 需要新写 2D 版 AttnRes image block,而不是直接复用 1D IMF head block
|
||||
|
||||
### Option B
|
||||
完全移除 ResNet stage,换成 patchify + ViT/AttnRes 图像 transformer,再接 SpatialSoftmax/MLP。
|
||||
|
||||
优点:实现概念更统一。
|
||||
缺点:已经不算“把 ResNet 中残差替换掉”,而是直接换 backbone,和用户要求不完全一致。
|
||||
|
||||
### Option C
|
||||
保留现有 ResNet block,只在 block 外层加 AttnRes mixing。
|
||||
|
||||
不推荐,因为不满足“所有残差均由 AttnRes 替代”。
|
||||
|
||||
## Concrete architecture choice
|
||||
采用 Option A:
|
||||
1. 保留 stem(conv/bn-or-gn/relu/maxpool)与 stage 边界
|
||||
2. 新增 `AttnResImageBlock2D`
|
||||
3. 新增 `AttnResResNetLikeBackbone2D`,负责堆叠 stage/block
|
||||
4. 在 `ResNetDiffusionBackbone` 中增加可选 backbone mode,例如:
|
||||
- `vision_backbone_mode: resnet`
|
||||
- `vision_backbone_mode: attnres_resnet`
|
||||
5. `resnet_imf_attnres` agent 配置新增一个 Phase-2 变体,默认打开 `attnres_resnet`
|
||||
6. 仍保持:
|
||||
- 每相机输出 `64`
|
||||
- 多相机总视觉输出 `3 * 64`
|
||||
- 与 state 拼接后 `cond_dim = 208`
|
||||
|
||||
## Files likely to change
|
||||
- `roboimi/vla/models/backbones/resnet_diffusion.py`
|
||||
- `roboimi/vla/conf/backbone/resnet_diffusion.yaml`
|
||||
- `roboimi/vla/conf/agent/resnet_imf_attnres.yaml`
|
||||
- new: `roboimi/vla/models/backbones/attnres_resnet2d.py`
|
||||
- tests:
|
||||
- new: `tests/test_attnres_resnet2d_backbone.py`
|
||||
- update/add wiring test for agent cond dims
|
||||
|
||||
## Test plan
|
||||
1. New backbone instantiates and forwards `(B,T,C,H,W)` multi-camera input
|
||||
2. Output shape unchanged vs current backbone
|
||||
3. `output_dim == 64`
|
||||
4. 3-camera cond path still yields `208`
|
||||
5. Phase-2 config instantiates full IMF agent successfully
|
||||
6. One short CPU smoke forward for `compute_loss`
|
||||
|
||||
## Phase-2 experiment plan
|
||||
固定使用 Phase-1 最优组合:
|
||||
- `pred_horizon=16`
|
||||
- `num_action_steps=8`
|
||||
|
||||
比较:
|
||||
1. baseline: current IMF head-only AttnRes + original ResNet vision backbone
|
||||
2. phase2: IMF head AttnRes + full AttnRes-replaced vision backbone
|
||||
|
||||
训练超参保持与 Phase-1 最优设置一致,先跑一组 50k step 对比。
|
||||
@@ -0,0 +1,32 @@
|
||||
# ResNet Multitoken IMF Design
|
||||
|
||||
**Status:** user-specified architecture, treated as approved on 2026-04-06.
|
||||
|
||||
## Goal
|
||||
Keep a standard ResNet-18 visual trunk (no AttnRes in vision), but change IMF conditioning from one concatenated multiview token per obs step into three camera-specific condition tokens per obs step.
|
||||
|
||||
## Approved architecture
|
||||
- Vision trunk: standard `resnet18` residual network
|
||||
- Cameras: `front`, `top`, `r_vis`
|
||||
- Each camera uses its **own** ResNet-18 weights (`use_separate_rgb_encoder_per_camera=true`)
|
||||
- Each camera produces one visual token
|
||||
- For each obs step and each camera:
|
||||
1. take that camera visual token
|
||||
2. concatenate robot state
|
||||
3. project to one condition token
|
||||
- IMF input should receive **3 condition tokens per obs step**, not one concatenated token
|
||||
- With `obs_horizon=2`, IMF cond sequence length becomes `2 * 3 = 6`
|
||||
- IMF head remains on the existing IMF/AttnRes implementation path
|
||||
- Vision trunk remains standard ResNet; **no AttnRes vision replacement**
|
||||
|
||||
## Design choices
|
||||
- Extend `ResNetDiffusionBackbone` with an opt-in mode that returns per-camera tokens shaped `(B, T, num_cams, D)` instead of concatenating camera features into `(B, T, num_cams * D)`.
|
||||
- Teach `VLAAgent` to detect multi-token visual features, broadcast state per camera token, apply the existing condition projector on each token, then flatten `(T, num_cams)` into one cond sequence for the IMF head.
|
||||
- Keep `per_step_cond_dim` as the width of a single condition token, and add explicit token-count metadata so transformer heads get the correct cond-sequence length.
|
||||
- For the new experiments, set the condition-token width equal to `n_emb` via `cond_projector.output_dim=${agent.head.n_emb}`.
|
||||
|
||||
## Files expected to change
|
||||
- `roboimi/vla/models/backbones/resnet_diffusion.py`
|
||||
- `roboimi/vla/agent.py`
|
||||
- new Hydra agent config for the multitoken ResNet IMF variant
|
||||
- focused tests in `tests/test_imf_vla_agent.py` and/or `tests/test_resnet_transformer_agent_wiring.py`
|
||||
@@ -0,0 +1,41 @@
|
||||
# SigLIP2 Multiview VLA Design
|
||||
|
||||
**Status:** user-specified architecture, treated as approved on 2026-04-06
|
||||
|
||||
## Goal
|
||||
Replace the current vision encoder for the IMF/AttnRes diffusion policy with a frozen SigLIP2 image encoder while preserving the downstream action-diffusion stack and rollout behavior.
|
||||
|
||||
## Approved architecture
|
||||
- Backbone model: `google/siglip2-base-patch16-256`
|
||||
- Camera inputs: three views, encoded **independently** with a **shared** SigLIP2 vision encoder
|
||||
- Input size:
|
||||
- dataset images stay at native `256x256` (no dataset-side resize)
|
||||
- eval/rollout images resize to `256x256` before SigLIP2 because env renders are larger
|
||||
- Per-view feature: use the global pooled image feature (`pooler_output`, 768-d)
|
||||
- Per-view projection experiments:
|
||||
1. `768 -> 96`
|
||||
2. `768 -> 192`
|
||||
- Conditioning pipeline:
|
||||
1. concatenate 3 projected camera vectors
|
||||
2. concatenate robot state
|
||||
3. project concatenated condition to `384`
|
||||
4. feed that `384`-d per-step condition into the existing IMF/AttnRes diffusion head
|
||||
- Training/run defaults for requested experiments:
|
||||
- `n_emb=384`
|
||||
- `n_layer=12`
|
||||
- `pred_horizon=16`
|
||||
- `num_action_steps=8`
|
||||
- rollout count for validation: keep current requested behavior on this branch unless explicitly overridden later
|
||||
|
||||
## Design decisions
|
||||
- The condition projector lives in `VLAAgent._build_cond()` so the backbone owns only visual features, while the agent owns the final conditioning contract expected by the diffusion head.
|
||||
- The SigLIP2 backbone is frozen by default; only the per-view projectors and downstream policy layers train.
|
||||
- The backbone exposes `dataset_image_resize_shape=None` and `eval_image_resize_shape=(256, 256)` so existing train/eval plumbing can reuse the raw-256 path already added in this branch.
|
||||
- One shared vision encoder is used across cameras to keep memory and download size reasonable and to match the user's request for per-view independent encoding rather than a fused multiview image.
|
||||
|
||||
## Files expected to change
|
||||
- `roboimi/vla/models/backbones/` for the new SigLIP2 backbone
|
||||
- `roboimi/vla/agent.py` for optional post-concat condition projection
|
||||
- Hydra configs under `roboimi/vla/conf/{agent,backbone,modules}`
|
||||
- tests for backbone wiring and agent conditioning dims
|
||||
- remote launch commands/scripts only as needed for training
|
||||
@@ -229,6 +229,11 @@ dependencies:
|
||||
- python-xxhash=3.6.0
|
||||
- python_abi=3.10
|
||||
- pytorch=2.4.0
|
||||
- hydra-core=1.3.2
|
||||
- omegaconf=2.3.0
|
||||
- einops=0.8.2
|
||||
- diffusers=0.36.0
|
||||
- torchvision=0.19.0
|
||||
- pytz=2024.1
|
||||
- pyyaml=6.0.3
|
||||
- qhull=2020.2
|
||||
@@ -321,12 +326,10 @@ dependencies:
|
||||
- datasets==4.5.0
|
||||
- decorator==5.2.1
|
||||
- deepdiff==8.6.1
|
||||
- diffusers==0.30.0
|
||||
- dill==0.4.0
|
||||
- docstring_parser==0.17.0
|
||||
- draccus==0.10.0
|
||||
- eigenpy==3.10.3
|
||||
- einops==0.8.1
|
||||
- etils==1.7.0
|
||||
- evdev==1.9.2
|
||||
- exceptiongroup==1.3.1
|
||||
@@ -350,7 +353,6 @@ dependencies:
|
||||
- httpcore==1.0.9
|
||||
- httpx==0.28.1
|
||||
- huggingface_hub==1.3.2
|
||||
- hydra-core==1.3.2
|
||||
- imageio==2.35.1
|
||||
- imageio-ffmpeg==0.6.0
|
||||
- importlib_metadata==8.7.1
|
||||
@@ -380,22 +382,6 @@ dependencies:
|
||||
- networkx==3.4.2
|
||||
- numcodecs==0.13.1
|
||||
- numpy==2.2.6
|
||||
- nvidia-cublas-cu12==12.4.5.8
|
||||
- nvidia-cuda-cupti-cu12==12.4.127
|
||||
- nvidia-cuda-nvrtc-cu12==12.4.127
|
||||
- nvidia-cuda-runtime-cu12==12.4.127
|
||||
- nvidia-cudnn-cu12==9.1.0.70
|
||||
- nvidia-cufft-cu12==11.2.1.3
|
||||
- nvidia-cufile-cu12==1.11.1.6
|
||||
- nvidia-curand-cu12==10.3.5.147
|
||||
- nvidia-cusolver-cu12==11.6.1.9
|
||||
- nvidia-cusparse-cu12==12.3.1.170
|
||||
- nvidia-cusparselt-cu12==0.6.3
|
||||
- nvidia-nccl-cu12==2.21.5
|
||||
- nvidia-nvjitlink-cu12==12.4.127
|
||||
- nvidia-nvshmem-cu12==3.3.20
|
||||
- nvidia-nvtx-cu12==12.4.127
|
||||
- omegaconf==2.3.0
|
||||
- opencv-contrib-python==4.10.0.84
|
||||
- opencv-python==4.13.0.90
|
||||
- orderly-set==5.5.0
|
||||
@@ -431,7 +417,7 @@ dependencies:
|
||||
- regex==2026.1.15
|
||||
- requests==2.32.5
|
||||
- rerun-sdk==0.26.2
|
||||
- rich==14.2.0
|
||||
- rich==13.9.4
|
||||
- ruckig==0.9.2
|
||||
- safehttpx==0.1.7
|
||||
- safetensors==0.7.0
|
||||
@@ -443,18 +429,16 @@ dependencies:
|
||||
- stack-data==0.6.3
|
||||
- starlette==0.50.0
|
||||
- sympy==1.13.1
|
||||
- swanlab==0.7.13
|
||||
- termcolor==3.3.0
|
||||
- timm==1.0.24
|
||||
- toml==0.10.2
|
||||
- tomli==2.4.0
|
||||
- tomlkit==0.13.3
|
||||
- torch==2.5.0
|
||||
- torchcodec==0.5
|
||||
- torchmetrics==1.8.2
|
||||
- torchvision==0.20.0
|
||||
- tqdm==4.67.1
|
||||
- traitlets==5.14.3
|
||||
- triton==3.1.0
|
||||
- typer==0.21.1
|
||||
- typer-slim==0.21.1
|
||||
- typeshed_client==2.8.2
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
# Phase-1 Final Report and Phase-2 Handoff
|
||||
|
||||
- Finalized: 2026-04-05 00:34:20 CST
|
||||
- Scope: IMF AttnRes policy horizon/action-step grid on `sim_transfer`
|
||||
- Fixed setup: `n_emb=384`, `n_layer=12`, batch size `80`, learning rate `2.5e-4`, `max_steps=50k`, rollout every 5 epochs with 5 episodes, 3 cameras `[r_vis, top, front]`.
|
||||
- Main metric: `checkpoints/vla_model_best.pt` 中记录的训练期最大 `rollout_avg_reward`。
|
||||
|
||||
## Final leaderboard
|
||||
|
||||
| Rank | Run ID | pred_horizon | executed action steps | Best avg_reward | Best step | Final loss |
|
||||
|---:|---|---:|---:|---:|---:|---:|
|
||||
| 1 | `ph16_ex8` | 16 | 8 | **610.8** | 21874 | 0.0034 |
|
||||
| 2 | `ph16_ex16` | 16 | 16 | 561.2 | 48124 | 0.0045 |
|
||||
| 3 | `ph32_ex32` | 32 | 32 | 513.2 | 43749 | 0.0040 |
|
||||
| 4 | `ph8_ex8` | 8 | 8 | 415.6 | 48124 | 0.0070 |
|
||||
| 5 | `ph32_ex8` | 32 | 8 | 361.6 | 43749 | 0.0048 |
|
||||
| 6 | `ph32_ex16` | 32 | 16 | 239.6 | 48124 | 0.0038 |
|
||||
|
||||
## Final conclusions
|
||||
|
||||
1. **最佳组合是 `pred_horizon=16` + `num_action_steps=8`**,最佳平均奖励为 **610.8**,出现在 **step 21874**。
|
||||
2. 在 `pred_horizon=16` 下,执行 8 步优于执行 16 步,优势约 **+8.8%**(610.8 vs 561.2)。
|
||||
3. `pred_horizon=32` 时,对执行步长非常敏感:`32/32` 明显优于 `32/8` 和 `32/16`;特别是 `32/16` 退化最明显。
|
||||
4. 更长的预测窗口并不会自动带来更高 reward;**预测窗口与实际执行窗口的匹配关系** 是关键。
|
||||
5. 最佳 checkpoint 并不在训练结束时出现,而是在 50k 训练中较早的 **21.9k step** 出现,说明 rollout 验证比仅看 train loss 更重要。
|
||||
6. 因而 Phase-2 的比较基线固定为 **`ph16_ex8`**。
|
||||
|
||||
## Recommended baseline for follow-up experiments
|
||||
|
||||
- Baseline run: `ph16_ex8`
|
||||
- Baseline best checkpoint: `step 21874`
|
||||
- Baseline best avg_reward: `610.8`
|
||||
- Baseline run dir: `/home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223`
|
||||
|
||||
## Phase-2 target: full-AttnRes vision backbone
|
||||
|
||||
本阶段按你的要求,不再只是 IMF head 中使用 AttnRes,而是把**之前视觉 ResNet 主干中的残差单元全部替换为 AttnRes 残差单元**。当前实现保留了 ResNet 风格的 stage / downsample 宏观结构,但视觉残差 trunk 已切换到 AttnRes:
|
||||
|
||||
- implementation: `roboimi/vla/models/backbones/attnres_resnet2d.py`
|
||||
- wiring: `roboimi/vla/models/backbones/resnet_diffusion.py`
|
||||
- config: `roboimi/vla/conf/backbone/resnet_diffusion.yaml`
|
||||
|
||||
相关代码已提交:
|
||||
|
||||
- `a780068` — headless rollout 修复 + Phase-1 汇总
|
||||
- `2033169` — full-AttnRes vision backbone
|
||||
|
||||
## Phase-2 launch status (observed on 2026-04-05 00:36 CST)
|
||||
|
||||
- Run: `imf-p2-full-attnres-vision-ph16-ex08-emb384-l12-b40-lr1p25e4-ms50k-l20g3-20260405-002424`
|
||||
- Host: `100.119.99.14`, GPU `3`
|
||||
- Config anchor: `pred_horizon=16`, `num_action_steps=8`
|
||||
- Vision backbone: `attnres_resnet`
|
||||
- Because batch size `80` OOMed on both local 5090 and remote L20, Phase-2 currently uses:
|
||||
- batch size: `40`
|
||||
- learning rate: `1.25e-4`
|
||||
- Latest confirmed progress: **step 1300**
|
||||
- First rollout has **not happened yet** at this observation point.
|
||||
- SwanLab: https://swanlab.cn/@game-loader/roboimi-vla/runs/xy7fjdmn0stdr19eu3gub
|
||||
|
||||
## Next action
|
||||
|
||||
继续监控 Phase-2 full-AttnRes 训练,待其完成后直接与 Phase-1 baseline `610.8` 做对比,判断“视觉主干全部替换为 AttnRes”是否优于“仅 IMF 中使用 AttnRes”。
|
||||
@@ -0,0 +1,7 @@
|
||||
rank,run_id,status,pred_horizon,num_action_steps,best_rollout_avg_reward,best_step,final_step,final_loss,host,run_dir,latest_step
|
||||
1,ph16_ex8,running,16,8,610.8,21874,50000,0.0034315965604037046,100.73.14.65,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223,50000
|
||||
2,ph16_ex16,running,16,16,561.2,48124,50000,0.004544622730463743,100.119.99.14,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex16-emb384-l12-ms50k-l20g0-20260404-131223,50000
|
||||
3,ph32_ex32,finished,32,32,513.2,43749,50000,0.003953303210437298,local,/home/droid/project/roboimi/.worktrees/feat-imf-attnres-policy/runs/imf-p1-ph32-ex32-emb384-l12-ms50k-5090-20260404-131223,49900
|
||||
4,ph8_ex8,running,8,8,415.6,48124,50000,0.007008877582848072,100.73.14.65,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph08-ex08-emb384-l12-ms50k-5880g0-20260404-131223,50000
|
||||
5,ph32_ex8,running,32,8,361.6,43749,50000,0.004788532387465239,100.119.99.14,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph32-ex08-emb384-l12-ms50k-l20g1-20260404-131223,50000
|
||||
6,ph32_ex16,running,32,16,239.6,48124,50000,0.0038348555099219084,100.119.99.14,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph32-ex16-emb384-l12-ms50k-l20g2-20260404-131223,50000
|
||||
|
115
experiment_suites/2026-04-04-imf-horizon-grid/manifest.json
Normal file
115
experiment_suites/2026-04-04-imf-horizon-grid/manifest.json
Normal file
@@ -0,0 +1,115 @@
|
||||
{
|
||||
"suite_name": "2026-04-04-imf-horizon-grid",
|
||||
"created_at": "2026-04-04 13:19:52",
|
||||
"updated_at": "2026-04-04 13:19:52",
|
||||
"phase": "phase1_launching",
|
||||
"metric": "max_avg_reward",
|
||||
"baseline": {
|
||||
"agent": "resnet_imf_attnres",
|
||||
"batch_size": 80,
|
||||
"lr": 0.00025,
|
||||
"num_workers": 12,
|
||||
"max_steps": 50000,
|
||||
"rollout_val_freq_epochs": 5,
|
||||
"rollout_num_episodes": 5,
|
||||
"val_split": 0.0,
|
||||
"seed": 42,
|
||||
"scheduler_type": "cosine",
|
||||
"warmup_steps": 2000,
|
||||
"min_lr": 1e-06,
|
||||
"weight_decay": 1e-05,
|
||||
"grad_clip": 1.0,
|
||||
"inference_steps": 1,
|
||||
"embed_dim": 384,
|
||||
"n_layer": 12,
|
||||
"n_head": 1,
|
||||
"n_kv_head": 1,
|
||||
"freeze_backbone": false,
|
||||
"pretrained_backbone_weights": null,
|
||||
"camera_names": [
|
||||
"r_vis",
|
||||
"top",
|
||||
"front"
|
||||
]
|
||||
},
|
||||
"runs": [
|
||||
{
|
||||
"id": "ph8_ex8",
|
||||
"pred_horizon": 8,
|
||||
"num_action_steps": 8,
|
||||
"host": "100.73.14.65",
|
||||
"host_label": "tailnet-5880",
|
||||
"gpu": 0,
|
||||
"workdir": "/home/droid/roboimi_suite_20260404",
|
||||
"python": "/home/droid/miniforge3/envs/roboimi/bin/python",
|
||||
"dataset_dir": "/home/droid/sim_dataset/sim_transfer",
|
||||
"run_name": "imf-p1-ph08-ex08-emb384-l12-ms50k-5880g0-20260404-131223",
|
||||
"launch_state": "ready"
|
||||
},
|
||||
{
|
||||
"id": "ph16_ex8",
|
||||
"pred_horizon": 16,
|
||||
"num_action_steps": 8,
|
||||
"host": "100.73.14.65",
|
||||
"host_label": "tailnet-5880",
|
||||
"gpu": 1,
|
||||
"workdir": "/home/droid/roboimi_suite_20260404",
|
||||
"python": "/home/droid/miniforge3/envs/roboimi/bin/python",
|
||||
"dataset_dir": "/home/droid/sim_dataset/sim_transfer",
|
||||
"run_name": "imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223",
|
||||
"launch_state": "ready"
|
||||
},
|
||||
{
|
||||
"id": "ph16_ex16",
|
||||
"pred_horizon": 16,
|
||||
"num_action_steps": 16,
|
||||
"host": "100.119.99.14",
|
||||
"host_label": "tailnet-l20",
|
||||
"gpu": 0,
|
||||
"workdir": "/home/droid/roboimi_suite_20260404",
|
||||
"python": "/home/droid/miniforge3/envs/roboimi/bin/python",
|
||||
"dataset_dir": "/home/droid/sim_dataset/sim_transfer",
|
||||
"run_name": "imf-p1-ph16-ex16-emb384-l12-ms50k-l20g0-20260404-131223",
|
||||
"launch_state": "provisioning_required"
|
||||
},
|
||||
{
|
||||
"id": "ph32_ex8",
|
||||
"pred_horizon": 32,
|
||||
"num_action_steps": 8,
|
||||
"host": "100.119.99.14",
|
||||
"host_label": "tailnet-l20",
|
||||
"gpu": 1,
|
||||
"workdir": "/home/droid/roboimi_suite_20260404",
|
||||
"python": "/home/droid/miniforge3/envs/roboimi/bin/python",
|
||||
"dataset_dir": "/home/droid/sim_dataset/sim_transfer",
|
||||
"run_name": "imf-p1-ph32-ex08-emb384-l12-ms50k-l20g1-20260404-131223",
|
||||
"launch_state": "provisioning_required"
|
||||
},
|
||||
{
|
||||
"id": "ph32_ex16",
|
||||
"pred_horizon": 32,
|
||||
"num_action_steps": 16,
|
||||
"host": "100.119.99.14",
|
||||
"host_label": "tailnet-l20",
|
||||
"gpu": 2,
|
||||
"workdir": "/home/droid/roboimi_suite_20260404",
|
||||
"python": "/home/droid/miniforge3/envs/roboimi/bin/python",
|
||||
"dataset_dir": "/home/droid/sim_dataset/sim_transfer",
|
||||
"run_name": "imf-p1-ph32-ex16-emb384-l12-ms50k-l20g2-20260404-131223",
|
||||
"launch_state": "provisioning_required"
|
||||
},
|
||||
{
|
||||
"id": "ph32_ex32",
|
||||
"pred_horizon": 32,
|
||||
"num_action_steps": 32,
|
||||
"host": "local",
|
||||
"host_label": "local-5090",
|
||||
"gpu": 0,
|
||||
"workdir": "/home/droid/project/roboimi/.worktrees/feat-imf-attnres-policy",
|
||||
"python": "/home/droid/.conda/envs/roboimi/bin/python",
|
||||
"dataset_dir": "/home/droid/project/diana_sim/sim_transfer",
|
||||
"run_name": "imf-p1-ph32-ex32-emb384-l12-ms50k-5090-20260404-131223",
|
||||
"launch_state": "ready"
|
||||
}
|
||||
]
|
||||
}
|
||||
20
experiment_suites/2026-04-04-imf-horizon-grid/notes.md
Normal file
20
experiment_suites/2026-04-04-imf-horizon-grid/notes.md
Normal file
@@ -0,0 +1,20 @@
|
||||
# IMF Horizon Grid Suite Notes
|
||||
|
||||
- Created: 2026-04-04 13:19:52
|
||||
- Phase-1 matrix: (8,8), (16,8), (16,16), (32,8), (32,16), (32,32)
|
||||
- Fixed baseline: IMF AttnRes, n_emb=384, n_layer=12, batch_size=80, lr=2.5e-4, max_steps=50k, rollout every 5 epochs with 5 episodes.
|
||||
- Host allocation:
|
||||
- local RTX 5090: ph32_ex32
|
||||
- 100.73.14.65 RTX 5880 GPU0: ph8_ex8
|
||||
- 100.73.14.65 RTX 5880 GPU1: ph16_ex8
|
||||
- 100.119.99.14 L20 GPU0: ph16_ex16
|
||||
- 100.119.99.14 L20 GPU1: ph32_ex8
|
||||
- 100.119.99.14 L20 GPU2: ph32_ex16
|
||||
- 100.119.99.14 still needs env + dataset + swanlab credential copy before launch.
|
||||
|
||||
- 2026-04-04 13:23:43: launched local ph32_ex32 (pid 1437836), remote 100.73 ph8_ex8 (pid 931824), ph16_ex8 (pid 931826); started 100.119 bootstrap (local pid 1437837).
|
||||
- 2026-04-04 13:25:43: first status sync — local ph32_ex32 step≈500; remote ph8_ex8 step≈400; remote ph16_ex8 step≈400.
|
||||
- 2026-04-04 13:27:41: second status sync — 100.119 bootstrap finished env copy and entered dataset copy; local ph32_ex32 step≈900; remote ph8_ex8 step≈800; remote ph16_ex8 step≈800.
|
||||
- 2026-04-04 13:35:31: 100.119 bootstrap data/env copy finished. Original validation command hit a quoting bug, then I manually revalidated torch+mujoco+swanlab and launched ph16_ex16/ph32_ex8/ph32_ex16 with pids 81129/81130/81131.
|
||||
- 2026-04-04 13:37:36: all 6 Phase-1 runs are now up. SwanLab links recorded in status.json; latest observed steps ~ local 900 / 5880 runs 800 / L20 runs 100.
|
||||
- 2026-04-04 14:41:08: diagnosed remote first-rollout crash as early mujoco import before MUJOCO_GL=egl in eval_vla.py via raw_action_trajectory_viewer. Added regression test tests/test_eval_vla_headless_import.py, fixed import to lazy-load, verified 20-step headless eval on 5880 and L20, then resumed 5 failed runs from step 4374. Current resumed pids: ph8_ex8=938714, ph16_ex8=938717, ph16_ex16=90169, ph32_ex8=90173, ph32_ex16=90175.
|
||||
@@ -0,0 +1,38 @@
|
||||
# Phase-1 IMF Horizon Grid Summary
|
||||
|
||||
- Generated: 2026-04-04 23:43:38
|
||||
- Fixed baseline: IMF AttnRes head, n_emb=384, n_layer=12, batch_size=80, lr=2.5e-4, max_steps=50k, rollout every 5 epochs with 5 episodes, 3 cameras `[r_vis, top, front]`.
|
||||
- Primary metric: `checkpoints/vla_model_best.pt -> rollout_avg_reward` (max training-time rollout average reward).
|
||||
|
||||
## Ranked results
|
||||
|
||||
| Rank | Run ID | pred_horizon | num_action_steps | Best avg_reward | Best step | Final loss | Host |
|
||||
|---:|---|---:|---:|---:|---:|---:|---|
|
||||
| 1 | `ph16_ex8` | 16 | 8 | 610.8 | 21874 | 0.0034 | 100.73.14.65 |
|
||||
| 2 | `ph16_ex16` | 16 | 16 | 561.2 | 48124 | 0.0045 | 100.119.99.14 |
|
||||
| 3 | `ph32_ex32` | 32 | 32 | 513.2 | 43749 | 0.0040 | local |
|
||||
| 4 | `ph8_ex8` | 8 | 8 | 415.6 | 48124 | 0.0070 | 100.73.14.65 |
|
||||
| 5 | `ph32_ex8` | 32 | 8 | 361.6 | 43749 | 0.0048 | 100.119.99.14 |
|
||||
| 6 | `ph32_ex16` | 32 | 16 | 239.6 | 48124 | 0.0038 | 100.119.99.14 |
|
||||
|
||||
## Main observations
|
||||
|
||||
- Best overall setting was **`pred_horizon=16`, `num_action_steps=8`** with **max avg_reward = 610.8** at step **21874**.
|
||||
- Comparing horizon 16: executing 8 steps outperformed executing 16 steps (`ph16_ex8` > `ph16_ex16`).
|
||||
- Comparing horizon 32: executing the full 32-step chunk was much better than executing 16 or 8 steps (`ph32_ex32` > `ph32_ex8` > `ph32_ex16`).
|
||||
- Short horizon 8 with 8-step execution was competitive but clearly below the best 16/8 and 32/32 settings.
|
||||
- In this sweep, increasing prediction horizon helped only when the executed chunk length matched a good control cadence; mismatch could hurt a lot (especially `ph32_ex16`).
|
||||
|
||||
## Raw results
|
||||
|
||||
- `ph16_ex8`: best avg_reward=610.8 @ step 21874, final_loss=0.0034, run_dir=`/home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223`
|
||||
- `ph16_ex16`: best avg_reward=561.2 @ step 48124, final_loss=0.0045, run_dir=`/home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex16-emb384-l12-ms50k-l20g0-20260404-131223`
|
||||
- `ph32_ex32`: best avg_reward=513.2 @ step 43749, final_loss=0.0040, run_dir=`/home/droid/project/roboimi/.worktrees/feat-imf-attnres-policy/runs/imf-p1-ph32-ex32-emb384-l12-ms50k-5090-20260404-131223`
|
||||
- `ph8_ex8`: best avg_reward=415.6 @ step 48124, final_loss=0.0070, run_dir=`/home/droid/roboimi_suite_20260404/runs/imf-p1-ph08-ex08-emb384-l12-ms50k-5880g0-20260404-131223`
|
||||
- `ph32_ex8`: best avg_reward=361.6 @ step 43749, final_loss=0.0048, run_dir=`/home/droid/roboimi_suite_20260404/runs/imf-p1-ph32-ex08-emb384-l12-ms50k-l20g1-20260404-131223`
|
||||
- `ph32_ex16`: best avg_reward=239.6 @ step 48124, final_loss=0.0038, run_dir=`/home/droid/roboimi_suite_20260404/runs/imf-p1-ph32-ex16-emb384-l12-ms50k-l20g2-20260404-131223`
|
||||
|
||||
## Recommendation for Phase-2 anchor
|
||||
|
||||
- Use **`pred_horizon=16`, `num_action_steps=8`** as the strongest Phase-1 baseline if the goal is purely maximizing rollout reward.
|
||||
- If phase-2 needs a more conservative action execution budget, `ph16_ex8` is the strongest non-full-32 execution setting and may still be a good comparison anchor.
|
||||
167
experiment_suites/2026-04-04-imf-horizon-grid/status.json
Normal file
167
experiment_suites/2026-04-04-imf-horizon-grid/status.json
Normal file
@@ -0,0 +1,167 @@
|
||||
{
|
||||
"suite_name": "2026-04-04-imf-horizon-grid",
|
||||
"updated_at": "2026-04-05 00:34:20",
|
||||
"phase": "phase1_completed",
|
||||
"provisioning": {
|
||||
"100.119.99.14": {
|
||||
"state": "completed_manual_launch",
|
||||
"bootstrap_pid_local": 1437837,
|
||||
"log_path": "experiment_suites/2026-04-04-imf-horizon-grid/provision_logs/100.119.99.14-bootstrap-20260404-131223.log",
|
||||
"env_copy": "completed",
|
||||
"dataset_copy": "completed",
|
||||
"launch_watcher_pid_local": null,
|
||||
"launch_watcher_log": "experiment_suites/2026-04-04-imf-horizon-grid/launch_logs/100.119.99.14-launch-watcher-20260404-131223.log",
|
||||
"swanlab_copy": "completed",
|
||||
"bootstrap_validation_note": "initial validation command had a quoting bug; manual validation passed and launches were started successfully"
|
||||
}
|
||||
},
|
||||
"runs": {
|
||||
"ph8_ex8": {
|
||||
"status": "finished",
|
||||
"host": "100.73.14.65",
|
||||
"gpu": 0,
|
||||
"run_name": "imf-p1-ph08-ex08-emb384-l12-ms50k-5880g0-20260404-131223",
|
||||
"workdir": "/home/droid/roboimi_suite_20260404",
|
||||
"dataset_dir": "/home/droid/sim_dataset/sim_transfer",
|
||||
"log_path": "/home/droid/roboimi_suite_20260404/runs/imf-p1-ph08-ex08-emb384-l12-ms50k-5880g0-20260404-131223/train_vla.log",
|
||||
"run_dir": "/home/droid/roboimi_suite_20260404/runs/imf-p1-ph08-ex08-emb384-l12-ms50k-5880g0-20260404-131223",
|
||||
"pred_horizon": 8,
|
||||
"num_action_steps": 8,
|
||||
"pid": 938714,
|
||||
"launch_log": "experiment_suite_launch_logs/imf-p1-ph08-ex08-emb384-l12-ms50k-5880g0-20260404-131223.restartfix-20260404-143827.log",
|
||||
"latest_step": 50000,
|
||||
"latest_log_sync": "2026-04-05 00:34:20",
|
||||
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/i5syc57b6zq7rbkrtqy7b",
|
||||
"process_running": false,
|
||||
"best_step": 48124,
|
||||
"best_rollout_avg_reward": 415.6,
|
||||
"final_loss": 0.007008877582848072
|
||||
},
|
||||
"ph16_ex8": {
|
||||
"status": "finished",
|
||||
"host": "100.73.14.65",
|
||||
"gpu": 1,
|
||||
"run_name": "imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223",
|
||||
"workdir": "/home/droid/roboimi_suite_20260404",
|
||||
"dataset_dir": "/home/droid/sim_dataset/sim_transfer",
|
||||
"log_path": "/home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223/train_vla.log",
|
||||
"run_dir": "/home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223",
|
||||
"pred_horizon": 16,
|
||||
"num_action_steps": 8,
|
||||
"pid": 938717,
|
||||
"launch_log": "experiment_suite_launch_logs/imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223.restartfix-20260404-143827.log",
|
||||
"latest_step": 50000,
|
||||
"latest_log_sync": "2026-04-05 00:34:20",
|
||||
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/4rusbrpfxmw4ffii1ul5w",
|
||||
"process_running": false,
|
||||
"best_step": 21874,
|
||||
"best_rollout_avg_reward": 610.8,
|
||||
"final_loss": 0.0034315965604037046
|
||||
},
|
||||
"ph16_ex16": {
|
||||
"status": "finished",
|
||||
"host": "100.119.99.14",
|
||||
"gpu": 0,
|
||||
"run_name": "imf-p1-ph16-ex16-emb384-l12-ms50k-l20g0-20260404-131223",
|
||||
"workdir": "/home/droid/roboimi_suite_20260404",
|
||||
"dataset_dir": "/home/droid/sim_dataset/sim_transfer",
|
||||
"log_path": "/home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex16-emb384-l12-ms50k-l20g0-20260404-131223/train_vla.log",
|
||||
"run_dir": "/home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex16-emb384-l12-ms50k-l20g0-20260404-131223",
|
||||
"pred_horizon": 16,
|
||||
"num_action_steps": 16,
|
||||
"pid": 90169,
|
||||
"launch_log": "experiment_suite_launch_logs/imf-p1-ph16-ex16-emb384-l12-ms50k-l20g0-20260404-131223.restartfix-20260404-143827.log",
|
||||
"latest_log_sync": "2026-04-05 00:34:20",
|
||||
"latest_step": 50000,
|
||||
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/wwm232k6190gexnze8mg6",
|
||||
"process_running": false,
|
||||
"best_step": 48124,
|
||||
"best_rollout_avg_reward": 561.2,
|
||||
"final_loss": 0.004544622730463743
|
||||
},
|
||||
"ph32_ex8": {
|
||||
"status": "finished",
|
||||
"host": "100.119.99.14",
|
||||
"gpu": 1,
|
||||
"run_name": "imf-p1-ph32-ex08-emb384-l12-ms50k-l20g1-20260404-131223",
|
||||
"workdir": "/home/droid/roboimi_suite_20260404",
|
||||
"dataset_dir": "/home/droid/sim_dataset/sim_transfer",
|
||||
"log_path": "/home/droid/roboimi_suite_20260404/runs/imf-p1-ph32-ex08-emb384-l12-ms50k-l20g1-20260404-131223/train_vla.log",
|
||||
"run_dir": "/home/droid/roboimi_suite_20260404/runs/imf-p1-ph32-ex08-emb384-l12-ms50k-l20g1-20260404-131223",
|
||||
"pred_horizon": 32,
|
||||
"num_action_steps": 8,
|
||||
"pid": 90173,
|
||||
"launch_log": "experiment_suite_launch_logs/imf-p1-ph32-ex08-emb384-l12-ms50k-l20g1-20260404-131223.restartfix-20260404-143827.log",
|
||||
"latest_log_sync": "2026-04-05 00:34:20",
|
||||
"latest_step": 50000,
|
||||
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/o5y2xjb2rsb3lmfcuhy4p",
|
||||
"process_running": false,
|
||||
"best_step": 43749,
|
||||
"best_rollout_avg_reward": 361.6,
|
||||
"final_loss": 0.004788532387465239
|
||||
},
|
||||
"ph32_ex16": {
|
||||
"status": "finished",
|
||||
"host": "100.119.99.14",
|
||||
"gpu": 2,
|
||||
"run_name": "imf-p1-ph32-ex16-emb384-l12-ms50k-l20g2-20260404-131223",
|
||||
"workdir": "/home/droid/roboimi_suite_20260404",
|
||||
"dataset_dir": "/home/droid/sim_dataset/sim_transfer",
|
||||
"log_path": "/home/droid/roboimi_suite_20260404/runs/imf-p1-ph32-ex16-emb384-l12-ms50k-l20g2-20260404-131223/train_vla.log",
|
||||
"run_dir": "/home/droid/roboimi_suite_20260404/runs/imf-p1-ph32-ex16-emb384-l12-ms50k-l20g2-20260404-131223",
|
||||
"pred_horizon": 32,
|
||||
"num_action_steps": 16,
|
||||
"pid": 90175,
|
||||
"launch_log": "experiment_suite_launch_logs/imf-p1-ph32-ex16-emb384-l12-ms50k-l20g2-20260404-131223.restartfix-20260404-143827.log",
|
||||
"latest_log_sync": "2026-04-05 00:34:20",
|
||||
"latest_step": 50000,
|
||||
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/54cjpgba9eqsopdm0l8d3",
|
||||
"process_running": false,
|
||||
"best_step": 48124,
|
||||
"best_rollout_avg_reward": 239.6,
|
||||
"final_loss": 0.0038348555099219084
|
||||
},
|
||||
"ph32_ex32": {
|
||||
"status": "finished",
|
||||
"host": "local",
|
||||
"gpu": 0,
|
||||
"run_name": "imf-p1-ph32-ex32-emb384-l12-ms50k-5090-20260404-131223",
|
||||
"workdir": "/home/droid/project/roboimi/.worktrees/feat-imf-attnres-policy",
|
||||
"dataset_dir": "/home/droid/project/diana_sim/sim_transfer",
|
||||
"log_path": "/home/droid/project/roboimi/.worktrees/feat-imf-attnres-policy/runs/imf-p1-ph32-ex32-emb384-l12-ms50k-5090-20260404-131223/train_vla.log",
|
||||
"run_dir": "/home/droid/project/roboimi/.worktrees/feat-imf-attnres-policy/runs/imf-p1-ph32-ex32-emb384-l12-ms50k-5090-20260404-131223",
|
||||
"pred_horizon": 32,
|
||||
"num_action_steps": 32,
|
||||
"pid": 1437836,
|
||||
"launch_log": "experiment_suites/2026-04-04-imf-horizon-grid/launch_logs/imf-p1-ph32-ex32-emb384-l12-ms50k-5090-20260404-131223.launch.log",
|
||||
"latest_step": 49900,
|
||||
"latest_log_sync": "2026-04-05 00:34:20",
|
||||
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/ajs2m218jd260hawhy5ns",
|
||||
"process_running": false,
|
||||
"latest_rollout_avg_reward": 513.2,
|
||||
"best_rollout_avg_reward": 513.2,
|
||||
"best_step": 43749,
|
||||
"final_loss": 0.003953303210437298
|
||||
}
|
||||
},
|
||||
"monitor": {
|
||||
"state": "stopped",
|
||||
"pid_local": null,
|
||||
"log_path": "experiment_suites/2026-04-04-imf-horizon-grid/monitor_logs/status-sync-20260404-131223.log",
|
||||
"interval_seconds": 300,
|
||||
"stopped_at": "2026-04-05 00:34:20",
|
||||
"stop_reason": "phase1 suite finalized after all six runs completed"
|
||||
},
|
||||
"debug": {
|
||||
"remote_rollout_failure_20260404": {
|
||||
"root_cause": "eval_vla.py imported raw_action_trajectory_viewer at module import time, which imported mujoco before MUJOCO_GL=egl was set; remote headless rollout then fell back to GLFW/X11 and crashed with mujoco.FatalError: gladLoadGL error during env.reset()->mj.Renderer(...)",
|
||||
"fixed_file": "roboimi/demos/vla_scripts/eval_vla.py",
|
||||
"verification": {
|
||||
"pytest": "tests/test_eval_vla_headless_import.py passed",
|
||||
"remote_eval_5880": "1 episode x 20 steps headless eval passed",
|
||||
"remote_eval_l20": "1 episode x 20 steps headless eval passed"
|
||||
}
|
||||
}
|
||||
},
|
||||
"phase1_summary_md": "/home/droid/project/roboimi/.worktrees/feat-imf-attnres-policy/experiment_suites/2026-04-04-imf-horizon-grid/phase1_summary.md"
|
||||
}
|
||||
69
experiment_suites/2026-04-05-camera-ablation-summary.md
Normal file
69
experiment_suites/2026-04-05-camera-ablation-summary.md
Normal file
@@ -0,0 +1,69 @@
|
||||
# Camera Ablation Summary (`pred_horizon=16`, `num_action_steps=8`, ResNet IMF)
|
||||
|
||||
- Generated: 2026-04-05
|
||||
- Common setup: original ResNet vision backbone, `n_emb=384`, `n_layer=12`, `batch_size=80`, `lr=2.5e-4`, `max_steps=50k`, rollout every 5 epochs with 5 episodes, headless eval.
|
||||
- Metric for comparison: `checkpoints/vla_model_best.pt -> rollout_avg_reward`.
|
||||
|
||||
## Leaderboard
|
||||
|
||||
| Rank | Cameras | Best avg_reward | Best step | Final loss | Run name |
|
||||
|---:|---|---:|---:|---:|---|
|
||||
| 1 | `top + front` | **274.8** | 48124 | 0.0056 | `imf-resnet-topfront-2cam-ph16-ex08-emb384-l12-ms50k-5090-20260405-085023` |
|
||||
| 2 | `top` | **271.2** | 43749 | 0.0052 | `imf-resnet-top-1cam-ph16-ex08-emb384-l12-ms50k-l20g4-20260405-125844` |
|
||||
| 3 | `r_vis + front` | **244.0** | 21874 | 0.0043 | `imf-resnet-frontrvis-2cam-ph16-ex08-emb384-l12-ms50k-l20g1-20260405-102029` |
|
||||
| 4 | `r_vis` | **6.4** | 17499 | 0.0047 | `imf-resnet-rvis-1cam-ph16-ex08-emb384-l12-ms50k-l20g3-20260405-125844` |
|
||||
| 5 | `r_vis + top` | **1.2** | 4374 | 0.0047 | `imf-resnet-rvistop-2cam-ph16-ex08-emb384-l12-ms50k-l20g2-20260405-125844` |
|
||||
| 6 | `front` | **0.0** | 4374 | 0.0074 | `imf-resnet-front-1cam-ph16-ex08-emb384-l12-ms50k-l20g0-20260405-095607` |
|
||||
|
||||
## Main takeaways
|
||||
|
||||
1. **`top` 是最关键的单相机视角**:`top only = 271.2`,几乎与 `top + front = 274.8` 持平。
|
||||
2. **`front` 单独几乎没有效用**:`front only = 0.0`。
|
||||
3. **`r_vis` 单独也基本无效**:`r_vis only = 6.4`。
|
||||
4. **`r_vis + front` 可以显著优于单独 `front` / `r_vis`**,说明这两个视角有一定互补性,但仍明显弱于任何包含 `top` 且表现正常的配置。
|
||||
5. **`r_vis + top` 的结果异常差**:只有 `1.2`,远低于 `top only = 271.2`。这说明简单加入 `r_vis` 并不保证增益,甚至可能破坏当前设置下的学习。
|
||||
6. **训练 loss 与 rollout reward 明显不一致**:例如 `r_vis + top` 和 `r_vis only` 的 final loss 都不高,但 reward 很差,因此本组实验必须以 rollout reward 而不是 loss 选型。
|
||||
|
||||
## Horizontal comparison views
|
||||
|
||||
### Single-camera comparison
|
||||
|
||||
- `top`: **271.2**
|
||||
- `r_vis`: **6.4**
|
||||
- `front`: **0.0**
|
||||
|
||||
结论:**`top >>> r_vis > front`**。
|
||||
|
||||
### Two-camera comparison
|
||||
|
||||
- `top + front`: **274.8**
|
||||
- `r_vis + front`: **244.0**
|
||||
- `r_vis + top`: **1.2**
|
||||
|
||||
结论:
|
||||
- **最稳妥的双相机组合是 `top + front`**。
|
||||
- `r_vis + front` 有效,但不如 `top + front`。
|
||||
- `r_vis + top` 在当前设置下几乎失效。
|
||||
|
||||
### Incremental effect of adding a second view
|
||||
|
||||
- 在 `top` 基础上加 `front`:`271.2 -> 274.8`,**增益很小**。
|
||||
- 在 `front` 基础上加 `r_vis`:`0.0 -> 244.0`,**增益很大**。
|
||||
- 在 `top` 基础上加 `r_vis`:`271.2 -> 1.2`,**显著退化**。
|
||||
|
||||
## Practical recommendation
|
||||
|
||||
如果只从这 6 个实验里选:
|
||||
|
||||
- **首选**:`top + front`
|
||||
- **次选**:`top only`
|
||||
- 如果必须不用 `top`:`r_vis + front` 明显优于 `front only` / `r_vis only`
|
||||
- **不建议**:`r_vis + top`
|
||||
|
||||
## Note relative to previous 3-camera baseline
|
||||
|
||||
此前 3 相机 `[r_vis, top, front]` 的最佳 reward 为 **610.8**。
|
||||
因此这次 6 个 camera ablation 的最佳结果(`top + front = 274.8`)说明:
|
||||
|
||||
- 当前这个训练批次里,**去掉任意一个视角都会显著低于之前的 3 相机最优结果**;
|
||||
- 但在去掉视角的约束下,**`top` 仍然是最核心的保留对象**。
|
||||
@@ -0,0 +1,8 @@
|
||||
# CHECKLIST
|
||||
|
||||
- [x] Confirm remote free GPU
|
||||
- [x] Create front-only run contract
|
||||
- [x] Remote smoke test passes
|
||||
- [x] Launch 50k run on remote GPU0
|
||||
- [x] Record pid / log / SwanLab
|
||||
- [x] Report status back to user
|
||||
28
experiment_suites/2026-04-05-front-only-resnet-1cam/PLAN.md
Normal file
28
experiment_suites/2026-04-05-front-only-resnet-1cam/PLAN.md
Normal file
@@ -0,0 +1,28 @@
|
||||
# PLAN
|
||||
|
||||
## Goal
|
||||
Train a 50k-step IMF baseline with the original ResNet vision backbone, using only the `front` camera as image conditioning.
|
||||
|
||||
## Fixed comparison contract
|
||||
- Same as the active `top/front` run except image input is reduced to `[front]`
|
||||
- Agent: `resnet_imf_attnres`
|
||||
- Vision backbone mode: `resnet`
|
||||
- `pred_horizon=16`, `num_action_steps=8`
|
||||
- `n_emb=384`, `n_layer=12`, `n_head=1`, `n_kv_head=1`
|
||||
- `inference_steps=1`
|
||||
- `batch_size=80`, `lr=2.5e-4`, cosine, warmup=2000
|
||||
- dataset: `/home/droid/sim_dataset/sim_transfer`
|
||||
- cameras: `[front]` only
|
||||
- rollout every 5 epochs with 5 episodes, headless
|
||||
|
||||
## Resource plan
|
||||
- Host: `100.119.99.14`
|
||||
- GPU: `0`
|
||||
|
||||
## Important dimension override
|
||||
- Single-camera visual cond dim = `64 + 16 = 80`, so override `agent.head.cond_dim=80` and `agent.num_cams=1`.
|
||||
|
||||
## Execution path
|
||||
1. 2-step smoke test on remote GPU0.
|
||||
2. If smoke passes, launch 50k main run with SwanLab.
|
||||
3. Record pid / run_dir / log / URL locally.
|
||||
@@ -0,0 +1,6 @@
|
||||
# Notes
|
||||
|
||||
- 2026-04-05 09:55:27: remote 2-step smoke passed on `100.119.99.14` GPU0 with `front` only, batch=80, no OOM.
|
||||
- 2026-04-05 09:56:26: launched main run `imf-resnet-front-1cam-ph16-ex08-emb384-l12-ms50k-l20g0-20260405-095607`.
|
||||
- 2026-04-05 09:57:36: confirmed training is stable through step 200, latest loss 0.2830.
|
||||
- SwanLab: https://swanlab.cn/@game-loader/roboimi-vla/runs/7kdii8oc6tjkcyu5y0lwq
|
||||
@@ -0,0 +1,51 @@
|
||||
{
|
||||
"suite_name": "2026-04-05-front-only-resnet-1cam",
|
||||
"updated_at": "2026-04-05 09:57:36",
|
||||
"phase": "running",
|
||||
"baseline_reference": {
|
||||
"source_run": "imf-resnet-topfront-2cam-ph16-ex08-emb384-l12-ms50k-5090-20260405-085023",
|
||||
"notes": "Same hyperparameters as the active top/front run, but image input is reduced to [front] only."
|
||||
},
|
||||
"smoke_test": {
|
||||
"status": "passed",
|
||||
"host": "100.119.99.14",
|
||||
"gpu": 0,
|
||||
"run_dir": "/home/droid/roboimi_suite_20260404/runs/smoke-frontonly-resnet-ph16-ex08-20260405-095509",
|
||||
"batch_size": 80,
|
||||
"max_steps": 2,
|
||||
"note": "2-step remote CUDA smoke passed on L20 GPU0 without OOM."
|
||||
},
|
||||
"main_run": {
|
||||
"status": "running",
|
||||
"host": "100.119.99.14",
|
||||
"gpu": 0,
|
||||
"launch_pid": 158874,
|
||||
"pid": 158877,
|
||||
"run_name": "imf-resnet-front-1cam-ph16-ex08-emb384-l12-ms50k-l20g0-20260405-095607",
|
||||
"run_dir": "/home/droid/roboimi_suite_20260404/runs/imf-resnet-front-1cam-ph16-ex08-emb384-l12-ms50k-l20g0-20260405-095607",
|
||||
"log_path": "/home/droid/roboimi_suite_20260404/runs/imf-resnet-front-1cam-ph16-ex08-emb384-l12-ms50k-l20g0-20260405-095607/train_vla.log",
|
||||
"launch_log": "/home/droid/roboimi_suite_20260404/experiment_suite_launch_logs/imf-resnet-front-1cam-ph16-ex08-emb384-l12-ms50k-l20g0-20260405-095607.launch.log",
|
||||
"dataset_dir": "/home/droid/sim_dataset/sim_transfer",
|
||||
"camera_names": [
|
||||
"front"
|
||||
],
|
||||
"pred_horizon": 16,
|
||||
"num_action_steps": 8,
|
||||
"head_cond_dim": 80,
|
||||
"head_n_emb": 384,
|
||||
"head_n_layer": 12,
|
||||
"vision_backbone_mode": "resnet",
|
||||
"pretrained_backbone_weights": null,
|
||||
"freeze_backbone": false,
|
||||
"batch_size": 80,
|
||||
"lr": 0.00025,
|
||||
"num_workers": 12,
|
||||
"max_steps": 50000,
|
||||
"rollout_val_freq_epochs": 5,
|
||||
"rollout_num_episodes": 5,
|
||||
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/7kdii8oc6tjkcyu5y0lwq",
|
||||
"latest_step": 200,
|
||||
"latest_loss": 0.283,
|
||||
"process_running": true
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
# CHECKLIST
|
||||
|
||||
- [x] Confirm camera mapping (`right` -> `r_vis`)
|
||||
- [x] Create front+r_vis run contract
|
||||
- [x] Remote smoke test passes
|
||||
- [x] Launch 50k run on remote GPU1
|
||||
- [x] Record pid / log / SwanLab
|
||||
- [x] Report status back to user
|
||||
23
experiment_suites/2026-04-05-front-rvis-resnet-2cam/PLAN.md
Normal file
23
experiment_suites/2026-04-05-front-rvis-resnet-2cam/PLAN.md
Normal file
@@ -0,0 +1,23 @@
|
||||
# PLAN
|
||||
|
||||
## Goal
|
||||
Train a 50k-step IMF baseline with the original ResNet vision backbone, using `front` + `r_vis` cameras only.
|
||||
|
||||
## Fixed comparison contract
|
||||
- Same hyperparameters as the active top/front and front-only runs
|
||||
- Agent: `resnet_imf_attnres`
|
||||
- Vision backbone mode: `resnet`
|
||||
- `pred_horizon=16`, `num_action_steps=8`
|
||||
- `n_emb=384`, `n_layer=12`, `n_head=1`, `n_kv_head=1`
|
||||
- `inference_steps=1`
|
||||
- `batch_size=80`, `lr=2.5e-4`, cosine warmup 2000
|
||||
- dataset: `/home/droid/sim_dataset/sim_transfer`
|
||||
- cameras: `[r_vis, front]`
|
||||
- rollout every 5 epochs with 5 episodes, headless
|
||||
|
||||
## Important dimension override
|
||||
- Two-camera visual cond dim = `64*2 + 16 = 144`, so set `agent.num_cams=2`, `agent.head.cond_dim=144`.
|
||||
|
||||
## Resource plan
|
||||
- Host: `100.119.99.14`
|
||||
- GPU: `1`
|
||||
@@ -0,0 +1,6 @@
|
||||
# Notes
|
||||
|
||||
- 2026-04-05 10:20:09: remote 2-step smoke passed on `100.119.99.14` GPU1 with `r_vis + front`, batch=80, no OOM.
|
||||
- 2026-04-05 10:20:49: launched main run `imf-resnet-frontrvis-2cam-ph16-ex08-emb384-l12-ms50k-l20g1-20260405-102029`.
|
||||
- 2026-04-05 10:22:03: confirmed training is stable through step 200, latest loss 0.3321.
|
||||
- SwanLab: https://swanlab.cn/@game-loader/roboimi-vla/runs/3fyzjfdcbiq7frtbqv6ss
|
||||
@@ -0,0 +1,55 @@
|
||||
{
|
||||
"suite_name": "2026-04-05-front-rvis-resnet-2cam",
|
||||
"updated_at": "2026-04-05 10:22:03",
|
||||
"phase": "running",
|
||||
"interpretation": {
|
||||
"right_camera_name": "r_vis"
|
||||
},
|
||||
"baseline_reference": {
|
||||
"source_run": "imf-resnet-topfront-2cam-ph16-ex08-emb384-l12-ms50k-5090-20260405-085023",
|
||||
"notes": "Same hyperparameters as the active top/front run, replacing top with r_vis."
|
||||
},
|
||||
"smoke_test": {
|
||||
"status": "passed",
|
||||
"host": "100.119.99.14",
|
||||
"gpu": 1,
|
||||
"run_dir": "/home/droid/roboimi_suite_20260404/runs/smoke-frontrvis-resnet-ph16-ex08-20260405-102001",
|
||||
"batch_size": 80,
|
||||
"max_steps": 2,
|
||||
"note": "2-step remote CUDA smoke passed on L20 GPU1 without OOM."
|
||||
},
|
||||
"main_run": {
|
||||
"status": "running",
|
||||
"host": "100.119.99.14",
|
||||
"gpu": 1,
|
||||
"launch_pid": 159910,
|
||||
"pid": 159913,
|
||||
"run_name": "imf-resnet-frontrvis-2cam-ph16-ex08-emb384-l12-ms50k-l20g1-20260405-102029",
|
||||
"run_dir": "/home/droid/roboimi_suite_20260404/runs/imf-resnet-frontrvis-2cam-ph16-ex08-emb384-l12-ms50k-l20g1-20260405-102029",
|
||||
"log_path": "/home/droid/roboimi_suite_20260404/runs/imf-resnet-frontrvis-2cam-ph16-ex08-emb384-l12-ms50k-l20g1-20260405-102029/train_vla.log",
|
||||
"launch_log": "/home/droid/roboimi_suite_20260404/experiment_suite_launch_logs/imf-resnet-frontrvis-2cam-ph16-ex08-emb384-l12-ms50k-l20g1-20260405-102029.launch.log",
|
||||
"dataset_dir": "/home/droid/sim_dataset/sim_transfer",
|
||||
"camera_names": [
|
||||
"r_vis",
|
||||
"front"
|
||||
],
|
||||
"pred_horizon": 16,
|
||||
"num_action_steps": 8,
|
||||
"head_cond_dim": 144,
|
||||
"head_n_emb": 384,
|
||||
"head_n_layer": 12,
|
||||
"vision_backbone_mode": "resnet",
|
||||
"pretrained_backbone_weights": null,
|
||||
"freeze_backbone": false,
|
||||
"batch_size": 80,
|
||||
"lr": 0.00025,
|
||||
"num_workers": 12,
|
||||
"max_steps": 50000,
|
||||
"rollout_val_freq_epochs": 5,
|
||||
"rollout_num_episodes": 5,
|
||||
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/3fyzjfdcbiq7frtbqv6ss",
|
||||
"latest_step": 200,
|
||||
"latest_loss": 0.3321,
|
||||
"process_running": true
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
{
|
||||
"suite_name": "2026-04-05-full-attnres-vision-phase2",
|
||||
"created_at": "2026-04-05 00:12:14",
|
||||
"phase": "phase2_running",
|
||||
"baseline_reference": {
|
||||
"run_id": "ph16_ex8",
|
||||
"best_rollout_avg_reward": 610.8,
|
||||
"best_step": 21874,
|
||||
"run_dir": "/home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223"
|
||||
},
|
||||
"candidate": {
|
||||
"run_name": "imf-p2-full-attnres-vision-ph16-ex08-emb384-l12-ms50k-20260405-001214",
|
||||
"host": "local",
|
||||
"gpu": 0,
|
||||
"pred_horizon": 16,
|
||||
"num_action_steps": 8,
|
||||
"vision_backbone_mode": "attnres_resnet",
|
||||
"notes": "Full-AttnRes vision backbone replacing ResNet residual units; IMF head unchanged."
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
# Full-AttnRes Vision Phase-2
|
||||
|
||||
- Created: 2026-04-05 00:12:14
|
||||
- Baseline reference: ph16_ex8 best avg_reward=610.8
|
||||
- Candidate run: imf-p2-full-attnres-vision-ph16-ex08-emb384-l12-ms50k-20260405-001214
|
||||
- 2026-04-05 00:23:03: batch=80 OOM on both 5090 and L20; using validated fallback batch=40, lr=1.25e-4 on remote L20 GPU3.
|
||||
- 2026-04-05 00:24:24: launching candidate imf-p2-full-attnres-vision-ph16-ex08-emb384-l12-b40-lr1p25e4-ms50k-l20g3-20260405-002424 on 100.119.99.14 GPU3 with batch=40 lr=1.25e-4.
|
||||
- 2026-04-05 00:27:17: remote phase2 run is active on 100.119.99.14 GPU3, validated at least to step 200. SwanLab: https://swanlab.cn/@game-loader/roboimi-vla/runs/xy7fjdmn0stdr19eu3gub
|
||||
- 2026-04-05 00:36:54: latest confirmed progress is step 1300 on 100.119.99.14 GPU3; first rollout not reached yet.
|
||||
@@ -0,0 +1,32 @@
|
||||
{
|
||||
"suite_name": "2026-04-05-full-attnres-vision-phase2",
|
||||
"updated_at": "2026-04-05 00:36:54",
|
||||
"phase": "phase2_running",
|
||||
"baseline_reference": {
|
||||
"run_id": "ph16_ex8",
|
||||
"best_rollout_avg_reward": 610.8,
|
||||
"best_step": 21874,
|
||||
"run_dir": "/home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223"
|
||||
},
|
||||
"candidate": {
|
||||
"run_name": "imf-p2-full-attnres-vision-ph16-ex08-emb384-l12-b40-lr1p25e4-ms50k-l20g3-20260405-002424",
|
||||
"host": "100.119.99.14",
|
||||
"gpu": 3,
|
||||
"pred_horizon": 16,
|
||||
"num_action_steps": 8,
|
||||
"vision_backbone_mode": "attnres_resnet",
|
||||
"notes": "Full-AttnRes vision backbone replacing ResNet residual units; IMF head unchanged.",
|
||||
"status": "running",
|
||||
"run_dir": "/home/droid/roboimi_suite_20260404/runs/imf-p2-full-attnres-vision-ph16-ex08-emb384-l12-b40-lr1p25e4-ms50k-l20g3-20260405-002424",
|
||||
"log_path": "/home/droid/roboimi_suite_20260404/runs/imf-p2-full-attnres-vision-ph16-ex08-emb384-l12-b40-lr1p25e4-ms50k-l20g3-20260405-002424/train_vla.log",
|
||||
"pid": 151187,
|
||||
"batch_size": 40,
|
||||
"lr": 0.000125,
|
||||
"num_workers": 12,
|
||||
"launch_log": "/home/droid/roboimi_suite_20260404/experiment_suite_launch_logs/imf-p2-full-attnres-vision-ph16-ex08-emb384-l12-b40-lr1p25e4-ms50k-l20g3-20260405-002424.launch.log",
|
||||
"note": "Local 5090 and remote L20 both OOM at batch=80; switched to batch=40 and linearly scaled lr to 1.25e-4 after smoke validation on L20.",
|
||||
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/xy7fjdmn0stdr19eu3gub",
|
||||
"latest_step": 1300,
|
||||
"latest_log_sync": "2026-04-05 00:36:54"
|
||||
}
|
||||
}
|
||||
73
experiment_suites/2026-04-05-lewm-vit-transfer/manifest.json
Normal file
73
experiment_suites/2026-04-05-lewm-vit-transfer/manifest.json
Normal file
@@ -0,0 +1,73 @@
|
||||
{
|
||||
"date": "2026-04-06",
|
||||
"branch": "feat-imf-attnres-policy",
|
||||
"worktree": "/home/droid/project/roboimi/.worktrees/feat-imf-attnres-policy",
|
||||
"model": "LEWM ViT frozen visual encoder + IMF AttnRes diffusion head",
|
||||
"checkpoint_path": "/home/droid/le-wm/lewm-sim-transfer/pa1w85md8jop6bvol8oxp/checkpoints/epoch=99-step=47800.ckpt",
|
||||
"visual_contract": {
|
||||
"input_camera_names": ["r_vis", "top", "front"],
|
||||
"fused_camera_names": ["front", "top", "r_vis"],
|
||||
"joint_output_dim": 192,
|
||||
"freeze_backbone": true,
|
||||
"dataset_image_resize_shape": null,
|
||||
"eval_image_resize_shape": [256, 256],
|
||||
"fused_short_side_resize": 224
|
||||
},
|
||||
"training_contract": {
|
||||
"pred_horizon": 16,
|
||||
"num_action_steps": 8,
|
||||
"max_steps": 50000,
|
||||
"rollout_val_freq_epochs": 5,
|
||||
"rollout_num_episodes": 10,
|
||||
"batch_size": 80,
|
||||
"lr": 0.00025,
|
||||
"num_workers": 12,
|
||||
"scheduler_type": "cosine",
|
||||
"warmup_steps": 2000,
|
||||
"min_lr": 1e-06,
|
||||
"weight_decay": 1e-05,
|
||||
"grad_clip": 1.0
|
||||
},
|
||||
"verification": {
|
||||
"local_tests": "38 passed",
|
||||
"remote_dataset_shape": [2, 3, 256, 256],
|
||||
"remote_eval_prepared_shape": [3, 256, 256],
|
||||
"remote_smoke_run": {
|
||||
"run_name": "smoke-lewm-imf-rawpath-emb384-20260406-002002",
|
||||
"result": "passed",
|
||||
"details": "2-step train + checkpoint-triggered 1-episode headless rollout succeeded with corrected raw256 path"
|
||||
}
|
||||
},
|
||||
"superseded_runs": [
|
||||
{
|
||||
"run_name": "lewm-vit-imf-sim-transfer-emb384-l12-ph16-ex08-step50k-roll10-5880g0-20260405-201914",
|
||||
"reason": "stopped due to incorrect early per-camera 224 resize"
|
||||
},
|
||||
{
|
||||
"run_name": "lewm-vit-imf-sim-transfer-emb256-l12-ph16-ex08-step50k-roll10-5880g1-20260405-201914",
|
||||
"reason": "stopped due to incorrect early per-camera 224 resize"
|
||||
}
|
||||
],
|
||||
"full_runs": [
|
||||
{
|
||||
"host": "100.73.14.65",
|
||||
"gpu": 0,
|
||||
"run_name": "lewm-vit-imf-raw256fix-sim-transfer-emb384-l12-ph16-ex08-step50k-roll10-5880g0-20260406-002124",
|
||||
"pid": 1058589,
|
||||
"log_path": "/home/droid/roboimi_suite_20260404/experiment_suite_launch_logs/lewm-vit-imf-raw256fix-sim-transfer-emb384-l12-ph16-ex08-step50k-roll10-5880g0-20260406-002124.launch.log",
|
||||
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/y5tzgqe0u966w9ak41i31",
|
||||
"head_n_emb": 384,
|
||||
"head_n_layer": 12
|
||||
},
|
||||
{
|
||||
"host": "100.73.14.65",
|
||||
"gpu": 1,
|
||||
"run_name": "lewm-vit-imf-raw256fix-sim-transfer-emb256-l12-ph16-ex08-step50k-roll10-5880g1-20260406-002124",
|
||||
"pid": 1058590,
|
||||
"log_path": "/home/droid/roboimi_suite_20260404/experiment_suite_launch_logs/lewm-vit-imf-raw256fix-sim-transfer-emb256-l12-ph16-ex08-step50k-roll10-5880g1-20260406-002124.launch.log",
|
||||
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/2esr9y7t2dgesstgrn5i6",
|
||||
"head_n_emb": 256,
|
||||
"head_n_layer": 12
|
||||
}
|
||||
]
|
||||
}
|
||||
25
experiment_suites/2026-04-05-lewm-vit-transfer/notes.md
Normal file
25
experiment_suites/2026-04-05-lewm-vit-transfer/notes.md
Normal file
@@ -0,0 +1,25 @@
|
||||
# 2026-04-06 LEWM ViT Transfer Notes
|
||||
|
||||
## Root-cause fix
|
||||
|
||||
The first LEWM runs were stopped because the data path still resized each camera view to `224x224` **before** multiview fusion. That preserved the final tensor shape but broke the original LEWM geometry.
|
||||
|
||||
Corrected path now is:
|
||||
|
||||
- **Training dataset**: keep stored per-view `256x256` images (`data.image_resize_shape=null` at launch; dataset instantiate override is `None` for LEWM)
|
||||
- **Eval rollout input**: resize live MuJoCo `480x640` camera images to `256x256` per view
|
||||
- **Backbone**: fuse `front, top, r_vis` on the LEWM axis, then resize fused short side to `224`
|
||||
|
||||
## Verification
|
||||
|
||||
- Local tests passed (`38 passed` across the focused suite)
|
||||
- Remote check:
|
||||
- dataset sample image shape: `(2, 3, 256, 256)`
|
||||
- eval-prepared live frame shape: `(3, 256, 256)`
|
||||
- Remote smoke passed with real checkpoint:
|
||||
- `smoke-lewm-imf-rawpath-emb384-20260406-002002`
|
||||
|
||||
## Current runs
|
||||
|
||||
- `lewm-vit-imf-raw256fix-sim-transfer-emb384-l12-ph16-ex08-step50k-roll10-5880g0-20260406-002124`
|
||||
- `lewm-vit-imf-raw256fix-sim-transfer-emb256-l12-ph16-ex08-step50k-roll10-5880g1-20260406-002124`
|
||||
19
experiment_suites/2026-04-05-lewm-vit-transfer/status.json
Normal file
19
experiment_suites/2026-04-05-lewm-vit-transfer/status.json
Normal file
@@ -0,0 +1,19 @@
|
||||
{
|
||||
"status": "running",
|
||||
"updated_at": "2026-04-06T00:22:10+08:00",
|
||||
"remote_host": "100.73.14.65",
|
||||
"runs": [
|
||||
{
|
||||
"run_name": "lewm-vit-imf-raw256fix-sim-transfer-emb384-l12-ph16-ex08-step50k-roll10-5880g0-20260406-002124",
|
||||
"pid": 1058589,
|
||||
"gpu": 0,
|
||||
"state": "running"
|
||||
},
|
||||
{
|
||||
"run_name": "lewm-vit-imf-raw256fix-sim-transfer-emb256-l12-ph16-ex08-step50k-roll10-5880g1-20260406-002124",
|
||||
"pid": 1058590,
|
||||
"gpu": 1,
|
||||
"state": "running"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
# CHECKLIST
|
||||
|
||||
- [x] Create run contract
|
||||
- [x] Remote smoke test passes
|
||||
- [x] Launch 50k main run
|
||||
- [x] Record pid / log / SwanLab
|
||||
- [x] Report status back to user
|
||||
12
experiment_suites/2026-04-05-rvis-only-resnet-1cam/PLAN.md
Normal file
12
experiment_suites/2026-04-05-rvis-only-resnet-1cam/PLAN.md
Normal file
@@ -0,0 +1,12 @@
|
||||
# PLAN
|
||||
|
||||
## Goal
|
||||
Train a 50k-step IMF baseline with the original ResNet vision backbone using r_vis only as the only image conditioning.
|
||||
|
||||
## Fixed comparison contract
|
||||
- same hyperparameters as the active top/front run
|
||||
- cameras: ['r_vis']
|
||||
- num_cams=1
|
||||
- head.cond_dim=80
|
||||
- host: 100.119.99.14
|
||||
- gpu: 3
|
||||
@@ -0,0 +1,6 @@
|
||||
# Notes
|
||||
|
||||
- 2026-04-05 12:58:22: smoke passed for ['r_vis'] on 100.119.99.14 GPU3.
|
||||
- 2026-04-05 12:59:24: launched main run `imf-resnet-rvis-1cam-ph16-ex08-emb384-l12-ms50k-l20g3-20260405-125844`.
|
||||
- 2026-04-05 13:01:20: latest confirmed progress step=400, loss=0.1165.
|
||||
- SwanLab: https://swanlab.cn/@game-loader/roboimi-vla/runs/qnuh7vln9mqomxxldyecq
|
||||
@@ -0,0 +1,47 @@
|
||||
{
|
||||
"suite_name": "2026-04-05-rvis-only-resnet-1cam",
|
||||
"updated_at": "2026-04-05 13:01:20",
|
||||
"phase": "running",
|
||||
"smoke_test": {
|
||||
"status": "passed",
|
||||
"host": "100.119.99.14",
|
||||
"gpu": 3,
|
||||
"run_dir": "/home/droid/roboimi_suite_20260404/runs/smoke-rvisonly-resnet-ph16-ex08-20260405-125812",
|
||||
"batch_size": 80,
|
||||
"max_steps": 2,
|
||||
"note": "2-step remote CUDA smoke passed without OOM."
|
||||
},
|
||||
"main_run": {
|
||||
"status": "running",
|
||||
"host": "100.119.99.14",
|
||||
"gpu": 3,
|
||||
"launch_pid": 164812,
|
||||
"pid": 164816,
|
||||
"run_name": "imf-resnet-rvis-1cam-ph16-ex08-emb384-l12-ms50k-l20g3-20260405-125844",
|
||||
"run_dir": "/home/droid/roboimi_suite_20260404/runs/imf-resnet-rvis-1cam-ph16-ex08-emb384-l12-ms50k-l20g3-20260405-125844",
|
||||
"log_path": "/home/droid/roboimi_suite_20260404/runs/imf-resnet-rvis-1cam-ph16-ex08-emb384-l12-ms50k-l20g3-20260405-125844/train_vla.log",
|
||||
"launch_log": "/home/droid/roboimi_suite_20260404/experiment_suite_launch_logs/imf-resnet-rvis-1cam-ph16-ex08-emb384-l12-ms50k-l20g3-20260405-125844.launch.log",
|
||||
"dataset_dir": "/home/droid/sim_dataset/sim_transfer",
|
||||
"camera_names": [
|
||||
"r_vis"
|
||||
],
|
||||
"pred_horizon": 16,
|
||||
"num_action_steps": 8,
|
||||
"head_cond_dim": 80,
|
||||
"head_n_emb": 384,
|
||||
"head_n_layer": 12,
|
||||
"vision_backbone_mode": "resnet",
|
||||
"pretrained_backbone_weights": null,
|
||||
"freeze_backbone": false,
|
||||
"batch_size": 80,
|
||||
"lr": 0.00025,
|
||||
"num_workers": 12,
|
||||
"max_steps": 50000,
|
||||
"rollout_val_freq_epochs": 5,
|
||||
"rollout_num_episodes": 5,
|
||||
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/qnuh7vln9mqomxxldyecq",
|
||||
"latest_step": 400,
|
||||
"latest_loss": 0.1165,
|
||||
"process_running": true
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
# CHECKLIST
|
||||
|
||||
- [x] Create run contract
|
||||
- [x] Remote smoke test passes
|
||||
- [x] Launch 50k main run
|
||||
- [x] Record pid / log / SwanLab
|
||||
- [x] Report status back to user
|
||||
12
experiment_suites/2026-04-05-rvistop-resnet-2cam/PLAN.md
Normal file
12
experiment_suites/2026-04-05-rvistop-resnet-2cam/PLAN.md
Normal file
@@ -0,0 +1,12 @@
|
||||
# PLAN
|
||||
|
||||
## Goal
|
||||
Train a 50k-step IMF baseline with the original ResNet vision backbone using r_vis + top as the only image conditioning.
|
||||
|
||||
## Fixed comparison contract
|
||||
- same hyperparameters as the active top/front run
|
||||
- cameras: ['r_vis', 'top']
|
||||
- num_cams=2
|
||||
- head.cond_dim=144
|
||||
- host: 100.119.99.14
|
||||
- gpu: 2
|
||||
@@ -0,0 +1,6 @@
|
||||
# Notes
|
||||
|
||||
- 2026-04-05 12:58:22: smoke passed for ['r_vis', 'top'] on 100.119.99.14 GPU2.
|
||||
- 2026-04-05 12:59:24: launched main run `imf-resnet-rvistop-2cam-ph16-ex08-emb384-l12-ms50k-l20g2-20260405-125844`.
|
||||
- 2026-04-05 13:01:20: latest confirmed progress step=200, loss=0.2845.
|
||||
- SwanLab: https://swanlab.cn/@game-loader/roboimi-vla/runs/umsm6402eb81et7wx7z4a
|
||||
48
experiment_suites/2026-04-05-rvistop-resnet-2cam/status.json
Normal file
48
experiment_suites/2026-04-05-rvistop-resnet-2cam/status.json
Normal file
@@ -0,0 +1,48 @@
|
||||
{
|
||||
"suite_name": "2026-04-05-rvistop-resnet-2cam",
|
||||
"updated_at": "2026-04-05 13:01:20",
|
||||
"phase": "running",
|
||||
"smoke_test": {
|
||||
"status": "passed",
|
||||
"host": "100.119.99.14",
|
||||
"gpu": 2,
|
||||
"run_dir": "/home/droid/roboimi_suite_20260404/runs/smoke-rvistop-resnet-ph16-ex08-20260405-125812",
|
||||
"batch_size": 80,
|
||||
"max_steps": 2,
|
||||
"note": "2-step remote CUDA smoke passed without OOM."
|
||||
},
|
||||
"main_run": {
|
||||
"status": "running",
|
||||
"host": "100.119.99.14",
|
||||
"gpu": 2,
|
||||
"launch_pid": 164745,
|
||||
"pid": 164749,
|
||||
"run_name": "imf-resnet-rvistop-2cam-ph16-ex08-emb384-l12-ms50k-l20g2-20260405-125844",
|
||||
"run_dir": "/home/droid/roboimi_suite_20260404/runs/imf-resnet-rvistop-2cam-ph16-ex08-emb384-l12-ms50k-l20g2-20260405-125844",
|
||||
"log_path": "/home/droid/roboimi_suite_20260404/runs/imf-resnet-rvistop-2cam-ph16-ex08-emb384-l12-ms50k-l20g2-20260405-125844/train_vla.log",
|
||||
"launch_log": "/home/droid/roboimi_suite_20260404/experiment_suite_launch_logs/imf-resnet-rvistop-2cam-ph16-ex08-emb384-l12-ms50k-l20g2-20260405-125844.launch.log",
|
||||
"dataset_dir": "/home/droid/sim_dataset/sim_transfer",
|
||||
"camera_names": [
|
||||
"r_vis",
|
||||
"top"
|
||||
],
|
||||
"pred_horizon": 16,
|
||||
"num_action_steps": 8,
|
||||
"head_cond_dim": 144,
|
||||
"head_n_emb": 384,
|
||||
"head_n_layer": 12,
|
||||
"vision_backbone_mode": "resnet",
|
||||
"pretrained_backbone_weights": null,
|
||||
"freeze_backbone": false,
|
||||
"batch_size": 80,
|
||||
"lr": 0.00025,
|
||||
"num_workers": 12,
|
||||
"max_steps": 50000,
|
||||
"rollout_val_freq_epochs": 5,
|
||||
"rollout_num_episodes": 5,
|
||||
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/umsm6402eb81et7wx7z4a",
|
||||
"latest_step": 200,
|
||||
"latest_loss": 0.2845,
|
||||
"process_running": true
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
# CHECKLIST
|
||||
|
||||
- [x] Confirm baseline hyperparameters from trusted prior run
|
||||
- [x] Confirm local GPU availability
|
||||
- [x] Smoke test with `top/front` cameras only
|
||||
- [x] Launch 50k run
|
||||
- [x] Record pid / run dir / log path / SwanLab URL
|
||||
- [x] Report status back to user
|
||||
30
experiment_suites/2026-04-05-top-front-resnet-2cam/PLAN.md
Normal file
30
experiment_suites/2026-04-05-top-front-resnet-2cam/PLAN.md
Normal file
@@ -0,0 +1,30 @@
|
||||
# PLAN
|
||||
|
||||
## Goal
|
||||
Train a 50k-step IMF baseline with the original ResNet vision backbone (no full-AttnRes vision replacement), using only `top` and `front` cameras as image conditioning.
|
||||
|
||||
## Fixed comparison contract
|
||||
- Agent: `resnet_imf_attnres`
|
||||
- Vision backbone mode: `resnet`
|
||||
- `pred_horizon=16`
|
||||
- `num_action_steps=8`
|
||||
- `n_emb=384`, `n_layer=12`, `n_head=1`, `n_kv_head=1`
|
||||
- `inference_steps=1`
|
||||
- `batch_size=80`, `lr=2.5e-4`, cosine scheduler, warmup 2000
|
||||
- dataset: `/home/droid/project/diana_sim/sim_transfer`
|
||||
- cameras: `[top, front]` only
|
||||
- training budget: `max_steps=50000`
|
||||
- rollout validation: every 5 epochs, 5 episodes, headless
|
||||
|
||||
## Resource plan
|
||||
- Host: local
|
||||
- GPU: RTX 5090 (GPU 0)
|
||||
|
||||
## Execution path
|
||||
1. Run a short 2-step smoke test on GPU with the exact 2-camera config.
|
||||
2. If smoke passes, launch the 50k main run with durable log redirection.
|
||||
3. Record run name, pid, log path, and SwanLab URL into suite status.
|
||||
|
||||
## Fallbacks
|
||||
- If batch 80 OOMs, fall back to batch 64 with scaled lr 2.0e-4.
|
||||
- If dataloader startup is unstable, reduce num_workers from 12 to 8.
|
||||
@@ -0,0 +1,5 @@
|
||||
# Notes
|
||||
|
||||
- 2026-04-05 08:50:04: 2-step smoke test passed locally on RTX 5090 with `top/front` cameras, batch=80, no OOM.
|
||||
- 2026-04-05 08:50:42: launched main run `imf-resnet-topfront-2cam-ph16-ex08-emb384-l12-ms50k-5090-20260405-085023` on local GPU0.
|
||||
- SwanLab: https://swanlab.cn/@game-loader/roboimi-vla/runs/vi77mn5dwd19z4nttxab8
|
||||
@@ -0,0 +1,51 @@
|
||||
{
|
||||
"suite_name": "2026-04-05-top-front-resnet-2cam",
|
||||
"updated_at": "2026-04-05 08:52:12",
|
||||
"phase": "running",
|
||||
"baseline_reference": {
|
||||
"source_run": "imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223",
|
||||
"best_rollout_avg_reward": 610.8,
|
||||
"best_step": 21874,
|
||||
"notes": "Same IMF baseline as Phase-1 best, but switch cameras from [r_vis, top, front] to [top, front] and keep the original ResNet vision backbone."
|
||||
},
|
||||
"smoke_test": {
|
||||
"status": "passed",
|
||||
"run_dir": "/home/droid/project/roboimi/.worktrees/feat-imf-attnres-policy/runs/smoke-topfront-resnet-ph16-ex08-20260405-085000",
|
||||
"batch_size": 80,
|
||||
"num_workers": 4,
|
||||
"max_steps": 2,
|
||||
"note": "2-step local CUDA smoke passed without OOM using top/front only."
|
||||
},
|
||||
"main_run": {
|
||||
"status": "running",
|
||||
"host": "local",
|
||||
"gpu": 0,
|
||||
"pid": 1693348,
|
||||
"run_name": "imf-resnet-topfront-2cam-ph16-ex08-emb384-l12-ms50k-5090-20260405-085023",
|
||||
"run_dir": "/home/droid/project/roboimi/.worktrees/feat-imf-attnres-policy/runs/imf-resnet-topfront-2cam-ph16-ex08-emb384-l12-ms50k-5090-20260405-085023",
|
||||
"log_path": "/home/droid/project/roboimi/.worktrees/feat-imf-attnres-policy/runs/imf-resnet-topfront-2cam-ph16-ex08-emb384-l12-ms50k-5090-20260405-085023/train_vla.log",
|
||||
"launch_log": "/home/droid/project/roboimi/.worktrees/feat-imf-attnres-policy/experiment_suites/2026-04-05-top-front-resnet-2cam/launch_logs/imf-resnet-topfront-2cam-ph16-ex08-emb384-l12-ms50k-5090-20260405-085023.launch.log",
|
||||
"dataset_dir": "/home/droid/project/diana_sim/sim_transfer",
|
||||
"camera_names": [
|
||||
"top",
|
||||
"front"
|
||||
],
|
||||
"pred_horizon": 16,
|
||||
"num_action_steps": 8,
|
||||
"head_n_emb": 384,
|
||||
"head_n_layer": 12,
|
||||
"vision_backbone_mode": "resnet",
|
||||
"pretrained_backbone_weights": null,
|
||||
"freeze_backbone": false,
|
||||
"batch_size": 80,
|
||||
"lr": 0.00025,
|
||||
"num_workers": 12,
|
||||
"max_steps": 50000,
|
||||
"rollout_val_freq_epochs": 5,
|
||||
"rollout_num_episodes": 5,
|
||||
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/vi77mn5dwd19z4nttxab8",
|
||||
"latest_step": 500,
|
||||
"latest_loss": 0.0978,
|
||||
"process_running": true
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
# CHECKLIST
|
||||
|
||||
- [x] Create run contract
|
||||
- [x] Remote smoke test passes
|
||||
- [x] Launch 50k main run
|
||||
- [x] Record pid / log / SwanLab
|
||||
- [x] Report status back to user
|
||||
12
experiment_suites/2026-04-05-top-only-resnet-1cam/PLAN.md
Normal file
12
experiment_suites/2026-04-05-top-only-resnet-1cam/PLAN.md
Normal file
@@ -0,0 +1,12 @@
|
||||
# PLAN
|
||||
|
||||
## Goal
|
||||
Train a 50k-step IMF baseline with the original ResNet vision backbone using top only as the only image conditioning.
|
||||
|
||||
## Fixed comparison contract
|
||||
- same hyperparameters as the active top/front run
|
||||
- cameras: ['top']
|
||||
- num_cams=1
|
||||
- head.cond_dim=80
|
||||
- host: 100.119.99.14
|
||||
- gpu: 4
|
||||
@@ -0,0 +1,6 @@
|
||||
# Notes
|
||||
|
||||
- 2026-04-05 12:58:22: smoke passed for ['top'] on 100.119.99.14 GPU4.
|
||||
- 2026-04-05 12:59:24: launched main run `imf-resnet-top-1cam-ph16-ex08-emb384-l12-ms50k-l20g4-20260405-125844`.
|
||||
- 2026-04-05 13:01:20: latest confirmed progress step=400, loss=0.1233.
|
||||
- SwanLab: https://swanlab.cn/@game-loader/roboimi-vla/runs/egzo29l3z9ftsaunhf025
|
||||
@@ -0,0 +1,47 @@
|
||||
{
|
||||
"suite_name": "2026-04-05-top-only-resnet-1cam",
|
||||
"updated_at": "2026-04-05 13:01:20",
|
||||
"phase": "running",
|
||||
"smoke_test": {
|
||||
"status": "passed",
|
||||
"host": "100.119.99.14",
|
||||
"gpu": 4,
|
||||
"run_dir": "/home/droid/roboimi_suite_20260404/runs/smoke-toponly-resnet-ph16-ex08-20260405-125812",
|
||||
"batch_size": 80,
|
||||
"max_steps": 2,
|
||||
"note": "2-step remote CUDA smoke passed without OOM."
|
||||
},
|
||||
"main_run": {
|
||||
"status": "running",
|
||||
"host": "100.119.99.14",
|
||||
"gpu": 4,
|
||||
"launch_pid": 164808,
|
||||
"pid": 164813,
|
||||
"run_name": "imf-resnet-top-1cam-ph16-ex08-emb384-l12-ms50k-l20g4-20260405-125844",
|
||||
"run_dir": "/home/droid/roboimi_suite_20260404/runs/imf-resnet-top-1cam-ph16-ex08-emb384-l12-ms50k-l20g4-20260405-125844",
|
||||
"log_path": "/home/droid/roboimi_suite_20260404/runs/imf-resnet-top-1cam-ph16-ex08-emb384-l12-ms50k-l20g4-20260405-125844/train_vla.log",
|
||||
"launch_log": "/home/droid/roboimi_suite_20260404/experiment_suite_launch_logs/imf-resnet-top-1cam-ph16-ex08-emb384-l12-ms50k-l20g4-20260405-125844.launch.log",
|
||||
"dataset_dir": "/home/droid/sim_dataset/sim_transfer",
|
||||
"camera_names": [
|
||||
"top"
|
||||
],
|
||||
"pred_horizon": 16,
|
||||
"num_action_steps": 8,
|
||||
"head_cond_dim": 80,
|
||||
"head_n_emb": 384,
|
||||
"head_n_layer": 12,
|
||||
"vision_backbone_mode": "resnet",
|
||||
"pretrained_backbone_weights": null,
|
||||
"freeze_backbone": false,
|
||||
"batch_size": 80,
|
||||
"lr": 0.00025,
|
||||
"num_workers": 12,
|
||||
"max_steps": 50000,
|
||||
"rollout_val_freq_epochs": 5,
|
||||
"rollout_num_episodes": 5,
|
||||
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/egzo29l3z9ftsaunhf025",
|
||||
"latest_step": 400,
|
||||
"latest_loss": 0.1233,
|
||||
"process_running": true
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,46 @@
|
||||
import mujoco
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from roboimi.utils.KDL_utils import KDL_utils
|
||||
|
||||
|
||||
def resolve_robot_asset_path(asset_path):
|
||||
if asset_path is None:
|
||||
return None
|
||||
|
||||
raw_path = Path(asset_path).expanduser()
|
||||
if raw_path.is_absolute():
|
||||
return str(raw_path.resolve())
|
||||
|
||||
current_dir = Path(__file__).resolve().parent
|
||||
package_root = current_dir.parents[1]
|
||||
repo_root = current_dir.parents[2]
|
||||
|
||||
candidates = []
|
||||
if raw_path.parts and raw_path.parts[0] == 'roboimi':
|
||||
candidates.append(repo_root / raw_path)
|
||||
|
||||
candidates.extend([
|
||||
current_dir / raw_path,
|
||||
package_root / raw_path,
|
||||
repo_root / raw_path,
|
||||
])
|
||||
|
||||
normalized_candidates = []
|
||||
seen = set()
|
||||
for candidate in candidates:
|
||||
resolved = candidate.resolve()
|
||||
if resolved not in seen:
|
||||
normalized_candidates.append(resolved)
|
||||
seen.add(resolved)
|
||||
|
||||
for candidate in normalized_candidates:
|
||||
if candidate.exists():
|
||||
return str(candidate)
|
||||
|
||||
return str(normalized_candidates[0])
|
||||
|
||||
|
||||
class ArmBase(object):
|
||||
def __init__(self,
|
||||
name=None,
|
||||
@@ -11,8 +49,8 @@ class ArmBase(object):
|
||||
gripper=None
|
||||
):
|
||||
self.name = name
|
||||
self.urdf_path = urdf_path
|
||||
self.xml_path = xml_path
|
||||
self.urdf_path = resolve_robot_asset_path(urdf_path)
|
||||
self.xml_path = resolve_robot_asset_path(xml_path)
|
||||
self.gripper = gripper
|
||||
self.robot_model = mujoco.MjModel.from_xml_path(filename=self.xml_path, assets=None)
|
||||
self.robot_data = mujoco.MjData(self.robot_model)
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import time
|
||||
import os,collections,sys
|
||||
import os
|
||||
import numpy as np
|
||||
import h5py
|
||||
from roboimi.envs.double_pos_ctrl_env import make_sim_env
|
||||
from diana_policy import TestPickAndTransferPolicy
|
||||
import cv2
|
||||
from roboimi.utils.act_ex_utils import sample_transfer_pose
|
||||
from roboimi.utils.streaming_episode_writer import StreamingEpisodeWriter
|
||||
|
||||
import pathlib
|
||||
HOME_PATH = str(pathlib.Path(__file__).parent.resolve())
|
||||
@@ -16,14 +16,12 @@ def main():
|
||||
task_name = 'sim_transfer'
|
||||
dataset_dir = DATASET_DIR + '/sim_transfer' #SIM_TASK_CONFIGS[task_name]['dataset_dir']
|
||||
num_episodes = 100 #SIM_TASK_CONFIGS[task_name]['num_episodes']
|
||||
onscreen_render = None #config['onscreen_render']
|
||||
inject_noise = False
|
||||
render_cam_name = 'angle'
|
||||
|
||||
episode_len = 700 #SIM_TASK_CONFIGS[task_name]['episode_len']
|
||||
camera_names = ['angle','r_vis', 'top', 'front'] #SIM_TASK_CONFIGS[task_name]['camera_names']
|
||||
image_size = (256, 256)
|
||||
if task_name == 'sim_transfer':
|
||||
policy = TestPickAndTransferPolicy(inject_noise)
|
||||
print(task_name)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
@@ -39,62 +37,38 @@ def main():
|
||||
print("osmesa已就绪,开始收集数据...")
|
||||
|
||||
for episode_idx in range(num_episodes):
|
||||
obs = []
|
||||
reward_ee = []
|
||||
sum_reward = 0.0
|
||||
max_reward = float('-inf')
|
||||
print(f'\n{episode_idx=}')
|
||||
print('Rollout out EE space scripted policy')
|
||||
box_pos = sample_transfer_pose()
|
||||
env.reset(box_pos)
|
||||
episode_writer = StreamingEpisodeWriter(
|
||||
dataset_path=os.path.join(dataset_dir, f'episode_{episode_idx}.hdf5'),
|
||||
max_timesteps=episode_len,
|
||||
camera_names=camera_names,
|
||||
image_size=image_size,
|
||||
)
|
||||
for step in range(episode_len):
|
||||
|
||||
|
||||
action = policy.predict(box_pos,step)
|
||||
env.step(action)
|
||||
raw_action = policy.predict(box_pos,step)
|
||||
env.step(raw_action)
|
||||
env.render()
|
||||
reward_ee.append(env.rew)
|
||||
obs.append(env.obs)
|
||||
sum_reward = np.sum(reward_ee)
|
||||
max_reward = np.max(reward_ee)
|
||||
sum_reward += env.rew
|
||||
max_reward = max(max_reward, env.rew)
|
||||
episode_writer.append(
|
||||
qpos=env.obs['qpos'],
|
||||
action=raw_action,
|
||||
images=env.obs['images'],
|
||||
)
|
||||
if max_reward == env.max_reward:
|
||||
success.append(1)
|
||||
print(f"{episode_idx=} Successful, {sum_reward=}")
|
||||
t0 = time.time()
|
||||
data_dict = {
|
||||
'/observations/qpos': [],
|
||||
'/action': [],
|
||||
}
|
||||
|
||||
for cam_name in camera_names:
|
||||
data_dict[f'/observations/images/{cam_name}'] = []
|
||||
for i in range(episode_len):
|
||||
print("type qpos==",obs[i]['qpos'])
|
||||
data_dict['/observations/qpos'].append(obs[i]['qpos'])
|
||||
data_dict['/action'].append(obs[i]['action'])
|
||||
for cam_name in camera_names:
|
||||
data_dict[f'/observations/images/{cam_name}'].append(obs[i]['images'][cam_name])
|
||||
|
||||
dataset_path = os.path.join(dataset_dir, f'episode_{episode_idx}')
|
||||
|
||||
with h5py.File(dataset_path + '.hdf5', 'w', rdcc_nbytes=1024 ** 2 * 2) as root:
|
||||
max_timesteps = episode_len
|
||||
root.attrs['sim'] = True
|
||||
obs_ = root.create_group('observations')
|
||||
image = obs_.create_group('images')
|
||||
for cam_name in camera_names:
|
||||
_ = image.create_dataset(cam_name, (max_timesteps, 480, 640, 3), dtype='uint8',
|
||||
chunks=(1, 480, 640, 3), )
|
||||
qpos = obs_.create_dataset('qpos', (max_timesteps, 16))
|
||||
action = root.create_dataset('action', (max_timesteps, 16))
|
||||
for name, array in data_dict.items():
|
||||
root[name][...] = np.array(array)
|
||||
episode_writer.commit()
|
||||
else:
|
||||
success.append(0)
|
||||
print(f"{episode_idx=} Failed")
|
||||
print(max_reward)
|
||||
del obs
|
||||
del reward_ee
|
||||
del sum_reward
|
||||
del max_reward
|
||||
episode_writer.discard()
|
||||
|
||||
# del policy
|
||||
# env.viewer.close()
|
||||
|
||||
36
roboimi/demos/view_raw_action_trajectory.py
Normal file
36
roboimi/demos/view_raw_action_trajectory.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from roboimi.utils.raw_action_trajectory_viewer import launch_raw_action_trajectory_viewer
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Launch an interactive MuJoCo viewer with raw-action trajectory overlay.")
|
||||
parser.add_argument("trajectory_path", help="Path to raw_action.npy or trajectory.npz")
|
||||
parser.add_argument("--task-name", default="sim_transfer")
|
||||
parser.add_argument("--line-radius", type=float, default=0.004)
|
||||
parser.add_argument("--max-markers", type=int, default=1500)
|
||||
parser.add_argument(
|
||||
"--box-pos",
|
||||
type=float,
|
||||
nargs=3,
|
||||
default=None,
|
||||
help="Optional box xyz to use when resetting the environment",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
box_pos = np.asarray(args.box_pos, dtype=np.float32) if args.box_pos is not None else None
|
||||
launch_raw_action_trajectory_viewer(
|
||||
args.trajectory_path,
|
||||
task_name=args.task_name,
|
||||
line_radius=args.line_radius,
|
||||
max_markers=args.max_markers,
|
||||
box_pos=box_pos,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -3,6 +3,7 @@ import os
|
||||
import logging
|
||||
import json
|
||||
import pickle
|
||||
import importlib
|
||||
import hydra
|
||||
import torch
|
||||
import re
|
||||
@@ -13,8 +14,58 @@ from torch.optim import AdamW
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from pathlib import Path
|
||||
|
||||
# 确保正确的导入路径
|
||||
sys.path.append(os.getcwd())
|
||||
# 确保正确的导入路径(不能依赖 cwd,因为 Hydra 会在运行时切换 cwd)
|
||||
def _ensure_repo_root_on_syspath():
|
||||
repo_root = Path(__file__).resolve().parents[3]
|
||||
repo_root_str = str(repo_root)
|
||||
if repo_root_str in sys.path:
|
||||
sys.path.remove(repo_root_str)
|
||||
sys.path.insert(0, repo_root_str)
|
||||
return repo_root
|
||||
|
||||
|
||||
_PROBLEMATIC_LD_PRELOAD_SUBSTRINGS = ('/usr/NX/lib/libnxegl.so', 'libnxegl.so')
|
||||
|
||||
|
||||
def _clean_ld_preload_value(value: str | None):
|
||||
if not value:
|
||||
return value, False
|
||||
entries = [entry for entry in value.split() if entry]
|
||||
filtered = [
|
||||
entry for entry in entries
|
||||
if not any(marker in entry for marker in _PROBLEMATIC_LD_PRELOAD_SUBSTRINGS)
|
||||
]
|
||||
changed = filtered != entries
|
||||
cleaned = ' '.join(filtered) if filtered else None
|
||||
return cleaned, changed
|
||||
|
||||
|
||||
def _maybe_reexec_without_problematic_ld_preload():
|
||||
if __name__ != '__main__':
|
||||
return False
|
||||
if os.environ.get('_ROBOIMI_LD_PRELOAD_SANITIZED') == '1':
|
||||
return False
|
||||
|
||||
cleaned, changed = _clean_ld_preload_value(os.environ.get('LD_PRELOAD'))
|
||||
if not changed:
|
||||
return False
|
||||
|
||||
new_env = dict(os.environ)
|
||||
new_env['_ROBOIMI_LD_PRELOAD_SANITIZED'] = '1'
|
||||
if cleaned:
|
||||
new_env['LD_PRELOAD'] = cleaned
|
||||
else:
|
||||
new_env.pop('LD_PRELOAD', None)
|
||||
|
||||
print(
|
||||
'Detected problematic LD_PRELOAD entry for CUDA/cuDNN; re-executing train_vla.py without it.',
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
os.execvpe(sys.executable, [sys.executable, *sys.argv], new_env)
|
||||
|
||||
|
||||
_REPO_ROOT = _ensure_repo_root_on_syspath()
|
||||
|
||||
from hydra.utils import instantiate
|
||||
|
||||
@@ -25,6 +76,28 @@ if not OmegaConf.has_resolver("len"):
|
||||
OmegaConf.register_new_resolver("len", lambda x: len(x))
|
||||
|
||||
|
||||
def _resolve_run_output_dir() -> Path:
|
||||
try:
|
||||
from hydra.core.hydra_config import HydraConfig
|
||||
if HydraConfig.initialized():
|
||||
output_dir = HydraConfig.get().runtime.output_dir
|
||||
if output_dir:
|
||||
return Path(output_dir).resolve()
|
||||
except Exception:
|
||||
pass
|
||||
return Path.cwd().resolve()
|
||||
|
||||
|
||||
_maybe_reexec_without_problematic_ld_preload()
|
||||
|
||||
|
||||
def _configure_cuda_runtime(cfg):
|
||||
"""Apply process-level CUDA runtime switches required by this environment."""
|
||||
if str(cfg.train.device).startswith('cuda') and bool(cfg.train.get('disable_cudnn', False)):
|
||||
torch.backends.cudnn.enabled = False
|
||||
log.warning('⚠️ 已按配置禁用 cuDNN;GPU 卷积将回退到非-cuDNN 实现')
|
||||
|
||||
|
||||
def recursive_to_device(data, device):
|
||||
"""
|
||||
递归地将嵌套字典/列表中的张量移动到指定设备。
|
||||
@@ -45,6 +118,127 @@ def recursive_to_device(data, device):
|
||||
return data
|
||||
|
||||
|
||||
def build_agent_input(batch_data):
|
||||
agent_input = {
|
||||
'images': {
|
||||
cam_name.replace('observation.', ''): value
|
||||
for cam_name, value in batch_data.items()
|
||||
if cam_name.startswith('observation.') and cam_name != 'observation.state'
|
||||
},
|
||||
'qpos': batch_data['observation.state'],
|
||||
'action': batch_data['action'],
|
||||
}
|
||||
|
||||
if 'action_is_pad' in batch_data:
|
||||
agent_input['action_is_pad'] = batch_data['action_is_pad']
|
||||
|
||||
lewm_images = {
|
||||
cam_name.replace('lewm.observation.', ''): value
|
||||
for cam_name, value in batch_data.items()
|
||||
if cam_name.startswith('lewm.observation.') and cam_name != 'lewm.observation.state'
|
||||
}
|
||||
if lewm_images:
|
||||
agent_input['lewm_images'] = lewm_images
|
||||
if 'lewm.observation.state' in batch_data:
|
||||
agent_input['lewm_qpos'] = batch_data['lewm.observation.state']
|
||||
|
||||
lewm_future_images = {
|
||||
cam_name.replace('lewm.future.', ''): value
|
||||
for cam_name, value in batch_data.items()
|
||||
if cam_name.startswith('lewm.future.') and cam_name != 'lewm.future.state'
|
||||
}
|
||||
if lewm_future_images:
|
||||
agent_input['lewm_future_images'] = lewm_future_images
|
||||
if 'lewm.future.state' in batch_data:
|
||||
agent_input['lewm_future_qpos'] = batch_data['lewm.future.state']
|
||||
|
||||
return agent_input
|
||||
|
||||
|
||||
def _instantiate_dataset(cfg, dataset_image_resize_shape, episode_indices=None):
|
||||
kwargs = {'image_resize_shape': dataset_image_resize_shape}
|
||||
if episode_indices is not None:
|
||||
kwargs['episode_indices'] = episode_indices
|
||||
return instantiate(cfg.data, **kwargs)
|
||||
|
||||
|
||||
def build_train_val_datasets(cfg, dataset_image_resize_shape):
|
||||
val_episode_indices = cfg.train.get('val_episode_indices', None)
|
||||
if val_episode_indices:
|
||||
dataset = _instantiate_dataset(cfg, dataset_image_resize_shape)
|
||||
available_episode_indices = list(getattr(dataset, 'available_episode_indices', []))
|
||||
if not available_episode_indices:
|
||||
raise ValueError('显式 val_episode_indices 需要数据集暴露 available_episode_indices')
|
||||
requested_val_episode_indices = sorted(int(idx) for idx in val_episode_indices)
|
||||
available_set = set(available_episode_indices)
|
||||
missing = sorted(set(requested_val_episode_indices) - available_set)
|
||||
if missing:
|
||||
raise ValueError(
|
||||
f'val_episode_indices {missing} 不存在于数据集可用 episodes {available_episode_indices}'
|
||||
)
|
||||
train_episode_indices = [
|
||||
idx for idx in available_episode_indices
|
||||
if idx not in set(requested_val_episode_indices)
|
||||
]
|
||||
if not train_episode_indices:
|
||||
raise ValueError('显式 val_episode_indices 不能覆盖全部 episodes,训练集将为空')
|
||||
|
||||
train_dataset = _instantiate_dataset(
|
||||
cfg,
|
||||
dataset_image_resize_shape,
|
||||
episode_indices=train_episode_indices,
|
||||
)
|
||||
val_dataset = _instantiate_dataset(
|
||||
cfg,
|
||||
dataset_image_resize_shape,
|
||||
episode_indices=requested_val_episode_indices,
|
||||
)
|
||||
return dataset, train_dataset, val_dataset, requested_val_episode_indices
|
||||
|
||||
dataset = _instantiate_dataset(cfg, dataset_image_resize_shape)
|
||||
val_split = float(cfg.train.get('val_split', 0.1))
|
||||
seed = int(cfg.train.get('seed', 42))
|
||||
val_size = int(len(dataset) * val_split)
|
||||
train_size = len(dataset) - val_size
|
||||
if val_size > 0:
|
||||
train_dataset, val_dataset = random_split(
|
||||
dataset,
|
||||
[train_size, val_size],
|
||||
generator=torch.Generator().manual_seed(seed)
|
||||
)
|
||||
else:
|
||||
train_dataset, val_dataset = dataset, None
|
||||
return dataset, train_dataset, val_dataset, None
|
||||
|
||||
|
||||
def compute_action_mse_validation(agent, val_loader, device):
|
||||
if val_loader is None:
|
||||
return None
|
||||
|
||||
was_training = agent.training
|
||||
agent.eval()
|
||||
total_squared_error = 0.0
|
||||
total_count = 0.0
|
||||
with torch.no_grad():
|
||||
for val_batch in val_loader:
|
||||
val_batch = recursive_to_device(val_batch, device)
|
||||
val_input = build_agent_input(val_batch)
|
||||
pred_actions = agent.predict_action_chunk(val_input)
|
||||
target_actions = val_input['action']
|
||||
squared_error = (pred_actions - target_actions).pow(2)
|
||||
action_is_pad = val_input.get('action_is_pad', None)
|
||||
if action_is_pad is not None:
|
||||
mask = (~action_is_pad).unsqueeze(-1).to(squared_error.dtype)
|
||||
total_squared_error += (squared_error * mask).sum().item()
|
||||
total_count += mask.sum().item() * squared_error.shape[-1]
|
||||
else:
|
||||
total_squared_error += squared_error.sum().item()
|
||||
total_count += target_actions.numel()
|
||||
if was_training:
|
||||
agent.train()
|
||||
return total_squared_error / max(total_count, 1.0)
|
||||
|
||||
|
||||
def resolve_resume_checkpoint(resume_ckpt, checkpoint_dir):
|
||||
"""
|
||||
解析恢复训练用的 checkpoint 路径。
|
||||
@@ -111,8 +305,196 @@ def get_lr_schedule_with_warmup(optimizer, warmup_steps, max_steps, scheduler_ty
|
||||
return LambdaLR(optimizer, lr_lambda)
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config")
|
||||
def main(cfg: DictConfig):
|
||||
def build_training_optimizer(agent, lr, weight_decay):
|
||||
"""为训练脚本构建优化器,优先复用任意 head 自带的参数分组。"""
|
||||
trainable_params = [param for param in agent.parameters() if param.requires_grad]
|
||||
noise_pred_net = getattr(agent, 'noise_pred_net', None)
|
||||
get_optim_groups = getattr(noise_pred_net, 'get_optim_groups', None)
|
||||
use_head_groups = callable(get_optim_groups)
|
||||
|
||||
if not use_head_groups:
|
||||
return AdamW(trainable_params, lr=lr, weight_decay=weight_decay)
|
||||
|
||||
head_groups = []
|
||||
grouped_param_ids = set()
|
||||
for group in get_optim_groups(weight_decay=weight_decay):
|
||||
params = [param for param in group['params'] if param.requires_grad]
|
||||
if not params:
|
||||
continue
|
||||
normalized_group = dict(group)
|
||||
normalized_group['params'] = params
|
||||
head_groups.append(normalized_group)
|
||||
|
||||
for param in params:
|
||||
param_id = id(param)
|
||||
if param_id in grouped_param_ids:
|
||||
raise ValueError('Head optimizer groups contain duplicate parameters')
|
||||
grouped_param_ids.add(param_id)
|
||||
|
||||
head_trainable_param_ids = {
|
||||
id(param) for param in noise_pred_net.parameters() if param.requires_grad
|
||||
}
|
||||
missing_head_param_ids = head_trainable_param_ids - grouped_param_ids
|
||||
if missing_head_param_ids:
|
||||
raise ValueError('Head optimizer groups missed trainable head parameters')
|
||||
|
||||
remaining_params = [
|
||||
param for param in trainable_params
|
||||
if id(param) not in grouped_param_ids
|
||||
]
|
||||
|
||||
optim_groups = head_groups
|
||||
if remaining_params:
|
||||
optim_groups = optim_groups + [{
|
||||
'params': remaining_params,
|
||||
'weight_decay': weight_decay,
|
||||
}]
|
||||
grouped_param_ids.update(id(param) for param in remaining_params)
|
||||
|
||||
all_trainable_param_ids = {id(param) for param in trainable_params}
|
||||
if grouped_param_ids != all_trainable_param_ids:
|
||||
raise ValueError('Optimizer parameter groups must include each trainable parameter exactly once')
|
||||
|
||||
return AdamW(optim_groups, lr=lr, weight_decay=weight_decay)
|
||||
|
||||
|
||||
def load_state_dict_ignoring_shape_mismatches(module, incoming_state_dict):
|
||||
"""Load only checkpoint tensors whose keys exist locally and whose shapes match."""
|
||||
current_state_dict = module.state_dict()
|
||||
compatible_state_dict = {}
|
||||
mismatched_keys = []
|
||||
missing_keys = []
|
||||
|
||||
for key, value in incoming_state_dict.items():
|
||||
if key not in current_state_dict:
|
||||
missing_keys.append(key)
|
||||
continue
|
||||
if current_state_dict[key].shape != value.shape:
|
||||
mismatched_keys.append(key)
|
||||
continue
|
||||
compatible_state_dict[key] = value
|
||||
|
||||
merged_state_dict = dict(current_state_dict)
|
||||
merged_state_dict.update(compatible_state_dict)
|
||||
module.load_state_dict(merged_state_dict, strict=True)
|
||||
return {
|
||||
'loaded_keys': sorted(compatible_state_dict.keys()),
|
||||
'missing_keys': sorted(missing_keys),
|
||||
'mismatched_keys': sorted(mismatched_keys),
|
||||
}
|
||||
|
||||
|
||||
def _init_swanlab(cfg):
|
||||
"""按需初始化 SwanLab,并在缺少依赖或认证失败时快速失败。"""
|
||||
if not bool(cfg.train.get('use_swanlab', False)):
|
||||
return None
|
||||
|
||||
try:
|
||||
swanlab = importlib.import_module("swanlab")
|
||||
except ImportError as exc:
|
||||
raise RuntimeError(
|
||||
"SwanLab logging is enabled, but the 'swanlab' package could not be imported."
|
||||
) from exc
|
||||
|
||||
def _to_plain_config(value):
|
||||
if isinstance(value, dict):
|
||||
return {key: _to_plain_config(val) for key, val in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [_to_plain_config(item) for item in value]
|
||||
if isinstance(value, tuple):
|
||||
return tuple(_to_plain_config(item) for item in value)
|
||||
|
||||
items_method = getattr(value, 'items', None)
|
||||
if callable(items_method):
|
||||
try:
|
||||
return {key: _to_plain_config(val) for key, val in items_method()}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return value
|
||||
|
||||
swanlab_config = {
|
||||
key: _to_plain_config(cfg[key])
|
||||
for key in ('train', 'data', 'agent')
|
||||
if key in cfg
|
||||
}
|
||||
|
||||
init_kwargs = {
|
||||
'project': cfg.train.get('swanlab_project', 'roboimi-vla'),
|
||||
'config': swanlab_config,
|
||||
}
|
||||
run_name = cfg.train.get('swanlab_run_name', None)
|
||||
if run_name:
|
||||
init_kwargs['experiment_name'] = run_name
|
||||
|
||||
try:
|
||||
swanlab.init(**init_kwargs)
|
||||
except Exception as exc:
|
||||
raise RuntimeError(
|
||||
f"SwanLab logging is enabled, but SwanLab init/login failed: {exc}"
|
||||
) from exc
|
||||
|
||||
return swanlab
|
||||
|
||||
|
||||
def _log_to_swanlab(swanlab_module, payload, step=None):
|
||||
if swanlab_module is None:
|
||||
return
|
||||
try:
|
||||
swanlab_module.log(payload, step=step)
|
||||
except Exception as exc:
|
||||
log.warning(f"SwanLab log failed at step {step}: {exc}")
|
||||
|
||||
|
||||
def _log_rollout_trajectory_images_to_swanlab(
|
||||
swanlab_module,
|
||||
rollout_stats,
|
||||
step=None,
|
||||
context_label: str = 'rollout',
|
||||
):
|
||||
if swanlab_module is None or not rollout_stats:
|
||||
return
|
||||
|
||||
image_factory = getattr(swanlab_module, 'Image', None)
|
||||
if image_factory is None:
|
||||
return
|
||||
|
||||
payload = {}
|
||||
for fallback_episode_index, episode in enumerate(rollout_stats.get('episodes', [])):
|
||||
if not isinstance(episode, dict):
|
||||
continue
|
||||
artifact_paths = episode.get('artifact_paths', {})
|
||||
if not isinstance(artifact_paths, dict):
|
||||
continue
|
||||
trajectory_image = artifact_paths.get('trajectory_image')
|
||||
if not trajectory_image:
|
||||
continue
|
||||
episode_index = int(episode.get('episode_index', fallback_episode_index))
|
||||
caption = f'{context_label} trajectory image - episode {episode_index} (front)'
|
||||
try:
|
||||
payload[f'rollout/trajectory_image_episode_{episode_index}'] = image_factory(
|
||||
str(trajectory_image),
|
||||
caption=caption,
|
||||
)
|
||||
except Exception as exc:
|
||||
log.warning(
|
||||
f"SwanLab rollout trajectory image upload prep failed at step {step}: {exc}"
|
||||
)
|
||||
|
||||
if payload:
|
||||
_log_to_swanlab(swanlab_module, payload, step=step)
|
||||
|
||||
|
||||
def _finish_swanlab(swanlab_module):
|
||||
if swanlab_module is None:
|
||||
return
|
||||
try:
|
||||
swanlab_module.finish()
|
||||
except Exception as exc:
|
||||
log.warning(f"SwanLab finish failed: {exc}")
|
||||
|
||||
|
||||
def _run_training(cfg: DictConfig):
|
||||
"""
|
||||
VLA 训练脚本(ResNet 骨干网络 + Diffusion 策略)
|
||||
|
||||
@@ -131,57 +513,77 @@ def main(cfg: DictConfig):
|
||||
print("=" * 80)
|
||||
|
||||
log.info(f"🚀 开始 VLA 训练 (设备: {cfg.train.device})")
|
||||
|
||||
_configure_cuda_runtime(cfg)
|
||||
swanlab_module = _init_swanlab(cfg)
|
||||
try:
|
||||
# 创建检查点目录
|
||||
checkpoint_dir = Path("checkpoints")
|
||||
checkpoint_dir.mkdir(exist_ok=True)
|
||||
run_output_dir = _resolve_run_output_dir()
|
||||
checkpoint_dir = run_output_dir / "checkpoints"
|
||||
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
default_best_model_path = checkpoint_dir / "vla_model_best.pt"
|
||||
|
||||
# =========================================================================
|
||||
# 1. 实例化数据集与 DataLoader
|
||||
# =========================================================================
|
||||
log.info("📦 加载数据集...")
|
||||
try:
|
||||
dataset = instantiate(cfg.data)
|
||||
dataset_image_resize_shape = cfg.data.get('image_resize_shape', (224, 224))
|
||||
vision_backbone_cfg = cfg.agent.get('vision_backbone', None)
|
||||
if vision_backbone_cfg is not None and 'dataset_image_resize_shape' in vision_backbone_cfg:
|
||||
dataset_image_resize_shape = vision_backbone_cfg.get('dataset_image_resize_shape')
|
||||
dataset, train_dataset, val_dataset, explicit_val_episode_indices = (
|
||||
build_train_val_datasets(cfg, dataset_image_resize_shape)
|
||||
)
|
||||
log.info(f"✅ 数据集加载成功。总样本数: {len(dataset)}")
|
||||
except Exception as e:
|
||||
log.error(f"❌ 数据集加载失败: {e}")
|
||||
raise
|
||||
|
||||
# 训练/验证集划分
|
||||
val_split = float(cfg.train.get('val_split', 0.1))
|
||||
seed = int(cfg.train.get('seed', 42))
|
||||
val_size = int(len(dataset) * val_split)
|
||||
train_size = len(dataset) - val_size
|
||||
if val_size > 0:
|
||||
train_dataset, val_dataset = random_split(
|
||||
dataset,
|
||||
[train_size, val_size],
|
||||
generator=torch.Generator().manual_seed(seed)
|
||||
if explicit_val_episode_indices is not None:
|
||||
log.info(
|
||||
"✅ 数据集划分: 训练集=%s, 验证集=%s (显式 held-out episodes=%s)",
|
||||
len(train_dataset),
|
||||
len(val_dataset),
|
||||
explicit_val_episode_indices,
|
||||
)
|
||||
else:
|
||||
val_split = float(cfg.train.get('val_split', 0.1))
|
||||
val_size = len(val_dataset) if val_dataset is not None else 0
|
||||
if val_size > 0:
|
||||
log.info(
|
||||
f"✅ 数据集划分: 训练集={len(train_dataset)}, 验证集={val_size} (验证比例={val_split})"
|
||||
)
|
||||
log.info(f"✅ 数据集划分: 训练集={train_size}, 验证集={val_size} (验证比例={val_split})")
|
||||
else:
|
||||
train_dataset, val_dataset = dataset, None
|
||||
log.info("✅ 数据集划分: 全部用于训练, 验证集=0 (验证比例=0)")
|
||||
|
||||
train_batch_size = int(cfg.train.batch_size)
|
||||
train_drop_last = len(train_dataset) >= train_batch_size
|
||||
if not train_drop_last:
|
||||
log.warning(
|
||||
"⚠️ 训练集样本数 (%s) 小于 batch_size (%s),将保留最后一个不完整批次以避免空训练加载器",
|
||||
len(train_dataset),
|
||||
train_batch_size,
|
||||
)
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=cfg.train.batch_size,
|
||||
batch_size=train_batch_size,
|
||||
shuffle=True,
|
||||
num_workers=cfg.train.num_workers,
|
||||
pin_memory=(cfg.train.device != "cpu"),
|
||||
persistent_workers=(cfg.train.num_workers > 0),
|
||||
drop_last=True # 丢弃不完整批次以稳定训练
|
||||
persistent_workers=False,
|
||||
drop_last=train_drop_last
|
||||
)
|
||||
|
||||
val_loader = None
|
||||
if val_dataset is not None:
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=cfg.train.batch_size,
|
||||
batch_size=train_batch_size,
|
||||
shuffle=False,
|
||||
num_workers=cfg.train.num_workers,
|
||||
pin_memory=(cfg.train.device != "cpu"),
|
||||
persistent_workers=(cfg.train.num_workers > 0),
|
||||
persistent_workers=False,
|
||||
drop_last=False
|
||||
)
|
||||
|
||||
@@ -254,18 +656,23 @@ def main(cfg: DictConfig):
|
||||
try:
|
||||
checkpoint = torch.load(ckpt_path, map_location=cfg.train.device)
|
||||
|
||||
# 只加载模型权重(不加载 optimizer、scheduler)
|
||||
missing_keys, unexpected_keys = agent.load_state_dict(
|
||||
load_info = load_state_dict_ignoring_shape_mismatches(
|
||||
agent,
|
||||
checkpoint['model_state_dict'],
|
||||
strict=False # 允许部分加载(结构不完全匹配时)
|
||||
)
|
||||
|
||||
log.info(f"✅ [Finetune] 模型权重加载成功")
|
||||
|
||||
if missing_keys:
|
||||
log.warning(f"⚠️ [Finetune] 缺少的键 ({len(missing_keys)} 个): {missing_keys[:5]}...")
|
||||
if unexpected_keys:
|
||||
log.warning(f"⚠️ [Finetune] 多余的键 ({len(unexpected_keys)} 个): {unexpected_keys[:5]}...")
|
||||
if load_info['missing_keys']:
|
||||
log.warning(
|
||||
f"⚠️ [Finetune] checkpoint 中存在本地模型没有的键 ({len(load_info['missing_keys'])} 个): "
|
||||
f"{load_info['missing_keys'][:5]}..."
|
||||
)
|
||||
if load_info['mismatched_keys']:
|
||||
log.warning(
|
||||
f"⚠️ [Finetune] 因形状不匹配而跳过的键 ({len(load_info['mismatched_keys'])} 个): "
|
||||
f"{load_info['mismatched_keys'][:5]}..."
|
||||
)
|
||||
|
||||
log.info(f"📊 [Finetune] 预训练信息: 步骤={checkpoint.get('step', 'N/A')}, 损失={checkpoint.get('loss', 'N/A')}")
|
||||
log.info(f"📈 [Finetune] 使用新的训练配置(lr={cfg.train.lr}, max_steps={cfg.train.max_steps})")
|
||||
@@ -283,7 +690,7 @@ def main(cfg: DictConfig):
|
||||
weight_decay = float(cfg.train.get('weight_decay', 1e-5))
|
||||
grad_clip = float(cfg.train.get('grad_clip', 1.0))
|
||||
|
||||
optimizer = AdamW(agent.parameters(), lr=cfg.train.lr, weight_decay=weight_decay)
|
||||
optimizer = build_training_optimizer(agent, lr=cfg.train.lr, weight_decay=weight_decay)
|
||||
log.info(f"🔧 优化器: AdamW (学习率={cfg.train.lr}, weight_decay={weight_decay})")
|
||||
|
||||
# 设置带预热的学習率调度器
|
||||
@@ -303,9 +710,26 @@ def main(cfg: DictConfig):
|
||||
# =========================================================================
|
||||
# 4.1 断点续训(恢复模型、优化器、调度器、步数)
|
||||
# =========================================================================
|
||||
def extract_checkpoint_metric_baseline(checkpoint):
|
||||
checkpoint_loss = checkpoint.get('loss', None)
|
||||
checkpoint_val_loss = checkpoint.get('val_loss', None)
|
||||
checkpoint_rollout_reward = checkpoint.get('rollout_avg_reward', None)
|
||||
|
||||
baseline_loss = float('inf')
|
||||
baseline_rollout_reward = float('-inf')
|
||||
if checkpoint_rollout_reward is not None:
|
||||
baseline_rollout_reward = float(checkpoint_rollout_reward)
|
||||
if checkpoint_val_loss is not None:
|
||||
baseline_loss = float(checkpoint_val_loss)
|
||||
elif checkpoint_loss is not None:
|
||||
baseline_loss = float(checkpoint_loss)
|
||||
return baseline_loss, baseline_rollout_reward
|
||||
|
||||
start_step = 0
|
||||
resume_loss = None
|
||||
resume_best_loss = float('inf')
|
||||
resume_best_rollout_reward = float('-inf')
|
||||
best_model_path = None
|
||||
|
||||
resume_ckpt = cfg.train.get('resume_ckpt', None)
|
||||
resume_path = resolve_resume_checkpoint(resume_ckpt, checkpoint_dir)
|
||||
@@ -330,12 +754,31 @@ def main(cfg: DictConfig):
|
||||
start_step = resume_step + 1
|
||||
|
||||
loaded_loss = checkpoint.get('loss', None)
|
||||
loaded_val_loss = checkpoint.get('val_loss', None)
|
||||
resume_loss = float(loaded_loss) if loaded_loss is not None else None
|
||||
if loaded_val_loss is not None:
|
||||
resume_best_loss = float(loaded_val_loss)
|
||||
elif loaded_loss is not None:
|
||||
resume_best_loss = float(loaded_loss)
|
||||
resume_best_loss, resume_best_rollout_reward = extract_checkpoint_metric_baseline(checkpoint)
|
||||
if (
|
||||
resume_best_rollout_reward != float('-inf')
|
||||
or resume_best_loss != float('inf')
|
||||
):
|
||||
best_model_path = resume_path
|
||||
|
||||
if default_best_model_path.exists():
|
||||
try:
|
||||
best_checkpoint = torch.load(default_best_model_path, map_location=cfg.train.device)
|
||||
_, best_checkpoint_rollout_reward = (
|
||||
extract_checkpoint_metric_baseline(best_checkpoint)
|
||||
)
|
||||
if best_checkpoint_rollout_reward != float('-inf'):
|
||||
resume_best_rollout_reward = best_checkpoint_rollout_reward
|
||||
best_model_path = default_best_model_path
|
||||
log.info(
|
||||
"📈 [Resume] 从最佳 checkpoint 恢复最佳 rollout 基线: %s",
|
||||
default_best_model_path,
|
||||
)
|
||||
except Exception as e:
|
||||
log.warning(
|
||||
f"⚠️ [Resume] 读取最佳 checkpoint 失败,将回退到恢复 checkpoint 的验证基线: {e}"
|
||||
)
|
||||
|
||||
log.info(f"✅ [Resume] 恢复成功: 上次步骤={resume_step}, 本次从步骤 {start_step} 开始")
|
||||
log.info(f"📈 [Resume] 当前学习率: {optimizer.param_groups[0]['lr']:.2e}")
|
||||
@@ -345,27 +788,27 @@ def main(cfg: DictConfig):
|
||||
start_step = 0
|
||||
resume_loss = None
|
||||
resume_best_loss = float('inf')
|
||||
resume_best_rollout_reward = float('-inf')
|
||||
|
||||
# =========================================================================
|
||||
# 5. 训练循环
|
||||
# =========================================================================
|
||||
log.info("🏋️ 开始训练循环...")
|
||||
|
||||
def build_agent_input(batch_data):
|
||||
"""构建 agent 输入格式"""
|
||||
images = {}
|
||||
# SimpleRobotDataset 返回 observation.{cam_name} 格式
|
||||
for cam_name in cfg.data.camera_names:
|
||||
key = f"observation.{cam_name}"
|
||||
if key in batch_data:
|
||||
images[cam_name] = batch_data[key]
|
||||
|
||||
return {
|
||||
'images': images,
|
||||
'qpos': batch_data['observation.state'], # SimpleRobotDataset 使用 observation.state
|
||||
'action': batch_data['action'],
|
||||
'action_is_pad': batch_data.get('action_is_pad', None) # 传递padding mask
|
||||
}
|
||||
def save_checkpoint(checkpoint_path: Path, step: int, loss_value, val_loss=None, rollout_avg_reward=None):
|
||||
agent_stats = agent.get_normalization_stats()
|
||||
torch.save({
|
||||
'step': step,
|
||||
'model_state_dict': agent.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'scheduler_state_dict': scheduler.state_dict(),
|
||||
'loss': loss_value,
|
||||
'val_loss': val_loss,
|
||||
'rollout_avg_reward': rollout_avg_reward,
|
||||
'dataset_stats': agent_stats, # 保存agent的统计信息
|
||||
'current_lr': optimizer.param_groups[0]['lr'],
|
||||
}, checkpoint_path)
|
||||
return checkpoint_path
|
||||
|
||||
def run_validation():
|
||||
"""运行验证"""
|
||||
@@ -391,10 +834,64 @@ def main(cfg: DictConfig):
|
||||
agent.train()
|
||||
return total_loss / max(num_batches, 1)
|
||||
|
||||
def run_rollout_validation(checkpoint_path: Path):
|
||||
from roboimi.demos.vla_scripts import eval_vla
|
||||
|
||||
rollout_cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=False))
|
||||
rollout_num_episodes = int(cfg.train.get('rollout_num_episodes', 1))
|
||||
rollout_device = str(cfg.train.get('rollout_device', cfg.train.device))
|
||||
configured_rollout_workers = cfg.train.get('rollout_num_workers', None)
|
||||
if configured_rollout_workers is None:
|
||||
if rollout_device.startswith('cuda'):
|
||||
rollout_num_workers = min(max(rollout_num_episodes, 1), 8)
|
||||
else:
|
||||
rollout_num_workers = 1
|
||||
else:
|
||||
rollout_num_workers = int(configured_rollout_workers)
|
||||
rollout_cfg.eval.ckpt_path = str(checkpoint_path)
|
||||
rollout_cfg.eval.num_episodes = rollout_num_episodes
|
||||
rollout_cfg.eval.num_workers = rollout_num_workers
|
||||
rollout_cfg.eval.headless = True
|
||||
rollout_cfg.eval.device = rollout_device
|
||||
rollout_cfg.eval.cuda_devices = cfg.train.get('rollout_cuda_devices', None)
|
||||
rollout_cfg.eval.response_timeout_s = float(
|
||||
cfg.train.get('rollout_response_timeout_s', 300.0)
|
||||
)
|
||||
rollout_cfg.eval.server_startup_timeout_s = float(
|
||||
cfg.train.get('rollout_server_startup_timeout_s', 300.0)
|
||||
)
|
||||
rollout_cfg.eval.verbose_action = False
|
||||
rollout_cfg.eval.record_video = False
|
||||
rollout_cfg.eval.save_trajectory_image = True
|
||||
rollout_cfg.eval.trajectory_image_camera_name = 'front'
|
||||
rollout_cfg.eval.save_summary_json = True
|
||||
rollout_cfg.eval.artifact_dir = str(
|
||||
(run_output_dir / 'rollout_artifacts' / checkpoint_path.stem).resolve()
|
||||
)
|
||||
|
||||
log.info(
|
||||
"🎯 开始 checkpoint rollout 验证: %s (episodes=%s, device=%s, workers=%s, headless=True)",
|
||||
checkpoint_path,
|
||||
rollout_cfg.eval.num_episodes,
|
||||
rollout_cfg.eval.device,
|
||||
rollout_cfg.eval.num_workers,
|
||||
)
|
||||
return eval_vla._run_eval(rollout_cfg)
|
||||
|
||||
def run_checkpoint_rollout_validation(checkpoint_path: Path):
|
||||
if not bool(cfg.train.get('rollout_validate_on_checkpoint', False)):
|
||||
return None
|
||||
return run_rollout_validation(checkpoint_path)
|
||||
|
||||
data_iter = iter(train_loader)
|
||||
pbar = tqdm(range(start_step, cfg.train.max_steps), desc="训练中", ncols=100)
|
||||
|
||||
steps_per_epoch = len(train_loader)
|
||||
action_mse_val_freq_epochs = int(cfg.train.get('action_mse_val_freq_epochs', 0) or 0)
|
||||
rollout_val_freq_epochs = int(cfg.train.get('rollout_val_freq_epochs', 0) or 0)
|
||||
rollout_validation_enabled = rollout_val_freq_epochs > 0
|
||||
best_loss = resume_best_loss
|
||||
best_rollout_reward = resume_best_rollout_reward
|
||||
last_loss = resume_loss
|
||||
|
||||
if start_step >= cfg.train.max_steps:
|
||||
@@ -452,80 +949,230 @@ def main(cfg: DictConfig):
|
||||
# =====================================================================
|
||||
if step % cfg.train.log_freq == 0:
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
best_loss_to_log = best_loss if best_loss != float('inf') else loss.item()
|
||||
pbar.set_postfix({
|
||||
"loss": f"{loss.item():.4f}",
|
||||
"lr": f"{current_lr:.2e}",
|
||||
"best_loss": f"{best_loss:.4f}"
|
||||
"best_loss": f"{best_loss_to_log:.4f}"
|
||||
})
|
||||
log.info(f"步骤 {step}/{cfg.train.max_steps} | 损失: {loss.item():.4f} | 学习率: {current_lr:.2e}")
|
||||
_log_to_swanlab(
|
||||
swanlab_module,
|
||||
{
|
||||
'train/loss': loss.item(),
|
||||
'train/lr': current_lr,
|
||||
'train/best_loss': best_loss_to_log,
|
||||
'train/step': step,
|
||||
},
|
||||
step=step,
|
||||
)
|
||||
if hasattr(agent, 'get_last_loss_breakdown'):
|
||||
loss_breakdown = agent.get_last_loss_breakdown()
|
||||
extra_train_metrics = {
|
||||
f"train/{key}": value
|
||||
for key, value in loss_breakdown.items()
|
||||
if value is not None and key != 'loss'
|
||||
}
|
||||
if extra_train_metrics:
|
||||
_log_to_swanlab(swanlab_module, extra_train_metrics, step=step)
|
||||
|
||||
# =====================================================================
|
||||
# 检查点保存与验证
|
||||
# =====================================================================
|
||||
checkpoint_path = None
|
||||
val_loss = None
|
||||
if step > 0 and step % cfg.train.save_freq == 0:
|
||||
# 运行验证
|
||||
val_loss = run_validation()
|
||||
if val_loss is not None:
|
||||
log.info(f"步骤 {step}/{cfg.train.max_steps} | 验证损失: {val_loss:.4f}")
|
||||
_log_to_swanlab(
|
||||
swanlab_module,
|
||||
{'val/loss': val_loss},
|
||||
step=step,
|
||||
)
|
||||
|
||||
checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt"
|
||||
# 使用agent的归一化统计信息(包含normalization_type)
|
||||
agent_stats = agent.get_normalization_stats()
|
||||
torch.save({
|
||||
'step': step,
|
||||
'model_state_dict': agent.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'scheduler_state_dict': scheduler.state_dict(),
|
||||
'loss': loss.item(),
|
||||
'val_loss': val_loss,
|
||||
'dataset_stats': agent_stats, # 保存agent的统计信息
|
||||
'current_lr': optimizer.param_groups[0]['lr'],
|
||||
}, checkpoint_path)
|
||||
save_checkpoint(
|
||||
checkpoint_path,
|
||||
step,
|
||||
loss.item(),
|
||||
val_loss=val_loss,
|
||||
)
|
||||
log.info(f"💾 检查点已保存: {checkpoint_path}")
|
||||
|
||||
# 根据验证损失保存最佳模型
|
||||
# 在首次拿到 rollout 平均奖励之前,使用损失作为最佳模型回退指标
|
||||
if best_rollout_reward == float('-inf'):
|
||||
eval_loss = val_loss if val_loss is not None else loss.item()
|
||||
if eval_loss < best_loss:
|
||||
best_loss = eval_loss
|
||||
best_model_path = checkpoint_dir / "vla_model_best.pt"
|
||||
agent_stats = agent.get_normalization_stats()
|
||||
torch.save({
|
||||
'step': step,
|
||||
'model_state_dict': agent.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'scheduler_state_dict': scheduler.state_dict(),
|
||||
'loss': loss.item(),
|
||||
'val_loss': val_loss,
|
||||
'dataset_stats': agent_stats, # 保存agent的统计信息
|
||||
'current_lr': optimizer.param_groups[0]['lr'],
|
||||
}, best_model_path)
|
||||
best_model_path = default_best_model_path
|
||||
save_checkpoint(
|
||||
best_model_path,
|
||||
step,
|
||||
loss.item(),
|
||||
val_loss=val_loss,
|
||||
)
|
||||
log.info(f"🌟 最佳模型已更新: {best_model_path} (验证损失: {best_loss:.4f})")
|
||||
|
||||
checkpoint_rollout_stats = run_checkpoint_rollout_validation(checkpoint_path)
|
||||
checkpoint_rollout_avg_reward = (
|
||||
checkpoint_rollout_stats.get('avg_reward')
|
||||
if checkpoint_rollout_stats is not None else None
|
||||
)
|
||||
if checkpoint_rollout_avg_reward is not None:
|
||||
log.info(
|
||||
f"步骤 {step}/{cfg.train.max_steps} | checkpoint rollout 平均奖励: "
|
||||
f"{checkpoint_rollout_avg_reward:.4f}"
|
||||
)
|
||||
_log_to_swanlab(
|
||||
swanlab_module,
|
||||
{'rollout/avg_reward': checkpoint_rollout_avg_reward},
|
||||
step=step,
|
||||
)
|
||||
if checkpoint_rollout_avg_reward > best_rollout_reward:
|
||||
best_rollout_reward = checkpoint_rollout_avg_reward
|
||||
best_model_path = default_best_model_path
|
||||
save_checkpoint(
|
||||
best_model_path,
|
||||
step,
|
||||
loss.item(),
|
||||
val_loss=val_loss,
|
||||
rollout_avg_reward=checkpoint_rollout_avg_reward,
|
||||
)
|
||||
log.info(
|
||||
f"🌟 最佳模型已更新: {best_model_path} "
|
||||
f"(checkpoint rollout 平均奖励: {best_rollout_reward:.4f})"
|
||||
)
|
||||
|
||||
completed_steps = step + 1
|
||||
completed_epoch = (
|
||||
completed_steps // steps_per_epoch
|
||||
if steps_per_epoch > 0 else 0
|
||||
)
|
||||
should_run_epoch_rollout = (
|
||||
rollout_validation_enabled
|
||||
and steps_per_epoch > 0
|
||||
and completed_steps % steps_per_epoch == 0
|
||||
and completed_epoch > 0
|
||||
and completed_epoch % rollout_val_freq_epochs == 0
|
||||
)
|
||||
should_run_action_mse_validation = (
|
||||
action_mse_val_freq_epochs > 0
|
||||
and val_loader is not None
|
||||
and steps_per_epoch > 0
|
||||
and completed_steps % steps_per_epoch == 0
|
||||
and completed_epoch > 0
|
||||
and completed_epoch % action_mse_val_freq_epochs == 0
|
||||
)
|
||||
if should_run_action_mse_validation:
|
||||
action_mse = compute_action_mse_validation(
|
||||
agent,
|
||||
val_loader,
|
||||
cfg.train.device,
|
||||
)
|
||||
if action_mse is not None:
|
||||
log.info(
|
||||
f"步骤 {step}/{cfg.train.max_steps} | Epoch {completed_epoch} "
|
||||
f"held-out action MSE: {action_mse:.6f}"
|
||||
)
|
||||
_log_to_swanlab(
|
||||
swanlab_module,
|
||||
{
|
||||
'val/action_mse': action_mse,
|
||||
'val/action_mse_epoch': completed_epoch,
|
||||
},
|
||||
step=step,
|
||||
)
|
||||
if should_run_epoch_rollout:
|
||||
if checkpoint_path is None:
|
||||
checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt"
|
||||
save_checkpoint(
|
||||
checkpoint_path,
|
||||
step,
|
||||
loss.item(),
|
||||
val_loss=val_loss,
|
||||
)
|
||||
log.info(f"💾 Epoch rollout 验证前检查点已保存: {checkpoint_path}")
|
||||
|
||||
rollout_stats = run_rollout_validation(checkpoint_path)
|
||||
rollout_avg_reward = (
|
||||
rollout_stats.get('avg_reward')
|
||||
if rollout_stats is not None else None
|
||||
)
|
||||
if rollout_avg_reward is not None:
|
||||
log.info(
|
||||
f"步骤 {step}/{cfg.train.max_steps} | Epoch {completed_epoch} "
|
||||
f"rollout 平均奖励: {rollout_avg_reward:.4f}"
|
||||
)
|
||||
_log_to_swanlab(
|
||||
swanlab_module,
|
||||
{
|
||||
'rollout/avg_reward': rollout_avg_reward,
|
||||
'rollout/epoch': completed_epoch,
|
||||
},
|
||||
step=step,
|
||||
)
|
||||
_log_rollout_trajectory_images_to_swanlab(
|
||||
swanlab_module,
|
||||
rollout_stats,
|
||||
step=step,
|
||||
context_label=f'epoch {completed_epoch} rollout',
|
||||
)
|
||||
if rollout_avg_reward > best_rollout_reward:
|
||||
best_rollout_reward = rollout_avg_reward
|
||||
best_model_path = default_best_model_path
|
||||
save_checkpoint(
|
||||
best_model_path,
|
||||
step,
|
||||
loss.item(),
|
||||
val_loss=val_loss,
|
||||
rollout_avg_reward=rollout_avg_reward,
|
||||
)
|
||||
log.info(
|
||||
f"🌟 最佳模型已更新: {best_model_path} "
|
||||
f"(Epoch {completed_epoch} rollout 平均奖励: {best_rollout_reward:.4f})"
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# 6. 保存最终模型
|
||||
# =========================================================================
|
||||
final_model_path = checkpoint_dir / "vla_model_final.pt"
|
||||
agent_stats = agent.get_normalization_stats()
|
||||
torch.save({
|
||||
'step': cfg.train.max_steps,
|
||||
'model_state_dict': agent.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'scheduler_state_dict': scheduler.state_dict(),
|
||||
'loss': last_loss,
|
||||
'dataset_stats': agent_stats, # 保存agent的统计信息
|
||||
'current_lr': optimizer.param_groups[0]['lr'],
|
||||
}, final_model_path)
|
||||
save_checkpoint(
|
||||
final_model_path,
|
||||
cfg.train.max_steps,
|
||||
last_loss,
|
||||
)
|
||||
log.info(f"💾 最终模型已保存: {final_model_path}")
|
||||
_log_to_swanlab(
|
||||
swanlab_module,
|
||||
{
|
||||
'final/checkpoint_path': str(final_model_path),
|
||||
'final/best_checkpoint_path': (
|
||||
str(best_model_path) if best_model_path is not None else ''
|
||||
),
|
||||
},
|
||||
step=cfg.train.max_steps,
|
||||
)
|
||||
|
||||
log.info("✅ 训练成功完成!")
|
||||
if last_loss is not None:
|
||||
log.info(f"📊 最终损失: {last_loss:.4f}")
|
||||
else:
|
||||
log.info("📊 最终损失: N/A(未执行训练步)")
|
||||
if best_loss != float('inf'):
|
||||
if best_rollout_reward != float('-inf'):
|
||||
log.info(f"📊 最佳 rollout 平均奖励: {best_rollout_reward:.4f}")
|
||||
elif best_loss != float('inf'):
|
||||
log.info(f"📊 最佳损失: {best_loss:.4f}")
|
||||
else:
|
||||
log.info("📊 最佳损失: N/A(无有效验证/训练损失)")
|
||||
log.info("📊 最佳验证指标: N/A(无有效 rollout/验证损失)")
|
||||
finally:
|
||||
_finish_swanlab(swanlab_module)
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config")
|
||||
def main(cfg: DictConfig):
|
||||
_run_training(cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -57,6 +57,7 @@ class DualDianaMed(MujocoEnv):
|
||||
self.obs = None
|
||||
|
||||
self.rew = None
|
||||
self._offscreen_renderer = None
|
||||
|
||||
|
||||
def actuate_J(self, q_target, qdot_target, Arm):
|
||||
@@ -161,6 +162,8 @@ class DualDianaMed(MujocoEnv):
|
||||
|
||||
|
||||
def _get_obs(self):
|
||||
if not self.is_render:
|
||||
self._update_camera_images_sync()
|
||||
obs = collections.OrderedDict()
|
||||
obs['qpos'] = self.get_obs_qpos
|
||||
obs['action'] = self.compute_qpos
|
||||
@@ -173,6 +176,8 @@ class DualDianaMed(MujocoEnv):
|
||||
return obs
|
||||
|
||||
def _get_image_obs(self):
|
||||
if not self.is_render:
|
||||
self._update_camera_images_sync()
|
||||
obs = collections.OrderedDict()
|
||||
obs['images'] = dict()
|
||||
obs['images']['top'] = self.top
|
||||
@@ -211,31 +216,46 @@ class DualDianaMed(MujocoEnv):
|
||||
raise AttributeError("please input right name")
|
||||
|
||||
|
||||
def _get_or_create_offscreen_renderer(self):
|
||||
renderer = getattr(self, '_offscreen_renderer', None)
|
||||
if renderer is None:
|
||||
renderer = mj.Renderer(self.mj_model, height=480, width=640)
|
||||
self._offscreen_renderer = renderer
|
||||
return renderer
|
||||
|
||||
def _render_camera_set(self, img_renderer):
|
||||
img_renderer.update_scene(self.mj_data, camera="rs_cam_right")
|
||||
self.r_vis = img_renderer.render()[:, :, ::-1]
|
||||
img_renderer.update_scene(self.mj_data, camera="rs_cam_left")
|
||||
self.l_vis = img_renderer.render()[:, :, ::-1]
|
||||
img_renderer.update_scene(self.mj_data, camera="top")
|
||||
self.top = img_renderer.render()[:, :, ::-1]
|
||||
img_renderer.update_scene(self.mj_data, camera="angle")
|
||||
self.angle = img_renderer.render()[:, :, ::-1]
|
||||
img_renderer.update_scene(self.mj_data, camera="front")
|
||||
self.front = img_renderer.render()[:, :, ::-1]
|
||||
|
||||
def _update_camera_images_sync(self):
|
||||
img_renderer = self._get_or_create_offscreen_renderer()
|
||||
self._render_camera_set(img_renderer)
|
||||
|
||||
def camera_viewer(self):
|
||||
img_renderer = mj.Renderer(self.mj_model,height=480,width=640)
|
||||
img_renderer = self._get_or_create_offscreen_renderer()
|
||||
show_gui = self.is_render
|
||||
if show_gui:
|
||||
cv2.namedWindow('Cam view',cv2.WINDOW_NORMAL)
|
||||
while not self.exit_flag:
|
||||
img_renderer.update_scene(self.mj_data,camera="rs_cam_right")
|
||||
self.r_vis = img_renderer.render()
|
||||
self.r_vis = self.r_vis[:, :, ::-1]
|
||||
img_renderer.update_scene(self.mj_data,camera="rs_cam_left")
|
||||
self.l_vis = img_renderer.render()
|
||||
self.l_vis = self.l_vis[:, :, ::-1]
|
||||
img_renderer.update_scene(self.mj_data,camera="top")
|
||||
self.top = img_renderer.render()
|
||||
self.top = self.top[:, :, ::-1]
|
||||
img_renderer.update_scene(self.mj_data,camera="angle")
|
||||
self.angle = img_renderer.render()
|
||||
self.angle = self.angle[:, :, ::-1]
|
||||
img_renderer.update_scene(self.mj_data,camera="front")
|
||||
self.front = img_renderer.render()
|
||||
self.front = self.front[:, :, ::-1]
|
||||
self._render_camera_set(img_renderer)
|
||||
if show_gui:
|
||||
if self.cam_view is not None:
|
||||
cv2.imshow('Cam view', self.cam_view)
|
||||
cv2.waitKey(1)
|
||||
|
||||
|
||||
def cam_start(self):
|
||||
if not self.is_render:
|
||||
self.cam_thread = None
|
||||
return
|
||||
self.cam_thread = threading.Thread(target=self.camera_viewer,daemon=True)
|
||||
self.cam_thread.start()
|
||||
|
||||
|
||||
@@ -76,6 +76,9 @@ class DualDianaMed_Pos_Ctrl(DualDianaMed):
|
||||
self.angle = None
|
||||
self.r_vis = None
|
||||
self.front = None
|
||||
if not self.is_render:
|
||||
self._update_camera_images_sync()
|
||||
return
|
||||
self.cam_flage = True
|
||||
t=0
|
||||
while self.cam_flage:
|
||||
@@ -133,12 +136,12 @@ class DualDianaMed_Pos_Ctrl(DualDianaMed):
|
||||
return reward
|
||||
|
||||
|
||||
def make_sim_env(task_name):
|
||||
def make_sim_env(task_name, headless=False):
|
||||
if 'sim_transfer' in task_name:
|
||||
from roboimi.assets.robots.diana_med import BiDianaMed
|
||||
env = DualDianaMed_Pos_Ctrl(
|
||||
robot=BiDianaMed(),
|
||||
is_render=True,
|
||||
is_render=not headless,
|
||||
control_freq=30,
|
||||
is_interpolate=True,
|
||||
cam_view='angle'
|
||||
|
||||
267
roboimi/scripts/refresh_experiment_suite_status.py
Executable file
267
roboimi/scripts/refresh_experiment_suite_status.py
Executable file
@@ -0,0 +1,267 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import datetime as dt
|
||||
import json
|
||||
import pathlib
|
||||
import re
|
||||
import shlex
|
||||
import subprocess
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
|
||||
STEP_PAT = re.compile(r"步骤\s+(\d+)/(\d+)")
|
||||
BAR_PAT = re.compile(r"\|\s*(\d+)/(\d+)")
|
||||
|
||||
|
||||
def normalize_chunks(text: str):
|
||||
for part in re.split(r"[\r\n]+", text):
|
||||
part = part.strip()
|
||||
if part:
|
||||
yield part
|
||||
|
||||
|
||||
def parse_latest_line(text: str) -> tuple[str, int | None]:
|
||||
latest_line = ""
|
||||
latest_step = None
|
||||
for line in normalize_chunks(text):
|
||||
if "步骤" not in line and "训练中:" not in line:
|
||||
continue
|
||||
latest_line = line
|
||||
match = STEP_PAT.search(line) or BAR_PAT.search(line)
|
||||
if match:
|
||||
latest_step = int(match.group(1))
|
||||
return latest_line, latest_step
|
||||
|
||||
|
||||
def now_iso() -> str:
|
||||
return dt.datetime.now(
|
||||
dt.timezone(dt.timedelta(hours=8)),
|
||||
).isoformat(timespec="seconds")
|
||||
|
||||
|
||||
def run_cmd(cmd: list[str], check: bool = True) -> subprocess.CompletedProcess[str]:
|
||||
return subprocess.run(cmd, capture_output=True, text=True, check=check)
|
||||
|
||||
|
||||
def probe_local(run: dict[str, Any]) -> dict[str, Any]:
|
||||
pid = str(run["pid"])
|
||||
ps = run_cmd(["ps", "-p", pid, "-o", "pid=,stat=,etime=,args="], check=False)
|
||||
log_path = pathlib.Path(run["log_path"])
|
||||
latest_line = ""
|
||||
latest_step = None
|
||||
if log_path.exists():
|
||||
latest_line, latest_step = parse_latest_line(log_path.read_text(errors="replace"))
|
||||
return {
|
||||
"alive": bool(ps.stdout.strip()),
|
||||
"ps": ps.stdout.strip(),
|
||||
"log_exists": log_path.exists(),
|
||||
"latest_line": latest_line,
|
||||
"latest_step": latest_step,
|
||||
}
|
||||
|
||||
|
||||
def remote_probe(host: str, remote_user: str, runs: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
|
||||
payload = [
|
||||
{
|
||||
"run_id": run["run_id"],
|
||||
"pid": str(run["pid"]),
|
||||
"log_path": run["log_path"],
|
||||
}
|
||||
for run in runs
|
||||
]
|
||||
remote_py = r"""
|
||||
import json
|
||||
import pathlib
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
payload = json.loads(sys.argv[1])
|
||||
step_pat = re.compile(r"步骤\s+(\d+)/(\d+)")
|
||||
bar_pat = re.compile(r"\|\s*(\d+)/(\d+)")
|
||||
|
||||
def normalize_chunks(text):
|
||||
for part in re.split(r"[\r\n]+", text):
|
||||
part = part.strip()
|
||||
if part:
|
||||
yield part
|
||||
|
||||
def parse_latest_line(text):
|
||||
latest_line = ""
|
||||
latest_step = None
|
||||
for line in normalize_chunks(text):
|
||||
if "步骤" not in line and "训练中:" not in line:
|
||||
continue
|
||||
latest_line = line
|
||||
match = step_pat.search(line) or bar_pat.search(line)
|
||||
if match:
|
||||
latest_step = int(match.group(1))
|
||||
return latest_line, latest_step
|
||||
|
||||
out = {}
|
||||
for item in payload:
|
||||
try:
|
||||
ps = subprocess.run(
|
||||
["ps", "-p", item["pid"], "-o", "pid=,stat=,etime=,args="],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
log_path = pathlib.Path(item["log_path"])
|
||||
latest_line = ""
|
||||
latest_step = None
|
||||
if log_path.exists():
|
||||
latest_line, latest_step = parse_latest_line(log_path.read_text(errors="replace"))
|
||||
out[item["run_id"]] = {
|
||||
"alive": bool(ps.stdout.strip()),
|
||||
"ps": ps.stdout.strip(),
|
||||
"log_exists": log_path.exists(),
|
||||
"latest_line": latest_line,
|
||||
"latest_step": latest_step,
|
||||
}
|
||||
except Exception as exc:
|
||||
out[item["run_id"]] = {
|
||||
"alive": False,
|
||||
"ps": "",
|
||||
"log_exists": False,
|
||||
"latest_line": "",
|
||||
"latest_step": None,
|
||||
"error": str(exc),
|
||||
}
|
||||
print(json.dumps(out, ensure_ascii=False))
|
||||
"""
|
||||
remote_target = host if "@" in host else f"{remote_user}@{host}"
|
||||
remote_cmd = (
|
||||
f"python3 -c {shlex.quote(remote_py)} "
|
||||
f"{shlex.quote(json.dumps(payload, ensure_ascii=False))}"
|
||||
)
|
||||
try:
|
||||
res = run_cmd(
|
||||
[
|
||||
"ssh",
|
||||
"-F",
|
||||
"/dev/null",
|
||||
"-o",
|
||||
"BatchMode=yes",
|
||||
"-o",
|
||||
"StrictHostKeyChecking=accept-new",
|
||||
remote_target,
|
||||
remote_cmd,
|
||||
]
|
||||
)
|
||||
return json.loads(res.stdout)
|
||||
except subprocess.CalledProcessError as exc:
|
||||
error = (exc.stderr or exc.stdout or str(exc)).strip()
|
||||
return {
|
||||
run["run_id"]: {
|
||||
"alive": False,
|
||||
"ps": "",
|
||||
"log_exists": False,
|
||||
"latest_line": "",
|
||||
"latest_step": None,
|
||||
"error": f"ssh_failed: {error}",
|
||||
}
|
||||
for run in runs
|
||||
}
|
||||
|
||||
|
||||
def append_notes(notes_path: pathlib.Path, snapshot_at: str, runs: list[dict[str, Any]]) -> None:
|
||||
lines = [f"\n## Status snapshot {snapshot_at}"]
|
||||
for run in runs:
|
||||
lines.append(
|
||||
(
|
||||
f"- {run['run_id']}: host={run['host']} gpu={run['gpu']} "
|
||||
f"alive={run.get('alive', False)} step={run.get('latest_step')} "
|
||||
f"pid={run['pid']}"
|
||||
)
|
||||
)
|
||||
if run.get("latest_line"):
|
||||
lines.append(f" - latest_line: `{run['latest_line']}`")
|
||||
if run.get("error"):
|
||||
lines.append(f" - error: `{run['error']}`")
|
||||
with notes_path.open("a", encoding="utf-8") as f:
|
||||
f.write("\n".join(lines) + "\n")
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("suite_dir", type=pathlib.Path)
|
||||
parser.add_argument("--remote-user", default="droid")
|
||||
parser.add_argument("--append-notes", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
suite_dir = args.suite_dir.resolve()
|
||||
status_path = suite_dir / "status.json"
|
||||
notes_path = suite_dir / "notes.md"
|
||||
monitor_dir = suite_dir / "monitor_logs"
|
||||
monitor_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
status = json.loads(status_path.read_text(encoding="utf-8"))
|
||||
runs: list[dict[str, Any]] = status["runs"]
|
||||
snapshot_at = now_iso()
|
||||
|
||||
by_host: dict[str, list[dict[str, Any]]] = defaultdict(list)
|
||||
for run in runs:
|
||||
by_host[run["host"]].append(run)
|
||||
|
||||
results: dict[str, dict[str, Any]] = {}
|
||||
for host, host_runs in by_host.items():
|
||||
if host == "local":
|
||||
for run in host_runs:
|
||||
results[run["run_id"]] = probe_local(run)
|
||||
else:
|
||||
results.update(remote_probe(host, args.remote_user, host_runs))
|
||||
|
||||
alive_count = 0
|
||||
for run in runs:
|
||||
result = results[run["run_id"]]
|
||||
run["alive"] = result["alive"]
|
||||
run["ps"] = result["ps"]
|
||||
run["log_exists"] = result["log_exists"]
|
||||
run["latest_line"] = result["latest_line"]
|
||||
run["latest_step"] = result["latest_step"]
|
||||
run["last_verified_at"] = snapshot_at
|
||||
if "error" in result:
|
||||
run["error"] = result["error"]
|
||||
else:
|
||||
run.pop("error", None)
|
||||
run["status"] = "running" if result["alive"] else "stopped"
|
||||
alive_count += int(result["alive"])
|
||||
|
||||
status["last_verified_at"] = snapshot_at
|
||||
status["alive_count"] = alive_count
|
||||
status["total_runs"] = len(runs)
|
||||
|
||||
status_path.write_text(json.dumps(status, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
|
||||
|
||||
snapshot_payload = {
|
||||
"suite_name": status.get("suite_name"),
|
||||
"snapshot_at": snapshot_at,
|
||||
"alive_count": alive_count,
|
||||
"total_runs": len(runs),
|
||||
"runs": {run["run_id"]: results[run["run_id"]] for run in runs},
|
||||
}
|
||||
timestamp_slug = snapshot_at.replace(":", "").replace("+", "_").replace("-", "")
|
||||
snapshot_path = monitor_dir / f"status-{timestamp_slug}.json"
|
||||
snapshot_path.write_text(
|
||||
json.dumps(snapshot_payload, ensure_ascii=False, indent=2) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
if args.append_notes:
|
||||
append_notes(notes_path, snapshot_at, runs)
|
||||
|
||||
print(json.dumps(snapshot_payload, ensure_ascii=False, indent=2))
|
||||
print(f"\nstatus_json={status_path}")
|
||||
print(f"snapshot_json={snapshot_path}")
|
||||
if args.append_notes:
|
||||
print(f"notes_md={notes_path}")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
176
roboimi/utils/raw_action_trajectory_viewer.py
Normal file
176
roboimi/utils/raw_action_trajectory_viewer.py
Normal file
@@ -0,0 +1,176 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
import cv2
|
||||
import mujoco
|
||||
import numpy as np
|
||||
|
||||
from roboimi.assets.robots.diana_med import BiDianaMed
|
||||
from roboimi.envs.mujoco_base import MujocoEnv
|
||||
from roboimi.envs.double_pos_ctrl_env import make_sim_env
|
||||
from roboimi.utils.act_ex_utils import sample_transfer_pose
|
||||
|
||||
|
||||
def _load_raw_action_array(path: str | Path) -> np.ndarray:
|
||||
path = Path(path)
|
||||
if path.suffix == ".npy":
|
||||
raw_action = np.load(path)
|
||||
elif path.suffix == ".npz":
|
||||
archive = np.load(path)
|
||||
if "raw_action" in archive:
|
||||
raw_action = archive["raw_action"]
|
||||
elif "raw_predicted_ee_action" in archive:
|
||||
raw_action = archive["raw_predicted_ee_action"]
|
||||
else:
|
||||
raise KeyError(f"{path} does not contain raw_action")
|
||||
else:
|
||||
raise ValueError(f"unsupported trajectory file: {path}")
|
||||
raw_action = np.asarray(raw_action, dtype=np.float32)
|
||||
if raw_action.ndim != 2 or raw_action.shape[1] < 10:
|
||||
raise ValueError(f"raw_action must have shape (T, 16)-like, got {raw_action.shape}")
|
||||
return raw_action
|
||||
|
||||
|
||||
def disable_cv2_highgui(cv2_module=cv2):
|
||||
original = {
|
||||
"namedWindow": cv2_module.namedWindow,
|
||||
"imshow": cv2_module.imshow,
|
||||
"waitKey": cv2_module.waitKey,
|
||||
}
|
||||
|
||||
cv2_module.namedWindow = lambda *args, **kwargs: None
|
||||
cv2_module.imshow = lambda *args, **kwargs: None
|
||||
cv2_module.waitKey = lambda *args, **kwargs: 1
|
||||
|
||||
def restore():
|
||||
cv2_module.namedWindow = original["namedWindow"]
|
||||
cv2_module.imshow = original["imshow"]
|
||||
cv2_module.waitKey = original["waitKey"]
|
||||
|
||||
return restore
|
||||
|
||||
|
||||
def set_transfer_box_pose(mj_data, box_pos: np.ndarray) -> None:
|
||||
box_pos = np.asarray(box_pos, dtype=np.float64)
|
||||
if box_pos.shape != (3,):
|
||||
raise ValueError(f"box_pos must have shape (3,), got {box_pos.shape}")
|
||||
joint = mj_data.joint("red_box_joint")
|
||||
joint.qpos[0] = box_pos[0]
|
||||
joint.qpos[1] = box_pos[1]
|
||||
joint.qpos[2] = box_pos[2]
|
||||
joint.qpos[3] = 1.0
|
||||
joint.qpos[4] = 0.0
|
||||
joint.qpos[5] = 0.0
|
||||
joint.qpos[6] = 0.0
|
||||
|
||||
|
||||
def load_raw_action_positions(path: str | Path) -> dict[str, np.ndarray]:
|
||||
raw_action = _load_raw_action_array(path)
|
||||
return {
|
||||
"left": raw_action[:, :3].astype(np.float32, copy=True),
|
||||
"right": raw_action[:, 7:10].astype(np.float32, copy=True),
|
||||
}
|
||||
|
||||
|
||||
def _downsample_points(points: np.ndarray, stride: int) -> np.ndarray:
|
||||
sampled = points[::stride]
|
||||
if len(sampled) == 0:
|
||||
return points
|
||||
if not np.array_equal(sampled[-1], points[-1]):
|
||||
sampled = np.concatenate([sampled, points[-1:]], axis=0)
|
||||
return sampled
|
||||
|
||||
|
||||
def build_trajectory_capsule_markers(
|
||||
positions: dict[str, np.ndarray],
|
||||
*,
|
||||
max_markers: int,
|
||||
radius: float = 0.003,
|
||||
rgba: tuple[float, float, float, float] = (1.0, 0.0, 0.0, 1.0),
|
||||
) -> list[dict]:
|
||||
total_segments = sum(max(len(points) - 1, 0) for points in positions.values())
|
||||
if total_segments == 0:
|
||||
return []
|
||||
stride = max(1, math.ceil(total_segments / max_markers))
|
||||
markers = []
|
||||
for points in positions.values():
|
||||
sampled = _downsample_points(np.asarray(points, dtype=np.float64), stride)
|
||||
for idx in range(len(sampled) - 1):
|
||||
markers.append(
|
||||
{
|
||||
"from": sampled[idx],
|
||||
"to": sampled[idx + 1],
|
||||
"rgba": rgba,
|
||||
"radius": float(radius),
|
||||
}
|
||||
)
|
||||
return markers[:max_markers]
|
||||
|
||||
|
||||
def apply_capsule_markers_to_scene(user_scn, markers: Iterable[dict]) -> None:
|
||||
user_scn.ngeom = 0
|
||||
for marker in markers:
|
||||
if user_scn.ngeom >= user_scn.maxgeom:
|
||||
break
|
||||
geom = user_scn.geoms[user_scn.ngeom]
|
||||
mujoco.mjv_initGeom(
|
||||
geom,
|
||||
mujoco.mjtGeom.mjGEOM_CAPSULE,
|
||||
np.zeros(3, dtype=np.float64),
|
||||
np.zeros(3, dtype=np.float64),
|
||||
np.eye(3, dtype=np.float64).reshape(-1),
|
||||
np.asarray(marker["rgba"], dtype=np.float32),
|
||||
)
|
||||
mujoco.mjv_connector(
|
||||
geom,
|
||||
mujoco.mjtGeom.mjGEOM_CAPSULE,
|
||||
float(marker["radius"]),
|
||||
np.asarray(marker["from"], dtype=np.float64),
|
||||
np.asarray(marker["to"], dtype=np.float64),
|
||||
)
|
||||
user_scn.ngeom += 1
|
||||
|
||||
|
||||
def launch_raw_action_trajectory_viewer(
|
||||
trajectory_path: str | Path,
|
||||
*,
|
||||
task_name: str = "sim_transfer",
|
||||
line_radius: float = 0.004,
|
||||
max_markers: int = 1500,
|
||||
box_pos: np.ndarray | None = None,
|
||||
disable_camera_window: bool = True,
|
||||
):
|
||||
positions = load_raw_action_positions(trajectory_path)
|
||||
if task_name != "sim_transfer":
|
||||
raise NotImplementedError(f"unsupported task_name: {task_name}")
|
||||
if box_pos is None:
|
||||
box_pos = sample_transfer_pose()
|
||||
|
||||
robot = BiDianaMed()
|
||||
viewer_env = MujocoEnv(robot=robot, is_render=True, renderer="viewer", control_freq=30)
|
||||
viewer_env.reset()
|
||||
set_transfer_box_pose(viewer_env.mj_data, box_pos)
|
||||
mujoco.mj_forward(viewer_env.mj_model, viewer_env.mj_data)
|
||||
markers = build_trajectory_capsule_markers(
|
||||
positions,
|
||||
max_markers=max_markers,
|
||||
radius=line_radius,
|
||||
)
|
||||
|
||||
if viewer_env.viewer is None or getattr(viewer_env.viewer, "user_scn", None) is None:
|
||||
raise RuntimeError("viewer does not expose user_scn; cannot render trajectory overlay")
|
||||
|
||||
try:
|
||||
while viewer_env.viewer.is_running() and not viewer_env.exit_flag:
|
||||
with viewer_env.viewer.lock():
|
||||
apply_capsule_markers_to_scene(viewer_env.viewer.user_scn, markers)
|
||||
viewer_env.render()
|
||||
time.sleep(1 / 60.0)
|
||||
finally:
|
||||
viewer_env.exit_flag = True
|
||||
if getattr(viewer_env, "viewer", None) is not None:
|
||||
viewer_env.viewer.close()
|
||||
113
roboimi/utils/streaming_episode_writer.py
Normal file
113
roboimi/utils/streaming_episode_writer.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import h5py
|
||||
import numpy as np
|
||||
|
||||
|
||||
class StreamingEpisodeWriter:
|
||||
"""逐帧写入 episode 数据,成功后提交,失败时丢弃临时文件。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_path: str | os.PathLike[str],
|
||||
max_timesteps: int,
|
||||
camera_names: list[str],
|
||||
image_size: tuple[int, int] = (256, 256),
|
||||
) -> None:
|
||||
self.dataset_path = Path(dataset_path)
|
||||
self.tmp_path = Path(f"{self.dataset_path}.tmp")
|
||||
self.max_timesteps = int(max_timesteps)
|
||||
self.camera_names = list(camera_names)
|
||||
self.image_height = int(image_size[0])
|
||||
self.image_width = int(image_size[1])
|
||||
self.frame_index = 0
|
||||
self._committed = False
|
||||
self._closed = False
|
||||
|
||||
self.dataset_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if self.tmp_path.exists():
|
||||
self.tmp_path.unlink()
|
||||
|
||||
self._file = h5py.File(self.tmp_path, "w", rdcc_nbytes=1024**2 * 2)
|
||||
self._file.attrs["sim"] = True
|
||||
self._file.attrs["action_repr"] = "ee_pose_xyz_quat_gripper"
|
||||
self._file.attrs["image_height"] = self.image_height
|
||||
self._file.attrs["image_width"] = self.image_width
|
||||
self._file.attrs["camera_names"] = np.asarray(self.camera_names, dtype="S")
|
||||
|
||||
observations = self._file.create_group("observations")
|
||||
images = observations.create_group("images")
|
||||
for cam_name in self.camera_names:
|
||||
images.create_dataset(
|
||||
cam_name,
|
||||
(self.max_timesteps, self.image_height, self.image_width, 3),
|
||||
dtype="uint8",
|
||||
chunks=(1, self.image_height, self.image_width, 3),
|
||||
)
|
||||
observations.create_dataset(
|
||||
"qpos",
|
||||
(self.max_timesteps, 16),
|
||||
dtype="float32",
|
||||
chunks=(min(128, self.max_timesteps), 16),
|
||||
)
|
||||
self._file.create_dataset(
|
||||
"action",
|
||||
(self.max_timesteps, 16),
|
||||
dtype="float32",
|
||||
chunks=(min(128, self.max_timesteps), 16),
|
||||
)
|
||||
|
||||
def append(self, qpos: np.ndarray, action: np.ndarray, images: dict[str, np.ndarray]) -> None:
|
||||
if self._closed:
|
||||
raise RuntimeError("writer is already closed")
|
||||
if self.frame_index >= self.max_timesteps:
|
||||
raise IndexError("frame index exceeds max_timesteps")
|
||||
|
||||
qpos = np.asarray(qpos, dtype=np.float32)
|
||||
action = np.asarray(action, dtype=np.float32)
|
||||
if qpos.shape != (16,):
|
||||
raise ValueError(f"qpos shape must be (16,), got {qpos.shape}")
|
||||
if action.shape != (16,):
|
||||
raise ValueError(f"action shape must be (16,), got {action.shape}")
|
||||
|
||||
self._file["observations/qpos"][self.frame_index] = qpos
|
||||
self._file["action"][self.frame_index] = action
|
||||
|
||||
for cam_name in self.camera_names:
|
||||
if cam_name not in images:
|
||||
raise KeyError(f"missing image for camera '{cam_name}'")
|
||||
self._file[f"observations/images/{cam_name}"][self.frame_index] = self._resize_image(images[cam_name])
|
||||
|
||||
self.frame_index += 1
|
||||
|
||||
def commit(self) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._file.flush()
|
||||
self._file.close()
|
||||
self._closed = True
|
||||
os.replace(self.tmp_path, self.dataset_path)
|
||||
self._committed = True
|
||||
|
||||
def discard(self) -> None:
|
||||
if not self._closed:
|
||||
self._file.close()
|
||||
self._closed = True
|
||||
if self.tmp_path.exists():
|
||||
self.tmp_path.unlink()
|
||||
|
||||
def _resize_image(self, image: np.ndarray) -> np.ndarray:
|
||||
image = np.asarray(image, dtype=np.uint8)
|
||||
if image.ndim != 3 or image.shape[2] != 3:
|
||||
raise ValueError(f"image shape must be HxWx3, got {image.shape}")
|
||||
if image.shape[:2] == (self.image_height, self.image_width):
|
||||
return image
|
||||
|
||||
interpolation = cv2.INTER_AREA
|
||||
if image.shape[0] < self.image_height or image.shape[1] < self.image_width:
|
||||
interpolation = cv2.INTER_LINEAR
|
||||
return cv2.resize(image, (self.image_width, self.image_height), interpolation=interpolation)
|
||||
@@ -3,10 +3,8 @@ import torch.nn as nn
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
from typing import Dict, Optional, Any, Tuple
|
||||
from roboimi.vla.core.interfaces import VLABackbone, VLAProjector, VLAHead
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
from roboimi.vla.models.heads.conditional_unet1d import ConditionalUnet1D
|
||||
from roboimi.vla.models.normalization import NormalizationModule
|
||||
|
||||
class VLAAgent(nn.Module):
|
||||
@@ -24,10 +22,13 @@ class VLAAgent(nn.Module):
|
||||
diffusion_steps=100, # DDPM 加噪步数
|
||||
inference_steps=10, # DDIM 推理步数
|
||||
num_cams=3, # 视觉输入的摄像头数量
|
||||
camera_names: Optional[Tuple[str, ...]] = None, # 条件相机顺序
|
||||
dataset_stats=None, # 数据集统计信息,用于归一化
|
||||
normalization_type='min_max', # 归一化类型: 'gaussian' 或 'min_max'
|
||||
num_action_steps=8, # 每次推理实际执行多少步动作
|
||||
head_type='unet', # Policy head类型: 'unet' 或 'transformer'
|
||||
cond_projector=None, # 可选:将视觉+状态条件投影到head期望维度
|
||||
extra_condition_tokens: int = 0, # 可选:额外条件token数量(例如未来预测embedding)
|
||||
):
|
||||
super().__init__()
|
||||
# 保存参数
|
||||
@@ -39,6 +40,34 @@ class VLAAgent(nn.Module):
|
||||
self.num_action_steps = num_action_steps
|
||||
self.inference_steps = inference_steps
|
||||
self.head_type = head_type # 'unet' 或 'transformer'
|
||||
self.extra_condition_tokens = int(extra_condition_tokens)
|
||||
if self.extra_condition_tokens < 0:
|
||||
raise ValueError(f"extra_condition_tokens must be >= 0, got {self.extra_condition_tokens}")
|
||||
agent_camera_names = tuple(camera_names) if camera_names is not None else None
|
||||
backbone_camera_names = getattr(vision_backbone, 'camera_names', None)
|
||||
backbone_camera_names = tuple(backbone_camera_names) if backbone_camera_names is not None else None
|
||||
backbone_num_cameras = getattr(vision_backbone, 'num_cameras', None)
|
||||
if backbone_num_cameras is not None and backbone_num_cameras != self.num_cams:
|
||||
raise ValueError(
|
||||
f"agent.num_cams({self.num_cams}) 与 "
|
||||
f"vision_backbone.num_cameras({backbone_num_cameras}) 不一致"
|
||||
)
|
||||
if (
|
||||
agent_camera_names is not None
|
||||
and backbone_camera_names is not None
|
||||
and agent_camera_names != backbone_camera_names
|
||||
):
|
||||
raise ValueError(
|
||||
f"agent.camera_names({list(agent_camera_names)}) 与 "
|
||||
f"vision_backbone.camera_names({list(backbone_camera_names)}) 不一致"
|
||||
)
|
||||
self.camera_names = (
|
||||
agent_camera_names if agent_camera_names is not None else backbone_camera_names
|
||||
)
|
||||
if self.camera_names is not None and len(self.camera_names) != self.num_cams:
|
||||
raise ValueError(
|
||||
f"camera_names 长度({len(self.camera_names)})与 num_cams({self.num_cams})不一致"
|
||||
)
|
||||
|
||||
|
||||
# 归一化模块 - 统一训练和推理的归一化逻辑
|
||||
@@ -46,17 +75,42 @@ class VLAAgent(nn.Module):
|
||||
stats=dataset_stats,
|
||||
normalization_type=normalization_type
|
||||
)
|
||||
self.dataset_stats = dataset_stats
|
||||
|
||||
self.vision_encoder = vision_backbone
|
||||
self.state_encoder = state_encoder
|
||||
if self.camera_names is not None:
|
||||
self.vision_encoder.camera_names = self.camera_names
|
||||
self.condition_tokens_per_step = int(getattr(self.vision_encoder, 'tokens_per_step', 1))
|
||||
self.state_feature_dim = int(getattr(self.state_encoder, 'output_dim', obs_dim))
|
||||
joint_vision_dim = getattr(self.vision_encoder, 'joint_output_dim', None)
|
||||
if joint_vision_dim is not None:
|
||||
per_token_vision_dim = int(joint_vision_dim)
|
||||
self.condition_tokens_per_step = 1
|
||||
else:
|
||||
single_cam_feat_dim = self.vision_encoder.output_dim
|
||||
# global_cond_dim: 展平后的总维度(用于UNet)
|
||||
total_vision_dim = single_cam_feat_dim * num_cams * obs_horizon
|
||||
total_prop_dim = obs_dim * obs_horizon
|
||||
self.global_cond_dim = total_vision_dim + total_prop_dim
|
||||
if self.condition_tokens_per_step > 1:
|
||||
per_token_vision_dim = int(single_cam_feat_dim)
|
||||
else:
|
||||
per_token_vision_dim = int(single_cam_feat_dim) * int(num_cams)
|
||||
|
||||
# per_step_cond_dim: 每步的条件维度(用于Transformer)
|
||||
# 注意:这里不乘以obs_horizon,因为Transformer的输入是序列形式
|
||||
self.per_step_cond_dim = single_cam_feat_dim * num_cams + obs_dim
|
||||
self.history_condition_sequence_length = self.obs_horizon * self.condition_tokens_per_step
|
||||
self.condition_sequence_length = (
|
||||
self.history_condition_sequence_length + self.extra_condition_tokens
|
||||
)
|
||||
self.raw_per_step_cond_dim = per_token_vision_dim + self.state_feature_dim
|
||||
if cond_projector is None:
|
||||
self.cond_projector = None
|
||||
self.per_step_cond_dim = self.raw_per_step_cond_dim
|
||||
else:
|
||||
if isinstance(cond_projector, nn.Module):
|
||||
self.cond_projector = cond_projector
|
||||
else:
|
||||
self.cond_projector = cond_projector(input_dim=self.raw_per_step_cond_dim)
|
||||
self.per_step_cond_dim = self._projector_output_dim(self.cond_projector, self.raw_per_step_cond_dim)
|
||||
|
||||
# global_cond_dim: 展平后的总维度(用于UNet)
|
||||
self.global_cond_dim = self.per_step_cond_dim * self.condition_sequence_length
|
||||
|
||||
self.noise_scheduler = DDPMScheduler(
|
||||
num_train_timesteps=diffusion_steps,
|
||||
@@ -85,7 +139,7 @@ class VLAAgent(nn.Module):
|
||||
input_dim=action_dim,
|
||||
output_dim=action_dim,
|
||||
horizon=pred_horizon,
|
||||
n_obs_steps=obs_horizon,
|
||||
n_obs_steps=self.condition_sequence_length,
|
||||
cond_dim=self.per_step_cond_dim # 每步的条件维度
|
||||
)
|
||||
else: # 'unet' (default)
|
||||
@@ -95,7 +149,6 @@ class VLAAgent(nn.Module):
|
||||
global_cond_dim=self.global_cond_dim
|
||||
)
|
||||
|
||||
self.state_encoder = state_encoder
|
||||
self.action_encoder = action_encoder
|
||||
|
||||
# 初始化队列(用于在线推理)
|
||||
@@ -117,6 +170,84 @@ class VLAAgent(nn.Module):
|
||||
return tuple(self._move_to_device(v, device) for v in data)
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def _projector_output_dim(projector: nn.Module, fallback: int) -> int:
|
||||
output_dim = getattr(projector, 'output_dim', None)
|
||||
if output_dim is not None:
|
||||
return int(output_dim)
|
||||
out_features = getattr(projector, 'out_features', None)
|
||||
if out_features is not None:
|
||||
return int(out_features)
|
||||
linear = getattr(projector, 'linear', None)
|
||||
linear_out_features = getattr(linear, 'out_features', None)
|
||||
if linear_out_features is not None:
|
||||
return int(linear_out_features)
|
||||
return int(fallback)
|
||||
|
||||
def _order_images(self, images: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
"""按显式配置的相机顺序返回图像字典。"""
|
||||
if self.camera_names is None:
|
||||
camera_names = tuple(sorted(images.keys()))
|
||||
if len(camera_names) != self.num_cams:
|
||||
raise ValueError(
|
||||
f"图像条件相机数量({len(camera_names)})与 num_cams({self.num_cams})不一致"
|
||||
)
|
||||
return {cam_name: images[cam_name] for cam_name in camera_names}
|
||||
|
||||
missing = [cam_name for cam_name in self.camera_names if cam_name not in images]
|
||||
if missing:
|
||||
raise ValueError(
|
||||
f"图像条件缺少必需相机。missing={missing}, expected={list(self.camera_names)}"
|
||||
)
|
||||
return {cam_name: images[cam_name] for cam_name in self.camera_names}
|
||||
|
||||
def _build_cond(self, images: Dict[str, torch.Tensor], states: torch.Tensor) -> torch.Tensor:
|
||||
"""构造每步条件,确保图像条件顺序稳定。"""
|
||||
ordered_images = self._order_images(images)
|
||||
visual_features = self.vision_encoder(ordered_images)
|
||||
state_features = self.state_encoder(states)
|
||||
if visual_features.ndim == 4:
|
||||
batch_size, obs_steps, token_count, _ = visual_features.shape
|
||||
if obs_steps != state_features.shape[1]:
|
||||
raise RuntimeError(
|
||||
f"观测时间维不匹配: visual={obs_steps}, state={state_features.shape[1]}"
|
||||
)
|
||||
if token_count != self.condition_tokens_per_step:
|
||||
raise RuntimeError(
|
||||
f"条件token数量不匹配: got {token_count}, expected {self.condition_tokens_per_step}"
|
||||
)
|
||||
state_features = state_features.unsqueeze(2).expand(-1, -1, token_count, -1)
|
||||
cond = torch.cat([visual_features, state_features], dim=-1)
|
||||
if cond.shape[-1] != self.raw_per_step_cond_dim:
|
||||
raise RuntimeError(
|
||||
f"原始条件维度不匹配: got {cond.shape[-1]}, expected {self.raw_per_step_cond_dim}"
|
||||
)
|
||||
if self.cond_projector is not None:
|
||||
cond = self.cond_projector(cond)
|
||||
if cond.shape[-1] != self.per_step_cond_dim:
|
||||
raise RuntimeError(
|
||||
f"条件维度不匹配: got {cond.shape[-1]}, expected {self.per_step_cond_dim}"
|
||||
)
|
||||
cond = cond.reshape(batch_size, obs_steps * token_count, self.per_step_cond_dim)
|
||||
expected_length = self.history_condition_sequence_length
|
||||
if cond.shape[1] != expected_length:
|
||||
raise RuntimeError(
|
||||
f"条件序列长度不匹配: got {cond.shape[1]}, expected {expected_length}"
|
||||
)
|
||||
return cond
|
||||
|
||||
cond = torch.cat([visual_features, state_features], dim=-1)
|
||||
if cond.shape[-1] != self.raw_per_step_cond_dim:
|
||||
raise RuntimeError(
|
||||
f"原始条件维度不匹配: got {cond.shape[-1]}, expected {self.raw_per_step_cond_dim}"
|
||||
)
|
||||
if self.cond_projector is not None:
|
||||
cond = self.cond_projector(cond)
|
||||
if cond.shape[-1] != self.per_step_cond_dim:
|
||||
raise RuntimeError(
|
||||
f"条件维度不匹配: got {cond.shape[-1]}, expected {self.per_step_cond_dim}"
|
||||
)
|
||||
return cond
|
||||
|
||||
# ==========================
|
||||
# 训练阶段 (Training)
|
||||
@@ -136,10 +267,8 @@ class VLAAgent(nn.Module):
|
||||
states = self.normalization.normalize_qpos(states)
|
||||
actions = self.normalization.normalize_action(actions)
|
||||
|
||||
state_features = self.state_encoder(states)
|
||||
|
||||
# 1. 提取视觉特征
|
||||
visual_features = self.vision_encoder(images) # (B, obs_horizon, vision_dim)
|
||||
per_step_cond = self._build_cond(images, states)
|
||||
action_features = self.action_encoder(actions)
|
||||
|
||||
# 2. 采样噪声
|
||||
@@ -157,21 +286,16 @@ class VLAAgent(nn.Module):
|
||||
)
|
||||
|
||||
# 拼接全局条件并展平
|
||||
# visual_features: (B, obs_horizon, vision_dim)
|
||||
# state_features: (B, obs_horizon, obs_dim)
|
||||
# 拼接后展平为 (B, obs_horizon * (vision_dim + obs_dim))
|
||||
global_cond = torch.cat([visual_features, state_features], dim=-1)
|
||||
global_cond = global_cond.flatten(start_dim=1)
|
||||
# per_step_cond: (B, obs_horizon, vision_dim * num_cams + obs_dim)
|
||||
# 展平后用于 UNet,全序列形式用于 Transformer
|
||||
global_cond = per_step_cond.flatten(start_dim=1)
|
||||
|
||||
# 5. 网络预测噪声(根据head类型选择接口)
|
||||
if self.head_type == 'transformer':
|
||||
# Transformer需要序列格式的条件: (B, obs_horizon, cond_dim_per_step)
|
||||
# 将展平的global_cond reshape回序列格式
|
||||
cond = global_cond.reshape(B, self.obs_horizon, self.per_step_cond_dim)
|
||||
pred_noise = self.noise_pred_net(
|
||||
sample=noisy_actions,
|
||||
timestep=timesteps,
|
||||
cond=cond
|
||||
cond=per_step_cond
|
||||
)
|
||||
else: # 'unet'
|
||||
pred_noise = self.noise_pred_net(
|
||||
@@ -218,7 +342,8 @@ class VLAAgent(nn.Module):
|
||||
|
||||
# 添加图像
|
||||
if 'images' in observation:
|
||||
self._queues['images'].append({k: v.clone() for k, v in observation['images'].items()})
|
||||
ordered_images = self._order_images(observation['images'])
|
||||
self._queues['images'].append({k: v.clone() for k, v in ordered_images.items()})
|
||||
|
||||
def _prepare_observation_batch(self) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
@@ -246,7 +371,8 @@ class VLAAgent(nn.Module):
|
||||
images_list.append(images_list[-1])
|
||||
|
||||
batch_images = {}
|
||||
for cam_name in images_list[0].keys():
|
||||
camera_names = self.camera_names if self.camera_names is not None else tuple(sorted(images_list[0].keys()))
|
||||
for cam_name in camera_names:
|
||||
batch_images[cam_name] = torch.stack([img[cam_name] for img in images_list], dim=0).unsqueeze(0)
|
||||
|
||||
return {'qpos': batch_qpos, 'images': batch_images}
|
||||
@@ -346,22 +472,18 @@ class VLAAgent(nn.Module):
|
||||
proprioception = self.normalization.normalize_qpos(proprioception)
|
||||
|
||||
# 1. 提取当前观测特征(只提取一次)
|
||||
visual_features = self.vision_encoder(images)
|
||||
state_features = self.state_encoder(proprioception)
|
||||
per_step_cond = self._build_cond(images, proprioception)
|
||||
|
||||
# 拼接条件(只计算一次)
|
||||
# visual_features: (B, obs_horizon, vision_dim)
|
||||
# state_features: (B, obs_horizon, obs_dim)
|
||||
global_cond = torch.cat([visual_features, state_features], dim=-1)
|
||||
global_cond_flat = global_cond.flatten(start_dim=1)
|
||||
global_cond_flat = per_step_cond.flatten(start_dim=1)
|
||||
if self.head_type == 'transformer':
|
||||
cond = global_cond.reshape(B, self.obs_horizon, self.per_step_cond_dim)
|
||||
cond = per_step_cond
|
||||
else:
|
||||
cond = None
|
||||
|
||||
# 2. 初始化纯高斯噪声动作
|
||||
# 形状: (B, pred_horizon, action_dim)
|
||||
device = visual_features.device
|
||||
device = per_step_cond.device
|
||||
current_actions = torch.randn(
|
||||
(B, self.pred_horizon, self.action_dim), device=device
|
||||
)
|
||||
|
||||
567
roboimi/vla/agent_imf.py
Normal file
567
roboimi/vla/agent_imf.py
Normal file
@@ -0,0 +1,567 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import nullcontext
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Mapping, Optional, Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from roboimi.vla.agent import VLAAgent
|
||||
|
||||
try:
|
||||
from torch.func import jvp as TORCH_FUNC_JVP
|
||||
except ImportError: # pragma: no cover
|
||||
TORCH_FUNC_JVP = None
|
||||
|
||||
|
||||
class IMFVLAAgent(VLAAgent):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
inference_steps: int = 1,
|
||||
lewm_history_horizon: Optional[int] = None,
|
||||
lewm_query_offsets: Optional[Sequence[int]] = None,
|
||||
lewm_predictor: Optional[nn.Module] = None,
|
||||
lewm_pred_projector: Optional[nn.Module] = None,
|
||||
future_decoder: Optional[nn.Module] = None,
|
||||
future_query_init_std: float = 0.02,
|
||||
lewm_sigreg: Optional[nn.Module] = None,
|
||||
lewm_sigreg_weight: float = 0.09,
|
||||
lewm_loss_weight: float = 0.0,
|
||||
lewm_pretrained_ckpt: Optional[str | Path | Mapping[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if inference_steps != 1:
|
||||
raise ValueError(
|
||||
'IMFVLAAgent only supports one-step inference; '
|
||||
f'inference_steps must be 1, got {inference_steps}.'
|
||||
)
|
||||
lewm_query_offsets = tuple(int(offset) for offset in (lewm_query_offsets or ()))
|
||||
inferred_extra_condition_tokens = len(lewm_query_offsets) if lewm_query_offsets else 0
|
||||
default_extra_condition_tokens = (
|
||||
0 if future_decoder is not None else inferred_extra_condition_tokens
|
||||
)
|
||||
kwargs.setdefault('extra_condition_tokens', default_extra_condition_tokens)
|
||||
self.__dict__['lewm_history_horizon'] = int(lewm_history_horizon or kwargs.get('obs_horizon', 1))
|
||||
self.__dict__['lewm_query_offsets'] = lewm_query_offsets
|
||||
self.__dict__['lewm_predictor'] = lewm_predictor
|
||||
self.__dict__['lewm_pred_projector'] = lewm_pred_projector or nn.Identity()
|
||||
self.__dict__['future_decoder'] = future_decoder
|
||||
self.__dict__['future_query_tokens'] = None
|
||||
self.__dict__['future_query_init_std'] = float(future_query_init_std)
|
||||
self.__dict__['lewm_sigreg'] = lewm_sigreg
|
||||
self.__dict__['lewm_sigreg_weight'] = float(lewm_sigreg_weight)
|
||||
self.__dict__['lewm_loss_weight'] = float(lewm_loss_weight)
|
||||
self.__dict__['_last_loss_breakdown'] = {
|
||||
'action_loss': 0.0,
|
||||
'lewm_pred_loss': 0.0,
|
||||
'lewm_sigreg_loss': 0.0,
|
||||
'lewm_loss': 0.0,
|
||||
'loss': 0.0,
|
||||
}
|
||||
super().__init__(*args, inference_steps=inference_steps, **kwargs)
|
||||
self.inference_steps = 1
|
||||
self.lewm_history_horizon = int(lewm_history_horizon or self.obs_horizon)
|
||||
self.lewm_predictor = lewm_predictor
|
||||
self.lewm_pred_projector = lewm_pred_projector or nn.Identity()
|
||||
if future_decoder is not None and not isinstance(future_decoder, nn.Module):
|
||||
self.future_decoder = future_decoder()
|
||||
else:
|
||||
self.future_decoder = future_decoder
|
||||
self.future_query_tokens = None
|
||||
self.future_query_init_std = float(future_query_init_std)
|
||||
self.lewm_sigreg = lewm_sigreg
|
||||
self.lewm_sigreg_weight = float(lewm_sigreg_weight)
|
||||
if self.lewm_predictor is not None and self.future_decoder is not None:
|
||||
raise ValueError('lewm_predictor and future_decoder are mutually exclusive')
|
||||
if self.lewm_predictor is None and self.extra_condition_tokens > 0:
|
||||
raise ValueError(
|
||||
'extra_condition_tokens > 0 requires lewm_predictor to be provided'
|
||||
)
|
||||
if self.lewm_predictor is not None and self.extra_condition_tokens != inferred_extra_condition_tokens:
|
||||
raise ValueError(
|
||||
'extra_condition_tokens must equal len(lewm_query_offsets) when lewm_predictor is enabled'
|
||||
)
|
||||
if self.future_decoder is not None:
|
||||
if inferred_extra_condition_tokens <= 0:
|
||||
raise ValueError('future_decoder requires non-empty lewm_query_offsets')
|
||||
if self.extra_condition_tokens != 0:
|
||||
raise ValueError('future_decoder requires extra_condition_tokens=0')
|
||||
self.future_query_tokens = nn.Parameter(
|
||||
torch.randn(
|
||||
1,
|
||||
inferred_extra_condition_tokens,
|
||||
self.per_step_cond_dim,
|
||||
) * self.future_query_init_std
|
||||
)
|
||||
if lewm_pretrained_ckpt is not None:
|
||||
self.load_lewm_pretrained_components(lewm_pretrained_ckpt)
|
||||
|
||||
@staticmethod
|
||||
def _broadcast_batch_time(value: torch.Tensor, reference: torch.Tensor) -> torch.Tensor:
|
||||
while value.ndim < reference.ndim:
|
||||
value = value.unsqueeze(-1)
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _apply_conditioning(
|
||||
trajectory: torch.Tensor,
|
||||
condition_data: Optional[torch.Tensor] = None,
|
||||
condition_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if condition_data is None or condition_mask is None:
|
||||
return trajectory
|
||||
conditioned = trajectory.clone()
|
||||
conditioned[condition_mask] = condition_data[condition_mask]
|
||||
return conditioned
|
||||
|
||||
@staticmethod
|
||||
def _jvp_math_sdp_context(z_t: torch.Tensor):
|
||||
if z_t.is_cuda:
|
||||
return torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=False,
|
||||
enable_math=True,
|
||||
enable_mem_efficient=False,
|
||||
enable_cudnn=False,
|
||||
)
|
||||
return nullcontext()
|
||||
|
||||
@staticmethod
|
||||
def _jvp_tangents(v: torch.Tensor, r: torch.Tensor, t: torch.Tensor):
|
||||
return v.detach(), torch.zeros_like(r), torch.ones_like(t)
|
||||
|
||||
def fn(self, z: torch.Tensor, r: torch.Tensor, t: torch.Tensor, cond=None) -> torch.Tensor:
|
||||
return self.noise_pred_net(z, r, t, cond=cond)
|
||||
|
||||
def _compute_u_and_du_dt(
|
||||
self,
|
||||
z_t: torch.Tensor,
|
||||
r: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
cond,
|
||||
v: torch.Tensor,
|
||||
condition_data: Optional[torch.Tensor] = None,
|
||||
condition_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
tangents = self._jvp_tangents(v, r, t)
|
||||
|
||||
def g(z, r_value, t_value):
|
||||
conditioned_z = self._apply_conditioning(z, condition_data, condition_mask)
|
||||
return self.fn(conditioned_z, r_value, t_value, cond=cond)
|
||||
|
||||
with self._jvp_math_sdp_context(z_t):
|
||||
if TORCH_FUNC_JVP is not None:
|
||||
try:
|
||||
return TORCH_FUNC_JVP(g, (z_t, r, t), tangents)
|
||||
except (RuntimeError, TypeError, NotImplementedError):
|
||||
pass
|
||||
|
||||
u = g(z_t, r, t)
|
||||
_, du_dt = torch.autograd.functional.jvp(
|
||||
g,
|
||||
(z_t, r, t),
|
||||
tangents,
|
||||
create_graph=False,
|
||||
strict=False,
|
||||
)
|
||||
return u, du_dt
|
||||
|
||||
def _compound_velocity(
|
||||
self,
|
||||
u: torch.Tensor,
|
||||
du_dt: torch.Tensor,
|
||||
r: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
delta = self._broadcast_batch_time(t - r, u)
|
||||
return u + delta * du_dt.detach()
|
||||
|
||||
def _sample_one_step(
|
||||
self,
|
||||
z_t: torch.Tensor,
|
||||
r: Optional[torch.Tensor] = None,
|
||||
t: Optional[torch.Tensor] = None,
|
||||
cond=None,
|
||||
) -> torch.Tensor:
|
||||
batch_size = z_t.shape[0]
|
||||
if t is None:
|
||||
t = torch.ones(batch_size, device=z_t.device, dtype=z_t.dtype)
|
||||
if r is None:
|
||||
r = torch.zeros(batch_size, device=z_t.device, dtype=z_t.dtype)
|
||||
u = self.fn(z_t, r, t, cond=cond)
|
||||
delta = self._broadcast_batch_time(t - r, z_t)
|
||||
return z_t - delta * u
|
||||
|
||||
def _normalize_qpos_for_lewm(self, qpos: torch.Tensor) -> torch.Tensor:
|
||||
if not self.normalization.enabled:
|
||||
return qpos
|
||||
|
||||
qpos_mean = getattr(self.normalization, 'qpos_mean', None)
|
||||
qpos_std = getattr(self.normalization, 'qpos_std', None)
|
||||
if qpos_mean is not None and qpos_std is not None:
|
||||
return (qpos - qpos_mean) / qpos_std
|
||||
if isinstance(self.dataset_stats, dict):
|
||||
mean = self.dataset_stats.get('qpos_mean', None)
|
||||
std = self.dataset_stats.get('qpos_std', None)
|
||||
if mean is not None and std is not None:
|
||||
mean = torch.as_tensor(mean, dtype=qpos.dtype, device=qpos.device)
|
||||
std = torch.as_tensor(std, dtype=qpos.dtype, device=qpos.device)
|
||||
return (qpos - mean) / std
|
||||
return self.normalization.normalize_qpos(qpos)
|
||||
|
||||
def _project_lewm_future_tokens(self, predicted_tokens: torch.Tensor) -> torch.Tensor:
|
||||
if predicted_tokens.ndim != 3:
|
||||
raise ValueError(
|
||||
f"expected predicted future tokens to be 3D, got rank {predicted_tokens.ndim}"
|
||||
)
|
||||
batch_size, token_count, token_dim = predicted_tokens.shape
|
||||
flattened = predicted_tokens.reshape(batch_size * token_count, token_dim)
|
||||
projected = self.lewm_pred_projector(flattened)
|
||||
if projected.ndim != 2:
|
||||
raise ValueError(
|
||||
f"expected lewm_pred_projector to return rank-2 tensors, got rank {projected.ndim}"
|
||||
)
|
||||
return projected.reshape(batch_size, token_count, projected.shape[-1])
|
||||
|
||||
@staticmethod
|
||||
def _load_checkpoint_payload(
|
||||
checkpoint_or_path: str | Path | Mapping[str, Any],
|
||||
) -> Mapping[str, torch.Tensor]:
|
||||
if isinstance(checkpoint_or_path, (str, Path)):
|
||||
payload = torch.load(Path(checkpoint_or_path), map_location='cpu', weights_only=False)
|
||||
else:
|
||||
payload = checkpoint_or_path
|
||||
state_dict = payload.get('state_dict', payload)
|
||||
if not isinstance(state_dict, Mapping):
|
||||
raise TypeError('checkpoint payload must contain a mapping state_dict')
|
||||
return state_dict
|
||||
|
||||
@staticmethod
|
||||
def _extract_prefixed_state_dict(
|
||||
state_dict: Mapping[str, torch.Tensor],
|
||||
prefix: str,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
extracted = {
|
||||
key[len(prefix):]: value
|
||||
for key, value in state_dict.items()
|
||||
if key.startswith(prefix)
|
||||
}
|
||||
if not extracted:
|
||||
raise KeyError(f"checkpoint missing parameters with prefix {prefix!r}")
|
||||
return extracted
|
||||
|
||||
@staticmethod
|
||||
def _adapt_and_load_state_dict(
|
||||
module: nn.Module,
|
||||
incoming_state_dict: Mapping[str, torch.Tensor],
|
||||
*,
|
||||
query_key: str = 'query_tokens',
|
||||
pos_key: str = 'pos_embedding',
|
||||
) -> Dict[str, Sequence[str]]:
|
||||
current_state_dict = module.state_dict()
|
||||
adapted_state_dict = dict(current_state_dict)
|
||||
loaded_keys = []
|
||||
mismatched_keys = []
|
||||
missing_keys = []
|
||||
for key, current_tensor in current_state_dict.items():
|
||||
if key not in incoming_state_dict:
|
||||
continue
|
||||
source_tensor = incoming_state_dict[key]
|
||||
if source_tensor.shape == current_tensor.shape:
|
||||
adapted_state_dict[key] = source_tensor
|
||||
loaded_keys.append(key)
|
||||
continue
|
||||
|
||||
if key in {query_key, pos_key} and source_tensor.ndim == current_tensor.ndim:
|
||||
patched = current_tensor.clone()
|
||||
overlap_slices = tuple(
|
||||
slice(0, min(src_dim, cur_dim))
|
||||
for src_dim, cur_dim in zip(source_tensor.shape, current_tensor.shape)
|
||||
)
|
||||
patched[overlap_slices] = source_tensor[overlap_slices]
|
||||
if key == query_key:
|
||||
copy_count = min(source_tensor.shape[1], current_tensor.shape[1])
|
||||
if copy_count < current_tensor.shape[1] and copy_count > 0:
|
||||
tail = source_tensor[:, copy_count - 1:copy_count, ...]
|
||||
feature_dim = min(tail.shape[-1], patched.shape[-1])
|
||||
patched[:, copy_count:, :feature_dim] = tail[:, :, :feature_dim]
|
||||
else:
|
||||
copy_count = min(source_tensor.shape[1], current_tensor.shape[1])
|
||||
if copy_count < current_tensor.shape[1] and copy_count > 0:
|
||||
tail = source_tensor[:, copy_count - 1:copy_count, ...]
|
||||
feature_dim = min(tail.shape[-1], patched.shape[-1])
|
||||
patched[:, copy_count:, :feature_dim] = tail[:, :, :feature_dim]
|
||||
adapted_state_dict[key] = patched
|
||||
loaded_keys.append(key)
|
||||
continue
|
||||
mismatched_keys.append(key)
|
||||
|
||||
for key in incoming_state_dict.keys():
|
||||
if key not in current_state_dict:
|
||||
missing_keys.append(key)
|
||||
module.load_state_dict(adapted_state_dict, strict=True)
|
||||
return {
|
||||
'loaded_keys': tuple(sorted(loaded_keys)),
|
||||
'mismatched_keys': tuple(sorted(set(mismatched_keys))),
|
||||
'missing_keys': tuple(sorted(set(missing_keys))),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _load_state_dict_ignoring_shape_mismatches(
|
||||
module: nn.Module,
|
||||
incoming_state_dict: Mapping[str, torch.Tensor],
|
||||
) -> Dict[str, Sequence[str]]:
|
||||
current_state_dict = module.state_dict()
|
||||
merged_state_dict = dict(current_state_dict)
|
||||
loaded_keys = []
|
||||
mismatched_keys = []
|
||||
missing_keys = []
|
||||
|
||||
for key, value in incoming_state_dict.items():
|
||||
if key not in current_state_dict:
|
||||
missing_keys.append(key)
|
||||
continue
|
||||
if current_state_dict[key].shape != value.shape:
|
||||
mismatched_keys.append(key)
|
||||
continue
|
||||
merged_state_dict[key] = value
|
||||
loaded_keys.append(key)
|
||||
|
||||
module.load_state_dict(merged_state_dict, strict=True)
|
||||
return {
|
||||
'loaded_keys': tuple(sorted(loaded_keys)),
|
||||
'mismatched_keys': tuple(sorted(mismatched_keys)),
|
||||
'missing_keys': tuple(sorted(missing_keys)),
|
||||
}
|
||||
|
||||
def load_lewm_pretrained_components(
|
||||
self,
|
||||
checkpoint_or_path: str | Path | Mapping[str, Any],
|
||||
) -> None:
|
||||
state_dict = self._load_checkpoint_payload(checkpoint_or_path)
|
||||
|
||||
if hasattr(self.vision_encoder, 'load_lewm_checkpoint'):
|
||||
try:
|
||||
self.vision_encoder.load_lewm_checkpoint({'state_dict': state_dict})
|
||||
except RuntimeError:
|
||||
vision_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.encoder.')
|
||||
self._load_state_dict_ignoring_shape_mismatches(self.vision_encoder, vision_state_dict)
|
||||
else:
|
||||
vision_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.encoder.')
|
||||
self._load_state_dict_ignoring_shape_mismatches(self.vision_encoder, vision_state_dict)
|
||||
|
||||
state_encoder_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.state_encoder.')
|
||||
self._load_state_dict_ignoring_shape_mismatches(self.state_encoder, state_encoder_state_dict)
|
||||
|
||||
projector_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.projector.proj.')
|
||||
mapped_projector_state_dict = {
|
||||
f'linear.{key}': value
|
||||
for key, value in projector_state_dict.items()
|
||||
}
|
||||
self._load_state_dict_ignoring_shape_mismatches(self.cond_projector, mapped_projector_state_dict)
|
||||
|
||||
if self.lewm_predictor is not None:
|
||||
predictor_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.predictor.')
|
||||
self._adapt_and_load_state_dict(self.lewm_predictor, predictor_state_dict)
|
||||
|
||||
if self.lewm_pred_projector is not None:
|
||||
pred_projector_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.pred_proj.')
|
||||
self._load_state_dict_ignoring_shape_mismatches(
|
||||
self.lewm_pred_projector,
|
||||
pred_projector_state_dict,
|
||||
)
|
||||
|
||||
def _predict_future_tokens_with_decoder(self, history_cond: torch.Tensor) -> torch.Tensor:
|
||||
if self.future_decoder is None or self.future_query_tokens is None:
|
||||
raise RuntimeError('future_decoder path requested but not initialized')
|
||||
batch_size = history_cond.shape[0]
|
||||
query_tokens = self.future_query_tokens.expand(batch_size, -1, -1)
|
||||
r = torch.zeros(batch_size, device=history_cond.device, dtype=history_cond.dtype)
|
||||
t = torch.ones(batch_size, device=history_cond.device, dtype=history_cond.dtype)
|
||||
return self.future_decoder(query_tokens, r, t, cond=history_cond)
|
||||
|
||||
def _build_full_condition(
|
||||
self,
|
||||
images,
|
||||
proprioception,
|
||||
*,
|
||||
lewm_images=None,
|
||||
lewm_proprioception=None,
|
||||
):
|
||||
normalized_proprioception = self.normalization.normalize_qpos(proprioception)
|
||||
history_cond = self._build_cond(images, normalized_proprioception)
|
||||
predicted_future_tokens = None
|
||||
lewm_history_cond = None
|
||||
|
||||
if self.lewm_predictor is None and self.future_decoder is None:
|
||||
return history_cond, predicted_future_tokens, lewm_history_cond
|
||||
|
||||
lewm_images = lewm_images if lewm_images is not None else images
|
||||
lewm_proprioception = (
|
||||
lewm_proprioception if lewm_proprioception is not None else proprioception
|
||||
)
|
||||
lewm_history_cond = self._build_cond(
|
||||
lewm_images,
|
||||
self._normalize_qpos_for_lewm(lewm_proprioception),
|
||||
)
|
||||
cond = history_cond
|
||||
if self.lewm_predictor is not None:
|
||||
predicted_future_tokens = self.lewm_predictor(lewm_history_cond)
|
||||
predicted_future_tokens = self._project_lewm_future_tokens(predicted_future_tokens)
|
||||
cond = torch.cat([history_cond, predicted_future_tokens], dim=1)
|
||||
if cond.shape[1] != self.condition_sequence_length:
|
||||
raise RuntimeError(
|
||||
f"完整条件序列长度不匹配: got {cond.shape[1]}, expected {self.condition_sequence_length}"
|
||||
)
|
||||
if cond.shape[-1] != self.per_step_cond_dim:
|
||||
raise RuntimeError(
|
||||
f"完整条件维度不匹配: got {cond.shape[-1]}, expected {self.per_step_cond_dim}"
|
||||
)
|
||||
elif self.future_decoder is not None:
|
||||
predicted_future_tokens = self._predict_future_tokens_with_decoder(lewm_history_cond)
|
||||
return cond, predicted_future_tokens, lewm_history_cond
|
||||
|
||||
@staticmethod
|
||||
def _masked_mse_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
return F.mse_loss(pred, target)
|
||||
|
||||
def compute_loss(self, batch):
|
||||
actions, states, images = batch['action'], batch['qpos'], batch['images']
|
||||
action_is_pad = batch.get('action_is_pad', None)
|
||||
batch_size = actions.shape[0]
|
||||
|
||||
actions = self.normalization.normalize_action(actions)
|
||||
cond, predicted_future_tokens, lewm_history_cond = self._build_full_condition(
|
||||
images,
|
||||
states,
|
||||
lewm_images=batch.get('lewm_images', None),
|
||||
lewm_proprioception=batch.get('lewm_qpos', None),
|
||||
)
|
||||
|
||||
x = actions
|
||||
e = torch.randn_like(x)
|
||||
t = torch.rand(batch_size, device=x.device, dtype=x.dtype)
|
||||
r = torch.rand(batch_size, device=x.device, dtype=x.dtype)
|
||||
t, r = torch.maximum(t, r), torch.minimum(t, r)
|
||||
|
||||
t_broadcast = self._broadcast_batch_time(t, x)
|
||||
z_t = (1 - t_broadcast) * x + t_broadcast * e
|
||||
|
||||
v = self.fn(z_t, t, t, cond=cond)
|
||||
u, du_dt = self._compute_u_and_du_dt(z_t, r, t, cond=cond, v=v)
|
||||
V = self._compound_velocity(u, du_dt, r, t)
|
||||
target = e - x
|
||||
|
||||
loss = F.mse_loss(V, target, reduction='none')
|
||||
if action_is_pad is not None:
|
||||
mask = (~action_is_pad).unsqueeze(-1).to(loss.dtype)
|
||||
valid_count = mask.sum() * loss.shape[-1]
|
||||
action_loss = (loss * mask).sum() / valid_count.clamp_min(1.0)
|
||||
else:
|
||||
action_loss = loss.mean()
|
||||
|
||||
lewm_pred_loss = torch.zeros((), device=action_loss.device, dtype=action_loss.dtype)
|
||||
lewm_sigreg_loss = torch.zeros((), device=action_loss.device, dtype=action_loss.dtype)
|
||||
if predicted_future_tokens is not None:
|
||||
lewm_future_images = batch.get('lewm_future_images', None)
|
||||
lewm_future_qpos = batch.get('lewm_future_qpos', None)
|
||||
if lewm_future_images is not None and lewm_future_qpos is not None:
|
||||
future_target = self._build_cond(
|
||||
lewm_future_images,
|
||||
self._normalize_qpos_for_lewm(lewm_future_qpos),
|
||||
)
|
||||
lewm_pred_loss = self._masked_mse_loss(predicted_future_tokens, future_target)
|
||||
if self.lewm_sigreg is not None and lewm_history_cond is not None:
|
||||
lewm_sigreg_loss = self.lewm_sigreg(lewm_history_cond.transpose(0, 1))
|
||||
|
||||
lewm_loss = lewm_pred_loss + self.lewm_sigreg_weight * lewm_sigreg_loss
|
||||
total_loss = action_loss + self.lewm_loss_weight * lewm_loss
|
||||
self._last_loss_breakdown = {
|
||||
'action_loss': float(action_loss.detach().item()),
|
||||
'lewm_pred_loss': float(lewm_pred_loss.detach().item()),
|
||||
'lewm_sigreg_loss': float(lewm_sigreg_loss.detach().item()),
|
||||
'lewm_loss': float(lewm_loss.detach().item()),
|
||||
'loss': float(total_loss.detach().item()),
|
||||
}
|
||||
return total_loss
|
||||
|
||||
def get_last_loss_breakdown(self) -> Dict[str, float]:
|
||||
return dict(self._last_loss_breakdown)
|
||||
|
||||
def reset(self):
|
||||
super().reset()
|
||||
if self.lewm_predictor is not None:
|
||||
self._queues['lewm_qpos'] = deque(maxlen=self.lewm_history_horizon)
|
||||
self._queues['lewm_images'] = deque(maxlen=self.lewm_history_horizon)
|
||||
|
||||
def _populate_queues(self, observation: Dict[str, torch.Tensor]) -> None:
|
||||
super()._populate_queues(observation)
|
||||
if self.lewm_predictor is None:
|
||||
return
|
||||
if 'qpos' in observation:
|
||||
self._queues['lewm_qpos'].append(observation['qpos'].clone())
|
||||
if 'images' in observation:
|
||||
ordered_images = self._order_images(observation['images'])
|
||||
self._queues['lewm_images'].append({k: v.clone() for k, v in ordered_images.items()})
|
||||
|
||||
def _prepare_observation_batch(self) -> Dict[str, torch.Tensor]:
|
||||
batch = super()._prepare_observation_batch()
|
||||
if self.lewm_predictor is None:
|
||||
return batch
|
||||
|
||||
qpos_list = list(self._queues['lewm_qpos'])
|
||||
images_list = list(self._queues['lewm_images'])
|
||||
if len(qpos_list) == 0 or len(images_list) == 0:
|
||||
raise ValueError("LeWM 观测队列为空,请先调用 _populate_queues 添加观测")
|
||||
while len(qpos_list) < self.lewm_history_horizon:
|
||||
qpos_list.append(qpos_list[-1])
|
||||
while len(images_list) < self.lewm_history_horizon:
|
||||
images_list.append(images_list[-1])
|
||||
|
||||
batch['lewm_qpos'] = torch.stack(qpos_list, dim=0).unsqueeze(0)
|
||||
batch['lewm_images'] = {}
|
||||
camera_names = self.camera_names if self.camera_names is not None else tuple(sorted(images_list[0].keys()))
|
||||
for cam_name in camera_names:
|
||||
batch['lewm_images'][cam_name] = torch.stack(
|
||||
[img[cam_name] for img in images_list],
|
||||
dim=0,
|
||||
).unsqueeze(0)
|
||||
return batch
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
return self.predict_action(
|
||||
batch['images'],
|
||||
batch['qpos'],
|
||||
lewm_images=batch.get('lewm_images', None),
|
||||
lewm_proprioception=batch.get('lewm_qpos', None),
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action(
|
||||
self,
|
||||
images,
|
||||
proprioception,
|
||||
*,
|
||||
lewm_images=None,
|
||||
lewm_proprioception=None,
|
||||
):
|
||||
batch_size = proprioception.shape[0]
|
||||
if self.lewm_predictor is not None:
|
||||
cond, _predicted_future_tokens, _lewm_history_cond = self._build_full_condition(
|
||||
images,
|
||||
proprioception,
|
||||
lewm_images=lewm_images,
|
||||
lewm_proprioception=lewm_proprioception,
|
||||
)
|
||||
else:
|
||||
cond = self._build_cond(
|
||||
images,
|
||||
self.normalization.normalize_qpos(proprioception),
|
||||
)
|
||||
z_t = torch.randn((batch_size, self.pred_horizon, self.action_dim), device=cond.device, dtype=cond.dtype)
|
||||
action = self._sample_one_step(z_t, cond=cond)
|
||||
return self.normalization.denormalize_action(action)
|
||||
41
roboimi/vla/conf/agent/lewm_imf_attnres.yaml
Normal file
41
roboimi/vla/conf/agent/lewm_imf_attnres.yaml
Normal file
@@ -0,0 +1,41 @@
|
||||
# @package agent
|
||||
defaults:
|
||||
- /backbone@vision_backbone: lewm_vit_diffusion
|
||||
- /modules@state_encoder: identity_state_encoder
|
||||
- /modules@action_encoder: identity_action_encoder
|
||||
- /head: imf_transformer1d
|
||||
- _self_
|
||||
|
||||
_target_: roboimi.vla.agent_imf.IMFVLAAgent
|
||||
|
||||
action_dim: 16
|
||||
obs_dim: 16
|
||||
normalization_type: "min_max"
|
||||
pred_horizon: 16
|
||||
obs_horizon: 2
|
||||
num_action_steps: 8
|
||||
camera_names: ${data.camera_names}
|
||||
num_cams: 3
|
||||
|
||||
vision_backbone:
|
||||
num_cameras: ${agent.num_cams}
|
||||
camera_names: ${agent.camera_names}
|
||||
fused_camera_names: [front, top, r_vis]
|
||||
|
||||
diffusion_steps: 100
|
||||
inference_steps: 1
|
||||
head_type: "transformer"
|
||||
|
||||
head:
|
||||
input_dim: ${agent.action_dim}
|
||||
output_dim: ${agent.action_dim}
|
||||
horizon: ${agent.pred_horizon}
|
||||
n_obs_steps: ${agent.obs_horizon}
|
||||
cond_dim: 208
|
||||
causal_attn: false
|
||||
time_as_cond: true
|
||||
obs_as_cond: true
|
||||
n_cond_layers: 0
|
||||
backbone_type: attnres_full
|
||||
n_head: 1
|
||||
n_kv_head: 1
|
||||
@@ -0,0 +1,74 @@
|
||||
# @package agent
|
||||
defaults:
|
||||
- /backbone@vision_backbone: lewm_resnet_query_fusion
|
||||
- /modules@state_encoder: lewm_state_encoder
|
||||
- /modules@action_encoder: identity_action_encoder
|
||||
- /modules@cond_projector: linear_condition_projector
|
||||
- /head: imf_transformer1d
|
||||
- /head@future_decoder: imf_transformer1d
|
||||
- _self_
|
||||
|
||||
_target_: roboimi.vla.agent_imf.IMFVLAAgent
|
||||
|
||||
action_dim: 16
|
||||
obs_dim: 16
|
||||
normalization_type: "min_max"
|
||||
pred_horizon: 8
|
||||
obs_horizon: 2
|
||||
num_action_steps: 8
|
||||
camera_names: ${data.camera_names}
|
||||
num_cams: 3
|
||||
|
||||
vision_backbone:
|
||||
camera_names: ${agent.camera_names}
|
||||
num_views: ${agent.num_cams}
|
||||
|
||||
cond_projector:
|
||||
output_dim: 288
|
||||
|
||||
lewm_history_horizon: 3
|
||||
lewm_query_offsets: [8]
|
||||
extra_condition_tokens: 0
|
||||
lewm_loss_weight: 1.0
|
||||
lewm_sigreg_weight: 0.09
|
||||
lewm_pretrained_ckpt: null
|
||||
future_query_init_std: 0.02
|
||||
|
||||
lewm_sigreg:
|
||||
_target_: roboimi.vla.models.backbones.lewm_resnet_query_fusion.SIGReg
|
||||
knots: 17
|
||||
num_proj: 1024
|
||||
|
||||
diffusion_steps: 100
|
||||
inference_steps: 1
|
||||
head_type: "transformer"
|
||||
|
||||
head:
|
||||
input_dim: ${agent.action_dim}
|
||||
output_dim: ${agent.action_dim}
|
||||
horizon: ${agent.pred_horizon}
|
||||
n_obs_steps: ${agent.obs_horizon}
|
||||
cond_dim: ${agent.cond_projector.output_dim}
|
||||
n_emb: 384
|
||||
causal_attn: false
|
||||
time_as_cond: true
|
||||
obs_as_cond: true
|
||||
n_cond_layers: 0
|
||||
backbone_type: attnres_full
|
||||
n_head: 1
|
||||
n_kv_head: 1
|
||||
|
||||
future_decoder:
|
||||
input_dim: ${agent.cond_projector.output_dim}
|
||||
output_dim: ${agent.cond_projector.output_dim}
|
||||
horizon: ${len:${agent.lewm_query_offsets}}
|
||||
n_obs_steps: ${agent.lewm_history_horizon}
|
||||
cond_dim: ${agent.cond_projector.output_dim}
|
||||
n_emb: 384
|
||||
causal_attn: false
|
||||
time_as_cond: true
|
||||
obs_as_cond: true
|
||||
n_cond_layers: 0
|
||||
backbone_type: attnres_full
|
||||
n_head: 1
|
||||
n_kv_head: 1
|
||||
77
roboimi/vla/conf/agent/lewm_resnet_query_imf_attnres.yaml
Normal file
77
roboimi/vla/conf/agent/lewm_resnet_query_imf_attnres.yaml
Normal file
@@ -0,0 +1,77 @@
|
||||
# @package agent
|
||||
defaults:
|
||||
- /backbone@vision_backbone: lewm_resnet_query_fusion
|
||||
- /modules@state_encoder: lewm_state_encoder
|
||||
- /modules@action_encoder: identity_action_encoder
|
||||
- /modules@cond_projector: linear_condition_projector
|
||||
- /head: imf_transformer1d
|
||||
- _self_
|
||||
|
||||
_target_: roboimi.vla.agent_imf.IMFVLAAgent
|
||||
|
||||
action_dim: 16
|
||||
obs_dim: 16
|
||||
normalization_type: "min_max"
|
||||
pred_horizon: 8
|
||||
obs_horizon: 2
|
||||
num_action_steps: 8
|
||||
camera_names: ${data.camera_names}
|
||||
num_cams: 3
|
||||
|
||||
vision_backbone:
|
||||
camera_names: ${agent.camera_names}
|
||||
num_views: ${agent.num_cams}
|
||||
|
||||
cond_projector:
|
||||
output_dim: 288
|
||||
|
||||
lewm_history_horizon: 3
|
||||
lewm_query_offsets: [8]
|
||||
extra_condition_tokens: ${len:${agent.lewm_query_offsets}}
|
||||
lewm_loss_weight: 1.0
|
||||
lewm_sigreg_weight: 0.09
|
||||
lewm_pretrained_ckpt: null
|
||||
|
||||
lewm_sigreg:
|
||||
_target_: roboimi.vla.models.backbones.lewm_resnet_query_fusion.SIGReg
|
||||
knots: 17
|
||||
num_proj: 1024
|
||||
|
||||
lewm_predictor:
|
||||
_target_: roboimi.vla.models.backbones.lewm_resnet_query_fusion.QueryTokenPredictor
|
||||
num_frames: ${agent.lewm_history_horizon}
|
||||
query_offsets: ${agent.lewm_query_offsets}
|
||||
input_dim: ${agent.cond_projector.output_dim}
|
||||
hidden_dim: ${agent.cond_projector.output_dim}
|
||||
output_dim: ${agent.cond_projector.output_dim}
|
||||
depth: 6
|
||||
heads: 16
|
||||
mlp_dim: 2048
|
||||
dim_head: 64
|
||||
dropout: 0.1
|
||||
emb_dropout: 0.0
|
||||
|
||||
lewm_pred_projector:
|
||||
_target_: roboimi.vla.models.backbones.lewm_resnet_query_fusion.LeWMProjectorMLP
|
||||
input_dim: ${agent.cond_projector.output_dim}
|
||||
hidden_dim: 2048
|
||||
output_dim: ${agent.cond_projector.output_dim}
|
||||
|
||||
diffusion_steps: 100
|
||||
inference_steps: 1
|
||||
head_type: "transformer"
|
||||
|
||||
head:
|
||||
input_dim: ${agent.action_dim}
|
||||
output_dim: ${agent.action_dim}
|
||||
horizon: ${agent.pred_horizon}
|
||||
n_obs_steps: ${agent.obs_horizon}
|
||||
cond_dim: 288
|
||||
n_emb: 384
|
||||
causal_attn: false
|
||||
time_as_cond: true
|
||||
obs_as_cond: true
|
||||
n_cond_layers: 0
|
||||
backbone_type: attnres_full
|
||||
n_head: 1
|
||||
n_kv_head: 1
|
||||
40
roboimi/vla/conf/agent/resnet_imf_attnres.yaml
Normal file
40
roboimi/vla/conf/agent/resnet_imf_attnres.yaml
Normal file
@@ -0,0 +1,40 @@
|
||||
# @package agent
|
||||
defaults:
|
||||
- /backbone@vision_backbone: resnet_diffusion
|
||||
- /modules@state_encoder: identity_state_encoder
|
||||
- /modules@action_encoder: identity_action_encoder
|
||||
- /head: imf_transformer1d
|
||||
- _self_
|
||||
|
||||
_target_: roboimi.vla.agent_imf.IMFVLAAgent
|
||||
|
||||
action_dim: 16
|
||||
obs_dim: 16
|
||||
normalization_type: "min_max"
|
||||
pred_horizon: 16
|
||||
obs_horizon: 2
|
||||
num_action_steps: 8
|
||||
camera_names: ${data.camera_names}
|
||||
num_cams: 3
|
||||
|
||||
vision_backbone:
|
||||
num_cameras: ${agent.num_cams}
|
||||
camera_names: ${agent.camera_names}
|
||||
|
||||
diffusion_steps: 100
|
||||
inference_steps: 1
|
||||
head_type: "transformer"
|
||||
|
||||
head:
|
||||
input_dim: ${agent.action_dim}
|
||||
output_dim: ${agent.action_dim}
|
||||
horizon: ${agent.pred_horizon}
|
||||
n_obs_steps: ${agent.obs_horizon}
|
||||
cond_dim: 208
|
||||
causal_attn: false
|
||||
time_as_cond: true
|
||||
obs_as_cond: true
|
||||
n_cond_layers: 0
|
||||
backbone_type: attnres_full
|
||||
n_head: 1
|
||||
n_kv_head: 1
|
||||
48
roboimi/vla/conf/agent/resnet_imf_attnres_multitoken.yaml
Normal file
48
roboimi/vla/conf/agent/resnet_imf_attnres_multitoken.yaml
Normal file
@@ -0,0 +1,48 @@
|
||||
# @package agent
|
||||
defaults:
|
||||
- /backbone@vision_backbone: resnet_diffusion
|
||||
- /modules@state_encoder: identity_state_encoder
|
||||
- /modules@action_encoder: identity_action_encoder
|
||||
- /modules@cond_projector: linear_condition_projector
|
||||
- /head: imf_transformer1d
|
||||
- _self_
|
||||
|
||||
_target_: roboimi.vla.agent_imf.IMFVLAAgent
|
||||
|
||||
action_dim: 16
|
||||
obs_dim: 16
|
||||
normalization_type: "min_max"
|
||||
pred_horizon: 16
|
||||
obs_horizon: 2
|
||||
num_action_steps: 8
|
||||
camera_names: ${data.camera_names}
|
||||
num_cams: ${len:${agent.camera_names}}
|
||||
|
||||
vision_backbone:
|
||||
num_cameras: ${agent.num_cams}
|
||||
camera_names: ${agent.camera_names}
|
||||
vision_backbone: "resnet18"
|
||||
vision_backbone_mode: "resnet"
|
||||
freeze_backbone: false
|
||||
use_separate_rgb_encoder_per_camera: true
|
||||
output_tokens_per_camera: true
|
||||
|
||||
cond_projector:
|
||||
output_dim: ${agent.head.n_emb}
|
||||
|
||||
diffusion_steps: 100
|
||||
inference_steps: 1
|
||||
head_type: "transformer"
|
||||
|
||||
head:
|
||||
input_dim: ${agent.action_dim}
|
||||
output_dim: ${agent.action_dim}
|
||||
horizon: ${agent.pred_horizon}
|
||||
cond_dim: ${agent.head.n_emb}
|
||||
causal_attn: false
|
||||
time_as_cond: true
|
||||
obs_as_cond: true
|
||||
n_cond_layers: 0
|
||||
backbone_type: attnres_full
|
||||
n_head: 1
|
||||
n_kv_head: 1
|
||||
@@ -29,8 +29,13 @@ num_action_steps: 8 # 每次推理实际执行多少步动作(应 <= p
|
||||
# ====================
|
||||
# 相机配置
|
||||
# ====================
|
||||
camera_names: ${data.camera_names} # 条件相机顺序固定为 r_vis, top, front
|
||||
num_cams: 3 # 摄像头数量 (r_vis, top, front)
|
||||
|
||||
vision_backbone:
|
||||
num_cameras: ${agent.num_cams}
|
||||
camera_names: ${agent.camera_names}
|
||||
|
||||
# ====================
|
||||
# 扩散过程配置
|
||||
# ====================
|
||||
@@ -52,3 +57,6 @@ head:
|
||||
# ResNet18 + SpatialSoftmax(32 keypoints) = 64维/相机
|
||||
# 计算方式:单相机特征(64) * 相机数(3) + obs_dim(16) = 208
|
||||
cond_dim: 208
|
||||
causal_attn: false
|
||||
time_as_cond: true
|
||||
obs_as_cond: true
|
||||
|
||||
44
roboimi/vla/conf/agent/siglip2_imf_attnres.yaml
Normal file
44
roboimi/vla/conf/agent/siglip2_imf_attnres.yaml
Normal file
@@ -0,0 +1,44 @@
|
||||
# @package agent
|
||||
defaults:
|
||||
- /backbone@vision_backbone: siglip2_diffusion
|
||||
- /modules@state_encoder: identity_state_encoder
|
||||
- /modules@action_encoder: identity_action_encoder
|
||||
- /modules@cond_projector: linear_condition_projector
|
||||
- /head: imf_transformer1d
|
||||
- _self_
|
||||
|
||||
_target_: roboimi.vla.agent_imf.IMFVLAAgent
|
||||
|
||||
action_dim: 16
|
||||
obs_dim: 16
|
||||
normalization_type: "min_max"
|
||||
pred_horizon: 16
|
||||
obs_horizon: 2
|
||||
num_action_steps: 8
|
||||
camera_names: ${data.camera_names}
|
||||
num_cams: ${len:${agent.camera_names}}
|
||||
|
||||
vision_backbone:
|
||||
num_cameras: ${agent.num_cams}
|
||||
camera_names: ${agent.camera_names}
|
||||
|
||||
cond_projector:
|
||||
output_dim: ${agent.head.cond_dim}
|
||||
|
||||
diffusion_steps: 100
|
||||
inference_steps: 1
|
||||
head_type: "transformer"
|
||||
|
||||
head:
|
||||
input_dim: ${agent.action_dim}
|
||||
output_dim: ${agent.action_dim}
|
||||
horizon: ${agent.pred_horizon}
|
||||
n_obs_steps: ${agent.obs_horizon}
|
||||
cond_dim: 384
|
||||
causal_attn: false
|
||||
time_as_cond: true
|
||||
obs_as_cond: true
|
||||
n_cond_layers: 0
|
||||
backbone_type: attnres_full
|
||||
n_head: 1
|
||||
n_kv_head: 1
|
||||
7
roboimi/vla/conf/backbone/lewm_resnet_query_fusion.yaml
Normal file
7
roboimi/vla/conf/backbone/lewm_resnet_query_fusion.yaml
Normal file
@@ -0,0 +1,7 @@
|
||||
_target_: roboimi.vla.models.backbones.lewm_resnet_query_fusion.LeWMMultiViewResNetBackbone
|
||||
|
||||
view_feature_dim: 96
|
||||
num_views: ${agent.num_cams}
|
||||
view_encoder_mode: separate
|
||||
camera_names: ${agent.camera_names}
|
||||
checkpoint_path: null
|
||||
16
roboimi/vla/conf/backbone/lewm_vit_diffusion.yaml
Normal file
16
roboimi/vla/conf/backbone/lewm_vit_diffusion.yaml
Normal file
@@ -0,0 +1,16 @@
|
||||
_target_: roboimi.vla.models.backbones.lewm_vit_backbone.LEWMViTBackbone
|
||||
|
||||
# LEWM checkpoint path; override this on the target machine.
|
||||
checkpoint_path: null
|
||||
|
||||
# Input camera contract for roboimi; internal LEWM fusion order stays front/top/r_vis.
|
||||
num_cameras: 3
|
||||
camera_names: [r_vis, top, front]
|
||||
fused_camera_names: [front, top, r_vis]
|
||||
|
||||
freeze_backbone: true
|
||||
joint_output_dim: 192
|
||||
output_dim: 192
|
||||
image_size: 224
|
||||
dataset_image_resize_shape: null
|
||||
eval_image_resize_shape: [256, 256]
|
||||
@@ -5,6 +5,7 @@ _target_: roboimi.vla.models.backbones.resnet_diffusion.ResNetDiffusionBackbone
|
||||
# ====================
|
||||
vision_backbone: "resnet18" # torchvision 模型名称: resnet18, resnet34, resnet50
|
||||
pretrained_backbone_weights: "IMAGENET1K_V1" # 使用ImageNet预训练权重(torchvision>=0.13)
|
||||
vision_backbone_mode: "resnet" # resnet | attnres_resnet
|
||||
|
||||
# ====================
|
||||
# 冻结设置
|
||||
@@ -30,4 +31,20 @@ spatial_softmax_num_keypoints: 32 # Spatial Softmax 关键点数量
|
||||
# false: 共享编码器(所有摄像头共享一个 ResNet,参数少但容量受限)推荐!
|
||||
# true: 独立编码器(每个摄像头有独立的 ResNet,参数多但容量大)
|
||||
use_separate_rgb_encoder_per_camera: true
|
||||
# false: 将所有相机特征拼成一个条件token;true: 每个相机输出一个独立token
|
||||
output_tokens_per_camera: false
|
||||
num_cameras: 3 # 摄像头数量
|
||||
|
||||
# ====================
|
||||
# Full-AttnRes vision trunk(当 vision_backbone_mode=attnres_resnet 时生效)
|
||||
# ====================
|
||||
attnres_stem_dim: 64
|
||||
attnres_stage_dims: [64, 128, 256, 512]
|
||||
attnres_stage_depths: [2, 2, 2, 2]
|
||||
attnres_stage_heads: [4, 4, 8, 8]
|
||||
attnres_stage_kv_heads: [1, 1, 1, 1]
|
||||
attnres_stage_window_sizes: [7, 7, 7, 7]
|
||||
attnres_dropout: 0.0
|
||||
attnres_ffn_mult: 2.667
|
||||
attnres_eps: 1.0e-06
|
||||
attnres_rope_theta: 10000.0
|
||||
|
||||
10
roboimi/vla/conf/backbone/siglip2_diffusion.yaml
Normal file
10
roboimi/vla/conf/backbone/siglip2_diffusion.yaml
Normal file
@@ -0,0 +1,10 @@
|
||||
_target_: roboimi.vla.models.backbones.siglip2_diffusion_backbone.SigLIP2DiffusionBackbone
|
||||
|
||||
model_name: google/siglip2-base-patch16-256
|
||||
camera_names: [r_vis, top, front]
|
||||
num_cameras: 3
|
||||
per_view_output_dim: 96
|
||||
freeze_backbone: true
|
||||
|
||||
dataset_image_resize_shape: null
|
||||
eval_image_resize_shape: [256, 256]
|
||||
@@ -9,19 +9,33 @@ defaults:
|
||||
# ====================
|
||||
train:
|
||||
# 基础训练参数
|
||||
batch_size: 8 # 批次大小
|
||||
lr: 5e-5 # 学习率(Transformer建议更小)
|
||||
batch_size: 16 # 批次大小
|
||||
lr: 1e-4 # 学习率
|
||||
max_steps: 100000 # 最大训练步数
|
||||
device: "cuda" # 设备: "cuda" 或 "cpu"
|
||||
disable_cudnn: false # 遇到当前机器的 cuDNN 兼容性问题时可置 true
|
||||
|
||||
# 数据加载
|
||||
num_workers: 8 # DataLoader 工作进程数(调试时设为 0,生产环境用 8)
|
||||
val_split: 0.1 # 验证集比例
|
||||
num_workers: 12 # DataLoader 工作进程数(调试时设为 0)
|
||||
val_split: 0.0 # 验证集比例;默认使用全量数据训练
|
||||
val_episode_indices: null # 显式按 episode 划出的验证集,例如 [100]
|
||||
action_mse_val_freq_epochs: 0 # >0 时每隔多少个 epoch 在 held-out episode 上计算 action MSE
|
||||
seed: 42 # 随机种子(用于数据划分)
|
||||
|
||||
# 日志和检查点
|
||||
log_freq: 100 # 日志记录频率(步数)
|
||||
save_freq: 2000 # 保存检查点频率(步数)
|
||||
use_swanlab: false # 是否启用 SwanLab 标量日志
|
||||
swanlab_project: "roboimi-vla" # SwanLab project 名称
|
||||
swanlab_run_name: null # 可选的 SwanLab 运行名
|
||||
rollout_val_freq_epochs: 50 # 每隔多少个 epoch 执行一次 rollout 验证
|
||||
rollout_validate_on_checkpoint: false # 是否在保存 checkpoint 后立即运行 rollout 验证
|
||||
rollout_num_episodes: 3 # rollout 验证的回合数
|
||||
rollout_device: ${train.device} # rollout 使用的设备;默认跟随训练设备
|
||||
rollout_num_workers: null # rollout 并行 worker 数;null 时 CUDA 自动推断,CPU 保持 1
|
||||
rollout_cuda_devices: null # rollout CUDA 并行使用的逻辑 device 列表;null 时默认 [0]
|
||||
rollout_response_timeout_s: 300.0 # rollout worker 等待 inference server 响应的超时时间
|
||||
rollout_server_startup_timeout_s: 300.0 # rollout 等待 inference server 就绪的超时时间
|
||||
|
||||
# 学习率调度器(带预热)
|
||||
warmup_steps: 2000 # 预热步数(Transformer建议更长)
|
||||
|
||||
@@ -11,6 +11,8 @@ dataset_dir: "roboimi/demos/dataset/sim_transfer"
|
||||
# ====================
|
||||
pred_horizon: ${agent.pred_horizon} # 预测步数
|
||||
obs_horizon: ${agent.obs_horizon} # 观测步数
|
||||
lewm_history_horizon: ${oc.select:agent.lewm_history_horizon,null}
|
||||
lewm_query_offsets: ${oc.select:agent.lewm_query_offsets,null}
|
||||
|
||||
# ====================
|
||||
# 相机配置
|
||||
@@ -19,3 +21,6 @@ camera_names:
|
||||
- r_vis # 机器人视角相机
|
||||
- top # 顶部相机
|
||||
- front # 前方相机
|
||||
|
||||
# 单视角预缩放尺寸;为 null 时保留数据集中的原始分辨率
|
||||
image_resize_shape: [224, 224]
|
||||
|
||||
@@ -2,6 +2,10 @@
|
||||
# 评估配置
|
||||
ckpt_path: "checkpoints/vla_model_best.pt" # 模型检查点路径
|
||||
num_episodes: 3 # 评估回合数
|
||||
num_workers: 1 # 并行 worker 数;1 表示保持单进程评估
|
||||
cuda_devices: null # CUDA 并行评估时使用的逻辑设备列表;null 表示默认 [0]
|
||||
response_timeout_s: 300.0 # worker 等待 inference server 响应的超时时间(秒)
|
||||
server_startup_timeout_s: 300.0 # parent 等待 inference server 就绪的超时时间(秒)
|
||||
max_timesteps: 700 # 每回合最大时间步
|
||||
device: ${train.device} # 与训练保持一致
|
||||
task_name: "sim_transfer" # 环境任务名称
|
||||
@@ -29,6 +33,22 @@ smooth_alpha: 0.3
|
||||
# ====================
|
||||
# 调试选项
|
||||
# ====================
|
||||
headless: false # 是否禁用 MuJoCo / OpenCV GUI 渲染
|
||||
verbose_action: true # 是否打印每个时间步的动作信息
|
||||
|
||||
|
||||
# ====================
|
||||
# Rollout artifact 导出
|
||||
# ====================
|
||||
artifact_dir: null # 可选输出目录;为空时在启用导出时自动创建目录
|
||||
save_artifacts: false # 总开关;实际仍需搭配下面的具体导出项
|
||||
save_timing: false # 是否保存 timing.json(包含各阶段耗时统计)
|
||||
save_trajectory: false # 是否保存 trajectory.npz(原始 EE action + 执行后 EE pose)
|
||||
save_summary_json: false # 是否保存 JSON-friendly rollout summary
|
||||
save_trajectory_npz: false # 是否保存每步轨迹/时序/EE pose 为 NPZ
|
||||
save_trajectory_image: false # 是否保存带红色 EE 轨迹覆盖的静态 PNG
|
||||
trajectory_image_camera: null # trajectory_image_camera_name 的别名
|
||||
trajectory_image_camera_name: null # 导出轨迹图片使用的相机名;为空时默认取 camera_names[0]
|
||||
record_video: false # 是否从单个相机流录制 rollout mp4
|
||||
video_camera: null # video_camera_name 的别名
|
||||
video_camera_name: null # 录制视频使用的相机名;为空时默认取 camera_names[0]
|
||||
video_fps: 30 # 导出 mp4 的目标帧率
|
||||
|
||||
22
roboimi/vla/conf/head/imf_transformer1d.yaml
Normal file
22
roboimi/vla/conf/head/imf_transformer1d.yaml
Normal file
@@ -0,0 +1,22 @@
|
||||
_target_: roboimi.vla.models.heads.imf_transformer1d.IMFTransformer1D
|
||||
_partial_: true
|
||||
|
||||
input_dim: ${agent.action_dim}
|
||||
output_dim: ${agent.action_dim}
|
||||
horizon: ${agent.pred_horizon}
|
||||
n_obs_steps: ${agent.obs_horizon}
|
||||
cond_dim: 208
|
||||
n_layer: 12
|
||||
n_head: 1
|
||||
n_emb: 768
|
||||
p_drop_emb: 0.1
|
||||
p_drop_attn: 0.1
|
||||
causal_attn: false
|
||||
time_as_cond: true
|
||||
obs_as_cond: true
|
||||
n_cond_layers: 0
|
||||
backbone_type: attnres_full
|
||||
n_kv_head: 1
|
||||
attn_res_ffn_mult: 2.667
|
||||
attn_res_eps: 1.0e-6
|
||||
attn_res_rope_theta: 10000.0
|
||||
@@ -5,7 +5,7 @@ _partial_: true
|
||||
# ====================
|
||||
# Transformer 架构配置
|
||||
# ====================
|
||||
n_layer: 4 # Transformer层数(先用小模型提高收敛稳定性)
|
||||
n_layer: 4 # Transformer层数(保持当前小模型配置)
|
||||
n_head: 4 # 注意力头数
|
||||
n_emb: 128 # 嵌入维度
|
||||
p_drop_emb: 0.05 # Embedding dropout
|
||||
@@ -14,9 +14,10 @@ p_drop_attn: 0.05 # Attention dropout
|
||||
# ====================
|
||||
# 条件配置
|
||||
# ====================
|
||||
causal_attn: false # 是否使用因果注意力(自回归生成)
|
||||
obs_as_cond: true # 观测作为条件(由cond_dim > 0决定)
|
||||
n_cond_layers: 1 # 条件编码器层数(1层先做稳定融合)
|
||||
causal_attn: false # 对齐 external TransformerForDiffusion 的 full-attention / nocausal 变体
|
||||
time_as_cond: true # 与 external 实现一致:时间步作为条件 token
|
||||
obs_as_cond: true # API 对齐;实际是否启用由 cond_dim > 0 决定
|
||||
n_cond_layers: 1 # 条件编码器层数(保留当前配置)
|
||||
|
||||
# ====================
|
||||
# 注意事项
|
||||
|
||||
5
roboimi/vla/conf/modules/lewm_state_encoder.yaml
Normal file
5
roboimi/vla/conf/modules/lewm_state_encoder.yaml
Normal file
@@ -0,0 +1,5 @@
|
||||
_target_: roboimi.vla.modules.encoders.LeWMStateEncoder
|
||||
|
||||
input_dim: ${agent.obs_dim}
|
||||
hidden_dim: 256
|
||||
output_dim: 64
|
||||
5
roboimi/vla/conf/modules/linear_condition_projector.yaml
Normal file
5
roboimi/vla/conf/modules/linear_condition_projector.yaml
Normal file
@@ -0,0 +1,5 @@
|
||||
_target_: roboimi.vla.modules.projectors.LinearConditionProjector
|
||||
_partial_: true
|
||||
|
||||
output_dim: 384
|
||||
bias: true
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import h5py
|
||||
from torch.utils.data import Dataset
|
||||
from typing import List, Dict, Union
|
||||
from typing import List, Dict, Union, Optional, Sequence
|
||||
from pathlib import Path
|
||||
from collections import OrderedDict
|
||||
|
||||
@@ -22,7 +22,11 @@ class SimpleRobotDataset(Dataset):
|
||||
obs_horizon: int = 2,
|
||||
pred_horizon: int = 8,
|
||||
camera_names: List[str] = None,
|
||||
image_resize_shape: Optional[Sequence[int]] = (224, 224),
|
||||
max_open_files: int = 64,
|
||||
lewm_history_horizon: Optional[int] = None,
|
||||
lewm_query_offsets: Optional[Sequence[int]] = None,
|
||||
episode_indices: Optional[Sequence[int]] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -30,6 +34,7 @@ class SimpleRobotDataset(Dataset):
|
||||
obs_horizon: 观察过去多少帧
|
||||
pred_horizon: 预测未来多少帧动作
|
||||
camera_names: 相机名称列表,如 ["r_vis", "top", "front"]
|
||||
image_resize_shape: 图像缩放尺寸 (W, H);为 None 时保留原始分辨率
|
||||
max_open_files: 每个 worker 最多缓存的 HDF5 文件句柄数
|
||||
|
||||
HDF5 文件格式:
|
||||
@@ -40,8 +45,22 @@ class SimpleRobotDataset(Dataset):
|
||||
self.obs_horizon = obs_horizon
|
||||
self.pred_horizon = pred_horizon
|
||||
self.camera_names = camera_names or []
|
||||
self.lewm_history_horizon = (
|
||||
int(lewm_history_horizon) if lewm_history_horizon is not None else None
|
||||
)
|
||||
self.lewm_query_offsets = (
|
||||
tuple(int(offset) for offset in lewm_query_offsets)
|
||||
if lewm_query_offsets is not None else ()
|
||||
)
|
||||
self.image_resize_shape = (
|
||||
tuple(int(v) for v in image_resize_shape)
|
||||
if image_resize_shape is not None else None
|
||||
)
|
||||
self.max_open_files = max(1, int(max_open_files))
|
||||
self._file_cache: "OrderedDict[str, h5py.File]" = OrderedDict()
|
||||
self.requested_episode_indices = (
|
||||
None if episode_indices is None else tuple(sorted(int(idx) for idx in episode_indices))
|
||||
)
|
||||
|
||||
self.dataset_dir = Path(dataset_dir)
|
||||
if not self.dataset_dir.exists():
|
||||
@@ -54,20 +73,45 @@ class SimpleRobotDataset(Dataset):
|
||||
if not self.hdf5_files:
|
||||
raise FileNotFoundError(f"在 {dataset_dir} 中未找到 HDF5 文件")
|
||||
|
||||
if self.requested_episode_indices is not None:
|
||||
requested = set(self.requested_episode_indices)
|
||||
filtered = []
|
||||
for hdf5_path in self.hdf5_files:
|
||||
stem = hdf5_path.stem
|
||||
if stem.startswith("episode_"):
|
||||
try:
|
||||
idx = int(stem.split("_")[-1])
|
||||
except ValueError:
|
||||
continue
|
||||
if idx in requested:
|
||||
filtered.append(hdf5_path)
|
||||
self.hdf5_files = filtered
|
||||
if not self.hdf5_files:
|
||||
raise FileNotFoundError(
|
||||
f"在 {dataset_dir} 中未找到 episode_indices={sorted(requested)} 对应的 HDF5 文件"
|
||||
)
|
||||
|
||||
# 构建 episode 索引(只存储元数据,不加载数据)
|
||||
self.episodes = {}
|
||||
self.frame_meta = [] # 存储 (ep_idx, frame_idx, hdf5_path)
|
||||
for ep_idx, hdf5_path in enumerate(self.hdf5_files):
|
||||
with h5py.File(hdf5_path, 'r') as f:
|
||||
T = f['action'].shape[0]
|
||||
dataset_episode_idx = ep_idx
|
||||
stem = hdf5_path.stem
|
||||
if stem.startswith("episode_"):
|
||||
try:
|
||||
dataset_episode_idx = int(stem.split("_")[-1])
|
||||
except ValueError:
|
||||
pass
|
||||
start_idx = len(self.frame_meta)
|
||||
for t in range(T):
|
||||
self.frame_meta.append({
|
||||
"ep_idx": ep_idx,
|
||||
"ep_idx": dataset_episode_idx,
|
||||
"frame_idx": t,
|
||||
"hdf5_path": hdf5_path,
|
||||
})
|
||||
self.episodes[ep_idx] = list(range(start_idx, len(self.frame_meta)))
|
||||
self.episodes[dataset_episode_idx] = list(range(start_idx, len(self.frame_meta)))
|
||||
|
||||
print(f"懒加载模式: {len(self.hdf5_files)} 个 episodes, 共 {len(self.frame_meta)} 帧")
|
||||
|
||||
@@ -105,7 +149,7 @@ class SimpleRobotDataset(Dataset):
|
||||
self._file_cache[key] = f
|
||||
return f
|
||||
|
||||
def _load_frame(self, idx: int) -> Dict:
|
||||
def _load_frame(self, idx: int, *, load_images: bool = True) -> Dict:
|
||||
"""从 HDF5 文件懒加载单帧数据"""
|
||||
meta = self.frame_meta[idx]
|
||||
f = self._get_h5_file(meta["hdf5_path"])
|
||||
@@ -118,13 +162,14 @@ class SimpleRobotDataset(Dataset):
|
||||
}
|
||||
|
||||
# 加载图像数据: observations/images/{cam_name} -> observation.{cam_name}
|
||||
if load_images:
|
||||
for cam_name in self.camera_names:
|
||||
h5_path = f'observations/images/{cam_name}'
|
||||
if h5_path in f:
|
||||
img = f[h5_path][meta["frame_idx"]]
|
||||
# Resize图像到224x224(减少内存和I/O负担)
|
||||
if self.image_resize_shape is not None:
|
||||
import cv2
|
||||
img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_LINEAR)
|
||||
img = cv2.resize(img, self.image_resize_shape, interpolation=cv2.INTER_LINEAR)
|
||||
# 转换为float并归一化到 [0, 1]
|
||||
img = torch.from_numpy(img).float() / 255.0
|
||||
frame[f"observation.{cam_name}"] = img.permute(2, 0, 1) # HWC -> CHW
|
||||
@@ -132,7 +177,7 @@ class SimpleRobotDataset(Dataset):
|
||||
return frame
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
||||
frame = self._load_frame(idx)
|
||||
frame = self._load_frame(idx, load_images=False)
|
||||
ep_idx = frame["episode_index"]
|
||||
|
||||
# 获取当前 episode 的帧索引范围
|
||||
@@ -186,10 +231,10 @@ class SimpleRobotDataset(Dataset):
|
||||
target_idx = idx + delta
|
||||
|
||||
if target_idx <= ep_end:
|
||||
actions.append(self._load_frame(target_idx)["action"])
|
||||
actions.append(self._load_frame(target_idx, load_images=False)["action"])
|
||||
action_is_pad.append(False)
|
||||
else:
|
||||
actions.append(self._load_frame(ep_end)["action"])
|
||||
actions.append(self._load_frame(ep_end, load_images=False)["action"])
|
||||
action_is_pad.append(True)
|
||||
|
||||
# ============================================
|
||||
@@ -213,6 +258,60 @@ class SimpleRobotDataset(Dataset):
|
||||
for cam_name in self.camera_names:
|
||||
result[f"observation.{cam_name}"] = torch.stack(observations[f"observation.{cam_name}"])
|
||||
|
||||
if self.lewm_history_horizon is not None and self.lewm_history_horizon > 0:
|
||||
lewm_observations = {
|
||||
"state": [],
|
||||
}
|
||||
for cam_name in self.camera_names:
|
||||
lewm_observations[f"observation.{cam_name}"] = []
|
||||
|
||||
for delta in range(-self.lewm_history_horizon + 1, 1):
|
||||
target_idx = idx + delta
|
||||
if ep_start <= target_idx <= ep_end:
|
||||
target_frame = self._load_frame(target_idx)
|
||||
else:
|
||||
boundary_idx = ep_start if target_idx < ep_start else ep_end
|
||||
target_frame = self._load_frame(boundary_idx)
|
||||
|
||||
lewm_observations["state"].append(target_frame["observation.state"])
|
||||
for cam_name in self.camera_names:
|
||||
lewm_observations[f"observation.{cam_name}"].append(
|
||||
target_frame[f"observation.{cam_name}"]
|
||||
)
|
||||
|
||||
result["lewm.observation.state"] = torch.stack(lewm_observations["state"])
|
||||
for cam_name in self.camera_names:
|
||||
result[f"lewm.observation.{cam_name}"] = torch.stack(
|
||||
lewm_observations[f"observation.{cam_name}"]
|
||||
)
|
||||
|
||||
if self.lewm_query_offsets:
|
||||
lewm_future = {
|
||||
"state": [],
|
||||
}
|
||||
for cam_name in self.camera_names:
|
||||
lewm_future[f"observation.{cam_name}"] = []
|
||||
|
||||
for offset in self.lewm_query_offsets:
|
||||
target_idx = idx + offset
|
||||
if ep_start <= target_idx <= ep_end:
|
||||
target_frame = self._load_frame(target_idx)
|
||||
else:
|
||||
boundary_idx = ep_start if target_idx < ep_start else ep_end
|
||||
target_frame = self._load_frame(boundary_idx)
|
||||
|
||||
lewm_future["state"].append(target_frame["observation.state"])
|
||||
for cam_name in self.camera_names:
|
||||
lewm_future[f"observation.{cam_name}"].append(
|
||||
target_frame[f"observation.{cam_name}"]
|
||||
)
|
||||
|
||||
result["lewm.future.state"] = torch.stack(lewm_future["state"])
|
||||
for cam_name in self.camera_names:
|
||||
result[f"lewm.future.{cam_name}"] = torch.stack(
|
||||
lewm_future[f"observation.{cam_name}"]
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@property
|
||||
@@ -220,6 +319,10 @@ class SimpleRobotDataset(Dataset):
|
||||
"""获取所有相机键名 (LeRobotDataset 格式)"""
|
||||
return [f"observation.{cam_name}" for cam_name in self.camera_names]
|
||||
|
||||
@property
|
||||
def available_episode_indices(self) -> List[int]:
|
||||
return sorted(self.episodes.keys())
|
||||
|
||||
@property
|
||||
def camera_info(self) -> dict:
|
||||
"""获取相机信息"""
|
||||
|
||||
3
roboimi/vla/eval_utils.py
Normal file
3
roboimi/vla/eval_utils.py
Normal file
@@ -0,0 +1,3 @@
|
||||
def execute_policy_action(env, action):
|
||||
"""Execute policy outputs using EE-action semantics."""
|
||||
env.step(action)
|
||||
@@ -1,4 +1,37 @@
|
||||
# Backbone models
|
||||
from .resnet_diffusion import ResNetDiffusionBackbone
|
||||
__all__ = [
|
||||
"LEWMViTBackbone",
|
||||
"LeWMMultiViewResNetBackbone",
|
||||
"QueryTokenPredictor",
|
||||
"LeWMProjectorMLP",
|
||||
"SIGReg",
|
||||
"ResNetBackbone",
|
||||
"ResNetDiffusionBackbone",
|
||||
"SigLIP2DiffusionBackbone",
|
||||
]
|
||||
|
||||
__all__ = ["ResNetBackbone", "ResNetDiffusionBackbone"]
|
||||
|
||||
def __getattr__(name):
|
||||
if name == "LEWMViTBackbone":
|
||||
from .lewm_vit_backbone import LEWMViTBackbone
|
||||
return LEWMViTBackbone
|
||||
if name == "SigLIP2DiffusionBackbone":
|
||||
from .siglip2_diffusion_backbone import SigLIP2DiffusionBackbone
|
||||
return SigLIP2DiffusionBackbone
|
||||
if name in {"LeWMMultiViewResNetBackbone", "QueryTokenPredictor", "LeWMProjectorMLP", "SIGReg"}:
|
||||
from .lewm_resnet_query_fusion import (
|
||||
LeWMMultiViewResNetBackbone,
|
||||
QueryTokenPredictor,
|
||||
LeWMProjectorMLP,
|
||||
SIGReg,
|
||||
)
|
||||
return {
|
||||
"LeWMMultiViewResNetBackbone": LeWMMultiViewResNetBackbone,
|
||||
"QueryTokenPredictor": QueryTokenPredictor,
|
||||
"LeWMProjectorMLP": LeWMProjectorMLP,
|
||||
"SIGReg": SIGReg,
|
||||
}[name]
|
||||
if name in {"ResNetBackbone", "ResNetDiffusionBackbone"}:
|
||||
from .resnet_diffusion import ResNetDiffusionBackbone
|
||||
return ResNetDiffusionBackbone
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
228
roboimi/vla/models/backbones/attnres_resnet2d.py
Normal file
228
roboimi/vla/models/backbones/attnres_resnet2d.py
Normal file
@@ -0,0 +1,228 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Iterable, Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from roboimi.vla.models.heads.attnres_transformer_components import AttnResTransformerBackbone
|
||||
|
||||
|
||||
def _make_norm2d(num_channels: int, use_group_norm: bool) -> nn.Module:
|
||||
if use_group_norm:
|
||||
num_groups = max(1, num_channels // 16)
|
||||
return nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)
|
||||
return nn.BatchNorm2d(num_channels)
|
||||
|
||||
|
||||
class _ConvNormAct(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
*,
|
||||
kernel_size: int,
|
||||
stride: int,
|
||||
padding: int,
|
||||
use_group_norm: bool,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=False,
|
||||
),
|
||||
_make_norm2d(out_channels, use_group_norm),
|
||||
nn.SiLU(),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class AttnResImageBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
*,
|
||||
window_size: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
dropout: float,
|
||||
ffn_mult: float,
|
||||
eps: float,
|
||||
rope_theta: float,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.window_size = int(window_size)
|
||||
self.block = AttnResTransformerBackbone(
|
||||
d_model=dim,
|
||||
n_blocks=1,
|
||||
n_heads=n_heads,
|
||||
n_kv_heads=n_kv_heads,
|
||||
max_seq_len=self.window_size * self.window_size,
|
||||
dropout=dropout,
|
||||
ffn_mult=ffn_mult,
|
||||
eps=eps,
|
||||
rope_theta=rope_theta,
|
||||
causal_attn=False,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
bsz, channels, height, width = x.shape
|
||||
ws = self.window_size
|
||||
pad_h = (ws - height % ws) % ws
|
||||
pad_w = (ws - width % ws) % ws
|
||||
if pad_h or pad_w:
|
||||
x = F.pad(x, (0, pad_w, 0, pad_h))
|
||||
padded_height, padded_width = x.shape[-2:]
|
||||
num_h = padded_height // ws
|
||||
num_w = padded_width // ws
|
||||
|
||||
windows = (
|
||||
x.permute(0, 2, 3, 1)
|
||||
.contiguous()
|
||||
.view(bsz, num_h, ws, num_w, ws, channels)
|
||||
.permute(0, 1, 3, 2, 4, 5)
|
||||
.contiguous()
|
||||
.view(bsz * num_h * num_w, ws * ws, channels)
|
||||
)
|
||||
windows = self.block(windows)
|
||||
x = (
|
||||
windows.view(bsz, num_h, num_w, ws, ws, channels)
|
||||
.permute(0, 1, 3, 2, 4, 5)
|
||||
.contiguous()
|
||||
.view(bsz, padded_height, padded_width, channels)
|
||||
.permute(0, 3, 1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
return x[:, :, :height, :width]
|
||||
|
||||
|
||||
class _AttnResStage2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
*,
|
||||
depth: int,
|
||||
downsample_stride: int,
|
||||
use_group_norm: bool,
|
||||
window_size: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
dropout: float,
|
||||
ffn_mult: float,
|
||||
eps: float,
|
||||
rope_theta: float,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.downsample = None
|
||||
if in_channels != out_channels or downsample_stride != 1:
|
||||
kernel_size = 1 if downsample_stride == 1 else 3
|
||||
padding = 0 if downsample_stride == 1 else 1
|
||||
self.downsample = _ConvNormAct(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=downsample_stride,
|
||||
padding=padding,
|
||||
use_group_norm=use_group_norm,
|
||||
)
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
AttnResImageBlock2D(
|
||||
out_channels,
|
||||
window_size=window_size,
|
||||
n_heads=n_heads,
|
||||
n_kv_heads=n_kv_heads,
|
||||
dropout=dropout,
|
||||
ffn_mult=ffn_mult,
|
||||
eps=eps,
|
||||
rope_theta=rope_theta,
|
||||
)
|
||||
for _ in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
return x
|
||||
|
||||
|
||||
class AttnResResNetLikeBackbone2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
input_channels: int = 3,
|
||||
stem_dim: int = 64,
|
||||
stage_dims: Sequence[int] = (64, 128, 256, 512),
|
||||
stage_depths: Sequence[int] = (2, 2, 2, 2),
|
||||
stage_heads: Sequence[int] = (4, 4, 8, 8),
|
||||
stage_kv_heads: Sequence[int] = (1, 1, 1, 1),
|
||||
stage_window_sizes: Sequence[int] = (7, 7, 7, 7),
|
||||
use_group_norm: bool = True,
|
||||
dropout: float = 0.0,
|
||||
ffn_mult: float = 2.667,
|
||||
eps: float = 1e-6,
|
||||
rope_theta: float = 10000.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
lengths = {
|
||||
len(stage_dims),
|
||||
len(stage_depths),
|
||||
len(stage_heads),
|
||||
len(stage_kv_heads),
|
||||
len(stage_window_sizes),
|
||||
}
|
||||
if len(lengths) != 1:
|
||||
raise ValueError('stage_dims/depths/heads/kv_heads/window_sizes 长度必须一致')
|
||||
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(input_channels, stem_dim, kernel_size=7, stride=2, padding=3, bias=False),
|
||||
_make_norm2d(stem_dim, use_group_norm),
|
||||
nn.SiLU(),
|
||||
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
|
||||
)
|
||||
|
||||
in_channels = stem_dim
|
||||
stages = []
|
||||
for stage_idx, (out_channels, depth, n_heads, n_kv_heads, window_size) in enumerate(
|
||||
zip(stage_dims, stage_depths, stage_heads, stage_kv_heads, stage_window_sizes)
|
||||
):
|
||||
stages.append(
|
||||
_AttnResStage2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
depth=int(depth),
|
||||
downsample_stride=1 if stage_idx == 0 else 2,
|
||||
use_group_norm=use_group_norm,
|
||||
window_size=int(window_size),
|
||||
n_heads=int(n_heads),
|
||||
n_kv_heads=int(n_kv_heads),
|
||||
dropout=float(dropout),
|
||||
ffn_mult=float(ffn_mult),
|
||||
eps=float(eps),
|
||||
rope_theta=float(rope_theta),
|
||||
)
|
||||
)
|
||||
in_channels = out_channels
|
||||
|
||||
self.stages = nn.ModuleList(stages)
|
||||
self.output_channels = in_channels
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.stem(x)
|
||||
for stage in self.stages:
|
||||
x = stage(x)
|
||||
return x
|
||||
409
roboimi/vla/models/backbones/lewm_resnet_query_fusion.py
Normal file
409
roboimi/vla/models/backbones/lewm_resnet_query_fusion.py
Normal file
@@ -0,0 +1,409 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Mapping, Optional, Sequence
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torchvision import models
|
||||
|
||||
from roboimi.vla.core.interfaces import VLABackbone
|
||||
|
||||
|
||||
class SpatialSoftmax2D(nn.Module):
|
||||
"""Convert a feature map into expected 2D keypoint coordinates per channel."""
|
||||
|
||||
def forward(self, feature_map):
|
||||
if feature_map.ndim != 4:
|
||||
raise ValueError(
|
||||
f"SpatialSoftmax2D expects a 4D tensor, got rank {feature_map.ndim}"
|
||||
)
|
||||
|
||||
batch, channels, height, width = feature_map.shape
|
||||
scores = feature_map.reshape(batch, channels, height * width)
|
||||
attention = F.softmax(scores, dim=-1)
|
||||
|
||||
ys = torch.linspace(-1.0, 1.0, height, device=feature_map.device, dtype=feature_map.dtype)
|
||||
xs = torch.linspace(-1.0, 1.0, width, device=feature_map.device, dtype=feature_map.dtype)
|
||||
grid_y, grid_x = torch.meshgrid(ys, xs, indexing="ij")
|
||||
grid_x = grid_x.reshape(1, 1, height * width)
|
||||
grid_y = grid_y.reshape(1, 1, height * width)
|
||||
|
||||
expected_x = (attention * grid_x).sum(dim=-1)
|
||||
expected_y = (attention * grid_y).sum(dim=-1)
|
||||
return torch.cat([expected_x, expected_y], dim=-1)
|
||||
|
||||
|
||||
class ResNet18SpatialEncoder(nn.Module):
|
||||
"""Encode one camera view into a fixed-dimensional spatial-softmax embedding."""
|
||||
|
||||
def __init__(self, view_feature_dim=96):
|
||||
super().__init__()
|
||||
if view_feature_dim % 2 != 0:
|
||||
raise ValueError("view_feature_dim must be even for spatial softmax features")
|
||||
|
||||
backbone = models.resnet18(weights=None)
|
||||
if all(
|
||||
hasattr(backbone, name)
|
||||
for name in ("conv1", "bn1", "relu", "maxpool", "layer1", "layer2", "layer3", "layer4")
|
||||
):
|
||||
self.backbone = nn.Sequential(
|
||||
backbone.conv1,
|
||||
backbone.bn1,
|
||||
backbone.relu,
|
||||
backbone.maxpool,
|
||||
backbone.layer1,
|
||||
backbone.layer2,
|
||||
backbone.layer3,
|
||||
backbone.layer4,
|
||||
)
|
||||
feature_channels = 512
|
||||
else:
|
||||
children = list(backbone.children())
|
||||
if len(children) < 1:
|
||||
raise ValueError("resnet18 backbone must expose child modules")
|
||||
truncated = children[:-2] if len(children) > 2 else children
|
||||
self.backbone = nn.Sequential(*truncated)
|
||||
with torch.no_grad():
|
||||
dummy = torch.zeros(1, 3, 16, 16)
|
||||
feature_channels = int(self.backbone(dummy).shape[1])
|
||||
|
||||
self.proj = nn.Conv2d(feature_channels, view_feature_dim // 2, kernel_size=1)
|
||||
self.spatial_softmax = SpatialSoftmax2D()
|
||||
self.output_dim = int(view_feature_dim)
|
||||
|
||||
def forward(self, pixels):
|
||||
if pixels.ndim not in (4, 5):
|
||||
raise ValueError(
|
||||
f"ResNet18SpatialEncoder expects a 4D or 5D tensor, got rank {pixels.ndim}"
|
||||
)
|
||||
|
||||
needs_unflatten = pixels.ndim == 5
|
||||
if needs_unflatten:
|
||||
batch, steps, channels, height, width = pixels.shape
|
||||
pixels = rearrange(pixels, "b t c h w -> (b t) c h w")
|
||||
|
||||
features = self.backbone(pixels.float())
|
||||
features = self.proj(features)
|
||||
embeddings = self.spatial_softmax(features)
|
||||
|
||||
if needs_unflatten:
|
||||
embeddings = rearrange(embeddings, "(b t) d -> b t d", b=batch, t=steps)
|
||||
return embeddings
|
||||
|
||||
|
||||
class LeWMMultiViewResNetBackbone(VLABackbone):
|
||||
"""RoboIMI-side LeWM multiview ResNet spatial-softmax encoder."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
view_feature_dim: int = 96,
|
||||
num_views: int = 3,
|
||||
view_encoder_mode: str = "shared",
|
||||
camera_names: Sequence[str] = ("r_vis", "top", "front"),
|
||||
checkpoint_path: str | Path | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if view_encoder_mode not in {"shared", "separate"}:
|
||||
raise ValueError(
|
||||
f"view_encoder_mode must be 'shared' or 'separate', got {view_encoder_mode}"
|
||||
)
|
||||
|
||||
self.view_feature_dim = int(view_feature_dim)
|
||||
self.num_views = int(num_views)
|
||||
self.view_encoder_mode = view_encoder_mode
|
||||
self.camera_names = tuple(camera_names)
|
||||
if len(self.camera_names) != self.num_views:
|
||||
raise ValueError(
|
||||
f"camera_names length({len(self.camera_names)}) must equal num_views({self.num_views})"
|
||||
)
|
||||
self.output_dim = self.view_feature_dim * self.num_views
|
||||
self.joint_output_dim = self.output_dim
|
||||
self.tokens_per_step = 1
|
||||
|
||||
if view_encoder_mode == "shared":
|
||||
self.single_view_encoder = ResNet18SpatialEncoder(
|
||||
view_feature_dim=view_feature_dim
|
||||
)
|
||||
self.view_encoders = None
|
||||
else:
|
||||
self.single_view_encoder = None
|
||||
self.view_encoders = nn.ModuleList(
|
||||
[
|
||||
ResNet18SpatialEncoder(view_feature_dim=view_feature_dim)
|
||||
for _ in range(num_views)
|
||||
]
|
||||
)
|
||||
|
||||
if checkpoint_path is not None:
|
||||
self.load_lewm_checkpoint(checkpoint_path)
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_state_dict(payload: Mapping[str, Any]) -> Mapping[str, torch.Tensor]:
|
||||
state_dict = payload.get("state_dict", payload)
|
||||
if not isinstance(state_dict, Mapping):
|
||||
raise TypeError("checkpoint payload must contain a mapping state_dict")
|
||||
return state_dict
|
||||
|
||||
@staticmethod
|
||||
def _extract_prefixed_state_dict(
|
||||
state_dict: Mapping[str, torch.Tensor],
|
||||
prefix: str,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
extracted = {
|
||||
key[len(prefix):]: value
|
||||
for key, value in state_dict.items()
|
||||
if key.startswith(prefix)
|
||||
}
|
||||
if not extracted:
|
||||
raise KeyError(f"checkpoint missing parameters with prefix {prefix!r}")
|
||||
return extracted
|
||||
|
||||
def load_lewm_checkpoint(self, checkpoint_or_path: str | Path | Mapping[str, Any]) -> None:
|
||||
if isinstance(checkpoint_or_path, (str, Path)):
|
||||
payload = torch.load(Path(checkpoint_or_path), map_location="cpu", weights_only=False)
|
||||
else:
|
||||
payload = checkpoint_or_path
|
||||
state_dict = self._unwrap_state_dict(payload)
|
||||
encoder_state_dict = self._extract_prefixed_state_dict(state_dict, "model.encoder.")
|
||||
self.load_state_dict(encoder_state_dict, strict=True)
|
||||
|
||||
def forward(self, images):
|
||||
missing = [camera_name for camera_name in self.camera_names if camera_name not in images]
|
||||
if missing:
|
||||
raise ValueError(
|
||||
f"image input missing required cameras. missing={missing}, expected={list(self.camera_names)}"
|
||||
)
|
||||
|
||||
first_image = images[self.camera_names[0]]
|
||||
batch_size, steps = first_image.shape[:2]
|
||||
view_embeddings = []
|
||||
if self.view_encoder_mode == "shared":
|
||||
for camera_name in self.camera_names:
|
||||
view_embeddings.append(self.single_view_encoder(images[camera_name]))
|
||||
else:
|
||||
for single_view_encoder, camera_name in zip(self.view_encoders, self.camera_names):
|
||||
view_embeddings.append(single_view_encoder(images[camera_name]))
|
||||
|
||||
embeddings = torch.cat(view_embeddings, dim=-1)
|
||||
return embeddings.reshape(batch_size, steps, self.output_dim)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout=0.0):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
self.heads = heads
|
||||
self.dropout = dropout
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
|
||||
self.to_out = (
|
||||
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
|
||||
if project_out
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, x, causal=True):
|
||||
x = self.norm(x)
|
||||
drop = self.dropout if self.training else 0.0
|
||||
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
||||
q, k, v = (rearrange(t, "b t (h d) -> b h t d", h=self.heads) for t in qkv)
|
||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=drop, is_causal=causal)
|
||||
out = rearrange(out, "b h t d -> b t (h d)")
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, dim, heads, dim_head, mlp_dim, dropout=0.0):
|
||||
super().__init__()
|
||||
self.attn = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
|
||||
self.mlp = FeedForward(dim, mlp_dim, dropout=dropout)
|
||||
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.attn(self.norm1(x))
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim,
|
||||
hidden_dim,
|
||||
output_dim,
|
||||
depth,
|
||||
heads,
|
||||
dim_head,
|
||||
mlp_dim,
|
||||
dropout=0.0,
|
||||
block_class=Block,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(hidden_dim)
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
self.input_proj = (
|
||||
nn.Linear(input_dim, hidden_dim)
|
||||
if input_dim != hidden_dim
|
||||
else nn.Identity()
|
||||
)
|
||||
self.cond_proj = (
|
||||
nn.Linear(input_dim, hidden_dim)
|
||||
if input_dim != hidden_dim
|
||||
else nn.Identity()
|
||||
)
|
||||
self.output_proj = (
|
||||
nn.Linear(hidden_dim, output_dim)
|
||||
if hidden_dim != output_dim
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(block_class(hidden_dim, heads, dim_head, mlp_dim, dropout))
|
||||
|
||||
def forward(self, x, c=None):
|
||||
x = self.input_proj(x)
|
||||
if c is not None:
|
||||
c = self.cond_proj(c)
|
||||
for block in self.layers:
|
||||
x = block(x)
|
||||
x = self.norm(x)
|
||||
return self.output_proj(x)
|
||||
|
||||
|
||||
class QueryTokenPredictor(nn.Module):
|
||||
"""History-only transformer predictor that decodes learned query tokens."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num_frames,
|
||||
query_offsets,
|
||||
depth,
|
||||
heads,
|
||||
mlp_dim,
|
||||
input_dim,
|
||||
hidden_dim,
|
||||
output_dim=None,
|
||||
dim_head=64,
|
||||
dropout=0.0,
|
||||
emb_dropout=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
if num_frames <= 0:
|
||||
raise ValueError(f"num_frames must be positive, got {num_frames}")
|
||||
|
||||
query_offsets = tuple(query_offsets)
|
||||
if not query_offsets:
|
||||
raise ValueError("query_offsets must contain at least one offset")
|
||||
if any(offset <= 0 for offset in query_offsets):
|
||||
raise ValueError(f"query_offsets must be positive, got {query_offsets}")
|
||||
|
||||
self.num_frames = int(num_frames)
|
||||
self.query_offsets = query_offsets
|
||||
self.num_query_tokens = len(query_offsets)
|
||||
self.pos_embedding = nn.Parameter(
|
||||
torch.randn(1, self.num_frames + self.num_query_tokens, input_dim)
|
||||
)
|
||||
self.query_tokens = nn.Parameter(
|
||||
torch.randn(1, self.num_query_tokens, input_dim)
|
||||
)
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
self.transformer = Transformer(
|
||||
input_dim,
|
||||
hidden_dim,
|
||||
output_dim or input_dim,
|
||||
depth,
|
||||
heads,
|
||||
dim_head,
|
||||
mlp_dim,
|
||||
dropout,
|
||||
block_class=Block,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if x.ndim != 3:
|
||||
raise ValueError(
|
||||
f"QueryTokenPredictor expects a 3D tensor, got rank {x.ndim}"
|
||||
)
|
||||
|
||||
T = x.size(1)
|
||||
if T > self.num_frames:
|
||||
raise ValueError(
|
||||
f"input sequence length {T} exceeds configured num_frames {self.num_frames}"
|
||||
)
|
||||
|
||||
query_tokens = self.query_tokens.expand(x.size(0), -1, -1)
|
||||
tokens = torch.cat([x, query_tokens], dim=1)
|
||||
tokens = tokens + self.pos_embedding[:, : tokens.size(1)]
|
||||
tokens = self.dropout(tokens)
|
||||
tokens = self.transformer(tokens)
|
||||
return tokens[:, -self.num_query_tokens :]
|
||||
|
||||
|
||||
class LeWMProjectorMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int = 288,
|
||||
hidden_dim: int = 2048,
|
||||
output_dim: int = 288,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.output_dim = int(output_dim)
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(int(input_dim), int(hidden_dim)),
|
||||
nn.BatchNorm1d(int(hidden_dim)),
|
||||
nn.GELU(),
|
||||
nn.Linear(int(hidden_dim), self.output_dim),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class SIGReg(nn.Module):
|
||||
"""Sketch Isotropic Gaussian Regularizer, matching the original LeWM design."""
|
||||
|
||||
def __init__(self, knots: int = 17, num_proj: int = 1024) -> None:
|
||||
super().__init__()
|
||||
self.num_proj = int(num_proj)
|
||||
t = torch.linspace(0, 3, int(knots), dtype=torch.float32)
|
||||
dt = 3 / (int(knots) - 1)
|
||||
weights = torch.full((int(knots),), 2 * dt, dtype=torch.float32)
|
||||
weights[[0, -1]] = dt
|
||||
window = torch.exp(-t.square() / 2.0)
|
||||
self.register_buffer("t", t)
|
||||
self.register_buffer("phi", window)
|
||||
self.register_buffer("weights", weights * window)
|
||||
|
||||
def forward(self, proj: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
proj: (T, B, D)
|
||||
"""
|
||||
A = torch.randn(proj.size(-1), self.num_proj, device=proj.device)
|
||||
A = A.div_(A.norm(p=2, dim=0))
|
||||
x_t = (proj @ A).unsqueeze(-1) * self.t
|
||||
err = (x_t.cos().mean(-3) - self.phi).square() + x_t.sin().mean(-3).square()
|
||||
statistic = (err @ self.weights) * proj.size(-2)
|
||||
return statistic.mean()
|
||||
230
roboimi/vla/models/backbones/lewm_vit_backbone.py
Normal file
230
roboimi/vla/models/backbones/lewm_vit_backbone.py
Normal file
@@ -0,0 +1,230 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Mapping, Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from roboimi.vla.core.interfaces import VLABackbone
|
||||
|
||||
|
||||
class _LEWMProjector(nn.Module):
|
||||
"""LEWM projector MLP: 192 -> 2048 -> 192 with BatchNorm1d + GELU."""
|
||||
|
||||
def __init__(self, input_dim: int = 192, hidden_dim: int = 2048, output_dim: int = 192) -> None:
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(input_dim, hidden_dim),
|
||||
nn.BatchNorm1d(hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, output_dim),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class LEWMViTBackbone(VLABackbone):
|
||||
"""Frozen LEWM joint-multiview ViT backbone.
|
||||
|
||||
The backbone fuses the three camera views into a single LEWM-style image,
|
||||
runs a ViT-tiny encoder plus the LEWM projector, and returns one joint
|
||||
192-d embedding per timestep.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_path: str | Path | None = None,
|
||||
*,
|
||||
checkpoint: Mapping[str, Any] | None = None,
|
||||
camera_names: Sequence[str] = ("r_vis", "top", "front"),
|
||||
fused_camera_names: Sequence[str] = ("front", "top", "r_vis"),
|
||||
num_cameras: int | None = None,
|
||||
dataset_image_resize_shape: Sequence[int] | None = None,
|
||||
eval_image_resize_shape: Sequence[int] | None = (256, 256),
|
||||
freeze_backbone: bool = True,
|
||||
joint_output_dim: int = 192,
|
||||
image_size: int = 224,
|
||||
output_dim: int = 192,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.camera_names = tuple(camera_names)
|
||||
self.fused_camera_names = tuple(fused_camera_names)
|
||||
self.num_cameras = int(num_cameras) if num_cameras is not None else len(self.camera_names)
|
||||
self.freeze_backbone = bool(freeze_backbone)
|
||||
self.joint_output_dim = int(joint_output_dim)
|
||||
self.image_size = int(image_size)
|
||||
self._output_dim = int(output_dim)
|
||||
self.dataset_image_resize_shape = (
|
||||
tuple(int(v) for v in dataset_image_resize_shape)
|
||||
if dataset_image_resize_shape is not None else None
|
||||
)
|
||||
self.eval_image_resize_shape = (
|
||||
tuple(int(v) for v in eval_image_resize_shape)
|
||||
if eval_image_resize_shape is not None else None
|
||||
)
|
||||
if self.num_cameras != len(self.camera_names):
|
||||
raise ValueError(
|
||||
f"num_cameras({self.num_cameras}) must match len(camera_names)({len(self.camera_names)})"
|
||||
)
|
||||
if set(self.fused_camera_names) != set(self.camera_names):
|
||||
raise ValueError(
|
||||
"fused_camera_names must contain the same cameras as camera_names. "
|
||||
f"got camera_names={list(self.camera_names)}, fused_camera_names={list(self.fused_camera_names)}"
|
||||
)
|
||||
|
||||
self.encoder = self._build_encoder(self.image_size)
|
||||
self.projector = _LEWMProjector(
|
||||
input_dim=self.encoder.config.hidden_size,
|
||||
hidden_dim=2048,
|
||||
output_dim=self.joint_output_dim,
|
||||
)
|
||||
|
||||
self.register_buffer(
|
||||
"mean",
|
||||
torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32).view(1, 3, 1, 1),
|
||||
)
|
||||
self.register_buffer(
|
||||
"std",
|
||||
torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32).view(1, 3, 1, 1),
|
||||
)
|
||||
|
||||
if checkpoint_path is not None and checkpoint is not None:
|
||||
raise ValueError("checkpoint_path and checkpoint cannot both be provided")
|
||||
if checkpoint_path is not None:
|
||||
self.load_lewm_checkpoint(checkpoint_path)
|
||||
elif checkpoint is not None:
|
||||
self.load_lewm_checkpoint(checkpoint)
|
||||
|
||||
if self.freeze_backbone:
|
||||
self._freeze_encoder_and_projector()
|
||||
|
||||
@staticmethod
|
||||
def _build_encoder_config(image_size: int):
|
||||
from transformers import ViTConfig
|
||||
|
||||
return ViTConfig(
|
||||
image_size=image_size,
|
||||
patch_size=14,
|
||||
num_channels=3,
|
||||
hidden_size=192,
|
||||
intermediate_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=3,
|
||||
qkv_bias=True,
|
||||
hidden_dropout_prob=0.0,
|
||||
attention_probs_dropout_prob=0.0,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _build_encoder(cls, image_size: int) -> nn.Module:
|
||||
from transformers import ViTModel
|
||||
|
||||
return ViTModel(cls._build_encoder_config(image_size), add_pooling_layer=False)
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_state_dict(payload: Mapping[str, Any]) -> Mapping[str, torch.Tensor]:
|
||||
state_dict = payload.get("state_dict", payload)
|
||||
if not isinstance(state_dict, Mapping):
|
||||
raise TypeError("checkpoint payload must contain a mapping state_dict")
|
||||
return state_dict
|
||||
|
||||
@staticmethod
|
||||
def _extract_prefixed_state_dict(
|
||||
state_dict: Mapping[str, torch.Tensor],
|
||||
prefix: str,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
extracted = {
|
||||
key[len(prefix) :]: value
|
||||
for key, value in state_dict.items()
|
||||
if key.startswith(prefix)
|
||||
}
|
||||
if not extracted:
|
||||
raise KeyError(f"checkpoint missing parameters with prefix {prefix!r}")
|
||||
return extracted
|
||||
|
||||
def load_lewm_checkpoint(self, checkpoint_or_path: str | Path | Mapping[str, Any]) -> None:
|
||||
if isinstance(checkpoint_or_path, (str, Path)):
|
||||
payload = torch.load(Path(checkpoint_or_path), map_location="cpu", weights_only=False)
|
||||
else:
|
||||
payload = checkpoint_or_path
|
||||
|
||||
state_dict = self._unwrap_state_dict(payload)
|
||||
encoder_state_dict = self._extract_prefixed_state_dict(state_dict, "model.encoder.")
|
||||
projector_state_dict = self._extract_prefixed_state_dict(state_dict, "model.projector.")
|
||||
|
||||
self.encoder.load_state_dict(encoder_state_dict, strict=True)
|
||||
self.projector.load_state_dict(projector_state_dict, strict=True)
|
||||
|
||||
def _freeze_encoder_and_projector(self) -> None:
|
||||
for module in (self.encoder, self.projector):
|
||||
module.eval()
|
||||
for parameter in module.parameters():
|
||||
parameter.requires_grad = False
|
||||
|
||||
def train(self, mode: bool = True) -> "LEWMViTBackbone":
|
||||
super().train(mode)
|
||||
if self.freeze_backbone:
|
||||
self._freeze_encoder_and_projector()
|
||||
return self
|
||||
|
||||
def _ordered_images(self, images: Dict[str, torch.Tensor]) -> list[torch.Tensor]:
|
||||
missing = [camera_name for camera_name in self.camera_names if camera_name not in images]
|
||||
if missing:
|
||||
raise ValueError(
|
||||
f"image input missing required cameras. missing={missing}, expected={list(self.camera_names)}"
|
||||
)
|
||||
|
||||
ordered = [images[camera_name] for camera_name in self.camera_names]
|
||||
reference_shape = ordered[0].shape
|
||||
if len(reference_shape) != 5:
|
||||
raise ValueError(f"expected image tensors shaped (B, T, C, H, W), got {reference_shape}")
|
||||
|
||||
for camera_name, image in zip(self.camera_names[1:], ordered[1:]):
|
||||
if image.shape != reference_shape:
|
||||
raise ValueError(
|
||||
f"camera {camera_name!r} shape {tuple(image.shape)} does not match {tuple(reference_shape)}"
|
||||
)
|
||||
|
||||
return ordered
|
||||
|
||||
def _prepare_pixels(self, images: Dict[str, torch.Tensor]) -> tuple[torch.Tensor, int, int]:
|
||||
self._ordered_images(images)
|
||||
fused = torch.cat([images[camera_name] for camera_name in self.fused_camera_names], dim=-2)
|
||||
bsz, steps = fused.shape[:2]
|
||||
fused = fused.reshape(bsz * steps, *fused.shape[2:]).contiguous().float()
|
||||
|
||||
fused = fused.clamp(0.0, 1.0)
|
||||
fused = (fused - self.mean) / self.std
|
||||
|
||||
height, width = fused.shape[-2:]
|
||||
short_side = min(height, width)
|
||||
if short_side <= 0:
|
||||
raise ValueError(f"invalid fused image shape: {tuple(fused.shape)}")
|
||||
scale = self.image_size / float(short_side)
|
||||
resized_height = int(round(height * scale))
|
||||
resized_width = int(round(width * scale))
|
||||
if (resized_height, resized_width) != (height, width):
|
||||
fused = F.interpolate(
|
||||
fused,
|
||||
size=(resized_height, resized_width),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
antialias=True,
|
||||
)
|
||||
return fused, bsz, steps
|
||||
|
||||
def forward(self, images: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
pixels, bsz, steps = self._prepare_pixels(images)
|
||||
with torch.set_grad_enabled(torch.is_grad_enabled() and not self.freeze_backbone):
|
||||
output = self.encoder(pixel_values=pixels, interpolate_pos_encoding=True)
|
||||
cls = output.last_hidden_state[:, 0]
|
||||
embedding = self.projector(cls)
|
||||
return embedding.view(bsz, steps, self.joint_output_dim)
|
||||
|
||||
@property
|
||||
def output_dim(self) -> int:
|
||||
return self._output_dim
|
||||
@@ -6,6 +6,8 @@ import torchvision
|
||||
import numpy as np
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
from .attnres_resnet2d import AttnResResNetLikeBackbone2D
|
||||
|
||||
def _replace_submodules(
|
||||
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
|
||||
) -> nn.Module:
|
||||
@@ -103,6 +105,17 @@ class _SingleRgbEncoder(nn.Module):
|
||||
use_group_norm: bool,
|
||||
spatial_softmax_num_keypoints: int,
|
||||
freeze_backbone: bool = True, # 新增:是否冻结backbone
|
||||
vision_backbone_mode: str = "resnet",
|
||||
attnres_stem_dim: int = 64,
|
||||
attnres_stage_dims: Optional[Tuple[int, ...]] = None,
|
||||
attnres_stage_depths: Optional[Tuple[int, ...]] = None,
|
||||
attnres_stage_heads: Optional[Tuple[int, ...]] = None,
|
||||
attnres_stage_kv_heads: Optional[Tuple[int, ...]] = None,
|
||||
attnres_stage_window_sizes: Optional[Tuple[int, ...]] = None,
|
||||
attnres_dropout: float = 0.0,
|
||||
attnres_ffn_mult: float = 2.667,
|
||||
attnres_eps: float = 1e-6,
|
||||
attnres_rope_theta: float = 10000.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -119,6 +132,7 @@ class _SingleRgbEncoder(nn.Module):
|
||||
self.do_crop = False
|
||||
crop_shape = input_shape[1:]
|
||||
|
||||
if vision_backbone_mode == "resnet":
|
||||
# 设置骨干网络
|
||||
backbone_model = getattr(torchvision.models, vision_backbone)(
|
||||
weights=pretrained_backbone_weights
|
||||
@@ -131,8 +145,28 @@ class _SingleRgbEncoder(nn.Module):
|
||||
self.backbone = _replace_submodules(
|
||||
root_module=self.backbone,
|
||||
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
||||
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
|
||||
func=lambda x: nn.GroupNorm(
|
||||
num_groups=max(1, x.num_features // 16),
|
||||
num_channels=x.num_features,
|
||||
),
|
||||
)
|
||||
elif vision_backbone_mode == "attnres_resnet":
|
||||
self.backbone = AttnResResNetLikeBackbone2D(
|
||||
input_channels=input_shape[0],
|
||||
stem_dim=attnres_stem_dim,
|
||||
stage_dims=tuple(attnres_stage_dims or (64, 128, 256, 512)),
|
||||
stage_depths=tuple(attnres_stage_depths or (2, 2, 2, 2)),
|
||||
stage_heads=tuple(attnres_stage_heads or (4, 4, 8, 8)),
|
||||
stage_kv_heads=tuple(attnres_stage_kv_heads or (1, 1, 1, 1)),
|
||||
stage_window_sizes=tuple(attnres_stage_window_sizes or (7, 7, 7, 7)),
|
||||
use_group_norm=use_group_norm,
|
||||
dropout=attnres_dropout,
|
||||
ffn_mult=attnres_ffn_mult,
|
||||
eps=attnres_eps,
|
||||
rope_theta=attnres_rope_theta,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"不支持的 vision_backbone_mode: {vision_backbone_mode}")
|
||||
|
||||
# 冻结backbone参数(可选)
|
||||
if freeze_backbone:
|
||||
@@ -177,13 +211,33 @@ class ResNetDiffusionBackbone(VLABackbone):
|
||||
use_group_norm: bool = True,
|
||||
spatial_softmax_num_keypoints: int = 32,
|
||||
use_separate_rgb_encoder_per_camera: bool = False, # 新增:是否为每个摄像头使用独立编码器
|
||||
output_tokens_per_camera: bool = False, # 是否按相机返回多token,而不是拼成一个token
|
||||
num_cameras: int = 1, # 新增:摄像头数量(仅在独立编码器模式下使用)
|
||||
camera_names: Optional[Tuple[str, ...]] = None, # 显式相机顺序
|
||||
freeze_backbone: bool = True, # 新增:是否冻结ResNet backbone(推荐True)
|
||||
vision_backbone_mode: str = "resnet",
|
||||
attnres_stem_dim: int = 64,
|
||||
attnres_stage_dims: Optional[Tuple[int, ...]] = None,
|
||||
attnres_stage_depths: Optional[Tuple[int, ...]] = None,
|
||||
attnres_stage_heads: Optional[Tuple[int, ...]] = None,
|
||||
attnres_stage_kv_heads: Optional[Tuple[int, ...]] = None,
|
||||
attnres_stage_window_sizes: Optional[Tuple[int, ...]] = None,
|
||||
attnres_dropout: float = 0.0,
|
||||
attnres_ffn_mult: float = 2.667,
|
||||
attnres_eps: float = 1e-6,
|
||||
attnres_rope_theta: float = 10000.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.use_separate_rgb_encoder_per_camera = use_separate_rgb_encoder_per_camera
|
||||
self.output_tokens_per_camera = bool(output_tokens_per_camera)
|
||||
self.num_cameras = num_cameras
|
||||
self.tokens_per_step = self.num_cameras if self.output_tokens_per_camera else 1
|
||||
self.camera_names = tuple(camera_names) if camera_names is not None else None
|
||||
if self.camera_names is not None and len(self.camera_names) != self.num_cameras:
|
||||
raise ValueError(
|
||||
f"camera_names 长度({len(self.camera_names)})与 num_cameras({self.num_cameras})不一致"
|
||||
)
|
||||
|
||||
if use_separate_rgb_encoder_per_camera:
|
||||
# 独立编码器模式:为每个摄像头创建独立的编码器
|
||||
@@ -197,6 +251,17 @@ class ResNetDiffusionBackbone(VLABackbone):
|
||||
use_group_norm=use_group_norm,
|
||||
spatial_softmax_num_keypoints=spatial_softmax_num_keypoints,
|
||||
freeze_backbone=freeze_backbone,
|
||||
vision_backbone_mode=vision_backbone_mode,
|
||||
attnres_stem_dim=attnres_stem_dim,
|
||||
attnres_stage_dims=attnres_stage_dims,
|
||||
attnres_stage_depths=attnres_stage_depths,
|
||||
attnres_stage_heads=attnres_stage_heads,
|
||||
attnres_stage_kv_heads=attnres_stage_kv_heads,
|
||||
attnres_stage_window_sizes=attnres_stage_window_sizes,
|
||||
attnres_dropout=attnres_dropout,
|
||||
attnres_ffn_mult=attnres_ffn_mult,
|
||||
attnres_eps=attnres_eps,
|
||||
attnres_rope_theta=attnres_rope_theta,
|
||||
)
|
||||
for _ in range(num_cameras)
|
||||
]
|
||||
@@ -214,9 +279,36 @@ class ResNetDiffusionBackbone(VLABackbone):
|
||||
use_group_norm=use_group_norm,
|
||||
spatial_softmax_num_keypoints=spatial_softmax_num_keypoints,
|
||||
freeze_backbone=freeze_backbone,
|
||||
vision_backbone_mode=vision_backbone_mode,
|
||||
attnres_stem_dim=attnres_stem_dim,
|
||||
attnres_stage_dims=attnres_stage_dims,
|
||||
attnres_stage_depths=attnres_stage_depths,
|
||||
attnres_stage_heads=attnres_stage_heads,
|
||||
attnres_stage_kv_heads=attnres_stage_kv_heads,
|
||||
attnres_stage_window_sizes=attnres_stage_window_sizes,
|
||||
attnres_dropout=attnres_dropout,
|
||||
attnres_ffn_mult=attnres_ffn_mult,
|
||||
attnres_eps=attnres_eps,
|
||||
attnres_rope_theta=attnres_rope_theta,
|
||||
)
|
||||
self.feature_dim = self.rgb_encoder.feature_dim
|
||||
|
||||
def _ordered_camera_names(self, images) -> Tuple[str, ...]:
|
||||
if self.camera_names is None:
|
||||
camera_names = tuple(sorted(images.keys()))
|
||||
if len(camera_names) != self.num_cameras:
|
||||
raise ValueError(
|
||||
f"图像输入相机数量({len(camera_names)})与 num_cameras({self.num_cameras})不一致"
|
||||
)
|
||||
return camera_names
|
||||
|
||||
missing = [cam_name for cam_name in self.camera_names if cam_name not in images]
|
||||
if missing:
|
||||
raise ValueError(
|
||||
f"图像输入缺少必需相机。missing={missing}, expected={list(self.camera_names)}"
|
||||
)
|
||||
return self.camera_names
|
||||
|
||||
def forward(self, images):
|
||||
"""
|
||||
Args:
|
||||
@@ -228,24 +320,26 @@ class ResNetDiffusionBackbone(VLABackbone):
|
||||
"""
|
||||
any_tensor = next(iter(images.values()))
|
||||
B, T = any_tensor.shape[:2]
|
||||
cam_names = sorted(images.keys())
|
||||
cam_names = self._ordered_camera_names(images)
|
||||
|
||||
features_all = []
|
||||
if self.use_separate_rgb_encoder_per_camera:
|
||||
# 独立编码器模式:每个摄像头使用对应的编码器
|
||||
features_all = []
|
||||
for cam_idx, cam_name in enumerate(cam_names):
|
||||
img = images[cam_name]
|
||||
encoder = self.rgb_encoder[cam_idx]
|
||||
features = encoder.forward_single_image(img.view(B * T, *img.shape[2:]))
|
||||
features = encoder.forward_single_image(img.reshape(B * T, *img.shape[2:]))
|
||||
features_all.append(features)
|
||||
return torch.cat(features_all, dim=1).view(B, T, -1)
|
||||
else:
|
||||
# 共享编码器模式:所有摄像头共享同一个编码器
|
||||
features_all = []
|
||||
for cam_name in cam_names:
|
||||
img = images[cam_name]
|
||||
features = self.rgb_encoder.forward_single_image(img.view(B * T, *img.shape[2:]))
|
||||
features = self.rgb_encoder.forward_single_image(img.reshape(B * T, *img.shape[2:]))
|
||||
features_all.append(features)
|
||||
|
||||
if self.output_tokens_per_camera:
|
||||
stacked = torch.stack(features_all, dim=1) # (B*T, num_cams, feature_dim)
|
||||
return stacked.view(B, T, len(cam_names), self.feature_dim)
|
||||
return torch.cat(features_all, dim=1).view(B, T, -1)
|
||||
|
||||
@property
|
||||
|
||||
124
roboimi/vla/models/backbones/siglip2_diffusion_backbone.py
Normal file
124
roboimi/vla/models/backbones/siglip2_diffusion_backbone.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import SiglipVisionModel
|
||||
|
||||
from roboimi.vla.core.interfaces import VLABackbone
|
||||
|
||||
|
||||
class SigLIP2DiffusionBackbone(VLABackbone):
|
||||
"""Shared SigLIP vision tower for multiview diffusion-policy conditioning.
|
||||
|
||||
We intentionally load the checkpoint `google/siglip2-base-patch16-256` through
|
||||
`SiglipVisionModel.from_pretrained(...)` so each camera can be fed as a normal
|
||||
`(B, C, H, W)` image tensor and produce one pooled global feature vector.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = 'google/siglip2-base-patch16-256',
|
||||
*,
|
||||
model_name_or_path: str | None = None,
|
||||
vision_model: nn.Module | None = None,
|
||||
camera_names: Sequence[str] = ('r_vis', 'top', 'front'),
|
||||
num_cameras: Optional[int] = None,
|
||||
per_view_output_dim: int = 96,
|
||||
output_dim: int | None = None,
|
||||
freeze_backbone: bool = True,
|
||||
dataset_image_resize_shape: Sequence[int] | None = None,
|
||||
eval_image_resize_shape: Sequence[int] | None = (256, 256),
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if model_name_or_path is not None:
|
||||
model_name = model_name_or_path
|
||||
if output_dim is not None:
|
||||
per_view_output_dim = output_dim
|
||||
|
||||
self.model_name = str(model_name)
|
||||
self.camera_names = tuple(camera_names)
|
||||
self.num_cameras = int(num_cameras) if num_cameras is not None else len(self.camera_names)
|
||||
if len(self.camera_names) != self.num_cameras:
|
||||
raise ValueError(
|
||||
f'camera_names length ({len(self.camera_names)}) must match num_cameras ({self.num_cameras})'
|
||||
)
|
||||
|
||||
self._output_dim = int(per_view_output_dim)
|
||||
self.joint_output_dim = self._output_dim * self.num_cameras
|
||||
self.freeze_backbone = bool(freeze_backbone)
|
||||
self.dataset_image_resize_shape = self._normalize_resize_shape(dataset_image_resize_shape)
|
||||
self.eval_image_resize_shape = self._normalize_resize_shape(eval_image_resize_shape)
|
||||
|
||||
self.encoder = vision_model if vision_model is not None else SiglipVisionModel.from_pretrained(self.model_name)
|
||||
hidden_size = int(getattr(self.encoder.config, 'hidden_size'))
|
||||
self.view_projector = nn.Linear(hidden_size, self._output_dim)
|
||||
self.projector = self.view_projector
|
||||
|
||||
self.register_buffer('mean', torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32).view(1, 3, 1, 1))
|
||||
self.register_buffer('std', torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32).view(1, 3, 1, 1))
|
||||
|
||||
if self.freeze_backbone:
|
||||
self._freeze_encoder()
|
||||
|
||||
@staticmethod
|
||||
def _normalize_resize_shape(shape: Sequence[int] | None) -> tuple[int, int] | None:
|
||||
if shape is None:
|
||||
return None
|
||||
normalized = tuple(int(v) for v in shape)
|
||||
if len(normalized) != 2:
|
||||
raise ValueError(f'resize shape must contain exactly two values, got {normalized}')
|
||||
return normalized
|
||||
|
||||
@property
|
||||
def output_dim(self) -> int:
|
||||
return self._output_dim
|
||||
|
||||
def _freeze_encoder(self) -> None:
|
||||
self.encoder.eval()
|
||||
for param in self.encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
super().train(mode)
|
||||
if self.freeze_backbone:
|
||||
self._freeze_encoder()
|
||||
return self
|
||||
|
||||
def _ordered_camera_names(self, images: Dict[str, torch.Tensor]) -> Tuple[str, ...]:
|
||||
missing = [camera_name for camera_name in self.camera_names if camera_name not in images]
|
||||
if missing:
|
||||
raise ValueError(
|
||||
f'image input missing required cameras. missing={missing}, expected={list(self.camera_names)}'
|
||||
)
|
||||
return self.camera_names
|
||||
|
||||
def _prepare_pixels(self, image: torch.Tensor) -> torch.Tensor:
|
||||
if image.ndim != 5:
|
||||
raise ValueError(f'expected image tensor shaped (B, T, C, H, W), got {tuple(image.shape)}')
|
||||
pixels = image.reshape(-1, *image.shape[2:]).contiguous().float()
|
||||
pixels = pixels.clamp(0.0, 1.0)
|
||||
return (pixels - self.mean) / self.std
|
||||
|
||||
def forward(self, images: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
camera_names = self._ordered_camera_names(images)
|
||||
reference_shape = images[camera_names[0]].shape
|
||||
batch_size, steps = reference_shape[:2]
|
||||
per_view_features = []
|
||||
for camera_name in camera_names:
|
||||
image = images[camera_name]
|
||||
if image.shape != reference_shape:
|
||||
raise ValueError(
|
||||
f'camera {camera_name!r} shape {tuple(image.shape)} does not match {tuple(reference_shape)}'
|
||||
)
|
||||
pixels = self._prepare_pixels(image)
|
||||
with torch.set_grad_enabled(torch.is_grad_enabled() and not self.freeze_backbone):
|
||||
encoded = self.encoder(pixel_values=pixels)
|
||||
pooled = encoded.pooler_output
|
||||
per_view_features.append(self.view_projector(pooled))
|
||||
features = torch.cat(per_view_features, dim=-1)
|
||||
return features.view(batch_size, steps, self.joint_output_dim)
|
||||
|
||||
|
||||
Siglip2DiffusionBackbone = SigLIP2DiffusionBackbone
|
||||
249
roboimi/vla/models/heads/attnres_transformer_components.py
Normal file
249
roboimi/vla/models/heads/attnres_transformer_components.py
Normal file
@@ -0,0 +1,249 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6) -> None:
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
rms = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
return (x.float() * rms).to(x.dtype) * self.weight
|
||||
|
||||
|
||||
class RMSNormNoWeight(nn.Module):
|
||||
def __init__(self, eps: float = 1e-6) -> None:
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
rms = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
return (x.float() * rms).to(x.dtype)
|
||||
|
||||
|
||||
def precompute_rope_freqs(
|
||||
dim: int,
|
||||
max_seq_len: int,
|
||||
theta: float = 10000.0,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> Tensor:
|
||||
if dim % 2 != 0:
|
||||
raise ValueError(f'RoPE requires an even head dimension, got {dim}.')
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim))
|
||||
positions = torch.arange(max_seq_len, device=device).float()
|
||||
angles = torch.outer(positions, freqs)
|
||||
return torch.polar(torch.ones_like(angles), angles)
|
||||
|
||||
|
||||
def apply_rope(x: Tensor, freqs: Tensor) -> Tensor:
|
||||
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
||||
freqs = freqs.unsqueeze(0).unsqueeze(2)
|
||||
x_rotated = x_complex * freqs
|
||||
return torch.view_as_real(x_rotated).reshape_as(x).to(x.dtype)
|
||||
|
||||
|
||||
class GroupedQuerySelfAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
dropout: float = 0.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if d_model % n_heads != 0:
|
||||
raise ValueError(f'd_model={d_model} must be divisible by n_heads={n_heads}.')
|
||||
if n_heads % n_kv_heads != 0:
|
||||
raise ValueError(f'n_heads={n_heads} must be divisible by n_kv_heads={n_kv_heads}.')
|
||||
|
||||
self.d_model = d_model
|
||||
self.n_heads = n_heads
|
||||
self.n_kv_heads = n_kv_heads
|
||||
self.n_kv_groups = n_heads // n_kv_heads
|
||||
self.d_head = d_model // n_heads
|
||||
self.attn_dropout = nn.Dropout(dropout)
|
||||
self.out_dropout = nn.Dropout(dropout)
|
||||
|
||||
self.w_q = nn.Linear(d_model, n_heads * self.d_head, bias=False)
|
||||
self.w_k = nn.Linear(d_model, n_kv_heads * self.d_head, bias=False)
|
||||
self.w_v = nn.Linear(d_model, n_kv_heads * self.d_head, bias=False)
|
||||
self.w_o = nn.Linear(n_heads * self.d_head, d_model, bias=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
rope_freqs: Tensor,
|
||||
mask: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
batch_size, seq_len, _ = x.shape
|
||||
|
||||
q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_head)
|
||||
k = self.w_k(x).view(batch_size, seq_len, self.n_kv_heads, self.d_head)
|
||||
v = self.w_v(x).view(batch_size, seq_len, self.n_kv_heads, self.d_head)
|
||||
|
||||
q = apply_rope(q, rope_freqs)
|
||||
k = apply_rope(k, rope_freqs)
|
||||
|
||||
if self.n_kv_heads != self.n_heads:
|
||||
k = k.unsqueeze(3).expand(
|
||||
batch_size, seq_len, self.n_kv_heads, self.n_kv_groups, self.d_head
|
||||
)
|
||||
k = k.reshape(batch_size, seq_len, self.n_heads, self.d_head)
|
||||
v = v.unsqueeze(3).expand(
|
||||
batch_size, seq_len, self.n_kv_heads, self.n_kv_groups, self.d_head
|
||||
)
|
||||
v = v.reshape(batch_size, seq_len, self.n_heads, self.d_head)
|
||||
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
scale = 1.0 / math.sqrt(self.d_head)
|
||||
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale
|
||||
if mask is not None:
|
||||
attn_weights = attn_weights + mask
|
||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||
attn_weights = self.attn_dropout(attn_weights)
|
||||
|
||||
out = torch.matmul(attn_weights, v)
|
||||
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
|
||||
return self.out_dropout(self.w_o(out))
|
||||
|
||||
|
||||
class SwiGLUFFN(nn.Module):
|
||||
def __init__(self, d_model: int, dropout: float = 0.0, mult: float = 2.667) -> None:
|
||||
super().__init__()
|
||||
raw = int(mult * d_model)
|
||||
d_ff = ((raw + 7) // 8) * 8
|
||||
self.w_gate = nn.Linear(d_model, d_ff, bias=False)
|
||||
self.w_up = nn.Linear(d_model, d_ff, bias=False)
|
||||
self.w_down = nn.Linear(d_ff, d_model, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.dropout(self.w_down(F.silu(self.w_gate(x)) * self.w_up(x)))
|
||||
|
||||
|
||||
class AttnResOperator(nn.Module):
|
||||
def __init__(self, d_model: int, eps: float = 1e-6) -> None:
|
||||
super().__init__()
|
||||
self.pseudo_query = nn.Parameter(torch.zeros(d_model))
|
||||
self.key_norm = RMSNormNoWeight(eps=eps)
|
||||
|
||||
def forward(self, sources: Tensor) -> Tensor:
|
||||
keys = self.key_norm(sources)
|
||||
logits = torch.einsum('d,nbtd->nbt', self.pseudo_query, keys)
|
||||
weights = F.softmax(logits, dim=0)
|
||||
return torch.einsum('nbt,nbtd->btd', weights, sources)
|
||||
|
||||
|
||||
class AttnResSubLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
dropout: float,
|
||||
ffn_mult: float,
|
||||
eps: float,
|
||||
is_attention: bool,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(d_model, eps=eps)
|
||||
self.attn_res = AttnResOperator(d_model, eps=eps)
|
||||
self.is_attention = is_attention
|
||||
if self.is_attention:
|
||||
self.fn = GroupedQuerySelfAttention(
|
||||
d_model=d_model,
|
||||
n_heads=n_heads,
|
||||
n_kv_heads=n_kv_heads,
|
||||
dropout=dropout,
|
||||
)
|
||||
else:
|
||||
self.fn = SwiGLUFFN(d_model=d_model, dropout=dropout, mult=ffn_mult)
|
||||
|
||||
def forward(self, sources: Tensor, rope_freqs: Tensor, mask: Optional[Tensor] = None) -> Tensor:
|
||||
h = self.attn_res(sources)
|
||||
normed = self.norm(h)
|
||||
if self.is_attention:
|
||||
return self.fn(normed, rope_freqs, mask)
|
||||
return self.fn(normed)
|
||||
|
||||
|
||||
class AttnResTransformerBackbone(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
n_blocks: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
max_seq_len: int,
|
||||
dropout: float = 0.0,
|
||||
ffn_mult: float = 2.667,
|
||||
eps: float = 1e-6,
|
||||
rope_theta: float = 10000.0,
|
||||
causal_attn: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.causal_attn = causal_attn
|
||||
self.layers = nn.ModuleList()
|
||||
for _ in range(n_blocks):
|
||||
self.layers.append(
|
||||
AttnResSubLayer(
|
||||
d_model=d_model,
|
||||
n_heads=n_heads,
|
||||
n_kv_heads=n_kv_heads,
|
||||
dropout=dropout,
|
||||
ffn_mult=ffn_mult,
|
||||
eps=eps,
|
||||
is_attention=True,
|
||||
)
|
||||
)
|
||||
self.layers.append(
|
||||
AttnResSubLayer(
|
||||
d_model=d_model,
|
||||
n_heads=n_heads,
|
||||
n_kv_heads=n_kv_heads,
|
||||
dropout=dropout,
|
||||
ffn_mult=ffn_mult,
|
||||
eps=eps,
|
||||
is_attention=False,
|
||||
)
|
||||
)
|
||||
|
||||
rope_freqs = precompute_rope_freqs(
|
||||
dim=d_model // n_heads,
|
||||
max_seq_len=max_seq_len,
|
||||
theta=rope_theta,
|
||||
)
|
||||
self.register_buffer('rope_freqs', rope_freqs, persistent=False)
|
||||
|
||||
@staticmethod
|
||||
def _build_causal_mask(seq_len: int, device: torch.device) -> Tensor:
|
||||
mask = torch.full((seq_len, seq_len), float('-inf'), device=device)
|
||||
mask = torch.triu(mask, diagonal=1)
|
||||
return mask.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
seq_len = x.shape[1]
|
||||
rope_freqs = self.rope_freqs[:seq_len]
|
||||
mask = None
|
||||
if self.causal_attn:
|
||||
mask = self._build_causal_mask(seq_len, x.device)
|
||||
|
||||
layer_outputs = [x]
|
||||
for layer in self.layers:
|
||||
sources = torch.stack(layer_outputs, dim=0)
|
||||
output = layer(sources, rope_freqs, mask)
|
||||
layer_outputs.append(output)
|
||||
|
||||
return torch.stack(layer_outputs, dim=0).sum(dim=0)
|
||||
379
roboimi/vla/models/heads/imf_transformer1d.py
Normal file
379
roboimi/vla/models/heads/imf_transformer1d.py
Normal file
@@ -0,0 +1,379 @@
|
||||
"""Local IMF-AttnRes transformer head aligned with diffusion_policy@185ed659."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .attnres_transformer_components import (
|
||||
AttnResOperator,
|
||||
AttnResSubLayer,
|
||||
AttnResTransformerBackbone,
|
||||
GroupedQuerySelfAttention,
|
||||
RMSNorm,
|
||||
RMSNormNoWeight,
|
||||
SwiGLUFFN,
|
||||
)
|
||||
from .transformer1d import ModuleAttrMixin, SinusoidalPosEmb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IMFTransformer1D(ModuleAttrMixin):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
horizon: int,
|
||||
n_obs_steps: Optional[int] = None,
|
||||
cond_dim: int = 0,
|
||||
n_layer: int = 12,
|
||||
n_head: int = 1,
|
||||
n_emb: int = 768,
|
||||
p_drop_emb: float = 0.1,
|
||||
p_drop_attn: float = 0.1,
|
||||
causal_attn: bool = False,
|
||||
time_as_cond: bool = True,
|
||||
obs_as_cond: bool = False,
|
||||
n_cond_layers: int = 0,
|
||||
backbone_type: str = 'attnres_full',
|
||||
n_kv_head: int = 1,
|
||||
attn_res_ffn_mult: float = 2.667,
|
||||
attn_res_eps: float = 1e-6,
|
||||
attn_res_rope_theta: float = 10000.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
if n_head != 1:
|
||||
raise AssertionError('IMFTransformer1D currently supports single-head attention only.')
|
||||
if n_obs_steps is None:
|
||||
n_obs_steps = horizon
|
||||
|
||||
self.backbone_type = backbone_type
|
||||
|
||||
T = horizon
|
||||
T_cond = 2
|
||||
if not time_as_cond:
|
||||
T += 2
|
||||
T_cond -= 2
|
||||
obs_as_cond = cond_dim > 0
|
||||
if obs_as_cond:
|
||||
assert time_as_cond
|
||||
T_cond += n_obs_steps
|
||||
|
||||
self.input_emb = nn.Linear(input_dim, n_emb)
|
||||
self.drop = nn.Dropout(p_drop_emb)
|
||||
self.time_emb = SinusoidalPosEmb(n_emb)
|
||||
self.cond_obs_emb = nn.Linear(cond_dim, n_emb) if obs_as_cond else None
|
||||
self.time_token_proj = None
|
||||
self.cond_pos_emb = None
|
||||
self.pos_emb = None
|
||||
self.encoder = None
|
||||
self.decoder = None
|
||||
self.attnres_backbone = None
|
||||
encoder_only = False
|
||||
|
||||
if backbone_type == 'attnres_full':
|
||||
if not time_as_cond:
|
||||
raise ValueError('attnres_full backbone requires time_as_cond=True.')
|
||||
if n_cond_layers != 0:
|
||||
raise ValueError('attnres_full backbone does not support n_cond_layers > 0.')
|
||||
|
||||
self.time_token_proj = nn.Linear(n_emb, n_emb)
|
||||
self.attnres_backbone = AttnResTransformerBackbone(
|
||||
d_model=n_emb,
|
||||
n_blocks=n_layer,
|
||||
n_heads=n_head,
|
||||
n_kv_heads=n_kv_head,
|
||||
max_seq_len=T + T_cond,
|
||||
dropout=p_drop_attn,
|
||||
ffn_mult=attn_res_ffn_mult,
|
||||
eps=attn_res_eps,
|
||||
rope_theta=attn_res_rope_theta,
|
||||
causal_attn=causal_attn,
|
||||
)
|
||||
self.ln_f = RMSNorm(n_emb, eps=attn_res_eps)
|
||||
else:
|
||||
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
|
||||
if T_cond > 0:
|
||||
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
|
||||
if n_cond_layers > 0:
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=n_emb,
|
||||
nhead=n_head,
|
||||
dim_feedforward=4 * n_emb,
|
||||
dropout=p_drop_attn,
|
||||
activation='gelu',
|
||||
batch_first=True,
|
||||
norm_first=True,
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(
|
||||
encoder_layer=encoder_layer,
|
||||
num_layers=n_cond_layers,
|
||||
)
|
||||
else:
|
||||
self.encoder = nn.Sequential(
|
||||
nn.Linear(n_emb, 4 * n_emb),
|
||||
nn.Mish(),
|
||||
nn.Linear(4 * n_emb, n_emb),
|
||||
)
|
||||
|
||||
decoder_layer = nn.TransformerDecoderLayer(
|
||||
d_model=n_emb,
|
||||
nhead=n_head,
|
||||
dim_feedforward=4 * n_emb,
|
||||
dropout=p_drop_attn,
|
||||
activation='gelu',
|
||||
batch_first=True,
|
||||
norm_first=True,
|
||||
)
|
||||
self.decoder = nn.TransformerDecoder(
|
||||
decoder_layer=decoder_layer,
|
||||
num_layers=n_layer,
|
||||
)
|
||||
else:
|
||||
encoder_only = True
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=n_emb,
|
||||
nhead=n_head,
|
||||
dim_feedforward=4 * n_emb,
|
||||
dropout=p_drop_attn,
|
||||
activation='gelu',
|
||||
batch_first=True,
|
||||
norm_first=True,
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(
|
||||
encoder_layer=encoder_layer,
|
||||
num_layers=n_layer,
|
||||
)
|
||||
|
||||
self.ln_f = nn.LayerNorm(n_emb)
|
||||
|
||||
if causal_attn and backbone_type != 'attnres_full':
|
||||
sz = T
|
||||
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
||||
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||
self.register_buffer('mask', mask)
|
||||
|
||||
if time_as_cond and obs_as_cond:
|
||||
S = T_cond
|
||||
t_idx, s_idx = torch.meshgrid(
|
||||
torch.arange(T),
|
||||
torch.arange(S),
|
||||
indexing='ij',
|
||||
)
|
||||
mask = t_idx >= (s_idx - 2)
|
||||
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||
self.register_buffer('memory_mask', mask)
|
||||
else:
|
||||
self.memory_mask = None
|
||||
else:
|
||||
self.mask = None
|
||||
self.memory_mask = None
|
||||
|
||||
self.head = nn.Linear(n_emb, output_dim)
|
||||
|
||||
self.T = T
|
||||
self.T_cond = T_cond
|
||||
self.horizon = horizon
|
||||
self.time_as_cond = time_as_cond
|
||||
self.obs_as_cond = obs_as_cond
|
||||
self.encoder_only = encoder_only
|
||||
|
||||
self.apply(self._init_weights)
|
||||
logger.info('number of parameters: %e', sum(p.numel() for p in self.parameters()))
|
||||
|
||||
def _init_weights(self, module):
|
||||
ignore_types = (
|
||||
nn.Dropout,
|
||||
SinusoidalPosEmb,
|
||||
nn.TransformerEncoderLayer,
|
||||
nn.TransformerDecoderLayer,
|
||||
nn.TransformerEncoder,
|
||||
nn.TransformerDecoder,
|
||||
nn.ModuleList,
|
||||
nn.Mish,
|
||||
nn.Sequential,
|
||||
AttnResTransformerBackbone,
|
||||
AttnResSubLayer,
|
||||
GroupedQuerySelfAttention,
|
||||
SwiGLUFFN,
|
||||
RMSNormNoWeight,
|
||||
)
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.MultiheadAttention):
|
||||
for name in ('in_proj_weight', 'q_proj_weight', 'k_proj_weight', 'v_proj_weight'):
|
||||
weight = getattr(module, name)
|
||||
if weight is not None:
|
||||
torch.nn.init.normal_(weight, mean=0.0, std=0.02)
|
||||
|
||||
for name in ('in_proj_bias', 'bias_k', 'bias_v'):
|
||||
bias = getattr(module, name)
|
||||
if bias is not None:
|
||||
torch.nn.init.zeros_(bias)
|
||||
elif isinstance(module, (nn.LayerNorm, RMSNorm)):
|
||||
if getattr(module, 'bias', None) is not None:
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
torch.nn.init.ones_(module.weight)
|
||||
elif isinstance(module, AttnResOperator):
|
||||
torch.nn.init.zeros_(module.pseudo_query)
|
||||
elif isinstance(module, IMFTransformer1D):
|
||||
if module.pos_emb is not None:
|
||||
torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
|
||||
if module.cond_pos_emb is not None:
|
||||
torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02)
|
||||
elif isinstance(module, ignore_types):
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError(f'Unaccounted module {module}')
|
||||
|
||||
def get_optim_groups(self, weight_decay: float = 1e-3):
|
||||
decay = set()
|
||||
no_decay = set()
|
||||
whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention)
|
||||
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding, RMSNorm)
|
||||
for mn, m in self.named_modules():
|
||||
for pn, _ in m.named_parameters(recurse=False):
|
||||
fpn = f'{mn}.{pn}' if mn else pn
|
||||
|
||||
if pn.endswith('bias'):
|
||||
no_decay.add(fpn)
|
||||
elif pn.startswith('bias'):
|
||||
no_decay.add(fpn)
|
||||
elif pn == 'pseudo_query':
|
||||
no_decay.add(fpn)
|
||||
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
|
||||
decay.add(fpn)
|
||||
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
|
||||
no_decay.add(fpn)
|
||||
|
||||
if self.pos_emb is not None:
|
||||
no_decay.add('pos_emb')
|
||||
no_decay.add('_dummy_variable')
|
||||
if self.cond_pos_emb is not None:
|
||||
no_decay.add('cond_pos_emb')
|
||||
|
||||
param_dict = {pn: p for pn, p in self.named_parameters()}
|
||||
inter_params = decay & no_decay
|
||||
union_params = decay | no_decay
|
||||
assert len(inter_params) == 0, f'parameters {inter_params} made it into both decay/no_decay sets!'
|
||||
assert len(param_dict.keys() - union_params) == 0, (
|
||||
f'parameters {param_dict.keys() - union_params} were not separated into either decay/no_decay sets!'
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
'params': [param_dict[pn] for pn in sorted(list(decay))],
|
||||
'weight_decay': weight_decay,
|
||||
},
|
||||
{
|
||||
'params': [param_dict[pn] for pn in sorted(list(no_decay))],
|
||||
'weight_decay': 0.0,
|
||||
},
|
||||
]
|
||||
|
||||
def configure_optimizers(
|
||||
self,
|
||||
learning_rate: float = 1e-4,
|
||||
weight_decay: float = 1e-3,
|
||||
betas: Tuple[float, float] = (0.9, 0.95),
|
||||
):
|
||||
optim_groups = self.get_optim_groups(weight_decay=weight_decay)
|
||||
return torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
|
||||
|
||||
def _prepare_time_input(self, value: Union[torch.Tensor, float, int], sample: torch.Tensor) -> torch.Tensor:
|
||||
if not torch.is_tensor(value):
|
||||
value = torch.tensor([value], dtype=sample.dtype, device=sample.device)
|
||||
elif value.ndim == 0:
|
||||
value = value[None].to(device=sample.device, dtype=sample.dtype)
|
||||
else:
|
||||
value = value.to(device=sample.device, dtype=sample.dtype)
|
||||
return value.expand(sample.shape[0])
|
||||
|
||||
def _forward_attnres_full(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
r: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
cond: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
sample_tokens = self.input_emb(sample)
|
||||
token_parts = [
|
||||
self.time_token_proj(self.time_emb(r)).unsqueeze(1),
|
||||
self.time_token_proj(self.time_emb(t)).unsqueeze(1),
|
||||
]
|
||||
if self.obs_as_cond:
|
||||
if cond is None:
|
||||
raise ValueError('cond is required when obs_as_cond=True for attnres_full backbone.')
|
||||
token_parts.append(self.cond_obs_emb(cond))
|
||||
token_parts.append(sample_tokens)
|
||||
x = torch.cat(token_parts, dim=1)
|
||||
x = self.drop(x)
|
||||
x = self.attnres_backbone(x)
|
||||
x = x[:, -sample_tokens.shape[1]:, :]
|
||||
return x
|
||||
|
||||
def _forward_vanilla(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
r: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
cond: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
r_emb = self.time_emb(r).unsqueeze(1)
|
||||
t_emb = self.time_emb(t).unsqueeze(1)
|
||||
input_emb = self.input_emb(sample)
|
||||
|
||||
if self.encoder_only:
|
||||
token_embeddings = torch.cat([r_emb, t_emb, input_emb], dim=1)
|
||||
token_count = token_embeddings.shape[1]
|
||||
position_embeddings = self.pos_emb[:, :token_count, :]
|
||||
x = self.drop(token_embeddings + position_embeddings)
|
||||
x = self.encoder(src=x, mask=self.mask)
|
||||
x = x[:, 2:, :]
|
||||
else:
|
||||
cond_embeddings = torch.cat([r_emb, t_emb], dim=1)
|
||||
if self.obs_as_cond:
|
||||
cond_embeddings = torch.cat([cond_embeddings, self.cond_obs_emb(cond)], dim=1)
|
||||
token_count = cond_embeddings.shape[1]
|
||||
position_embeddings = self.cond_pos_emb[:, :token_count, :]
|
||||
x = self.drop(cond_embeddings + position_embeddings)
|
||||
x = self.encoder(x)
|
||||
memory = x
|
||||
|
||||
token_embeddings = input_emb
|
||||
token_count = token_embeddings.shape[1]
|
||||
position_embeddings = self.pos_emb[:, :token_count, :]
|
||||
x = self.drop(token_embeddings + position_embeddings)
|
||||
x = self.decoder(
|
||||
tgt=x,
|
||||
memory=memory,
|
||||
tgt_mask=self.mask,
|
||||
memory_mask=self.memory_mask,
|
||||
)
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
r: Union[torch.Tensor, float, int],
|
||||
t: Union[torch.Tensor, float, int],
|
||||
cond: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
r = self._prepare_time_input(r, sample)
|
||||
t = self._prepare_time_input(t, sample)
|
||||
|
||||
if self.backbone_type == 'attnres_full':
|
||||
x = self._forward_attnres_full(sample, r, t, cond=cond)
|
||||
else:
|
||||
x = self._forward_vanilla(sample, r, t, cond=cond)
|
||||
|
||||
x = self.ln_f(x)
|
||||
x = self.head(x)
|
||||
return x
|
||||
@@ -1,19 +1,35 @@
|
||||
"""
|
||||
Transformer-based Diffusion Policy Head
|
||||
"""Transformer-based diffusion head aligned with diffusion_policy's TransformerForDiffusion."""
|
||||
|
||||
使用Transformer架构(Encoder-Decoder)替代UNet进行噪声预测。
|
||||
支持通过Cross-Attention注入全局条件(观测特征)。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModuleAttrMixin(nn.Module):
|
||||
"""Minimal local copy of diffusion_policy's ModuleAttrMixin for state-dict parity."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._dummy_variable = nn.Parameter()
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(iter(self.parameters())).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(iter(self.parameters())).dtype
|
||||
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
"""正弦位置编码(用于时间步嵌入)"""
|
||||
def __init__(self, dim: int):
|
||||
def __init__(self, dim: int) -> None:
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
@@ -27,35 +43,13 @@ class SinusoidalPosEmb(nn.Module):
|
||||
return emb
|
||||
|
||||
|
||||
class Transformer1D(nn.Module):
|
||||
"""
|
||||
Transformer-based 1D Diffusion Model
|
||||
|
||||
使用Encoder-Decoder架构:
|
||||
- Encoder: 处理条件(观测 + 时间步)
|
||||
- Decoder: 通过Cross-Attention预测噪声
|
||||
|
||||
Args:
|
||||
input_dim: 输入动作维度
|
||||
output_dim: 输出动作维度
|
||||
horizon: 预测horizon长度
|
||||
n_obs_steps: 观测步数
|
||||
cond_dim: 条件维度
|
||||
n_layer: Transformer层数
|
||||
n_head: 注意力头数
|
||||
n_emb: 嵌入维度
|
||||
p_drop_emb: Embedding dropout
|
||||
p_drop_attn: Attention dropout
|
||||
causal_attn: 是否使用因果注意力(自回归)
|
||||
n_cond_layers: Encoder层数(0表示使用MLP)
|
||||
"""
|
||||
|
||||
class Transformer1D(ModuleAttrMixin):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
horizon: int,
|
||||
n_obs_steps: int = None,
|
||||
n_obs_steps: Optional[int] = None,
|
||||
cond_dim: int = 0,
|
||||
n_layer: int = 8,
|
||||
n_head: int = 8,
|
||||
@@ -63,57 +57,42 @@ class Transformer1D(nn.Module):
|
||||
p_drop_emb: float = 0.1,
|
||||
p_drop_attn: float = 0.1,
|
||||
causal_attn: bool = False,
|
||||
time_as_cond: bool = True,
|
||||
obs_as_cond: bool = False,
|
||||
n_cond_layers: int = 0
|
||||
):
|
||||
n_cond_layers: int = 0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# 计算序列长度
|
||||
if n_obs_steps is None:
|
||||
n_obs_steps = horizon
|
||||
|
||||
T = horizon
|
||||
T_cond = 1 # 时间步token数量
|
||||
|
||||
# 确定是否使用观测作为条件
|
||||
T_cond = 1
|
||||
if not time_as_cond:
|
||||
T += 1
|
||||
T_cond -= 1
|
||||
obs_as_cond = cond_dim > 0
|
||||
if obs_as_cond:
|
||||
assert time_as_cond
|
||||
T_cond += n_obs_steps
|
||||
|
||||
# 保存配置
|
||||
self.T = T
|
||||
self.T_cond = T_cond
|
||||
self.horizon = horizon
|
||||
self.obs_as_cond = obs_as_cond
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
|
||||
# ==================== 输入嵌入 ====================
|
||||
self.input_emb = nn.Linear(input_dim, n_emb)
|
||||
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
|
||||
self.drop = nn.Dropout(p_drop_emb)
|
||||
|
||||
# ==================== 条件编码 ====================
|
||||
# 时间步嵌入
|
||||
self.time_emb = SinusoidalPosEmb(n_emb)
|
||||
|
||||
# 观测条件嵌入(可选)
|
||||
self.cond_obs_emb = None
|
||||
if obs_as_cond:
|
||||
self.cond_obs_emb = nn.Linear(cond_dim, n_emb)
|
||||
|
||||
# 条件位置编码
|
||||
self.cond_pos_emb = None
|
||||
self.encoder = None
|
||||
self.decoder = None
|
||||
encoder_only = False
|
||||
|
||||
if T_cond > 0:
|
||||
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
|
||||
|
||||
# ==================== Encoder ====================
|
||||
self.encoder = None
|
||||
self.encoder_only = False
|
||||
|
||||
if T_cond > 0:
|
||||
if n_cond_layers > 0:
|
||||
# 使用Transformer Encoder
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=n_emb,
|
||||
nhead=n_head,
|
||||
@@ -121,61 +100,19 @@ class Transformer1D(nn.Module):
|
||||
dropout=p_drop_attn,
|
||||
activation='gelu',
|
||||
batch_first=True,
|
||||
norm_first=True # Pre-LN更稳定
|
||||
norm_first=True,
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(
|
||||
encoder_layer=encoder_layer,
|
||||
num_layers=n_cond_layers
|
||||
num_layers=n_cond_layers,
|
||||
)
|
||||
else:
|
||||
# 使用简单的MLP
|
||||
self.encoder = nn.Sequential(
|
||||
nn.Linear(n_emb, 4 * n_emb),
|
||||
nn.Mish(),
|
||||
nn.Linear(4 * n_emb, n_emb)
|
||||
)
|
||||
else:
|
||||
# Encoder-only模式(BERT风格)
|
||||
self.encoder_only = True
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=n_emb,
|
||||
nhead=n_head,
|
||||
dim_feedforward=4 * n_emb,
|
||||
dropout=p_drop_attn,
|
||||
activation='gelu',
|
||||
batch_first=True,
|
||||
norm_first=True
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(
|
||||
encoder_layer=encoder_layer,
|
||||
num_layers=n_layer
|
||||
nn.Linear(4 * n_emb, n_emb),
|
||||
)
|
||||
|
||||
# ==================== Attention Mask ====================
|
||||
self.mask = None
|
||||
self.memory_mask = None
|
||||
|
||||
if causal_attn:
|
||||
# 因果mask:确保只关注左侧
|
||||
sz = T
|
||||
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
||||
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||
self.register_buffer("mask", mask)
|
||||
|
||||
if obs_as_cond:
|
||||
# 交叉注意力mask
|
||||
S = T_cond
|
||||
t, s = torch.meshgrid(
|
||||
torch.arange(T),
|
||||
torch.arange(S),
|
||||
indexing='ij'
|
||||
)
|
||||
mask = t >= (s - 1)
|
||||
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||
self.register_buffer('memory_mask', mask)
|
||||
|
||||
# ==================== Decoder ====================
|
||||
if not self.encoder_only:
|
||||
decoder_layer = nn.TransformerDecoderLayer(
|
||||
d_model=n_emb,
|
||||
nhead=n_head,
|
||||
@@ -183,136 +120,199 @@ class Transformer1D(nn.Module):
|
||||
dropout=p_drop_attn,
|
||||
activation='gelu',
|
||||
batch_first=True,
|
||||
norm_first=True
|
||||
norm_first=True,
|
||||
)
|
||||
self.decoder = nn.TransformerDecoder(
|
||||
decoder_layer=decoder_layer,
|
||||
num_layers=n_layer
|
||||
num_layers=n_layer,
|
||||
)
|
||||
else:
|
||||
encoder_only = True
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=n_emb,
|
||||
nhead=n_head,
|
||||
dim_feedforward=4 * n_emb,
|
||||
dropout=p_drop_attn,
|
||||
activation='gelu',
|
||||
batch_first=True,
|
||||
norm_first=True,
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(
|
||||
encoder_layer=encoder_layer,
|
||||
num_layers=n_layer,
|
||||
)
|
||||
|
||||
# ==================== 输出头 ====================
|
||||
if causal_attn:
|
||||
sz = T
|
||||
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
||||
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||
self.register_buffer('mask', mask)
|
||||
|
||||
if time_as_cond and obs_as_cond:
|
||||
S = T_cond
|
||||
t, s = torch.meshgrid(torch.arange(T), torch.arange(S), indexing='ij')
|
||||
mask = t >= (s - 1)
|
||||
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||
self.register_buffer('memory_mask', mask)
|
||||
else:
|
||||
self.memory_mask = None
|
||||
else:
|
||||
self.mask = None
|
||||
self.memory_mask = None
|
||||
|
||||
self.ln_f = nn.LayerNorm(n_emb)
|
||||
self.head = nn.Linear(n_emb, output_dim)
|
||||
|
||||
# ==================== 初始化 ====================
|
||||
self.apply(self._init_weights)
|
||||
self.T = T
|
||||
self.T_cond = T_cond
|
||||
self.horizon = horizon
|
||||
self.time_as_cond = time_as_cond
|
||||
self.obs_as_cond = obs_as_cond
|
||||
self.encoder_only = encoder_only
|
||||
|
||||
# 打印参数量
|
||||
total_params = sum(p.numel() for p in self.parameters())
|
||||
print(f"Transformer1D parameters: {total_params:,}")
|
||||
self.apply(self._init_weights)
|
||||
logger.info('number of parameters: %e', sum(p.numel() for p in self.parameters()))
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""初始化权重"""
|
||||
ignore_types = (
|
||||
nn.Dropout,
|
||||
SinusoidalPosEmb,
|
||||
nn.TransformerEncoderLayer,
|
||||
nn.TransformerDecoderLayer,
|
||||
nn.TransformerEncoder,
|
||||
nn.TransformerDecoder,
|
||||
nn.ModuleList,
|
||||
nn.Mish,
|
||||
nn.Sequential,
|
||||
)
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.MultiheadAttention):
|
||||
# MultiheadAttention的权重初始化
|
||||
for name in ['in_proj_weight', 'q_proj_weight', 'k_proj_weight', 'v_proj_weight']:
|
||||
weight = getattr(module, name, None)
|
||||
for name in ('in_proj_weight', 'q_proj_weight', 'k_proj_weight', 'v_proj_weight'):
|
||||
weight = getattr(module, name)
|
||||
if weight is not None:
|
||||
torch.nn.init.normal_(weight, mean=0.0, std=0.02)
|
||||
|
||||
for name in ['in_proj_bias', 'bias_k', 'bias_v']:
|
||||
bias = getattr(module, name, None)
|
||||
for name in ('in_proj_bias', 'bias_k', 'bias_v'):
|
||||
bias = getattr(module, name)
|
||||
if bias is not None:
|
||||
torch.nn.init.zeros_(bias)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
torch.nn.init.ones_(module.weight)
|
||||
elif isinstance(module, Transformer1D):
|
||||
# 位置编码初始化
|
||||
torch.nn.init.normal_(self.pos_emb, mean=0.0, std=0.02)
|
||||
torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
|
||||
if module.cond_obs_emb is not None:
|
||||
torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02)
|
||||
elif isinstance(module, ignore_types):
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError(f'Unaccounted module {module}')
|
||||
|
||||
def get_optim_groups(self, weight_decay: float = 1e-3):
|
||||
decay = set()
|
||||
no_decay = set()
|
||||
whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention)
|
||||
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
||||
|
||||
for module_name, module in self.named_modules():
|
||||
for param_name, _ in module.named_parameters():
|
||||
full_param_name = f'{module_name}.{param_name}' if module_name else param_name
|
||||
|
||||
if param_name.endswith('bias'):
|
||||
no_decay.add(full_param_name)
|
||||
elif param_name.startswith('bias'):
|
||||
no_decay.add(full_param_name)
|
||||
elif param_name.endswith('weight') and isinstance(module, whitelist_weight_modules):
|
||||
decay.add(full_param_name)
|
||||
elif param_name.endswith('weight') and isinstance(module, blacklist_weight_modules):
|
||||
no_decay.add(full_param_name)
|
||||
|
||||
no_decay.add('pos_emb')
|
||||
no_decay.add('_dummy_variable')
|
||||
if self.cond_pos_emb is not None:
|
||||
torch.nn.init.normal_(self.cond_pos_emb, mean=0.0, std=0.02)
|
||||
no_decay.add('cond_pos_emb')
|
||||
|
||||
param_dict = {name: param for name, param in self.named_parameters()}
|
||||
inter_params = decay & no_decay
|
||||
union_params = decay | no_decay
|
||||
assert len(inter_params) == 0, f'parameters {inter_params} made it into both decay/no_decay sets!'
|
||||
assert len(param_dict.keys() - union_params) == 0, (
|
||||
f'parameters {param_dict.keys() - union_params} were not separated into either decay/no_decay sets!'
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
'params': [param_dict[name] for name in sorted(decay)],
|
||||
'weight_decay': weight_decay,
|
||||
},
|
||||
{
|
||||
'params': [param_dict[name] for name in sorted(no_decay)],
|
||||
'weight_decay': 0.0,
|
||||
},
|
||||
]
|
||||
|
||||
def configure_optimizers(
|
||||
self,
|
||||
learning_rate: float = 1e-4,
|
||||
weight_decay: float = 1e-3,
|
||||
betas: Tuple[float, float] = (0.9, 0.95),
|
||||
):
|
||||
optim_groups = self.get_optim_groups(weight_decay=weight_decay)
|
||||
return torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
cond: Optional[torch.Tensor] = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
前向传播
|
||||
|
||||
Args:
|
||||
sample: (B, T, input_dim) 输入序列(加噪动作)
|
||||
timestep: (B,) 时间步
|
||||
cond: (B, T', cond_dim) 条件序列(观测特征)
|
||||
|
||||
Returns:
|
||||
(B, T, output_dim) 预测的噪声
|
||||
"""
|
||||
# ==================== 处理时间步 ====================
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
# 扩展到batch维度
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
time_emb = self.time_emb(timesteps).unsqueeze(1) # (B, 1, n_emb)
|
||||
time_emb = self.time_emb(timesteps).unsqueeze(1)
|
||||
|
||||
# ==================== 处理输入 ====================
|
||||
input_emb = self.input_emb(sample) # (B, T, n_emb)
|
||||
input_emb = self.input_emb(sample)
|
||||
|
||||
# ==================== Encoder-Decoder模式 ====================
|
||||
if not self.encoder_only:
|
||||
# --- Encoder: 处理条件 ---
|
||||
if self.encoder_only:
|
||||
token_embeddings = torch.cat([time_emb, input_emb], dim=1)
|
||||
t = token_embeddings.shape[1]
|
||||
position_embeddings = self.pos_emb[:, :t, :]
|
||||
x = self.drop(token_embeddings + position_embeddings)
|
||||
x = self.encoder(src=x, mask=self.mask)
|
||||
x = x[:, 1:, :]
|
||||
else:
|
||||
cond_embeddings = time_emb
|
||||
|
||||
if self.obs_as_cond and cond is not None:
|
||||
# 添加观测条件
|
||||
cond_obs_emb = self.cond_obs_emb(cond) # (B, T_cond-1, n_emb)
|
||||
if self.obs_as_cond:
|
||||
cond_obs_emb = self.cond_obs_emb(cond)
|
||||
cond_embeddings = torch.cat([cond_embeddings, cond_obs_emb], dim=1)
|
||||
|
||||
# 添加位置编码
|
||||
tc = cond_embeddings.shape[1]
|
||||
pos_emb = self.cond_pos_emb[:, :tc, :]
|
||||
x = self.drop(cond_embeddings + pos_emb)
|
||||
position_embeddings = self.cond_pos_emb[:, :tc, :]
|
||||
x = self.drop(cond_embeddings + position_embeddings)
|
||||
memory = self.encoder(x)
|
||||
|
||||
# 通过encoder
|
||||
memory = self.encoder(x) # (B, T_cond, n_emb)
|
||||
|
||||
# --- Decoder: 预测噪声 ---
|
||||
# 添加位置编码到输入
|
||||
token_embeddings = input_emb
|
||||
t = token_embeddings.shape[1]
|
||||
pos_emb = self.pos_emb[:, :t, :]
|
||||
x = self.drop(token_embeddings + pos_emb)
|
||||
|
||||
# Cross-Attention: Query来自输入,Key/Value来自memory
|
||||
position_embeddings = self.pos_emb[:, :t, :]
|
||||
x = self.drop(token_embeddings + position_embeddings)
|
||||
x = self.decoder(
|
||||
tgt=x,
|
||||
memory=memory,
|
||||
tgt_mask=self.mask,
|
||||
memory_mask=self.memory_mask
|
||||
memory_mask=self.memory_mask,
|
||||
)
|
||||
|
||||
# ==================== Encoder-Only模式 ====================
|
||||
else:
|
||||
# BERT风格:时间步作为特殊token
|
||||
token_embeddings = torch.cat([time_emb, input_emb], dim=1)
|
||||
t = token_embeddings.shape[1]
|
||||
pos_emb = self.pos_emb[:, :t, :]
|
||||
x = self.drop(token_embeddings + pos_emb)
|
||||
|
||||
x = self.encoder(src=x, mask=self.mask)
|
||||
x = x[:, 1:, :] # 移除时间步token
|
||||
|
||||
# ==================== 输出头 ====================
|
||||
x = self.ln_f(x)
|
||||
x = self.head(x) # (B, T, output_dim)
|
||||
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 便捷函数:创建Transformer1D模型
|
||||
# ============================================================================
|
||||
def create_transformer1d(
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
@@ -322,26 +322,9 @@ def create_transformer1d(
|
||||
n_layer: int = 8,
|
||||
n_head: int = 8,
|
||||
n_emb: int = 256,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Transformer1D:
|
||||
"""
|
||||
创建Transformer1D模型的便捷函数
|
||||
|
||||
Args:
|
||||
input_dim: 输入动作维度
|
||||
output_dim: 输出动作维度
|
||||
horizon: 预测horizon
|
||||
n_obs_steps: 观测步数
|
||||
cond_dim: 条件维度
|
||||
n_layer: Transformer层数
|
||||
n_head: 注意力头数
|
||||
n_emb: 嵌入维度
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
Transformer1D模型
|
||||
"""
|
||||
model = Transformer1D(
|
||||
return Transformer1D(
|
||||
input_dim=input_dim,
|
||||
output_dim=output_dim,
|
||||
horizon=horizon,
|
||||
@@ -350,47 +333,5 @@ def create_transformer1d(
|
||||
n_layer=n_layer,
|
||||
n_head=n_head,
|
||||
n_emb=n_emb,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 80)
|
||||
print("Testing Transformer1D")
|
||||
print("=" * 80)
|
||||
|
||||
# 配置
|
||||
B = 4
|
||||
T = 16
|
||||
action_dim = 16
|
||||
obs_horizon = 2
|
||||
cond_dim = 416 # vision + state特征维度
|
||||
|
||||
# 创建模型
|
||||
model = Transformer1D(
|
||||
input_dim=action_dim,
|
||||
output_dim=action_dim,
|
||||
horizon=T,
|
||||
n_obs_steps=obs_horizon,
|
||||
cond_dim=cond_dim,
|
||||
n_layer=4,
|
||||
n_head=8,
|
||||
n_emb=256,
|
||||
causal_attn=False
|
||||
)
|
||||
|
||||
# 测试前向传播
|
||||
sample = torch.randn(B, T, action_dim)
|
||||
timestep = torch.randint(0, 100, (B,))
|
||||
cond = torch.randn(B, obs_horizon, cond_dim)
|
||||
|
||||
output = model(sample, timestep, cond)
|
||||
|
||||
print(f"\n输入:")
|
||||
print(f" sample: {sample.shape}")
|
||||
print(f" timestep: {timestep.shape}")
|
||||
print(f" cond: {cond.shape}")
|
||||
print(f"\n输出:")
|
||||
print(f" output: {output.shape}")
|
||||
print(f"\n✅ 测试通过!")
|
||||
|
||||
@@ -16,3 +16,23 @@ class IdentityActionEncoder(nn.Module):
|
||||
|
||||
def forward(self, action):
|
||||
return action
|
||||
|
||||
|
||||
class LeWMStateEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int = 16,
|
||||
hidden_dim: int = 256,
|
||||
output_dim: int = 64,
|
||||
):
|
||||
super().__init__()
|
||||
self.output_dim = int(output_dim)
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(int(input_dim), int(hidden_dim)),
|
||||
nn.LayerNorm(int(hidden_dim)),
|
||||
nn.GELU(),
|
||||
nn.Linear(int(hidden_dim), self.output_dim),
|
||||
)
|
||||
|
||||
def forward(self, state):
|
||||
return self.net(state)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user