feat: add vision transfer backbones and IMF variants
This commit is contained in:
@@ -0,0 +1,92 @@
|
||||
# LEWM ViT Backbone Implementation Plan
|
||||
|
||||
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||
|
||||
**Goal:** Replace the current ResNet visual encoder in roboimi VLA training with a frozen LEWM ViT visual backbone (encoder + projector) that consumes the three camera views jointly and outputs one 192-d CLS embedding per timestep, then launch two 50k runs on the 5880 machine.
|
||||
|
||||
**Architecture:** Add a new joint-multiview LEWM backbone that fuses `front/top/r_vis` into one LEWM-style image, reproduces LEWM preprocessing, loads frozen weights from the trained checkpoint, and exposes a `joint_output_dim=192`. Add a minimal `VLAAgent` compatibility branch so conditions can be sized from joint visual dim instead of `output_dim * num_cams`, while leaving the rest of the diffusion pipeline unchanged.
|
||||
|
||||
**Tech Stack:** PyTorch, transformers `ViTModel`, Hydra configs, existing roboimi VLA training/eval scripts, remote SSH/rsync to 100.73.14.65.
|
||||
|
||||
---
|
||||
|
||||
### Task 1: Add failing tests for LEWM joint-vision backbone contract
|
||||
|
||||
**Files:**
|
||||
- Create: `tests/test_lewm_vit_backbone.py`
|
||||
- Modify: `tests/test_imf_vla_agent.py`
|
||||
|
||||
- [ ] **Step 1: Write the failing backbone shape/load test**
|
||||
- [ ] **Step 2: Run `pytest tests/test_lewm_vit_backbone.py -q` and verify it fails**
|
||||
- [ ] **Step 3: Extend `tests/test_imf_vla_agent.py` with a failing joint-output backbone case**
|
||||
- [ ] **Step 4: Run `pytest tests/test_imf_vla_agent.py -q` and verify it fails**
|
||||
|
||||
### Task 2: Implement LEWM joint-multiview frozen backbone
|
||||
|
||||
**Files:**
|
||||
- Create: `roboimi/vla/models/backbones/lewm_vit_backbone.py`
|
||||
- Modify: `roboimi/vla/models/backbones/__init__.py` only if exports are needed
|
||||
|
||||
- [ ] **Step 1: Create `LEWMViTBackbone` with public attrs `camera_names`, `num_cameras`, `joint_output_dim=192`**
|
||||
- [ ] **Step 2: Reproduce LEWM preprocessing and joint multiview fusion**
|
||||
- [ ] **Step 3: Load checkpoint weights from `model.encoder.*` and `model.projector.*`**
|
||||
- [ ] **Step 4: Freeze encoder/projector and keep them in eval mode via `train()` override**
|
||||
- [ ] **Step 5: Run `pytest tests/test_lewm_vit_backbone.py -q` and verify green**
|
||||
|
||||
### Task 3: Add minimal agent support for joint visual dim
|
||||
|
||||
**Files:**
|
||||
- Modify: `roboimi/vla/agent.py`
|
||||
- Test: `tests/test_imf_vla_agent.py`
|
||||
|
||||
- [ ] **Step 1: Add a `joint_output_dim` branch in `VLAAgent.__init__` for `per_step_cond_dim` / `global_cond_dim`**
|
||||
- [ ] **Step 2: Keep `_build_cond()` semantics unchanged except for matching the new dim contract**
|
||||
- [ ] **Step 3: Run `pytest tests/test_imf_vla_agent.py -q` and verify green**
|
||||
|
||||
### Task 4: Add Hydra configs for LEWM backbone training
|
||||
|
||||
**Files:**
|
||||
- Create: `roboimi/vla/conf/backbone/lewm_vit_diffusion.yaml`
|
||||
- Create: `roboimi/vla/conf/agent/lewm_imf_attnres.yaml`
|
||||
|
||||
- [ ] **Step 1: Add backbone config pointing to the new LEWM backbone**
|
||||
- [ ] **Step 2: Add `agent=lewm_imf_attnres` config with 3 cameras and `head.cond_dim=208`**
|
||||
- [ ] **Step 3: Verify Hydra instantiation with a one-shot compose smoke**
|
||||
|
||||
### Task 5: Verify focused local tests
|
||||
|
||||
**Files:**
|
||||
- Reuse the above
|
||||
|
||||
- [ ] **Step 1: Run `pytest tests/test_lewm_vit_backbone.py tests/test_imf_vla_agent.py tests/test_eval_vla_headless_import.py -q`**
|
||||
- [ ] **Step 2: If needed, run one tiny local import/forward smoke**
|
||||
|
||||
### Task 6: Sync to 5880 and remote smoke with real checkpoint
|
||||
|
||||
**Files:**
|
||||
- Remote target: `/home/droid/roboimi_suite_20260404`
|
||||
|
||||
- [ ] **Step 1: Rsync modified source/config files to `100.73.14.65:/home/droid/roboimi_suite_20260404`**
|
||||
- [ ] **Step 2: Run a 2-step smoke on GPU0 with `agent.head.n_emb=384`, `train.rollout_num_episodes=10`, real LEWM checkpoint**
|
||||
- [ ] **Step 3: Run a 2-step smoke on GPU1 with `agent.head.n_emb=256`, same checkpoint**
|
||||
|
||||
### Task 7: Launch two real 50k runs on the 5880 machine
|
||||
|
||||
**Files:**
|
||||
- Remote logs under `/home/droid/roboimi_suite_20260404/experiment_suite_launch_logs/`
|
||||
|
||||
- [ ] **Step 1: Launch embed384/layer12 on GPU0**
|
||||
- [ ] **Step 2: Launch embed256/layer12 on GPU1**
|
||||
- [ ] **Step 3: Ensure both use `data.camera_names=[r_vis,top,front]`, `pred_horizon=16`, `num_action_steps=8`, `train.rollout_num_episodes=10`, `max_steps=50000`**
|
||||
- [ ] **Step 4: Record run names, pids, log paths, SwanLab URLs**
|
||||
|
||||
### Task 8: Update experiment tracking docs and commit
|
||||
|
||||
**Files:**
|
||||
- Create: `experiment_suites/2026-04-05-lewm-vit-transfer/manifest.json`
|
||||
- Create: `experiment_suites/2026-04-05-lewm-vit-transfer/status.json`
|
||||
- Create: `experiment_suites/2026-04-05-lewm-vit-transfer/notes.md`
|
||||
|
||||
- [ ] **Step 1: Record checkpoint path, frozen LEWM design, rollout=10, and both run configs**
|
||||
- [ ] **Step 2: Record running status after launch**
|
||||
- [ ] **Step 3: Commit implementation + docs with a focused message**
|
||||
81
docs/superpowers/plans/2026-04-06-resnet-multitoken-imf.md
Normal file
81
docs/superpowers/plans/2026-04-06-resnet-multitoken-imf.md
Normal file
@@ -0,0 +1,81 @@
|
||||
# ResNet Multitoken IMF Implementation Plan
|
||||
|
||||
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||
|
||||
**Goal:** Implement a standard-ResNet-18 multiview IMF variant that emits three condition tokens per obs step and launch four L20 experiments for `n_emb in {256,384}` and `n_layer in {12,16}`.
|
||||
|
||||
**Architecture:** The ResNet backbone will optionally return one token per camera instead of concatenating all cameras into one token. `VLAAgent` will pair each camera token with the current state, project each pair into a condition token, flatten the per-step camera tokens into one cond sequence, and feed that sequence into the existing IMF/AttnRes head.
|
||||
|
||||
**Tech Stack:** PyTorch, torchvision ResNet-18, Hydra, pytest, SwanLab, SSH/Tailscale.
|
||||
|
||||
---
|
||||
|
||||
### Task 1: Add failing tests for multi-token conditioning
|
||||
|
||||
**Files:**
|
||||
- Modify: `tests/test_imf_vla_agent.py`
|
||||
- Modify: `tests/test_resnet_transformer_agent_wiring.py`
|
||||
|
||||
- [ ] **Step 1: Add a direct agent test**
|
||||
- Stub a vision backbone returning `(B,T,3,D)` and assert `_build_cond()` yields `(B, T*3, D_cond)`.
|
||||
- Assert state is paired with each camera token, not concatenated across cameras first.
|
||||
- [ ] **Step 2: Add Hydra wiring test**
|
||||
- Instantiate a new `agent=resnet_imf_attnres_multitoken` config with small dims.
|
||||
- Assert `condition_tokens_per_step == 3`, `condition_sequence_length == obs_horizon * 3`, and head `n_obs_steps` receives that sequence length.
|
||||
- [ ] **Step 3: Run focused tests and verify RED**
|
||||
- `python -m pytest tests/test_imf_vla_agent.py tests/test_resnet_transformer_agent_wiring.py -q`
|
||||
|
||||
### Task 2: Implement multi-token ResNet conditioning path
|
||||
|
||||
**Files:**
|
||||
- Modify: `roboimi/vla/models/backbones/resnet_diffusion.py`
|
||||
- Modify: `roboimi/vla/agent.py`
|
||||
- Create: `roboimi/vla/conf/agent/resnet_imf_attnres_multitoken.yaml`
|
||||
|
||||
- [ ] **Step 1: Extend ResNet backbone**
|
||||
- Add an opt-in flag to return `(B,T,num_cams,D)` camera tokens instead of one concatenated `(B,T,num_cams*D)` token.
|
||||
- Keep standard ResNet-18 vision mode; do not switch to AttnRes vision.
|
||||
- [ ] **Step 2: Extend VLAAgent condition building**
|
||||
- Support visual features with rank 4 `(B,T,K,D)`.
|
||||
- Broadcast state to `(B,T,K,D_state)`, concatenate per camera, apply projector per token, then flatten to `(B,T*K,D_cond)`.
|
||||
- Track `condition_tokens_per_step` and `condition_sequence_length`.
|
||||
- [ ] **Step 3: Update transformer-head instantiation**
|
||||
- Pass `n_obs_steps=condition_sequence_length` when building transformer heads.
|
||||
- [ ] **Step 4: Add Hydra config**
|
||||
- New agent config uses:
|
||||
- separate ResNet-18 per camera
|
||||
- standard residual vision trunk (`vision_backbone_mode=resnet`)
|
||||
- condition projector output dim tied to `${agent.head.n_emb}`
|
||||
- rollout episodes `10`, `pred_horizon=16`, `num_action_steps=8`
|
||||
|
||||
### Task 3: Verify locally
|
||||
|
||||
**Files:**
|
||||
- Modify only if verification reveals issues
|
||||
|
||||
- [ ] **Step 1: Run focused tests and make them pass**
|
||||
- `python -m pytest tests/test_imf_vla_agent.py tests/test_resnet_transformer_agent_wiring.py -q`
|
||||
- [ ] **Step 2: Run regression subset**
|
||||
- `python -m pytest tests/test_eval_vla_headless.py tests/test_train_vla_rollout_validation.py tests/test_simple_robot_dataset_image_loading.py -q`
|
||||
- [ ] **Step 3: Run local smoke instantiation**
|
||||
- instantiate the new Hydra config and verify cond shape / sequence length
|
||||
|
||||
### Task 4: Launch 4 L20 experiments
|
||||
|
||||
**Files:**
|
||||
- Remote repo copy under `/home/droid/roboimi_suite_20260404`
|
||||
|
||||
- [ ] **Step 1: Sync code to `100.119.99.14`**
|
||||
- [ ] **Step 2: Smoke the new config on remote**
|
||||
- [ ] **Step 3: Launch runs**
|
||||
- `(n_emb=256, n_layer=12)`
|
||||
- `(n_emb=256, n_layer=16)`
|
||||
- `(n_emb=384, n_layer=12)`
|
||||
- `(n_emb=384, n_layer=16)`
|
||||
- [ ] **Step 4: Keep fixed across runs**
|
||||
- rollout episodes `10`
|
||||
- `pred_horizon=16`
|
||||
- `num_action_steps=8`
|
||||
- standard ResNet-18 vision trunk
|
||||
- three separate camera weights
|
||||
- [ ] **Step 5: Record PIDs, GPUs, log paths, SwanLab URLs**
|
||||
78
docs/superpowers/plans/2026-04-06-siglip2-multiview-vla.md
Normal file
78
docs/superpowers/plans/2026-04-06-siglip2-multiview-vla.md
Normal file
@@ -0,0 +1,78 @@
|
||||
# SigLIP2 Multiview VLA Implementation Plan
|
||||
|
||||
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||
|
||||
**Goal:** Integrate a frozen shared SigLIP2 multiview encoder into the IMF/AttnRes policy, preserve raw-256 image handling, and launch two 50k-step experiments on the 5880 host with per-view projection dims 96 and 192.
|
||||
|
||||
**Architecture:** A new backbone will independently encode each camera view with SigLIP2 and project each 768-d pooled feature to a configurable per-view dimension. `VLAAgent` will concatenate visual features with robot state, then optionally project the combined per-step condition to the head's required 384-d interface before diffusion training/inference.
|
||||
|
||||
**Tech Stack:** PyTorch, transformers SigLIP2, Hydra, pytest, SSH/Tailscale, SwanLab.
|
||||
|
||||
---
|
||||
|
||||
### Task 1: Add failing tests for SigLIP2 backbone and projected conditioning
|
||||
|
||||
**Files:**
|
||||
- Create: `tests/test_siglip2_diffusion_backbone.py`
|
||||
- Modify: `tests/test_imf_vla_agent.py`
|
||||
|
||||
- [ ] **Step 1: Write failing backbone tests**
|
||||
- Instantiate the new backbone with a stub SigLIP2 vision model.
|
||||
- Assert raw dataset resize is `None`, eval resize is `(256, 256)`, output shape is `(B, T, 3 * per_view_output_dim)`.
|
||||
- Assert three views are encoded independently and projected.
|
||||
- [ ] **Step 2: Run focused tests and verify RED**
|
||||
- Run `pytest tests/test_siglip2_diffusion_backbone.py tests/test_imf_vla_agent.py -q`
|
||||
- Expect failure because the backbone/config/projector do not exist yet.
|
||||
- [ ] **Step 3: Extend agent wiring tests**
|
||||
- Add a Hydra/instantiate test for a new SigLIP2 IMF config.
|
||||
- Assert raw condition dim `3 * per_view_output_dim + obs_dim`, projected cond dim `384`, and head `cond_dim == 384`.
|
||||
|
||||
### Task 2: Implement SigLIP2 backbone and optional condition projector
|
||||
|
||||
**Files:**
|
||||
- Create: `roboimi/vla/models/backbones/siglip2_diffusion_backbone.py`
|
||||
- Create: `roboimi/vla/conf/backbone/siglip2_diffusion.yaml`
|
||||
- Create: `roboimi/vla/conf/agent/siglip2_imf_attnres.yaml`
|
||||
- Create: `roboimi/vla/conf/modules/linear_condition_projector.yaml`
|
||||
- Modify: `roboimi/vla/models/backbones/__init__.py`
|
||||
- Modify: `roboimi/vla/agent.py`
|
||||
|
||||
- [ ] **Step 1: Implement backbone**
|
||||
- Load `SiglipVisionModel.from_pretrained("google/siglip2-base-patch16-256")`.
|
||||
- Normalize `[0,1]` pixels with mean/std `0.5` and encode each view independently.
|
||||
- Project each 768-d pooled feature to configurable per-view dim and concatenate across cameras.
|
||||
- [ ] **Step 2: Implement optional condition projector**
|
||||
- Allow `VLAAgent` to accept `cond_projector`.
|
||||
- Track `raw_per_step_cond_dim` and projected `per_step_cond_dim` / `global_cond_dim`.
|
||||
- Apply the projector in `_build_cond()` after visual+state concatenation.
|
||||
- [ ] **Step 3: Add Hydra configs**
|
||||
- New agent config should default to `n_emb=384`, `n_layer=12`, `pred_horizon=16`, `num_action_steps=8`, `head.cond_dim=384`.
|
||||
- Backbone config should set `dataset_image_resize_shape: null` and `eval_image_resize_shape: [256, 256]`.
|
||||
|
||||
### Task 3: Verify locally and prepare remote execution
|
||||
|
||||
**Files:**
|
||||
- Modify as needed only if tests/smoke reveal issues
|
||||
|
||||
- [ ] **Step 1: Run focused tests and make them pass**
|
||||
- `pytest tests/test_siglip2_diffusion_backbone.py tests/test_imf_vla_agent.py tests/test_eval_vla_headless.py tests/test_train_vla_rollout_validation.py tests/test_simple_robot_dataset_image_loading.py -q`
|
||||
- [ ] **Step 2: Run a local smoke instantiation**
|
||||
- Instantiate the new Hydra config with stubbed optional modules or offline-safe monkeypatching.
|
||||
- [ ] **Step 3: Review diffs for unintended LEWM/raw256 regressions**
|
||||
|
||||
### Task 4: Sync to 5880 and launch experiments
|
||||
|
||||
**Files:**
|
||||
- Remote repo copy under `/home/droid/roboimi_suite_20260404`
|
||||
|
||||
- [ ] **Step 1: Stop superseded remote jobs**
|
||||
- [ ] **Step 2: Sync updated code to remote**
|
||||
- Prefer `rsync` or `git push/pull` without overwriting unrelated files.
|
||||
- [ ] **Step 3: Remote smoke test**
|
||||
- Confirm SigLIP2 model download/import works in `/home/droid/miniforge3/envs/roboimi/bin/python`.
|
||||
- Confirm headless rollout path still uses `256x256` eval resize.
|
||||
- [ ] **Step 4: Launch experiment A**
|
||||
- `per_view_output_dim=96`, `embed=384`, `layer=12`, `pred=16`, `exec=8`, `steps=50000`.
|
||||
- [ ] **Step 5: Launch experiment B**
|
||||
- `per_view_output_dim=192`, same other hyperparameters.
|
||||
- [ ] **Step 6: Record PIDs, GPUs, log paths, and SwanLab run URLs.**
|
||||
138
docs/superpowers/specs/2026-04-05-lewm-vit-backbone-design.md
Normal file
138
docs/superpowers/specs/2026-04-05-lewm-vit-backbone-design.md
Normal file
@@ -0,0 +1,138 @@
|
||||
# LEWM ViT Backbone Replacement Design
|
||||
|
||||
## Goal
|
||||
将当前 roboimi VLA policy 中的 ResNet 视觉编码器替换为来自 LEWM checkpoint 的冻结 ViT 视觉编码器(encoder + projector),仅使用最终 CLS token 的 192 维 embedding 作为视觉特征。
|
||||
|
||||
## User constraints
|
||||
- 使用 `/home/droid/下载/lewm_sim_transfer_checkpoint_usage.md` 中确认的训练好 checkpoint
|
||||
- 只使用视觉编码部分:`encoder + projector`
|
||||
- 权重冻结
|
||||
- 维持“视觉特征 + state 拼接,再送入 diffusion transformer”这一总体处理方式
|
||||
- 输入使用三视角:`[r_vis, top, front]`
|
||||
- 在 5880 机器上启动两个训练:`embed=384/layer=12` 和 `embed=256/layer=12`
|
||||
- `pred_horizon=16`
|
||||
- `num_action_steps=8`
|
||||
- 每个训练 `50k` steps
|
||||
- rollout 验证每次用 `10` 个 episodes,不是之前的 `5`
|
||||
|
||||
## Trusted existing facts
|
||||
1. LEWM checkpoint 路径:
|
||||
- `/home/droid/le-wm/lewm-sim-transfer/pa1w85md8jop6bvol8oxp/checkpoints/epoch=99-step=47800.ckpt`
|
||||
2. 需要加载的 state_dict 前缀:
|
||||
- `model.encoder.*`
|
||||
- `model.projector.*`
|
||||
3. LEWM ViT 配置:
|
||||
- encoder scale: `tiny`
|
||||
- hidden size: `192`
|
||||
- layers: `12`
|
||||
- attention heads: `3`
|
||||
- patch size: `14`
|
||||
- projector: `MLP(192 -> 2048 -> 192)` with `BatchNorm1d + GELU`
|
||||
4. LEWM 训练时三视角先拼成单图,再送入单个 ViT encoder;输出整体视觉 embedding 是 **192 维**。
|
||||
|
||||
## Key design decision
|
||||
### Chosen design: fuse 3 cameras into one LEWM-style image, output one 192-d visual vector per timestep
|
||||
不是把 LEWM ViT 当成“每相机一个 192-d encoder”,而是按 LEWM 原训练方式:
|
||||
- 输入三视角图像字典 `{r_vis, top, front}`
|
||||
- 按固定顺序拼成一张 fused image
|
||||
- 走单个 frozen ViT + projector
|
||||
- 得到一个 **192 维总视觉特征**
|
||||
|
||||
### Why this is the right replacement
|
||||
当前 ResNet backbone 对外给到 policy head 的**总视觉特征维度**是:
|
||||
- 每相机 `64`
|
||||
- 三相机总计 `192`
|
||||
|
||||
而 LEWM checkpoint 输出的 CLS/projector embedding 也是:
|
||||
- 总计 `192`
|
||||
|
||||
因此,最自然的“直接平替当前 ResNet 视觉编码器”的方式是:
|
||||
- 用 LEWM backbone 直接产出一个 192-d 总视觉向量
|
||||
- 后续和 state `16-d` 拼接后,依旧得到 `208-d` 条件向量
|
||||
- 不改 diffusion head 的总体接口和语义
|
||||
|
||||
## Interface compatibility plan
|
||||
现有 `VLAAgent` 假设 backbone 暴露:
|
||||
- `camera_names`
|
||||
- `num_cameras`
|
||||
- `output_dim`(语义上是“每相机特征维度”)
|
||||
- `forward(images_dict) -> (B, T, total_visual_dim)`
|
||||
|
||||
为了最小改动兼容现有 agent:
|
||||
- 新 LEWM backbone 的 `forward()` 返回 `(B, T, 192)`
|
||||
- `camera_names = ('r_vis', 'top', 'front')`
|
||||
- `num_cameras = 3`
|
||||
- `output_dim = 64`
|
||||
|
||||
这样 `VLAAgent` 内部仍会计算:
|
||||
- `per_step_cond_dim = output_dim * num_cams + obs_dim = 64*3 + 16 = 208`
|
||||
与实际 `forward()` 输出的 `192 + 16 = 208` 保持一致。
|
||||
|
||||
> 也就是说:`output_dim` 在这个 backbone 里保留为“与旧 ResNet 总特征等价的单相机占位维度”,而不是“真实 projector 输出维度”。这是一个兼容性 shim,用来避免改 agent 主逻辑。
|
||||
|
||||
## Image preprocessing design
|
||||
当前 roboimi dataset 已经把每个相机图像读成:
|
||||
- `(C, 224, 224)`
|
||||
- 值域 `[0, 1]`
|
||||
|
||||
新 LEWM backbone 将:
|
||||
1. 按顺序取 `r_vis`, `top`, `front`
|
||||
2. 在宽度方向拼接,得到 fused image:
|
||||
- `(C, 224, 672)`
|
||||
3. 使用 LEWM 一致的 ImageNet normalize:
|
||||
- mean `[0.485, 0.456, 0.406]`
|
||||
- std `[0.229, 0.224, 0.225]`
|
||||
4. 调用 `ViTModel(..., interpolate_pos_encoding=True)`
|
||||
5. 取 `last_hidden_state[:, 0]`
|
||||
6. 送入 frozen projector,得到 `(B*T, 192)`
|
||||
|
||||
## Files to create / modify
|
||||
### New files
|
||||
- `roboimi/vla/models/backbones/lewm_vit_backbone.py`
|
||||
- `roboimi/vla/conf/backbone/lewm_vit_diffusion.yaml`
|
||||
- `roboimi/vla/conf/agent/lewm_imf_attnres.yaml`
|
||||
- `tests/test_lewm_vit_backbone.py`
|
||||
|
||||
### Modified files
|
||||
- `roboimi/vla/models/backbones/__init__`(如果需要导出)
|
||||
- `tests/test_imf_vla_agent.py`(增加新 backbone 集成用例)
|
||||
- `roboimi/demos/vla_scripts/train_vla.py`(如需仅调整 rollout 默认/日志;如果命令覆盖足够,则尽量不改主逻辑)
|
||||
- 训练/实验 suite 文档(新增本次 LEWM ViT 训练记录)
|
||||
|
||||
## Testing plan
|
||||
1. **Unit test: load + forward**
|
||||
- 用 synthetic checkpoint 验证新 backbone 能正确加载 `model.encoder.*` 与 `model.projector.*`
|
||||
- 输入 3 相机 `(B,T,C,224,224)`
|
||||
- 输出 `(B,T,192)`
|
||||
2. **Agent integration test**
|
||||
- backbone.output_dim=64, num_cameras=3
|
||||
- agent `_build_cond()` 输出最后维度为 `208`
|
||||
3. **Remote smoke test on 5880**
|
||||
- 使用真实 checkpoint
|
||||
- `max_steps=2`
|
||||
- 两个实验各自 smoke 一次
|
||||
4. **Full run**
|
||||
- GPU0: `embed=384, layer=12`
|
||||
- GPU1: `embed=256, layer=12`
|
||||
- `rollout_num_episodes=10`
|
||||
|
||||
## Training launch contract
|
||||
- host: `100.73.14.65`
|
||||
- code dir: `/home/droid/roboimi_suite_20260404`
|
||||
- python: `/home/droid/miniforge3/envs/roboimi/bin/python`
|
||||
- dataset: `/home/droid/sim_dataset/sim_transfer`
|
||||
- cameras: `[r_vis, top, front]`
|
||||
- agent: new `lewm_imf_attnres`
|
||||
- max_steps: `50000`
|
||||
- rollout every `5` epochs
|
||||
- rollout episodes: `10`
|
||||
|
||||
## Risks
|
||||
1. LEWM 训练时的 fused image 预处理如果方向实现错了(224x672 vs 672x224),会导致分布偏移。
|
||||
2. 当前 roboimi env 需确保安装 `transformers`;从 `environment.yml` 看本地已有该依赖,但远端训练环境要 smoke 确认。
|
||||
3. 因为这是 frozen ViT + projector,若 projector BN 仍保持 train 模式,统计量会漂移,所以必须整体 `eval()` 并冻结。
|
||||
|
||||
## Recommended first implementation path
|
||||
- 先实现一个独立 `LEWMViTBackbone` 类,不改现有 `ResNetDiffusionBackbone` 主逻辑。
|
||||
- 再通过新的 hydra backbone/agent 配置接入。
|
||||
- 优先做到“最少侵入 + smoke 可跑 + 远端可训”。
|
||||
@@ -0,0 +1,32 @@
|
||||
# ResNet Multitoken IMF Design
|
||||
|
||||
**Status:** user-specified architecture, treated as approved on 2026-04-06.
|
||||
|
||||
## Goal
|
||||
Keep a standard ResNet-18 visual trunk (no AttnRes in vision), but change IMF conditioning from one concatenated multiview token per obs step into three camera-specific condition tokens per obs step.
|
||||
|
||||
## Approved architecture
|
||||
- Vision trunk: standard `resnet18` residual network
|
||||
- Cameras: `front`, `top`, `r_vis`
|
||||
- Each camera uses its **own** ResNet-18 weights (`use_separate_rgb_encoder_per_camera=true`)
|
||||
- Each camera produces one visual token
|
||||
- For each obs step and each camera:
|
||||
1. take that camera visual token
|
||||
2. concatenate robot state
|
||||
3. project to one condition token
|
||||
- IMF input should receive **3 condition tokens per obs step**, not one concatenated token
|
||||
- With `obs_horizon=2`, IMF cond sequence length becomes `2 * 3 = 6`
|
||||
- IMF head remains on the existing IMF/AttnRes implementation path
|
||||
- Vision trunk remains standard ResNet; **no AttnRes vision replacement**
|
||||
|
||||
## Design choices
|
||||
- Extend `ResNetDiffusionBackbone` with an opt-in mode that returns per-camera tokens shaped `(B, T, num_cams, D)` instead of concatenating camera features into `(B, T, num_cams * D)`.
|
||||
- Teach `VLAAgent` to detect multi-token visual features, broadcast state per camera token, apply the existing condition projector on each token, then flatten `(T, num_cams)` into one cond sequence for the IMF head.
|
||||
- Keep `per_step_cond_dim` as the width of a single condition token, and add explicit token-count metadata so transformer heads get the correct cond-sequence length.
|
||||
- For the new experiments, set the condition-token width equal to `n_emb` via `cond_projector.output_dim=${agent.head.n_emb}`.
|
||||
|
||||
## Files expected to change
|
||||
- `roboimi/vla/models/backbones/resnet_diffusion.py`
|
||||
- `roboimi/vla/agent.py`
|
||||
- new Hydra agent config for the multitoken ResNet IMF variant
|
||||
- focused tests in `tests/test_imf_vla_agent.py` and/or `tests/test_resnet_transformer_agent_wiring.py`
|
||||
@@ -0,0 +1,41 @@
|
||||
# SigLIP2 Multiview VLA Design
|
||||
|
||||
**Status:** user-specified architecture, treated as approved on 2026-04-06
|
||||
|
||||
## Goal
|
||||
Replace the current vision encoder for the IMF/AttnRes diffusion policy with a frozen SigLIP2 image encoder while preserving the downstream action-diffusion stack and rollout behavior.
|
||||
|
||||
## Approved architecture
|
||||
- Backbone model: `google/siglip2-base-patch16-256`
|
||||
- Camera inputs: three views, encoded **independently** with a **shared** SigLIP2 vision encoder
|
||||
- Input size:
|
||||
- dataset images stay at native `256x256` (no dataset-side resize)
|
||||
- eval/rollout images resize to `256x256` before SigLIP2 because env renders are larger
|
||||
- Per-view feature: use the global pooled image feature (`pooler_output`, 768-d)
|
||||
- Per-view projection experiments:
|
||||
1. `768 -> 96`
|
||||
2. `768 -> 192`
|
||||
- Conditioning pipeline:
|
||||
1. concatenate 3 projected camera vectors
|
||||
2. concatenate robot state
|
||||
3. project concatenated condition to `384`
|
||||
4. feed that `384`-d per-step condition into the existing IMF/AttnRes diffusion head
|
||||
- Training/run defaults for requested experiments:
|
||||
- `n_emb=384`
|
||||
- `n_layer=12`
|
||||
- `pred_horizon=16`
|
||||
- `num_action_steps=8`
|
||||
- rollout count for validation: keep current requested behavior on this branch unless explicitly overridden later
|
||||
|
||||
## Design decisions
|
||||
- The condition projector lives in `VLAAgent._build_cond()` so the backbone owns only visual features, while the agent owns the final conditioning contract expected by the diffusion head.
|
||||
- The SigLIP2 backbone is frozen by default; only the per-view projectors and downstream policy layers train.
|
||||
- The backbone exposes `dataset_image_resize_shape=None` and `eval_image_resize_shape=(256, 256)` so existing train/eval plumbing can reuse the raw-256 path already added in this branch.
|
||||
- One shared vision encoder is used across cameras to keep memory and download size reasonable and to match the user's request for per-view independent encoding rather than a fused multiview image.
|
||||
|
||||
## Files expected to change
|
||||
- `roboimi/vla/models/backbones/` for the new SigLIP2 backbone
|
||||
- `roboimi/vla/agent.py` for optional post-concat condition projection
|
||||
- Hydra configs under `roboimi/vla/conf/{agent,backbone,modules}`
|
||||
- tests for backbone wiring and agent conditioning dims
|
||||
- remote launch commands/scripts only as needed for training
|
||||
69
experiment_suites/2026-04-05-camera-ablation-summary.md
Normal file
69
experiment_suites/2026-04-05-camera-ablation-summary.md
Normal file
@@ -0,0 +1,69 @@
|
||||
# Camera Ablation Summary (`pred_horizon=16`, `num_action_steps=8`, ResNet IMF)
|
||||
|
||||
- Generated: 2026-04-05
|
||||
- Common setup: original ResNet vision backbone, `n_emb=384`, `n_layer=12`, `batch_size=80`, `lr=2.5e-4`, `max_steps=50k`, rollout every 5 epochs with 5 episodes, headless eval.
|
||||
- Metric for comparison: `checkpoints/vla_model_best.pt -> rollout_avg_reward`.
|
||||
|
||||
## Leaderboard
|
||||
|
||||
| Rank | Cameras | Best avg_reward | Best step | Final loss | Run name |
|
||||
|---:|---|---:|---:|---:|---|
|
||||
| 1 | `top + front` | **274.8** | 48124 | 0.0056 | `imf-resnet-topfront-2cam-ph16-ex08-emb384-l12-ms50k-5090-20260405-085023` |
|
||||
| 2 | `top` | **271.2** | 43749 | 0.0052 | `imf-resnet-top-1cam-ph16-ex08-emb384-l12-ms50k-l20g4-20260405-125844` |
|
||||
| 3 | `r_vis + front` | **244.0** | 21874 | 0.0043 | `imf-resnet-frontrvis-2cam-ph16-ex08-emb384-l12-ms50k-l20g1-20260405-102029` |
|
||||
| 4 | `r_vis` | **6.4** | 17499 | 0.0047 | `imf-resnet-rvis-1cam-ph16-ex08-emb384-l12-ms50k-l20g3-20260405-125844` |
|
||||
| 5 | `r_vis + top` | **1.2** | 4374 | 0.0047 | `imf-resnet-rvistop-2cam-ph16-ex08-emb384-l12-ms50k-l20g2-20260405-125844` |
|
||||
| 6 | `front` | **0.0** | 4374 | 0.0074 | `imf-resnet-front-1cam-ph16-ex08-emb384-l12-ms50k-l20g0-20260405-095607` |
|
||||
|
||||
## Main takeaways
|
||||
|
||||
1. **`top` 是最关键的单相机视角**:`top only = 271.2`,几乎与 `top + front = 274.8` 持平。
|
||||
2. **`front` 单独几乎没有效用**:`front only = 0.0`。
|
||||
3. **`r_vis` 单独也基本无效**:`r_vis only = 6.4`。
|
||||
4. **`r_vis + front` 可以显著优于单独 `front` / `r_vis`**,说明这两个视角有一定互补性,但仍明显弱于任何包含 `top` 且表现正常的配置。
|
||||
5. **`r_vis + top` 的结果异常差**:只有 `1.2`,远低于 `top only = 271.2`。这说明简单加入 `r_vis` 并不保证增益,甚至可能破坏当前设置下的学习。
|
||||
6. **训练 loss 与 rollout reward 明显不一致**:例如 `r_vis + top` 和 `r_vis only` 的 final loss 都不高,但 reward 很差,因此本组实验必须以 rollout reward 而不是 loss 选型。
|
||||
|
||||
## Horizontal comparison views
|
||||
|
||||
### Single-camera comparison
|
||||
|
||||
- `top`: **271.2**
|
||||
- `r_vis`: **6.4**
|
||||
- `front`: **0.0**
|
||||
|
||||
结论:**`top >>> r_vis > front`**。
|
||||
|
||||
### Two-camera comparison
|
||||
|
||||
- `top + front`: **274.8**
|
||||
- `r_vis + front`: **244.0**
|
||||
- `r_vis + top`: **1.2**
|
||||
|
||||
结论:
|
||||
- **最稳妥的双相机组合是 `top + front`**。
|
||||
- `r_vis + front` 有效,但不如 `top + front`。
|
||||
- `r_vis + top` 在当前设置下几乎失效。
|
||||
|
||||
### Incremental effect of adding a second view
|
||||
|
||||
- 在 `top` 基础上加 `front`:`271.2 -> 274.8`,**增益很小**。
|
||||
- 在 `front` 基础上加 `r_vis`:`0.0 -> 244.0`,**增益很大**。
|
||||
- 在 `top` 基础上加 `r_vis`:`271.2 -> 1.2`,**显著退化**。
|
||||
|
||||
## Practical recommendation
|
||||
|
||||
如果只从这 6 个实验里选:
|
||||
|
||||
- **首选**:`top + front`
|
||||
- **次选**:`top only`
|
||||
- 如果必须不用 `top`:`r_vis + front` 明显优于 `front only` / `r_vis only`
|
||||
- **不建议**:`r_vis + top`
|
||||
|
||||
## Note relative to previous 3-camera baseline
|
||||
|
||||
此前 3 相机 `[r_vis, top, front]` 的最佳 reward 为 **610.8**。
|
||||
因此这次 6 个 camera ablation 的最佳结果(`top + front = 274.8`)说明:
|
||||
|
||||
- 当前这个训练批次里,**去掉任意一个视角都会显著低于之前的 3 相机最优结果**;
|
||||
- 但在去掉视角的约束下,**`top` 仍然是最核心的保留对象**。
|
||||
@@ -0,0 +1,8 @@
|
||||
# CHECKLIST
|
||||
|
||||
- [x] Confirm remote free GPU
|
||||
- [x] Create front-only run contract
|
||||
- [x] Remote smoke test passes
|
||||
- [x] Launch 50k run on remote GPU0
|
||||
- [x] Record pid / log / SwanLab
|
||||
- [x] Report status back to user
|
||||
28
experiment_suites/2026-04-05-front-only-resnet-1cam/PLAN.md
Normal file
28
experiment_suites/2026-04-05-front-only-resnet-1cam/PLAN.md
Normal file
@@ -0,0 +1,28 @@
|
||||
# PLAN
|
||||
|
||||
## Goal
|
||||
Train a 50k-step IMF baseline with the original ResNet vision backbone, using only the `front` camera as image conditioning.
|
||||
|
||||
## Fixed comparison contract
|
||||
- Same as the active `top/front` run except image input is reduced to `[front]`
|
||||
- Agent: `resnet_imf_attnres`
|
||||
- Vision backbone mode: `resnet`
|
||||
- `pred_horizon=16`, `num_action_steps=8`
|
||||
- `n_emb=384`, `n_layer=12`, `n_head=1`, `n_kv_head=1`
|
||||
- `inference_steps=1`
|
||||
- `batch_size=80`, `lr=2.5e-4`, cosine, warmup=2000
|
||||
- dataset: `/home/droid/sim_dataset/sim_transfer`
|
||||
- cameras: `[front]` only
|
||||
- rollout every 5 epochs with 5 episodes, headless
|
||||
|
||||
## Resource plan
|
||||
- Host: `100.119.99.14`
|
||||
- GPU: `0`
|
||||
|
||||
## Important dimension override
|
||||
- Single-camera visual cond dim = `64 + 16 = 80`, so override `agent.head.cond_dim=80` and `agent.num_cams=1`.
|
||||
|
||||
## Execution path
|
||||
1. 2-step smoke test on remote GPU0.
|
||||
2. If smoke passes, launch 50k main run with SwanLab.
|
||||
3. Record pid / run_dir / log / URL locally.
|
||||
@@ -0,0 +1,6 @@
|
||||
# Notes
|
||||
|
||||
- 2026-04-05 09:55:27: remote 2-step smoke passed on `100.119.99.14` GPU0 with `front` only, batch=80, no OOM.
|
||||
- 2026-04-05 09:56:26: launched main run `imf-resnet-front-1cam-ph16-ex08-emb384-l12-ms50k-l20g0-20260405-095607`.
|
||||
- 2026-04-05 09:57:36: confirmed training is stable through step 200, latest loss 0.2830.
|
||||
- SwanLab: https://swanlab.cn/@game-loader/roboimi-vla/runs/7kdii8oc6tjkcyu5y0lwq
|
||||
@@ -0,0 +1,51 @@
|
||||
{
|
||||
"suite_name": "2026-04-05-front-only-resnet-1cam",
|
||||
"updated_at": "2026-04-05 09:57:36",
|
||||
"phase": "running",
|
||||
"baseline_reference": {
|
||||
"source_run": "imf-resnet-topfront-2cam-ph16-ex08-emb384-l12-ms50k-5090-20260405-085023",
|
||||
"notes": "Same hyperparameters as the active top/front run, but image input is reduced to [front] only."
|
||||
},
|
||||
"smoke_test": {
|
||||
"status": "passed",
|
||||
"host": "100.119.99.14",
|
||||
"gpu": 0,
|
||||
"run_dir": "/home/droid/roboimi_suite_20260404/runs/smoke-frontonly-resnet-ph16-ex08-20260405-095509",
|
||||
"batch_size": 80,
|
||||
"max_steps": 2,
|
||||
"note": "2-step remote CUDA smoke passed on L20 GPU0 without OOM."
|
||||
},
|
||||
"main_run": {
|
||||
"status": "running",
|
||||
"host": "100.119.99.14",
|
||||
"gpu": 0,
|
||||
"launch_pid": 158874,
|
||||
"pid": 158877,
|
||||
"run_name": "imf-resnet-front-1cam-ph16-ex08-emb384-l12-ms50k-l20g0-20260405-095607",
|
||||
"run_dir": "/home/droid/roboimi_suite_20260404/runs/imf-resnet-front-1cam-ph16-ex08-emb384-l12-ms50k-l20g0-20260405-095607",
|
||||
"log_path": "/home/droid/roboimi_suite_20260404/runs/imf-resnet-front-1cam-ph16-ex08-emb384-l12-ms50k-l20g0-20260405-095607/train_vla.log",
|
||||
"launch_log": "/home/droid/roboimi_suite_20260404/experiment_suite_launch_logs/imf-resnet-front-1cam-ph16-ex08-emb384-l12-ms50k-l20g0-20260405-095607.launch.log",
|
||||
"dataset_dir": "/home/droid/sim_dataset/sim_transfer",
|
||||
"camera_names": [
|
||||
"front"
|
||||
],
|
||||
"pred_horizon": 16,
|
||||
"num_action_steps": 8,
|
||||
"head_cond_dim": 80,
|
||||
"head_n_emb": 384,
|
||||
"head_n_layer": 12,
|
||||
"vision_backbone_mode": "resnet",
|
||||
"pretrained_backbone_weights": null,
|
||||
"freeze_backbone": false,
|
||||
"batch_size": 80,
|
||||
"lr": 0.00025,
|
||||
"num_workers": 12,
|
||||
"max_steps": 50000,
|
||||
"rollout_val_freq_epochs": 5,
|
||||
"rollout_num_episodes": 5,
|
||||
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/7kdii8oc6tjkcyu5y0lwq",
|
||||
"latest_step": 200,
|
||||
"latest_loss": 0.283,
|
||||
"process_running": true
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
# CHECKLIST
|
||||
|
||||
- [x] Confirm camera mapping (`right` -> `r_vis`)
|
||||
- [x] Create front+r_vis run contract
|
||||
- [x] Remote smoke test passes
|
||||
- [x] Launch 50k run on remote GPU1
|
||||
- [x] Record pid / log / SwanLab
|
||||
- [x] Report status back to user
|
||||
23
experiment_suites/2026-04-05-front-rvis-resnet-2cam/PLAN.md
Normal file
23
experiment_suites/2026-04-05-front-rvis-resnet-2cam/PLAN.md
Normal file
@@ -0,0 +1,23 @@
|
||||
# PLAN
|
||||
|
||||
## Goal
|
||||
Train a 50k-step IMF baseline with the original ResNet vision backbone, using `front` + `r_vis` cameras only.
|
||||
|
||||
## Fixed comparison contract
|
||||
- Same hyperparameters as the active top/front and front-only runs
|
||||
- Agent: `resnet_imf_attnres`
|
||||
- Vision backbone mode: `resnet`
|
||||
- `pred_horizon=16`, `num_action_steps=8`
|
||||
- `n_emb=384`, `n_layer=12`, `n_head=1`, `n_kv_head=1`
|
||||
- `inference_steps=1`
|
||||
- `batch_size=80`, `lr=2.5e-4`, cosine warmup 2000
|
||||
- dataset: `/home/droid/sim_dataset/sim_transfer`
|
||||
- cameras: `[r_vis, front]`
|
||||
- rollout every 5 epochs with 5 episodes, headless
|
||||
|
||||
## Important dimension override
|
||||
- Two-camera visual cond dim = `64*2 + 16 = 144`, so set `agent.num_cams=2`, `agent.head.cond_dim=144`.
|
||||
|
||||
## Resource plan
|
||||
- Host: `100.119.99.14`
|
||||
- GPU: `1`
|
||||
@@ -0,0 +1,6 @@
|
||||
# Notes
|
||||
|
||||
- 2026-04-05 10:20:09: remote 2-step smoke passed on `100.119.99.14` GPU1 with `r_vis + front`, batch=80, no OOM.
|
||||
- 2026-04-05 10:20:49: launched main run `imf-resnet-frontrvis-2cam-ph16-ex08-emb384-l12-ms50k-l20g1-20260405-102029`.
|
||||
- 2026-04-05 10:22:03: confirmed training is stable through step 200, latest loss 0.3321.
|
||||
- SwanLab: https://swanlab.cn/@game-loader/roboimi-vla/runs/3fyzjfdcbiq7frtbqv6ss
|
||||
@@ -0,0 +1,55 @@
|
||||
{
|
||||
"suite_name": "2026-04-05-front-rvis-resnet-2cam",
|
||||
"updated_at": "2026-04-05 10:22:03",
|
||||
"phase": "running",
|
||||
"interpretation": {
|
||||
"right_camera_name": "r_vis"
|
||||
},
|
||||
"baseline_reference": {
|
||||
"source_run": "imf-resnet-topfront-2cam-ph16-ex08-emb384-l12-ms50k-5090-20260405-085023",
|
||||
"notes": "Same hyperparameters as the active top/front run, replacing top with r_vis."
|
||||
},
|
||||
"smoke_test": {
|
||||
"status": "passed",
|
||||
"host": "100.119.99.14",
|
||||
"gpu": 1,
|
||||
"run_dir": "/home/droid/roboimi_suite_20260404/runs/smoke-frontrvis-resnet-ph16-ex08-20260405-102001",
|
||||
"batch_size": 80,
|
||||
"max_steps": 2,
|
||||
"note": "2-step remote CUDA smoke passed on L20 GPU1 without OOM."
|
||||
},
|
||||
"main_run": {
|
||||
"status": "running",
|
||||
"host": "100.119.99.14",
|
||||
"gpu": 1,
|
||||
"launch_pid": 159910,
|
||||
"pid": 159913,
|
||||
"run_name": "imf-resnet-frontrvis-2cam-ph16-ex08-emb384-l12-ms50k-l20g1-20260405-102029",
|
||||
"run_dir": "/home/droid/roboimi_suite_20260404/runs/imf-resnet-frontrvis-2cam-ph16-ex08-emb384-l12-ms50k-l20g1-20260405-102029",
|
||||
"log_path": "/home/droid/roboimi_suite_20260404/runs/imf-resnet-frontrvis-2cam-ph16-ex08-emb384-l12-ms50k-l20g1-20260405-102029/train_vla.log",
|
||||
"launch_log": "/home/droid/roboimi_suite_20260404/experiment_suite_launch_logs/imf-resnet-frontrvis-2cam-ph16-ex08-emb384-l12-ms50k-l20g1-20260405-102029.launch.log",
|
||||
"dataset_dir": "/home/droid/sim_dataset/sim_transfer",
|
||||
"camera_names": [
|
||||
"r_vis",
|
||||
"front"
|
||||
],
|
||||
"pred_horizon": 16,
|
||||
"num_action_steps": 8,
|
||||
"head_cond_dim": 144,
|
||||
"head_n_emb": 384,
|
||||
"head_n_layer": 12,
|
||||
"vision_backbone_mode": "resnet",
|
||||
"pretrained_backbone_weights": null,
|
||||
"freeze_backbone": false,
|
||||
"batch_size": 80,
|
||||
"lr": 0.00025,
|
||||
"num_workers": 12,
|
||||
"max_steps": 50000,
|
||||
"rollout_val_freq_epochs": 5,
|
||||
"rollout_num_episodes": 5,
|
||||
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/3fyzjfdcbiq7frtbqv6ss",
|
||||
"latest_step": 200,
|
||||
"latest_loss": 0.3321,
|
||||
"process_running": true
|
||||
}
|
||||
}
|
||||
73
experiment_suites/2026-04-05-lewm-vit-transfer/manifest.json
Normal file
73
experiment_suites/2026-04-05-lewm-vit-transfer/manifest.json
Normal file
@@ -0,0 +1,73 @@
|
||||
{
|
||||
"date": "2026-04-06",
|
||||
"branch": "feat-imf-attnres-policy",
|
||||
"worktree": "/home/droid/project/roboimi/.worktrees/feat-imf-attnres-policy",
|
||||
"model": "LEWM ViT frozen visual encoder + IMF AttnRes diffusion head",
|
||||
"checkpoint_path": "/home/droid/le-wm/lewm-sim-transfer/pa1w85md8jop6bvol8oxp/checkpoints/epoch=99-step=47800.ckpt",
|
||||
"visual_contract": {
|
||||
"input_camera_names": ["r_vis", "top", "front"],
|
||||
"fused_camera_names": ["front", "top", "r_vis"],
|
||||
"joint_output_dim": 192,
|
||||
"freeze_backbone": true,
|
||||
"dataset_image_resize_shape": null,
|
||||
"eval_image_resize_shape": [256, 256],
|
||||
"fused_short_side_resize": 224
|
||||
},
|
||||
"training_contract": {
|
||||
"pred_horizon": 16,
|
||||
"num_action_steps": 8,
|
||||
"max_steps": 50000,
|
||||
"rollout_val_freq_epochs": 5,
|
||||
"rollout_num_episodes": 10,
|
||||
"batch_size": 80,
|
||||
"lr": 0.00025,
|
||||
"num_workers": 12,
|
||||
"scheduler_type": "cosine",
|
||||
"warmup_steps": 2000,
|
||||
"min_lr": 1e-06,
|
||||
"weight_decay": 1e-05,
|
||||
"grad_clip": 1.0
|
||||
},
|
||||
"verification": {
|
||||
"local_tests": "38 passed",
|
||||
"remote_dataset_shape": [2, 3, 256, 256],
|
||||
"remote_eval_prepared_shape": [3, 256, 256],
|
||||
"remote_smoke_run": {
|
||||
"run_name": "smoke-lewm-imf-rawpath-emb384-20260406-002002",
|
||||
"result": "passed",
|
||||
"details": "2-step train + checkpoint-triggered 1-episode headless rollout succeeded with corrected raw256 path"
|
||||
}
|
||||
},
|
||||
"superseded_runs": [
|
||||
{
|
||||
"run_name": "lewm-vit-imf-sim-transfer-emb384-l12-ph16-ex08-step50k-roll10-5880g0-20260405-201914",
|
||||
"reason": "stopped due to incorrect early per-camera 224 resize"
|
||||
},
|
||||
{
|
||||
"run_name": "lewm-vit-imf-sim-transfer-emb256-l12-ph16-ex08-step50k-roll10-5880g1-20260405-201914",
|
||||
"reason": "stopped due to incorrect early per-camera 224 resize"
|
||||
}
|
||||
],
|
||||
"full_runs": [
|
||||
{
|
||||
"host": "100.73.14.65",
|
||||
"gpu": 0,
|
||||
"run_name": "lewm-vit-imf-raw256fix-sim-transfer-emb384-l12-ph16-ex08-step50k-roll10-5880g0-20260406-002124",
|
||||
"pid": 1058589,
|
||||
"log_path": "/home/droid/roboimi_suite_20260404/experiment_suite_launch_logs/lewm-vit-imf-raw256fix-sim-transfer-emb384-l12-ph16-ex08-step50k-roll10-5880g0-20260406-002124.launch.log",
|
||||
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/y5tzgqe0u966w9ak41i31",
|
||||
"head_n_emb": 384,
|
||||
"head_n_layer": 12
|
||||
},
|
||||
{
|
||||
"host": "100.73.14.65",
|
||||
"gpu": 1,
|
||||
"run_name": "lewm-vit-imf-raw256fix-sim-transfer-emb256-l12-ph16-ex08-step50k-roll10-5880g1-20260406-002124",
|
||||
"pid": 1058590,
|
||||
"log_path": "/home/droid/roboimi_suite_20260404/experiment_suite_launch_logs/lewm-vit-imf-raw256fix-sim-transfer-emb256-l12-ph16-ex08-step50k-roll10-5880g1-20260406-002124.launch.log",
|
||||
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/2esr9y7t2dgesstgrn5i6",
|
||||
"head_n_emb": 256,
|
||||
"head_n_layer": 12
|
||||
}
|
||||
]
|
||||
}
|
||||
25
experiment_suites/2026-04-05-lewm-vit-transfer/notes.md
Normal file
25
experiment_suites/2026-04-05-lewm-vit-transfer/notes.md
Normal file
@@ -0,0 +1,25 @@
|
||||
# 2026-04-06 LEWM ViT Transfer Notes
|
||||
|
||||
## Root-cause fix
|
||||
|
||||
The first LEWM runs were stopped because the data path still resized each camera view to `224x224` **before** multiview fusion. That preserved the final tensor shape but broke the original LEWM geometry.
|
||||
|
||||
Corrected path now is:
|
||||
|
||||
- **Training dataset**: keep stored per-view `256x256` images (`data.image_resize_shape=null` at launch; dataset instantiate override is `None` for LEWM)
|
||||
- **Eval rollout input**: resize live MuJoCo `480x640` camera images to `256x256` per view
|
||||
- **Backbone**: fuse `front, top, r_vis` on the LEWM axis, then resize fused short side to `224`
|
||||
|
||||
## Verification
|
||||
|
||||
- Local tests passed (`38 passed` across the focused suite)
|
||||
- Remote check:
|
||||
- dataset sample image shape: `(2, 3, 256, 256)`
|
||||
- eval-prepared live frame shape: `(3, 256, 256)`
|
||||
- Remote smoke passed with real checkpoint:
|
||||
- `smoke-lewm-imf-rawpath-emb384-20260406-002002`
|
||||
|
||||
## Current runs
|
||||
|
||||
- `lewm-vit-imf-raw256fix-sim-transfer-emb384-l12-ph16-ex08-step50k-roll10-5880g0-20260406-002124`
|
||||
- `lewm-vit-imf-raw256fix-sim-transfer-emb256-l12-ph16-ex08-step50k-roll10-5880g1-20260406-002124`
|
||||
19
experiment_suites/2026-04-05-lewm-vit-transfer/status.json
Normal file
19
experiment_suites/2026-04-05-lewm-vit-transfer/status.json
Normal file
@@ -0,0 +1,19 @@
|
||||
{
|
||||
"status": "running",
|
||||
"updated_at": "2026-04-06T00:22:10+08:00",
|
||||
"remote_host": "100.73.14.65",
|
||||
"runs": [
|
||||
{
|
||||
"run_name": "lewm-vit-imf-raw256fix-sim-transfer-emb384-l12-ph16-ex08-step50k-roll10-5880g0-20260406-002124",
|
||||
"pid": 1058589,
|
||||
"gpu": 0,
|
||||
"state": "running"
|
||||
},
|
||||
{
|
||||
"run_name": "lewm-vit-imf-raw256fix-sim-transfer-emb256-l12-ph16-ex08-step50k-roll10-5880g1-20260406-002124",
|
||||
"pid": 1058590,
|
||||
"gpu": 1,
|
||||
"state": "running"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
# CHECKLIST
|
||||
|
||||
- [x] Create run contract
|
||||
- [x] Remote smoke test passes
|
||||
- [x] Launch 50k main run
|
||||
- [x] Record pid / log / SwanLab
|
||||
- [x] Report status back to user
|
||||
12
experiment_suites/2026-04-05-rvis-only-resnet-1cam/PLAN.md
Normal file
12
experiment_suites/2026-04-05-rvis-only-resnet-1cam/PLAN.md
Normal file
@@ -0,0 +1,12 @@
|
||||
# PLAN
|
||||
|
||||
## Goal
|
||||
Train a 50k-step IMF baseline with the original ResNet vision backbone using r_vis only as the only image conditioning.
|
||||
|
||||
## Fixed comparison contract
|
||||
- same hyperparameters as the active top/front run
|
||||
- cameras: ['r_vis']
|
||||
- num_cams=1
|
||||
- head.cond_dim=80
|
||||
- host: 100.119.99.14
|
||||
- gpu: 3
|
||||
@@ -0,0 +1,6 @@
|
||||
# Notes
|
||||
|
||||
- 2026-04-05 12:58:22: smoke passed for ['r_vis'] on 100.119.99.14 GPU3.
|
||||
- 2026-04-05 12:59:24: launched main run `imf-resnet-rvis-1cam-ph16-ex08-emb384-l12-ms50k-l20g3-20260405-125844`.
|
||||
- 2026-04-05 13:01:20: latest confirmed progress step=400, loss=0.1165.
|
||||
- SwanLab: https://swanlab.cn/@game-loader/roboimi-vla/runs/qnuh7vln9mqomxxldyecq
|
||||
@@ -0,0 +1,47 @@
|
||||
{
|
||||
"suite_name": "2026-04-05-rvis-only-resnet-1cam",
|
||||
"updated_at": "2026-04-05 13:01:20",
|
||||
"phase": "running",
|
||||
"smoke_test": {
|
||||
"status": "passed",
|
||||
"host": "100.119.99.14",
|
||||
"gpu": 3,
|
||||
"run_dir": "/home/droid/roboimi_suite_20260404/runs/smoke-rvisonly-resnet-ph16-ex08-20260405-125812",
|
||||
"batch_size": 80,
|
||||
"max_steps": 2,
|
||||
"note": "2-step remote CUDA smoke passed without OOM."
|
||||
},
|
||||
"main_run": {
|
||||
"status": "running",
|
||||
"host": "100.119.99.14",
|
||||
"gpu": 3,
|
||||
"launch_pid": 164812,
|
||||
"pid": 164816,
|
||||
"run_name": "imf-resnet-rvis-1cam-ph16-ex08-emb384-l12-ms50k-l20g3-20260405-125844",
|
||||
"run_dir": "/home/droid/roboimi_suite_20260404/runs/imf-resnet-rvis-1cam-ph16-ex08-emb384-l12-ms50k-l20g3-20260405-125844",
|
||||
"log_path": "/home/droid/roboimi_suite_20260404/runs/imf-resnet-rvis-1cam-ph16-ex08-emb384-l12-ms50k-l20g3-20260405-125844/train_vla.log",
|
||||
"launch_log": "/home/droid/roboimi_suite_20260404/experiment_suite_launch_logs/imf-resnet-rvis-1cam-ph16-ex08-emb384-l12-ms50k-l20g3-20260405-125844.launch.log",
|
||||
"dataset_dir": "/home/droid/sim_dataset/sim_transfer",
|
||||
"camera_names": [
|
||||
"r_vis"
|
||||
],
|
||||
"pred_horizon": 16,
|
||||
"num_action_steps": 8,
|
||||
"head_cond_dim": 80,
|
||||
"head_n_emb": 384,
|
||||
"head_n_layer": 12,
|
||||
"vision_backbone_mode": "resnet",
|
||||
"pretrained_backbone_weights": null,
|
||||
"freeze_backbone": false,
|
||||
"batch_size": 80,
|
||||
"lr": 0.00025,
|
||||
"num_workers": 12,
|
||||
"max_steps": 50000,
|
||||
"rollout_val_freq_epochs": 5,
|
||||
"rollout_num_episodes": 5,
|
||||
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/qnuh7vln9mqomxxldyecq",
|
||||
"latest_step": 400,
|
||||
"latest_loss": 0.1165,
|
||||
"process_running": true
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
# CHECKLIST
|
||||
|
||||
- [x] Create run contract
|
||||
- [x] Remote smoke test passes
|
||||
- [x] Launch 50k main run
|
||||
- [x] Record pid / log / SwanLab
|
||||
- [x] Report status back to user
|
||||
12
experiment_suites/2026-04-05-rvistop-resnet-2cam/PLAN.md
Normal file
12
experiment_suites/2026-04-05-rvistop-resnet-2cam/PLAN.md
Normal file
@@ -0,0 +1,12 @@
|
||||
# PLAN
|
||||
|
||||
## Goal
|
||||
Train a 50k-step IMF baseline with the original ResNet vision backbone using r_vis + top as the only image conditioning.
|
||||
|
||||
## Fixed comparison contract
|
||||
- same hyperparameters as the active top/front run
|
||||
- cameras: ['r_vis', 'top']
|
||||
- num_cams=2
|
||||
- head.cond_dim=144
|
||||
- host: 100.119.99.14
|
||||
- gpu: 2
|
||||
@@ -0,0 +1,6 @@
|
||||
# Notes
|
||||
|
||||
- 2026-04-05 12:58:22: smoke passed for ['r_vis', 'top'] on 100.119.99.14 GPU2.
|
||||
- 2026-04-05 12:59:24: launched main run `imf-resnet-rvistop-2cam-ph16-ex08-emb384-l12-ms50k-l20g2-20260405-125844`.
|
||||
- 2026-04-05 13:01:20: latest confirmed progress step=200, loss=0.2845.
|
||||
- SwanLab: https://swanlab.cn/@game-loader/roboimi-vla/runs/umsm6402eb81et7wx7z4a
|
||||
48
experiment_suites/2026-04-05-rvistop-resnet-2cam/status.json
Normal file
48
experiment_suites/2026-04-05-rvistop-resnet-2cam/status.json
Normal file
@@ -0,0 +1,48 @@
|
||||
{
|
||||
"suite_name": "2026-04-05-rvistop-resnet-2cam",
|
||||
"updated_at": "2026-04-05 13:01:20",
|
||||
"phase": "running",
|
||||
"smoke_test": {
|
||||
"status": "passed",
|
||||
"host": "100.119.99.14",
|
||||
"gpu": 2,
|
||||
"run_dir": "/home/droid/roboimi_suite_20260404/runs/smoke-rvistop-resnet-ph16-ex08-20260405-125812",
|
||||
"batch_size": 80,
|
||||
"max_steps": 2,
|
||||
"note": "2-step remote CUDA smoke passed without OOM."
|
||||
},
|
||||
"main_run": {
|
||||
"status": "running",
|
||||
"host": "100.119.99.14",
|
||||
"gpu": 2,
|
||||
"launch_pid": 164745,
|
||||
"pid": 164749,
|
||||
"run_name": "imf-resnet-rvistop-2cam-ph16-ex08-emb384-l12-ms50k-l20g2-20260405-125844",
|
||||
"run_dir": "/home/droid/roboimi_suite_20260404/runs/imf-resnet-rvistop-2cam-ph16-ex08-emb384-l12-ms50k-l20g2-20260405-125844",
|
||||
"log_path": "/home/droid/roboimi_suite_20260404/runs/imf-resnet-rvistop-2cam-ph16-ex08-emb384-l12-ms50k-l20g2-20260405-125844/train_vla.log",
|
||||
"launch_log": "/home/droid/roboimi_suite_20260404/experiment_suite_launch_logs/imf-resnet-rvistop-2cam-ph16-ex08-emb384-l12-ms50k-l20g2-20260405-125844.launch.log",
|
||||
"dataset_dir": "/home/droid/sim_dataset/sim_transfer",
|
||||
"camera_names": [
|
||||
"r_vis",
|
||||
"top"
|
||||
],
|
||||
"pred_horizon": 16,
|
||||
"num_action_steps": 8,
|
||||
"head_cond_dim": 144,
|
||||
"head_n_emb": 384,
|
||||
"head_n_layer": 12,
|
||||
"vision_backbone_mode": "resnet",
|
||||
"pretrained_backbone_weights": null,
|
||||
"freeze_backbone": false,
|
||||
"batch_size": 80,
|
||||
"lr": 0.00025,
|
||||
"num_workers": 12,
|
||||
"max_steps": 50000,
|
||||
"rollout_val_freq_epochs": 5,
|
||||
"rollout_num_episodes": 5,
|
||||
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/umsm6402eb81et7wx7z4a",
|
||||
"latest_step": 200,
|
||||
"latest_loss": 0.2845,
|
||||
"process_running": true
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
# CHECKLIST
|
||||
|
||||
- [x] Confirm baseline hyperparameters from trusted prior run
|
||||
- [x] Confirm local GPU availability
|
||||
- [x] Smoke test with `top/front` cameras only
|
||||
- [x] Launch 50k run
|
||||
- [x] Record pid / run dir / log path / SwanLab URL
|
||||
- [x] Report status back to user
|
||||
30
experiment_suites/2026-04-05-top-front-resnet-2cam/PLAN.md
Normal file
30
experiment_suites/2026-04-05-top-front-resnet-2cam/PLAN.md
Normal file
@@ -0,0 +1,30 @@
|
||||
# PLAN
|
||||
|
||||
## Goal
|
||||
Train a 50k-step IMF baseline with the original ResNet vision backbone (no full-AttnRes vision replacement), using only `top` and `front` cameras as image conditioning.
|
||||
|
||||
## Fixed comparison contract
|
||||
- Agent: `resnet_imf_attnres`
|
||||
- Vision backbone mode: `resnet`
|
||||
- `pred_horizon=16`
|
||||
- `num_action_steps=8`
|
||||
- `n_emb=384`, `n_layer=12`, `n_head=1`, `n_kv_head=1`
|
||||
- `inference_steps=1`
|
||||
- `batch_size=80`, `lr=2.5e-4`, cosine scheduler, warmup 2000
|
||||
- dataset: `/home/droid/project/diana_sim/sim_transfer`
|
||||
- cameras: `[top, front]` only
|
||||
- training budget: `max_steps=50000`
|
||||
- rollout validation: every 5 epochs, 5 episodes, headless
|
||||
|
||||
## Resource plan
|
||||
- Host: local
|
||||
- GPU: RTX 5090 (GPU 0)
|
||||
|
||||
## Execution path
|
||||
1. Run a short 2-step smoke test on GPU with the exact 2-camera config.
|
||||
2. If smoke passes, launch the 50k main run with durable log redirection.
|
||||
3. Record run name, pid, log path, and SwanLab URL into suite status.
|
||||
|
||||
## Fallbacks
|
||||
- If batch 80 OOMs, fall back to batch 64 with scaled lr 2.0e-4.
|
||||
- If dataloader startup is unstable, reduce num_workers from 12 to 8.
|
||||
@@ -0,0 +1,5 @@
|
||||
# Notes
|
||||
|
||||
- 2026-04-05 08:50:04: 2-step smoke test passed locally on RTX 5090 with `top/front` cameras, batch=80, no OOM.
|
||||
- 2026-04-05 08:50:42: launched main run `imf-resnet-topfront-2cam-ph16-ex08-emb384-l12-ms50k-5090-20260405-085023` on local GPU0.
|
||||
- SwanLab: https://swanlab.cn/@game-loader/roboimi-vla/runs/vi77mn5dwd19z4nttxab8
|
||||
@@ -0,0 +1,51 @@
|
||||
{
|
||||
"suite_name": "2026-04-05-top-front-resnet-2cam",
|
||||
"updated_at": "2026-04-05 08:52:12",
|
||||
"phase": "running",
|
||||
"baseline_reference": {
|
||||
"source_run": "imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223",
|
||||
"best_rollout_avg_reward": 610.8,
|
||||
"best_step": 21874,
|
||||
"notes": "Same IMF baseline as Phase-1 best, but switch cameras from [r_vis, top, front] to [top, front] and keep the original ResNet vision backbone."
|
||||
},
|
||||
"smoke_test": {
|
||||
"status": "passed",
|
||||
"run_dir": "/home/droid/project/roboimi/.worktrees/feat-imf-attnres-policy/runs/smoke-topfront-resnet-ph16-ex08-20260405-085000",
|
||||
"batch_size": 80,
|
||||
"num_workers": 4,
|
||||
"max_steps": 2,
|
||||
"note": "2-step local CUDA smoke passed without OOM using top/front only."
|
||||
},
|
||||
"main_run": {
|
||||
"status": "running",
|
||||
"host": "local",
|
||||
"gpu": 0,
|
||||
"pid": 1693348,
|
||||
"run_name": "imf-resnet-topfront-2cam-ph16-ex08-emb384-l12-ms50k-5090-20260405-085023",
|
||||
"run_dir": "/home/droid/project/roboimi/.worktrees/feat-imf-attnres-policy/runs/imf-resnet-topfront-2cam-ph16-ex08-emb384-l12-ms50k-5090-20260405-085023",
|
||||
"log_path": "/home/droid/project/roboimi/.worktrees/feat-imf-attnres-policy/runs/imf-resnet-topfront-2cam-ph16-ex08-emb384-l12-ms50k-5090-20260405-085023/train_vla.log",
|
||||
"launch_log": "/home/droid/project/roboimi/.worktrees/feat-imf-attnres-policy/experiment_suites/2026-04-05-top-front-resnet-2cam/launch_logs/imf-resnet-topfront-2cam-ph16-ex08-emb384-l12-ms50k-5090-20260405-085023.launch.log",
|
||||
"dataset_dir": "/home/droid/project/diana_sim/sim_transfer",
|
||||
"camera_names": [
|
||||
"top",
|
||||
"front"
|
||||
],
|
||||
"pred_horizon": 16,
|
||||
"num_action_steps": 8,
|
||||
"head_n_emb": 384,
|
||||
"head_n_layer": 12,
|
||||
"vision_backbone_mode": "resnet",
|
||||
"pretrained_backbone_weights": null,
|
||||
"freeze_backbone": false,
|
||||
"batch_size": 80,
|
||||
"lr": 0.00025,
|
||||
"num_workers": 12,
|
||||
"max_steps": 50000,
|
||||
"rollout_val_freq_epochs": 5,
|
||||
"rollout_num_episodes": 5,
|
||||
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/vi77mn5dwd19z4nttxab8",
|
||||
"latest_step": 500,
|
||||
"latest_loss": 0.0978,
|
||||
"process_running": true
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
# CHECKLIST
|
||||
|
||||
- [x] Create run contract
|
||||
- [x] Remote smoke test passes
|
||||
- [x] Launch 50k main run
|
||||
- [x] Record pid / log / SwanLab
|
||||
- [x] Report status back to user
|
||||
12
experiment_suites/2026-04-05-top-only-resnet-1cam/PLAN.md
Normal file
12
experiment_suites/2026-04-05-top-only-resnet-1cam/PLAN.md
Normal file
@@ -0,0 +1,12 @@
|
||||
# PLAN
|
||||
|
||||
## Goal
|
||||
Train a 50k-step IMF baseline with the original ResNet vision backbone using top only as the only image conditioning.
|
||||
|
||||
## Fixed comparison contract
|
||||
- same hyperparameters as the active top/front run
|
||||
- cameras: ['top']
|
||||
- num_cams=1
|
||||
- head.cond_dim=80
|
||||
- host: 100.119.99.14
|
||||
- gpu: 4
|
||||
@@ -0,0 +1,6 @@
|
||||
# Notes
|
||||
|
||||
- 2026-04-05 12:58:22: smoke passed for ['top'] on 100.119.99.14 GPU4.
|
||||
- 2026-04-05 12:59:24: launched main run `imf-resnet-top-1cam-ph16-ex08-emb384-l12-ms50k-l20g4-20260405-125844`.
|
||||
- 2026-04-05 13:01:20: latest confirmed progress step=400, loss=0.1233.
|
||||
- SwanLab: https://swanlab.cn/@game-loader/roboimi-vla/runs/egzo29l3z9ftsaunhf025
|
||||
@@ -0,0 +1,47 @@
|
||||
{
|
||||
"suite_name": "2026-04-05-top-only-resnet-1cam",
|
||||
"updated_at": "2026-04-05 13:01:20",
|
||||
"phase": "running",
|
||||
"smoke_test": {
|
||||
"status": "passed",
|
||||
"host": "100.119.99.14",
|
||||
"gpu": 4,
|
||||
"run_dir": "/home/droid/roboimi_suite_20260404/runs/smoke-toponly-resnet-ph16-ex08-20260405-125812",
|
||||
"batch_size": 80,
|
||||
"max_steps": 2,
|
||||
"note": "2-step remote CUDA smoke passed without OOM."
|
||||
},
|
||||
"main_run": {
|
||||
"status": "running",
|
||||
"host": "100.119.99.14",
|
||||
"gpu": 4,
|
||||
"launch_pid": 164808,
|
||||
"pid": 164813,
|
||||
"run_name": "imf-resnet-top-1cam-ph16-ex08-emb384-l12-ms50k-l20g4-20260405-125844",
|
||||
"run_dir": "/home/droid/roboimi_suite_20260404/runs/imf-resnet-top-1cam-ph16-ex08-emb384-l12-ms50k-l20g4-20260405-125844",
|
||||
"log_path": "/home/droid/roboimi_suite_20260404/runs/imf-resnet-top-1cam-ph16-ex08-emb384-l12-ms50k-l20g4-20260405-125844/train_vla.log",
|
||||
"launch_log": "/home/droid/roboimi_suite_20260404/experiment_suite_launch_logs/imf-resnet-top-1cam-ph16-ex08-emb384-l12-ms50k-l20g4-20260405-125844.launch.log",
|
||||
"dataset_dir": "/home/droid/sim_dataset/sim_transfer",
|
||||
"camera_names": [
|
||||
"top"
|
||||
],
|
||||
"pred_horizon": 16,
|
||||
"num_action_steps": 8,
|
||||
"head_cond_dim": 80,
|
||||
"head_n_emb": 384,
|
||||
"head_n_layer": 12,
|
||||
"vision_backbone_mode": "resnet",
|
||||
"pretrained_backbone_weights": null,
|
||||
"freeze_backbone": false,
|
||||
"batch_size": 80,
|
||||
"lr": 0.00025,
|
||||
"num_workers": 12,
|
||||
"max_steps": 50000,
|
||||
"rollout_val_freq_epochs": 5,
|
||||
"rollout_num_episodes": 5,
|
||||
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/egzo29l3z9ftsaunhf025",
|
||||
"latest_step": 400,
|
||||
"latest_loss": 0.1233,
|
||||
"process_running": true
|
||||
}
|
||||
}
|
||||
@@ -106,7 +106,11 @@ def load_checkpoint(
|
||||
return agent, stats
|
||||
|
||||
|
||||
def prepare_observation(obs: Dict, camera_names: list) -> Dict:
|
||||
def prepare_observation(
|
||||
obs: Dict,
|
||||
camera_names: list,
|
||||
image_resize_shape: Optional[tuple[int, int]] = (224, 224),
|
||||
) -> Dict:
|
||||
"""
|
||||
将环境观测转换为 agent 格式。
|
||||
|
||||
@@ -117,14 +121,13 @@ def prepare_observation(obs: Dict, camera_names: list) -> Dict:
|
||||
Returns:
|
||||
agent 格式的观测字典
|
||||
"""
|
||||
import cv2
|
||||
|
||||
# 转换图像: numpy -> tensor, HWC -> CHW
|
||||
images = {}
|
||||
for cam_name in camera_names:
|
||||
img = obs['images'][cam_name]
|
||||
# Resize 到 224x224(与训练时一致)
|
||||
img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_LINEAR)
|
||||
if image_resize_shape is not None:
|
||||
import cv2
|
||||
img = cv2.resize(img, tuple(image_resize_shape), interpolation=cv2.INTER_LINEAR)
|
||||
img = rearrange(img, 'h w c -> c h w')
|
||||
img = torch.from_numpy(img / 255.0).float()
|
||||
images[cam_name] = img
|
||||
@@ -668,6 +671,8 @@ def _run_eval(cfg: DictConfig):
|
||||
agent_cfg=cfg.agent,
|
||||
device=device
|
||||
)
|
||||
vision_encoder = getattr(agent, 'vision_encoder', None)
|
||||
image_resize_shape = getattr(vision_encoder, 'eval_image_resize_shape', (224, 224))
|
||||
|
||||
# 重置 agent 的队列
|
||||
agent.reset()
|
||||
@@ -725,7 +730,11 @@ def _run_eval(cfg: DictConfig):
|
||||
video_recorder.write(video_frame)
|
||||
|
||||
# 准备给 agent 的观测
|
||||
observation = prepare_observation(obs, camera_names)
|
||||
observation = prepare_observation(
|
||||
obs,
|
||||
camera_names,
|
||||
image_resize_shape=image_resize_shape,
|
||||
)
|
||||
end_preprocess = time.perf_counter()
|
||||
|
||||
# 选择动作(agent 内部处理队列管理)
|
||||
|
||||
@@ -380,7 +380,14 @@ def _run_training(cfg: DictConfig):
|
||||
# =========================================================================
|
||||
log.info("📦 加载数据集...")
|
||||
try:
|
||||
dataset = instantiate(cfg.data)
|
||||
dataset_image_resize_shape = cfg.data.get('image_resize_shape', (224, 224))
|
||||
vision_backbone_cfg = cfg.agent.get('vision_backbone', None)
|
||||
if vision_backbone_cfg is not None and 'dataset_image_resize_shape' in vision_backbone_cfg:
|
||||
dataset_image_resize_shape = vision_backbone_cfg.get('dataset_image_resize_shape')
|
||||
dataset = instantiate(
|
||||
cfg.data,
|
||||
image_resize_shape=dataset_image_resize_shape,
|
||||
)
|
||||
log.info(f"✅ 数据集加载成功。总样本数: {len(dataset)}")
|
||||
except Exception as e:
|
||||
log.error(f"❌ 数据集加载失败: {e}")
|
||||
|
||||
@@ -27,6 +27,7 @@ class VLAAgent(nn.Module):
|
||||
normalization_type='min_max', # 归一化类型: 'gaussian' 或 'min_max'
|
||||
num_action_steps=8, # 每次推理实际执行多少步动作
|
||||
head_type='unet', # Policy head类型: 'unet' 或 'transformer'
|
||||
cond_projector=None, # 可选:将视觉+状态条件投影到head期望维度
|
||||
):
|
||||
super().__init__()
|
||||
# 保存参数
|
||||
@@ -74,15 +75,32 @@ class VLAAgent(nn.Module):
|
||||
self.vision_encoder = vision_backbone
|
||||
if self.camera_names is not None:
|
||||
self.vision_encoder.camera_names = self.camera_names
|
||||
single_cam_feat_dim = self.vision_encoder.output_dim
|
||||
# global_cond_dim: 展平后的总维度(用于UNet)
|
||||
total_vision_dim = single_cam_feat_dim * num_cams * obs_horizon
|
||||
total_prop_dim = obs_dim * obs_horizon
|
||||
self.global_cond_dim = total_vision_dim + total_prop_dim
|
||||
self.condition_tokens_per_step = int(getattr(self.vision_encoder, 'tokens_per_step', 1))
|
||||
joint_vision_dim = getattr(self.vision_encoder, 'joint_output_dim', None)
|
||||
if joint_vision_dim is not None:
|
||||
per_token_vision_dim = int(joint_vision_dim)
|
||||
self.condition_tokens_per_step = 1
|
||||
else:
|
||||
single_cam_feat_dim = self.vision_encoder.output_dim
|
||||
if self.condition_tokens_per_step > 1:
|
||||
per_token_vision_dim = int(single_cam_feat_dim)
|
||||
else:
|
||||
per_token_vision_dim = int(single_cam_feat_dim) * int(num_cams)
|
||||
|
||||
# per_step_cond_dim: 每步的条件维度(用于Transformer)
|
||||
# 注意:这里不乘以obs_horizon,因为Transformer的输入是序列形式
|
||||
self.per_step_cond_dim = single_cam_feat_dim * num_cams + obs_dim
|
||||
self.condition_sequence_length = self.obs_horizon * self.condition_tokens_per_step
|
||||
self.raw_per_step_cond_dim = per_token_vision_dim + obs_dim
|
||||
if cond_projector is None:
|
||||
self.cond_projector = None
|
||||
self.per_step_cond_dim = self.raw_per_step_cond_dim
|
||||
else:
|
||||
if isinstance(cond_projector, nn.Module):
|
||||
self.cond_projector = cond_projector
|
||||
else:
|
||||
self.cond_projector = cond_projector(input_dim=self.raw_per_step_cond_dim)
|
||||
self.per_step_cond_dim = self._projector_output_dim(self.cond_projector, self.raw_per_step_cond_dim)
|
||||
|
||||
# global_cond_dim: 展平后的总维度(用于UNet)
|
||||
self.global_cond_dim = self.per_step_cond_dim * self.condition_sequence_length
|
||||
|
||||
self.noise_scheduler = DDPMScheduler(
|
||||
num_train_timesteps=diffusion_steps,
|
||||
@@ -111,7 +129,7 @@ class VLAAgent(nn.Module):
|
||||
input_dim=action_dim,
|
||||
output_dim=action_dim,
|
||||
horizon=pred_horizon,
|
||||
n_obs_steps=obs_horizon,
|
||||
n_obs_steps=self.condition_sequence_length,
|
||||
cond_dim=self.per_step_cond_dim # 每步的条件维度
|
||||
)
|
||||
else: # 'unet' (default)
|
||||
@@ -143,6 +161,20 @@ class VLAAgent(nn.Module):
|
||||
return tuple(self._move_to_device(v, device) for v in data)
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def _projector_output_dim(projector: nn.Module, fallback: int) -> int:
|
||||
output_dim = getattr(projector, 'output_dim', None)
|
||||
if output_dim is not None:
|
||||
return int(output_dim)
|
||||
out_features = getattr(projector, 'out_features', None)
|
||||
if out_features is not None:
|
||||
return int(out_features)
|
||||
linear = getattr(projector, 'linear', None)
|
||||
linear_out_features = getattr(linear, 'out_features', None)
|
||||
if linear_out_features is not None:
|
||||
return int(linear_out_features)
|
||||
return int(fallback)
|
||||
|
||||
def _order_images(self, images: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
"""按显式配置的相机顺序返回图像字典。"""
|
||||
if self.camera_names is None:
|
||||
@@ -165,7 +197,43 @@ class VLAAgent(nn.Module):
|
||||
ordered_images = self._order_images(images)
|
||||
visual_features = self.vision_encoder(ordered_images)
|
||||
state_features = self.state_encoder(states)
|
||||
if visual_features.ndim == 4:
|
||||
batch_size, obs_steps, token_count, _ = visual_features.shape
|
||||
if obs_steps != state_features.shape[1]:
|
||||
raise RuntimeError(
|
||||
f"观测时间维不匹配: visual={obs_steps}, state={state_features.shape[1]}"
|
||||
)
|
||||
if token_count != self.condition_tokens_per_step:
|
||||
raise RuntimeError(
|
||||
f"条件token数量不匹配: got {token_count}, expected {self.condition_tokens_per_step}"
|
||||
)
|
||||
state_features = state_features.unsqueeze(2).expand(-1, -1, token_count, -1)
|
||||
cond = torch.cat([visual_features, state_features], dim=-1)
|
||||
if cond.shape[-1] != self.raw_per_step_cond_dim:
|
||||
raise RuntimeError(
|
||||
f"原始条件维度不匹配: got {cond.shape[-1]}, expected {self.raw_per_step_cond_dim}"
|
||||
)
|
||||
if self.cond_projector is not None:
|
||||
cond = self.cond_projector(cond)
|
||||
if cond.shape[-1] != self.per_step_cond_dim:
|
||||
raise RuntimeError(
|
||||
f"条件维度不匹配: got {cond.shape[-1]}, expected {self.per_step_cond_dim}"
|
||||
)
|
||||
cond = cond.reshape(batch_size, obs_steps * token_count, self.per_step_cond_dim)
|
||||
expected_length = self.condition_sequence_length
|
||||
if cond.shape[1] != expected_length:
|
||||
raise RuntimeError(
|
||||
f"条件序列长度不匹配: got {cond.shape[1]}, expected {expected_length}"
|
||||
)
|
||||
return cond
|
||||
|
||||
cond = torch.cat([visual_features, state_features], dim=-1)
|
||||
if cond.shape[-1] != self.raw_per_step_cond_dim:
|
||||
raise RuntimeError(
|
||||
f"原始条件维度不匹配: got {cond.shape[-1]}, expected {self.raw_per_step_cond_dim}"
|
||||
)
|
||||
if self.cond_projector is not None:
|
||||
cond = self.cond_projector(cond)
|
||||
if cond.shape[-1] != self.per_step_cond_dim:
|
||||
raise RuntimeError(
|
||||
f"条件维度不匹配: got {cond.shape[-1]}, expected {self.per_step_cond_dim}"
|
||||
|
||||
41
roboimi/vla/conf/agent/lewm_imf_attnres.yaml
Normal file
41
roboimi/vla/conf/agent/lewm_imf_attnres.yaml
Normal file
@@ -0,0 +1,41 @@
|
||||
# @package agent
|
||||
defaults:
|
||||
- /backbone@vision_backbone: lewm_vit_diffusion
|
||||
- /modules@state_encoder: identity_state_encoder
|
||||
- /modules@action_encoder: identity_action_encoder
|
||||
- /head: imf_transformer1d
|
||||
- _self_
|
||||
|
||||
_target_: roboimi.vla.agent_imf.IMFVLAAgent
|
||||
|
||||
action_dim: 16
|
||||
obs_dim: 16
|
||||
normalization_type: "min_max"
|
||||
pred_horizon: 16
|
||||
obs_horizon: 2
|
||||
num_action_steps: 8
|
||||
camera_names: ${data.camera_names}
|
||||
num_cams: 3
|
||||
|
||||
vision_backbone:
|
||||
num_cameras: ${agent.num_cams}
|
||||
camera_names: ${agent.camera_names}
|
||||
fused_camera_names: [front, top, r_vis]
|
||||
|
||||
diffusion_steps: 100
|
||||
inference_steps: 1
|
||||
head_type: "transformer"
|
||||
|
||||
head:
|
||||
input_dim: ${agent.action_dim}
|
||||
output_dim: ${agent.action_dim}
|
||||
horizon: ${agent.pred_horizon}
|
||||
n_obs_steps: ${agent.obs_horizon}
|
||||
cond_dim: 208
|
||||
causal_attn: false
|
||||
time_as_cond: true
|
||||
obs_as_cond: true
|
||||
n_cond_layers: 0
|
||||
backbone_type: attnres_full
|
||||
n_head: 1
|
||||
n_kv_head: 1
|
||||
48
roboimi/vla/conf/agent/resnet_imf_attnres_multitoken.yaml
Normal file
48
roboimi/vla/conf/agent/resnet_imf_attnres_multitoken.yaml
Normal file
@@ -0,0 +1,48 @@
|
||||
# @package agent
|
||||
defaults:
|
||||
- /backbone@vision_backbone: resnet_diffusion
|
||||
- /modules@state_encoder: identity_state_encoder
|
||||
- /modules@action_encoder: identity_action_encoder
|
||||
- /modules@cond_projector: linear_condition_projector
|
||||
- /head: imf_transformer1d
|
||||
- _self_
|
||||
|
||||
_target_: roboimi.vla.agent_imf.IMFVLAAgent
|
||||
|
||||
action_dim: 16
|
||||
obs_dim: 16
|
||||
normalization_type: "min_max"
|
||||
pred_horizon: 16
|
||||
obs_horizon: 2
|
||||
num_action_steps: 8
|
||||
camera_names: ${data.camera_names}
|
||||
num_cams: ${len:${agent.camera_names}}
|
||||
|
||||
vision_backbone:
|
||||
num_cameras: ${agent.num_cams}
|
||||
camera_names: ${agent.camera_names}
|
||||
vision_backbone: "resnet18"
|
||||
vision_backbone_mode: "resnet"
|
||||
freeze_backbone: false
|
||||
use_separate_rgb_encoder_per_camera: true
|
||||
output_tokens_per_camera: true
|
||||
|
||||
cond_projector:
|
||||
output_dim: ${agent.head.n_emb}
|
||||
|
||||
diffusion_steps: 100
|
||||
inference_steps: 1
|
||||
head_type: "transformer"
|
||||
|
||||
head:
|
||||
input_dim: ${agent.action_dim}
|
||||
output_dim: ${agent.action_dim}
|
||||
horizon: ${agent.pred_horizon}
|
||||
cond_dim: ${agent.head.n_emb}
|
||||
causal_attn: false
|
||||
time_as_cond: true
|
||||
obs_as_cond: true
|
||||
n_cond_layers: 0
|
||||
backbone_type: attnres_full
|
||||
n_head: 1
|
||||
n_kv_head: 1
|
||||
44
roboimi/vla/conf/agent/siglip2_imf_attnres.yaml
Normal file
44
roboimi/vla/conf/agent/siglip2_imf_attnres.yaml
Normal file
@@ -0,0 +1,44 @@
|
||||
# @package agent
|
||||
defaults:
|
||||
- /backbone@vision_backbone: siglip2_diffusion
|
||||
- /modules@state_encoder: identity_state_encoder
|
||||
- /modules@action_encoder: identity_action_encoder
|
||||
- /modules@cond_projector: linear_condition_projector
|
||||
- /head: imf_transformer1d
|
||||
- _self_
|
||||
|
||||
_target_: roboimi.vla.agent_imf.IMFVLAAgent
|
||||
|
||||
action_dim: 16
|
||||
obs_dim: 16
|
||||
normalization_type: "min_max"
|
||||
pred_horizon: 16
|
||||
obs_horizon: 2
|
||||
num_action_steps: 8
|
||||
camera_names: ${data.camera_names}
|
||||
num_cams: ${len:${agent.camera_names}}
|
||||
|
||||
vision_backbone:
|
||||
num_cameras: ${agent.num_cams}
|
||||
camera_names: ${agent.camera_names}
|
||||
|
||||
cond_projector:
|
||||
output_dim: ${agent.head.cond_dim}
|
||||
|
||||
diffusion_steps: 100
|
||||
inference_steps: 1
|
||||
head_type: "transformer"
|
||||
|
||||
head:
|
||||
input_dim: ${agent.action_dim}
|
||||
output_dim: ${agent.action_dim}
|
||||
horizon: ${agent.pred_horizon}
|
||||
n_obs_steps: ${agent.obs_horizon}
|
||||
cond_dim: 384
|
||||
causal_attn: false
|
||||
time_as_cond: true
|
||||
obs_as_cond: true
|
||||
n_cond_layers: 0
|
||||
backbone_type: attnres_full
|
||||
n_head: 1
|
||||
n_kv_head: 1
|
||||
16
roboimi/vla/conf/backbone/lewm_vit_diffusion.yaml
Normal file
16
roboimi/vla/conf/backbone/lewm_vit_diffusion.yaml
Normal file
@@ -0,0 +1,16 @@
|
||||
_target_: roboimi.vla.models.backbones.lewm_vit_backbone.LEWMViTBackbone
|
||||
|
||||
# LEWM checkpoint path; override this on the target machine.
|
||||
checkpoint_path: null
|
||||
|
||||
# Input camera contract for roboimi; internal LEWM fusion order stays front/top/r_vis.
|
||||
num_cameras: 3
|
||||
camera_names: [r_vis, top, front]
|
||||
fused_camera_names: [front, top, r_vis]
|
||||
|
||||
freeze_backbone: true
|
||||
joint_output_dim: 192
|
||||
output_dim: 192
|
||||
image_size: 224
|
||||
dataset_image_resize_shape: null
|
||||
eval_image_resize_shape: [256, 256]
|
||||
@@ -31,6 +31,8 @@ spatial_softmax_num_keypoints: 32 # Spatial Softmax 关键点数量
|
||||
# false: 共享编码器(所有摄像头共享一个 ResNet,参数少但容量受限)推荐!
|
||||
# true: 独立编码器(每个摄像头有独立的 ResNet,参数多但容量大)
|
||||
use_separate_rgb_encoder_per_camera: true
|
||||
# false: 将所有相机特征拼成一个条件token;true: 每个相机输出一个独立token
|
||||
output_tokens_per_camera: false
|
||||
num_cameras: 3 # 摄像头数量
|
||||
|
||||
# ====================
|
||||
|
||||
10
roboimi/vla/conf/backbone/siglip2_diffusion.yaml
Normal file
10
roboimi/vla/conf/backbone/siglip2_diffusion.yaml
Normal file
@@ -0,0 +1,10 @@
|
||||
_target_: roboimi.vla.models.backbones.siglip2_diffusion_backbone.SigLIP2DiffusionBackbone
|
||||
|
||||
model_name: google/siglip2-base-patch16-256
|
||||
camera_names: [r_vis, top, front]
|
||||
num_cameras: 3
|
||||
per_view_output_dim: 96
|
||||
freeze_backbone: true
|
||||
|
||||
dataset_image_resize_shape: null
|
||||
eval_image_resize_shape: [256, 256]
|
||||
@@ -19,3 +19,6 @@ camera_names:
|
||||
- r_vis # 机器人视角相机
|
||||
- top # 顶部相机
|
||||
- front # 前方相机
|
||||
|
||||
# 单视角预缩放尺寸;为 null 时保留数据集中的原始分辨率
|
||||
image_resize_shape: [224, 224]
|
||||
|
||||
5
roboimi/vla/conf/modules/linear_condition_projector.yaml
Normal file
5
roboimi/vla/conf/modules/linear_condition_projector.yaml
Normal file
@@ -0,0 +1,5 @@
|
||||
_target_: roboimi.vla.modules.projectors.LinearConditionProjector
|
||||
_partial_: true
|
||||
|
||||
output_dim: 384
|
||||
bias: true
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import h5py
|
||||
from torch.utils.data import Dataset
|
||||
from typing import List, Dict, Union
|
||||
from typing import List, Dict, Union, Optional, Sequence
|
||||
from pathlib import Path
|
||||
from collections import OrderedDict
|
||||
|
||||
@@ -22,6 +22,7 @@ class SimpleRobotDataset(Dataset):
|
||||
obs_horizon: int = 2,
|
||||
pred_horizon: int = 8,
|
||||
camera_names: List[str] = None,
|
||||
image_resize_shape: Optional[Sequence[int]] = (224, 224),
|
||||
max_open_files: int = 64,
|
||||
):
|
||||
"""
|
||||
@@ -30,6 +31,7 @@ class SimpleRobotDataset(Dataset):
|
||||
obs_horizon: 观察过去多少帧
|
||||
pred_horizon: 预测未来多少帧动作
|
||||
camera_names: 相机名称列表,如 ["r_vis", "top", "front"]
|
||||
image_resize_shape: 图像缩放尺寸 (W, H);为 None 时保留原始分辨率
|
||||
max_open_files: 每个 worker 最多缓存的 HDF5 文件句柄数
|
||||
|
||||
HDF5 文件格式:
|
||||
@@ -40,6 +42,10 @@ class SimpleRobotDataset(Dataset):
|
||||
self.obs_horizon = obs_horizon
|
||||
self.pred_horizon = pred_horizon
|
||||
self.camera_names = camera_names or []
|
||||
self.image_resize_shape = (
|
||||
tuple(int(v) for v in image_resize_shape)
|
||||
if image_resize_shape is not None else None
|
||||
)
|
||||
self.max_open_files = max(1, int(max_open_files))
|
||||
self._file_cache: "OrderedDict[str, h5py.File]" = OrderedDict()
|
||||
|
||||
@@ -123,9 +129,9 @@ class SimpleRobotDataset(Dataset):
|
||||
h5_path = f'observations/images/{cam_name}'
|
||||
if h5_path in f:
|
||||
img = f[h5_path][meta["frame_idx"]]
|
||||
# Resize图像到224x224(减少内存和I/O负担)
|
||||
import cv2
|
||||
img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_LINEAR)
|
||||
if self.image_resize_shape is not None:
|
||||
import cv2
|
||||
img = cv2.resize(img, self.image_resize_shape, interpolation=cv2.INTER_LINEAR)
|
||||
# 转换为float并归一化到 [0, 1]
|
||||
img = torch.from_numpy(img).float() / 255.0
|
||||
frame[f"observation.{cam_name}"] = img.permute(2, 0, 1) # HWC -> CHW
|
||||
|
||||
@@ -1,4 +1,15 @@
|
||||
# Backbone models
|
||||
from .resnet_diffusion import ResNetDiffusionBackbone
|
||||
__all__ = ["LEWMViTBackbone", "ResNetBackbone", "ResNetDiffusionBackbone", "SigLIP2DiffusionBackbone"]
|
||||
|
||||
__all__ = ["ResNetBackbone", "ResNetDiffusionBackbone"]
|
||||
|
||||
def __getattr__(name):
|
||||
if name == "LEWMViTBackbone":
|
||||
from .lewm_vit_backbone import LEWMViTBackbone
|
||||
return LEWMViTBackbone
|
||||
if name == "SigLIP2DiffusionBackbone":
|
||||
from .siglip2_diffusion_backbone import SigLIP2DiffusionBackbone
|
||||
return SigLIP2DiffusionBackbone
|
||||
if name in {"ResNetBackbone", "ResNetDiffusionBackbone"}:
|
||||
from .resnet_diffusion import ResNetDiffusionBackbone
|
||||
return ResNetDiffusionBackbone
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
230
roboimi/vla/models/backbones/lewm_vit_backbone.py
Normal file
230
roboimi/vla/models/backbones/lewm_vit_backbone.py
Normal file
@@ -0,0 +1,230 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Mapping, Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from roboimi.vla.core.interfaces import VLABackbone
|
||||
|
||||
|
||||
class _LEWMProjector(nn.Module):
|
||||
"""LEWM projector MLP: 192 -> 2048 -> 192 with BatchNorm1d + GELU."""
|
||||
|
||||
def __init__(self, input_dim: int = 192, hidden_dim: int = 2048, output_dim: int = 192) -> None:
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(input_dim, hidden_dim),
|
||||
nn.BatchNorm1d(hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, output_dim),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class LEWMViTBackbone(VLABackbone):
|
||||
"""Frozen LEWM joint-multiview ViT backbone.
|
||||
|
||||
The backbone fuses the three camera views into a single LEWM-style image,
|
||||
runs a ViT-tiny encoder plus the LEWM projector, and returns one joint
|
||||
192-d embedding per timestep.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_path: str | Path | None = None,
|
||||
*,
|
||||
checkpoint: Mapping[str, Any] | None = None,
|
||||
camera_names: Sequence[str] = ("r_vis", "top", "front"),
|
||||
fused_camera_names: Sequence[str] = ("front", "top", "r_vis"),
|
||||
num_cameras: int | None = None,
|
||||
dataset_image_resize_shape: Sequence[int] | None = None,
|
||||
eval_image_resize_shape: Sequence[int] | None = (256, 256),
|
||||
freeze_backbone: bool = True,
|
||||
joint_output_dim: int = 192,
|
||||
image_size: int = 224,
|
||||
output_dim: int = 192,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.camera_names = tuple(camera_names)
|
||||
self.fused_camera_names = tuple(fused_camera_names)
|
||||
self.num_cameras = int(num_cameras) if num_cameras is not None else len(self.camera_names)
|
||||
self.freeze_backbone = bool(freeze_backbone)
|
||||
self.joint_output_dim = int(joint_output_dim)
|
||||
self.image_size = int(image_size)
|
||||
self._output_dim = int(output_dim)
|
||||
self.dataset_image_resize_shape = (
|
||||
tuple(int(v) for v in dataset_image_resize_shape)
|
||||
if dataset_image_resize_shape is not None else None
|
||||
)
|
||||
self.eval_image_resize_shape = (
|
||||
tuple(int(v) for v in eval_image_resize_shape)
|
||||
if eval_image_resize_shape is not None else None
|
||||
)
|
||||
if self.num_cameras != len(self.camera_names):
|
||||
raise ValueError(
|
||||
f"num_cameras({self.num_cameras}) must match len(camera_names)({len(self.camera_names)})"
|
||||
)
|
||||
if set(self.fused_camera_names) != set(self.camera_names):
|
||||
raise ValueError(
|
||||
"fused_camera_names must contain the same cameras as camera_names. "
|
||||
f"got camera_names={list(self.camera_names)}, fused_camera_names={list(self.fused_camera_names)}"
|
||||
)
|
||||
|
||||
self.encoder = self._build_encoder(self.image_size)
|
||||
self.projector = _LEWMProjector(
|
||||
input_dim=self.encoder.config.hidden_size,
|
||||
hidden_dim=2048,
|
||||
output_dim=self.joint_output_dim,
|
||||
)
|
||||
|
||||
self.register_buffer(
|
||||
"mean",
|
||||
torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32).view(1, 3, 1, 1),
|
||||
)
|
||||
self.register_buffer(
|
||||
"std",
|
||||
torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32).view(1, 3, 1, 1),
|
||||
)
|
||||
|
||||
if checkpoint_path is not None and checkpoint is not None:
|
||||
raise ValueError("checkpoint_path and checkpoint cannot both be provided")
|
||||
if checkpoint_path is not None:
|
||||
self.load_lewm_checkpoint(checkpoint_path)
|
||||
elif checkpoint is not None:
|
||||
self.load_lewm_checkpoint(checkpoint)
|
||||
|
||||
if self.freeze_backbone:
|
||||
self._freeze_encoder_and_projector()
|
||||
|
||||
@staticmethod
|
||||
def _build_encoder_config(image_size: int):
|
||||
from transformers import ViTConfig
|
||||
|
||||
return ViTConfig(
|
||||
image_size=image_size,
|
||||
patch_size=14,
|
||||
num_channels=3,
|
||||
hidden_size=192,
|
||||
intermediate_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=3,
|
||||
qkv_bias=True,
|
||||
hidden_dropout_prob=0.0,
|
||||
attention_probs_dropout_prob=0.0,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _build_encoder(cls, image_size: int) -> nn.Module:
|
||||
from transformers import ViTModel
|
||||
|
||||
return ViTModel(cls._build_encoder_config(image_size), add_pooling_layer=False)
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_state_dict(payload: Mapping[str, Any]) -> Mapping[str, torch.Tensor]:
|
||||
state_dict = payload.get("state_dict", payload)
|
||||
if not isinstance(state_dict, Mapping):
|
||||
raise TypeError("checkpoint payload must contain a mapping state_dict")
|
||||
return state_dict
|
||||
|
||||
@staticmethod
|
||||
def _extract_prefixed_state_dict(
|
||||
state_dict: Mapping[str, torch.Tensor],
|
||||
prefix: str,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
extracted = {
|
||||
key[len(prefix) :]: value
|
||||
for key, value in state_dict.items()
|
||||
if key.startswith(prefix)
|
||||
}
|
||||
if not extracted:
|
||||
raise KeyError(f"checkpoint missing parameters with prefix {prefix!r}")
|
||||
return extracted
|
||||
|
||||
def load_lewm_checkpoint(self, checkpoint_or_path: str | Path | Mapping[str, Any]) -> None:
|
||||
if isinstance(checkpoint_or_path, (str, Path)):
|
||||
payload = torch.load(Path(checkpoint_or_path), map_location="cpu", weights_only=False)
|
||||
else:
|
||||
payload = checkpoint_or_path
|
||||
|
||||
state_dict = self._unwrap_state_dict(payload)
|
||||
encoder_state_dict = self._extract_prefixed_state_dict(state_dict, "model.encoder.")
|
||||
projector_state_dict = self._extract_prefixed_state_dict(state_dict, "model.projector.")
|
||||
|
||||
self.encoder.load_state_dict(encoder_state_dict, strict=True)
|
||||
self.projector.load_state_dict(projector_state_dict, strict=True)
|
||||
|
||||
def _freeze_encoder_and_projector(self) -> None:
|
||||
for module in (self.encoder, self.projector):
|
||||
module.eval()
|
||||
for parameter in module.parameters():
|
||||
parameter.requires_grad = False
|
||||
|
||||
def train(self, mode: bool = True) -> "LEWMViTBackbone":
|
||||
super().train(mode)
|
||||
if self.freeze_backbone:
|
||||
self._freeze_encoder_and_projector()
|
||||
return self
|
||||
|
||||
def _ordered_images(self, images: Dict[str, torch.Tensor]) -> list[torch.Tensor]:
|
||||
missing = [camera_name for camera_name in self.camera_names if camera_name not in images]
|
||||
if missing:
|
||||
raise ValueError(
|
||||
f"image input missing required cameras. missing={missing}, expected={list(self.camera_names)}"
|
||||
)
|
||||
|
||||
ordered = [images[camera_name] for camera_name in self.camera_names]
|
||||
reference_shape = ordered[0].shape
|
||||
if len(reference_shape) != 5:
|
||||
raise ValueError(f"expected image tensors shaped (B, T, C, H, W), got {reference_shape}")
|
||||
|
||||
for camera_name, image in zip(self.camera_names[1:], ordered[1:]):
|
||||
if image.shape != reference_shape:
|
||||
raise ValueError(
|
||||
f"camera {camera_name!r} shape {tuple(image.shape)} does not match {tuple(reference_shape)}"
|
||||
)
|
||||
|
||||
return ordered
|
||||
|
||||
def _prepare_pixels(self, images: Dict[str, torch.Tensor]) -> tuple[torch.Tensor, int, int]:
|
||||
self._ordered_images(images)
|
||||
fused = torch.cat([images[camera_name] for camera_name in self.fused_camera_names], dim=-2)
|
||||
bsz, steps = fused.shape[:2]
|
||||
fused = fused.reshape(bsz * steps, *fused.shape[2:]).contiguous().float()
|
||||
|
||||
fused = fused.clamp(0.0, 1.0)
|
||||
fused = (fused - self.mean) / self.std
|
||||
|
||||
height, width = fused.shape[-2:]
|
||||
short_side = min(height, width)
|
||||
if short_side <= 0:
|
||||
raise ValueError(f"invalid fused image shape: {tuple(fused.shape)}")
|
||||
scale = self.image_size / float(short_side)
|
||||
resized_height = int(round(height * scale))
|
||||
resized_width = int(round(width * scale))
|
||||
if (resized_height, resized_width) != (height, width):
|
||||
fused = F.interpolate(
|
||||
fused,
|
||||
size=(resized_height, resized_width),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
antialias=True,
|
||||
)
|
||||
return fused, bsz, steps
|
||||
|
||||
def forward(self, images: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
pixels, bsz, steps = self._prepare_pixels(images)
|
||||
with torch.set_grad_enabled(torch.is_grad_enabled() and not self.freeze_backbone):
|
||||
output = self.encoder(pixel_values=pixels, interpolate_pos_encoding=True)
|
||||
cls = output.last_hidden_state[:, 0]
|
||||
embedding = self.projector(cls)
|
||||
return embedding.view(bsz, steps, self.joint_output_dim)
|
||||
|
||||
@property
|
||||
def output_dim(self) -> int:
|
||||
return self._output_dim
|
||||
@@ -211,6 +211,7 @@ class ResNetDiffusionBackbone(VLABackbone):
|
||||
use_group_norm: bool = True,
|
||||
spatial_softmax_num_keypoints: int = 32,
|
||||
use_separate_rgb_encoder_per_camera: bool = False, # 新增:是否为每个摄像头使用独立编码器
|
||||
output_tokens_per_camera: bool = False, # 是否按相机返回多token,而不是拼成一个token
|
||||
num_cameras: int = 1, # 新增:摄像头数量(仅在独立编码器模式下使用)
|
||||
camera_names: Optional[Tuple[str, ...]] = None, # 显式相机顺序
|
||||
freeze_backbone: bool = True, # 新增:是否冻结ResNet backbone(推荐True)
|
||||
@@ -229,7 +230,9 @@ class ResNetDiffusionBackbone(VLABackbone):
|
||||
super().__init__()
|
||||
|
||||
self.use_separate_rgb_encoder_per_camera = use_separate_rgb_encoder_per_camera
|
||||
self.output_tokens_per_camera = bool(output_tokens_per_camera)
|
||||
self.num_cameras = num_cameras
|
||||
self.tokens_per_step = self.num_cameras if self.output_tokens_per_camera else 1
|
||||
self.camera_names = tuple(camera_names) if camera_names is not None else None
|
||||
if self.camera_names is not None and len(self.camera_names) != self.num_cameras:
|
||||
raise ValueError(
|
||||
@@ -319,23 +322,25 @@ class ResNetDiffusionBackbone(VLABackbone):
|
||||
B, T = any_tensor.shape[:2]
|
||||
cam_names = self._ordered_camera_names(images)
|
||||
|
||||
features_all = []
|
||||
if self.use_separate_rgb_encoder_per_camera:
|
||||
# 独立编码器模式:每个摄像头使用对应的编码器
|
||||
features_all = []
|
||||
for cam_idx, cam_name in enumerate(cam_names):
|
||||
img = images[cam_name]
|
||||
encoder = self.rgb_encoder[cam_idx]
|
||||
features = encoder.forward_single_image(img.reshape(B * T, *img.shape[2:]))
|
||||
features_all.append(features)
|
||||
return torch.cat(features_all, dim=1).view(B, T, -1)
|
||||
else:
|
||||
# 共享编码器模式:所有摄像头共享同一个编码器
|
||||
features_all = []
|
||||
for cam_name in cam_names:
|
||||
img = images[cam_name]
|
||||
features = self.rgb_encoder.forward_single_image(img.reshape(B * T, *img.shape[2:]))
|
||||
features_all.append(features)
|
||||
return torch.cat(features_all, dim=1).view(B, T, -1)
|
||||
|
||||
if self.output_tokens_per_camera:
|
||||
stacked = torch.stack(features_all, dim=1) # (B*T, num_cams, feature_dim)
|
||||
return stacked.view(B, T, len(cam_names), self.feature_dim)
|
||||
return torch.cat(features_all, dim=1).view(B, T, -1)
|
||||
|
||||
@property
|
||||
def output_dim(self):
|
||||
|
||||
124
roboimi/vla/models/backbones/siglip2_diffusion_backbone.py
Normal file
124
roboimi/vla/models/backbones/siglip2_diffusion_backbone.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import SiglipVisionModel
|
||||
|
||||
from roboimi.vla.core.interfaces import VLABackbone
|
||||
|
||||
|
||||
class SigLIP2DiffusionBackbone(VLABackbone):
|
||||
"""Shared SigLIP vision tower for multiview diffusion-policy conditioning.
|
||||
|
||||
We intentionally load the checkpoint `google/siglip2-base-patch16-256` through
|
||||
`SiglipVisionModel.from_pretrained(...)` so each camera can be fed as a normal
|
||||
`(B, C, H, W)` image tensor and produce one pooled global feature vector.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = 'google/siglip2-base-patch16-256',
|
||||
*,
|
||||
model_name_or_path: str | None = None,
|
||||
vision_model: nn.Module | None = None,
|
||||
camera_names: Sequence[str] = ('r_vis', 'top', 'front'),
|
||||
num_cameras: Optional[int] = None,
|
||||
per_view_output_dim: int = 96,
|
||||
output_dim: int | None = None,
|
||||
freeze_backbone: bool = True,
|
||||
dataset_image_resize_shape: Sequence[int] | None = None,
|
||||
eval_image_resize_shape: Sequence[int] | None = (256, 256),
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if model_name_or_path is not None:
|
||||
model_name = model_name_or_path
|
||||
if output_dim is not None:
|
||||
per_view_output_dim = output_dim
|
||||
|
||||
self.model_name = str(model_name)
|
||||
self.camera_names = tuple(camera_names)
|
||||
self.num_cameras = int(num_cameras) if num_cameras is not None else len(self.camera_names)
|
||||
if len(self.camera_names) != self.num_cameras:
|
||||
raise ValueError(
|
||||
f'camera_names length ({len(self.camera_names)}) must match num_cameras ({self.num_cameras})'
|
||||
)
|
||||
|
||||
self._output_dim = int(per_view_output_dim)
|
||||
self.joint_output_dim = self._output_dim * self.num_cameras
|
||||
self.freeze_backbone = bool(freeze_backbone)
|
||||
self.dataset_image_resize_shape = self._normalize_resize_shape(dataset_image_resize_shape)
|
||||
self.eval_image_resize_shape = self._normalize_resize_shape(eval_image_resize_shape)
|
||||
|
||||
self.encoder = vision_model if vision_model is not None else SiglipVisionModel.from_pretrained(self.model_name)
|
||||
hidden_size = int(getattr(self.encoder.config, 'hidden_size'))
|
||||
self.view_projector = nn.Linear(hidden_size, self._output_dim)
|
||||
self.projector = self.view_projector
|
||||
|
||||
self.register_buffer('mean', torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32).view(1, 3, 1, 1))
|
||||
self.register_buffer('std', torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32).view(1, 3, 1, 1))
|
||||
|
||||
if self.freeze_backbone:
|
||||
self._freeze_encoder()
|
||||
|
||||
@staticmethod
|
||||
def _normalize_resize_shape(shape: Sequence[int] | None) -> tuple[int, int] | None:
|
||||
if shape is None:
|
||||
return None
|
||||
normalized = tuple(int(v) for v in shape)
|
||||
if len(normalized) != 2:
|
||||
raise ValueError(f'resize shape must contain exactly two values, got {normalized}')
|
||||
return normalized
|
||||
|
||||
@property
|
||||
def output_dim(self) -> int:
|
||||
return self._output_dim
|
||||
|
||||
def _freeze_encoder(self) -> None:
|
||||
self.encoder.eval()
|
||||
for param in self.encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
super().train(mode)
|
||||
if self.freeze_backbone:
|
||||
self._freeze_encoder()
|
||||
return self
|
||||
|
||||
def _ordered_camera_names(self, images: Dict[str, torch.Tensor]) -> Tuple[str, ...]:
|
||||
missing = [camera_name for camera_name in self.camera_names if camera_name not in images]
|
||||
if missing:
|
||||
raise ValueError(
|
||||
f'image input missing required cameras. missing={missing}, expected={list(self.camera_names)}'
|
||||
)
|
||||
return self.camera_names
|
||||
|
||||
def _prepare_pixels(self, image: torch.Tensor) -> torch.Tensor:
|
||||
if image.ndim != 5:
|
||||
raise ValueError(f'expected image tensor shaped (B, T, C, H, W), got {tuple(image.shape)}')
|
||||
pixels = image.reshape(-1, *image.shape[2:]).contiguous().float()
|
||||
pixels = pixels.clamp(0.0, 1.0)
|
||||
return (pixels - self.mean) / self.std
|
||||
|
||||
def forward(self, images: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
camera_names = self._ordered_camera_names(images)
|
||||
reference_shape = images[camera_names[0]].shape
|
||||
batch_size, steps = reference_shape[:2]
|
||||
per_view_features = []
|
||||
for camera_name in camera_names:
|
||||
image = images[camera_name]
|
||||
if image.shape != reference_shape:
|
||||
raise ValueError(
|
||||
f'camera {camera_name!r} shape {tuple(image.shape)} does not match {tuple(reference_shape)}'
|
||||
)
|
||||
pixels = self._prepare_pixels(image)
|
||||
with torch.set_grad_enabled(torch.is_grad_enabled() and not self.freeze_backbone):
|
||||
encoded = self.encoder(pixel_values=pixels)
|
||||
pooled = encoded.pooler_output
|
||||
per_view_features.append(self.view_projector(pooled))
|
||||
features = torch.cat(per_view_features, dim=-1)
|
||||
return features.view(batch_size, steps, self.joint_output_dim)
|
||||
|
||||
|
||||
Siglip2DiffusionBackbone = SigLIP2DiffusionBackbone
|
||||
17
roboimi/vla/modules/projectors.py
Normal file
17
roboimi/vla/modules/projectors.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class LinearConditionProjector(nn.Module):
|
||||
"""Projects per-step visual+state conditioning to the head conditioning width."""
|
||||
|
||||
def __init__(self, input_dim: int, output_dim: int, bias: bool = True):
|
||||
super().__init__()
|
||||
self.input_dim = int(input_dim)
|
||||
self.output_dim = int(output_dim)
|
||||
self.linear = nn.Linear(self.input_dim, self.output_dim, bias=bias)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.linear(x)
|
||||
@@ -90,6 +90,24 @@ class _FakeRenderer:
|
||||
|
||||
|
||||
class EvalVLAHeadlessTest(unittest.TestCase):
|
||||
def test_prepare_observation_skips_resize_when_image_resize_shape_is_none(self):
|
||||
obs = {
|
||||
"images": {
|
||||
"front": np.arange(8 * 8 * 3, dtype=np.uint8).reshape(8, 8, 3),
|
||||
},
|
||||
"qpos": np.zeros(16, dtype=np.float32),
|
||||
}
|
||||
|
||||
with mock.patch("cv2.resize", side_effect=AssertionError("resize should be skipped")):
|
||||
prepared = eval_vla.prepare_observation(
|
||||
obs,
|
||||
["front"],
|
||||
image_resize_shape=None,
|
||||
)
|
||||
|
||||
self.assertEqual(tuple(prepared["images"]["front"].shape), (3, 8, 8))
|
||||
self.assertEqual(tuple(prepared["qpos"].shape), (16,))
|
||||
|
||||
def test_headless_eval_sets_mujoco_gl_to_egl_when_display_missing(self):
|
||||
cfg = OmegaConf.create({"eval": {"headless": True}})
|
||||
with mock.patch.dict(eval_vla.os.environ, {}, clear=True):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import contextlib
|
||||
import importlib
|
||||
import importlib.machinery
|
||||
import sys
|
||||
import types
|
||||
import unittest
|
||||
@@ -69,6 +70,68 @@ class _FakeRearrange(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class _FakeViTConfig:
|
||||
def __init__(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class _FakeViTModel(nn.Module):
|
||||
def __init__(self, config, add_pooling_layer=False):
|
||||
super().__init__()
|
||||
del add_pooling_layer
|
||||
self.config = config
|
||||
hidden_size = int(getattr(config, 'hidden_size', 192))
|
||||
self.proj = nn.Linear(hidden_size, hidden_size)
|
||||
|
||||
def forward(self, pixel_values=None, interpolate_pos_encoding=False, **kwargs):
|
||||
del interpolate_pos_encoding, kwargs
|
||||
batch_size = pixel_values.shape[0]
|
||||
hidden_size = int(getattr(self.config, 'hidden_size', 192))
|
||||
seq_len = 2
|
||||
last_hidden_state = torch.zeros(batch_size, seq_len, hidden_size, dtype=pixel_values.dtype, device=pixel_values.device)
|
||||
return types.SimpleNamespace(last_hidden_state=last_hidden_state)
|
||||
|
||||
|
||||
class _FakeSiglipVisionOutput:
|
||||
def __init__(self, pooler_output):
|
||||
self.pooler_output = pooler_output
|
||||
|
||||
|
||||
class _FakeSiglipVisionConfig:
|
||||
def __init__(self, hidden_size=768, image_size=256):
|
||||
self.hidden_size = hidden_size
|
||||
self.image_size = image_size
|
||||
|
||||
|
||||
class _FakeSiglipVisionModel(nn.Module):
|
||||
load_calls = []
|
||||
|
||||
def __init__(self, hidden_size=768):
|
||||
super().__init__()
|
||||
self.config = _FakeSiglipVisionConfig(hidden_size=hidden_size)
|
||||
self.scale = nn.Parameter(torch.tensor(1.0))
|
||||
self.forward_calls = []
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
||||
model = cls()
|
||||
cls.load_calls.append({
|
||||
'pretrained_model_name_or_path': pretrained_model_name_or_path,
|
||||
'args': args,
|
||||
'kwargs': kwargs,
|
||||
})
|
||||
return model
|
||||
|
||||
def forward(self, pixel_values=None, **kwargs):
|
||||
self.forward_calls.append({
|
||||
'pixel_values': pixel_values.detach().clone(),
|
||||
'kwargs': dict(kwargs),
|
||||
})
|
||||
pooled = pixel_values.mean(dim=(2, 3), keepdim=False) * self.scale
|
||||
return _FakeSiglipVisionOutput(pooler_output=pooled)
|
||||
|
||||
|
||||
class _StubIMFHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -105,6 +168,11 @@ class _StubIMFHead(nn.Module):
|
||||
def _stub_optional_modules(include_imf_head=False):
|
||||
previous_modules = {}
|
||||
|
||||
def remember_and_remove(name):
|
||||
if name not in previous_modules:
|
||||
previous_modules[name] = sys.modules.get(name, _MISSING)
|
||||
sys.modules.pop(name, None)
|
||||
|
||||
def inject(name, module):
|
||||
if name not in previous_modules:
|
||||
previous_modules[name] = sys.modules.get(name, _MISSING)
|
||||
@@ -125,6 +193,9 @@ def _stub_optional_modules(include_imf_head=False):
|
||||
torchvision_module = types.ModuleType('torchvision')
|
||||
models_module = types.ModuleType('torchvision.models')
|
||||
transforms_module = types.ModuleType('torchvision.transforms')
|
||||
torchvision_module.__spec__ = importlib.machinery.ModuleSpec('torchvision', loader=None)
|
||||
models_module.__spec__ = importlib.machinery.ModuleSpec('torchvision.models', loader=None)
|
||||
transforms_module.__spec__ = importlib.machinery.ModuleSpec('torchvision.transforms', loader=None)
|
||||
models_module.resnet18 = lambda weights=None: _FakeResNet()
|
||||
transforms_module.CenterCrop = _IdentityCrop
|
||||
transforms_module.RandomCrop = _IdentityCrop
|
||||
@@ -139,7 +210,14 @@ def _stub_optional_modules(include_imf_head=False):
|
||||
einops_module.layers = einops_layers_module
|
||||
einops_layers_module.torch = einops_layers_torch_module
|
||||
|
||||
transformers_module = types.ModuleType('transformers')
|
||||
transformers_module.__spec__ = importlib.machinery.ModuleSpec('transformers', loader=None)
|
||||
transformers_module.ViTConfig = _FakeViTConfig
|
||||
transformers_module.ViTModel = _FakeViTModel
|
||||
transformers_module.SiglipVisionModel = _FakeSiglipVisionModel
|
||||
|
||||
try:
|
||||
remember_and_remove('roboimi.vla.models.backbones.siglip2_diffusion_backbone')
|
||||
inject('diffusers', diffusers_module)
|
||||
inject('diffusers.schedulers', schedulers_module)
|
||||
inject('diffusers.schedulers.scheduling_ddpm', ddpm_module)
|
||||
@@ -150,6 +228,7 @@ def _stub_optional_modules(include_imf_head=False):
|
||||
inject('einops', einops_module)
|
||||
inject('einops.layers', einops_layers_module)
|
||||
inject('einops.layers.torch', einops_layers_torch_module)
|
||||
inject('transformers', transformers_module)
|
||||
|
||||
if include_imf_head:
|
||||
import roboimi.vla.models.heads as heads_package
|
||||
@@ -200,6 +279,67 @@ class _StubVisionBackbone(nn.Module):
|
||||
return torch.cat(per_camera_features, dim=-1)
|
||||
|
||||
|
||||
class _StubJointVisionBackbone(nn.Module):
|
||||
joint_output_dim = 5
|
||||
output_dim = 5
|
||||
|
||||
def __init__(self, camera_names=_CAMERA_NAMES):
|
||||
super().__init__()
|
||||
self.camera_names = tuple(camera_names)
|
||||
self.num_cameras = len(self.camera_names)
|
||||
|
||||
def forward(self, images):
|
||||
batch_size, obs_horizon = next(iter(images.values())).shape[:2]
|
||||
features = []
|
||||
for camera_name in ('front', 'top', 'r_vis'):
|
||||
image_batch = images[camera_name]
|
||||
features.append(image_batch.mean(dim=(2, 3, 4), keepdim=False).unsqueeze(-1))
|
||||
joint_features = torch.cat(features, dim=-1)
|
||||
front_top_sum = joint_features[..., :2].sum(dim=-1, keepdim=True)
|
||||
r_vis_minus_front = (joint_features[..., 2:] - joint_features[..., :1])
|
||||
time_marker = torch.arange(obs_horizon, dtype=joint_features.dtype).view(1, obs_horizon, 1)
|
||||
time_marker = time_marker.expand(batch_size, -1, -1)
|
||||
return torch.cat([joint_features, front_top_sum, r_vis_minus_front + time_marker], dim=-1)
|
||||
|
||||
|
||||
class _StubMultiTokenVisionBackbone(nn.Module):
|
||||
output_dim = 2
|
||||
tokens_per_step = 3
|
||||
|
||||
def __init__(self, camera_names=_CAMERA_NAMES):
|
||||
super().__init__()
|
||||
self.camera_names = tuple(camera_names)
|
||||
self.num_cameras = len(self.camera_names)
|
||||
|
||||
def forward(self, images):
|
||||
batch_size, obs_horizon = next(iter(images.values())).shape[:2]
|
||||
features = []
|
||||
time_marker = torch.arange(obs_horizon, dtype=torch.float32).view(1, obs_horizon, 1).expand(batch_size, -1, -1)
|
||||
for camera_name in self.camera_names:
|
||||
image_batch = images[camera_name]
|
||||
camera_marker = image_batch.mean(dim=(2, 3, 4), keepdim=False).unsqueeze(-1)
|
||||
features.append(torch.cat([camera_marker, camera_marker + time_marker], dim=-1))
|
||||
return torch.stack(features, dim=2)
|
||||
|
||||
|
||||
class _StubMultiTokenVisionBackbone(nn.Module):
|
||||
output_dim = 2
|
||||
tokens_per_step = 3
|
||||
|
||||
def __init__(self, camera_names=_CAMERA_NAMES):
|
||||
super().__init__()
|
||||
self.camera_names = tuple(camera_names)
|
||||
self.num_cameras = len(self.camera_names)
|
||||
|
||||
def forward(self, images):
|
||||
per_camera = []
|
||||
for camera_name in self.camera_names:
|
||||
image_batch = images[camera_name]
|
||||
base = image_batch.mean(dim=(2, 3, 4), keepdim=False)
|
||||
per_camera.append(torch.stack([base, base + 0.5], dim=-1))
|
||||
return torch.stack(per_camera, dim=2)
|
||||
|
||||
|
||||
class _RecordingLinearIMFHead(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -390,6 +530,178 @@ class IMFVLAAgentTest(unittest.TestCase):
|
||||
self.assertTrue(torch.equal(third_action, second_chunk[0, 1]))
|
||||
self.assertEqual(mock_predict_chunk.call_count, 2)
|
||||
|
||||
def test_joint_visual_backbone_uses_joint_output_dim_for_conditioning(self):
|
||||
agent_cls, _agent_module = _load_imf_agent_class()
|
||||
head = _RecordingLinearIMFHead()
|
||||
vision_backbone = _StubJointVisionBackbone()
|
||||
agent = agent_cls(
|
||||
vision_backbone=vision_backbone,
|
||||
state_encoder=nn.Identity(),
|
||||
action_encoder=nn.Identity(),
|
||||
head=head,
|
||||
action_dim=2,
|
||||
obs_dim=1,
|
||||
pred_horizon=3,
|
||||
obs_horizon=2,
|
||||
diffusion_steps=10,
|
||||
inference_steps=1,
|
||||
num_cams=len(_CAMERA_NAMES),
|
||||
camera_names=_CAMERA_NAMES,
|
||||
num_action_steps=2,
|
||||
head_type='transformer',
|
||||
)
|
||||
|
||||
self.assertEqual(agent.per_step_cond_dim, vision_backbone.joint_output_dim + agent.obs_dim)
|
||||
self.assertEqual(
|
||||
agent.global_cond_dim,
|
||||
vision_backbone.joint_output_dim * agent.obs_horizon + agent.obs_dim * agent.obs_horizon,
|
||||
)
|
||||
|
||||
images = _make_images(
|
||||
batch_size=1,
|
||||
obs_horizon=2,
|
||||
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
|
||||
)
|
||||
qpos = torch.tensor([[[1.0], [2.0]]], dtype=torch.float32)
|
||||
initial_noise = torch.tensor(
|
||||
[[[1.0, -1.0], [0.0, 2.0], [3.0, -2.0]]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
with mock.patch.object(torch, 'randn', return_value=initial_noise):
|
||||
predicted_actions = agent.predict_action(images, qpos)
|
||||
|
||||
self.assertEqual(predicted_actions.shape, (1, 3, 2))
|
||||
self.assertEqual(len(head.calls), 1)
|
||||
expected_cond = torch.tensor(
|
||||
[[[30.0, 20.0, 10.0, 50.0, -20.0, 1.0], [30.0, 20.0, 10.0, 50.0, -19.0, 2.0]]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
self.assertEqual(head.calls[0]['cond'].shape[-1], 6)
|
||||
self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_cond))
|
||||
|
||||
def test_multitoken_visual_backbone_flattens_camera_tokens_and_projects_each_with_state(self):
|
||||
agent_cls, _agent_module = _load_imf_agent_class()
|
||||
head = _RecordingLinearIMFHead()
|
||||
projector = nn.Linear(3, 4, bias=False)
|
||||
with torch.no_grad():
|
||||
projector.weight.copy_(
|
||||
torch.tensor(
|
||||
[
|
||||
[1.0, 0.0, 0.0],
|
||||
[0.0, 1.0, 0.0],
|
||||
[0.0, 0.0, 1.0],
|
||||
[1.0, 0.0, 1.0],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
)
|
||||
agent = agent_cls(
|
||||
vision_backbone=_StubMultiTokenVisionBackbone(),
|
||||
state_encoder=nn.Identity(),
|
||||
action_encoder=nn.Identity(),
|
||||
head=head,
|
||||
action_dim=2,
|
||||
obs_dim=1,
|
||||
pred_horizon=3,
|
||||
obs_horizon=2,
|
||||
diffusion_steps=10,
|
||||
inference_steps=1,
|
||||
num_cams=len(_CAMERA_NAMES),
|
||||
camera_names=_CAMERA_NAMES,
|
||||
num_action_steps=2,
|
||||
head_type='transformer',
|
||||
cond_projector=projector,
|
||||
)
|
||||
|
||||
self.assertEqual(agent.condition_tokens_per_step, 3)
|
||||
self.assertEqual(agent.condition_sequence_length, 6)
|
||||
self.assertEqual(agent.per_step_cond_dim, 4)
|
||||
self.assertEqual(agent.global_cond_dim, 24)
|
||||
|
||||
images = _make_images(
|
||||
batch_size=1,
|
||||
obs_horizon=2,
|
||||
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
|
||||
)
|
||||
qpos = torch.tensor([[[1.0], [2.0]]], dtype=torch.float32)
|
||||
cond = agent._build_cond(images, qpos)
|
||||
|
||||
expected = torch.tensor(
|
||||
[
|
||||
[
|
||||
[10.0, 10.5, 1.0, 11.0],
|
||||
[20.0, 20.5, 1.0, 21.0],
|
||||
[30.0, 30.5, 1.0, 31.0],
|
||||
[10.0, 10.5, 2.0, 12.0],
|
||||
[20.0, 20.5, 2.0, 22.0],
|
||||
[30.0, 30.5, 2.0, 32.0],
|
||||
]
|
||||
],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
self.assertEqual(cond.shape, (1, 6, 4))
|
||||
self.assertTrue(torch.allclose(cond, expected))
|
||||
|
||||
def test_multi_token_visual_backbone_pairs_state_per_camera_and_flattens_condition_sequence(self):
|
||||
agent_cls, agent_module = _load_imf_agent_class()
|
||||
head = _RecordingLinearIMFHead()
|
||||
cond_projector = nn.Linear(3, 4, bias=False)
|
||||
with torch.no_grad():
|
||||
cond_projector.weight.copy_(torch.tensor([
|
||||
[1.0, 0.0, 0.0],
|
||||
[0.0, 1.0, 0.0],
|
||||
[0.0, 0.0, 1.0],
|
||||
[1.0, 0.0, 1.0],
|
||||
], dtype=torch.float32))
|
||||
|
||||
agent = agent_cls(
|
||||
vision_backbone=_StubMultiTokenVisionBackbone(),
|
||||
state_encoder=nn.Identity(),
|
||||
action_encoder=nn.Identity(),
|
||||
head=head,
|
||||
action_dim=2,
|
||||
obs_dim=1,
|
||||
pred_horizon=3,
|
||||
obs_horizon=2,
|
||||
diffusion_steps=10,
|
||||
inference_steps=1,
|
||||
num_cams=len(_CAMERA_NAMES),
|
||||
camera_names=_CAMERA_NAMES,
|
||||
num_action_steps=2,
|
||||
head_type='transformer',
|
||||
cond_projector=cond_projector,
|
||||
)
|
||||
agent.infer_scheduler = _ForbiddenScheduler()
|
||||
|
||||
images = _make_images(
|
||||
batch_size=1,
|
||||
obs_horizon=2,
|
||||
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
|
||||
)
|
||||
qpos = torch.tensor([[[1.0], [2.0]]], dtype=torch.float32)
|
||||
initial_noise = torch.tensor([[[1.0, -1.0], [0.0, 2.0], [3.0, -2.0]]], dtype=torch.float32)
|
||||
|
||||
with mock.patch.object(agent_module.torch, 'randn', return_value=initial_noise):
|
||||
predicted_actions = agent.predict_action(images, qpos)
|
||||
|
||||
expected_cond = torch.tensor([[[10.0, 10.5, 1.0, 11.0],
|
||||
[20.0, 20.5, 1.0, 21.0],
|
||||
[30.0, 30.5, 1.0, 31.0],
|
||||
[10.0, 10.5, 2.0, 12.0],
|
||||
[20.0, 20.5, 2.0, 22.0],
|
||||
[30.0, 30.5, 2.0, 32.0]]], dtype=torch.float32)
|
||||
|
||||
self.assertEqual(agent.condition_tokens_per_step, 3)
|
||||
self.assertEqual(agent.condition_sequence_length, 6)
|
||||
self.assertEqual(agent.raw_per_step_cond_dim, 3)
|
||||
self.assertEqual(agent.per_step_cond_dim, 4)
|
||||
self.assertEqual(agent.global_cond_dim, 24)
|
||||
self.assertEqual(predicted_actions.shape, (1, 3, 2))
|
||||
self.assertEqual(len(head.calls), 1)
|
||||
self.assertEqual(head.calls[0]['cond'].shape, (1, 6, 4))
|
||||
self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_cond))
|
||||
|
||||
def test_hydra_config_instantiates_resnet_imf_attnres_with_stub_head(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
@@ -448,6 +760,130 @@ class IMFVLAAgentTest(unittest.TestCase):
|
||||
self.assertEqual(agent.per_step_cond_dim, 64 * agent.num_cams + agent.obs_dim)
|
||||
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], agent.per_step_cond_dim)
|
||||
|
||||
def test_hydra_config_instantiates_lewm_imf_attnres_with_joint_visual_condition_dim(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
'agent=lewm_imf_attnres',
|
||||
'agent.vision_backbone.checkpoint_path=null',
|
||||
'agent.head.n_layer=1',
|
||||
'agent.head.n_emb=16',
|
||||
]
|
||||
)
|
||||
|
||||
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
|
||||
self.assertEqual(cfg.agent.vision_backbone._target_, 'roboimi.vla.models.backbones.lewm_vit_backbone.LEWMViTBackbone')
|
||||
self.assertEqual(list(cfg.agent.camera_names), list(_CAMERA_NAMES))
|
||||
self.assertEqual(list(cfg.agent.vision_backbone.camera_names), list(_CAMERA_NAMES))
|
||||
self.assertEqual(list(cfg.agent.vision_backbone.fused_camera_names), ['front', 'top', 'r_vis'])
|
||||
self.assertIsNone(cfg.agent.vision_backbone.dataset_image_resize_shape)
|
||||
self.assertEqual(list(cfg.agent.vision_backbone.eval_image_resize_shape), [256, 256])
|
||||
self.assertEqual(cfg.agent.head.cond_dim, 208)
|
||||
|
||||
with _stub_optional_modules(include_imf_head=True):
|
||||
agent = instantiate(cfg.agent)
|
||||
|
||||
self.assertEqual(agent.per_step_cond_dim, agent.vision_encoder.joint_output_dim + agent.obs_dim)
|
||||
self.assertEqual(agent.per_step_cond_dim, 208)
|
||||
self.assertEqual(agent.global_cond_dim, agent.obs_horizon * 208)
|
||||
self.assertIsNone(agent.vision_encoder.dataset_image_resize_shape)
|
||||
self.assertEqual(agent.vision_encoder.eval_image_resize_shape, (256, 256))
|
||||
self.assertIsInstance(agent.noise_pred_net, _StubIMFHead)
|
||||
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], 208)
|
||||
|
||||
def test_hydra_config_instantiates_resnet_imf_attnres_multitoken_with_projected_camera_tokens(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
'agent=resnet_imf_attnres_multitoken',
|
||||
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||
'agent.vision_backbone.input_shape=[3,16,16]',
|
||||
'agent.head.n_layer=1',
|
||||
'agent.head.n_emb=32',
|
||||
]
|
||||
)
|
||||
|
||||
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
|
||||
self.assertEqual(cfg.agent.vision_backbone.vision_backbone_mode, 'resnet')
|
||||
self.assertTrue(cfg.agent.vision_backbone.use_separate_rgb_encoder_per_camera)
|
||||
self.assertTrue(cfg.agent.vision_backbone.output_tokens_per_camera)
|
||||
self.assertEqual(cfg.agent.cond_projector.output_dim, 32)
|
||||
self.assertEqual(cfg.agent.head.cond_dim, 32)
|
||||
|
||||
with _stub_optional_modules(include_imf_head=True):
|
||||
agent = instantiate(cfg.agent)
|
||||
|
||||
self.assertEqual(agent.condition_tokens_per_step, 3)
|
||||
self.assertEqual(agent.condition_sequence_length, agent.obs_horizon * 3)
|
||||
self.assertEqual(agent.per_step_cond_dim, 32)
|
||||
self.assertEqual(agent.global_cond_dim, agent.condition_sequence_length * 32)
|
||||
self.assertIsInstance(agent.noise_pred_net, _StubIMFHead)
|
||||
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], 32)
|
||||
self.assertEqual(agent.noise_pred_net.constructor_kwargs['n_obs_steps'], 6)
|
||||
|
||||
|
||||
def test_hydra_config_instantiates_siglip2_imf_attnres_with_condition_projection(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
'agent=siglip2_imf_attnres',
|
||||
'agent.vision_backbone.per_view_output_dim=96',
|
||||
'agent.head.n_layer=1',
|
||||
'agent.head.n_emb=16',
|
||||
'agent.cond_projector.output_dim=384',
|
||||
]
|
||||
)
|
||||
|
||||
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
|
||||
self.assertEqual(
|
||||
cfg.agent.vision_backbone._target_,
|
||||
'roboimi.vla.models.backbones.siglip2_diffusion_backbone.SigLIP2DiffusionBackbone',
|
||||
)
|
||||
self.assertEqual(list(cfg.agent.camera_names), list(_CAMERA_NAMES))
|
||||
self.assertIsNone(cfg.agent.vision_backbone.dataset_image_resize_shape)
|
||||
self.assertEqual(list(cfg.agent.vision_backbone.eval_image_resize_shape), [256, 256])
|
||||
self.assertEqual(cfg.agent.head.cond_dim, 384)
|
||||
|
||||
with _stub_optional_modules(include_imf_head=True):
|
||||
agent = instantiate(cfg.agent)
|
||||
|
||||
self.assertEqual(agent.raw_per_step_cond_dim, 3 * 96 + agent.obs_dim)
|
||||
self.assertEqual(agent.per_step_cond_dim, 384)
|
||||
self.assertEqual(agent.global_cond_dim, agent.obs_horizon * 384)
|
||||
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], 384)
|
||||
self.assertEqual(agent.vision_encoder.output_dim, 96)
|
||||
self.assertEqual(agent.vision_encoder.eval_image_resize_shape, (256, 256))
|
||||
|
||||
|
||||
def test_hydra_config_instantiates_resnet_imf_attnres_multitoken_with_sequence_length_three_times_obs_horizon(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
'agent=resnet_imf_attnres_multitoken',
|
||||
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||
'agent.vision_backbone.input_shape=[3,16,16]',
|
||||
'agent.vision_backbone.freeze_backbone=false',
|
||||
'agent.head.n_layer=1',
|
||||
'agent.head.n_emb=16',
|
||||
]
|
||||
)
|
||||
|
||||
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
|
||||
self.assertEqual(list(cfg.agent.camera_names), list(_CAMERA_NAMES))
|
||||
self.assertTrue(cfg.agent.vision_backbone.use_separate_rgb_encoder_per_camera)
|
||||
self.assertTrue(cfg.agent.vision_backbone.output_tokens_per_camera)
|
||||
self.assertEqual(cfg.agent.vision_backbone.vision_backbone_mode, 'resnet')
|
||||
self.assertEqual(cfg.agent.cond_projector.output_dim, 16)
|
||||
self.assertEqual(cfg.agent.head.cond_dim, 16)
|
||||
|
||||
with _stub_optional_modules(include_imf_head=True):
|
||||
agent = instantiate(cfg.agent)
|
||||
|
||||
self.assertEqual(agent.condition_tokens_per_step, 3)
|
||||
self.assertEqual(agent.condition_sequence_length, agent.obs_horizon * 3)
|
||||
self.assertEqual(agent.per_step_cond_dim, 16)
|
||||
self.assertEqual(agent.global_cond_dim, agent.condition_sequence_length * 16)
|
||||
self.assertEqual(agent.vision_encoder.tokens_per_step, 3)
|
||||
self.assertIsInstance(agent.noise_pred_net, _StubIMFHead)
|
||||
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], 16)
|
||||
self.assertEqual(agent.noise_pred_net.constructor_kwargs['n_obs_steps'], agent.condition_sequence_length)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
220
tests/test_lewm_vit_backbone.py
Normal file
220
tests/test_lewm_vit_backbone.py
Normal file
@@ -0,0 +1,220 @@
|
||||
import tempfile
|
||||
import types
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import ViTConfig, ViTModel
|
||||
|
||||
|
||||
_INPUT_CAMERA_NAMES = ("r_vis", "top", "front")
|
||||
_FUSED_CAMERA_NAMES = ("front", "top", "r_vis")
|
||||
|
||||
|
||||
class _ReferenceProjector(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(192, 2048),
|
||||
nn.BatchNorm1d(2048),
|
||||
nn.GELU(),
|
||||
nn.Linear(2048, 192),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def _build_reference_encoder() -> ViTModel:
|
||||
return ViTModel(
|
||||
ViTConfig(
|
||||
image_size=224,
|
||||
patch_size=14,
|
||||
num_channels=3,
|
||||
hidden_size=192,
|
||||
intermediate_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=3,
|
||||
qkv_bias=True,
|
||||
),
|
||||
add_pooling_layer=False,
|
||||
)
|
||||
|
||||
|
||||
def _write_synthetic_lightning_ckpt(path: Path):
|
||||
torch.manual_seed(7)
|
||||
encoder = _build_reference_encoder()
|
||||
projector = _ReferenceProjector()
|
||||
lightning_state_dict = {}
|
||||
for key, value in encoder.state_dict().items():
|
||||
lightning_state_dict[f"model.encoder.{key}"] = value.detach().clone()
|
||||
for key, value in projector.state_dict().items():
|
||||
lightning_state_dict[f"model.projector.{key}"] = value.detach().clone()
|
||||
torch.save({"state_dict": lightning_state_dict}, path)
|
||||
return encoder.state_dict(), projector.state_dict()
|
||||
|
||||
|
||||
class LEWMViTBackboneTest(unittest.TestCase):
|
||||
def test_loads_lightning_encoder_and_projector_checkpoint_and_emits_joint_embedding(self):
|
||||
from roboimi.vla.models.backbones.lewm_vit_backbone import LEWMViTBackbone
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
ckpt_path = Path(tmpdir) / "synthetic-lewm.ckpt"
|
||||
reference_encoder_state, reference_projector_state = _write_synthetic_lightning_ckpt(
|
||||
ckpt_path
|
||||
)
|
||||
|
||||
backbone = LEWMViTBackbone(
|
||||
checkpoint_path=ckpt_path,
|
||||
camera_names=_INPUT_CAMERA_NAMES,
|
||||
fused_camera_names=_FUSED_CAMERA_NAMES,
|
||||
freeze_backbone=True,
|
||||
)
|
||||
|
||||
self.assertEqual(backbone.camera_names, _INPUT_CAMERA_NAMES)
|
||||
self.assertEqual(backbone.fused_camera_names, _FUSED_CAMERA_NAMES)
|
||||
self.assertEqual(backbone.num_cameras, 3)
|
||||
self.assertEqual(backbone.joint_output_dim, 192)
|
||||
self.assertEqual(backbone.output_dim, 192)
|
||||
self.assertEqual(backbone.encoder.config.hidden_size, 192)
|
||||
self.assertEqual(backbone.encoder.config.patch_size, 14)
|
||||
self.assertEqual(backbone.encoder.config.num_hidden_layers, 12)
|
||||
self.assertEqual(backbone.encoder.config.num_attention_heads, 3)
|
||||
|
||||
for key, value in reference_encoder_state.items():
|
||||
self.assertTrue(torch.equal(backbone.encoder.state_dict()[key], value), key)
|
||||
for key, value in reference_projector_state.items():
|
||||
self.assertTrue(torch.equal(backbone.projector.state_dict()[key], value), key)
|
||||
|
||||
images = {
|
||||
cam_name: torch.rand(1, 1, 3, 224, 224)
|
||||
for cam_name in _INPUT_CAMERA_NAMES
|
||||
}
|
||||
output = backbone(images)
|
||||
|
||||
self.assertEqual(output.shape, (1, 1, 192))
|
||||
self.assertFalse(output.requires_grad)
|
||||
|
||||
def test_forward_uses_front_top_rvis_fusion_order_and_exact_lewm_cwh_resize_path(self):
|
||||
from roboimi.vla.models.backbones.lewm_vit_backbone import LEWMViTBackbone
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
ckpt_path = Path(tmpdir) / "synthetic-lewm.ckpt"
|
||||
_write_synthetic_lightning_ckpt(ckpt_path)
|
||||
|
||||
backbone = LEWMViTBackbone(
|
||||
checkpoint_path=ckpt_path,
|
||||
camera_names=_INPUT_CAMERA_NAMES,
|
||||
fused_camera_names=_FUSED_CAMERA_NAMES,
|
||||
freeze_backbone=True,
|
||||
)
|
||||
captured = {}
|
||||
|
||||
def fake_encoder_forward(module, pixel_values, interpolate_pos_encoding=False, **kwargs):
|
||||
del module, kwargs
|
||||
captured["pixel_values"] = pixel_values.detach().clone()
|
||||
captured["interpolate_pos_encoding"] = interpolate_pos_encoding
|
||||
batch = pixel_values.shape[0]
|
||||
patch_tokens = (pixel_values.shape[-2] // 14) * (pixel_values.shape[-1] // 14)
|
||||
cls = (
|
||||
torch.arange(192, dtype=pixel_values.dtype, device=pixel_values.device)
|
||||
.unsqueeze(0)
|
||||
.expand(batch, -1)
|
||||
)
|
||||
last_hidden_state = torch.zeros(
|
||||
batch,
|
||||
patch_tokens + 1,
|
||||
192,
|
||||
dtype=pixel_values.dtype,
|
||||
device=pixel_values.device,
|
||||
)
|
||||
last_hidden_state[:, 0] = cls
|
||||
return types.SimpleNamespace(last_hidden_state=last_hidden_state)
|
||||
|
||||
backbone.encoder.forward = types.MethodType(fake_encoder_forward, backbone.encoder)
|
||||
|
||||
r_vis = torch.full((1, 1, 3, 256, 256), 0.30)
|
||||
top = torch.full((1, 1, 3, 256, 256), 0.20)
|
||||
front = torch.full((1, 1, 3, 256, 256), 0.10)
|
||||
bn = backbone.projector.net[1]
|
||||
running_mean_before = bn.running_mean.detach().clone()
|
||||
running_var_before = bn.running_var.detach().clone()
|
||||
|
||||
backbone.train()
|
||||
self.assertFalse(backbone.encoder.training)
|
||||
self.assertFalse(backbone.projector.training)
|
||||
|
||||
output = backbone({"r_vis": r_vis, "top": top, "front": front})
|
||||
|
||||
self.assertEqual(output.shape, (1, 1, 192))
|
||||
self.assertEqual(captured["pixel_values"].shape, (1, 3, 672, 224))
|
||||
self.assertTrue(captured["interpolate_pos_encoding"])
|
||||
|
||||
normalized_views = [
|
||||
((view.reshape(-1, *view.shape[2:]).float()).clamp(0.0, 1.0) - backbone.mean) / backbone.std
|
||||
for view in (front, top, r_vis)
|
||||
]
|
||||
expected_fuse_then_resize = F.interpolate(
|
||||
torch.cat(normalized_views, dim=-2),
|
||||
size=(672, 224),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
antialias=True,
|
||||
)
|
||||
expected_pre_resize_then_fuse = torch.cat(
|
||||
[
|
||||
F.interpolate(
|
||||
view,
|
||||
size=(224, 224),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
antialias=True,
|
||||
)
|
||||
for view in normalized_views
|
||||
],
|
||||
dim=-2,
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(captured["pixel_values"], expected_fuse_then_resize, atol=1e-6, rtol=1e-6)
|
||||
)
|
||||
self.assertFalse(
|
||||
torch.allclose(
|
||||
expected_fuse_then_resize,
|
||||
expected_pre_resize_then_fuse,
|
||||
atol=1e-6,
|
||||
rtol=1e-6,
|
||||
)
|
||||
)
|
||||
self.assertFalse(
|
||||
torch.allclose(
|
||||
captured["pixel_values"],
|
||||
expected_pre_resize_then_fuse,
|
||||
atol=1e-6,
|
||||
rtol=1e-6,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
captured["pixel_values"][0, :, 223, :],
|
||||
expected_fuse_then_resize[0, :, 223, :],
|
||||
atol=1e-6,
|
||||
rtol=1e-6,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
captured["pixel_values"][0, :, 447, :],
|
||||
expected_fuse_then_resize[0, :, 447, :],
|
||||
atol=1e-6,
|
||||
rtol=1e-6,
|
||||
)
|
||||
)
|
||||
self.assertTrue(torch.equal(bn.running_mean, running_mean_before))
|
||||
self.assertTrue(torch.equal(bn.running_var, running_var_before))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -180,6 +180,14 @@ def _extract_camera_markers(cond, feature_dim, num_cams):
|
||||
return camera_block[:, 0]
|
||||
|
||||
|
||||
def _extract_token_camera_markers(tokens):
|
||||
return tokens[0, 0, :, 0]
|
||||
|
||||
|
||||
def _extract_token_markers(token_sequence):
|
||||
return token_sequence[0, 0, :, 0]
|
||||
|
||||
|
||||
class ResNetTransformerAgentWiringTest(unittest.TestCase):
|
||||
def test_hydra_wiring_uses_required_three_camera_transformer_conditioning_in_agent_order_and_ignores_extra_keys(self):
|
||||
cfg = _compose_cfg(
|
||||
@@ -246,6 +254,36 @@ class ResNetTransformerAgentWiringTest(unittest.TestCase):
|
||||
with self.assertRaisesRegex(ValueError, 'missing=.*top'):
|
||||
agent.predict_action(missing_images, proprioception)
|
||||
|
||||
def test_multitoken_resnet_backbone_emits_one_token_per_camera_in_agent_order(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
'agent=resnet_imf_attnres_multitoken',
|
||||
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||
'agent.vision_backbone.input_shape=[3,16,16]',
|
||||
]
|
||||
)
|
||||
|
||||
with _stub_optional_modules():
|
||||
backbone = instantiate(cfg.agent.vision_backbone)
|
||||
_patch_backbone_for_order_tracking(backbone)
|
||||
images = _make_images(
|
||||
batch_size=1,
|
||||
obs_horizon=cfg.agent.obs_horizon,
|
||||
image_shape=tuple(cfg.agent.vision_backbone.input_shape),
|
||||
per_camera_fill={
|
||||
'front': 30.0,
|
||||
'top': 20.0,
|
||||
'r_vis': 10.0,
|
||||
'left_wrist': 99.0,
|
||||
},
|
||||
)
|
||||
tokens = backbone(images)
|
||||
|
||||
self.assertEqual(tokens.shape, (1, cfg.agent.obs_horizon, 3, backbone.output_dim))
|
||||
self.assertEqual(backbone.tokens_per_step, 3)
|
||||
camera_markers = _extract_token_camera_markers(tokens)
|
||||
self.assertTrue(torch.allclose(camera_markers, torch.tensor([10.0, 20.0, 30.0])))
|
||||
|
||||
def test_agent_rejects_conflicting_explicit_backbone_camera_names(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
@@ -382,6 +420,36 @@ class ResNetTransformerAgentWiringTest(unittest.TestCase):
|
||||
with self.assertRaisesRegex(InstantiationException, 'num_cams'):
|
||||
instantiate(cfg.agent)
|
||||
|
||||
def test_multitoken_resnet_backbone_emits_one_token_per_camera_in_agent_order(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
'agent=resnet_imf_attnres_multitoken',
|
||||
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||
'agent.vision_backbone.input_shape=[3,16,16]',
|
||||
'agent.head.n_layer=1',
|
||||
'agent.head.n_emb=32',
|
||||
]
|
||||
)
|
||||
|
||||
with _stub_optional_modules():
|
||||
backbone = instantiate(cfg.agent.vision_backbone)
|
||||
_patch_backbone_for_order_tracking(backbone)
|
||||
images = _make_images(
|
||||
batch_size=1,
|
||||
obs_horizon=cfg.agent.obs_horizon,
|
||||
image_shape=tuple(cfg.agent.vision_backbone.input_shape),
|
||||
per_camera_fill={
|
||||
'front': 30.0,
|
||||
'top': 20.0,
|
||||
'r_vis': 10.0,
|
||||
},
|
||||
)
|
||||
output = backbone(images)
|
||||
|
||||
self.assertEqual(output.shape, (1, cfg.agent.obs_horizon, 3, backbone.output_dim))
|
||||
token_markers = _extract_token_markers(output)
|
||||
self.assertTrue(torch.allclose(token_markers, torch.tensor([10.0, 20.0, 30.0])))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
121
tests/test_siglip2_diffusion_backbone.py
Normal file
121
tests/test_siglip2_diffusion_backbone.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import types
|
||||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
_CAMERA_NAMES = ("r_vis", "top", "front")
|
||||
|
||||
|
||||
class _FakeSiglipVisionOutput:
|
||||
def __init__(self, pooler_output):
|
||||
self.pooler_output = pooler_output
|
||||
|
||||
|
||||
class _FakeSiglipVisionConfig:
|
||||
def __init__(self, hidden_size=768, image_size=256):
|
||||
self.hidden_size = hidden_size
|
||||
self.image_size = image_size
|
||||
|
||||
|
||||
class _FakeSiglipVisionModel(nn.Module):
|
||||
def __init__(self, hidden_size=768):
|
||||
super().__init__()
|
||||
self.config = _FakeSiglipVisionConfig(hidden_size=hidden_size)
|
||||
self.forward_calls = []
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
del args, kwargs
|
||||
return cls()
|
||||
|
||||
def forward(self, pixel_values=None, **kwargs):
|
||||
self.forward_calls.append({
|
||||
"pixel_values": pixel_values.detach().clone(),
|
||||
"kwargs": dict(kwargs),
|
||||
})
|
||||
pooled = pixel_values.mean(dim=(2, 3), keepdim=False)
|
||||
return _FakeSiglipVisionOutput(pooler_output=pooled)
|
||||
|
||||
|
||||
class SigLIP2DiffusionBackboneTest(unittest.TestCase):
|
||||
def test_forward_encodes_each_view_independently_and_concatenates_projected_features(self):
|
||||
from roboimi.vla.models.backbones.siglip2_diffusion_backbone import SigLIP2DiffusionBackbone
|
||||
|
||||
fake_model = _FakeSiglipVisionModel(hidden_size=3)
|
||||
with mock.patch(
|
||||
"roboimi.vla.models.backbones.siglip2_diffusion_backbone.SiglipVisionModel.from_pretrained",
|
||||
return_value=fake_model,
|
||||
) as mock_from_pretrained:
|
||||
backbone = SigLIP2DiffusionBackbone(
|
||||
model_name="google/siglip2-base-patch16-256",
|
||||
camera_names=_CAMERA_NAMES,
|
||||
num_cameras=3,
|
||||
per_view_output_dim=2,
|
||||
freeze_backbone=True,
|
||||
)
|
||||
|
||||
self.assertEqual(backbone.camera_names, _CAMERA_NAMES)
|
||||
self.assertEqual(backbone.num_cameras, 3)
|
||||
self.assertEqual(backbone.output_dim, 2)
|
||||
self.assertEqual(backbone.joint_output_dim, 6)
|
||||
self.assertIsNone(backbone.dataset_image_resize_shape)
|
||||
self.assertEqual(backbone.eval_image_resize_shape, (256, 256))
|
||||
mock_from_pretrained.assert_called_once_with("google/siglip2-base-patch16-256")
|
||||
self.assertTrue(all(not p.requires_grad for p in backbone.encoder.parameters()))
|
||||
self.assertFalse(backbone.encoder.training)
|
||||
|
||||
with torch.no_grad():
|
||||
backbone.view_projector.weight.zero_()
|
||||
backbone.view_projector.bias.zero_()
|
||||
backbone.view_projector.weight[0, 0] = 1.0
|
||||
backbone.view_projector.weight[1, 1] = 1.0
|
||||
|
||||
images = {
|
||||
"r_vis": torch.full((1, 2, 3, 256, 256), 0.25),
|
||||
"top": torch.full((1, 2, 3, 256, 256), 0.50),
|
||||
"front": torch.full((1, 2, 3, 256, 256), 0.75),
|
||||
}
|
||||
output = backbone(images)
|
||||
|
||||
self.assertEqual(output.shape, (1, 2, 6))
|
||||
self.assertEqual(len(fake_model.forward_calls), 3)
|
||||
|
||||
expected_per_camera = []
|
||||
for cam_name in _CAMERA_NAMES:
|
||||
img = images[cam_name].reshape(2, 3, 256, 256)
|
||||
normalized = (img - 0.5) / 0.5
|
||||
expected_per_camera.append(normalized.mean(dim=(2, 3))[:, :2])
|
||||
expected = torch.cat(expected_per_camera, dim=-1).view(1, 2, 6)
|
||||
self.assertTrue(torch.allclose(output, expected, atol=1e-6, rtol=1e-6))
|
||||
|
||||
for call, cam_name in zip(fake_model.forward_calls, _CAMERA_NAMES):
|
||||
pixels = call["pixel_values"]
|
||||
self.assertEqual(tuple(pixels.shape), (2, 3, 256, 256))
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
pixels,
|
||||
(images[cam_name].reshape(2, 3, 256, 256) - 0.5) / 0.5,
|
||||
)
|
||||
)
|
||||
|
||||
def test_forward_rejects_missing_required_camera(self):
|
||||
from roboimi.vla.models.backbones.siglip2_diffusion_backbone import SigLIP2DiffusionBackbone
|
||||
|
||||
backbone = SigLIP2DiffusionBackbone(
|
||||
vision_model=_FakeSiglipVisionModel(hidden_size=4),
|
||||
camera_names=_CAMERA_NAMES,
|
||||
num_cameras=3,
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "missing"):
|
||||
backbone({
|
||||
"r_vis": torch.rand(1, 1, 3, 256, 256),
|
||||
"top": torch.rand(1, 1, 3, 256, 256),
|
||||
})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -56,3 +56,26 @@ class SimpleRobotDatasetImageLoadingTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(len(resize_calls), 2)
|
||||
self.assertEqual(tuple(sample["observation.front"].shape), (2, 3, 8, 8))
|
||||
|
||||
def test_getitem_skips_resize_when_image_resize_shape_is_none(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
dataset_dir = Path(tmpdir)
|
||||
self._write_episode(dataset_dir)
|
||||
dataset = SimpleRobotDataset(
|
||||
dataset_dir,
|
||||
obs_horizon=2,
|
||||
pred_horizon=3,
|
||||
camera_names=["front"],
|
||||
image_resize_shape=None,
|
||||
)
|
||||
|
||||
fake_cv2 = types.SimpleNamespace(
|
||||
INTER_LINEAR=1,
|
||||
resize=mock.Mock(side_effect=AssertionError("resize should be skipped when image_resize_shape=None")),
|
||||
)
|
||||
|
||||
with mock.patch.dict(sys.modules, {"cv2": fake_cv2}):
|
||||
sample = dataset[1]
|
||||
|
||||
fake_cv2.resize.assert_not_called()
|
||||
self.assertEqual(tuple(sample["observation.front"].shape), (2, 3, 8, 8))
|
||||
|
||||
@@ -159,6 +159,92 @@ class TrainVLARolloutValidationTest(unittest.TestCase):
|
||||
self.assertGreater(cfg.train.num_workers, 8)
|
||||
self.assertEqual(cfg.train.rollout_val_freq_epochs, 50)
|
||||
|
||||
def test_training_passes_backbone_image_resize_override_to_dataset_instantiation(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
'agent': {
|
||||
'vision_backbone': {
|
||||
'dataset_image_resize_shape': None,
|
||||
},
|
||||
'normalization_type': 'min_max',
|
||||
},
|
||||
'data': {
|
||||
'dataset_dir': 'unused',
|
||||
'camera_names': ['front'],
|
||||
},
|
||||
'train': {
|
||||
'batch_size': 2,
|
||||
'lr': 1e-4,
|
||||
'max_steps': 0,
|
||||
'device': 'cpu',
|
||||
'disable_cudnn': False,
|
||||
'num_workers': 0,
|
||||
'val_split': 0.0,
|
||||
'seed': 42,
|
||||
'log_freq': 1,
|
||||
'save_freq': 10,
|
||||
'use_swanlab': False,
|
||||
'rollout_val_freq_epochs': 0,
|
||||
'rollout_validate_on_checkpoint': False,
|
||||
'rollout_num_episodes': 1,
|
||||
'warmup_steps': 1,
|
||||
'scheduler_type': 'constant',
|
||||
'min_lr': 1e-6,
|
||||
'weight_decay': 1e-5,
|
||||
'grad_clip': 1.0,
|
||||
'pretrained_ckpt': None,
|
||||
},
|
||||
'eval': {
|
||||
'ckpt_path': 'unused.pt',
|
||||
'num_episodes': 1,
|
||||
'headless': True,
|
||||
'device': 'cpu',
|
||||
'verbose_action': False,
|
||||
},
|
||||
'experiment': {},
|
||||
}
|
||||
)
|
||||
captured_dataset_kwargs = {}
|
||||
|
||||
def fake_instantiate(config_node, **kwargs):
|
||||
if config_node is cfg.data:
|
||||
captured_dataset_kwargs.update(kwargs)
|
||||
return _FakeDataset()
|
||||
if config_node is cfg.agent:
|
||||
return _FakeAgent()
|
||||
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||
|
||||
def fake_dataloader(_dataset, *, shuffle, **_kwargs):
|
||||
del shuffle, _kwargs
|
||||
return _FakeLoader(
|
||||
{
|
||||
'observation.front': torch.zeros(1, 3, 2, 2),
|
||||
'observation.state': torch.zeros(1, 4),
|
||||
'action': torch.zeros(1, 2),
|
||||
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
|
||||
},
|
||||
length=1,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \
|
||||
mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \
|
||||
mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \
|
||||
mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \
|
||||
mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \
|
||||
mock.patch.object(train_vla, '_init_swanlab', return_value=None), \
|
||||
mock.patch.object(train_vla, '_finish_swanlab', return_value=None), \
|
||||
mock.patch.object(train_vla.torch, 'save', return_value=None):
|
||||
train_vla._run_training(cfg)
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
self.assertIn('image_resize_shape', captured_dataset_kwargs)
|
||||
self.assertIsNone(captured_dataset_kwargs['image_resize_shape'])
|
||||
|
||||
def test_eval_main_delegates_to_plain_run_eval_helper(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user