feat: add pusht dit no-causal config
This commit is contained in:
@@ -0,0 +1,53 @@
|
||||
# PushT DiT No-Causal Compare Implementation Plan
|
||||
|
||||
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||
|
||||
**Goal:** Add a PushT image DiT no-causal config, rerun the two prior DiT baselines for 350 epochs, and compare max `test_mean_score` plus batch-1 inference latency.
|
||||
|
||||
**Architecture:** Keep the existing causal DiT baselines unchanged and add a separate no-causal config that only flips `policy.causal_attn=false` while preserving the SwanLab naming safeguards. Launch the default DiT (`256x8`) locally and the `256x18` DiT on 5880 GPU0, then parse `logs.json.txt` and benchmark both checkpoints on the same hardware.
|
||||
|
||||
**Tech Stack:** Hydra, Diffusion Policy transformer image workspace, SwanLab, uv Python env, local 5090 + trusted remote 5880.
|
||||
|
||||
---
|
||||
|
||||
### Task 1: Add no-causal DiT config and config regression test
|
||||
|
||||
**Files:**
|
||||
- Create: `image_pusht_diffusion_policy_dit_nocausal.yaml`
|
||||
- Modify: `tests/test_pusht_swanlab_config.py`
|
||||
|
||||
- [ ] Write a failing test asserting the new no-causal DiT config uses SwanLab-safe naming and `policy.causal_attn == False`.
|
||||
- [ ] Run the targeted pytest command and verify it fails because the config does not exist yet.
|
||||
- [ ] Add the minimal new config by composing from the existing PushT DiT config and overriding only `policy.causal_attn=false`.
|
||||
- [ ] Re-run the targeted pytest command and verify it passes.
|
||||
|
||||
### Task 2: Smoke-verify the new config
|
||||
|
||||
**Files:**
|
||||
- Read: `image_pusht_diffusion_policy_dit_nocausal.yaml`
|
||||
|
||||
- [ ] Run `train.py --help` against the new config.
|
||||
- [ ] Verify Hydra resolves the config without errors.
|
||||
|
||||
### Task 3: Launch the two 350-epoch no-causal reruns
|
||||
|
||||
**Files:**
|
||||
- Write runtime scripts/logs under `data/run_logs/`
|
||||
- Write outputs under `data/outputs/`
|
||||
|
||||
- [ ] Launch local run: `dit_nocausal_img_pusht_default_seed42_local` with 350 epochs.
|
||||
- [ ] Launch remote run: `dit_nocausal_img_pusht_emb256_layer18_seed42_5880gpu0` with 350 epochs and `policy.n_layer=18`.
|
||||
- [ ] Use explicit SwanLab overrides: unique `logging.name`, `logging.resume=false`, `logging.id=null`, shared group `dit_pusht_nocausal_compare`.
|
||||
- [ ] Record pid files and launcher scripts.
|
||||
|
||||
### Task 4: Monitor and summarize
|
||||
|
||||
**Files:**
|
||||
- Read: per-run `logs.json.txt`
|
||||
- Read: checkpoints directories
|
||||
|
||||
- [ ] Monitor until both runs reach epoch 349 completion.
|
||||
- [ ] Extract `max(test_mean_score)` and final logged `test_mean_score`.
|
||||
- [ ] Identify the best checkpoint for each run.
|
||||
- [ ] Benchmark batch-1 `policy.predict_action(obs)` latency on the same hardware.
|
||||
- [ ] Report the final comparison table and short conclusion.
|
||||
31
image_pusht_diffusion_policy_dit_nocausal.yaml
Normal file
31
image_pusht_diffusion_policy_dit_nocausal.yaml
Normal file
@@ -0,0 +1,31 @@
|
||||
defaults:
|
||||
- diffusion_policy/config/train_diffusion_transformer_hybrid_workspace@_here_
|
||||
- override /diffusion_policy/config/task@task: pusht_image
|
||||
- _self_
|
||||
|
||||
exp_name: pusht_image_dit_nocausal
|
||||
|
||||
policy:
|
||||
_target_: diffusion_policy.policy.diffusion_transformer_hybrid_image_policy.DiffusionTransformerHybridImagePolicy
|
||||
causal_attn: false
|
||||
|
||||
logging:
|
||||
backend: swanlab
|
||||
mode: online
|
||||
name: ${exp_name}
|
||||
resume: false
|
||||
tags: ["${name}", "${task_name}", "${exp_name}", "swanlab"]
|
||||
id: null
|
||||
group: ${exp_name}
|
||||
|
||||
dataloader:
|
||||
num_workers: 0
|
||||
|
||||
val_dataloader:
|
||||
num_workers: 0
|
||||
|
||||
task:
|
||||
env_runner:
|
||||
n_envs: 1
|
||||
n_test_vis: 0
|
||||
n_train_vis: 0
|
||||
@@ -30,3 +30,15 @@ def test_image_pusht_dit_imf_swanlab_config_uses_exp_name_and_no_resume_collisio
|
||||
assert cfg.logging.resume is False
|
||||
assert cfg.logging.id is None
|
||||
assert cfg.logging.group == cfg.exp_name
|
||||
|
||||
|
||||
def test_image_pusht_dit_nocausal_config_uses_exp_name_and_disables_causal_attention():
|
||||
cfg = _load_cfg('image_pusht_diffusion_policy_dit_nocausal.yaml')
|
||||
|
||||
assert cfg.logging.backend == 'swanlab'
|
||||
assert cfg.logging.mode == 'online'
|
||||
assert cfg.logging.name == cfg.exp_name
|
||||
assert cfg.logging.resume is False
|
||||
assert cfg.logging.id is None
|
||||
assert cfg.logging.group == cfg.exp_name
|
||||
assert cfg.policy.causal_attn is False
|
||||
|
||||
Reference in New Issue
Block a user