Compare commits
10 Commits
749db2ce9c
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
08c1950c6d | ||
|
|
5ba07ac666 | ||
|
|
548a52bbb1 | ||
|
|
de4384e84a | ||
|
|
7dd9dc417a | ||
|
|
5aa9996fdc | ||
|
|
5c3d54fca3 | ||
|
|
a98e74873b | ||
|
|
68eef44d3e | ||
|
|
c52bac42ee |
68
AGENTS.md
Normal file
68
AGENTS.md
Normal file
@@ -0,0 +1,68 @@
|
||||
# Agent Notes
|
||||
|
||||
## Purpose
|
||||
`~/diffusion_policy` is the Diffusion Policy training repo. The main workflow here is Hydra-driven training via `train.py`, with the canonical PushT image experiment configured by `image_pusht_diffusion_policy_cnn.yaml`.
|
||||
|
||||
## Top Level
|
||||
- `diffusion_policy/`: core code, configs, datasets, env runners, workspaces.
|
||||
- `data/`: local datasets, outputs, checkpoints, run logs.
|
||||
- `train.py`: main training entrypoint.
|
||||
- `eval.py`: checkpoint evaluation entrypoint.
|
||||
- `image_pusht_diffusion_policy_cnn.yaml`: canonical single-seed PushT image config from the README path.
|
||||
- `.venv/`: local `uv`-managed virtualenv.
|
||||
- `.uv-cache/`, `.uv-python/`: local `uv` cache and Python install state.
|
||||
- `README.md`: upstream instructions and canonical commands.
|
||||
|
||||
## Canonical PushT Image Path
|
||||
- Entrypoint: `python train.py --config-dir=. --config-name=image_pusht_diffusion_policy_cnn.yaml`
|
||||
- Dataset path in config: `data/pusht/pusht_cchi_v7_replay.zarr`
|
||||
- README canonical device override: `training.device=cuda:0`
|
||||
|
||||
## Data
|
||||
- PushT archive currently present at `data/pusht.zip`
|
||||
- Unpacked dataset used by training: `data/pusht/pusht_cchi_v7_replay.zarr`
|
||||
|
||||
## Local Compatibility Adjustments
|
||||
- `diffusion_policy/env_runner/pusht_image_runner.py` now uses `SyncVectorEnv` instead of `AsyncVectorEnv`.
|
||||
Reason: avoid shared-memory and semaphore failures on this host/session.
|
||||
- `diffusion_policy/gym_util/sync_vector_env.py` has local compatibility changes:
|
||||
- added `reset_async`
|
||||
- seeded `reset_wait`
|
||||
- updated `concatenate(...)` call order for the current `gym` API
|
||||
|
||||
## Environment Expectations
|
||||
- Use the local `uv` env at `.venv`
|
||||
- Verified local Python: `3.9.25`
|
||||
- Verified local Torch stack: `torch 2.8.0+cu128`, `torchvision 0.23.0+cu128`
|
||||
- Other key installed versions verified in `.venv`:
|
||||
- `gym 0.23.1`
|
||||
- `hydra-core 1.2.0`
|
||||
- `diffusers 0.11.1`
|
||||
- `huggingface_hub 0.10.1`
|
||||
- `wandb 0.13.3`
|
||||
- `zarr 2.12.0`
|
||||
- `numcodecs 0.10.2`
|
||||
- `av 14.0.1`
|
||||
- Important note: this shell currently reports `torch.cuda.is_available() == False`, so always verify CUDA access in the current session before assuming GPU is usable.
|
||||
|
||||
## Logging And Outputs
|
||||
- Hydra run outputs: `data/outputs/...`
|
||||
- Per-run files to check first:
|
||||
- `.hydra/overrides.yaml`
|
||||
- `logs.json.txt`
|
||||
- `train.log`
|
||||
- `checkpoints/latest.ckpt`
|
||||
- Extra launcher logs may live under `data/run_logs/`
|
||||
|
||||
## Practical Guidance
|
||||
- Inspect with `rg`, `sed`, and existing Hydra output folders before changing code.
|
||||
- Prefer config overrides before code edits.
|
||||
- On this host, start from these safety overrides unless revalidated:
|
||||
- `logging.mode=offline`
|
||||
- `dataloader.num_workers=0`
|
||||
- `val_dataloader.num_workers=0`
|
||||
- `task.env_runner.n_envs=1`
|
||||
- `task.env_runner.n_test_vis=0`
|
||||
- `task.env_runner.n_train_vis=0`
|
||||
- If a run fails, inspect `.hydra/overrides.yaml`, then `logs.json.txt`, then `train.log`.
|
||||
- Avoid driver or system changes unless the repo-local path is clearly blocked.
|
||||
108
PUSHT_REPRO_5090.md
Normal file
108
PUSHT_REPRO_5090.md
Normal file
@@ -0,0 +1,108 @@
|
||||
# PushT Repro On 5090
|
||||
|
||||
## Goal
|
||||
Reproduce the canonical single-seed image PushT experiment from this repo in `~/diffusion_policy` using `image_pusht_diffusion_policy_cnn.yaml`.
|
||||
|
||||
## Current Verified Local Setup
|
||||
- Virtualenv: `./.venv` managed with `uv`
|
||||
- Python: `3.9.25`
|
||||
- Torch stack: `torch 2.8.0+cu128`, `torchvision 0.23.0+cu128`
|
||||
- Version strategy used here:
|
||||
- newer Torch/CUDA stack for current 5090-class hardware support
|
||||
- keep older repo-era packages where they are still required by the code
|
||||
- Verified key pins in `.venv`:
|
||||
- `numpy 1.26.4`
|
||||
- `gym 0.23.1`
|
||||
- `hydra-core 1.2.0`
|
||||
- `diffusers 0.11.1`
|
||||
- `huggingface_hub 0.10.1`
|
||||
- `wandb 0.13.3`
|
||||
- `zarr 2.12.0`
|
||||
- `numcodecs 0.10.2`
|
||||
- `av 14.0.1`
|
||||
- `robomimic 0.2.0`
|
||||
|
||||
## Dataset
|
||||
- README source: `https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip`
|
||||
- Local archive currently present: `data/pusht.zip`
|
||||
- Unpacked dataset used by the config: `data/pusht/pusht_cchi_v7_replay.zarr`
|
||||
|
||||
## Repo-Local Code Adjustments
|
||||
- `diffusion_policy/env_runner/pusht_image_runner.py`
|
||||
- switched PushT image evaluation from `AsyncVectorEnv` to `SyncVectorEnv`
|
||||
- `diffusion_policy/gym_util/sync_vector_env.py`
|
||||
- added `reset_async`
|
||||
- added seeded `reset_wait`
|
||||
- updated `concatenate(...)` call order for current `gym`
|
||||
|
||||
These changes were needed to keep PushT evaluation working without the async shared-memory path.
|
||||
|
||||
## Validated GPU Smoke Command
|
||||
This route is verified by `data/outputs/gpu_smoke2___pusht_gpu_smoke`, which contains `logs.json.txt` plus checkpoints:
|
||||
|
||||
```bash
|
||||
.venv/bin/python train.py \
|
||||
--config-dir=. \
|
||||
--config-name=image_pusht_diffusion_policy_cnn.yaml \
|
||||
training.seed=42 \
|
||||
training.device=cuda:0 \
|
||||
logging.mode=offline \
|
||||
dataloader.num_workers=0 \
|
||||
val_dataloader.num_workers=0 \
|
||||
task.env_runner.n_envs=1 \
|
||||
training.debug=true \
|
||||
task.env_runner.n_test=2 \
|
||||
task.env_runner.n_test_vis=0 \
|
||||
task.env_runner.n_train=1 \
|
||||
task.env_runner.n_train_vis=0 \
|
||||
task.env_runner.max_steps=20
|
||||
```
|
||||
|
||||
## Practical Full Training Command Used Here
|
||||
This matches the longer GPU run under `data/outputs/2026.03.13/15.37.00_train_diffusion_unet_hybrid_pusht_image_gpu_seed42`:
|
||||
|
||||
```bash
|
||||
.venv/bin/python train.py \
|
||||
--config-dir=. \
|
||||
--config-name=image_pusht_diffusion_policy_cnn.yaml \
|
||||
training.seed=42 \
|
||||
training.device=cuda:0 \
|
||||
logging.mode=offline \
|
||||
dataloader.num_workers=0 \
|
||||
val_dataloader.num_workers=0 \
|
||||
task.env_runner.n_envs=1 \
|
||||
task.env_runner.n_test_vis=0 \
|
||||
task.env_runner.n_train_vis=0 \
|
||||
hydra.run.dir=data/outputs/2026.03.13/15.37.00_train_diffusion_unet_hybrid_pusht_image_gpu_seed42
|
||||
```
|
||||
|
||||
## Why These Overrides Were Used
|
||||
- `logging.mode=offline`
|
||||
- avoids needing a W&B login and still leaves local run metadata in the output dir
|
||||
- `dataloader.num_workers=0` and `val_dataloader.num_workers=0`
|
||||
- avoids extra multiprocessing on this host
|
||||
- `task.env_runner.n_envs=1`
|
||||
- keeps PushT eval on the serial `SyncVectorEnv` path
|
||||
- `task.env_runner.n_test_vis=0` and `task.env_runner.n_train_vis=0`
|
||||
- avoids video-writing issues on this stack
|
||||
- one earlier GPU run with default vis settings logged libav/libx264 `profile=high` errors in `data/outputs/_train_diffusion_unet_hybrid_pusht_image_gpu_seed42/train.log`
|
||||
|
||||
## Output Locations
|
||||
- Smoke run:
|
||||
- `data/outputs/gpu_smoke2___pusht_gpu_smoke`
|
||||
- Longer GPU run:
|
||||
- `data/outputs/2026.03.13/15.37.00_train_diffusion_unet_hybrid_pusht_image_gpu_seed42`
|
||||
- Files to inspect inside a run:
|
||||
- `.hydra/overrides.yaml`
|
||||
- `logs.json.txt`
|
||||
- `train.log`
|
||||
- `checkpoints/latest.ckpt`
|
||||
|
||||
## Known Caveats
|
||||
- The default config is still tuned for older assumptions:
|
||||
- `logging.mode=online`
|
||||
- `dataloader.num_workers=8`
|
||||
- `task.env_runner.n_envs=null`
|
||||
- `task.env_runner.n_test_vis=4`
|
||||
- `task.env_runner.n_train_vis=2`
|
||||
- In this shell, `torch.cuda.is_available()` currently reports `False` even though the repo contains validated GPU smoke/full run artifacts. Re-check device visibility in the current session before restarting a GPU run.
|
||||
35
README.md
35
README.md
@@ -202,6 +202,41 @@ data/outputs/2023.03.01/22.13.58_train_diffusion_unet_hybrid_pusht_image
|
||||
|
||||
7 directories, 16 files
|
||||
```
|
||||
### 🆕 Evaluate Pre-trained Checkpoints
|
||||
Download a checkpoint from the published training log folders, such as [https://diffusion-policy.cs.columbia.edu/data/experiments/low_dim/pusht/diffusion_policy_cnn/train_0/checkpoints/epoch=0550-test_mean_score=0.969.ckpt](https://diffusion-policy.cs.columbia.edu/data/experiments/low_dim/pusht/diffusion_policy_cnn/train_0/checkpoints/epoch=0550-test_mean_score=0.969.ckpt).
|
||||
|
||||
Run the evaluation script:
|
||||
```console
|
||||
(robodiff)[diffusion_policy]$ python eval.py --checkpoint data/0550-test_mean_score=0.969.ckpt --output_dir data/pusht_eval_output --device cuda:0
|
||||
```
|
||||
|
||||
This will generate the following directory structure:
|
||||
```console
|
||||
(robodiff)[diffusion_policy]$ tree data/pusht_eval_output
|
||||
data/pusht_eval_output
|
||||
├── eval_log.json
|
||||
└── media
|
||||
├── 1fxtno84.mp4
|
||||
├── 224l7jqd.mp4
|
||||
├── 2fo4btlf.mp4
|
||||
├── 2in4cn7a.mp4
|
||||
├── 34b3o2qq.mp4
|
||||
└── 3p7jqn32.mp4
|
||||
|
||||
1 directory, 7 files
|
||||
```
|
||||
|
||||
`eval_log.json` contains metrics that is logged to wandb during training:
|
||||
```console
|
||||
(robodiff)[diffusion_policy]$ cat data/pusht_eval_output/eval_log.json
|
||||
{
|
||||
"test/mean_score": 0.9150393806777066,
|
||||
"test/sim_max_reward_4300000": 1.0,
|
||||
"test/sim_max_reward_4300001": 0.9872969750774386,
|
||||
...
|
||||
"train/sim_video_1": "data/pusht_eval_output//media/2fo4btlf.mp4"
|
||||
}
|
||||
```
|
||||
|
||||
## 🦾 Demo, Training and Eval on a Real Robot
|
||||
Make sure your UR5 robot is running and accepting command from its network interface (emergency stop button within reach at all time), your RealSense cameras plugged in to your workstation (tested with `realsense-viewer`) and your SpaceMouse connected with the `spacenavd` daemon running (verify with `systemctl status spacenavd`).
|
||||
|
||||
@@ -46,6 +46,8 @@ dependencies:
|
||||
- diffusers=0.11.1
|
||||
- av=10.0.0
|
||||
- cmake=3.24.3
|
||||
# trick to avoid cpu affinity issue described in https://github.com/pytorch/pytorch/issues/99625
|
||||
- llvm-openmp=14
|
||||
# trick to force reinstall imagecodecs via pip
|
||||
- imagecodecs==2022.8.8
|
||||
- pip:
|
||||
|
||||
@@ -46,6 +46,8 @@ dependencies:
|
||||
- diffusers=0.11.1
|
||||
- av=10.0.0
|
||||
- cmake=3.24.3
|
||||
# trick to avoid cpu affinity issue described in https://github.com/pytorch/pytorch/issues/99625
|
||||
- llvm-openmp=14
|
||||
# trick to force reinstall imagecodecs via pip
|
||||
- imagecodecs==2022.8.8
|
||||
- pip:
|
||||
|
||||
@@ -158,6 +158,8 @@ class ReplayBuffer:
|
||||
# numpy backend
|
||||
meta = dict()
|
||||
for key, value in src_root['meta'].items():
|
||||
if isinstance(value, zarr.Group):
|
||||
continue
|
||||
if len(value.shape) == 0:
|
||||
meta[key] = np.array(value)
|
||||
else:
|
||||
|
||||
109
diffusion_policy/dataset/mujoco_image_dataset.py
Normal file
109
diffusion_policy/dataset/mujoco_image_dataset.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from typing import Dict
|
||||
import torch
|
||||
import numpy as np
|
||||
import copy
|
||||
from diffusion_policy.common.pytorch_util import dict_apply
|
||||
from diffusion_policy.common.replay_buffer import ReplayBuffer
|
||||
from diffusion_policy.common.sampler import (
|
||||
SequenceSampler, get_val_mask, downsample_mask)
|
||||
from diffusion_policy.model.common.normalizer import LinearNormalizer
|
||||
from diffusion_policy.dataset.base_dataset import BaseImageDataset
|
||||
from diffusion_policy.common.normalize_util import get_image_range_normalizer
|
||||
|
||||
class MujocoImageDataset(BaseImageDataset):
|
||||
def __init__(self,
|
||||
zarr_path,
|
||||
horizon=1,
|
||||
pad_before=0,
|
||||
pad_after=0,
|
||||
seed=42,
|
||||
val_ratio=0.0,
|
||||
max_train_episodes=None
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
self.replay_buffer = ReplayBuffer.copy_from_path(
|
||||
# zarr_path, keys=['img', 'state', 'action'])
|
||||
zarr_path, keys=['robot_0_camera_images', 'robot_0_tcp_xyz_wxyz', 'robot_0_gripper_width', 'action_0_tcp_xyz_wxyz', 'action_0_gripper_width'])
|
||||
val_mask = get_val_mask(
|
||||
n_episodes=self.replay_buffer.n_episodes,
|
||||
val_ratio=val_ratio,
|
||||
seed=seed)
|
||||
train_mask = ~val_mask
|
||||
train_mask = downsample_mask(
|
||||
mask=train_mask,
|
||||
max_n=max_train_episodes,
|
||||
seed=seed)
|
||||
|
||||
self.sampler = SequenceSampler(
|
||||
replay_buffer=self.replay_buffer,
|
||||
sequence_length=horizon,
|
||||
pad_before=pad_before,
|
||||
pad_after=pad_after,
|
||||
episode_mask=train_mask)
|
||||
self.train_mask = train_mask
|
||||
self.horizon = horizon
|
||||
self.pad_before = pad_before
|
||||
self.pad_after = pad_after
|
||||
|
||||
def get_validation_dataset(self):
|
||||
val_set = copy.copy(self)
|
||||
val_set.sampler = SequenceSampler(
|
||||
replay_buffer=self.replay_buffer,
|
||||
sequence_length=self.horizon,
|
||||
pad_before=self.pad_before,
|
||||
pad_after=self.pad_after,
|
||||
episode_mask=~self.train_mask
|
||||
)
|
||||
val_set.train_mask = ~self.train_mask
|
||||
return val_set
|
||||
|
||||
def get_normalizer(self, mode='limits', **kwargs):
|
||||
data = {
|
||||
'action': np.concatenate([self.replay_buffer['action_0_tcp_xyz_wxyz'], self.replay_buffer['action_0_gripper_width']], axis=-1),
|
||||
'agent_pos': np.concatenate([self.replay_buffer['robot_0_tcp_xyz_wxyz'], self.replay_buffer['robot_0_tcp_xyz_wxyz']], axis=-1)
|
||||
}
|
||||
normalizer = LinearNormalizer()
|
||||
normalizer.fit(data=data, last_n_dims=1, mode=mode, **kwargs)
|
||||
normalizer['image'] = get_image_range_normalizer()
|
||||
return normalizer
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.sampler)
|
||||
|
||||
def _sample_to_data(self, sample):
|
||||
# agent_pos = sample['state'][:,:2].astype(np.float32) # (agent_posx2, block_posex3)
|
||||
agent_pos = np.concatenate([sample['robot_0_tcp_xyz_wxyz'], sample['robot_0_gripper_width']], axis=-1).astype(np.float32)
|
||||
agent_action = np.concatenate([sample['action_0_tcp_xyz_wxyz'], sample['action_0_gripper_width']], axis=-1).astype(np.float32)
|
||||
# image = np.moveaxis(sample['img'],-1,1)/255
|
||||
image = np.moveaxis(sample['robot_0_camera_images'].astype(np.float32).squeeze(1),-1,1)/255
|
||||
|
||||
data = {
|
||||
'obs': {
|
||||
'image': image, # T, 3, 224, 224
|
||||
'agent_pos': agent_pos, # T, 8 (x,y,z,qx,qy,qz,qw,gripper_width)
|
||||
},
|
||||
'action': agent_action # T, 8 (x,y,z,qx,qy,qz,qw,gripper_width)
|
||||
}
|
||||
return data
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
||||
sample = self.sampler.sample_sequence(idx)
|
||||
data = self._sample_to_data(sample)
|
||||
torch_data = dict_apply(data, torch.from_numpy)
|
||||
return torch_data
|
||||
|
||||
|
||||
def test():
|
||||
import os
|
||||
zarr_path = os.path.expanduser('/home/yihuai/robotics/repositories/mujoco/mujoco-env/data/collect_heuristic_data/2024-12-24_11-36-15_100episodes/merged_data.zarr')
|
||||
dataset = MujocoImageDataset(zarr_path, horizon=16)
|
||||
print(dataset[0])
|
||||
# from matplotlib import pyplot as plt
|
||||
# normalizer = dataset.get_normalizer()
|
||||
# nactions = normalizer['action'].normalize(dataset.replay_buffer['action'])
|
||||
# diff = np.diff(nactions, axis=0)
|
||||
# dists = np.linalg.norm(np.diff(nactions, axis=0), axis=-1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
@@ -8,8 +8,7 @@ import dill
|
||||
import math
|
||||
import wandb.sdk.data_types.video as wv
|
||||
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
|
||||
from diffusion_policy.gym_util.async_vector_env import AsyncVectorEnv
|
||||
# from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
|
||||
from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
|
||||
from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
|
||||
from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
|
||||
|
||||
@@ -121,7 +120,9 @@ class PushTImageRunner(BaseImageRunner):
|
||||
env_prefixs.append('test/')
|
||||
env_init_fn_dills.append(dill.dumps(init_fn))
|
||||
|
||||
env = AsyncVectorEnv(env_fns)
|
||||
# This environment can run without multiprocessing, which avoids
|
||||
# shared-memory and semaphore restrictions on some machines.
|
||||
env = SyncVectorEnv(env_fns)
|
||||
|
||||
# test env
|
||||
# env.reset(seed=env_seeds)
|
||||
|
||||
@@ -60,17 +60,44 @@ class SyncVectorEnv(VectorEnv):
|
||||
for env, seed in zip(self.envs, seeds):
|
||||
env.seed(seed)
|
||||
|
||||
def reset_wait(self):
|
||||
def reset_async(self, seed=None, return_info=False, options=None):
|
||||
if seed is None:
|
||||
seeds = [None for _ in range(self.num_envs)]
|
||||
elif isinstance(seed, int):
|
||||
seeds = [seed + i for i in range(self.num_envs)]
|
||||
else:
|
||||
seeds = list(seed)
|
||||
assert len(seeds) == self.num_envs
|
||||
self._reset_seeds = seeds
|
||||
self._reset_return_info = return_info
|
||||
self._reset_options = options
|
||||
|
||||
def reset_wait(self, seed=None, return_info=False, options=None):
|
||||
seeds = getattr(self, '_reset_seeds', None)
|
||||
if seeds is None:
|
||||
if seed is None:
|
||||
seeds = [None for _ in range(self.num_envs)]
|
||||
elif isinstance(seed, int):
|
||||
seeds = [seed + i for i in range(self.num_envs)]
|
||||
else:
|
||||
seeds = list(seed)
|
||||
self._dones[:] = False
|
||||
observations = []
|
||||
for env in self.envs:
|
||||
infos = []
|
||||
for env, seed_i in zip(self.envs, seeds):
|
||||
if seed_i is not None:
|
||||
env.seed(seed_i)
|
||||
observation = env.reset()
|
||||
observations.append(observation)
|
||||
infos.append({})
|
||||
self.observations = concatenate(
|
||||
observations, self.observations, self.single_observation_space
|
||||
self.single_observation_space, observations, self.observations
|
||||
)
|
||||
|
||||
return deepcopy(self.observations) if self.copy else self.observations
|
||||
obs = deepcopy(self.observations) if self.copy else self.observations
|
||||
if return_info:
|
||||
return obs, infos
|
||||
return obs
|
||||
|
||||
def step_async(self, actions):
|
||||
self._actions = actions
|
||||
@@ -84,7 +111,7 @@ class SyncVectorEnv(VectorEnv):
|
||||
observations.append(observation)
|
||||
infos.append(info)
|
||||
self.observations = concatenate(
|
||||
observations, self.observations, self.single_observation_space
|
||||
self.single_observation_space, observations, self.observations
|
||||
)
|
||||
|
||||
return (
|
||||
|
||||
@@ -40,7 +40,7 @@ class RotationTransformer:
|
||||
getattr(pt, f'matrix_to_{from_rep}')
|
||||
]
|
||||
if from_convention is not None:
|
||||
funcs = [functools.partial(func, convernsion=from_convention)
|
||||
funcs = [functools.partial(func, convention=from_convention)
|
||||
for func in funcs]
|
||||
forward_funcs.append(funcs[0])
|
||||
inverse_funcs.append(funcs[1])
|
||||
@@ -51,7 +51,7 @@ class RotationTransformer:
|
||||
getattr(pt, f'{to_rep}_to_matrix')
|
||||
]
|
||||
if to_convention is not None:
|
||||
funcs = [functools.partial(func, convernsion=to_convention)
|
||||
funcs = [functools.partial(func, convention=to_convention)
|
||||
for func in funcs]
|
||||
forward_funcs.append(funcs[0])
|
||||
inverse_funcs.append(funcs[1])
|
||||
|
||||
@@ -256,8 +256,8 @@ class DiffusionTransformerHybridImagePolicy(BaseImagePolicy):
|
||||
# condition through impainting
|
||||
this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
|
||||
nobs_features = self.obs_encoder(this_nobs)
|
||||
# reshape back to B, T, Do
|
||||
nobs_features = nobs_features.reshape(B, T, -1)
|
||||
# reshape back to B, To, Do
|
||||
nobs_features = nobs_features.reshape(B, To, -1)
|
||||
shape = (B, T, Da+Do)
|
||||
cond_data = torch.zeros(size=shape, device=device, dtype=dtype)
|
||||
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
|
||||
|
||||
@@ -247,7 +247,7 @@ class DiffusionUnetHybridImagePolicy(BaseImagePolicy):
|
||||
# condition through impainting
|
||||
this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
|
||||
nobs_features = self.obs_encoder(this_nobs)
|
||||
# reshape back to B, T, Do
|
||||
# reshape back to B, To, Do
|
||||
nobs_features = nobs_features.reshape(B, To, -1)
|
||||
cond_data = torch.zeros(size=(B, T, Da+Do), device=device, dtype=dtype)
|
||||
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
|
||||
|
||||
64
eval.py
Normal file
64
eval.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
Usage:
|
||||
python eval.py --checkpoint data/image/pusht/diffusion_policy_cnn/train_0/checkpoints/latest.ckpt -o data/pusht_eval_output
|
||||
"""
|
||||
|
||||
import sys
|
||||
# use line-buffering for both stdout and stderr
|
||||
sys.stdout = open(sys.stdout.fileno(), mode='w', buffering=1)
|
||||
sys.stderr = open(sys.stderr.fileno(), mode='w', buffering=1)
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import click
|
||||
import hydra
|
||||
import torch
|
||||
import dill
|
||||
import wandb
|
||||
import json
|
||||
from diffusion_policy.workspace.base_workspace import BaseWorkspace
|
||||
|
||||
@click.command()
|
||||
@click.option('-c', '--checkpoint', required=True)
|
||||
@click.option('-o', '--output_dir', required=True)
|
||||
@click.option('-d', '--device', default='cuda:0')
|
||||
def main(checkpoint, output_dir, device):
|
||||
if os.path.exists(output_dir):
|
||||
click.confirm(f"Output path {output_dir} already exists! Overwrite?", abort=True)
|
||||
pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# load checkpoint
|
||||
payload = torch.load(open(checkpoint, 'rb'), pickle_module=dill)
|
||||
cfg = payload['cfg']
|
||||
cls = hydra.utils.get_class(cfg._target_)
|
||||
workspace = cls(cfg, output_dir=output_dir)
|
||||
workspace: BaseWorkspace
|
||||
workspace.load_payload(payload, exclude_keys=None, include_keys=None)
|
||||
|
||||
# get policy from workspace
|
||||
policy = workspace.model
|
||||
if cfg.training.use_ema:
|
||||
policy = workspace.ema_model
|
||||
|
||||
device = torch.device(device)
|
||||
policy.to(device)
|
||||
policy.eval()
|
||||
|
||||
# run eval
|
||||
env_runner = hydra.utils.instantiate(
|
||||
cfg.task.env_runner,
|
||||
output_dir=output_dir)
|
||||
runner_log = env_runner.run(policy)
|
||||
|
||||
# dump log to json
|
||||
json_log = dict()
|
||||
for key, value in runner_log.items():
|
||||
if isinstance(value, wandb.sdk.data_types.video.Video):
|
||||
json_log[key] = value._path
|
||||
else:
|
||||
json_log[key] = value
|
||||
out_path = os.path.join(output_dir, 'eval_log.json')
|
||||
json.dump(json_log, open(out_path, 'w'), indent=2, sort_keys=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
182
image_pusht_diffusion_policy_cnn.yaml
Normal file
182
image_pusht_diffusion_policy_cnn.yaml
Normal file
@@ -0,0 +1,182 @@
|
||||
_target_: diffusion_policy.workspace.train_diffusion_unet_hybrid_workspace.TrainDiffusionUnetHybridWorkspace
|
||||
checkpoint:
|
||||
save_last_ckpt: true
|
||||
save_last_snapshot: false
|
||||
topk:
|
||||
format_str: epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt
|
||||
k: 5
|
||||
mode: max
|
||||
monitor_key: test_mean_score
|
||||
dataloader:
|
||||
batch_size: 64
|
||||
num_workers: 8
|
||||
persistent_workers: false
|
||||
pin_memory: true
|
||||
shuffle: true
|
||||
dataset_obs_steps: 2
|
||||
ema:
|
||||
_target_: diffusion_policy.model.diffusion.ema_model.EMAModel
|
||||
inv_gamma: 1.0
|
||||
max_value: 0.9999
|
||||
min_value: 0.0
|
||||
power: 0.75
|
||||
update_after_step: 0
|
||||
exp_name: default
|
||||
horizon: 16
|
||||
keypoint_visible_rate: 1.0
|
||||
logging:
|
||||
group: null
|
||||
id: null
|
||||
mode: online
|
||||
name: 2023.01.16-20.20.06_train_diffusion_unet_hybrid_pusht_image
|
||||
project: diffusion_policy_debug
|
||||
resume: true
|
||||
tags:
|
||||
- train_diffusion_unet_hybrid
|
||||
- pusht_image
|
||||
- default
|
||||
multi_run:
|
||||
run_dir: data/outputs/2023.01.16/20.20.06_train_diffusion_unet_hybrid_pusht_image
|
||||
wandb_name_base: 2023.01.16-20.20.06_train_diffusion_unet_hybrid_pusht_image
|
||||
n_action_steps: 8
|
||||
n_latency_steps: 0
|
||||
n_obs_steps: 2
|
||||
name: train_diffusion_unet_hybrid
|
||||
obs_as_global_cond: true
|
||||
optimizer:
|
||||
_target_: torch.optim.AdamW
|
||||
betas:
|
||||
- 0.95
|
||||
- 0.999
|
||||
eps: 1.0e-08
|
||||
lr: 0.0001
|
||||
weight_decay: 1.0e-06
|
||||
past_action_visible: false
|
||||
policy:
|
||||
_target_: diffusion_policy.policy.diffusion_unet_hybrid_image_policy.DiffusionUnetHybridImagePolicy
|
||||
cond_predict_scale: true
|
||||
crop_shape:
|
||||
- 84
|
||||
- 84
|
||||
diffusion_step_embed_dim: 128
|
||||
down_dims:
|
||||
- 512
|
||||
- 1024
|
||||
- 2048
|
||||
eval_fixed_crop: true
|
||||
horizon: 16
|
||||
kernel_size: 5
|
||||
n_action_steps: 8
|
||||
n_groups: 8
|
||||
n_obs_steps: 2
|
||||
noise_scheduler:
|
||||
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
||||
beta_end: 0.02
|
||||
beta_schedule: squaredcos_cap_v2
|
||||
beta_start: 0.0001
|
||||
clip_sample: true
|
||||
num_train_timesteps: 100
|
||||
prediction_type: epsilon
|
||||
variance_type: fixed_small
|
||||
num_inference_steps: 100
|
||||
obs_as_global_cond: true
|
||||
obs_encoder_group_norm: true
|
||||
shape_meta:
|
||||
action:
|
||||
shape:
|
||||
- 2
|
||||
obs:
|
||||
agent_pos:
|
||||
shape:
|
||||
- 2
|
||||
type: low_dim
|
||||
image:
|
||||
shape:
|
||||
- 3
|
||||
- 96
|
||||
- 96
|
||||
type: rgb
|
||||
shape_meta:
|
||||
action:
|
||||
shape:
|
||||
- 2
|
||||
obs:
|
||||
agent_pos:
|
||||
shape:
|
||||
- 2
|
||||
type: low_dim
|
||||
image:
|
||||
shape:
|
||||
- 3
|
||||
- 96
|
||||
- 96
|
||||
type: rgb
|
||||
task:
|
||||
dataset:
|
||||
_target_: diffusion_policy.dataset.pusht_image_dataset.PushTImageDataset
|
||||
horizon: 16
|
||||
max_train_episodes: 90
|
||||
pad_after: 7
|
||||
pad_before: 1
|
||||
seed: 42
|
||||
val_ratio: 0.02
|
||||
zarr_path: data/pusht/pusht_cchi_v7_replay.zarr
|
||||
env_runner:
|
||||
_target_: diffusion_policy.env_runner.pusht_image_runner.PushTImageRunner
|
||||
fps: 10
|
||||
legacy_test: true
|
||||
max_steps: 300
|
||||
n_action_steps: 8
|
||||
n_envs: null
|
||||
n_obs_steps: 2
|
||||
n_test: 50
|
||||
n_test_vis: 4
|
||||
n_train: 6
|
||||
n_train_vis: 2
|
||||
past_action: false
|
||||
test_start_seed: 100000
|
||||
train_start_seed: 0
|
||||
image_shape:
|
||||
- 3
|
||||
- 96
|
||||
- 96
|
||||
name: pusht_image
|
||||
shape_meta:
|
||||
action:
|
||||
shape:
|
||||
- 2
|
||||
obs:
|
||||
agent_pos:
|
||||
shape:
|
||||
- 2
|
||||
type: low_dim
|
||||
image:
|
||||
shape:
|
||||
- 3
|
||||
- 96
|
||||
- 96
|
||||
type: rgb
|
||||
task_name: pusht_image
|
||||
training:
|
||||
checkpoint_every: 50
|
||||
debug: false
|
||||
device: cuda:0
|
||||
gradient_accumulate_every: 1
|
||||
lr_scheduler: cosine
|
||||
lr_warmup_steps: 500
|
||||
max_train_steps: null
|
||||
max_val_steps: null
|
||||
num_epochs: 3050
|
||||
resume: true
|
||||
rollout_every: 50
|
||||
sample_every: 5
|
||||
seed: 42
|
||||
tqdm_interval_sec: 1.0
|
||||
use_ema: true
|
||||
val_every: 1
|
||||
val_dataloader:
|
||||
batch_size: 64
|
||||
num_workers: 8
|
||||
persistent_workers: false
|
||||
pin_memory: true
|
||||
shuffle: false
|
||||
38
requirements-pusht-5090.txt
Normal file
38
requirements-pusht-5090.txt
Normal file
@@ -0,0 +1,38 @@
|
||||
# Direct package pins for the canonical PushT image workflow on host 5090.
|
||||
# Torch/TorchVision/Torchaudio are installed separately from the cu128 index in setup_uv_pusht_5090.sh.
|
||||
|
||||
numpy==1.26.4
|
||||
scipy==1.11.4
|
||||
numba==0.59.1
|
||||
llvmlite==0.42.0
|
||||
cffi==1.15.1
|
||||
cython==0.29.32
|
||||
h5py==3.8.0
|
||||
pandas==2.2.3
|
||||
zarr==2.12.0
|
||||
numcodecs==0.10.2
|
||||
hydra-core==1.2.0
|
||||
einops==0.4.1
|
||||
tqdm==4.64.1
|
||||
dill==0.3.5.1
|
||||
scikit-video==1.1.11
|
||||
scikit-image==0.19.3
|
||||
gym==0.23.1
|
||||
pymunk==6.2.1
|
||||
wandb==0.13.3
|
||||
threadpoolctl==3.1.0
|
||||
shapely==1.8.5.post1
|
||||
imageio==2.22.0
|
||||
imageio-ffmpeg==0.4.7
|
||||
termcolor==2.0.1
|
||||
tensorboard==2.10.1
|
||||
tensorboardx==2.5.1
|
||||
psutil==7.2.2
|
||||
click==8.1.8
|
||||
boto3==1.24.96
|
||||
diffusers==0.11.1
|
||||
huggingface-hub==0.10.1
|
||||
av==14.0.1
|
||||
pygame==2.5.2
|
||||
robomimic==0.2.0
|
||||
opencv-python-headless==4.10.0.84
|
||||
20
setup_uv_pusht_5090.sh
Executable file
20
setup_uv_pusht_5090.sh
Executable file
@@ -0,0 +1,20 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
ROOT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
cd "$ROOT_DIR"
|
||||
|
||||
export UV_CACHE_DIR="${UV_CACHE_DIR:-$ROOT_DIR/.uv-cache}"
|
||||
export UV_PYTHON_INSTALL_DIR="${UV_PYTHON_INSTALL_DIR:-$ROOT_DIR/.uv-python}"
|
||||
|
||||
uv venv --python 3.9 .venv
|
||||
source .venv/bin/activate
|
||||
|
||||
uv pip install --upgrade pip wheel setuptools==80.9.0
|
||||
uv pip install --python .venv/bin/python \
|
||||
--index-url https://download.pytorch.org/whl/cu128 \
|
||||
torch==2.8.0+cu128 torchvision==0.23.0+cu128 torchaudio==2.8.0+cu128
|
||||
uv pip install --python .venv/bin/python -r requirements-pusht-5090.txt
|
||||
uv pip install --python .venv/bin/python -e .
|
||||
|
||||
echo "uv environment ready at $ROOT_DIR/.venv"
|
||||
Reference in New Issue
Block a user