feat: add full attnres vision backbone
This commit is contained in:
@@ -0,0 +1,64 @@
|
|||||||
|
# Phase-2 Full-AttnRes Vision Implementation Plan
|
||||||
|
|
||||||
|
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||||
|
|
||||||
|
**Goal:** Replace all ResNet residual units in the vision backbone with AttnRes-based image blocks while preserving the current IMF agent interfaces and launch a Phase-2 experiment anchored on the best Phase-1 horizon setting.
|
||||||
|
|
||||||
|
**Architecture:** Keep the current multi-camera encoder shell and per-camera output contract, but introduce a new ResNet-like 2D AttnRes backbone that preserves stage-wise downsampling and final SpatialSoftmax conditioning. Wire it into the existing `ResNetDiffusionBackbone` via an opt-in mode and keep the agent/head/data interfaces unchanged.
|
||||||
|
|
||||||
|
**Tech Stack:** PyTorch, Hydra/OmegaConf, existing IMF AttnRes transformer components, pytest.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 1: Add failing tests for the new full-AttnRes visual backbone
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `tests/test_attnres_resnet2d_backbone.py`
|
||||||
|
- Update: `tests/test_imf_vla_agent.py`
|
||||||
|
|
||||||
|
- [ ] **Step 1: Write a failing backbone shape test**
|
||||||
|
- [ ] **Step 2: Run it to confirm the new backbone/config does not exist yet**
|
||||||
|
- [ ] **Step 3: Add a failing IMF agent wiring test for unchanged cond_dim=208**
|
||||||
|
- [ ] **Step 4: Run the targeted tests and capture the failure**
|
||||||
|
|
||||||
|
### Task 2: Implement a ResNet-like 2D AttnRes backbone
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `roboimi/vla/models/backbones/attnres_resnet2d.py`
|
||||||
|
- Modify: `roboimi/vla/models/backbones/resnet_diffusion.py`
|
||||||
|
|
||||||
|
- [ ] **Step 1: Add minimal 2D tokenization helpers and positional encoding / bias handling**
|
||||||
|
- [ ] **Step 2: Implement `AttnResImageBlock2D` for feature maps**
|
||||||
|
- [ ] **Step 3: Implement `AttnResResNetLikeBackbone2D` with stage-wise downsampling**
|
||||||
|
- [ ] **Step 4: Wire `_SingleRgbEncoder` to choose between original ResNet trunk and the new full-AttnRes trunk**
|
||||||
|
- [ ] **Step 5: Run the new backbone tests**
|
||||||
|
|
||||||
|
### Task 3: Expose config switches and agent wiring
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `roboimi/vla/conf/backbone/resnet_diffusion.yaml`
|
||||||
|
- Modify: `roboimi/vla/conf/agent/resnet_imf_attnres.yaml`
|
||||||
|
|
||||||
|
- [ ] **Step 1: Add a backbone mode/config flag for the full-AttnRes vision trunk**
|
||||||
|
- [ ] **Step 2: Add defaults for attnres image depth/heads/etc. if needed**
|
||||||
|
- [ ] **Step 3: Add a Phase-2 launch override path that enables the new visual trunk**
|
||||||
|
- [ ] **Step 4: Run agent wiring tests again**
|
||||||
|
|
||||||
|
### Task 4: Smoke-verify training path
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Reuse existing training scripts and configs
|
||||||
|
|
||||||
|
- [ ] **Step 1: Run a short CPU or tiny-step smoke instantiation / `compute_loss` test**
|
||||||
|
- [ ] **Step 2: If needed, run a very short training smoke launch**
|
||||||
|
- [ ] **Step 3: Verify no cond-dim or rollout-loading regressions**
|
||||||
|
|
||||||
|
### Task 5: Launch the Phase-2 experiment
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Update experiment tracking under `experiment_suites/`
|
||||||
|
|
||||||
|
- [ ] **Step 1: Use Phase-1 best setting (`pred_horizon=16`, `num_action_steps=8`)**
|
||||||
|
- [ ] **Step 2: Launch baseline reference or reuse existing result**
|
||||||
|
- [ ] **Step 3: Launch full-AttnRes vision experiment**
|
||||||
|
- [ ] **Step 4: Track rollout metrics and compare max avg_reward**
|
||||||
@@ -0,0 +1,81 @@
|
|||||||
|
# Phase-2 Full-AttnRes Vision Design
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
在当前 roboimi IMF policy 中,把视觉 backbone 里原先由 ResNet BasicBlock/Bottleneck 提供的残差单元全部替换为 AttnRes 风格单元,同时尽量保持现有 agent / cond / rollout / 训练脚本接口不变。
|
||||||
|
|
||||||
|
## User requirement interpretation
|
||||||
|
这里按最严格解释执行:
|
||||||
|
- 不是“在 ResNet 后面再加一个 AttnRes 模块”
|
||||||
|
- 也不是“只在某几个 stage 加 AttnRes 混合”
|
||||||
|
- 而是:视觉主干网络中原本依赖 ResNet residual block 的地方,统一改成 AttnRes residual operator 驱动的 block
|
||||||
|
- 最终仍然输出与现有 `ResNetDiffusionBackbone` 相同的每相机特征接口,以便复用 `SpatialSoftmax -> Linear -> ReLU`、多相机拼接、state concat、IMF head 条件输入
|
||||||
|
|
||||||
|
## Recommended design
|
||||||
|
### Option A (recommended)
|
||||||
|
保留 ResNet 的宏观 stage/stem 结构与通道/步幅规划,但把每个 stage 内的 BasicBlock/Bottleneck 替换为新的 `AttnResImageBlock2D`:
|
||||||
|
- 输入仍是 `(B, C, H, W)` feature map
|
||||||
|
- block 内先把空间维 flatten 成 token 序列 `(B, H*W, C)`
|
||||||
|
- 用二维位置编码 / 可学习位置偏置 + AttnRes self-attention + AttnRes FFN 完成 block 变换
|
||||||
|
- 再 reshape 回 `(B, C, H, W)`
|
||||||
|
- stage 间下采样仍由显式 stride/downsample path 完成
|
||||||
|
|
||||||
|
优点:
|
||||||
|
- 最接近“ResNet 中所有残差都由 AttnRes 代替”的要求
|
||||||
|
- 保留现有视觉输出接口和 cond_dim,不用改 agent/head/data pipeline
|
||||||
|
- 仍可沿用现有多相机编码器框架
|
||||||
|
|
||||||
|
缺点:
|
||||||
|
- 需要新写 2D 版 AttnRes image block,而不是直接复用 1D IMF head block
|
||||||
|
|
||||||
|
### Option B
|
||||||
|
完全移除 ResNet stage,换成 patchify + ViT/AttnRes 图像 transformer,再接 SpatialSoftmax/MLP。
|
||||||
|
|
||||||
|
优点:实现概念更统一。
|
||||||
|
缺点:已经不算“把 ResNet 中残差替换掉”,而是直接换 backbone,和用户要求不完全一致。
|
||||||
|
|
||||||
|
### Option C
|
||||||
|
保留现有 ResNet block,只在 block 外层加 AttnRes mixing。
|
||||||
|
|
||||||
|
不推荐,因为不满足“所有残差均由 AttnRes 替代”。
|
||||||
|
|
||||||
|
## Concrete architecture choice
|
||||||
|
采用 Option A:
|
||||||
|
1. 保留 stem(conv/bn-or-gn/relu/maxpool)与 stage 边界
|
||||||
|
2. 新增 `AttnResImageBlock2D`
|
||||||
|
3. 新增 `AttnResResNetLikeBackbone2D`,负责堆叠 stage/block
|
||||||
|
4. 在 `ResNetDiffusionBackbone` 中增加可选 backbone mode,例如:
|
||||||
|
- `vision_backbone_mode: resnet`
|
||||||
|
- `vision_backbone_mode: attnres_resnet`
|
||||||
|
5. `resnet_imf_attnres` agent 配置新增一个 Phase-2 变体,默认打开 `attnres_resnet`
|
||||||
|
6. 仍保持:
|
||||||
|
- 每相机输出 `64`
|
||||||
|
- 多相机总视觉输出 `3 * 64`
|
||||||
|
- 与 state 拼接后 `cond_dim = 208`
|
||||||
|
|
||||||
|
## Files likely to change
|
||||||
|
- `roboimi/vla/models/backbones/resnet_diffusion.py`
|
||||||
|
- `roboimi/vla/conf/backbone/resnet_diffusion.yaml`
|
||||||
|
- `roboimi/vla/conf/agent/resnet_imf_attnres.yaml`
|
||||||
|
- new: `roboimi/vla/models/backbones/attnres_resnet2d.py`
|
||||||
|
- tests:
|
||||||
|
- new: `tests/test_attnres_resnet2d_backbone.py`
|
||||||
|
- update/add wiring test for agent cond dims
|
||||||
|
|
||||||
|
## Test plan
|
||||||
|
1. New backbone instantiates and forwards `(B,T,C,H,W)` multi-camera input
|
||||||
|
2. Output shape unchanged vs current backbone
|
||||||
|
3. `output_dim == 64`
|
||||||
|
4. 3-camera cond path still yields `208`
|
||||||
|
5. Phase-2 config instantiates full IMF agent successfully
|
||||||
|
6. One short CPU smoke forward for `compute_loss`
|
||||||
|
|
||||||
|
## Phase-2 experiment plan
|
||||||
|
固定使用 Phase-1 最优组合:
|
||||||
|
- `pred_horizon=16`
|
||||||
|
- `num_action_steps=8`
|
||||||
|
|
||||||
|
比较:
|
||||||
|
1. baseline: current IMF head-only AttnRes + original ResNet vision backbone
|
||||||
|
2. phase2: IMF head AttnRes + full AttnRes-replaced vision backbone
|
||||||
|
|
||||||
|
训练超参保持与 Phase-1 最优设置一致,先跑一组 50k step 对比。
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
rank,run_id,status,pred_horizon,num_action_steps,best_rollout_avg_reward,best_step,final_step,final_loss,host,run_dir
|
rank,run_id,status,pred_horizon,num_action_steps,best_rollout_avg_reward,best_step,final_step,final_loss,host,run_dir,latest_step
|
||||||
1,ph16_ex8,finished,16,8,610.8,21874,50000,0.0034315965604037046,100.73.14.65,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223
|
1,ph16_ex8,running,16,8,610.8,21874,50000,0.0034315965604037046,100.73.14.65,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223,50000
|
||||||
2,ph16_ex16,finished,16,16,561.2,48124,50000,0.004544622730463743,100.119.99.14,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex16-emb384-l12-ms50k-l20g0-20260404-131223
|
2,ph16_ex16,running,16,16,561.2,48124,50000,0.004544622730463743,100.119.99.14,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph16-ex16-emb384-l12-ms50k-l20g0-20260404-131223,50000
|
||||||
3,ph32_ex32,finished,32,32,513.2,43749,50000,0.003953303210437298,local,/home/droid/project/roboimi/.worktrees/feat-imf-attnres-policy/runs/imf-p1-ph32-ex32-emb384-l12-ms50k-5090-20260404-131223
|
3,ph32_ex32,finished,32,32,513.2,43749,50000,0.003953303210437298,local,/home/droid/project/roboimi/.worktrees/feat-imf-attnres-policy/runs/imf-p1-ph32-ex32-emb384-l12-ms50k-5090-20260404-131223,49900
|
||||||
4,ph8_ex8,finished,8,8,415.6,48124,50000,0.007008877582848072,100.73.14.65,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph08-ex08-emb384-l12-ms50k-5880g0-20260404-131223
|
4,ph8_ex8,running,8,8,415.6,48124,50000,0.007008877582848072,100.73.14.65,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph08-ex08-emb384-l12-ms50k-5880g0-20260404-131223,50000
|
||||||
5,ph32_ex8,finished,32,8,361.6,43749,50000,0.004788532387465239,100.119.99.14,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph32-ex08-emb384-l12-ms50k-l20g1-20260404-131223
|
5,ph32_ex8,running,32,8,361.6,43749,50000,0.004788532387465239,100.119.99.14,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph32-ex08-emb384-l12-ms50k-l20g1-20260404-131223,50000
|
||||||
6,ph32_ex16,finished,32,16,239.6,48124,50000,0.0038348555099219084,100.119.99.14,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph32-ex16-emb384-l12-ms50k-l20g2-20260404-131223
|
6,ph32_ex16,running,32,16,239.6,48124,50000,0.0038348555099219084,100.119.99.14,/home/droid/roboimi_suite_20260404/runs/imf-p1-ph32-ex16-emb384-l12-ms50k-l20g2-20260404-131223,50000
|
||||||
|
|||||||
|
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"suite_name": "2026-04-04-imf-horizon-grid",
|
"suite_name": "2026-04-04-imf-horizon-grid",
|
||||||
"updated_at": "2026-04-04 23:46:01",
|
"updated_at": "2026-04-05 00:07:39",
|
||||||
"phase": "phase1_completed",
|
"phase": "phase1_completed",
|
||||||
"provisioning": {
|
"provisioning": {
|
||||||
"100.119.99.14": {
|
"100.119.99.14": {
|
||||||
@@ -17,7 +17,7 @@
|
|||||||
},
|
},
|
||||||
"runs": {
|
"runs": {
|
||||||
"ph8_ex8": {
|
"ph8_ex8": {
|
||||||
"status": "finished",
|
"status": "running",
|
||||||
"host": "100.73.14.65",
|
"host": "100.73.14.65",
|
||||||
"gpu": 0,
|
"gpu": 0,
|
||||||
"run_name": "imf-p1-ph08-ex08-emb384-l12-ms50k-5880g0-20260404-131223",
|
"run_name": "imf-p1-ph08-ex08-emb384-l12-ms50k-5880g0-20260404-131223",
|
||||||
@@ -30,15 +30,15 @@
|
|||||||
"pid": 938714,
|
"pid": 938714,
|
||||||
"launch_log": "experiment_suite_launch_logs/imf-p1-ph08-ex08-emb384-l12-ms50k-5880g0-20260404-131223.restartfix-20260404-143827.log",
|
"launch_log": "experiment_suite_launch_logs/imf-p1-ph08-ex08-emb384-l12-ms50k-5880g0-20260404-131223.restartfix-20260404-143827.log",
|
||||||
"latest_step": 50000,
|
"latest_step": 50000,
|
||||||
"latest_log_sync": "2026-04-04 23:42:34",
|
"latest_log_sync": "2026-04-05 00:07:39",
|
||||||
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/i5syc57b6zq7rbkrtqy7b",
|
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/i5syc57b6zq7rbkrtqy7b",
|
||||||
"process_running": false,
|
"process_running": true,
|
||||||
"best_step": 48124,
|
"best_step": 48124,
|
||||||
"best_rollout_avg_reward": 415.6,
|
"best_rollout_avg_reward": 415.6,
|
||||||
"final_loss": 0.007008877582848072
|
"final_loss": 0.007008877582848072
|
||||||
},
|
},
|
||||||
"ph16_ex8": {
|
"ph16_ex8": {
|
||||||
"status": "finished",
|
"status": "running",
|
||||||
"host": "100.73.14.65",
|
"host": "100.73.14.65",
|
||||||
"gpu": 1,
|
"gpu": 1,
|
||||||
"run_name": "imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223",
|
"run_name": "imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223",
|
||||||
@@ -51,15 +51,15 @@
|
|||||||
"pid": 938717,
|
"pid": 938717,
|
||||||
"launch_log": "experiment_suite_launch_logs/imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223.restartfix-20260404-143827.log",
|
"launch_log": "experiment_suite_launch_logs/imf-p1-ph16-ex08-emb384-l12-ms50k-5880g1-20260404-131223.restartfix-20260404-143827.log",
|
||||||
"latest_step": 50000,
|
"latest_step": 50000,
|
||||||
"latest_log_sync": "2026-04-04 23:42:34",
|
"latest_log_sync": "2026-04-05 00:07:39",
|
||||||
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/4rusbrpfxmw4ffii1ul5w",
|
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/4rusbrpfxmw4ffii1ul5w",
|
||||||
"process_running": false,
|
"process_running": true,
|
||||||
"best_step": 21874,
|
"best_step": 21874,
|
||||||
"best_rollout_avg_reward": 610.8,
|
"best_rollout_avg_reward": 610.8,
|
||||||
"final_loss": 0.0034315965604037046
|
"final_loss": 0.0034315965604037046
|
||||||
},
|
},
|
||||||
"ph16_ex16": {
|
"ph16_ex16": {
|
||||||
"status": "finished",
|
"status": "running",
|
||||||
"host": "100.119.99.14",
|
"host": "100.119.99.14",
|
||||||
"gpu": 0,
|
"gpu": 0,
|
||||||
"run_name": "imf-p1-ph16-ex16-emb384-l12-ms50k-l20g0-20260404-131223",
|
"run_name": "imf-p1-ph16-ex16-emb384-l12-ms50k-l20g0-20260404-131223",
|
||||||
@@ -71,16 +71,16 @@
|
|||||||
"num_action_steps": 16,
|
"num_action_steps": 16,
|
||||||
"pid": 90169,
|
"pid": 90169,
|
||||||
"launch_log": "experiment_suite_launch_logs/imf-p1-ph16-ex16-emb384-l12-ms50k-l20g0-20260404-131223.restartfix-20260404-143827.log",
|
"launch_log": "experiment_suite_launch_logs/imf-p1-ph16-ex16-emb384-l12-ms50k-l20g0-20260404-131223.restartfix-20260404-143827.log",
|
||||||
"latest_log_sync": "2026-04-04 23:42:34",
|
"latest_log_sync": "2026-04-05 00:07:39",
|
||||||
"latest_step": 50000,
|
"latest_step": 50000,
|
||||||
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/wwm232k6190gexnze8mg6",
|
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/wwm232k6190gexnze8mg6",
|
||||||
"process_running": false,
|
"process_running": true,
|
||||||
"best_step": 48124,
|
"best_step": 48124,
|
||||||
"best_rollout_avg_reward": 561.2,
|
"best_rollout_avg_reward": 561.2,
|
||||||
"final_loss": 0.004544622730463743
|
"final_loss": 0.004544622730463743
|
||||||
},
|
},
|
||||||
"ph32_ex8": {
|
"ph32_ex8": {
|
||||||
"status": "finished",
|
"status": "running",
|
||||||
"host": "100.119.99.14",
|
"host": "100.119.99.14",
|
||||||
"gpu": 1,
|
"gpu": 1,
|
||||||
"run_name": "imf-p1-ph32-ex08-emb384-l12-ms50k-l20g1-20260404-131223",
|
"run_name": "imf-p1-ph32-ex08-emb384-l12-ms50k-l20g1-20260404-131223",
|
||||||
@@ -92,16 +92,16 @@
|
|||||||
"num_action_steps": 8,
|
"num_action_steps": 8,
|
||||||
"pid": 90173,
|
"pid": 90173,
|
||||||
"launch_log": "experiment_suite_launch_logs/imf-p1-ph32-ex08-emb384-l12-ms50k-l20g1-20260404-131223.restartfix-20260404-143827.log",
|
"launch_log": "experiment_suite_launch_logs/imf-p1-ph32-ex08-emb384-l12-ms50k-l20g1-20260404-131223.restartfix-20260404-143827.log",
|
||||||
"latest_log_sync": "2026-04-04 23:42:34",
|
"latest_log_sync": "2026-04-05 00:07:39",
|
||||||
"latest_step": 50000,
|
"latest_step": 50000,
|
||||||
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/o5y2xjb2rsb3lmfcuhy4p",
|
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/o5y2xjb2rsb3lmfcuhy4p",
|
||||||
"process_running": false,
|
"process_running": true,
|
||||||
"best_step": 43749,
|
"best_step": 43749,
|
||||||
"best_rollout_avg_reward": 361.6,
|
"best_rollout_avg_reward": 361.6,
|
||||||
"final_loss": 0.004788532387465239
|
"final_loss": 0.004788532387465239
|
||||||
},
|
},
|
||||||
"ph32_ex16": {
|
"ph32_ex16": {
|
||||||
"status": "finished",
|
"status": "running",
|
||||||
"host": "100.119.99.14",
|
"host": "100.119.99.14",
|
||||||
"gpu": 2,
|
"gpu": 2,
|
||||||
"run_name": "imf-p1-ph32-ex16-emb384-l12-ms50k-l20g2-20260404-131223",
|
"run_name": "imf-p1-ph32-ex16-emb384-l12-ms50k-l20g2-20260404-131223",
|
||||||
@@ -113,10 +113,10 @@
|
|||||||
"num_action_steps": 16,
|
"num_action_steps": 16,
|
||||||
"pid": 90175,
|
"pid": 90175,
|
||||||
"launch_log": "experiment_suite_launch_logs/imf-p1-ph32-ex16-emb384-l12-ms50k-l20g2-20260404-131223.restartfix-20260404-143827.log",
|
"launch_log": "experiment_suite_launch_logs/imf-p1-ph32-ex16-emb384-l12-ms50k-l20g2-20260404-131223.restartfix-20260404-143827.log",
|
||||||
"latest_log_sync": "2026-04-04 23:42:34",
|
"latest_log_sync": "2026-04-05 00:07:39",
|
||||||
"latest_step": 50000,
|
"latest_step": 50000,
|
||||||
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/54cjpgba9eqsopdm0l8d3",
|
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/54cjpgba9eqsopdm0l8d3",
|
||||||
"process_running": false,
|
"process_running": true,
|
||||||
"best_step": 48124,
|
"best_step": 48124,
|
||||||
"best_rollout_avg_reward": 239.6,
|
"best_rollout_avg_reward": 239.6,
|
||||||
"final_loss": 0.0038348555099219084
|
"final_loss": 0.0038348555099219084
|
||||||
@@ -134,8 +134,8 @@
|
|||||||
"num_action_steps": 32,
|
"num_action_steps": 32,
|
||||||
"pid": 1437836,
|
"pid": 1437836,
|
||||||
"launch_log": "experiment_suites/2026-04-04-imf-horizon-grid/launch_logs/imf-p1-ph32-ex32-emb384-l12-ms50k-5090-20260404-131223.launch.log",
|
"launch_log": "experiment_suites/2026-04-04-imf-horizon-grid/launch_logs/imf-p1-ph32-ex32-emb384-l12-ms50k-5090-20260404-131223.launch.log",
|
||||||
"latest_step": 50000,
|
"latest_step": 49900,
|
||||||
"latest_log_sync": "2026-04-04 23:42:34",
|
"latest_log_sync": "2026-04-05 00:07:39",
|
||||||
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/ajs2m218jd260hawhy5ns",
|
"swanlab_url": "https://swanlab.cn/@game-loader/roboimi-vla/runs/ajs2m218jd260hawhy5ns",
|
||||||
"process_running": false,
|
"process_running": false,
|
||||||
"latest_rollout_avg_reward": 513.2,
|
"latest_rollout_avg_reward": 513.2,
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ _target_: roboimi.vla.models.backbones.resnet_diffusion.ResNetDiffusionBackbone
|
|||||||
# ====================
|
# ====================
|
||||||
vision_backbone: "resnet18" # torchvision 模型名称: resnet18, resnet34, resnet50
|
vision_backbone: "resnet18" # torchvision 模型名称: resnet18, resnet34, resnet50
|
||||||
pretrained_backbone_weights: "IMAGENET1K_V1" # 使用ImageNet预训练权重(torchvision>=0.13)
|
pretrained_backbone_weights: "IMAGENET1K_V1" # 使用ImageNet预训练权重(torchvision>=0.13)
|
||||||
|
vision_backbone_mode: "resnet" # resnet | attnres_resnet
|
||||||
|
|
||||||
# ====================
|
# ====================
|
||||||
# 冻结设置
|
# 冻结设置
|
||||||
@@ -31,3 +32,17 @@ spatial_softmax_num_keypoints: 32 # Spatial Softmax 关键点数量
|
|||||||
# true: 独立编码器(每个摄像头有独立的 ResNet,参数多但容量大)
|
# true: 独立编码器(每个摄像头有独立的 ResNet,参数多但容量大)
|
||||||
use_separate_rgb_encoder_per_camera: true
|
use_separate_rgb_encoder_per_camera: true
|
||||||
num_cameras: 3 # 摄像头数量
|
num_cameras: 3 # 摄像头数量
|
||||||
|
|
||||||
|
# ====================
|
||||||
|
# Full-AttnRes vision trunk(当 vision_backbone_mode=attnres_resnet 时生效)
|
||||||
|
# ====================
|
||||||
|
attnres_stem_dim: 64
|
||||||
|
attnres_stage_dims: [64, 128, 256, 512]
|
||||||
|
attnres_stage_depths: [2, 2, 2, 2]
|
||||||
|
attnres_stage_heads: [4, 4, 8, 8]
|
||||||
|
attnres_stage_kv_heads: [1, 1, 1, 1]
|
||||||
|
attnres_stage_window_sizes: [7, 7, 7, 7]
|
||||||
|
attnres_dropout: 0.0
|
||||||
|
attnres_ffn_mult: 2.667
|
||||||
|
attnres_eps: 1.0e-06
|
||||||
|
attnres_rope_theta: 10000.0
|
||||||
|
|||||||
228
roboimi/vla/models/backbones/attnres_resnet2d.py
Normal file
228
roboimi/vla/models/backbones/attnres_resnet2d.py
Normal file
@@ -0,0 +1,228 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Iterable, Sequence
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from roboimi.vla.models.heads.attnres_transformer_components import AttnResTransformerBackbone
|
||||||
|
|
||||||
|
|
||||||
|
def _make_norm2d(num_channels: int, use_group_norm: bool) -> nn.Module:
|
||||||
|
if use_group_norm:
|
||||||
|
num_groups = max(1, num_channels // 16)
|
||||||
|
return nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)
|
||||||
|
return nn.BatchNorm2d(num_channels)
|
||||||
|
|
||||||
|
|
||||||
|
class _ConvNormAct(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
*,
|
||||||
|
kernel_size: int,
|
||||||
|
stride: int,
|
||||||
|
padding: int,
|
||||||
|
use_group_norm: bool,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.block = nn.Sequential(
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
bias=False,
|
||||||
|
),
|
||||||
|
_make_norm2d(out_channels, use_group_norm),
|
||||||
|
nn.SiLU(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.block(x)
|
||||||
|
|
||||||
|
|
||||||
|
class AttnResImageBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
*,
|
||||||
|
window_size: int,
|
||||||
|
n_heads: int,
|
||||||
|
n_kv_heads: int,
|
||||||
|
dropout: float,
|
||||||
|
ffn_mult: float,
|
||||||
|
eps: float,
|
||||||
|
rope_theta: float,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.window_size = int(window_size)
|
||||||
|
self.block = AttnResTransformerBackbone(
|
||||||
|
d_model=dim,
|
||||||
|
n_blocks=1,
|
||||||
|
n_heads=n_heads,
|
||||||
|
n_kv_heads=n_kv_heads,
|
||||||
|
max_seq_len=self.window_size * self.window_size,
|
||||||
|
dropout=dropout,
|
||||||
|
ffn_mult=ffn_mult,
|
||||||
|
eps=eps,
|
||||||
|
rope_theta=rope_theta,
|
||||||
|
causal_attn=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
bsz, channels, height, width = x.shape
|
||||||
|
ws = self.window_size
|
||||||
|
pad_h = (ws - height % ws) % ws
|
||||||
|
pad_w = (ws - width % ws) % ws
|
||||||
|
if pad_h or pad_w:
|
||||||
|
x = F.pad(x, (0, pad_w, 0, pad_h))
|
||||||
|
padded_height, padded_width = x.shape[-2:]
|
||||||
|
num_h = padded_height // ws
|
||||||
|
num_w = padded_width // ws
|
||||||
|
|
||||||
|
windows = (
|
||||||
|
x.permute(0, 2, 3, 1)
|
||||||
|
.contiguous()
|
||||||
|
.view(bsz, num_h, ws, num_w, ws, channels)
|
||||||
|
.permute(0, 1, 3, 2, 4, 5)
|
||||||
|
.contiguous()
|
||||||
|
.view(bsz * num_h * num_w, ws * ws, channels)
|
||||||
|
)
|
||||||
|
windows = self.block(windows)
|
||||||
|
x = (
|
||||||
|
windows.view(bsz, num_h, num_w, ws, ws, channels)
|
||||||
|
.permute(0, 1, 3, 2, 4, 5)
|
||||||
|
.contiguous()
|
||||||
|
.view(bsz, padded_height, padded_width, channels)
|
||||||
|
.permute(0, 3, 1, 2)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
return x[:, :, :height, :width]
|
||||||
|
|
||||||
|
|
||||||
|
class _AttnResStage2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
*,
|
||||||
|
depth: int,
|
||||||
|
downsample_stride: int,
|
||||||
|
use_group_norm: bool,
|
||||||
|
window_size: int,
|
||||||
|
n_heads: int,
|
||||||
|
n_kv_heads: int,
|
||||||
|
dropout: float,
|
||||||
|
ffn_mult: float,
|
||||||
|
eps: float,
|
||||||
|
rope_theta: float,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.downsample = None
|
||||||
|
if in_channels != out_channels or downsample_stride != 1:
|
||||||
|
kernel_size = 1 if downsample_stride == 1 else 3
|
||||||
|
padding = 0 if downsample_stride == 1 else 1
|
||||||
|
self.downsample = _ConvNormAct(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=downsample_stride,
|
||||||
|
padding=padding,
|
||||||
|
use_group_norm=use_group_norm,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
AttnResImageBlock2D(
|
||||||
|
out_channels,
|
||||||
|
window_size=window_size,
|
||||||
|
n_heads=n_heads,
|
||||||
|
n_kv_heads=n_kv_heads,
|
||||||
|
dropout=dropout,
|
||||||
|
ffn_mult=ffn_mult,
|
||||||
|
eps=eps,
|
||||||
|
rope_theta=rope_theta,
|
||||||
|
)
|
||||||
|
for _ in range(depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.downsample is not None:
|
||||||
|
x = self.downsample(x)
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AttnResResNetLikeBackbone2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
input_channels: int = 3,
|
||||||
|
stem_dim: int = 64,
|
||||||
|
stage_dims: Sequence[int] = (64, 128, 256, 512),
|
||||||
|
stage_depths: Sequence[int] = (2, 2, 2, 2),
|
||||||
|
stage_heads: Sequence[int] = (4, 4, 8, 8),
|
||||||
|
stage_kv_heads: Sequence[int] = (1, 1, 1, 1),
|
||||||
|
stage_window_sizes: Sequence[int] = (7, 7, 7, 7),
|
||||||
|
use_group_norm: bool = True,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
ffn_mult: float = 2.667,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
rope_theta: float = 10000.0,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
lengths = {
|
||||||
|
len(stage_dims),
|
||||||
|
len(stage_depths),
|
||||||
|
len(stage_heads),
|
||||||
|
len(stage_kv_heads),
|
||||||
|
len(stage_window_sizes),
|
||||||
|
}
|
||||||
|
if len(lengths) != 1:
|
||||||
|
raise ValueError('stage_dims/depths/heads/kv_heads/window_sizes 长度必须一致')
|
||||||
|
|
||||||
|
self.stem = nn.Sequential(
|
||||||
|
nn.Conv2d(input_channels, stem_dim, kernel_size=7, stride=2, padding=3, bias=False),
|
||||||
|
_make_norm2d(stem_dim, use_group_norm),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
in_channels = stem_dim
|
||||||
|
stages = []
|
||||||
|
for stage_idx, (out_channels, depth, n_heads, n_kv_heads, window_size) in enumerate(
|
||||||
|
zip(stage_dims, stage_depths, stage_heads, stage_kv_heads, stage_window_sizes)
|
||||||
|
):
|
||||||
|
stages.append(
|
||||||
|
_AttnResStage2D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
depth=int(depth),
|
||||||
|
downsample_stride=1 if stage_idx == 0 else 2,
|
||||||
|
use_group_norm=use_group_norm,
|
||||||
|
window_size=int(window_size),
|
||||||
|
n_heads=int(n_heads),
|
||||||
|
n_kv_heads=int(n_kv_heads),
|
||||||
|
dropout=float(dropout),
|
||||||
|
ffn_mult=float(ffn_mult),
|
||||||
|
eps=float(eps),
|
||||||
|
rope_theta=float(rope_theta),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
in_channels = out_channels
|
||||||
|
|
||||||
|
self.stages = nn.ModuleList(stages)
|
||||||
|
self.output_channels = in_channels
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.stem(x)
|
||||||
|
for stage in self.stages:
|
||||||
|
x = stage(x)
|
||||||
|
return x
|
||||||
@@ -6,6 +6,8 @@ import torchvision
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Callable, Optional, Tuple, Union
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from .attnres_resnet2d import AttnResResNetLikeBackbone2D
|
||||||
|
|
||||||
def _replace_submodules(
|
def _replace_submodules(
|
||||||
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
|
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
@@ -103,6 +105,17 @@ class _SingleRgbEncoder(nn.Module):
|
|||||||
use_group_norm: bool,
|
use_group_norm: bool,
|
||||||
spatial_softmax_num_keypoints: int,
|
spatial_softmax_num_keypoints: int,
|
||||||
freeze_backbone: bool = True, # 新增:是否冻结backbone
|
freeze_backbone: bool = True, # 新增:是否冻结backbone
|
||||||
|
vision_backbone_mode: str = "resnet",
|
||||||
|
attnres_stem_dim: int = 64,
|
||||||
|
attnres_stage_dims: Optional[Tuple[int, ...]] = None,
|
||||||
|
attnres_stage_depths: Optional[Tuple[int, ...]] = None,
|
||||||
|
attnres_stage_heads: Optional[Tuple[int, ...]] = None,
|
||||||
|
attnres_stage_kv_heads: Optional[Tuple[int, ...]] = None,
|
||||||
|
attnres_stage_window_sizes: Optional[Tuple[int, ...]] = None,
|
||||||
|
attnres_dropout: float = 0.0,
|
||||||
|
attnres_ffn_mult: float = 2.667,
|
||||||
|
attnres_eps: float = 1e-6,
|
||||||
|
attnres_rope_theta: float = 10000.0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -119,6 +132,7 @@ class _SingleRgbEncoder(nn.Module):
|
|||||||
self.do_crop = False
|
self.do_crop = False
|
||||||
crop_shape = input_shape[1:]
|
crop_shape = input_shape[1:]
|
||||||
|
|
||||||
|
if vision_backbone_mode == "resnet":
|
||||||
# 设置骨干网络
|
# 设置骨干网络
|
||||||
backbone_model = getattr(torchvision.models, vision_backbone)(
|
backbone_model = getattr(torchvision.models, vision_backbone)(
|
||||||
weights=pretrained_backbone_weights
|
weights=pretrained_backbone_weights
|
||||||
@@ -131,8 +145,28 @@ class _SingleRgbEncoder(nn.Module):
|
|||||||
self.backbone = _replace_submodules(
|
self.backbone = _replace_submodules(
|
||||||
root_module=self.backbone,
|
root_module=self.backbone,
|
||||||
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
||||||
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
|
func=lambda x: nn.GroupNorm(
|
||||||
|
num_groups=max(1, x.num_features // 16),
|
||||||
|
num_channels=x.num_features,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
elif vision_backbone_mode == "attnres_resnet":
|
||||||
|
self.backbone = AttnResResNetLikeBackbone2D(
|
||||||
|
input_channels=input_shape[0],
|
||||||
|
stem_dim=attnres_stem_dim,
|
||||||
|
stage_dims=tuple(attnres_stage_dims or (64, 128, 256, 512)),
|
||||||
|
stage_depths=tuple(attnres_stage_depths or (2, 2, 2, 2)),
|
||||||
|
stage_heads=tuple(attnres_stage_heads or (4, 4, 8, 8)),
|
||||||
|
stage_kv_heads=tuple(attnres_stage_kv_heads or (1, 1, 1, 1)),
|
||||||
|
stage_window_sizes=tuple(attnres_stage_window_sizes or (7, 7, 7, 7)),
|
||||||
|
use_group_norm=use_group_norm,
|
||||||
|
dropout=attnres_dropout,
|
||||||
|
ffn_mult=attnres_ffn_mult,
|
||||||
|
eps=attnres_eps,
|
||||||
|
rope_theta=attnres_rope_theta,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"不支持的 vision_backbone_mode: {vision_backbone_mode}")
|
||||||
|
|
||||||
# 冻结backbone参数(可选)
|
# 冻结backbone参数(可选)
|
||||||
if freeze_backbone:
|
if freeze_backbone:
|
||||||
@@ -180,6 +214,17 @@ class ResNetDiffusionBackbone(VLABackbone):
|
|||||||
num_cameras: int = 1, # 新增:摄像头数量(仅在独立编码器模式下使用)
|
num_cameras: int = 1, # 新增:摄像头数量(仅在独立编码器模式下使用)
|
||||||
camera_names: Optional[Tuple[str, ...]] = None, # 显式相机顺序
|
camera_names: Optional[Tuple[str, ...]] = None, # 显式相机顺序
|
||||||
freeze_backbone: bool = True, # 新增:是否冻结ResNet backbone(推荐True)
|
freeze_backbone: bool = True, # 新增:是否冻结ResNet backbone(推荐True)
|
||||||
|
vision_backbone_mode: str = "resnet",
|
||||||
|
attnres_stem_dim: int = 64,
|
||||||
|
attnres_stage_dims: Optional[Tuple[int, ...]] = None,
|
||||||
|
attnres_stage_depths: Optional[Tuple[int, ...]] = None,
|
||||||
|
attnres_stage_heads: Optional[Tuple[int, ...]] = None,
|
||||||
|
attnres_stage_kv_heads: Optional[Tuple[int, ...]] = None,
|
||||||
|
attnres_stage_window_sizes: Optional[Tuple[int, ...]] = None,
|
||||||
|
attnres_dropout: float = 0.0,
|
||||||
|
attnres_ffn_mult: float = 2.667,
|
||||||
|
attnres_eps: float = 1e-6,
|
||||||
|
attnres_rope_theta: float = 10000.0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -203,6 +248,17 @@ class ResNetDiffusionBackbone(VLABackbone):
|
|||||||
use_group_norm=use_group_norm,
|
use_group_norm=use_group_norm,
|
||||||
spatial_softmax_num_keypoints=spatial_softmax_num_keypoints,
|
spatial_softmax_num_keypoints=spatial_softmax_num_keypoints,
|
||||||
freeze_backbone=freeze_backbone,
|
freeze_backbone=freeze_backbone,
|
||||||
|
vision_backbone_mode=vision_backbone_mode,
|
||||||
|
attnres_stem_dim=attnres_stem_dim,
|
||||||
|
attnres_stage_dims=attnres_stage_dims,
|
||||||
|
attnres_stage_depths=attnres_stage_depths,
|
||||||
|
attnres_stage_heads=attnres_stage_heads,
|
||||||
|
attnres_stage_kv_heads=attnres_stage_kv_heads,
|
||||||
|
attnres_stage_window_sizes=attnres_stage_window_sizes,
|
||||||
|
attnres_dropout=attnres_dropout,
|
||||||
|
attnres_ffn_mult=attnres_ffn_mult,
|
||||||
|
attnres_eps=attnres_eps,
|
||||||
|
attnres_rope_theta=attnres_rope_theta,
|
||||||
)
|
)
|
||||||
for _ in range(num_cameras)
|
for _ in range(num_cameras)
|
||||||
]
|
]
|
||||||
@@ -220,6 +276,17 @@ class ResNetDiffusionBackbone(VLABackbone):
|
|||||||
use_group_norm=use_group_norm,
|
use_group_norm=use_group_norm,
|
||||||
spatial_softmax_num_keypoints=spatial_softmax_num_keypoints,
|
spatial_softmax_num_keypoints=spatial_softmax_num_keypoints,
|
||||||
freeze_backbone=freeze_backbone,
|
freeze_backbone=freeze_backbone,
|
||||||
|
vision_backbone_mode=vision_backbone_mode,
|
||||||
|
attnres_stem_dim=attnres_stem_dim,
|
||||||
|
attnres_stage_dims=attnres_stage_dims,
|
||||||
|
attnres_stage_depths=attnres_stage_depths,
|
||||||
|
attnres_stage_heads=attnres_stage_heads,
|
||||||
|
attnres_stage_kv_heads=attnres_stage_kv_heads,
|
||||||
|
attnres_stage_window_sizes=attnres_stage_window_sizes,
|
||||||
|
attnres_dropout=attnres_dropout,
|
||||||
|
attnres_ffn_mult=attnres_ffn_mult,
|
||||||
|
attnres_eps=attnres_eps,
|
||||||
|
attnres_rope_theta=attnres_rope_theta,
|
||||||
)
|
)
|
||||||
self.feature_dim = self.rgb_encoder.feature_dim
|
self.feature_dim = self.rgb_encoder.feature_dim
|
||||||
|
|
||||||
|
|||||||
26
tests/test_attnres_resnet2d_backbone.py
Normal file
26
tests/test_attnres_resnet2d_backbone.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class AttnResResNet2DBackboneTest(unittest.TestCase):
|
||||||
|
def test_backbone_preserves_resnet_like_stage_contract(self):
|
||||||
|
from roboimi.vla.models.backbones.attnres_resnet2d import AttnResResNetLikeBackbone2D
|
||||||
|
|
||||||
|
backbone = AttnResResNetLikeBackbone2D(
|
||||||
|
input_channels=3,
|
||||||
|
stem_dim=16,
|
||||||
|
stage_dims=(16, 32, 64, 128),
|
||||||
|
stage_depths=(1, 1, 1, 1),
|
||||||
|
stage_heads=(2, 4, 4, 8),
|
||||||
|
stage_kv_heads=(1, 1, 1, 1),
|
||||||
|
stage_window_sizes=(7, 7, 7, 7),
|
||||||
|
dropout=0.0,
|
||||||
|
)
|
||||||
|
x = torch.randn(2, 3, 56, 56)
|
||||||
|
y = backbone(x)
|
||||||
|
self.assertEqual(y.shape, (2, 128, 2, 2))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
@@ -422,6 +422,32 @@ class IMFVLAAgentTest(unittest.TestCase):
|
|||||||
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], agent.per_step_cond_dim)
|
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], agent.per_step_cond_dim)
|
||||||
self.assertEqual(agent.noise_pred_net.constructor_kwargs['backbone_type'], 'attnres_full')
|
self.assertEqual(agent.noise_pred_net.constructor_kwargs['backbone_type'], 'attnres_full')
|
||||||
|
|
||||||
|
def test_hydra_config_instantiates_resnet_imf_attnres_with_full_attnres_vision_backbone(self):
|
||||||
|
cfg = _compose_cfg(
|
||||||
|
overrides=[
|
||||||
|
'agent=resnet_imf_attnres',
|
||||||
|
'agent.vision_backbone.vision_backbone_mode=attnres_resnet',
|
||||||
|
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||||
|
'agent.vision_backbone.input_shape=[3,56,56]',
|
||||||
|
'agent.vision_backbone.freeze_backbone=false',
|
||||||
|
'agent.vision_backbone.attnres_stem_dim=16',
|
||||||
|
'agent.vision_backbone.attnres_stage_dims=[16,32,64,128]',
|
||||||
|
'agent.vision_backbone.attnres_stage_depths=[1,1,1,1]',
|
||||||
|
'agent.vision_backbone.attnres_stage_heads=[2,4,4,8]',
|
||||||
|
'agent.vision_backbone.attnres_stage_kv_heads=[1,1,1,1]',
|
||||||
|
'agent.vision_backbone.attnres_stage_window_sizes=[7,7,7,7]',
|
||||||
|
'agent.head.n_layer=1',
|
||||||
|
'agent.head.n_emb=16',
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
with _stub_optional_modules(include_imf_head=True):
|
||||||
|
agent = instantiate(cfg.agent)
|
||||||
|
|
||||||
|
self.assertEqual(agent.vision_encoder.output_dim, 64)
|
||||||
|
self.assertEqual(agent.per_step_cond_dim, 64 * agent.num_cams + agent.obs_dim)
|
||||||
|
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], agent.per_step_cond_dim)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user