Compare commits
20 Commits
749db2ce9c
...
feat/pusht
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
211abbb87f | ||
|
|
185ed6596c | ||
|
|
78ab18e8f3 | ||
|
|
484d008997 | ||
|
|
36fbf2a6b7 | ||
|
|
4cd5085b33 | ||
|
|
5e7ae6cfa5 | ||
|
|
23374a4cd2 | ||
|
|
15a0c41cbf | ||
|
|
ba6ede9425 | ||
|
|
08c1950c6d | ||
|
|
5ba07ac666 | ||
|
|
548a52bbb1 | ||
|
|
de4384e84a | ||
|
|
7dd9dc417a | ||
|
|
5aa9996fdc | ||
|
|
5c3d54fca3 | ||
|
|
a98e74873b | ||
|
|
68eef44d3e | ||
|
|
c52bac42ee |
68
AGENTS.md
Normal file
68
AGENTS.md
Normal file
@@ -0,0 +1,68 @@
|
||||
# Agent Notes
|
||||
|
||||
## Purpose
|
||||
`~/diffusion_policy` is the Diffusion Policy training repo. The main workflow here is Hydra-driven training via `train.py`, with the canonical PushT image experiment configured by `image_pusht_diffusion_policy_cnn.yaml`.
|
||||
|
||||
## Top Level
|
||||
- `diffusion_policy/`: core code, configs, datasets, env runners, workspaces.
|
||||
- `data/`: local datasets, outputs, checkpoints, run logs.
|
||||
- `train.py`: main training entrypoint.
|
||||
- `eval.py`: checkpoint evaluation entrypoint.
|
||||
- `image_pusht_diffusion_policy_cnn.yaml`: canonical single-seed PushT image config from the README path.
|
||||
- `.venv/`: local `uv`-managed virtualenv.
|
||||
- `.uv-cache/`, `.uv-python/`: local `uv` cache and Python install state.
|
||||
- `README.md`: upstream instructions and canonical commands.
|
||||
|
||||
## Canonical PushT Image Path
|
||||
- Entrypoint: `python train.py --config-dir=. --config-name=image_pusht_diffusion_policy_cnn.yaml`
|
||||
- Dataset path in config: `data/pusht/pusht_cchi_v7_replay.zarr`
|
||||
- README canonical device override: `training.device=cuda:0`
|
||||
|
||||
## Data
|
||||
- PushT archive currently present at `data/pusht.zip`
|
||||
- Unpacked dataset used by training: `data/pusht/pusht_cchi_v7_replay.zarr`
|
||||
|
||||
## Local Compatibility Adjustments
|
||||
- `diffusion_policy/env_runner/pusht_image_runner.py` now uses `SyncVectorEnv` instead of `AsyncVectorEnv`.
|
||||
Reason: avoid shared-memory and semaphore failures on this host/session.
|
||||
- `diffusion_policy/gym_util/sync_vector_env.py` has local compatibility changes:
|
||||
- added `reset_async`
|
||||
- seeded `reset_wait`
|
||||
- updated `concatenate(...)` call order for the current `gym` API
|
||||
|
||||
## Environment Expectations
|
||||
- Use the local `uv` env at `.venv`
|
||||
- Verified local Python: `3.9.25`
|
||||
- Verified local Torch stack: `torch 2.8.0+cu128`, `torchvision 0.23.0+cu128`
|
||||
- Other key installed versions verified in `.venv`:
|
||||
- `gym 0.23.1`
|
||||
- `hydra-core 1.2.0`
|
||||
- `diffusers 0.11.1`
|
||||
- `huggingface_hub 0.10.1`
|
||||
- `wandb 0.13.3`
|
||||
- `zarr 2.12.0`
|
||||
- `numcodecs 0.10.2`
|
||||
- `av 14.0.1`
|
||||
- Important note: this shell currently reports `torch.cuda.is_available() == False`, so always verify CUDA access in the current session before assuming GPU is usable.
|
||||
|
||||
## Logging And Outputs
|
||||
- Hydra run outputs: `data/outputs/...`
|
||||
- Per-run files to check first:
|
||||
- `.hydra/overrides.yaml`
|
||||
- `logs.json.txt`
|
||||
- `train.log`
|
||||
- `checkpoints/latest.ckpt`
|
||||
- Extra launcher logs may live under `data/run_logs/`
|
||||
|
||||
## Practical Guidance
|
||||
- Inspect with `rg`, `sed`, and existing Hydra output folders before changing code.
|
||||
- Prefer config overrides before code edits.
|
||||
- On this host, start from these safety overrides unless revalidated:
|
||||
- `logging.mode=offline`
|
||||
- `dataloader.num_workers=0`
|
||||
- `val_dataloader.num_workers=0`
|
||||
- `task.env_runner.n_envs=1`
|
||||
- `task.env_runner.n_test_vis=0`
|
||||
- `task.env_runner.n_train_vis=0`
|
||||
- If a run fails, inspect `.hydra/overrides.yaml`, then `logs.json.txt`, then `train.log`.
|
||||
- Avoid driver or system changes unless the repo-local path is clearly blocked.
|
||||
108
PUSHT_REPRO_5090.md
Normal file
108
PUSHT_REPRO_5090.md
Normal file
@@ -0,0 +1,108 @@
|
||||
# PushT Repro On 5090
|
||||
|
||||
## Goal
|
||||
Reproduce the canonical single-seed image PushT experiment from this repo in `~/diffusion_policy` using `image_pusht_diffusion_policy_cnn.yaml`.
|
||||
|
||||
## Current Verified Local Setup
|
||||
- Virtualenv: `./.venv` managed with `uv`
|
||||
- Python: `3.9.25`
|
||||
- Torch stack: `torch 2.8.0+cu128`, `torchvision 0.23.0+cu128`
|
||||
- Version strategy used here:
|
||||
- newer Torch/CUDA stack for current 5090-class hardware support
|
||||
- keep older repo-era packages where they are still required by the code
|
||||
- Verified key pins in `.venv`:
|
||||
- `numpy 1.26.4`
|
||||
- `gym 0.23.1`
|
||||
- `hydra-core 1.2.0`
|
||||
- `diffusers 0.11.1`
|
||||
- `huggingface_hub 0.10.1`
|
||||
- `wandb 0.13.3`
|
||||
- `zarr 2.12.0`
|
||||
- `numcodecs 0.10.2`
|
||||
- `av 14.0.1`
|
||||
- `robomimic 0.2.0`
|
||||
|
||||
## Dataset
|
||||
- README source: `https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip`
|
||||
- Local archive currently present: `data/pusht.zip`
|
||||
- Unpacked dataset used by the config: `data/pusht/pusht_cchi_v7_replay.zarr`
|
||||
|
||||
## Repo-Local Code Adjustments
|
||||
- `diffusion_policy/env_runner/pusht_image_runner.py`
|
||||
- switched PushT image evaluation from `AsyncVectorEnv` to `SyncVectorEnv`
|
||||
- `diffusion_policy/gym_util/sync_vector_env.py`
|
||||
- added `reset_async`
|
||||
- added seeded `reset_wait`
|
||||
- updated `concatenate(...)` call order for current `gym`
|
||||
|
||||
These changes were needed to keep PushT evaluation working without the async shared-memory path.
|
||||
|
||||
## Validated GPU Smoke Command
|
||||
This route is verified by `data/outputs/gpu_smoke2___pusht_gpu_smoke`, which contains `logs.json.txt` plus checkpoints:
|
||||
|
||||
```bash
|
||||
.venv/bin/python train.py \
|
||||
--config-dir=. \
|
||||
--config-name=image_pusht_diffusion_policy_cnn.yaml \
|
||||
training.seed=42 \
|
||||
training.device=cuda:0 \
|
||||
logging.mode=offline \
|
||||
dataloader.num_workers=0 \
|
||||
val_dataloader.num_workers=0 \
|
||||
task.env_runner.n_envs=1 \
|
||||
training.debug=true \
|
||||
task.env_runner.n_test=2 \
|
||||
task.env_runner.n_test_vis=0 \
|
||||
task.env_runner.n_train=1 \
|
||||
task.env_runner.n_train_vis=0 \
|
||||
task.env_runner.max_steps=20
|
||||
```
|
||||
|
||||
## Practical Full Training Command Used Here
|
||||
This matches the longer GPU run under `data/outputs/2026.03.13/15.37.00_train_diffusion_unet_hybrid_pusht_image_gpu_seed42`:
|
||||
|
||||
```bash
|
||||
.venv/bin/python train.py \
|
||||
--config-dir=. \
|
||||
--config-name=image_pusht_diffusion_policy_cnn.yaml \
|
||||
training.seed=42 \
|
||||
training.device=cuda:0 \
|
||||
logging.mode=offline \
|
||||
dataloader.num_workers=0 \
|
||||
val_dataloader.num_workers=0 \
|
||||
task.env_runner.n_envs=1 \
|
||||
task.env_runner.n_test_vis=0 \
|
||||
task.env_runner.n_train_vis=0 \
|
||||
hydra.run.dir=data/outputs/2026.03.13/15.37.00_train_diffusion_unet_hybrid_pusht_image_gpu_seed42
|
||||
```
|
||||
|
||||
## Why These Overrides Were Used
|
||||
- `logging.mode=offline`
|
||||
- avoids needing a W&B login and still leaves local run metadata in the output dir
|
||||
- `dataloader.num_workers=0` and `val_dataloader.num_workers=0`
|
||||
- avoids extra multiprocessing on this host
|
||||
- `task.env_runner.n_envs=1`
|
||||
- keeps PushT eval on the serial `SyncVectorEnv` path
|
||||
- `task.env_runner.n_test_vis=0` and `task.env_runner.n_train_vis=0`
|
||||
- avoids video-writing issues on this stack
|
||||
- one earlier GPU run with default vis settings logged libav/libx264 `profile=high` errors in `data/outputs/_train_diffusion_unet_hybrid_pusht_image_gpu_seed42/train.log`
|
||||
|
||||
## Output Locations
|
||||
- Smoke run:
|
||||
- `data/outputs/gpu_smoke2___pusht_gpu_smoke`
|
||||
- Longer GPU run:
|
||||
- `data/outputs/2026.03.13/15.37.00_train_diffusion_unet_hybrid_pusht_image_gpu_seed42`
|
||||
- Files to inspect inside a run:
|
||||
- `.hydra/overrides.yaml`
|
||||
- `logs.json.txt`
|
||||
- `train.log`
|
||||
- `checkpoints/latest.ckpt`
|
||||
|
||||
## Known Caveats
|
||||
- The default config is still tuned for older assumptions:
|
||||
- `logging.mode=online`
|
||||
- `dataloader.num_workers=8`
|
||||
- `task.env_runner.n_envs=null`
|
||||
- `task.env_runner.n_test_vis=4`
|
||||
- `task.env_runner.n_train_vis=2`
|
||||
- In this shell, `torch.cuda.is_available()` currently reports `False` even though the repo contains validated GPU smoke/full run artifacts. Re-check device visibility in the current session before restarting a GPU run.
|
||||
35
README.md
35
README.md
@@ -202,6 +202,41 @@ data/outputs/2023.03.01/22.13.58_train_diffusion_unet_hybrid_pusht_image
|
||||
|
||||
7 directories, 16 files
|
||||
```
|
||||
### 🆕 Evaluate Pre-trained Checkpoints
|
||||
Download a checkpoint from the published training log folders, such as [https://diffusion-policy.cs.columbia.edu/data/experiments/low_dim/pusht/diffusion_policy_cnn/train_0/checkpoints/epoch=0550-test_mean_score=0.969.ckpt](https://diffusion-policy.cs.columbia.edu/data/experiments/low_dim/pusht/diffusion_policy_cnn/train_0/checkpoints/epoch=0550-test_mean_score=0.969.ckpt).
|
||||
|
||||
Run the evaluation script:
|
||||
```console
|
||||
(robodiff)[diffusion_policy]$ python eval.py --checkpoint data/0550-test_mean_score=0.969.ckpt --output_dir data/pusht_eval_output --device cuda:0
|
||||
```
|
||||
|
||||
This will generate the following directory structure:
|
||||
```console
|
||||
(robodiff)[diffusion_policy]$ tree data/pusht_eval_output
|
||||
data/pusht_eval_output
|
||||
├── eval_log.json
|
||||
└── media
|
||||
├── 1fxtno84.mp4
|
||||
├── 224l7jqd.mp4
|
||||
├── 2fo4btlf.mp4
|
||||
├── 2in4cn7a.mp4
|
||||
├── 34b3o2qq.mp4
|
||||
└── 3p7jqn32.mp4
|
||||
|
||||
1 directory, 7 files
|
||||
```
|
||||
|
||||
`eval_log.json` contains metrics that is logged to wandb during training:
|
||||
```console
|
||||
(robodiff)[diffusion_policy]$ cat data/pusht_eval_output/eval_log.json
|
||||
{
|
||||
"test/mean_score": 0.9150393806777066,
|
||||
"test/sim_max_reward_4300000": 1.0,
|
||||
"test/sim_max_reward_4300001": 0.9872969750774386,
|
||||
...
|
||||
"train/sim_video_1": "data/pusht_eval_output//media/2fo4btlf.mp4"
|
||||
}
|
||||
```
|
||||
|
||||
## 🦾 Demo, Training and Eval on a Real Robot
|
||||
Make sure your UR5 robot is running and accepting command from its network interface (emergency stop button within reach at all time), your RealSense cameras plugged in to your workstation (tested with `realsense-viewer`) and your SpaceMouse connected with the `spacenavd` daemon running (verify with `systemctl status spacenavd`).
|
||||
|
||||
@@ -46,6 +46,8 @@ dependencies:
|
||||
- diffusers=0.11.1
|
||||
- av=10.0.0
|
||||
- cmake=3.24.3
|
||||
# trick to avoid cpu affinity issue described in https://github.com/pytorch/pytorch/issues/99625
|
||||
- llvm-openmp=14
|
||||
# trick to force reinstall imagecodecs via pip
|
||||
- imagecodecs==2022.8.8
|
||||
- pip:
|
||||
|
||||
@@ -46,6 +46,8 @@ dependencies:
|
||||
- diffusers=0.11.1
|
||||
- av=10.0.0
|
||||
- cmake=3.24.3
|
||||
# trick to avoid cpu affinity issue described in https://github.com/pytorch/pytorch/issues/99625
|
||||
- llvm-openmp=14
|
||||
# trick to force reinstall imagecodecs via pip
|
||||
- imagecodecs==2022.8.8
|
||||
- pip:
|
||||
|
||||
@@ -158,6 +158,8 @@ class ReplayBuffer:
|
||||
# numpy backend
|
||||
meta = dict()
|
||||
for key, value in src_root['meta'].items():
|
||||
if isinstance(value, zarr.Group):
|
||||
continue
|
||||
if len(value.shape) == 0:
|
||||
meta[key] = np.array(value)
|
||||
else:
|
||||
|
||||
109
diffusion_policy/dataset/mujoco_image_dataset.py
Normal file
109
diffusion_policy/dataset/mujoco_image_dataset.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from typing import Dict
|
||||
import torch
|
||||
import numpy as np
|
||||
import copy
|
||||
from diffusion_policy.common.pytorch_util import dict_apply
|
||||
from diffusion_policy.common.replay_buffer import ReplayBuffer
|
||||
from diffusion_policy.common.sampler import (
|
||||
SequenceSampler, get_val_mask, downsample_mask)
|
||||
from diffusion_policy.model.common.normalizer import LinearNormalizer
|
||||
from diffusion_policy.dataset.base_dataset import BaseImageDataset
|
||||
from diffusion_policy.common.normalize_util import get_image_range_normalizer
|
||||
|
||||
class MujocoImageDataset(BaseImageDataset):
|
||||
def __init__(self,
|
||||
zarr_path,
|
||||
horizon=1,
|
||||
pad_before=0,
|
||||
pad_after=0,
|
||||
seed=42,
|
||||
val_ratio=0.0,
|
||||
max_train_episodes=None
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
self.replay_buffer = ReplayBuffer.copy_from_path(
|
||||
# zarr_path, keys=['img', 'state', 'action'])
|
||||
zarr_path, keys=['robot_0_camera_images', 'robot_0_tcp_xyz_wxyz', 'robot_0_gripper_width', 'action_0_tcp_xyz_wxyz', 'action_0_gripper_width'])
|
||||
val_mask = get_val_mask(
|
||||
n_episodes=self.replay_buffer.n_episodes,
|
||||
val_ratio=val_ratio,
|
||||
seed=seed)
|
||||
train_mask = ~val_mask
|
||||
train_mask = downsample_mask(
|
||||
mask=train_mask,
|
||||
max_n=max_train_episodes,
|
||||
seed=seed)
|
||||
|
||||
self.sampler = SequenceSampler(
|
||||
replay_buffer=self.replay_buffer,
|
||||
sequence_length=horizon,
|
||||
pad_before=pad_before,
|
||||
pad_after=pad_after,
|
||||
episode_mask=train_mask)
|
||||
self.train_mask = train_mask
|
||||
self.horizon = horizon
|
||||
self.pad_before = pad_before
|
||||
self.pad_after = pad_after
|
||||
|
||||
def get_validation_dataset(self):
|
||||
val_set = copy.copy(self)
|
||||
val_set.sampler = SequenceSampler(
|
||||
replay_buffer=self.replay_buffer,
|
||||
sequence_length=self.horizon,
|
||||
pad_before=self.pad_before,
|
||||
pad_after=self.pad_after,
|
||||
episode_mask=~self.train_mask
|
||||
)
|
||||
val_set.train_mask = ~self.train_mask
|
||||
return val_set
|
||||
|
||||
def get_normalizer(self, mode='limits', **kwargs):
|
||||
data = {
|
||||
'action': np.concatenate([self.replay_buffer['action_0_tcp_xyz_wxyz'], self.replay_buffer['action_0_gripper_width']], axis=-1),
|
||||
'agent_pos': np.concatenate([self.replay_buffer['robot_0_tcp_xyz_wxyz'], self.replay_buffer['robot_0_tcp_xyz_wxyz']], axis=-1)
|
||||
}
|
||||
normalizer = LinearNormalizer()
|
||||
normalizer.fit(data=data, last_n_dims=1, mode=mode, **kwargs)
|
||||
normalizer['image'] = get_image_range_normalizer()
|
||||
return normalizer
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.sampler)
|
||||
|
||||
def _sample_to_data(self, sample):
|
||||
# agent_pos = sample['state'][:,:2].astype(np.float32) # (agent_posx2, block_posex3)
|
||||
agent_pos = np.concatenate([sample['robot_0_tcp_xyz_wxyz'], sample['robot_0_gripper_width']], axis=-1).astype(np.float32)
|
||||
agent_action = np.concatenate([sample['action_0_tcp_xyz_wxyz'], sample['action_0_gripper_width']], axis=-1).astype(np.float32)
|
||||
# image = np.moveaxis(sample['img'],-1,1)/255
|
||||
image = np.moveaxis(sample['robot_0_camera_images'].astype(np.float32).squeeze(1),-1,1)/255
|
||||
|
||||
data = {
|
||||
'obs': {
|
||||
'image': image, # T, 3, 224, 224
|
||||
'agent_pos': agent_pos, # T, 8 (x,y,z,qx,qy,qz,qw,gripper_width)
|
||||
},
|
||||
'action': agent_action # T, 8 (x,y,z,qx,qy,qz,qw,gripper_width)
|
||||
}
|
||||
return data
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
||||
sample = self.sampler.sample_sequence(idx)
|
||||
data = self._sample_to_data(sample)
|
||||
torch_data = dict_apply(data, torch.from_numpy)
|
||||
return torch_data
|
||||
|
||||
|
||||
def test():
|
||||
import os
|
||||
zarr_path = os.path.expanduser('/home/yihuai/robotics/repositories/mujoco/mujoco-env/data/collect_heuristic_data/2024-12-24_11-36-15_100episodes/merged_data.zarr')
|
||||
dataset = MujocoImageDataset(zarr_path, horizon=16)
|
||||
print(dataset[0])
|
||||
# from matplotlib import pyplot as plt
|
||||
# normalizer = dataset.get_normalizer()
|
||||
# nactions = normalizer['action'].normalize(dataset.replay_buffer['action'])
|
||||
# diff = np.diff(nactions, axis=0)
|
||||
# dists = np.linalg.norm(np.diff(nactions, axis=0), axis=-1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
@@ -1,22 +1,37 @@
|
||||
import wandb
|
||||
import numpy as np
|
||||
import torch
|
||||
import collections
|
||||
import pathlib
|
||||
import tqdm
|
||||
import dill
|
||||
import math
|
||||
import wandb.sdk.data_types.video as wv
|
||||
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
|
||||
from diffusion_policy.gym_util.async_vector_env import AsyncVectorEnv
|
||||
# from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
|
||||
from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
|
||||
from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
|
||||
from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
|
||||
|
||||
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
|
||||
from diffusion_policy.common.pytorch_util import dict_apply
|
||||
from diffusion_policy.env_runner.base_image_runner import BaseImageRunner
|
||||
|
||||
|
||||
def summarize_rollout_metrics(env_seeds, env_prefixs, all_rewards, all_video_paths=None):
|
||||
del all_video_paths
|
||||
|
||||
max_rewards = collections.defaultdict(list)
|
||||
log_data = dict()
|
||||
for seed, prefix, rewards in zip(env_seeds, env_prefixs, all_rewards):
|
||||
max_reward = np.max(rewards)
|
||||
max_rewards[prefix].append(max_reward)
|
||||
log_data[prefix + f'sim_max_reward_{seed}'] = max_reward
|
||||
|
||||
aggregate_key_map = {
|
||||
'train/': 'train_mean_score',
|
||||
'test/': 'test_mean_score',
|
||||
}
|
||||
for prefix, value in max_rewards.items():
|
||||
log_data[aggregate_key_map.get(prefix, prefix + 'mean_score')] = np.mean(value)
|
||||
|
||||
return log_data
|
||||
|
||||
class PushTImageRunner(BaseImageRunner):
|
||||
def __init__(self,
|
||||
output_dir,
|
||||
@@ -41,25 +56,12 @@ class PushTImageRunner(BaseImageRunner):
|
||||
if n_envs is None:
|
||||
n_envs = n_train + n_test
|
||||
|
||||
steps_per_render = max(10 // fps, 1)
|
||||
def env_fn():
|
||||
return MultiStepWrapper(
|
||||
VideoRecordingWrapper(
|
||||
PushTImageEnv(
|
||||
legacy=legacy_test,
|
||||
render_size=render_size
|
||||
),
|
||||
video_recoder=VideoRecorder.create_h264(
|
||||
fps=fps,
|
||||
codec='h264',
|
||||
input_pix_fmt='rgb24',
|
||||
crf=crf,
|
||||
thread_type='FRAME',
|
||||
thread_count=1
|
||||
),
|
||||
file_path=None,
|
||||
steps_per_render=steps_per_render
|
||||
),
|
||||
n_obs_steps=n_obs_steps,
|
||||
n_action_steps=n_action_steps,
|
||||
max_episode_steps=max_steps
|
||||
@@ -72,21 +74,8 @@ class PushTImageRunner(BaseImageRunner):
|
||||
# train
|
||||
for i in range(n_train):
|
||||
seed = train_start_seed + i
|
||||
enable_render = i < n_train_vis
|
||||
|
||||
def init_fn(env, seed=seed, enable_render=enable_render):
|
||||
# setup rendering
|
||||
# video_wrapper
|
||||
assert isinstance(env.env, VideoRecordingWrapper)
|
||||
env.env.video_recoder.stop()
|
||||
env.env.file_path = None
|
||||
if enable_render:
|
||||
filename = pathlib.Path(output_dir).joinpath(
|
||||
'media', wv.util.generate_id() + ".mp4")
|
||||
filename.parent.mkdir(parents=False, exist_ok=True)
|
||||
filename = str(filename)
|
||||
env.env.file_path = filename
|
||||
|
||||
def init_fn(env, seed=seed):
|
||||
# set seed
|
||||
assert isinstance(env, MultiStepWrapper)
|
||||
env.seed(seed)
|
||||
@@ -98,21 +87,8 @@ class PushTImageRunner(BaseImageRunner):
|
||||
# test
|
||||
for i in range(n_test):
|
||||
seed = test_start_seed + i
|
||||
enable_render = i < n_test_vis
|
||||
|
||||
def init_fn(env, seed=seed, enable_render=enable_render):
|
||||
# setup rendering
|
||||
# video_wrapper
|
||||
assert isinstance(env.env, VideoRecordingWrapper)
|
||||
env.env.video_recoder.stop()
|
||||
env.env.file_path = None
|
||||
if enable_render:
|
||||
filename = pathlib.Path(output_dir).joinpath(
|
||||
'media', wv.util.generate_id() + ".mp4")
|
||||
filename.parent.mkdir(parents=False, exist_ok=True)
|
||||
filename = str(filename)
|
||||
env.env.file_path = filename
|
||||
|
||||
def init_fn(env, seed=seed):
|
||||
# set seed
|
||||
assert isinstance(env, MultiStepWrapper)
|
||||
env.seed(seed)
|
||||
@@ -121,7 +97,9 @@ class PushTImageRunner(BaseImageRunner):
|
||||
env_prefixs.append('test/')
|
||||
env_init_fn_dills.append(dill.dumps(init_fn))
|
||||
|
||||
env = AsyncVectorEnv(env_fns)
|
||||
# This environment can run without multiprocessing, which avoids
|
||||
# shared-memory and semaphore restrictions on some machines.
|
||||
env = SyncVectorEnv(env_fns)
|
||||
|
||||
# test env
|
||||
# env.reset(seed=env_seeds)
|
||||
@@ -153,7 +131,6 @@ class PushTImageRunner(BaseImageRunner):
|
||||
n_chunks = math.ceil(n_inits / n_envs)
|
||||
|
||||
# allocate data
|
||||
all_video_paths = [None] * n_inits
|
||||
all_rewards = [None] * n_inits
|
||||
|
||||
for chunk_idx in range(n_chunks):
|
||||
@@ -213,39 +190,16 @@ class PushTImageRunner(BaseImageRunner):
|
||||
pbar.update(action.shape[1])
|
||||
pbar.close()
|
||||
|
||||
all_video_paths[this_global_slice] = env.render()[this_local_slice]
|
||||
all_rewards[this_global_slice] = env.call('get_attr', 'reward')[this_local_slice]
|
||||
# clear out video buffer
|
||||
# reset env state between evaluation calls
|
||||
_ = env.reset()
|
||||
|
||||
# log
|
||||
max_rewards = collections.defaultdict(list)
|
||||
log_data = dict()
|
||||
# results reported in the paper are generated using the commented out line below
|
||||
# which will only report and average metrics from first n_envs initial condition and seeds
|
||||
# fortunately this won't invalidate our conclusion since
|
||||
# 1. This bug only affects the variance of metrics, not their mean
|
||||
# 2. All baseline methods are evaluated using the same code
|
||||
# to completely reproduce reported numbers, uncomment this line:
|
||||
# for i in range(len(self.env_fns)):
|
||||
# and comment out this line
|
||||
for i in range(n_inits):
|
||||
seed = self.env_seeds[i]
|
||||
prefix = self.env_prefixs[i]
|
||||
max_reward = np.max(all_rewards[i])
|
||||
max_rewards[prefix].append(max_reward)
|
||||
log_data[prefix+f'sim_max_reward_{seed}'] = max_reward
|
||||
|
||||
# visualize sim
|
||||
video_path = all_video_paths[i]
|
||||
if video_path is not None:
|
||||
sim_video = wandb.Video(video_path)
|
||||
log_data[prefix+f'sim_video_{seed}'] = sim_video
|
||||
|
||||
# log aggregate metrics
|
||||
for prefix, value in max_rewards.items():
|
||||
name = prefix+'mean_score'
|
||||
value = np.mean(value)
|
||||
log_data[name] = value
|
||||
|
||||
return log_data
|
||||
# results reported in the paper are generated using the commented out
|
||||
# line below, which would only report and average metrics from the
|
||||
# first n_envs initial conditions and seeds. We keep the full n_inits
|
||||
# behavior here.
|
||||
return summarize_rollout_metrics(
|
||||
env_seeds=self.env_seeds[:n_inits],
|
||||
env_prefixs=self.env_prefixs[:n_inits],
|
||||
all_rewards=all_rewards[:n_inits],
|
||||
)
|
||||
|
||||
@@ -60,17 +60,44 @@ class SyncVectorEnv(VectorEnv):
|
||||
for env, seed in zip(self.envs, seeds):
|
||||
env.seed(seed)
|
||||
|
||||
def reset_wait(self):
|
||||
def reset_async(self, seed=None, return_info=False, options=None):
|
||||
if seed is None:
|
||||
seeds = [None for _ in range(self.num_envs)]
|
||||
elif isinstance(seed, int):
|
||||
seeds = [seed + i for i in range(self.num_envs)]
|
||||
else:
|
||||
seeds = list(seed)
|
||||
assert len(seeds) == self.num_envs
|
||||
self._reset_seeds = seeds
|
||||
self._reset_return_info = return_info
|
||||
self._reset_options = options
|
||||
|
||||
def reset_wait(self, seed=None, return_info=False, options=None):
|
||||
seeds = getattr(self, '_reset_seeds', None)
|
||||
if seeds is None:
|
||||
if seed is None:
|
||||
seeds = [None for _ in range(self.num_envs)]
|
||||
elif isinstance(seed, int):
|
||||
seeds = [seed + i for i in range(self.num_envs)]
|
||||
else:
|
||||
seeds = list(seed)
|
||||
self._dones[:] = False
|
||||
observations = []
|
||||
for env in self.envs:
|
||||
infos = []
|
||||
for env, seed_i in zip(self.envs, seeds):
|
||||
if seed_i is not None:
|
||||
env.seed(seed_i)
|
||||
observation = env.reset()
|
||||
observations.append(observation)
|
||||
infos.append({})
|
||||
self.observations = concatenate(
|
||||
observations, self.observations, self.single_observation_space
|
||||
self.single_observation_space, observations, self.observations
|
||||
)
|
||||
|
||||
return deepcopy(self.observations) if self.copy else self.observations
|
||||
obs = deepcopy(self.observations) if self.copy else self.observations
|
||||
if return_info:
|
||||
return obs, infos
|
||||
return obs
|
||||
|
||||
def step_async(self, actions):
|
||||
self._actions = actions
|
||||
@@ -84,7 +111,7 @@ class SyncVectorEnv(VectorEnv):
|
||||
observations.append(observation)
|
||||
infos.append(info)
|
||||
self.observations = concatenate(
|
||||
observations, self.observations, self.single_observation_space
|
||||
self.single_observation_space, observations, self.observations
|
||||
)
|
||||
|
||||
return (
|
||||
|
||||
@@ -40,7 +40,7 @@ class RotationTransformer:
|
||||
getattr(pt, f'matrix_to_{from_rep}')
|
||||
]
|
||||
if from_convention is not None:
|
||||
funcs = [functools.partial(func, convernsion=from_convention)
|
||||
funcs = [functools.partial(func, convention=from_convention)
|
||||
for func in funcs]
|
||||
forward_funcs.append(funcs[0])
|
||||
inverse_funcs.append(funcs[1])
|
||||
@@ -51,7 +51,7 @@ class RotationTransformer:
|
||||
getattr(pt, f'{to_rep}_to_matrix')
|
||||
]
|
||||
if to_convention is not None:
|
||||
funcs = [functools.partial(func, convernsion=to_convention)
|
||||
funcs = [functools.partial(func, convention=to_convention)
|
||||
for func in funcs]
|
||||
forward_funcs.append(funcs[0])
|
||||
inverse_funcs.append(funcs[1])
|
||||
|
||||
@@ -0,0 +1,247 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6) -> None:
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
rms = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
return (x.float() * rms).to(x.dtype) * self.weight
|
||||
|
||||
|
||||
class RMSNormNoWeight(nn.Module):
|
||||
def __init__(self, eps: float = 1e-6) -> None:
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
rms = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
return (x.float() * rms).to(x.dtype)
|
||||
|
||||
|
||||
def precompute_rope_freqs(
|
||||
dim: int,
|
||||
max_seq_len: int,
|
||||
theta: float = 10000.0,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> Tensor:
|
||||
if dim % 2 != 0:
|
||||
raise ValueError(f'RoPE requires an even head dimension, got {dim}.')
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim))
|
||||
positions = torch.arange(max_seq_len, device=device).float()
|
||||
angles = torch.outer(positions, freqs)
|
||||
return torch.polar(torch.ones_like(angles), angles)
|
||||
|
||||
|
||||
def apply_rope(x: Tensor, freqs: Tensor) -> Tensor:
|
||||
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
||||
freqs = freqs.unsqueeze(0).unsqueeze(2)
|
||||
x_rotated = x_complex * freqs
|
||||
return torch.view_as_real(x_rotated).reshape_as(x).to(x.dtype)
|
||||
|
||||
|
||||
class GroupedQuerySelfAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
dropout: float = 0.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if d_model % n_heads != 0:
|
||||
raise ValueError(f'd_model={d_model} must be divisible by n_heads={n_heads}.')
|
||||
if n_heads % n_kv_heads != 0:
|
||||
raise ValueError(f'n_heads={n_heads} must be divisible by n_kv_heads={n_kv_heads}.')
|
||||
|
||||
self.d_model = d_model
|
||||
self.n_heads = n_heads
|
||||
self.n_kv_heads = n_kv_heads
|
||||
self.n_kv_groups = n_heads // n_kv_heads
|
||||
self.d_head = d_model // n_heads
|
||||
self.attn_dropout = nn.Dropout(dropout)
|
||||
self.out_dropout = nn.Dropout(dropout)
|
||||
|
||||
self.w_q = nn.Linear(d_model, n_heads * self.d_head, bias=False)
|
||||
self.w_k = nn.Linear(d_model, n_kv_heads * self.d_head, bias=False)
|
||||
self.w_v = nn.Linear(d_model, n_kv_heads * self.d_head, bias=False)
|
||||
self.w_o = nn.Linear(n_heads * self.d_head, d_model, bias=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
rope_freqs: Tensor,
|
||||
mask: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
batch_size, seq_len, _ = x.shape
|
||||
|
||||
q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_head)
|
||||
k = self.w_k(x).view(batch_size, seq_len, self.n_kv_heads, self.d_head)
|
||||
v = self.w_v(x).view(batch_size, seq_len, self.n_kv_heads, self.d_head)
|
||||
|
||||
q = apply_rope(q, rope_freqs)
|
||||
k = apply_rope(k, rope_freqs)
|
||||
|
||||
if self.n_kv_heads != self.n_heads:
|
||||
k = k.unsqueeze(3).expand(
|
||||
batch_size, seq_len, self.n_kv_heads, self.n_kv_groups, self.d_head
|
||||
)
|
||||
k = k.reshape(batch_size, seq_len, self.n_heads, self.d_head)
|
||||
v = v.unsqueeze(3).expand(
|
||||
batch_size, seq_len, self.n_kv_heads, self.n_kv_groups, self.d_head
|
||||
)
|
||||
v = v.reshape(batch_size, seq_len, self.n_heads, self.d_head)
|
||||
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
scale = 1.0 / math.sqrt(self.d_head)
|
||||
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale
|
||||
if mask is not None:
|
||||
attn_weights = attn_weights + mask
|
||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||
attn_weights = self.attn_dropout(attn_weights)
|
||||
|
||||
out = torch.matmul(attn_weights, v)
|
||||
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
|
||||
return self.out_dropout(self.w_o(out))
|
||||
|
||||
|
||||
class SwiGLUFFN(nn.Module):
|
||||
def __init__(self, d_model: int, dropout: float = 0.0, mult: float = 2.667) -> None:
|
||||
super().__init__()
|
||||
raw = int(mult * d_model)
|
||||
d_ff = ((raw + 7) // 8) * 8
|
||||
self.w_gate = nn.Linear(d_model, d_ff, bias=False)
|
||||
self.w_up = nn.Linear(d_model, d_ff, bias=False)
|
||||
self.w_down = nn.Linear(d_ff, d_model, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.dropout(self.w_down(F.silu(self.w_gate(x)) * self.w_up(x)))
|
||||
|
||||
|
||||
class AttnResOperator(nn.Module):
|
||||
def __init__(self, d_model: int, eps: float = 1e-6) -> None:
|
||||
super().__init__()
|
||||
self.pseudo_query = nn.Parameter(torch.zeros(d_model))
|
||||
self.key_norm = RMSNormNoWeight(eps=eps)
|
||||
|
||||
def forward(self, sources: Tensor) -> Tensor:
|
||||
keys = self.key_norm(sources)
|
||||
logits = torch.einsum('d,nbtd->nbt', self.pseudo_query, keys)
|
||||
weights = F.softmax(logits, dim=0)
|
||||
return torch.einsum('nbt,nbtd->btd', weights, sources)
|
||||
|
||||
|
||||
class AttnResSubLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
dropout: float,
|
||||
ffn_mult: float,
|
||||
eps: float,
|
||||
is_attention: bool,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(d_model, eps=eps)
|
||||
self.attn_res = AttnResOperator(d_model, eps=eps)
|
||||
self.is_attention = is_attention
|
||||
if is_attention:
|
||||
self.fn = GroupedQuerySelfAttention(
|
||||
d_model=d_model,
|
||||
n_heads=n_heads,
|
||||
n_kv_heads=n_kv_heads,
|
||||
dropout=dropout,
|
||||
)
|
||||
else:
|
||||
self.fn = SwiGLUFFN(d_model=d_model, dropout=dropout, mult=ffn_mult)
|
||||
|
||||
def forward(self, sources: Tensor, rope_freqs: Tensor, mask: Optional[Tensor] = None) -> Tensor:
|
||||
h = self.attn_res(sources)
|
||||
normed = self.norm(h)
|
||||
if self.is_attention:
|
||||
return self.fn(normed, rope_freqs, mask)
|
||||
return self.fn(normed)
|
||||
|
||||
|
||||
class AttnResTransformerBackbone(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
n_blocks: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
max_seq_len: int,
|
||||
dropout: float = 0.0,
|
||||
ffn_mult: float = 2.667,
|
||||
eps: float = 1e-6,
|
||||
rope_theta: float = 10000.0,
|
||||
causal_attn: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.causal_attn = causal_attn
|
||||
self.layers = nn.ModuleList()
|
||||
for _ in range(n_blocks):
|
||||
self.layers.append(
|
||||
AttnResSubLayer(
|
||||
d_model=d_model,
|
||||
n_heads=n_heads,
|
||||
n_kv_heads=n_kv_heads,
|
||||
dropout=dropout,
|
||||
ffn_mult=ffn_mult,
|
||||
eps=eps,
|
||||
is_attention=True,
|
||||
)
|
||||
)
|
||||
self.layers.append(
|
||||
AttnResSubLayer(
|
||||
d_model=d_model,
|
||||
n_heads=n_heads,
|
||||
n_kv_heads=n_kv_heads,
|
||||
dropout=dropout,
|
||||
ffn_mult=ffn_mult,
|
||||
eps=eps,
|
||||
is_attention=False,
|
||||
)
|
||||
)
|
||||
|
||||
rope_freqs = precompute_rope_freqs(
|
||||
dim=d_model // n_heads,
|
||||
max_seq_len=max_seq_len,
|
||||
theta=rope_theta,
|
||||
)
|
||||
self.register_buffer('rope_freqs', rope_freqs, persistent=False)
|
||||
|
||||
@staticmethod
|
||||
def _build_causal_mask(seq_len: int, device: torch.device) -> Tensor:
|
||||
mask = torch.full((seq_len, seq_len), float('-inf'), device=device)
|
||||
mask = torch.triu(mask, diagonal=1)
|
||||
return mask.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
seq_len = x.shape[1]
|
||||
rope_freqs = self.rope_freqs[:seq_len]
|
||||
mask = None
|
||||
if self.causal_attn:
|
||||
mask = self._build_causal_mask(seq_len, x.device)
|
||||
|
||||
layer_outputs = [x]
|
||||
for layer in self.layers:
|
||||
sources = torch.stack(layer_outputs, dim=0)
|
||||
output = layer(sources, rope_freqs, mask)
|
||||
layer_outputs.append(output)
|
||||
|
||||
return torch.stack(layer_outputs, dim=0).sum(dim=0)
|
||||
@@ -0,0 +1,383 @@
|
||||
from typing import Optional, Tuple, Union
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin
|
||||
from diffusion_policy.model.diffusion.attnres_transformer_components import (
|
||||
AttnResOperator,
|
||||
AttnResSubLayer,
|
||||
AttnResTransformerBackbone,
|
||||
GroupedQuerySelfAttention,
|
||||
RMSNorm,
|
||||
RMSNormNoWeight,
|
||||
SwiGLUFFN,
|
||||
)
|
||||
from diffusion_policy.model.diffusion.positional_embedding import SinusoidalPosEmb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IMFTransformerForDiffusion(ModuleAttrMixin):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
horizon: int,
|
||||
n_obs_steps: int = None,
|
||||
cond_dim: int = 0,
|
||||
n_layer: int = 12,
|
||||
n_head: int = 1,
|
||||
n_emb: int = 768,
|
||||
p_drop_emb: float = 0.1,
|
||||
p_drop_attn: float = 0.1,
|
||||
causal_attn: bool = False,
|
||||
time_as_cond: bool = True,
|
||||
obs_as_cond: bool = False,
|
||||
n_cond_layers: int = 0,
|
||||
backbone_type: str = 'vanilla',
|
||||
n_kv_head: int = 1,
|
||||
attn_res_ffn_mult: float = 2.667,
|
||||
attn_res_eps: float = 1e-6,
|
||||
attn_res_rope_theta: float = 10000.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
assert n_head == 1, 'IMFTransformerForDiffusion currently supports single-head attention only.'
|
||||
if n_obs_steps is None:
|
||||
n_obs_steps = horizon
|
||||
|
||||
self.backbone_type = backbone_type
|
||||
|
||||
T = horizon
|
||||
T_cond = 2
|
||||
if not time_as_cond:
|
||||
T += 2
|
||||
T_cond -= 2
|
||||
obs_as_cond = cond_dim > 0
|
||||
if obs_as_cond:
|
||||
assert time_as_cond
|
||||
T_cond += n_obs_steps
|
||||
|
||||
self.input_emb = nn.Linear(input_dim, n_emb)
|
||||
self.drop = nn.Dropout(p_drop_emb)
|
||||
self.time_emb = SinusoidalPosEmb(n_emb)
|
||||
self.cond_obs_emb = nn.Linear(cond_dim, n_emb) if obs_as_cond else None
|
||||
self.time_token_proj = None
|
||||
self.cond_pos_emb = None
|
||||
self.pos_emb = None
|
||||
self.encoder = None
|
||||
self.decoder = None
|
||||
self.attnres_backbone = None
|
||||
encoder_only = False
|
||||
|
||||
if backbone_type == 'attnres_full':
|
||||
if not time_as_cond:
|
||||
raise ValueError('attnres_full backbone requires time_as_cond=True.')
|
||||
if n_cond_layers != 0:
|
||||
raise ValueError('attnres_full backbone does not support n_cond_layers > 0.')
|
||||
|
||||
self.time_token_proj = nn.Linear(n_emb, n_emb)
|
||||
self.attnres_backbone = AttnResTransformerBackbone(
|
||||
d_model=n_emb,
|
||||
n_blocks=n_layer,
|
||||
n_heads=n_head,
|
||||
n_kv_heads=n_kv_head,
|
||||
max_seq_len=T + T_cond,
|
||||
dropout=p_drop_attn,
|
||||
ffn_mult=attn_res_ffn_mult,
|
||||
eps=attn_res_eps,
|
||||
rope_theta=attn_res_rope_theta,
|
||||
causal_attn=causal_attn,
|
||||
)
|
||||
self.ln_f = RMSNorm(n_emb, eps=attn_res_eps)
|
||||
else:
|
||||
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
|
||||
if T_cond > 0:
|
||||
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
|
||||
if n_cond_layers > 0:
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=n_emb,
|
||||
nhead=n_head,
|
||||
dim_feedforward=4 * n_emb,
|
||||
dropout=p_drop_attn,
|
||||
activation='gelu',
|
||||
batch_first=True,
|
||||
norm_first=True,
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(
|
||||
encoder_layer=encoder_layer,
|
||||
num_layers=n_cond_layers,
|
||||
)
|
||||
else:
|
||||
self.encoder = nn.Sequential(
|
||||
nn.Linear(n_emb, 4 * n_emb),
|
||||
nn.Mish(),
|
||||
nn.Linear(4 * n_emb, n_emb),
|
||||
)
|
||||
|
||||
decoder_layer = nn.TransformerDecoderLayer(
|
||||
d_model=n_emb,
|
||||
nhead=n_head,
|
||||
dim_feedforward=4 * n_emb,
|
||||
dropout=p_drop_attn,
|
||||
activation='gelu',
|
||||
batch_first=True,
|
||||
norm_first=True,
|
||||
)
|
||||
self.decoder = nn.TransformerDecoder(
|
||||
decoder_layer=decoder_layer,
|
||||
num_layers=n_layer,
|
||||
)
|
||||
else:
|
||||
encoder_only = True
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=n_emb,
|
||||
nhead=n_head,
|
||||
dim_feedforward=4 * n_emb,
|
||||
dropout=p_drop_attn,
|
||||
activation='gelu',
|
||||
batch_first=True,
|
||||
norm_first=True,
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(
|
||||
encoder_layer=encoder_layer,
|
||||
num_layers=n_layer,
|
||||
)
|
||||
|
||||
self.ln_f = nn.LayerNorm(n_emb)
|
||||
|
||||
if causal_attn and backbone_type != 'attnres_full':
|
||||
sz = T
|
||||
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
||||
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||
self.register_buffer('mask', mask)
|
||||
|
||||
if time_as_cond and obs_as_cond:
|
||||
S = T_cond
|
||||
t_idx, s_idx = torch.meshgrid(
|
||||
torch.arange(T),
|
||||
torch.arange(S),
|
||||
indexing='ij',
|
||||
)
|
||||
mask = t_idx >= (s_idx - 2)
|
||||
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||
self.register_buffer('memory_mask', mask)
|
||||
else:
|
||||
self.memory_mask = None
|
||||
else:
|
||||
self.mask = None
|
||||
self.memory_mask = None
|
||||
|
||||
self.head = nn.Linear(n_emb, output_dim)
|
||||
|
||||
self.T = T
|
||||
self.T_cond = T_cond
|
||||
self.horizon = horizon
|
||||
self.time_as_cond = time_as_cond
|
||||
self.obs_as_cond = obs_as_cond
|
||||
self.encoder_only = encoder_only
|
||||
|
||||
self.apply(self._init_weights)
|
||||
logger.info(
|
||||
'number of parameters: %e',
|
||||
sum(p.numel() for p in self.parameters()),
|
||||
)
|
||||
|
||||
def _init_weights(self, module):
|
||||
ignore_types = (
|
||||
nn.Dropout,
|
||||
SinusoidalPosEmb,
|
||||
nn.TransformerEncoderLayer,
|
||||
nn.TransformerDecoderLayer,
|
||||
nn.TransformerEncoder,
|
||||
nn.TransformerDecoder,
|
||||
nn.ModuleList,
|
||||
nn.Mish,
|
||||
nn.Sequential,
|
||||
AttnResTransformerBackbone,
|
||||
AttnResSubLayer,
|
||||
GroupedQuerySelfAttention,
|
||||
SwiGLUFFN,
|
||||
RMSNormNoWeight,
|
||||
)
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.MultiheadAttention):
|
||||
weight_names = ['in_proj_weight', 'q_proj_weight', 'k_proj_weight', 'v_proj_weight']
|
||||
for name in weight_names:
|
||||
weight = getattr(module, name)
|
||||
if weight is not None:
|
||||
torch.nn.init.normal_(weight, mean=0.0, std=0.02)
|
||||
|
||||
bias_names = ['in_proj_bias', 'bias_k', 'bias_v']
|
||||
for name in bias_names:
|
||||
bias = getattr(module, name)
|
||||
if bias is not None:
|
||||
torch.nn.init.zeros_(bias)
|
||||
elif isinstance(module, (nn.LayerNorm, RMSNorm)):
|
||||
if getattr(module, 'bias', None) is not None:
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
torch.nn.init.ones_(module.weight)
|
||||
elif isinstance(module, AttnResOperator):
|
||||
torch.nn.init.zeros_(module.pseudo_query)
|
||||
elif isinstance(module, IMFTransformerForDiffusion):
|
||||
if module.pos_emb is not None:
|
||||
torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
|
||||
if module.cond_pos_emb is not None:
|
||||
torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02)
|
||||
elif isinstance(module, ignore_types):
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError(f'Unaccounted module {module}')
|
||||
|
||||
def get_optim_groups(self, weight_decay: float = 1e-3):
|
||||
decay = set()
|
||||
no_decay = set()
|
||||
whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention)
|
||||
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding, RMSNorm)
|
||||
for mn, m in self.named_modules():
|
||||
for pn, _ in m.named_parameters(recurse=False):
|
||||
fpn = '%s.%s' % (mn, pn) if mn else pn
|
||||
|
||||
if pn.endswith('bias'):
|
||||
no_decay.add(fpn)
|
||||
elif pn.startswith('bias'):
|
||||
no_decay.add(fpn)
|
||||
elif pn == 'pseudo_query':
|
||||
no_decay.add(fpn)
|
||||
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
|
||||
decay.add(fpn)
|
||||
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
|
||||
no_decay.add(fpn)
|
||||
|
||||
if self.pos_emb is not None:
|
||||
no_decay.add('pos_emb')
|
||||
no_decay.add('_dummy_variable')
|
||||
if self.cond_pos_emb is not None:
|
||||
no_decay.add('cond_pos_emb')
|
||||
|
||||
param_dict = {pn: p for pn, p in self.named_parameters()}
|
||||
inter_params = decay & no_decay
|
||||
union_params = decay | no_decay
|
||||
assert len(inter_params) == 0, f'parameters {inter_params} made it into both decay/no_decay sets!'
|
||||
assert len(param_dict.keys() - union_params) == 0, (
|
||||
f'parameters {param_dict.keys() - union_params} were not separated into either decay/no_decay sets!'
|
||||
)
|
||||
|
||||
optim_groups = [
|
||||
{
|
||||
'params': [param_dict[pn] for pn in sorted(list(decay))],
|
||||
'weight_decay': weight_decay,
|
||||
},
|
||||
{
|
||||
'params': [param_dict[pn] for pn in sorted(list(no_decay))],
|
||||
'weight_decay': 0.0,
|
||||
},
|
||||
]
|
||||
return optim_groups
|
||||
|
||||
def configure_optimizers(
|
||||
self,
|
||||
learning_rate: float = 1e-4,
|
||||
weight_decay: float = 1e-3,
|
||||
betas: Tuple[float, float] = (0.9, 0.95),
|
||||
):
|
||||
optim_groups = self.get_optim_groups(weight_decay=weight_decay)
|
||||
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
|
||||
return optimizer
|
||||
|
||||
def _prepare_time_input(self, value: Union[torch.Tensor, float, int], sample: torch.Tensor) -> torch.Tensor:
|
||||
if not torch.is_tensor(value):
|
||||
value = torch.tensor([value], dtype=sample.dtype, device=sample.device)
|
||||
elif value.ndim == 0:
|
||||
value = value[None].to(device=sample.device, dtype=sample.dtype)
|
||||
else:
|
||||
value = value.to(device=sample.device, dtype=sample.dtype)
|
||||
return value.expand(sample.shape[0])
|
||||
|
||||
def _forward_attnres_full(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
r: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
cond: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
sample_tokens = self.input_emb(sample)
|
||||
token_parts = [
|
||||
self.time_token_proj(self.time_emb(r)).unsqueeze(1),
|
||||
self.time_token_proj(self.time_emb(t)).unsqueeze(1),
|
||||
]
|
||||
if self.obs_as_cond:
|
||||
if cond is None:
|
||||
raise ValueError('cond is required when obs_as_cond=True for attnres_full backbone.')
|
||||
token_parts.append(self.cond_obs_emb(cond))
|
||||
token_parts.append(sample_tokens)
|
||||
x = torch.cat(token_parts, dim=1)
|
||||
x = self.drop(x)
|
||||
x = self.attnres_backbone(x)
|
||||
x = x[:, -sample_tokens.shape[1] :, :]
|
||||
return x
|
||||
|
||||
def _forward_vanilla(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
r: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
cond: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
r_emb = self.time_emb(r).unsqueeze(1)
|
||||
t_emb = self.time_emb(t).unsqueeze(1)
|
||||
input_emb = self.input_emb(sample)
|
||||
|
||||
if self.encoder_only:
|
||||
token_embeddings = torch.cat([r_emb, t_emb, input_emb], dim=1)
|
||||
token_count = token_embeddings.shape[1]
|
||||
position_embeddings = self.pos_emb[:, :token_count, :]
|
||||
x = self.drop(token_embeddings + position_embeddings)
|
||||
x = self.encoder(src=x, mask=self.mask)
|
||||
x = x[:, 2:, :]
|
||||
else:
|
||||
cond_embeddings = torch.cat([r_emb, t_emb], dim=1)
|
||||
if self.obs_as_cond:
|
||||
cond_obs_emb = self.cond_obs_emb(cond)
|
||||
cond_embeddings = torch.cat([cond_embeddings, cond_obs_emb], dim=1)
|
||||
token_count = cond_embeddings.shape[1]
|
||||
position_embeddings = self.cond_pos_emb[:, :token_count, :]
|
||||
x = self.drop(cond_embeddings + position_embeddings)
|
||||
x = self.encoder(x)
|
||||
memory = x
|
||||
|
||||
token_embeddings = input_emb
|
||||
token_count = token_embeddings.shape[1]
|
||||
position_embeddings = self.pos_emb[:, :token_count, :]
|
||||
x = self.drop(token_embeddings + position_embeddings)
|
||||
x = self.decoder(
|
||||
tgt=x,
|
||||
memory=memory,
|
||||
tgt_mask=self.mask,
|
||||
memory_mask=self.memory_mask,
|
||||
)
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
r: Union[torch.Tensor, float, int],
|
||||
t: Union[torch.Tensor, float, int],
|
||||
cond: Optional[torch.Tensor] = None,
|
||||
):
|
||||
r = self._prepare_time_input(r, sample)
|
||||
t = self._prepare_time_input(t, sample)
|
||||
|
||||
if self.backbone_type == 'attnres_full':
|
||||
x = self._forward_attnres_full(sample, r, t, cond=cond)
|
||||
else:
|
||||
x = self._forward_vanilla(sample, r, t, cond=cond)
|
||||
|
||||
x = self.ln_f(x)
|
||||
x = self.head(x)
|
||||
return x
|
||||
@@ -256,8 +256,8 @@ class DiffusionTransformerHybridImagePolicy(BaseImagePolicy):
|
||||
# condition through impainting
|
||||
this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
|
||||
nobs_features = self.obs_encoder(this_nobs)
|
||||
# reshape back to B, T, Do
|
||||
nobs_features = nobs_features.reshape(B, T, -1)
|
||||
# reshape back to B, To, Do
|
||||
nobs_features = nobs_features.reshape(B, To, -1)
|
||||
shape = (B, T, Da+Do)
|
||||
cond_data = torch.zeros(size=shape, device=device, dtype=dtype)
|
||||
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
|
||||
|
||||
@@ -247,7 +247,7 @@ class DiffusionUnetHybridImagePolicy(BaseImagePolicy):
|
||||
# condition through impainting
|
||||
this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
|
||||
nobs_features = self.obs_encoder(this_nobs)
|
||||
# reshape back to B, T, Do
|
||||
# reshape back to B, To, Do
|
||||
nobs_features = nobs_features.reshape(B, To, -1)
|
||||
cond_data = torch.zeros(size=(B, T, Da+Do), device=device, dtype=dtype)
|
||||
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
|
||||
|
||||
283
diffusion_policy/policy/imf_transformer_hybrid_image_policy.py
Normal file
283
diffusion_policy/policy/imf_transformer_hybrid_image_policy.py
Normal file
@@ -0,0 +1,283 @@
|
||||
from contextlib import nullcontext
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import reduce
|
||||
|
||||
from diffusion_policy.common.pytorch_util import dict_apply
|
||||
from diffusion_policy.model.diffusion.imf_transformer_for_diffusion import IMFTransformerForDiffusion
|
||||
from diffusion_policy.policy.diffusion_transformer_hybrid_image_policy import (
|
||||
DiffusionTransformerHybridImagePolicy,
|
||||
)
|
||||
|
||||
try:
|
||||
from torch.func import jvp as TORCH_FUNC_JVP
|
||||
except ImportError: # pragma: no cover - depends on torch version
|
||||
TORCH_FUNC_JVP = None
|
||||
|
||||
|
||||
class IMFTransformerHybridImagePolicy(DiffusionTransformerHybridImagePolicy):
|
||||
def __init__(
|
||||
self,
|
||||
shape_meta: dict,
|
||||
noise_scheduler,
|
||||
horizon,
|
||||
n_action_steps,
|
||||
n_obs_steps,
|
||||
num_inference_steps=None,
|
||||
crop_shape=(76, 76),
|
||||
obs_encoder_group_norm=False,
|
||||
eval_fixed_crop=False,
|
||||
n_layer=8,
|
||||
n_cond_layers=0,
|
||||
n_head=1,
|
||||
n_emb=256,
|
||||
p_drop_emb=0.0,
|
||||
p_drop_attn=0.3,
|
||||
causal_attn=True,
|
||||
time_as_cond=True,
|
||||
obs_as_cond=True,
|
||||
pred_action_steps_only=False,
|
||||
backbone_type='vanilla',
|
||||
n_kv_head=1,
|
||||
attn_res_ffn_mult=2.667,
|
||||
attn_res_eps=1e-6,
|
||||
attn_res_rope_theta=10000.0,
|
||||
**kwargs,
|
||||
):
|
||||
if num_inference_steps is None:
|
||||
num_inference_steps = 1
|
||||
elif num_inference_steps != 1:
|
||||
raise ValueError(
|
||||
'IMFTransformerHybridImagePolicy only supports one-step inference; '
|
||||
f'num_inference_steps must be 1, got {num_inference_steps}.'
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
shape_meta=shape_meta,
|
||||
noise_scheduler=noise_scheduler,
|
||||
horizon=horizon,
|
||||
n_action_steps=n_action_steps,
|
||||
n_obs_steps=n_obs_steps,
|
||||
num_inference_steps=num_inference_steps,
|
||||
crop_shape=crop_shape,
|
||||
obs_encoder_group_norm=obs_encoder_group_norm,
|
||||
eval_fixed_crop=eval_fixed_crop,
|
||||
n_layer=n_layer,
|
||||
n_cond_layers=n_cond_layers,
|
||||
n_head=n_head,
|
||||
n_emb=n_emb,
|
||||
p_drop_emb=p_drop_emb,
|
||||
p_drop_attn=p_drop_attn,
|
||||
causal_attn=causal_attn,
|
||||
time_as_cond=time_as_cond,
|
||||
obs_as_cond=obs_as_cond,
|
||||
pred_action_steps_only=pred_action_steps_only,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
input_dim = self.action_dim if self.obs_as_cond else (self.obs_feature_dim + self.action_dim)
|
||||
cond_dim = self.obs_feature_dim if self.obs_as_cond else 0
|
||||
model_horizon = self.n_action_steps if self.pred_action_steps_only else horizon
|
||||
self.model = IMFTransformerForDiffusion(
|
||||
input_dim=input_dim,
|
||||
output_dim=input_dim,
|
||||
horizon=model_horizon,
|
||||
n_obs_steps=n_obs_steps,
|
||||
cond_dim=cond_dim,
|
||||
n_layer=n_layer,
|
||||
n_head=n_head,
|
||||
n_emb=n_emb,
|
||||
p_drop_emb=p_drop_emb,
|
||||
p_drop_attn=p_drop_attn,
|
||||
causal_attn=causal_attn,
|
||||
time_as_cond=time_as_cond,
|
||||
obs_as_cond=obs_as_cond,
|
||||
n_cond_layers=n_cond_layers,
|
||||
backbone_type=backbone_type,
|
||||
n_kv_head=n_kv_head,
|
||||
attn_res_ffn_mult=attn_res_ffn_mult,
|
||||
attn_res_eps=attn_res_eps,
|
||||
attn_res_rope_theta=attn_res_rope_theta,
|
||||
)
|
||||
self.num_inference_steps = 1
|
||||
|
||||
def fn(self, z: torch.Tensor, r: torch.Tensor, t: torch.Tensor, cond=None) -> torch.Tensor:
|
||||
return self.model(z, r, t, cond=cond)
|
||||
|
||||
@staticmethod
|
||||
def _broadcast_batch_time(value: torch.Tensor, reference: torch.Tensor) -> torch.Tensor:
|
||||
while value.ndim < reference.ndim:
|
||||
value = value.unsqueeze(-1)
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _apply_conditioning(
|
||||
trajectory: torch.Tensor,
|
||||
condition_data: Optional[torch.Tensor] = None,
|
||||
condition_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if condition_data is None or condition_mask is None:
|
||||
return trajectory
|
||||
conditioned = trajectory.clone()
|
||||
conditioned[condition_mask] = condition_data[condition_mask]
|
||||
return conditioned
|
||||
|
||||
@staticmethod
|
||||
def _jvp_math_sdp_context(z_t: torch.Tensor):
|
||||
if z_t.is_cuda:
|
||||
return torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=False,
|
||||
enable_math=True,
|
||||
enable_mem_efficient=False,
|
||||
enable_cudnn=False,
|
||||
)
|
||||
return nullcontext()
|
||||
|
||||
@staticmethod
|
||||
def _jvp_tangents(v: torch.Tensor, r: torch.Tensor, t: torch.Tensor):
|
||||
return v.detach(), torch.zeros_like(r), torch.ones_like(t)
|
||||
|
||||
def _compute_u_and_du_dt(
|
||||
self,
|
||||
z_t: torch.Tensor,
|
||||
r: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
cond,
|
||||
v: torch.Tensor,
|
||||
condition_data: Optional[torch.Tensor] = None,
|
||||
condition_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
tangents = self._jvp_tangents(v, r, t)
|
||||
|
||||
def g(z, r_value, t_value):
|
||||
conditioned_z = self._apply_conditioning(z, condition_data, condition_mask)
|
||||
return self.fn(conditioned_z, r_value, t_value, cond=cond)
|
||||
|
||||
with self._jvp_math_sdp_context(z_t):
|
||||
if TORCH_FUNC_JVP is not None:
|
||||
try:
|
||||
return TORCH_FUNC_JVP(g, (z_t, r, t), tangents)
|
||||
except (RuntimeError, TypeError, NotImplementedError):
|
||||
pass
|
||||
|
||||
u = g(z_t, r, t)
|
||||
_, du_dt = torch.autograd.functional.jvp(
|
||||
g,
|
||||
(z_t, r, t),
|
||||
tangents,
|
||||
create_graph=False,
|
||||
strict=False,
|
||||
)
|
||||
return u, du_dt
|
||||
|
||||
def _compound_velocity(
|
||||
self,
|
||||
u: torch.Tensor,
|
||||
du_dt: torch.Tensor,
|
||||
r: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
delta = self._broadcast_batch_time(t - r, u)
|
||||
return u + delta * du_dt.detach()
|
||||
|
||||
def _sample_one_step(
|
||||
self,
|
||||
z_t: torch.Tensor,
|
||||
r: torch.Tensor = None,
|
||||
t: torch.Tensor = None,
|
||||
cond=None,
|
||||
) -> torch.Tensor:
|
||||
batch_size = z_t.shape[0]
|
||||
if t is None:
|
||||
t = torch.ones(batch_size, device=z_t.device, dtype=z_t.dtype)
|
||||
if r is None:
|
||||
r = torch.zeros(batch_size, device=z_t.device, dtype=z_t.dtype)
|
||||
u = self.fn(z_t, r, t, cond=cond)
|
||||
delta = self._broadcast_batch_time(t - r, z_t)
|
||||
return z_t - delta * u
|
||||
|
||||
def conditional_sample(
|
||||
self,
|
||||
condition_data,
|
||||
condition_mask,
|
||||
cond=None,
|
||||
generator=None,
|
||||
**kwargs,
|
||||
):
|
||||
trajectory = torch.randn(
|
||||
size=condition_data.shape,
|
||||
dtype=condition_data.dtype,
|
||||
device=condition_data.device,
|
||||
generator=generator,
|
||||
)
|
||||
trajectory = self._apply_conditioning(trajectory, condition_data, condition_mask)
|
||||
trajectory = self._sample_one_step(trajectory, cond=cond)
|
||||
trajectory = self._apply_conditioning(trajectory, condition_data, condition_mask)
|
||||
return trajectory
|
||||
|
||||
def compute_loss(self, batch):
|
||||
assert 'valid_mask' not in batch
|
||||
nobs = self.normalizer.normalize(batch['obs'])
|
||||
nactions = self.normalizer['action'].normalize(batch['action'])
|
||||
batch_size = nactions.shape[0]
|
||||
horizon = nactions.shape[1]
|
||||
To = self.n_obs_steps
|
||||
|
||||
cond = None
|
||||
trajectory = nactions
|
||||
if self.obs_as_cond:
|
||||
this_nobs = dict_apply(
|
||||
nobs,
|
||||
lambda x: x[:, :To, ...].reshape(-1, *x.shape[2:]),
|
||||
)
|
||||
nobs_features = self.obs_encoder(this_nobs)
|
||||
cond = nobs_features.reshape(batch_size, To, -1)
|
||||
if self.pred_action_steps_only:
|
||||
start = To - 1
|
||||
end = start + self.n_action_steps
|
||||
trajectory = nactions[:, start:end]
|
||||
else:
|
||||
this_nobs = dict_apply(nobs, lambda x: x.reshape(-1, *x.shape[2:]))
|
||||
nobs_features = self.obs_encoder(this_nobs)
|
||||
nobs_features = nobs_features.reshape(batch_size, horizon, -1)
|
||||
trajectory = torch.cat([nactions, nobs_features], dim=-1).detach()
|
||||
|
||||
if self.pred_action_steps_only:
|
||||
condition_mask = torch.zeros_like(trajectory, dtype=torch.bool)
|
||||
else:
|
||||
condition_mask = self.mask_generator(trajectory.shape)
|
||||
|
||||
loss_mask = torch.zeros_like(trajectory, dtype=torch.bool)
|
||||
loss_mask[..., : self.action_dim] = True
|
||||
loss_mask = loss_mask & ~condition_mask
|
||||
|
||||
x = trajectory
|
||||
e = torch.randn_like(x)
|
||||
t = torch.rand(batch_size, device=x.device, dtype=x.dtype)
|
||||
r = torch.rand(batch_size, device=x.device, dtype=x.dtype)
|
||||
t, r = torch.maximum(t, r), torch.minimum(t, r)
|
||||
|
||||
t_broadcast = self._broadcast_batch_time(t, x)
|
||||
z_t = (1 - t_broadcast) * x + t_broadcast * e
|
||||
z_t = self._apply_conditioning(z_t, x, condition_mask)
|
||||
|
||||
v = self.fn(z_t, t, t, cond=cond)
|
||||
u, du_dt = self._compute_u_and_du_dt(
|
||||
z_t,
|
||||
r,
|
||||
t,
|
||||
cond=cond,
|
||||
v=v,
|
||||
condition_data=x,
|
||||
condition_mask=condition_mask,
|
||||
)
|
||||
V = self._compound_velocity(u, du_dt, r, t)
|
||||
target = e - x
|
||||
|
||||
loss = F.mse_loss(V, target, reduction='none')
|
||||
loss = loss * loss_mask.type(loss.dtype)
|
||||
loss = reduce(loss, 'b ... -> b (...)', 'mean')
|
||||
loss = loss.mean()
|
||||
return loss
|
||||
@@ -8,6 +8,8 @@ if __name__ == "__main__":
|
||||
os.chdir(ROOT_DIR)
|
||||
|
||||
import os
|
||||
import contextlib
|
||||
import importlib
|
||||
import hydra
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
@@ -15,7 +17,6 @@ import pathlib
|
||||
from torch.utils.data import DataLoader
|
||||
import copy
|
||||
import random
|
||||
import wandb
|
||||
import tqdm
|
||||
import numpy as np
|
||||
import shutil
|
||||
@@ -31,6 +32,111 @@ from diffusion_policy.model.common.lr_scheduler import get_scheduler
|
||||
|
||||
OmegaConf.register_new_resolver("eval", eval, replace=True)
|
||||
|
||||
|
||||
class _LoggingBackend:
|
||||
def log(self, payload, step=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def finish(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class _WandbLoggingBackend(_LoggingBackend):
|
||||
def __init__(self, run):
|
||||
self.run = run
|
||||
|
||||
def log(self, payload, step=None):
|
||||
self.run.log(payload, step=step)
|
||||
|
||||
def finish(self):
|
||||
self.run.finish()
|
||||
|
||||
|
||||
class _SwanLabLoggingBackend(_LoggingBackend):
|
||||
def __init__(self, run):
|
||||
self.run = run
|
||||
|
||||
def log(self, payload, step=None):
|
||||
self.run.log(payload, step=step)
|
||||
|
||||
def finish(self):
|
||||
self.run.finish()
|
||||
|
||||
|
||||
def _load_wandb():
|
||||
try:
|
||||
return importlib.import_module('wandb')
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"wandb is required when cfg.logging.backend == 'wandb' or missing"
|
||||
) from exc
|
||||
|
||||
|
||||
def _load_swanlab():
|
||||
try:
|
||||
return importlib.import_module('swanlab')
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def init_logging_backend(cfg: OmegaConf, output_dir):
|
||||
backend = OmegaConf.select(cfg, 'logging.backend', default='wandb')
|
||||
if backend == 'swanlab':
|
||||
swanlab = _load_swanlab()
|
||||
if swanlab is None:
|
||||
raise ImportError("swanlab is required when cfg.logging.backend == 'swanlab'")
|
||||
logging_cfg = cfg.logging
|
||||
mode = logging_cfg.mode
|
||||
if mode == 'online':
|
||||
mode = 'cloud'
|
||||
run = swanlab.init(
|
||||
project=logging_cfg.project,
|
||||
experiment_name=logging_cfg.name,
|
||||
group=logging_cfg.group,
|
||||
tags=logging_cfg.tags,
|
||||
id=logging_cfg.id,
|
||||
resume=logging_cfg.resume,
|
||||
mode=mode,
|
||||
logdir=str(pathlib.Path(output_dir) / 'swanlog'),
|
||||
config=OmegaConf.to_container(cfg, resolve=True),
|
||||
)
|
||||
return _SwanLabLoggingBackend(run)
|
||||
|
||||
if backend not in (None, 'wandb'):
|
||||
raise ValueError(f"Unknown logging backend: {backend}")
|
||||
|
||||
wandb = _load_wandb()
|
||||
logging_kwargs = OmegaConf.to_container(cfg.logging, resolve=True)
|
||||
logging_kwargs.pop('backend', None)
|
||||
run = wandb.init(
|
||||
dir=str(output_dir),
|
||||
config=OmegaConf.to_container(cfg, resolve=True),
|
||||
**logging_kwargs
|
||||
)
|
||||
wandb.config.update(
|
||||
{
|
||||
"output_dir": str(output_dir),
|
||||
}
|
||||
)
|
||||
return _WandbLoggingBackend(run)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def logging_backend_session(cfg: OmegaConf, output_dir):
|
||||
logging_backend = init_logging_backend(cfg=cfg, output_dir=output_dir)
|
||||
primary_error = None
|
||||
try:
|
||||
yield logging_backend
|
||||
except BaseException as exc:
|
||||
primary_error = exc
|
||||
raise
|
||||
finally:
|
||||
try:
|
||||
logging_backend.finish()
|
||||
except BaseException:
|
||||
if primary_error is None:
|
||||
raise
|
||||
|
||||
class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace):
|
||||
include_keys = ['global_step', 'epoch']
|
||||
|
||||
@@ -109,18 +215,6 @@ class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace):
|
||||
output_dir=self.output_dir)
|
||||
assert isinstance(env_runner, BaseImageRunner)
|
||||
|
||||
# configure logging
|
||||
wandb_run = wandb.init(
|
||||
dir=str(self.output_dir),
|
||||
config=OmegaConf.to_container(cfg, resolve=True),
|
||||
**cfg.logging
|
||||
)
|
||||
wandb.config.update(
|
||||
{
|
||||
"output_dir": self.output_dir,
|
||||
}
|
||||
)
|
||||
|
||||
# configure checkpoint
|
||||
topk_manager = TopKCheckpointManager(
|
||||
save_dir=os.path.join(self.output_dir, 'checkpoints'),
|
||||
@@ -148,6 +242,7 @@ class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace):
|
||||
|
||||
# training loop
|
||||
log_path = os.path.join(self.output_dir, 'logs.json.txt')
|
||||
with logging_backend_session(cfg=cfg, output_dir=self.output_dir) as logging_backend:
|
||||
with JsonLogger(log_path) as json_logger:
|
||||
for local_epoch_idx in range(cfg.training.num_epochs):
|
||||
step_log = dict()
|
||||
@@ -190,7 +285,7 @@ class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace):
|
||||
is_last_batch = (batch_idx == (len(train_dataloader)-1))
|
||||
if not is_last_batch:
|
||||
# log of last step is combined with validation and rollout
|
||||
wandb_run.log(step_log, step=self.global_step)
|
||||
logging_backend.log(step_log, step=self.global_step)
|
||||
json_logger.log(step_log)
|
||||
self.global_step += 1
|
||||
|
||||
@@ -278,7 +373,7 @@ class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace):
|
||||
|
||||
# end of epoch
|
||||
# log of last step is combined with validation and rollout
|
||||
wandb_run.log(step_log, step=self.global_step)
|
||||
logging_backend.log(step_log, step=self.global_step)
|
||||
json_logger.log(step_log)
|
||||
self.global_step += 1
|
||||
self.epoch += 1
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
# PushT iMF Full-Attention Implementation Plan
|
||||
|
||||
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||
|
||||
**Goal:** Add a separate full-attention PushT image iMF config, commit/push it on a new branch, and launch the 9-run 350-epoch architecture sweep across 3 GPUs.
|
||||
|
||||
**Architecture:** Keep the existing causal iMF path untouched and add a standalone full-attention config that only flips `policy.causal_attn=false` while retaining one-step iMF inference and SwanLab-safe naming. Reuse the previous 9-run architecture matrix and balanced 3-queue scheduling across local 5090 plus 5880 GPU0/GPU1.
|
||||
|
||||
**Tech Stack:** Hydra, Diffusion Policy iMF image workspace, SwanLab, uv env, local shell + trusted remote 5880 over SSH.
|
||||
|
||||
---
|
||||
|
||||
### Task 1: Add full-attention iMF config with TDD
|
||||
|
||||
**Files:**
|
||||
- Create: `image_pusht_diffusion_policy_dit_imf_fullattn.yaml`
|
||||
- Modify: `tests/test_pusht_swanlab_config.py`
|
||||
|
||||
- [ ] Write a failing config regression test asserting the new config uses SwanLab-safe naming and `policy.causal_attn == False`.
|
||||
- [ ] Run the targeted pytest command and verify it fails because the config does not exist yet.
|
||||
- [ ] Add the minimal full-attention config by composing from the existing PushT image iMF config and overriding only `exp_name` and `policy.causal_attn=false`.
|
||||
- [ ] Re-run the targeted pytest and verify it passes.
|
||||
|
||||
### Task 2: Verify the new config
|
||||
|
||||
**Files:**
|
||||
- Read: `image_pusht_diffusion_policy_dit_imf_fullattn.yaml`
|
||||
|
||||
- [ ] Run `train.py --help` for the new config.
|
||||
- [ ] Run a real `training.debug=true` smoke test locally to confirm the training path is valid.
|
||||
|
||||
### Task 3: Commit and push the new branch
|
||||
|
||||
**Files:**
|
||||
- Commit only the new config/test/plan files needed for the full-attention experiment chain.
|
||||
|
||||
- [ ] Run verification commands again before commit.
|
||||
- [ ] Commit with a focused message.
|
||||
- [ ] Push `feat/pusht-imf-fullattn` to origin.
|
||||
|
||||
### Task 4: Launch the 9-run sweep
|
||||
|
||||
**Files:**
|
||||
- Write queue scripts and logs under `data/run_logs/` locally and on 5880.
|
||||
- Write outputs under `data/outputs/` locally and on 5880.
|
||||
|
||||
- [ ] Use the same matrix as the prior iMF sweep: `n_emb ∈ {128,256,384}`, `n_layer ∈ {6,12,18}`, `seed=42`.
|
||||
- [ ] Set `training.num_epochs=350` for all 9 runs.
|
||||
- [ ] Encode `fullattn` in every `exp_name`, `logging.name`, and run directory to avoid collisions.
|
||||
- [ ] Balance the 9 runs across local 5090, 5880 GPU0, and 5880 GPU1 as three serial queues.
|
||||
- [ ] Sync the new config to the remote smoke repo before launching remote queues.
|
||||
|
||||
### Task 5: Monitor and auto-summarize
|
||||
|
||||
**Files:**
|
||||
- Read local and remote pid files, logs, outputs, checkpoints.
|
||||
|
||||
- [ ] Start an xhigh monitoring agent that polls all three queues.
|
||||
- [ ] On completion, parse all 9 `logs.json.txt` files and rank by max `test_mean_score`.
|
||||
- [ ] Report embedding/layer trends and the best configuration.
|
||||
@@ -0,0 +1,57 @@
|
||||
# PushT Image iMF AttnRes Implementation Plan
|
||||
|
||||
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||
|
||||
**Goal:** Add an AttnRes-backed full-attention iMF backbone for the PushT image experiment path, verify it with tests/smoke runs, then launch the 9-run 350-epoch architecture sweep across the local 5090 and remote 5880 GPUs.
|
||||
|
||||
**Architecture:** Extend `IMFTransformerForDiffusion` with a selectable `attnres_full` backbone that keeps the current iMF training/inference API unchanged while replacing the transformer internals with RMSNorm + RoPE self-attention + SwiGLU + Full AttnRes depth-wise residual routing. Add one standalone Hydra config for the PushT image sweep and reuse queue-style launch scripts with unique SwanLab names.
|
||||
|
||||
**Tech Stack:** Python 3.9 via uv, PyTorch 2.8 CUDA, Hydra, SwanLab online logging, local shell + SSH to trusted 5880 host.
|
||||
|
||||
---
|
||||
|
||||
### Task 1: Add regression tests for the new AttnRes path
|
||||
|
||||
**Files:**
|
||||
- Modify: `tests/test_imf_transformer_for_diffusion.py`
|
||||
- Modify: `tests/test_pusht_swanlab_config.py`
|
||||
|
||||
- [ ] Add a failing model test that instantiates `IMFTransformerForDiffusion(backbone_type='attnres_full', causal_attn=False, ...)`, runs a forward pass with conditional observations, and asserts output shape plus optimizer construction.
|
||||
- [ ] Run the targeted pytest selection and confirm the new test fails for the expected missing-backbone reason.
|
||||
- [ ] Add a failing config regression test for `image_pusht_diffusion_policy_dit_imf_attnres_full.yaml` asserting SwanLab naming fields and `policy.causal_attn == False`.
|
||||
- [ ] Re-run the targeted pytest selection and confirm the config test fails before implementation.
|
||||
|
||||
### Task 2: Implement the AttnRes-backed iMF backbone
|
||||
|
||||
**Files:**
|
||||
- Create: `diffusion_policy/model/diffusion/attnres_transformer_components.py`
|
||||
- Modify: `diffusion_policy/model/diffusion/imf_transformer_for_diffusion.py`
|
||||
|
||||
- [ ] Add focused reusable modules for `RMSNorm`, RoPE helpers, grouped-query self-attention, SwiGLU FFN, and the Full AttnRes operator.
|
||||
- [ ] Extend `IMFTransformerForDiffusion` with a `backbone_type` switch that preserves the existing vanilla path and adds an `attnres_full` path using concatenated `[r, t, obs, sample]` tokens.
|
||||
- [ ] Ensure the AttnRes path slices condition tokens away before the output head so the returned tensor still matches the sample/action horizon.
|
||||
- [ ] Update optimizer parameter grouping to treat RMSNorm weights like LayerNorm weights (no decay) and include any new positional/conditioning parameters.
|
||||
- [ ] Run the targeted tests and get them green.
|
||||
|
||||
### Task 3: Add the new PushT config and smoke-test path
|
||||
|
||||
**Files:**
|
||||
- Create: `image_pusht_diffusion_policy_dit_imf_attnres_full.yaml`
|
||||
- Modify: `tests/test_pusht_swanlab_config.py`
|
||||
|
||||
- [ ] Add a standalone PushT image config for the AttnRes iMF variant with SwanLab online logging, `policy.backbone_type=attnres_full`, and `policy.causal_attn=false`.
|
||||
- [ ] Verify `uv run python train.py --config-dir=. --config-name=image_pusht_diffusion_policy_dit_imf_attnres_full.yaml --help` succeeds.
|
||||
- [ ] Run a real smoke training command with `training.debug=true`, `training.device=cuda:0`, safety overrides (`dataloader.num_workers=0`, `task.env_runner.n_envs=1`, no vis), and confirm it reaches the training loop and writes a run directory.
|
||||
|
||||
### Task 4: Prepare launch scripts and start the 9-run sweep
|
||||
|
||||
**Files:**
|
||||
- Create or modify: `data/run_logs/imf_attnres_local_queue.sh`
|
||||
- Create or modify locally before copy: `data/run_logs/imf_attnres_remote_gpu0_queue.sh`
|
||||
- Create or modify locally before copy: `data/run_logs/imf_attnres_remote_gpu1_queue.sh`
|
||||
|
||||
- [ ] Write queue command templates for the 9 runs using config `image_pusht_diffusion_policy_dit_imf_attnres_full.yaml`, `training.num_epochs=350`, unique `exp_name/logging.name`, and shared `logging.group=imf_pusht_attnres_arch_sweep`.
|
||||
- [ ] Sync the necessary config/model files plus remote queue scripts to `droid@100.73.14.65:~/project/diffusion_policy-smoke`.
|
||||
- [ ] Start the local queue under `nohup`, record PID, and verify the first run log is advancing.
|
||||
- [ ] Start the two remote queues under `nohup`, record PIDs, and verify both first-run logs are advancing.
|
||||
- [ ] Confirm all three GPUs have officially entered training for the new sweep.
|
||||
168
docs/superpowers/specs/2026-03-26-pusht-imf-swanlab-design.md
Normal file
168
docs/superpowers/specs/2026-03-26-pusht-imf-swanlab-design.md
Normal file
@@ -0,0 +1,168 @@
|
||||
# PushT Image DiT iMF + SwanLab Design
|
||||
|
||||
## Goal
|
||||
Migrate the PushT image DiT experiment path from W&B to SwanLab online logging, suppress simulation video logging, then add an iMeanFlow-based one-step transformer policy for PushT image experiments and run a controlled architecture sweep over embedding width and depth using `test_mean_score` as the primary metric.
|
||||
|
||||
## Context
|
||||
- The implementation baseline is `main`.
|
||||
- The experiment path is limited to the PushT image transformer workflow; unrelated workspaces and runners should remain unchanged.
|
||||
- Environment management must use the repo-local `uv` workflow.
|
||||
- The trusted remote machine alias `5880` refers to `droid-system-product-name` (`droid@100.73.14.65`) and can run two GPU jobs in parallel.
|
||||
|
||||
## Architecture Overview
|
||||
The work is split into two verified phases:
|
||||
|
||||
1. **Logging migration phase**
|
||||
- Keep the existing PushT image DiT training behavior intact.
|
||||
- Replace W&B usage with SwanLab in the transformer hybrid workspace used by PushT image DiT experiments.
|
||||
- Preserve local `logs.json.txt` output.
|
||||
- Ensure rollout metrics such as `test_mean_score` and per-seed rewards are still logged.
|
||||
- Disable simulation video logging at both the config and runner/logging boundary.
|
||||
|
||||
2. **iMF migration phase**
|
||||
- Keep the original diffusion-based transformer image policy available on `main`.
|
||||
- Add a parallel iMF-specific model/policy/config path rather than overwriting the baseline diffusion policy.
|
||||
- Reuse the existing observation encoder and training workspace where possible.
|
||||
- Replace diffusion training with the iMeanFlow training objective.
|
||||
- Use one-step inference for validation/rollout in the iMF path.
|
||||
|
||||
The implementation planning boundary for this spec is:
|
||||
- code changes through a smoke-tested, pushed iMF branch
|
||||
- not the full 3x3 sweep execution/monitoring workflow, which should be planned separately after the code path is verified and pushed
|
||||
|
||||
## Logging Design
|
||||
### Scope
|
||||
Only the PushT image DiT experiment chain is changed:
|
||||
- `train_diffusion_transformer_hybrid_workspace.py`
|
||||
- `pusht_image_runner.py`
|
||||
- the new/updated PushT image transformer configs
|
||||
|
||||
### Behavior
|
||||
- SwanLab runs in `online` mode.
|
||||
- Logged values are scalar metrics only, e.g.:
|
||||
- `train_loss`
|
||||
- `val_loss`
|
||||
- `train_action_mse_error`
|
||||
- `test_mean_score`
|
||||
- aggregate rollout metrics and optional per-seed scalar rewards
|
||||
- No simulation videos are uploaded or wrapped as logging objects.
|
||||
- Local JSON logging remains enabled for auditability and remote-job fallback debugging.
|
||||
|
||||
### Operational safeguards
|
||||
- Default PushT experiment configs set `task.env_runner.n_test_vis=0` and `task.env_runner.n_train_vis=0`.
|
||||
- The PushT image runner will not emit video objects into `log_data`, preventing accidental uploads even if visualization counts are later changed.
|
||||
- SwanLab credentials are provided through the environment at runtime, not committed into the repo.
|
||||
|
||||
## iMF Model Design
|
||||
### Baseline reuse
|
||||
The iMF path reuses:
|
||||
- the existing image observation encoder
|
||||
- the existing action/observation normalization path
|
||||
- the existing training workspace skeleton
|
||||
- the existing PushT image dataset and env runner
|
||||
|
||||
### New files
|
||||
- `diffusion_policy/model/diffusion/imf_transformer_for_diffusion.py`
|
||||
- `diffusion_policy/policy/imf_transformer_hybrid_image_policy.py`
|
||||
- `image_pusht_diffusion_policy_dit_imf.yaml`
|
||||
|
||||
### Existing files changed for the iMF path
|
||||
- `diffusion_policy/workspace/train_diffusion_transformer_hybrid_workspace.py`
|
||||
- logging migration to SwanLab for this experiment chain
|
||||
- no structural training-loop fork beyond instantiating the configured policy and logging scalar metrics
|
||||
- `diffusion_policy/env_runner/pusht_image_runner.py`
|
||||
- suppress video objects in returned logs
|
||||
|
||||
### Model structure
|
||||
The iMF transformer mirrors the current transformer policy structure closely enough to reuse known-good conditioning patterns, but it remains a **single-head model** that predicts only:
|
||||
- `u`: average velocity field
|
||||
|
||||
The same function is reused at two evaluation points:
|
||||
- canonical signature: `fn(z, r, t, cond)`
|
||||
- `fn(z_t, r, t, cond)` predicts average velocity `u`
|
||||
- `fn(z_t, t, t, cond)` predicts the instantaneous velocity surrogate `v`
|
||||
|
||||
Inputs remain conditioned on encoded observations and action trajectory tokens.
|
||||
|
||||
## iMF Training Objective
|
||||
For a normalized action trajectory `x`, the initial implementation follows the user-provided Algorithm 1 exactly:
|
||||
1. sample `t, r`
|
||||
2. sample Gaussian noise `e`
|
||||
3. form `z_t = (1 - t) * x + t * e`
|
||||
4. predict instantaneous velocity surrogate with the same network:
|
||||
- `v = fn(z_t, t, t, cond)`
|
||||
5. define the JVP function exactly as:
|
||||
- `g(z, r, t) = fn(z, r, t, cond)`
|
||||
6. compute the primal output and JVP with tangent:
|
||||
- `u, du_dt = jvp(g, (z_t, r, t), (v.detach(), 0, 1))`
|
||||
7. form compound velocity:
|
||||
- `V = u + (t - r) * stopgrad(du_dt)`
|
||||
8. train against the average-velocity target:
|
||||
- `target = e - x`
|
||||
9. optimize only the masked iMF loss:
|
||||
- `loss = metric(V - target)`
|
||||
|
||||
There is **no auxiliary `v` loss** in the initial implementation. The implementation should prefer `torch.func.jvp` and keep a safe fallback path if the local Torch stack needs it.
|
||||
|
||||
## iMF Inference Design
|
||||
Inference uses a single step starting from noise:
|
||||
- initialize `z_1 ~ N(0, I)`
|
||||
- set `t = 1.0`, `r = 0.0`
|
||||
- predict `u = fn(z_1, r, t, cond)`
|
||||
- produce the action sample with one update:
|
||||
- `x_hat = z_1 - (t - r) * u`
|
||||
|
||||
This matches the time direction in the reference iMeanFlow sampling logic.
|
||||
|
||||
## Testing Strategy
|
||||
### Phase 1: logging migration smoke test
|
||||
- use the repo-local `uv` environment
|
||||
- run a debug/smoke PushT image DiT training job on a single GPU with:
|
||||
- `training.debug=true`
|
||||
- `dataloader.num_workers=0`
|
||||
- `val_dataloader.num_workers=0`
|
||||
- `task.env_runner.n_envs=1`
|
||||
- `task.env_runner.n_test_vis=0`
|
||||
- `task.env_runner.n_train_vis=0`
|
||||
- verify:
|
||||
- SwanLab initializes successfully
|
||||
- `logs.json.txt` is populated
|
||||
- rollout metrics still include `test_mean_score`
|
||||
- no video logging is attempted
|
||||
|
||||
### Phase 2: iMF smoke test
|
||||
- run an equivalent debug PushT image iMF job
|
||||
- verify:
|
||||
- forward/backward passes succeed
|
||||
- JVP path executes on the local Torch version
|
||||
- one-step inference returns correctly shaped actions
|
||||
- rollout produces scalar metrics including `test_mean_score`
|
||||
|
||||
## Branch and Commit Strategy
|
||||
1. start from a `main`-based worktree branch
|
||||
2. commit the SwanLab/no-video migration after smoke verification
|
||||
3. continue with the iMF implementation
|
||||
4. once iMF smoke tests pass, create/preserve a dedicated feature branch for the experiment code and push it to Gitea
|
||||
|
||||
## Post-Implementation Experiment Plan
|
||||
After the iMF path is smoke-tested and pushed, a separate experiment-execution plan should launch:
|
||||
- run a 3x3 grid over:
|
||||
- `n_emb ∈ {128, 256, 384}`
|
||||
- `n_layer ∈ {6, 12, 18}`
|
||||
- keep the rest of the setup fixed
|
||||
- use a fixed single-seed setting for comparability unless a later explicit experiment plan expands that scope
|
||||
- run each experiment for 300 epochs
|
||||
- primary comparison metric: `test_mean_score`
|
||||
|
||||
## Post-Implementation Resource Allocation
|
||||
The separate experiment-execution plan should schedule three concurrent runs until the matrix is complete:
|
||||
- local machine: 1 GPU
|
||||
- `5880`: 2 GPUs
|
||||
|
||||
Each run uses the same uv-managed environment and the same pushed branch so the code path is consistent across hosts.
|
||||
|
||||
## Risks and Mitigations
|
||||
- **Torch JVP compatibility risk**: provide a fallback JVP implementation and smoke-test immediately.
|
||||
- **Logging regression risk**: keep local JSON logging and verify scalar rollout metrics before moving to iMF.
|
||||
- **Video/logging side effects**: disable visualizations in config and filter video objects out of runner logs.
|
||||
- **Cross-host drift**: push the verified branch to Gitea before launching the experiment matrix on multiple machines.
|
||||
107
docs/superpowers/specs/2026-03-27-pusht-imf-fullattn-design.md
Normal file
107
docs/superpowers/specs/2026-03-27-pusht-imf-fullattn-design.md
Normal file
@@ -0,0 +1,107 @@
|
||||
# PushT Image iMF Full-Attention Sweep Design
|
||||
|
||||
## Goal
|
||||
在一个独立新分支上,为 PushT 图像 iMF 路线新增 **full-attention** 变体(关闭因果注意力),并按与之前相同的架构扫描网格运行 **9 组实验**,每组训练 **350 epochs**。所有实验完成后,提取每组 **`max(test_mean_score)`** 并输出完整排名和趋势总结。
|
||||
|
||||
## Scope
|
||||
本次工作仅覆盖:
|
||||
1. 在不影响现有因果版 iMF 路线的前提下,新增 full-attention 实验链路;
|
||||
2. 对 `n_emb ∈ {128, 256, 384}` 与 `n_layer ∈ {6, 12, 18}` 的 9 组组合做 350-epoch 扫描;
|
||||
3. 在本机 5090 与 5880 双卡上做三路并行调度;
|
||||
4. 在全部实验完成后自动汇总结果并直接向用户汇报。
|
||||
|
||||
不在本次范围内:
|
||||
- 不替换或删除现有因果版 iMF 配置;
|
||||
- 不改动已有 DiT baseline 实现;
|
||||
- 不做多 seed 扩展;
|
||||
- 不额外增加视频记录。
|
||||
|
||||
## Design Choice
|
||||
采用“**新增独立配置 + 新分支**”的方式,而不是覆盖现有 iMF 默认配置。
|
||||
|
||||
原因:
|
||||
- 现有因果版 iMF 已完成实验与结果记录,保持不动更利于对照;
|
||||
- full-attention 作为新的实验链路,使用独立配置更易复现;
|
||||
- 运行时只需要通过配置切换 `policy.causal_attn=false`,不需要重新设计 iMF 算法本身。
|
||||
|
||||
## Configuration Design
|
||||
新增一个独立配置文件,例如:
|
||||
- `image_pusht_diffusion_policy_dit_imf_fullattn.yaml`
|
||||
|
||||
其职责:
|
||||
- 继承当前 PushT image iMF 配置链路;
|
||||
- 保持 iMF 单步推理、SwanLab 标量记录、无视频记录;
|
||||
- 显式设置:
|
||||
- `policy.causal_attn=false`
|
||||
- `policy.n_head=1`
|
||||
- 保持其余 iMF 训练语义不变。
|
||||
|
||||
SwanLab 命名延续当前修复后的策略:
|
||||
- `logging.name=${exp_name}`
|
||||
- `logging.resume=false`
|
||||
- `logging.id=null`
|
||||
- `logging.group=${exp_name}` 或统一 sweep group override
|
||||
|
||||
## Code Change Strategy
|
||||
优先最小改动:
|
||||
- 若当前 `IMFTransformerForDiffusion` 已支持 `causal_attn=False` 分支,则不改核心算法,仅通过新配置关闭因果 mask;
|
||||
- 如需补充回归验证,则新增针对 full-attention 配置/掩码行为的最小测试;
|
||||
- 不改变已有因果版实验配置和已有测试语义。
|
||||
|
||||
## Experiment Matrix
|
||||
实验网格固定为:
|
||||
|
||||
- `n_emb=128, n_layer=6`
|
||||
- `n_emb=128, n_layer=12`
|
||||
- `n_emb=128, n_layer=18`
|
||||
- `n_emb=256, n_layer=6`
|
||||
- `n_emb=256, n_layer=12`
|
||||
- `n_emb=256, n_layer=18`
|
||||
- `n_emb=384, n_layer=6`
|
||||
- `n_emb=384, n_layer=12`
|
||||
- `n_emb=384, n_layer=18`
|
||||
|
||||
统一设置:
|
||||
- `training.num_epochs=350`
|
||||
- `training.resume=false`
|
||||
- `seed=42`
|
||||
- PushT image 数据路径不变
|
||||
- 指标以 **`logs.json.txt` 中 `test_mean_score` 的最大值** 为准
|
||||
|
||||
## Scheduling Design
|
||||
使用三路串行队列并行执行 9 个实验:
|
||||
|
||||
- 本机 5090:1 个顺序队列
|
||||
- 5880 GPU0:1 个顺序队列
|
||||
- 5880 GPU1:1 个顺序队列
|
||||
|
||||
分配原则:
|
||||
- 延续按 `n_emb × n_layer` 近似平衡工作量;
|
||||
- 每张卡同一时刻只跑 1 个实验;
|
||||
- 队列脚本负责“前一个结束后自动启动下一个”。
|
||||
|
||||
## Monitoring Design
|
||||
继续采用“**训练队列脚本 + 监控 agent**”双层机制:
|
||||
|
||||
1. **实际调度**由本地/远端队列脚本负责;
|
||||
2. **监控**由一个 xhigh 子 agent 轮询:
|
||||
- 读取 pid 状态
|
||||
- 检查 master log
|
||||
- 检查每个 run 的 `logs.json.txt`
|
||||
- 判断是否卡死/失败/全部完成
|
||||
3. 一旦全部完成,监控 agent 直接返回:
|
||||
- 9 组实验的最终 epoch
|
||||
- 每组 `max(test_mean_score)`
|
||||
- 排名表
|
||||
- embedding / layer 趋势总结
|
||||
|
||||
本次要求下,agent 在收到全部完成信号后应直接向主会话回报结果,不等待用户再次提醒。
|
||||
|
||||
## Success Criteria
|
||||
满足以下条件即视为完成:
|
||||
1. full-attention iMF 配置在新分支上可运行;
|
||||
2. 9 组 350-epoch 实验全部完成;
|
||||
3. 不记录仿真视频,只记录标量;
|
||||
4. SwanLab 运行命名不冲突;
|
||||
5. 输出 9 组实验 `max(test_mean_score)` 的完整汇总与结论;
|
||||
6. 全部实验结束后主会话可直接给用户最终总结。
|
||||
108
docs/superpowers/specs/2026-03-29-pusht-imf-attnres-design.md
Normal file
108
docs/superpowers/specs/2026-03-29-pusht-imf-attnres-design.md
Normal file
@@ -0,0 +1,108 @@
|
||||
# PushT Image iMF AttnRes Design
|
||||
|
||||
## Goal
|
||||
在现有 PushT 图像 iMF full-attention 路线之上,引入 `attn_res` 仓库中的 **Full AttnRes** 残差聚合形式,并同步使用与其匹配的 **RMSNorm + 自注意力 + SwiGLU FFN** 模块,保持 iMF 训练目标与一步推理语义不变,仅作用于本次实验链路。实现完成并验证后,启动与此前相同的 9 组 `n_emb × n_layer` 扫描(350 epochs, seed=42, SwanLab online, 无视频记录)。
|
||||
|
||||
## Scope
|
||||
本次工作仅覆盖:
|
||||
1. 为 `IMFTransformerForDiffusion` 增加一个 AttnRes-backed backbone 变体;
|
||||
2. 保持 `forward(sample, r, t, cond=None)`、iMF loss、一步推理策略接口不变;
|
||||
3. 新增独立 PushT 图像配置用于该变体;
|
||||
4. 复用本地 5090 + 远端 5880 双卡三路并行调度 9 组实验。
|
||||
|
||||
不在范围内:
|
||||
- 不替换已有 vanilla iMF/full-attn 配置;
|
||||
- 不修改 DiT baseline;
|
||||
- 不增加视频日志;
|
||||
- 不扩大到多 seed。
|
||||
|
||||
## Recommended Approach
|
||||
采用“**在当前 iMF 模型内增加可选 AttnRes backbone**”的方式,而不是新建独立 policy 链路。
|
||||
|
||||
理由:
|
||||
- policy / workspace / loss / sampling 路径已经被验证,保留这些路径可最大程度缩小变动面;
|
||||
- 仅在模型内部切换 backbone,可以让新实验与既有 iMF 结果保持可比;
|
||||
- 配置上只需显式打开 `backbone_type=attnres_full`、`causal_attn=false` 等开关,复现实验更直接。
|
||||
|
||||
## Architecture
|
||||
### 1. Backbone split
|
||||
`IMFTransformerForDiffusion` 保留现有 vanilla encoder/decoder 实现为默认路径,并新增 `attnres_full` 路径:
|
||||
- **vanilla**:保持当前实现不变;
|
||||
- **attnres_full**:使用单栈式全注意力 Transformer,输入 token 序列为
|
||||
`[r token, t token, obs cond tokens..., action/sample tokens...]`。
|
||||
|
||||
模型只对末尾的 action/sample token 位置输出 `u` 预测,前置条件 token 仅参与上下文建模。
|
||||
|
||||
### 2. AttnRes stack
|
||||
新 backbone 使用以下模块:
|
||||
- `RMSNorm`
|
||||
- `Rotary Position Embedding`(用于自注意力 q/k)
|
||||
- `GroupedQueryAttention`(本实验默认 `n_kv_head=1`,与单头配置兼容)
|
||||
- `SwiGLU` FFN
|
||||
- `AttnResOperator`(每个子层一个 pseudo-query,执行 full depth-wise residual aggregation)
|
||||
|
||||
每个 transformer block 由两个子层组成:
|
||||
1. self-attention 子层
|
||||
2. FFN 子层
|
||||
|
||||
每个子层的输入不再是简单 `x + f(x)`,而是从 embedding 与全部历史子层输出中通过 Full AttnRes 聚合得到 `h_l`,再执行 `RMSNorm(h_l) -> sublayer_fn(...)`。
|
||||
|
||||
### 3. Conditioning and token flow
|
||||
- `sample` 先经 `input_emb` 映射为 action tokens;
|
||||
- `r` 和 `t` 各自经 `SinusoidalPosEmb + linear` 映射为两个条件 token;
|
||||
- 图像观测编码后的 `cond` 通过 `cond_obs_emb` 映射为 obs tokens;
|
||||
- 拼接后的完整 token 序列进入 AttnRes stack;
|
||||
- 输出时切掉前置条件 token,仅保留 action/sample token 段,随后经 `RMSNorm + head` 得到最终 `u`。
|
||||
|
||||
### 4. Attention mode
|
||||
本次实验链路固定为 **non-causal full attention**:
|
||||
- `causal_attn=false`
|
||||
- 不构造 causal mask
|
||||
- 所有 token 可彼此双向可见
|
||||
|
||||
这与用户指定的“训练过程仍然使用全注意力(不加因果注意)”一致。
|
||||
|
||||
## Config and Logging
|
||||
新增独立配置文件,例如:
|
||||
- `image_pusht_diffusion_policy_dit_imf_attnres_full.yaml`
|
||||
|
||||
该配置需要:
|
||||
- 指向现有 `IMFTransformerHybridImagePolicy`
|
||||
- 显式开启 AttnRes backbone 相关参数
|
||||
- 设置 `policy.causal_attn=false`
|
||||
- 保持 `logging.backend=swanlab`、`logging.mode=online`
|
||||
- 运行时通过覆盖保证:
|
||||
- `logging.name=<unique_run_name>`
|
||||
- `logging.group=imf_pusht_attnres_arch_sweep`
|
||||
- `exp_name=<unique_run_name>`
|
||||
- 保持 `task.env_runner.n_test_vis=0` 与 `n_train_vis=0`,仅记录标量
|
||||
|
||||
## Experiment Matrix
|
||||
固定 9 组:
|
||||
- `n_emb ∈ {128, 256, 384}`
|
||||
- `n_layer ∈ {6, 12, 18}`
|
||||
- `seed=42`
|
||||
- `training.num_epochs=350`
|
||||
|
||||
## Scheduling
|
||||
沿用之前验证过的三队列分配:
|
||||
- 本机 5090:`384x18`, `256x6`, `128x6`
|
||||
- 5880 GPU0:`384x12`, `256x12`, `128x12`
|
||||
- 5880 GPU1:`384x6`, `256x18`, `128x18`
|
||||
|
||||
每个 run name 编码 backbone 与结构,例如:
|
||||
`imf_attnres_emb256_layer12_seed42_5880gpu0`
|
||||
|
||||
## Verification
|
||||
实现阶段至少验证:
|
||||
1. 新配置的 SwanLab 命名与 `causal_attn=false` 正确;
|
||||
2. 新 backbone 的 forward shape 与 `configure_optimizers()` 可用;
|
||||
3. 旧 vanilla 路径测试不回归;
|
||||
4. `training.debug=true` smoke run 可以完整通过。
|
||||
|
||||
## Success Criteria
|
||||
1. 新 AttnRes iMF 变体在本分支可训练、可一步推理;
|
||||
2. 不影响已有 vanilla iMF/full-attn 链路;
|
||||
3. 9 组实验成功在三张卡上正式启动;
|
||||
4. SwanLab run 名称唯一,无冲突;
|
||||
5. 不记录视频,仅记录标量。
|
||||
64
eval.py
Normal file
64
eval.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
Usage:
|
||||
python eval.py --checkpoint data/image/pusht/diffusion_policy_cnn/train_0/checkpoints/latest.ckpt -o data/pusht_eval_output
|
||||
"""
|
||||
|
||||
import sys
|
||||
# use line-buffering for both stdout and stderr
|
||||
sys.stdout = open(sys.stdout.fileno(), mode='w', buffering=1)
|
||||
sys.stderr = open(sys.stderr.fileno(), mode='w', buffering=1)
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import click
|
||||
import hydra
|
||||
import torch
|
||||
import dill
|
||||
import wandb
|
||||
import json
|
||||
from diffusion_policy.workspace.base_workspace import BaseWorkspace
|
||||
|
||||
@click.command()
|
||||
@click.option('-c', '--checkpoint', required=True)
|
||||
@click.option('-o', '--output_dir', required=True)
|
||||
@click.option('-d', '--device', default='cuda:0')
|
||||
def main(checkpoint, output_dir, device):
|
||||
if os.path.exists(output_dir):
|
||||
click.confirm(f"Output path {output_dir} already exists! Overwrite?", abort=True)
|
||||
pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# load checkpoint
|
||||
payload = torch.load(open(checkpoint, 'rb'), pickle_module=dill)
|
||||
cfg = payload['cfg']
|
||||
cls = hydra.utils.get_class(cfg._target_)
|
||||
workspace = cls(cfg, output_dir=output_dir)
|
||||
workspace: BaseWorkspace
|
||||
workspace.load_payload(payload, exclude_keys=None, include_keys=None)
|
||||
|
||||
# get policy from workspace
|
||||
policy = workspace.model
|
||||
if cfg.training.use_ema:
|
||||
policy = workspace.ema_model
|
||||
|
||||
device = torch.device(device)
|
||||
policy.to(device)
|
||||
policy.eval()
|
||||
|
||||
# run eval
|
||||
env_runner = hydra.utils.instantiate(
|
||||
cfg.task.env_runner,
|
||||
output_dir=output_dir)
|
||||
runner_log = env_runner.run(policy)
|
||||
|
||||
# dump log to json
|
||||
json_log = dict()
|
||||
for key, value in runner_log.items():
|
||||
if isinstance(value, wandb.sdk.data_types.video.Video):
|
||||
json_log[key] = value._path
|
||||
else:
|
||||
json_log[key] = value
|
||||
out_path = os.path.join(output_dir, 'eval_log.json')
|
||||
json.dump(json_log, open(out_path, 'w'), indent=2, sort_keys=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
182
image_pusht_diffusion_policy_cnn.yaml
Normal file
182
image_pusht_diffusion_policy_cnn.yaml
Normal file
@@ -0,0 +1,182 @@
|
||||
_target_: diffusion_policy.workspace.train_diffusion_unet_hybrid_workspace.TrainDiffusionUnetHybridWorkspace
|
||||
checkpoint:
|
||||
save_last_ckpt: true
|
||||
save_last_snapshot: false
|
||||
topk:
|
||||
format_str: epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt
|
||||
k: 5
|
||||
mode: max
|
||||
monitor_key: test_mean_score
|
||||
dataloader:
|
||||
batch_size: 64
|
||||
num_workers: 8
|
||||
persistent_workers: false
|
||||
pin_memory: true
|
||||
shuffle: true
|
||||
dataset_obs_steps: 2
|
||||
ema:
|
||||
_target_: diffusion_policy.model.diffusion.ema_model.EMAModel
|
||||
inv_gamma: 1.0
|
||||
max_value: 0.9999
|
||||
min_value: 0.0
|
||||
power: 0.75
|
||||
update_after_step: 0
|
||||
exp_name: default
|
||||
horizon: 16
|
||||
keypoint_visible_rate: 1.0
|
||||
logging:
|
||||
group: null
|
||||
id: null
|
||||
mode: online
|
||||
name: 2023.01.16-20.20.06_train_diffusion_unet_hybrid_pusht_image
|
||||
project: diffusion_policy_debug
|
||||
resume: true
|
||||
tags:
|
||||
- train_diffusion_unet_hybrid
|
||||
- pusht_image
|
||||
- default
|
||||
multi_run:
|
||||
run_dir: data/outputs/2023.01.16/20.20.06_train_diffusion_unet_hybrid_pusht_image
|
||||
wandb_name_base: 2023.01.16-20.20.06_train_diffusion_unet_hybrid_pusht_image
|
||||
n_action_steps: 8
|
||||
n_latency_steps: 0
|
||||
n_obs_steps: 2
|
||||
name: train_diffusion_unet_hybrid
|
||||
obs_as_global_cond: true
|
||||
optimizer:
|
||||
_target_: torch.optim.AdamW
|
||||
betas:
|
||||
- 0.95
|
||||
- 0.999
|
||||
eps: 1.0e-08
|
||||
lr: 0.0001
|
||||
weight_decay: 1.0e-06
|
||||
past_action_visible: false
|
||||
policy:
|
||||
_target_: diffusion_policy.policy.diffusion_unet_hybrid_image_policy.DiffusionUnetHybridImagePolicy
|
||||
cond_predict_scale: true
|
||||
crop_shape:
|
||||
- 84
|
||||
- 84
|
||||
diffusion_step_embed_dim: 128
|
||||
down_dims:
|
||||
- 512
|
||||
- 1024
|
||||
- 2048
|
||||
eval_fixed_crop: true
|
||||
horizon: 16
|
||||
kernel_size: 5
|
||||
n_action_steps: 8
|
||||
n_groups: 8
|
||||
n_obs_steps: 2
|
||||
noise_scheduler:
|
||||
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
||||
beta_end: 0.02
|
||||
beta_schedule: squaredcos_cap_v2
|
||||
beta_start: 0.0001
|
||||
clip_sample: true
|
||||
num_train_timesteps: 100
|
||||
prediction_type: epsilon
|
||||
variance_type: fixed_small
|
||||
num_inference_steps: 100
|
||||
obs_as_global_cond: true
|
||||
obs_encoder_group_norm: true
|
||||
shape_meta:
|
||||
action:
|
||||
shape:
|
||||
- 2
|
||||
obs:
|
||||
agent_pos:
|
||||
shape:
|
||||
- 2
|
||||
type: low_dim
|
||||
image:
|
||||
shape:
|
||||
- 3
|
||||
- 96
|
||||
- 96
|
||||
type: rgb
|
||||
shape_meta:
|
||||
action:
|
||||
shape:
|
||||
- 2
|
||||
obs:
|
||||
agent_pos:
|
||||
shape:
|
||||
- 2
|
||||
type: low_dim
|
||||
image:
|
||||
shape:
|
||||
- 3
|
||||
- 96
|
||||
- 96
|
||||
type: rgb
|
||||
task:
|
||||
dataset:
|
||||
_target_: diffusion_policy.dataset.pusht_image_dataset.PushTImageDataset
|
||||
horizon: 16
|
||||
max_train_episodes: 90
|
||||
pad_after: 7
|
||||
pad_before: 1
|
||||
seed: 42
|
||||
val_ratio: 0.02
|
||||
zarr_path: data/pusht/pusht_cchi_v7_replay.zarr
|
||||
env_runner:
|
||||
_target_: diffusion_policy.env_runner.pusht_image_runner.PushTImageRunner
|
||||
fps: 10
|
||||
legacy_test: true
|
||||
max_steps: 300
|
||||
n_action_steps: 8
|
||||
n_envs: null
|
||||
n_obs_steps: 2
|
||||
n_test: 50
|
||||
n_test_vis: 4
|
||||
n_train: 6
|
||||
n_train_vis: 2
|
||||
past_action: false
|
||||
test_start_seed: 100000
|
||||
train_start_seed: 0
|
||||
image_shape:
|
||||
- 3
|
||||
- 96
|
||||
- 96
|
||||
name: pusht_image
|
||||
shape_meta:
|
||||
action:
|
||||
shape:
|
||||
- 2
|
||||
obs:
|
||||
agent_pos:
|
||||
shape:
|
||||
- 2
|
||||
type: low_dim
|
||||
image:
|
||||
shape:
|
||||
- 3
|
||||
- 96
|
||||
- 96
|
||||
type: rgb
|
||||
task_name: pusht_image
|
||||
training:
|
||||
checkpoint_every: 50
|
||||
debug: false
|
||||
device: cuda:0
|
||||
gradient_accumulate_every: 1
|
||||
lr_scheduler: cosine
|
||||
lr_warmup_steps: 500
|
||||
max_train_steps: null
|
||||
max_val_steps: null
|
||||
num_epochs: 3050
|
||||
resume: true
|
||||
rollout_every: 50
|
||||
sample_every: 5
|
||||
seed: 42
|
||||
tqdm_interval_sec: 1.0
|
||||
use_ema: true
|
||||
val_every: 1
|
||||
val_dataloader:
|
||||
batch_size: 64
|
||||
num_workers: 8
|
||||
persistent_workers: false
|
||||
pin_memory: true
|
||||
shuffle: false
|
||||
30
image_pusht_diffusion_policy_dit.yaml
Normal file
30
image_pusht_diffusion_policy_dit.yaml
Normal file
@@ -0,0 +1,30 @@
|
||||
defaults:
|
||||
- diffusion_policy/config/train_diffusion_transformer_hybrid_workspace@_here_
|
||||
- override /diffusion_policy/config/task@task: pusht_image
|
||||
- _self_
|
||||
|
||||
exp_name: pusht_image_dit
|
||||
|
||||
policy:
|
||||
_target_: diffusion_policy.policy.diffusion_transformer_hybrid_image_policy.DiffusionTransformerHybridImagePolicy
|
||||
|
||||
logging:
|
||||
backend: swanlab
|
||||
mode: online
|
||||
name: ${exp_name}
|
||||
resume: false
|
||||
tags: ["${name}", "${task_name}", "${exp_name}", "swanlab"]
|
||||
id: null
|
||||
group: ${exp_name}
|
||||
|
||||
dataloader:
|
||||
num_workers: 0
|
||||
|
||||
val_dataloader:
|
||||
num_workers: 0
|
||||
|
||||
task:
|
||||
env_runner:
|
||||
n_envs: 1
|
||||
n_test_vis: 0
|
||||
n_train_vis: 0
|
||||
32
image_pusht_diffusion_policy_dit_imf.yaml
Normal file
32
image_pusht_diffusion_policy_dit_imf.yaml
Normal file
@@ -0,0 +1,32 @@
|
||||
defaults:
|
||||
- diffusion_policy/config/train_diffusion_transformer_hybrid_workspace@_here_
|
||||
- override /diffusion_policy/config/task@task: pusht_image
|
||||
- _self_
|
||||
|
||||
exp_name: pusht_image_dit_imf
|
||||
|
||||
policy:
|
||||
_target_: diffusion_policy.policy.imf_transformer_hybrid_image_policy.IMFTransformerHybridImagePolicy
|
||||
num_inference_steps: 1
|
||||
n_head: 1
|
||||
|
||||
logging:
|
||||
backend: swanlab
|
||||
mode: online
|
||||
name: ${exp_name}
|
||||
resume: false
|
||||
tags: ["${name}", "${task_name}", "${exp_name}", "swanlab"]
|
||||
id: null
|
||||
group: ${exp_name}
|
||||
|
||||
dataloader:
|
||||
num_workers: 0
|
||||
|
||||
val_dataloader:
|
||||
num_workers: 0
|
||||
|
||||
task:
|
||||
env_runner:
|
||||
n_envs: 1
|
||||
n_test_vis: 0
|
||||
n_train_vis: 0
|
||||
35
image_pusht_diffusion_policy_dit_imf_attnres_full.yaml
Normal file
35
image_pusht_diffusion_policy_dit_imf_attnres_full.yaml
Normal file
@@ -0,0 +1,35 @@
|
||||
defaults:
|
||||
- diffusion_policy/config/train_diffusion_transformer_hybrid_workspace@_here_
|
||||
- override /diffusion_policy/config/task@task: pusht_image
|
||||
- _self_
|
||||
|
||||
exp_name: pusht_image_dit_imf_attnres_full
|
||||
|
||||
policy:
|
||||
_target_: diffusion_policy.policy.imf_transformer_hybrid_image_policy.IMFTransformerHybridImagePolicy
|
||||
num_inference_steps: 1
|
||||
n_head: 1
|
||||
n_kv_head: 1
|
||||
causal_attn: false
|
||||
backbone_type: attnres_full
|
||||
|
||||
logging:
|
||||
backend: swanlab
|
||||
mode: online
|
||||
name: ${exp_name}
|
||||
resume: false
|
||||
tags: ["${name}", "${task_name}", "${exp_name}", "swanlab"]
|
||||
id: null
|
||||
group: ${exp_name}
|
||||
|
||||
dataloader:
|
||||
num_workers: 0
|
||||
|
||||
val_dataloader:
|
||||
num_workers: 0
|
||||
|
||||
task:
|
||||
env_runner:
|
||||
n_envs: 1
|
||||
n_test_vis: 0
|
||||
n_train_vis: 0
|
||||
33
image_pusht_diffusion_policy_dit_imf_fullattn.yaml
Normal file
33
image_pusht_diffusion_policy_dit_imf_fullattn.yaml
Normal file
@@ -0,0 +1,33 @@
|
||||
defaults:
|
||||
- diffusion_policy/config/train_diffusion_transformer_hybrid_workspace@_here_
|
||||
- override /diffusion_policy/config/task@task: pusht_image
|
||||
- _self_
|
||||
|
||||
exp_name: pusht_image_dit_imf_fullattn
|
||||
|
||||
policy:
|
||||
_target_: diffusion_policy.policy.imf_transformer_hybrid_image_policy.IMFTransformerHybridImagePolicy
|
||||
num_inference_steps: 1
|
||||
n_head: 1
|
||||
causal_attn: false
|
||||
|
||||
logging:
|
||||
backend: swanlab
|
||||
mode: online
|
||||
name: ${exp_name}
|
||||
resume: false
|
||||
tags: ["${name}", "${task_name}", "${exp_name}", "swanlab"]
|
||||
id: null
|
||||
group: ${exp_name}
|
||||
|
||||
dataloader:
|
||||
num_workers: 0
|
||||
|
||||
val_dataloader:
|
||||
num_workers: 0
|
||||
|
||||
task:
|
||||
env_runner:
|
||||
n_envs: 1
|
||||
n_test_vis: 0
|
||||
n_train_vis: 0
|
||||
39
requirements-pusht-5090.txt
Normal file
39
requirements-pusht-5090.txt
Normal file
@@ -0,0 +1,39 @@
|
||||
# Direct package pins for the canonical PushT image workflow on host 5090.
|
||||
# Torch/TorchVision/Torchaudio are installed separately from the cu128 index in setup_uv_pusht_5090.sh.
|
||||
|
||||
numpy==1.26.4
|
||||
scipy==1.11.4
|
||||
numba==0.59.1
|
||||
llvmlite==0.42.0
|
||||
cffi==1.15.1
|
||||
cython==0.29.32
|
||||
h5py==3.8.0
|
||||
pandas==2.2.3
|
||||
zarr==2.12.0
|
||||
numcodecs==0.10.2
|
||||
hydra-core==1.2.0
|
||||
einops==0.4.1
|
||||
tqdm==4.64.1
|
||||
dill==0.3.5.1
|
||||
scikit-video==1.1.11
|
||||
scikit-image==0.19.3
|
||||
gym==0.23.1
|
||||
pymunk==6.2.1
|
||||
wandb==0.13.3
|
||||
threadpoolctl==3.1.0
|
||||
shapely==1.8.5.post1
|
||||
imageio==2.22.0
|
||||
imageio-ffmpeg==0.4.7
|
||||
termcolor==2.0.1
|
||||
tensorboard==2.10.1
|
||||
tensorboardx==2.5.1
|
||||
psutil==7.2.2
|
||||
click==8.1.8
|
||||
boto3==1.24.96
|
||||
diffusers==0.11.1
|
||||
huggingface-hub==0.10.1
|
||||
av==14.0.1
|
||||
pygame==2.5.2
|
||||
robomimic==0.2.0
|
||||
opencv-python-headless==4.10.0.84
|
||||
swanlab
|
||||
29
scripts/pusht/imf_attnres_local_queue.sh
Executable file
29
scripts/pusht/imf_attnres_local_queue.sh
Executable file
@@ -0,0 +1,29 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
cd /home/droid/project/diffusion_policy/.worktrees/feat-pusht-imf-attnres
|
||||
export PYTHONUNBUFFERED=1
|
||||
export SWANLAB_API_KEY='PSZrBMLx1XAjDjvmhUcNz'
|
||||
export LD_LIBRARY_PATH="$(printf '%s:' .venv/lib/python3.9/site-packages/nvidia/*/lib | sed 's/:$//')"
|
||||
run_exp() {
|
||||
local name="$1" emb="$2" layer="$3"
|
||||
echo "[$(date '+%F %T')] START $name emb=$emb layer=$layer"
|
||||
.venv/bin/python train.py \
|
||||
--config-dir=. \
|
||||
--config-name=image_pusht_diffusion_policy_dit_imf_attnres_full.yaml \
|
||||
training.device=cuda:0 \
|
||||
training.num_epochs=350 \
|
||||
training.resume=false \
|
||||
exp_name="$name" \
|
||||
logging.group=imf_pusht_attnres_arch_sweep \
|
||||
logging.name="$name" \
|
||||
logging.resume=false \
|
||||
logging.id=null \
|
||||
hydra.run.dir="data/outputs/$name" \
|
||||
policy.n_emb="$emb" \
|
||||
policy.n_layer="$layer" \
|
||||
> "data/run_logs/${name}.log" 2>&1
|
||||
echo "[$(date '+%F %T')] END $name"
|
||||
}
|
||||
run_exp imf_attnres_emb384_layer18_seed42_local 384 18
|
||||
run_exp imf_attnres_emb256_layer6_seed42_local 256 6
|
||||
run_exp imf_attnres_emb128_layer6_seed42_local 128 6
|
||||
29
scripts/pusht/imf_attnres_remote_gpu0_queue.sh
Executable file
29
scripts/pusht/imf_attnres_remote_gpu0_queue.sh
Executable file
@@ -0,0 +1,29 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
cd /home/droid/project/diffusion_policy-smoke
|
||||
export PYTHONUNBUFFERED=1
|
||||
export SWANLAB_API_KEY='PSZrBMLx1XAjDjvmhUcNz'
|
||||
export LD_LIBRARY_PATH="$(printf '%s:' .venv/lib/python3.9/site-packages/nvidia/*/lib | sed 's/:$//')"
|
||||
run_exp() {
|
||||
local name="$1" emb="$2" layer="$3"
|
||||
echo "[$(date '+%F %T')] START $name emb=$emb layer=$layer"
|
||||
.venv/bin/python train.py \
|
||||
--config-dir=. \
|
||||
--config-name=image_pusht_diffusion_policy_dit_imf_attnres_full.yaml \
|
||||
training.device=cuda:0 \
|
||||
training.num_epochs=350 \
|
||||
training.resume=false \
|
||||
exp_name="$name" \
|
||||
logging.group=imf_pusht_attnres_arch_sweep \
|
||||
logging.name="$name" \
|
||||
logging.resume=false \
|
||||
logging.id=null \
|
||||
hydra.run.dir="data/outputs/$name" \
|
||||
policy.n_emb="$emb" \
|
||||
policy.n_layer="$layer" \
|
||||
> "data/run_logs/${name}.log" 2>&1
|
||||
echo "[$(date '+%F %T')] END $name"
|
||||
}
|
||||
run_exp imf_attnres_emb384_layer12_seed42_5880gpu0 384 12
|
||||
run_exp imf_attnres_emb256_layer12_seed42_5880gpu0 256 12
|
||||
run_exp imf_attnres_emb128_layer12_seed42_5880gpu0 128 12
|
||||
29
scripts/pusht/imf_attnres_remote_gpu1_queue.sh
Executable file
29
scripts/pusht/imf_attnres_remote_gpu1_queue.sh
Executable file
@@ -0,0 +1,29 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
cd /home/droid/project/diffusion_policy-smoke
|
||||
export PYTHONUNBUFFERED=1
|
||||
export SWANLAB_API_KEY='PSZrBMLx1XAjDjvmhUcNz'
|
||||
export LD_LIBRARY_PATH="$(printf '%s:' .venv/lib/python3.9/site-packages/nvidia/*/lib | sed 's/:$//')"
|
||||
run_exp() {
|
||||
local name="$1" emb="$2" layer="$3"
|
||||
echo "[$(date '+%F %T')] START $name emb=$emb layer=$layer"
|
||||
.venv/bin/python train.py \
|
||||
--config-dir=. \
|
||||
--config-name=image_pusht_diffusion_policy_dit_imf_attnres_full.yaml \
|
||||
training.device=cuda:1 \
|
||||
training.num_epochs=350 \
|
||||
training.resume=false \
|
||||
exp_name="$name" \
|
||||
logging.group=imf_pusht_attnres_arch_sweep \
|
||||
logging.name="$name" \
|
||||
logging.resume=false \
|
||||
logging.id=null \
|
||||
hydra.run.dir="data/outputs/$name" \
|
||||
policy.n_emb="$emb" \
|
||||
policy.n_layer="$layer" \
|
||||
> "data/run_logs/${name}.log" 2>&1
|
||||
echo "[$(date '+%F %T')] END $name"
|
||||
}
|
||||
run_exp imf_attnres_emb384_layer6_seed42_5880gpu1 384 6
|
||||
run_exp imf_attnres_emb256_layer18_seed42_5880gpu1 256 18
|
||||
run_exp imf_attnres_emb128_layer18_seed42_5880gpu1 128 18
|
||||
20
setup_uv_pusht_5090.sh
Executable file
20
setup_uv_pusht_5090.sh
Executable file
@@ -0,0 +1,20 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
ROOT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
cd "$ROOT_DIR"
|
||||
|
||||
export UV_CACHE_DIR="${UV_CACHE_DIR:-$ROOT_DIR/.uv-cache}"
|
||||
export UV_PYTHON_INSTALL_DIR="${UV_PYTHON_INSTALL_DIR:-$ROOT_DIR/.uv-python}"
|
||||
|
||||
uv venv --python 3.9 .venv
|
||||
source .venv/bin/activate
|
||||
|
||||
uv pip install --upgrade pip wheel setuptools==80.9.0
|
||||
uv pip install --python .venv/bin/python \
|
||||
--index-url https://download.pytorch.org/whl/cu128 \
|
||||
torch==2.8.0+cu128 torchvision==0.23.0+cu128 torchaudio==2.8.0+cu128
|
||||
uv pip install --python .venv/bin/python -r requirements-pusht-5090.txt
|
||||
uv pip install --python .venv/bin/python -e .
|
||||
|
||||
echo "uv environment ready at $ROOT_DIR/.venv"
|
||||
77
tests/test_imf_transformer_for_diffusion.py
Normal file
77
tests/test_imf_transformer_for_diffusion.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import inspect
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
|
||||
if str(ROOT_DIR) not in sys.path:
|
||||
sys.path.append(str(ROOT_DIR))
|
||||
|
||||
from diffusion_policy.model.diffusion.imf_transformer_for_diffusion import ( # noqa: E402
|
||||
IMFTransformerForDiffusion,
|
||||
)
|
||||
|
||||
|
||||
def test_imf_transformer_forward_signature_and_shape_single_head():
|
||||
signature = inspect.signature(IMFTransformerForDiffusion.forward)
|
||||
assert list(signature.parameters)[:5] == ['self', 'sample', 'r', 't', 'cond']
|
||||
assert signature.parameters['cond'].default is None
|
||||
|
||||
model = IMFTransformerForDiffusion(
|
||||
input_dim=3,
|
||||
output_dim=3,
|
||||
horizon=5,
|
||||
n_obs_steps=2,
|
||||
cond_dim=4,
|
||||
n_layer=1,
|
||||
n_head=1,
|
||||
n_emb=16,
|
||||
p_drop_emb=0.0,
|
||||
p_drop_attn=0.0,
|
||||
causal_attn=True,
|
||||
time_as_cond=True,
|
||||
obs_as_cond=True,
|
||||
n_cond_layers=0,
|
||||
)
|
||||
model.configure_optimizers()
|
||||
|
||||
sample = torch.randn(2, 5, 3)
|
||||
r = torch.rand(2)
|
||||
t = torch.rand(2)
|
||||
cond = torch.randn(2, 2, 4)
|
||||
|
||||
pred_u = model(sample, r, t, cond=cond)
|
||||
|
||||
assert pred_u.shape == sample.shape
|
||||
|
||||
|
||||
def test_imf_transformer_attnres_full_backbone_forward_shape_and_optimizer():
|
||||
model = IMFTransformerForDiffusion(
|
||||
input_dim=3,
|
||||
output_dim=3,
|
||||
horizon=5,
|
||||
n_obs_steps=2,
|
||||
cond_dim=4,
|
||||
n_layer=2,
|
||||
n_head=1,
|
||||
n_emb=16,
|
||||
p_drop_emb=0.0,
|
||||
p_drop_attn=0.0,
|
||||
causal_attn=False,
|
||||
time_as_cond=True,
|
||||
obs_as_cond=True,
|
||||
n_cond_layers=0,
|
||||
backbone_type='attnres_full',
|
||||
)
|
||||
optimizer = model.configure_optimizers()
|
||||
|
||||
sample = torch.randn(2, 5, 3)
|
||||
r = torch.rand(2)
|
||||
t = torch.rand(2)
|
||||
cond = torch.randn(2, 2, 4)
|
||||
|
||||
pred_u = model(sample, r, t, cond=cond)
|
||||
|
||||
assert pred_u.shape == sample.shape
|
||||
assert optimizer is not None
|
||||
313
tests/test_imf_transformer_hybrid_image_policy.py
Normal file
313
tests/test_imf_transformer_hybrid_image_policy.py
Normal file
@@ -0,0 +1,313 @@
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
|
||||
if str(ROOT_DIR) not in sys.path:
|
||||
sys.path.append(str(ROOT_DIR))
|
||||
|
||||
import diffusion_policy.policy.imf_transformer_hybrid_image_policy as policy_module # noqa: E402
|
||||
from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin # noqa: E402
|
||||
from diffusion_policy.policy.imf_transformer_hybrid_image_policy import ( # noqa: E402
|
||||
IMFTransformerHybridImagePolicy,
|
||||
)
|
||||
|
||||
|
||||
class ConstantModel(nn.Module):
|
||||
def __init__(self, value):
|
||||
super().__init__()
|
||||
self.value = value
|
||||
|
||||
def forward(self, sample, r, t, cond=None):
|
||||
return torch.full_like(sample, self.value)
|
||||
|
||||
|
||||
class AffineModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.tensor(2.0))
|
||||
|
||||
def forward(self, sample, r, t, cond=None):
|
||||
return sample * self.weight + (r + t).view(-1, 1, 1)
|
||||
|
||||
|
||||
class SumMixModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.tensor(2.0))
|
||||
|
||||
def forward(self, sample, r, t, cond=None):
|
||||
mixed = sample.sum(dim=-1, keepdim=True).expand_as(sample)
|
||||
return mixed * self.weight + t.view(-1, 1, 1)
|
||||
|
||||
|
||||
class TrackingContext:
|
||||
def __init__(self):
|
||||
self.active = False
|
||||
self.enter_count = 0
|
||||
|
||||
def __enter__(self):
|
||||
self.active = True
|
||||
self.enter_count += 1
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
self.active = False
|
||||
return False
|
||||
|
||||
|
||||
def make_policy(model):
|
||||
policy = IMFTransformerHybridImagePolicy.__new__(IMFTransformerHybridImagePolicy)
|
||||
ModuleAttrMixin.__init__(policy)
|
||||
policy.model = model
|
||||
return policy
|
||||
|
||||
|
||||
def fake_parent_init(
|
||||
self,
|
||||
shape_meta,
|
||||
noise_scheduler,
|
||||
horizon,
|
||||
n_action_steps,
|
||||
n_obs_steps,
|
||||
num_inference_steps=None,
|
||||
crop_shape=(76, 76),
|
||||
obs_encoder_group_norm=False,
|
||||
eval_fixed_crop=False,
|
||||
n_layer=8,
|
||||
n_cond_layers=0,
|
||||
n_head=1,
|
||||
n_emb=256,
|
||||
p_drop_emb=0.0,
|
||||
p_drop_attn=0.3,
|
||||
causal_attn=True,
|
||||
time_as_cond=True,
|
||||
obs_as_cond=True,
|
||||
pred_action_steps_only=False,
|
||||
**kwargs,
|
||||
):
|
||||
ModuleAttrMixin.__init__(self)
|
||||
self.action_dim = shape_meta['action']['shape'][0]
|
||||
self.obs_feature_dim = 4
|
||||
self.obs_as_cond = obs_as_cond
|
||||
self.pred_action_steps_only = pred_action_steps_only
|
||||
self.n_action_steps = n_action_steps
|
||||
self.n_obs_steps = n_obs_steps
|
||||
self.horizon = horizon
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def shape_meta():
|
||||
return {
|
||||
'action': {'shape': [2]},
|
||||
'obs': {},
|
||||
}
|
||||
|
||||
|
||||
def test_sample_one_step_uses_imf_update_formula():
|
||||
policy = make_policy(ConstantModel(0.25))
|
||||
z_1 = torch.tensor([
|
||||
[[1.0, -1.0], [0.5, 0.0]],
|
||||
[[2.0, 3.0], [-2.0, 4.0]],
|
||||
])
|
||||
r = torch.zeros(z_1.shape[0])
|
||||
t = torch.ones(z_1.shape[0])
|
||||
|
||||
x_hat = policy._sample_one_step(z_1, r=r, t=t, cond=None)
|
||||
|
||||
expected = z_1 - (t - r).view(-1, 1, 1) * 0.25
|
||||
assert torch.allclose(x_hat, expected)
|
||||
|
||||
|
||||
def test_compound_velocity_uses_detached_du_dt_term():
|
||||
policy = make_policy(ConstantModel(0.0))
|
||||
u = torch.tensor([[[1.0], [2.0]]], requires_grad=True)
|
||||
du_dt = torch.tensor([[[3.0], [4.0]]], requires_grad=True)
|
||||
r = torch.tensor([0.2])
|
||||
t = torch.tensor([0.8])
|
||||
|
||||
compound = policy._compound_velocity(u, du_dt, r, t)
|
||||
expected = u + (t - r).view(-1, 1, 1) * du_dt.detach()
|
||||
|
||||
assert torch.allclose(compound, expected)
|
||||
|
||||
compound.sum().backward()
|
||||
assert u.grad is not None
|
||||
assert du_dt.grad is None
|
||||
|
||||
|
||||
def test_compute_u_and_du_dt_uses_math_sdpa_context_for_torch_func_jvp(monkeypatch):
|
||||
tracker = TrackingContext()
|
||||
|
||||
def fake_jvp(fn, primals, tangents):
|
||||
assert tracker.active is True
|
||||
return fn(*primals), torch.zeros_like(primals[0])
|
||||
|
||||
monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', fake_jvp)
|
||||
|
||||
policy = make_policy(ConstantModel(0.5))
|
||||
policy._jvp_math_sdp_context = lambda tensor: tracker
|
||||
z_t = torch.randn(2, 3, 4)
|
||||
r = torch.rand(2, requires_grad=True)
|
||||
t = torch.rand(2, requires_grad=True)
|
||||
v = torch.randn_like(z_t, requires_grad=True)
|
||||
|
||||
policy._compute_u_and_du_dt(z_t, r, t, cond=None, v=v)
|
||||
|
||||
assert tracker.enter_count == 1
|
||||
|
||||
|
||||
def test_compute_u_and_du_dt_uses_math_sdpa_context_for_autograd_fallback(monkeypatch):
|
||||
tracker = TrackingContext()
|
||||
|
||||
def fake_autograd_jvp(fn, primals, tangents, create_graph=False, strict=False):
|
||||
assert tracker.active is True
|
||||
return fn(*primals), torch.zeros_like(primals[0])
|
||||
|
||||
monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', None)
|
||||
monkeypatch.setattr(policy_module.torch.autograd.functional, 'jvp', fake_autograd_jvp)
|
||||
|
||||
policy = make_policy(ConstantModel(0.5))
|
||||
policy._jvp_math_sdp_context = lambda tensor: tracker
|
||||
z_t = torch.randn(2, 3, 4)
|
||||
r = torch.rand(2, requires_grad=True)
|
||||
t = torch.rand(2, requires_grad=True)
|
||||
v = torch.randn_like(z_t, requires_grad=True)
|
||||
|
||||
policy._compute_u_and_du_dt(z_t, r, t, cond=None, v=v)
|
||||
|
||||
assert tracker.enter_count == 1
|
||||
|
||||
|
||||
def test_compute_u_and_du_dt_uses_detached_v_zero_r_unit_t_and_reapplies_conditioning(monkeypatch):
|
||||
captured = {}
|
||||
|
||||
def fake_jvp(fn, primals, tangents):
|
||||
captured['tangents'] = tangents
|
||||
captured['primal_output'] = fn(*primals)
|
||||
return captured['primal_output'], torch.zeros_like(primals[0])
|
||||
|
||||
monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', fake_jvp)
|
||||
|
||||
policy = make_policy(SumMixModel())
|
||||
z_t = torch.tensor([[[1.0, 2.0, 3.0]]])
|
||||
r = torch.rand(1, requires_grad=True)
|
||||
t = torch.rand(1, requires_grad=True)
|
||||
v = torch.tensor([[[10.0, 20.0, 30.0]]], requires_grad=True)
|
||||
condition_mask = torch.tensor([[[False, True, False]]])
|
||||
condition_data = torch.tensor([[[0.0, 7.0, 0.0]]])
|
||||
|
||||
policy._compute_u_and_du_dt(
|
||||
z_t,
|
||||
r,
|
||||
t,
|
||||
cond=None,
|
||||
v=v,
|
||||
condition_data=condition_data,
|
||||
condition_mask=condition_mask,
|
||||
)
|
||||
|
||||
tangent_v, tangent_r, tangent_t = captured['tangents']
|
||||
assert torch.equal(tangent_v, v.detach())
|
||||
assert tangent_v.requires_grad is False
|
||||
assert torch.equal(tangent_r, torch.zeros_like(r))
|
||||
assert torch.equal(tangent_t, torch.ones_like(t))
|
||||
|
||||
conditioned = z_t.clone()
|
||||
conditioned[condition_mask] = condition_data[condition_mask]
|
||||
expected_primal = policy.model(conditioned, r, t, cond=None)
|
||||
assert torch.allclose(captured['primal_output'], expected_primal)
|
||||
|
||||
|
||||
def test_compute_u_and_du_dt_fallback_blocks_conditioned_tangent_leakage_and_keeps_primal_gradients(monkeypatch):
|
||||
monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', None)
|
||||
|
||||
policy = make_policy(SumMixModel())
|
||||
z_t = torch.tensor([[[1.0, 2.0, 3.0]]], requires_grad=True)
|
||||
r = torch.rand(1, requires_grad=True)
|
||||
t = torch.rand(1, requires_grad=True)
|
||||
v = torch.tensor([[[1.0, 10.0, 100.0]]], requires_grad=True)
|
||||
condition_mask = torch.tensor([[[False, True, False]]])
|
||||
condition_data = torch.tensor([[[0.0, 7.0, 0.0]]])
|
||||
|
||||
u, du_dt = policy._compute_u_and_du_dt(
|
||||
z_t,
|
||||
r,
|
||||
t,
|
||||
cond=None,
|
||||
v=v,
|
||||
condition_data=condition_data,
|
||||
condition_mask=condition_mask,
|
||||
)
|
||||
|
||||
conditioned = z_t.detach().clone()
|
||||
conditioned[condition_mask] = condition_data[condition_mask]
|
||||
expected_u = policy.model(conditioned, r, t, cond=None)
|
||||
expected_du_dt_scalar = policy.model.weight.detach() * torch.tensor(101.0) + 1.0
|
||||
expected_du_dt = torch.full_like(z_t, expected_du_dt_scalar)
|
||||
|
||||
assert u.shape == z_t.shape
|
||||
assert du_dt.shape == z_t.shape
|
||||
assert torch.allclose(u, expected_u)
|
||||
assert torch.allclose(du_dt, expected_du_dt)
|
||||
|
||||
u.sum().backward()
|
||||
assert policy.model.weight.grad is not None
|
||||
assert torch.count_nonzero(policy.model.weight.grad) > 0
|
||||
|
||||
|
||||
def test_init_uses_action_step_horizon_when_pred_action_steps_only(monkeypatch, shape_meta):
|
||||
monkeypatch.setattr(
|
||||
policy_module.DiffusionTransformerHybridImagePolicy,
|
||||
'__init__',
|
||||
fake_parent_init,
|
||||
)
|
||||
|
||||
policy = IMFTransformerHybridImagePolicy(
|
||||
shape_meta=shape_meta,
|
||||
noise_scheduler=None,
|
||||
horizon=10,
|
||||
n_action_steps=4,
|
||||
n_obs_steps=2,
|
||||
num_inference_steps=1,
|
||||
n_layer=1,
|
||||
n_head=1,
|
||||
n_emb=16,
|
||||
p_drop_emb=0.0,
|
||||
p_drop_attn=0.0,
|
||||
causal_attn=True,
|
||||
obs_as_cond=True,
|
||||
pred_action_steps_only=True,
|
||||
)
|
||||
|
||||
assert policy.model.horizon == 4
|
||||
assert policy.num_inference_steps == 1
|
||||
|
||||
|
||||
def test_init_rejects_non_one_step_inference(monkeypatch, shape_meta):
|
||||
monkeypatch.setattr(
|
||||
policy_module.DiffusionTransformerHybridImagePolicy,
|
||||
'__init__',
|
||||
fake_parent_init,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match='num_inference_steps'):
|
||||
IMFTransformerHybridImagePolicy(
|
||||
shape_meta=shape_meta,
|
||||
noise_scheduler=None,
|
||||
horizon=10,
|
||||
n_action_steps=4,
|
||||
n_obs_steps=2,
|
||||
num_inference_steps=2,
|
||||
n_layer=1,
|
||||
n_head=1,
|
||||
n_emb=16,
|
||||
p_drop_emb=0.0,
|
||||
p_drop_attn=0.0,
|
||||
causal_attn=True,
|
||||
obs_as_cond=True,
|
||||
pred_action_steps_only=False,
|
||||
)
|
||||
110
tests/test_pusht_image_runner_metrics.py
Normal file
110
tests/test_pusht_image_runner_metrics.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
import gym
|
||||
from gym import spaces
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
|
||||
if str(ROOT_DIR) not in sys.path:
|
||||
sys.path.append(str(ROOT_DIR))
|
||||
|
||||
import diffusion_policy.env_runner.pusht_image_runner as runner_module
|
||||
from diffusion_policy.env_runner.pusht_image_runner import summarize_rollout_metrics
|
||||
|
||||
|
||||
class FakePushTImageEnv(gym.Env):
|
||||
metadata = {'render.modes': ['rgb_array']}
|
||||
|
||||
def __init__(self, legacy=False, render_size=96):
|
||||
del legacy, render_size
|
||||
self.observation_space = spaces.Dict({
|
||||
'image': spaces.Box(low=0, high=255, shape=(3, 4, 4), dtype=np.uint8),
|
||||
})
|
||||
self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,), dtype=np.float32)
|
||||
self.seed_value = 0
|
||||
self.step_count = 0
|
||||
|
||||
def seed(self, seed=None):
|
||||
self.seed_value = 0 if seed is None else seed
|
||||
|
||||
def reset(self):
|
||||
self.step_count = 0
|
||||
return {'image': np.zeros((3, 4, 4), dtype=np.uint8)}
|
||||
|
||||
def step(self, action):
|
||||
del action
|
||||
self.step_count += 1
|
||||
reward = 0.1 if self.seed_value < 10000 else 0.9
|
||||
done = self.step_count >= 1
|
||||
obs = {'image': np.full((3, 4, 4), self.step_count, dtype=np.uint8)}
|
||||
return obs, reward, done, {}
|
||||
|
||||
def render(self, *args, **kwargs):
|
||||
raise AssertionError('render should not be called for scalar-only PushT image rollouts')
|
||||
|
||||
|
||||
class FakePolicy:
|
||||
device = torch.device('cpu')
|
||||
dtype = torch.float32
|
||||
|
||||
def reset(self):
|
||||
return None
|
||||
|
||||
def predict_action(self, obs_dict):
|
||||
n_envs = next(iter(obs_dict.values())).shape[0]
|
||||
return {
|
||||
'action': torch.zeros((n_envs, 2, 2), dtype=torch.float32),
|
||||
}
|
||||
|
||||
|
||||
def test_summarize_rollout_metrics_keeps_scalar_rewards_renames_means_and_omits_videos():
|
||||
log_data = summarize_rollout_metrics(
|
||||
env_seeds=[11, 12, 101],
|
||||
env_prefixs=['train/', 'train/', 'test/'],
|
||||
all_rewards=[
|
||||
[0.2, 0.8],
|
||||
[0.1, 0.4],
|
||||
[0.5, 0.9],
|
||||
],
|
||||
all_video_paths=[
|
||||
'/tmp/train-11.mp4',
|
||||
'/tmp/train-12.mp4',
|
||||
'/tmp/test-101.mp4',
|
||||
],
|
||||
)
|
||||
|
||||
assert log_data['train/sim_max_reward_11'] == 0.8
|
||||
assert log_data['train/sim_max_reward_12'] == 0.4
|
||||
assert log_data['test/sim_max_reward_101'] == 0.9
|
||||
assert log_data['train_mean_score'] == pytest.approx(0.6)
|
||||
assert log_data['test_mean_score'] == pytest.approx(0.9)
|
||||
assert not any(key.startswith('train/sim_video_') for key in log_data)
|
||||
assert not any(key.startswith('test/sim_video_') for key in log_data)
|
||||
|
||||
|
||||
def test_runner_ignores_vis_flags_and_never_emits_sim_videos(tmp_path, monkeypatch):
|
||||
monkeypatch.setattr(runner_module, 'PushTImageEnv', FakePushTImageEnv)
|
||||
|
||||
runner = runner_module.PushTImageRunner(
|
||||
output_dir=tmp_path,
|
||||
n_train=1,
|
||||
n_train_vis=1,
|
||||
n_test=1,
|
||||
n_test_vis=1,
|
||||
n_envs=2,
|
||||
max_steps=2,
|
||||
n_obs_steps=2,
|
||||
n_action_steps=2,
|
||||
tqdm_interval_sec=0.0,
|
||||
)
|
||||
|
||||
log_data = runner.run(FakePolicy())
|
||||
|
||||
assert log_data['train/sim_max_reward_0'] == pytest.approx(0.1)
|
||||
assert log_data['test/sim_max_reward_10000'] == pytest.approx(0.9)
|
||||
assert log_data['train_mean_score'] == pytest.approx(0.1)
|
||||
assert log_data['test_mean_score'] == pytest.approx(0.9)
|
||||
assert not any('sim_video' in key for key in log_data)
|
||||
57
tests/test_pusht_swanlab_config.py
Normal file
57
tests/test_pusht_swanlab_config.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import pathlib
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
|
||||
|
||||
|
||||
def _load_cfg(name: str):
|
||||
return OmegaConf.load(ROOT_DIR / name)
|
||||
|
||||
|
||||
def test_image_pusht_dit_swanlab_config_uses_exp_name_and_no_resume_collision():
|
||||
cfg = _load_cfg('image_pusht_diffusion_policy_dit.yaml')
|
||||
|
||||
assert cfg.logging.backend == 'swanlab'
|
||||
assert cfg.logging.mode == 'online'
|
||||
assert cfg.logging.name == cfg.exp_name
|
||||
assert cfg.logging.resume is False
|
||||
assert cfg.logging.id is None
|
||||
assert cfg.logging.group == cfg.exp_name
|
||||
|
||||
|
||||
def test_image_pusht_dit_imf_swanlab_config_uses_exp_name_and_no_resume_collision():
|
||||
cfg = _load_cfg('image_pusht_diffusion_policy_dit_imf.yaml')
|
||||
|
||||
assert cfg.logging.backend == 'swanlab'
|
||||
assert cfg.logging.mode == 'online'
|
||||
assert cfg.logging.name == cfg.exp_name
|
||||
assert cfg.logging.resume is False
|
||||
assert cfg.logging.id is None
|
||||
assert cfg.logging.group == cfg.exp_name
|
||||
|
||||
|
||||
def test_image_pusht_dit_imf_fullattn_config_uses_exp_name_and_disables_causal_attention():
|
||||
cfg = _load_cfg('image_pusht_diffusion_policy_dit_imf_fullattn.yaml')
|
||||
|
||||
assert cfg.logging.backend == 'swanlab'
|
||||
assert cfg.logging.mode == 'online'
|
||||
assert cfg.logging.name == cfg.exp_name
|
||||
assert cfg.logging.resume is False
|
||||
assert cfg.logging.id is None
|
||||
assert cfg.logging.group == cfg.exp_name
|
||||
assert cfg.policy.causal_attn is False
|
||||
|
||||
|
||||
def test_image_pusht_dit_imf_attnres_full_config_uses_exp_name_and_disables_causal_attention():
|
||||
cfg = _load_cfg('image_pusht_diffusion_policy_dit_imf_attnres_full.yaml')
|
||||
|
||||
assert cfg.logging.backend == 'swanlab'
|
||||
assert cfg.logging.mode == 'online'
|
||||
assert cfg.logging.name == cfg.exp_name
|
||||
assert cfg.logging.resume is False
|
||||
assert cfg.logging.id is None
|
||||
assert cfg.logging.group == cfg.exp_name
|
||||
assert cfg.policy.causal_attn is False
|
||||
assert cfg.policy.backbone_type == 'attnres_full'
|
||||
198
tests/test_train_diffusion_transformer_workspace_logging.py
Normal file
198
tests/test_train_diffusion_transformer_workspace_logging.py
Normal file
@@ -0,0 +1,198 @@
|
||||
import importlib
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
|
||||
if str(ROOT_DIR) not in sys.path:
|
||||
sys.path.append(str(ROOT_DIR))
|
||||
|
||||
MODULE_NAME = 'diffusion_policy.workspace.train_diffusion_transformer_hybrid_workspace'
|
||||
|
||||
|
||||
def load_workspace_module(monkeypatch, *, wandb_missing=False):
|
||||
sys.modules.pop(MODULE_NAME, None)
|
||||
if wandb_missing:
|
||||
monkeypatch.setitem(sys.modules, 'wandb', None)
|
||||
return importlib.import_module(MODULE_NAME)
|
||||
|
||||
|
||||
def test_init_logger_uses_swanlab_backend_mapping_without_loading_wandb(tmp_path, monkeypatch):
|
||||
workspace_module = load_workspace_module(monkeypatch, wandb_missing=True)
|
||||
events = []
|
||||
|
||||
class FakeRun:
|
||||
def log(self, payload, step=None):
|
||||
events.append(('log', payload, step))
|
||||
|
||||
def finish(self):
|
||||
events.append(('finish',))
|
||||
|
||||
class FakeSwanLab:
|
||||
def init(self, **kwargs):
|
||||
events.append(('init', kwargs))
|
||||
return FakeRun()
|
||||
|
||||
monkeypatch.setattr(workspace_module, '_load_swanlab', lambda: FakeSwanLab())
|
||||
monkeypatch.setattr(
|
||||
workspace_module,
|
||||
'_load_wandb',
|
||||
lambda: pytest.fail('wandb should not be loaded for the SwanLab backend'),
|
||||
)
|
||||
|
||||
cfg = OmegaConf.create({
|
||||
'logging': {
|
||||
'backend': 'swanlab',
|
||||
'project': 'demo-project',
|
||||
'name': 'demo-run',
|
||||
'group': 'demo-group',
|
||||
'tags': ['pusht', 'dit'],
|
||||
'id': 'run-123',
|
||||
'resume': True,
|
||||
'mode': 'online',
|
||||
}
|
||||
})
|
||||
|
||||
logger = workspace_module.init_logging_backend(cfg=cfg, output_dir=tmp_path)
|
||||
logger.log({'metric': 1.0}, step=7)
|
||||
logger.finish()
|
||||
|
||||
assert events[0][0] == 'init'
|
||||
init_kwargs = events[0][1]
|
||||
assert init_kwargs['project'] == 'demo-project'
|
||||
assert init_kwargs['experiment_name'] == 'demo-run'
|
||||
assert init_kwargs['group'] == 'demo-group'
|
||||
assert init_kwargs['tags'] == ['pusht', 'dit']
|
||||
assert init_kwargs['id'] == 'run-123'
|
||||
assert init_kwargs['resume'] is True
|
||||
assert init_kwargs['mode'] == 'cloud'
|
||||
assert init_kwargs['logdir'] == str(tmp_path / 'swanlog')
|
||||
assert ('log', {'metric': 1.0}, 7) in events
|
||||
assert events.count(('finish',)) == 1
|
||||
|
||||
|
||||
def test_init_logger_defaults_to_legacy_wandb_path_when_backend_missing(tmp_path, monkeypatch):
|
||||
workspace_module = load_workspace_module(monkeypatch)
|
||||
events = []
|
||||
|
||||
class FakeRun:
|
||||
def log(self, payload, step=None):
|
||||
events.append(('log', payload, step))
|
||||
|
||||
def finish(self):
|
||||
events.append(('finish',))
|
||||
|
||||
class FakeConfig:
|
||||
def update(self, payload):
|
||||
events.append(('config.update', payload))
|
||||
|
||||
class FakeWandb:
|
||||
def __init__(self):
|
||||
self.config = FakeConfig()
|
||||
|
||||
def init(self, **kwargs):
|
||||
events.append(('init', kwargs))
|
||||
return FakeRun()
|
||||
|
||||
monkeypatch.setattr(workspace_module, '_load_wandb', lambda: FakeWandb())
|
||||
|
||||
cfg = OmegaConf.create({
|
||||
'logging': {
|
||||
'project': 'demo-project',
|
||||
'name': 'demo-run',
|
||||
'group': None,
|
||||
'tags': ['shared'],
|
||||
'id': None,
|
||||
'resume': True,
|
||||
'mode': 'online',
|
||||
}
|
||||
})
|
||||
|
||||
logger = workspace_module.init_logging_backend(cfg=cfg, output_dir=tmp_path)
|
||||
logger.log({'metric': 2.0}, step=3)
|
||||
logger.finish()
|
||||
|
||||
assert events[0][0] == 'init'
|
||||
init_kwargs = events[0][1]
|
||||
assert init_kwargs['dir'] == str(tmp_path)
|
||||
assert init_kwargs['project'] == 'demo-project'
|
||||
assert init_kwargs['name'] == 'demo-run'
|
||||
assert init_kwargs['mode'] == 'online'
|
||||
assert ('config.update', {'output_dir': str(tmp_path)}) in events
|
||||
assert ('log', {'metric': 2.0}, 3) in events
|
||||
assert events.count(('finish',)) == 1
|
||||
|
||||
|
||||
def test_init_logger_rejects_unknown_backends(tmp_path, monkeypatch):
|
||||
workspace_module = load_workspace_module(monkeypatch)
|
||||
cfg = OmegaConf.create({
|
||||
'logging': {
|
||||
'backend': 'tensorboard',
|
||||
'project': 'demo-project',
|
||||
'name': 'demo-run',
|
||||
'mode': 'offline',
|
||||
}
|
||||
})
|
||||
|
||||
with pytest.raises(ValueError, match='Unknown logging backend'):
|
||||
workspace_module.init_logging_backend(cfg=cfg, output_dir=tmp_path)
|
||||
|
||||
|
||||
|
||||
|
||||
def test_logging_backend_session_preserves_primary_exception_when_finish_fails(tmp_path, monkeypatch):
|
||||
workspace_module = load_workspace_module(monkeypatch)
|
||||
events = []
|
||||
|
||||
class FakeBackend:
|
||||
def log(self, payload, step=None):
|
||||
events.append(('log', payload, step))
|
||||
|
||||
def finish(self):
|
||||
events.append(('finish',))
|
||||
raise RuntimeError('finish boom')
|
||||
|
||||
monkeypatch.setattr(
|
||||
workspace_module,
|
||||
'init_logging_backend',
|
||||
lambda cfg, output_dir: FakeBackend(),
|
||||
)
|
||||
|
||||
cfg = OmegaConf.create({'logging': {'mode': 'offline'}})
|
||||
|
||||
with pytest.raises(ValueError, match='primary boom'):
|
||||
with workspace_module.logging_backend_session(cfg=cfg, output_dir=tmp_path) as logger:
|
||||
logger.log({'metric': 6.0}, step=12)
|
||||
raise ValueError('primary boom')
|
||||
|
||||
assert ('log', {'metric': 6.0}, 12) in events
|
||||
assert events.count(('finish',)) == 1
|
||||
|
||||
def test_logging_backend_session_finishes_on_exception(tmp_path, monkeypatch):
|
||||
workspace_module = load_workspace_module(monkeypatch)
|
||||
events = []
|
||||
|
||||
class FakeBackend:
|
||||
def log(self, payload, step=None):
|
||||
events.append(('log', payload, step))
|
||||
|
||||
def finish(self):
|
||||
events.append(('finish',))
|
||||
|
||||
monkeypatch.setattr(
|
||||
workspace_module,
|
||||
'init_logging_backend',
|
||||
lambda cfg, output_dir: FakeBackend(),
|
||||
)
|
||||
|
||||
cfg = OmegaConf.create({'logging': {'mode': 'offline'}})
|
||||
|
||||
with pytest.raises(RuntimeError, match='boom'):
|
||||
with workspace_module.logging_backend_session(cfg=cfg, output_dir=tmp_path) as logger:
|
||||
logger.log({'metric': 5.0}, step=11)
|
||||
raise RuntimeError('boom')
|
||||
|
||||
assert ('log', {'metric': 5.0}, 11) in events
|
||||
assert events.count(('finish',)) == 1
|
||||
Reference in New Issue
Block a user