1 Commits

Author SHA1 Message Date
Logic
31925bbf39 feat: add pusht dit no-causal config 2026-03-27 17:06:16 +08:00
3 changed files with 96 additions and 0 deletions

View File

@@ -0,0 +1,53 @@
# PushT DiT No-Causal Compare Implementation Plan
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
**Goal:** Add a PushT image DiT no-causal config, rerun the two prior DiT baselines for 350 epochs, and compare max `test_mean_score` plus batch-1 inference latency.
**Architecture:** Keep the existing causal DiT baselines unchanged and add a separate no-causal config that only flips `policy.causal_attn=false` while preserving the SwanLab naming safeguards. Launch the default DiT (`256x8`) locally and the `256x18` DiT on 5880 GPU0, then parse `logs.json.txt` and benchmark both checkpoints on the same hardware.
**Tech Stack:** Hydra, Diffusion Policy transformer image workspace, SwanLab, uv Python env, local 5090 + trusted remote 5880.
---
### Task 1: Add no-causal DiT config and config regression test
**Files:**
- Create: `image_pusht_diffusion_policy_dit_nocausal.yaml`
- Modify: `tests/test_pusht_swanlab_config.py`
- [ ] Write a failing test asserting the new no-causal DiT config uses SwanLab-safe naming and `policy.causal_attn == False`.
- [ ] Run the targeted pytest command and verify it fails because the config does not exist yet.
- [ ] Add the minimal new config by composing from the existing PushT DiT config and overriding only `policy.causal_attn=false`.
- [ ] Re-run the targeted pytest command and verify it passes.
### Task 2: Smoke-verify the new config
**Files:**
- Read: `image_pusht_diffusion_policy_dit_nocausal.yaml`
- [ ] Run `train.py --help` against the new config.
- [ ] Verify Hydra resolves the config without errors.
### Task 3: Launch the two 350-epoch no-causal reruns
**Files:**
- Write runtime scripts/logs under `data/run_logs/`
- Write outputs under `data/outputs/`
- [ ] Launch local run: `dit_nocausal_img_pusht_default_seed42_local` with 350 epochs.
- [ ] Launch remote run: `dit_nocausal_img_pusht_emb256_layer18_seed42_5880gpu0` with 350 epochs and `policy.n_layer=18`.
- [ ] Use explicit SwanLab overrides: unique `logging.name`, `logging.resume=false`, `logging.id=null`, shared group `dit_pusht_nocausal_compare`.
- [ ] Record pid files and launcher scripts.
### Task 4: Monitor and summarize
**Files:**
- Read: per-run `logs.json.txt`
- Read: checkpoints directories
- [ ] Monitor until both runs reach epoch 349 completion.
- [ ] Extract `max(test_mean_score)` and final logged `test_mean_score`.
- [ ] Identify the best checkpoint for each run.
- [ ] Benchmark batch-1 `policy.predict_action(obs)` latency on the same hardware.
- [ ] Report the final comparison table and short conclusion.

View File

@@ -0,0 +1,31 @@
defaults:
- diffusion_policy/config/train_diffusion_transformer_hybrid_workspace@_here_
- override /diffusion_policy/config/task@task: pusht_image
- _self_
exp_name: pusht_image_dit_nocausal
policy:
_target_: diffusion_policy.policy.diffusion_transformer_hybrid_image_policy.DiffusionTransformerHybridImagePolicy
causal_attn: false
logging:
backend: swanlab
mode: online
name: ${exp_name}
resume: false
tags: ["${name}", "${task_name}", "${exp_name}", "swanlab"]
id: null
group: ${exp_name}
dataloader:
num_workers: 0
val_dataloader:
num_workers: 0
task:
env_runner:
n_envs: 1
n_test_vis: 0
n_train_vis: 0

View File

@@ -30,3 +30,15 @@ def test_image_pusht_dit_imf_swanlab_config_uses_exp_name_and_no_resume_collisio
assert cfg.logging.resume is False
assert cfg.logging.id is None
assert cfg.logging.group == cfg.exp_name
def test_image_pusht_dit_nocausal_config_uses_exp_name_and_disables_causal_attention():
cfg = _load_cfg('image_pusht_diffusion_policy_dit_nocausal.yaml')
assert cfg.logging.backend == 'swanlab'
assert cfg.logging.mode == 'online'
assert cfg.logging.name == cfg.exp_name
assert cfg.logging.resume is False
assert cfg.logging.id is None
assert cfg.logging.group == cfg.exp_name
assert cfg.policy.causal_attn is False