# LEWM ViT Backbone Replacement Design ## Goal 将当前 roboimi VLA policy 中的 ResNet 视觉编码器替换为来自 LEWM checkpoint 的冻结 ViT 视觉编码器(encoder + projector),仅使用最终 CLS token 的 192 维 embedding 作为视觉特征。 ## User constraints - 使用 `/home/droid/下载/lewm_sim_transfer_checkpoint_usage.md` 中确认的训练好 checkpoint - 只使用视觉编码部分:`encoder + projector` - 权重冻结 - 维持“视觉特征 + state 拼接,再送入 diffusion transformer”这一总体处理方式 - 输入使用三视角:`[r_vis, top, front]` - 在 5880 机器上启动两个训练:`embed=384/layer=12` 和 `embed=256/layer=12` - `pred_horizon=16` - `num_action_steps=8` - 每个训练 `50k` steps - rollout 验证每次用 `10` 个 episodes,不是之前的 `5` ## Trusted existing facts 1. LEWM checkpoint 路径: - `/home/droid/le-wm/lewm-sim-transfer/pa1w85md8jop6bvol8oxp/checkpoints/epoch=99-step=47800.ckpt` 2. 需要加载的 state_dict 前缀: - `model.encoder.*` - `model.projector.*` 3. LEWM ViT 配置: - encoder scale: `tiny` - hidden size: `192` - layers: `12` - attention heads: `3` - patch size: `14` - projector: `MLP(192 -> 2048 -> 192)` with `BatchNorm1d + GELU` 4. LEWM 训练时三视角先拼成单图,再送入单个 ViT encoder;输出整体视觉 embedding 是 **192 维**。 ## Key design decision ### Chosen design: fuse 3 cameras into one LEWM-style image, output one 192-d visual vector per timestep 不是把 LEWM ViT 当成“每相机一个 192-d encoder”,而是按 LEWM 原训练方式: - 输入三视角图像字典 `{r_vis, top, front}` - 按固定顺序拼成一张 fused image - 走单个 frozen ViT + projector - 得到一个 **192 维总视觉特征** ### Why this is the right replacement 当前 ResNet backbone 对外给到 policy head 的**总视觉特征维度**是: - 每相机 `64` - 三相机总计 `192` 而 LEWM checkpoint 输出的 CLS/projector embedding 也是: - 总计 `192` 因此,最自然的“直接平替当前 ResNet 视觉编码器”的方式是: - 用 LEWM backbone 直接产出一个 192-d 总视觉向量 - 后续和 state `16-d` 拼接后,依旧得到 `208-d` 条件向量 - 不改 diffusion head 的总体接口和语义 ## Interface compatibility plan 现有 `VLAAgent` 假设 backbone 暴露: - `camera_names` - `num_cameras` - `output_dim`(语义上是“每相机特征维度”) - `forward(images_dict) -> (B, T, total_visual_dim)` 为了最小改动兼容现有 agent: - 新 LEWM backbone 的 `forward()` 返回 `(B, T, 192)` - `camera_names = ('r_vis', 'top', 'front')` - `num_cameras = 3` - `output_dim = 64` 这样 `VLAAgent` 内部仍会计算: - `per_step_cond_dim = output_dim * num_cams + obs_dim = 64*3 + 16 = 208` 与实际 `forward()` 输出的 `192 + 16 = 208` 保持一致。 > 也就是说:`output_dim` 在这个 backbone 里保留为“与旧 ResNet 总特征等价的单相机占位维度”,而不是“真实 projector 输出维度”。这是一个兼容性 shim,用来避免改 agent 主逻辑。 ## Image preprocessing design 当前 roboimi dataset 已经把每个相机图像读成: - `(C, 224, 224)` - 值域 `[0, 1]` 新 LEWM backbone 将: 1. 按顺序取 `r_vis`, `top`, `front` 2. 在宽度方向拼接,得到 fused image: - `(C, 224, 672)` 3. 使用 LEWM 一致的 ImageNet normalize: - mean `[0.485, 0.456, 0.406]` - std `[0.229, 0.224, 0.225]` 4. 调用 `ViTModel(..., interpolate_pos_encoding=True)` 5. 取 `last_hidden_state[:, 0]` 6. 送入 frozen projector,得到 `(B*T, 192)` ## Files to create / modify ### New files - `roboimi/vla/models/backbones/lewm_vit_backbone.py` - `roboimi/vla/conf/backbone/lewm_vit_diffusion.yaml` - `roboimi/vla/conf/agent/lewm_imf_attnres.yaml` - `tests/test_lewm_vit_backbone.py` ### Modified files - `roboimi/vla/models/backbones/__init__`(如果需要导出) - `tests/test_imf_vla_agent.py`(增加新 backbone 集成用例) - `roboimi/demos/vla_scripts/train_vla.py`(如需仅调整 rollout 默认/日志;如果命令覆盖足够,则尽量不改主逻辑) - 训练/实验 suite 文档(新增本次 LEWM ViT 训练记录) ## Testing plan 1. **Unit test: load + forward** - 用 synthetic checkpoint 验证新 backbone 能正确加载 `model.encoder.*` 与 `model.projector.*` - 输入 3 相机 `(B,T,C,224,224)` - 输出 `(B,T,192)` 2. **Agent integration test** - backbone.output_dim=64, num_cameras=3 - agent `_build_cond()` 输出最后维度为 `208` 3. **Remote smoke test on 5880** - 使用真实 checkpoint - `max_steps=2` - 两个实验各自 smoke 一次 4. **Full run** - GPU0: `embed=384, layer=12` - GPU1: `embed=256, layer=12` - `rollout_num_episodes=10` ## Training launch contract - host: `100.73.14.65` - code dir: `/home/droid/roboimi_suite_20260404` - python: `/home/droid/miniforge3/envs/roboimi/bin/python` - dataset: `/home/droid/sim_dataset/sim_transfer` - cameras: `[r_vis, top, front]` - agent: new `lewm_imf_attnres` - max_steps: `50000` - rollout every `5` epochs - rollout episodes: `10` ## Risks 1. LEWM 训练时的 fused image 预处理如果方向实现错了(224x672 vs 672x224),会导致分布偏移。 2. 当前 roboimi env 需确保安装 `transformers`;从 `environment.yml` 看本地已有该依赖,但远端训练环境要 smoke 确认。 3. 因为这是 frozen ViT + projector,若 projector BN 仍保持 train 模式,统计量会漂移,所以必须整体 `eval()` 并冻结。 ## Recommended first implementation path - 先实现一个独立 `LEWMViTBackbone` 类,不改现有 `ResNetDiffusionBackbone` 主逻辑。 - 再通过新的 hydra backbone/agent 配置接入。 - 优先做到“最少侵入 + smoke 可跑 + 远端可训”。