Files
roboimi/docs/superpowers/specs/2026-04-05-lewm-vit-backbone-design.md

139 lines
5.7 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# LEWM ViT Backbone Replacement Design
## Goal
将当前 roboimi VLA policy 中的 ResNet 视觉编码器替换为来自 LEWM checkpoint 的冻结 ViT 视觉编码器encoder + projector仅使用最终 CLS token 的 192 维 embedding 作为视觉特征。
## User constraints
- 使用 `/home/droid/下载/lewm_sim_transfer_checkpoint_usage.md` 中确认的训练好 checkpoint
- 只使用视觉编码部分:`encoder + projector`
- 权重冻结
- 维持“视觉特征 + state 拼接,再送入 diffusion transformer”这一总体处理方式
- 输入使用三视角:`[r_vis, top, front]`
- 在 5880 机器上启动两个训练:`embed=384/layer=12``embed=256/layer=12`
- `pred_horizon=16`
- `num_action_steps=8`
- 每个训练 `50k` steps
- rollout 验证每次用 `10` 个 episodes不是之前的 `5`
## Trusted existing facts
1. LEWM checkpoint 路径:
- `/home/droid/le-wm/lewm-sim-transfer/pa1w85md8jop6bvol8oxp/checkpoints/epoch=99-step=47800.ckpt`
2. 需要加载的 state_dict 前缀:
- `model.encoder.*`
- `model.projector.*`
3. LEWM ViT 配置:
- encoder scale: `tiny`
- hidden size: `192`
- layers: `12`
- attention heads: `3`
- patch size: `14`
- projector: `MLP(192 -> 2048 -> 192)` with `BatchNorm1d + GELU`
4. LEWM 训练时三视角先拼成单图,再送入单个 ViT encoder输出整体视觉 embedding 是 **192 维**
## Key design decision
### Chosen design: fuse 3 cameras into one LEWM-style image, output one 192-d visual vector per timestep
不是把 LEWM ViT 当成“每相机一个 192-d encoder”而是按 LEWM 原训练方式:
- 输入三视角图像字典 `{r_vis, top, front}`
- 按固定顺序拼成一张 fused image
- 走单个 frozen ViT + projector
- 得到一个 **192 维总视觉特征**
### Why this is the right replacement
当前 ResNet backbone 对外给到 policy head 的**总视觉特征维度**是:
- 每相机 `64`
- 三相机总计 `192`
而 LEWM checkpoint 输出的 CLS/projector embedding 也是:
- 总计 `192`
因此,最自然的“直接平替当前 ResNet 视觉编码器”的方式是:
- 用 LEWM backbone 直接产出一个 192-d 总视觉向量
- 后续和 state `16-d` 拼接后,依旧得到 `208-d` 条件向量
- 不改 diffusion head 的总体接口和语义
## Interface compatibility plan
现有 `VLAAgent` 假设 backbone 暴露:
- `camera_names`
- `num_cameras`
- `output_dim`(语义上是“每相机特征维度”)
- `forward(images_dict) -> (B, T, total_visual_dim)`
为了最小改动兼容现有 agent
- 新 LEWM backbone 的 `forward()` 返回 `(B, T, 192)`
- `camera_names = ('r_vis', 'top', 'front')`
- `num_cameras = 3`
- `output_dim = 64`
这样 `VLAAgent` 内部仍会计算:
- `per_step_cond_dim = output_dim * num_cams + obs_dim = 64*3 + 16 = 208`
与实际 `forward()` 输出的 `192 + 16 = 208` 保持一致。
> 也就是说:`output_dim` 在这个 backbone 里保留为“与旧 ResNet 总特征等价的单相机占位维度”,而不是“真实 projector 输出维度”。这是一个兼容性 shim用来避免改 agent 主逻辑。
## Image preprocessing design
当前 roboimi dataset 已经把每个相机图像读成:
- `(C, 224, 224)`
- 值域 `[0, 1]`
新 LEWM backbone 将:
1. 按顺序取 `r_vis`, `top`, `front`
2. 在宽度方向拼接,得到 fused image
- `(C, 224, 672)`
3. 使用 LEWM 一致的 ImageNet normalize
- mean `[0.485, 0.456, 0.406]`
- std `[0.229, 0.224, 0.225]`
4. 调用 `ViTModel(..., interpolate_pos_encoding=True)`
5.`last_hidden_state[:, 0]`
6. 送入 frozen projector得到 `(B*T, 192)`
## Files to create / modify
### New files
- `roboimi/vla/models/backbones/lewm_vit_backbone.py`
- `roboimi/vla/conf/backbone/lewm_vit_diffusion.yaml`
- `roboimi/vla/conf/agent/lewm_imf_attnres.yaml`
- `tests/test_lewm_vit_backbone.py`
### Modified files
- `roboimi/vla/models/backbones/__init__`(如果需要导出)
- `tests/test_imf_vla_agent.py`(增加新 backbone 集成用例)
- `roboimi/demos/vla_scripts/train_vla.py`(如需仅调整 rollout 默认/日志;如果命令覆盖足够,则尽量不改主逻辑)
- 训练/实验 suite 文档(新增本次 LEWM ViT 训练记录)
## Testing plan
1. **Unit test: load + forward**
- 用 synthetic checkpoint 验证新 backbone 能正确加载 `model.encoder.*``model.projector.*`
- 输入 3 相机 `(B,T,C,224,224)`
- 输出 `(B,T,192)`
2. **Agent integration test**
- backbone.output_dim=64, num_cameras=3
- agent `_build_cond()` 输出最后维度为 `208`
3. **Remote smoke test on 5880**
- 使用真实 checkpoint
- `max_steps=2`
- 两个实验各自 smoke 一次
4. **Full run**
- GPU0: `embed=384, layer=12`
- GPU1: `embed=256, layer=12`
- `rollout_num_episodes=10`
## Training launch contract
- host: `100.73.14.65`
- code dir: `/home/droid/roboimi_suite_20260404`
- python: `/home/droid/miniforge3/envs/roboimi/bin/python`
- dataset: `/home/droid/sim_dataset/sim_transfer`
- cameras: `[r_vis, top, front]`
- agent: new `lewm_imf_attnres`
- max_steps: `50000`
- rollout every `5` epochs
- rollout episodes: `10`
## Risks
1. LEWM 训练时的 fused image 预处理如果方向实现错了224x672 vs 672x224会导致分布偏移。
2. 当前 roboimi env 需确保安装 `transformers`;从 `environment.yml` 看本地已有该依赖,但远端训练环境要 smoke 确认。
3. 因为这是 frozen ViT + projector若 projector BN 仍保持 train 模式,统计量会漂移,所以必须整体 `eval()` 并冻结。
## Recommended first implementation path
- 先实现一个独立 `LEWMViTBackbone` 类,不改现有 `ResNetDiffusionBackbone` 主逻辑。
- 再通过新的 hydra backbone/agent 配置接入。
- 优先做到“最少侵入 + smoke 可跑 + 远端可训”。