Compare commits
13 Commits
749db2ce9c
...
DiT-imageP
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
42dc29a2cb | ||
|
|
79f31940c4 | ||
|
|
2aa06c8917 | ||
|
|
08c1950c6d | ||
|
|
5ba07ac666 | ||
|
|
548a52bbb1 | ||
|
|
de4384e84a | ||
|
|
7dd9dc417a | ||
|
|
5aa9996fdc | ||
|
|
5c3d54fca3 | ||
|
|
a98e74873b | ||
|
|
68eef44d3e | ||
|
|
c52bac42ee |
68
AGENTS.md
Normal file
68
AGENTS.md
Normal file
@@ -0,0 +1,68 @@
|
||||
# Agent Notes
|
||||
|
||||
## Purpose
|
||||
`~/diffusion_policy` is the Diffusion Policy training repo. The main workflow here is Hydra-driven training via `train.py`, with the canonical PushT image experiment configured by `image_pusht_diffusion_policy_cnn.yaml`.
|
||||
|
||||
## Top Level
|
||||
- `diffusion_policy/`: core code, configs, datasets, env runners, workspaces.
|
||||
- `data/`: local datasets, outputs, checkpoints, run logs.
|
||||
- `train.py`: main training entrypoint.
|
||||
- `eval.py`: checkpoint evaluation entrypoint.
|
||||
- `image_pusht_diffusion_policy_cnn.yaml`: canonical single-seed PushT image config from the README path.
|
||||
- `.venv/`: local `uv`-managed virtualenv.
|
||||
- `.uv-cache/`, `.uv-python/`: local `uv` cache and Python install state.
|
||||
- `README.md`: upstream instructions and canonical commands.
|
||||
|
||||
## Canonical PushT Image Path
|
||||
- Entrypoint: `python train.py --config-dir=. --config-name=image_pusht_diffusion_policy_cnn.yaml`
|
||||
- Dataset path in config: `data/pusht/pusht_cchi_v7_replay.zarr`
|
||||
- README canonical device override: `training.device=cuda:0`
|
||||
|
||||
## Data
|
||||
- PushT archive currently present at `data/pusht.zip`
|
||||
- Unpacked dataset used by training: `data/pusht/pusht_cchi_v7_replay.zarr`
|
||||
|
||||
## Local Compatibility Adjustments
|
||||
- `diffusion_policy/env_runner/pusht_image_runner.py` now uses `SyncVectorEnv` instead of `AsyncVectorEnv`.
|
||||
Reason: avoid shared-memory and semaphore failures on this host/session.
|
||||
- `diffusion_policy/gym_util/sync_vector_env.py` has local compatibility changes:
|
||||
- added `reset_async`
|
||||
- seeded `reset_wait`
|
||||
- updated `concatenate(...)` call order for the current `gym` API
|
||||
|
||||
## Environment Expectations
|
||||
- Use the local `uv` env at `.venv`
|
||||
- Verified local Python: `3.9.25`
|
||||
- Verified local Torch stack: `torch 2.8.0+cu128`, `torchvision 0.23.0+cu128`
|
||||
- Other key installed versions verified in `.venv`:
|
||||
- `gym 0.23.1`
|
||||
- `hydra-core 1.2.0`
|
||||
- `diffusers 0.11.1`
|
||||
- `huggingface_hub 0.10.1`
|
||||
- `wandb 0.13.3`
|
||||
- `zarr 2.12.0`
|
||||
- `numcodecs 0.10.2`
|
||||
- `av 14.0.1`
|
||||
- Important note: this shell currently reports `torch.cuda.is_available() == False`, so always verify CUDA access in the current session before assuming GPU is usable.
|
||||
|
||||
## Logging And Outputs
|
||||
- Hydra run outputs: `data/outputs/...`
|
||||
- Per-run files to check first:
|
||||
- `.hydra/overrides.yaml`
|
||||
- `logs.json.txt`
|
||||
- `train.log`
|
||||
- `checkpoints/latest.ckpt`
|
||||
- Extra launcher logs may live under `data/run_logs/`
|
||||
|
||||
## Practical Guidance
|
||||
- Inspect with `rg`, `sed`, and existing Hydra output folders before changing code.
|
||||
- Prefer config overrides before code edits.
|
||||
- On this host, start from these safety overrides unless revalidated:
|
||||
- `logging.mode=offline`
|
||||
- `dataloader.num_workers=0`
|
||||
- `val_dataloader.num_workers=0`
|
||||
- `task.env_runner.n_envs=1`
|
||||
- `task.env_runner.n_test_vis=0`
|
||||
- `task.env_runner.n_train_vis=0`
|
||||
- If a run fails, inspect `.hydra/overrides.yaml`, then `logs.json.txt`, then `train.log`.
|
||||
- Avoid driver or system changes unless the repo-local path is clearly blocked.
|
||||
108
PUSHT_REPRO_5090.md
Normal file
108
PUSHT_REPRO_5090.md
Normal file
@@ -0,0 +1,108 @@
|
||||
# PushT Repro On 5090
|
||||
|
||||
## Goal
|
||||
Reproduce the canonical single-seed image PushT experiment from this repo in `~/diffusion_policy` using `image_pusht_diffusion_policy_cnn.yaml`.
|
||||
|
||||
## Current Verified Local Setup
|
||||
- Virtualenv: `./.venv` managed with `uv`
|
||||
- Python: `3.9.25`
|
||||
- Torch stack: `torch 2.8.0+cu128`, `torchvision 0.23.0+cu128`
|
||||
- Version strategy used here:
|
||||
- newer Torch/CUDA stack for current 5090-class hardware support
|
||||
- keep older repo-era packages where they are still required by the code
|
||||
- Verified key pins in `.venv`:
|
||||
- `numpy 1.26.4`
|
||||
- `gym 0.23.1`
|
||||
- `hydra-core 1.2.0`
|
||||
- `diffusers 0.11.1`
|
||||
- `huggingface_hub 0.10.1`
|
||||
- `wandb 0.13.3`
|
||||
- `zarr 2.12.0`
|
||||
- `numcodecs 0.10.2`
|
||||
- `av 14.0.1`
|
||||
- `robomimic 0.2.0`
|
||||
|
||||
## Dataset
|
||||
- README source: `https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip`
|
||||
- Local archive currently present: `data/pusht.zip`
|
||||
- Unpacked dataset used by the config: `data/pusht/pusht_cchi_v7_replay.zarr`
|
||||
|
||||
## Repo-Local Code Adjustments
|
||||
- `diffusion_policy/env_runner/pusht_image_runner.py`
|
||||
- switched PushT image evaluation from `AsyncVectorEnv` to `SyncVectorEnv`
|
||||
- `diffusion_policy/gym_util/sync_vector_env.py`
|
||||
- added `reset_async`
|
||||
- added seeded `reset_wait`
|
||||
- updated `concatenate(...)` call order for current `gym`
|
||||
|
||||
These changes were needed to keep PushT evaluation working without the async shared-memory path.
|
||||
|
||||
## Validated GPU Smoke Command
|
||||
This route is verified by `data/outputs/gpu_smoke2___pusht_gpu_smoke`, which contains `logs.json.txt` plus checkpoints:
|
||||
|
||||
```bash
|
||||
.venv/bin/python train.py \
|
||||
--config-dir=. \
|
||||
--config-name=image_pusht_diffusion_policy_cnn.yaml \
|
||||
training.seed=42 \
|
||||
training.device=cuda:0 \
|
||||
logging.mode=offline \
|
||||
dataloader.num_workers=0 \
|
||||
val_dataloader.num_workers=0 \
|
||||
task.env_runner.n_envs=1 \
|
||||
training.debug=true \
|
||||
task.env_runner.n_test=2 \
|
||||
task.env_runner.n_test_vis=0 \
|
||||
task.env_runner.n_train=1 \
|
||||
task.env_runner.n_train_vis=0 \
|
||||
task.env_runner.max_steps=20
|
||||
```
|
||||
|
||||
## Practical Full Training Command Used Here
|
||||
This matches the longer GPU run under `data/outputs/2026.03.13/15.37.00_train_diffusion_unet_hybrid_pusht_image_gpu_seed42`:
|
||||
|
||||
```bash
|
||||
.venv/bin/python train.py \
|
||||
--config-dir=. \
|
||||
--config-name=image_pusht_diffusion_policy_cnn.yaml \
|
||||
training.seed=42 \
|
||||
training.device=cuda:0 \
|
||||
logging.mode=offline \
|
||||
dataloader.num_workers=0 \
|
||||
val_dataloader.num_workers=0 \
|
||||
task.env_runner.n_envs=1 \
|
||||
task.env_runner.n_test_vis=0 \
|
||||
task.env_runner.n_train_vis=0 \
|
||||
hydra.run.dir=data/outputs/2026.03.13/15.37.00_train_diffusion_unet_hybrid_pusht_image_gpu_seed42
|
||||
```
|
||||
|
||||
## Why These Overrides Were Used
|
||||
- `logging.mode=offline`
|
||||
- avoids needing a W&B login and still leaves local run metadata in the output dir
|
||||
- `dataloader.num_workers=0` and `val_dataloader.num_workers=0`
|
||||
- avoids extra multiprocessing on this host
|
||||
- `task.env_runner.n_envs=1`
|
||||
- keeps PushT eval on the serial `SyncVectorEnv` path
|
||||
- `task.env_runner.n_test_vis=0` and `task.env_runner.n_train_vis=0`
|
||||
- avoids video-writing issues on this stack
|
||||
- one earlier GPU run with default vis settings logged libav/libx264 `profile=high` errors in `data/outputs/_train_diffusion_unet_hybrid_pusht_image_gpu_seed42/train.log`
|
||||
|
||||
## Output Locations
|
||||
- Smoke run:
|
||||
- `data/outputs/gpu_smoke2___pusht_gpu_smoke`
|
||||
- Longer GPU run:
|
||||
- `data/outputs/2026.03.13/15.37.00_train_diffusion_unet_hybrid_pusht_image_gpu_seed42`
|
||||
- Files to inspect inside a run:
|
||||
- `.hydra/overrides.yaml`
|
||||
- `logs.json.txt`
|
||||
- `train.log`
|
||||
- `checkpoints/latest.ckpt`
|
||||
|
||||
## Known Caveats
|
||||
- The default config is still tuned for older assumptions:
|
||||
- `logging.mode=online`
|
||||
- `dataloader.num_workers=8`
|
||||
- `task.env_runner.n_envs=null`
|
||||
- `task.env_runner.n_test_vis=4`
|
||||
- `task.env_runner.n_train_vis=2`
|
||||
- In this shell, `torch.cuda.is_available()` currently reports `False` even though the repo contains validated GPU smoke/full run artifacts. Re-check device visibility in the current session before restarting a GPU run.
|
||||
35
README.md
35
README.md
@@ -202,6 +202,41 @@ data/outputs/2023.03.01/22.13.58_train_diffusion_unet_hybrid_pusht_image
|
||||
|
||||
7 directories, 16 files
|
||||
```
|
||||
### 🆕 Evaluate Pre-trained Checkpoints
|
||||
Download a checkpoint from the published training log folders, such as [https://diffusion-policy.cs.columbia.edu/data/experiments/low_dim/pusht/diffusion_policy_cnn/train_0/checkpoints/epoch=0550-test_mean_score=0.969.ckpt](https://diffusion-policy.cs.columbia.edu/data/experiments/low_dim/pusht/diffusion_policy_cnn/train_0/checkpoints/epoch=0550-test_mean_score=0.969.ckpt).
|
||||
|
||||
Run the evaluation script:
|
||||
```console
|
||||
(robodiff)[diffusion_policy]$ python eval.py --checkpoint data/0550-test_mean_score=0.969.ckpt --output_dir data/pusht_eval_output --device cuda:0
|
||||
```
|
||||
|
||||
This will generate the following directory structure:
|
||||
```console
|
||||
(robodiff)[diffusion_policy]$ tree data/pusht_eval_output
|
||||
data/pusht_eval_output
|
||||
├── eval_log.json
|
||||
└── media
|
||||
├── 1fxtno84.mp4
|
||||
├── 224l7jqd.mp4
|
||||
├── 2fo4btlf.mp4
|
||||
├── 2in4cn7a.mp4
|
||||
├── 34b3o2qq.mp4
|
||||
└── 3p7jqn32.mp4
|
||||
|
||||
1 directory, 7 files
|
||||
```
|
||||
|
||||
`eval_log.json` contains metrics that is logged to wandb during training:
|
||||
```console
|
||||
(robodiff)[diffusion_policy]$ cat data/pusht_eval_output/eval_log.json
|
||||
{
|
||||
"test/mean_score": 0.9150393806777066,
|
||||
"test/sim_max_reward_4300000": 1.0,
|
||||
"test/sim_max_reward_4300001": 0.9872969750774386,
|
||||
...
|
||||
"train/sim_video_1": "data/pusht_eval_output//media/2fo4btlf.mp4"
|
||||
}
|
||||
```
|
||||
|
||||
## 🦾 Demo, Training and Eval on a Real Robot
|
||||
Make sure your UR5 robot is running and accepting command from its network interface (emergency stop button within reach at all time), your RealSense cameras plugged in to your workstation (tested with `realsense-viewer`) and your SpaceMouse connected with the `spacenavd` daemon running (verify with `systemctl status spacenavd`).
|
||||
|
||||
@@ -46,6 +46,8 @@ dependencies:
|
||||
- diffusers=0.11.1
|
||||
- av=10.0.0
|
||||
- cmake=3.24.3
|
||||
# trick to avoid cpu affinity issue described in https://github.com/pytorch/pytorch/issues/99625
|
||||
- llvm-openmp=14
|
||||
# trick to force reinstall imagecodecs via pip
|
||||
- imagecodecs==2022.8.8
|
||||
- pip:
|
||||
|
||||
@@ -46,6 +46,8 @@ dependencies:
|
||||
- diffusers=0.11.1
|
||||
- av=10.0.0
|
||||
- cmake=3.24.3
|
||||
# trick to avoid cpu affinity issue described in https://github.com/pytorch/pytorch/issues/99625
|
||||
- llvm-openmp=14
|
||||
# trick to force reinstall imagecodecs via pip
|
||||
- imagecodecs==2022.8.8
|
||||
- pip:
|
||||
|
||||
@@ -158,6 +158,8 @@ class ReplayBuffer:
|
||||
# numpy backend
|
||||
meta = dict()
|
||||
for key, value in src_root['meta'].items():
|
||||
if isinstance(value, zarr.Group):
|
||||
continue
|
||||
if len(value.shape) == 0:
|
||||
meta[key] = np.array(value)
|
||||
else:
|
||||
|
||||
109
diffusion_policy/dataset/mujoco_image_dataset.py
Normal file
109
diffusion_policy/dataset/mujoco_image_dataset.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from typing import Dict
|
||||
import torch
|
||||
import numpy as np
|
||||
import copy
|
||||
from diffusion_policy.common.pytorch_util import dict_apply
|
||||
from diffusion_policy.common.replay_buffer import ReplayBuffer
|
||||
from diffusion_policy.common.sampler import (
|
||||
SequenceSampler, get_val_mask, downsample_mask)
|
||||
from diffusion_policy.model.common.normalizer import LinearNormalizer
|
||||
from diffusion_policy.dataset.base_dataset import BaseImageDataset
|
||||
from diffusion_policy.common.normalize_util import get_image_range_normalizer
|
||||
|
||||
class MujocoImageDataset(BaseImageDataset):
|
||||
def __init__(self,
|
||||
zarr_path,
|
||||
horizon=1,
|
||||
pad_before=0,
|
||||
pad_after=0,
|
||||
seed=42,
|
||||
val_ratio=0.0,
|
||||
max_train_episodes=None
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
self.replay_buffer = ReplayBuffer.copy_from_path(
|
||||
# zarr_path, keys=['img', 'state', 'action'])
|
||||
zarr_path, keys=['robot_0_camera_images', 'robot_0_tcp_xyz_wxyz', 'robot_0_gripper_width', 'action_0_tcp_xyz_wxyz', 'action_0_gripper_width'])
|
||||
val_mask = get_val_mask(
|
||||
n_episodes=self.replay_buffer.n_episodes,
|
||||
val_ratio=val_ratio,
|
||||
seed=seed)
|
||||
train_mask = ~val_mask
|
||||
train_mask = downsample_mask(
|
||||
mask=train_mask,
|
||||
max_n=max_train_episodes,
|
||||
seed=seed)
|
||||
|
||||
self.sampler = SequenceSampler(
|
||||
replay_buffer=self.replay_buffer,
|
||||
sequence_length=horizon,
|
||||
pad_before=pad_before,
|
||||
pad_after=pad_after,
|
||||
episode_mask=train_mask)
|
||||
self.train_mask = train_mask
|
||||
self.horizon = horizon
|
||||
self.pad_before = pad_before
|
||||
self.pad_after = pad_after
|
||||
|
||||
def get_validation_dataset(self):
|
||||
val_set = copy.copy(self)
|
||||
val_set.sampler = SequenceSampler(
|
||||
replay_buffer=self.replay_buffer,
|
||||
sequence_length=self.horizon,
|
||||
pad_before=self.pad_before,
|
||||
pad_after=self.pad_after,
|
||||
episode_mask=~self.train_mask
|
||||
)
|
||||
val_set.train_mask = ~self.train_mask
|
||||
return val_set
|
||||
|
||||
def get_normalizer(self, mode='limits', **kwargs):
|
||||
data = {
|
||||
'action': np.concatenate([self.replay_buffer['action_0_tcp_xyz_wxyz'], self.replay_buffer['action_0_gripper_width']], axis=-1),
|
||||
'agent_pos': np.concatenate([self.replay_buffer['robot_0_tcp_xyz_wxyz'], self.replay_buffer['robot_0_tcp_xyz_wxyz']], axis=-1)
|
||||
}
|
||||
normalizer = LinearNormalizer()
|
||||
normalizer.fit(data=data, last_n_dims=1, mode=mode, **kwargs)
|
||||
normalizer['image'] = get_image_range_normalizer()
|
||||
return normalizer
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.sampler)
|
||||
|
||||
def _sample_to_data(self, sample):
|
||||
# agent_pos = sample['state'][:,:2].astype(np.float32) # (agent_posx2, block_posex3)
|
||||
agent_pos = np.concatenate([sample['robot_0_tcp_xyz_wxyz'], sample['robot_0_gripper_width']], axis=-1).astype(np.float32)
|
||||
agent_action = np.concatenate([sample['action_0_tcp_xyz_wxyz'], sample['action_0_gripper_width']], axis=-1).astype(np.float32)
|
||||
# image = np.moveaxis(sample['img'],-1,1)/255
|
||||
image = np.moveaxis(sample['robot_0_camera_images'].astype(np.float32).squeeze(1),-1,1)/255
|
||||
|
||||
data = {
|
||||
'obs': {
|
||||
'image': image, # T, 3, 224, 224
|
||||
'agent_pos': agent_pos, # T, 8 (x,y,z,qx,qy,qz,qw,gripper_width)
|
||||
},
|
||||
'action': agent_action # T, 8 (x,y,z,qx,qy,qz,qw,gripper_width)
|
||||
}
|
||||
return data
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
||||
sample = self.sampler.sample_sequence(idx)
|
||||
data = self._sample_to_data(sample)
|
||||
torch_data = dict_apply(data, torch.from_numpy)
|
||||
return torch_data
|
||||
|
||||
|
||||
def test():
|
||||
import os
|
||||
zarr_path = os.path.expanduser('/home/yihuai/robotics/repositories/mujoco/mujoco-env/data/collect_heuristic_data/2024-12-24_11-36-15_100episodes/merged_data.zarr')
|
||||
dataset = MujocoImageDataset(zarr_path, horizon=16)
|
||||
print(dataset[0])
|
||||
# from matplotlib import pyplot as plt
|
||||
# normalizer = dataset.get_normalizer()
|
||||
# nactions = normalizer['action'].normalize(dataset.replay_buffer['action'])
|
||||
# diff = np.diff(nactions, axis=0)
|
||||
# dists = np.linalg.norm(np.diff(nactions, axis=0), axis=-1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
@@ -8,8 +8,7 @@ import dill
|
||||
import math
|
||||
import wandb.sdk.data_types.video as wv
|
||||
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
|
||||
from diffusion_policy.gym_util.async_vector_env import AsyncVectorEnv
|
||||
# from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
|
||||
from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
|
||||
from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
|
||||
from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
|
||||
|
||||
@@ -121,7 +120,9 @@ class PushTImageRunner(BaseImageRunner):
|
||||
env_prefixs.append('test/')
|
||||
env_init_fn_dills.append(dill.dumps(init_fn))
|
||||
|
||||
env = AsyncVectorEnv(env_fns)
|
||||
# This environment can run without multiprocessing, which avoids
|
||||
# shared-memory and semaphore restrictions on some machines.
|
||||
env = SyncVectorEnv(env_fns)
|
||||
|
||||
# test env
|
||||
# env.reset(seed=env_seeds)
|
||||
|
||||
@@ -8,8 +8,7 @@ import dill
|
||||
import math
|
||||
import wandb.sdk.data_types.video as wv
|
||||
from diffusion_policy.env.pusht.pusht_keypoints_env import PushTKeypointsEnv
|
||||
from diffusion_policy.gym_util.async_vector_env import AsyncVectorEnv
|
||||
# from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
|
||||
from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
|
||||
from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
|
||||
from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
|
||||
|
||||
@@ -133,7 +132,7 @@ class PushTKeypointsRunner(BaseLowdimRunner):
|
||||
env_prefixs.append('test/')
|
||||
env_init_fn_dills.append(dill.dumps(init_fn))
|
||||
|
||||
env = AsyncVectorEnv(env_fns)
|
||||
env = SyncVectorEnv(env_fns)
|
||||
|
||||
# test env
|
||||
# env.reset(seed=env_seeds)
|
||||
|
||||
@@ -60,17 +60,44 @@ class SyncVectorEnv(VectorEnv):
|
||||
for env, seed in zip(self.envs, seeds):
|
||||
env.seed(seed)
|
||||
|
||||
def reset_wait(self):
|
||||
def reset_async(self, seed=None, return_info=False, options=None):
|
||||
if seed is None:
|
||||
seeds = [None for _ in range(self.num_envs)]
|
||||
elif isinstance(seed, int):
|
||||
seeds = [seed + i for i in range(self.num_envs)]
|
||||
else:
|
||||
seeds = list(seed)
|
||||
assert len(seeds) == self.num_envs
|
||||
self._reset_seeds = seeds
|
||||
self._reset_return_info = return_info
|
||||
self._reset_options = options
|
||||
|
||||
def reset_wait(self, seed=None, return_info=False, options=None):
|
||||
seeds = getattr(self, '_reset_seeds', None)
|
||||
if seeds is None:
|
||||
if seed is None:
|
||||
seeds = [None for _ in range(self.num_envs)]
|
||||
elif isinstance(seed, int):
|
||||
seeds = [seed + i for i in range(self.num_envs)]
|
||||
else:
|
||||
seeds = list(seed)
|
||||
self._dones[:] = False
|
||||
observations = []
|
||||
for env in self.envs:
|
||||
infos = []
|
||||
for env, seed_i in zip(self.envs, seeds):
|
||||
if seed_i is not None:
|
||||
env.seed(seed_i)
|
||||
observation = env.reset()
|
||||
observations.append(observation)
|
||||
infos.append({})
|
||||
self.observations = concatenate(
|
||||
observations, self.observations, self.single_observation_space
|
||||
self.single_observation_space, observations, self.observations
|
||||
)
|
||||
|
||||
return deepcopy(self.observations) if self.copy else self.observations
|
||||
obs = deepcopy(self.observations) if self.copy else self.observations
|
||||
if return_info:
|
||||
return obs, infos
|
||||
return obs
|
||||
|
||||
def step_async(self, actions):
|
||||
self._actions = actions
|
||||
@@ -84,7 +111,7 @@ class SyncVectorEnv(VectorEnv):
|
||||
observations.append(observation)
|
||||
infos.append(info)
|
||||
self.observations = concatenate(
|
||||
observations, self.observations, self.single_observation_space
|
||||
self.single_observation_space, observations, self.observations
|
||||
)
|
||||
|
||||
return (
|
||||
|
||||
@@ -40,7 +40,7 @@ class RotationTransformer:
|
||||
getattr(pt, f'matrix_to_{from_rep}')
|
||||
]
|
||||
if from_convention is not None:
|
||||
funcs = [functools.partial(func, convernsion=from_convention)
|
||||
funcs = [functools.partial(func, convention=from_convention)
|
||||
for func in funcs]
|
||||
forward_funcs.append(funcs[0])
|
||||
inverse_funcs.append(funcs[1])
|
||||
@@ -51,7 +51,7 @@ class RotationTransformer:
|
||||
getattr(pt, f'{to_rep}_to_matrix')
|
||||
]
|
||||
if to_convention is not None:
|
||||
funcs = [functools.partial(func, convernsion=to_convention)
|
||||
funcs = [functools.partial(func, convention=to_convention)
|
||||
for func in funcs]
|
||||
forward_funcs.append(funcs[0])
|
||||
inverse_funcs.append(funcs[1])
|
||||
|
||||
@@ -0,0 +1,265 @@
|
||||
from typing import Optional, Tuple, Union
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin
|
||||
from diffusion_policy.model.diffusion.positional_embedding import SinusoidalPosEmb
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PMFTransformerForDiffusion(ModuleAttrMixin):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
horizon: int,
|
||||
n_obs_steps: Optional[int] = None,
|
||||
cond_dim: int = 0,
|
||||
n_layer: int = 12,
|
||||
n_head: int = 12,
|
||||
n_emb: int = 768,
|
||||
p_drop_emb: float = 0.1,
|
||||
p_drop_attn: float = 0.1,
|
||||
causal_attn: bool = False,
|
||||
obs_as_cond: bool = False,
|
||||
n_cond_layers: int = 0,
|
||||
n_time_tokens: int = 4,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
if n_obs_steps is None:
|
||||
n_obs_steps = horizon
|
||||
if n_time_tokens < 1:
|
||||
raise ValueError("n_time_tokens must be >= 1")
|
||||
|
||||
obs_as_cond = cond_dim > 0
|
||||
T = horizon
|
||||
n_global_cond_tokens = 2 * n_time_tokens
|
||||
T_cond = n_global_cond_tokens + (n_obs_steps if obs_as_cond else 0)
|
||||
|
||||
self.input_emb = nn.Linear(input_dim, n_emb)
|
||||
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
|
||||
self.drop = nn.Dropout(p_drop_emb)
|
||||
|
||||
self.t_emb = SinusoidalPosEmb(n_emb)
|
||||
self.r_emb = SinusoidalPosEmb(n_emb)
|
||||
self.t_tokens = nn.Parameter(torch.zeros(1, n_time_tokens, n_emb))
|
||||
self.r_tokens = nn.Parameter(torch.zeros(1, n_time_tokens, n_emb))
|
||||
self.cond_obs_emb = nn.Linear(cond_dim, n_emb) if obs_as_cond else None
|
||||
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
|
||||
|
||||
if n_cond_layers > 0:
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=n_emb,
|
||||
nhead=n_head,
|
||||
dim_feedforward=4 * n_emb,
|
||||
dropout=p_drop_attn,
|
||||
activation="gelu",
|
||||
batch_first=True,
|
||||
norm_first=True,
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(
|
||||
encoder_layer=encoder_layer,
|
||||
num_layers=n_cond_layers,
|
||||
)
|
||||
else:
|
||||
self.encoder = nn.Sequential(
|
||||
nn.Linear(n_emb, 4 * n_emb),
|
||||
nn.Mish(),
|
||||
nn.Linear(4 * n_emb, n_emb),
|
||||
)
|
||||
|
||||
decoder_layer = nn.TransformerDecoderLayer(
|
||||
d_model=n_emb,
|
||||
nhead=n_head,
|
||||
dim_feedforward=4 * n_emb,
|
||||
dropout=p_drop_attn,
|
||||
activation="gelu",
|
||||
batch_first=True,
|
||||
norm_first=True,
|
||||
)
|
||||
self.decoder = nn.TransformerDecoder(
|
||||
decoder_layer=decoder_layer,
|
||||
num_layers=n_layer,
|
||||
)
|
||||
|
||||
if causal_attn:
|
||||
sz = T
|
||||
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
||||
mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0))
|
||||
self.register_buffer("mask", mask)
|
||||
|
||||
if obs_as_cond:
|
||||
q_idx, c_idx = torch.meshgrid(
|
||||
torch.arange(T),
|
||||
torch.arange(T_cond),
|
||||
indexing="ij",
|
||||
)
|
||||
obs_offset = n_global_cond_tokens
|
||||
visible = c_idx < obs_offset
|
||||
visible = visible | (q_idx >= (c_idx - obs_offset))
|
||||
memory_mask = visible.float().masked_fill(~visible, float("-inf")).masked_fill(visible, float(0.0))
|
||||
self.register_buffer("memory_mask", memory_mask)
|
||||
else:
|
||||
self.memory_mask = None
|
||||
else:
|
||||
self.mask = None
|
||||
self.memory_mask = None
|
||||
|
||||
self.ln_f = nn.LayerNorm(n_emb)
|
||||
self.head_u = nn.Linear(n_emb, output_dim)
|
||||
self.head_v = nn.Linear(n_emb, output_dim)
|
||||
|
||||
self.T = T
|
||||
self.T_cond = T_cond
|
||||
self.horizon = horizon
|
||||
self.n_obs_steps = n_obs_steps
|
||||
self.obs_as_cond = obs_as_cond
|
||||
self.n_global_cond_tokens = n_global_cond_tokens
|
||||
self.n_time_tokens = n_time_tokens
|
||||
|
||||
self.apply(self._init_weights)
|
||||
logger.info(
|
||||
"number of parameters: %e", sum(p.numel() for p in self.parameters())
|
||||
)
|
||||
|
||||
def _init_weights(self, module):
|
||||
ignore_types = (
|
||||
nn.Dropout,
|
||||
SinusoidalPosEmb,
|
||||
nn.TransformerEncoderLayer,
|
||||
nn.TransformerDecoderLayer,
|
||||
nn.TransformerEncoder,
|
||||
nn.TransformerDecoder,
|
||||
nn.ModuleList,
|
||||
nn.Mish,
|
||||
nn.Sequential,
|
||||
)
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.MultiheadAttention):
|
||||
for name in ("in_proj_weight", "q_proj_weight", "k_proj_weight", "v_proj_weight"):
|
||||
weight = getattr(module, name)
|
||||
if weight is not None:
|
||||
torch.nn.init.normal_(weight, mean=0.0, std=0.02)
|
||||
for name in ("in_proj_bias", "bias_k", "bias_v"):
|
||||
bias = getattr(module, name)
|
||||
if bias is not None:
|
||||
torch.nn.init.zeros_(bias)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
torch.nn.init.ones_(module.weight)
|
||||
elif isinstance(module, PMFTransformerForDiffusion):
|
||||
torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
|
||||
torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02)
|
||||
torch.nn.init.normal_(module.t_tokens, mean=0.0, std=0.02)
|
||||
torch.nn.init.normal_(module.r_tokens, mean=0.0, std=0.02)
|
||||
elif isinstance(module, ignore_types):
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError("Unaccounted module {}".format(module))
|
||||
|
||||
def get_optim_groups(self, weight_decay: float = 1e-3):
|
||||
decay = set()
|
||||
no_decay = set()
|
||||
whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention)
|
||||
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
||||
for mn, m in self.named_modules():
|
||||
for pn, _ in m.named_parameters():
|
||||
fpn = "%s.%s" % (mn, pn) if mn else pn
|
||||
if pn.endswith("bias"):
|
||||
no_decay.add(fpn)
|
||||
elif pn.startswith("bias"):
|
||||
no_decay.add(fpn)
|
||||
elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules):
|
||||
decay.add(fpn)
|
||||
elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules):
|
||||
no_decay.add(fpn)
|
||||
|
||||
no_decay.update(
|
||||
{
|
||||
"pos_emb",
|
||||
"cond_pos_emb",
|
||||
"t_tokens",
|
||||
"r_tokens",
|
||||
"_dummy_variable",
|
||||
}
|
||||
)
|
||||
|
||||
param_dict = {pn: p for pn, p in self.named_parameters()}
|
||||
inter_params = decay & no_decay
|
||||
union_params = decay | no_decay
|
||||
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),)
|
||||
assert len(param_dict.keys() - union_params) == 0, (
|
||||
"parameters %s were not separated into either decay/no_decay set!" % (str(param_dict.keys() - union_params),)
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"params": [param_dict[pn] for pn in sorted(list(decay))],
|
||||
"weight_decay": weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [param_dict[pn] for pn in sorted(list(no_decay))],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
|
||||
def configure_optimizers(
|
||||
self,
|
||||
learning_rate: float = 1e-4,
|
||||
weight_decay: float = 1e-3,
|
||||
betas: Tuple[float, float] = (0.9, 0.95),
|
||||
):
|
||||
optim_groups = self.get_optim_groups(weight_decay=weight_decay)
|
||||
return torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
|
||||
|
||||
def _broadcast_time(self, value: Union[torch.Tensor, float, int], batch_size: int, device: torch.device):
|
||||
if not torch.is_tensor(value):
|
||||
value = torch.tensor([value], dtype=torch.float32, device=device)
|
||||
elif value.ndim == 0:
|
||||
value = value[None].to(device=device, dtype=torch.float32)
|
||||
else:
|
||||
value = value.to(device=device, dtype=torch.float32)
|
||||
return value.expand(batch_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
t: Union[torch.Tensor, float, int],
|
||||
r: Union[torch.Tensor, float, int],
|
||||
cond: Optional[torch.Tensor] = None,
|
||||
):
|
||||
batch_size = sample.shape[0]
|
||||
device = sample.device
|
||||
t = self._broadcast_time(t, batch_size, device)
|
||||
r = self._broadcast_time(r, batch_size, device)
|
||||
|
||||
input_emb = self.input_emb(sample)
|
||||
|
||||
t_cond = self.t_tokens + self.t_emb(t).unsqueeze(1)
|
||||
r_cond = self.r_tokens + self.r_emb(r).unsqueeze(1)
|
||||
cond_embeddings = [t_cond, r_cond]
|
||||
if self.obs_as_cond:
|
||||
cond_embeddings.append(self.cond_obs_emb(cond))
|
||||
cond_embeddings = torch.cat(cond_embeddings, dim=1)
|
||||
|
||||
cond_pos = self.cond_pos_emb[:, : cond_embeddings.shape[1], :]
|
||||
memory = self.drop(cond_embeddings + cond_pos)
|
||||
memory = self.encoder(memory)
|
||||
|
||||
token_pos = self.pos_emb[:, : input_emb.shape[1], :]
|
||||
x = self.drop(input_emb + token_pos)
|
||||
x = self.decoder(
|
||||
tgt=x,
|
||||
memory=memory,
|
||||
tgt_mask=self.mask,
|
||||
memory_mask=self.memory_mask,
|
||||
)
|
||||
x = self.ln_f(x)
|
||||
return self.head_u(x), self.head_v(x)
|
||||
@@ -256,8 +256,8 @@ class DiffusionTransformerHybridImagePolicy(BaseImagePolicy):
|
||||
# condition through impainting
|
||||
this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
|
||||
nobs_features = self.obs_encoder(this_nobs)
|
||||
# reshape back to B, T, Do
|
||||
nobs_features = nobs_features.reshape(B, T, -1)
|
||||
# reshape back to B, To, Do
|
||||
nobs_features = nobs_features.reshape(B, To, -1)
|
||||
shape = (B, T, Da+Do)
|
||||
cond_data = torch.zeros(size=shape, device=device, dtype=dtype)
|
||||
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
|
||||
|
||||
@@ -247,7 +247,7 @@ class DiffusionUnetHybridImagePolicy(BaseImagePolicy):
|
||||
# condition through impainting
|
||||
this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
|
||||
nobs_features = self.obs_encoder(this_nobs)
|
||||
# reshape back to B, T, Do
|
||||
# reshape back to B, To, Do
|
||||
nobs_features = nobs_features.reshape(B, To, -1)
|
||||
cond_data = torch.zeros(size=(B, T, Da+Do), device=device, dtype=dtype)
|
||||
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
|
||||
|
||||
453
diffusion_policy/policy/pmf_transformer_hybrid_image_policy.py
Normal file
453
diffusion_policy/policy/pmf_transformer_hybrid_image_policy.py
Normal file
@@ -0,0 +1,453 @@
|
||||
from typing import Dict, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import reduce
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
|
||||
import diffusion_policy.model.vision.crop_randomizer as dmvc
|
||||
import robomimic.models.base_nets as rmbn
|
||||
import robomimic.utils.obs_utils as ObsUtils
|
||||
from diffusion_policy.common.pytorch_util import dict_apply, replace_submodules
|
||||
from diffusion_policy.common.robomimic_config_util import get_robomimic_config
|
||||
from diffusion_policy.model.common.normalizer import LinearNormalizer
|
||||
from diffusion_policy.model.diffusion.mask_generator import LowdimMaskGenerator
|
||||
from diffusion_policy.model.diffusion.pmf_transformer_for_diffusion import (
|
||||
PMFTransformerForDiffusion,
|
||||
)
|
||||
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
|
||||
from robomimic.algo import algo_factory
|
||||
from robomimic.algo.algo import PolicyAlgo
|
||||
|
||||
|
||||
class PMFTransformerHybridImagePolicy(BaseImagePolicy):
|
||||
def __init__(
|
||||
self,
|
||||
shape_meta: dict,
|
||||
noise_scheduler: DDPMScheduler,
|
||||
horizon,
|
||||
n_action_steps,
|
||||
n_obs_steps,
|
||||
num_inference_steps=None,
|
||||
crop_shape=(76, 76),
|
||||
obs_encoder_group_norm=False,
|
||||
eval_fixed_crop=False,
|
||||
n_layer=8,
|
||||
n_cond_layers=0,
|
||||
n_head=4,
|
||||
n_emb=256,
|
||||
p_drop_emb=0.0,
|
||||
p_drop_attn=0.0,
|
||||
causal_attn=True,
|
||||
obs_as_cond=True,
|
||||
pred_action_steps_only=False,
|
||||
n_time_tokens=4,
|
||||
min_time=0.05,
|
||||
du_dt_epsilon=1.0e-3,
|
||||
pmf_u_loss_weight=1.0,
|
||||
pmf_v_loss_weight=1.0,
|
||||
noise_scale=1.0,
|
||||
adatloss_eps=0.01,
|
||||
p_mean=-0.4,
|
||||
p_std=1.0,
|
||||
tr_uniform=True,
|
||||
tr_uniform_prob=0.1,
|
||||
data_proportion=0.5,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
action_shape = shape_meta["action"]["shape"]
|
||||
assert len(action_shape) == 1
|
||||
action_dim = action_shape[0]
|
||||
obs_shape_meta = shape_meta["obs"]
|
||||
obs_config = {
|
||||
"low_dim": [],
|
||||
"rgb": [],
|
||||
"depth": [],
|
||||
"scan": [],
|
||||
}
|
||||
obs_key_shapes = dict()
|
||||
for key, attr in obs_shape_meta.items():
|
||||
shape = attr["shape"]
|
||||
obs_key_shapes[key] = list(shape)
|
||||
|
||||
obs_type = attr.get("type", "low_dim")
|
||||
if obs_type == "rgb":
|
||||
obs_config["rgb"].append(key)
|
||||
elif obs_type == "low_dim":
|
||||
obs_config["low_dim"].append(key)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported obs type: {obs_type}")
|
||||
|
||||
config = get_robomimic_config(
|
||||
algo_name="bc_rnn",
|
||||
hdf5_type="image",
|
||||
task_name="square",
|
||||
dataset_type="ph",
|
||||
)
|
||||
|
||||
with config.unlocked():
|
||||
config.observation.modalities.obs = obs_config
|
||||
|
||||
if crop_shape is None:
|
||||
for _, modality in config.observation.encoder.items():
|
||||
if modality.obs_randomizer_class == "CropRandomizer":
|
||||
modality["obs_randomizer_class"] = None
|
||||
else:
|
||||
crop_h, crop_w = crop_shape
|
||||
for _, modality in config.observation.encoder.items():
|
||||
if modality.obs_randomizer_class == "CropRandomizer":
|
||||
modality.obs_randomizer_kwargs.crop_height = crop_h
|
||||
modality.obs_randomizer_kwargs.crop_width = crop_w
|
||||
|
||||
ObsUtils.initialize_obs_utils_with_config(config)
|
||||
|
||||
policy: PolicyAlgo = algo_factory(
|
||||
algo_name=config.algo_name,
|
||||
config=config,
|
||||
obs_key_shapes=obs_key_shapes,
|
||||
ac_dim=action_dim,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
obs_encoder = policy.nets["policy"].nets["encoder"].nets["obs"]
|
||||
if obs_encoder_group_norm:
|
||||
replace_submodules(
|
||||
root_module=obs_encoder,
|
||||
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
||||
func=lambda x: nn.GroupNorm(
|
||||
num_groups=x.num_features // 16,
|
||||
num_channels=x.num_features,
|
||||
),
|
||||
)
|
||||
|
||||
if eval_fixed_crop:
|
||||
replace_submodules(
|
||||
root_module=obs_encoder,
|
||||
predicate=lambda x: isinstance(x, rmbn.CropRandomizer),
|
||||
func=lambda x: dmvc.CropRandomizer(
|
||||
input_shape=x.input_shape,
|
||||
crop_height=x.crop_height,
|
||||
crop_width=x.crop_width,
|
||||
num_crops=x.num_crops,
|
||||
pos_enc=x.pos_enc,
|
||||
),
|
||||
)
|
||||
|
||||
obs_feature_dim = obs_encoder.output_shape()[0]
|
||||
input_dim = action_dim if obs_as_cond else (obs_feature_dim + action_dim)
|
||||
cond_dim = obs_feature_dim if obs_as_cond else 0
|
||||
|
||||
self.obs_encoder = obs_encoder
|
||||
self.model = PMFTransformerForDiffusion(
|
||||
input_dim=input_dim,
|
||||
output_dim=input_dim,
|
||||
horizon=horizon if not pred_action_steps_only else n_action_steps,
|
||||
n_obs_steps=n_obs_steps,
|
||||
cond_dim=cond_dim,
|
||||
n_layer=n_layer,
|
||||
n_head=n_head,
|
||||
n_emb=n_emb,
|
||||
p_drop_emb=p_drop_emb,
|
||||
p_drop_attn=p_drop_attn,
|
||||
causal_attn=causal_attn,
|
||||
obs_as_cond=obs_as_cond,
|
||||
n_cond_layers=n_cond_layers,
|
||||
n_time_tokens=n_time_tokens,
|
||||
)
|
||||
self.noise_scheduler = noise_scheduler
|
||||
self.mask_generator = LowdimMaskGenerator(
|
||||
action_dim=action_dim,
|
||||
obs_dim=0 if obs_as_cond else obs_feature_dim,
|
||||
max_n_obs_steps=n_obs_steps,
|
||||
fix_obs_steps=True,
|
||||
action_visible=False,
|
||||
)
|
||||
self.normalizer = LinearNormalizer()
|
||||
self.horizon = horizon
|
||||
self.obs_feature_dim = obs_feature_dim
|
||||
self.action_dim = action_dim
|
||||
self.n_action_steps = n_action_steps
|
||||
self.n_obs_steps = n_obs_steps
|
||||
self.obs_as_cond = obs_as_cond
|
||||
self.pred_action_steps_only = pred_action_steps_only
|
||||
self.min_time = min_time
|
||||
self.du_dt_epsilon = du_dt_epsilon
|
||||
self.pmf_u_loss_weight = pmf_u_loss_weight
|
||||
self.pmf_v_loss_weight = pmf_v_loss_weight
|
||||
self.noise_scale = noise_scale
|
||||
self.adatloss_eps = adatloss_eps
|
||||
self.p_mean = p_mean
|
||||
self.p_std = p_std
|
||||
self.tr_uniform = tr_uniform
|
||||
self.tr_uniform_prob = tr_uniform_prob
|
||||
self.data_proportion = data_proportion
|
||||
self.kwargs = kwargs
|
||||
|
||||
if num_inference_steps is None:
|
||||
num_inference_steps = noise_scheduler.config.num_train_timesteps
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
def _encode_obs(self, nobs: Dict[str, torch.Tensor], n_steps: int) -> torch.Tensor:
|
||||
flat_nobs = dict_apply(nobs, lambda x: x[:, :n_steps, ...].reshape(-1, *x.shape[2:]))
|
||||
nobs_features = self.obs_encoder(flat_nobs)
|
||||
return nobs_features.reshape(next(iter(nobs.values())).shape[0], n_steps, -1)
|
||||
|
||||
def _time_view(self, value: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
|
||||
return value.reshape(value.shape[0], *([1] * (ref.ndim - 1)))
|
||||
|
||||
def _adatloss(self, loss: torch.Tensor) -> torch.Tensor:
|
||||
denom = loss.detach() + self.adatloss_eps
|
||||
return loss / denom
|
||||
|
||||
def _sample_logit_normal(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||
normal = torch.randn(batch_size, device=device, dtype=dtype)
|
||||
return torch.sigmoid(normal * self.p_std + self.p_mean)
|
||||
|
||||
def _sample_tr(self, batch_size: int, device: torch.device, dtype: torch.dtype):
|
||||
t = self._sample_logit_normal(batch_size, device, dtype)
|
||||
r = self._sample_logit_normal(batch_size, device, dtype)
|
||||
|
||||
if self.tr_uniform:
|
||||
uniform_mask = torch.rand(batch_size, device=device) < self.tr_uniform_prob
|
||||
uniform_t = torch.rand(batch_size, device=device, dtype=dtype)
|
||||
uniform_r = torch.rand(batch_size, device=device, dtype=dtype)
|
||||
t = torch.where(uniform_mask, uniform_t, t)
|
||||
r = torch.where(uniform_mask, uniform_r, r)
|
||||
|
||||
data_size = int(batch_size * self.data_proportion)
|
||||
fm_mask = torch.arange(batch_size, device=device) < data_size
|
||||
r = torch.where(fm_mask, t, r)
|
||||
|
||||
t_final = torch.maximum(t, r)
|
||||
r_final = torch.minimum(t, r)
|
||||
return t_final, r_final
|
||||
|
||||
def _trajectory_inputs(
|
||||
self,
|
||||
nobs: Dict[str, torch.Tensor],
|
||||
nactions: torch.Tensor,
|
||||
):
|
||||
batch_size = nactions.shape[0]
|
||||
horizon = nactions.shape[1]
|
||||
cond = None
|
||||
trajectory = nactions
|
||||
if self.obs_as_cond:
|
||||
cond = self._encode_obs(nobs, self.n_obs_steps)
|
||||
if self.pred_action_steps_only:
|
||||
start = self.n_obs_steps - 1
|
||||
end = start + self.n_action_steps
|
||||
trajectory = nactions[:, start:end]
|
||||
else:
|
||||
nobs_features = self._encode_obs(nobs, horizon)
|
||||
trajectory = torch.cat([nactions, nobs_features], dim=-1).detach()
|
||||
|
||||
if self.pred_action_steps_only:
|
||||
condition_mask = torch.zeros_like(trajectory, dtype=torch.bool)
|
||||
else:
|
||||
condition_mask = self.mask_generator(trajectory.shape)
|
||||
|
||||
return batch_size, trajectory, cond, condition_mask
|
||||
|
||||
def _apply_conditioning(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
condition_data: torch.Tensor,
|
||||
condition_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
if not condition_mask.any():
|
||||
return sample
|
||||
return torch.where(condition_mask, condition_data, sample)
|
||||
|
||||
def _compute_u_v(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
r: torch.Tensor,
|
||||
cond: torch.Tensor,
|
||||
):
|
||||
x_hat_u, x_hat_v = self.model(sample, t, r, cond)
|
||||
denom = self._time_view(t, sample)
|
||||
u = (sample - x_hat_u) / denom
|
||||
v = (sample - x_hat_v) / denom
|
||||
return u, v
|
||||
|
||||
def _compute_du_dt(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
r: torch.Tensor,
|
||||
cond: torch.Tensor,
|
||||
condition_data: torch.Tensor,
|
||||
condition_mask: torch.Tensor,
|
||||
tangent_v: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
tangent_sample = tangent_v.detach()
|
||||
tangent_r = torch.zeros_like(r)
|
||||
tangent_t = torch.ones_like(t)
|
||||
|
||||
def u_fn(sample_input, r_input, t_input):
|
||||
conditioned_sample = self._apply_conditioning(
|
||||
sample_input, condition_data, condition_mask
|
||||
)
|
||||
u_value, _ = self._compute_u_v(conditioned_sample, t_input, r_input, cond)
|
||||
return u_value
|
||||
|
||||
primals = (sample, r, t)
|
||||
tangents = (tangent_sample, tangent_r, tangent_t)
|
||||
try:
|
||||
_, du_dt = torch.func.jvp(u_fn, primals, tangents)
|
||||
except (AttributeError, NotImplementedError, RuntimeError):
|
||||
_, du_dt = torch.autograd.functional.jvp(
|
||||
u_fn,
|
||||
primals,
|
||||
tangents,
|
||||
create_graph=False,
|
||||
strict=False,
|
||||
)
|
||||
|
||||
return du_dt
|
||||
|
||||
# ========= inference ============
|
||||
def conditional_sample(
|
||||
self,
|
||||
condition_data,
|
||||
condition_mask,
|
||||
cond=None,
|
||||
generator=None,
|
||||
**kwargs,
|
||||
):
|
||||
del kwargs
|
||||
|
||||
trajectory = torch.randn(
|
||||
size=condition_data.shape,
|
||||
dtype=condition_data.dtype,
|
||||
device=condition_data.device,
|
||||
generator=generator,
|
||||
) * self.noise_scale
|
||||
|
||||
time_steps = torch.linspace(
|
||||
1.0,
|
||||
0.0,
|
||||
self.num_inference_steps + 1,
|
||||
dtype=trajectory.dtype,
|
||||
device=trajectory.device,
|
||||
)
|
||||
|
||||
for step_idx in range(self.num_inference_steps):
|
||||
trajectory = self._apply_conditioning(trajectory, condition_data, condition_mask)
|
||||
t = time_steps[step_idx].expand(trajectory.shape[0])
|
||||
r = time_steps[step_idx + 1].expand(trajectory.shape[0])
|
||||
u, _ = self._compute_u_v(trajectory, t, r, cond)
|
||||
delta = self._time_view(t - r, trajectory)
|
||||
trajectory = trajectory - delta * u
|
||||
|
||||
trajectory = self._apply_conditioning(trajectory, condition_data, condition_mask)
|
||||
return trajectory
|
||||
|
||||
def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
assert "past_action" not in obs_dict
|
||||
nobs = self.normalizer.normalize(obs_dict)
|
||||
value = next(iter(nobs.values()))
|
||||
batch_size, to_steps = value.shape[:2]
|
||||
horizon = self.horizon
|
||||
action_dim = self.action_dim
|
||||
|
||||
device = self.device
|
||||
dtype = self.dtype
|
||||
cond = None
|
||||
if self.obs_as_cond:
|
||||
cond = self._encode_obs(nobs, self.n_obs_steps)
|
||||
shape = (batch_size, horizon, action_dim)
|
||||
if self.pred_action_steps_only:
|
||||
shape = (batch_size, self.n_action_steps, action_dim)
|
||||
cond_data = torch.zeros(size=shape, device=device, dtype=dtype)
|
||||
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
|
||||
else:
|
||||
nobs_features = self._encode_obs(nobs, self.n_obs_steps)
|
||||
shape = (batch_size, horizon, action_dim + self.obs_feature_dim)
|
||||
cond_data = torch.zeros(size=shape, device=device, dtype=dtype)
|
||||
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
|
||||
cond_data[:, : self.n_obs_steps, action_dim:] = nobs_features
|
||||
cond_mask[:, : self.n_obs_steps, action_dim:] = True
|
||||
|
||||
nsample = self.conditional_sample(
|
||||
cond_data,
|
||||
cond_mask,
|
||||
cond=cond,
|
||||
**self.kwargs,
|
||||
)
|
||||
|
||||
naction_pred = nsample[..., :action_dim]
|
||||
action_pred = self.normalizer["action"].unnormalize(naction_pred)
|
||||
if self.pred_action_steps_only:
|
||||
action = action_pred
|
||||
else:
|
||||
start = to_steps - 1
|
||||
end = start + self.n_action_steps
|
||||
action = action_pred[:, start:end]
|
||||
return {
|
||||
"action": action,
|
||||
"action_pred": action_pred,
|
||||
}
|
||||
|
||||
# ========= training ============
|
||||
def set_normalizer(self, normalizer: LinearNormalizer):
|
||||
self.normalizer.load_state_dict(normalizer.state_dict())
|
||||
|
||||
def get_optimizer(
|
||||
self,
|
||||
transformer_weight_decay: float,
|
||||
obs_encoder_weight_decay: float,
|
||||
learning_rate: float,
|
||||
betas: Tuple[float, float],
|
||||
) -> torch.optim.Optimizer:
|
||||
optim_groups = self.model.get_optim_groups(weight_decay=transformer_weight_decay)
|
||||
optim_groups.append(
|
||||
{
|
||||
"params": self.obs_encoder.parameters(),
|
||||
"weight_decay": obs_encoder_weight_decay,
|
||||
}
|
||||
)
|
||||
return torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
|
||||
|
||||
def compute_loss(self, batch):
|
||||
assert "valid_mask" not in batch
|
||||
nobs = self.normalizer.normalize(batch["obs"])
|
||||
nactions = self.normalizer["action"].normalize(batch["action"])
|
||||
|
||||
_, trajectory, cond, condition_mask = self._trajectory_inputs(nobs, nactions)
|
||||
noise = torch.randn_like(trajectory) * self.noise_scale
|
||||
batch_size = trajectory.shape[0]
|
||||
|
||||
t, r = self._sample_tr(
|
||||
batch_size, device=trajectory.device, dtype=trajectory.dtype
|
||||
)
|
||||
z_t = (1 - self._time_view(t, trajectory)) * trajectory + self._time_view(t, trajectory) * noise
|
||||
z_t = self._apply_conditioning(z_t, trajectory, condition_mask)
|
||||
|
||||
loss_mask = ~condition_mask
|
||||
target_v = noise - trajectory
|
||||
|
||||
u, v = self._compute_u_v(z_t, t, r, cond)
|
||||
du_dt = self._compute_du_dt(
|
||||
sample=z_t,
|
||||
t=t,
|
||||
r=r,
|
||||
cond=cond,
|
||||
condition_data=trajectory,
|
||||
condition_mask=condition_mask,
|
||||
tangent_v=v,
|
||||
)
|
||||
pmf_velocity = u + self._time_view(t - r, trajectory) * du_dt.detach()
|
||||
|
||||
loss_u = F.mse_loss(pmf_velocity, target_v, reduction="none")
|
||||
loss_v = F.mse_loss(v, target_v, reduction="none")
|
||||
loss_u = loss_u * loss_mask.type(loss_u.dtype)
|
||||
loss_v = loss_v * loss_mask.type(loss_v.dtype)
|
||||
loss_u = reduce(loss_u, "b ... -> b (...)", "mean").mean()
|
||||
loss_v = reduce(loss_v, "b ... -> b (...)", "mean").mean()
|
||||
loss_u = self._adatloss(loss_u)
|
||||
loss_v = self._adatloss(loss_v)
|
||||
return self.pmf_u_loss_weight * loss_u + self.pmf_v_loss_weight * loss_v
|
||||
64
eval.py
Normal file
64
eval.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
Usage:
|
||||
python eval.py --checkpoint data/image/pusht/diffusion_policy_cnn/train_0/checkpoints/latest.ckpt -o data/pusht_eval_output
|
||||
"""
|
||||
|
||||
import sys
|
||||
# use line-buffering for both stdout and stderr
|
||||
sys.stdout = open(sys.stdout.fileno(), mode='w', buffering=1)
|
||||
sys.stderr = open(sys.stderr.fileno(), mode='w', buffering=1)
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import click
|
||||
import hydra
|
||||
import torch
|
||||
import dill
|
||||
import wandb
|
||||
import json
|
||||
from diffusion_policy.workspace.base_workspace import BaseWorkspace
|
||||
|
||||
@click.command()
|
||||
@click.option('-c', '--checkpoint', required=True)
|
||||
@click.option('-o', '--output_dir', required=True)
|
||||
@click.option('-d', '--device', default='cuda:0')
|
||||
def main(checkpoint, output_dir, device):
|
||||
if os.path.exists(output_dir):
|
||||
click.confirm(f"Output path {output_dir} already exists! Overwrite?", abort=True)
|
||||
pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# load checkpoint
|
||||
payload = torch.load(open(checkpoint, 'rb'), pickle_module=dill)
|
||||
cfg = payload['cfg']
|
||||
cls = hydra.utils.get_class(cfg._target_)
|
||||
workspace = cls(cfg, output_dir=output_dir)
|
||||
workspace: BaseWorkspace
|
||||
workspace.load_payload(payload, exclude_keys=None, include_keys=None)
|
||||
|
||||
# get policy from workspace
|
||||
policy = workspace.model
|
||||
if cfg.training.use_ema:
|
||||
policy = workspace.ema_model
|
||||
|
||||
device = torch.device(device)
|
||||
policy.to(device)
|
||||
policy.eval()
|
||||
|
||||
# run eval
|
||||
env_runner = hydra.utils.instantiate(
|
||||
cfg.task.env_runner,
|
||||
output_dir=output_dir)
|
||||
runner_log = env_runner.run(policy)
|
||||
|
||||
# dump log to json
|
||||
json_log = dict()
|
||||
for key, value in runner_log.items():
|
||||
if isinstance(value, wandb.sdk.data_types.video.Video):
|
||||
json_log[key] = value._path
|
||||
else:
|
||||
json_log[key] = value
|
||||
out_path = os.path.join(output_dir, 'eval_log.json')
|
||||
json.dump(json_log, open(out_path, 'w'), indent=2, sort_keys=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
182
image_pusht_diffusion_policy_cnn.yaml
Normal file
182
image_pusht_diffusion_policy_cnn.yaml
Normal file
@@ -0,0 +1,182 @@
|
||||
_target_: diffusion_policy.workspace.train_diffusion_unet_hybrid_workspace.TrainDiffusionUnetHybridWorkspace
|
||||
checkpoint:
|
||||
save_last_ckpt: true
|
||||
save_last_snapshot: false
|
||||
topk:
|
||||
format_str: epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt
|
||||
k: 5
|
||||
mode: max
|
||||
monitor_key: test_mean_score
|
||||
dataloader:
|
||||
batch_size: 64
|
||||
num_workers: 8
|
||||
persistent_workers: false
|
||||
pin_memory: true
|
||||
shuffle: true
|
||||
dataset_obs_steps: 2
|
||||
ema:
|
||||
_target_: diffusion_policy.model.diffusion.ema_model.EMAModel
|
||||
inv_gamma: 1.0
|
||||
max_value: 0.9999
|
||||
min_value: 0.0
|
||||
power: 0.75
|
||||
update_after_step: 0
|
||||
exp_name: default
|
||||
horizon: 16
|
||||
keypoint_visible_rate: 1.0
|
||||
logging:
|
||||
group: null
|
||||
id: null
|
||||
mode: online
|
||||
name: 2023.01.16-20.20.06_train_diffusion_unet_hybrid_pusht_image
|
||||
project: diffusion_policy_debug
|
||||
resume: true
|
||||
tags:
|
||||
- train_diffusion_unet_hybrid
|
||||
- pusht_image
|
||||
- default
|
||||
multi_run:
|
||||
run_dir: data/outputs/2023.01.16/20.20.06_train_diffusion_unet_hybrid_pusht_image
|
||||
wandb_name_base: 2023.01.16-20.20.06_train_diffusion_unet_hybrid_pusht_image
|
||||
n_action_steps: 8
|
||||
n_latency_steps: 0
|
||||
n_obs_steps: 2
|
||||
name: train_diffusion_unet_hybrid
|
||||
obs_as_global_cond: true
|
||||
optimizer:
|
||||
_target_: torch.optim.AdamW
|
||||
betas:
|
||||
- 0.95
|
||||
- 0.999
|
||||
eps: 1.0e-08
|
||||
lr: 0.0001
|
||||
weight_decay: 1.0e-06
|
||||
past_action_visible: false
|
||||
policy:
|
||||
_target_: diffusion_policy.policy.diffusion_unet_hybrid_image_policy.DiffusionUnetHybridImagePolicy
|
||||
cond_predict_scale: true
|
||||
crop_shape:
|
||||
- 84
|
||||
- 84
|
||||
diffusion_step_embed_dim: 128
|
||||
down_dims:
|
||||
- 512
|
||||
- 1024
|
||||
- 2048
|
||||
eval_fixed_crop: true
|
||||
horizon: 16
|
||||
kernel_size: 5
|
||||
n_action_steps: 8
|
||||
n_groups: 8
|
||||
n_obs_steps: 2
|
||||
noise_scheduler:
|
||||
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
||||
beta_end: 0.02
|
||||
beta_schedule: squaredcos_cap_v2
|
||||
beta_start: 0.0001
|
||||
clip_sample: true
|
||||
num_train_timesteps: 100
|
||||
prediction_type: epsilon
|
||||
variance_type: fixed_small
|
||||
num_inference_steps: 100
|
||||
obs_as_global_cond: true
|
||||
obs_encoder_group_norm: true
|
||||
shape_meta:
|
||||
action:
|
||||
shape:
|
||||
- 2
|
||||
obs:
|
||||
agent_pos:
|
||||
shape:
|
||||
- 2
|
||||
type: low_dim
|
||||
image:
|
||||
shape:
|
||||
- 3
|
||||
- 96
|
||||
- 96
|
||||
type: rgb
|
||||
shape_meta:
|
||||
action:
|
||||
shape:
|
||||
- 2
|
||||
obs:
|
||||
agent_pos:
|
||||
shape:
|
||||
- 2
|
||||
type: low_dim
|
||||
image:
|
||||
shape:
|
||||
- 3
|
||||
- 96
|
||||
- 96
|
||||
type: rgb
|
||||
task:
|
||||
dataset:
|
||||
_target_: diffusion_policy.dataset.pusht_image_dataset.PushTImageDataset
|
||||
horizon: 16
|
||||
max_train_episodes: 90
|
||||
pad_after: 7
|
||||
pad_before: 1
|
||||
seed: 42
|
||||
val_ratio: 0.02
|
||||
zarr_path: data/pusht/pusht_cchi_v7_replay.zarr
|
||||
env_runner:
|
||||
_target_: diffusion_policy.env_runner.pusht_image_runner.PushTImageRunner
|
||||
fps: 10
|
||||
legacy_test: true
|
||||
max_steps: 300
|
||||
n_action_steps: 8
|
||||
n_envs: null
|
||||
n_obs_steps: 2
|
||||
n_test: 50
|
||||
n_test_vis: 4
|
||||
n_train: 6
|
||||
n_train_vis: 2
|
||||
past_action: false
|
||||
test_start_seed: 100000
|
||||
train_start_seed: 0
|
||||
image_shape:
|
||||
- 3
|
||||
- 96
|
||||
- 96
|
||||
name: pusht_image
|
||||
shape_meta:
|
||||
action:
|
||||
shape:
|
||||
- 2
|
||||
obs:
|
||||
agent_pos:
|
||||
shape:
|
||||
- 2
|
||||
type: low_dim
|
||||
image:
|
||||
shape:
|
||||
- 3
|
||||
- 96
|
||||
- 96
|
||||
type: rgb
|
||||
task_name: pusht_image
|
||||
training:
|
||||
checkpoint_every: 50
|
||||
debug: false
|
||||
device: cuda:0
|
||||
gradient_accumulate_every: 1
|
||||
lr_scheduler: cosine
|
||||
lr_warmup_steps: 500
|
||||
max_train_steps: null
|
||||
max_val_steps: null
|
||||
num_epochs: 3050
|
||||
resume: true
|
||||
rollout_every: 50
|
||||
sample_every: 5
|
||||
seed: 42
|
||||
tqdm_interval_sec: 1.0
|
||||
use_ema: true
|
||||
val_every: 1
|
||||
val_dataloader:
|
||||
batch_size: 64
|
||||
num_workers: 8
|
||||
persistent_workers: false
|
||||
pin_memory: true
|
||||
shuffle: false
|
||||
189
image_pusht_diffusion_policy_dit_pmf.yaml
Normal file
189
image_pusht_diffusion_policy_dit_pmf.yaml
Normal file
@@ -0,0 +1,189 @@
|
||||
_target_: diffusion_policy.workspace.train_diffusion_transformer_hybrid_workspace.TrainDiffusionTransformerHybridWorkspace
|
||||
checkpoint:
|
||||
save_last_ckpt: true
|
||||
save_last_snapshot: false
|
||||
topk:
|
||||
format_str: epoch={epoch:04d}-train_loss={train_loss:.3f}.ckpt
|
||||
k: 5
|
||||
mode: min
|
||||
monitor_key: train_loss
|
||||
dataloader:
|
||||
batch_size: 64
|
||||
num_workers: 8
|
||||
persistent_workers: false
|
||||
pin_memory: true
|
||||
shuffle: true
|
||||
dataset_obs_steps: 2
|
||||
ema:
|
||||
_target_: diffusion_policy.model.diffusion.ema_model.EMAModel
|
||||
inv_gamma: 1.0
|
||||
max_value: 0.9999
|
||||
min_value: 0.0
|
||||
power: 0.75
|
||||
update_after_step: 0
|
||||
exp_name: default
|
||||
horizon: 16
|
||||
keypoint_visible_rate: 1.0
|
||||
logging:
|
||||
group: null
|
||||
id: null
|
||||
mode: online
|
||||
name: ${now:%Y.%m.%d-%H.%M.%S}_train_diffusion_transformer_hybrid_pmf_pusht_image
|
||||
project: diffusion_policy_debug
|
||||
resume: true
|
||||
tags:
|
||||
- train_diffusion_transformer_hybrid_pmf
|
||||
- pusht_image
|
||||
- default
|
||||
multi_run:
|
||||
run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_train_diffusion_transformer_hybrid_pmf_pusht_image
|
||||
wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_train_diffusion_transformer_hybrid_pmf_pusht_image
|
||||
n_action_steps: 8
|
||||
n_latency_steps: 0
|
||||
n_obs_steps: 2
|
||||
name: train_diffusion_transformer_hybrid_pmf
|
||||
obs_as_cond: true
|
||||
optimizer:
|
||||
betas:
|
||||
- 0.9
|
||||
- 0.95
|
||||
learning_rate: 0.0001
|
||||
obs_encoder_weight_decay: 1.0e-06
|
||||
transformer_weight_decay: 0.001
|
||||
past_action_visible: false
|
||||
policy:
|
||||
_target_: diffusion_policy.policy.pmf_transformer_hybrid_image_policy.PMFTransformerHybridImagePolicy
|
||||
crop_shape:
|
||||
- 84
|
||||
- 84
|
||||
eval_fixed_crop: true
|
||||
horizon: 16
|
||||
n_action_steps: 8
|
||||
n_cond_layers: 0
|
||||
n_emb: 256
|
||||
n_head: 4
|
||||
n_layer: 12
|
||||
n_obs_steps: 2
|
||||
n_time_tokens: 4
|
||||
noise_scale: 1.0
|
||||
adatloss_eps: 0.01
|
||||
p_mean: -0.4
|
||||
p_std: 1.0
|
||||
noise_scheduler:
|
||||
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
||||
beta_end: 0.02
|
||||
beta_schedule: squaredcos_cap_v2
|
||||
beta_start: 0.0001
|
||||
clip_sample: true
|
||||
num_train_timesteps: 100
|
||||
prediction_type: sample
|
||||
variance_type: fixed_small
|
||||
num_inference_steps: 1
|
||||
obs_as_cond: true
|
||||
obs_encoder_group_norm: true
|
||||
p_drop_attn: 0.0
|
||||
p_drop_emb: 0.0
|
||||
pmf_u_loss_weight: 1.0
|
||||
pmf_v_loss_weight: 1.0
|
||||
tr_uniform: true
|
||||
tr_uniform_prob: 0.1
|
||||
data_proportion: 0.5
|
||||
shape_meta:
|
||||
action:
|
||||
shape:
|
||||
- 2
|
||||
obs:
|
||||
agent_pos:
|
||||
shape:
|
||||
- 2
|
||||
type: low_dim
|
||||
image:
|
||||
shape:
|
||||
- 3
|
||||
- 96
|
||||
- 96
|
||||
type: rgb
|
||||
shape_meta:
|
||||
action:
|
||||
shape:
|
||||
- 2
|
||||
obs:
|
||||
agent_pos:
|
||||
shape:
|
||||
- 2
|
||||
type: low_dim
|
||||
image:
|
||||
shape:
|
||||
- 3
|
||||
- 96
|
||||
- 96
|
||||
type: rgb
|
||||
task:
|
||||
dataset:
|
||||
_target_: diffusion_policy.dataset.pusht_image_dataset.PushTImageDataset
|
||||
horizon: 16
|
||||
max_train_episodes: 90
|
||||
pad_after: 7
|
||||
pad_before: 1
|
||||
seed: 42
|
||||
val_ratio: 0.02
|
||||
zarr_path: data/pusht/pusht_cchi_v7_replay.zarr
|
||||
env_runner:
|
||||
_target_: diffusion_policy.env_runner.pusht_image_runner.PushTImageRunner
|
||||
fps: 10
|
||||
legacy_test: true
|
||||
max_steps: 300
|
||||
n_action_steps: 8
|
||||
n_envs: null
|
||||
n_obs_steps: 2
|
||||
n_test: 50
|
||||
n_test_vis: 4
|
||||
n_train: 6
|
||||
n_train_vis: 2
|
||||
past_action: false
|
||||
test_start_seed: 100000
|
||||
train_start_seed: 0
|
||||
image_shape:
|
||||
- 3
|
||||
- 96
|
||||
- 96
|
||||
name: pusht_image
|
||||
shape_meta:
|
||||
action:
|
||||
shape:
|
||||
- 2
|
||||
obs:
|
||||
agent_pos:
|
||||
shape:
|
||||
- 2
|
||||
type: low_dim
|
||||
image:
|
||||
shape:
|
||||
- 3
|
||||
- 96
|
||||
- 96
|
||||
type: rgb
|
||||
task_name: pusht_image
|
||||
training:
|
||||
checkpoint_every: 50
|
||||
debug: false
|
||||
device: cuda:0
|
||||
gradient_accumulate_every: 1
|
||||
lr_scheduler: cosine
|
||||
lr_warmup_steps: 500
|
||||
max_train_steps: null
|
||||
max_val_steps: null
|
||||
num_epochs: 600
|
||||
resume: true
|
||||
rollout_every: 50
|
||||
sample_every: 5
|
||||
seed: 42
|
||||
tqdm_interval_sec: 1.0
|
||||
use_ema: true
|
||||
val_every: 1
|
||||
val_dataloader:
|
||||
batch_size: 64
|
||||
num_workers: 8
|
||||
persistent_workers: false
|
||||
pin_memory: true
|
||||
shuffle: false
|
||||
39
requirements-pusht-5090.txt
Normal file
39
requirements-pusht-5090.txt
Normal file
@@ -0,0 +1,39 @@
|
||||
# Direct package pins for the canonical PushT image workflow on host 5090.
|
||||
# Torch/TorchVision/Torchaudio are installed separately from the cu128 index in setup_uv_pusht_5090.sh.
|
||||
|
||||
numpy==1.26.4
|
||||
scipy==1.11.4
|
||||
numba==0.59.1
|
||||
llvmlite==0.42.0
|
||||
cffi==1.15.1
|
||||
cython==0.29.32
|
||||
h5py==3.8.0
|
||||
pandas==2.2.3
|
||||
zarr==2.12.0
|
||||
numcodecs==0.10.2
|
||||
hydra-core==1.2.0
|
||||
einops==0.4.1
|
||||
tqdm==4.64.1
|
||||
dill==0.3.5.1
|
||||
scikit-video==1.1.11
|
||||
scikit-image==0.19.3
|
||||
gym==0.23.1
|
||||
pymunk==6.2.1
|
||||
wandb==0.13.3
|
||||
threadpoolctl==3.1.0
|
||||
shapely==1.8.5.post1
|
||||
matplotlib==3.6.1
|
||||
imageio==2.22.0
|
||||
imageio-ffmpeg==0.4.7
|
||||
termcolor==2.0.1
|
||||
tensorboard==2.10.1
|
||||
tensorboardx==2.5.1
|
||||
psutil==7.2.2
|
||||
click==8.1.8
|
||||
boto3==1.24.96
|
||||
diffusers==0.11.1
|
||||
huggingface-hub==0.10.1
|
||||
av==14.0.1
|
||||
pygame==2.5.2
|
||||
robomimic==0.2.0
|
||||
opencv-python-headless==4.10.0.84
|
||||
20
setup_uv_pusht_5090.sh
Executable file
20
setup_uv_pusht_5090.sh
Executable file
@@ -0,0 +1,20 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
ROOT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
cd "$ROOT_DIR"
|
||||
|
||||
export UV_CACHE_DIR="${UV_CACHE_DIR:-$ROOT_DIR/.uv-cache}"
|
||||
export UV_PYTHON_INSTALL_DIR="${UV_PYTHON_INSTALL_DIR:-$ROOT_DIR/.uv-python}"
|
||||
|
||||
uv venv --python 3.9 .venv
|
||||
source .venv/bin/activate
|
||||
|
||||
uv pip install --upgrade pip wheel setuptools==80.9.0
|
||||
uv pip install --python .venv/bin/python \
|
||||
--index-url https://download.pytorch.org/whl/cu128 \
|
||||
torch==2.8.0+cu128 torchvision==0.23.0+cu128 torchaudio==2.8.0+cu128
|
||||
uv pip install --python .venv/bin/python -r requirements-pusht-5090.txt
|
||||
uv pip install --python .venv/bin/python -e .
|
||||
|
||||
echo "uv environment ready at $ROOT_DIR/.venv"
|
||||
Reference in New Issue
Block a user