feat: add full attnres vision backbone

This commit is contained in:
Logic
2026-04-05 00:07:59 +08:00
parent a78006808a
commit 2033169840
9 changed files with 546 additions and 39 deletions

View File

@@ -0,0 +1,64 @@
# Phase-2 Full-AttnRes Vision Implementation Plan
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
**Goal:** Replace all ResNet residual units in the vision backbone with AttnRes-based image blocks while preserving the current IMF agent interfaces and launch a Phase-2 experiment anchored on the best Phase-1 horizon setting.
**Architecture:** Keep the current multi-camera encoder shell and per-camera output contract, but introduce a new ResNet-like 2D AttnRes backbone that preserves stage-wise downsampling and final SpatialSoftmax conditioning. Wire it into the existing `ResNetDiffusionBackbone` via an opt-in mode and keep the agent/head/data interfaces unchanged.
**Tech Stack:** PyTorch, Hydra/OmegaConf, existing IMF AttnRes transformer components, pytest.
---
### Task 1: Add failing tests for the new full-AttnRes visual backbone
**Files:**
- Create: `tests/test_attnres_resnet2d_backbone.py`
- Update: `tests/test_imf_vla_agent.py`
- [ ] **Step 1: Write a failing backbone shape test**
- [ ] **Step 2: Run it to confirm the new backbone/config does not exist yet**
- [ ] **Step 3: Add a failing IMF agent wiring test for unchanged cond_dim=208**
- [ ] **Step 4: Run the targeted tests and capture the failure**
### Task 2: Implement a ResNet-like 2D AttnRes backbone
**Files:**
- Create: `roboimi/vla/models/backbones/attnres_resnet2d.py`
- Modify: `roboimi/vla/models/backbones/resnet_diffusion.py`
- [ ] **Step 1: Add minimal 2D tokenization helpers and positional encoding / bias handling**
- [ ] **Step 2: Implement `AttnResImageBlock2D` for feature maps**
- [ ] **Step 3: Implement `AttnResResNetLikeBackbone2D` with stage-wise downsampling**
- [ ] **Step 4: Wire `_SingleRgbEncoder` to choose between original ResNet trunk and the new full-AttnRes trunk**
- [ ] **Step 5: Run the new backbone tests**
### Task 3: Expose config switches and agent wiring
**Files:**
- Modify: `roboimi/vla/conf/backbone/resnet_diffusion.yaml`
- Modify: `roboimi/vla/conf/agent/resnet_imf_attnres.yaml`
- [ ] **Step 1: Add a backbone mode/config flag for the full-AttnRes vision trunk**
- [ ] **Step 2: Add defaults for attnres image depth/heads/etc. if needed**
- [ ] **Step 3: Add a Phase-2 launch override path that enables the new visual trunk**
- [ ] **Step 4: Run agent wiring tests again**
### Task 4: Smoke-verify training path
**Files:**
- Reuse existing training scripts and configs
- [ ] **Step 1: Run a short CPU or tiny-step smoke instantiation / `compute_loss` test**
- [ ] **Step 2: If needed, run a very short training smoke launch**
- [ ] **Step 3: Verify no cond-dim or rollout-loading regressions**
### Task 5: Launch the Phase-2 experiment
**Files:**
- Update experiment tracking under `experiment_suites/`
- [ ] **Step 1: Use Phase-1 best setting (`pred_horizon=16`, `num_action_steps=8`)**
- [ ] **Step 2: Launch baseline reference or reuse existing result**
- [ ] **Step 3: Launch full-AttnRes vision experiment**
- [ ] **Step 4: Track rollout metrics and compare max avg_reward**

View File

