5.7 KiB
5.7 KiB
LEWM ViT Backbone Replacement Design
Goal
将当前 roboimi VLA policy 中的 ResNet 视觉编码器替换为来自 LEWM checkpoint 的冻结 ViT 视觉编码器(encoder + projector),仅使用最终 CLS token 的 192 维 embedding 作为视觉特征。
User constraints
- 使用
/home/droid/下载/lewm_sim_transfer_checkpoint_usage.md中确认的训练好 checkpoint - 只使用视觉编码部分:
encoder + projector - 权重冻结
- 维持“视觉特征 + state 拼接,再送入 diffusion transformer”这一总体处理方式
- 输入使用三视角:
[r_vis, top, front] - 在 5880 机器上启动两个训练:
embed=384/layer=12和embed=256/layer=12 pred_horizon=16num_action_steps=8- 每个训练
50ksteps - rollout 验证每次用
10个 episodes,不是之前的5
Trusted existing facts
- LEWM checkpoint 路径:
/home/droid/le-wm/lewm-sim-transfer/pa1w85md8jop6bvol8oxp/checkpoints/epoch=99-step=47800.ckpt
- 需要加载的 state_dict 前缀:
model.encoder.*model.projector.*
- LEWM ViT 配置:
- encoder scale:
tiny - hidden size:
192 - layers:
12 - attention heads:
3 - patch size:
14 - projector:
MLP(192 -> 2048 -> 192)withBatchNorm1d + GELU
- encoder scale:
- LEWM 训练时三视角先拼成单图,再送入单个 ViT encoder;输出整体视觉 embedding 是 192 维。
Key design decision
Chosen design: fuse 3 cameras into one LEWM-style image, output one 192-d visual vector per timestep
不是把 LEWM ViT 当成“每相机一个 192-d encoder”,而是按 LEWM 原训练方式:
- 输入三视角图像字典
{r_vis, top, front} - 按固定顺序拼成一张 fused image
- 走单个 frozen ViT + projector
- 得到一个 192 维总视觉特征
Why this is the right replacement
当前 ResNet backbone 对外给到 policy head 的总视觉特征维度是:
- 每相机
64 - 三相机总计
192
而 LEWM checkpoint 输出的 CLS/projector embedding 也是:
- 总计
192
因此,最自然的“直接平替当前 ResNet 视觉编码器”的方式是:
- 用 LEWM backbone 直接产出一个 192-d 总视觉向量
- 后续和 state
16-d拼接后,依旧得到208-d条件向量 - 不改 diffusion head 的总体接口和语义
Interface compatibility plan
现有 VLAAgent 假设 backbone 暴露:
camera_namesnum_camerasoutput_dim(语义上是“每相机特征维度”)forward(images_dict) -> (B, T, total_visual_dim)
为了最小改动兼容现有 agent:
- 新 LEWM backbone 的
forward()返回(B, T, 192) camera_names = ('r_vis', 'top', 'front')num_cameras = 3output_dim = 64
这样 VLAAgent 内部仍会计算:
per_step_cond_dim = output_dim * num_cams + obs_dim = 64*3 + 16 = 208与实际forward()输出的192 + 16 = 208保持一致。
也就是说:
output_dim在这个 backbone 里保留为“与旧 ResNet 总特征等价的单相机占位维度”,而不是“真实 projector 输出维度”。这是一个兼容性 shim,用来避免改 agent 主逻辑。
Image preprocessing design
当前 roboimi dataset 已经把每个相机图像读成:
(C, 224, 224)- 值域
[0, 1]
新 LEWM backbone 将:
- 按顺序取
r_vis,top,front - 在宽度方向拼接,得到 fused image:
(C, 224, 672)
- 使用 LEWM 一致的 ImageNet normalize:
- mean
[0.485, 0.456, 0.406] - std
[0.229, 0.224, 0.225]
- mean
- 调用
ViTModel(..., interpolate_pos_encoding=True) - 取
last_hidden_state[:, 0] - 送入 frozen projector,得到
(B*T, 192)
Files to create / modify
New files
roboimi/vla/models/backbones/lewm_vit_backbone.pyroboimi/vla/conf/backbone/lewm_vit_diffusion.yamlroboimi/vla/conf/agent/lewm_imf_attnres.yamltests/test_lewm_vit_backbone.py
Modified files
roboimi/vla/models/backbones/__init__(如果需要导出)tests/test_imf_vla_agent.py(增加新 backbone 集成用例)roboimi/demos/vla_scripts/train_vla.py(如需仅调整 rollout 默认/日志;如果命令覆盖足够,则尽量不改主逻辑)- 训练/实验 suite 文档(新增本次 LEWM ViT 训练记录)
Testing plan
- Unit test: load + forward
- 用 synthetic checkpoint 验证新 backbone 能正确加载
model.encoder.*与model.projector.* - 输入 3 相机
(B,T,C,224,224) - 输出
(B,T,192)
- 用 synthetic checkpoint 验证新 backbone 能正确加载
- Agent integration test
- backbone.output_dim=64, num_cameras=3
- agent
_build_cond()输出最后维度为208
- Remote smoke test on 5880
- 使用真实 checkpoint
max_steps=2- 两个实验各自 smoke 一次
- Full run
- GPU0:
embed=384, layer=12 - GPU1:
embed=256, layer=12 rollout_num_episodes=10
- GPU0:
Training launch contract
- host:
100.73.14.65 - code dir:
/home/droid/roboimi_suite_20260404 - python:
/home/droid/miniforge3/envs/roboimi/bin/python - dataset:
/home/droid/sim_dataset/sim_transfer - cameras:
[r_vis, top, front] - agent: new
lewm_imf_attnres - max_steps:
50000 - rollout every
5epochs - rollout episodes:
10
Risks
- LEWM 训练时的 fused image 预处理如果方向实现错了(224x672 vs 672x224),会导致分布偏移。
- 当前 roboimi env 需确保安装
transformers;从environment.yml看本地已有该依赖,但远端训练环境要 smoke 确认。 - 因为这是 frozen ViT + projector,若 projector BN 仍保持 train 模式,统计量会漂移,所以必须整体
eval()并冻结。
Recommended first implementation path
- 先实现一个独立
LEWMViTBackbone类,不改现有ResNetDiffusionBackbone主逻辑。 - 再通过新的 hydra backbone/agent 配置接入。
- 优先做到“最少侵入 + smoke 可跑 + 远端可训”。