feat: add pusht imf attnres backbone
This commit is contained in:
108
docs/superpowers/specs/2026-03-29-pusht-imf-attnres-design.md
Normal file
108
docs/superpowers/specs/2026-03-29-pusht-imf-attnres-design.md
Normal file
@@ -0,0 +1,108 @@
|
||||
# PushT Image iMF AttnRes Design
|
||||
|
||||
## Goal
|
||||
在现有 PushT 图像 iMF full-attention 路线之上,引入 `attn_res` 仓库中的 **Full AttnRes** 残差聚合形式,并同步使用与其匹配的 **RMSNorm + 自注意力 + SwiGLU FFN** 模块,保持 iMF 训练目标与一步推理语义不变,仅作用于本次实验链路。实现完成并验证后,启动与此前相同的 9 组 `n_emb × n_layer` 扫描(350 epochs, seed=42, SwanLab online, 无视频记录)。
|
||||
|
||||
## Scope
|
||||
本次工作仅覆盖:
|
||||
1. 为 `IMFTransformerForDiffusion` 增加一个 AttnRes-backed backbone 变体;
|
||||
2. 保持 `forward(sample, r, t, cond=None)`、iMF loss、一步推理策略接口不变;
|
||||
3. 新增独立 PushT 图像配置用于该变体;
|
||||
4. 复用本地 5090 + 远端 5880 双卡三路并行调度 9 组实验。
|
||||
|
||||
不在范围内:
|
||||
- 不替换已有 vanilla iMF/full-attn 配置;
|
||||
- 不修改 DiT baseline;
|
||||
- 不增加视频日志;
|
||||
- 不扩大到多 seed。
|
||||
|
||||
## Recommended Approach
|
||||
采用“**在当前 iMF 模型内增加可选 AttnRes backbone**”的方式,而不是新建独立 policy 链路。
|
||||
|
||||
理由:
|
||||
- policy / workspace / loss / sampling 路径已经被验证,保留这些路径可最大程度缩小变动面;
|
||||
- 仅在模型内部切换 backbone,可以让新实验与既有 iMF 结果保持可比;
|
||||
- 配置上只需显式打开 `backbone_type=attnres_full`、`causal_attn=false` 等开关,复现实验更直接。
|
||||
|
||||
## Architecture
|
||||
### 1. Backbone split
|
||||
`IMFTransformerForDiffusion` 保留现有 vanilla encoder/decoder 实现为默认路径,并新增 `attnres_full` 路径:
|
||||
- **vanilla**:保持当前实现不变;
|
||||
- **attnres_full**:使用单栈式全注意力 Transformer,输入 token 序列为
|
||||
`[r token, t token, obs cond tokens..., action/sample tokens...]`。
|
||||
|
||||
模型只对末尾的 action/sample token 位置输出 `u` 预测,前置条件 token 仅参与上下文建模。
|
||||
|
||||
### 2. AttnRes stack
|
||||
新 backbone 使用以下模块:
|
||||
- `RMSNorm`
|
||||
- `Rotary Position Embedding`(用于自注意力 q/k)
|
||||
- `GroupedQueryAttention`(本实验默认 `n_kv_head=1`,与单头配置兼容)
|
||||
- `SwiGLU` FFN
|
||||
- `AttnResOperator`(每个子层一个 pseudo-query,执行 full depth-wise residual aggregation)
|
||||
|
||||
每个 transformer block 由两个子层组成:
|
||||
1. self-attention 子层
|
||||
2. FFN 子层
|
||||
|
||||
每个子层的输入不再是简单 `x + f(x)`,而是从 embedding 与全部历史子层输出中通过 Full AttnRes 聚合得到 `h_l`,再执行 `RMSNorm(h_l) -> sublayer_fn(...)`。
|
||||
|
||||
### 3. Conditioning and token flow
|
||||
- `sample` 先经 `input_emb` 映射为 action tokens;
|
||||
- `r` 和 `t` 各自经 `SinusoidalPosEmb + linear` 映射为两个条件 token;
|
||||
- 图像观测编码后的 `cond` 通过 `cond_obs_emb` 映射为 obs tokens;
|
||||
- 拼接后的完整 token 序列进入 AttnRes stack;
|
||||
- 输出时切掉前置条件 token,仅保留 action/sample token 段,随后经 `RMSNorm + head` 得到最终 `u`。
|
||||
|
||||
### 4. Attention mode
|
||||
本次实验链路固定为 **non-causal full attention**:
|
||||
- `causal_attn=false`
|
||||
- 不构造 causal mask
|
||||
- 所有 token 可彼此双向可见
|
||||
|
||||
这与用户指定的“训练过程仍然使用全注意力(不加因果注意)”一致。
|
||||
|
||||
## Config and Logging
|
||||
新增独立配置文件,例如:
|
||||
- `image_pusht_diffusion_policy_dit_imf_attnres_full.yaml`
|
||||
|
||||
该配置需要:
|
||||
- 指向现有 `IMFTransformerHybridImagePolicy`
|
||||
- 显式开启 AttnRes backbone 相关参数
|
||||
- 设置 `policy.causal_attn=false`
|
||||
- 保持 `logging.backend=swanlab`、`logging.mode=online`
|
||||
- 运行时通过覆盖保证:
|
||||
- `logging.name=<unique_run_name>`
|
||||
- `logging.group=imf_pusht_attnres_arch_sweep`
|
||||
- `exp_name=<unique_run_name>`
|
||||
- 保持 `task.env_runner.n_test_vis=0` 与 `n_train_vis=0`,仅记录标量
|
||||
|
||||
## Experiment Matrix
|
||||
固定 9 组:
|
||||
- `n_emb ∈ {128, 256, 384}`
|
||||
- `n_layer ∈ {6, 12, 18}`
|
||||
- `seed=42`
|
||||
- `training.num_epochs=350`
|
||||
|
||||
## Scheduling
|
||||
沿用之前验证过的三队列分配:
|
||||
- 本机 5090:`384x18`, `256x6`, `128x6`
|
||||
- 5880 GPU0:`384x12`, `256x12`, `128x12`
|
||||
- 5880 GPU1:`384x6`, `256x18`, `128x18`
|
||||
|
||||
每个 run name 编码 backbone 与结构,例如:
|
||||
`imf_attnres_emb256_layer12_seed42_5880gpu0`
|
||||
|
||||
## Verification
|
||||
实现阶段至少验证:
|
||||
1. 新配置的 SwanLab 命名与 `causal_attn=false` 正确;
|
||||
2. 新 backbone 的 forward shape 与 `configure_optimizers()` 可用;
|
||||
3. 旧 vanilla 路径测试不回归;
|
||||
4. `training.debug=true` smoke run 可以完整通过。
|
||||
|
||||
## Success Criteria
|
||||
1. 新 AttnRes iMF 变体在本分支可训练、可一步推理;
|
||||
2. 不影响已有 vanilla iMF/full-attn 链路;
|
||||
3. 9 组实验成功在三张卡上正式启动;
|
||||
4. SwanLab run 名称唯一,无冲突;
|
||||
5. 不记录视频,仅记录标量。
|
||||
Reference in New Issue
Block a user