@@ -0,0 +1,81 @@
# Phase-2 Full-AttnRes Vision Design
## Goal
在当前 roboimi IMF policy 中,把视觉 backbone 里原先由 ResNet BasicBlock/Bottleneck 提供的残差单元全部替换为 AttnRes 风格单元,同时尽量保持现有 agent / cond / rollout / 训练脚本接口不变。
## User requirement interpretation
这里按最严格解释执行:
- 不是“在 ResNet 后面再加一个 AttnRes 模块”
- 也不是“只在某几个 stage 加 AttnRes 混合”
- 而是:视觉主干网络中原本依赖 ResNet residual block 的地方,统一改成 AttnRes residual operator 驱动的 block
- 最终仍然输出与现有 `ResNetDiffusionBackbone` 相同的每相机特征接口,以便复用 `SpatialSoftmax -> Linear -> ReLU`、多相机拼接、state concat、IMF head 条件输入
## Recommended design
### Option A (recommended)
保留 ResNet 的宏观 stage/stem 结构与通道/步幅规划,但把每个 stage 内的 BasicBlock/Bottleneck 替换为新的 `AttnResImageBlock2D`
- 输入仍是 `(B, C, H, W)` feature map
- block 内先把空间维 flatten 成 token 序列 `(B, H*W, C)`
- 用二维位置编码 / 可学习位置偏置 + AttnRes self-attention + AttnRes FFN 完成 block 变换
- 再 reshape 回 `(B, C, H, W)`
- stage 间下采样仍由显式 stride/downsample path 完成
优点:
- 最接近“ResNet 中所有残差都由 AttnRes 代替”的要求
- 保留现有视觉输出接口和 cond_dim不用改 agent/head/data pipeline
- 仍可沿用现有多相机编码器框架
缺点:
- 需要新写 2D 版 AttnRes image block而不是直接复用 1D IMF head block
### Option B
完全移除 ResNet stage换成 patchify + ViT/AttnRes 图像 transformer再接 SpatialSoftmax/MLP。
优点:实现概念更统一。
缺点:已经不算“把 ResNet 中残差替换掉”,而是直接换 backbone和用户要求不完全一致。
### Option C
保留现有 ResNet block只在 block 外层加 AttnRes mixing。
不推荐,因为不满足“所有残差均由 AttnRes 替代”。
## Concrete architecture choice
采用 Option A
1. 保留 stemconv/bn-or-gn/relu/maxpool与 stage 边界
2. 新增 `AttnResImageBlock2D`
3. 新增 `AttnResResNetLikeBackbone2D`,负责堆叠 stage/block
4.`ResNetDiffusionBackbone` 中增加可选 backbone mode例如
- `vision_backbone_mode: resnet`
- `vision_backbone_mode: attnres_resnet`
5. `resnet_imf_attnres` agent 配置新增一个 Phase-2 变体,默认打开 `attnres_resnet`
6. 仍保持:
- 每相机输出 `64`
- 多相机总视觉输出 `3 * 64`
- 与 state 拼接后 `cond_dim = 208`
## Files likely to change
- `roboimi/vla/models/backbones/resnet_diffusion.py`
- `roboimi/vla/conf/backbone/resnet_diffusion.yaml`
- `roboimi/vla/conf/agent/resnet_imf_attnres.yaml`
- new: `roboimi/vla/models/backbones/attnres_resnet2d.py`
- tests:
- new: `tests/test_attnres_resnet2d_backbone.py`
- update/add wiring test for agent cond dims
## Test plan
1. New backbone instantiates and forwards `(B,T,C,H,W)` multi-camera input
2. Output shape unchanged vs current backbone
3. `output_dim == 64`
4. 3-camera cond path still yields `208`
5. Phase-2 config instantiates full IMF agent successfully
6. One short CPU smoke forward for `compute_loss`
## Phase-2 experiment plan
固定使用 Phase-1 最优组合:
- `pred_horizon=16`
- `num_action_steps=8`
比较:
1. baseline: current IMF head-only AttnRes + original ResNet vision backbone
2. phase2: IMF head AttnRes + full AttnRes-replaced vision backbone
训练超参保持与 Phase-1 最优设置一致,先跑一组 50k step 对比。

View File

