Files
roboimi/docs/superpowers/specs/2026-04-06-resnet-multitoken-imf-design.md

1.9 KiB

ResNet Multitoken IMF Design

Status: user-specified architecture, treated as approved on 2026-04-06.

Goal

Keep a standard ResNet-18 visual trunk (no AttnRes in vision), but change IMF conditioning from one concatenated multiview token per obs step into three camera-specific condition tokens per obs step.

Approved architecture

  • Vision trunk: standard resnet18 residual network
  • Cameras: front, top, r_vis
  • Each camera uses its own ResNet-18 weights (use_separate_rgb_encoder_per_camera=true)
  • Each camera produces one visual token
  • For each obs step and each camera:
    1. take that camera visual token
    2. concatenate robot state
    3. project to one condition token
  • IMF input should receive 3 condition tokens per obs step, not one concatenated token
  • With obs_horizon=2, IMF cond sequence length becomes 2 * 3 = 6
  • IMF head remains on the existing IMF/AttnRes implementation path
  • Vision trunk remains standard ResNet; no AttnRes vision replacement

Design choices

  • Extend ResNetDiffusionBackbone with an opt-in mode that returns per-camera tokens shaped (B, T, num_cams, D) instead of concatenating camera features into (B, T, num_cams * D).
  • Teach VLAAgent to detect multi-token visual features, broadcast state per camera token, apply the existing condition projector on each token, then flatten (T, num_cams) into one cond sequence for the IMF head.
  • Keep per_step_cond_dim as the width of a single condition token, and add explicit token-count metadata so transformer heads get the correct cond-sequence length.
  • For the new experiments, set the condition-token width equal to n_emb via cond_projector.output_dim=${agent.head.n_emb}.

Files expected to change

  • roboimi/vla/models/backbones/resnet_diffusion.py
  • roboimi/vla/agent.py
  • new Hydra agent config for the multitoken ResNet IMF variant
  • focused tests in tests/test_imf_vla_agent.py and/or tests/test_resnet_transformer_agent_wiring.py