From 20331698402e569e4de3ea2c25fc0c688c0a7ba5 Mon Sep 17 00:00:00 2001 From: Logic Date: Sun, 5 Apr 2026 00:07:59 +0800 Subject: [PATCH] feat: add full attnres vision backbone --- ...6-04-05-phase2-full-attnres-vision-plan.md | 64 +++++ ...04-05-phase2-full-attnres-vision-design.md | 81 +++++++ .../leaderboard.csv | 14 +- .../2026-04-04-imf-horizon-grid/status.json | 36 +-- .../vla/conf/backbone/resnet_diffusion.yaml | 17 +- .../vla/models/backbones/attnres_resnet2d.py | 228 ++++++++++++++++++ .../vla/models/backbones/resnet_diffusion.py | 93 ++++++- tests/test_attnres_resnet2d_backbone.py | 26 ++ tests/test_imf_vla_agent.py | 26 ++ 9 files changed, 546 insertions(+), 39 deletions(-) create mode 100644 docs/superpowers/plans/2026-04-05-phase2-full-attnres-vision-plan.md create mode 100644 docs/superpowers/specs/2026-04-05-phase2-full-attnres-vision-design.md create mode 100644 roboimi/vla/models/backbones/attnres_resnet2d.py create mode 100644 tests/test_attnres_resnet2d_backbone.py diff --git a/docs/superpowers/plans/2026-04-05-phase2-full-attnres-vision-plan.md b/docs/superpowers/plans/2026-04-05-phase2-full-attnres-vision-plan.md new file mode 100644 index 0000000..8414f81 --- /dev/null +++ b/docs/superpowers/plans/2026-04-05-phase2-full-attnres-vision-plan.md @@ -0,0 +1,64 @@ +# Phase-2 Full-AttnRes Vision Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Replace all ResNet residual units in the vision backbone with AttnRes-based image blocks while preserving the current IMF agent interfaces and launch a Phase-2 experiment anchored on the best Phase-1 horizon setting. + +**Architecture:** Keep the current multi-camera encoder shell and per-camera output contract, but introduce a new ResNet-like 2D AttnRes backbone that preserves stage-wise downsampling and final SpatialSoftmax conditioning. Wire it into the existing `ResNetDiffusionBackbone` via an opt-in mode and keep the agent/head/data interfaces unchanged. + +**Tech Stack:** PyTorch, Hydra/OmegaConf, existing IMF AttnRes transformer components, pytest. + +--- + +### Task 1: Add failing tests for the new full-AttnRes visual backbone + +**Files:** +- Create: `tests/test_attnres_resnet2d_backbone.py` +- Update: `tests/test_imf_vla_agent.py` + +- [ ] **Step 1: Write a failing backbone shape test** +- [ ] **Step 2: Run it to confirm the new backbone/config does not exist yet** +- [ ] **Step 3: Add a failing IMF agent wiring test for unchanged cond_dim=208** +- [ ] **Step 4: Run the targeted tests and capture the failure** + +### Task 2: Implement a ResNet-like 2D AttnRes backbone + +**Files:** +- Create: `roboimi/vla/models/backbones/attnres_resnet2d.py` +- Modify: `roboimi/vla/models/backbones/resnet_diffusion.py` + +- [ ] **Step 1: Add minimal 2D tokenization helpers and positional encoding / bias handling** +- [ ] **Step 2: Implement `AttnResImageBlock2D` for feature maps** +- [ ] **Step 3: Implement `AttnResResNetLikeBackbone2D` with stage-wise downsampling** +- [ ] **Step 4: Wire `_SingleRgbEncoder` to choose between original ResNet trunk and the new full-AttnRes trunk** +- [ ] **Step 5: Run the new backbone tests** + +### Task 3: Expose config switches and agent wiring + +**Files:** +- Modify: `roboimi/vla/conf/backbone/resnet_diffusion.yaml` +- Modify: `roboimi/vla/conf/agent/resnet_imf_attnres.yaml` + +- [ ] **Step 1: Add a backbone mode/config flag for the full-AttnRes vision trunk** +- [ ] **Step 2: Add defaults for attnres image depth/heads/etc. if needed** +- [ ] **Step 3: Add a Phase-2 launch override path that enables the new visual trunk** +- [ ] **Step 4: Run agent wiring tests again** + +### Task 4: Smoke-verify training path + +**Files:** +- Reuse existing training scripts and configs + +- [ ] **Step 1: Run a short CPU or tiny-step smoke instantiation / `compute_loss` test** +- [ ] **Step 2: If needed, run a very short training smoke launch** +- [ ] **Step 3: Verify no cond-dim or rollout-loading regressions** + +### Task 5: Launch the Phase-2 experiment + +**Files:** +- Update experiment tracking under `experiment_suites/` + +- [ ] **Step 1: Use Phase-1 best setting (`pred_horizon=16`, `num_action_steps=8`)** +- [ ] **Step 2: Launch baseline reference or reuse existing result** +- [ ] **Step 3: Launch full-AttnRes vision experiment** +- [ ] **Step 4: Track rollout metrics and compare max avg_reward** diff --git a/docs/superpowers/specs/2026-04-05-phase2-full-attnres-vision-design.md b/docs/superpowers/specs/2026-04-05-phase2-full-attnres-vision-design.md new file mode 100644 index 0000000..b1c2f0c --- /dev/null +++ b/docs/superpowers/specs/2026-04-05-phase2-full-attnres-vision-design.md @@ -0,0 +1,81 @@ +# Phase-2 Full-AttnRes Vision Design + +## Goal +在当前 roboimi IMF policy 中,把视觉 backbone 里原先由 ResNet BasicBlock/Bottleneck 提供的残差单元全部替换为 AttnRes 风格单元,同时尽量保持现有 agent / cond / rollout / 训练脚本接口不变。 + +## User requirement interpretation +这里按最严格解释执行: +- 不是“在 ResNet 后面再加一个 AttnRes 模块” +- 也不是“只在某几个 stage 加 AttnRes 混合” +- 而是:视觉主干网络中原本依赖 ResNet residual block 的地方,统一改成 AttnRes residual operator 驱动的 block +- 最终仍然输出与现有 `ResNetDiffusionBackbone` 相同的每相机特征接口,以便复用 `SpatialSoftmax -> Linear -> ReLU`、多相机拼接、state concat、IMF head 条件输入 + +## Recommended design +### Option A (recommended) +保留 ResNet 的宏观 stage/stem 结构与通道/步幅规划,但把每个 stage 内的 BasicBlock/Bottleneck 替换为新的 `AttnResImageBlock2D`: +- 输入仍是 `(B, C, H, W)` feature map +- block 内先把空间维 flatten 成 token 序列 `(B, H*W, C)` +- 用二维位置编码 / 可学习位置偏置 + AttnRes self-attention + AttnRes FFN 完成 block 变换 +- 再 reshape 回 `(B, C, H, W)` +- stage 间下采样仍由显式 stride/downsample path 完成 + +优点: +- 最接近“ResNet 中所有残差都由 AttnRes 代替”的要求 +- 保留现有视觉输出接口和 cond_dim,不用改 agent/head/data pipeline +- 仍可沿用现有多相机编码器框架 + +缺点: +- 需要新写 2D 版 AttnRes image block,而不是直接复用 1D IMF head block + +### Option B +完全移除 ResNet stage,换成 patchify + ViT/AttnRes 图像 transformer,再接 SpatialSoftmax/MLP。 + +优点:实现概念更统一。 +缺点:已经不算“把 ResNet 中残差替换掉”,而是直接换 backbone,和用户要求不完全一致。 + +### Option C +保留现有 ResNet block,只在 block 外层加 AttnRes mixing。 + +不推荐,因为不满足“所有残差均由 AttnRes 替代”。 + +## Concrete architecture choice +采用 Option A: +1. 保留 stem(conv/bn-or-gn/relu/maxpool)与 stage 边界 +2. 新增 `AttnResImageBlock2D` +3. 新增 `AttnResResNetLikeBackbone2D`,负责堆叠 stage/block +4. 在 `ResNetDiffusionBackbone` 中增加可选 backbone mode,例如: + - `vision_backbone_mode: resnet` + - `vision_backbone_mode: attnres_resnet` +5. `resnet_imf_attnres` agent 配置新增一个 Phase-2 变体,默认打开 `attnres_resnet` +6. 仍保持: + - 每相机输出 `64` + - 多相机总视觉输出 `3 * 64` + - 与 state 拼接后 `cond_dim = 208` + +## Files likely to change +- `roboimi/vla/models/backbones/resnet_diffusion.py` +- `roboimi/vla/conf/backbone/resnet_diffusion.yaml` +- `roboimi/vla/conf/agent/resnet_imf_attnres.yaml` +- new: `roboimi/vla/models/backbones/attnres_resnet2d.py` +- tests: + - new: `tests/test_attnres_resnet2d_backbone.py` + - update/add wiring test for agent cond dims + +## Test plan +1. New backbone instantiates and forwards `(B,T,C,H,W)` multi-camera input +2. Output shape unchanged vs current backbone +3. `output_dim == 64` +4. 3-camera cond path still yields `208` +5. Phase-2 config instantiates full IMF agent successfully +6. One short CPU smoke forward for `compute_loss` + +## Phase-2 experiment plan +固定使用 Phase-1 最优组合: +- `pred_horizon=16` +- `num_action_steps=8` + +比较: +1. baseline: current IMF head-only AttnRes + original ResNet vision backbone +2. phase2: IMF head AttnRes + full AttnRes-replaced vision backbone + +训练超参保持与 Phase-1 最优设置一致,先跑一组 50k step 对比。 diff --git a/experiment_suites/2026-04-04-imf-horizon-grid/leaderboard.csv b/experiment_suites/2026-04-04-imf-horizon-grid/leaderboard.csv index 908eff0..5031b26 100644 --- a/experiment_suites/2026-04-04-imf-horizon-grid/leaderboard.csv +++ b/experiment_suites/2026-04-04-imf-horizon-grid/leaderboard.csv @@ -1,7 +1,7 @@ -rank,run_id,status,pred_horizon,num_action_steps,best_rollout_avg_reward,best_step,final_step,final_loss,host,run_dir -1,ph16_ex8,finished,16,8,610.8,21874,50000,0.0034315965604037046,100.73.14.65,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223 -2,ph16_ex16,finished,16,16,561.2,48124,50000,0.004544622730463743,100.119.99.14,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex16-emb384-l12-ms50k-l20g0-20260404-131223 -3,ph32_ex32,finished,32,32,513.2,43749,50000,0.003953303210437298,local,/home/droid/project/roboimi/.worktrees/feat-imf-attnres-policy/runs/imf-p1-ph32-ex32-emb384-l12-ms50k-5090-20260404-131223 -4,ph8_ex8,finished,8,8,415.6,48124,50000,0.007008877582848072,100.73.14.65,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph08-ex08-emb384-l12-ms50k-5880g0-20260404-131223 -5,ph32_ex8,finished,32,8,361.6,43749,50000,0.004788532387465239,100.119.99.14,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph32-ex08-emb384-l12-ms50k-l20g1-20260404-131223 -6,ph32_ex16,finished,32,16,239.6,48124,50000,0.0038348555099219084,100.119.99.14,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph32-ex16-emb384-l12-ms50k-l20g2-20260404-131223 +rank,run_id,status,pred_horizon,num_action_steps,best_rollout_avg_reward,best_step,final_step,final_loss,host,run_dir,latest_step +1,ph16_ex8,running,16,8,610.8,21874,50000,0.0034315965604037046,100.73.14.65,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223,50000 +2,ph16_ex16,running,16,16,561.2,48124,50000,0.004544622730463743,100.119.99.14,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex16-emb384-l12-ms50k-l20g0-20260404-131223,50000 +3,ph32_ex32,finished,32,32,513.2,43749,50000,0.003953303210437298,local,/home/droid/project/roboimi/.worktrees/feat-imf-attnres-policy/runs/imf-p1-ph32-ex32-emb384-l12-ms50k-5090-20260404-131223,49900 +4,ph8_ex8,running,8,8,415.6,48124,50000,0.007008877582848072,100.73.14.65,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph08-ex08-emb384-l12-ms50k-5880g0-20260404-131223,50000 +5,ph32_ex8,running,32,8,361.6,43749,50000,0.004788532387465239,100.119.99.14,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph32-ex08-emb384-l12-ms50k-l20g1-20260404-131223,50000 +6,ph32_ex16,running,32,16,239.6,48124,50000,0.0038348555099219084,100.119.99.14,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph32-ex16-emb384-l12-ms50k-l20g2-20260404-131223,50000 diff --git a/experiment_suites/2026-04-04-imf-horizon-grid/status.json b/experiment_suites/2026-04-04-imf-horizon-grid/status.json index 0cae42c..3034e43 100644 --- a/experiment_suites/2026-04-04-imf-horizon-grid/status.json +++ b/experiment_suites/2026-04-04-imf-horizon-grid/status.json @@ -1,6 +1,6 @@ { "suite_name": "2026-04-04-imf-horizon-grid", - "updated_at": "2026-04-04 23:46:01", + "updated_at": "2026-04-05 00:07:39", "phase": "phase1_completed", "provisioning": { "100.119.99.14": { @@ -17,7 +17,7 @@ }, "runs": { "ph8_ex8": { - "status": "finished", + "status": "running", "host": "100.73.14.65", "gpu": 0, "run_name": "imf-p1-ph08-ex08-emb384-l12-ms50k-5880g0-20260404-131223", @@ -30,15 +30,15 @@ "pid": 938714, "launch_log": "experiment_suite_launch_logs/imf-p1-ph08-ex08-emb384-l12-ms50k-5880g0-20260404-131223.restartfix-20260404-143827.log", "latest_step": 50000, - "latest_log_sync": "2026-04-04 23:42:34", + "latest_log_sync": "2026-04-05 00:07:39", "swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/i5syc57b6zq7rbkrtqy7b", - "process_running": false, + "process_running": true, "best_step": 48124, "best_rollout_avg_reward": 415.6, "final_loss": 0.007008877582848072 }, "ph16_ex8": { - "status": "finished", + "status": "running", "host": "100.73.14.65", "gpu": 1, "run_name": "imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223", @@ -51,15 +51,15 @@ "pid": 938717, "launch_log": "experiment_suite_launch_logs/imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223.restartfix-20260404-143827.log", "latest_step": 50000, - "latest_log_sync": "2026-04-04 23:42:34", + "latest_log_sync": "2026-04-05 00:07:39", "swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/4rusbrpfxmw4ffii1ul5w", - "process_running": false, + "process_running": true, "best_step": 21874, "best_rollout_avg_reward": 610.8, "final_loss": 0.0034315965604037046 }, "ph16_ex16": { - "status": "finished", + "status": "running", "host": "100.119.99.14", "gpu": 0, "run_name": "imf-p1-ph16-ex16-emb384-l12-ms50k-l20g0-20260404-131223", @@ -71,16 +71,16 @@ "num_action_steps": 16, "pid": 90169, "launch_log": "experiment_suite_launch_logs/imf-p1-ph16-ex16-emb384-l12-ms50k-l20g0-20260404-131223.restartfix-20260404-143827.log", - "latest_log_sync": "2026-04-04 23:42:34", + "latest_log_sync": "2026-04-05 00:07:39", "latest_step": 50000, "swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/wwm232k6190gexnze8mg6", - "process_running": false, + "process_running": true, "best_step": 48124, "best_rollout_avg_reward": 561.2, "final_loss": 0.004544622730463743 }, "ph32_ex8": { - "status": "finished", + "status": "running", "host": "100.119.99.14", "gpu": 1, "run_name": "imf-p1-ph32-ex08-emb384-l12-ms50k-l20g1-20260404-131223", @@ -92,16 +92,16 @@ "num_action_steps": 8, "pid": 90173, "launch_log": "experiment_suite_launch_logs/imf-p1-ph32-ex08-emb384-l12-ms50k-l20g1-20260404-131223.restartfix-20260404-143827.log", - "latest_log_sync": "2026-04-04 23:42:34", + "latest_log_sync": "2026-04-05 00:07:39", "latest_step": 50000, "swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/o5y2xjb2rsb3lmfcuhy4p", - "process_running": false, + "process_running": true, "best_step": 43749, "best_rollout_avg_reward": 361.6, "final_loss": 0.004788532387465239 }, "ph32_ex16": { - "status": "finished", + "status": "running", "host": "100.119.99.14", "gpu": 2, "run_name": "imf-p1-ph32-ex16-emb384-l12-ms50k-l20g2-20260404-131223", @@ -113,10 +113,10 @@ "num_action_steps": 16, "pid": 90175, "launch_log": "experiment_suite_launch_logs/imf-p1-ph32-ex16-emb384-l12-ms50k-l20g2-20260404-131223.restartfix-20260404-143827.log", - "latest_log_sync": "2026-04-04 23:42:34", + "latest_log_sync": "2026-04-05 00:07:39", "latest_step": 50000, "swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/54cjpgba9eqsopdm0l8d3", - "process_running": false, + "process_running": true, "best_step": 48124, "best_rollout_avg_reward": 239.6, "final_loss": 0.0038348555099219084 @@ -134,8 +134,8 @@ "num_action_steps": 32, "pid": 1437836, "launch_log": "experiment_suites/2026-04-04-imf-horizon-grid/launch_logs/imf-p1-ph32-ex32-emb384-l12-ms50k-5090-20260404-131223.launch.log", - "latest_step": 50000, - "latest_log_sync": "2026-04-04 23:42:34", + "latest_step": 49900, + "latest_log_sync": "2026-04-05 00:07:39", "swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/ajs2m218jd260hawhy5ns", "process_running": false, "latest_rollout_avg_reward": 513.2, diff --git a/roboimi/vla/conf/backbone/resnet_diffusion.yaml b/roboimi/vla/conf/backbone/resnet_diffusion.yaml index 6f8a11a..ca08799 100644 --- a/roboimi/vla/conf/backbone/resnet_diffusion.yaml +++ b/roboimi/vla/conf/backbone/resnet_diffusion.yaml @@ -5,6 +5,7 @@ _target_: roboimi.vla.models.backbones.resnet_diffusion.ResNetDiffusionBackbone # ==================== vision_backbone: "resnet18" # torchvision 模型名称: resnet18, resnet34, resnet50 pretrained_backbone_weights: "IMAGENET1K_V1" # 使用ImageNet预训练权重(torchvision>=0.13) +vision_backbone_mode: "resnet" # resnet | attnres_resnet # ==================== # 冻结设置 @@ -30,4 +31,18 @@ spatial_softmax_num_keypoints: 32 # Spatial Softmax 关键点数量 # false: 共享编码器(所有摄像头共享一个 ResNet,参数少但容量受限)推荐! # true: 独立编码器(每个摄像头有独立的 ResNet,参数多但容量大) use_separate_rgb_encoder_per_camera: true -num_cameras: 3 # 摄像头数量 \ No newline at end of file +num_cameras: 3 # 摄像头数量 + +# ==================== +# Full-AttnRes vision trunk(当 vision_backbone_mode=attnres_resnet 时生效) +# ==================== +attnres_stem_dim: 64 +attnres_stage_dims: [64, 128, 256, 512] +attnres_stage_depths: [2, 2, 2, 2] +attnres_stage_heads: [4, 4, 8, 8] +attnres_stage_kv_heads: [1, 1, 1, 1] +attnres_stage_window_sizes: [7, 7, 7, 7] +attnres_dropout: 0.0 +attnres_ffn_mult: 2.667 +attnres_eps: 1.0e-06 +attnres_rope_theta: 10000.0 diff --git a/roboimi/vla/models/backbones/attnres_resnet2d.py b/roboimi/vla/models/backbones/attnres_resnet2d.py new file mode 100644 index 0000000..1ef144f --- /dev/null +++ b/roboimi/vla/models/backbones/attnres_resnet2d.py @@ -0,0 +1,228 @@ +from __future__ import annotations + +from typing import Iterable, Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from roboimi.vla.models.heads.attnres_transformer_components import AttnResTransformerBackbone + + +def _make_norm2d(num_channels: int, use_group_norm: bool) -> nn.Module: + if use_group_norm: + num_groups = max(1, num_channels // 16) + return nn.GroupNorm(num_groups=num_groups, num_channels=num_channels) + return nn.BatchNorm2d(num_channels) + + +class _ConvNormAct(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + *, + kernel_size: int, + stride: int, + padding: int, + use_group_norm: bool, + ) -> None: + super().__init__() + self.block = nn.Sequential( + nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=False, + ), + _make_norm2d(out_channels, use_group_norm), + nn.SiLU(), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.block(x) + + +class AttnResImageBlock2D(nn.Module): + def __init__( + self, + dim: int, + *, + window_size: int, + n_heads: int, + n_kv_heads: int, + dropout: float, + ffn_mult: float, + eps: float, + rope_theta: float, + ) -> None: + super().__init__() + self.window_size = int(window_size) + self.block = AttnResTransformerBackbone( + d_model=dim, + n_blocks=1, + n_heads=n_heads, + n_kv_heads=n_kv_heads, + max_seq_len=self.window_size * self.window_size, + dropout=dropout, + ffn_mult=ffn_mult, + eps=eps, + rope_theta=rope_theta, + causal_attn=False, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + bsz, channels, height, width = x.shape + ws = self.window_size + pad_h = (ws - height % ws) % ws + pad_w = (ws - width % ws) % ws + if pad_h or pad_w: + x = F.pad(x, (0, pad_w, 0, pad_h)) + padded_height, padded_width = x.shape[-2:] + num_h = padded_height // ws + num_w = padded_width // ws + + windows = ( + x.permute(0, 2, 3, 1) + .contiguous() + .view(bsz, num_h, ws, num_w, ws, channels) + .permute(0, 1, 3, 2, 4, 5) + .contiguous() + .view(bsz * num_h * num_w, ws * ws, channels) + ) + windows = self.block(windows) + x = ( + windows.view(bsz, num_h, num_w, ws, ws, channels) + .permute(0, 1, 3, 2, 4, 5) + .contiguous() + .view(bsz, padded_height, padded_width, channels) + .permute(0, 3, 1, 2) + .contiguous() + ) + return x[:, :, :height, :width] + + +class _AttnResStage2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + *, + depth: int, + downsample_stride: int, + use_group_norm: bool, + window_size: int, + n_heads: int, + n_kv_heads: int, + dropout: float, + ffn_mult: float, + eps: float, + rope_theta: float, + ) -> None: + super().__init__() + self.downsample = None + if in_channels != out_channels or downsample_stride != 1: + kernel_size = 1 if downsample_stride == 1 else 3 + padding = 0 if downsample_stride == 1 else 1 + self.downsample = _ConvNormAct( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=downsample_stride, + padding=padding, + use_group_norm=use_group_norm, + ) + + self.blocks = nn.ModuleList( + [ + AttnResImageBlock2D( + out_channels, + window_size=window_size, + n_heads=n_heads, + n_kv_heads=n_kv_heads, + dropout=dropout, + ffn_mult=ffn_mult, + eps=eps, + rope_theta=rope_theta, + ) + for _ in range(depth) + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.downsample is not None: + x = self.downsample(x) + for block in self.blocks: + x = block(x) + return x + + +class AttnResResNetLikeBackbone2D(nn.Module): + def __init__( + self, + *, + input_channels: int = 3, + stem_dim: int = 64, + stage_dims: Sequence[int] = (64, 128, 256, 512), + stage_depths: Sequence[int] = (2, 2, 2, 2), + stage_heads: Sequence[int] = (4, 4, 8, 8), + stage_kv_heads: Sequence[int] = (1, 1, 1, 1), + stage_window_sizes: Sequence[int] = (7, 7, 7, 7), + use_group_norm: bool = True, + dropout: float = 0.0, + ffn_mult: float = 2.667, + eps: float = 1e-6, + rope_theta: float = 10000.0, + ) -> None: + super().__init__() + + lengths = { + len(stage_dims), + len(stage_depths), + len(stage_heads), + len(stage_kv_heads), + len(stage_window_sizes), + } + if len(lengths) != 1: + raise ValueError('stage_dims/depths/heads/kv_heads/window_sizes 长度必须一致') + + self.stem = nn.Sequential( + nn.Conv2d(input_channels, stem_dim, kernel_size=7, stride=2, padding=3, bias=False), + _make_norm2d(stem_dim, use_group_norm), + nn.SiLU(), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ) + + in_channels = stem_dim + stages = [] + for stage_idx, (out_channels, depth, n_heads, n_kv_heads, window_size) in enumerate( + zip(stage_dims, stage_depths, stage_heads, stage_kv_heads, stage_window_sizes) + ): + stages.append( + _AttnResStage2D( + in_channels=in_channels, + out_channels=out_channels, + depth=int(depth), + downsample_stride=1 if stage_idx == 0 else 2, + use_group_norm=use_group_norm, + window_size=int(window_size), + n_heads=int(n_heads), + n_kv_heads=int(n_kv_heads), + dropout=float(dropout), + ffn_mult=float(ffn_mult), + eps=float(eps), + rope_theta=float(rope_theta), + ) + ) + in_channels = out_channels + + self.stages = nn.ModuleList(stages) + self.output_channels = in_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.stem(x) + for stage in self.stages: + x = stage(x) + return x diff --git a/roboimi/vla/models/backbones/resnet_diffusion.py b/roboimi/vla/models/backbones/resnet_diffusion.py index 726c504..3480777 100644 --- a/roboimi/vla/models/backbones/resnet_diffusion.py +++ b/roboimi/vla/models/backbones/resnet_diffusion.py @@ -6,6 +6,8 @@ import torchvision import numpy as np from typing import Callable, Optional, Tuple, Union +from .attnres_resnet2d import AttnResResNetLikeBackbone2D + def _replace_submodules( root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module] ) -> nn.Module: @@ -103,6 +105,17 @@ class _SingleRgbEncoder(nn.Module): use_group_norm: bool, spatial_softmax_num_keypoints: int, freeze_backbone: bool = True, # 新增:是否冻结backbone + vision_backbone_mode: str = "resnet", + attnres_stem_dim: int = 64, + attnres_stage_dims: Optional[Tuple[int, ...]] = None, + attnres_stage_depths: Optional[Tuple[int, ...]] = None, + attnres_stage_heads: Optional[Tuple[int, ...]] = None, + attnres_stage_kv_heads: Optional[Tuple[int, ...]] = None, + attnres_stage_window_sizes: Optional[Tuple[int, ...]] = None, + attnres_dropout: float = 0.0, + attnres_ffn_mult: float = 2.667, + attnres_eps: float = 1e-6, + attnres_rope_theta: float = 10000.0, ): super().__init__() @@ -119,21 +132,42 @@ class _SingleRgbEncoder(nn.Module): self.do_crop = False crop_shape = input_shape[1:] - # 设置骨干网络 - backbone_model = getattr(torchvision.models, vision_backbone)( - weights=pretrained_backbone_weights - ) - - # 移除 AvgPool 和 FC (假设 layer4 是 children()[-3]) - self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2])) - - if use_group_norm: - self.backbone = _replace_submodules( - root_module=self.backbone, - predicate=lambda x: isinstance(x, nn.BatchNorm2d), - func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features), + if vision_backbone_mode == "resnet": + # 设置骨干网络 + backbone_model = getattr(torchvision.models, vision_backbone)( + weights=pretrained_backbone_weights ) + # 移除 AvgPool 和 FC (假设 layer4 是 children()[-3]) + self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2])) + + if use_group_norm: + self.backbone = _replace_submodules( + root_module=self.backbone, + predicate=lambda x: isinstance(x, nn.BatchNorm2d), + func=lambda x: nn.GroupNorm( + num_groups=max(1, x.num_features // 16), + num_channels=x.num_features, + ), + ) + elif vision_backbone_mode == "attnres_resnet": + self.backbone = AttnResResNetLikeBackbone2D( + input_channels=input_shape[0], + stem_dim=attnres_stem_dim, + stage_dims=tuple(attnres_stage_dims or (64, 128, 256, 512)), + stage_depths=tuple(attnres_stage_depths or (2, 2, 2, 2)), + stage_heads=tuple(attnres_stage_heads or (4, 4, 8, 8)), + stage_kv_heads=tuple(attnres_stage_kv_heads or (1, 1, 1, 1)), + stage_window_sizes=tuple(attnres_stage_window_sizes or (7, 7, 7, 7)), + use_group_norm=use_group_norm, + dropout=attnres_dropout, + ffn_mult=attnres_ffn_mult, + eps=attnres_eps, + rope_theta=attnres_rope_theta, + ) + else: + raise ValueError(f"不支持的 vision_backbone_mode: {vision_backbone_mode}") + # 冻结backbone参数(可选) if freeze_backbone: for param in self.backbone.parameters(): @@ -180,6 +214,17 @@ class ResNetDiffusionBackbone(VLABackbone): num_cameras: int = 1, # 新增:摄像头数量(仅在独立编码器模式下使用) camera_names: Optional[Tuple[str, ...]] = None, # 显式相机顺序 freeze_backbone: bool = True, # 新增:是否冻结ResNet backbone(推荐True) + vision_backbone_mode: str = "resnet", + attnres_stem_dim: int = 64, + attnres_stage_dims: Optional[Tuple[int, ...]] = None, + attnres_stage_depths: Optional[Tuple[int, ...]] = None, + attnres_stage_heads: Optional[Tuple[int, ...]] = None, + attnres_stage_kv_heads: Optional[Tuple[int, ...]] = None, + attnres_stage_window_sizes: Optional[Tuple[int, ...]] = None, + attnres_dropout: float = 0.0, + attnres_ffn_mult: float = 2.667, + attnres_eps: float = 1e-6, + attnres_rope_theta: float = 10000.0, ): super().__init__() @@ -203,6 +248,17 @@ class ResNetDiffusionBackbone(VLABackbone): use_group_norm=use_group_norm, spatial_softmax_num_keypoints=spatial_softmax_num_keypoints, freeze_backbone=freeze_backbone, + vision_backbone_mode=vision_backbone_mode, + attnres_stem_dim=attnres_stem_dim, + attnres_stage_dims=attnres_stage_dims, + attnres_stage_depths=attnres_stage_depths, + attnres_stage_heads=attnres_stage_heads, + attnres_stage_kv_heads=attnres_stage_kv_heads, + attnres_stage_window_sizes=attnres_stage_window_sizes, + attnres_dropout=attnres_dropout, + attnres_ffn_mult=attnres_ffn_mult, + attnres_eps=attnres_eps, + attnres_rope_theta=attnres_rope_theta, ) for _ in range(num_cameras) ] @@ -220,6 +276,17 @@ class ResNetDiffusionBackbone(VLABackbone): use_group_norm=use_group_norm, spatial_softmax_num_keypoints=spatial_softmax_num_keypoints, freeze_backbone=freeze_backbone, + vision_backbone_mode=vision_backbone_mode, + attnres_stem_dim=attnres_stem_dim, + attnres_stage_dims=attnres_stage_dims, + attnres_stage_depths=attnres_stage_depths, + attnres_stage_heads=attnres_stage_heads, + attnres_stage_kv_heads=attnres_stage_kv_heads, + attnres_stage_window_sizes=attnres_stage_window_sizes, + attnres_dropout=attnres_dropout, + attnres_ffn_mult=attnres_ffn_mult, + attnres_eps=attnres_eps, + attnres_rope_theta=attnres_rope_theta, ) self.feature_dim = self.rgb_encoder.feature_dim diff --git a/tests/test_attnres_resnet2d_backbone.py b/tests/test_attnres_resnet2d_backbone.py new file mode 100644 index 0000000..3bb1829 --- /dev/null +++ b/tests/test_attnres_resnet2d_backbone.py @@ -0,0 +1,26 @@ +import unittest + +import torch + + +class AttnResResNet2DBackboneTest(unittest.TestCase): + def test_backbone_preserves_resnet_like_stage_contract(self): + from roboimi.vla.models.backbones.attnres_resnet2d import AttnResResNetLikeBackbone2D + + backbone = AttnResResNetLikeBackbone2D( + input_channels=3, + stem_dim=16, + stage_dims=(16, 32, 64, 128), + stage_depths=(1, 1, 1, 1), + stage_heads=(2, 4, 4, 8), + stage_kv_heads=(1, 1, 1, 1), + stage_window_sizes=(7, 7, 7, 7), + dropout=0.0, + ) + x = torch.randn(2, 3, 56, 56) + y = backbone(x) + self.assertEqual(y.shape, (2, 128, 2, 2)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_imf_vla_agent.py b/tests/test_imf_vla_agent.py index 0050c9c..1d3cf79 100644 --- a/tests/test_imf_vla_agent.py +++ b/tests/test_imf_vla_agent.py @@ -422,6 +422,32 @@ class IMFVLAAgentTest(unittest.TestCase): self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], agent.per_step_cond_dim) self.assertEqual(agent.noise_pred_net.constructor_kwargs['backbone_type'], 'attnres_full') + def test_hydra_config_instantiates_resnet_imf_attnres_with_full_attnres_vision_backbone(self): + cfg = _compose_cfg( + overrides=[ + 'agent=resnet_imf_attnres', + 'agent.vision_backbone.vision_backbone_mode=attnres_resnet', + 'agent.vision_backbone.pretrained_backbone_weights=null', + 'agent.vision_backbone.input_shape=[3,56,56]', + 'agent.vision_backbone.freeze_backbone=false', + 'agent.vision_backbone.attnres_stem_dim=16', + 'agent.vision_backbone.attnres_stage_dims=[16,32,64,128]', + 'agent.vision_backbone.attnres_stage_depths=[1,1,1,1]', + 'agent.vision_backbone.attnres_stage_heads=[2,4,4,8]', + 'agent.vision_backbone.attnres_stage_kv_heads=[1,1,1,1]', + 'agent.vision_backbone.attnres_stage_window_sizes=[7,7,7,7]', + 'agent.head.n_layer=1', + 'agent.head.n_emb=16', + ] + ) + + with _stub_optional_modules(include_imf_head=True): + agent = instantiate(cfg.agent) + + self.assertEqual(agent.vision_encoder.output_dim, 64) + self.assertEqual(agent.per_step_cond_dim, 64 * agent.num_cams + agent.obs_dim) + self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], agent.per_step_cond_dim) + if __name__ == '__main__': unittest.main()