@@ -1,7 +1,7 @@
rank,run_id,status,pred_horizon,num_action_steps,best_rollout_avg_reward,best_step,final_step,final_loss,host,run_dir rank,run_id,status,pred_horizon,num_action_steps,best_rollout_avg_reward,best_step,final_step,final_loss,host,run_dir,latest_step
1,ph16_ex8,finished,16,8,610.8,21874,50000,0.0034315965604037046,100.73.14.65,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223 1,ph16_ex8,running,16,8,610.8,21874,50000,0.0034315965604037046,100.73.14.65,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223,50000
2,ph16_ex16,finished,16,16,561.2,48124,50000,0.004544622730463743,100.119.99.14,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex16-emb384-l12-ms50k-l20g0-20260404-131223 2,ph16_ex16,running,16,16,561.2,48124,50000,0.004544622730463743,100.119.99.14,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex16-emb384-l12-ms50k-l20g0-20260404-131223,50000
3,ph32_ex32,finished,32,32,513.2,43749,50000,0.003953303210437298,local,/home/droid/project/roboimi/.worktrees/feat-imf-attnres-policy/runs/imf-p1-ph32-ex32-emb384-l12-ms50k-5090-20260404-131223 3,ph32_ex32,finished,32,32,513.2,43749,50000,0.003953303210437298,local,/home/droid/project/roboimi/.worktrees/feat-imf-attnres-policy/runs/imf-p1-ph32-ex32-emb384-l12-ms50k-5090-20260404-131223,49900
4,ph8_ex8,finished,8,8,415.6,48124,50000,0.007008877582848072,100.73.14.65,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph08-ex08-emb384-l12-ms50k-5880g0-20260404-131223 4,ph8_ex8,running,8,8,415.6,48124,50000,0.007008877582848072,100.73.14.65,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph08-ex08-emb384-l12-ms50k-5880g0-20260404-131223,50000
5,ph32_ex8,finished,32,8,361.6,43749,50000,0.004788532387465239,100.119.99.14,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph32-ex08-emb384-l12-ms50k-l20g1-20260404-131223 5,ph32_ex8,running,32,8,361.6,43749,50000,0.004788532387465239,100.119.99.14,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph32-ex08-emb384-l12-ms50k-l20g1-20260404-131223,50000
6,ph32_ex16,finished,32,16,239.6,48124,50000,0.0038348555099219084,100.119.99.14,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph32-ex16-emb384-l12-ms50k-l20g2-20260404-131223 6,ph32_ex16,running,32,16,239.6,48124,50000,0.0038348555099219084,100.119.99.14,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph32-ex16-emb384-l12-ms50k-l20g2-20260404-131223,50000
1 rank run_id status pred_horizon num_action_steps best_rollout_avg_reward best_step final_step final_loss host run_dir latest_step
2 1 ph16_ex8 finished running 16 8 610.8 21874 50000 0.0034315965604037046 100.73.14.65 /home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223 50000
3 2 ph16_ex16 finished running 16 16 561.2 48124 50000 0.004544622730463743 100.119.99.14 /home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex16-emb384-l12-ms50k-l20g0-20260404-131223 50000
4 3 ph32_ex32 finished 32 32 513.2 43749 50000 0.003953303210437298 local /home/droid/project/roboimi/.worktrees/feat-imf-attnres-policy/runs/imf-p1-ph32-ex32-emb384-l12-ms50k-5090-20260404-131223 49900
5 4 ph8_ex8 finished running 8 8 415.6 48124 50000 0.007008877582848072 100.73.14.65 /home/droid/roboimi_suite_20260404/runs/imf-p1-ph08-ex08-emb384-l12-ms50k-5880g0-20260404-131223 50000
6 5 ph32_ex8 finished running 32 8 361.6 43749 50000 0.004788532387465239 100.119.99.14 /home/droid/roboimi_suite_20260404/runs/imf-p1-ph32-ex08-emb384-l12-ms50k-l20g1-20260404-131223 50000
7 6 ph32_ex16 finished running 32 16 239.6 48124 50000 0.0038348555099219084 100.119.99.14 /home/droid/roboimi_suite_20260404/runs/imf-p1-ph32-ex16-emb384-l12-ms50k-l20g2-20260404-131223 50000

View File

