6.4 KiB
PushT Image DiT iMF + SwanLab Design
Goal
Migrate the PushT image DiT experiment path from W&B to SwanLab online logging, suppress simulation video logging, then add an iMeanFlow-based one-step transformer policy for PushT image experiments and run a controlled architecture sweep over embedding width and depth using test_mean_score as the primary metric.
Context
- The implementation baseline is
main. - The experiment path is limited to the PushT image transformer workflow; unrelated workspaces and runners should remain unchanged.
- Environment management must use the repo-local
uvworkflow. - The trusted remote machine alias
5880refers todroid-system-product-name(droid@100.73.14.65) and can run two GPU jobs in parallel.
Architecture Overview
The work is split into two verified phases:
-
Logging migration phase
- Keep the existing PushT image DiT training behavior intact.
- Replace W&B usage with SwanLab in the transformer hybrid workspace used by PushT image DiT experiments.
- Preserve local
logs.json.txtoutput. - Ensure rollout metrics such as
test_mean_scoreand per-seed rewards are still logged. - Disable simulation video logging at both the config and runner/logging boundary.
-
iMF migration phase
- Keep the original diffusion-based transformer image policy available on
main. - Add a parallel iMF-specific model/policy/config path rather than overwriting the baseline diffusion policy.
- Reuse the existing observation encoder and training workspace where possible.
- Replace diffusion training with the iMeanFlow training objective.
- Use one-step inference for validation/rollout in the iMF path.
- Keep the original diffusion-based transformer image policy available on
Logging Design
Scope
Only the PushT image DiT experiment chain is changed:
train_diffusion_transformer_hybrid_workspace.pypusht_image_runner.py- the new/updated PushT image transformer configs
Behavior
- SwanLab runs in
onlinemode. - Logged values are scalar metrics only, e.g.:
train_lossval_losstrain_action_mse_errortest_mean_score- aggregate rollout metrics and optional per-seed scalar rewards
- No simulation videos are uploaded or wrapped as logging objects.
- Local JSON logging remains enabled for auditability and remote-job fallback debugging.
Operational safeguards
- Default PushT experiment configs set
task.env_runner.n_test_vis=0andtask.env_runner.n_train_vis=0. - The PushT image runner will not emit video objects into
log_data, preventing accidental uploads even if visualization counts are later changed. - SwanLab credentials are provided through the environment at runtime, not committed into the repo.
iMF Model Design
Baseline reuse
The iMF path reuses:
- the existing image observation encoder
- the existing action/observation normalization path
- the existing training workspace skeleton
- the existing PushT image dataset and env runner
New files
diffusion_policy/model/diffusion/imf_transformer_for_diffusion.pydiffusion_policy/policy/imf_transformer_hybrid_image_policy.pyimage_pusht_diffusion_policy_dit_imf.yaml
Model structure
The iMF transformer mirrors the current transformer policy structure closely enough to reuse known-good conditioning patterns, but predicts two heads:
u: average velocity fieldv: instantaneous velocity field
Inputs remain conditioned on encoded observations and action trajectory tokens.
iMF Training Objective
For a normalized action trajectory x:
- sample
t, r - sample Gaussian noise
e - form
z_t = (1 - t) * x + t * e - predict instantaneous velocity
v = fn(z_t, t, t)or equivalently the model’svhead at timet - compute
uanddu/dtwith JVP using tangent(v, 0, 1)over(z, r, t) - form compound velocity:
V = u + (t - r) * stopgrad(du_dt)
- train against target average velocity:
target = e - x
- optimize the iMF loss on unmasked action tokens, with any auxiliary
v-head loss kept only if it helps preserve stability
The implementation should prefer torch.func.jvp and keep a safe fallback path if the local Torch stack needs it.
iMF Inference Design
Inference uses a single step starting from noise:
- initialize
z_1 ~ N(0, I) - set
t = 1.0,r = 0.0 - predict
u(z_1, t, r, cond) - produce the action sample with one update:
x_hat = z_1 - (t - r) * u
This matches the time direction in the reference iMeanFlow sampling logic.
Testing Strategy
Phase 1: logging migration smoke test
- use the repo-local
uvenvironment - run a debug/smoke PushT image DiT training job on a single GPU with:
training.debug=truedataloader.num_workers=0val_dataloader.num_workers=0task.env_runner.n_envs=1task.env_runner.n_test_vis=0task.env_runner.n_train_vis=0
- verify:
- SwanLab initializes successfully
logs.json.txtis populated- rollout metrics still include
test_mean_score - no video logging is attempted
Phase 2: iMF smoke test
- run an equivalent debug PushT image iMF job
- verify:
- forward/backward passes succeed
- JVP path executes on the local Torch version
- one-step inference returns correctly shaped actions
- rollout produces scalar metrics including
test_mean_score
Branch and Commit Strategy
- start from a
main-based worktree branch - commit the SwanLab/no-video migration after smoke verification
- continue with the iMF implementation
- once iMF smoke tests pass, create/preserve a dedicated feature branch for the experiment code and push it to Gitea
Experiment Plan
After the iMF path is smoke-tested:
- run a 3x3 grid over:
n_emb ∈ {128, 256, 384}n_layer ∈ {6, 12, 18}
- keep the rest of the setup fixed
- run each experiment for 300 epochs
- primary comparison metric:
test_mean_score
Resource Allocation
Three concurrent runs should be scheduled continuously until the matrix is complete:
- local machine: 1 GPU
5880: 2 GPUs
Each run uses the same uv-managed environment and the same pushed branch so the code path is consistent across hosts.
Risks and Mitigations
- Torch JVP compatibility risk: provide a fallback JVP implementation and smoke-test immediately.
- Logging regression risk: keep local JSON logging and verify scalar rollout metrics before moving to iMF.
- Video/logging side effects: disable visualizations in config and filter video objects out of runner logs.
- Cross-host drift: push the verified branch to Gitea before launching the experiment matrix on multiple machines.