Compare commits
10 Commits
8d6060224a
...
feat-imf-a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ff7c9c1f2a | ||
|
|
d51b3ecafa | ||
|
|
2033169840 | ||
|
|
a78006808a | ||
|
|
0586a6e6c7 | ||
|
|
48f0eb8dd0 | ||
|
|
3a17744dcf | ||
|
|
0514f86c36 | ||
|
|
dffd92f82d | ||
|
|
c2000b5533 |
@@ -0,0 +1,79 @@
|
|||||||
|
# IMF Rollout Trajectory Images and Short-Horizon Training Implementation Plan
|
||||||
|
|
||||||
|
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||||
|
|
||||||
|
**Goal:** Add training-time rollout front trajectory image export plus SwanLab image logging, then start a new local IMF training run with `emb=384`, `layer=12`, `pred_horizon=8`, `num_action_steps=4`, `max_steps=50000`.
|
||||||
|
|
||||||
|
**Architecture:** Extend `eval_vla.py` so a rollout can emit one per-episode static front-view image with red EE trajectory overlay. Extend `train_vla.py` so rollout validation forces image export, forces video off, and uploads those per-episode images to SwanLab. Launch the requested new run through explicit command-line overrides rather than branch-default config changes.
|
||||||
|
|
||||||
|
**Tech Stack:** Python, PyTorch, Hydra/OmegaConf, MuJoCo, OpenCV, SwanLab.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 1: Add and validate rollout image tests
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `tests/test_eval_vla_rollout_artifacts.py`
|
||||||
|
- Modify: `tests/test_train_vla_swanlab_logging.py`
|
||||||
|
- Modify: `tests/test_train_vla_rollout_validation.py`
|
||||||
|
|
||||||
|
- [ ] Add/adjust eval tests so they assert per-episode trajectory image paths are produced without requiring video export.
|
||||||
|
- [ ] Add/adjust training tests so they assert training-time rollout validation forces `record_video=false`.
|
||||||
|
- [ ] Add/adjust training tests so they assert trajectory image paths flow from eval summary into SwanLab media logging.
|
||||||
|
- [ ] Add/adjust training tests so they assert image media is logged, not only scalar reward metrics.
|
||||||
|
|
||||||
|
### Task 2: Implement per-episode front trajectory image export in eval
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `roboimi/demos/vla_scripts/eval_vla.py`
|
||||||
|
- Reuse/Read: `roboimi/utils/raw_action_trajectory_viewer.py`
|
||||||
|
- Modify: `roboimi/vla/conf/eval/eval.yaml`
|
||||||
|
|
||||||
|
- [ ] Add config plumbing for `save_trajectory_image` and `trajectory_image_camera_name`.
|
||||||
|
- [ ] Ensure the default training-time camera resolution path is pinned to `front`.
|
||||||
|
- [ ] Implement distinct per-episode image naming so 5 rollout episodes create 5 distinct PNGs.
|
||||||
|
- [ ] Reuse the existing red trajectory representation logic when composing the PNG.
|
||||||
|
- [ ] Ensure headless eval works under EGL even on machines with `DISPLAY` set.
|
||||||
|
|
||||||
|
### Task 3: Implement SwanLab rollout image logging in training
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `roboimi/demos/vla_scripts/train_vla.py`
|
||||||
|
- Modify: `tests/test_train_vla_swanlab_logging.py`
|
||||||
|
- Modify: `tests/test_train_vla_rollout_validation.py`
|
||||||
|
|
||||||
|
- [ ] Make `run_rollout_validation()` force `record_video=false`.
|
||||||
|
- [ ] Make `run_rollout_validation()` force `save_trajectory_image=true` and `trajectory_image_camera_name=front`.
|
||||||
|
- [ ] Ensure rollout validation still uses 5 episodes per validation event for the requested run.
|
||||||
|
- [ ] Add a best-effort helper that converts per-episode image paths into SwanLab image media payloads.
|
||||||
|
- [ ] Keep image-upload failures non-fatal and warning-only.
|
||||||
|
|
||||||
|
### Task 4: Verify action-chunk semantics for the new run
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Verify: `roboimi/vla/agent.py`
|
||||||
|
- Verify: `roboimi/vla/agent_imf.py`
|
||||||
|
- Test: `tests/test_imf_vla_agent.py`
|
||||||
|
|
||||||
|
- [ ] Confirm the existing queue logic still means “predict 8, execute first 4”.
|
||||||
|
- [ ] Do not change branch defaults unless strictly necessary; prefer launch-time overrides.
|
||||||
|
|
||||||
|
### Task 5: Verify and launch the requested local training run
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Use: `roboimi/demos/vla_scripts/train_vla.py`
|
||||||
|
- Use: `roboimi/demos/vla_scripts/eval_vla.py`
|
||||||
|
|
||||||
|
- [ ] Run the targeted verification suite.
|
||||||
|
- [ ] Run one real headless smoke eval and confirm a front trajectory PNG is produced while `video_mp4` stays null.
|
||||||
|
- [ ] Launch the new local training run with explicit overrides including:
|
||||||
|
- `agent=resnet_imf_attnres`
|
||||||
|
- `agent.head.n_emb=384`
|
||||||
|
- `agent.head.n_layer=12`
|
||||||
|
- `agent.pred_horizon=8`
|
||||||
|
- `agent.num_action_steps=4`
|
||||||
|
- `train.max_steps=50000`
|
||||||
|
- `train.rollout_num_episodes=5`
|
||||||
|
- `train.use_swanlab=true`
|
||||||
|
- current local baseline dataset/camera/CUDA/batch/lr/num_workers/backbone settings
|
||||||
|
- [ ] Verify PID, GPU allocation, log tail, and SwanLab run URL.
|
||||||
@@ -0,0 +1,68 @@
|
|||||||
|
# IMF Horizon Grid and AttnRes Ablation Implementation Plan
|
||||||
|
|
||||||
|
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||||
|
|
||||||
|
**Goal:** Run a 6-run Phase-1 IMF horizon/action-step experiment grid across available GPUs, monitor progress and collect best rollout metrics, then use the best horizon setting for a Phase-2 visual-attnres ablation.
|
||||||
|
|
||||||
|
**Architecture:** Use the current IMF training code as-is for Phase-1 by sweeping explicit `(pred_horizon, num_action_steps)` overrides while keeping emb=384, layer=12, and max_steps=50k fixed. Maintain a local experiment suite directory with a manifest and machine-readable status snapshots so progress can be resumed and summarized across turns. After Phase-1 completes, compare the current head-only attnres setup against a variant that also adds attnres into the visual ResNet path.
|
||||||
|
|
||||||
|
**Tech Stack:** Python, Hydra/OmegaConf, PyTorch, SSH/Tailscale, JSON/CSV status files, SwanLab.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 1: Prepare the experiment suite manifest and state tracking
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `experiment_suites/2026-04-04-imf-horizon-grid/manifest.json`
|
||||||
|
- Create: `experiment_suites/2026-04-04-imf-horizon-grid/status.json`
|
||||||
|
- Create: `experiment_suites/2026-04-04-imf-horizon-grid/notes.md`
|
||||||
|
|
||||||
|
- [ ] Define the 6 legal Phase-1 combinations: `(8,8)`, `(16,8)`, `(16,16)`, `(32,8)`, `(32,16)`, `(32,32)`.
|
||||||
|
- [ ] Record for each run: name, host, GPU slot, command, log path, SwanLab run name, and completion criteria.
|
||||||
|
- [ ] Define the comparison metric as the maximum rollout average reward seen during training (`max avg_reward`), preferably read from the best-checkpoint metadata and cross-checked against logs.
|
||||||
|
- [ ] Keep `status.json` updated with per-run state: queued / running / finished / failed plus latest parsed progress.
|
||||||
|
|
||||||
|
### Task 2: Prepare the remote 8-GPU execution target
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Remote working directory under `/home/droid/`
|
||||||
|
- Reuse or create a synced code directory for this suite
|
||||||
|
|
||||||
|
- [ ] Verify the remote dataset path and environment path.
|
||||||
|
- [ ] Verify GPU availability and reserve 6 GPUs for Phase-1 launches.
|
||||||
|
- [ ] Sync the required code to a dedicated remote suite directory.
|
||||||
|
- [ ] Record exact remote paths back into the local suite manifest.
|
||||||
|
|
||||||
|
### Task 3: Launch the 6 Phase-1 experiments in parallel
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Reuse: `roboimi/demos/vla_scripts/train_vla.py`
|
||||||
|
- Modify only local suite tracking files unless a launch bug is discovered
|
||||||
|
|
||||||
|
- [ ] Launch 6 runs concurrently with fixed settings: IMF, emb=384, layer=12, max_steps=50k.
|
||||||
|
- [ ] Keep all other relevant training hyperparameters aligned to the current strong baseline unless a concrete blocker appears.
|
||||||
|
- [ ] Assign one GPU per run on the 8xL20 host.
|
||||||
|
- [ ] Capture PID, log path, and SwanLab URL for each run in `status.json`.
|
||||||
|
|
||||||
|
### Task 4: Monitor and summarize Phase-1 until all 6 finish
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Update: `experiment_suites/2026-04-04-imf-horizon-grid/status.json`
|
||||||
|
- Update: `experiment_suites/2026-04-04-imf-horizon-grid/notes.md`
|
||||||
|
|
||||||
|
- [ ] Periodically parse each run’s log/checkpoints to extract latest step, latest rollout reward, and best rollout reward so far.
|
||||||
|
- [ ] Keep a resumable local summary so progress can be continued in later turns without rediscovery.
|
||||||
|
- [ ] After all 6 runs finish, rank them by `max avg_reward` and write a compact Phase-1 summary.
|
||||||
|
|
||||||
|
### Task 5: Prepare the Phase-2 visual-attnres ablation
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Likely modify: vision backbone implementation and config files (to be confirmed after code inspection)
|
||||||
|
- Add/update targeted tests for the visual backbone path if code changes are needed
|
||||||
|
|
||||||
|
- [ ] Use the best Phase-1 `(pred_horizon, num_action_steps)` combination as the fixed rollout setting for Phase-2.
|
||||||
|
- [ ] Compare:
|
||||||
|
1. current setup: attnres only in the IMF head
|
||||||
|
2. ablation setup: attnres in both IMF head and visual encoder path
|
||||||
|
- [ ] Keep the rest of the training settings fixed.
|
||||||
|
- [ ] Launch and monitor the Phase-2 pair after Phase-1 summary is complete.
|
||||||
@@ -0,0 +1,92 @@
|
|||||||
|
# LEWM ViT Backbone Implementation Plan
|
||||||
|
|
||||||
|
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||||
|
|
||||||
|
**Goal:** Replace the current ResNet visual encoder in roboimi VLA training with a frozen LEWM ViT visual backbone (encoder + projector) that consumes the three camera views jointly and outputs one 192-d CLS embedding per timestep, then launch two 50k runs on the 5880 machine.
|
||||||
|
|
||||||
|
**Architecture:** Add a new joint-multiview LEWM backbone that fuses `front/top/r_vis` into one LEWM-style image, reproduces LEWM preprocessing, loads frozen weights from the trained checkpoint, and exposes a `joint_output_dim=192`. Add a minimal `VLAAgent` compatibility branch so conditions can be sized from joint visual dim instead of `output_dim * num_cams`, while leaving the rest of the diffusion pipeline unchanged.
|
||||||
|
|
||||||
|
**Tech Stack:** PyTorch, transformers `ViTModel`, Hydra configs, existing roboimi VLA training/eval scripts, remote SSH/rsync to 100.73.14.65.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 1: Add failing tests for LEWM joint-vision backbone contract
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `tests/test_lewm_vit_backbone.py`
|
||||||
|
- Modify: `tests/test_imf_vla_agent.py`
|
||||||
|
|
||||||
|
- [ ] **Step 1: Write the failing backbone shape/load test**
|
||||||
|
- [ ] **Step 2: Run `pytest tests/test_lewm_vit_backbone.py -q` and verify it fails**
|
||||||
|
- [ ] **Step 3: Extend `tests/test_imf_vla_agent.py` with a failing joint-output backbone case**
|
||||||
|
- [ ] **Step 4: Run `pytest tests/test_imf_vla_agent.py -q` and verify it fails**
|
||||||
|
|
||||||
|
### Task 2: Implement LEWM joint-multiview frozen backbone
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `roboimi/vla/models/backbones/lewm_vit_backbone.py`
|
||||||
|
- Modify: `roboimi/vla/models/backbones/__init__.py` only if exports are needed
|
||||||
|
|
||||||
|
- [ ] **Step 1: Create `LEWMViTBackbone` with public attrs `camera_names`, `num_cameras`, `joint_output_dim=192`**
|
||||||
|
- [ ] **Step 2: Reproduce LEWM preprocessing and joint multiview fusion**
|
||||||
|
- [ ] **Step 3: Load checkpoint weights from `model.encoder.*` and `model.projector.*`**
|
||||||
|
- [ ] **Step 4: Freeze encoder/projector and keep them in eval mode via `train()` override**
|
||||||
|
- [ ] **Step 5: Run `pytest tests/test_lewm_vit_backbone.py -q` and verify green**
|
||||||
|
|
||||||
|
### Task 3: Add minimal agent support for joint visual dim
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `roboimi/vla/agent.py`
|
||||||
|
- Test: `tests/test_imf_vla_agent.py`
|
||||||
|
|
||||||
|
- [ ] **Step 1: Add a `joint_output_dim` branch in `VLAAgent.__init__` for `per_step_cond_dim` / `global_cond_dim`**
|
||||||
|
- [ ] **Step 2: Keep `_build_cond()` semantics unchanged except for matching the new dim contract**
|
||||||
|
- [ ] **Step 3: Run `pytest tests/test_imf_vla_agent.py -q` and verify green**
|
||||||
|
|
||||||
|
### Task 4: Add Hydra configs for LEWM backbone training
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `roboimi/vla/conf/backbone/lewm_vit_diffusion.yaml`
|
||||||
|
- Create: `roboimi/vla/conf/agent/lewm_imf_attnres.yaml`
|
||||||
|
|
||||||
|
- [ ] **Step 1: Add backbone config pointing to the new LEWM backbone**
|
||||||
|
- [ ] **Step 2: Add `agent=lewm_imf_attnres` config with 3 cameras and `head.cond_dim=208`**
|
||||||
|
- [ ] **Step 3: Verify Hydra instantiation with a one-shot compose smoke**
|
||||||
|
|
||||||
|
### Task 5: Verify focused local tests
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Reuse the above
|
||||||
|
|
||||||
|
- [ ] **Step 1: Run `pytest tests/test_lewm_vit_backbone.py tests/test_imf_vla_agent.py tests/test_eval_vla_headless_import.py -q`**
|
||||||
|
- [ ] **Step 2: If needed, run one tiny local import/forward smoke**
|
||||||
|
|
||||||
|
### Task 6: Sync to 5880 and remote smoke with real checkpoint
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Remote target: `/home/droid/roboimi_suite_20260404`
|
||||||
|
|
||||||
|
- [ ] **Step 1: Rsync modified source/config files to `100.73.14.65:/home/droid/roboimi_suite_20260404`**
|
||||||
|
- [ ] **Step 2: Run a 2-step smoke on GPU0 with `agent.head.n_emb=384`, `train.rollout_num_episodes=10`, real LEWM checkpoint**
|
||||||
|
- [ ] **Step 3: Run a 2-step smoke on GPU1 with `agent.head.n_emb=256`, same checkpoint**
|
||||||
|
|
||||||
|
### Task 7: Launch two real 50k runs on the 5880 machine
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Remote logs under `/home/droid/roboimi_suite_20260404/experiment_suite_launch_logs/`
|
||||||
|
|
||||||
|
- [ ] **Step 1: Launch embed384/layer12 on GPU0**
|
||||||
|
- [ ] **Step 2: Launch embed256/layer12 on GPU1**
|
||||||
|
- [ ] **Step 3: Ensure both use `data.camera_names=[r_vis,top,front]`, `pred_horizon=16`, `num_action_steps=8`, `train.rollout_num_episodes=10`, `max_steps=50000`**
|
||||||
|
- [ ] **Step 4: Record run names, pids, log paths, SwanLab URLs**
|
||||||
|
|
||||||
|
### Task 8: Update experiment tracking docs and commit
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `experiment_suites/2026-04-05-lewm-vit-transfer/manifest.json`
|
||||||
|
- Create: `experiment_suites/2026-04-05-lewm-vit-transfer/status.json`
|
||||||
|
- Create: `experiment_suites/2026-04-05-lewm-vit-transfer/notes.md`
|
||||||
|
|
||||||
|
- [ ] **Step 1: Record checkpoint path, frozen LEWM design, rollout=10, and both run configs**
|
||||||
|
- [ ] **Step 2: Record running status after launch**
|
||||||
|
- [ ] **Step 3: Commit implementation + docs with a focused message**
|
||||||
@@ -0,0 +1,64 @@
|
|||||||
|
# Phase-2 Full-AttnRes Vision Implementation Plan
|
||||||
|
|
||||||
|
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||||
|
|
||||||
|
**Goal:** Replace all ResNet residual units in the vision backbone with AttnRes-based image blocks while preserving the current IMF agent interfaces and launch a Phase-2 experiment anchored on the best Phase-1 horizon setting.
|
||||||
|
|
||||||
|
**Architecture:** Keep the current multi-camera encoder shell and per-camera output contract, but introduce a new ResNet-like 2D AttnRes backbone that preserves stage-wise downsampling and final SpatialSoftmax conditioning. Wire it into the existing `ResNetDiffusionBackbone` via an opt-in mode and keep the agent/head/data interfaces unchanged.
|
||||||
|
|
||||||
|
**Tech Stack:** PyTorch, Hydra/OmegaConf, existing IMF AttnRes transformer components, pytest.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 1: Add failing tests for the new full-AttnRes visual backbone
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `tests/test_attnres_resnet2d_backbone.py`
|
||||||
|
- Update: `tests/test_imf_vla_agent.py`
|
||||||
|
|
||||||
|
- [ ] **Step 1: Write a failing backbone shape test**
|
||||||
|
- [ ] **Step 2: Run it to confirm the new backbone/config does not exist yet**
|
||||||
|
- [ ] **Step 3: Add a failing IMF agent wiring test for unchanged cond_dim=208**
|
||||||
|
- [ ] **Step 4: Run the targeted tests and capture the failure**
|
||||||
|
|
||||||
|
### Task 2: Implement a ResNet-like 2D AttnRes backbone
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `roboimi/vla/models/backbones/attnres_resnet2d.py`
|
||||||
|
- Modify: `roboimi/vla/models/backbones/resnet_diffusion.py`
|
||||||
|
|
||||||
|
- [ ] **Step 1: Add minimal 2D tokenization helpers and positional encoding / bias handling**
|
||||||
|
- [ ] **Step 2: Implement `AttnResImageBlock2D` for feature maps**
|
||||||
|
- [ ] **Step 3: Implement `AttnResResNetLikeBackbone2D` with stage-wise downsampling**
|
||||||
|
- [ ] **Step 4: Wire `_SingleRgbEncoder` to choose between original ResNet trunk and the new full-AttnRes trunk**
|
||||||
|
- [ ] **Step 5: Run the new backbone tests**
|
||||||
|
|
||||||
|
### Task 3: Expose config switches and agent wiring
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `roboimi/vla/conf/backbone/resnet_diffusion.yaml`
|
||||||
|
- Modify: `roboimi/vla/conf/agent/resnet_imf_attnres.yaml`
|
||||||
|
|
||||||
|
- [ ] **Step 1: Add a backbone mode/config flag for the full-AttnRes vision trunk**
|
||||||
|
- [ ] **Step 2: Add defaults for attnres image depth/heads/etc. if needed**
|
||||||
|
- [ ] **Step 3: Add a Phase-2 launch override path that enables the new visual trunk**
|
||||||
|
- [ ] **Step 4: Run agent wiring tests again**
|
||||||
|
|
||||||
|
### Task 4: Smoke-verify training path
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Reuse existing training scripts and configs
|
||||||
|
|
||||||
|
- [ ] **Step 1: Run a short CPU or tiny-step smoke instantiation / `compute_loss` test**
|
||||||
|
- [ ] **Step 2: If needed, run a very short training smoke launch**
|
||||||
|
- [ ] **Step 3: Verify no cond-dim or rollout-loading regressions**
|
||||||
|
|
||||||
|
### Task 5: Launch the Phase-2 experiment
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Update experiment tracking under `experiment_suites/`
|
||||||
|
|
||||||
|
- [ ] **Step 1: Use Phase-1 best setting (`pred_horizon=16`, `num_action_steps=8`)**
|
||||||
|
- [ ] **Step 2: Launch baseline reference or reuse existing result**
|
||||||
|
- [ ] **Step 3: Launch full-AttnRes vision experiment**
|
||||||
|
- [ ] **Step 4: Track rollout metrics and compare max avg_reward**
|
||||||
81
docs/superpowers/plans/2026-04-06-resnet-multitoken-imf.md
Normal file
81
docs/superpowers/plans/2026-04-06-resnet-multitoken-imf.md
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
# ResNet Multitoken IMF Implementation Plan
|
||||||
|
|
||||||
|
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||||
|
|
||||||
|
**Goal:** Implement a standard-ResNet-18 multiview IMF variant that emits three condition tokens per obs step and launch four L20 experiments for `n_emb in {256,384}` and `n_layer in {12,16}`.
|
||||||
|
|
||||||
|
**Architecture:** The ResNet backbone will optionally return one token per camera instead of concatenating all cameras into one token. `VLAAgent` will pair each camera token with the current state, project each pair into a condition token, flatten the per-step camera tokens into one cond sequence, and feed that sequence into the existing IMF/AttnRes head.
|
||||||
|
|
||||||
|
**Tech Stack:** PyTorch, torchvision ResNet-18, Hydra, pytest, SwanLab, SSH/Tailscale.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 1: Add failing tests for multi-token conditioning
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `tests/test_imf_vla_agent.py`
|
||||||
|
- Modify: `tests/test_resnet_transformer_agent_wiring.py`
|
||||||
|
|
||||||
|
- [ ] **Step 1: Add a direct agent test**
|
||||||
|
- Stub a vision backbone returning `(B,T,3,D)` and assert `_build_cond()` yields `(B, T*3, D_cond)`.
|
||||||
|
- Assert state is paired with each camera token, not concatenated across cameras first.
|
||||||
|
- [ ] **Step 2: Add Hydra wiring test**
|
||||||
|
- Instantiate a new `agent=resnet_imf_attnres_multitoken` config with small dims.
|
||||||
|
- Assert `condition_tokens_per_step == 3`, `condition_sequence_length == obs_horizon * 3`, and head `n_obs_steps` receives that sequence length.
|
||||||
|
- [ ] **Step 3: Run focused tests and verify RED**
|
||||||
|
- `python -m pytest tests/test_imf_vla_agent.py tests/test_resnet_transformer_agent_wiring.py -q`
|
||||||
|
|
||||||
|
### Task 2: Implement multi-token ResNet conditioning path
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `roboimi/vla/models/backbones/resnet_diffusion.py`
|
||||||
|
- Modify: `roboimi/vla/agent.py`
|
||||||
|
- Create: `roboimi/vla/conf/agent/resnet_imf_attnres_multitoken.yaml`
|
||||||
|
|
||||||
|
- [ ] **Step 1: Extend ResNet backbone**
|
||||||
|
- Add an opt-in flag to return `(B,T,num_cams,D)` camera tokens instead of one concatenated `(B,T,num_cams*D)` token.
|
||||||
|
- Keep standard ResNet-18 vision mode; do not switch to AttnRes vision.
|
||||||
|
- [ ] **Step 2: Extend VLAAgent condition building**
|
||||||
|
- Support visual features with rank 4 `(B,T,K,D)`.
|
||||||
|
- Broadcast state to `(B,T,K,D_state)`, concatenate per camera, apply projector per token, then flatten to `(B,T*K,D_cond)`.
|
||||||
|
- Track `condition_tokens_per_step` and `condition_sequence_length`.
|
||||||
|
- [ ] **Step 3: Update transformer-head instantiation**
|
||||||
|
- Pass `n_obs_steps=condition_sequence_length` when building transformer heads.
|
||||||
|
- [ ] **Step 4: Add Hydra config**
|
||||||
|
- New agent config uses:
|
||||||
|
- separate ResNet-18 per camera
|
||||||
|
- standard residual vision trunk (`vision_backbone_mode=resnet`)
|
||||||
|
- condition projector output dim tied to `${agent.head.n_emb}`
|
||||||
|
- rollout episodes `10`, `pred_horizon=16`, `num_action_steps=8`
|
||||||
|
|
||||||
|
### Task 3: Verify locally
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify only if verification reveals issues
|
||||||
|
|
||||||
|
- [ ] **Step 1: Run focused tests and make them pass**
|
||||||
|
- `python -m pytest tests/test_imf_vla_agent.py tests/test_resnet_transformer_agent_wiring.py -q`
|
||||||
|
- [ ] **Step 2: Run regression subset**
|
||||||
|
- `python -m pytest tests/test_eval_vla_headless.py tests/test_train_vla_rollout_validation.py tests/test_simple_robot_dataset_image_loading.py -q`
|
||||||
|
- [ ] **Step 3: Run local smoke instantiation**
|
||||||
|
- instantiate the new Hydra config and verify cond shape / sequence length
|
||||||
|
|
||||||
|
### Task 4: Launch 4 L20 experiments
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Remote repo copy under `/home/droid/roboimi_suite_20260404`
|
||||||
|
|
||||||
|
- [ ] **Step 1: Sync code to `100.119.99.14`**
|
||||||
|
- [ ] **Step 2: Smoke the new config on remote**
|
||||||
|
- [ ] **Step 3: Launch runs**
|
||||||
|
- `(n_emb=256, n_layer=12)`
|
||||||
|
- `(n_emb=256, n_layer=16)`
|
||||||
|
- `(n_emb=384, n_layer=12)`
|
||||||
|
- `(n_emb=384, n_layer=16)`
|
||||||
|
- [ ] **Step 4: Keep fixed across runs**
|
||||||
|
- rollout episodes `10`
|
||||||
|
- `pred_horizon=16`
|
||||||
|
- `num_action_steps=8`
|
||||||
|
- standard ResNet-18 vision trunk
|
||||||
|
- three separate camera weights
|
||||||
|
- [ ] **Step 5: Record PIDs, GPUs, log paths, SwanLab URLs**
|
||||||
78
docs/superpowers/plans/2026-04-06-siglip2-multiview-vla.md
Normal file
78
docs/superpowers/plans/2026-04-06-siglip2-multiview-vla.md
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
# SigLIP2 Multiview VLA Implementation Plan
|
||||||
|
|
||||||
|
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||||
|
|
||||||
|
**Goal:** Integrate a frozen shared SigLIP2 multiview encoder into the IMF/AttnRes policy, preserve raw-256 image handling, and launch two 50k-step experiments on the 5880 host with per-view projection dims 96 and 192.
|
||||||
|
|
||||||
|
**Architecture:** A new backbone will independently encode each camera view with SigLIP2 and project each 768-d pooled feature to a configurable per-view dimension. `VLAAgent` will concatenate visual features with robot state, then optionally project the combined per-step condition to the head's required 384-d interface before diffusion training/inference.
|
||||||
|
|
||||||
|
**Tech Stack:** PyTorch, transformers SigLIP2, Hydra, pytest, SSH/Tailscale, SwanLab.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 1: Add failing tests for SigLIP2 backbone and projected conditioning
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `tests/test_siglip2_diffusion_backbone.py`
|
||||||
|
- Modify: `tests/test_imf_vla_agent.py`
|
||||||
|
|
||||||
|
- [ ] **Step 1: Write failing backbone tests**
|
||||||
|
- Instantiate the new backbone with a stub SigLIP2 vision model.
|
||||||
|
- Assert raw dataset resize is `None`, eval resize is `(256, 256)`, output shape is `(B, T, 3 * per_view_output_dim)`.
|
||||||
|
- Assert three views are encoded independently and projected.
|
||||||
|
- [ ] **Step 2: Run focused tests and verify RED**
|
||||||
|
- Run `pytest tests/test_siglip2_diffusion_backbone.py tests/test_imf_vla_agent.py -q`
|
||||||
|
- Expect failure because the backbone/config/projector do not exist yet.
|
||||||
|
- [ ] **Step 3: Extend agent wiring tests**
|
||||||
|
- Add a Hydra/instantiate test for a new SigLIP2 IMF config.
|
||||||
|
- Assert raw condition dim `3 * per_view_output_dim + obs_dim`, projected cond dim `384`, and head `cond_dim == 384`.
|
||||||
|
|
||||||
|
### Task 2: Implement SigLIP2 backbone and optional condition projector
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `roboimi/vla/models/backbones/siglip2_diffusion_backbone.py`
|
||||||
|
- Create: `roboimi/vla/conf/backbone/siglip2_diffusion.yaml`
|
||||||
|
- Create: `roboimi/vla/conf/agent/siglip2_imf_attnres.yaml`
|
||||||
|
- Create: `roboimi/vla/conf/modules/linear_condition_projector.yaml`
|
||||||
|
- Modify: `roboimi/vla/models/backbones/__init__.py`
|
||||||
|
- Modify: `roboimi/vla/agent.py`
|
||||||
|
|
||||||
|
- [ ] **Step 1: Implement backbone**
|
||||||
|
- Load `SiglipVisionModel.from_pretrained("google/siglip2-base-patch16-256")`.
|
||||||
|
- Normalize `[0,1]` pixels with mean/std `0.5` and encode each view independently.
|
||||||
|
- Project each 768-d pooled feature to configurable per-view dim and concatenate across cameras.
|
||||||
|
- [ ] **Step 2: Implement optional condition projector**
|
||||||
|
- Allow `VLAAgent` to accept `cond_projector`.
|
||||||
|
- Track `raw_per_step_cond_dim` and projected `per_step_cond_dim` / `global_cond_dim`.
|
||||||
|
- Apply the projector in `_build_cond()` after visual+state concatenation.
|
||||||
|
- [ ] **Step 3: Add Hydra configs**
|
||||||
|
- New agent config should default to `n_emb=384`, `n_layer=12`, `pred_horizon=16`, `num_action_steps=8`, `head.cond_dim=384`.
|
||||||
|
- Backbone config should set `dataset_image_resize_shape: null` and `eval_image_resize_shape: [256, 256]`.
|
||||||
|
|
||||||
|
### Task 3: Verify locally and prepare remote execution
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify as needed only if tests/smoke reveal issues
|
||||||
|
|
||||||
|
- [ ] **Step 1: Run focused tests and make them pass**
|
||||||
|
- `pytest tests/test_siglip2_diffusion_backbone.py tests/test_imf_vla_agent.py tests/test_eval_vla_headless.py tests/test_train_vla_rollout_validation.py tests/test_simple_robot_dataset_image_loading.py -q`
|
||||||
|
- [ ] **Step 2: Run a local smoke instantiation**
|
||||||
|
- Instantiate the new Hydra config with stubbed optional modules or offline-safe monkeypatching.
|
||||||
|
- [ ] **Step 3: Review diffs for unintended LEWM/raw256 regressions**
|
||||||
|
|
||||||
|
### Task 4: Sync to 5880 and launch experiments
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Remote repo copy under `/home/droid/roboimi_suite_20260404`
|
||||||
|
|
||||||
|
- [ ] **Step 1: Stop superseded remote jobs**
|
||||||
|
- [ ] **Step 2: Sync updated code to remote**
|
||||||
|
- Prefer `rsync` or `git push/pull` without overwriting unrelated files.
|
||||||
|
- [ ] **Step 3: Remote smoke test**
|
||||||
|
- Confirm SigLIP2 model download/import works in `/home/droid/miniforge3/envs/roboimi/bin/python`.
|
||||||
|
- Confirm headless rollout path still uses `256x256` eval resize.
|
||||||
|
- [ ] **Step 4: Launch experiment A**
|
||||||
|
- `per_view_output_dim=96`, `embed=384`, `layer=12`, `pred=16`, `exec=8`, `steps=50000`.
|
||||||
|
- [ ] **Step 5: Launch experiment B**
|
||||||
|
- `per_view_output_dim=192`, same other hyperparameters.
|
||||||
|
- [ ] **Step 6: Record PIDs, GPUs, log paths, and SwanLab run URLs.**
|
||||||
@@ -0,0 +1,75 @@
|
|||||||
|
# IMF Rollout Trajectory Images + Short-Horizon Training Design
|
||||||
|
|
||||||
|
## Background
|
||||||
|
The current RoboIMI IMF training flow can perform rollout validation and log scalar reward metrics to SwanLab, but it does not yet emit the qualitative rollout artifacts now required for analysis. The user wants training-time rollout validation to save front-view trajectory images with the model-generated trajectory drawn in red, upload those images to SwanLab, and then start a new local short-horizon IMF training run.
|
||||||
|
|
||||||
|
## Goals
|
||||||
|
1. During training-time rollout validation, save one **front-camera** trajectory image per rollout episode.
|
||||||
|
2. The image must show the rollout EE trajectory in red.
|
||||||
|
3. Reuse the existing repository trajectory visualization logic as much as practical, especially the existing red capsule-marker trajectory representation.
|
||||||
|
4. Save 5 rollout images locally for each validation event and upload the same 5 images to SwanLab.
|
||||||
|
5. Do **not** record rollout videos for this training-time validation flow.
|
||||||
|
6. Start a new local IMF-AttnRes training run with:
|
||||||
|
- `agent.head.n_emb=384`
|
||||||
|
- `agent.head.n_layer=12`
|
||||||
|
- `agent.pred_horizon=8`
|
||||||
|
- `agent.num_action_steps=4`
|
||||||
|
- `train.max_steps=50000`
|
||||||
|
- `train.rollout_num_episodes=5`
|
||||||
|
- `train.use_swanlab=true`
|
||||||
|
|
||||||
|
## Non-Goals
|
||||||
|
- No IMF architecture or loss-function change.
|
||||||
|
- No dataset schema change.
|
||||||
|
- No rollout video generation for the new training flow.
|
||||||
|
- No interactive viewer requirement.
|
||||||
|
|
||||||
|
## Existing Relevant Code
|
||||||
|
- `roboimi/demos/vla_scripts/eval_vla.py`
|
||||||
|
- already supports rollout summaries, optional trajectory export, and optional video export.
|
||||||
|
- `roboimi/utils/raw_action_trajectory_viewer.py`
|
||||||
|
- already contains the red trajectory capsule-marker construction logic.
|
||||||
|
- `roboimi/demos/vla_scripts/train_vla.py`
|
||||||
|
- already performs periodic rollout validation and scalar SwanLab logging.
|
||||||
|
- `roboimi/vla/agent.py`
|
||||||
|
- already implements “predict pred_horizon, execute first num_action_steps” queue semantics.
|
||||||
|
|
||||||
|
## Design Decisions
|
||||||
|
|
||||||
|
### 1. Artifact contract
|
||||||
|
Each rollout episode will emit one distinct PNG file under the eval artifact directory. The file naming/path contract must be per-episode, not shared, so a 5-episode validation event yields 5 stable image paths without overwriting.
|
||||||
|
|
||||||
|
### 2. Trajectory definition
|
||||||
|
The red trajectory corresponds to the **actually executed model action sequence** over the rollout loop: the raw EE actions returned and consumed step-by-step by the policy loop. For the requested short-horizon run, this means the visualization reflects repeated execution of the first 4 actions from each predicted 8-action chunk, not every discarded future prediction from replanning.
|
||||||
|
|
||||||
|
### 3. Camera choice
|
||||||
|
The training-time image export path is explicitly pinned to the repo’s concrete `front` camera key. It must not silently use `camera_names[0]` if that is not `front`.
|
||||||
|
|
||||||
|
### 4. Rendering path
|
||||||
|
`eval_vla.py` will add a lightweight headless image-export path that:
|
||||||
|
- renders the `front` camera frame,
|
||||||
|
- overlays the trajectory using the existing red trajectory representation,
|
||||||
|
- saves a static PNG per episode.
|
||||||
|
|
||||||
|
The implementation may reuse the existing marker-construction logic directly and add a minimal helper for final image composition/export.
|
||||||
|
|
||||||
|
### 5. Training-time behavior
|
||||||
|
`train_vla.py` rollout validation must explicitly:
|
||||||
|
- request/save trajectory images,
|
||||||
|
- keep `record_video=false`,
|
||||||
|
- return the 5 per-episode image paths in the rollout summary payload,
|
||||||
|
- upload those 5 images to SwanLab,
|
||||||
|
- keep image-upload failures non-fatal.
|
||||||
|
|
||||||
|
## Expected User-Visible Outcome
|
||||||
|
For each scheduled validation event in the new training run:
|
||||||
|
- 5 rollout episodes execute,
|
||||||
|
- 5 front-view PNG trajectory images are saved locally,
|
||||||
|
- the same 5 images are uploaded to SwanLab,
|
||||||
|
- scalar reward metrics continue to be logged,
|
||||||
|
- no rollout videos are generated.
|
||||||
|
|
||||||
|
## Risks and Mitigations
|
||||||
|
- **Headless rendering conflicts from desktop env vars**: force headless eval onto EGL when `headless=true`.
|
||||||
|
- **Image overwrite risk**: use explicit per-episode artifact paths.
|
||||||
|
- **SwanLab media API mismatch**: isolate media logging in a small best-effort helper.
|
||||||
138
docs/superpowers/specs/2026-04-05-lewm-vit-backbone-design.md
Normal file
138
docs/superpowers/specs/2026-04-05-lewm-vit-backbone-design.md
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
# LEWM ViT Backbone Replacement Design
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
将当前 roboimi VLA policy 中的 ResNet 视觉编码器替换为来自 LEWM checkpoint 的冻结 ViT 视觉编码器(encoder + projector),仅使用最终 CLS token 的 192 维 embedding 作为视觉特征。
|
||||||
|
|
||||||
|
## User constraints
|
||||||
|
- 使用 `/home/droid/下载/lewm_sim_transfer_checkpoint_usage.md` 中确认的训练好 checkpoint
|
||||||
|
- 只使用视觉编码部分:`encoder + projector`
|
||||||
|
- 权重冻结
|
||||||
|
- 维持“视觉特征 + state 拼接,再送入 diffusion transformer”这一总体处理方式
|
||||||
|
- 输入使用三视角:`[r_vis, top, front]`
|
||||||
|
- 在 5880 机器上启动两个训练:`embed=384/layer=12` 和 `embed=256/layer=12`
|
||||||
|
- `pred_horizon=16`
|
||||||
|
- `num_action_steps=8`
|
||||||
|
- 每个训练 `50k` steps
|
||||||
|
- rollout 验证每次用 `10` 个 episodes,不是之前的 `5`
|
||||||
|
|
||||||
|
## Trusted existing facts
|
||||||
|
1. LEWM checkpoint 路径:
|
||||||
|
- `/home/droid/le-wm/lewm-sim-transfer/pa1w85md8jop6bvol8oxp/checkpoints/epoch=99-step=47800.ckpt`
|
||||||
|
2. 需要加载的 state_dict 前缀:
|
||||||
|
- `model.encoder.*`
|
||||||
|
- `model.projector.*`
|
||||||
|
3. LEWM ViT 配置:
|
||||||
|
- encoder scale: `tiny`
|
||||||
|
- hidden size: `192`
|
||||||
|
- layers: `12`
|
||||||
|
- attention heads: `3`
|
||||||
|
- patch size: `14`
|
||||||
|
- projector: `MLP(192 -> 2048 -> 192)` with `BatchNorm1d + GELU`
|
||||||
|
4. LEWM 训练时三视角先拼成单图,再送入单个 ViT encoder;输出整体视觉 embedding 是 **192 维**。
|
||||||
|
|
||||||
|
## Key design decision
|
||||||
|
### Chosen design: fuse 3 cameras into one LEWM-style image, output one 192-d visual vector per timestep
|
||||||
|
不是把 LEWM ViT 当成“每相机一个 192-d encoder”,而是按 LEWM 原训练方式:
|
||||||
|
- 输入三视角图像字典 `{r_vis, top, front}`
|
||||||
|
- 按固定顺序拼成一张 fused image
|
||||||
|
- 走单个 frozen ViT + projector
|
||||||
|
- 得到一个 **192 维总视觉特征**
|
||||||
|
|
||||||
|
### Why this is the right replacement
|
||||||
|
当前 ResNet backbone 对外给到 policy head 的**总视觉特征维度**是:
|
||||||
|
- 每相机 `64`
|
||||||
|
- 三相机总计 `192`
|
||||||
|
|
||||||
|
而 LEWM checkpoint 输出的 CLS/projector embedding 也是:
|
||||||
|
- 总计 `192`
|
||||||
|
|
||||||
|
因此,最自然的“直接平替当前 ResNet 视觉编码器”的方式是:
|
||||||
|
- 用 LEWM backbone 直接产出一个 192-d 总视觉向量
|
||||||
|
- 后续和 state `16-d` 拼接后,依旧得到 `208-d` 条件向量
|
||||||
|
- 不改 diffusion head 的总体接口和语义
|
||||||
|
|
||||||
|
## Interface compatibility plan
|
||||||
|
现有 `VLAAgent` 假设 backbone 暴露:
|
||||||
|
- `camera_names`
|
||||||
|
- `num_cameras`
|
||||||
|
- `output_dim`(语义上是“每相机特征维度”)
|
||||||
|
- `forward(images_dict) -> (B, T, total_visual_dim)`
|
||||||
|
|
||||||
|
为了最小改动兼容现有 agent:
|
||||||
|
- 新 LEWM backbone 的 `forward()` 返回 `(B, T, 192)`
|
||||||
|
- `camera_names = ('r_vis', 'top', 'front')`
|
||||||
|
- `num_cameras = 3`
|
||||||
|
- `output_dim = 64`
|
||||||
|
|
||||||
|
这样 `VLAAgent` 内部仍会计算:
|
||||||
|
- `per_step_cond_dim = output_dim * num_cams + obs_dim = 64*3 + 16 = 208`
|
||||||
|
与实际 `forward()` 输出的 `192 + 16 = 208` 保持一致。
|
||||||
|
|
||||||
|
> 也就是说:`output_dim` 在这个 backbone 里保留为“与旧 ResNet 总特征等价的单相机占位维度”,而不是“真实 projector 输出维度”。这是一个兼容性 shim,用来避免改 agent 主逻辑。
|
||||||
|
|
||||||
|
## Image preprocessing design
|
||||||
|
当前 roboimi dataset 已经把每个相机图像读成:
|
||||||
|
- `(C, 224, 224)`
|
||||||
|
- 值域 `[0, 1]`
|
||||||
|
|
||||||
|
新 LEWM backbone 将:
|
||||||
|
1. 按顺序取 `r_vis`, `top`, `front`
|
||||||
|
2. 在宽度方向拼接,得到 fused image:
|
||||||
|
- `(C, 224, 672)`
|
||||||
|
3. 使用 LEWM 一致的 ImageNet normalize:
|
||||||
|
- mean `[0.485, 0.456, 0.406]`
|
||||||
|
- std `[0.229, 0.224, 0.225]`
|
||||||
|
4. 调用 `ViTModel(..., interpolate_pos_encoding=True)`
|
||||||
|
5. 取 `last_hidden_state[:, 0]`
|
||||||
|
6. 送入 frozen projector,得到 `(B*T, 192)`
|
||||||
|
|
||||||
|
## Files to create / modify
|
||||||
|
### New files
|
||||||
|
- `roboimi/vla/models/backbones/lewm_vit_backbone.py`
|
||||||
|
- `roboimi/vla/conf/backbone/lewm_vit_diffusion.yaml`
|
||||||
|
- `roboimi/vla/conf/agent/lewm_imf_attnres.yaml`
|
||||||
|
- `tests/test_lewm_vit_backbone.py`
|
||||||
|
|
||||||
|
### Modified files
|
||||||
|
- `roboimi/vla/models/backbones/__init__`(如果需要导出)
|
||||||
|
- `tests/test_imf_vla_agent.py`(增加新 backbone 集成用例)
|
||||||
|
- `roboimi/demos/vla_scripts/train_vla.py`(如需仅调整 rollout 默认/日志;如果命令覆盖足够,则尽量不改主逻辑)
|
||||||
|
- 训练/实验 suite 文档(新增本次 LEWM ViT 训练记录)
|
||||||
|
|
||||||
|
## Testing plan
|
||||||
|
1. **Unit test: load + forward**
|
||||||
|
- 用 synthetic checkpoint 验证新 backbone 能正确加载 `model.encoder.*` 与 `model.projector.*`
|
||||||
|
- 输入 3 相机 `(B,T,C,224,224)`
|
||||||
|
- 输出 `(B,T,192)`
|
||||||
|
2. **Agent integration test**
|
||||||
|
- backbone.output_dim=64, num_cameras=3
|
||||||
|
- agent `_build_cond()` 输出最后维度为 `208`
|
||||||
|
3. **Remote smoke test on 5880**
|
||||||
|
- 使用真实 checkpoint
|
||||||
|
- `max_steps=2`
|
||||||
|
- 两个实验各自 smoke 一次
|
||||||
|
4. **Full run**
|
||||||
|
- GPU0: `embed=384, layer=12`
|
||||||
|
- GPU1: `embed=256, layer=12`
|
||||||
|
- `rollout_num_episodes=10`
|
||||||
|
|
||||||
|
## Training launch contract
|
||||||
|
- host: `100.73.14.65`
|
||||||
|
- code dir: `/home/droid/roboimi_suite_20260404`
|
||||||
|
- python: `/home/droid/miniforge3/envs/roboimi/bin/python`
|
||||||
|
- dataset: `/home/droid/sim_dataset/sim_transfer`
|
||||||
|
- cameras: `[r_vis, top, front]`
|
||||||
|
- agent: new `lewm_imf_attnres`
|
||||||
|
- max_steps: `50000`
|
||||||
|
- rollout every `5` epochs
|
||||||
|
- rollout episodes: `10`
|
||||||
|
|
||||||
|
## Risks
|
||||||
|
1. LEWM 训练时的 fused image 预处理如果方向实现错了(224x672 vs 672x224),会导致分布偏移。
|
||||||
|
2. 当前 roboimi env 需确保安装 `transformers`;从 `environment.yml` 看本地已有该依赖,但远端训练环境要 smoke 确认。
|
||||||
|
3. 因为这是 frozen ViT + projector,若 projector BN 仍保持 train 模式,统计量会漂移,所以必须整体 `eval()` 并冻结。
|
||||||
|
|
||||||
|
## Recommended first implementation path
|
||||||
|
- 先实现一个独立 `LEWMViTBackbone` 类,不改现有 `ResNetDiffusionBackbone` 主逻辑。
|
||||||
|
- 再通过新的 hydra backbone/agent 配置接入。
|
||||||
|
- 优先做到“最少侵入 + smoke 可跑 + 远端可训”。
|
||||||
@@ -0,0 +1,81 @@
|
|||||||
|
# Phase-2 Full-AttnRes Vision Design
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
在当前 roboimi IMF policy 中,把视觉 backbone 里原先由 ResNet BasicBlock/Bottleneck 提供的残差单元全部替换为 AttnRes 风格单元,同时尽量保持现有 agent / cond / rollout / 训练脚本接口不变。
|
||||||
|
|
||||||
|
## User requirement interpretation
|
||||||
|
这里按最严格解释执行:
|
||||||
|
- 不是“在 ResNet 后面再加一个 AttnRes 模块”
|
||||||
|
- 也不是“只在某几个 stage 加 AttnRes 混合”
|
||||||
|
- 而是:视觉主干网络中原本依赖 ResNet residual block 的地方,统一改成 AttnRes residual operator 驱动的 block
|
||||||
|
- 最终仍然输出与现有 `ResNetDiffusionBackbone` 相同的每相机特征接口,以便复用 `SpatialSoftmax -> Linear -> ReLU`、多相机拼接、state concat、IMF head 条件输入
|
||||||
|
|
||||||
|
## Recommended design
|
||||||
|
### Option A (recommended)
|
||||||
|
保留 ResNet 的宏观 stage/stem 结构与通道/步幅规划,但把每个 stage 内的 BasicBlock/Bottleneck 替换为新的 `AttnResImageBlock2D`:
|
||||||
|
- 输入仍是 `(B, C, H, W)` feature map
|
||||||
|
- block 内先把空间维 flatten 成 token 序列 `(B, H*W, C)`
|
||||||
|
- 用二维位置编码 / 可学习位置偏置 + AttnRes self-attention + AttnRes FFN 完成 block 变换
|
||||||
|
- 再 reshape 回 `(B, C, H, W)`
|
||||||
|
- stage 间下采样仍由显式 stride/downsample path 完成
|
||||||
|
|
||||||
|
优点:
|
||||||
|
- 最接近“ResNet 中所有残差都由 AttnRes 代替”的要求
|
||||||
|
- 保留现有视觉输出接口和 cond_dim,不用改 agent/head/data pipeline
|
||||||
|
- 仍可沿用现有多相机编码器框架
|
||||||
|
|
||||||
|
缺点:
|
||||||
|
- 需要新写 2D 版 AttnRes image block,而不是直接复用 1D IMF head block
|
||||||
|
|
||||||
|
### Option B
|
||||||
|
完全移除 ResNet stage,换成 patchify + ViT/AttnRes 图像 transformer,再接 SpatialSoftmax/MLP。
|
||||||
|
|
||||||
|
优点:实现概念更统一。
|
||||||
|
缺点:已经不算“把 ResNet 中残差替换掉”,而是直接换 backbone,和用户要求不完全一致。
|
||||||
|
|
||||||
|
### Option C
|
||||||
|
保留现有 ResNet block,只在 block 外层加 AttnRes mixing。
|
||||||
|
|
||||||
|
不推荐,因为不满足“所有残差均由 AttnRes 替代”。
|
||||||
|
|
||||||
|
## Concrete architecture choice
|
||||||
|
采用 Option A:
|
||||||
|
1. 保留 stem(conv/bn-or-gn/relu/maxpool)与 stage 边界
|
||||||
|
2. 新增 `AttnResImageBlock2D`
|
||||||
|
3. 新增 `AttnResResNetLikeBackbone2D`,负责堆叠 stage/block
|
||||||
|
4. 在 `ResNetDiffusionBackbone` 中增加可选 backbone mode,例如:
|
||||||
|
- `vision_backbone_mode: resnet`
|
||||||
|
- `vision_backbone_mode: attnres_resnet`
|
||||||
|
5. `resnet_imf_attnres` agent 配置新增一个 Phase-2 变体,默认打开 `attnres_resnet`
|
||||||
|
6. 仍保持:
|
||||||
|
- 每相机输出 `64`
|
||||||
|
- 多相机总视觉输出 `3 * 64`
|
||||||
|
- 与 state 拼接后 `cond_dim = 208`
|
||||||
|
|
||||||
|
## Files likely to change
|
||||||
|
- `roboimi/vla/models/backbones/resnet_diffusion.py`
|
||||||
|
- `roboimi/vla/conf/backbone/resnet_diffusion.yaml`
|
||||||
|
- `roboimi/vla/conf/agent/resnet_imf_attnres.yaml`
|
||||||
|
- new: `roboimi/vla/models/backbones/attnres_resnet2d.py`
|
||||||
|
- tests:
|
||||||
|
- new: `tests/test_attnres_resnet2d_backbone.py`
|
||||||
|
- update/add wiring test for agent cond dims
|
||||||
|
|
||||||
|
## Test plan
|
||||||
|
1. New backbone instantiates and forwards `(B,T,C,H,W)` multi-camera input
|
||||||
|
2. Output shape unchanged vs current backbone
|
||||||
|
3. `output_dim == 64`
|
||||||
|
4. 3-camera cond path still yields `208`
|
||||||
|
5. Phase-2 config instantiates full IMF agent successfully
|
||||||
|
6. One short CPU smoke forward for `compute_loss`
|
||||||
|
|
||||||
|
## Phase-2 experiment plan
|
||||||
|
固定使用 Phase-1 最优组合:
|
||||||
|
- `pred_horizon=16`
|
||||||
|
- `num_action_steps=8`
|
||||||
|
|
||||||
|
比较:
|
||||||
|
1. baseline: current IMF head-only AttnRes + original ResNet vision backbone
|
||||||
|
2. phase2: IMF head AttnRes + full AttnRes-replaced vision backbone
|
||||||
|
|
||||||
|
训练超参保持与 Phase-1 最优设置一致,先跑一组 50k step 对比。
|
||||||
@@ -0,0 +1,32 @@
|
|||||||
|
# ResNet Multitoken IMF Design
|
||||||
|
|
||||||
|
**Status:** user-specified architecture, treated as approved on 2026-04-06.
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
Keep a standard ResNet-18 visual trunk (no AttnRes in vision), but change IMF conditioning from one concatenated multiview token per obs step into three camera-specific condition tokens per obs step.
|
||||||
|
|
||||||
|
## Approved architecture
|
||||||
|
- Vision trunk: standard `resnet18` residual network
|
||||||
|
- Cameras: `front`, `top`, `r_vis`
|
||||||
|
- Each camera uses its **own** ResNet-18 weights (`use_separate_rgb_encoder_per_camera=true`)
|
||||||
|
- Each camera produces one visual token
|
||||||
|
- For each obs step and each camera:
|
||||||
|
1. take that camera visual token
|
||||||
|
2. concatenate robot state
|
||||||
|
3. project to one condition token
|
||||||
|
- IMF input should receive **3 condition tokens per obs step**, not one concatenated token
|
||||||
|
- With `obs_horizon=2`, IMF cond sequence length becomes `2 * 3 = 6`
|
||||||
|
- IMF head remains on the existing IMF/AttnRes implementation path
|
||||||
|
- Vision trunk remains standard ResNet; **no AttnRes vision replacement**
|
||||||
|
|
||||||
|
## Design choices
|
||||||
|
- Extend `ResNetDiffusionBackbone` with an opt-in mode that returns per-camera tokens shaped `(B, T, num_cams, D)` instead of concatenating camera features into `(B, T, num_cams * D)`.
|
||||||
|
- Teach `VLAAgent` to detect multi-token visual features, broadcast state per camera token, apply the existing condition projector on each token, then flatten `(T, num_cams)` into one cond sequence for the IMF head.
|
||||||
|
- Keep `per_step_cond_dim` as the width of a single condition token, and add explicit token-count metadata so transformer heads get the correct cond-sequence length.
|
||||||
|
- For the new experiments, set the condition-token width equal to `n_emb` via `cond_projector.output_dim=${agent.head.n_emb}`.
|
||||||
|
|
||||||
|
## Files expected to change
|
||||||
|
- `roboimi/vla/models/backbones/resnet_diffusion.py`
|
||||||
|
- `roboimi/vla/agent.py`
|
||||||
|
- new Hydra agent config for the multitoken ResNet IMF variant
|
||||||
|
- focused tests in `tests/test_imf_vla_agent.py` and/or `tests/test_resnet_transformer_agent_wiring.py`
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
# SigLIP2 Multiview VLA Design
|
||||||
|
|
||||||
|
**Status:** user-specified architecture, treated as approved on 2026-04-06
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
Replace the current vision encoder for the IMF/AttnRes diffusion policy with a frozen SigLIP2 image encoder while preserving the downstream action-diffusion stack and rollout behavior.
|
||||||
|
|
||||||
|
## Approved architecture
|
||||||
|
- Backbone model: `google/siglip2-base-patch16-256`
|
||||||
|
- Camera inputs: three views, encoded **independently** with a **shared** SigLIP2 vision encoder
|
||||||
|
- Input size:
|
||||||
|
- dataset images stay at native `256x256` (no dataset-side resize)
|
||||||
|
- eval/rollout images resize to `256x256` before SigLIP2 because env renders are larger
|
||||||
|
- Per-view feature: use the global pooled image feature (`pooler_output`, 768-d)
|
||||||
|
- Per-view projection experiments:
|
||||||
|
1. `768 -> 96`
|
||||||
|
2. `768 -> 192`
|
||||||
|
- Conditioning pipeline:
|
||||||
|
1. concatenate 3 projected camera vectors
|
||||||
|
2. concatenate robot state
|
||||||
|
3. project concatenated condition to `384`
|
||||||
|
4. feed that `384`-d per-step condition into the existing IMF/AttnRes diffusion head
|
||||||
|
- Training/run defaults for requested experiments:
|
||||||
|
- `n_emb=384`
|
||||||
|
- `n_layer=12`
|
||||||
|
- `pred_horizon=16`
|
||||||
|
- `num_action_steps=8`
|
||||||
|
- rollout count for validation: keep current requested behavior on this branch unless explicitly overridden later
|
||||||
|
|
||||||
|
## Design decisions
|
||||||
|
- The condition projector lives in `VLAAgent._build_cond()` so the backbone owns only visual features, while the agent owns the final conditioning contract expected by the diffusion head.
|
||||||
|
- The SigLIP2 backbone is frozen by default; only the per-view projectors and downstream policy layers train.
|
||||||
|
- The backbone exposes `dataset_image_resize_shape=None` and `eval_image_resize_shape=(256, 256)` so existing train/eval plumbing can reuse the raw-256 path already added in this branch.
|
||||||
|
- One shared vision encoder is used across cameras to keep memory and download size reasonable and to match the user's request for per-view independent encoding rather than a fused multiview image.
|
||||||
|
|
||||||
|
## Files expected to change
|
||||||
|
- `roboimi/vla/models/backbones/` for the new SigLIP2 backbone
|
||||||
|
- `roboimi/vla/agent.py` for optional post-concat condition projection
|
||||||
|
- Hydra configs under `roboimi/vla/conf/{agent,backbone,modules}`
|
||||||
|
- tests for backbone wiring and agent conditioning dims
|
||||||
|
- remote launch commands/scripts only as needed for training
|
||||||
@@ -0,0 +1,63 @@
|
|||||||
|
# Phase-1 Final Report and Phase-2 Handoff
|
||||||
|
|
||||||
|
- Finalized: 2026-04-05 00:34:20 CST
|
||||||
|
- Scope: IMF AttnRes policy horizon/action-step grid on `sim_transfer`
|
||||||
|
- Fixed setup: `n_emb=384`, `n_layer=12`, batch size `80`, learning rate `2.5e-4`, `max_steps=50k`, rollout every 5 epochs with 5 episodes, 3 cameras `[r_vis, top, front]`.
|
||||||
|
- Main metric: `checkpoints/vla_model_best.pt` 中记录的训练期最大 `rollout_avg_reward`。
|
||||||
|
|
||||||
|
## Final leaderboard
|
||||||
|
|
||||||
|
| Rank | Run ID | pred_horizon | executed action steps | Best avg_reward | Best step | Final loss |
|
||||||
|
|---:|---|---:|---:|---:|---:|---:|
|
||||||
|
| 1 | `ph16_ex8` | 16 | 8 | **610.8** | 21874 | 0.0034 |
|
||||||
|
| 2 | `ph16_ex16` | 16 | 16 | 561.2 | 48124 | 0.0045 |
|
||||||
|
| 3 | `ph32_ex32` | 32 | 32 | 513.2 | 43749 | 0.0040 |
|
||||||
|
| 4 | `ph8_ex8` | 8 | 8 | 415.6 | 48124 | 0.0070 |
|
||||||
|
| 5 | `ph32_ex8` | 32 | 8 | 361.6 | 43749 | 0.0048 |
|
||||||
|
| 6 | `ph32_ex16` | 32 | 16 | 239.6 | 48124 | 0.0038 |
|
||||||
|
|
||||||
|
## Final conclusions
|
||||||
|
|
||||||
|
1. **最佳组合是 `pred_horizon=16` + `num_action_steps=8`**,最佳平均奖励为 **610.8**,出现在 **step 21874**。
|
||||||
|
2. 在 `pred_horizon=16` 下,执行 8 步优于执行 16 步,优势约 **+8.8%**(610.8 vs 561.2)。
|
||||||
|
3. `pred_horizon=32` 时,对执行步长非常敏感:`32/32` 明显优于 `32/8` 和 `32/16`;特别是 `32/16` 退化最明显。
|
||||||
|
4. 更长的预测窗口并不会自动带来更高 reward;**预测窗口与实际执行窗口的匹配关系** 是关键。
|
||||||
|
5. 最佳 checkpoint 并不在训练结束时出现,而是在 50k 训练中较早的 **21.9k step** 出现,说明 rollout 验证比仅看 train loss 更重要。
|
||||||
|
6. 因而 Phase-2 的比较基线固定为 **`ph16_ex8`**。
|
||||||
|
|
||||||
|
## Recommended baseline for follow-up experiments
|
||||||
|
|
||||||
|
- Baseline run: `ph16_ex8`
|
||||||
|
- Baseline best checkpoint: `step 21874`
|
||||||
|
- Baseline best avg_reward: `610.8`
|
||||||
|
- Baseline run dir: `/home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223`
|
||||||
|
|
||||||
|
## Phase-2 target: full-AttnRes vision backbone
|
||||||
|
|
||||||
|
本阶段按你的要求,不再只是 IMF head 中使用 AttnRes,而是把**之前视觉 ResNet 主干中的残差单元全部替换为 AttnRes 残差单元**。当前实现保留了 ResNet 风格的 stage / downsample 宏观结构,但视觉残差 trunk 已切换到 AttnRes:
|
||||||
|
|
||||||
|
- implementation: `roboimi/vla/models/backbones/attnres_resnet2d.py`
|
||||||
|
- wiring: `roboimi/vla/models/backbones/resnet_diffusion.py`
|
||||||
|
- config: `roboimi/vla/conf/backbone/resnet_diffusion.yaml`
|
||||||
|
|
||||||
|
相关代码已提交:
|
||||||
|
|
||||||
|
- `a780068` — headless rollout 修复 + Phase-1 汇总
|
||||||
|
- `2033169` — full-AttnRes vision backbone
|
||||||
|
|
||||||
|
## Phase-2 launch status (observed on 2026-04-05 00:36 CST)
|
||||||
|
|
||||||
|
- Run: `imf-p2-full-attnres-vision-ph16-ex08-emb384-l12-b40-lr1p25e4-ms50k-l20g3-20260405-002424`
|
||||||
|
- Host: `100.119.99.14`, GPU `3`
|
||||||
|
- Config anchor: `pred_horizon=16`, `num_action_steps=8`
|
||||||
|
- Vision backbone: `attnres_resnet`
|
||||||
|
- Because batch size `80` OOMed on both local 5090 and remote L20, Phase-2 currently uses:
|
||||||
|
- batch size: `40`
|
||||||
|
- learning rate: `1.25e-4`
|
||||||
|
- Latest confirmed progress: **step 1300**
|
||||||
|
- First rollout has **not happened yet** at this observation point.
|
||||||
|
- SwanLab: https://swanlab.cn/@game-loader/roboimi-vla/runs/xy7fjdmn0stdr19eu3gub
|
||||||
|
|
||||||
|
## Next action
|
||||||
|
|
||||||
|
继续监控 Phase-2 full-AttnRes 训练,待其完成后直接与 Phase-1 baseline `610.8` 做对比,判断“视觉主干全部替换为 AttnRes”是否优于“仅 IMF 中使用 AttnRes”。
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
rank,run_id,status,pred_horizon,num_action_steps,best_rollout_avg_reward,best_step,final_step,final_loss,host,run_dir,latest_step
|
||||||
|
1,ph16_ex8,running,16,8,610.8,21874,50000,0.0034315965604037046,100.73.14.65,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223,50000
|
||||||
|
2,ph16_ex16,running,16,16,561.2,48124,50000,0.004544622730463743,100.119.99.14,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex16-emb384-l12-ms50k-l20g0-20260404-131223,50000
|
||||||
|
3,ph32_ex32,finished,32,32,513.2,43749,50000,0.003953303210437298,local,/home/droid/project/roboimi/.worktrees/feat-imf-attnres-policy/runs/imf-p1-ph32-ex32-emb384-l12-ms50k-5090-20260404-131223,49900
|
||||||
|
4,ph8_ex8,running,8,8,415.6,48124,50000,0.007008877582848072,100.73.14.65,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph08-ex08-emb384-l12-ms50k-5880g0-20260404-131223,50000
|
||||||
|
5,ph32_ex8,running,32,8,361.6,43749,50000,0.004788532387465239,100.119.99.14,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph32-ex08-emb384-l12-ms50k-l20g1-20260404-131223,50000
|
||||||
|
6,ph32_ex16,running,32,16,239.6,48124,50000,0.0038348555099219084,100.119.99.14,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph32-ex16-emb384-l12-ms50k-l20g2-20260404-131223,50000
|
||||||
|
115
experiment_suites/2026-04-04-imf-horizon-grid/manifest.json
Normal file
115
experiment_suites/2026-04-04-imf-horizon-grid/manifest.json
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
{
|
||||||
|
"suite_name": "2026-04-04-imf-horizon-grid",
|
||||||
|
"created_at": "2026-04-04 13:19:52",
|
||||||
|
"updated_at": "2026-04-04 13:19:52",
|
||||||
|
"phase": "phase1_launching",
|
||||||
|
"metric": "max_avg_reward",
|
||||||
|
"baseline": {
|
||||||
|
"agent": "resnet_imf_attnres",
|
||||||
|
"batch_size": 80,
|
||||||
|
"lr": 0.00025,
|
||||||
|
"num_workers": 12,
|
||||||
|
"max_steps": 50000,
|
||||||
|
"rollout_val_freq_epochs": 5,
|
||||||
|
"rollout_num_episodes": 5,
|
||||||
|
"val_split": 0.0,
|
||||||
|
"seed": 42,
|
||||||
|
"scheduler_type": "cosine",
|
||||||
|
"warmup_steps": 2000,
|
||||||
|
"min_lr": 1e-06,
|
||||||
|
"weight_decay": 1e-05,
|
||||||
|
"grad_clip": 1.0,
|
||||||
|
"inference_steps": 1,
|
||||||
|
"embed_dim": 384,
|
||||||
|
"n_layer": 12,
|
||||||
|
"n_head": 1,
|
||||||
|
"n_kv_head": 1,
|
||||||
|
"freeze_backbone": false,
|
||||||
|
"pretrained_backbone_weights": null,
|
||||||
|
"camera_names": [
|
||||||
|
"r_vis",
|
||||||
|
"top",
|
||||||
|
"front"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"runs": [
|
||||||
|
{
|
||||||
|
"id": "ph8_ex8",
|
||||||
|
"pred_horizon": 8,
|
||||||
|
"num_action_steps": 8,
|
||||||
|
"host": "100.73.14.65",
|
||||||
|
"host_label": "tailnet-5880",
|
||||||
|
"gpu": 0,
|
||||||
|
"workdir": "/home/droid/roboimi_suite_20260404",
|
||||||
|
"python": "/home/droid/miniforge3/envs/roboimi/bin/python",
|
||||||
|
"dataset_dir": "/home/droid/sim_dataset/sim_transfer",
|
||||||
|
"run_name": "imf-p1-ph08-ex08-emb384-l12-ms50k-5880g0-20260404-131223",
|
||||||
|
"launch_state": "ready"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "ph16_ex8",
|
||||||
|
"pred_horizon": 16,
|
||||||
|
"num_action_steps": 8,
|
||||||
|
"host": "100.73.14.65",
|
||||||
|
"host_label": "tailnet-5880",
|
||||||
|
"gpu": 1,
|
||||||
|
"workdir": "/home/droid/roboimi_suite_20260404",
|
||||||
|
"python": "/home/droid/miniforge3/envs/roboimi/bin/python",
|
||||||
|
"dataset_dir": "/home/droid/sim_dataset/sim_transfer",
|
||||||
|
"run_name": "imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223",
|
||||||
|
"launch_state": "ready"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "ph16_ex16",
|
||||||
|
"pred_horizon": 16,
|
||||||
|
"num_action_steps": 16,
|
||||||
|
"host": "100.119.99.14",
|
||||||
|
"host_label": "tailnet-l20",
|
||||||
|
"gpu": 0,
|
||||||
|
"workdir": "/home/droid/roboimi_suite_20260404",
|
||||||
|
"python": "/home/droid/miniforge3/envs/roboimi/bin/python",
|
||||||
|
"dataset_dir": "/home/droid/sim_dataset/sim_transfer",
|
||||||
|
"run_name": "imf-p1-ph16-ex16-emb384-l12-ms50k-l20g0-20260404-131223",
|
||||||
|
"launch_state": "provisioning_required"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "ph32_ex8",
|
||||||
|
"pred_horizon": 32,
|
||||||
|
"num_action_steps": 8,
|
||||||
|
"host": "100.119.99.14",
|
||||||
|
"host_label": "tailnet-l20",
|
||||||
|
"gpu": 1,
|
||||||
|
"workdir": "/home/droid/roboimi_suite_20260404",
|
||||||
|
"python": "/home/droid/miniforge3/envs/roboimi/bin/python",
|
||||||
|
"dataset_dir": "/home/droid/sim_dataset/sim_transfer",
|
||||||
|
"run_name": "imf-p1-ph32-ex08-emb384-l12-ms50k-l20g1-20260404-131223",
|
||||||
|
"launch_state": "provisioning_required"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "ph32_ex16",
|
||||||
|
"pred_horizon": 32,
|
||||||
|
"num_action_steps": 16,
|
||||||
|
"host": "100.119.99.14",
|
||||||
|
"host_label": "tailnet-l20",
|
||||||
|
"gpu": 2,
|
||||||
|
"workdir": "/home/droid/roboimi_suite_20260404",
|
||||||
|
"python": "/home/droid/miniforge3/envs/roboimi/bin/python",
|
||||||
|
"dataset_dir": "/home/droid/sim_dataset/sim_transfer",
|
||||||
|
"run_name": "imf-p1-ph32-ex16-emb384-l12-ms50k-l20g2-20260404-131223",
|
||||||
|
"launch_state": "provisioning_required"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "ph32_ex32",
|
||||||
|
"pred_horizon": 32,
|
||||||
|
"num_action_steps": 32,
|
||||||
|
"host": "local",
|
||||||
|
"host_label": "local-5090",
|
||||||
|
"gpu": 0,
|
||||||
|
"workdir": "/home/droid/project/roboimi/.worktrees/feat-imf-attnres-policy",
|
||||||
|
"python": "/home/droid/.conda/envs/roboimi/bin/python",
|
||||||
|
"dataset_dir": "/home/droid/project/diana_sim/sim_transfer",
|
||||||
|
"run_name": "imf-p1-ph32-ex32-emb384-l12-ms50k-5090-20260404-131223",
|
||||||
|
"launch_state": "ready"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
20
experiment_suites/2026-04-04-imf-horizon-grid/notes.md
Normal file
20
experiment_suites/2026-04-04-imf-horizon-grid/notes.md
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
# IMF Horizon Grid Suite Notes
|
||||||
|
|
||||||
|
- Created: 2026-04-04 13:19:52
|
||||||
|
- Phase-1 matrix: (8,8), (16,8), (16,16), (32,8), (32,16), (32,32)
|
||||||
|
- Fixed baseline: IMF AttnRes, n_emb=384, n_layer=12, batch_size=80, lr=2.5e-4, max_steps=50k, rollout every 5 epochs with 5 episodes.
|
||||||
|
- Host allocation:
|
||||||
|
- local RTX 5090: ph32_ex32
|
||||||
|
- 100.73.14.65 RTX 5880 GPU0: ph8_ex8
|
||||||
|
- 100.73.14.65 RTX 5880 GPU1: ph16_ex8
|
||||||
|
- 100.119.99.14 L20 GPU0: ph16_ex16
|
||||||
|
- 100.119.99.14 L20 GPU1: ph32_ex8
|
||||||
|
- 100.119.99.14 L20 GPU2: ph32_ex16
|
||||||
|
- 100.119.99.14 still needs env + dataset + swanlab credential copy before launch.
|
||||||
|
|
||||||
|
- 2026-04-04 13:23:43: launched local ph32_ex32 (pid 1437836), remote 100.73 ph8_ex8 (pid 931824), ph16_ex8 (pid 931826); started 100.119 bootstrap (local pid 1437837).
|
||||||
|
- 2026-04-04 13:25:43: first status sync — local ph32_ex32 step≈500; remote ph8_ex8 step≈400; remote ph16_ex8 step≈400.
|
||||||
|
- 2026-04-04 13:27:41: second status sync — 100.119 bootstrap finished env copy and entered dataset copy; local ph32_ex32 step≈900; remote ph8_ex8 step≈800; remote ph16_ex8 step≈800.
|
||||||
|
- 2026-04-04 13:35:31: 100.119 bootstrap data/env copy finished. Original validation command hit a quoting bug, then I manually revalidated torch+mujoco+swanlab and launched ph16_ex16/ph32_ex8/ph32_ex16 with pids 81129/81130/81131.
|
||||||
|
- 2026-04-04 13:37:36: all 6 Phase-1 runs are now up. SwanLab links recorded in status.json; latest observed steps ~ local 900 / 5880 runs 800 / L20 runs 100.
|
||||||
|
- 2026-04-04 14:41:08: diagnosed remote first-rollout crash as early mujoco import before MUJOCO_GL=egl in eval_vla.py via raw_action_trajectory_viewer. Added regression test tests/test_eval_vla_headless_import.py, fixed import to lazy-load, verified 20-step headless eval on 5880 and L20, then resumed 5 failed runs from step 4374. Current resumed pids: ph8_ex8=938714, ph16_ex8=938717, ph16_ex16=90169, ph32_ex8=90173, ph32_ex16=90175.
|
||||||
@@ -0,0 +1,38 @@
|
|||||||
|
# Phase-1 IMF Horizon Grid Summary
|
||||||
|
|
||||||
|
- Generated: 2026-04-04 23:43:38
|
||||||
|
- Fixed baseline: IMF AttnRes head, n_emb=384, n_layer=12, batch_size=80, lr=2.5e-4, max_steps=50k, rollout every 5 epochs with 5 episodes, 3 cameras `[r_vis, top, front]`.
|
||||||
|
- Primary metric: `checkpoints/vla_model_best.pt -> rollout_avg_reward` (max training-time rollout average reward).
|
||||||
|
|
||||||
|
## Ranked results
|
||||||
|
|
||||||
|
| Rank | Run ID | pred_horizon | num_action_steps | Best avg_reward | Best step | Final loss | Host |
|
||||||
|
|---:|---|---:|---:|---:|---:|---:|---|
|
||||||
|
| 1 | `ph16_ex8` | 16 | 8 | 610.8 | 21874 | 0.0034 | 100.73.14.65 |
|
||||||
|
| 2 | `ph16_ex16` | 16 | 16 | 561.2 | 48124 | 0.0045 | 100.119.99.14 |
|
||||||
|
| 3 | `ph32_ex32` | 32 | 32 | 513.2 | 43749 | 0.0040 | local |
|
||||||
|
| 4 | `ph8_ex8` | 8 | 8 | 415.6 | 48124 | 0.0070 | 100.73.14.65 |
|
||||||
|
| 5 | `ph32_ex8` | 32 | 8 | 361.6 | 43749 | 0.0048 | 100.119.99.14 |
|
||||||
|
| 6 | `ph32_ex16` | 32 | 16 | 239.6 | 48124 | 0.0038 | 100.119.99.14 |
|
||||||
|
|
||||||
|
## Main observations
|
||||||
|
|
||||||
|
- Best overall setting was **`pred_horizon=16`, `num_action_steps=8`** with **max avg_reward = 610.8** at step **21874**.
|
||||||
|
- Comparing horizon 16: executing 8 steps outperformed executing 16 steps (`ph16_ex8` > `ph16_ex16`).
|
||||||
|
- Comparing horizon 32: executing the full 32-step chunk was much better than executing 16 or 8 steps (`ph32_ex32` > `ph32_ex8` > `ph32_ex16`).
|
||||||
|
- Short horizon 8 with 8-step execution was competitive but clearly below the best 16/8 and 32/32 settings.
|
||||||
|
- In this sweep, increasing prediction horizon helped only when the executed chunk length matched a good control cadence; mismatch could hurt a lot (especially `ph32_ex16`).
|
||||||
|
|
||||||
|
## Raw results
|
||||||
|
|
||||||
|
- `ph16_ex8`: best avg_reward=610.8 @ step 21874, final_loss=0.0034, run_dir=`/home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223`
|
||||||
|
- `ph16_ex16`: best avg_reward=561.2 @ step 48124, final_loss=0.0045, run_dir=`/home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex16-emb384-l12-ms50k-l20g0-20260404-131223`
|
||||||
|
- `ph32_ex32`: best avg_reward=513.2 @ step 43749, final_loss=0.0040, run_dir=`/home/droid/project/roboimi/.worktrees/feat-imf-attnres-policy/runs/imf-p1-ph32-ex32-emb384-l12-ms50k-5090-20260404-131223`
|
||||||
|
- `ph8_ex8`: best avg_reward=415.6 @ step 48124, final_loss=0.0070, run_dir=`/home/droid/roboimi_suite_20260404/runs/imf-p1-ph08-ex08-emb384-l12-ms50k-5880g0-20260404-131223`
|
||||||
|
- `ph32_ex8`: best avg_reward=361.6 @ step 43749, final_loss=0.0048, run_dir=`/home/droid/roboimi_suite_20260404/runs/imf-p1-ph32-ex08-emb384-l12-ms50k-l20g1-20260404-131223`
|
||||||
|
- `ph32_ex16`: best avg_reward=239.6 @ step 48124, final_loss=0.0038, run_dir=`/home/droid/roboimi_suite_20260404/runs/imf-p1-ph32-ex16-emb384-l12-ms50k-l20g2-20260404-131223`
|
||||||
|
|
||||||
|
## Recommendation for Phase-2 anchor
|
||||||
|
|
||||||
|
- Use **`pred_horizon=16`, `num_action_steps=8`** as the strongest Phase-1 baseline if the goal is purely maximizing rollout reward.
|
||||||
|
- If phase-2 needs a more conservative action execution budget, `ph16_ex8` is the strongest non-full-32 execution setting and may still be a good comparison anchor.
|
||||||
167
experiment_suites/2026-04-04-imf-horizon-grid/status.json
Normal file
167
experiment_suites/2026-04-04-imf-horizon-grid/status.json
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
{
|
||||||
|
"suite_name": "2026-04-04-imf-horizon-grid",
|
||||||
|
"updated_at": "2026-04-05 00:34:20",
|
||||||
|
"phase": "phase1_completed",
|
||||||
|
"provisioning": {
|
||||||
|
"100.119.99.14": {
|
||||||
|
"state": "completed_manual_launch",
|
||||||
|
"bootstrap_pid_local": 1437837,
|
||||||
|
"log_path": "experiment_suites/2026-04-04-imf-horizon-grid/provision_logs/100.119.99.14-bootstrap-20260404-131223.log",
|
||||||
|
"env_copy": "completed",
|
||||||
|
"dataset_copy": "completed",
|
||||||
|
"launch_watcher_pid_local": null,
|
||||||
|
"launch_watcher_log": "experiment_suites/2026-04-04-imf-horizon-grid/launch_logs/100.119.99.14-launch-watcher-20260404-131223.log",
|
||||||
|
"swanlab_copy": "completed",
|
||||||
|
"bootstrap_validation_note": "initial validation command had a quoting bug; manual validation passed and launches were started successfully"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"runs": {
|
||||||
|
"ph8_ex8": {
|
||||||
|
"status": "finished",
|
||||||
|
"host": "100.73.14.65",
|
||||||
|
"gpu": 0,
|
||||||
|
"run_name": "imf-p1-ph08-ex08-emb384-l12-ms50k-5880g0-20260404-131223",
|
||||||
|
"workdir": "/home/droid/roboimi_suite_20260404",
|
||||||
|
"dataset_dir": "/home/droid/sim_dataset/sim_transfer",
|
||||||
|
"log_path": "/home/droid/roboimi_suite_20260404/runs/imf-p1-ph08-ex08-emb384-l12-ms50k-5880g0-20260404-131223/train_vla.log",
|
||||||
|
"run_dir": "/home/droid/roboimi_suite_20260404/runs/imf-p1-ph08-ex08-emb384-l12-ms50k-5880g0-20260404-131223",
|
||||||
|
"pred_horizon": 8,
|
||||||
|
"num_action_steps": 8,
|
||||||
|
"pid": 938714,
|
||||||
|
"launch_log": "experiment_suite_launch_logs/imf-p1-ph08-ex08-emb384-l12-ms50k-5880g0-20260404-131223.restartfix-20260404-143827.log",
|
||||||
|
"latest_step": 50000,
|
||||||
|
"latest_log_sync": "2026-04-05 00:34:20",
|
||||||
|
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/i5syc57b6zq7rbkrtqy7b",
|
||||||
|
"process_running": false,
|
||||||
|
"best_step": 48124,
|
||||||
|
"best_rollout_avg_reward": 415.6,
|
||||||
|
"final_loss": 0.007008877582848072
|
||||||
|
},
|
||||||
|
"ph16_ex8": {
|
||||||
|
"status": "finished",
|
||||||
|
"host": "100.73.14.65",
|
||||||
|
"gpu": 1,
|
||||||
|
"run_name": "imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223",
|
||||||
|
"workdir": "/home/droid/roboimi_suite_20260404",
|
||||||
|
"dataset_dir": "/home/droid/sim_dataset/sim_transfer",
|
||||||
|
"log_path": "/home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223/train_vla.log",
|
||||||
|
"run_dir": "/home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223",
|
||||||
|
"pred_horizon": 16,
|
||||||
|
"num_action_steps": 8,
|
||||||
|
"pid": 938717,
|
||||||
|
"launch_log": "experiment_suite_launch_logs/imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223.restartfix-20260404-143827.log",
|
||||||
|
"latest_step": 50000,
|
||||||
|
"latest_log_sync": "2026-04-05 00:34:20",
|
||||||
|
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/4rusbrpfxmw4ffii1ul5w",
|
||||||
|
"process_running": false,
|
||||||
|
"best_step": 21874,
|
||||||
|
"best_rollout_avg_reward": 610.8,
|
||||||
|
"final_loss": 0.0034315965604037046
|
||||||
|
},
|
||||||
|
"ph16_ex16": {
|
||||||
|
"status": "finished",
|
||||||
|
"host": "100.119.99.14",
|
||||||
|
"gpu": 0,
|
||||||
|
"run_name": "imf-p1-ph16-ex16-emb384-l12-ms50k-l20g0-20260404-131223",
|
||||||
|
"workdir": "/home/droid/roboimi_suite_20260404",
|
||||||
|
"dataset_dir": "/home/droid/sim_dataset/sim_transfer",
|
||||||
|
"log_path": "/home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex16-emb384-l12-ms50k-l20g0-20260404-131223/train_vla.log",
|
||||||
|
"run_dir": "/home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex16-emb384-l12-ms50k-l20g0-20260404-131223",
|
||||||
|
"pred_horizon": 16,
|
||||||
|
"num_action_steps": 16,
|
||||||
|
"pid": 90169,
|
||||||
|
"launch_log": "experiment_suite_launch_logs/imf-p1-ph16-ex16-emb384-l12-ms50k-l20g0-20260404-131223.restartfix-20260404-143827.log",
|
||||||
|
"latest_log_sync": "2026-04-05 00:34:20",
|
||||||
|
"latest_step": 50000,
|
||||||
|
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/wwm232k6190gexnze8mg6",
|
||||||
|
"process_running": false,
|
||||||
|
"best_step": 48124,
|
||||||
|
"best_rollout_avg_reward": 561.2,
|
||||||
|
"final_loss": 0.004544622730463743
|
||||||
|
},
|
||||||
|
"ph32_ex8": {
|
||||||
|
"status": "finished",
|
||||||
|
"host": "100.119.99.14",
|
||||||
|
"gpu": 1,
|
||||||
|
"run_name": "imf-p1-ph32-ex08-emb384-l12-ms50k-l20g1-20260404-131223",
|
||||||
|
"workdir": "/home/droid/roboimi_suite_20260404",
|
||||||
|
"dataset_dir": "/home/droid/sim_dataset/sim_transfer",
|
||||||
|
"log_path": "/home/droid/roboimi_suite_20260404/runs/imf-p1-ph32-ex08-emb384-l12-ms50k-l20g1-20260404-131223/train_vla.log",
|
||||||
|
"run_dir": "/home/droid/roboimi_suite_20260404/runs/imf-p1-ph32-ex08-emb384-l12-ms50k-l20g1-20260404-131223",
|
||||||
|
"pred_horizon": 32,
|
||||||
|
"num_action_steps": 8,
|
||||||
|
"pid": 90173,
|
||||||
|
"launch_log": "experiment_suite_launch_logs/imf-p1-ph32-ex08-emb384-l12-ms50k-l20g1-20260404-131223.restartfix-20260404-143827.log",
|
||||||
|
"latest_log_sync": "2026-04-05 00:34:20",
|
||||||
|
"latest_step": 50000,
|
||||||
|
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/o5y2xjb2rsb3lmfcuhy4p",
|
||||||
|
"process_running": false,
|
||||||
|
"best_step": 43749,
|
||||||
|
"best_rollout_avg_reward": 361.6,
|
||||||
|
"final_loss": 0.004788532387465239
|
||||||
|
},
|
||||||
|
"ph32_ex16": {
|
||||||
|
"status": "finished",
|
||||||
|
"host": "100.119.99.14",
|
||||||
|
"gpu": 2,
|
||||||
|
"run_name": "imf-p1-ph32-ex16-emb384-l12-ms50k-l20g2-20260404-131223",
|
||||||
|
"workdir": "/home/droid/roboimi_suite_20260404",
|
||||||
|
"dataset_dir": "/home/droid/sim_dataset/sim_transfer",
|
||||||
|
"log_path": "/home/droid/roboimi_suite_20260404/runs/imf-p1-ph32-ex16-emb384-l12-ms50k-l20g2-20260404-131223/train_vla.log",
|
||||||
|
"run_dir": "/home/droid/roboimi_suite_20260404/runs/imf-p1-ph32-ex16-emb384-l12-ms50k-l20g2-20260404-131223",
|
||||||
|
"pred_horizon": 32,
|
||||||
|
"num_action_steps": 16,
|
||||||
|
"pid": 90175,
|
||||||
|
"launch_log": "experiment_suite_launch_logs/imf-p1-ph32-ex16-emb384-l12-ms50k-l20g2-20260404-131223.restartfix-20260404-143827.log",
|
||||||
|
"latest_log_sync": "2026-04-05 00:34:20",
|
||||||
|
"latest_step": 50000,
|
||||||
|
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/54cjpgba9eqsopdm0l8d3",
|
||||||
|
"process_running": false,
|
||||||
|
"best_step": 48124,
|
||||||
|
"best_rollout_avg_reward": 239.6,
|
||||||
|
"final_loss": 0.0038348555099219084
|
||||||
|
},
|
||||||
|
"ph32_ex32": {
|
||||||
|
"status": "finished",
|
||||||
|
"host": "local",
|
||||||
|
"gpu": 0,
|
||||||
|
"run_name": "imf-p1-ph32-ex32-emb384-l12-ms50k-5090-20260404-131223",
|
||||||
|
"workdir": "/home/droid/project/roboimi/.worktrees/feat-imf-attnres-policy",
|
||||||
|
"dataset_dir": "/home/droid/project/diana_sim/sim_transfer",
|
||||||
|
"log_path": "/home/droid/project/roboimi/.worktrees/feat-imf-attnres-policy/runs/imf-p1-ph32-ex32-emb384-l12-ms50k-5090-20260404-131223/train_vla.log",
|
||||||
|
"run_dir": "/home/droid/project/roboimi/.worktrees/feat-imf-attnres-policy/runs/imf-p1-ph32-ex32-emb384-l12-ms50k-5090-20260404-131223",
|
||||||
|
"pred_horizon": 32,
|
||||||
|
"num_action_steps": 32,
|
||||||
|
"pid": 1437836,
|
||||||
|
"launch_log": "experiment_suites/2026-04-04-imf-horizon-grid/launch_logs/imf-p1-ph32-ex32-emb384-l12-ms50k-5090-20260404-131223.launch.log",
|
||||||
|
"latest_step": 49900,
|
||||||
|
"latest_log_sync": "2026-04-05 00:34:20",
|
||||||
|
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/ajs2m218jd260hawhy5ns",
|
||||||
|
"process_running": false,
|
||||||
|
"latest_rollout_avg_reward": 513.2,
|
||||||
|
"best_rollout_avg_reward": 513.2,
|
||||||
|
"best_step": 43749,
|
||||||
|
"final_loss": 0.003953303210437298
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"monitor": {
|
||||||
|
"state": "stopped",
|
||||||
|
"pid_local": null,
|
||||||
|
"log_path": "experiment_suites/2026-04-04-imf-horizon-grid/monitor_logs/status-sync-20260404-131223.log",
|
||||||
|
"interval_seconds": 300,
|
||||||
|
"stopped_at": "2026-04-05 00:34:20",
|
||||||
|
"stop_reason": "phase1 suite finalized after all six runs completed"
|
||||||
|
},
|
||||||
|
"debug": {
|
||||||
|
"remote_rollout_failure_20260404": {
|
||||||
|
"root_cause": "eval_vla.py imported raw_action_trajectory_viewer at module import time, which imported mujoco before MUJOCO_GL=egl was set; remote headless rollout then fell back to GLFW/X11 and crashed with mujoco.FatalError: gladLoadGL error during env.reset()->mj.Renderer(...)",
|
||||||
|
"fixed_file": "roboimi/demos/vla_scripts/eval_vla.py",
|
||||||
|
"verification": {
|
||||||
|
"pytest": "tests/test_eval_vla_headless_import.py passed",
|
||||||
|
"remote_eval_5880": "1 episode x 20 steps headless eval passed",
|
||||||
|
"remote_eval_l20": "1 episode x 20 steps headless eval passed"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"phase1_summary_md": "/home/droid/project/roboimi/.worktrees/feat-imf-attnres-policy/experiment_suites/2026-04-04-imf-horizon-grid/phase1_summary.md"
|
||||||
|
}
|
||||||
69
experiment_suites/2026-04-05-camera-ablation-summary.md
Normal file
69
experiment_suites/2026-04-05-camera-ablation-summary.md
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
# Camera Ablation Summary (`pred_horizon=16`, `num_action_steps=8`, ResNet IMF)
|
||||||
|
|
||||||
|
- Generated: 2026-04-05
|
||||||
|
- Common setup: original ResNet vision backbone, `n_emb=384`, `n_layer=12`, `batch_size=80`, `lr=2.5e-4`, `max_steps=50k`, rollout every 5 epochs with 5 episodes, headless eval.
|
||||||
|
- Metric for comparison: `checkpoints/vla_model_best.pt -> rollout_avg_reward`.
|
||||||
|
|
||||||
|
## Leaderboard
|
||||||
|
|
||||||
|
| Rank | Cameras | Best avg_reward | Best step | Final loss | Run name |
|
||||||
|
|---:|---|---:|---:|---:|---|
|
||||||
|
| 1 | `top + front` | **274.8** | 48124 | 0.0056 | `imf-resnet-topfront-2cam-ph16-ex08-emb384-l12-ms50k-5090-20260405-085023` |
|
||||||
|
| 2 | `top` | **271.2** | 43749 | 0.0052 | `imf-resnet-top-1cam-ph16-ex08-emb384-l12-ms50k-l20g4-20260405-125844` |
|
||||||
|
| 3 | `r_vis + front` | **244.0** | 21874 | 0.0043 | `imf-resnet-frontrvis-2cam-ph16-ex08-emb384-l12-ms50k-l20g1-20260405-102029` |
|
||||||
|
| 4 | `r_vis` | **6.4** | 17499 | 0.0047 | `imf-resnet-rvis-1cam-ph16-ex08-emb384-l12-ms50k-l20g3-20260405-125844` |
|
||||||
|
| 5 | `r_vis + top` | **1.2** | 4374 | 0.0047 | `imf-resnet-rvistop-2cam-ph16-ex08-emb384-l12-ms50k-l20g2-20260405-125844` |
|
||||||
|
| 6 | `front` | **0.0** | 4374 | 0.0074 | `imf-resnet-front-1cam-ph16-ex08-emb384-l12-ms50k-l20g0-20260405-095607` |
|
||||||
|
|
||||||
|
## Main takeaways
|
||||||
|
|
||||||
|
1. **`top` 是最关键的单相机视角**:`top only = 271.2`,几乎与 `top + front = 274.8` 持平。
|
||||||
|
2. **`front` 单独几乎没有效用**:`front only = 0.0`。
|
||||||
|
3. **`r_vis` 单独也基本无效**:`r_vis only = 6.4`。
|
||||||
|
4. **`r_vis + front` 可以显著优于单独 `front` / `r_vis`**,说明这两个视角有一定互补性,但仍明显弱于任何包含 `top` 且表现正常的配置。
|
||||||
|
5. **`r_vis + top` 的结果异常差**:只有 `1.2`,远低于 `top only = 271.2`。这说明简单加入 `r_vis` 并不保证增益,甚至可能破坏当前设置下的学习。
|
||||||
|
6. **训练 loss 与 rollout reward 明显不一致**:例如 `r_vis + top` 和 `r_vis only` 的 final loss 都不高,但 reward 很差,因此本组实验必须以 rollout reward 而不是 loss 选型。
|
||||||
|
|
||||||
|
## Horizontal comparison views
|
||||||
|
|
||||||
|
### Single-camera comparison
|
||||||
|
|
||||||
|
- `top`: **271.2**
|
||||||
|
- `r_vis`: **6.4**
|
||||||
|
- `front`: **0.0**
|
||||||
|
|
||||||
|
结论:**`top >>> r_vis > front`**。
|
||||||
|
|
||||||
|
### Two-camera comparison
|
||||||
|
|
||||||
|
- `top + front`: **274.8**
|
||||||
|
- `r_vis + front`: **244.0**
|
||||||
|
- `r_vis + top`: **1.2**
|
||||||
|
|
||||||
|
结论:
|
||||||
|
- **最稳妥的双相机组合是 `top + front`**。
|
||||||
|
- `r_vis + front` 有效,但不如 `top + front`。
|
||||||
|
- `r_vis + top` 在当前设置下几乎失效。
|
||||||
|
|
||||||
|
### Incremental effect of adding a second view
|
||||||
|
|
||||||
|
- 在 `top` 基础上加 `front`:`271.2 -> 274.8`,**增益很小**。
|
||||||
|
- 在 `front` 基础上加 `r_vis`:`0.0 -> 244.0`,**增益很大**。
|
||||||
|
- 在 `top` 基础上加 `r_vis`:`271.2 -> 1.2`,**显著退化**。
|
||||||
|
|
||||||
|
## Practical recommendation
|
||||||
|
|
||||||
|
如果只从这 6 个实验里选:
|
||||||
|
|
||||||
|
- **首选**:`top + front`
|
||||||
|
- **次选**:`top only`
|
||||||
|
- 如果必须不用 `top`:`r_vis + front` 明显优于 `front only` / `r_vis only`
|
||||||
|
- **不建议**:`r_vis + top`
|
||||||
|
|
||||||
|
## Note relative to previous 3-camera baseline
|
||||||
|
|
||||||
|
此前 3 相机 `[r_vis, top, front]` 的最佳 reward 为 **610.8**。
|
||||||
|
因此这次 6 个 camera ablation 的最佳结果(`top + front = 274.8`)说明:
|
||||||
|
|
||||||
|
- 当前这个训练批次里,**去掉任意一个视角都会显著低于之前的 3 相机最优结果**;
|
||||||
|
- 但在去掉视角的约束下,**`top` 仍然是最核心的保留对象**。
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
# CHECKLIST
|
||||||
|
|
||||||
|
- [x] Confirm remote free GPU
|
||||||
|
- [x] Create front-only run contract
|
||||||
|
- [x] Remote smoke test passes
|
||||||
|
- [x] Launch 50k run on remote GPU0
|
||||||
|
- [x] Record pid / log / SwanLab
|
||||||
|
- [x] Report status back to user
|
||||||
28
experiment_suites/2026-04-05-front-only-resnet-1cam/PLAN.md
Normal file
28
experiment_suites/2026-04-05-front-only-resnet-1cam/PLAN.md
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
# PLAN
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
Train a 50k-step IMF baseline with the original ResNet vision backbone, using only the `front` camera as image conditioning.
|
||||||
|
|
||||||
|
## Fixed comparison contract
|
||||||
|
- Same as the active `top/front` run except image input is reduced to `[front]`
|
||||||
|
- Agent: `resnet_imf_attnres`
|
||||||
|
- Vision backbone mode: `resnet`
|
||||||
|
- `pred_horizon=16`, `num_action_steps=8`
|
||||||
|
- `n_emb=384`, `n_layer=12`, `n_head=1`, `n_kv_head=1`
|
||||||
|
- `inference_steps=1`
|
||||||
|
- `batch_size=80`, `lr=2.5e-4`, cosine, warmup=2000
|
||||||
|
- dataset: `/home/droid/sim_dataset/sim_transfer`
|
||||||
|
- cameras: `[front]` only
|
||||||
|
- rollout every 5 epochs with 5 episodes, headless
|
||||||
|
|
||||||
|
## Resource plan
|
||||||
|
- Host: `100.119.99.14`
|
||||||
|
- GPU: `0`
|
||||||
|
|
||||||
|
## Important dimension override
|
||||||
|
- Single-camera visual cond dim = `64 + 16 = 80`, so override `agent.head.cond_dim=80` and `agent.num_cams=1`.
|
||||||
|
|
||||||
|
## Execution path
|
||||||
|
1. 2-step smoke test on remote GPU0.
|
||||||
|
2. If smoke passes, launch 50k main run with SwanLab.
|
||||||
|
3. Record pid / run_dir / log / URL locally.
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
# Notes
|
||||||
|
|
||||||
|
- 2026-04-05 09:55:27: remote 2-step smoke passed on `100.119.99.14` GPU0 with `front` only, batch=80, no OOM.
|
||||||
|
- 2026-04-05 09:56:26: launched main run `imf-resnet-front-1cam-ph16-ex08-emb384-l12-ms50k-l20g0-20260405-095607`.
|
||||||
|
- 2026-04-05 09:57:36: confirmed training is stable through step 200, latest loss 0.2830.
|
||||||
|
- SwanLab: https://swanlab.cn/@game-loader/roboimi-vla/runs/7kdii8oc6tjkcyu5y0lwq
|
||||||
@@ -0,0 +1,51 @@
|
|||||||
|
{
|
||||||
|
"suite_name": "2026-04-05-front-only-resnet-1cam",
|
||||||
|
"updated_at": "2026-04-05 09:57:36",
|
||||||
|
"phase": "running",
|
||||||
|
"baseline_reference": {
|
||||||
|
"source_run": "imf-resnet-topfront-2cam-ph16-ex08-emb384-l12-ms50k-5090-20260405-085023",
|
||||||
|
"notes": "Same hyperparameters as the active top/front run, but image input is reduced to [front] only."
|
||||||
|
},
|
||||||
|
"smoke_test": {
|
||||||
|
"status": "passed",
|
||||||
|
"host": "100.119.99.14",
|
||||||
|
"gpu": 0,
|
||||||
|
"run_dir": "/home/droid/roboimi_suite_20260404/runs/smoke-frontonly-resnet-ph16-ex08-20260405-095509",
|
||||||
|
"batch_size": 80,
|
||||||
|
"max_steps": 2,
|
||||||
|
"note": "2-step remote CUDA smoke passed on L20 GPU0 without OOM."
|
||||||
|
},
|
||||||
|
"main_run": {
|
||||||
|
"status": "running",
|
||||||
|
"host": "100.119.99.14",
|
||||||
|
"gpu": 0,
|
||||||
|
"launch_pid": 158874,
|
||||||
|
"pid": 158877,
|
||||||
|
"run_name": "imf-resnet-front-1cam-ph16-ex08-emb384-l12-ms50k-l20g0-20260405-095607",
|
||||||
|
"run_dir": "/home/droid/roboimi_suite_20260404/runs/imf-resnet-front-1cam-ph16-ex08-emb384-l12-ms50k-l20g0-20260405-095607",
|
||||||
|
"log_path": "/home/droid/roboimi_suite_20260404/runs/imf-resnet-front-1cam-ph16-ex08-emb384-l12-ms50k-l20g0-20260405-095607/train_vla.log",
|
||||||
|
"launch_log": "/home/droid/roboimi_suite_20260404/experiment_suite_launch_logs/imf-resnet-front-1cam-ph16-ex08-emb384-l12-ms50k-l20g0-20260405-095607.launch.log",
|
||||||
|
"dataset_dir": "/home/droid/sim_dataset/sim_transfer",
|
||||||
|
"camera_names": [
|
||||||
|
"front"
|
||||||
|
],
|
||||||
|
"pred_horizon": 16,
|
||||||
|
"num_action_steps": 8,
|
||||||
|
"head_cond_dim": 80,
|
||||||
|
"head_n_emb": 384,
|
||||||
|
"head_n_layer": 12,
|
||||||
|
"vision_backbone_mode": "resnet",
|
||||||
|
"pretrained_backbone_weights": null,
|
||||||
|
"freeze_backbone": false,
|
||||||
|
"batch_size": 80,
|
||||||
|
"lr": 0.00025,
|
||||||
|
"num_workers": 12,
|
||||||
|
"max_steps": 50000,
|
||||||
|
"rollout_val_freq_epochs": 5,
|
||||||
|
"rollout_num_episodes": 5,
|
||||||
|
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/7kdii8oc6tjkcyu5y0lwq",
|
||||||
|
"latest_step": 200,
|
||||||
|
"latest_loss": 0.283,
|
||||||
|
"process_running": true
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
# CHECKLIST
|
||||||
|
|
||||||
|
- [x] Confirm camera mapping (`right` -> `r_vis`)
|
||||||
|
- [x] Create front+r_vis run contract
|
||||||
|
- [x] Remote smoke test passes
|
||||||
|
- [x] Launch 50k run on remote GPU1
|
||||||
|
- [x] Record pid / log / SwanLab
|
||||||
|
- [x] Report status back to user
|
||||||
23
experiment_suites/2026-04-05-front-rvis-resnet-2cam/PLAN.md
Normal file
23
experiment_suites/2026-04-05-front-rvis-resnet-2cam/PLAN.md
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
# PLAN
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
Train a 50k-step IMF baseline with the original ResNet vision backbone, using `front` + `r_vis` cameras only.
|
||||||
|
|
||||||
|
## Fixed comparison contract
|
||||||
|
- Same hyperparameters as the active top/front and front-only runs
|
||||||
|
- Agent: `resnet_imf_attnres`
|
||||||
|
- Vision backbone mode: `resnet`
|
||||||
|
- `pred_horizon=16`, `num_action_steps=8`
|
||||||
|
- `n_emb=384`, `n_layer=12`, `n_head=1`, `n_kv_head=1`
|
||||||
|
- `inference_steps=1`
|
||||||
|
- `batch_size=80`, `lr=2.5e-4`, cosine warmup 2000
|
||||||
|
- dataset: `/home/droid/sim_dataset/sim_transfer`
|
||||||
|
- cameras: `[r_vis, front]`
|
||||||
|
- rollout every 5 epochs with 5 episodes, headless
|
||||||
|
|
||||||
|
## Important dimension override
|
||||||
|
- Two-camera visual cond dim = `64*2 + 16 = 144`, so set `agent.num_cams=2`, `agent.head.cond_dim=144`.
|
||||||
|
|
||||||
|
## Resource plan
|
||||||
|
- Host: `100.119.99.14`
|
||||||
|
- GPU: `1`
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
# Notes
|
||||||
|
|
||||||
|
- 2026-04-05 10:20:09: remote 2-step smoke passed on `100.119.99.14` GPU1 with `r_vis + front`, batch=80, no OOM.
|
||||||
|
- 2026-04-05 10:20:49: launched main run `imf-resnet-frontrvis-2cam-ph16-ex08-emb384-l12-ms50k-l20g1-20260405-102029`.
|
||||||
|
- 2026-04-05 10:22:03: confirmed training is stable through step 200, latest loss 0.3321.
|
||||||
|
- SwanLab: https://swanlab.cn/@game-loader/roboimi-vla/runs/3fyzjfdcbiq7frtbqv6ss
|
||||||
@@ -0,0 +1,55 @@
|
|||||||
|
{
|
||||||
|
"suite_name": "2026-04-05-front-rvis-resnet-2cam",
|
||||||
|
"updated_at": "2026-04-05 10:22:03",
|
||||||
|
"phase": "running",
|
||||||
|
"interpretation": {
|
||||||
|
"right_camera_name": "r_vis"
|
||||||
|
},
|
||||||
|
"baseline_reference": {
|
||||||
|
"source_run": "imf-resnet-topfront-2cam-ph16-ex08-emb384-l12-ms50k-5090-20260405-085023",
|
||||||
|
"notes": "Same hyperparameters as the active top/front run, replacing top with r_vis."
|
||||||
|
},
|
||||||
|
"smoke_test": {
|
||||||
|
"status": "passed",
|
||||||
|
"host": "100.119.99.14",
|
||||||
|
"gpu": 1,
|
||||||
|
"run_dir": "/home/droid/roboimi_suite_20260404/runs/smoke-frontrvis-resnet-ph16-ex08-20260405-102001",
|
||||||
|
"batch_size": 80,
|
||||||
|
"max_steps": 2,
|
||||||
|
"note": "2-step remote CUDA smoke passed on L20 GPU1 without OOM."
|
||||||
|
},
|
||||||
|
"main_run": {
|
||||||
|
"status": "running",
|
||||||
|
"host": "100.119.99.14",
|
||||||
|
"gpu": 1,
|
||||||
|
"launch_pid": 159910,
|
||||||
|
"pid": 159913,
|
||||||
|
"run_name": "imf-resnet-frontrvis-2cam-ph16-ex08-emb384-l12-ms50k-l20g1-20260405-102029",
|
||||||
|
"run_dir": "/home/droid/roboimi_suite_20260404/runs/imf-resnet-frontrvis-2cam-ph16-ex08-emb384-l12-ms50k-l20g1-20260405-102029",
|
||||||
|
"log_path": "/home/droid/roboimi_suite_20260404/runs/imf-resnet-frontrvis-2cam-ph16-ex08-emb384-l12-ms50k-l20g1-20260405-102029/train_vla.log",
|
||||||
|
"launch_log": "/home/droid/roboimi_suite_20260404/experiment_suite_launch_logs/imf-resnet-frontrvis-2cam-ph16-ex08-emb384-l12-ms50k-l20g1-20260405-102029.launch.log",
|
||||||
|
"dataset_dir": "/home/droid/sim_dataset/sim_transfer",
|
||||||
|
"camera_names": [
|
||||||
|
"r_vis",
|
||||||
|
"front"
|
||||||
|
],
|
||||||
|
"pred_horizon": 16,
|
||||||
|
"num_action_steps": 8,
|
||||||
|
"head_cond_dim": 144,
|
||||||
|
"head_n_emb": 384,
|
||||||
|
"head_n_layer": 12,
|
||||||
|
"vision_backbone_mode": "resnet",
|
||||||
|
"pretrained_backbone_weights": null,
|
||||||
|
"freeze_backbone": false,
|
||||||
|
"batch_size": 80,
|
||||||
|
"lr": 0.00025,
|
||||||
|
"num_workers": 12,
|
||||||
|
"max_steps": 50000,
|
||||||
|
"rollout_val_freq_epochs": 5,
|
||||||
|
"rollout_num_episodes": 5,
|
||||||
|
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/3fyzjfdcbiq7frtbqv6ss",
|
||||||
|
"latest_step": 200,
|
||||||
|
"latest_loss": 0.3321,
|
||||||
|
"process_running": true
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
{
|
||||||
|
"suite_name": "2026-04-05-full-attnres-vision-phase2",
|
||||||
|
"created_at": "2026-04-05 00:12:14",
|
||||||
|
"phase": "phase2_running",
|
||||||
|
"baseline_reference": {
|
||||||
|
"run_id": "ph16_ex8",
|
||||||
|
"best_rollout_avg_reward": 610.8,
|
||||||
|
"best_step": 21874,
|
||||||
|
"run_dir": "/home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223"
|
||||||
|
},
|
||||||
|
"candidate": {
|
||||||
|
"run_name": "imf-p2-full-attnres-vision-ph16-ex08-emb384-l12-ms50k-20260405-001214",
|
||||||
|
"host": "local",
|
||||||
|
"gpu": 0,
|
||||||
|
"pred_horizon": 16,
|
||||||
|
"num_action_steps": 8,
|
||||||
|
"vision_backbone_mode": "attnres_resnet",
|
||||||
|
"notes": "Full-AttnRes vision backbone replacing ResNet residual units; IMF head unchanged."
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,9 @@
|
|||||||
|
# Full-AttnRes Vision Phase-2
|
||||||
|
|
||||||
|
- Created: 2026-04-05 00:12:14
|
||||||
|
- Baseline reference: ph16_ex8 best avg_reward=610.8
|
||||||
|
- Candidate run: imf-p2-full-attnres-vision-ph16-ex08-emb384-l12-ms50k-20260405-001214
|
||||||
|
- 2026-04-05 00:23:03: batch=80 OOM on both 5090 and L20; using validated fallback batch=40, lr=1.25e-4 on remote L20 GPU3.
|
||||||
|
- 2026-04-05 00:24:24: launching candidate imf-p2-full-attnres-vision-ph16-ex08-emb384-l12-b40-lr1p25e4-ms50k-l20g3-20260405-002424 on 100.119.99.14 GPU3 with batch=40 lr=1.25e-4.
|
||||||
|
- 2026-04-05 00:27:17: remote phase2 run is active on 100.119.99.14 GPU3, validated at least to step 200. SwanLab: https://swanlab.cn/@game-loader/roboimi-vla/runs/xy7fjdmn0stdr19eu3gub
|
||||||
|
- 2026-04-05 00:36:54: latest confirmed progress is step 1300 on 100.119.99.14 GPU3; first rollout not reached yet.
|
||||||
@@ -0,0 +1,32 @@
|
|||||||
|
{
|
||||||
|
"suite_name": "2026-04-05-full-attnres-vision-phase2",
|
||||||
|
"updated_at": "2026-04-05 00:36:54",
|
||||||
|
"phase": "phase2_running",
|
||||||
|
"baseline_reference": {
|
||||||
|
"run_id": "ph16_ex8",
|
||||||
|
"best_rollout_avg_reward": 610.8,
|
||||||
|
"best_step": 21874,
|
||||||
|
"run_dir": "/home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223"
|
||||||
|
},
|
||||||
|
"candidate": {
|
||||||
|
"run_name": "imf-p2-full-attnres-vision-ph16-ex08-emb384-l12-b40-lr1p25e4-ms50k-l20g3-20260405-002424",
|
||||||
|
"host": "100.119.99.14",
|
||||||
|
"gpu": 3,
|
||||||
|
"pred_horizon": 16,
|
||||||
|
"num_action_steps": 8,
|
||||||
|
"vision_backbone_mode": "attnres_resnet",
|
||||||
|
"notes": "Full-AttnRes vision backbone replacing ResNet residual units; IMF head unchanged.",
|
||||||
|
"status": "running",
|
||||||
|
"run_dir": "/home/droid/roboimi_suite_20260404/runs/imf-p2-full-attnres-vision-ph16-ex08-emb384-l12-b40-lr1p25e4-ms50k-l20g3-20260405-002424",
|
||||||
|
"log_path": "/home/droid/roboimi_suite_20260404/runs/imf-p2-full-attnres-vision-ph16-ex08-emb384-l12-b40-lr1p25e4-ms50k-l20g3-20260405-002424/train_vla.log",
|
||||||
|
"pid": 151187,
|
||||||
|
"batch_size": 40,
|
||||||
|
"lr": 0.000125,
|
||||||
|
"num_workers": 12,
|
||||||
|
"launch_log": "/home/droid/roboimi_suite_20260404/experiment_suite_launch_logs/imf-p2-full-attnres-vision-ph16-ex08-emb384-l12-b40-lr1p25e4-ms50k-l20g3-20260405-002424.launch.log",
|
||||||
|
"note": "Local 5090 and remote L20 both OOM at batch=80; switched to batch=40 and linearly scaled lr to 1.25e-4 after smoke validation on L20.",
|
||||||
|
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/xy7fjdmn0stdr19eu3gub",
|
||||||
|
"latest_step": 1300,
|
||||||
|
"latest_log_sync": "2026-04-05 00:36:54"
|
||||||
|
}
|
||||||
|
}
|
||||||
73
experiment_suites/2026-04-05-lewm-vit-transfer/manifest.json
Normal file
73
experiment_suites/2026-04-05-lewm-vit-transfer/manifest.json
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
{
|
||||||
|
"date": "2026-04-06",
|
||||||
|
"branch": "feat-imf-attnres-policy",
|
||||||
|
"worktree": "/home/droid/project/roboimi/.worktrees/feat-imf-attnres-policy",
|
||||||
|
"model": "LEWM ViT frozen visual encoder + IMF AttnRes diffusion head",
|
||||||
|
"checkpoint_path": "/home/droid/le-wm/lewm-sim-transfer/pa1w85md8jop6bvol8oxp/checkpoints/epoch=99-step=47800.ckpt",
|
||||||
|
"visual_contract": {
|
||||||
|
"input_camera_names": ["r_vis", "top", "front"],
|
||||||
|
"fused_camera_names": ["front", "top", "r_vis"],
|
||||||
|
"joint_output_dim": 192,
|
||||||
|
"freeze_backbone": true,
|
||||||
|
"dataset_image_resize_shape": null,
|
||||||
|
"eval_image_resize_shape": [256, 256],
|
||||||
|
"fused_short_side_resize": 224
|
||||||
|
},
|
||||||
|
"training_contract": {
|
||||||
|
"pred_horizon": 16,
|
||||||
|
"num_action_steps": 8,
|
||||||
|
"max_steps": 50000,
|
||||||
|
"rollout_val_freq_epochs": 5,
|
||||||
|
"rollout_num_episodes": 10,
|
||||||
|
"batch_size": 80,
|
||||||
|
"lr": 0.00025,
|
||||||
|
"num_workers": 12,
|
||||||
|
"scheduler_type": "cosine",
|
||||||
|
"warmup_steps": 2000,
|
||||||
|
"min_lr": 1e-06,
|
||||||
|
"weight_decay": 1e-05,
|
||||||
|
"grad_clip": 1.0
|
||||||
|
},
|
||||||
|
"verification": {
|
||||||
|
"local_tests": "38 passed",
|
||||||
|
"remote_dataset_shape": [2, 3, 256, 256],
|
||||||
|
"remote_eval_prepared_shape": [3, 256, 256],
|
||||||
|
"remote_smoke_run": {
|
||||||
|
"run_name": "smoke-lewm-imf-rawpath-emb384-20260406-002002",
|
||||||
|
"result": "passed",
|
||||||
|
"details": "2-step train + checkpoint-triggered 1-episode headless rollout succeeded with corrected raw256 path"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"superseded_runs": [
|
||||||
|
{
|
||||||
|
"run_name": "lewm-vit-imf-sim-transfer-emb384-l12-ph16-ex08-step50k-roll10-5880g0-20260405-201914",
|
||||||
|
"reason": "stopped due to incorrect early per-camera 224 resize"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"run_name": "lewm-vit-imf-sim-transfer-emb256-l12-ph16-ex08-step50k-roll10-5880g1-20260405-201914",
|
||||||
|
"reason": "stopped due to incorrect early per-camera 224 resize"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"full_runs": [
|
||||||
|
{
|
||||||
|
"host": "100.73.14.65",
|
||||||
|
"gpu": 0,
|
||||||
|
"run_name": "lewm-vit-imf-raw256fix-sim-transfer-emb384-l12-ph16-ex08-step50k-roll10-5880g0-20260406-002124",
|
||||||
|
"pid": 1058589,
|
||||||
|
"log_path": "/home/droid/roboimi_suite_20260404/experiment_suite_launch_logs/lewm-vit-imf-raw256fix-sim-transfer-emb384-l12-ph16-ex08-step50k-roll10-5880g0-20260406-002124.launch.log",
|
||||||
|
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/y5tzgqe0u966w9ak41i31",
|
||||||
|
"head_n_emb": 384,
|
||||||
|
"head_n_layer": 12
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"host": "100.73.14.65",
|
||||||
|
"gpu": 1,
|
||||||
|
"run_name": "lewm-vit-imf-raw256fix-sim-transfer-emb256-l12-ph16-ex08-step50k-roll10-5880g1-20260406-002124",
|
||||||
|
"pid": 1058590,
|
||||||
|
"log_path": "/home/droid/roboimi_suite_20260404/experiment_suite_launch_logs/lewm-vit-imf-raw256fix-sim-transfer-emb256-l12-ph16-ex08-step50k-roll10-5880g1-20260406-002124.launch.log",
|
||||||
|
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/2esr9y7t2dgesstgrn5i6",
|
||||||
|
"head_n_emb": 256,
|
||||||
|
"head_n_layer": 12
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
25
experiment_suites/2026-04-05-lewm-vit-transfer/notes.md
Normal file
25
experiment_suites/2026-04-05-lewm-vit-transfer/notes.md
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
# 2026-04-06 LEWM ViT Transfer Notes
|
||||||
|
|
||||||
|
## Root-cause fix
|
||||||
|
|
||||||
|
The first LEWM runs were stopped because the data path still resized each camera view to `224x224` **before** multiview fusion. That preserved the final tensor shape but broke the original LEWM geometry.
|
||||||
|
|
||||||
|
Corrected path now is:
|
||||||
|
|
||||||
|
- **Training dataset**: keep stored per-view `256x256` images (`data.image_resize_shape=null` at launch; dataset instantiate override is `None` for LEWM)
|
||||||
|
- **Eval rollout input**: resize live MuJoCo `480x640` camera images to `256x256` per view
|
||||||
|
- **Backbone**: fuse `front, top, r_vis` on the LEWM axis, then resize fused short side to `224`
|
||||||
|
|
||||||
|
## Verification
|
||||||
|
|
||||||
|
- Local tests passed (`38 passed` across the focused suite)
|
||||||
|
- Remote check:
|
||||||
|
- dataset sample image shape: `(2, 3, 256, 256)`
|
||||||
|
- eval-prepared live frame shape: `(3, 256, 256)`
|
||||||
|
- Remote smoke passed with real checkpoint:
|
||||||
|
- `smoke-lewm-imf-rawpath-emb384-20260406-002002`
|
||||||
|
|
||||||
|
## Current runs
|
||||||
|
|
||||||
|
- `lewm-vit-imf-raw256fix-sim-transfer-emb384-l12-ph16-ex08-step50k-roll10-5880g0-20260406-002124`
|
||||||
|
- `lewm-vit-imf-raw256fix-sim-transfer-emb256-l12-ph16-ex08-step50k-roll10-5880g1-20260406-002124`
|
||||||
19
experiment_suites/2026-04-05-lewm-vit-transfer/status.json
Normal file
19
experiment_suites/2026-04-05-lewm-vit-transfer/status.json
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
{
|
||||||
|
"status": "running",
|
||||||
|
"updated_at": "2026-04-06T00:22:10+08:00",
|
||||||
|
"remote_host": "100.73.14.65",
|
||||||
|
"runs": [
|
||||||
|
{
|
||||||
|
"run_name": "lewm-vit-imf-raw256fix-sim-transfer-emb384-l12-ph16-ex08-step50k-roll10-5880g0-20260406-002124",
|
||||||
|
"pid": 1058589,
|
||||||
|
"gpu": 0,
|
||||||
|
"state": "running"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"run_name": "lewm-vit-imf-raw256fix-sim-transfer-emb256-l12-ph16-ex08-step50k-roll10-5880g1-20260406-002124",
|
||||||
|
"pid": 1058590,
|
||||||
|
"gpu": 1,
|
||||||
|
"state": "running"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
# CHECKLIST
|
||||||
|
|
||||||
|
- [x] Create run contract
|
||||||
|
- [x] Remote smoke test passes
|
||||||
|
- [x] Launch 50k main run
|
||||||
|
- [x] Record pid / log / SwanLab
|
||||||
|
- [x] Report status back to user
|
||||||
12
experiment_suites/2026-04-05-rvis-only-resnet-1cam/PLAN.md
Normal file
12
experiment_suites/2026-04-05-rvis-only-resnet-1cam/PLAN.md
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
# PLAN
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
Train a 50k-step IMF baseline with the original ResNet vision backbone using r_vis only as the only image conditioning.
|
||||||
|
|
||||||
|
## Fixed comparison contract
|
||||||
|
- same hyperparameters as the active top/front run
|
||||||
|
- cameras: ['r_vis']
|
||||||
|
- num_cams=1
|
||||||
|
- head.cond_dim=80
|
||||||
|
- host: 100.119.99.14
|
||||||
|
- gpu: 3
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
# Notes
|
||||||
|
|
||||||
|
- 2026-04-05 12:58:22: smoke passed for ['r_vis'] on 100.119.99.14 GPU3.
|
||||||
|
- 2026-04-05 12:59:24: launched main run `imf-resnet-rvis-1cam-ph16-ex08-emb384-l12-ms50k-l20g3-20260405-125844`.
|
||||||
|
- 2026-04-05 13:01:20: latest confirmed progress step=400, loss=0.1165.
|
||||||
|
- SwanLab: https://swanlab.cn/@game-loader/roboimi-vla/runs/qnuh7vln9mqomxxldyecq
|
||||||
@@ -0,0 +1,47 @@
|
|||||||
|
{
|
||||||
|
"suite_name": "2026-04-05-rvis-only-resnet-1cam",
|
||||||
|
"updated_at": "2026-04-05 13:01:20",
|
||||||
|
"phase": "running",
|
||||||
|
"smoke_test": {
|
||||||
|
"status": "passed",
|
||||||
|
"host": "100.119.99.14",
|
||||||
|
"gpu": 3,
|
||||||
|
"run_dir": "/home/droid/roboimi_suite_20260404/runs/smoke-rvisonly-resnet-ph16-ex08-20260405-125812",
|
||||||
|
"batch_size": 80,
|
||||||
|
"max_steps": 2,
|
||||||
|
"note": "2-step remote CUDA smoke passed without OOM."
|
||||||
|
},
|
||||||
|
"main_run": {
|
||||||
|
"status": "running",
|
||||||
|
"host": "100.119.99.14",
|
||||||
|
"gpu": 3,
|
||||||
|
"launch_pid": 164812,
|
||||||
|
"pid": 164816,
|
||||||
|
"run_name": "imf-resnet-rvis-1cam-ph16-ex08-emb384-l12-ms50k-l20g3-20260405-125844",
|
||||||
|
"run_dir": "/home/droid/roboimi_suite_20260404/runs/imf-resnet-rvis-1cam-ph16-ex08-emb384-l12-ms50k-l20g3-20260405-125844",
|
||||||
|
"log_path": "/home/droid/roboimi_suite_20260404/runs/imf-resnet-rvis-1cam-ph16-ex08-emb384-l12-ms50k-l20g3-20260405-125844/train_vla.log",
|
||||||
|
"launch_log": "/home/droid/roboimi_suite_20260404/experiment_suite_launch_logs/imf-resnet-rvis-1cam-ph16-ex08-emb384-l12-ms50k-l20g3-20260405-125844.launch.log",
|
||||||
|
"dataset_dir": "/home/droid/sim_dataset/sim_transfer",
|
||||||
|
"camera_names": [
|
||||||
|
"r_vis"
|
||||||
|
],
|
||||||
|
"pred_horizon": 16,
|
||||||
|
"num_action_steps": 8,
|
||||||
|
"head_cond_dim": 80,
|
||||||
|
"head_n_emb": 384,
|
||||||
|
"head_n_layer": 12,
|
||||||
|
"vision_backbone_mode": "resnet",
|
||||||
|
"pretrained_backbone_weights": null,
|
||||||
|
"freeze_backbone": false,
|
||||||
|
"batch_size": 80,
|
||||||
|
"lr": 0.00025,
|
||||||
|
"num_workers": 12,
|
||||||
|
"max_steps": 50000,
|
||||||
|
"rollout_val_freq_epochs": 5,
|
||||||
|
"rollout_num_episodes": 5,
|
||||||
|
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/qnuh7vln9mqomxxldyecq",
|
||||||
|
"latest_step": 400,
|
||||||
|
"latest_loss": 0.1165,
|
||||||
|
"process_running": true
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
# CHECKLIST
|
||||||
|
|
||||||
|
- [x] Create run contract
|
||||||
|
- [x] Remote smoke test passes
|
||||||
|
- [x] Launch 50k main run
|
||||||
|
- [x] Record pid / log / SwanLab
|
||||||
|
- [x] Report status back to user
|
||||||
12
experiment_suites/2026-04-05-rvistop-resnet-2cam/PLAN.md
Normal file
12
experiment_suites/2026-04-05-rvistop-resnet-2cam/PLAN.md
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
# PLAN
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
Train a 50k-step IMF baseline with the original ResNet vision backbone using r_vis + top as the only image conditioning.
|
||||||
|
|
||||||
|
## Fixed comparison contract
|
||||||
|
- same hyperparameters as the active top/front run
|
||||||
|
- cameras: ['r_vis', 'top']
|
||||||
|
- num_cams=2
|
||||||
|
- head.cond_dim=144
|
||||||
|
- host: 100.119.99.14
|
||||||
|
- gpu: 2
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
# Notes
|
||||||
|
|
||||||
|
- 2026-04-05 12:58:22: smoke passed for ['r_vis', 'top'] on 100.119.99.14 GPU2.
|
||||||
|
- 2026-04-05 12:59:24: launched main run `imf-resnet-rvistop-2cam-ph16-ex08-emb384-l12-ms50k-l20g2-20260405-125844`.
|
||||||
|
- 2026-04-05 13:01:20: latest confirmed progress step=200, loss=0.2845.
|
||||||
|
- SwanLab: https://swanlab.cn/@game-loader/roboimi-vla/runs/umsm6402eb81et7wx7z4a
|
||||||
48
experiment_suites/2026-04-05-rvistop-resnet-2cam/status.json
Normal file
48
experiment_suites/2026-04-05-rvistop-resnet-2cam/status.json
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
{
|
||||||
|
"suite_name": "2026-04-05-rvistop-resnet-2cam",
|
||||||
|
"updated_at": "2026-04-05 13:01:20",
|
||||||
|
"phase": "running",
|
||||||
|
"smoke_test": {
|
||||||
|
"status": "passed",
|
||||||
|
"host": "100.119.99.14",
|
||||||
|
"gpu": 2,
|
||||||
|
"run_dir": "/home/droid/roboimi_suite_20260404/runs/smoke-rvistop-resnet-ph16-ex08-20260405-125812",
|
||||||
|
"batch_size": 80,
|
||||||
|
"max_steps": 2,
|
||||||
|
"note": "2-step remote CUDA smoke passed without OOM."
|
||||||
|
},
|
||||||
|
"main_run": {
|
||||||
|
"status": "running",
|
||||||
|
"host": "100.119.99.14",
|
||||||
|
"gpu": 2,
|
||||||
|
"launch_pid": 164745,
|
||||||
|
"pid": 164749,
|
||||||
|
"run_name": "imf-resnet-rvistop-2cam-ph16-ex08-emb384-l12-ms50k-l20g2-20260405-125844",
|
||||||
|
"run_dir": "/home/droid/roboimi_suite_20260404/runs/imf-resnet-rvistop-2cam-ph16-ex08-emb384-l12-ms50k-l20g2-20260405-125844",
|
||||||
|
"log_path": "/home/droid/roboimi_suite_20260404/runs/imf-resnet-rvistop-2cam-ph16-ex08-emb384-l12-ms50k-l20g2-20260405-125844/train_vla.log",
|
||||||
|
"launch_log": "/home/droid/roboimi_suite_20260404/experiment_suite_launch_logs/imf-resnet-rvistop-2cam-ph16-ex08-emb384-l12-ms50k-l20g2-20260405-125844.launch.log",
|
||||||
|
"dataset_dir": "/home/droid/sim_dataset/sim_transfer",
|
||||||
|
"camera_names": [
|
||||||
|
"r_vis",
|
||||||
|
"top"
|
||||||
|
],
|
||||||
|
"pred_horizon": 16,
|
||||||
|
"num_action_steps": 8,
|
||||||
|
"head_cond_dim": 144,
|
||||||
|
"head_n_emb": 384,
|
||||||
|
"head_n_layer": 12,
|
||||||
|
"vision_backbone_mode": "resnet",
|
||||||
|
"pretrained_backbone_weights": null,
|
||||||
|
"freeze_backbone": false,
|
||||||
|
"batch_size": 80,
|
||||||
|
"lr": 0.00025,
|
||||||
|
"num_workers": 12,
|
||||||
|
"max_steps": 50000,
|
||||||
|
"rollout_val_freq_epochs": 5,
|
||||||
|
"rollout_num_episodes": 5,
|
||||||
|
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/umsm6402eb81et7wx7z4a",
|
||||||
|
"latest_step": 200,
|
||||||
|
"latest_loss": 0.2845,
|
||||||
|
"process_running": true
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
# CHECKLIST
|
||||||
|
|
||||||
|
- [x] Confirm baseline hyperparameters from trusted prior run
|
||||||
|
- [x] Confirm local GPU availability
|
||||||
|
- [x] Smoke test with `top/front` cameras only
|
||||||
|
- [x] Launch 50k run
|
||||||
|
- [x] Record pid / run dir / log path / SwanLab URL
|
||||||
|
- [x] Report status back to user
|
||||||
30
experiment_suites/2026-04-05-top-front-resnet-2cam/PLAN.md
Normal file
30
experiment_suites/2026-04-05-top-front-resnet-2cam/PLAN.md
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
# PLAN
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
Train a 50k-step IMF baseline with the original ResNet vision backbone (no full-AttnRes vision replacement), using only `top` and `front` cameras as image conditioning.
|
||||||
|
|
||||||
|
## Fixed comparison contract
|
||||||
|
- Agent: `resnet_imf_attnres`
|
||||||
|
- Vision backbone mode: `resnet`
|
||||||
|
- `pred_horizon=16`
|
||||||
|
- `num_action_steps=8`
|
||||||
|
- `n_emb=384`, `n_layer=12`, `n_head=1`, `n_kv_head=1`
|
||||||
|
- `inference_steps=1`
|
||||||
|
- `batch_size=80`, `lr=2.5e-4`, cosine scheduler, warmup 2000
|
||||||
|
- dataset: `/home/droid/project/diana_sim/sim_transfer`
|
||||||
|
- cameras: `[top, front]` only
|
||||||
|
- training budget: `max_steps=50000`
|
||||||
|
- rollout validation: every 5 epochs, 5 episodes, headless
|
||||||
|
|
||||||
|
## Resource plan
|
||||||
|
- Host: local
|
||||||
|
- GPU: RTX 5090 (GPU 0)
|
||||||
|
|
||||||
|
## Execution path
|
||||||
|
1. Run a short 2-step smoke test on GPU with the exact 2-camera config.
|
||||||
|
2. If smoke passes, launch the 50k main run with durable log redirection.
|
||||||
|
3. Record run name, pid, log path, and SwanLab URL into suite status.
|
||||||
|
|
||||||
|
## Fallbacks
|
||||||
|
- If batch 80 OOMs, fall back to batch 64 with scaled lr 2.0e-4.
|
||||||
|
- If dataloader startup is unstable, reduce num_workers from 12 to 8.
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
# Notes
|
||||||
|
|
||||||
|
- 2026-04-05 08:50:04: 2-step smoke test passed locally on RTX 5090 with `top/front` cameras, batch=80, no OOM.
|
||||||
|
- 2026-04-05 08:50:42: launched main run `imf-resnet-topfront-2cam-ph16-ex08-emb384-l12-ms50k-5090-20260405-085023` on local GPU0.
|
||||||
|
- SwanLab: https://swanlab.cn/@game-loader/roboimi-vla/runs/vi77mn5dwd19z4nttxab8
|
||||||
@@ -0,0 +1,51 @@
|
|||||||
|
{
|
||||||
|
"suite_name": "2026-04-05-top-front-resnet-2cam",
|
||||||
|
"updated_at": "2026-04-05 08:52:12",
|
||||||
|
"phase": "running",
|
||||||
|
"baseline_reference": {
|
||||||
|
"source_run": "imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223",
|
||||||
|
"best_rollout_avg_reward": 610.8,
|
||||||
|
"best_step": 21874,
|
||||||
|
"notes": "Same IMF baseline as Phase-1 best, but switch cameras from [r_vis, top, front] to [top, front] and keep the original ResNet vision backbone."
|
||||||
|
},
|
||||||
|
"smoke_test": {
|
||||||
|
"status": "passed",
|
||||||
|
"run_dir": "/home/droid/project/roboimi/.worktrees/feat-imf-attnres-policy/runs/smoke-topfront-resnet-ph16-ex08-20260405-085000",
|
||||||
|
"batch_size": 80,
|
||||||
|
"num_workers": 4,
|
||||||
|
"max_steps": 2,
|
||||||
|
"note": "2-step local CUDA smoke passed without OOM using top/front only."
|
||||||
|
},
|
||||||
|
"main_run": {
|
||||||
|
"status": "running",
|
||||||
|
"host": "local",
|
||||||
|
"gpu": 0,
|
||||||
|
"pid": 1693348,
|
||||||
|
"run_name": "imf-resnet-topfront-2cam-ph16-ex08-emb384-l12-ms50k-5090-20260405-085023",
|
||||||
|
"run_dir": "/home/droid/project/roboimi/.worktrees/feat-imf-attnres-policy/runs/imf-resnet-topfront-2cam-ph16-ex08-emb384-l12-ms50k-5090-20260405-085023",
|
||||||
|
"log_path": "/home/droid/project/roboimi/.worktrees/feat-imf-attnres-policy/runs/imf-resnet-topfront-2cam-ph16-ex08-emb384-l12-ms50k-5090-20260405-085023/train_vla.log",
|
||||||
|
"launch_log": "/home/droid/project/roboimi/.worktrees/feat-imf-attnres-policy/experiment_suites/2026-04-05-top-front-resnet-2cam/launch_logs/imf-resnet-topfront-2cam-ph16-ex08-emb384-l12-ms50k-5090-20260405-085023.launch.log",
|
||||||
|
"dataset_dir": "/home/droid/project/diana_sim/sim_transfer",
|
||||||
|
"camera_names": [
|
||||||
|
"top",
|
||||||
|
"front"
|
||||||
|
],
|
||||||
|
"pred_horizon": 16,
|
||||||
|
"num_action_steps": 8,
|
||||||
|
"head_n_emb": 384,
|
||||||
|
"head_n_layer": 12,
|
||||||
|
"vision_backbone_mode": "resnet",
|
||||||
|
"pretrained_backbone_weights": null,
|
||||||
|
"freeze_backbone": false,
|
||||||
|
"batch_size": 80,
|
||||||
|
"lr": 0.00025,
|
||||||
|
"num_workers": 12,
|
||||||
|
"max_steps": 50000,
|
||||||
|
"rollout_val_freq_epochs": 5,
|
||||||
|
"rollout_num_episodes": 5,
|
||||||
|
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/vi77mn5dwd19z4nttxab8",
|
||||||
|
"latest_step": 500,
|
||||||
|
"latest_loss": 0.0978,
|
||||||
|
"process_running": true
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
# CHECKLIST
|
||||||
|
|
||||||
|
- [x] Create run contract
|
||||||
|
- [x] Remote smoke test passes
|
||||||
|
- [x] Launch 50k main run
|
||||||
|
- [x] Record pid / log / SwanLab
|
||||||
|
- [x] Report status back to user
|
||||||
12
experiment_suites/2026-04-05-top-only-resnet-1cam/PLAN.md
Normal file
12
experiment_suites/2026-04-05-top-only-resnet-1cam/PLAN.md
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
# PLAN
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
Train a 50k-step IMF baseline with the original ResNet vision backbone using top only as the only image conditioning.
|
||||||
|
|
||||||
|
## Fixed comparison contract
|
||||||
|
- same hyperparameters as the active top/front run
|
||||||
|
- cameras: ['top']
|
||||||
|
- num_cams=1
|
||||||
|
- head.cond_dim=80
|
||||||
|
- host: 100.119.99.14
|
||||||
|
- gpu: 4
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
# Notes
|
||||||
|
|
||||||
|
- 2026-04-05 12:58:22: smoke passed for ['top'] on 100.119.99.14 GPU4.
|
||||||
|
- 2026-04-05 12:59:24: launched main run `imf-resnet-top-1cam-ph16-ex08-emb384-l12-ms50k-l20g4-20260405-125844`.
|
||||||
|
- 2026-04-05 13:01:20: latest confirmed progress step=400, loss=0.1233.
|
||||||
|
- SwanLab: https://swanlab.cn/@game-loader/roboimi-vla/runs/egzo29l3z9ftsaunhf025
|
||||||
@@ -0,0 +1,47 @@
|
|||||||
|
{
|
||||||
|
"suite_name": "2026-04-05-top-only-resnet-1cam",
|
||||||
|
"updated_at": "2026-04-05 13:01:20",
|
||||||
|
"phase": "running",
|
||||||
|
"smoke_test": {
|
||||||
|
"status": "passed",
|
||||||
|
"host": "100.119.99.14",
|
||||||
|
"gpu": 4,
|
||||||
|
"run_dir": "/home/droid/roboimi_suite_20260404/runs/smoke-toponly-resnet-ph16-ex08-20260405-125812",
|
||||||
|
"batch_size": 80,
|
||||||
|
"max_steps": 2,
|
||||||
|
"note": "2-step remote CUDA smoke passed without OOM."
|
||||||
|
},
|
||||||
|
"main_run": {
|
||||||
|
"status": "running",
|
||||||
|
"host": "100.119.99.14",
|
||||||
|
"gpu": 4,
|
||||||
|
"launch_pid": 164808,
|
||||||
|
"pid": 164813,
|
||||||
|
"run_name": "imf-resnet-top-1cam-ph16-ex08-emb384-l12-ms50k-l20g4-20260405-125844",
|
||||||
|
"run_dir": "/home/droid/roboimi_suite_20260404/runs/imf-resnet-top-1cam-ph16-ex08-emb384-l12-ms50k-l20g4-20260405-125844",
|
||||||
|
"log_path": "/home/droid/roboimi_suite_20260404/runs/imf-resnet-top-1cam-ph16-ex08-emb384-l12-ms50k-l20g4-20260405-125844/train_vla.log",
|
||||||
|
"launch_log": "/home/droid/roboimi_suite_20260404/experiment_suite_launch_logs/imf-resnet-top-1cam-ph16-ex08-emb384-l12-ms50k-l20g4-20260405-125844.launch.log",
|
||||||
|
"dataset_dir": "/home/droid/sim_dataset/sim_transfer",
|
||||||
|
"camera_names": [
|
||||||
|
"top"
|
||||||
|
],
|
||||||
|
"pred_horizon": 16,
|
||||||
|
"num_action_steps": 8,
|
||||||
|
"head_cond_dim": 80,
|
||||||
|
"head_n_emb": 384,
|
||||||
|
"head_n_layer": 12,
|
||||||
|
"vision_backbone_mode": "resnet",
|
||||||
|
"pretrained_backbone_weights": null,
|
||||||
|
"freeze_backbone": false,
|
||||||
|
"batch_size": 80,
|
||||||
|
"lr": 0.00025,
|
||||||
|
"num_workers": 12,
|
||||||
|
"max_steps": 50000,
|
||||||
|
"rollout_val_freq_epochs": 5,
|
||||||
|
"rollout_num_episodes": 5,
|
||||||
|
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/egzo29l3z9ftsaunhf025",
|
||||||
|
"latest_step": 400,
|
||||||
|
"latest_loss": 0.1233,
|
||||||
|
"process_running": true
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -25,7 +25,6 @@ from omegaconf import DictConfig, OmegaConf
|
|||||||
from hydra.utils import instantiate
|
from hydra.utils import instantiate
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from roboimi.envs.double_pos_ctrl_env import make_sim_env
|
|
||||||
from roboimi.utils.act_ex_utils import sample_transfer_pose
|
from roboimi.utils.act_ex_utils import sample_transfer_pose
|
||||||
from roboimi.vla.eval_utils import execute_policy_action
|
from roboimi.vla.eval_utils import execute_policy_action
|
||||||
|
|
||||||
@@ -37,6 +36,20 @@ if not OmegaConf.has_resolver("len"):
|
|||||||
OmegaConf.register_new_resolver("len", lambda x: len(x))
|
OmegaConf.register_new_resolver("len", lambda x: len(x))
|
||||||
|
|
||||||
|
|
||||||
|
def _configure_headless_mujoco_gl(eval_cfg: DictConfig) -> None:
|
||||||
|
if not bool(eval_cfg.get('headless', False)):
|
||||||
|
return
|
||||||
|
if os.environ.get('MUJOCO_GL'):
|
||||||
|
return
|
||||||
|
os.environ['MUJOCO_GL'] = 'egl'
|
||||||
|
log.info('headless eval detected; set MUJOCO_GL=egl')
|
||||||
|
|
||||||
|
|
||||||
|
def make_sim_env(task_name: str, headless: bool = False):
|
||||||
|
from roboimi.envs.double_pos_ctrl_env import make_sim_env as _make_sim_env_impl
|
||||||
|
return _make_sim_env_impl(task_name, headless=headless)
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
ckpt_path: str,
|
ckpt_path: str,
|
||||||
agent_cfg: DictConfig,
|
agent_cfg: DictConfig,
|
||||||
@@ -93,7 +106,11 @@ def load_checkpoint(
|
|||||||
return agent, stats
|
return agent, stats
|
||||||
|
|
||||||
|
|
||||||
def prepare_observation(obs: Dict, camera_names: list) -> Dict:
|
def prepare_observation(
|
||||||
|
obs: Dict,
|
||||||
|
camera_names: list,
|
||||||
|
image_resize_shape: Optional[tuple[int, int]] = (224, 224),
|
||||||
|
) -> Dict:
|
||||||
"""
|
"""
|
||||||
将环境观测转换为 agent 格式。
|
将环境观测转换为 agent 格式。
|
||||||
|
|
||||||
@@ -104,14 +121,13 @@ def prepare_observation(obs: Dict, camera_names: list) -> Dict:
|
|||||||
Returns:
|
Returns:
|
||||||
agent 格式的观测字典
|
agent 格式的观测字典
|
||||||
"""
|
"""
|
||||||
import cv2
|
|
||||||
|
|
||||||
# 转换图像: numpy -> tensor, HWC -> CHW
|
# 转换图像: numpy -> tensor, HWC -> CHW
|
||||||
images = {}
|
images = {}
|
||||||
for cam_name in camera_names:
|
for cam_name in camera_names:
|
||||||
img = obs['images'][cam_name]
|
img = obs['images'][cam_name]
|
||||||
# Resize 到 224x224(与训练时一致)
|
if image_resize_shape is not None:
|
||||||
img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_LINEAR)
|
import cv2
|
||||||
|
img = cv2.resize(img, tuple(image_resize_shape), interpolation=cv2.INTER_LINEAR)
|
||||||
img = rearrange(img, 'h w c -> c h w')
|
img = rearrange(img, 'h w c -> c h w')
|
||||||
img = torch.from_numpy(img / 255.0).float()
|
img = torch.from_numpy(img / 255.0).float()
|
||||||
images[cam_name] = img
|
images[cam_name] = img
|
||||||
@@ -189,10 +205,12 @@ def _resolve_artifact_paths(eval_cfg: DictConfig) -> dict[str, Optional[str]]:
|
|||||||
save_trajectory = bool(
|
save_trajectory = bool(
|
||||||
eval_cfg.get('save_trajectory', False) or eval_cfg.get('save_trajectory_npz', False)
|
eval_cfg.get('save_trajectory', False) or eval_cfg.get('save_trajectory_npz', False)
|
||||||
)
|
)
|
||||||
|
save_trajectory_image = bool(eval_cfg.get('save_trajectory_image', False))
|
||||||
wants_artifacts = any([
|
wants_artifacts = any([
|
||||||
bool(eval_cfg.get('save_artifacts', False)),
|
bool(eval_cfg.get('save_artifacts', False)),
|
||||||
save_timing,
|
save_timing,
|
||||||
save_trajectory,
|
save_trajectory,
|
||||||
|
save_trajectory_image,
|
||||||
bool(eval_cfg.get('record_video', False)),
|
bool(eval_cfg.get('record_video', False)),
|
||||||
])
|
])
|
||||||
output_dir: Optional[Path] = None
|
output_dir: Optional[Path] = None
|
||||||
@@ -218,6 +236,22 @@ def _resolve_artifact_paths(eval_cfg: DictConfig) -> dict[str, Optional[str]]:
|
|||||||
else:
|
else:
|
||||||
raise ValueError('record_video=true requires eval.video_camera_name or a non-empty eval.camera_names')
|
raise ValueError('record_video=true requires eval.video_camera_name or a non-empty eval.camera_names')
|
||||||
|
|
||||||
|
trajectory_image_camera_name = None
|
||||||
|
if save_trajectory_image:
|
||||||
|
configured_camera_name = eval_cfg.get('trajectory_image_camera_name', None)
|
||||||
|
if configured_camera_name is None:
|
||||||
|
configured_camera_name = eval_cfg.get('trajectory_image_camera', None)
|
||||||
|
if configured_camera_name is not None:
|
||||||
|
trajectory_image_camera_name = str(configured_camera_name)
|
||||||
|
elif eval_cfg.get('camera_names'):
|
||||||
|
camera_names = [str(name) for name in eval_cfg.camera_names]
|
||||||
|
trajectory_image_camera_name = 'front' if 'front' in camera_names else camera_names[0]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
'save_trajectory_image=true requires eval.trajectory_image_camera_name '
|
||||||
|
'or a non-empty eval.camera_names'
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'output_dir': str(output_dir) if output_dir is not None else None,
|
'output_dir': str(output_dir) if output_dir is not None else None,
|
||||||
'summary_json': (
|
'summary_json': (
|
||||||
@@ -242,6 +276,7 @@ def _resolve_artifact_paths(eval_cfg: DictConfig) -> dict[str, Optional[str]]:
|
|||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
'video_camera_name': video_camera_name,
|
'video_camera_name': video_camera_name,
|
||||||
|
'trajectory_image_camera_name': trajectory_image_camera_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -270,6 +305,116 @@ def _open_video_writer(output_path: str, frame_size: tuple[int, int], fps: int):
|
|||||||
return writer
|
return writer
|
||||||
|
|
||||||
|
|
||||||
|
def _episode_trajectory_image_path(
|
||||||
|
artifact_paths: dict[str, Optional[str]],
|
||||||
|
episode_idx: int,
|
||||||
|
) -> Optional[str]:
|
||||||
|
output_dir = artifact_paths.get('output_dir')
|
||||||
|
camera_name = artifact_paths.get('trajectory_image_camera_name')
|
||||||
|
if output_dir is None or camera_name is None:
|
||||||
|
return None
|
||||||
|
return str(Path(output_dir) / f'rollout_{camera_name}_ep{episode_idx + 1:02d}_trajectory.png')
|
||||||
|
|
||||||
|
|
||||||
|
def _build_action_trajectory_positions(raw_actions: list[np.ndarray]) -> dict[str, np.ndarray]:
|
||||||
|
if not raw_actions:
|
||||||
|
empty = np.zeros((0, 3), dtype=np.float32)
|
||||||
|
return {'left': empty, 'right': empty}
|
||||||
|
raw_action_array = np.asarray(raw_actions, dtype=np.float32)
|
||||||
|
return {
|
||||||
|
'left': raw_action_array[:, :3].astype(np.float32, copy=True),
|
||||||
|
'right': raw_action_array[:, 7:10].astype(np.float32, copy=True),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _append_capsule_markers_to_scene(scene, markers: list[dict]) -> None:
|
||||||
|
import mujoco
|
||||||
|
|
||||||
|
for marker in markers:
|
||||||
|
if scene.ngeom >= scene.maxgeom:
|
||||||
|
break
|
||||||
|
geom = scene.geoms[scene.ngeom]
|
||||||
|
mujoco.mjv_initGeom(
|
||||||
|
geom,
|
||||||
|
mujoco.mjtGeom.mjGEOM_CAPSULE,
|
||||||
|
np.zeros(3, dtype=np.float64),
|
||||||
|
np.zeros(3, dtype=np.float64),
|
||||||
|
np.eye(3, dtype=np.float64).reshape(-1),
|
||||||
|
np.asarray(marker['rgba'], dtype=np.float32),
|
||||||
|
)
|
||||||
|
mujoco.mjv_connector(
|
||||||
|
geom,
|
||||||
|
mujoco.mjtGeom.mjGEOM_CAPSULE,
|
||||||
|
float(marker['radius']),
|
||||||
|
np.asarray(marker['from'], dtype=np.float64),
|
||||||
|
np.asarray(marker['to'], dtype=np.float64),
|
||||||
|
)
|
||||||
|
scene.ngeom += 1
|
||||||
|
|
||||||
|
|
||||||
|
def _save_rollout_trajectory_image(
|
||||||
|
env,
|
||||||
|
output_path: Optional[str],
|
||||||
|
raw_actions: list[np.ndarray],
|
||||||
|
camera_name: Optional[str],
|
||||||
|
*,
|
||||||
|
line_radius: float = 0.004,
|
||||||
|
max_markers: int = 1500,
|
||||||
|
) -> Optional[str]:
|
||||||
|
if output_path is None or camera_name is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# IMPORTANT:
|
||||||
|
# Keep this import lazy so headless rollout can set MUJOCO_GL=egl before
|
||||||
|
# anything imports mujoco. Importing this helper at module import time would
|
||||||
|
# pull in mujoco too early on remote headless hosts and make rollout fail
|
||||||
|
# with gladLoadGL / missing DISPLAY errors.
|
||||||
|
from roboimi.utils.raw_action_trajectory_viewer import build_trajectory_capsule_markers
|
||||||
|
|
||||||
|
output_path = str(output_path)
|
||||||
|
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
frame = None
|
||||||
|
owned_renderer = None
|
||||||
|
positions = _build_action_trajectory_positions(raw_actions)
|
||||||
|
markers = build_trajectory_capsule_markers(
|
||||||
|
positions,
|
||||||
|
max_markers=max_markers,
|
||||||
|
radius=line_radius,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
renderer = None
|
||||||
|
if callable(getattr(env, '_get_or_create_offscreen_renderer', None)):
|
||||||
|
renderer = env._get_or_create_offscreen_renderer()
|
||||||
|
elif hasattr(env, 'mj_model') and hasattr(env, 'mj_data'):
|
||||||
|
import mujoco
|
||||||
|
|
||||||
|
renderer = mujoco.Renderer(env.mj_model, height=480, width=640)
|
||||||
|
owned_renderer = renderer
|
||||||
|
|
||||||
|
if renderer is not None and hasattr(env, 'mj_data'):
|
||||||
|
renderer.update_scene(env.mj_data, camera=str(camera_name))
|
||||||
|
if markers:
|
||||||
|
_append_capsule_markers_to_scene(renderer.scene, markers)
|
||||||
|
frame = renderer.render()[:, :, ::-1]
|
||||||
|
finally:
|
||||||
|
if owned_renderer is not None:
|
||||||
|
owned_renderer.close()
|
||||||
|
|
||||||
|
if frame is None and callable(getattr(env, '_get_image_obs', None)):
|
||||||
|
obs = env._get_image_obs()
|
||||||
|
frame = _get_video_frame(obs, str(camera_name))
|
||||||
|
|
||||||
|
if frame is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
cv2.imwrite(output_path, frame)
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
class _RolloutVideoRecorder:
|
class _RolloutVideoRecorder:
|
||||||
def __init__(self, output_path: Optional[str], fps: int):
|
def __init__(self, output_path: Optional[str], fps: int):
|
||||||
self.output_path = output_path
|
self.output_path = output_path
|
||||||
@@ -501,6 +646,7 @@ def _run_eval(cfg: DictConfig):
|
|||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
|
|
||||||
eval_cfg = cfg.eval
|
eval_cfg = cfg.eval
|
||||||
|
_configure_headless_mujoco_gl(eval_cfg)
|
||||||
device = eval_cfg.device
|
device = eval_cfg.device
|
||||||
camera_names = list(eval_cfg.camera_names)
|
camera_names = list(eval_cfg.camera_names)
|
||||||
artifact_paths = _resolve_artifact_paths(eval_cfg)
|
artifact_paths = _resolve_artifact_paths(eval_cfg)
|
||||||
@@ -525,6 +671,8 @@ def _run_eval(cfg: DictConfig):
|
|||||||
agent_cfg=cfg.agent,
|
agent_cfg=cfg.agent,
|
||||||
device=device
|
device=device
|
||||||
)
|
)
|
||||||
|
vision_encoder = getattr(agent, 'vision_encoder', None)
|
||||||
|
image_resize_shape = getattr(vision_encoder, 'eval_image_resize_shape', (224, 224))
|
||||||
|
|
||||||
# 重置 agent 的队列
|
# 重置 agent 的队列
|
||||||
agent.reset()
|
agent.reset()
|
||||||
@@ -566,6 +714,7 @@ def _run_eval(cfg: DictConfig):
|
|||||||
model_forward_flags = []
|
model_forward_flags = []
|
||||||
episode_reward = 0.0
|
episode_reward = 0.0
|
||||||
episode_max_reward = float('-inf')
|
episode_max_reward = float('-inf')
|
||||||
|
episode_raw_actions: list[np.ndarray] = []
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
for t in tqdm(range(eval_cfg.max_timesteps), desc=f"回合 {episode_idx + 1}"):
|
for t in tqdm(range(eval_cfg.max_timesteps), desc=f"回合 {episode_idx + 1}"):
|
||||||
@@ -581,7 +730,11 @@ def _run_eval(cfg: DictConfig):
|
|||||||
video_recorder.write(video_frame)
|
video_recorder.write(video_frame)
|
||||||
|
|
||||||
# 准备给 agent 的观测
|
# 准备给 agent 的观测
|
||||||
observation = prepare_observation(obs, camera_names)
|
observation = prepare_observation(
|
||||||
|
obs,
|
||||||
|
camera_names,
|
||||||
|
image_resize_shape=image_resize_shape,
|
||||||
|
)
|
||||||
end_preprocess = time.perf_counter()
|
end_preprocess = time.perf_counter()
|
||||||
|
|
||||||
# 选择动作(agent 内部处理队列管理)
|
# 选择动作(agent 内部处理队列管理)
|
||||||
@@ -596,6 +749,7 @@ def _run_eval(cfg: DictConfig):
|
|||||||
|
|
||||||
# 转换为 numpy
|
# 转换为 numpy
|
||||||
raw_action = _to_numpy_action(action)
|
raw_action = _to_numpy_action(action)
|
||||||
|
episode_raw_actions.append(raw_action.astype(np.float32, copy=True))
|
||||||
|
|
||||||
# 调试:打印当前时间步的动作(由配置控制)
|
# 调试:打印当前时间步的动作(由配置控制)
|
||||||
if eval_cfg.get('verbose_action', False):
|
if eval_cfg.get('verbose_action', False):
|
||||||
@@ -680,6 +834,12 @@ def _run_eval(cfg: DictConfig):
|
|||||||
episode_artifact_paths = {
|
episode_artifact_paths = {
|
||||||
'video': artifact_paths['video_mp4'],
|
'video': artifact_paths['video_mp4'],
|
||||||
'trajectory': artifact_paths['trajectory_npz'],
|
'trajectory': artifact_paths['trajectory_npz'],
|
||||||
|
'trajectory_image': _save_rollout_trajectory_image(
|
||||||
|
env,
|
||||||
|
_episode_trajectory_image_path(artifact_paths, episode_idx),
|
||||||
|
episode_raw_actions,
|
||||||
|
artifact_paths['trajectory_image_camera_name'],
|
||||||
|
),
|
||||||
'timing': artifact_paths['timing_json'] or artifact_paths['summary_json'],
|
'timing': artifact_paths['timing_json'] or artifact_paths['summary_json'],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,8 +14,58 @@ from torch.optim import AdamW
|
|||||||
from torch.optim.lr_scheduler import LambdaLR
|
from torch.optim.lr_scheduler import LambdaLR
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
# 确保正确的导入路径
|
# 确保正确的导入路径(不能依赖 cwd,因为 Hydra 会在运行时切换 cwd)
|
||||||
sys.path.append(os.getcwd())
|
def _ensure_repo_root_on_syspath():
|
||||||
|
repo_root = Path(__file__).resolve().parents[3]
|
||||||
|
repo_root_str = str(repo_root)
|
||||||
|
if repo_root_str in sys.path:
|
||||||
|
sys.path.remove(repo_root_str)
|
||||||
|
sys.path.insert(0, repo_root_str)
|
||||||
|
return repo_root
|
||||||
|
|
||||||
|
|
||||||
|
_PROBLEMATIC_LD_PRELOAD_SUBSTRINGS = ('/usr/NX/lib/libnxegl.so', 'libnxegl.so')
|
||||||
|
|
||||||
|
|
||||||
|
def _clean_ld_preload_value(value: str | None):
|
||||||
|
if not value:
|
||||||
|
return value, False
|
||||||
|
entries = [entry for entry in value.split() if entry]
|
||||||
|
filtered = [
|
||||||
|
entry for entry in entries
|
||||||
|
if not any(marker in entry for marker in _PROBLEMATIC_LD_PRELOAD_SUBSTRINGS)
|
||||||
|
]
|
||||||
|
changed = filtered != entries
|
||||||
|
cleaned = ' '.join(filtered) if filtered else None
|
||||||
|
return cleaned, changed
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_reexec_without_problematic_ld_preload():
|
||||||
|
if __name__ != '__main__':
|
||||||
|
return False
|
||||||
|
if os.environ.get('_ROBOIMI_LD_PRELOAD_SANITIZED') == '1':
|
||||||
|
return False
|
||||||
|
|
||||||
|
cleaned, changed = _clean_ld_preload_value(os.environ.get('LD_PRELOAD'))
|
||||||
|
if not changed:
|
||||||
|
return False
|
||||||
|
|
||||||
|
new_env = dict(os.environ)
|
||||||
|
new_env['_ROBOIMI_LD_PRELOAD_SANITIZED'] = '1'
|
||||||
|
if cleaned:
|
||||||
|
new_env['LD_PRELOAD'] = cleaned
|
||||||
|
else:
|
||||||
|
new_env.pop('LD_PRELOAD', None)
|
||||||
|
|
||||||
|
print(
|
||||||
|
'Detected problematic LD_PRELOAD entry for CUDA/cuDNN; re-executing train_vla.py without it.',
|
||||||
|
file=sys.stderr,
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
os.execvpe(sys.executable, [sys.executable, *sys.argv], new_env)
|
||||||
|
|
||||||
|
|
||||||
|
_REPO_ROOT = _ensure_repo_root_on_syspath()
|
||||||
|
|
||||||
from hydra.utils import instantiate
|
from hydra.utils import instantiate
|
||||||
|
|
||||||
@@ -26,6 +76,28 @@ if not OmegaConf.has_resolver("len"):
|
|||||||
OmegaConf.register_new_resolver("len", lambda x: len(x))
|
OmegaConf.register_new_resolver("len", lambda x: len(x))
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_run_output_dir() -> Path:
|
||||||
|
try:
|
||||||
|
from hydra.core.hydra_config import HydraConfig
|
||||||
|
if HydraConfig.initialized():
|
||||||
|
output_dir = HydraConfig.get().runtime.output_dir
|
||||||
|
if output_dir:
|
||||||
|
return Path(output_dir).resolve()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return Path.cwd().resolve()
|
||||||
|
|
||||||
|
|
||||||
|
_maybe_reexec_without_problematic_ld_preload()
|
||||||
|
|
||||||
|
|
||||||
|
def _configure_cuda_runtime(cfg):
|
||||||
|
"""Apply process-level CUDA runtime switches required by this environment."""
|
||||||
|
if str(cfg.train.device).startswith('cuda') and bool(cfg.train.get('disable_cudnn', False)):
|
||||||
|
torch.backends.cudnn.enabled = False
|
||||||
|
log.warning('⚠️ 已按配置禁用 cuDNN;GPU 卷积将回退到非-cuDNN 实现')
|
||||||
|
|
||||||
|
|
||||||
def recursive_to_device(data, device):
|
def recursive_to_device(data, device):
|
||||||
"""
|
"""
|
||||||
递归地将嵌套字典/列表中的张量移动到指定设备。
|
递归地将嵌套字典/列表中的张量移动到指定设备。
|
||||||
@@ -113,14 +185,11 @@ def get_lr_schedule_with_warmup(optimizer, warmup_steps, max_steps, scheduler_ty
|
|||||||
|
|
||||||
|
|
||||||
def build_training_optimizer(agent, lr, weight_decay):
|
def build_training_optimizer(agent, lr, weight_decay):
|
||||||
"""为训练脚本构建优化器,优先复用 transformer head 自带的参数分组。"""
|
"""为训练脚本构建优化器,优先复用任意 head 自带的参数分组。"""
|
||||||
trainable_params = [param for param in agent.parameters() if param.requires_grad]
|
trainable_params = [param for param in agent.parameters() if param.requires_grad]
|
||||||
noise_pred_net = getattr(agent, 'noise_pred_net', None)
|
noise_pred_net = getattr(agent, 'noise_pred_net', None)
|
||||||
get_optim_groups = getattr(noise_pred_net, 'get_optim_groups', None)
|
get_optim_groups = getattr(noise_pred_net, 'get_optim_groups', None)
|
||||||
use_head_groups = (
|
use_head_groups = callable(get_optim_groups)
|
||||||
getattr(agent, 'head_type', None) == 'transformer'
|
|
||||||
and callable(get_optim_groups)
|
|
||||||
)
|
|
||||||
|
|
||||||
if not use_head_groups:
|
if not use_head_groups:
|
||||||
return AdamW(trainable_params, lr=lr, weight_decay=weight_decay)
|
return AdamW(trainable_params, lr=lr, weight_decay=weight_decay)
|
||||||
@@ -138,7 +207,7 @@ def build_training_optimizer(agent, lr, weight_decay):
|
|||||||
for param in params:
|
for param in params:
|
||||||
param_id = id(param)
|
param_id = id(param)
|
||||||
if param_id in grouped_param_ids:
|
if param_id in grouped_param_ids:
|
||||||
raise ValueError('Transformer optimizer groups contain duplicate parameters')
|
raise ValueError('Head optimizer groups contain duplicate parameters')
|
||||||
grouped_param_ids.add(param_id)
|
grouped_param_ids.add(param_id)
|
||||||
|
|
||||||
head_trainable_param_ids = {
|
head_trainable_param_ids = {
|
||||||
@@ -146,7 +215,7 @@ def build_training_optimizer(agent, lr, weight_decay):
|
|||||||
}
|
}
|
||||||
missing_head_param_ids = head_trainable_param_ids - grouped_param_ids
|
missing_head_param_ids = head_trainable_param_ids - grouped_param_ids
|
||||||
if missing_head_param_ids:
|
if missing_head_param_ids:
|
||||||
raise ValueError('Transformer optimizer groups missed trainable head parameters')
|
raise ValueError('Head optimizer groups missed trainable head parameters')
|
||||||
|
|
||||||
remaining_params = [
|
remaining_params = [
|
||||||
param for param in trainable_params
|
param for param in trainable_params
|
||||||
@@ -230,6 +299,45 @@ def _log_to_swanlab(swanlab_module, payload, step=None):
|
|||||||
log.warning(f"SwanLab log failed at step {step}: {exc}")
|
log.warning(f"SwanLab log failed at step {step}: {exc}")
|
||||||
|
|
||||||
|
|
||||||
|
def _log_rollout_trajectory_images_to_swanlab(
|
||||||
|
swanlab_module,
|
||||||
|
rollout_stats,
|
||||||
|
step=None,
|
||||||
|
context_label: str = 'rollout',
|
||||||
|
):
|
||||||
|
if swanlab_module is None or not rollout_stats:
|
||||||
|
return
|
||||||
|
|
||||||
|
image_factory = getattr(swanlab_module, 'Image', None)
|
||||||
|
if image_factory is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
payload = {}
|
||||||
|
for fallback_episode_index, episode in enumerate(rollout_stats.get('episodes', [])):
|
||||||
|
if not isinstance(episode, dict):
|
||||||
|
continue
|
||||||
|
artifact_paths = episode.get('artifact_paths', {})
|
||||||
|
if not isinstance(artifact_paths, dict):
|
||||||
|
continue
|
||||||
|
trajectory_image = artifact_paths.get('trajectory_image')
|
||||||
|
if not trajectory_image:
|
||||||
|
continue
|
||||||
|
episode_index = int(episode.get('episode_index', fallback_episode_index))
|
||||||
|
caption = f'{context_label} trajectory image - episode {episode_index} (front)'
|
||||||
|
try:
|
||||||
|
payload[f'rollout/trajectory_image_episode_{episode_index}'] = image_factory(
|
||||||
|
str(trajectory_image),
|
||||||
|
caption=caption,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
log.warning(
|
||||||
|
f"SwanLab rollout trajectory image upload prep failed at step {step}: {exc}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if payload:
|
||||||
|
_log_to_swanlab(swanlab_module, payload, step=step)
|
||||||
|
|
||||||
|
|
||||||
def _finish_swanlab(swanlab_module):
|
def _finish_swanlab(swanlab_module):
|
||||||
if swanlab_module is None:
|
if swanlab_module is None:
|
||||||
return
|
return
|
||||||
@@ -258,11 +366,13 @@ def _run_training(cfg: DictConfig):
|
|||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
|
|
||||||
log.info(f"🚀 开始 VLA 训练 (设备: {cfg.train.device})")
|
log.info(f"🚀 开始 VLA 训练 (设备: {cfg.train.device})")
|
||||||
|
_configure_cuda_runtime(cfg)
|
||||||
swanlab_module = _init_swanlab(cfg)
|
swanlab_module = _init_swanlab(cfg)
|
||||||
try:
|
try:
|
||||||
# 创建检查点目录
|
# 创建检查点目录
|
||||||
checkpoint_dir = Path("checkpoints")
|
run_output_dir = _resolve_run_output_dir()
|
||||||
checkpoint_dir.mkdir(exist_ok=True)
|
checkpoint_dir = run_output_dir / "checkpoints"
|
||||||
|
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||||
default_best_model_path = checkpoint_dir / "vla_model_best.pt"
|
default_best_model_path = checkpoint_dir / "vla_model_best.pt"
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
@@ -270,7 +380,14 @@ def _run_training(cfg: DictConfig):
|
|||||||
# =========================================================================
|
# =========================================================================
|
||||||
log.info("📦 加载数据集...")
|
log.info("📦 加载数据集...")
|
||||||
try:
|
try:
|
||||||
dataset = instantiate(cfg.data)
|
dataset_image_resize_shape = cfg.data.get('image_resize_shape', (224, 224))
|
||||||
|
vision_backbone_cfg = cfg.agent.get('vision_backbone', None)
|
||||||
|
if vision_backbone_cfg is not None and 'dataset_image_resize_shape' in vision_backbone_cfg:
|
||||||
|
dataset_image_resize_shape = vision_backbone_cfg.get('dataset_image_resize_shape')
|
||||||
|
dataset = instantiate(
|
||||||
|
cfg.data,
|
||||||
|
image_resize_shape=dataset_image_resize_shape,
|
||||||
|
)
|
||||||
log.info(f"✅ 数据集加载成功。总样本数: {len(dataset)}")
|
log.info(f"✅ 数据集加载成功。总样本数: {len(dataset)}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"❌ 数据集加载失败: {e}")
|
log.error(f"❌ 数据集加载失败: {e}")
|
||||||
@@ -590,6 +707,13 @@ def _run_training(cfg: DictConfig):
|
|||||||
rollout_cfg.eval.headless = True
|
rollout_cfg.eval.headless = True
|
||||||
rollout_cfg.eval.device = 'cpu'
|
rollout_cfg.eval.device = 'cpu'
|
||||||
rollout_cfg.eval.verbose_action = False
|
rollout_cfg.eval.verbose_action = False
|
||||||
|
rollout_cfg.eval.record_video = False
|
||||||
|
rollout_cfg.eval.save_trajectory_image = True
|
||||||
|
rollout_cfg.eval.trajectory_image_camera_name = 'front'
|
||||||
|
rollout_cfg.eval.save_summary_json = True
|
||||||
|
rollout_cfg.eval.artifact_dir = str(
|
||||||
|
(run_output_dir / 'rollout_artifacts' / checkpoint_path.stem).resolve()
|
||||||
|
)
|
||||||
|
|
||||||
log.info(
|
log.info(
|
||||||
"🎯 开始 checkpoint rollout 验证: %s (episodes=%s, headless=True)",
|
"🎯 开始 checkpoint rollout 验证: %s (episodes=%s, headless=True)",
|
||||||
@@ -796,6 +920,12 @@ def _run_training(cfg: DictConfig):
|
|||||||
},
|
},
|
||||||
step=step,
|
step=step,
|
||||||
)
|
)
|
||||||
|
_log_rollout_trajectory_images_to_swanlab(
|
||||||
|
swanlab_module,
|
||||||
|
rollout_stats,
|
||||||
|
step=step,
|
||||||
|
context_label=f'epoch {completed_epoch} rollout',
|
||||||
|
)
|
||||||
if rollout_avg_reward > best_rollout_reward:
|
if rollout_avg_reward > best_rollout_reward:
|
||||||
best_rollout_reward = rollout_avg_reward
|
best_rollout_reward = rollout_avg_reward
|
||||||
best_model_path = default_best_model_path
|
best_model_path = default_best_model_path
|
||||||
|
|||||||
@@ -57,6 +57,7 @@ class DualDianaMed(MujocoEnv):
|
|||||||
self.obs = None
|
self.obs = None
|
||||||
|
|
||||||
self.rew = None
|
self.rew = None
|
||||||
|
self._offscreen_renderer = None
|
||||||
|
|
||||||
|
|
||||||
def actuate_J(self, q_target, qdot_target, Arm):
|
def actuate_J(self, q_target, qdot_target, Arm):
|
||||||
@@ -161,6 +162,8 @@ class DualDianaMed(MujocoEnv):
|
|||||||
|
|
||||||
|
|
||||||
def _get_obs(self):
|
def _get_obs(self):
|
||||||
|
if not self.is_render:
|
||||||
|
self._update_camera_images_sync()
|
||||||
obs = collections.OrderedDict()
|
obs = collections.OrderedDict()
|
||||||
obs['qpos'] = self.get_obs_qpos
|
obs['qpos'] = self.get_obs_qpos
|
||||||
obs['action'] = self.compute_qpos
|
obs['action'] = self.compute_qpos
|
||||||
@@ -173,6 +176,8 @@ class DualDianaMed(MujocoEnv):
|
|||||||
return obs
|
return obs
|
||||||
|
|
||||||
def _get_image_obs(self):
|
def _get_image_obs(self):
|
||||||
|
if not self.is_render:
|
||||||
|
self._update_camera_images_sync()
|
||||||
obs = collections.OrderedDict()
|
obs = collections.OrderedDict()
|
||||||
obs['images'] = dict()
|
obs['images'] = dict()
|
||||||
obs['images']['top'] = self.top
|
obs['images']['top'] = self.top
|
||||||
@@ -211,27 +216,36 @@ class DualDianaMed(MujocoEnv):
|
|||||||
raise AttributeError("please input right name")
|
raise AttributeError("please input right name")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_or_create_offscreen_renderer(self):
|
||||||
|
renderer = getattr(self, '_offscreen_renderer', None)
|
||||||
|
if renderer is None:
|
||||||
|
renderer = mj.Renderer(self.mj_model, height=480, width=640)
|
||||||
|
self._offscreen_renderer = renderer
|
||||||
|
return renderer
|
||||||
|
|
||||||
|
def _render_camera_set(self, img_renderer):
|
||||||
|
img_renderer.update_scene(self.mj_data, camera="rs_cam_right")
|
||||||
|
self.r_vis = img_renderer.render()[:, :, ::-1]
|
||||||
|
img_renderer.update_scene(self.mj_data, camera="rs_cam_left")
|
||||||
|
self.l_vis = img_renderer.render()[:, :, ::-1]
|
||||||
|
img_renderer.update_scene(self.mj_data, camera="top")
|
||||||
|
self.top = img_renderer.render()[:, :, ::-1]
|
||||||
|
img_renderer.update_scene(self.mj_data, camera="angle")
|
||||||
|
self.angle = img_renderer.render()[:, :, ::-1]
|
||||||
|
img_renderer.update_scene(self.mj_data, camera="front")
|
||||||
|
self.front = img_renderer.render()[:, :, ::-1]
|
||||||
|
|
||||||
|
def _update_camera_images_sync(self):
|
||||||
|
img_renderer = self._get_or_create_offscreen_renderer()
|
||||||
|
self._render_camera_set(img_renderer)
|
||||||
|
|
||||||
def camera_viewer(self):
|
def camera_viewer(self):
|
||||||
img_renderer = mj.Renderer(self.mj_model,height=480,width=640)
|
img_renderer = self._get_or_create_offscreen_renderer()
|
||||||
show_gui = self.is_render
|
show_gui = self.is_render
|
||||||
if show_gui:
|
if show_gui:
|
||||||
cv2.namedWindow('Cam view',cv2.WINDOW_NORMAL)
|
cv2.namedWindow('Cam view',cv2.WINDOW_NORMAL)
|
||||||
while not self.exit_flag:
|
while not self.exit_flag:
|
||||||
img_renderer.update_scene(self.mj_data,camera="rs_cam_right")
|
self._render_camera_set(img_renderer)
|
||||||
self.r_vis = img_renderer.render()
|
|
||||||
self.r_vis = self.r_vis[:, :, ::-1]
|
|
||||||
img_renderer.update_scene(self.mj_data,camera="rs_cam_left")
|
|
||||||
self.l_vis = img_renderer.render()
|
|
||||||
self.l_vis = self.l_vis[:, :, ::-1]
|
|
||||||
img_renderer.update_scene(self.mj_data,camera="top")
|
|
||||||
self.top = img_renderer.render()
|
|
||||||
self.top = self.top[:, :, ::-1]
|
|
||||||
img_renderer.update_scene(self.mj_data,camera="angle")
|
|
||||||
self.angle = img_renderer.render()
|
|
||||||
self.angle = self.angle[:, :, ::-1]
|
|
||||||
img_renderer.update_scene(self.mj_data,camera="front")
|
|
||||||
self.front = img_renderer.render()
|
|
||||||
self.front = self.front[:, :, ::-1]
|
|
||||||
if show_gui:
|
if show_gui:
|
||||||
if self.cam_view is not None:
|
if self.cam_view is not None:
|
||||||
cv2.imshow('Cam view', self.cam_view)
|
cv2.imshow('Cam view', self.cam_view)
|
||||||
@@ -239,6 +253,9 @@ class DualDianaMed(MujocoEnv):
|
|||||||
|
|
||||||
|
|
||||||
def cam_start(self):
|
def cam_start(self):
|
||||||
|
if not self.is_render:
|
||||||
|
self.cam_thread = None
|
||||||
|
return
|
||||||
self.cam_thread = threading.Thread(target=self.camera_viewer,daemon=True)
|
self.cam_thread = threading.Thread(target=self.camera_viewer,daemon=True)
|
||||||
self.cam_thread.start()
|
self.cam_thread.start()
|
||||||
|
|
||||||
|
|||||||
@@ -76,6 +76,9 @@ class DualDianaMed_Pos_Ctrl(DualDianaMed):
|
|||||||
self.angle = None
|
self.angle = None
|
||||||
self.r_vis = None
|
self.r_vis = None
|
||||||
self.front = None
|
self.front = None
|
||||||
|
if not self.is_render:
|
||||||
|
self._update_camera_images_sync()
|
||||||
|
return
|
||||||
self.cam_flage = True
|
self.cam_flage = True
|
||||||
t=0
|
t=0
|
||||||
while self.cam_flage:
|
while self.cam_flage:
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ class VLAAgent(nn.Module):
|
|||||||
normalization_type='min_max', # 归一化类型: 'gaussian' 或 'min_max'
|
normalization_type='min_max', # 归一化类型: 'gaussian' 或 'min_max'
|
||||||
num_action_steps=8, # 每次推理实际执行多少步动作
|
num_action_steps=8, # 每次推理实际执行多少步动作
|
||||||
head_type='unet', # Policy head类型: 'unet' 或 'transformer'
|
head_type='unet', # Policy head类型: 'unet' 或 'transformer'
|
||||||
|
cond_projector=None, # 可选:将视觉+状态条件投影到head期望维度
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# 保存参数
|
# 保存参数
|
||||||
@@ -74,15 +75,32 @@ class VLAAgent(nn.Module):
|
|||||||
self.vision_encoder = vision_backbone
|
self.vision_encoder = vision_backbone
|
||||||
if self.camera_names is not None:
|
if self.camera_names is not None:
|
||||||
self.vision_encoder.camera_names = self.camera_names
|
self.vision_encoder.camera_names = self.camera_names
|
||||||
|
self.condition_tokens_per_step = int(getattr(self.vision_encoder, 'tokens_per_step', 1))
|
||||||
|
joint_vision_dim = getattr(self.vision_encoder, 'joint_output_dim', None)
|
||||||
|
if joint_vision_dim is not None:
|
||||||
|
per_token_vision_dim = int(joint_vision_dim)
|
||||||
|
self.condition_tokens_per_step = 1
|
||||||
|
else:
|
||||||
single_cam_feat_dim = self.vision_encoder.output_dim
|
single_cam_feat_dim = self.vision_encoder.output_dim
|
||||||
# global_cond_dim: 展平后的总维度(用于UNet)
|
if self.condition_tokens_per_step > 1:
|
||||||
total_vision_dim = single_cam_feat_dim * num_cams * obs_horizon
|
per_token_vision_dim = int(single_cam_feat_dim)
|
||||||
total_prop_dim = obs_dim * obs_horizon
|
else:
|
||||||
self.global_cond_dim = total_vision_dim + total_prop_dim
|
per_token_vision_dim = int(single_cam_feat_dim) * int(num_cams)
|
||||||
|
|
||||||
# per_step_cond_dim: 每步的条件维度(用于Transformer)
|
self.condition_sequence_length = self.obs_horizon * self.condition_tokens_per_step
|
||||||
# 注意:这里不乘以obs_horizon,因为Transformer的输入是序列形式
|
self.raw_per_step_cond_dim = per_token_vision_dim + obs_dim
|
||||||
self.per_step_cond_dim = single_cam_feat_dim * num_cams + obs_dim
|
if cond_projector is None:
|
||||||
|
self.cond_projector = None
|
||||||
|
self.per_step_cond_dim = self.raw_per_step_cond_dim
|
||||||
|
else:
|
||||||
|
if isinstance(cond_projector, nn.Module):
|
||||||
|
self.cond_projector = cond_projector
|
||||||
|
else:
|
||||||
|
self.cond_projector = cond_projector(input_dim=self.raw_per_step_cond_dim)
|
||||||
|
self.per_step_cond_dim = self._projector_output_dim(self.cond_projector, self.raw_per_step_cond_dim)
|
||||||
|
|
||||||
|
# global_cond_dim: 展平后的总维度(用于UNet)
|
||||||
|
self.global_cond_dim = self.per_step_cond_dim * self.condition_sequence_length
|
||||||
|
|
||||||
self.noise_scheduler = DDPMScheduler(
|
self.noise_scheduler = DDPMScheduler(
|
||||||
num_train_timesteps=diffusion_steps,
|
num_train_timesteps=diffusion_steps,
|
||||||
@@ -111,7 +129,7 @@ class VLAAgent(nn.Module):
|
|||||||
input_dim=action_dim,
|
input_dim=action_dim,
|
||||||
output_dim=action_dim,
|
output_dim=action_dim,
|
||||||
horizon=pred_horizon,
|
horizon=pred_horizon,
|
||||||
n_obs_steps=obs_horizon,
|
n_obs_steps=self.condition_sequence_length,
|
||||||
cond_dim=self.per_step_cond_dim # 每步的条件维度
|
cond_dim=self.per_step_cond_dim # 每步的条件维度
|
||||||
)
|
)
|
||||||
else: # 'unet' (default)
|
else: # 'unet' (default)
|
||||||
@@ -143,6 +161,20 @@ class VLAAgent(nn.Module):
|
|||||||
return tuple(self._move_to_device(v, device) for v in data)
|
return tuple(self._move_to_device(v, device) for v in data)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _projector_output_dim(projector: nn.Module, fallback: int) -> int:
|
||||||
|
output_dim = getattr(projector, 'output_dim', None)
|
||||||
|
if output_dim is not None:
|
||||||
|
return int(output_dim)
|
||||||
|
out_features = getattr(projector, 'out_features', None)
|
||||||
|
if out_features is not None:
|
||||||
|
return int(out_features)
|
||||||
|
linear = getattr(projector, 'linear', None)
|
||||||
|
linear_out_features = getattr(linear, 'out_features', None)
|
||||||
|
if linear_out_features is not None:
|
||||||
|
return int(linear_out_features)
|
||||||
|
return int(fallback)
|
||||||
|
|
||||||
def _order_images(self, images: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
def _order_images(self, images: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||||
"""按显式配置的相机顺序返回图像字典。"""
|
"""按显式配置的相机顺序返回图像字典。"""
|
||||||
if self.camera_names is None:
|
if self.camera_names is None:
|
||||||
@@ -165,7 +197,43 @@ class VLAAgent(nn.Module):
|
|||||||
ordered_images = self._order_images(images)
|
ordered_images = self._order_images(images)
|
||||||
visual_features = self.vision_encoder(ordered_images)
|
visual_features = self.vision_encoder(ordered_images)
|
||||||
state_features = self.state_encoder(states)
|
state_features = self.state_encoder(states)
|
||||||
|
if visual_features.ndim == 4:
|
||||||
|
batch_size, obs_steps, token_count, _ = visual_features.shape
|
||||||
|
if obs_steps != state_features.shape[1]:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"观测时间维不匹配: visual={obs_steps}, state={state_features.shape[1]}"
|
||||||
|
)
|
||||||
|
if token_count != self.condition_tokens_per_step:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"条件token数量不匹配: got {token_count}, expected {self.condition_tokens_per_step}"
|
||||||
|
)
|
||||||
|
state_features = state_features.unsqueeze(2).expand(-1, -1, token_count, -1)
|
||||||
cond = torch.cat([visual_features, state_features], dim=-1)
|
cond = torch.cat([visual_features, state_features], dim=-1)
|
||||||
|
if cond.shape[-1] != self.raw_per_step_cond_dim:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"原始条件维度不匹配: got {cond.shape[-1]}, expected {self.raw_per_step_cond_dim}"
|
||||||
|
)
|
||||||
|
if self.cond_projector is not None:
|
||||||
|
cond = self.cond_projector(cond)
|
||||||
|
if cond.shape[-1] != self.per_step_cond_dim:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"条件维度不匹配: got {cond.shape[-1]}, expected {self.per_step_cond_dim}"
|
||||||
|
)
|
||||||
|
cond = cond.reshape(batch_size, obs_steps * token_count, self.per_step_cond_dim)
|
||||||
|
expected_length = self.condition_sequence_length
|
||||||
|
if cond.shape[1] != expected_length:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"条件序列长度不匹配: got {cond.shape[1]}, expected {expected_length}"
|
||||||
|
)
|
||||||
|
return cond
|
||||||
|
|
||||||
|
cond = torch.cat([visual_features, state_features], dim=-1)
|
||||||
|
if cond.shape[-1] != self.raw_per_step_cond_dim:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"原始条件维度不匹配: got {cond.shape[-1]}, expected {self.raw_per_step_cond_dim}"
|
||||||
|
)
|
||||||
|
if self.cond_projector is not None:
|
||||||
|
cond = self.cond_projector(cond)
|
||||||
if cond.shape[-1] != self.per_step_cond_dim:
|
if cond.shape[-1] != self.per_step_cond_dim:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"条件维度不匹配: got {cond.shape[-1]}, expected {self.per_step_cond_dim}"
|
f"条件维度不匹配: got {cond.shape[-1]}, expected {self.per_step_cond_dim}"
|
||||||
|
|||||||
161
roboimi/vla/agent_imf.py
Normal file
161
roboimi/vla/agent_imf.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from roboimi.vla.agent import VLAAgent
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torch.func import jvp as TORCH_FUNC_JVP
|
||||||
|
except ImportError: # pragma: no cover
|
||||||
|
TORCH_FUNC_JVP = None
|
||||||
|
|
||||||
|
|
||||||
|
class IMFVLAAgent(VLAAgent):
|
||||||
|
def __init__(self, *args, inference_steps: int = 1, **kwargs):
|
||||||
|
if inference_steps != 1:
|
||||||
|
raise ValueError(
|
||||||
|
'IMFVLAAgent only supports one-step inference; '
|
||||||
|
f'inference_steps must be 1, got {inference_steps}.'
|
||||||
|
)
|
||||||
|
super().__init__(*args, inference_steps=inference_steps, **kwargs)
|
||||||
|
self.inference_steps = 1
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _broadcast_batch_time(value: torch.Tensor, reference: torch.Tensor) -> torch.Tensor:
|
||||||
|
while value.ndim < reference.ndim:
|
||||||
|
value = value.unsqueeze(-1)
|
||||||
|
return value
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _apply_conditioning(
|
||||||
|
trajectory: torch.Tensor,
|
||||||
|
condition_data: Optional[torch.Tensor] = None,
|
||||||
|
condition_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if condition_data is None or condition_mask is None:
|
||||||
|
return trajectory
|
||||||
|
conditioned = trajectory.clone()
|
||||||
|
conditioned[condition_mask] = condition_data[condition_mask]
|
||||||
|
return conditioned
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _jvp_math_sdp_context(z_t: torch.Tensor):
|
||||||
|
if z_t.is_cuda:
|
||||||
|
return torch.backends.cuda.sdp_kernel(
|
||||||
|
enable_flash=False,
|
||||||
|
enable_math=True,
|
||||||
|
enable_mem_efficient=False,
|
||||||
|
enable_cudnn=False,
|
||||||
|
)
|
||||||
|
return nullcontext()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _jvp_tangents(v: torch.Tensor, r: torch.Tensor, t: torch.Tensor):
|
||||||
|
return v.detach(), torch.zeros_like(r), torch.ones_like(t)
|
||||||
|
|
||||||
|
def fn(self, z: torch.Tensor, r: torch.Tensor, t: torch.Tensor, cond=None) -> torch.Tensor:
|
||||||
|
return self.noise_pred_net(z, r, t, cond=cond)
|
||||||
|
|
||||||
|
def _compute_u_and_du_dt(
|
||||||
|
self,
|
||||||
|
z_t: torch.Tensor,
|
||||||
|
r: torch.Tensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
cond,
|
||||||
|
v: torch.Tensor,
|
||||||
|
condition_data: Optional[torch.Tensor] = None,
|
||||||
|
condition_mask: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
tangents = self._jvp_tangents(v, r, t)
|
||||||
|
|
||||||
|
def g(z, r_value, t_value):
|
||||||
|
conditioned_z = self._apply_conditioning(z, condition_data, condition_mask)
|
||||||
|
return self.fn(conditioned_z, r_value, t_value, cond=cond)
|
||||||
|
|
||||||
|
with self._jvp_math_sdp_context(z_t):
|
||||||
|
if TORCH_FUNC_JVP is not None:
|
||||||
|
try:
|
||||||
|
return TORCH_FUNC_JVP(g, (z_t, r, t), tangents)
|
||||||
|
except (RuntimeError, TypeError, NotImplementedError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
u = g(z_t, r, t)
|
||||||
|
_, du_dt = torch.autograd.functional.jvp(
|
||||||
|
g,
|
||||||
|
(z_t, r, t),
|
||||||
|
tangents,
|
||||||
|
create_graph=False,
|
||||||
|
strict=False,
|
||||||
|
)
|
||||||
|
return u, du_dt
|
||||||
|
|
||||||
|
def _compound_velocity(
|
||||||
|
self,
|
||||||
|
u: torch.Tensor,
|
||||||
|
du_dt: torch.Tensor,
|
||||||
|
r: torch.Tensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
delta = self._broadcast_batch_time(t - r, u)
|
||||||
|
return u + delta * du_dt.detach()
|
||||||
|
|
||||||
|
def _sample_one_step(
|
||||||
|
self,
|
||||||
|
z_t: torch.Tensor,
|
||||||
|
r: Optional[torch.Tensor] = None,
|
||||||
|
t: Optional[torch.Tensor] = None,
|
||||||
|
cond=None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
batch_size = z_t.shape[0]
|
||||||
|
if t is None:
|
||||||
|
t = torch.ones(batch_size, device=z_t.device, dtype=z_t.dtype)
|
||||||
|
if r is None:
|
||||||
|
r = torch.zeros(batch_size, device=z_t.device, dtype=z_t.dtype)
|
||||||
|
u = self.fn(z_t, r, t, cond=cond)
|
||||||
|
delta = self._broadcast_batch_time(t - r, z_t)
|
||||||
|
return z_t - delta * u
|
||||||
|
|
||||||
|
def compute_loss(self, batch):
|
||||||
|
actions, states, images = batch['action'], batch['qpos'], batch['images']
|
||||||
|
action_is_pad = batch.get('action_is_pad', None)
|
||||||
|
batch_size = actions.shape[0]
|
||||||
|
|
||||||
|
states = self.normalization.normalize_qpos(states)
|
||||||
|
actions = self.normalization.normalize_action(actions)
|
||||||
|
cond = self._build_cond(images, states)
|
||||||
|
|
||||||
|
x = actions
|
||||||
|
e = torch.randn_like(x)
|
||||||
|
t = torch.rand(batch_size, device=x.device, dtype=x.dtype)
|
||||||
|
r = torch.rand(batch_size, device=x.device, dtype=x.dtype)
|
||||||
|
t, r = torch.maximum(t, r), torch.minimum(t, r)
|
||||||
|
|
||||||
|
t_broadcast = self._broadcast_batch_time(t, x)
|
||||||
|
z_t = (1 - t_broadcast) * x + t_broadcast * e
|
||||||
|
|
||||||
|
v = self.fn(z_t, t, t, cond=cond)
|
||||||
|
u, du_dt = self._compute_u_and_du_dt(z_t, r, t, cond=cond, v=v)
|
||||||
|
V = self._compound_velocity(u, du_dt, r, t)
|
||||||
|
target = e - x
|
||||||
|
|
||||||
|
loss = F.mse_loss(V, target, reduction='none')
|
||||||
|
if action_is_pad is not None:
|
||||||
|
mask = (~action_is_pad).unsqueeze(-1).to(loss.dtype)
|
||||||
|
valid_count = mask.sum() * loss.shape[-1]
|
||||||
|
loss = (loss * mask).sum() / valid_count.clamp_min(1.0)
|
||||||
|
else:
|
||||||
|
loss = loss.mean()
|
||||||
|
return loss
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def predict_action(self, images, proprioception):
|
||||||
|
batch_size = proprioception.shape[0]
|
||||||
|
proprioception = self.normalization.normalize_qpos(proprioception)
|
||||||
|
cond = self._build_cond(images, proprioception)
|
||||||
|
z_t = torch.randn((batch_size, self.pred_horizon, self.action_dim), device=cond.device, dtype=cond.dtype)
|
||||||
|
action = self._sample_one_step(z_t, cond=cond)
|
||||||
|
return self.normalization.denormalize_action(action)
|
||||||
41
roboimi/vla/conf/agent/lewm_imf_attnres.yaml
Normal file
41
roboimi/vla/conf/agent/lewm_imf_attnres.yaml
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
# @package agent
|
||||||
|
defaults:
|
||||||
|
- /backbone@vision_backbone: lewm_vit_diffusion
|
||||||
|
- /modules@state_encoder: identity_state_encoder
|
||||||
|
- /modules@action_encoder: identity_action_encoder
|
||||||
|
- /head: imf_transformer1d
|
||||||
|
- _self_
|
||||||
|
|
||||||
|
_target_: roboimi.vla.agent_imf.IMFVLAAgent
|
||||||
|
|
||||||
|
action_dim: 16
|
||||||
|
obs_dim: 16
|
||||||
|
normalization_type: "min_max"
|
||||||
|
pred_horizon: 16
|
||||||
|
obs_horizon: 2
|
||||||
|
num_action_steps: 8
|
||||||
|
camera_names: ${data.camera_names}
|
||||||
|
num_cams: 3
|
||||||
|
|
||||||
|
vision_backbone:
|
||||||
|
num_cameras: ${agent.num_cams}
|
||||||
|
camera_names: ${agent.camera_names}
|
||||||
|
fused_camera_names: [front, top, r_vis]
|
||||||
|
|
||||||
|
diffusion_steps: 100
|
||||||
|
inference_steps: 1
|
||||||
|
head_type: "transformer"
|
||||||
|
|
||||||
|
head:
|
||||||
|
input_dim: ${agent.action_dim}
|
||||||
|
output_dim: ${agent.action_dim}
|
||||||
|
horizon: ${agent.pred_horizon}
|
||||||
|
n_obs_steps: ${agent.obs_horizon}
|
||||||
|
cond_dim: 208
|
||||||
|
causal_attn: false
|
||||||
|
time_as_cond: true
|
||||||
|
obs_as_cond: true
|
||||||
|
n_cond_layers: 0
|
||||||
|
backbone_type: attnres_full
|
||||||
|
n_head: 1
|
||||||
|
n_kv_head: 1
|
||||||
40
roboimi/vla/conf/agent/resnet_imf_attnres.yaml
Normal file
40
roboimi/vla/conf/agent/resnet_imf_attnres.yaml
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
# @package agent
|
||||||
|
defaults:
|
||||||
|
- /backbone@vision_backbone: resnet_diffusion
|
||||||
|
- /modules@state_encoder: identity_state_encoder
|
||||||
|
- /modules@action_encoder: identity_action_encoder
|
||||||
|
- /head: imf_transformer1d
|
||||||
|
- _self_
|
||||||
|
|
||||||
|
_target_: roboimi.vla.agent_imf.IMFVLAAgent
|
||||||
|
|
||||||
|
action_dim: 16
|
||||||
|
obs_dim: 16
|
||||||
|
normalization_type: "min_max"
|
||||||
|
pred_horizon: 16
|
||||||
|
obs_horizon: 2
|
||||||
|
num_action_steps: 8
|
||||||
|
camera_names: ${data.camera_names}
|
||||||
|
num_cams: 3
|
||||||
|
|
||||||
|
vision_backbone:
|
||||||
|
num_cameras: ${agent.num_cams}
|
||||||
|
camera_names: ${agent.camera_names}
|
||||||
|
|
||||||
|
diffusion_steps: 100
|
||||||
|
inference_steps: 1
|
||||||
|
head_type: "transformer"
|
||||||
|
|
||||||
|
head:
|
||||||
|
input_dim: ${agent.action_dim}
|
||||||
|
output_dim: ${agent.action_dim}
|
||||||
|
horizon: ${agent.pred_horizon}
|
||||||
|
n_obs_steps: ${agent.obs_horizon}
|
||||||
|
cond_dim: 208
|
||||||
|
causal_attn: false
|
||||||
|
time_as_cond: true
|
||||||
|
obs_as_cond: true
|
||||||
|
n_cond_layers: 0
|
||||||
|
backbone_type: attnres_full
|
||||||
|
n_head: 1
|
||||||
|
n_kv_head: 1
|
||||||
48
roboimi/vla/conf/agent/resnet_imf_attnres_multitoken.yaml
Normal file
48
roboimi/vla/conf/agent/resnet_imf_attnres_multitoken.yaml
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
# @package agent
|
||||||
|
defaults:
|
||||||
|
- /backbone@vision_backbone: resnet_diffusion
|
||||||
|
- /modules@state_encoder: identity_state_encoder
|
||||||
|
- /modules@action_encoder: identity_action_encoder
|
||||||
|
- /modules@cond_projector: linear_condition_projector
|
||||||
|
- /head: imf_transformer1d
|
||||||
|
- _self_
|
||||||
|
|
||||||
|
_target_: roboimi.vla.agent_imf.IMFVLAAgent
|
||||||
|
|
||||||
|
action_dim: 16
|
||||||
|
obs_dim: 16
|
||||||
|
normalization_type: "min_max"
|
||||||
|
pred_horizon: 16
|
||||||
|
obs_horizon: 2
|
||||||
|
num_action_steps: 8
|
||||||
|
camera_names: ${data.camera_names}
|
||||||
|
num_cams: ${len:${agent.camera_names}}
|
||||||
|
|
||||||
|
vision_backbone:
|
||||||
|
num_cameras: ${agent.num_cams}
|
||||||
|
camera_names: ${agent.camera_names}
|
||||||
|
vision_backbone: "resnet18"
|
||||||
|
vision_backbone_mode: "resnet"
|
||||||
|
freeze_backbone: false
|
||||||
|
use_separate_rgb_encoder_per_camera: true
|
||||||
|
output_tokens_per_camera: true
|
||||||
|
|
||||||
|
cond_projector:
|
||||||
|
output_dim: ${agent.head.n_emb}
|
||||||
|
|
||||||
|
diffusion_steps: 100
|
||||||
|
inference_steps: 1
|
||||||
|
head_type: "transformer"
|
||||||
|
|
||||||
|
head:
|
||||||
|
input_dim: ${agent.action_dim}
|
||||||
|
output_dim: ${agent.action_dim}
|
||||||
|
horizon: ${agent.pred_horizon}
|
||||||
|
cond_dim: ${agent.head.n_emb}
|
||||||
|
causal_attn: false
|
||||||
|
time_as_cond: true
|
||||||
|
obs_as_cond: true
|
||||||
|
n_cond_layers: 0
|
||||||
|
backbone_type: attnres_full
|
||||||
|
n_head: 1
|
||||||
|
n_kv_head: 1
|
||||||
44
roboimi/vla/conf/agent/siglip2_imf_attnres.yaml
Normal file
44
roboimi/vla/conf/agent/siglip2_imf_attnres.yaml
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
# @package agent
|
||||||
|
defaults:
|
||||||
|
- /backbone@vision_backbone: siglip2_diffusion
|
||||||
|
- /modules@state_encoder: identity_state_encoder
|
||||||
|
- /modules@action_encoder: identity_action_encoder
|
||||||
|
- /modules@cond_projector: linear_condition_projector
|
||||||
|
- /head: imf_transformer1d
|
||||||
|
- _self_
|
||||||
|
|
||||||
|
_target_: roboimi.vla.agent_imf.IMFVLAAgent
|
||||||
|
|
||||||
|
action_dim: 16
|
||||||
|
obs_dim: 16
|
||||||
|
normalization_type: "min_max"
|
||||||
|
pred_horizon: 16
|
||||||
|
obs_horizon: 2
|
||||||
|
num_action_steps: 8
|
||||||
|
camera_names: ${data.camera_names}
|
||||||
|
num_cams: ${len:${agent.camera_names}}
|
||||||
|
|
||||||
|
vision_backbone:
|
||||||
|
num_cameras: ${agent.num_cams}
|
||||||
|
camera_names: ${agent.camera_names}
|
||||||
|
|
||||||
|
cond_projector:
|
||||||
|
output_dim: ${agent.head.cond_dim}
|
||||||
|
|
||||||
|
diffusion_steps: 100
|
||||||
|
inference_steps: 1
|
||||||
|
head_type: "transformer"
|
||||||
|
|
||||||
|
head:
|
||||||
|
input_dim: ${agent.action_dim}
|
||||||
|
output_dim: ${agent.action_dim}
|
||||||
|
horizon: ${agent.pred_horizon}
|
||||||
|
n_obs_steps: ${agent.obs_horizon}
|
||||||
|
cond_dim: 384
|
||||||
|
causal_attn: false
|
||||||
|
time_as_cond: true
|
||||||
|
obs_as_cond: true
|
||||||
|
n_cond_layers: 0
|
||||||
|
backbone_type: attnres_full
|
||||||
|
n_head: 1
|
||||||
|
n_kv_head: 1
|
||||||
16
roboimi/vla/conf/backbone/lewm_vit_diffusion.yaml
Normal file
16
roboimi/vla/conf/backbone/lewm_vit_diffusion.yaml
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
_target_: roboimi.vla.models.backbones.lewm_vit_backbone.LEWMViTBackbone
|
||||||
|
|
||||||
|
# LEWM checkpoint path; override this on the target machine.
|
||||||
|
checkpoint_path: null
|
||||||
|
|
||||||
|
# Input camera contract for roboimi; internal LEWM fusion order stays front/top/r_vis.
|
||||||
|
num_cameras: 3
|
||||||
|
camera_names: [r_vis, top, front]
|
||||||
|
fused_camera_names: [front, top, r_vis]
|
||||||
|
|
||||||
|
freeze_backbone: true
|
||||||
|
joint_output_dim: 192
|
||||||
|
output_dim: 192
|
||||||
|
image_size: 224
|
||||||
|
dataset_image_resize_shape: null
|
||||||
|
eval_image_resize_shape: [256, 256]
|
||||||
@@ -5,6 +5,7 @@ _target_: roboimi.vla.models.backbones.resnet_diffusion.ResNetDiffusionBackbone
|
|||||||
# ====================
|
# ====================
|
||||||
vision_backbone: "resnet18" # torchvision 模型名称: resnet18, resnet34, resnet50
|
vision_backbone: "resnet18" # torchvision 模型名称: resnet18, resnet34, resnet50
|
||||||
pretrained_backbone_weights: "IMAGENET1K_V1" # 使用ImageNet预训练权重(torchvision>=0.13)
|
pretrained_backbone_weights: "IMAGENET1K_V1" # 使用ImageNet预训练权重(torchvision>=0.13)
|
||||||
|
vision_backbone_mode: "resnet" # resnet | attnres_resnet
|
||||||
|
|
||||||
# ====================
|
# ====================
|
||||||
# 冻结设置
|
# 冻结设置
|
||||||
@@ -30,4 +31,20 @@ spatial_softmax_num_keypoints: 32 # Spatial Softmax 关键点数量
|
|||||||
# false: 共享编码器(所有摄像头共享一个 ResNet,参数少但容量受限)推荐!
|
# false: 共享编码器(所有摄像头共享一个 ResNet,参数少但容量受限)推荐!
|
||||||
# true: 独立编码器(每个摄像头有独立的 ResNet,参数多但容量大)
|
# true: 独立编码器(每个摄像头有独立的 ResNet,参数多但容量大)
|
||||||
use_separate_rgb_encoder_per_camera: true
|
use_separate_rgb_encoder_per_camera: true
|
||||||
|
# false: 将所有相机特征拼成一个条件token;true: 每个相机输出一个独立token
|
||||||
|
output_tokens_per_camera: false
|
||||||
num_cameras: 3 # 摄像头数量
|
num_cameras: 3 # 摄像头数量
|
||||||
|
|
||||||
|
# ====================
|
||||||
|
# Full-AttnRes vision trunk(当 vision_backbone_mode=attnres_resnet 时生效)
|
||||||
|
# ====================
|
||||||
|
attnres_stem_dim: 64
|
||||||
|
attnres_stage_dims: [64, 128, 256, 512]
|
||||||
|
attnres_stage_depths: [2, 2, 2, 2]
|
||||||
|
attnres_stage_heads: [4, 4, 8, 8]
|
||||||
|
attnres_stage_kv_heads: [1, 1, 1, 1]
|
||||||
|
attnres_stage_window_sizes: [7, 7, 7, 7]
|
||||||
|
attnres_dropout: 0.0
|
||||||
|
attnres_ffn_mult: 2.667
|
||||||
|
attnres_eps: 1.0e-06
|
||||||
|
attnres_rope_theta: 10000.0
|
||||||
|
|||||||
10
roboimi/vla/conf/backbone/siglip2_diffusion.yaml
Normal file
10
roboimi/vla/conf/backbone/siglip2_diffusion.yaml
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
_target_: roboimi.vla.models.backbones.siglip2_diffusion_backbone.SigLIP2DiffusionBackbone
|
||||||
|
|
||||||
|
model_name: google/siglip2-base-patch16-256
|
||||||
|
camera_names: [r_vis, top, front]
|
||||||
|
num_cameras: 3
|
||||||
|
per_view_output_dim: 96
|
||||||
|
freeze_backbone: true
|
||||||
|
|
||||||
|
dataset_image_resize_shape: null
|
||||||
|
eval_image_resize_shape: [256, 256]
|
||||||
@@ -13,6 +13,7 @@ train:
|
|||||||
lr: 1e-4 # 学习率
|
lr: 1e-4 # 学习率
|
||||||
max_steps: 100000 # 最大训练步数
|
max_steps: 100000 # 最大训练步数
|
||||||
device: "cuda" # 设备: "cuda" 或 "cpu"
|
device: "cuda" # 设备: "cuda" 或 "cpu"
|
||||||
|
disable_cudnn: false # 遇到当前机器的 cuDNN 兼容性问题时可置 true
|
||||||
|
|
||||||
# 数据加载
|
# 数据加载
|
||||||
num_workers: 12 # DataLoader 工作进程数(调试时设为 0)
|
num_workers: 12 # DataLoader 工作进程数(调试时设为 0)
|
||||||
|
|||||||
@@ -19,3 +19,6 @@ camera_names:
|
|||||||
- r_vis # 机器人视角相机
|
- r_vis # 机器人视角相机
|
||||||
- top # 顶部相机
|
- top # 顶部相机
|
||||||
- front # 前方相机
|
- front # 前方相机
|
||||||
|
|
||||||
|
# 单视角预缩放尺寸;为 null 时保留数据集中的原始分辨率
|
||||||
|
image_resize_shape: [224, 224]
|
||||||
|
|||||||
@@ -41,6 +41,9 @@ save_timing: false # 是否保存 timing.json(包含各阶段耗时
|
|||||||
save_trajectory: false # 是否保存 trajectory.npz(原始 EE action + 执行后 EE pose)
|
save_trajectory: false # 是否保存 trajectory.npz(原始 EE action + 执行后 EE pose)
|
||||||
save_summary_json: false # 是否保存 JSON-friendly rollout summary
|
save_summary_json: false # 是否保存 JSON-friendly rollout summary
|
||||||
save_trajectory_npz: false # 是否保存每步轨迹/时序/EE pose 为 NPZ
|
save_trajectory_npz: false # 是否保存每步轨迹/时序/EE pose 为 NPZ
|
||||||
|
save_trajectory_image: false # 是否保存带红色 EE 轨迹覆盖的静态 PNG
|
||||||
|
trajectory_image_camera: null # trajectory_image_camera_name 的别名
|
||||||
|
trajectory_image_camera_name: null # 导出轨迹图片使用的相机名;为空时默认取 camera_names[0]
|
||||||
record_video: false # 是否从单个相机流录制 rollout mp4
|
record_video: false # 是否从单个相机流录制 rollout mp4
|
||||||
video_camera: null # video_camera_name 的别名
|
video_camera: null # video_camera_name 的别名
|
||||||
video_camera_name: null # 录制视频使用的相机名;为空时默认取 camera_names[0]
|
video_camera_name: null # 录制视频使用的相机名;为空时默认取 camera_names[0]
|
||||||
|
|||||||
22
roboimi/vla/conf/head/imf_transformer1d.yaml
Normal file
22
roboimi/vla/conf/head/imf_transformer1d.yaml
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
_target_: roboimi.vla.models.heads.imf_transformer1d.IMFTransformer1D
|
||||||
|
_partial_: true
|
||||||
|
|
||||||
|
input_dim: ${agent.action_dim}
|
||||||
|
output_dim: ${agent.action_dim}
|
||||||
|
horizon: ${agent.pred_horizon}
|
||||||
|
n_obs_steps: ${agent.obs_horizon}
|
||||||
|
cond_dim: 208
|
||||||
|
n_layer: 12
|
||||||
|
n_head: 1
|
||||||
|
n_emb: 768
|
||||||
|
p_drop_emb: 0.1
|
||||||
|
p_drop_attn: 0.1
|
||||||
|
causal_attn: false
|
||||||
|
time_as_cond: true
|
||||||
|
obs_as_cond: true
|
||||||
|
n_cond_layers: 0
|
||||||
|
backbone_type: attnres_full
|
||||||
|
n_kv_head: 1
|
||||||
|
attn_res_ffn_mult: 2.667
|
||||||
|
attn_res_eps: 1.0e-6
|
||||||
|
attn_res_rope_theta: 10000.0
|
||||||
5
roboimi/vla/conf/modules/linear_condition_projector.yaml
Normal file
5
roboimi/vla/conf/modules/linear_condition_projector.yaml
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
_target_: roboimi.vla.modules.projectors.LinearConditionProjector
|
||||||
|
_partial_: true
|
||||||
|
|
||||||
|
output_dim: 384
|
||||||
|
bias: true
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import h5py
|
import h5py
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from typing import List, Dict, Union
|
from typing import List, Dict, Union, Optional, Sequence
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
@@ -22,6 +22,7 @@ class SimpleRobotDataset(Dataset):
|
|||||||
obs_horizon: int = 2,
|
obs_horizon: int = 2,
|
||||||
pred_horizon: int = 8,
|
pred_horizon: int = 8,
|
||||||
camera_names: List[str] = None,
|
camera_names: List[str] = None,
|
||||||
|
image_resize_shape: Optional[Sequence[int]] = (224, 224),
|
||||||
max_open_files: int = 64,
|
max_open_files: int = 64,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -30,6 +31,7 @@ class SimpleRobotDataset(Dataset):
|
|||||||
obs_horizon: 观察过去多少帧
|
obs_horizon: 观察过去多少帧
|
||||||
pred_horizon: 预测未来多少帧动作
|
pred_horizon: 预测未来多少帧动作
|
||||||
camera_names: 相机名称列表,如 ["r_vis", "top", "front"]
|
camera_names: 相机名称列表,如 ["r_vis", "top", "front"]
|
||||||
|
image_resize_shape: 图像缩放尺寸 (W, H);为 None 时保留原始分辨率
|
||||||
max_open_files: 每个 worker 最多缓存的 HDF5 文件句柄数
|
max_open_files: 每个 worker 最多缓存的 HDF5 文件句柄数
|
||||||
|
|
||||||
HDF5 文件格式:
|
HDF5 文件格式:
|
||||||
@@ -40,6 +42,10 @@ class SimpleRobotDataset(Dataset):
|
|||||||
self.obs_horizon = obs_horizon
|
self.obs_horizon = obs_horizon
|
||||||
self.pred_horizon = pred_horizon
|
self.pred_horizon = pred_horizon
|
||||||
self.camera_names = camera_names or []
|
self.camera_names = camera_names or []
|
||||||
|
self.image_resize_shape = (
|
||||||
|
tuple(int(v) for v in image_resize_shape)
|
||||||
|
if image_resize_shape is not None else None
|
||||||
|
)
|
||||||
self.max_open_files = max(1, int(max_open_files))
|
self.max_open_files = max(1, int(max_open_files))
|
||||||
self._file_cache: "OrderedDict[str, h5py.File]" = OrderedDict()
|
self._file_cache: "OrderedDict[str, h5py.File]" = OrderedDict()
|
||||||
|
|
||||||
@@ -123,9 +129,9 @@ class SimpleRobotDataset(Dataset):
|
|||||||
h5_path = f'observations/images/{cam_name}'
|
h5_path = f'observations/images/{cam_name}'
|
||||||
if h5_path in f:
|
if h5_path in f:
|
||||||
img = f[h5_path][meta["frame_idx"]]
|
img = f[h5_path][meta["frame_idx"]]
|
||||||
# Resize图像到224x224(减少内存和I/O负担)
|
if self.image_resize_shape is not None:
|
||||||
import cv2
|
import cv2
|
||||||
img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_LINEAR)
|
img = cv2.resize(img, self.image_resize_shape, interpolation=cv2.INTER_LINEAR)
|
||||||
# 转换为float并归一化到 [0, 1]
|
# 转换为float并归一化到 [0, 1]
|
||||||
img = torch.from_numpy(img).float() / 255.0
|
img = torch.from_numpy(img).float() / 255.0
|
||||||
frame[f"observation.{cam_name}"] = img.permute(2, 0, 1) # HWC -> CHW
|
frame[f"observation.{cam_name}"] = img.permute(2, 0, 1) # HWC -> CHW
|
||||||
|
|||||||
@@ -1,4 +1,15 @@
|
|||||||
# Backbone models
|
# Backbone models
|
||||||
from .resnet_diffusion import ResNetDiffusionBackbone
|
__all__ = ["LEWMViTBackbone", "ResNetBackbone", "ResNetDiffusionBackbone", "SigLIP2DiffusionBackbone"]
|
||||||
|
|
||||||
__all__ = ["ResNetBackbone", "ResNetDiffusionBackbone"]
|
|
||||||
|
def __getattr__(name):
|
||||||
|
if name == "LEWMViTBackbone":
|
||||||
|
from .lewm_vit_backbone import LEWMViTBackbone
|
||||||
|
return LEWMViTBackbone
|
||||||
|
if name == "SigLIP2DiffusionBackbone":
|
||||||
|
from .siglip2_diffusion_backbone import SigLIP2DiffusionBackbone
|
||||||
|
return SigLIP2DiffusionBackbone
|
||||||
|
if name in {"ResNetBackbone", "ResNetDiffusionBackbone"}:
|
||||||
|
from .resnet_diffusion import ResNetDiffusionBackbone
|
||||||
|
return ResNetDiffusionBackbone
|
||||||
|
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||||
|
|||||||
228
roboimi/vla/models/backbones/attnres_resnet2d.py
Normal file
228
roboimi/vla/models/backbones/attnres_resnet2d.py
Normal file
@@ -0,0 +1,228 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Iterable, Sequence
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from roboimi.vla.models.heads.attnres_transformer_components import AttnResTransformerBackbone
|
||||||
|
|
||||||
|
|
||||||
|
def _make_norm2d(num_channels: int, use_group_norm: bool) -> nn.Module:
|
||||||
|
if use_group_norm:
|
||||||
|
num_groups = max(1, num_channels // 16)
|
||||||
|
return nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)
|
||||||
|
return nn.BatchNorm2d(num_channels)
|
||||||
|
|
||||||
|
|
||||||
|
class _ConvNormAct(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
*,
|
||||||
|
kernel_size: int,
|
||||||
|
stride: int,
|
||||||
|
padding: int,
|
||||||
|
use_group_norm: bool,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.block = nn.Sequential(
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
bias=False,
|
||||||
|
),
|
||||||
|
_make_norm2d(out_channels, use_group_norm),
|
||||||
|
nn.SiLU(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.block(x)
|
||||||
|
|
||||||
|
|
||||||
|
class AttnResImageBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
*,
|
||||||
|
window_size: int,
|
||||||
|
n_heads: int,
|
||||||
|
n_kv_heads: int,
|
||||||
|
dropout: float,
|
||||||
|
ffn_mult: float,
|
||||||
|
eps: float,
|
||||||
|
rope_theta: float,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.window_size = int(window_size)
|
||||||
|
self.block = AttnResTransformerBackbone(
|
||||||
|
d_model=dim,
|
||||||
|
n_blocks=1,
|
||||||
|
n_heads=n_heads,
|
||||||
|
n_kv_heads=n_kv_heads,
|
||||||
|
max_seq_len=self.window_size * self.window_size,
|
||||||
|
dropout=dropout,
|
||||||
|
ffn_mult=ffn_mult,
|
||||||
|
eps=eps,
|
||||||
|
rope_theta=rope_theta,
|
||||||
|
causal_attn=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
bsz, channels, height, width = x.shape
|
||||||
|
ws = self.window_size
|
||||||
|
pad_h = (ws - height % ws) % ws
|
||||||
|
pad_w = (ws - width % ws) % ws
|
||||||
|
if pad_h or pad_w:
|
||||||
|
x = F.pad(x, (0, pad_w, 0, pad_h))
|
||||||
|
padded_height, padded_width = x.shape[-2:]
|
||||||
|
num_h = padded_height // ws
|
||||||
|
num_w = padded_width // ws
|
||||||
|
|
||||||
|
windows = (
|
||||||
|
x.permute(0, 2, 3, 1)
|
||||||
|
.contiguous()
|
||||||
|
.view(bsz, num_h, ws, num_w, ws, channels)
|
||||||
|
.permute(0, 1, 3, 2, 4, 5)
|
||||||
|
.contiguous()
|
||||||
|
.view(bsz * num_h * num_w, ws * ws, channels)
|
||||||
|
)
|
||||||
|
windows = self.block(windows)
|
||||||
|
x = (
|
||||||
|
windows.view(bsz, num_h, num_w, ws, ws, channels)
|
||||||
|
.permute(0, 1, 3, 2, 4, 5)
|
||||||
|
.contiguous()
|
||||||
|
.view(bsz, padded_height, padded_width, channels)
|
||||||
|
.permute(0, 3, 1, 2)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
return x[:, :, :height, :width]
|
||||||
|
|
||||||
|
|
||||||
|
class _AttnResStage2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
*,
|
||||||
|
depth: int,
|
||||||
|
downsample_stride: int,
|
||||||
|
use_group_norm: bool,
|
||||||
|
window_size: int,
|
||||||
|
n_heads: int,
|
||||||
|
n_kv_heads: int,
|
||||||
|
dropout: float,
|
||||||
|
ffn_mult: float,
|
||||||
|
eps: float,
|
||||||
|
rope_theta: float,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.downsample = None
|
||||||
|
if in_channels != out_channels or downsample_stride != 1:
|
||||||
|
kernel_size = 1 if downsample_stride == 1 else 3
|
||||||
|
padding = 0 if downsample_stride == 1 else 1
|
||||||
|
self.downsample = _ConvNormAct(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=downsample_stride,
|
||||||
|
padding=padding,
|
||||||
|
use_group_norm=use_group_norm,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
AttnResImageBlock2D(
|
||||||
|
out_channels,
|
||||||
|
window_size=window_size,
|
||||||
|
n_heads=n_heads,
|
||||||
|
n_kv_heads=n_kv_heads,
|
||||||
|
dropout=dropout,
|
||||||
|
ffn_mult=ffn_mult,
|
||||||
|
eps=eps,
|
||||||
|
rope_theta=rope_theta,
|
||||||
|
)
|
||||||
|
for _ in range(depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.downsample is not None:
|
||||||
|
x = self.downsample(x)
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AttnResResNetLikeBackbone2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
input_channels: int = 3,
|
||||||
|
stem_dim: int = 64,
|
||||||
|
stage_dims: Sequence[int] = (64, 128, 256, 512),
|
||||||
|
stage_depths: Sequence[int] = (2, 2, 2, 2),
|
||||||
|
stage_heads: Sequence[int] = (4, 4, 8, 8),
|
||||||
|
stage_kv_heads: Sequence[int] = (1, 1, 1, 1),
|
||||||
|
stage_window_sizes: Sequence[int] = (7, 7, 7, 7),
|
||||||
|
use_group_norm: bool = True,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
ffn_mult: float = 2.667,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
rope_theta: float = 10000.0,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
lengths = {
|
||||||
|
len(stage_dims),
|
||||||
|
len(stage_depths),
|
||||||
|
len(stage_heads),
|
||||||
|
len(stage_kv_heads),
|
||||||
|
len(stage_window_sizes),
|
||||||
|
}
|
||||||
|
if len(lengths) != 1:
|
||||||
|
raise ValueError('stage_dims/depths/heads/kv_heads/window_sizes 长度必须一致')
|
||||||
|
|
||||||
|
self.stem = nn.Sequential(
|
||||||
|
nn.Conv2d(input_channels, stem_dim, kernel_size=7, stride=2, padding=3, bias=False),
|
||||||
|
_make_norm2d(stem_dim, use_group_norm),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
in_channels = stem_dim
|
||||||
|
stages = []
|
||||||
|
for stage_idx, (out_channels, depth, n_heads, n_kv_heads, window_size) in enumerate(
|
||||||
|
zip(stage_dims, stage_depths, stage_heads, stage_kv_heads, stage_window_sizes)
|
||||||
|
):
|
||||||
|
stages.append(
|
||||||
|
_AttnResStage2D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
depth=int(depth),
|
||||||
|
downsample_stride=1 if stage_idx == 0 else 2,
|
||||||
|
use_group_norm=use_group_norm,
|
||||||
|
window_size=int(window_size),
|
||||||
|
n_heads=int(n_heads),
|
||||||
|
n_kv_heads=int(n_kv_heads),
|
||||||
|
dropout=float(dropout),
|
||||||
|
ffn_mult=float(ffn_mult),
|
||||||
|
eps=float(eps),
|
||||||
|
rope_theta=float(rope_theta),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
in_channels = out_channels
|
||||||
|
|
||||||
|
self.stages = nn.ModuleList(stages)
|
||||||
|
self.output_channels = in_channels
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.stem(x)
|
||||||
|
for stage in self.stages:
|
||||||
|
x = stage(x)
|
||||||
|
return x
|
||||||
230
roboimi/vla/models/backbones/lewm_vit_backbone.py
Normal file
230
roboimi/vla/models/backbones/lewm_vit_backbone.py
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Mapping, Sequence
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from roboimi.vla.core.interfaces import VLABackbone
|
||||||
|
|
||||||
|
|
||||||
|
class _LEWMProjector(nn.Module):
|
||||||
|
"""LEWM projector MLP: 192 -> 2048 -> 192 with BatchNorm1d + GELU."""
|
||||||
|
|
||||||
|
def __init__(self, input_dim: int = 192, hidden_dim: int = 2048, output_dim: int = 192) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Linear(input_dim, hidden_dim),
|
||||||
|
nn.BatchNorm1d(hidden_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(hidden_dim, output_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
|
class LEWMViTBackbone(VLABackbone):
|
||||||
|
"""Frozen LEWM joint-multiview ViT backbone.
|
||||||
|
|
||||||
|
The backbone fuses the three camera views into a single LEWM-style image,
|
||||||
|
runs a ViT-tiny encoder plus the LEWM projector, and returns one joint
|
||||||
|
192-d embedding per timestep.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
checkpoint_path: str | Path | None = None,
|
||||||
|
*,
|
||||||
|
checkpoint: Mapping[str, Any] | None = None,
|
||||||
|
camera_names: Sequence[str] = ("r_vis", "top", "front"),
|
||||||
|
fused_camera_names: Sequence[str] = ("front", "top", "r_vis"),
|
||||||
|
num_cameras: int | None = None,
|
||||||
|
dataset_image_resize_shape: Sequence[int] | None = None,
|
||||||
|
eval_image_resize_shape: Sequence[int] | None = (256, 256),
|
||||||
|
freeze_backbone: bool = True,
|
||||||
|
joint_output_dim: int = 192,
|
||||||
|
image_size: int = 224,
|
||||||
|
output_dim: int = 192,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.camera_names = tuple(camera_names)
|
||||||
|
self.fused_camera_names = tuple(fused_camera_names)
|
||||||
|
self.num_cameras = int(num_cameras) if num_cameras is not None else len(self.camera_names)
|
||||||
|
self.freeze_backbone = bool(freeze_backbone)
|
||||||
|
self.joint_output_dim = int(joint_output_dim)
|
||||||
|
self.image_size = int(image_size)
|
||||||
|
self._output_dim = int(output_dim)
|
||||||
|
self.dataset_image_resize_shape = (
|
||||||
|
tuple(int(v) for v in dataset_image_resize_shape)
|
||||||
|
if dataset_image_resize_shape is not None else None
|
||||||
|
)
|
||||||
|
self.eval_image_resize_shape = (
|
||||||
|
tuple(int(v) for v in eval_image_resize_shape)
|
||||||
|
if eval_image_resize_shape is not None else None
|
||||||
|
)
|
||||||
|
if self.num_cameras != len(self.camera_names):
|
||||||
|
raise ValueError(
|
||||||
|
f"num_cameras({self.num_cameras}) must match len(camera_names)({len(self.camera_names)})"
|
||||||
|
)
|
||||||
|
if set(self.fused_camera_names) != set(self.camera_names):
|
||||||
|
raise ValueError(
|
||||||
|
"fused_camera_names must contain the same cameras as camera_names. "
|
||||||
|
f"got camera_names={list(self.camera_names)}, fused_camera_names={list(self.fused_camera_names)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.encoder = self._build_encoder(self.image_size)
|
||||||
|
self.projector = _LEWMProjector(
|
||||||
|
input_dim=self.encoder.config.hidden_size,
|
||||||
|
hidden_dim=2048,
|
||||||
|
output_dim=self.joint_output_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.register_buffer(
|
||||||
|
"mean",
|
||||||
|
torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32).view(1, 3, 1, 1),
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"std",
|
||||||
|
torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32).view(1, 3, 1, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
if checkpoint_path is not None and checkpoint is not None:
|
||||||
|
raise ValueError("checkpoint_path and checkpoint cannot both be provided")
|
||||||
|
if checkpoint_path is not None:
|
||||||
|
self.load_lewm_checkpoint(checkpoint_path)
|
||||||
|
elif checkpoint is not None:
|
||||||
|
self.load_lewm_checkpoint(checkpoint)
|
||||||
|
|
||||||
|
if self.freeze_backbone:
|
||||||
|
self._freeze_encoder_and_projector()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_encoder_config(image_size: int):
|
||||||
|
from transformers import ViTConfig
|
||||||
|
|
||||||
|
return ViTConfig(
|
||||||
|
image_size=image_size,
|
||||||
|
patch_size=14,
|
||||||
|
num_channels=3,
|
||||||
|
hidden_size=192,
|
||||||
|
intermediate_size=768,
|
||||||
|
num_hidden_layers=12,
|
||||||
|
num_attention_heads=3,
|
||||||
|
qkv_bias=True,
|
||||||
|
hidden_dropout_prob=0.0,
|
||||||
|
attention_probs_dropout_prob=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _build_encoder(cls, image_size: int) -> nn.Module:
|
||||||
|
from transformers import ViTModel
|
||||||
|
|
||||||
|
return ViTModel(cls._build_encoder_config(image_size), add_pooling_layer=False)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _unwrap_state_dict(payload: Mapping[str, Any]) -> Mapping[str, torch.Tensor]:
|
||||||
|
state_dict = payload.get("state_dict", payload)
|
||||||
|
if not isinstance(state_dict, Mapping):
|
||||||
|
raise TypeError("checkpoint payload must contain a mapping state_dict")
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_prefixed_state_dict(
|
||||||
|
state_dict: Mapping[str, torch.Tensor],
|
||||||
|
prefix: str,
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
extracted = {
|
||||||
|
key[len(prefix) :]: value
|
||||||
|
for key, value in state_dict.items()
|
||||||
|
if key.startswith(prefix)
|
||||||
|
}
|
||||||
|
if not extracted:
|
||||||
|
raise KeyError(f"checkpoint missing parameters with prefix {prefix!r}")
|
||||||
|
return extracted
|
||||||
|
|
||||||
|
def load_lewm_checkpoint(self, checkpoint_or_path: str | Path | Mapping[str, Any]) -> None:
|
||||||
|
if isinstance(checkpoint_or_path, (str, Path)):
|
||||||
|
payload = torch.load(Path(checkpoint_or_path), map_location="cpu", weights_only=False)
|
||||||
|
else:
|
||||||
|
payload = checkpoint_or_path
|
||||||
|
|
||||||
|
state_dict = self._unwrap_state_dict(payload)
|
||||||
|
encoder_state_dict = self._extract_prefixed_state_dict(state_dict, "model.encoder.")
|
||||||
|
projector_state_dict = self._extract_prefixed_state_dict(state_dict, "model.projector.")
|
||||||
|
|
||||||
|
self.encoder.load_state_dict(encoder_state_dict, strict=True)
|
||||||
|
self.projector.load_state_dict(projector_state_dict, strict=True)
|
||||||
|
|
||||||
|
def _freeze_encoder_and_projector(self) -> None:
|
||||||
|
for module in (self.encoder, self.projector):
|
||||||
|
module.eval()
|
||||||
|
for parameter in module.parameters():
|
||||||
|
parameter.requires_grad = False
|
||||||
|
|
||||||
|
def train(self, mode: bool = True) -> "LEWMViTBackbone":
|
||||||
|
super().train(mode)
|
||||||
|
if self.freeze_backbone:
|
||||||
|
self._freeze_encoder_and_projector()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _ordered_images(self, images: Dict[str, torch.Tensor]) -> list[torch.Tensor]:
|
||||||
|
missing = [camera_name for camera_name in self.camera_names if camera_name not in images]
|
||||||
|
if missing:
|
||||||
|
raise ValueError(
|
||||||
|
f"image input missing required cameras. missing={missing}, expected={list(self.camera_names)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
ordered = [images[camera_name] for camera_name in self.camera_names]
|
||||||
|
reference_shape = ordered[0].shape
|
||||||
|
if len(reference_shape) != 5:
|
||||||
|
raise ValueError(f"expected image tensors shaped (B, T, C, H, W), got {reference_shape}")
|
||||||
|
|
||||||
|
for camera_name, image in zip(self.camera_names[1:], ordered[1:]):
|
||||||
|
if image.shape != reference_shape:
|
||||||
|
raise ValueError(
|
||||||
|
f"camera {camera_name!r} shape {tuple(image.shape)} does not match {tuple(reference_shape)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return ordered
|
||||||
|
|
||||||
|
def _prepare_pixels(self, images: Dict[str, torch.Tensor]) -> tuple[torch.Tensor, int, int]:
|
||||||
|
self._ordered_images(images)
|
||||||
|
fused = torch.cat([images[camera_name] for camera_name in self.fused_camera_names], dim=-2)
|
||||||
|
bsz, steps = fused.shape[:2]
|
||||||
|
fused = fused.reshape(bsz * steps, *fused.shape[2:]).contiguous().float()
|
||||||
|
|
||||||
|
fused = fused.clamp(0.0, 1.0)
|
||||||
|
fused = (fused - self.mean) / self.std
|
||||||
|
|
||||||
|
height, width = fused.shape[-2:]
|
||||||
|
short_side = min(height, width)
|
||||||
|
if short_side <= 0:
|
||||||
|
raise ValueError(f"invalid fused image shape: {tuple(fused.shape)}")
|
||||||
|
scale = self.image_size / float(short_side)
|
||||||
|
resized_height = int(round(height * scale))
|
||||||
|
resized_width = int(round(width * scale))
|
||||||
|
if (resized_height, resized_width) != (height, width):
|
||||||
|
fused = F.interpolate(
|
||||||
|
fused,
|
||||||
|
size=(resized_height, resized_width),
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=False,
|
||||||
|
antialias=True,
|
||||||
|
)
|
||||||
|
return fused, bsz, steps
|
||||||
|
|
||||||
|
def forward(self, images: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||||
|
pixels, bsz, steps = self._prepare_pixels(images)
|
||||||
|
with torch.set_grad_enabled(torch.is_grad_enabled() and not self.freeze_backbone):
|
||||||
|
output = self.encoder(pixel_values=pixels, interpolate_pos_encoding=True)
|
||||||
|
cls = output.last_hidden_state[:, 0]
|
||||||
|
embedding = self.projector(cls)
|
||||||
|
return embedding.view(bsz, steps, self.joint_output_dim)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_dim(self) -> int:
|
||||||
|
return self._output_dim
|
||||||
@@ -6,6 +6,8 @@ import torchvision
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Callable, Optional, Tuple, Union
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from .attnres_resnet2d import AttnResResNetLikeBackbone2D
|
||||||
|
|
||||||
def _replace_submodules(
|
def _replace_submodules(
|
||||||
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
|
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
@@ -103,6 +105,17 @@ class _SingleRgbEncoder(nn.Module):
|
|||||||
use_group_norm: bool,
|
use_group_norm: bool,
|
||||||
spatial_softmax_num_keypoints: int,
|
spatial_softmax_num_keypoints: int,
|
||||||
freeze_backbone: bool = True, # 新增:是否冻结backbone
|
freeze_backbone: bool = True, # 新增:是否冻结backbone
|
||||||
|
vision_backbone_mode: str = "resnet",
|
||||||
|
attnres_stem_dim: int = 64,
|
||||||
|
attnres_stage_dims: Optional[Tuple[int, ...]] = None,
|
||||||
|
attnres_stage_depths: Optional[Tuple[int, ...]] = None,
|
||||||
|
attnres_stage_heads: Optional[Tuple[int, ...]] = None,
|
||||||
|
attnres_stage_kv_heads: Optional[Tuple[int, ...]] = None,
|
||||||
|
attnres_stage_window_sizes: Optional[Tuple[int, ...]] = None,
|
||||||
|
attnres_dropout: float = 0.0,
|
||||||
|
attnres_ffn_mult: float = 2.667,
|
||||||
|
attnres_eps: float = 1e-6,
|
||||||
|
attnres_rope_theta: float = 10000.0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -119,6 +132,7 @@ class _SingleRgbEncoder(nn.Module):
|
|||||||
self.do_crop = False
|
self.do_crop = False
|
||||||
crop_shape = input_shape[1:]
|
crop_shape = input_shape[1:]
|
||||||
|
|
||||||
|
if vision_backbone_mode == "resnet":
|
||||||
# 设置骨干网络
|
# 设置骨干网络
|
||||||
backbone_model = getattr(torchvision.models, vision_backbone)(
|
backbone_model = getattr(torchvision.models, vision_backbone)(
|
||||||
weights=pretrained_backbone_weights
|
weights=pretrained_backbone_weights
|
||||||
@@ -131,8 +145,28 @@ class _SingleRgbEncoder(nn.Module):
|
|||||||
self.backbone = _replace_submodules(
|
self.backbone = _replace_submodules(
|
||||||
root_module=self.backbone,
|
root_module=self.backbone,
|
||||||
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
||||||
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
|
func=lambda x: nn.GroupNorm(
|
||||||
|
num_groups=max(1, x.num_features // 16),
|
||||||
|
num_channels=x.num_features,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
elif vision_backbone_mode == "attnres_resnet":
|
||||||
|
self.backbone = AttnResResNetLikeBackbone2D(
|
||||||
|
input_channels=input_shape[0],
|
||||||
|
stem_dim=attnres_stem_dim,
|
||||||
|
stage_dims=tuple(attnres_stage_dims or (64, 128, 256, 512)),
|
||||||
|
stage_depths=tuple(attnres_stage_depths or (2, 2, 2, 2)),
|
||||||
|
stage_heads=tuple(attnres_stage_heads or (4, 4, 8, 8)),
|
||||||
|
stage_kv_heads=tuple(attnres_stage_kv_heads or (1, 1, 1, 1)),
|
||||||
|
stage_window_sizes=tuple(attnres_stage_window_sizes or (7, 7, 7, 7)),
|
||||||
|
use_group_norm=use_group_norm,
|
||||||
|
dropout=attnres_dropout,
|
||||||
|
ffn_mult=attnres_ffn_mult,
|
||||||
|
eps=attnres_eps,
|
||||||
|
rope_theta=attnres_rope_theta,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"不支持的 vision_backbone_mode: {vision_backbone_mode}")
|
||||||
|
|
||||||
# 冻结backbone参数(可选)
|
# 冻结backbone参数(可选)
|
||||||
if freeze_backbone:
|
if freeze_backbone:
|
||||||
@@ -177,14 +211,28 @@ class ResNetDiffusionBackbone(VLABackbone):
|
|||||||
use_group_norm: bool = True,
|
use_group_norm: bool = True,
|
||||||
spatial_softmax_num_keypoints: int = 32,
|
spatial_softmax_num_keypoints: int = 32,
|
||||||
use_separate_rgb_encoder_per_camera: bool = False, # 新增:是否为每个摄像头使用独立编码器
|
use_separate_rgb_encoder_per_camera: bool = False, # 新增:是否为每个摄像头使用独立编码器
|
||||||
|
output_tokens_per_camera: bool = False, # 是否按相机返回多token,而不是拼成一个token
|
||||||
num_cameras: int = 1, # 新增:摄像头数量(仅在独立编码器模式下使用)
|
num_cameras: int = 1, # 新增:摄像头数量(仅在独立编码器模式下使用)
|
||||||
camera_names: Optional[Tuple[str, ...]] = None, # 显式相机顺序
|
camera_names: Optional[Tuple[str, ...]] = None, # 显式相机顺序
|
||||||
freeze_backbone: bool = True, # 新增:是否冻结ResNet backbone(推荐True)
|
freeze_backbone: bool = True, # 新增:是否冻结ResNet backbone(推荐True)
|
||||||
|
vision_backbone_mode: str = "resnet",
|
||||||
|
attnres_stem_dim: int = 64,
|
||||||
|
attnres_stage_dims: Optional[Tuple[int, ...]] = None,
|
||||||
|
attnres_stage_depths: Optional[Tuple[int, ...]] = None,
|
||||||
|
attnres_stage_heads: Optional[Tuple[int, ...]] = None,
|
||||||
|
attnres_stage_kv_heads: Optional[Tuple[int, ...]] = None,
|
||||||
|
attnres_stage_window_sizes: Optional[Tuple[int, ...]] = None,
|
||||||
|
attnres_dropout: float = 0.0,
|
||||||
|
attnres_ffn_mult: float = 2.667,
|
||||||
|
attnres_eps: float = 1e-6,
|
||||||
|
attnres_rope_theta: float = 10000.0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.use_separate_rgb_encoder_per_camera = use_separate_rgb_encoder_per_camera
|
self.use_separate_rgb_encoder_per_camera = use_separate_rgb_encoder_per_camera
|
||||||
|
self.output_tokens_per_camera = bool(output_tokens_per_camera)
|
||||||
self.num_cameras = num_cameras
|
self.num_cameras = num_cameras
|
||||||
|
self.tokens_per_step = self.num_cameras if self.output_tokens_per_camera else 1
|
||||||
self.camera_names = tuple(camera_names) if camera_names is not None else None
|
self.camera_names = tuple(camera_names) if camera_names is not None else None
|
||||||
if self.camera_names is not None and len(self.camera_names) != self.num_cameras:
|
if self.camera_names is not None and len(self.camera_names) != self.num_cameras:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -203,6 +251,17 @@ class ResNetDiffusionBackbone(VLABackbone):
|
|||||||
use_group_norm=use_group_norm,
|
use_group_norm=use_group_norm,
|
||||||
spatial_softmax_num_keypoints=spatial_softmax_num_keypoints,
|
spatial_softmax_num_keypoints=spatial_softmax_num_keypoints,
|
||||||
freeze_backbone=freeze_backbone,
|
freeze_backbone=freeze_backbone,
|
||||||
|
vision_backbone_mode=vision_backbone_mode,
|
||||||
|
attnres_stem_dim=attnres_stem_dim,
|
||||||
|
attnres_stage_dims=attnres_stage_dims,
|
||||||
|
attnres_stage_depths=attnres_stage_depths,
|
||||||
|
attnres_stage_heads=attnres_stage_heads,
|
||||||
|
attnres_stage_kv_heads=attnres_stage_kv_heads,
|
||||||
|
attnres_stage_window_sizes=attnres_stage_window_sizes,
|
||||||
|
attnres_dropout=attnres_dropout,
|
||||||
|
attnres_ffn_mult=attnres_ffn_mult,
|
||||||
|
attnres_eps=attnres_eps,
|
||||||
|
attnres_rope_theta=attnres_rope_theta,
|
||||||
)
|
)
|
||||||
for _ in range(num_cameras)
|
for _ in range(num_cameras)
|
||||||
]
|
]
|
||||||
@@ -220,6 +279,17 @@ class ResNetDiffusionBackbone(VLABackbone):
|
|||||||
use_group_norm=use_group_norm,
|
use_group_norm=use_group_norm,
|
||||||
spatial_softmax_num_keypoints=spatial_softmax_num_keypoints,
|
spatial_softmax_num_keypoints=spatial_softmax_num_keypoints,
|
||||||
freeze_backbone=freeze_backbone,
|
freeze_backbone=freeze_backbone,
|
||||||
|
vision_backbone_mode=vision_backbone_mode,
|
||||||
|
attnres_stem_dim=attnres_stem_dim,
|
||||||
|
attnres_stage_dims=attnres_stage_dims,
|
||||||
|
attnres_stage_depths=attnres_stage_depths,
|
||||||
|
attnres_stage_heads=attnres_stage_heads,
|
||||||
|
attnres_stage_kv_heads=attnres_stage_kv_heads,
|
||||||
|
attnres_stage_window_sizes=attnres_stage_window_sizes,
|
||||||
|
attnres_dropout=attnres_dropout,
|
||||||
|
attnres_ffn_mult=attnres_ffn_mult,
|
||||||
|
attnres_eps=attnres_eps,
|
||||||
|
attnres_rope_theta=attnres_rope_theta,
|
||||||
)
|
)
|
||||||
self.feature_dim = self.rgb_encoder.feature_dim
|
self.feature_dim = self.rgb_encoder.feature_dim
|
||||||
|
|
||||||
@@ -252,22 +322,24 @@ class ResNetDiffusionBackbone(VLABackbone):
|
|||||||
B, T = any_tensor.shape[:2]
|
B, T = any_tensor.shape[:2]
|
||||||
cam_names = self._ordered_camera_names(images)
|
cam_names = self._ordered_camera_names(images)
|
||||||
|
|
||||||
|
features_all = []
|
||||||
if self.use_separate_rgb_encoder_per_camera:
|
if self.use_separate_rgb_encoder_per_camera:
|
||||||
# 独立编码器模式:每个摄像头使用对应的编码器
|
# 独立编码器模式:每个摄像头使用对应的编码器
|
||||||
features_all = []
|
|
||||||
for cam_idx, cam_name in enumerate(cam_names):
|
for cam_idx, cam_name in enumerate(cam_names):
|
||||||
img = images[cam_name]
|
img = images[cam_name]
|
||||||
encoder = self.rgb_encoder[cam_idx]
|
encoder = self.rgb_encoder[cam_idx]
|
||||||
features = encoder.forward_single_image(img.reshape(B * T, *img.shape[2:]))
|
features = encoder.forward_single_image(img.reshape(B * T, *img.shape[2:]))
|
||||||
features_all.append(features)
|
features_all.append(features)
|
||||||
return torch.cat(features_all, dim=1).view(B, T, -1)
|
|
||||||
else:
|
else:
|
||||||
# 共享编码器模式:所有摄像头共享同一个编码器
|
# 共享编码器模式:所有摄像头共享同一个编码器
|
||||||
features_all = []
|
|
||||||
for cam_name in cam_names:
|
for cam_name in cam_names:
|
||||||
img = images[cam_name]
|
img = images[cam_name]
|
||||||
features = self.rgb_encoder.forward_single_image(img.reshape(B * T, *img.shape[2:]))
|
features = self.rgb_encoder.forward_single_image(img.reshape(B * T, *img.shape[2:]))
|
||||||
features_all.append(features)
|
features_all.append(features)
|
||||||
|
|
||||||
|
if self.output_tokens_per_camera:
|
||||||
|
stacked = torch.stack(features_all, dim=1) # (B*T, num_cams, feature_dim)
|
||||||
|
return stacked.view(B, T, len(cam_names), self.feature_dim)
|
||||||
return torch.cat(features_all, dim=1).view(B, T, -1)
|
return torch.cat(features_all, dim=1).view(B, T, -1)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
124
roboimi/vla/models/backbones/siglip2_diffusion_backbone.py
Normal file
124
roboimi/vla/models/backbones/siglip2_diffusion_backbone.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Dict, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers import SiglipVisionModel
|
||||||
|
|
||||||
|
from roboimi.vla.core.interfaces import VLABackbone
|
||||||
|
|
||||||
|
|
||||||
|
class SigLIP2DiffusionBackbone(VLABackbone):
|
||||||
|
"""Shared SigLIP vision tower for multiview diffusion-policy conditioning.
|
||||||
|
|
||||||
|
We intentionally load the checkpoint `google/siglip2-base-patch16-256` through
|
||||||
|
`SiglipVisionModel.from_pretrained(...)` so each camera can be fed as a normal
|
||||||
|
`(B, C, H, W)` image tensor and produce one pooled global feature vector.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = 'google/siglip2-base-patch16-256',
|
||||||
|
*,
|
||||||
|
model_name_or_path: str | None = None,
|
||||||
|
vision_model: nn.Module | None = None,
|
||||||
|
camera_names: Sequence[str] = ('r_vis', 'top', 'front'),
|
||||||
|
num_cameras: Optional[int] = None,
|
||||||
|
per_view_output_dim: int = 96,
|
||||||
|
output_dim: int | None = None,
|
||||||
|
freeze_backbone: bool = True,
|
||||||
|
dataset_image_resize_shape: Sequence[int] | None = None,
|
||||||
|
eval_image_resize_shape: Sequence[int] | None = (256, 256),
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
if model_name_or_path is not None:
|
||||||
|
model_name = model_name_or_path
|
||||||
|
if output_dim is not None:
|
||||||
|
per_view_output_dim = output_dim
|
||||||
|
|
||||||
|
self.model_name = str(model_name)
|
||||||
|
self.camera_names = tuple(camera_names)
|
||||||
|
self.num_cameras = int(num_cameras) if num_cameras is not None else len(self.camera_names)
|
||||||
|
if len(self.camera_names) != self.num_cameras:
|
||||||
|
raise ValueError(
|
||||||
|
f'camera_names length ({len(self.camera_names)}) must match num_cameras ({self.num_cameras})'
|
||||||
|
)
|
||||||
|
|
||||||
|
self._output_dim = int(per_view_output_dim)
|
||||||
|
self.joint_output_dim = self._output_dim * self.num_cameras
|
||||||
|
self.freeze_backbone = bool(freeze_backbone)
|
||||||
|
self.dataset_image_resize_shape = self._normalize_resize_shape(dataset_image_resize_shape)
|
||||||
|
self.eval_image_resize_shape = self._normalize_resize_shape(eval_image_resize_shape)
|
||||||
|
|
||||||
|
self.encoder = vision_model if vision_model is not None else SiglipVisionModel.from_pretrained(self.model_name)
|
||||||
|
hidden_size = int(getattr(self.encoder.config, 'hidden_size'))
|
||||||
|
self.view_projector = nn.Linear(hidden_size, self._output_dim)
|
||||||
|
self.projector = self.view_projector
|
||||||
|
|
||||||
|
self.register_buffer('mean', torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32).view(1, 3, 1, 1))
|
||||||
|
self.register_buffer('std', torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32).view(1, 3, 1, 1))
|
||||||
|
|
||||||
|
if self.freeze_backbone:
|
||||||
|
self._freeze_encoder()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_resize_shape(shape: Sequence[int] | None) -> tuple[int, int] | None:
|
||||||
|
if shape is None:
|
||||||
|
return None
|
||||||
|
normalized = tuple(int(v) for v in shape)
|
||||||
|
if len(normalized) != 2:
|
||||||
|
raise ValueError(f'resize shape must contain exactly two values, got {normalized}')
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_dim(self) -> int:
|
||||||
|
return self._output_dim
|
||||||
|
|
||||||
|
def _freeze_encoder(self) -> None:
|
||||||
|
self.encoder.eval()
|
||||||
|
for param in self.encoder.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def train(self, mode: bool = True):
|
||||||
|
super().train(mode)
|
||||||
|
if self.freeze_backbone:
|
||||||
|
self._freeze_encoder()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _ordered_camera_names(self, images: Dict[str, torch.Tensor]) -> Tuple[str, ...]:
|
||||||
|
missing = [camera_name for camera_name in self.camera_names if camera_name not in images]
|
||||||
|
if missing:
|
||||||
|
raise ValueError(
|
||||||
|
f'image input missing required cameras. missing={missing}, expected={list(self.camera_names)}'
|
||||||
|
)
|
||||||
|
return self.camera_names
|
||||||
|
|
||||||
|
def _prepare_pixels(self, image: torch.Tensor) -> torch.Tensor:
|
||||||
|
if image.ndim != 5:
|
||||||
|
raise ValueError(f'expected image tensor shaped (B, T, C, H, W), got {tuple(image.shape)}')
|
||||||
|
pixels = image.reshape(-1, *image.shape[2:]).contiguous().float()
|
||||||
|
pixels = pixels.clamp(0.0, 1.0)
|
||||||
|
return (pixels - self.mean) / self.std
|
||||||
|
|
||||||
|
def forward(self, images: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||||
|
camera_names = self._ordered_camera_names(images)
|
||||||
|
reference_shape = images[camera_names[0]].shape
|
||||||
|
batch_size, steps = reference_shape[:2]
|
||||||
|
per_view_features = []
|
||||||
|
for camera_name in camera_names:
|
||||||
|
image = images[camera_name]
|
||||||
|
if image.shape != reference_shape:
|
||||||
|
raise ValueError(
|
||||||
|
f'camera {camera_name!r} shape {tuple(image.shape)} does not match {tuple(reference_shape)}'
|
||||||
|
)
|
||||||
|
pixels = self._prepare_pixels(image)
|
||||||
|
with torch.set_grad_enabled(torch.is_grad_enabled() and not self.freeze_backbone):
|
||||||
|
encoded = self.encoder(pixel_values=pixels)
|
||||||
|
pooled = encoded.pooler_output
|
||||||
|
per_view_features.append(self.view_projector(pooled))
|
||||||
|
features = torch.cat(per_view_features, dim=-1)
|
||||||
|
return features.view(batch_size, steps, self.joint_output_dim)
|
||||||
|
|
||||||
|
|
||||||
|
Siglip2DiffusionBackbone = SigLIP2DiffusionBackbone
|
||||||
249
roboimi/vla/models/heads/attnres_transformer_components.py
Normal file
249
roboimi/vla/models/heads/attnres_transformer_components.py
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
def __init__(self, dim: int, eps: float = 1e-6) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.weight = nn.Parameter(torch.ones(dim))
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
rms = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
|
||||||
|
return (x.float() * rms).to(x.dtype) * self.weight
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNormNoWeight(nn.Module):
|
||||||
|
def __init__(self, eps: float = 1e-6) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
rms = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
|
||||||
|
return (x.float() * rms).to(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def precompute_rope_freqs(
|
||||||
|
dim: int,
|
||||||
|
max_seq_len: int,
|
||||||
|
theta: float = 10000.0,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
) -> Tensor:
|
||||||
|
if dim % 2 != 0:
|
||||||
|
raise ValueError(f'RoPE requires an even head dimension, got {dim}.')
|
||||||
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim))
|
||||||
|
positions = torch.arange(max_seq_len, device=device).float()
|
||||||
|
angles = torch.outer(positions, freqs)
|
||||||
|
return torch.polar(torch.ones_like(angles), angles)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rope(x: Tensor, freqs: Tensor) -> Tensor:
|
||||||
|
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
||||||
|
freqs = freqs.unsqueeze(0).unsqueeze(2)
|
||||||
|
x_rotated = x_complex * freqs
|
||||||
|
return torch.view_as_real(x_rotated).reshape_as(x).to(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class GroupedQuerySelfAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int,
|
||||||
|
n_heads: int,
|
||||||
|
n_kv_heads: int,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
if d_model % n_heads != 0:
|
||||||
|
raise ValueError(f'd_model={d_model} must be divisible by n_heads={n_heads}.')
|
||||||
|
if n_heads % n_kv_heads != 0:
|
||||||
|
raise ValueError(f'n_heads={n_heads} must be divisible by n_kv_heads={n_kv_heads}.')
|
||||||
|
|
||||||
|
self.d_model = d_model
|
||||||
|
self.n_heads = n_heads
|
||||||
|
self.n_kv_heads = n_kv_heads
|
||||||
|
self.n_kv_groups = n_heads // n_kv_heads
|
||||||
|
self.d_head = d_model // n_heads
|
||||||
|
self.attn_dropout = nn.Dropout(dropout)
|
||||||
|
self.out_dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
self.w_q = nn.Linear(d_model, n_heads * self.d_head, bias=False)
|
||||||
|
self.w_k = nn.Linear(d_model, n_kv_heads * self.d_head, bias=False)
|
||||||
|
self.w_v = nn.Linear(d_model, n_kv_heads * self.d_head, bias=False)
|
||||||
|
self.w_o = nn.Linear(n_heads * self.d_head, d_model, bias=False)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
rope_freqs: Tensor,
|
||||||
|
mask: Optional[Tensor] = None,
|
||||||
|
) -> Tensor:
|
||||||
|
batch_size, seq_len, _ = x.shape
|
||||||
|
|
||||||
|
q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_head)
|
||||||
|
k = self.w_k(x).view(batch_size, seq_len, self.n_kv_heads, self.d_head)
|
||||||
|
v = self.w_v(x).view(batch_size, seq_len, self.n_kv_heads, self.d_head)
|
||||||
|
|
||||||
|
q = apply_rope(q, rope_freqs)
|
||||||
|
k = apply_rope(k, rope_freqs)
|
||||||
|
|
||||||
|
if self.n_kv_heads != self.n_heads:
|
||||||
|
k = k.unsqueeze(3).expand(
|
||||||
|
batch_size, seq_len, self.n_kv_heads, self.n_kv_groups, self.d_head
|
||||||
|
)
|
||||||
|
k = k.reshape(batch_size, seq_len, self.n_heads, self.d_head)
|
||||||
|
v = v.unsqueeze(3).expand(
|
||||||
|
batch_size, seq_len, self.n_kv_heads, self.n_kv_groups, self.d_head
|
||||||
|
)
|
||||||
|
v = v.reshape(batch_size, seq_len, self.n_heads, self.d_head)
|
||||||
|
|
||||||
|
q = q.transpose(1, 2)
|
||||||
|
k = k.transpose(1, 2)
|
||||||
|
v = v.transpose(1, 2)
|
||||||
|
|
||||||
|
scale = 1.0 / math.sqrt(self.d_head)
|
||||||
|
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale
|
||||||
|
if mask is not None:
|
||||||
|
attn_weights = attn_weights + mask
|
||||||
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
|
attn_weights = self.attn_dropout(attn_weights)
|
||||||
|
|
||||||
|
out = torch.matmul(attn_weights, v)
|
||||||
|
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
|
||||||
|
return self.out_dropout(self.w_o(out))
|
||||||
|
|
||||||
|
|
||||||
|
class SwiGLUFFN(nn.Module):
|
||||||
|
def __init__(self, d_model: int, dropout: float = 0.0, mult: float = 2.667) -> None:
|
||||||
|
super().__init__()
|
||||||
|
raw = int(mult * d_model)
|
||||||
|
d_ff = ((raw + 7) // 8) * 8
|
||||||
|
self.w_gate = nn.Linear(d_model, d_ff, bias=False)
|
||||||
|
self.w_up = nn.Linear(d_model, d_ff, bias=False)
|
||||||
|
self.w_down = nn.Linear(d_ff, d_model, bias=False)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return self.dropout(self.w_down(F.silu(self.w_gate(x)) * self.w_up(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class AttnResOperator(nn.Module):
|
||||||
|
def __init__(self, d_model: int, eps: float = 1e-6) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.pseudo_query = nn.Parameter(torch.zeros(d_model))
|
||||||
|
self.key_norm = RMSNormNoWeight(eps=eps)
|
||||||
|
|
||||||
|
def forward(self, sources: Tensor) -> Tensor:
|
||||||
|
keys = self.key_norm(sources)
|
||||||
|
logits = torch.einsum('d,nbtd->nbt', self.pseudo_query, keys)
|
||||||
|
weights = F.softmax(logits, dim=0)
|
||||||
|
return torch.einsum('nbt,nbtd->btd', weights, sources)
|
||||||
|
|
||||||
|
|
||||||
|
class AttnResSubLayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int,
|
||||||
|
n_heads: int,
|
||||||
|
n_kv_heads: int,
|
||||||
|
dropout: float,
|
||||||
|
ffn_mult: float,
|
||||||
|
eps: float,
|
||||||
|
is_attention: bool,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.norm = RMSNorm(d_model, eps=eps)
|
||||||
|
self.attn_res = AttnResOperator(d_model, eps=eps)
|
||||||
|
self.is_attention = is_attention
|
||||||
|
if self.is_attention:
|
||||||
|
self.fn = GroupedQuerySelfAttention(
|
||||||
|
d_model=d_model,
|
||||||
|
n_heads=n_heads,
|
||||||
|
n_kv_heads=n_kv_heads,
|
||||||
|
dropout=dropout,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.fn = SwiGLUFFN(d_model=d_model, dropout=dropout, mult=ffn_mult)
|
||||||
|
|
||||||
|
def forward(self, sources: Tensor, rope_freqs: Tensor, mask: Optional[Tensor] = None) -> Tensor:
|
||||||
|
h = self.attn_res(sources)
|
||||||
|
normed = self.norm(h)
|
||||||
|
if self.is_attention:
|
||||||
|
return self.fn(normed, rope_freqs, mask)
|
||||||
|
return self.fn(normed)
|
||||||
|
|
||||||
|
|
||||||
|
class AttnResTransformerBackbone(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int,
|
||||||
|
n_blocks: int,
|
||||||
|
n_heads: int,
|
||||||
|
n_kv_heads: int,
|
||||||
|
max_seq_len: int,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
ffn_mult: float = 2.667,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
rope_theta: float = 10000.0,
|
||||||
|
causal_attn: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.causal_attn = causal_attn
|
||||||
|
self.layers = nn.ModuleList()
|
||||||
|
for _ in range(n_blocks):
|
||||||
|
self.layers.append(
|
||||||
|
AttnResSubLayer(
|
||||||
|
d_model=d_model,
|
||||||
|
n_heads=n_heads,
|
||||||
|
n_kv_heads=n_kv_heads,
|
||||||
|
dropout=dropout,
|
||||||
|
ffn_mult=ffn_mult,
|
||||||
|
eps=eps,
|
||||||
|
is_attention=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.layers.append(
|
||||||
|
AttnResSubLayer(
|
||||||
|
d_model=d_model,
|
||||||
|
n_heads=n_heads,
|
||||||
|
n_kv_heads=n_kv_heads,
|
||||||
|
dropout=dropout,
|
||||||
|
ffn_mult=ffn_mult,
|
||||||
|
eps=eps,
|
||||||
|
is_attention=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
rope_freqs = precompute_rope_freqs(
|
||||||
|
dim=d_model // n_heads,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
theta=rope_theta,
|
||||||
|
)
|
||||||
|
self.register_buffer('rope_freqs', rope_freqs, persistent=False)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_causal_mask(seq_len: int, device: torch.device) -> Tensor:
|
||||||
|
mask = torch.full((seq_len, seq_len), float('-inf'), device=device)
|
||||||
|
mask = torch.triu(mask, diagonal=1)
|
||||||
|
return mask.unsqueeze(0).unsqueeze(0)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
seq_len = x.shape[1]
|
||||||
|
rope_freqs = self.rope_freqs[:seq_len]
|
||||||
|
mask = None
|
||||||
|
if self.causal_attn:
|
||||||
|
mask = self._build_causal_mask(seq_len, x.device)
|
||||||
|
|
||||||
|
layer_outputs = [x]
|
||||||
|
for layer in self.layers:
|
||||||
|
sources = torch.stack(layer_outputs, dim=0)
|
||||||
|
output = layer(sources, rope_freqs, mask)
|
||||||
|
layer_outputs.append(output)
|
||||||
|
|
||||||
|
return torch.stack(layer_outputs, dim=0).sum(dim=0)
|
||||||
379
roboimi/vla/models/heads/imf_transformer1d.py
Normal file
379
roboimi/vla/models/heads/imf_transformer1d.py
Normal file
@@ -0,0 +1,379 @@
|
|||||||
|
"""Local IMF-AttnRes transformer head aligned with diffusion_policy@185ed659."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .attnres_transformer_components import (
|
||||||
|
AttnResOperator,
|
||||||
|
AttnResSubLayer,
|
||||||
|
AttnResTransformerBackbone,
|
||||||
|
GroupedQuerySelfAttention,
|
||||||
|
RMSNorm,
|
||||||
|
RMSNormNoWeight,
|
||||||
|
SwiGLUFFN,
|
||||||
|
)
|
||||||
|
from .transformer1d import ModuleAttrMixin, SinusoidalPosEmb
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class IMFTransformer1D(ModuleAttrMixin):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim: int,
|
||||||
|
output_dim: int,
|
||||||
|
horizon: int,
|
||||||
|
n_obs_steps: Optional[int] = None,
|
||||||
|
cond_dim: int = 0,
|
||||||
|
n_layer: int = 12,
|
||||||
|
n_head: int = 1,
|
||||||
|
n_emb: int = 768,
|
||||||
|
p_drop_emb: float = 0.1,
|
||||||
|
p_drop_attn: float = 0.1,
|
||||||
|
causal_attn: bool = False,
|
||||||
|
time_as_cond: bool = True,
|
||||||
|
obs_as_cond: bool = False,
|
||||||
|
n_cond_layers: int = 0,
|
||||||
|
backbone_type: str = 'attnres_full',
|
||||||
|
n_kv_head: int = 1,
|
||||||
|
attn_res_ffn_mult: float = 2.667,
|
||||||
|
attn_res_eps: float = 1e-6,
|
||||||
|
attn_res_rope_theta: float = 10000.0,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if n_head != 1:
|
||||||
|
raise AssertionError('IMFTransformer1D currently supports single-head attention only.')
|
||||||
|
if n_obs_steps is None:
|
||||||
|
n_obs_steps = horizon
|
||||||
|
|
||||||
|
self.backbone_type = backbone_type
|
||||||
|
|
||||||
|
T = horizon
|
||||||
|
T_cond = 2
|
||||||
|
if not time_as_cond:
|
||||||
|
T += 2
|
||||||
|
T_cond -= 2
|
||||||
|
obs_as_cond = cond_dim > 0
|
||||||
|
if obs_as_cond:
|
||||||
|
assert time_as_cond
|
||||||
|
T_cond += n_obs_steps
|
||||||
|
|
||||||
|
self.input_emb = nn.Linear(input_dim, n_emb)
|
||||||
|
self.drop = nn.Dropout(p_drop_emb)
|
||||||
|
self.time_emb = SinusoidalPosEmb(n_emb)
|
||||||
|
self.cond_obs_emb = nn.Linear(cond_dim, n_emb) if obs_as_cond else None
|
||||||
|
self.time_token_proj = None
|
||||||
|
self.cond_pos_emb = None
|
||||||
|
self.pos_emb = None
|
||||||
|
self.encoder = None
|
||||||
|
self.decoder = None
|
||||||
|
self.attnres_backbone = None
|
||||||
|
encoder_only = False
|
||||||
|
|
||||||
|
if backbone_type == 'attnres_full':
|
||||||
|
if not time_as_cond:
|
||||||
|
raise ValueError('attnres_full backbone requires time_as_cond=True.')
|
||||||
|
if n_cond_layers != 0:
|
||||||
|
raise ValueError('attnres_full backbone does not support n_cond_layers > 0.')
|
||||||
|
|
||||||
|
self.time_token_proj = nn.Linear(n_emb, n_emb)
|
||||||
|
self.attnres_backbone = AttnResTransformerBackbone(
|
||||||
|
d_model=n_emb,
|
||||||
|
n_blocks=n_layer,
|
||||||
|
n_heads=n_head,
|
||||||
|
n_kv_heads=n_kv_head,
|
||||||
|
max_seq_len=T + T_cond,
|
||||||
|
dropout=p_drop_attn,
|
||||||
|
ffn_mult=attn_res_ffn_mult,
|
||||||
|
eps=attn_res_eps,
|
||||||
|
rope_theta=attn_res_rope_theta,
|
||||||
|
causal_attn=causal_attn,
|
||||||
|
)
|
||||||
|
self.ln_f = RMSNorm(n_emb, eps=attn_res_eps)
|
||||||
|
else:
|
||||||
|
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
|
||||||
|
if T_cond > 0:
|
||||||
|
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
|
||||||
|
if n_cond_layers > 0:
|
||||||
|
encoder_layer = nn.TransformerEncoderLayer(
|
||||||
|
d_model=n_emb,
|
||||||
|
nhead=n_head,
|
||||||
|
dim_feedforward=4 * n_emb,
|
||||||
|
dropout=p_drop_attn,
|
||||||
|
activation='gelu',
|
||||||
|
batch_first=True,
|
||||||
|
norm_first=True,
|
||||||
|
)
|
||||||
|
self.encoder = nn.TransformerEncoder(
|
||||||
|
encoder_layer=encoder_layer,
|
||||||
|
num_layers=n_cond_layers,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.encoder = nn.Sequential(
|
||||||
|
nn.Linear(n_emb, 4 * n_emb),
|
||||||
|
nn.Mish(),
|
||||||
|
nn.Linear(4 * n_emb, n_emb),
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_layer = nn.TransformerDecoderLayer(
|
||||||
|
d_model=n_emb,
|
||||||
|
nhead=n_head,
|
||||||
|
dim_feedforward=4 * n_emb,
|
||||||
|
dropout=p_drop_attn,
|
||||||
|
activation='gelu',
|
||||||
|
batch_first=True,
|
||||||
|
norm_first=True,
|
||||||
|
)
|
||||||
|
self.decoder = nn.TransformerDecoder(
|
||||||
|
decoder_layer=decoder_layer,
|
||||||
|
num_layers=n_layer,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
encoder_only = True
|
||||||
|
encoder_layer = nn.TransformerEncoderLayer(
|
||||||
|
d_model=n_emb,
|
||||||
|
nhead=n_head,
|
||||||
|
dim_feedforward=4 * n_emb,
|
||||||
|
dropout=p_drop_attn,
|
||||||
|
activation='gelu',
|
||||||
|
batch_first=True,
|
||||||
|
norm_first=True,
|
||||||
|
)
|
||||||
|
self.encoder = nn.TransformerEncoder(
|
||||||
|
encoder_layer=encoder_layer,
|
||||||
|
num_layers=n_layer,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.ln_f = nn.LayerNorm(n_emb)
|
||||||
|
|
||||||
|
if causal_attn and backbone_type != 'attnres_full':
|
||||||
|
sz = T
|
||||||
|
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
||||||
|
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||||
|
self.register_buffer('mask', mask)
|
||||||
|
|
||||||
|
if time_as_cond and obs_as_cond:
|
||||||
|
S = T_cond
|
||||||
|
t_idx, s_idx = torch.meshgrid(
|
||||||
|
torch.arange(T),
|
||||||
|
torch.arange(S),
|
||||||
|
indexing='ij',
|
||||||
|
)
|
||||||
|
mask = t_idx >= (s_idx - 2)
|
||||||
|
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||||
|
self.register_buffer('memory_mask', mask)
|
||||||
|
else:
|
||||||
|
self.memory_mask = None
|
||||||
|
else:
|
||||||
|
self.mask = None
|
||||||
|
self.memory_mask = None
|
||||||
|
|
||||||
|
self.head = nn.Linear(n_emb, output_dim)
|
||||||
|
|
||||||
|
self.T = T
|
||||||
|
self.T_cond = T_cond
|
||||||
|
self.horizon = horizon
|
||||||
|
self.time_as_cond = time_as_cond
|
||||||
|
self.obs_as_cond = obs_as_cond
|
||||||
|
self.encoder_only = encoder_only
|
||||||
|
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
logger.info('number of parameters: %e', sum(p.numel() for p in self.parameters()))
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
ignore_types = (
|
||||||
|
nn.Dropout,
|
||||||
|
SinusoidalPosEmb,
|
||||||
|
nn.TransformerEncoderLayer,
|
||||||
|
nn.TransformerDecoderLayer,
|
||||||
|
nn.TransformerEncoder,
|
||||||
|
nn.TransformerDecoder,
|
||||||
|
nn.ModuleList,
|
||||||
|
nn.Mish,
|
||||||
|
nn.Sequential,
|
||||||
|
AttnResTransformerBackbone,
|
||||||
|
AttnResSubLayer,
|
||||||
|
GroupedQuerySelfAttention,
|
||||||
|
SwiGLUFFN,
|
||||||
|
RMSNormNoWeight,
|
||||||
|
)
|
||||||
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||||
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||||
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||||
|
torch.nn.init.zeros_(module.bias)
|
||||||
|
elif isinstance(module, nn.MultiheadAttention):
|
||||||
|
for name in ('in_proj_weight', 'q_proj_weight', 'k_proj_weight', 'v_proj_weight'):
|
||||||
|
weight = getattr(module, name)
|
||||||
|
if weight is not None:
|
||||||
|
torch.nn.init.normal_(weight, mean=0.0, std=0.02)
|
||||||
|
|
||||||
|
for name in ('in_proj_bias', 'bias_k', 'bias_v'):
|
||||||
|
bias = getattr(module, name)
|
||||||
|
if bias is not None:
|
||||||
|
torch.nn.init.zeros_(bias)
|
||||||
|
elif isinstance(module, (nn.LayerNorm, RMSNorm)):
|
||||||
|
if getattr(module, 'bias', None) is not None:
|
||||||
|
torch.nn.init.zeros_(module.bias)
|
||||||
|
torch.nn.init.ones_(module.weight)
|
||||||
|
elif isinstance(module, AttnResOperator):
|
||||||
|
torch.nn.init.zeros_(module.pseudo_query)
|
||||||
|
elif isinstance(module, IMFTransformer1D):
|
||||||
|
if module.pos_emb is not None:
|
||||||
|
torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
|
||||||
|
if module.cond_pos_emb is not None:
|
||||||
|
torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02)
|
||||||
|
elif isinstance(module, ignore_types):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f'Unaccounted module {module}')
|
||||||
|
|
||||||
|
def get_optim_groups(self, weight_decay: float = 1e-3):
|
||||||
|
decay = set()
|
||||||
|
no_decay = set()
|
||||||
|
whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention)
|
||||||
|
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding, RMSNorm)
|
||||||
|
for mn, m in self.named_modules():
|
||||||
|
for pn, _ in m.named_parameters(recurse=False):
|
||||||
|
fpn = f'{mn}.{pn}' if mn else pn
|
||||||
|
|
||||||
|
if pn.endswith('bias'):
|
||||||
|
no_decay.add(fpn)
|
||||||
|
elif pn.startswith('bias'):
|
||||||
|
no_decay.add(fpn)
|
||||||
|
elif pn == 'pseudo_query':
|
||||||
|
no_decay.add(fpn)
|
||||||
|
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
|
||||||
|
decay.add(fpn)
|
||||||
|
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
|
||||||
|
no_decay.add(fpn)
|
||||||
|
|
||||||
|
if self.pos_emb is not None:
|
||||||
|
no_decay.add('pos_emb')
|
||||||
|
no_decay.add('_dummy_variable')
|
||||||
|
if self.cond_pos_emb is not None:
|
||||||
|
no_decay.add('cond_pos_emb')
|
||||||
|
|
||||||
|
param_dict = {pn: p for pn, p in self.named_parameters()}
|
||||||
|
inter_params = decay & no_decay
|
||||||
|
union_params = decay | no_decay
|
||||||
|
assert len(inter_params) == 0, f'parameters {inter_params} made it into both decay/no_decay sets!'
|
||||||
|
assert len(param_dict.keys() - union_params) == 0, (
|
||||||
|
f'parameters {param_dict.keys() - union_params} were not separated into either decay/no_decay sets!'
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
'params': [param_dict[pn] for pn in sorted(list(decay))],
|
||||||
|
'weight_decay': weight_decay,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'params': [param_dict[pn] for pn in sorted(list(no_decay))],
|
||||||
|
'weight_decay': 0.0,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
def configure_optimizers(
|
||||||
|
self,
|
||||||
|
learning_rate: float = 1e-4,
|
||||||
|
weight_decay: float = 1e-3,
|
||||||
|
betas: Tuple[float, float] = (0.9, 0.95),
|
||||||
|
):
|
||||||
|
optim_groups = self.get_optim_groups(weight_decay=weight_decay)
|
||||||
|
return torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
|
||||||
|
|
||||||
|
def _prepare_time_input(self, value: Union[torch.Tensor, float, int], sample: torch.Tensor) -> torch.Tensor:
|
||||||
|
if not torch.is_tensor(value):
|
||||||
|
value = torch.tensor([value], dtype=sample.dtype, device=sample.device)
|
||||||
|
elif value.ndim == 0:
|
||||||
|
value = value[None].to(device=sample.device, dtype=sample.dtype)
|
||||||
|
else:
|
||||||
|
value = value.to(device=sample.device, dtype=sample.dtype)
|
||||||
|
return value.expand(sample.shape[0])
|
||||||
|
|
||||||
|
def _forward_attnres_full(
|
||||||
|
self,
|
||||||
|
sample: torch.Tensor,
|
||||||
|
r: torch.Tensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
cond: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
sample_tokens = self.input_emb(sample)
|
||||||
|
token_parts = [
|
||||||
|
self.time_token_proj(self.time_emb(r)).unsqueeze(1),
|
||||||
|
self.time_token_proj(self.time_emb(t)).unsqueeze(1),
|
||||||
|
]
|
||||||
|
if self.obs_as_cond:
|
||||||
|
if cond is None:
|
||||||
|
raise ValueError('cond is required when obs_as_cond=True for attnres_full backbone.')
|
||||||
|
token_parts.append(self.cond_obs_emb(cond))
|
||||||
|
token_parts.append(sample_tokens)
|
||||||
|
x = torch.cat(token_parts, dim=1)
|
||||||
|
x = self.drop(x)
|
||||||
|
x = self.attnres_backbone(x)
|
||||||
|
x = x[:, -sample_tokens.shape[1]:, :]
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _forward_vanilla(
|
||||||
|
self,
|
||||||
|
sample: torch.Tensor,
|
||||||
|
r: torch.Tensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
cond: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
r_emb = self.time_emb(r).unsqueeze(1)
|
||||||
|
t_emb = self.time_emb(t).unsqueeze(1)
|
||||||
|
input_emb = self.input_emb(sample)
|
||||||
|
|
||||||
|
if self.encoder_only:
|
||||||
|
token_embeddings = torch.cat([r_emb, t_emb, input_emb], dim=1)
|
||||||
|
token_count = token_embeddings.shape[1]
|
||||||
|
position_embeddings = self.pos_emb[:, :token_count, :]
|
||||||
|
x = self.drop(token_embeddings + position_embeddings)
|
||||||
|
x = self.encoder(src=x, mask=self.mask)
|
||||||
|
x = x[:, 2:, :]
|
||||||
|
else:
|
||||||
|
cond_embeddings = torch.cat([r_emb, t_emb], dim=1)
|
||||||
|
if self.obs_as_cond:
|
||||||
|
cond_embeddings = torch.cat([cond_embeddings, self.cond_obs_emb(cond)], dim=1)
|
||||||
|
token_count = cond_embeddings.shape[1]
|
||||||
|
position_embeddings = self.cond_pos_emb[:, :token_count, :]
|
||||||
|
x = self.drop(cond_embeddings + position_embeddings)
|
||||||
|
x = self.encoder(x)
|
||||||
|
memory = x
|
||||||
|
|
||||||
|
token_embeddings = input_emb
|
||||||
|
token_count = token_embeddings.shape[1]
|
||||||
|
position_embeddings = self.pos_emb[:, :token_count, :]
|
||||||
|
x = self.drop(token_embeddings + position_embeddings)
|
||||||
|
x = self.decoder(
|
||||||
|
tgt=x,
|
||||||
|
memory=memory,
|
||||||
|
tgt_mask=self.mask,
|
||||||
|
memory_mask=self.memory_mask,
|
||||||
|
)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
sample: torch.Tensor,
|
||||||
|
r: Union[torch.Tensor, float, int],
|
||||||
|
t: Union[torch.Tensor, float, int],
|
||||||
|
cond: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
r = self._prepare_time_input(r, sample)
|
||||||
|
t = self._prepare_time_input(t, sample)
|
||||||
|
|
||||||
|
if self.backbone_type == 'attnres_full':
|
||||||
|
x = self._forward_attnres_full(sample, r, t, cond=cond)
|
||||||
|
else:
|
||||||
|
x = self._forward_vanilla(sample, r, t, cond=cond)
|
||||||
|
|
||||||
|
x = self.ln_f(x)
|
||||||
|
x = self.head(x)
|
||||||
|
return x
|
||||||
17
roboimi/vla/modules/projectors.py
Normal file
17
roboimi/vla/modules/projectors.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class LinearConditionProjector(nn.Module):
|
||||||
|
"""Projects per-step visual+state conditioning to the head conditioning width."""
|
||||||
|
|
||||||
|
def __init__(self, input_dim: int, output_dim: int, bias: bool = True):
|
||||||
|
super().__init__()
|
||||||
|
self.input_dim = int(input_dim)
|
||||||
|
self.output_dim = int(output_dim)
|
||||||
|
self.linear = nn.Linear(self.input_dim, self.output_dim, bias=bias)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.linear(x)
|
||||||
26
tests/test_attnres_resnet2d_backbone.py
Normal file
26
tests/test_attnres_resnet2d_backbone.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class AttnResResNet2DBackboneTest(unittest.TestCase):
|
||||||
|
def test_backbone_preserves_resnet_like_stage_contract(self):
|
||||||
|
from roboimi.vla.models.backbones.attnres_resnet2d import AttnResResNetLikeBackbone2D
|
||||||
|
|
||||||
|
backbone = AttnResResNetLikeBackbone2D(
|
||||||
|
input_channels=3,
|
||||||
|
stem_dim=16,
|
||||||
|
stage_dims=(16, 32, 64, 128),
|
||||||
|
stage_depths=(1, 1, 1, 1),
|
||||||
|
stage_heads=(2, 4, 4, 8),
|
||||||
|
stage_kv_heads=(1, 1, 1, 1),
|
||||||
|
stage_window_sizes=(7, 7, 7, 7),
|
||||||
|
dropout=0.0,
|
||||||
|
)
|
||||||
|
x = torch.randn(2, 3, 56, 56)
|
||||||
|
y = backbone(x)
|
||||||
|
self.assertEqual(y.shape, (2, 128, 2, 2))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
@@ -90,6 +90,36 @@ class _FakeRenderer:
|
|||||||
|
|
||||||
|
|
||||||
class EvalVLAHeadlessTest(unittest.TestCase):
|
class EvalVLAHeadlessTest(unittest.TestCase):
|
||||||
|
def test_prepare_observation_skips_resize_when_image_resize_shape_is_none(self):
|
||||||
|
obs = {
|
||||||
|
"images": {
|
||||||
|
"front": np.arange(8 * 8 * 3, dtype=np.uint8).reshape(8, 8, 3),
|
||||||
|
},
|
||||||
|
"qpos": np.zeros(16, dtype=np.float32),
|
||||||
|
}
|
||||||
|
|
||||||
|
with mock.patch("cv2.resize", side_effect=AssertionError("resize should be skipped")):
|
||||||
|
prepared = eval_vla.prepare_observation(
|
||||||
|
obs,
|
||||||
|
["front"],
|
||||||
|
image_resize_shape=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(tuple(prepared["images"]["front"].shape), (3, 8, 8))
|
||||||
|
self.assertEqual(tuple(prepared["qpos"].shape), (16,))
|
||||||
|
|
||||||
|
def test_headless_eval_sets_mujoco_gl_to_egl_when_display_missing(self):
|
||||||
|
cfg = OmegaConf.create({"eval": {"headless": True}})
|
||||||
|
with mock.patch.dict(eval_vla.os.environ, {}, clear=True):
|
||||||
|
eval_vla._configure_headless_mujoco_gl(cfg.eval)
|
||||||
|
self.assertEqual(eval_vla.os.environ.get("MUJOCO_GL"), "egl")
|
||||||
|
|
||||||
|
def test_headless_eval_preserves_existing_mujoco_gl(self):
|
||||||
|
cfg = OmegaConf.create({"eval": {"headless": True}})
|
||||||
|
with mock.patch.dict(eval_vla.os.environ, {"MUJOCO_GL": "osmesa"}, clear=True):
|
||||||
|
eval_vla._configure_headless_mujoco_gl(cfg.eval)
|
||||||
|
self.assertEqual(eval_vla.os.environ.get("MUJOCO_GL"), "osmesa")
|
||||||
|
|
||||||
def test_eval_config_exposes_headless_default(self):
|
def test_eval_config_exposes_headless_default(self):
|
||||||
eval_cfg = OmegaConf.load(Path("roboimi/vla/conf/eval/eval.yaml"))
|
eval_cfg = OmegaConf.load(Path("roboimi/vla/conf/eval/eval.yaml"))
|
||||||
|
|
||||||
@@ -117,6 +147,49 @@ class EvalVLAHeadlessTest(unittest.TestCase):
|
|||||||
cam_view="angle",
|
cam_view="angle",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_headless_sync_camera_capture_populates_images_without_gui_calls(self):
|
||||||
|
env = DualDianaMed.__new__(DualDianaMed)
|
||||||
|
env.mj_model = object()
|
||||||
|
env.mj_data = object()
|
||||||
|
env.exit_flag = False
|
||||||
|
env.is_render = False
|
||||||
|
env.cam = 'angle'
|
||||||
|
env.r_vis = None
|
||||||
|
env.l_vis = None
|
||||||
|
env.top = None
|
||||||
|
env.angle = None
|
||||||
|
env.front = None
|
||||||
|
env._offscreen_renderer = None
|
||||||
|
|
||||||
|
with mock.patch(
|
||||||
|
'roboimi.envs.double_base.mj.Renderer',
|
||||||
|
side_effect=lambda *args, **kwargs: _FakeRenderer(env),
|
||||||
|
) as renderer_cls, mock.patch('roboimi.envs.double_base.cv2.namedWindow') as named_window, mock.patch(
|
||||||
|
'roboimi.envs.double_base.cv2.imshow'
|
||||||
|
) as imshow, mock.patch('roboimi.envs.double_base.cv2.waitKey') as wait_key:
|
||||||
|
env._update_camera_images_sync()
|
||||||
|
|
||||||
|
renderer_cls.assert_called_once()
|
||||||
|
named_window.assert_not_called()
|
||||||
|
imshow.assert_not_called()
|
||||||
|
wait_key.assert_not_called()
|
||||||
|
self.assertIsNotNone(env.r_vis)
|
||||||
|
self.assertIsNotNone(env.l_vis)
|
||||||
|
self.assertIsNotNone(env.top)
|
||||||
|
self.assertIsNotNone(env.angle)
|
||||||
|
self.assertIsNotNone(env.front)
|
||||||
|
|
||||||
|
def test_cam_start_skips_background_thread_when_headless(self):
|
||||||
|
env = DualDianaMed.__new__(DualDianaMed)
|
||||||
|
env.is_render = False
|
||||||
|
env.cam_thread = None
|
||||||
|
|
||||||
|
with mock.patch('roboimi.envs.double_base.threading.Thread') as thread_cls:
|
||||||
|
env.cam_start()
|
||||||
|
|
||||||
|
thread_cls.assert_not_called()
|
||||||
|
self.assertIsNone(env.cam_thread)
|
||||||
|
|
||||||
def test_camera_viewer_headless_updates_images_without_gui_calls(self):
|
def test_camera_viewer_headless_updates_images_without_gui_calls(self):
|
||||||
env = DualDianaMed.__new__(DualDianaMed)
|
env = DualDianaMed.__new__(DualDianaMed)
|
||||||
env.mj_model = object()
|
env.mj_model = object()
|
||||||
|
|||||||
26
tests/test_eval_vla_headless_import.py
Normal file
26
tests/test_eval_vla_headless_import.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_vla_import_does_not_import_mujoco_early_when_headless_backend_not_set():
|
||||||
|
env = os.environ.copy()
|
||||||
|
env.pop('MUJOCO_GL', None)
|
||||||
|
proc = subprocess.run(
|
||||||
|
[
|
||||||
|
sys.executable,
|
||||||
|
'-c',
|
||||||
|
(
|
||||||
|
'import json, sys; '
|
||||||
|
'from roboimi.demos.vla_scripts import eval_vla; '
|
||||||
|
'print(json.dumps({"mujoco_in_sys_modules": "mujoco" in sys.modules}))'
|
||||||
|
),
|
||||||
|
],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
env=env,
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
|
payload = json.loads(proc.stdout.strip())
|
||||||
|
assert payload['mujoco_in_sys_modules'] is False
|
||||||
@@ -102,8 +102,10 @@ class EvalVLARolloutArtifactsTest(unittest.TestCase):
|
|||||||
self.assertIn('artifact_dir', eval_cfg)
|
self.assertIn('artifact_dir', eval_cfg)
|
||||||
self.assertFalse(eval_cfg.save_summary_json)
|
self.assertFalse(eval_cfg.save_summary_json)
|
||||||
self.assertFalse(eval_cfg.save_trajectory_npz)
|
self.assertFalse(eval_cfg.save_trajectory_npz)
|
||||||
|
self.assertFalse(eval_cfg.save_trajectory_image)
|
||||||
self.assertFalse(eval_cfg.record_video)
|
self.assertFalse(eval_cfg.record_video)
|
||||||
self.assertIsNone(eval_cfg.artifact_dir)
|
self.assertIsNone(eval_cfg.artifact_dir)
|
||||||
|
self.assertIsNone(eval_cfg.trajectory_image_camera_name)
|
||||||
self.assertIsNone(eval_cfg.video_camera_name)
|
self.assertIsNone(eval_cfg.video_camera_name)
|
||||||
self.assertEqual(eval_cfg.video_fps, 30)
|
self.assertEqual(eval_cfg.video_fps, 30)
|
||||||
|
|
||||||
@@ -133,6 +135,8 @@ class EvalVLARolloutArtifactsTest(unittest.TestCase):
|
|||||||
'artifact_dir': tmpdir,
|
'artifact_dir': tmpdir,
|
||||||
'save_summary_json': True,
|
'save_summary_json': True,
|
||||||
'save_trajectory_npz': True,
|
'save_trajectory_npz': True,
|
||||||
|
'save_trajectory_image': True,
|
||||||
|
'trajectory_image_camera_name': 'front',
|
||||||
'record_video': True,
|
'record_video': True,
|
||||||
'video_camera_name': 'front',
|
'video_camera_name': 'front',
|
||||||
'video_fps': 12,
|
'video_fps': 12,
|
||||||
@@ -176,12 +180,14 @@ class EvalVLARolloutArtifactsTest(unittest.TestCase):
|
|||||||
trajectory_path = Path(artifacts['trajectory_npz'])
|
trajectory_path = Path(artifacts['trajectory_npz'])
|
||||||
summary_path = Path(artifacts['summary_json'])
|
summary_path = Path(artifacts['summary_json'])
|
||||||
video_path = Path(artifacts['video_mp4'])
|
video_path = Path(artifacts['video_mp4'])
|
||||||
|
trajectory_image_path = Path(summary['episodes'][0]['artifact_paths']['trajectory_image'])
|
||||||
|
|
||||||
self.assertEqual(Path(artifacts['output_dir']), Path(tmpdir))
|
self.assertEqual(Path(artifacts['output_dir']), Path(tmpdir))
|
||||||
self.assertEqual(artifacts['video_camera_name'], 'front')
|
self.assertEqual(artifacts['video_camera_name'], 'front')
|
||||||
self.assertTrue(trajectory_path.exists())
|
self.assertTrue(trajectory_path.exists())
|
||||||
self.assertTrue(summary_path.exists())
|
self.assertTrue(summary_path.exists())
|
||||||
self.assertTrue(video_path.exists())
|
self.assertTrue(video_path.exists())
|
||||||
|
self.assertTrue(trajectory_image_path.exists())
|
||||||
|
|
||||||
rollout_npz = np.load(trajectory_path)
|
rollout_npz = np.load(trajectory_path)
|
||||||
np.testing.assert_array_equal(rollout_npz['episode_index'], np.array([0, 0]))
|
np.testing.assert_array_equal(rollout_npz['episode_index'], np.array([0, 0]))
|
||||||
@@ -218,11 +224,121 @@ class EvalVLARolloutArtifactsTest(unittest.TestCase):
|
|||||||
saved_summary = json.load(fh)
|
saved_summary = json.load(fh)
|
||||||
self.assertEqual(saved_summary['artifacts']['trajectory_npz'], str(trajectory_path))
|
self.assertEqual(saved_summary['artifacts']['trajectory_npz'], str(trajectory_path))
|
||||||
self.assertEqual(saved_summary['artifacts']['video_mp4'], str(video_path))
|
self.assertEqual(saved_summary['artifacts']['video_mp4'], str(video_path))
|
||||||
|
self.assertEqual(
|
||||||
|
saved_summary['episodes'][0]['artifact_paths']['trajectory_image'],
|
||||||
|
str(trajectory_image_path),
|
||||||
|
)
|
||||||
self.assertEqual(saved_summary['episode_rewards'], [3.0])
|
self.assertEqual(saved_summary['episode_rewards'], [3.0])
|
||||||
self.assertAlmostEqual(summary['avg_reward'], 3.0)
|
self.assertAlmostEqual(summary['avg_reward'], 3.0)
|
||||||
self.assertIn('avg_obs_read_time_ms', summary)
|
self.assertIn('avg_obs_read_time_ms', summary)
|
||||||
self.assertIn('avg_env_step_time_ms', summary)
|
self.assertIn('avg_env_step_time_ms', summary)
|
||||||
|
|
||||||
|
def test_run_eval_exports_front_trajectory_images_without_video_dependency(self):
|
||||||
|
actions = [
|
||||||
|
np.arange(16, dtype=np.float32),
|
||||||
|
np.arange(16, dtype=np.float32) + 10.0,
|
||||||
|
np.arange(16, dtype=np.float32) + 100.0,
|
||||||
|
np.arange(16, dtype=np.float32) + 110.0,
|
||||||
|
]
|
||||||
|
fake_agent = _FakeAgent(actions)
|
||||||
|
fake_env = _FakeEnv()
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
cfg = OmegaConf.create(
|
||||||
|
{
|
||||||
|
'agent': {},
|
||||||
|
'eval': {
|
||||||
|
'ckpt_path': 'checkpoints/vla_model_best.pt',
|
||||||
|
'num_episodes': 2,
|
||||||
|
'max_timesteps': 2,
|
||||||
|
'device': 'cpu',
|
||||||
|
'task_name': 'sim_transfer',
|
||||||
|
'camera_names': ['top', 'front'],
|
||||||
|
'use_smoothing': True,
|
||||||
|
'smooth_alpha': 0.5,
|
||||||
|
'verbose_action': False,
|
||||||
|
'headless': True,
|
||||||
|
'artifact_dir': tmpdir,
|
||||||
|
'save_trajectory_image': True,
|
||||||
|
'record_video': False,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
trajectory_image_calls = []
|
||||||
|
|
||||||
|
def fake_save_rollout_trajectory_image(
|
||||||
|
env,
|
||||||
|
output_path,
|
||||||
|
raw_actions,
|
||||||
|
camera_name,
|
||||||
|
*,
|
||||||
|
line_radius=0.004,
|
||||||
|
max_markers=1500,
|
||||||
|
):
|
||||||
|
del env, line_radius, max_markers
|
||||||
|
trajectory_image_calls.append(
|
||||||
|
{
|
||||||
|
'output_path': output_path,
|
||||||
|
'camera_name': camera_name,
|
||||||
|
'raw_actions': [np.array(action, copy=True) for action in raw_actions],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if output_path is None:
|
||||||
|
return None
|
||||||
|
output_path = Path(output_path)
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
output_path.write_bytes(b'fake-png')
|
||||||
|
return str(output_path)
|
||||||
|
|
||||||
|
with mock.patch.object(
|
||||||
|
eval_vla,
|
||||||
|
'load_checkpoint',
|
||||||
|
return_value=(fake_agent, None),
|
||||||
|
), mock.patch.object(
|
||||||
|
eval_vla,
|
||||||
|
'make_sim_env',
|
||||||
|
return_value=fake_env,
|
||||||
|
), mock.patch.object(
|
||||||
|
eval_vla,
|
||||||
|
'sample_transfer_pose',
|
||||||
|
return_value=np.array([0.1, 0.2, 0.3], dtype=np.float32),
|
||||||
|
), mock.patch.object(
|
||||||
|
eval_vla,
|
||||||
|
'tqdm',
|
||||||
|
side_effect=lambda iterable, **kwargs: iterable,
|
||||||
|
), mock.patch.object(
|
||||||
|
eval_vla,
|
||||||
|
'_save_rollout_trajectory_image',
|
||||||
|
side_effect=fake_save_rollout_trajectory_image,
|
||||||
|
) as save_trajectory_image_mock, mock.patch.object(
|
||||||
|
eval_vla,
|
||||||
|
'_open_video_writer',
|
||||||
|
) as open_video_writer_mock:
|
||||||
|
summary = eval_vla._run_eval(cfg)
|
||||||
|
|
||||||
|
self.assertEqual(save_trajectory_image_mock.call_count, 2)
|
||||||
|
open_video_writer_mock.assert_not_called()
|
||||||
|
self.assertIsNone(summary['artifacts']['video_mp4'])
|
||||||
|
self.assertEqual(summary['artifacts']['trajectory_image_camera_name'], 'front')
|
||||||
|
self.assertEqual(
|
||||||
|
[call['camera_name'] for call in trajectory_image_calls],
|
||||||
|
['front', 'front'],
|
||||||
|
)
|
||||||
|
|
||||||
|
first_episode_path = Path(summary['episodes'][0]['artifact_paths']['trajectory_image'])
|
||||||
|
second_episode_path = Path(summary['episodes'][1]['artifact_paths']['trajectory_image'])
|
||||||
|
self.assertTrue(first_episode_path.exists())
|
||||||
|
self.assertTrue(second_episode_path.exists())
|
||||||
|
self.assertNotEqual(first_episode_path, second_episode_path)
|
||||||
|
self.assertEqual(first_episode_path.parent, Path(tmpdir))
|
||||||
|
self.assertEqual(second_episode_path.parent, Path(tmpdir))
|
||||||
|
|
||||||
|
np.testing.assert_array_equal(trajectory_image_calls[0]['raw_actions'][0], actions[0])
|
||||||
|
np.testing.assert_array_equal(trajectory_image_calls[0]['raw_actions'][1], actions[1])
|
||||||
|
np.testing.assert_array_equal(trajectory_image_calls[1]['raw_actions'][0], actions[2])
|
||||||
|
np.testing.assert_array_equal(trajectory_image_calls[1]['raw_actions'][1], actions[3])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
196
tests/test_imf_transformer1d_external_alignment.py
Normal file
196
tests/test_imf_transformer1d_external_alignment.py
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
import contextlib
|
||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
if str(_REPO_ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(_REPO_ROOT))
|
||||||
|
|
||||||
|
_EXTERNAL_COMMIT = '185ed659'
|
||||||
|
_LOCAL_MODULE_NAME = 'roboimi.vla.models.heads.imf_transformer1d'
|
||||||
|
_MISSING = object()
|
||||||
|
|
||||||
|
|
||||||
|
def _find_external_checkout_root() -> Path | None:
|
||||||
|
for ancestor in (_REPO_ROOT, *_REPO_ROOT.parents):
|
||||||
|
candidate = ancestor / 'diffusion_policy'
|
||||||
|
if (candidate / '.git').exists():
|
||||||
|
return candidate
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
_EXTERNAL_CHECKOUT_ROOT = _find_external_checkout_root()
|
||||||
|
_EXTERNAL_MODULE_PATHS = {
|
||||||
|
'diffusion_policy.model.common.module_attr_mixin': 'diffusion_policy/model/common/module_attr_mixin.py',
|
||||||
|
'diffusion_policy.model.diffusion.positional_embedding': 'diffusion_policy/model/diffusion/positional_embedding.py',
|
||||||
|
'diffusion_policy.model.diffusion.attnres_transformer_components': 'diffusion_policy/model/diffusion/attnres_transformer_components.py',
|
||||||
|
'diffusion_policy.model.diffusion.imf_transformer_for_diffusion': 'diffusion_policy/model/diffusion/imf_transformer_for_diffusion.py',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def _temporary_registered_modules():
|
||||||
|
previous_modules = {}
|
||||||
|
|
||||||
|
def remember(name: str) -> None:
|
||||||
|
if name not in previous_modules:
|
||||||
|
previous_modules[name] = sys.modules.get(name, _MISSING)
|
||||||
|
|
||||||
|
def ensure_package(name: str) -> None:
|
||||||
|
if not name or name in sys.modules:
|
||||||
|
return
|
||||||
|
remember(name)
|
||||||
|
package = types.ModuleType(name)
|
||||||
|
package.__path__ = []
|
||||||
|
sys.modules[name] = package
|
||||||
|
|
||||||
|
def load(name: str, source: str, origin: str):
|
||||||
|
package_parts = name.split('.')[:-1]
|
||||||
|
for idx in range(1, len(package_parts) + 1):
|
||||||
|
ensure_package('.'.join(package_parts[:idx]))
|
||||||
|
|
||||||
|
remember(name)
|
||||||
|
module = types.ModuleType(name)
|
||||||
|
module.__file__ = origin
|
||||||
|
module.__package__ = name.rpartition('.')[0]
|
||||||
|
sys.modules[name] = module
|
||||||
|
exec(compile(source, origin, 'exec'), module.__dict__)
|
||||||
|
return module
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield load
|
||||||
|
finally:
|
||||||
|
for name, previous in reversed(list(previous_modules.items())):
|
||||||
|
if previous is _MISSING:
|
||||||
|
sys.modules.pop(name, None)
|
||||||
|
else:
|
||||||
|
sys.modules[name] = previous
|
||||||
|
|
||||||
|
|
||||||
|
def _git_show(repo_root: Path, commit: str, relative_path: str) -> str:
|
||||||
|
result = subprocess.run(
|
||||||
|
['git', '-C', str(repo_root), 'show', f'{commit}:{relative_path}'],
|
||||||
|
check=True,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
return result.stdout
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def _load_external_module_or_skip(test_case: unittest.TestCase):
|
||||||
|
if _EXTERNAL_CHECKOUT_ROOT is None:
|
||||||
|
test_case.skipTest('external diffusion_policy checkout unavailable')
|
||||||
|
|
||||||
|
try:
|
||||||
|
sources = {
|
||||||
|
name: _git_show(_EXTERNAL_CHECKOUT_ROOT, _EXTERNAL_COMMIT, relative_path)
|
||||||
|
for name, relative_path in _EXTERNAL_MODULE_PATHS.items()
|
||||||
|
}
|
||||||
|
except subprocess.CalledProcessError as exc:
|
||||||
|
test_case.skipTest(
|
||||||
|
f'external diffusion_policy commit {_EXTERNAL_COMMIT} is unavailable: {exc.stderr.strip() or exc}'
|
||||||
|
)
|
||||||
|
|
||||||
|
with _temporary_registered_modules() as load_external:
|
||||||
|
for name, relative_path in _EXTERNAL_MODULE_PATHS.items():
|
||||||
|
load_external(
|
||||||
|
name,
|
||||||
|
sources[name],
|
||||||
|
origin=f'{_EXTERNAL_CHECKOUT_ROOT}:{_EXTERNAL_COMMIT}:{relative_path}',
|
||||||
|
)
|
||||||
|
yield sys.modules['diffusion_policy.model.diffusion.imf_transformer_for_diffusion']
|
||||||
|
|
||||||
|
|
||||||
|
def _load_local_module():
|
||||||
|
importlib.invalidate_caches()
|
||||||
|
sys.modules.pop(_LOCAL_MODULE_NAME, None)
|
||||||
|
return importlib.import_module(_LOCAL_MODULE_NAME)
|
||||||
|
|
||||||
|
|
||||||
|
class IMFTransformer1DExternalAlignmentTest(unittest.TestCase):
|
||||||
|
def _optim_group_names(self, model, groups):
|
||||||
|
names_by_param = {id(param): name for name, param in model.named_parameters()}
|
||||||
|
return [
|
||||||
|
{names_by_param[id(param)] for param in group['params']}
|
||||||
|
for group in groups
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_local_defaults_preserve_supported_attnres_config(self):
|
||||||
|
local_module = _load_local_module()
|
||||||
|
ctor = inspect.signature(local_module.IMFTransformer1D.__init__).parameters
|
||||||
|
|
||||||
|
self.assertEqual(ctor['backbone_type'].default, 'attnres_full')
|
||||||
|
self.assertEqual(ctor['n_head'].default, 1)
|
||||||
|
self.assertEqual(ctor['n_kv_head'].default, 1)
|
||||||
|
self.assertEqual(ctor['n_cond_layers'].default, 0)
|
||||||
|
self.assertTrue(ctor['time_as_cond'].default)
|
||||||
|
self.assertFalse(ctor['causal_attn'].default)
|
||||||
|
|
||||||
|
def test_attnres_full_state_dict_forward_and_optim_groups_match_external(self):
|
||||||
|
local_module = _load_local_module()
|
||||||
|
with _load_external_module_or_skip(self) as external_module:
|
||||||
|
config = dict(
|
||||||
|
input_dim=4,
|
||||||
|
output_dim=4,
|
||||||
|
horizon=6,
|
||||||
|
n_obs_steps=3,
|
||||||
|
cond_dim=5,
|
||||||
|
n_layer=2,
|
||||||
|
n_head=1,
|
||||||
|
n_emb=16,
|
||||||
|
p_drop_emb=0.0,
|
||||||
|
p_drop_attn=0.0,
|
||||||
|
causal_attn=False,
|
||||||
|
time_as_cond=True,
|
||||||
|
n_cond_layers=0,
|
||||||
|
backbone_type='attnres_full',
|
||||||
|
n_kv_head=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.manual_seed(7)
|
||||||
|
external_model = external_module.IMFTransformerForDiffusion(**config)
|
||||||
|
local_model = local_module.IMFTransformer1D(**config)
|
||||||
|
external_model.eval()
|
||||||
|
local_model.eval()
|
||||||
|
|
||||||
|
external_state_dict = external_model.state_dict()
|
||||||
|
self.assertEqual(set(local_model.state_dict().keys()), set(external_state_dict.keys()))
|
||||||
|
local_model.load_state_dict(external_state_dict, strict=True)
|
||||||
|
|
||||||
|
batch_size = 2
|
||||||
|
sample = torch.randn(batch_size, config['horizon'], config['input_dim'])
|
||||||
|
r = torch.tensor([0.1, 0.4], dtype=torch.float32)
|
||||||
|
t = torch.tensor([0.7, 0.9], dtype=torch.float32)
|
||||||
|
cond = torch.randn(batch_size, config['n_obs_steps'], config['cond_dim'])
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
external_out = external_model(sample=sample, r=r, t=t, cond=cond)
|
||||||
|
local_out = local_model(sample=sample, r=r, t=t, cond=cond)
|
||||||
|
|
||||||
|
self.assertEqual(local_out.shape, (batch_size, config['horizon'], config['output_dim']))
|
||||||
|
self.assertEqual(local_out.shape, external_out.shape)
|
||||||
|
self.assertTrue(torch.allclose(local_out, external_out, atol=1e-6, rtol=1e-5))
|
||||||
|
|
||||||
|
weight_decay = 0.123
|
||||||
|
external_groups = external_model.get_optim_groups(weight_decay=weight_decay)
|
||||||
|
local_groups = local_model.get_optim_groups(weight_decay=weight_decay)
|
||||||
|
|
||||||
|
self.assertEqual(len(local_groups), len(external_groups))
|
||||||
|
self.assertEqual([group['weight_decay'] for group in local_groups], [weight_decay, 0.0])
|
||||||
|
self.assertEqual(
|
||||||
|
self._optim_group_names(local_model, local_groups),
|
||||||
|
self._optim_group_names(external_model, external_groups),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
889
tests/test_imf_vla_agent.py
Normal file
889
tests/test_imf_vla_agent.py
Normal file
@@ -0,0 +1,889 @@
|
|||||||
|
import contextlib
|
||||||
|
import importlib
|
||||||
|
import importlib.machinery
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from hydra import compose, initialize_config_dir
|
||||||
|
from hydra.core.global_hydra import GlobalHydra
|
||||||
|
from hydra.utils import instantiate
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
_CONFIG_DIR = str((_REPO_ROOT / 'roboimi/vla/conf').resolve())
|
||||||
|
_MISSING = object()
|
||||||
|
_CAMERA_NAMES = ('r_vis', 'top', 'front')
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeScheduler:
|
||||||
|
def __init__(self, num_train_timesteps=100, **kwargs):
|
||||||
|
self.config = types.SimpleNamespace(num_train_timesteps=num_train_timesteps)
|
||||||
|
self.timesteps = []
|
||||||
|
|
||||||
|
def add_noise(self, sample, noise, timestep):
|
||||||
|
return sample + noise
|
||||||
|
|
||||||
|
def set_timesteps(self, num_inference_steps):
|
||||||
|
self.timesteps = list(range(num_inference_steps - 1, -1, -1))
|
||||||
|
|
||||||
|
def step(self, noise_pred, timestep, sample):
|
||||||
|
return types.SimpleNamespace(prev_sample=sample)
|
||||||
|
|
||||||
|
|
||||||
|
class _IdentityCrop:
|
||||||
|
def __init__(self, size):
|
||||||
|
self.size = size
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeResNet(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.conv1 = nn.Conv2d(3, 8, kernel_size=3, padding=1)
|
||||||
|
self.relu1 = nn.ReLU()
|
||||||
|
self.conv2 = nn.Conv2d(8, 16, kernel_size=3, padding=1, stride=2)
|
||||||
|
self.relu2 = nn.ReLU()
|
||||||
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||||
|
self.fc = nn.Linear(16, 16)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.relu1(self.conv1(x))
|
||||||
|
x = self.relu2(self.conv2(x))
|
||||||
|
x = self.avgpool(x)
|
||||||
|
x = torch.flatten(x, start_dim=1)
|
||||||
|
return self.fc(x)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeRearrange(nn.Module):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeViTConfig:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeViTModel(nn.Module):
|
||||||
|
def __init__(self, config, add_pooling_layer=False):
|
||||||
|
super().__init__()
|
||||||
|
del add_pooling_layer
|
||||||
|
self.config = config
|
||||||
|
hidden_size = int(getattr(config, 'hidden_size', 192))
|
||||||
|
self.proj = nn.Linear(hidden_size, hidden_size)
|
||||||
|
|
||||||
|
def forward(self, pixel_values=None, interpolate_pos_encoding=False, **kwargs):
|
||||||
|
del interpolate_pos_encoding, kwargs
|
||||||
|
batch_size = pixel_values.shape[0]
|
||||||
|
hidden_size = int(getattr(self.config, 'hidden_size', 192))
|
||||||
|
seq_len = 2
|
||||||
|
last_hidden_state = torch.zeros(batch_size, seq_len, hidden_size, dtype=pixel_values.dtype, device=pixel_values.device)
|
||||||
|
return types.SimpleNamespace(last_hidden_state=last_hidden_state)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeSiglipVisionOutput:
|
||||||
|
def __init__(self, pooler_output):
|
||||||
|
self.pooler_output = pooler_output
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeSiglipVisionConfig:
|
||||||
|
def __init__(self, hidden_size=768, image_size=256):
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.image_size = image_size
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeSiglipVisionModel(nn.Module):
|
||||||
|
load_calls = []
|
||||||
|
|
||||||
|
def __init__(self, hidden_size=768):
|
||||||
|
super().__init__()
|
||||||
|
self.config = _FakeSiglipVisionConfig(hidden_size=hidden_size)
|
||||||
|
self.scale = nn.Parameter(torch.tensor(1.0))
|
||||||
|
self.forward_calls = []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
||||||
|
model = cls()
|
||||||
|
cls.load_calls.append({
|
||||||
|
'pretrained_model_name_or_path': pretrained_model_name_or_path,
|
||||||
|
'args': args,
|
||||||
|
'kwargs': kwargs,
|
||||||
|
})
|
||||||
|
return model
|
||||||
|
|
||||||
|
def forward(self, pixel_values=None, **kwargs):
|
||||||
|
self.forward_calls.append({
|
||||||
|
'pixel_values': pixel_values.detach().clone(),
|
||||||
|
'kwargs': dict(kwargs),
|
||||||
|
})
|
||||||
|
pooled = pixel_values.mean(dim=(2, 3), keepdim=False) * self.scale
|
||||||
|
return _FakeSiglipVisionOutput(pooler_output=pooled)
|
||||||
|
|
||||||
|
|
||||||
|
class _StubIMFHead(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim,
|
||||||
|
output_dim,
|
||||||
|
horizon,
|
||||||
|
n_obs_steps,
|
||||||
|
cond_dim,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.constructor_kwargs = {
|
||||||
|
'input_dim': input_dim,
|
||||||
|
'output_dim': output_dim,
|
||||||
|
'horizon': horizon,
|
||||||
|
'n_obs_steps': n_obs_steps,
|
||||||
|
'cond_dim': cond_dim,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
self.proj = nn.Linear(input_dim, output_dim)
|
||||||
|
self.cond_obs_emb = nn.Linear(cond_dim, max(cond_dim, 1))
|
||||||
|
|
||||||
|
def forward(self, sample, r, t, cond=None):
|
||||||
|
return torch.zeros_like(sample)
|
||||||
|
|
||||||
|
def get_optim_groups(self, weight_decay):
|
||||||
|
return [
|
||||||
|
{'params': [self.proj.weight], 'weight_decay': weight_decay},
|
||||||
|
{'params': [self.proj.bias, self.cond_obs_emb.weight, self.cond_obs_emb.bias], 'weight_decay': 0.0},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def _stub_optional_modules(include_imf_head=False):
|
||||||
|
previous_modules = {}
|
||||||
|
|
||||||
|
def remember_and_remove(name):
|
||||||
|
if name not in previous_modules:
|
||||||
|
previous_modules[name] = sys.modules.get(name, _MISSING)
|
||||||
|
sys.modules.pop(name, None)
|
||||||
|
|
||||||
|
def inject(name, module):
|
||||||
|
if name not in previous_modules:
|
||||||
|
previous_modules[name] = sys.modules.get(name, _MISSING)
|
||||||
|
sys.modules[name] = module
|
||||||
|
|
||||||
|
diffusers_module = types.ModuleType('diffusers')
|
||||||
|
schedulers_module = types.ModuleType('diffusers.schedulers')
|
||||||
|
ddpm_module = types.ModuleType('diffusers.schedulers.scheduling_ddpm')
|
||||||
|
ddim_module = types.ModuleType('diffusers.schedulers.scheduling_ddim')
|
||||||
|
ddpm_module.DDPMScheduler = _FakeScheduler
|
||||||
|
ddim_module.DDIMScheduler = _FakeScheduler
|
||||||
|
diffusers_module.DDPMScheduler = _FakeScheduler
|
||||||
|
diffusers_module.DDIMScheduler = _FakeScheduler
|
||||||
|
diffusers_module.schedulers = schedulers_module
|
||||||
|
schedulers_module.scheduling_ddpm = ddpm_module
|
||||||
|
schedulers_module.scheduling_ddim = ddim_module
|
||||||
|
|
||||||
|
torchvision_module = types.ModuleType('torchvision')
|
||||||
|
models_module = types.ModuleType('torchvision.models')
|
||||||
|
transforms_module = types.ModuleType('torchvision.transforms')
|
||||||
|
torchvision_module.__spec__ = importlib.machinery.ModuleSpec('torchvision', loader=None)
|
||||||
|
models_module.__spec__ = importlib.machinery.ModuleSpec('torchvision.models', loader=None)
|
||||||
|
transforms_module.__spec__ = importlib.machinery.ModuleSpec('torchvision.transforms', loader=None)
|
||||||
|
models_module.resnet18 = lambda weights=None: _FakeResNet()
|
||||||
|
transforms_module.CenterCrop = _IdentityCrop
|
||||||
|
transforms_module.RandomCrop = _IdentityCrop
|
||||||
|
torchvision_module.models = models_module
|
||||||
|
torchvision_module.transforms = transforms_module
|
||||||
|
|
||||||
|
einops_module = types.ModuleType('einops')
|
||||||
|
einops_module.rearrange = lambda x, *args, **kwargs: x
|
||||||
|
einops_layers_module = types.ModuleType('einops.layers')
|
||||||
|
einops_layers_torch_module = types.ModuleType('einops.layers.torch')
|
||||||
|
einops_layers_torch_module.Rearrange = _FakeRearrange
|
||||||
|
einops_module.layers = einops_layers_module
|
||||||
|
einops_layers_module.torch = einops_layers_torch_module
|
||||||
|
|
||||||
|
transformers_module = types.ModuleType('transformers')
|
||||||
|
transformers_module.__spec__ = importlib.machinery.ModuleSpec('transformers', loader=None)
|
||||||
|
transformers_module.ViTConfig = _FakeViTConfig
|
||||||
|
transformers_module.ViTModel = _FakeViTModel
|
||||||
|
transformers_module.SiglipVisionModel = _FakeSiglipVisionModel
|
||||||
|
|
||||||
|
try:
|
||||||
|
remember_and_remove('roboimi.vla.models.backbones.siglip2_diffusion_backbone')
|
||||||
|
inject('diffusers', diffusers_module)
|
||||||
|
inject('diffusers.schedulers', schedulers_module)
|
||||||
|
inject('diffusers.schedulers.scheduling_ddpm', ddpm_module)
|
||||||
|
inject('diffusers.schedulers.scheduling_ddim', ddim_module)
|
||||||
|
inject('torchvision', torchvision_module)
|
||||||
|
inject('torchvision.models', models_module)
|
||||||
|
inject('torchvision.transforms', transforms_module)
|
||||||
|
inject('einops', einops_module)
|
||||||
|
inject('einops.layers', einops_layers_module)
|
||||||
|
inject('einops.layers.torch', einops_layers_torch_module)
|
||||||
|
inject('transformers', transformers_module)
|
||||||
|
|
||||||
|
if include_imf_head:
|
||||||
|
import roboimi.vla.models.heads as heads_package
|
||||||
|
|
||||||
|
imf_head_module = types.ModuleType('roboimi.vla.models.heads.imf_transformer1d')
|
||||||
|
imf_head_module.IMFTransformer1D = _StubIMFHead
|
||||||
|
inject('roboimi.vla.models.heads.imf_transformer1d', imf_head_module)
|
||||||
|
setattr(heads_package, 'imf_transformer1d', imf_head_module)
|
||||||
|
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
for name, previous in reversed(list(previous_modules.items())):
|
||||||
|
if previous is _MISSING:
|
||||||
|
sys.modules.pop(name, None)
|
||||||
|
else:
|
||||||
|
sys.modules[name] = previous
|
||||||
|
|
||||||
|
|
||||||
|
def _compose_cfg(overrides=None):
|
||||||
|
if not OmegaConf.has_resolver('len'):
|
||||||
|
OmegaConf.register_new_resolver('len', lambda x: len(x))
|
||||||
|
|
||||||
|
GlobalHydra.instance().clear()
|
||||||
|
with initialize_config_dir(version_base=None, config_dir=_CONFIG_DIR):
|
||||||
|
return compose(config_name='config', overrides=list(overrides or []))
|
||||||
|
|
||||||
|
|
||||||
|
def _load_imf_agent_class():
|
||||||
|
with _stub_optional_modules():
|
||||||
|
sys.modules.pop('roboimi.vla.agent_imf', None)
|
||||||
|
module = importlib.import_module('roboimi.vla.agent_imf')
|
||||||
|
return module.IMFVLAAgent, module
|
||||||
|
|
||||||
|
|
||||||
|
class _StubVisionBackbone(nn.Module):
|
||||||
|
output_dim = 1
|
||||||
|
|
||||||
|
def __init__(self, camera_names=_CAMERA_NAMES):
|
||||||
|
super().__init__()
|
||||||
|
self.camera_names = tuple(camera_names)
|
||||||
|
self.num_cameras = len(self.camera_names)
|
||||||
|
|
||||||
|
def forward(self, images):
|
||||||
|
per_camera_features = []
|
||||||
|
for camera_name in self.camera_names:
|
||||||
|
image_batch = images[camera_name]
|
||||||
|
per_camera_features.append(image_batch.mean(dim=(2, 3, 4), keepdim=False).unsqueeze(-1))
|
||||||
|
return torch.cat(per_camera_features, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
class _StubJointVisionBackbone(nn.Module):
|
||||||
|
joint_output_dim = 5
|
||||||
|
output_dim = 5
|
||||||
|
|
||||||
|
def __init__(self, camera_names=_CAMERA_NAMES):
|
||||||
|
super().__init__()
|
||||||
|
self.camera_names = tuple(camera_names)
|
||||||
|
self.num_cameras = len(self.camera_names)
|
||||||
|
|
||||||
|
def forward(self, images):
|
||||||
|
batch_size, obs_horizon = next(iter(images.values())).shape[:2]
|
||||||
|
features = []
|
||||||
|
for camera_name in ('front', 'top', 'r_vis'):
|
||||||
|
image_batch = images[camera_name]
|
||||||
|
features.append(image_batch.mean(dim=(2, 3, 4), keepdim=False).unsqueeze(-1))
|
||||||
|
joint_features = torch.cat(features, dim=-1)
|
||||||
|
front_top_sum = joint_features[..., :2].sum(dim=-1, keepdim=True)
|
||||||
|
r_vis_minus_front = (joint_features[..., 2:] - joint_features[..., :1])
|
||||||
|
time_marker = torch.arange(obs_horizon, dtype=joint_features.dtype).view(1, obs_horizon, 1)
|
||||||
|
time_marker = time_marker.expand(batch_size, -1, -1)
|
||||||
|
return torch.cat([joint_features, front_top_sum, r_vis_minus_front + time_marker], dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
class _StubMultiTokenVisionBackbone(nn.Module):
|
||||||
|
output_dim = 2
|
||||||
|
tokens_per_step = 3
|
||||||
|
|
||||||
|
def __init__(self, camera_names=_CAMERA_NAMES):
|
||||||
|
super().__init__()
|
||||||
|
self.camera_names = tuple(camera_names)
|
||||||
|
self.num_cameras = len(self.camera_names)
|
||||||
|
|
||||||
|
def forward(self, images):
|
||||||
|
batch_size, obs_horizon = next(iter(images.values())).shape[:2]
|
||||||
|
features = []
|
||||||
|
time_marker = torch.arange(obs_horizon, dtype=torch.float32).view(1, obs_horizon, 1).expand(batch_size, -1, -1)
|
||||||
|
for camera_name in self.camera_names:
|
||||||
|
image_batch = images[camera_name]
|
||||||
|
camera_marker = image_batch.mean(dim=(2, 3, 4), keepdim=False).unsqueeze(-1)
|
||||||
|
features.append(torch.cat([camera_marker, camera_marker + time_marker], dim=-1))
|
||||||
|
return torch.stack(features, dim=2)
|
||||||
|
|
||||||
|
|
||||||
|
class _StubMultiTokenVisionBackbone(nn.Module):
|
||||||
|
output_dim = 2
|
||||||
|
tokens_per_step = 3
|
||||||
|
|
||||||
|
def __init__(self, camera_names=_CAMERA_NAMES):
|
||||||
|
super().__init__()
|
||||||
|
self.camera_names = tuple(camera_names)
|
||||||
|
self.num_cameras = len(self.camera_names)
|
||||||
|
|
||||||
|
def forward(self, images):
|
||||||
|
per_camera = []
|
||||||
|
for camera_name in self.camera_names:
|
||||||
|
image_batch = images[camera_name]
|
||||||
|
base = image_batch.mean(dim=(2, 3, 4), keepdim=False)
|
||||||
|
per_camera.append(torch.stack([base, base + 0.5], dim=-1))
|
||||||
|
return torch.stack(per_camera, dim=2)
|
||||||
|
|
||||||
|
|
||||||
|
class _RecordingLinearIMFHead(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = nn.Parameter(torch.tensor(0.5))
|
||||||
|
self.calls = []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _broadcast_batch_time(value, reference):
|
||||||
|
while value.ndim < reference.ndim:
|
||||||
|
value = value.unsqueeze(-1)
|
||||||
|
return value
|
||||||
|
|
||||||
|
def forward(self, sample, r, t, cond=None):
|
||||||
|
record = {
|
||||||
|
'sample': sample.detach().clone(),
|
||||||
|
'r': r.detach().clone(),
|
||||||
|
't': t.detach().clone(),
|
||||||
|
'cond': None if cond is None else cond.detach().clone(),
|
||||||
|
}
|
||||||
|
self.calls.append(record)
|
||||||
|
cond_term = 0.0
|
||||||
|
if cond is not None:
|
||||||
|
cond_term = cond.mean(dim=(1, 2), keepdim=True)
|
||||||
|
r_b = self._broadcast_batch_time(r, sample)
|
||||||
|
t_b = self._broadcast_batch_time(t, sample)
|
||||||
|
return self.scale * sample + r_b + 2.0 * t_b + cond_term
|
||||||
|
|
||||||
|
|
||||||
|
class _ForbiddenScheduler:
|
||||||
|
def set_timesteps(self, *args, **kwargs): # pragma: no cover - only runs on regression
|
||||||
|
raise AssertionError('IMF inference should not use DDIM scheduler set_timesteps')
|
||||||
|
|
||||||
|
def step(self, *args, **kwargs): # pragma: no cover - only runs on regression
|
||||||
|
raise AssertionError('IMF inference should not use DDIM scheduler step')
|
||||||
|
|
||||||
|
|
||||||
|
def _make_images(batch_size, obs_horizon, per_camera_fill):
|
||||||
|
return {
|
||||||
|
name: torch.full((batch_size, obs_horizon, 1, 2, 2), fill_value=value, dtype=torch.float32)
|
||||||
|
for name, value in per_camera_fill.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class IMFVLAAgentTest(unittest.TestCase):
|
||||||
|
def _make_agent(self, pred_horizon=3, obs_horizon=2, num_action_steps=2):
|
||||||
|
agent_cls, agent_module = _load_imf_agent_class()
|
||||||
|
head = _RecordingLinearIMFHead()
|
||||||
|
agent = agent_cls(
|
||||||
|
vision_backbone=_StubVisionBackbone(),
|
||||||
|
state_encoder=nn.Identity(),
|
||||||
|
action_encoder=nn.Identity(),
|
||||||
|
head=head,
|
||||||
|
action_dim=2,
|
||||||
|
obs_dim=1,
|
||||||
|
pred_horizon=pred_horizon,
|
||||||
|
obs_horizon=obs_horizon,
|
||||||
|
diffusion_steps=10,
|
||||||
|
inference_steps=1,
|
||||||
|
num_cams=len(_CAMERA_NAMES),
|
||||||
|
camera_names=_CAMERA_NAMES,
|
||||||
|
num_action_steps=num_action_steps,
|
||||||
|
head_type='transformer',
|
||||||
|
)
|
||||||
|
return agent, head, agent_module
|
||||||
|
|
||||||
|
def test_compute_loss_matches_imf_objective_and_masks_padded_actions(self):
|
||||||
|
agent, head, agent_module = self._make_agent(pred_horizon=3, obs_horizon=2)
|
||||||
|
images = _make_images(
|
||||||
|
batch_size=1,
|
||||||
|
obs_horizon=2,
|
||||||
|
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
|
||||||
|
)
|
||||||
|
qpos = torch.tensor([[[0.25], [0.75]]], dtype=torch.float32)
|
||||||
|
actions = torch.tensor(
|
||||||
|
[[[1.0, -1.0], [0.5, 0.25], [-0.5, 1.5]]],
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
action_is_pad = torch.tensor([[False, False, True]])
|
||||||
|
noise = torch.tensor(
|
||||||
|
[[[0.2, -0.4], [0.1, 0.3], [0.5, -0.2]]],
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
t_sample = torch.tensor([0.8], dtype=torch.float32)
|
||||||
|
r_sample = torch.tensor([0.25], dtype=torch.float32)
|
||||||
|
|
||||||
|
with mock.patch.object(agent_module.torch, 'randn_like', return_value=noise), \
|
||||||
|
mock.patch.object(agent_module.torch, 'rand', side_effect=[t_sample, r_sample]):
|
||||||
|
loss = agent.compute_loss(
|
||||||
|
{
|
||||||
|
'images': images,
|
||||||
|
'qpos': qpos,
|
||||||
|
'action': actions,
|
||||||
|
'action_is_pad': action_is_pad,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cond = torch.tensor([[[1.0, 2.0, 3.0, 0.25], [1.0, 2.0, 3.0, 0.75]]], dtype=torch.float32)
|
||||||
|
cond_term = cond.mean(dim=(1, 2), keepdim=True)
|
||||||
|
t = t_sample
|
||||||
|
r = r_sample
|
||||||
|
z_t = (1 - t.view(1, 1, 1)) * actions + t.view(1, 1, 1) * noise
|
||||||
|
scale = head.scale.detach()
|
||||||
|
u = scale * z_t + r.view(1, 1, 1) + 2.0 * t.view(1, 1, 1) + cond_term
|
||||||
|
v = scale * z_t + 3.0 * t.view(1, 1, 1) + cond_term
|
||||||
|
du_dt = scale * v + 2.0
|
||||||
|
compound_velocity = u + (t - r).view(1, 1, 1) * du_dt
|
||||||
|
target = noise - actions
|
||||||
|
elementwise_loss = (compound_velocity - target) ** 2
|
||||||
|
mask = (~action_is_pad).unsqueeze(-1).to(elementwise_loss.dtype)
|
||||||
|
expected_loss = (elementwise_loss * mask).sum() / (mask.sum() * elementwise_loss.shape[-1])
|
||||||
|
|
||||||
|
self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6)
|
||||||
|
self.assertEqual(len(head.calls), 2)
|
||||||
|
self.assertTrue(torch.allclose(head.calls[0]['r'], t_sample))
|
||||||
|
self.assertTrue(torch.allclose(head.calls[0]['t'], t_sample))
|
||||||
|
self.assertTrue(torch.allclose(head.calls[0]['cond'], cond))
|
||||||
|
|
||||||
|
def test_predict_action_uses_one_step_imf_sampling_and_image_conditioning(self):
|
||||||
|
agent, head, agent_module = self._make_agent(pred_horizon=3, obs_horizon=2)
|
||||||
|
agent.infer_scheduler = _ForbiddenScheduler()
|
||||||
|
|
||||||
|
images = _make_images(
|
||||||
|
batch_size=2,
|
||||||
|
obs_horizon=2,
|
||||||
|
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
|
||||||
|
)
|
||||||
|
qpos = torch.tensor(
|
||||||
|
[
|
||||||
|
[[1.0], [2.0]],
|
||||||
|
[[3.0], [4.0]],
|
||||||
|
],
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
initial_noise = torch.tensor(
|
||||||
|
[
|
||||||
|
[[1.0, -1.0], [0.0, 2.0], [3.0, -2.0]],
|
||||||
|
[[-1.0, 1.0], [2.0, -3.0], [0.5, 0.25]],
|
||||||
|
],
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
with mock.patch.object(agent_module.torch, 'randn', return_value=initial_noise):
|
||||||
|
predicted_actions = agent.predict_action(images, qpos)
|
||||||
|
|
||||||
|
expected_cond = torch.tensor(
|
||||||
|
[
|
||||||
|
[[10.0, 20.0, 30.0, 1.0], [10.0, 20.0, 30.0, 2.0]],
|
||||||
|
[[10.0, 20.0, 30.0, 3.0], [10.0, 20.0, 30.0, 4.0]],
|
||||||
|
],
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
cond_term = expected_cond.mean(dim=(1, 2), keepdim=True)
|
||||||
|
expected_actions = 0.5 * initial_noise - 2.0 - cond_term
|
||||||
|
|
||||||
|
self.assertEqual(predicted_actions.shape, (2, 3, 2))
|
||||||
|
self.assertTrue(torch.allclose(predicted_actions, expected_actions))
|
||||||
|
self.assertEqual(len(head.calls), 1)
|
||||||
|
self.assertTrue(torch.allclose(head.calls[0]['r'], torch.zeros(2)))
|
||||||
|
self.assertTrue(torch.allclose(head.calls[0]['t'], torch.ones(2)))
|
||||||
|
self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_cond))
|
||||||
|
|
||||||
|
def test_select_action_only_regenerates_when_action_queue_is_empty(self):
|
||||||
|
agent, _head, _agent_module = self._make_agent(pred_horizon=4, obs_horizon=2, num_action_steps=2)
|
||||||
|
observation = {
|
||||||
|
'qpos': torch.tensor([0.25], dtype=torch.float32),
|
||||||
|
'images': {
|
||||||
|
'front': torch.full((1, 2, 2), 3.0, dtype=torch.float32),
|
||||||
|
'top': torch.full((1, 2, 2), 2.0, dtype=torch.float32),
|
||||||
|
'r_vis': torch.full((1, 2, 2), 1.0, dtype=torch.float32),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
first_chunk = torch.tensor(
|
||||||
|
[[[10.0, 11.0], [12.0, 13.0], [14.0, 15.0], [16.0, 17.0]]],
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
second_chunk = torch.tensor(
|
||||||
|
[[[20.0, 21.0], [22.0, 23.0], [24.0, 25.0], [26.0, 27.0]]],
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
with mock.patch.object(agent, 'predict_action_chunk', side_effect=[first_chunk, second_chunk]) as mock_predict_chunk:
|
||||||
|
first_action = agent.select_action(observation)
|
||||||
|
second_action = agent.select_action(observation)
|
||||||
|
third_action = agent.select_action(observation)
|
||||||
|
|
||||||
|
self.assertTrue(torch.equal(first_action, first_chunk[0, 1]))
|
||||||
|
self.assertTrue(torch.equal(second_action, first_chunk[0, 2]))
|
||||||
|
self.assertTrue(torch.equal(third_action, second_chunk[0, 1]))
|
||||||
|
self.assertEqual(mock_predict_chunk.call_count, 2)
|
||||||
|
|
||||||
|
def test_joint_visual_backbone_uses_joint_output_dim_for_conditioning(self):
|
||||||
|
agent_cls, _agent_module = _load_imf_agent_class()
|
||||||
|
head = _RecordingLinearIMFHead()
|
||||||
|
vision_backbone = _StubJointVisionBackbone()
|
||||||
|
agent = agent_cls(
|
||||||
|
vision_backbone=vision_backbone,
|
||||||
|
state_encoder=nn.Identity(),
|
||||||
|
action_encoder=nn.Identity(),
|
||||||
|
head=head,
|
||||||
|
action_dim=2,
|
||||||
|
obs_dim=1,
|
||||||
|
pred_horizon=3,
|
||||||
|
obs_horizon=2,
|
||||||
|
diffusion_steps=10,
|
||||||
|
inference_steps=1,
|
||||||
|
num_cams=len(_CAMERA_NAMES),
|
||||||
|
camera_names=_CAMERA_NAMES,
|
||||||
|
num_action_steps=2,
|
||||||
|
head_type='transformer',
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(agent.per_step_cond_dim, vision_backbone.joint_output_dim + agent.obs_dim)
|
||||||
|
self.assertEqual(
|
||||||
|
agent.global_cond_dim,
|
||||||
|
vision_backbone.joint_output_dim * agent.obs_horizon + agent.obs_dim * agent.obs_horizon,
|
||||||
|
)
|
||||||
|
|
||||||
|
images = _make_images(
|
||||||
|
batch_size=1,
|
||||||
|
obs_horizon=2,
|
||||||
|
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
|
||||||
|
)
|
||||||
|
qpos = torch.tensor([[[1.0], [2.0]]], dtype=torch.float32)
|
||||||
|
initial_noise = torch.tensor(
|
||||||
|
[[[1.0, -1.0], [0.0, 2.0], [3.0, -2.0]]],
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
with mock.patch.object(torch, 'randn', return_value=initial_noise):
|
||||||
|
predicted_actions = agent.predict_action(images, qpos)
|
||||||
|
|
||||||
|
self.assertEqual(predicted_actions.shape, (1, 3, 2))
|
||||||
|
self.assertEqual(len(head.calls), 1)
|
||||||
|
expected_cond = torch.tensor(
|
||||||
|
[[[30.0, 20.0, 10.0, 50.0, -20.0, 1.0], [30.0, 20.0, 10.0, 50.0, -19.0, 2.0]]],
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
self.assertEqual(head.calls[0]['cond'].shape[-1], 6)
|
||||||
|
self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_cond))
|
||||||
|
|
||||||
|
def test_multitoken_visual_backbone_flattens_camera_tokens_and_projects_each_with_state(self):
|
||||||
|
agent_cls, _agent_module = _load_imf_agent_class()
|
||||||
|
head = _RecordingLinearIMFHead()
|
||||||
|
projector = nn.Linear(3, 4, bias=False)
|
||||||
|
with torch.no_grad():
|
||||||
|
projector.weight.copy_(
|
||||||
|
torch.tensor(
|
||||||
|
[
|
||||||
|
[1.0, 0.0, 0.0],
|
||||||
|
[0.0, 1.0, 0.0],
|
||||||
|
[0.0, 0.0, 1.0],
|
||||||
|
[1.0, 0.0, 1.0],
|
||||||
|
],
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
agent = agent_cls(
|
||||||
|
vision_backbone=_StubMultiTokenVisionBackbone(),
|
||||||
|
state_encoder=nn.Identity(),
|
||||||
|
action_encoder=nn.Identity(),
|
||||||
|
head=head,
|
||||||
|
action_dim=2,
|
||||||
|
obs_dim=1,
|
||||||
|
pred_horizon=3,
|
||||||
|
obs_horizon=2,
|
||||||
|
diffusion_steps=10,
|
||||||
|
inference_steps=1,
|
||||||
|
num_cams=len(_CAMERA_NAMES),
|
||||||
|
camera_names=_CAMERA_NAMES,
|
||||||
|
num_action_steps=2,
|
||||||
|
head_type='transformer',
|
||||||
|
cond_projector=projector,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(agent.condition_tokens_per_step, 3)
|
||||||
|
self.assertEqual(agent.condition_sequence_length, 6)
|
||||||
|
self.assertEqual(agent.per_step_cond_dim, 4)
|
||||||
|
self.assertEqual(agent.global_cond_dim, 24)
|
||||||
|
|
||||||
|
images = _make_images(
|
||||||
|
batch_size=1,
|
||||||
|
obs_horizon=2,
|
||||||
|
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
|
||||||
|
)
|
||||||
|
qpos = torch.tensor([[[1.0], [2.0]]], dtype=torch.float32)
|
||||||
|
cond = agent._build_cond(images, qpos)
|
||||||
|
|
||||||
|
expected = torch.tensor(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[10.0, 10.5, 1.0, 11.0],
|
||||||
|
[20.0, 20.5, 1.0, 21.0],
|
||||||
|
[30.0, 30.5, 1.0, 31.0],
|
||||||
|
[10.0, 10.5, 2.0, 12.0],
|
||||||
|
[20.0, 20.5, 2.0, 22.0],
|
||||||
|
[30.0, 30.5, 2.0, 32.0],
|
||||||
|
]
|
||||||
|
],
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
self.assertEqual(cond.shape, (1, 6, 4))
|
||||||
|
self.assertTrue(torch.allclose(cond, expected))
|
||||||
|
|
||||||
|
def test_multi_token_visual_backbone_pairs_state_per_camera_and_flattens_condition_sequence(self):
|
||||||
|
agent_cls, agent_module = _load_imf_agent_class()
|
||||||
|
head = _RecordingLinearIMFHead()
|
||||||
|
cond_projector = nn.Linear(3, 4, bias=False)
|
||||||
|
with torch.no_grad():
|
||||||
|
cond_projector.weight.copy_(torch.tensor([
|
||||||
|
[1.0, 0.0, 0.0],
|
||||||
|
[0.0, 1.0, 0.0],
|
||||||
|
[0.0, 0.0, 1.0],
|
||||||
|
[1.0, 0.0, 1.0],
|
||||||
|
], dtype=torch.float32))
|
||||||
|
|
||||||
|
agent = agent_cls(
|
||||||
|
vision_backbone=_StubMultiTokenVisionBackbone(),
|
||||||
|
state_encoder=nn.Identity(),
|
||||||
|
action_encoder=nn.Identity(),
|
||||||
|
head=head,
|
||||||
|
action_dim=2,
|
||||||
|
obs_dim=1,
|
||||||
|
pred_horizon=3,
|
||||||
|
obs_horizon=2,
|
||||||
|
diffusion_steps=10,
|
||||||
|
inference_steps=1,
|
||||||
|
num_cams=len(_CAMERA_NAMES),
|
||||||
|
camera_names=_CAMERA_NAMES,
|
||||||
|
num_action_steps=2,
|
||||||
|
head_type='transformer',
|
||||||
|
cond_projector=cond_projector,
|
||||||
|
)
|
||||||
|
agent.infer_scheduler = _ForbiddenScheduler()
|
||||||
|
|
||||||
|
images = _make_images(
|
||||||
|
batch_size=1,
|
||||||
|
obs_horizon=2,
|
||||||
|
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
|
||||||
|
)
|
||||||
|
qpos = torch.tensor([[[1.0], [2.0]]], dtype=torch.float32)
|
||||||
|
initial_noise = torch.tensor([[[1.0, -1.0], [0.0, 2.0], [3.0, -2.0]]], dtype=torch.float32)
|
||||||
|
|
||||||
|
with mock.patch.object(agent_module.torch, 'randn', return_value=initial_noise):
|
||||||
|
predicted_actions = agent.predict_action(images, qpos)
|
||||||
|
|
||||||
|
expected_cond = torch.tensor([[[10.0, 10.5, 1.0, 11.0],
|
||||||
|
[20.0, 20.5, 1.0, 21.0],
|
||||||
|
[30.0, 30.5, 1.0, 31.0],
|
||||||
|
[10.0, 10.5, 2.0, 12.0],
|
||||||
|
[20.0, 20.5, 2.0, 22.0],
|
||||||
|
[30.0, 30.5, 2.0, 32.0]]], dtype=torch.float32)
|
||||||
|
|
||||||
|
self.assertEqual(agent.condition_tokens_per_step, 3)
|
||||||
|
self.assertEqual(agent.condition_sequence_length, 6)
|
||||||
|
self.assertEqual(agent.raw_per_step_cond_dim, 3)
|
||||||
|
self.assertEqual(agent.per_step_cond_dim, 4)
|
||||||
|
self.assertEqual(agent.global_cond_dim, 24)
|
||||||
|
self.assertEqual(predicted_actions.shape, (1, 3, 2))
|
||||||
|
self.assertEqual(len(head.calls), 1)
|
||||||
|
self.assertEqual(head.calls[0]['cond'].shape, (1, 6, 4))
|
||||||
|
self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_cond))
|
||||||
|
|
||||||
|
def test_hydra_config_instantiates_resnet_imf_attnres_with_stub_head(self):
|
||||||
|
cfg = _compose_cfg(
|
||||||
|
overrides=[
|
||||||
|
'agent=resnet_imf_attnres',
|
||||||
|
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||||
|
'agent.vision_backbone.input_shape=[3,16,16]',
|
||||||
|
'agent.vision_backbone.freeze_backbone=false',
|
||||||
|
'agent.head.n_layer=1',
|
||||||
|
'agent.head.n_emb=16',
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
|
||||||
|
self.assertEqual(cfg.agent.head._target_, 'roboimi.vla.models.heads.imf_transformer1d.IMFTransformer1D')
|
||||||
|
self.assertEqual(cfg.agent.head.backbone_type, 'attnres_full')
|
||||||
|
self.assertEqual(cfg.agent.head.n_head, 1)
|
||||||
|
self.assertEqual(cfg.agent.head.n_kv_head, 1)
|
||||||
|
self.assertEqual(cfg.agent.head.n_cond_layers, 0)
|
||||||
|
self.assertTrue(cfg.agent.head.time_as_cond)
|
||||||
|
self.assertFalse(cfg.agent.head.causal_attn)
|
||||||
|
self.assertEqual(cfg.agent.inference_steps, 1)
|
||||||
|
self.assertEqual(list(cfg.agent.camera_names), list(_CAMERA_NAMES))
|
||||||
|
|
||||||
|
with _stub_optional_modules(include_imf_head=True):
|
||||||
|
agent = instantiate(cfg.agent)
|
||||||
|
|
||||||
|
self.assertEqual(agent.head_type, 'transformer')
|
||||||
|
self.assertEqual(agent.per_step_cond_dim, agent.vision_encoder.output_dim * agent.num_cams + agent.obs_dim)
|
||||||
|
self.assertIsInstance(agent.noise_pred_net, _StubIMFHead)
|
||||||
|
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], agent.per_step_cond_dim)
|
||||||
|
self.assertEqual(agent.noise_pred_net.constructor_kwargs['backbone_type'], 'attnres_full')
|
||||||
|
|
||||||
|
def test_hydra_config_instantiates_resnet_imf_attnres_with_full_attnres_vision_backbone(self):
|
||||||
|
cfg = _compose_cfg(
|
||||||
|
overrides=[
|
||||||
|
'agent=resnet_imf_attnres',
|
||||||
|
'agent.vision_backbone.vision_backbone_mode=attnres_resnet',
|
||||||
|
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||||
|
'agent.vision_backbone.input_shape=[3,56,56]',
|
||||||
|
'agent.vision_backbone.freeze_backbone=false',
|
||||||
|
'agent.vision_backbone.attnres_stem_dim=16',
|
||||||
|
'agent.vision_backbone.attnres_stage_dims=[16,32,64,128]',
|
||||||
|
'agent.vision_backbone.attnres_stage_depths=[1,1,1,1]',
|
||||||
|
'agent.vision_backbone.attnres_stage_heads=[2,4,4,8]',
|
||||||
|
'agent.vision_backbone.attnres_stage_kv_heads=[1,1,1,1]',
|
||||||
|
'agent.vision_backbone.attnres_stage_window_sizes=[7,7,7,7]',
|
||||||
|
'agent.head.n_layer=1',
|
||||||
|
'agent.head.n_emb=16',
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
with _stub_optional_modules(include_imf_head=True):
|
||||||
|
agent = instantiate(cfg.agent)
|
||||||
|
|
||||||
|
self.assertEqual(agent.vision_encoder.output_dim, 64)
|
||||||
|
self.assertEqual(agent.per_step_cond_dim, 64 * agent.num_cams + agent.obs_dim)
|
||||||
|
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], agent.per_step_cond_dim)
|
||||||
|
|
||||||
|
def test_hydra_config_instantiates_lewm_imf_attnres_with_joint_visual_condition_dim(self):
|
||||||
|
cfg = _compose_cfg(
|
||||||
|
overrides=[
|
||||||
|
'agent=lewm_imf_attnres',
|
||||||
|
'agent.vision_backbone.checkpoint_path=null',
|
||||||
|
'agent.head.n_layer=1',
|
||||||
|
'agent.head.n_emb=16',
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
|
||||||
|
self.assertEqual(cfg.agent.vision_backbone._target_, 'roboimi.vla.models.backbones.lewm_vit_backbone.LEWMViTBackbone')
|
||||||
|
self.assertEqual(list(cfg.agent.camera_names), list(_CAMERA_NAMES))
|
||||||
|
self.assertEqual(list(cfg.agent.vision_backbone.camera_names), list(_CAMERA_NAMES))
|
||||||
|
self.assertEqual(list(cfg.agent.vision_backbone.fused_camera_names), ['front', 'top', 'r_vis'])
|
||||||
|
self.assertIsNone(cfg.agent.vision_backbone.dataset_image_resize_shape)
|
||||||
|
self.assertEqual(list(cfg.agent.vision_backbone.eval_image_resize_shape), [256, 256])
|
||||||
|
self.assertEqual(cfg.agent.head.cond_dim, 208)
|
||||||
|
|
||||||
|
with _stub_optional_modules(include_imf_head=True):
|
||||||
|
agent = instantiate(cfg.agent)
|
||||||
|
|
||||||
|
self.assertEqual(agent.per_step_cond_dim, agent.vision_encoder.joint_output_dim + agent.obs_dim)
|
||||||
|
self.assertEqual(agent.per_step_cond_dim, 208)
|
||||||
|
self.assertEqual(agent.global_cond_dim, agent.obs_horizon * 208)
|
||||||
|
self.assertIsNone(agent.vision_encoder.dataset_image_resize_shape)
|
||||||
|
self.assertEqual(agent.vision_encoder.eval_image_resize_shape, (256, 256))
|
||||||
|
self.assertIsInstance(agent.noise_pred_net, _StubIMFHead)
|
||||||
|
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], 208)
|
||||||
|
|
||||||
|
def test_hydra_config_instantiates_resnet_imf_attnres_multitoken_with_projected_camera_tokens(self):
|
||||||
|
cfg = _compose_cfg(
|
||||||
|
overrides=[
|
||||||
|
'agent=resnet_imf_attnres_multitoken',
|
||||||
|
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||||
|
'agent.vision_backbone.input_shape=[3,16,16]',
|
||||||
|
'agent.head.n_layer=1',
|
||||||
|
'agent.head.n_emb=32',
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
|
||||||
|
self.assertEqual(cfg.agent.vision_backbone.vision_backbone_mode, 'resnet')
|
||||||
|
self.assertTrue(cfg.agent.vision_backbone.use_separate_rgb_encoder_per_camera)
|
||||||
|
self.assertTrue(cfg.agent.vision_backbone.output_tokens_per_camera)
|
||||||
|
self.assertEqual(cfg.agent.cond_projector.output_dim, 32)
|
||||||
|
self.assertEqual(cfg.agent.head.cond_dim, 32)
|
||||||
|
|
||||||
|
with _stub_optional_modules(include_imf_head=True):
|
||||||
|
agent = instantiate(cfg.agent)
|
||||||
|
|
||||||
|
self.assertEqual(agent.condition_tokens_per_step, 3)
|
||||||
|
self.assertEqual(agent.condition_sequence_length, agent.obs_horizon * 3)
|
||||||
|
self.assertEqual(agent.per_step_cond_dim, 32)
|
||||||
|
self.assertEqual(agent.global_cond_dim, agent.condition_sequence_length * 32)
|
||||||
|
self.assertIsInstance(agent.noise_pred_net, _StubIMFHead)
|
||||||
|
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], 32)
|
||||||
|
self.assertEqual(agent.noise_pred_net.constructor_kwargs['n_obs_steps'], 6)
|
||||||
|
|
||||||
|
|
||||||
|
def test_hydra_config_instantiates_siglip2_imf_attnres_with_condition_projection(self):
|
||||||
|
cfg = _compose_cfg(
|
||||||
|
overrides=[
|
||||||
|
'agent=siglip2_imf_attnres',
|
||||||
|
'agent.vision_backbone.per_view_output_dim=96',
|
||||||
|
'agent.head.n_layer=1',
|
||||||
|
'agent.head.n_emb=16',
|
||||||
|
'agent.cond_projector.output_dim=384',
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
|
||||||
|
self.assertEqual(
|
||||||
|
cfg.agent.vision_backbone._target_,
|
||||||
|
'roboimi.vla.models.backbones.siglip2_diffusion_backbone.SigLIP2DiffusionBackbone',
|
||||||
|
)
|
||||||
|
self.assertEqual(list(cfg.agent.camera_names), list(_CAMERA_NAMES))
|
||||||
|
self.assertIsNone(cfg.agent.vision_backbone.dataset_image_resize_shape)
|
||||||
|
self.assertEqual(list(cfg.agent.vision_backbone.eval_image_resize_shape), [256, 256])
|
||||||
|
self.assertEqual(cfg.agent.head.cond_dim, 384)
|
||||||
|
|
||||||
|
with _stub_optional_modules(include_imf_head=True):
|
||||||
|
agent = instantiate(cfg.agent)
|
||||||
|
|
||||||
|
self.assertEqual(agent.raw_per_step_cond_dim, 3 * 96 + agent.obs_dim)
|
||||||
|
self.assertEqual(agent.per_step_cond_dim, 384)
|
||||||
|
self.assertEqual(agent.global_cond_dim, agent.obs_horizon * 384)
|
||||||
|
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], 384)
|
||||||
|
self.assertEqual(agent.vision_encoder.output_dim, 96)
|
||||||
|
self.assertEqual(agent.vision_encoder.eval_image_resize_shape, (256, 256))
|
||||||
|
|
||||||
|
|
||||||
|
def test_hydra_config_instantiates_resnet_imf_attnres_multitoken_with_sequence_length_three_times_obs_horizon(self):
|
||||||
|
cfg = _compose_cfg(
|
||||||
|
overrides=[
|
||||||
|
'agent=resnet_imf_attnres_multitoken',
|
||||||
|
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||||
|
'agent.vision_backbone.input_shape=[3,16,16]',
|
||||||
|
'agent.vision_backbone.freeze_backbone=false',
|
||||||
|
'agent.head.n_layer=1',
|
||||||
|
'agent.head.n_emb=16',
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
|
||||||
|
self.assertEqual(list(cfg.agent.camera_names), list(_CAMERA_NAMES))
|
||||||
|
self.assertTrue(cfg.agent.vision_backbone.use_separate_rgb_encoder_per_camera)
|
||||||
|
self.assertTrue(cfg.agent.vision_backbone.output_tokens_per_camera)
|
||||||
|
self.assertEqual(cfg.agent.vision_backbone.vision_backbone_mode, 'resnet')
|
||||||
|
self.assertEqual(cfg.agent.cond_projector.output_dim, 16)
|
||||||
|
self.assertEqual(cfg.agent.head.cond_dim, 16)
|
||||||
|
|
||||||
|
with _stub_optional_modules(include_imf_head=True):
|
||||||
|
agent = instantiate(cfg.agent)
|
||||||
|
|
||||||
|
self.assertEqual(agent.condition_tokens_per_step, 3)
|
||||||
|
self.assertEqual(agent.condition_sequence_length, agent.obs_horizon * 3)
|
||||||
|
self.assertEqual(agent.per_step_cond_dim, 16)
|
||||||
|
self.assertEqual(agent.global_cond_dim, agent.condition_sequence_length * 16)
|
||||||
|
self.assertEqual(agent.vision_encoder.tokens_per_step, 3)
|
||||||
|
self.assertIsInstance(agent.noise_pred_net, _StubIMFHead)
|
||||||
|
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], 16)
|
||||||
|
self.assertEqual(agent.noise_pred_net.constructor_kwargs['n_obs_steps'], agent.condition_sequence_length)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
220
tests/test_lewm_vit_backbone.py
Normal file
220
tests/test_lewm_vit_backbone.py
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
import tempfile
|
||||||
|
import types
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers import ViTConfig, ViTModel
|
||||||
|
|
||||||
|
|
||||||
|
_INPUT_CAMERA_NAMES = ("r_vis", "top", "front")
|
||||||
|
_FUSED_CAMERA_NAMES = ("front", "top", "r_vis")
|
||||||
|
|
||||||
|
|
||||||
|
class _ReferenceProjector(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Linear(192, 2048),
|
||||||
|
nn.BatchNorm1d(2048),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(2048, 192),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_reference_encoder() -> ViTModel:
|
||||||
|
return ViTModel(
|
||||||
|
ViTConfig(
|
||||||
|
image_size=224,
|
||||||
|
patch_size=14,
|
||||||
|
num_channels=3,
|
||||||
|
hidden_size=192,
|
||||||
|
intermediate_size=768,
|
||||||
|
num_hidden_layers=12,
|
||||||
|
num_attention_heads=3,
|
||||||
|
qkv_bias=True,
|
||||||
|
),
|
||||||
|
add_pooling_layer=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _write_synthetic_lightning_ckpt(path: Path):
|
||||||
|
torch.manual_seed(7)
|
||||||
|
encoder = _build_reference_encoder()
|
||||||
|
projector = _ReferenceProjector()
|
||||||
|
lightning_state_dict = {}
|
||||||
|
for key, value in encoder.state_dict().items():
|
||||||
|
lightning_state_dict[f"model.encoder.{key}"] = value.detach().clone()
|
||||||
|
for key, value in projector.state_dict().items():
|
||||||
|
lightning_state_dict[f"model.projector.{key}"] = value.detach().clone()
|
||||||
|
torch.save({"state_dict": lightning_state_dict}, path)
|
||||||
|
return encoder.state_dict(), projector.state_dict()
|
||||||
|
|
||||||
|
|
||||||
|
class LEWMViTBackboneTest(unittest.TestCase):
|
||||||
|
def test_loads_lightning_encoder_and_projector_checkpoint_and_emits_joint_embedding(self):
|
||||||
|
from roboimi.vla.models.backbones.lewm_vit_backbone import LEWMViTBackbone
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
ckpt_path = Path(tmpdir) / "synthetic-lewm.ckpt"
|
||||||
|
reference_encoder_state, reference_projector_state = _write_synthetic_lightning_ckpt(
|
||||||
|
ckpt_path
|
||||||
|
)
|
||||||
|
|
||||||
|
backbone = LEWMViTBackbone(
|
||||||
|
checkpoint_path=ckpt_path,
|
||||||
|
camera_names=_INPUT_CAMERA_NAMES,
|
||||||
|
fused_camera_names=_FUSED_CAMERA_NAMES,
|
||||||
|
freeze_backbone=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(backbone.camera_names, _INPUT_CAMERA_NAMES)
|
||||||
|
self.assertEqual(backbone.fused_camera_names, _FUSED_CAMERA_NAMES)
|
||||||
|
self.assertEqual(backbone.num_cameras, 3)
|
||||||
|
self.assertEqual(backbone.joint_output_dim, 192)
|
||||||
|
self.assertEqual(backbone.output_dim, 192)
|
||||||
|
self.assertEqual(backbone.encoder.config.hidden_size, 192)
|
||||||
|
self.assertEqual(backbone.encoder.config.patch_size, 14)
|
||||||
|
self.assertEqual(backbone.encoder.config.num_hidden_layers, 12)
|
||||||
|
self.assertEqual(backbone.encoder.config.num_attention_heads, 3)
|
||||||
|
|
||||||
|
for key, value in reference_encoder_state.items():
|
||||||
|
self.assertTrue(torch.equal(backbone.encoder.state_dict()[key], value), key)
|
||||||
|
for key, value in reference_projector_state.items():
|
||||||
|
self.assertTrue(torch.equal(backbone.projector.state_dict()[key], value), key)
|
||||||
|
|
||||||
|
images = {
|
||||||
|
cam_name: torch.rand(1, 1, 3, 224, 224)
|
||||||
|
for cam_name in _INPUT_CAMERA_NAMES
|
||||||
|
}
|
||||||
|
output = backbone(images)
|
||||||
|
|
||||||
|
self.assertEqual(output.shape, (1, 1, 192))
|
||||||
|
self.assertFalse(output.requires_grad)
|
||||||
|
|
||||||
|
def test_forward_uses_front_top_rvis_fusion_order_and_exact_lewm_cwh_resize_path(self):
|
||||||
|
from roboimi.vla.models.backbones.lewm_vit_backbone import LEWMViTBackbone
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
ckpt_path = Path(tmpdir) / "synthetic-lewm.ckpt"
|
||||||
|
_write_synthetic_lightning_ckpt(ckpt_path)
|
||||||
|
|
||||||
|
backbone = LEWMViTBackbone(
|
||||||
|
checkpoint_path=ckpt_path,
|
||||||
|
camera_names=_INPUT_CAMERA_NAMES,
|
||||||
|
fused_camera_names=_FUSED_CAMERA_NAMES,
|
||||||
|
freeze_backbone=True,
|
||||||
|
)
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def fake_encoder_forward(module, pixel_values, interpolate_pos_encoding=False, **kwargs):
|
||||||
|
del module, kwargs
|
||||||
|
captured["pixel_values"] = pixel_values.detach().clone()
|
||||||
|
captured["interpolate_pos_encoding"] = interpolate_pos_encoding
|
||||||
|
batch = pixel_values.shape[0]
|
||||||
|
patch_tokens = (pixel_values.shape[-2] // 14) * (pixel_values.shape[-1] // 14)
|
||||||
|
cls = (
|
||||||
|
torch.arange(192, dtype=pixel_values.dtype, device=pixel_values.device)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.expand(batch, -1)
|
||||||
|
)
|
||||||
|
last_hidden_state = torch.zeros(
|
||||||
|
batch,
|
||||||
|
patch_tokens + 1,
|
||||||
|
192,
|
||||||
|
dtype=pixel_values.dtype,
|
||||||
|
device=pixel_values.device,
|
||||||
|
)
|
||||||
|
last_hidden_state[:, 0] = cls
|
||||||
|
return types.SimpleNamespace(last_hidden_state=last_hidden_state)
|
||||||
|
|
||||||
|
backbone.encoder.forward = types.MethodType(fake_encoder_forward, backbone.encoder)
|
||||||
|
|
||||||
|
r_vis = torch.full((1, 1, 3, 256, 256), 0.30)
|
||||||
|
top = torch.full((1, 1, 3, 256, 256), 0.20)
|
||||||
|
front = torch.full((1, 1, 3, 256, 256), 0.10)
|
||||||
|
bn = backbone.projector.net[1]
|
||||||
|
running_mean_before = bn.running_mean.detach().clone()
|
||||||
|
running_var_before = bn.running_var.detach().clone()
|
||||||
|
|
||||||
|
backbone.train()
|
||||||
|
self.assertFalse(backbone.encoder.training)
|
||||||
|
self.assertFalse(backbone.projector.training)
|
||||||
|
|
||||||
|
output = backbone({"r_vis": r_vis, "top": top, "front": front})
|
||||||
|
|
||||||
|
self.assertEqual(output.shape, (1, 1, 192))
|
||||||
|
self.assertEqual(captured["pixel_values"].shape, (1, 3, 672, 224))
|
||||||
|
self.assertTrue(captured["interpolate_pos_encoding"])
|
||||||
|
|
||||||
|
normalized_views = [
|
||||||
|
((view.reshape(-1, *view.shape[2:]).float()).clamp(0.0, 1.0) - backbone.mean) / backbone.std
|
||||||
|
for view in (front, top, r_vis)
|
||||||
|
]
|
||||||
|
expected_fuse_then_resize = F.interpolate(
|
||||||
|
torch.cat(normalized_views, dim=-2),
|
||||||
|
size=(672, 224),
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=False,
|
||||||
|
antialias=True,
|
||||||
|
)
|
||||||
|
expected_pre_resize_then_fuse = torch.cat(
|
||||||
|
[
|
||||||
|
F.interpolate(
|
||||||
|
view,
|
||||||
|
size=(224, 224),
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=False,
|
||||||
|
antialias=True,
|
||||||
|
)
|
||||||
|
for view in normalized_views
|
||||||
|
],
|
||||||
|
dim=-2,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(captured["pixel_values"], expected_fuse_then_resize, atol=1e-6, rtol=1e-6)
|
||||||
|
)
|
||||||
|
self.assertFalse(
|
||||||
|
torch.allclose(
|
||||||
|
expected_fuse_then_resize,
|
||||||
|
expected_pre_resize_then_fuse,
|
||||||
|
atol=1e-6,
|
||||||
|
rtol=1e-6,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertFalse(
|
||||||
|
torch.allclose(
|
||||||
|
captured["pixel_values"],
|
||||||
|
expected_pre_resize_then_fuse,
|
||||||
|
atol=1e-6,
|
||||||
|
rtol=1e-6,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(
|
||||||
|
captured["pixel_values"][0, :, 223, :],
|
||||||
|
expected_fuse_then_resize[0, :, 223, :],
|
||||||
|
atol=1e-6,
|
||||||
|
rtol=1e-6,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(
|
||||||
|
captured["pixel_values"][0, :, 447, :],
|
||||||
|
expected_fuse_then_resize[0, :, 447, :],
|
||||||
|
atol=1e-6,
|
||||||
|
rtol=1e-6,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(torch.equal(bn.running_mean, running_mean_before))
|
||||||
|
self.assertTrue(torch.equal(bn.running_var, running_var_before))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -180,6 +180,14 @@ def _extract_camera_markers(cond, feature_dim, num_cams):
|
|||||||
return camera_block[:, 0]
|
return camera_block[:, 0]
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_token_camera_markers(tokens):
|
||||||
|
return tokens[0, 0, :, 0]
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_token_markers(token_sequence):
|
||||||
|
return token_sequence[0, 0, :, 0]
|
||||||
|
|
||||||
|
|
||||||
class ResNetTransformerAgentWiringTest(unittest.TestCase):
|
class ResNetTransformerAgentWiringTest(unittest.TestCase):
|
||||||
def test_hydra_wiring_uses_required_three_camera_transformer_conditioning_in_agent_order_and_ignores_extra_keys(self):
|
def test_hydra_wiring_uses_required_three_camera_transformer_conditioning_in_agent_order_and_ignores_extra_keys(self):
|
||||||
cfg = _compose_cfg(
|
cfg = _compose_cfg(
|
||||||
@@ -246,6 +254,36 @@ class ResNetTransformerAgentWiringTest(unittest.TestCase):
|
|||||||
with self.assertRaisesRegex(ValueError, 'missing=.*top'):
|
with self.assertRaisesRegex(ValueError, 'missing=.*top'):
|
||||||
agent.predict_action(missing_images, proprioception)
|
agent.predict_action(missing_images, proprioception)
|
||||||
|
|
||||||
|
def test_multitoken_resnet_backbone_emits_one_token_per_camera_in_agent_order(self):
|
||||||
|
cfg = _compose_cfg(
|
||||||
|
overrides=[
|
||||||
|
'agent=resnet_imf_attnres_multitoken',
|
||||||
|
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||||
|
'agent.vision_backbone.input_shape=[3,16,16]',
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
with _stub_optional_modules():
|
||||||
|
backbone = instantiate(cfg.agent.vision_backbone)
|
||||||
|
_patch_backbone_for_order_tracking(backbone)
|
||||||
|
images = _make_images(
|
||||||
|
batch_size=1,
|
||||||
|
obs_horizon=cfg.agent.obs_horizon,
|
||||||
|
image_shape=tuple(cfg.agent.vision_backbone.input_shape),
|
||||||
|
per_camera_fill={
|
||||||
|
'front': 30.0,
|
||||||
|
'top': 20.0,
|
||||||
|
'r_vis': 10.0,
|
||||||
|
'left_wrist': 99.0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
tokens = backbone(images)
|
||||||
|
|
||||||
|
self.assertEqual(tokens.shape, (1, cfg.agent.obs_horizon, 3, backbone.output_dim))
|
||||||
|
self.assertEqual(backbone.tokens_per_step, 3)
|
||||||
|
camera_markers = _extract_token_camera_markers(tokens)
|
||||||
|
self.assertTrue(torch.allclose(camera_markers, torch.tensor([10.0, 20.0, 30.0])))
|
||||||
|
|
||||||
def test_agent_rejects_conflicting_explicit_backbone_camera_names(self):
|
def test_agent_rejects_conflicting_explicit_backbone_camera_names(self):
|
||||||
cfg = _compose_cfg(
|
cfg = _compose_cfg(
|
||||||
overrides=[
|
overrides=[
|
||||||
@@ -382,6 +420,36 @@ class ResNetTransformerAgentWiringTest(unittest.TestCase):
|
|||||||
with self.assertRaisesRegex(InstantiationException, 'num_cams'):
|
with self.assertRaisesRegex(InstantiationException, 'num_cams'):
|
||||||
instantiate(cfg.agent)
|
instantiate(cfg.agent)
|
||||||
|
|
||||||
|
def test_multitoken_resnet_backbone_emits_one_token_per_camera_in_agent_order(self):
|
||||||
|
cfg = _compose_cfg(
|
||||||
|
overrides=[
|
||||||
|
'agent=resnet_imf_attnres_multitoken',
|
||||||
|
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||||
|
'agent.vision_backbone.input_shape=[3,16,16]',
|
||||||
|
'agent.head.n_layer=1',
|
||||||
|
'agent.head.n_emb=32',
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
with _stub_optional_modules():
|
||||||
|
backbone = instantiate(cfg.agent.vision_backbone)
|
||||||
|
_patch_backbone_for_order_tracking(backbone)
|
||||||
|
images = _make_images(
|
||||||
|
batch_size=1,
|
||||||
|
obs_horizon=cfg.agent.obs_horizon,
|
||||||
|
image_shape=tuple(cfg.agent.vision_backbone.input_shape),
|
||||||
|
per_camera_fill={
|
||||||
|
'front': 30.0,
|
||||||
|
'top': 20.0,
|
||||||
|
'r_vis': 10.0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
output = backbone(images)
|
||||||
|
|
||||||
|
self.assertEqual(output.shape, (1, cfg.agent.obs_horizon, 3, backbone.output_dim))
|
||||||
|
token_markers = _extract_token_markers(output)
|
||||||
|
self.assertTrue(torch.allclose(token_markers, torch.tensor([10.0, 20.0, 30.0])))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
121
tests/test_siglip2_diffusion_backbone.py
Normal file
121
tests/test_siglip2_diffusion_backbone.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
import types
|
||||||
|
import unittest
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
_CAMERA_NAMES = ("r_vis", "top", "front")
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeSiglipVisionOutput:
|
||||||
|
def __init__(self, pooler_output):
|
||||||
|
self.pooler_output = pooler_output
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeSiglipVisionConfig:
|
||||||
|
def __init__(self, hidden_size=768, image_size=256):
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.image_size = image_size
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeSiglipVisionModel(nn.Module):
|
||||||
|
def __init__(self, hidden_size=768):
|
||||||
|
super().__init__()
|
||||||
|
self.config = _FakeSiglipVisionConfig(hidden_size=hidden_size)
|
||||||
|
self.forward_calls = []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
del args, kwargs
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
def forward(self, pixel_values=None, **kwargs):
|
||||||
|
self.forward_calls.append({
|
||||||
|
"pixel_values": pixel_values.detach().clone(),
|
||||||
|
"kwargs": dict(kwargs),
|
||||||
|
})
|
||||||
|
pooled = pixel_values.mean(dim=(2, 3), keepdim=False)
|
||||||
|
return _FakeSiglipVisionOutput(pooler_output=pooled)
|
||||||
|
|
||||||
|
|
||||||
|
class SigLIP2DiffusionBackboneTest(unittest.TestCase):
|
||||||
|
def test_forward_encodes_each_view_independently_and_concatenates_projected_features(self):
|
||||||
|
from roboimi.vla.models.backbones.siglip2_diffusion_backbone import SigLIP2DiffusionBackbone
|
||||||
|
|
||||||
|
fake_model = _FakeSiglipVisionModel(hidden_size=3)
|
||||||
|
with mock.patch(
|
||||||
|
"roboimi.vla.models.backbones.siglip2_diffusion_backbone.SiglipVisionModel.from_pretrained",
|
||||||
|
return_value=fake_model,
|
||||||
|
) as mock_from_pretrained:
|
||||||
|
backbone = SigLIP2DiffusionBackbone(
|
||||||
|
model_name="google/siglip2-base-patch16-256",
|
||||||
|
camera_names=_CAMERA_NAMES,
|
||||||
|
num_cameras=3,
|
||||||
|
per_view_output_dim=2,
|
||||||
|
freeze_backbone=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(backbone.camera_names, _CAMERA_NAMES)
|
||||||
|
self.assertEqual(backbone.num_cameras, 3)
|
||||||
|
self.assertEqual(backbone.output_dim, 2)
|
||||||
|
self.assertEqual(backbone.joint_output_dim, 6)
|
||||||
|
self.assertIsNone(backbone.dataset_image_resize_shape)
|
||||||
|
self.assertEqual(backbone.eval_image_resize_shape, (256, 256))
|
||||||
|
mock_from_pretrained.assert_called_once_with("google/siglip2-base-patch16-256")
|
||||||
|
self.assertTrue(all(not p.requires_grad for p in backbone.encoder.parameters()))
|
||||||
|
self.assertFalse(backbone.encoder.training)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
backbone.view_projector.weight.zero_()
|
||||||
|
backbone.view_projector.bias.zero_()
|
||||||
|
backbone.view_projector.weight[0, 0] = 1.0
|
||||||
|
backbone.view_projector.weight[1, 1] = 1.0
|
||||||
|
|
||||||
|
images = {
|
||||||
|
"r_vis": torch.full((1, 2, 3, 256, 256), 0.25),
|
||||||
|
"top": torch.full((1, 2, 3, 256, 256), 0.50),
|
||||||
|
"front": torch.full((1, 2, 3, 256, 256), 0.75),
|
||||||
|
}
|
||||||
|
output = backbone(images)
|
||||||
|
|
||||||
|
self.assertEqual(output.shape, (1, 2, 6))
|
||||||
|
self.assertEqual(len(fake_model.forward_calls), 3)
|
||||||
|
|
||||||
|
expected_per_camera = []
|
||||||
|
for cam_name in _CAMERA_NAMES:
|
||||||
|
img = images[cam_name].reshape(2, 3, 256, 256)
|
||||||
|
normalized = (img - 0.5) / 0.5
|
||||||
|
expected_per_camera.append(normalized.mean(dim=(2, 3))[:, :2])
|
||||||
|
expected = torch.cat(expected_per_camera, dim=-1).view(1, 2, 6)
|
||||||
|
self.assertTrue(torch.allclose(output, expected, atol=1e-6, rtol=1e-6))
|
||||||
|
|
||||||
|
for call, cam_name in zip(fake_model.forward_calls, _CAMERA_NAMES):
|
||||||
|
pixels = call["pixel_values"]
|
||||||
|
self.assertEqual(tuple(pixels.shape), (2, 3, 256, 256))
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(
|
||||||
|
pixels,
|
||||||
|
(images[cam_name].reshape(2, 3, 256, 256) - 0.5) / 0.5,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_forward_rejects_missing_required_camera(self):
|
||||||
|
from roboimi.vla.models.backbones.siglip2_diffusion_backbone import SigLIP2DiffusionBackbone
|
||||||
|
|
||||||
|
backbone = SigLIP2DiffusionBackbone(
|
||||||
|
vision_model=_FakeSiglipVisionModel(hidden_size=4),
|
||||||
|
camera_names=_CAMERA_NAMES,
|
||||||
|
num_cameras=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(ValueError, "missing"):
|
||||||
|
backbone({
|
||||||
|
"r_vis": torch.rand(1, 1, 3, 256, 256),
|
||||||
|
"top": torch.rand(1, 1, 3, 256, 256),
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -56,3 +56,26 @@ class SimpleRobotDatasetImageLoadingTest(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(len(resize_calls), 2)
|
self.assertEqual(len(resize_calls), 2)
|
||||||
self.assertEqual(tuple(sample["observation.front"].shape), (2, 3, 8, 8))
|
self.assertEqual(tuple(sample["observation.front"].shape), (2, 3, 8, 8))
|
||||||
|
|
||||||
|
def test_getitem_skips_resize_when_image_resize_shape_is_none(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
dataset_dir = Path(tmpdir)
|
||||||
|
self._write_episode(dataset_dir)
|
||||||
|
dataset = SimpleRobotDataset(
|
||||||
|
dataset_dir,
|
||||||
|
obs_horizon=2,
|
||||||
|
pred_horizon=3,
|
||||||
|
camera_names=["front"],
|
||||||
|
image_resize_shape=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
fake_cv2 = types.SimpleNamespace(
|
||||||
|
INTER_LINEAR=1,
|
||||||
|
resize=mock.Mock(side_effect=AssertionError("resize should be skipped when image_resize_shape=None")),
|
||||||
|
)
|
||||||
|
|
||||||
|
with mock.patch.dict(sys.modules, {"cv2": fake_cv2}):
|
||||||
|
sample = dataset[1]
|
||||||
|
|
||||||
|
fake_cv2.resize.assert_not_called()
|
||||||
|
self.assertEqual(tuple(sample["observation.front"].shape), (2, 3, 8, 8))
|
||||||
|
|||||||
@@ -159,6 +159,92 @@ class TrainVLARolloutValidationTest(unittest.TestCase):
|
|||||||
self.assertGreater(cfg.train.num_workers, 8)
|
self.assertGreater(cfg.train.num_workers, 8)
|
||||||
self.assertEqual(cfg.train.rollout_val_freq_epochs, 50)
|
self.assertEqual(cfg.train.rollout_val_freq_epochs, 50)
|
||||||
|
|
||||||
|
def test_training_passes_backbone_image_resize_override_to_dataset_instantiation(self):
|
||||||
|
cfg = OmegaConf.create(
|
||||||
|
{
|
||||||
|
'agent': {
|
||||||
|
'vision_backbone': {
|
||||||
|
'dataset_image_resize_shape': None,
|
||||||
|
},
|
||||||
|
'normalization_type': 'min_max',
|
||||||
|
},
|
||||||
|
'data': {
|
||||||
|
'dataset_dir': 'unused',
|
||||||
|
'camera_names': ['front'],
|
||||||
|
},
|
||||||
|
'train': {
|
||||||
|
'batch_size': 2,
|
||||||
|
'lr': 1e-4,
|
||||||
|
'max_steps': 0,
|
||||||
|
'device': 'cpu',
|
||||||
|
'disable_cudnn': False,
|
||||||
|
'num_workers': 0,
|
||||||
|
'val_split': 0.0,
|
||||||
|
'seed': 42,
|
||||||
|
'log_freq': 1,
|
||||||
|
'save_freq': 10,
|
||||||
|
'use_swanlab': False,
|
||||||
|
'rollout_val_freq_epochs': 0,
|
||||||
|
'rollout_validate_on_checkpoint': False,
|
||||||
|
'rollout_num_episodes': 1,
|
||||||
|
'warmup_steps': 1,
|
||||||
|
'scheduler_type': 'constant',
|
||||||
|
'min_lr': 1e-6,
|
||||||
|
'weight_decay': 1e-5,
|
||||||
|
'grad_clip': 1.0,
|
||||||
|
'pretrained_ckpt': None,
|
||||||
|
},
|
||||||
|
'eval': {
|
||||||
|
'ckpt_path': 'unused.pt',
|
||||||
|
'num_episodes': 1,
|
||||||
|
'headless': True,
|
||||||
|
'device': 'cpu',
|
||||||
|
'verbose_action': False,
|
||||||
|
},
|
||||||
|
'experiment': {},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
captured_dataset_kwargs = {}
|
||||||
|
|
||||||
|
def fake_instantiate(config_node, **kwargs):
|
||||||
|
if config_node is cfg.data:
|
||||||
|
captured_dataset_kwargs.update(kwargs)
|
||||||
|
return _FakeDataset()
|
||||||
|
if config_node is cfg.agent:
|
||||||
|
return _FakeAgent()
|
||||||
|
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||||
|
|
||||||
|
def fake_dataloader(_dataset, *, shuffle, **_kwargs):
|
||||||
|
del shuffle, _kwargs
|
||||||
|
return _FakeLoader(
|
||||||
|
{
|
||||||
|
'observation.front': torch.zeros(1, 3, 2, 2),
|
||||||
|
'observation.state': torch.zeros(1, 4),
|
||||||
|
'action': torch.zeros(1, 2),
|
||||||
|
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
|
||||||
|
},
|
||||||
|
length=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
|
previous_cwd = os.getcwd()
|
||||||
|
try:
|
||||||
|
os.chdir(tempdir)
|
||||||
|
with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \
|
||||||
|
mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \
|
||||||
|
mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \
|
||||||
|
mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \
|
||||||
|
mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \
|
||||||
|
mock.patch.object(train_vla, '_init_swanlab', return_value=None), \
|
||||||
|
mock.patch.object(train_vla, '_finish_swanlab', return_value=None), \
|
||||||
|
mock.patch.object(train_vla.torch, 'save', return_value=None):
|
||||||
|
train_vla._run_training(cfg)
|
||||||
|
finally:
|
||||||
|
os.chdir(previous_cwd)
|
||||||
|
|
||||||
|
self.assertIn('image_resize_shape', captured_dataset_kwargs)
|
||||||
|
self.assertIsNone(captured_dataset_kwargs['image_resize_shape'])
|
||||||
|
|
||||||
def test_eval_main_delegates_to_plain_run_eval_helper(self):
|
def test_eval_main_delegates_to_plain_run_eval_helper(self):
|
||||||
cfg = OmegaConf.create(
|
cfg = OmegaConf.create(
|
||||||
{
|
{
|
||||||
@@ -234,7 +320,28 @@ class TrainVLARolloutValidationTest(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
agent = _FakeAgent()
|
agent = _FakeAgent()
|
||||||
rollout_mock = mock.Mock(side_effect=[{'avg_reward': 2.0}, {'avg_reward': 1.0}])
|
rollout_mock = mock.Mock(
|
||||||
|
side_effect=[
|
||||||
|
{
|
||||||
|
'avg_reward': 2.0,
|
||||||
|
'episodes': [
|
||||||
|
{
|
||||||
|
'episode_index': 0,
|
||||||
|
'artifact_paths': {'trajectory_image': 'artifacts/epoch_49_front.png'},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'avg_reward': 1.0,
|
||||||
|
'episodes': [
|
||||||
|
{
|
||||||
|
'episode_index': 0,
|
||||||
|
'artifact_paths': {'trajectory_image': 'artifacts/epoch_99_front.png'},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
swanlab_log_mock = mock.Mock()
|
swanlab_log_mock = mock.Mock()
|
||||||
saved_checkpoints = []
|
saved_checkpoints = []
|
||||||
|
|
||||||
@@ -281,17 +388,22 @@ class TrainVLARolloutValidationTest(unittest.TestCase):
|
|||||||
self.assertEqual(rollout_mock.call_count, 2)
|
self.assertEqual(rollout_mock.call_count, 2)
|
||||||
first_rollout_cfg = rollout_mock.call_args_list[0].args[0]
|
first_rollout_cfg = rollout_mock.call_args_list[0].args[0]
|
||||||
second_rollout_cfg = rollout_mock.call_args_list[1].args[0]
|
second_rollout_cfg = rollout_mock.call_args_list[1].args[0]
|
||||||
self.assertEqual(first_rollout_cfg.eval.ckpt_path, 'checkpoints/vla_model_step_49.pt')
|
self.assertTrue(first_rollout_cfg.eval.ckpt_path.endswith('checkpoints/vla_model_step_49.pt'))
|
||||||
self.assertEqual(second_rollout_cfg.eval.ckpt_path, 'checkpoints/vla_model_step_99.pt')
|
self.assertTrue(second_rollout_cfg.eval.ckpt_path.endswith('checkpoints/vla_model_step_99.pt'))
|
||||||
self.assertEqual(first_rollout_cfg.eval.num_episodes, 3)
|
self.assertEqual(first_rollout_cfg.eval.num_episodes, 3)
|
||||||
self.assertTrue(first_rollout_cfg.eval.headless)
|
self.assertTrue(first_rollout_cfg.eval.headless)
|
||||||
self.assertEqual(first_rollout_cfg.eval.device, 'cpu')
|
self.assertEqual(first_rollout_cfg.eval.device, 'cpu')
|
||||||
self.assertFalse(first_rollout_cfg.eval.verbose_action)
|
self.assertFalse(first_rollout_cfg.eval.verbose_action)
|
||||||
|
self.assertFalse(first_rollout_cfg.eval.record_video)
|
||||||
|
self.assertTrue(first_rollout_cfg.eval.save_trajectory_image)
|
||||||
|
self.assertEqual(first_rollout_cfg.eval.trajectory_image_camera_name, 'front')
|
||||||
self.assertEqual(cfg.eval.ckpt_path, 'unused.pt')
|
self.assertEqual(cfg.eval.ckpt_path, 'unused.pt')
|
||||||
self.assertEqual(cfg.eval.num_episodes, 99)
|
self.assertEqual(cfg.eval.num_episodes, 99)
|
||||||
self.assertFalse(cfg.eval.headless)
|
self.assertFalse(cfg.eval.headless)
|
||||||
self.assertEqual(cfg.eval.device, 'cpu')
|
self.assertEqual(cfg.eval.device, 'cpu')
|
||||||
self.assertFalse(cfg.eval.verbose_action)
|
self.assertFalse(cfg.eval.verbose_action)
|
||||||
|
self.assertNotIn('save_trajectory_image', cfg.eval)
|
||||||
|
self.assertNotIn('trajectory_image_camera_name', cfg.eval)
|
||||||
|
|
||||||
rollout_reward_logs = [
|
rollout_reward_logs = [
|
||||||
call.args[1]['rollout/avg_reward']
|
call.args[1]['rollout/avg_reward']
|
||||||
@@ -769,10 +881,8 @@ class TrainVLARolloutValidationTest(unittest.TestCase):
|
|||||||
'dataset_len': 1,
|
'dataset_len': 1,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(len(saved_checkpoints), 1)
|
||||||
[path for path, _payload in saved_checkpoints],
|
self.assertTrue(saved_checkpoints[0][0].endswith('checkpoints/vla_model_final.pt'))
|
||||||
['checkpoints/vla_model_final.pt'],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@@ -115,13 +115,15 @@ class FakeAgent(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FakeSwanLab:
|
class FakeSwanLab:
|
||||||
def __init__(self, init_error=None, log_errors=None, finish_error=None):
|
def __init__(self, init_error=None, log_errors=None, finish_error=None, image_errors=None):
|
||||||
self.init_error = init_error
|
self.init_error = init_error
|
||||||
self.log_errors = list(log_errors or [])
|
self.log_errors = list(log_errors or [])
|
||||||
self.finish_error = finish_error
|
self.finish_error = finish_error
|
||||||
|
self.image_errors = list(image_errors or [])
|
||||||
self.init_calls = []
|
self.init_calls = []
|
||||||
self.log_calls = []
|
self.log_calls = []
|
||||||
self.finish_calls = 0
|
self.finish_calls = 0
|
||||||
|
self.image_calls = []
|
||||||
|
|
||||||
def init(self, project, experiment_name=None, config=None):
|
def init(self, project, experiment_name=None, config=None):
|
||||||
self.init_calls.append({
|
self.init_calls.append({
|
||||||
@@ -138,6 +140,18 @@ class FakeSwanLab:
|
|||||||
if self.log_errors:
|
if self.log_errors:
|
||||||
raise self.log_errors.pop(0)
|
raise self.log_errors.pop(0)
|
||||||
|
|
||||||
|
def Image(self, path, caption=None):
|
||||||
|
self.image_calls.append({
|
||||||
|
'path': path,
|
||||||
|
'caption': caption,
|
||||||
|
})
|
||||||
|
if self.image_errors:
|
||||||
|
raise self.image_errors.pop(0)
|
||||||
|
return {
|
||||||
|
'path': path,
|
||||||
|
'caption': caption,
|
||||||
|
}
|
||||||
|
|
||||||
def finish(self):
|
def finish(self):
|
||||||
self.finish_calls += 1
|
self.finish_calls += 1
|
||||||
if self.finish_error is not None:
|
if self.finish_error is not None:
|
||||||
@@ -149,6 +163,119 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
|||||||
config_text = _CONFIG_PATH.read_text(encoding='utf-8')
|
config_text = _CONFIG_PATH.read_text(encoding='utf-8')
|
||||||
self.assertIn('use_swanlab: false', config_text)
|
self.assertIn('use_swanlab: false', config_text)
|
||||||
|
|
||||||
|
def test_log_rollout_trajectory_images_to_swanlab_uploads_episode_artifacts(self):
|
||||||
|
module = self._load_train_vla_module()
|
||||||
|
fake_swanlab = FakeSwanLab()
|
||||||
|
|
||||||
|
module._log_rollout_trajectory_images_to_swanlab(
|
||||||
|
fake_swanlab,
|
||||||
|
{
|
||||||
|
'episodes': [
|
||||||
|
{
|
||||||
|
'episode_index': 0,
|
||||||
|
'artifact_paths': {'trajectory_image': 'artifacts/episode_0_front.png'},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'episode_index': 3,
|
||||||
|
'artifact_paths': {'trajectory_image': 'artifacts/episode_3_front.png'},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'episode_index': 7,
|
||||||
|
'artifact_paths': {'trajectory_image': None},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'episode_index': 8,
|
||||||
|
'artifact_paths': {},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
step=12,
|
||||||
|
context_label='epoch 1 rollout',
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
fake_swanlab.image_calls,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
'path': 'artifacts/episode_0_front.png',
|
||||||
|
'caption': 'epoch 1 rollout trajectory image - episode 0 (front)',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'path': 'artifacts/episode_3_front.png',
|
||||||
|
'caption': 'epoch 1 rollout trajectory image - episode 3 (front)',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
self.assertIn(
|
||||||
|
(
|
||||||
|
{
|
||||||
|
'rollout/trajectory_image_episode_0': {
|
||||||
|
'path': 'artifacts/episode_0_front.png',
|
||||||
|
'caption': 'epoch 1 rollout trajectory image - episode 0 (front)',
|
||||||
|
},
|
||||||
|
'rollout/trajectory_image_episode_3': {
|
||||||
|
'path': 'artifacts/episode_3_front.png',
|
||||||
|
'caption': 'epoch 1 rollout trajectory image - episode 3 (front)',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
12,
|
||||||
|
),
|
||||||
|
fake_swanlab.log_calls,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_log_rollout_trajectory_images_to_swanlab_is_best_effort(self):
|
||||||
|
module = self._load_train_vla_module()
|
||||||
|
fake_swanlab = FakeSwanLab(image_errors=[RuntimeError('decode failed')])
|
||||||
|
|
||||||
|
with mock.patch.object(module.log, 'warning') as warning_mock:
|
||||||
|
module._log_rollout_trajectory_images_to_swanlab(
|
||||||
|
fake_swanlab,
|
||||||
|
{
|
||||||
|
'episodes': [
|
||||||
|
{
|
||||||
|
'episode_index': 0,
|
||||||
|
'artifact_paths': {'trajectory_image': 'artifacts/bad_episode.png'},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'episode_index': 1,
|
||||||
|
'artifact_paths': {'trajectory_image': 'artifacts/good_episode.png'},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
step=7,
|
||||||
|
context_label='checkpoint rollout',
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
fake_swanlab.image_calls,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
'path': 'artifacts/bad_episode.png',
|
||||||
|
'caption': 'checkpoint rollout trajectory image - episode 0 (front)',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'path': 'artifacts/good_episode.png',
|
||||||
|
'caption': 'checkpoint rollout trajectory image - episode 1 (front)',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
self.assertIn(
|
||||||
|
(
|
||||||
|
{
|
||||||
|
'rollout/trajectory_image_episode_1': {
|
||||||
|
'path': 'artifacts/good_episode.png',
|
||||||
|
'caption': 'checkpoint rollout trajectory image - episode 1 (front)',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
7,
|
||||||
|
),
|
||||||
|
fake_swanlab.log_calls,
|
||||||
|
)
|
||||||
|
warning_messages = [call.args[0] for call in warning_mock.call_args_list]
|
||||||
|
self.assertTrue(
|
||||||
|
any('SwanLab rollout trajectory image upload prep failed' in message for message in warning_messages)
|
||||||
|
)
|
||||||
|
|
||||||
def _load_train_vla_module(self):
|
def _load_train_vla_module(self):
|
||||||
hydra_module = types.ModuleType('hydra')
|
hydra_module = types.ModuleType('hydra')
|
||||||
hydra_utils_module = types.ModuleType('hydra.utils')
|
hydra_utils_module = types.ModuleType('hydra.utils')
|
||||||
@@ -356,8 +483,8 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
|||||||
|
|
||||||
final_payload, final_step = fake_swanlab.log_calls[-1]
|
final_payload, final_step = fake_swanlab.log_calls[-1]
|
||||||
self.assertEqual(final_step, cfg.train.max_steps)
|
self.assertEqual(final_step, cfg.train.max_steps)
|
||||||
self.assertEqual(final_payload['final/checkpoint_path'], 'checkpoints/vla_model_final.pt')
|
self.assertTrue(final_payload['final/checkpoint_path'].endswith('checkpoints/vla_model_final.pt'))
|
||||||
self.assertEqual(final_payload['final/best_checkpoint_path'], 'checkpoints/vla_model_best.pt')
|
self.assertTrue(final_payload['final/best_checkpoint_path'].endswith('checkpoints/vla_model_best.pt'))
|
||||||
self.assertEqual(fake_swanlab.finish_calls, 1)
|
self.assertEqual(fake_swanlab.finish_calls, 1)
|
||||||
|
|
||||||
def test_run_training_skips_swanlab_when_disabled(self):
|
def test_run_training_skips_swanlab_when_disabled(self):
|
||||||
@@ -512,10 +639,10 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
|||||||
|
|
||||||
def fake_torch_load(path, map_location=None):
|
def fake_torch_load(path, map_location=None):
|
||||||
del map_location
|
del map_location
|
||||||
path = Path(path)
|
path = Path(path).resolve()
|
||||||
if path == resume_path:
|
if path == resume_path.resolve():
|
||||||
return resume_checkpoint_state
|
return resume_checkpoint_state
|
||||||
if path == best_path:
|
if path == best_path.resolve():
|
||||||
return best_checkpoint_state
|
return best_checkpoint_state
|
||||||
raise AssertionError(f'unexpected load path: {path}')
|
raise AssertionError(f'unexpected load path: {path}')
|
||||||
|
|
||||||
@@ -538,8 +665,8 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
|||||||
|
|
||||||
final_payload, final_step = fake_swanlab.log_calls[-1]
|
final_payload, final_step = fake_swanlab.log_calls[-1]
|
||||||
self.assertEqual(final_step, cfg.train.max_steps)
|
self.assertEqual(final_step, cfg.train.max_steps)
|
||||||
self.assertEqual(final_payload['final/best_checkpoint_path'], 'checkpoints/vla_model_best.pt')
|
self.assertTrue(final_payload['final/best_checkpoint_path'].endswith('checkpoints/vla_model_best.pt'))
|
||||||
self.assertNotIn('checkpoints/vla_model_best.pt', saved_paths)
|
self.assertFalse(any(path.endswith('checkpoints/vla_model_best.pt') for path in saved_paths))
|
||||||
|
|
||||||
def test_run_training_resume_ignores_best_checkpoint_without_rollout_metric(self):
|
def test_run_training_resume_ignores_best_checkpoint_without_rollout_metric(self):
|
||||||
module = self._load_train_vla_module()
|
module = self._load_train_vla_module()
|
||||||
@@ -594,10 +721,10 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
|||||||
|
|
||||||
def fake_torch_load(path, map_location=None):
|
def fake_torch_load(path, map_location=None):
|
||||||
del map_location
|
del map_location
|
||||||
path = Path(path)
|
path = Path(path).resolve()
|
||||||
if path == resume_path:
|
if path == resume_path.resolve():
|
||||||
return resume_checkpoint_state
|
return resume_checkpoint_state
|
||||||
if path == best_path:
|
if path == best_path.resolve():
|
||||||
return stale_best_checkpoint_state
|
return stale_best_checkpoint_state
|
||||||
raise AssertionError(f'unexpected load path: {path}')
|
raise AssertionError(f'unexpected load path: {path}')
|
||||||
|
|
||||||
|
|||||||
@@ -101,10 +101,19 @@ class RecordingTransformerHead(nn.Module):
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class FakeTransformerAgent(nn.Module):
|
class FakeIMFAgent(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.head_type = 'transformer'
|
self.head_type = 'imf_transformer'
|
||||||
|
self.noise_pred_net = RecordingTransformerHead()
|
||||||
|
self.backbone = nn.Linear(4, 3)
|
||||||
|
self.adapter = nn.Linear(3, 2, bias=False)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeTransformerAgent(nn.Module):
|
||||||
|
def __init__(self, *, head_type='transformer'):
|
||||||
|
super().__init__()
|
||||||
|
self.head_type = head_type
|
||||||
self.noise_pred_net = RecordingTransformerHead()
|
self.noise_pred_net = RecordingTransformerHead()
|
||||||
self.backbone = nn.Linear(4, 3)
|
self.backbone = nn.Linear(4, 3)
|
||||||
self.adapter = nn.Linear(3, 2, bias=False)
|
self.adapter = nn.Linear(3, 2, bias=False)
|
||||||
@@ -205,6 +214,95 @@ class TrainVLATransformerOptimizerTest(unittest.TestCase):
|
|||||||
for group in optimizer.param_groups
|
for group in optimizer.param_groups
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def test_clean_ld_preload_value_removes_problematic_nxegl_entry(self):
|
||||||
|
module = self._load_train_vla_module()
|
||||||
|
|
||||||
|
cleaned, changed = module._clean_ld_preload_value(
|
||||||
|
'/usr/lib/libfoo.so /usr/NX/lib/libnxegl.so /usr/lib/libbar.so'
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(changed)
|
||||||
|
self.assertEqual(cleaned, '/usr/lib/libfoo.so /usr/lib/libbar.so')
|
||||||
|
|
||||||
|
def test_clean_ld_preload_value_leaves_safe_entries_unchanged(self):
|
||||||
|
module = self._load_train_vla_module()
|
||||||
|
|
||||||
|
cleaned, changed = module._clean_ld_preload_value('/usr/lib/libfoo.so /usr/lib/libbar.so')
|
||||||
|
|
||||||
|
self.assertFalse(changed)
|
||||||
|
self.assertEqual(cleaned, '/usr/lib/libfoo.so /usr/lib/libbar.so')
|
||||||
|
|
||||||
|
|
||||||
|
def test_configure_cuda_runtime_can_disable_cudnn_for_training(self):
|
||||||
|
module = self._load_train_vla_module()
|
||||||
|
cfg = AttrDict(train=AttrDict(device='cuda', disable_cudnn=True))
|
||||||
|
|
||||||
|
original = module.torch.backends.cudnn.enabled
|
||||||
|
try:
|
||||||
|
module.torch.backends.cudnn.enabled = True
|
||||||
|
module._configure_cuda_runtime(cfg)
|
||||||
|
self.assertFalse(module.torch.backends.cudnn.enabled)
|
||||||
|
finally:
|
||||||
|
module.torch.backends.cudnn.enabled = original
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_run_output_dir_prefers_hydra_runtime_output_dir(self):
|
||||||
|
module = self._load_train_vla_module()
|
||||||
|
hydra_core_module = types.ModuleType('hydra.core')
|
||||||
|
hydra_hydra_config_module = types.ModuleType('hydra.core.hydra_config')
|
||||||
|
|
||||||
|
class _Runtime:
|
||||||
|
output_dir = '/tmp/hydra-output'
|
||||||
|
|
||||||
|
class _Cfg:
|
||||||
|
runtime = _Runtime()
|
||||||
|
|
||||||
|
class HydraConfigStub:
|
||||||
|
@staticmethod
|
||||||
|
def initialized():
|
||||||
|
return True
|
||||||
|
@staticmethod
|
||||||
|
def get():
|
||||||
|
return _Cfg()
|
||||||
|
|
||||||
|
hydra_hydra_config_module.HydraConfig = HydraConfigStub
|
||||||
|
with mock.patch.dict(sys.modules, {
|
||||||
|
'hydra.core': hydra_core_module,
|
||||||
|
'hydra.core.hydra_config': hydra_hydra_config_module,
|
||||||
|
}):
|
||||||
|
output_dir = module._resolve_run_output_dir()
|
||||||
|
|
||||||
|
self.assertEqual(Path(output_dir).resolve(), Path('/tmp/hydra-output').resolve())
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_script_uses_file_based_repo_root_on_sys_path(self):
|
||||||
|
module = self._load_train_vla_module()
|
||||||
|
|
||||||
|
fake_sys_path = ['/tmp/site-packages', '/another/path']
|
||||||
|
with mock.patch.object(module.sys, 'path', fake_sys_path):
|
||||||
|
repo_root = module._ensure_repo_root_on_syspath()
|
||||||
|
|
||||||
|
self.assertEqual(Path(repo_root).resolve(), _REPO_ROOT.resolve())
|
||||||
|
self.assertEqual(Path(fake_sys_path[0]).resolve(), _REPO_ROOT.resolve())
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_transformer_head_with_get_optim_groups_still_uses_custom_groups(self):
|
||||||
|
module = self._load_train_vla_module()
|
||||||
|
agent = FakeIMFAgent()
|
||||||
|
|
||||||
|
optimizer = module.build_training_optimizer(agent, lr=1e-4, weight_decay=0.123)
|
||||||
|
|
||||||
|
self.assertEqual(agent.noise_pred_net.optim_group_calls, [0.123])
|
||||||
|
group_names = self._group_names(agent, optimizer)
|
||||||
|
self.assertEqual(group_names[0], {'noise_pred_net.proj.weight'})
|
||||||
|
self.assertEqual(group_names[1], {
|
||||||
|
'noise_pred_net.proj.bias',
|
||||||
|
'noise_pred_net.norm.weight',
|
||||||
|
'noise_pred_net.norm.bias',
|
||||||
|
})
|
||||||
|
self.assertEqual(group_names[2], {'backbone.weight', 'backbone.bias', 'adapter.weight'})
|
||||||
|
|
||||||
|
|
||||||
def test_transformer_training_prefers_head_optim_groups_and_keeps_remaining_trainable_params(self):
|
def test_transformer_training_prefers_head_optim_groups_and_keeps_remaining_trainable_params(self):
|
||||||
module = self._load_train_vla_module()
|
module = self._load_train_vla_module()
|
||||||
agent = FakeTransformerAgent()
|
agent = FakeTransformerAgent()
|
||||||
@@ -268,6 +366,22 @@ class TrainVLATransformerOptimizerTest(unittest.TestCase):
|
|||||||
self.assertNotIn('frozen.weight', optimizer_names)
|
self.assertNotIn('frozen.weight', optimizer_names)
|
||||||
self.assertNotIn('frozen.bias', optimizer_names)
|
self.assertNotIn('frozen.bias', optimizer_names)
|
||||||
|
|
||||||
|
def test_any_head_with_get_optim_groups_uses_custom_groups_even_without_transformer_head_type(self):
|
||||||
|
module = self._load_train_vla_module()
|
||||||
|
agent = FakeTransformerAgent(head_type='imf')
|
||||||
|
|
||||||
|
with mock.patch.object(module, 'AdamW', RecordingAdamW):
|
||||||
|
optimizer = module.build_training_optimizer(agent, lr=1e-4, weight_decay=0.123)
|
||||||
|
|
||||||
|
self.assertEqual(agent.noise_pred_net.optim_group_calls, [0.123])
|
||||||
|
grouped_names = self._group_names(agent, optimizer)
|
||||||
|
self.assertEqual(grouped_names[0], {'noise_pred_net.proj.weight'})
|
||||||
|
self.assertEqual(
|
||||||
|
grouped_names[1],
|
||||||
|
{'noise_pred_net.proj.bias', 'noise_pred_net.norm.weight', 'noise_pred_net.norm.bias'},
|
||||||
|
)
|
||||||
|
self.assertEqual(grouped_names[2], {'backbone.weight', 'backbone.bias', 'adapter.weight'})
|
||||||
|
|
||||||
def test_transformer_optimizer_ignores_frozen_head_params_returned_by_head_groups(self):
|
def test_transformer_optimizer_ignores_frozen_head_params_returned_by_head_groups(self):
|
||||||
module = self._load_train_vla_module()
|
module = self._load_train_vla_module()
|
||||||
agent = FakeTransformerAgent()
|
agent = FakeTransformerAgent()
|
||||||
|
|||||||
Reference in New Issue
Block a user