Files
roboimi/docs/superpowers/specs/2026-04-05-lewm-vit-backbone-design.md

5.7 KiB
Raw Blame History

LEWM ViT Backbone Replacement Design

Goal

将当前 roboimi VLA policy 中的 ResNet 视觉编码器替换为来自 LEWM checkpoint 的冻结 ViT 视觉编码器encoder + projector仅使用最终 CLS token 的 192 维 embedding 作为视觉特征。

User constraints

  • 使用 /home/droid/下载/lewm_sim_transfer_checkpoint_usage.md 中确认的训练好 checkpoint
  • 只使用视觉编码部分:encoder + projector
  • 权重冻结
  • 维持“视觉特征 + state 拼接,再送入 diffusion transformer”这一总体处理方式
  • 输入使用三视角:[r_vis, top, front]
  • 在 5880 机器上启动两个训练:embed=384/layer=12embed=256/layer=12
  • pred_horizon=16
  • num_action_steps=8
  • 每个训练 50k steps
  • rollout 验证每次用 10 个 episodes不是之前的 5

Trusted existing facts

  1. LEWM checkpoint 路径:
    • /home/droid/le-wm/lewm-sim-transfer/pa1w85md8jop6bvol8oxp/checkpoints/epoch=99-step=47800.ckpt
  2. 需要加载的 state_dict 前缀:
    • model.encoder.*
    • model.projector.*
  3. LEWM ViT 配置:
    • encoder scale: tiny
    • hidden size: 192
    • layers: 12
    • attention heads: 3
    • patch size: 14
    • projector: MLP(192 -> 2048 -> 192) with BatchNorm1d + GELU
  4. LEWM 训练时三视角先拼成单图,再送入单个 ViT encoder输出整体视觉 embedding 是 192 维

Key design decision

Chosen design: fuse 3 cameras into one LEWM-style image, output one 192-d visual vector per timestep

不是把 LEWM ViT 当成“每相机一个 192-d encoder”而是按 LEWM 原训练方式:

  • 输入三视角图像字典 {r_vis, top, front}
  • 按固定顺序拼成一张 fused image
  • 走单个 frozen ViT + projector
  • 得到一个 192 维总视觉特征

Why this is the right replacement

当前 ResNet backbone 对外给到 policy head 的总视觉特征维度是:

  • 每相机 64
  • 三相机总计 192

而 LEWM checkpoint 输出的 CLS/projector embedding 也是:

  • 总计 192

因此,最自然的“直接平替当前 ResNet 视觉编码器”的方式是:

  • 用 LEWM backbone 直接产出一个 192-d 总视觉向量
  • 后续和 state 16-d 拼接后,依旧得到 208-d 条件向量
  • 不改 diffusion head 的总体接口和语义

Interface compatibility plan

现有 VLAAgent 假设 backbone 暴露:

  • camera_names
  • num_cameras
  • output_dim(语义上是“每相机特征维度”)
  • forward(images_dict) -> (B, T, total_visual_dim)

为了最小改动兼容现有 agent

  • 新 LEWM backbone 的 forward() 返回 (B, T, 192)
  • camera_names = ('r_vis', 'top', 'front')
  • num_cameras = 3
  • output_dim = 64

这样 VLAAgent 内部仍会计算:

  • per_step_cond_dim = output_dim * num_cams + obs_dim = 64*3 + 16 = 208 与实际 forward() 输出的 192 + 16 = 208 保持一致。

也就是说:output_dim 在这个 backbone 里保留为“与旧 ResNet 总特征等价的单相机占位维度”,而不是“真实 projector 输出维度”。这是一个兼容性 shim用来避免改 agent 主逻辑。

Image preprocessing design

当前 roboimi dataset 已经把每个相机图像读成:

  • (C, 224, 224)
  • 值域 [0, 1]

新 LEWM backbone 将:

  1. 按顺序取 r_vis, top, front
  2. 在宽度方向拼接,得到 fused image
    • (C, 224, 672)
  3. 使用 LEWM 一致的 ImageNet normalize
    • mean [0.485, 0.456, 0.406]
    • std [0.229, 0.224, 0.225]
  4. 调用 ViTModel(..., interpolate_pos_encoding=True)
  5. last_hidden_state[:, 0]
  6. 送入 frozen projector得到 (B*T, 192)

Files to create / modify

New files

  • roboimi/vla/models/backbones/lewm_vit_backbone.py
  • roboimi/vla/conf/backbone/lewm_vit_diffusion.yaml
  • roboimi/vla/conf/agent/lewm_imf_attnres.yaml
  • tests/test_lewm_vit_backbone.py

Modified files

  • roboimi/vla/models/backbones/__init__(如果需要导出)
  • tests/test_imf_vla_agent.py(增加新 backbone 集成用例)
  • roboimi/demos/vla_scripts/train_vla.py(如需仅调整 rollout 默认/日志;如果命令覆盖足够,则尽量不改主逻辑)
  • 训练/实验 suite 文档(新增本次 LEWM ViT 训练记录)

Testing plan

  1. Unit test: load + forward
    • 用 synthetic checkpoint 验证新 backbone 能正确加载 model.encoder.*model.projector.*
    • 输入 3 相机 (B,T,C,224,224)
    • 输出 (B,T,192)
  2. Agent integration test
    • backbone.output_dim=64, num_cameras=3
    • agent _build_cond() 输出最后维度为 208
  3. Remote smoke test on 5880
    • 使用真实 checkpoint
    • max_steps=2
    • 两个实验各自 smoke 一次
  4. Full run
    • GPU0: embed=384, layer=12
    • GPU1: embed=256, layer=12
    • rollout_num_episodes=10

Training launch contract

  • host: 100.73.14.65
  • code dir: /home/droid/roboimi_suite_20260404
  • python: /home/droid/miniforge3/envs/roboimi/bin/python
  • dataset: /home/droid/sim_dataset/sim_transfer
  • cameras: [r_vis, top, front]
  • agent: new lewm_imf_attnres
  • max_steps: 50000
  • rollout every 5 epochs
  • rollout episodes: 10

Risks

  1. LEWM 训练时的 fused image 预处理如果方向实现错了224x672 vs 672x224会导致分布偏移。
  2. 当前 roboimi env 需确保安装 transformers;从 environment.yml 看本地已有该依赖,但远端训练环境要 smoke 确认。
  3. 因为这是 frozen ViT + projector若 projector BN 仍保持 train 模式,统计量会漂移,所以必须整体 eval() 并冻结。
  • 先实现一个独立 LEWMViTBackbone 类,不改现有 ResNetDiffusionBackbone 主逻辑。
  • 再通过新的 hydra backbone/agent 配置接入。
  • 优先做到“最少侵入 + smoke 可跑 + 远端可训”。