feat: add full attnres vision backbone
This commit is contained in:
@@ -0,0 +1,64 @@
|
||||
# Phase-2 Full-AttnRes Vision Implementation Plan
|
||||
|
||||
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||
|
||||
**Goal:** Replace all ResNet residual units in the vision backbone with AttnRes-based image blocks while preserving the current IMF agent interfaces and launch a Phase-2 experiment anchored on the best Phase-1 horizon setting.
|
||||
|
||||
**Architecture:** Keep the current multi-camera encoder shell and per-camera output contract, but introduce a new ResNet-like 2D AttnRes backbone that preserves stage-wise downsampling and final SpatialSoftmax conditioning. Wire it into the existing `ResNetDiffusionBackbone` via an opt-in mode and keep the agent/head/data interfaces unchanged.
|
||||
|
||||
**Tech Stack:** PyTorch, Hydra/OmegaConf, existing IMF AttnRes transformer components, pytest.
|
||||
|
||||
---
|
||||
|
||||
### Task 1: Add failing tests for the new full-AttnRes visual backbone
|
||||
|
||||
**Files:**
|
||||
- Create: `tests/test_attnres_resnet2d_backbone.py`
|
||||
- Update: `tests/test_imf_vla_agent.py`
|
||||
|
||||
- [ ] **Step 1: Write a failing backbone shape test**
|
||||
- [ ] **Step 2: Run it to confirm the new backbone/config does not exist yet**
|
||||
- [ ] **Step 3: Add a failing IMF agent wiring test for unchanged cond_dim=208**
|
||||
- [ ] **Step 4: Run the targeted tests and capture the failure**
|
||||
|
||||
### Task 2: Implement a ResNet-like 2D AttnRes backbone
|
||||
|
||||
**Files:**
|
||||
- Create: `roboimi/vla/models/backbones/attnres_resnet2d.py`
|
||||
- Modify: `roboimi/vla/models/backbones/resnet_diffusion.py`
|
||||
|
||||
- [ ] **Step 1: Add minimal 2D tokenization helpers and positional encoding / bias handling**
|
||||
- [ ] **Step 2: Implement `AttnResImageBlock2D` for feature maps**
|
||||
- [ ] **Step 3: Implement `AttnResResNetLikeBackbone2D` with stage-wise downsampling**
|
||||
- [ ] **Step 4: Wire `_SingleRgbEncoder` to choose between original ResNet trunk and the new full-AttnRes trunk**
|
||||
- [ ] **Step 5: Run the new backbone tests**
|
||||
|
||||
### Task 3: Expose config switches and agent wiring
|
||||
|
||||
**Files:**
|
||||
- Modify: `roboimi/vla/conf/backbone/resnet_diffusion.yaml`
|
||||
- Modify: `roboimi/vla/conf/agent/resnet_imf_attnres.yaml`
|
||||
|
||||
- [ ] **Step 1: Add a backbone mode/config flag for the full-AttnRes vision trunk**
|
||||
- [ ] **Step 2: Add defaults for attnres image depth/heads/etc. if needed**
|
||||
- [ ] **Step 3: Add a Phase-2 launch override path that enables the new visual trunk**
|
||||
- [ ] **Step 4: Run agent wiring tests again**
|
||||
|
||||
### Task 4: Smoke-verify training path
|
||||
|
||||
**Files:**
|
||||
- Reuse existing training scripts and configs
|
||||
|
||||
- [ ] **Step 1: Run a short CPU or tiny-step smoke instantiation / `compute_loss` test**
|
||||
- [ ] **Step 2: If needed, run a very short training smoke launch**
|
||||
- [ ] **Step 3: Verify no cond-dim or rollout-loading regressions**
|
||||
|
||||
### Task 5: Launch the Phase-2 experiment
|
||||
|
||||
**Files:**
|
||||
- Update experiment tracking under `experiment_suites/`
|
||||
|
||||
- [ ] **Step 1: Use Phase-1 best setting (`pred_horizon=16`, `num_action_steps=8`)**
|
||||
- [ ] **Step 2: Launch baseline reference or reuse existing result**
|
||||
- [ ] **Step 3: Launch full-AttnRes vision experiment**
|
||||
- [ ] **Step 4: Track rollout metrics and compare max avg_reward**
|
||||
@@ -0,0 +1,81 @@
|
||||
# Phase-2 Full-AttnRes Vision Design
|
||||
|
||||
## Goal
|
||||
在当前 roboimi IMF policy 中,把视觉 backbone 里原先由 ResNet BasicBlock/Bottleneck 提供的残差单元全部替换为 AttnRes 风格单元,同时尽量保持现有 agent / cond / rollout / 训练脚本接口不变。
|
||||
|
||||
## User requirement interpretation
|
||||
这里按最严格解释执行:
|
||||
- 不是“在 ResNet 后面再加一个 AttnRes 模块”
|
||||
- 也不是“只在某几个 stage 加 AttnRes 混合”
|
||||
- 而是:视觉主干网络中原本依赖 ResNet residual block 的地方,统一改成 AttnRes residual operator 驱动的 block
|
||||
- 最终仍然输出与现有 `ResNetDiffusionBackbone` 相同的每相机特征接口,以便复用 `SpatialSoftmax -> Linear -> ReLU`、多相机拼接、state concat、IMF head 条件输入
|
||||
|
||||
## Recommended design
|
||||
### Option A (recommended)
|
||||
保留 ResNet 的宏观 stage/stem 结构与通道/步幅规划,但把每个 stage 内的 BasicBlock/Bottleneck 替换为新的 `AttnResImageBlock2D`:
|
||||
- 输入仍是 `(B, C, H, W)` feature map
|
||||
- block 内先把空间维 flatten 成 token 序列 `(B, H*W, C)`
|
||||
- 用二维位置编码 / 可学习位置偏置 + AttnRes self-attention + AttnRes FFN 完成 block 变换
|
||||
- 再 reshape 回 `(B, C, H, W)`
|
||||
- stage 间下采样仍由显式 stride/downsample path 完成
|
||||
|
||||
优点:
|
||||
- 最接近“ResNet 中所有残差都由 AttnRes 代替”的要求
|
||||
- 保留现有视觉输出接口和 cond_dim,不用改 agent/head/data pipeline
|
||||
- 仍可沿用现有多相机编码器框架
|
||||
|
||||
缺点:
|
||||
- 需要新写 2D 版 AttnRes image block,而不是直接复用 1D IMF head block
|
||||
|
||||
### Option B
|
||||
完全移除 ResNet stage,换成 patchify + ViT/AttnRes 图像 transformer,再接 SpatialSoftmax/MLP。
|
||||
|
||||
优点:实现概念更统一。
|
||||
缺点:已经不算“把 ResNet 中残差替换掉”,而是直接换 backbone,和用户要求不完全一致。
|
||||
|
||||
### Option C
|
||||
保留现有 ResNet block,只在 block 外层加 AttnRes mixing。
|
||||
|
||||
不推荐,因为不满足“所有残差均由 AttnRes 替代”。
|
||||
|
||||
## Concrete architecture choice
|
||||
采用 Option A:
|
||||
1. 保留 stem(conv/bn-or-gn/relu/maxpool)与 stage 边界
|
||||
2. 新增 `AttnResImageBlock2D`
|
||||
3. 新增 `AttnResResNetLikeBackbone2D`,负责堆叠 stage/block
|
||||
4. 在 `ResNetDiffusionBackbone` 中增加可选 backbone mode,例如:
|
||||
- `vision_backbone_mode: resnet`
|
||||
- `vision_backbone_mode: attnres_resnet`
|
||||
5. `resnet_imf_attnres` agent 配置新增一个 Phase-2 变体,默认打开 `attnres_resnet`
|
||||
6. 仍保持:
|
||||
- 每相机输出 `64`
|
||||
- 多相机总视觉输出 `3 * 64`
|
||||
- 与 state 拼接后 `cond_dim = 208`
|
||||
|
||||
## Files likely to change
|
||||
- `roboimi/vla/models/backbones/resnet_diffusion.py`
|
||||
- `roboimi/vla/conf/backbone/resnet_diffusion.yaml`
|
||||
- `roboimi/vla/conf/agent/resnet_imf_attnres.yaml`
|
||||
- new: `roboimi/vla/models/backbones/attnres_resnet2d.py`
|
||||
- tests:
|
||||
- new: `tests/test_attnres_resnet2d_backbone.py`
|
||||
- update/add wiring test for agent cond dims
|
||||
|
||||
## Test plan
|
||||
1. New backbone instantiates and forwards `(B,T,C,H,W)` multi-camera input
|
||||
2. Output shape unchanged vs current backbone
|
||||
3. `output_dim == 64`
|
||||
4. 3-camera cond path still yields `208`
|
||||
5. Phase-2 config instantiates full IMF agent successfully
|
||||
6. One short CPU smoke forward for `compute_loss`
|
||||
|
||||
## Phase-2 experiment plan
|
||||
固定使用 Phase-1 最优组合:
|
||||
- `pred_horizon=16`
|
||||
- `num_action_steps=8`
|
||||
|
||||
比较:
|
||||
1. baseline: current IMF head-only AttnRes + original ResNet vision backbone
|
||||
2. phase2: IMF head AttnRes + full AttnRes-replaced vision backbone
|
||||
|
||||
训练超参保持与 Phase-1 最优设置一致,先跑一组 50k step 对比。
|
||||
@@ -1,7 +1,7 @@
|
||||
rank,run_id,status,pred_horizon,num_action_steps,best_rollout_avg_reward,best_step,final_step,final_loss,host,run_dir
|
||||
1,ph16_ex8,finished,16,8,610.8,21874,50000,0.0034315965604037046,100.73.14.65,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223
|
||||
2,ph16_ex16,finished,16,16,561.2,48124,50000,0.004544622730463743,100.119.99.14,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex16-emb384-l12-ms50k-l20g0-20260404-131223
|
||||
3,ph32_ex32,finished,32,32,513.2,43749,50000,0.003953303210437298,local,/home/droid/project/roboimi/.worktrees/feat-imf-attnres-policy/runs/imf-p1-ph32-ex32-emb384-l12-ms50k-5090-20260404-131223
|
||||
4,ph8_ex8,finished,8,8,415.6,48124,50000,0.007008877582848072,100.73.14.65,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph08-ex08-emb384-l12-ms50k-5880g0-20260404-131223
|
||||
5,ph32_ex8,finished,32,8,361.6,43749,50000,0.004788532387465239,100.119.99.14,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph32-ex08-emb384-l12-ms50k-l20g1-20260404-131223
|
||||
6,ph32_ex16,finished,32,16,239.6,48124,50000,0.0038348555099219084,100.119.99.14,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph32-ex16-emb384-l12-ms50k-l20g2-20260404-131223
|
||||
rank,run_id,status,pred_horizon,num_action_steps,best_rollout_avg_reward,best_step,final_step,final_loss,host,run_dir,latest_step
|
||||
1,ph16_ex8,running,16,8,610.8,21874,50000,0.0034315965604037046,100.73.14.65,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223,50000
|
||||
2,ph16_ex16,running,16,16,561.2,48124,50000,0.004544622730463743,100.119.99.14,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex16-emb384-l12-ms50k-l20g0-20260404-131223,50000
|
||||
3,ph32_ex32,finished,32,32,513.2,43749,50000,0.003953303210437298,local,/home/droid/project/roboimi/.worktrees/feat-imf-attnres-policy/runs/imf-p1-ph32-ex32-emb384-l12-ms50k-5090-20260404-131223,49900
|
||||
4,ph8_ex8,running,8,8,415.6,48124,50000,0.007008877582848072,100.73.14.65,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph08-ex08-emb384-l12-ms50k-5880g0-20260404-131223,50000
|
||||
5,ph32_ex8,running,32,8,361.6,43749,50000,0.004788532387465239,100.119.99.14,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph32-ex08-emb384-l12-ms50k-l20g1-20260404-131223,50000
|
||||
6,ph32_ex16,running,32,16,239.6,48124,50000,0.0038348555099219084,100.119.99.14,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph32-ex16-emb384-l12-ms50k-l20g2-20260404-131223,50000
|
||||
|
||||
|
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"suite_name": "2026-04-04-imf-horizon-grid",
|
||||
"updated_at": "2026-04-04 23:46:01",
|
||||
"updated_at": "2026-04-05 00:07:39",
|
||||
"phase": "phase1_completed",
|
||||
"provisioning": {
|
||||
"100.119.99.14": {
|
||||
@@ -17,7 +17,7 @@
|
||||
},
|
||||
"runs": {
|
||||
"ph8_ex8": {
|
||||
"status": "finished",
|
||||
"status": "running",
|
||||
"host": "100.73.14.65",
|
||||
"gpu": 0,
|
||||
"run_name": "imf-p1-ph08-ex08-emb384-l12-ms50k-5880g0-20260404-131223",
|
||||
@@ -30,15 +30,15 @@
|
||||
"pid": 938714,
|
||||
"launch_log": "experiment_suite_launch_logs/imf-p1-ph08-ex08-emb384-l12-ms50k-5880g0-20260404-131223.restartfix-20260404-143827.log",
|
||||
"latest_step": 50000,
|
||||
"latest_log_sync": "2026-04-04 23:42:34",
|
||||
"latest_log_sync": "2026-04-05 00:07:39",
|
||||
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/i5syc57b6zq7rbkrtqy7b",
|
||||
"process_running": false,
|
||||
"process_running": true,
|
||||
"best_step": 48124,
|
||||
"best_rollout_avg_reward": 415.6,
|
||||
"final_loss": 0.007008877582848072
|
||||
},
|
||||
"ph16_ex8": {
|
||||
"status": "finished",
|
||||
"status": "running",
|
||||
"host": "100.73.14.65",
|
||||
"gpu": 1,
|
||||
"run_name": "imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223",
|
||||
@@ -51,15 +51,15 @@
|
||||
"pid": 938717,
|
||||
"launch_log": "experiment_suite_launch_logs/imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223.restartfix-20260404-143827.log",
|
||||
"latest_step": 50000,
|
||||
"latest_log_sync": "2026-04-04 23:42:34",
|
||||
"latest_log_sync": "2026-04-05 00:07:39",
|
||||
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/4rusbrpfxmw4ffii1ul5w",
|
||||
"process_running": false,
|
||||
"process_running": true,
|
||||
"best_step": 21874,
|
||||
"best_rollout_avg_reward": 610.8,
|
||||
"final_loss": 0.0034315965604037046
|
||||
},
|
||||
"ph16_ex16": {
|
||||
"status": "finished",
|
||||
"status": "running",
|
||||
"host": "100.119.99.14",
|
||||
"gpu": 0,
|
||||
"run_name": "imf-p1-ph16-ex16-emb384-l12-ms50k-l20g0-20260404-131223",
|
||||
@@ -71,16 +71,16 @@
|
||||
"num_action_steps": 16,
|
||||
"pid": 90169,
|
||||
"launch_log": "experiment_suite_launch_logs/imf-p1-ph16-ex16-emb384-l12-ms50k-l20g0-20260404-131223.restartfix-20260404-143827.log",
|
||||
"latest_log_sync": "2026-04-04 23:42:34",
|
||||
"latest_log_sync": "2026-04-05 00:07:39",
|
||||
"latest_step": 50000,
|
||||
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/wwm232k6190gexnze8mg6",
|
||||
"process_running": false,
|
||||
"process_running": true,
|
||||
"best_step": 48124,
|
||||
"best_rollout_avg_reward": 561.2,
|
||||
"final_loss": 0.004544622730463743
|
||||
},
|
||||
"ph32_ex8": {
|
||||
"status": "finished",
|
||||
"status": "running",
|
||||
"host": "100.119.99.14",
|
||||
"gpu": 1,
|
||||
"run_name": "imf-p1-ph32-ex08-emb384-l12-ms50k-l20g1-20260404-131223",
|
||||
@@ -92,16 +92,16 @@
|
||||
"num_action_steps": 8,
|
||||
"pid": 90173,
|
||||
"launch_log": "experiment_suite_launch_logs/imf-p1-ph32-ex08-emb384-l12-ms50k-l20g1-20260404-131223.restartfix-20260404-143827.log",
|
||||
"latest_log_sync": "2026-04-04 23:42:34",
|
||||
"latest_log_sync": "2026-04-05 00:07:39",
|
||||
"latest_step": 50000,
|
||||
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/o5y2xjb2rsb3lmfcuhy4p",
|
||||
"process_running": false,
|
||||
"process_running": true,
|
||||
"best_step": 43749,
|
||||
"best_rollout_avg_reward": 361.6,
|
||||
"final_loss": 0.004788532387465239
|
||||
},
|
||||
"ph32_ex16": {
|
||||
"status": "finished",
|
||||
"status": "running",
|
||||
"host": "100.119.99.14",
|
||||
"gpu": 2,
|
||||
"run_name": "imf-p1-ph32-ex16-emb384-l12-ms50k-l20g2-20260404-131223",
|
||||
@@ -113,10 +113,10 @@
|
||||
"num_action_steps": 16,
|
||||
"pid": 90175,
|
||||
"launch_log": "experiment_suite_launch_logs/imf-p1-ph32-ex16-emb384-l12-ms50k-l20g2-20260404-131223.restartfix-20260404-143827.log",
|
||||
"latest_log_sync": "2026-04-04 23:42:34",
|
||||
"latest_log_sync": "2026-04-05 00:07:39",
|
||||
"latest_step": 50000,
|
||||
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/54cjpgba9eqsopdm0l8d3",
|
||||
"process_running": false,
|
||||
"process_running": true,
|
||||
"best_step": 48124,
|
||||
"best_rollout_avg_reward": 239.6,
|
||||
"final_loss": 0.0038348555099219084
|
||||
@@ -134,8 +134,8 @@
|
||||
"num_action_steps": 32,
|
||||
"pid": 1437836,
|
||||
"launch_log": "experiment_suites/2026-04-04-imf-horizon-grid/launch_logs/imf-p1-ph32-ex32-emb384-l12-ms50k-5090-20260404-131223.launch.log",
|
||||
"latest_step": 50000,
|
||||
"latest_log_sync": "2026-04-04 23:42:34",
|
||||
"latest_step": 49900,
|
||||
"latest_log_sync": "2026-04-05 00:07:39",
|
||||
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/ajs2m218jd260hawhy5ns",
|
||||
"process_running": false,
|
||||
"latest_rollout_avg_reward": 513.2,
|
||||
|
||||
@@ -5,6 +5,7 @@ _target_: roboimi.vla.models.backbones.resnet_diffusion.ResNetDiffusionBackbone
|
||||
# ====================
|
||||
vision_backbone: "resnet18" # torchvision 模型名称: resnet18, resnet34, resnet50
|
||||
pretrained_backbone_weights: "IMAGENET1K_V1" # 使用ImageNet预训练权重(torchvision>=0.13)
|
||||
vision_backbone_mode: "resnet" # resnet | attnres_resnet
|
||||
|
||||
# ====================
|
||||
# 冻结设置
|
||||
@@ -31,3 +32,17 @@ spatial_softmax_num_keypoints: 32 # Spatial Softmax 关键点数量
|
||||
# true: 独立编码器(每个摄像头有独立的 ResNet,参数多但容量大)
|
||||
use_separate_rgb_encoder_per_camera: true
|
||||
num_cameras: 3 # 摄像头数量
|
||||
|
||||
# ====================
|
||||
# Full-AttnRes vision trunk(当 vision_backbone_mode=attnres_resnet 时生效)
|
||||
# ====================
|
||||
attnres_stem_dim: 64
|
||||
attnres_stage_dims: [64, 128, 256, 512]
|
||||
attnres_stage_depths: [2, 2, 2, 2]
|
||||
attnres_stage_heads: [4, 4, 8, 8]
|
||||
attnres_stage_kv_heads: [1, 1, 1, 1]
|
||||
attnres_stage_window_sizes: [7, 7, 7, 7]
|
||||
attnres_dropout: 0.0
|
||||
attnres_ffn_mult: 2.667
|
||||
attnres_eps: 1.0e-06
|
||||
attnres_rope_theta: 10000.0
|
||||
|
||||
228
roboimi/vla/models/backbones/attnres_resnet2d.py
Normal file
228
roboimi/vla/models/backbones/attnres_resnet2d.py
Normal file
@@ -0,0 +1,228 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Iterable, Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from roboimi.vla.models.heads.attnres_transformer_components import AttnResTransformerBackbone
|
||||
|
||||
|
||||
def _make_norm2d(num_channels: int, use_group_norm: bool) -> nn.Module:
|
||||
if use_group_norm:
|
||||
num_groups = max(1, num_channels // 16)
|
||||
return nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)
|
||||
return nn.BatchNorm2d(num_channels)
|
||||
|
||||
|
||||
class _ConvNormAct(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
*,
|
||||
kernel_size: int,
|
||||
stride: int,
|
||||
padding: int,
|
||||
use_group_norm: bool,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=False,
|
||||
),
|
||||
_make_norm2d(out_channels, use_group_norm),
|
||||
nn.SiLU(),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class AttnResImageBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
*,
|
||||
window_size: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
dropout: float,
|
||||
ffn_mult: float,
|
||||
eps: float,
|
||||
rope_theta: float,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.window_size = int(window_size)
|
||||
self.block = AttnResTransformerBackbone(
|
||||
d_model=dim,
|
||||
n_blocks=1,
|
||||
n_heads=n_heads,
|
||||
n_kv_heads=n_kv_heads,
|
||||
max_seq_len=self.window_size * self.window_size,
|
||||
dropout=dropout,
|
||||
ffn_mult=ffn_mult,
|
||||
eps=eps,
|
||||
rope_theta=rope_theta,
|
||||
causal_attn=False,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
bsz, channels, height, width = x.shape
|
||||
ws = self.window_size
|
||||
pad_h = (ws - height % ws) % ws
|
||||
pad_w = (ws - width % ws) % ws
|
||||
if pad_h or pad_w:
|
||||
x = F.pad(x, (0, pad_w, 0, pad_h))
|
||||
padded_height, padded_width = x.shape[-2:]
|
||||
num_h = padded_height // ws
|
||||
num_w = padded_width // ws
|
||||
|
||||
windows = (
|
||||
x.permute(0, 2, 3, 1)
|
||||
.contiguous()
|
||||
.view(bsz, num_h, ws, num_w, ws, channels)
|
||||
.permute(0, 1, 3, 2, 4, 5)
|
||||
.contiguous()
|
||||
.view(bsz * num_h * num_w, ws * ws, channels)
|
||||
)
|
||||
windows = self.block(windows)
|
||||
x = (
|
||||
windows.view(bsz, num_h, num_w, ws, ws, channels)
|
||||
.permute(0, 1, 3, 2, 4, 5)
|
||||
.contiguous()
|
||||
.view(bsz, padded_height, padded_width, channels)
|
||||
.permute(0, 3, 1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
return x[:, :, :height, :width]
|
||||
|
||||
|
||||
class _AttnResStage2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
*,
|
||||
depth: int,
|
||||
downsample_stride: int,
|
||||
use_group_norm: bool,
|
||||
window_size: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
dropout: float,
|
||||
ffn_mult: float,
|
||||
eps: float,
|
||||
rope_theta: float,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.downsample = None
|
||||
if in_channels != out_channels or downsample_stride != 1:
|
||||
kernel_size = 1 if downsample_stride == 1 else 3
|
||||
padding = 0 if downsample_stride == 1 else 1
|
||||
self.downsample = _ConvNormAct(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=downsample_stride,
|
||||
padding=padding,
|
||||
use_group_norm=use_group_norm,
|
||||
)
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
AttnResImageBlock2D(
|
||||
out_channels,
|
||||
window_size=window_size,
|
||||
n_heads=n_heads,
|
||||
n_kv_heads=n_kv_heads,
|
||||
dropout=dropout,
|
||||
ffn_mult=ffn_mult,
|
||||
eps=eps,
|
||||
rope_theta=rope_theta,
|
||||
)
|
||||
for _ in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
return x
|
||||
|
||||
|
||||
class AttnResResNetLikeBackbone2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
input_channels: int = 3,
|
||||
stem_dim: int = 64,
|
||||
stage_dims: Sequence[int] = (64, 128, 256, 512),
|
||||
stage_depths: Sequence[int] = (2, 2, 2, 2),
|
||||
stage_heads: Sequence[int] = (4, 4, 8, 8),
|
||||
stage_kv_heads: Sequence[int] = (1, 1, 1, 1),
|
||||
stage_window_sizes: Sequence[int] = (7, 7, 7, 7),
|
||||
use_group_norm: bool = True,
|
||||
dropout: float = 0.0,
|
||||
ffn_mult: float = 2.667,
|
||||
eps: float = 1e-6,
|
||||
rope_theta: float = 10000.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
lengths = {
|
||||
len(stage_dims),
|
||||
len(stage_depths),
|
||||
len(stage_heads),
|
||||
len(stage_kv_heads),
|
||||
len(stage_window_sizes),
|
||||
}
|
||||
if len(lengths) != 1:
|
||||
raise ValueError('stage_dims/depths/heads/kv_heads/window_sizes 长度必须一致')
|
||||
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(input_channels, stem_dim, kernel_size=7, stride=2, padding=3, bias=False),
|
||||
_make_norm2d(stem_dim, use_group_norm),
|
||||
nn.SiLU(),
|
||||
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
|
||||
)
|
||||
|
||||
in_channels = stem_dim
|
||||
stages = []
|
||||
for stage_idx, (out_channels, depth, n_heads, n_kv_heads, window_size) in enumerate(
|
||||
zip(stage_dims, stage_depths, stage_heads, stage_kv_heads, stage_window_sizes)
|
||||
):
|
||||
stages.append(
|
||||
_AttnResStage2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
depth=int(depth),
|
||||
downsample_stride=1 if stage_idx == 0 else 2,
|
||||
use_group_norm=use_group_norm,
|
||||
window_size=int(window_size),
|
||||
n_heads=int(n_heads),
|
||||
n_kv_heads=int(n_kv_heads),
|
||||
dropout=float(dropout),
|
||||
ffn_mult=float(ffn_mult),
|
||||
eps=float(eps),
|
||||
rope_theta=float(rope_theta),
|
||||
)
|
||||
)
|
||||
in_channels = out_channels
|
||||
|
||||
self.stages = nn.ModuleList(stages)
|
||||
self.output_channels = in_channels
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.stem(x)
|
||||
for stage in self.stages:
|
||||
x = stage(x)
|
||||
return x
|
||||
@@ -6,6 +6,8 @@ import torchvision
|
||||
import numpy as np
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
from .attnres_resnet2d import AttnResResNetLikeBackbone2D
|
||||
|
||||
def _replace_submodules(
|
||||
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
|
||||
) -> nn.Module:
|
||||
@@ -103,6 +105,17 @@ class _SingleRgbEncoder(nn.Module):
|
||||
use_group_norm: bool,
|
||||
spatial_softmax_num_keypoints: int,
|
||||
freeze_backbone: bool = True, # 新增:是否冻结backbone
|
||||
vision_backbone_mode: str = "resnet",
|
||||
attnres_stem_dim: int = 64,
|
||||
attnres_stage_dims: Optional[Tuple[int, ...]] = None,
|
||||
attnres_stage_depths: Optional[Tuple[int, ...]] = None,
|
||||
attnres_stage_heads: Optional[Tuple[int, ...]] = None,
|
||||
attnres_stage_kv_heads: Optional[Tuple[int, ...]] = None,
|
||||
attnres_stage_window_sizes: Optional[Tuple[int, ...]] = None,
|
||||
attnres_dropout: float = 0.0,
|
||||
attnres_ffn_mult: float = 2.667,
|
||||
attnres_eps: float = 1e-6,
|
||||
attnres_rope_theta: float = 10000.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -119,21 +132,42 @@ class _SingleRgbEncoder(nn.Module):
|
||||
self.do_crop = False
|
||||
crop_shape = input_shape[1:]
|
||||
|
||||
# 设置骨干网络
|
||||
backbone_model = getattr(torchvision.models, vision_backbone)(
|
||||
weights=pretrained_backbone_weights
|
||||
)
|
||||
|
||||
# 移除 AvgPool 和 FC (假设 layer4 是 children()[-3])
|
||||
self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
|
||||
|
||||
if use_group_norm:
|
||||
self.backbone = _replace_submodules(
|
||||
root_module=self.backbone,
|
||||
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
||||
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
|
||||
if vision_backbone_mode == "resnet":
|
||||
# 设置骨干网络
|
||||
backbone_model = getattr(torchvision.models, vision_backbone)(
|
||||
weights=pretrained_backbone_weights
|
||||
)
|
||||
|
||||
# 移除 AvgPool 和 FC (假设 layer4 是 children()[-3])
|
||||
self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
|
||||
|
||||
if use_group_norm:
|
||||
self.backbone = _replace_submodules(
|
||||
root_module=self.backbone,
|
||||
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
||||
func=lambda x: nn.GroupNorm(
|
||||
num_groups=max(1, x.num_features // 16),
|
||||
num_channels=x.num_features,
|
||||
),
|
||||
)
|
||||
elif vision_backbone_mode == "attnres_resnet":
|
||||
self.backbone = AttnResResNetLikeBackbone2D(
|
||||
input_channels=input_shape[0],
|
||||
stem_dim=attnres_stem_dim,
|
||||
stage_dims=tuple(attnres_stage_dims or (64, 128, 256, 512)),
|
||||
stage_depths=tuple(attnres_stage_depths or (2, 2, 2, 2)),
|
||||
stage_heads=tuple(attnres_stage_heads or (4, 4, 8, 8)),
|
||||
stage_kv_heads=tuple(attnres_stage_kv_heads or (1, 1, 1, 1)),
|
||||
stage_window_sizes=tuple(attnres_stage_window_sizes or (7, 7, 7, 7)),
|
||||
use_group_norm=use_group_norm,
|
||||
dropout=attnres_dropout,
|
||||
ffn_mult=attnres_ffn_mult,
|
||||
eps=attnres_eps,
|
||||
rope_theta=attnres_rope_theta,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"不支持的 vision_backbone_mode: {vision_backbone_mode}")
|
||||
|
||||
# 冻结backbone参数(可选)
|
||||
if freeze_backbone:
|
||||
for param in self.backbone.parameters():
|
||||
@@ -180,6 +214,17 @@ class ResNetDiffusionBackbone(VLABackbone):
|
||||
num_cameras: int = 1, # 新增:摄像头数量(仅在独立编码器模式下使用)
|
||||
camera_names: Optional[Tuple[str, ...]] = None, # 显式相机顺序
|
||||
freeze_backbone: bool = True, # 新增:是否冻结ResNet backbone(推荐True)
|
||||
vision_backbone_mode: str = "resnet",
|
||||
attnres_stem_dim: int = 64,
|
||||
attnres_stage_dims: Optional[Tuple[int, ...]] = None,
|
||||
attnres_stage_depths: Optional[Tuple[int, ...]] = None,
|
||||
attnres_stage_heads: Optional[Tuple[int, ...]] = None,
|
||||
attnres_stage_kv_heads: Optional[Tuple[int, ...]] = None,
|
||||
attnres_stage_window_sizes: Optional[Tuple[int, ...]] = None,
|
||||
attnres_dropout: float = 0.0,
|
||||
attnres_ffn_mult: float = 2.667,
|
||||
attnres_eps: float = 1e-6,
|
||||
attnres_rope_theta: float = 10000.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -203,6 +248,17 @@ class ResNetDiffusionBackbone(VLABackbone):
|
||||
use_group_norm=use_group_norm,
|
||||
spatial_softmax_num_keypoints=spatial_softmax_num_keypoints,
|
||||
freeze_backbone=freeze_backbone,
|
||||
vision_backbone_mode=vision_backbone_mode,
|
||||
attnres_stem_dim=attnres_stem_dim,
|
||||
attnres_stage_dims=attnres_stage_dims,
|
||||
attnres_stage_depths=attnres_stage_depths,
|
||||
attnres_stage_heads=attnres_stage_heads,
|
||||
attnres_stage_kv_heads=attnres_stage_kv_heads,
|
||||
attnres_stage_window_sizes=attnres_stage_window_sizes,
|
||||
attnres_dropout=attnres_dropout,
|
||||
attnres_ffn_mult=attnres_ffn_mult,
|
||||
attnres_eps=attnres_eps,
|
||||
attnres_rope_theta=attnres_rope_theta,
|
||||
)
|
||||
for _ in range(num_cameras)
|
||||
]
|
||||
@@ -220,6 +276,17 @@ class ResNetDiffusionBackbone(VLABackbone):
|
||||
use_group_norm=use_group_norm,
|
||||
spatial_softmax_num_keypoints=spatial_softmax_num_keypoints,
|
||||
freeze_backbone=freeze_backbone,
|
||||
vision_backbone_mode=vision_backbone_mode,
|
||||
attnres_stem_dim=attnres_stem_dim,
|
||||
attnres_stage_dims=attnres_stage_dims,
|
||||
attnres_stage_depths=attnres_stage_depths,
|
||||
attnres_stage_heads=attnres_stage_heads,
|
||||
attnres_stage_kv_heads=attnres_stage_kv_heads,
|
||||
attnres_stage_window_sizes=attnres_stage_window_sizes,
|
||||
attnres_dropout=attnres_dropout,
|
||||
attnres_ffn_mult=attnres_ffn_mult,
|
||||
attnres_eps=attnres_eps,
|
||||
attnres_rope_theta=attnres_rope_theta,
|
||||
)
|
||||
self.feature_dim = self.rgb_encoder.feature_dim
|
||||
|
||||
|
||||
26
tests/test_attnres_resnet2d_backbone.py
Normal file
26
tests/test_attnres_resnet2d_backbone.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class AttnResResNet2DBackboneTest(unittest.TestCase):
|
||||
def test_backbone_preserves_resnet_like_stage_contract(self):
|
||||
from roboimi.vla.models.backbones.attnres_resnet2d import AttnResResNetLikeBackbone2D
|
||||
|
||||
backbone = AttnResResNetLikeBackbone2D(
|
||||
input_channels=3,
|
||||
stem_dim=16,
|
||||
stage_dims=(16, 32, 64, 128),
|
||||
stage_depths=(1, 1, 1, 1),
|
||||
stage_heads=(2, 4, 4, 8),
|
||||
stage_kv_heads=(1, 1, 1, 1),
|
||||
stage_window_sizes=(7, 7, 7, 7),
|
||||
dropout=0.0,
|
||||
)
|
||||
x = torch.randn(2, 3, 56, 56)
|
||||
y = backbone(x)
|
||||
self.assertEqual(y.shape, (2, 128, 2, 2))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -422,6 +422,32 @@ class IMFVLAAgentTest(unittest.TestCase):
|
||||
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], agent.per_step_cond_dim)
|
||||
self.assertEqual(agent.noise_pred_net.constructor_kwargs['backbone_type'], 'attnres_full')
|
||||
|
||||
def test_hydra_config_instantiates_resnet_imf_attnres_with_full_attnres_vision_backbone(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
'agent=resnet_imf_attnres',
|
||||
'agent.vision_backbone.vision_backbone_mode=attnres_resnet',
|
||||
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||
'agent.vision_backbone.input_shape=[3,56,56]',
|
||||
'agent.vision_backbone.freeze_backbone=false',
|
||||
'agent.vision_backbone.attnres_stem_dim=16',
|
||||
'agent.vision_backbone.attnres_stage_dims=[16,32,64,128]',
|
||||
'agent.vision_backbone.attnres_stage_depths=[1,1,1,1]',
|
||||
'agent.vision_backbone.attnres_stage_heads=[2,4,4,8]',
|
||||
'agent.vision_backbone.attnres_stage_kv_heads=[1,1,1,1]',
|
||||
'agent.vision_backbone.attnres_stage_window_sizes=[7,7,7,7]',
|
||||
'agent.head.n_layer=1',
|
||||
'agent.head.n_emb=16',
|
||||
]
|
||||
)
|
||||
|
||||
with _stub_optional_modules(include_imf_head=True):
|
||||
agent = instantiate(cfg.agent)
|
||||
|
||||
self.assertEqual(agent.vision_encoder.output_dim, 64)
|
||||
self.assertEqual(agent.per_step_cond_dim, 64 * agent.num_cams + agent.obs_dim)
|
||||
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], agent.per_step_cond_dim)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user