docs: clarify pusht imf spec
This commit is contained in:
@@ -62,27 +62,42 @@ The iMF path reuses:
|
||||
- `diffusion_policy/policy/imf_transformer_hybrid_image_policy.py`
|
||||
- `image_pusht_diffusion_policy_dit_imf.yaml`
|
||||
|
||||
### Existing files changed for the iMF path
|
||||
- `diffusion_policy/workspace/train_diffusion_transformer_hybrid_workspace.py`
|
||||
- logging migration to SwanLab for this experiment chain
|
||||
- no structural training-loop fork beyond instantiating the configured policy and logging scalar metrics
|
||||
- `diffusion_policy/env_runner/pusht_image_runner.py`
|
||||
- suppress video objects in returned logs
|
||||
|
||||
### Model structure
|
||||
The iMF transformer mirrors the current transformer policy structure closely enough to reuse known-good conditioning patterns, but predicts two heads:
|
||||
The iMF transformer mirrors the current transformer policy structure closely enough to reuse known-good conditioning patterns, but it remains a **single-head model** that predicts only:
|
||||
- `u`: average velocity field
|
||||
- `v`: instantaneous velocity field
|
||||
|
||||
The same function is reused at two evaluation points:
|
||||
- `fn(z_t, r, t, cond)` predicts average velocity `u`
|
||||
- `fn(z_t, t, t, cond)` predicts the instantaneous velocity surrogate `v`
|
||||
|
||||
Inputs remain conditioned on encoded observations and action trajectory tokens.
|
||||
|
||||
## iMF Training Objective
|
||||
For a normalized action trajectory `x`:
|
||||
For a normalized action trajectory `x`, the initial implementation follows the user-provided Algorithm 1 exactly:
|
||||
1. sample `t, r`
|
||||
2. sample Gaussian noise `e`
|
||||
3. form `z_t = (1 - t) * x + t * e`
|
||||
4. predict instantaneous velocity `v = fn(z_t, t, t)` or equivalently the model’s `v` head at time `t`
|
||||
5. compute `u` and `du/dt` with JVP using tangent `(v, 0, 1)` over `(z, r, t)`
|
||||
6. form compound velocity:
|
||||
4. predict instantaneous velocity surrogate with the same network:
|
||||
- `v = fn(z_t, t, t, cond)`
|
||||
5. define the JVP function exactly as:
|
||||
- `g(z, r, t) = fn(z, r, t, cond)`
|
||||
6. compute the primal output and JVP with tangent:
|
||||
- `u, du_dt = jvp(g, (z_t, r, t), (v.detach(), 0, 1))`
|
||||
7. form compound velocity:
|
||||
- `V = u + (t - r) * stopgrad(du_dt)`
|
||||
7. train against target average velocity:
|
||||
8. train against the average-velocity target:
|
||||
- `target = e - x`
|
||||
8. optimize the iMF loss on unmasked action tokens, with any auxiliary `v`-head loss kept only if it helps preserve stability
|
||||
9. optimize only the masked iMF loss:
|
||||
- `loss = metric(V - target)`
|
||||
|
||||
The implementation should prefer `torch.func.jvp` and keep a safe fallback path if the local Torch stack needs it.
|
||||
There is **no auxiliary `v` loss** in the initial implementation. The implementation should prefer `torch.func.jvp` and keep a safe fallback path if the local Torch stack needs it.
|
||||
|
||||
## iMF Inference Design
|
||||
Inference uses a single step starting from noise:
|
||||
@@ -125,7 +140,7 @@ This matches the time direction in the reference iMeanFlow sampling logic.
|
||||
4. once iMF smoke tests pass, create/preserve a dedicated feature branch for the experiment code and push it to Gitea
|
||||
|
||||
## Experiment Plan
|
||||
After the iMF path is smoke-tested:
|
||||
After the iMF path is smoke-tested and pushed:
|
||||
- run a 3x3 grid over:
|
||||
- `n_emb ∈ {128, 256, 384}`
|
||||
- `n_layer ∈ {6, 12, 18}`
|
||||
|
||||
Reference in New Issue
Block a user