@@ -1,6 +1,6 @@
{ {
"suite_name": "2026-04-04-imf-horizon-grid", "suite_name": "2026-04-04-imf-horizon-grid",
"updated_at": "2026-04-04 23:46:01", "updated_at": "2026-04-05 00:07:39",
"phase": "phase1_completed", "phase": "phase1_completed",
"provisioning": { "provisioning": {
"100.119.99.14": { "100.119.99.14": {
@@ -17,7 +17,7 @@
}, },
"runs": { "runs": {
"ph8_ex8": { "ph8_ex8": {
"status": "finished", "status": "running",
"host": "100.73.14.65", "host": "100.73.14.65",
"gpu": 0, "gpu": 0,
"run_name": "imf-p1-ph08-ex08-emb384-l12-ms50k-5880g0-20260404-131223", "run_name": "imf-p1-ph08-ex08-emb384-l12-ms50k-5880g0-20260404-131223",
@@ -30,15 +30,15 @@
"pid": 938714, "pid": 938714,
"launch_log": "experiment_suite_launch_logs/imf-p1-ph08-ex08-emb384-l12-ms50k-5880g0-20260404-131223.restartfix-20260404-143827.log", "launch_log": "experiment_suite_launch_logs/imf-p1-ph08-ex08-emb384-l12-ms50k-5880g0-20260404-131223.restartfix-20260404-143827.log",
"latest_step": 50000, "latest_step": 50000,
"latest_log_sync": "2026-04-04 23:42:34", "latest_log_sync": "2026-04-05 00:07:39",
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/i5syc57b6zq7rbkrtqy7b", "swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/i5syc57b6zq7rbkrtqy7b",
"process_running": false, "process_running": true,
"best_step": 48124, "best_step": 48124,
"best_rollout_avg_reward": 415.6, "best_rollout_avg_reward": 415.6,
"final_loss": 0.007008877582848072 "final_loss": 0.007008877582848072
}, },
"ph16_ex8": { "ph16_ex8": {
"status": "finished", "status": "running",
"host": "100.73.14.65", "host": "100.73.14.65",
"gpu": 1, "gpu": 1,
"run_name": "imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223", "run_name": "imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223",
@@ -51,15 +51,15 @@
"pid": 938717, "pid": 938717,
"launch_log": "experiment_suite_launch_logs/imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223.restartfix-20260404-143827.log", "launch_log": "experiment_suite_launch_logs/imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223.restartfix-20260404-143827.log",
"latest_step": 50000, "latest_step": 50000,
"latest_log_sync": "2026-04-04 23:42:34", "latest_log_sync": "2026-04-05 00:07:39",
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/4rusbrpfxmw4ffii1ul5w", "swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/4rusbrpfxmw4ffii1ul5w",
"process_running": false, "process_running": true,
"best_step": 21874, "best_step": 21874,
"best_rollout_avg_reward": 610.8, "best_rollout_avg_reward": 610.8,
"final_loss": 0.0034315965604037046 "final_loss": 0.0034315965604037046
}, },
"ph16_ex16": { "ph16_ex16": {
"status": "finished", "status": "running",
"host": "100.119.99.14", "host": "100.119.99.14",
"gpu": 0, "gpu": 0,
"run_name": "imf-p1-ph16-ex16-emb384-l12-ms50k-l20g0-20260404-131223", "run_name": "imf-p1-ph16-ex16-emb384-l12-ms50k-l20g0-20260404-131223",
@@ -71,16 +71,16 @@
"num_action_steps": 16, "num_action_steps": 16,
"pid": 90169, "pid": 90169,
"launch_log": "experiment_suite_launch_logs/imf-p1-ph16-ex16-emb384-l12-ms50k-l20g0-20260404-131223.restartfix-20260404-143827.log", "launch_log": "experiment_suite_launch_logs/imf-p1-ph16-ex16-emb384-l12-ms50k-l20g0-20260404-131223.restartfix-20260404-143827.log",
"latest_log_sync": "2026-04-04 23:42:34", "latest_log_sync": "2026-04-05 00:07:39",
"latest_step": 50000, "latest_step": 50000,
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/wwm232k6190gexnze8mg6", "swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/wwm232k6190gexnze8mg6",
"process_running": false, "process_running": true,
"best_step": 48124, "best_step": 48124,
"best_rollout_avg_reward": 561.2, "best_rollout_avg_reward": 561.2,
"final_loss": 0.004544622730463743 "final_loss": 0.004544622730463743
}, },
"ph32_ex8": { "ph32_ex8": {
"status": "finished", "status": "running",
"host": "100.119.99.14", "host": "100.119.99.14",
"gpu": 1, "gpu": 1,
"run_name": "imf-p1-ph32-ex08-emb384-l12-ms50k-l20g1-20260404-131223", "run_name": "imf-p1-ph32-ex08-emb384-l12-ms50k-l20g1-20260404-131223",
@@ -92,16 +92,16 @@
"num_action_steps": 8, "num_action_steps": 8,
"pid": 90173, "pid": 90173,
"launch_log": "experiment_suite_launch_logs/imf-p1-ph32-ex08-emb384-l12-ms50k-l20g1-20260404-131223.restartfix-20260404-143827.log", "launch_log": "experiment_suite_launch_logs/imf-p1-ph32-ex08-emb384-l12-ms50k-l20g1-20260404-131223.restartfix-20260404-143827.log",
"latest_log_sync": "2026-04-04 23:42:34", "latest_log_sync": "2026-04-05 00:07:39",
"latest_step": 50000, "latest_step": 50000,
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/o5y2xjb2rsb3lmfcuhy4p", "swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/o5y2xjb2rsb3lmfcuhy4p",
"process_running": false, "process_running": true,
"best_step": 43749, "best_step": 43749,
"best_rollout_avg_reward": 361.6, "best_rollout_avg_reward": 361.6,
"final_loss": 0.004788532387465239 "final_loss": 0.004788532387465239
}, },
"ph32_ex16": { "ph32_ex16": {
"status": "finished", "status": "running",
"host": "100.119.99.14", "host": "100.119.99.14",
"gpu": 2, "gpu": 2,
"run_name": "imf-p1-ph32-ex16-emb384-l12-ms50k-l20g2-20260404-131223", "run_name": "imf-p1-ph32-ex16-emb384-l12-ms50k-l20g2-20260404-131223",
@@ -113,10 +113,10 @@
"num_action_steps": 16, "num_action_steps": 16,
"pid": 90175, "pid": 90175,
"launch_log": "experiment_suite_launch_logs/imf-p1-ph32-ex16-emb384-l12-ms50k-l20g2-20260404-131223.restartfix-20260404-143827.log", "launch_log": "experiment_suite_launch_logs/imf-p1-ph32-ex16-emb384-l12-ms50k-l20g2-20260404-131223.restartfix-20260404-143827.log",
"latest_log_sync": "2026-04-04 23:42:34", "latest_log_sync": "2026-04-05 00:07:39",
"latest_step": 50000, "latest_step": 50000,
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/54cjpgba9eqsopdm0l8d3", "swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/54cjpgba9eqsopdm0l8d3",
"process_running": false, "process_running": true,
"best_step": 48124, "best_step": 48124,
"best_rollout_avg_reward": 239.6, "best_rollout_avg_reward": 239.6,
"final_loss": 0.0038348555099219084 "final_loss": 0.0038348555099219084
@@ -134,8 +134,8 @@
"num_action_steps": 32, "num_action_steps": 32,
"pid": 1437836, "pid": 1437836,
"launch_log": "experiment_suites/2026-04-04-imf-horizon-grid/launch_logs/imf-p1-ph32-ex32-emb384-l12-ms50k-5090-20260404-131223.launch.log", "launch_log": "experiment_suites/2026-04-04-imf-horizon-grid/launch_logs/imf-p1-ph32-ex32-emb384-l12-ms50k-5090-20260404-131223.launch.log",
"latest_step": 50000, "latest_step": 49900,
"latest_log_sync": "2026-04-04 23:42:34", "latest_log_sync": "2026-04-05 00:07:39",
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/ajs2m218jd260hawhy5ns", "swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/ajs2m218jd260hawhy5ns",
"process_running": false, "process_running": false,
"latest_rollout_avg_reward": 513.2, "latest_rollout_avg_reward": 513.2,

View File

@@ -5,6 +5,7 @@ _target_: roboimi.vla.models.backbones.resnet_diffusion.ResNetDiffusionBackbone
# ==================== # ====================
vision_backbone: "resnet18" # torchvision 模型名称: resnet18, resnet34, resnet50 vision_backbone: "resnet18" # torchvision 模型名称: resnet18, resnet34, resnet50
pretrained_backbone_weights: "IMAGENET1K_V1" # 使用ImageNet预训练权重torchvision>=0.13 pretrained_backbone_weights: "IMAGENET1K_V1" # 使用ImageNet预训练权重torchvision>=0.13
vision_backbone_mode: "resnet" # resnet | attnres_resnet
# ==================== # ====================
# 冻结设置 # 冻结设置
@@ -31,3 +32,17 @@ spatial_softmax_num_keypoints: 32 # Spatial Softmax 关键点数量
# true: 独立编码器(每个摄像头有独立的 ResNet参数多但容量大 # true: 独立编码器(每个摄像头有独立的 ResNet参数多但容量大
use_separate_rgb_encoder_per_camera: true use_separate_rgb_encoder_per_camera: true
num_cameras: 3 # 摄像头数量 num_cameras: 3 # 摄像头数量
# ====================
# Full-AttnRes vision trunk当 vision_backbone_mode=attnres_resnet 时生效)
# ====================
attnres_stem_dim: 64
attnres_stage_dims: [64, 128, 256, 512]
attnres_stage_depths: [2, 2, 2, 2]
attnres_stage_heads: [4, 4, 8, 8]
attnres_stage_kv_heads: [1, 1, 1, 1]
attnres_stage_window_sizes: [7, 7, 7, 7]
attnres_dropout: 0.0
attnres_ffn_mult: 2.667
attnres_eps: 1.0e-06
attnres_rope_theta: 10000.0

View File

@@ -0,0 +1,228 @@
from __future__ import annotations
from typing import Iterable, Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
from roboimi.vla.models.heads.attnres_transformer_components import AttnResTransformerBackbone
def _make_norm2d(num_channels: int, use_group_norm: bool) -> nn.Module:
if use_group_norm:
num_groups = max(1, num_channels // 16)
return nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)
return nn.BatchNorm2d(num_channels)
class _ConvNormAct(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
*,
kernel_size: int,
stride: int,
padding: int,
use_group_norm: bool,
) -> None:
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=False,
),
_make_norm2d(out_channels, use_group_norm),
nn.SiLU(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.block(x)
class AttnResImageBlock2D(nn.Module):
def __init__(
self,
dim: int,
*,
window_size: int,
n_heads: int,
n_kv_heads: int,
dropout: float,
ffn_mult: float,
eps: float,
rope_theta: float,
) -> None:
super().__init__()
self.window_size = int(window_size)
self.block = AttnResTransformerBackbone(
d_model=dim,
n_blocks=1,
n_heads=n_heads,
n_kv_heads=n_kv_heads,
max_seq_len=self.window_size * self.window_size,
dropout=dropout,
ffn_mult=ffn_mult,
eps=eps,
rope_theta=rope_theta,
causal_attn=False,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
bsz, channels, height, width = x.shape
ws = self.window_size
pad_h = (ws - height % ws) % ws
pad_w = (ws - width % ws) % ws
if pad_h or pad_w:
x = F.pad(x, (0, pad_w, 0, pad_h))
padded_height, padded_width = x.shape[-2:]
num_h = padded_height // ws
num_w = padded_width // ws
windows = (
x.permute(0, 2, 3, 1)
.contiguous()
.view(bsz, num_h, ws, num_w, ws, channels)
.permute(0, 1, 3, 2, 4, 5)
.contiguous()
.view(bsz * num_h * num_w, ws * ws, channels)
)
windows = self.block(windows)
x = (
windows.view(bsz, num_h, num_w, ws, ws, channels)
.permute(0, 1, 3, 2, 4, 5)
.contiguous()
.view(bsz, padded_height, padded_width, channels)
.permute(0, 3, 1, 2)
.contiguous()
)
return x[:, :, :height, :width]
class _AttnResStage2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
*,
depth: int,
downsample_stride: int,
use_group_norm: bool,
window_size: int,
n_heads: int,
n_kv_heads: int,
dropout: float,
ffn_mult: float,
eps: float,
rope_theta: float,
) -> None:
super().__init__()
self.downsample = None
if in_channels != out_channels or downsample_stride != 1:
kernel_size = 1 if downsample_stride == 1 else 3
padding = 0 if downsample_stride == 1 else 1
self.downsample = _ConvNormAct(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=downsample_stride,
padding=padding,
use_group_norm=use_group_norm,
)
self.blocks = nn.ModuleList(
[
AttnResImageBlock2D(
out_channels,
window_size=window_size,
n_heads=n_heads,
n_kv_heads=n_kv_heads,
dropout=dropout,
ffn_mult=ffn_mult,
eps=eps,
rope_theta=rope_theta,
)
for _ in range(depth)
]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.downsample is not None:
x = self.downsample(x)
for block in self.blocks:
x = block(x)
return x
class AttnResResNetLikeBackbone2D(nn.Module):
def __init__(
self,
*,
input_channels: int = 3,
stem_dim: int = 64,
stage_dims: Sequence[int] = (64, 128, 256, 512),
stage_depths: Sequence[int] = (2, 2, 2, 2),
stage_heads: Sequence[int] = (4, 4, 8, 8),
stage_kv_heads: Sequence[int] = (1, 1, 1, 1),
stage_window_sizes: Sequence[int] = (7, 7, 7, 7),
use_group_norm: bool = True,
dropout: float = 0.0,
ffn_mult: float = 2.667,
eps: float = 1e-6,
rope_theta: float = 10000.0,
) -> None:
super().__init__()
lengths = {
len(stage_dims),
len(stage_depths),
len(stage_heads),
len(stage_kv_heads),
len(stage_window_sizes),
}
if len(lengths) != 1:
raise ValueError('stage_dims/depths/heads/kv_heads/window_sizes 长度必须一致')
self.stem = nn.Sequential(
nn.Conv2d(input_channels, stem_dim, kernel_size=7, stride=2, padding=3, bias=False),
_make_norm2d(stem_dim, use_group_norm),
nn.SiLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
)
in_channels = stem_dim
stages = []
for stage_idx, (out_channels, depth, n_heads, n_kv_heads, window_size) in enumerate(
zip(stage_dims, stage_depths, stage_heads, stage_kv_heads, stage_window_sizes)
):
stages.append(
_AttnResStage2D(
in_channels=in_channels,
out_channels=out_channels,
depth=int(depth),
downsample_stride=1 if stage_idx == 0 else 2,
use_group_norm=use_group_norm,
window_size=int(window_size),
n_heads=int(n_heads),
n_kv_heads=int(n_kv_heads),
dropout=float(dropout),
ffn_mult=float(ffn_mult),
eps=float(eps),
rope_theta=float(rope_theta),
)
)
in_channels = out_channels
self.stages = nn.ModuleList(stages)
self.output_channels = in_channels
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.stem(x)
for stage in self.stages:
x = stage(x)
return x

View File

@@ -6,6 +6,8 @@ import torchvision
import numpy as np import numpy as np
from typing import Callable, Optional, Tuple, Union from typing import Callable, Optional, Tuple, Union
from .attnres_resnet2d import AttnResResNetLikeBackbone2D
def _replace_submodules( def _replace_submodules(
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module] root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
) -> nn.Module: ) -> nn.Module:
@@ -103,6 +105,17 @@ class _SingleRgbEncoder(nn.Module):
use_group_norm: bool, use_group_norm: bool,
spatial_softmax_num_keypoints: int, spatial_softmax_num_keypoints: int,
freeze_backbone: bool = True, # 新增是否冻结backbone freeze_backbone: bool = True, # 新增是否冻结backbone
vision_backbone_mode: str = "resnet",
attnres_stem_dim: int = 64,
attnres_stage_dims: Optional[Tuple[int, ...]] = None,
attnres_stage_depths: Optional[Tuple[int, ...]] = None,
attnres_stage_heads: Optional[Tuple[int, ...]] = None,
attnres_stage_kv_heads: Optional[Tuple[int, ...]] = None,
attnres_stage_window_sizes: Optional[Tuple[int, ...]] = None,
attnres_dropout: float = 0.0,
attnres_ffn_mult: float = 2.667,
attnres_eps: float = 1e-6,
attnres_rope_theta: float = 10000.0,
): ):
super().__init__() super().__init__()
@@ -119,6 +132,7 @@ class _SingleRgbEncoder(nn.Module):
self.do_crop = False self.do_crop = False
crop_shape = input_shape[1:] crop_shape = input_shape[1:]
if vision_backbone_mode == "resnet":
# 设置骨干网络 # 设置骨干网络
backbone_model = getattr(torchvision.models, vision_backbone)( backbone_model = getattr(torchvision.models, vision_backbone)(
weights=pretrained_backbone_weights weights=pretrained_backbone_weights
@@ -131,8 +145,28 @@ class _SingleRgbEncoder(nn.Module):
self.backbone = _replace_submodules( self.backbone = _replace_submodules(
root_module=self.backbone, root_module=self.backbone,
predicate=lambda x: isinstance(x, nn.BatchNorm2d), predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features), func=lambda x: nn.GroupNorm(
num_groups=max(1, x.num_features // 16),
num_channels=x.num_features,
),
) )
elif vision_backbone_mode == "attnres_resnet":
self.backbone = AttnResResNetLikeBackbone2D(
input_channels=input_shape[0],
stem_dim=attnres_stem_dim,
stage_dims=tuple(attnres_stage_dims or (64, 128, 256, 512)),
stage_depths=tuple(attnres_stage_depths or (2, 2, 2, 2)),
stage_heads=tuple(attnres_stage_heads or (4, 4, 8, 8)),
stage_kv_heads=tuple(attnres_stage_kv_heads or (1, 1, 1, 1)),
stage_window_sizes=tuple(attnres_stage_window_sizes or (7, 7, 7, 7)),
use_group_norm=use_group_norm,
dropout=attnres_dropout,
ffn_mult=attnres_ffn_mult,
eps=attnres_eps,
rope_theta=attnres_rope_theta,
)
else:
raise ValueError(f"不支持的 vision_backbone_mode: {vision_backbone_mode}")
# 冻结backbone参数可选 # 冻结backbone参数可选
if freeze_backbone: if freeze_backbone:
@@ -180,6 +214,17 @@ class ResNetDiffusionBackbone(VLABackbone):
num_cameras: int = 1, # 新增:摄像头数量(仅在独立编码器模式下使用) num_cameras: int = 1, # 新增:摄像头数量(仅在独立编码器模式下使用)
camera_names: Optional[Tuple[str, ...]] = None, # 显式相机顺序 camera_names: Optional[Tuple[str, ...]] = None, # 显式相机顺序
freeze_backbone: bool = True, # 新增是否冻结ResNet backbone推荐True freeze_backbone: bool = True, # 新增是否冻结ResNet backbone推荐True
vision_backbone_mode: str = "resnet",
attnres_stem_dim: int = 64,
attnres_stage_dims: Optional[Tuple[int, ...]] = None,
attnres_stage_depths: Optional[Tuple[int, ...]] = None,
attnres_stage_heads: Optional[Tuple[int, ...]] = None,
attnres_stage_kv_heads: Optional[Tuple[int, ...]] = None,
attnres_stage_window_sizes: Optional[Tuple[int, ...]] = None,
attnres_dropout: float = 0.0,
attnres_ffn_mult: float = 2.667,
attnres_eps: float = 1e-6,
attnres_rope_theta: float = 10000.0,
): ):
super().__init__() super().__init__()
@@ -203,6 +248,17 @@ class ResNetDiffusionBackbone(VLABackbone):
use_group_norm=use_group_norm, use_group_norm=use_group_norm,
spatial_softmax_num_keypoints=spatial_softmax_num_keypoints, spatial_softmax_num_keypoints=spatial_softmax_num_keypoints,
freeze_backbone=freeze_backbone, freeze_backbone=freeze_backbone,
vision_backbone_mode=vision_backbone_mode,
attnres_stem_dim=attnres_stem_dim,
attnres_stage_dims=attnres_stage_dims,
attnres_stage_depths=attnres_stage_depths,
attnres_stage_heads=attnres_stage_heads,
attnres_stage_kv_heads=attnres_stage_kv_heads,
attnres_stage_window_sizes=attnres_stage_window_sizes,
attnres_dropout=attnres_dropout,
attnres_ffn_mult=attnres_ffn_mult,
attnres_eps=attnres_eps,
attnres_rope_theta=attnres_rope_theta,
) )
for _ in range(num_cameras) for _ in range(num_cameras)
] ]
@@ -220,6 +276,17 @@ class ResNetDiffusionBackbone(VLABackbone):
use_group_norm=use_group_norm, use_group_norm=use_group_norm,
spatial_softmax_num_keypoints=spatial_softmax_num_keypoints, spatial_softmax_num_keypoints=spatial_softmax_num_keypoints,
freeze_backbone=freeze_backbone, freeze_backbone=freeze_backbone,
vision_backbone_mode=vision_backbone_mode,
attnres_stem_dim=attnres_stem_dim,
attnres_stage_dims=attnres_stage_dims,
attnres_stage_depths=attnres_stage_depths,
attnres_stage_heads=attnres_stage_heads,
attnres_stage_kv_heads=attnres_stage_kv_heads,
attnres_stage_window_sizes=attnres_stage_window_sizes,
attnres_dropout=attnres_dropout,
attnres_ffn_mult=attnres_ffn_mult,
attnres_eps=attnres_eps,
attnres_rope_theta=attnres_rope_theta,
) )
self.feature_dim = self.rgb_encoder.feature_dim self.feature_dim = self.rgb_encoder.feature_dim

