10 KiB
IMF-AttnRes Policy Migration Design
Date: 2026-04-01 Status: Approved in chat, written spec pending review
Goal
将 /home/droid/project/diffusion_policy 中提交 185ed659 的 IMF-AttnRes diffusion policy 迁移到当前 roboimi 仓库,作为当前 DiT / Transformer diffusion policy 的替代训练选项;同时迁移其训练目标与一步推理机制,并保持 RoboIMI 现有的仿真环境、三相机视觉输入、数据集格式、训练脚本和 rollout 验证工作流可继续使用。
Non-Goals
- 不迁移 external repo 中与当前任务无关的 obs encoder、dataset、env wrapper、PushT 专用逻辑。
- 不强行复刻 external repo 中全部目录结构;仅迁移当前 RoboIMI 训练所必需的模型、loss、inference 语义。
- 不在本次工作中同时保留旧 DiT 为默认训练目标;旧配置继续可用,但新模型单独提供 config 入口。
User-Confirmed Requirements
- 迁移对象是
185ed659中的 IMF-AttnRes 模型相关代码。 - 不只是迁移骨架,还要迁移:
- 训练目标
- 一步推理机制
- 视觉输入与当前 RoboIMI diffusion policy 一致:
- 使用三个相机图像作为条件输入
- 图像观测必须作为条件,而不是拼进输出预测目标
- 当前任务里,IMF policy 用来替代现有 DiT/Transformer diffusion policy 训练。
- 训练参数沿用最近一次训练的大体设置(后续由训练命令显式覆盖),但推理方式改为 IMF 的 one-step 机制。
- 用户接受 IMF 中“全注意力 / 非因果注意力”的实现约束。
External Source of Truth
迁移语义以 external repo 的以下文件为准:
diffusion_policy/model/diffusion/attnres_transformer_components.pydiffusion_policy/model/diffusion/imf_transformer_for_diffusion.pydiffusion_policy/policy/imf_transformer_hybrid_image_policy.py- 参考配置:
image_pusht_diffusion_policy_dit_imf_attnres_full.yaml
其中最关键的差异是:该策略并非 DDPM/DDIM 多步去噪,而是 IMF 训练目标 + one-step 推理。
Current RoboIMI Baseline
当前 RoboIMI 中与该任务直接相关的基线如下:
- 视觉编码:
ResNetDiffusionBackbone- 三相机:
r_vis,top,front - 每个时间步将相机特征与
qpos拼接为 per-step condition
- 三相机:
- 策略主体:
VLAAgentcompute_loss()使用 DDPM 噪声预测损失predict_action()使用 DDIM 多步采样- 在线控制通过动作队列机制在
select_action()中按 chunk 触发预测
- 训练脚本:
roboimi/demos/vla_scripts/train_vla.py- 支持 GPU 训练、SwanLab 日志、headless rollout 验证
因此,本次迁移的核心不是换视觉 backbone,而是替换 head + loss + inference semantics。
Recommended Integration Approach
采用 最小侵入式集成:
- 保留当前 RoboIMI 的视觉编码、数据读取、rollout/eval、训练脚本主框架。
- 新增 IMF 专用 head 模块,在 RoboIMI 内本地实现:
- AttnRes 组件
- IMF transformer 主体
- 新增 IMF 专用 agent,复用当前
VLAAgent的:- 归一化逻辑
- 相机顺序管理
- 观测缓存 / 动作 chunk 缓存
- rollout 接口 但覆盖:
compute_loss()predict_action()
- 新增独立 Hydra config,让 IMF policy 作为新的 agent 选项,不破坏已有 resnet_transformer / gr00t_dit 配置。
这样做的原因:
- 迁移 IMF 语义时不必把当前 DDPM agent 搅乱;
- rollout / eval / checkpoint 逻辑仍然可复用;
- 便于和现有 Transformer / DiT 直接做 A/B 对比训练。
Architecture
1. Observation / Conditioning Path
沿用当前 RoboIMI 的视觉路径:
- 输入观测:
images={r_vis, top, front}+qpos ResNetDiffusionBackbone对每个相机编码,得到 per-camera featurestate_encoder编码qpos- 将三相机特征与 state feature 按时间步拼接,形成
per_step_cond
这里不迁移 external repo 的 obs_encoder 实现;我们只对齐 “图像作为条件 token 输入 transformer” 这一语义。
2. Condition Tokenization
对齐 external IMF transformer 的 token 使用方式:
- action trajectory token:由
(B, pred_horizon, action_dim)通过线性层映射到n_emb - time token:两个标量
r与t,分别通过 sinusoidal embedding + linear projection 得到 token - observation token:
per_step_cond通过线性层映射到n_emb - 最终 token 序列为:
[r_token, t_token, obs_cond_tokens..., action_tokens...]
在当前任务中,obs token 数量等于 obs_horizon,且图像观测始终作为条件输入。
3. IMF-AttnRes Backbone
在 RoboIMI 内新增 AttnRes backbone 实现,保持 external commit 的关键语义:
RMSNorm/RMSNormNoWeight- RoPE
- Grouped Query Self-Attention
- SwiGLU FFN
- AttnRes operator / residual source aggregation
AttnResTransformerBackbone
并保持:
- full attention(不使用因果注意力)
backbone_type='attnres_full'- 输出仅切回 action token 部分,再经过最终 norm + head 得到 velocity-like 输出
4. Training Objective
训练目标从当前 DDPM epsilon prediction 改为 external IMF 目标:
给定真实轨迹 x 与随机噪声 e:
- 采样
t ~ U(0,1)、r ~ U(0,1),并排序为t >= r - 构造插值状态:
z_t = (1 - t) x + t e
- 用模型计算:
v = f(z_t, t, t, cond)
- 对
g(z, r, t) = f(z, r, t, cond)做 JVP,得到:u, du_dt
- 构造 compound velocity:
V = u + (t - r) * du_dt
- 目标为:
target = e - x
- 用 action 维度上的 MSE 作为最终损失
RoboIMI 现有 batch 中的 action_is_pad 仍要保留支持;如果存在 padding,只在有效 action 上计算损失。
5. One-Step Inference
推理改为 external IMF 的一步采样语义:
- 从标准高斯初始化 action trajectory
z_t - 计算
u = f(z_t, r=0, t=1, cond) - 一步更新:
x_hat = z_t - (t-r) * u = z_t - u
- 反归一化得到动作序列
这意味着:
num_inference_steps对 IMF policy 固定为1- 不再调用 DDIM scheduler 的多步
step() - 在线控制中仍沿用当前 chunk 机制:
- 动作队列为空时触发一次
predict_action_chunk() - 取预测序列中
[obs_horizon-1 : obs_horizon-1+num_action_steps]这一段入队
- 动作队列为空时触发一次
也就是说,触发模型前向的规则不变,改变的是每次触发后的动作序列生成方式。
API / Code Structure
计划中的主要代码边界如下:
roboimi/vla/models/heads/attnres_transformer_components.py- IMF AttnRes 基础组件
roboimi/vla/models/heads/imf_transformer1d.py- RoboIMI 版本 IMF transformer head
- 对外暴露
forward(sample, r, t, cond=None) - 暴露
get_optim_groups()供 AdamW 分组使用
roboimi/vla/agent_imf.py- 复用
VLAAgent的观测处理 / normalization / queue 基础设施 - 覆盖 IMF 的训练损失与 one-step 预测逻辑
- 复用
- Hydra config
roboimi/vla/conf/head/imf_transformer1d.yamlroboimi/vla/conf/agent/resnet_imf_attnres.yaml
训练脚本主流程尽量不改;只要求它能 instantiate 新 agent 并继续使用当前 rollout / checkpoint / swanlab 逻辑。
Compatibility Decisions
Initial Config Defaults To Preserve
为避免迁移时语义漂移,首版 IMF 配置默认值明确固定为:
backbone_type: attnres_fulln_head: 1n_kv_head: 1n_cond_layers: 0time_as_cond: truecausal_attn: falsenum_inference_steps: 1
这些默认值与 external 185ed659 的 IMF-AttnRes 使用方式保持一致;后续调参可以覆盖,但首版迁移必须先以该语义跑通。
Reuse From RoboIMI
保留:
- 三相机数据读取方式
- ResNet visual backbone
- qpos / action normalization
- 训练循环、优化器、scheduler、SwanLab、headless rollout
select_action()的在线 chunk 执行方式
Replace With External IMF Semantics
替换:
- transformer head 实现
- diffusion training objective
- inference sampling semantics
Intentionally Not Mirrored 1:1
不强行与 external repo 一致的部分:
- external repo 的整体 policy 基类继承体系
- external repo 的 obs encoder 模块树
- external repo 的 normalizer / mask generator 框架
原因是当前 RoboIMI 已有稳定的数据接口和 rollout 流程,直接嫁接进去更稳。
Testing / Verification Strategy
迁移完成后至少验证以下内容:
- 单元 / 冒烟验证
- IMF head 前向 shape 正确
- IMF agent
compute_loss()在真实 batch 上可前向、反向 - IMF agent
predict_action()能输出(B, pred_horizon, action_dim)
- 训练链路验证
- 使用 GPU 跑一个短训练任务,确认:
- dataloader 正常
- optimizer / lr scheduler 正常
- SwanLab 正常记录配置和训练指标
- 使用 GPU 跑一个短训练任务,确认:
- rollout 验证
- 训练中周期性 headless rollout 能跑通
- 环境仍按 EE-style
step()接收动作
- 最终交付
- 用用户指定的同类超参数启动正式训练
Risks and Mitigations
Risk 1: JVP 在 CUDA 注意力内核上不稳定
缓解:沿用 external repo 的策略,在 JVP 路径上切换到 math SDP kernel,必要时 fallback 到 torch.autograd.functional.jvp。同时,JVP 的切线构造与 u, du_dt 计算流程必须严格对齐 external source,不在本次迁移中自行改写其数学语义。
Risk 2: Optimizer 参数分组遗漏新模块
缓解:IMF head 提供 get_optim_groups(),并在训练脚本中按“只要 head 提供该接口就使用”的策略统一处理,而不是绑定旧 head_type。
Risk 3: 现有 rollout 逻辑假定 DDIM 多步采样
缓解:保持 select_action() / predict_action_chunk() 接口不变,只替换 predict_action() 内部实现,确保 eval 代码无需理解 IMF 细节。
Risk 4: 训练命令参数与新 config 不一致
缓解:新增独立 agent config,并保留此前训练参数作为显式 CLI override 模板。
Success Criteria
以下条件全部满足,视为本次迁移成功:
- RoboIMI 中新增 IMF-AttnRes policy,可通过 Hydra config 单独启用。
- 训练时使用 external IMF 的 loss,而不是当前 DDPM epsilon loss。
- 推理时使用 one-step IMF 采样,而不是 DDIM 多步采样。
- 三相机图像始终作为条件输入参与模型前向。
- 在线 rollout 能在 headless 仿真环境中跑通。
- 能按最近一次实验参数模板成功启动训练。