feat: add pusht dit no-causal config

This commit is contained in:
Logic
2026-03-27 17:06:16 +08:00
parent 36fbf2a6b7
commit 31925bbf39
3 changed files with 96 additions and 0 deletions

View File

@@ -0,0 +1,53 @@
# PushT DiT No-Causal Compare Implementation Plan
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
**Goal:** Add a PushT image DiT no-causal config, rerun the two prior DiT baselines for 350 epochs, and compare max `test_mean_score` plus batch-1 inference latency.
**Architecture:** Keep the existing causal DiT baselines unchanged and add a separate no-causal config that only flips `policy.causal_attn=false` while preserving the SwanLab naming safeguards. Launch the default DiT (`256x8`) locally and the `256x18` DiT on 5880 GPU0, then parse `logs.json.txt` and benchmark both checkpoints on the same hardware.
**Tech Stack:** Hydra, Diffusion Policy transformer image workspace, SwanLab, uv Python env, local 5090 + trusted remote 5880.
---
### Task 1: Add no-causal DiT config and config regression test
**Files:**
- Create: `image_pusht_diffusion_policy_dit_nocausal.yaml`
- Modify: `tests/test_pusht_swanlab_config.py`
- [ ] Write a failing test asserting the new no-causal DiT config uses SwanLab-safe naming and `policy.causal_attn == False`.
- [ ] Run the targeted pytest command and verify it fails because the config does not exist yet.
- [ ] Add the minimal new config by composing from the existing PushT DiT config and overriding only `policy.causal_attn=false`.
- [ ] Re-run the targeted pytest command and verify it passes.
### Task 2: Smoke-verify the new config
**Files:**
- Read: `image_pusht_diffusion_policy_dit_nocausal.yaml`
- [ ] Run `train.py --help` against the new config.
- [ ] Verify Hydra resolves the config without errors.
### Task 3: Launch the two 350-epoch no-causal reruns
**Files:**
- Write runtime scripts/logs under `data/run_logs/`
- Write outputs under `data/outputs/`
- [ ] Launch local run: `dit_nocausal_img_pusht_default_seed42_local` with 350 epochs.
- [ ] Launch remote run: `dit_nocausal_img_pusht_emb256_layer18_seed42_5880gpu0` with 350 epochs and `policy.n_layer=18`.
- [ ] Use explicit SwanLab overrides: unique `logging.name`, `logging.resume=false`, `logging.id=null`, shared group `dit_pusht_nocausal_compare`.
- [ ] Record pid files and launcher scripts.
### Task 4: Monitor and summarize
**Files:**
- Read: per-run `logs.json.txt`
- Read: checkpoints directories
- [ ] Monitor until both runs reach epoch 349 completion.
- [ ] Extract `max(test_mean_score)` and final logged `test_mean_score`.
- [ ] Identify the best checkpoint for each run.
- [ ] Benchmark batch-1 `policy.predict_action(obs)` latency on the same hardware.
- [ ] Report the final comparison table and short conclusion.