View File

@@ -0,0 +1,26 @@
import unittest
import torch
class AttnResResNet2DBackboneTest(unittest.TestCase):
def test_backbone_preserves_resnet_like_stage_contract(self):
from roboimi.vla.models.backbones.attnres_resnet2d import AttnResResNetLikeBackbone2D
backbone = AttnResResNetLikeBackbone2D(
input_channels=3,
stem_dim=16,
stage_dims=(16, 32, 64, 128),
stage_depths=(1, 1, 1, 1),
stage_heads=(2, 4, 4, 8),
stage_kv_heads=(1, 1, 1, 1),
stage_window_sizes=(7, 7, 7, 7),
dropout=0.0,
)
x = torch.randn(2, 3, 56, 56)
y = backbone(x)
self.assertEqual(y.shape, (2, 128, 2, 2))
if __name__ == '__main__':
unittest.main()

View File

@@ -422,6 +422,32 @@ class IMFVLAAgentTest(unittest.TestCase):
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], agent.per_step_cond_dim) self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], agent.per_step_cond_dim)
self.assertEqual(agent.noise_pred_net.constructor_kwargs['backbone_type'], 'attnres_full') self.assertEqual(agent.noise_pred_net.constructor_kwargs['backbone_type'], 'attnres_full')
def test_hydra_config_instantiates_resnet_imf_attnres_with_full_attnres_vision_backbone(self):
cfg = _compose_cfg(
overrides=[
'agent=resnet_imf_attnres',
'agent.vision_backbone.vision_backbone_mode=attnres_resnet',
'agent.vision_backbone.pretrained_backbone_weights=null',
'agent.vision_backbone.input_shape=[3,56,56]',
'agent.vision_backbone.freeze_backbone=false',
'agent.vision_backbone.attnres_stem_dim=16',
'agent.vision_backbone.attnres_stage_dims=[16,32,64,128]',
'agent.vision_backbone.attnres_stage_depths=[1,1,1,1]',
'agent.vision_backbone.attnres_stage_heads=[2,4,4,8]',
'agent.vision_backbone.attnres_stage_kv_heads=[1,1,1,1]',
'agent.vision_backbone.attnres_stage_window_sizes=[7,7,7,7]',
'agent.head.n_layer=1',
'agent.head.n_emb=16',
]
)
with _stub_optional_modules(include_imf_head=True):
agent = instantiate(cfg.agent)
self.assertEqual(agent.vision_encoder.output_dim, 64)
self.assertEqual(agent.per_step_cond_dim, 64 * agent.num_cams + agent.obs_dim)
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], agent.per_step_cond_dim)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()