14 Commits

Author SHA1 Message Date
Logic
9169e4d7e0 feat(pusht): add dual-head uv transformer 2026-03-17 17:05:02 +08:00
gameloader
42dc29a2cb feat: align pmf transformer training and config defaults 2026-03-16 15:37:32 +08:00
Logic
79f31940c4 feat(pusht): add pMF-style DiT image policy 2026-03-16 11:11:43 +08:00
Logic
2aa06c8917 fix(pusht): stabilize DiT pusht training on current stack 2026-03-15 18:54:50 +08:00
Logic
08c1950c6d chore(pusht): add 5090 repro docs and uv setup 2026-03-14 12:25:44 +08:00
Yihuai Gao
5ba07ac666 Done adapting mujoco image dataset 2024-12-24 11:47:59 -08:00
Cheng Chi
548a52bbb1 Merge pull request #27 from pointW/main
Fix typo in rotation_transformer.py
2023-10-26 22:34:42 -07:00
Dian Wang
de4384e84a fix typo in rotation_transformer.py 2023-10-25 10:43:25 -04:00
Cheng Chi
7dd9dc417a Merge pull request #21 from columbia-ai-robotics/cchi/fix_cpu_affinity
pinned llvm-openmp version to avoid cpu affinity bug in pytorch
2023-09-12 23:36:52 -07:00
Cheng Chi
5aa9996fdc pinned llvm-openmp version to avoid cpu affinity bug in pytorch 2023-09-13 02:36:26 -04:00
Cheng Chi
5c3d54fca3 Merge pull request #20 from columbia-ai-robotics/cchi/eval_script
added eval script and documentation
2023-09-09 22:58:56 -07:00
Cheng Chi
a98e74873b added eval script and documentation 2023-09-10 01:58:04 -04:00
Cheng Chi
68eef44d3e Merge pull request #19 from columbia-ai-robotics/cchi/fix_transformer_impainting
fixed T->To based on suggestion from Dominique-Yiu
2023-09-09 09:52:24 -07:00
Cheng Chi
c52bac42ee fixed T->To based on suggestion from Dominique-Yiu 2023-09-09 12:51:49 -04:00
20 changed files with 1621 additions and 16 deletions

68
AGENTS.md Normal file
View File

@@ -0,0 +1,68 @@
# Agent Notes
## Purpose
`~/diffusion_policy` is the Diffusion Policy training repo. The main workflow here is Hydra-driven training via `train.py`, with the canonical PushT image experiment configured by `image_pusht_diffusion_policy_cnn.yaml`.
## Top Level
- `diffusion_policy/`: core code, configs, datasets, env runners, workspaces.
- `data/`: local datasets, outputs, checkpoints, run logs.
- `train.py`: main training entrypoint.
- `eval.py`: checkpoint evaluation entrypoint.
- `image_pusht_diffusion_policy_cnn.yaml`: canonical single-seed PushT image config from the README path.
- `.venv/`: local `uv`-managed virtualenv.
- `.uv-cache/`, `.uv-python/`: local `uv` cache and Python install state.
- `README.md`: upstream instructions and canonical commands.
## Canonical PushT Image Path
- Entrypoint: `python train.py --config-dir=. --config-name=image_pusht_diffusion_policy_cnn.yaml`
- Dataset path in config: `data/pusht/pusht_cchi_v7_replay.zarr`
- README canonical device override: `training.device=cuda:0`
## Data
- PushT archive currently present at `data/pusht.zip`
- Unpacked dataset used by training: `data/pusht/pusht_cchi_v7_replay.zarr`
## Local Compatibility Adjustments
- `diffusion_policy/env_runner/pusht_image_runner.py` now uses `SyncVectorEnv` instead of `AsyncVectorEnv`.
Reason: avoid shared-memory and semaphore failures on this host/session.
- `diffusion_policy/gym_util/sync_vector_env.py` has local compatibility changes:
- added `reset_async`
- seeded `reset_wait`
- updated `concatenate(...)` call order for the current `gym` API
## Environment Expectations
- Use the local `uv` env at `.venv`
- Verified local Python: `3.9.25`
- Verified local Torch stack: `torch 2.8.0+cu128`, `torchvision 0.23.0+cu128`
- Other key installed versions verified in `.venv`:
- `gym 0.23.1`
- `hydra-core 1.2.0`
- `diffusers 0.11.1`
- `huggingface_hub 0.10.1`
- `wandb 0.13.3`
- `zarr 2.12.0`
- `numcodecs 0.10.2`
- `av 14.0.1`
- Important note: this shell currently reports `torch.cuda.is_available() == False`, so always verify CUDA access in the current session before assuming GPU is usable.
## Logging And Outputs
- Hydra run outputs: `data/outputs/...`
- Per-run files to check first:
- `.hydra/overrides.yaml`
- `logs.json.txt`
- `train.log`
- `checkpoints/latest.ckpt`
- Extra launcher logs may live under `data/run_logs/`
## Practical Guidance
- Inspect with `rg`, `sed`, and existing Hydra output folders before changing code.
- Prefer config overrides before code edits.
- On this host, start from these safety overrides unless revalidated:
- `logging.mode=offline`
- `dataloader.num_workers=0`
- `val_dataloader.num_workers=0`
- `task.env_runner.n_envs=1`
- `task.env_runner.n_test_vis=0`
- `task.env_runner.n_train_vis=0`
- If a run fails, inspect `.hydra/overrides.yaml`, then `logs.json.txt`, then `train.log`.
- Avoid driver or system changes unless the repo-local path is clearly blocked.

108
PUSHT_REPRO_5090.md Normal file
View File

@@ -0,0 +1,108 @@
# PushT Repro On 5090
## Goal
Reproduce the canonical single-seed image PushT experiment from this repo in `~/diffusion_policy` using `image_pusht_diffusion_policy_cnn.yaml`.
## Current Verified Local Setup
- Virtualenv: `./.venv` managed with `uv`
- Python: `3.9.25`
- Torch stack: `torch 2.8.0+cu128`, `torchvision 0.23.0+cu128`
- Version strategy used here:
- newer Torch/CUDA stack for current 5090-class hardware support
- keep older repo-era packages where they are still required by the code
- Verified key pins in `.venv`:
- `numpy 1.26.4`
- `gym 0.23.1`
- `hydra-core 1.2.0`
- `diffusers 0.11.1`
- `huggingface_hub 0.10.1`
- `wandb 0.13.3`
- `zarr 2.12.0`
- `numcodecs 0.10.2`
- `av 14.0.1`
- `robomimic 0.2.0`
## Dataset
- README source: `https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip`
- Local archive currently present: `data/pusht.zip`
- Unpacked dataset used by the config: `data/pusht/pusht_cchi_v7_replay.zarr`
## Repo-Local Code Adjustments
- `diffusion_policy/env_runner/pusht_image_runner.py`
- switched PushT image evaluation from `AsyncVectorEnv` to `SyncVectorEnv`
- `diffusion_policy/gym_util/sync_vector_env.py`
- added `reset_async`
- added seeded `reset_wait`
- updated `concatenate(...)` call order for current `gym`
These changes were needed to keep PushT evaluation working without the async shared-memory path.
## Validated GPU Smoke Command
This route is verified by `data/outputs/gpu_smoke2___pusht_gpu_smoke`, which contains `logs.json.txt` plus checkpoints:
```bash
.venv/bin/python train.py \
--config-dir=. \
--config-name=image_pusht_diffusion_policy_cnn.yaml \
training.seed=42 \
training.device=cuda:0 \
logging.mode=offline \
dataloader.num_workers=0 \
val_dataloader.num_workers=0 \
task.env_runner.n_envs=1 \
training.debug=true \
task.env_runner.n_test=2 \
task.env_runner.n_test_vis=0 \
task.env_runner.n_train=1 \
task.env_runner.n_train_vis=0 \
task.env_runner.max_steps=20
```
## Practical Full Training Command Used Here
This matches the longer GPU run under `data/outputs/2026.03.13/15.37.00_train_diffusion_unet_hybrid_pusht_image_gpu_seed42`:
```bash
.venv/bin/python train.py \
--config-dir=. \
--config-name=image_pusht_diffusion_policy_cnn.yaml \
training.seed=42 \
training.device=cuda:0 \
logging.mode=offline \
dataloader.num_workers=0 \
val_dataloader.num_workers=0 \
task.env_runner.n_envs=1 \
task.env_runner.n_test_vis=0 \
task.env_runner.n_train_vis=0 \
hydra.run.dir=data/outputs/2026.03.13/15.37.00_train_diffusion_unet_hybrid_pusht_image_gpu_seed42
```
## Why These Overrides Were Used
- `logging.mode=offline`
- avoids needing a W&B login and still leaves local run metadata in the output dir
- `dataloader.num_workers=0` and `val_dataloader.num_workers=0`
- avoids extra multiprocessing on this host
- `task.env_runner.n_envs=1`
- keeps PushT eval on the serial `SyncVectorEnv` path
- `task.env_runner.n_test_vis=0` and `task.env_runner.n_train_vis=0`
- avoids video-writing issues on this stack
- one earlier GPU run with default vis settings logged libav/libx264 `profile=high` errors in `data/outputs/_train_diffusion_unet_hybrid_pusht_image_gpu_seed42/train.log`
## Output Locations
- Smoke run:
- `data/outputs/gpu_smoke2___pusht_gpu_smoke`
- Longer GPU run:
- `data/outputs/2026.03.13/15.37.00_train_diffusion_unet_hybrid_pusht_image_gpu_seed42`
- Files to inspect inside a run:
- `.hydra/overrides.yaml`
- `logs.json.txt`
- `train.log`
- `checkpoints/latest.ckpt`
## Known Caveats
- The default config is still tuned for older assumptions:
- `logging.mode=online`
- `dataloader.num_workers=8`
- `task.env_runner.n_envs=null`
- `task.env_runner.n_test_vis=4`
- `task.env_runner.n_train_vis=2`
- In this shell, `torch.cuda.is_available()` currently reports `False` even though the repo contains validated GPU smoke/full run artifacts. Re-check device visibility in the current session before restarting a GPU run.

View File

@@ -202,6 +202,41 @@ data/outputs/2023.03.01/22.13.58_train_diffusion_unet_hybrid_pusht_image
7 directories, 16 files 7 directories, 16 files
``` ```
### 🆕 Evaluate Pre-trained Checkpoints
Download a checkpoint from the published training log folders, such as [https://diffusion-policy.cs.columbia.edu/data/experiments/low_dim/pusht/diffusion_policy_cnn/train_0/checkpoints/epoch=0550-test_mean_score=0.969.ckpt](https://diffusion-policy.cs.columbia.edu/data/experiments/low_dim/pusht/diffusion_policy_cnn/train_0/checkpoints/epoch=0550-test_mean_score=0.969.ckpt).
Run the evaluation script:
```console
(robodiff)[diffusion_policy]$ python eval.py --checkpoint data/0550-test_mean_score=0.969.ckpt --output_dir data/pusht_eval_output --device cuda:0
```
This will generate the following directory structure:
```console
(robodiff)[diffusion_policy]$ tree data/pusht_eval_output
data/pusht_eval_output
├── eval_log.json
└── media
├── 1fxtno84.mp4
├── 224l7jqd.mp4
├── 2fo4btlf.mp4
├── 2in4cn7a.mp4
├── 34b3o2qq.mp4
└── 3p7jqn32.mp4
1 directory, 7 files
```
`eval_log.json` contains metrics that is logged to wandb during training:
```console
(robodiff)[diffusion_policy]$ cat data/pusht_eval_output/eval_log.json
{
"test/mean_score": 0.9150393806777066,
"test/sim_max_reward_4300000": 1.0,
"test/sim_max_reward_4300001": 0.9872969750774386,
...
"train/sim_video_1": "data/pusht_eval_output//media/2fo4btlf.mp4"
}
```
## 🦾 Demo, Training and Eval on a Real Robot ## 🦾 Demo, Training and Eval on a Real Robot
Make sure your UR5 robot is running and accepting command from its network interface (emergency stop button within reach at all time), your RealSense cameras plugged in to your workstation (tested with `realsense-viewer`) and your SpaceMouse connected with the `spacenavd` daemon running (verify with `systemctl status spacenavd`). Make sure your UR5 robot is running and accepting command from its network interface (emergency stop button within reach at all time), your RealSense cameras plugged in to your workstation (tested with `realsense-viewer`) and your SpaceMouse connected with the `spacenavd` daemon running (verify with `systemctl status spacenavd`).

View File

@@ -46,6 +46,8 @@ dependencies:
- diffusers=0.11.1 - diffusers=0.11.1
- av=10.0.0 - av=10.0.0
- cmake=3.24.3 - cmake=3.24.3
# trick to avoid cpu affinity issue described in https://github.com/pytorch/pytorch/issues/99625
- llvm-openmp=14
# trick to force reinstall imagecodecs via pip # trick to force reinstall imagecodecs via pip
- imagecodecs==2022.8.8 - imagecodecs==2022.8.8
- pip: - pip:

View File

@@ -46,6 +46,8 @@ dependencies:
- diffusers=0.11.1 - diffusers=0.11.1
- av=10.0.0 - av=10.0.0
- cmake=3.24.3 - cmake=3.24.3
# trick to avoid cpu affinity issue described in https://github.com/pytorch/pytorch/issues/99625
- llvm-openmp=14
# trick to force reinstall imagecodecs via pip # trick to force reinstall imagecodecs via pip
- imagecodecs==2022.8.8 - imagecodecs==2022.8.8
- pip: - pip:

View File

@@ -158,6 +158,8 @@ class ReplayBuffer:
# numpy backend # numpy backend
meta = dict() meta = dict()
for key, value in src_root['meta'].items(): for key, value in src_root['meta'].items():
if isinstance(value, zarr.Group):
continue
if len(value.shape) == 0: if len(value.shape) == 0:
meta[key] = np.array(value) meta[key] = np.array(value)
else: else:

View File

@@ -0,0 +1,109 @@
from typing import Dict
import torch
import numpy as np
import copy
from diffusion_policy.common.pytorch_util import dict_apply
from diffusion_policy.common.replay_buffer import ReplayBuffer
from diffusion_policy.common.sampler import (
SequenceSampler, get_val_mask, downsample_mask)
from diffusion_policy.model.common.normalizer import LinearNormalizer
from diffusion_policy.dataset.base_dataset import BaseImageDataset
from diffusion_policy.common.normalize_util import get_image_range_normalizer
class MujocoImageDataset(BaseImageDataset):
def __init__(self,
zarr_path,
horizon=1,
pad_before=0,
pad_after=0,
seed=42,
val_ratio=0.0,
max_train_episodes=None
):
super().__init__()
self.replay_buffer = ReplayBuffer.copy_from_path(
# zarr_path, keys=['img', 'state', 'action'])
zarr_path, keys=['robot_0_camera_images', 'robot_0_tcp_xyz_wxyz', 'robot_0_gripper_width', 'action_0_tcp_xyz_wxyz', 'action_0_gripper_width'])
val_mask = get_val_mask(
n_episodes=self.replay_buffer.n_episodes,
val_ratio=val_ratio,
seed=seed)
train_mask = ~val_mask
train_mask = downsample_mask(
mask=train_mask,
max_n=max_train_episodes,
seed=seed)
self.sampler = SequenceSampler(
replay_buffer=self.replay_buffer,
sequence_length=horizon,
pad_before=pad_before,
pad_after=pad_after,
episode_mask=train_mask)
self.train_mask = train_mask
self.horizon = horizon
self.pad_before = pad_before
self.pad_after = pad_after
def get_validation_dataset(self):
val_set = copy.copy(self)
val_set.sampler = SequenceSampler(
replay_buffer=self.replay_buffer,
sequence_length=self.horizon,
pad_before=self.pad_before,
pad_after=self.pad_after,
episode_mask=~self.train_mask
)
val_set.train_mask = ~self.train_mask
return val_set
def get_normalizer(self, mode='limits', **kwargs):
data = {
'action': np.concatenate([self.replay_buffer['action_0_tcp_xyz_wxyz'], self.replay_buffer['action_0_gripper_width']], axis=-1),
'agent_pos': np.concatenate([self.replay_buffer['robot_0_tcp_xyz_wxyz'], self.replay_buffer['robot_0_tcp_xyz_wxyz']], axis=-1)
}
normalizer = LinearNormalizer()
normalizer.fit(data=data, last_n_dims=1, mode=mode, **kwargs)
normalizer['image'] = get_image_range_normalizer()
return normalizer
def __len__(self) -> int:
return len(self.sampler)
def _sample_to_data(self, sample):
# agent_pos = sample['state'][:,:2].astype(np.float32) # (agent_posx2, block_posex3)
agent_pos = np.concatenate([sample['robot_0_tcp_xyz_wxyz'], sample['robot_0_gripper_width']], axis=-1).astype(np.float32)
agent_action = np.concatenate([sample['action_0_tcp_xyz_wxyz'], sample['action_0_gripper_width']], axis=-1).astype(np.float32)
# image = np.moveaxis(sample['img'],-1,1)/255
image = np.moveaxis(sample['robot_0_camera_images'].astype(np.float32).squeeze(1),-1,1)/255
data = {
'obs': {
'image': image, # T, 3, 224, 224
'agent_pos': agent_pos, # T, 8 (x,y,z,qx,qy,qz,qw,gripper_width)
},
'action': agent_action # T, 8 (x,y,z,qx,qy,qz,qw,gripper_width)
}
return data
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
sample = self.sampler.sample_sequence(idx)
data = self._sample_to_data(sample)
torch_data = dict_apply(data, torch.from_numpy)
return torch_data
def test():
import os
zarr_path = os.path.expanduser('/home/yihuai/robotics/repositories/mujoco/mujoco-env/data/collect_heuristic_data/2024-12-24_11-36-15_100episodes/merged_data.zarr')
dataset = MujocoImageDataset(zarr_path, horizon=16)
print(dataset[0])
# from matplotlib import pyplot as plt
# normalizer = dataset.get_normalizer()
# nactions = normalizer['action'].normalize(dataset.replay_buffer['action'])
# diff = np.diff(nactions, axis=0)
# dists = np.linalg.norm(np.diff(nactions, axis=0), axis=-1)
if __name__ == '__main__':
test()

View File

@@ -8,8 +8,7 @@ import dill
import math import math
import wandb.sdk.data_types.video as wv import wandb.sdk.data_types.video as wv
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
from diffusion_policy.gym_util.async_vector_env import AsyncVectorEnv from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
# from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
@@ -121,7 +120,9 @@ class PushTImageRunner(BaseImageRunner):
env_prefixs.append('test/') env_prefixs.append('test/')
env_init_fn_dills.append(dill.dumps(init_fn)) env_init_fn_dills.append(dill.dumps(init_fn))
env = AsyncVectorEnv(env_fns) # This environment can run without multiprocessing, which avoids
# shared-memory and semaphore restrictions on some machines.
env = SyncVectorEnv(env_fns)
# test env # test env
# env.reset(seed=env_seeds) # env.reset(seed=env_seeds)

View File

@@ -8,8 +8,7 @@ import dill
import math import math
import wandb.sdk.data_types.video as wv import wandb.sdk.data_types.video as wv
from diffusion_policy.env.pusht.pusht_keypoints_env import PushTKeypointsEnv from diffusion_policy.env.pusht.pusht_keypoints_env import PushTKeypointsEnv
from diffusion_policy.gym_util.async_vector_env import AsyncVectorEnv from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
# from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
@@ -133,7 +132,7 @@ class PushTKeypointsRunner(BaseLowdimRunner):
env_prefixs.append('test/') env_prefixs.append('test/')
env_init_fn_dills.append(dill.dumps(init_fn)) env_init_fn_dills.append(dill.dumps(init_fn))
env = AsyncVectorEnv(env_fns) env = SyncVectorEnv(env_fns)
# test env # test env
# env.reset(seed=env_seeds) # env.reset(seed=env_seeds)

View File

@@ -60,17 +60,44 @@ class SyncVectorEnv(VectorEnv):
for env, seed in zip(self.envs, seeds): for env, seed in zip(self.envs, seeds):
env.seed(seed) env.seed(seed)
def reset_wait(self): def reset_async(self, seed=None, return_info=False, options=None):
if seed is None:
seeds = [None for _ in range(self.num_envs)]
elif isinstance(seed, int):
seeds = [seed + i for i in range(self.num_envs)]
else:
seeds = list(seed)
assert len(seeds) == self.num_envs
self._reset_seeds = seeds
self._reset_return_info = return_info
self._reset_options = options
def reset_wait(self, seed=None, return_info=False, options=None):
seeds = getattr(self, '_reset_seeds', None)
if seeds is None:
if seed is None:
seeds = [None for _ in range(self.num_envs)]
elif isinstance(seed, int):
seeds = [seed + i for i in range(self.num_envs)]
else:
seeds = list(seed)
self._dones[:] = False self._dones[:] = False
observations = [] observations = []
for env in self.envs: infos = []
for env, seed_i in zip(self.envs, seeds):
if seed_i is not None:
env.seed(seed_i)
observation = env.reset() observation = env.reset()
observations.append(observation) observations.append(observation)
infos.append({})
self.observations = concatenate( self.observations = concatenate(
observations, self.observations, self.single_observation_space self.single_observation_space, observations, self.observations
) )
return deepcopy(self.observations) if self.copy else self.observations obs = deepcopy(self.observations) if self.copy else self.observations
if return_info:
return obs, infos
return obs
def step_async(self, actions): def step_async(self, actions):
self._actions = actions self._actions = actions
@@ -84,7 +111,7 @@ class SyncVectorEnv(VectorEnv):
observations.append(observation) observations.append(observation)
infos.append(info) infos.append(info)
self.observations = concatenate( self.observations = concatenate(
observations, self.observations, self.single_observation_space self.single_observation_space, observations, self.observations
) )
return ( return (

View File

@@ -40,7 +40,7 @@ class RotationTransformer:
getattr(pt, f'matrix_to_{from_rep}') getattr(pt, f'matrix_to_{from_rep}')
] ]
if from_convention is not None: if from_convention is not None:
funcs = [functools.partial(func, convernsion=from_convention) funcs = [functools.partial(func, convention=from_convention)
for func in funcs] for func in funcs]
forward_funcs.append(funcs[0]) forward_funcs.append(funcs[0])
inverse_funcs.append(funcs[1]) inverse_funcs.append(funcs[1])
@@ -51,7 +51,7 @@ class RotationTransformer:
getattr(pt, f'{to_rep}_to_matrix') getattr(pt, f'{to_rep}_to_matrix')
] ]
if to_convention is not None: if to_convention is not None:
funcs = [functools.partial(func, convernsion=to_convention) funcs = [functools.partial(func, convention=to_convention)
for func in funcs] for func in funcs]
forward_funcs.append(funcs[0]) forward_funcs.append(funcs[0])
inverse_funcs.append(funcs[1]) inverse_funcs.append(funcs[1])

View File

@@ -0,0 +1,302 @@
from typing import Optional, Tuple, Union
import logging
import torch
import torch.nn as nn
from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin
from diffusion_policy.model.diffusion.positional_embedding import SinusoidalPosEmb
logger = logging.getLogger(__name__)
class PMFTransformerForDiffusion(ModuleAttrMixin):
def __init__(
self,
input_dim: int,
output_dim: int,
horizon: int,
n_obs_steps: Optional[int] = None,
cond_dim: int = 0,
n_layer: int = 12,
n_head: int = 12,
n_emb: int = 768,
p_drop_emb: float = 0.1,
p_drop_attn: float = 0.1,
causal_attn: bool = False,
obs_as_cond: bool = False,
n_cond_layers: int = 0,
n_time_tokens: int = 4,
n_head_layers: int = 4,
) -> None:
super().__init__()
if n_obs_steps is None:
n_obs_steps = horizon
if n_time_tokens < 1:
raise ValueError("n_time_tokens must be >= 1")
if n_head_layers < 0:
raise ValueError("n_head_layers must be >= 0")
if n_head_layers >= n_layer:
raise ValueError(
"n_head_layers must be smaller than n_layer so shared trunk depth stays positive"
)
obs_as_cond = cond_dim > 0
T = horizon
n_global_cond_tokens = 2 * n_time_tokens
T_cond = n_global_cond_tokens + (n_obs_steps if obs_as_cond else 0)
n_shared_layers = n_layer - n_head_layers
self.input_emb = nn.Linear(input_dim, n_emb)
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
self.drop = nn.Dropout(p_drop_emb)
self.t_emb = SinusoidalPosEmb(n_emb)
self.r_emb = SinusoidalPosEmb(n_emb)
self.t_tokens = nn.Parameter(torch.zeros(1, n_time_tokens, n_emb))
self.r_tokens = nn.Parameter(torch.zeros(1, n_time_tokens, n_emb))
self.cond_obs_emb = nn.Linear(cond_dim, n_emb) if obs_as_cond else None
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
if n_cond_layers > 0:
encoder_layer = nn.TransformerEncoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4 * n_emb,
dropout=p_drop_attn,
activation="gelu",
batch_first=True,
norm_first=True,
)
self.encoder = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=n_cond_layers,
)
else:
self.encoder = nn.Sequential(
nn.Linear(n_emb, 4 * n_emb),
nn.Mish(),
nn.Linear(4 * n_emb, n_emb),
)
decoder_layer = nn.TransformerDecoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4 * n_emb,
dropout=p_drop_attn,
activation="gelu",
batch_first=True,
norm_first=True,
)
self.shared_decoder = nn.TransformerDecoder(
decoder_layer=decoder_layer,
num_layers=n_shared_layers,
)
self.u_decoder = nn.TransformerDecoder(
decoder_layer=decoder_layer,
num_layers=n_head_layers,
)
self.v_decoder = nn.TransformerDecoder(
decoder_layer=decoder_layer,
num_layers=n_head_layers,
)
if causal_attn:
sz = T
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0))
self.register_buffer("mask", mask)
if obs_as_cond:
q_idx, c_idx = torch.meshgrid(
torch.arange(T),
torch.arange(T_cond),
indexing="ij",
)
obs_offset = n_global_cond_tokens
visible = c_idx < obs_offset
visible = visible | (q_idx >= (c_idx - obs_offset))
memory_mask = visible.float().masked_fill(~visible, float("-inf")).masked_fill(visible, float(0.0))
self.register_buffer("memory_mask", memory_mask)
else:
self.memory_mask = None
else:
self.mask = None
self.memory_mask = None
self.ln_u = nn.LayerNorm(n_emb)
self.ln_v = nn.LayerNorm(n_emb)
self.head_u = nn.Linear(n_emb, output_dim)
self.head_v = nn.Linear(n_emb, output_dim)
self.T = T
self.T_cond = T_cond
self.horizon = horizon
self.n_obs_steps = n_obs_steps
self.obs_as_cond = obs_as_cond
self.n_global_cond_tokens = n_global_cond_tokens
self.n_time_tokens = n_time_tokens
self.n_layer = n_layer
self.n_head_layers = n_head_layers
self.n_shared_layers = n_shared_layers
self.apply(self._init_weights)
logger.info(
"number of parameters: %e", sum(p.numel() for p in self.parameters())
)
logger.info(
"PMFTransformerForDiffusion layers: shared=%d u_head=%d v_head=%d",
self.n_shared_layers,
self.n_head_layers,
self.n_head_layers,
)
def _init_weights(self, module):
ignore_types = (
nn.Dropout,
SinusoidalPosEmb,
nn.TransformerEncoderLayer,
nn.TransformerDecoderLayer,
nn.TransformerEncoder,
nn.TransformerDecoder,
nn.ModuleList,
nn.Mish,
nn.Sequential,
)
if isinstance(module, (nn.Linear, nn.Embedding)):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.MultiheadAttention):
for name in ("in_proj_weight", "q_proj_weight", "k_proj_weight", "v_proj_weight"):
weight = getattr(module, name)
if weight is not None:
torch.nn.init.normal_(weight, mean=0.0, std=0.02)
for name in ("in_proj_bias", "bias_k", "bias_v"):
bias = getattr(module, name)
if bias is not None:
torch.nn.init.zeros_(bias)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight)
elif isinstance(module, PMFTransformerForDiffusion):
torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02)
torch.nn.init.normal_(module.t_tokens, mean=0.0, std=0.02)
torch.nn.init.normal_(module.r_tokens, mean=0.0, std=0.02)
elif isinstance(module, ignore_types):
pass
else:
raise RuntimeError("Unaccounted module {}".format(module))
def get_optim_groups(self, weight_decay: float = 1e-3):
decay = set()
no_decay = set()
whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention)
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in self.named_modules():
for pn, _ in m.named_parameters():
fpn = "%s.%s" % (mn, pn) if mn else pn
if pn.endswith("bias"):
no_decay.add(fpn)
elif pn.startswith("bias"):
no_decay.add(fpn)
elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules):
decay.add(fpn)
elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules):
no_decay.add(fpn)
no_decay.update(
{
"pos_emb",
"cond_pos_emb",
"t_tokens",
"r_tokens",
"_dummy_variable",
}
)
param_dict = {pn: p for pn, p in self.named_parameters()}
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),)
assert len(param_dict.keys() - union_params) == 0, (
"parameters %s were not separated into either decay/no_decay set!" % (str(param_dict.keys() - union_params),)
)
return [
{
"params": [param_dict[pn] for pn in sorted(list(decay))],
"weight_decay": weight_decay,
},
{
"params": [param_dict[pn] for pn in sorted(list(no_decay))],
"weight_decay": 0.0,
},
]
def configure_optimizers(
self,
learning_rate: float = 1e-4,
weight_decay: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.95),
):
optim_groups = self.get_optim_groups(weight_decay=weight_decay)
return torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
def _broadcast_time(self, value: Union[torch.Tensor, float, int], batch_size: int, device: torch.device):
if not torch.is_tensor(value):
value = torch.tensor([value], dtype=torch.float32, device=device)
elif value.ndim == 0:
value = value[None].to(device=device, dtype=torch.float32)
else:
value = value.to(device=device, dtype=torch.float32)
return value.expand(batch_size)
def forward(
self,
sample: torch.Tensor,
t: Union[torch.Tensor, float, int],
r: Union[torch.Tensor, float, int],
cond: Optional[torch.Tensor] = None,
):
batch_size = sample.shape[0]
device = sample.device
t = self._broadcast_time(t, batch_size, device)
r = self._broadcast_time(r, batch_size, device)
input_emb = self.input_emb(sample)
t_cond = self.t_tokens + self.t_emb(t).unsqueeze(1)
r_cond = self.r_tokens + self.r_emb(r).unsqueeze(1)
cond_embeddings = [t_cond, r_cond]
if self.obs_as_cond:
cond_embeddings.append(self.cond_obs_emb(cond))
cond_embeddings = torch.cat(cond_embeddings, dim=1)
cond_pos = self.cond_pos_emb[:, : cond_embeddings.shape[1], :]
memory = self.drop(cond_embeddings + cond_pos)
memory = self.encoder(memory)
token_pos = self.pos_emb[:, : input_emb.shape[1], :]
x = self.drop(input_emb + token_pos)
shared_x = self.shared_decoder(
tgt=x,
memory=memory,
tgt_mask=self.mask,
memory_mask=self.memory_mask,
)
u_x = self.u_decoder(
tgt=shared_x,
memory=memory,
tgt_mask=self.mask,
memory_mask=self.memory_mask,
)
v_x = self.v_decoder(
tgt=shared_x,
memory=memory,
tgt_mask=self.mask,
memory_mask=self.memory_mask,
)
return self.head_u(self.ln_u(u_x)), self.head_v(self.ln_v(v_x))

View File

@@ -256,8 +256,8 @@ class DiffusionTransformerHybridImagePolicy(BaseImagePolicy):
# condition through impainting # condition through impainting
this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:])) this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
nobs_features = self.obs_encoder(this_nobs) nobs_features = self.obs_encoder(this_nobs)
# reshape back to B, T, Do # reshape back to B, To, Do
nobs_features = nobs_features.reshape(B, T, -1) nobs_features = nobs_features.reshape(B, To, -1)
shape = (B, T, Da+Do) shape = (B, T, Da+Do)
cond_data = torch.zeros(size=shape, device=device, dtype=dtype) cond_data = torch.zeros(size=shape, device=device, dtype=dtype)
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool) cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)

View File

@@ -247,7 +247,7 @@ class DiffusionUnetHybridImagePolicy(BaseImagePolicy):
# condition through impainting # condition through impainting
this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:])) this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
nobs_features = self.obs_encoder(this_nobs) nobs_features = self.obs_encoder(this_nobs)
# reshape back to B, T, Do # reshape back to B, To, Do
nobs_features = nobs_features.reshape(B, To, -1) nobs_features = nobs_features.reshape(B, To, -1)
cond_data = torch.zeros(size=(B, T, Da+Do), device=device, dtype=dtype) cond_data = torch.zeros(size=(B, T, Da+Do), device=device, dtype=dtype)
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool) cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)

View File

@@ -0,0 +1,455 @@
from typing import Dict, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import reduce
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
import diffusion_policy.model.vision.crop_randomizer as dmvc
import robomimic.models.base_nets as rmbn
import robomimic.utils.obs_utils as ObsUtils
from diffusion_policy.common.pytorch_util import dict_apply, replace_submodules
from diffusion_policy.common.robomimic_config_util import get_robomimic_config
from diffusion_policy.model.common.normalizer import LinearNormalizer
from diffusion_policy.model.diffusion.mask_generator import LowdimMaskGenerator
from diffusion_policy.model.diffusion.pmf_transformer_for_diffusion import (
PMFTransformerForDiffusion,
)
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
from robomimic.algo import algo_factory
from robomimic.algo.algo import PolicyAlgo
class PMFTransformerHybridImagePolicy(BaseImagePolicy):
def __init__(
self,
shape_meta: dict,
noise_scheduler: DDPMScheduler,
horizon,
n_action_steps,
n_obs_steps,
num_inference_steps=None,
crop_shape=(76, 76),
obs_encoder_group_norm=False,
eval_fixed_crop=False,
n_layer=8,
n_cond_layers=0,
n_head=4,
n_emb=256,
p_drop_emb=0.0,
p_drop_attn=0.0,
causal_attn=True,
obs_as_cond=True,
pred_action_steps_only=False,
n_time_tokens=4,
n_head_layers=4,
min_time=0.05,
du_dt_epsilon=1.0e-3,
pmf_u_loss_weight=1.0,
pmf_v_loss_weight=1.0,
noise_scale=1.0,
adatloss_eps=0.01,
p_mean=-0.4,
p_std=1.0,
tr_uniform=True,
tr_uniform_prob=0.1,
data_proportion=0.5,
**kwargs,
):
super().__init__()
action_shape = shape_meta["action"]["shape"]
assert len(action_shape) == 1
action_dim = action_shape[0]
obs_shape_meta = shape_meta["obs"]
obs_config = {
"low_dim": [],
"rgb": [],
"depth": [],
"scan": [],
}
obs_key_shapes = dict()
for key, attr in obs_shape_meta.items():
shape = attr["shape"]
obs_key_shapes[key] = list(shape)
obs_type = attr.get("type", "low_dim")
if obs_type == "rgb":
obs_config["rgb"].append(key)
elif obs_type == "low_dim":
obs_config["low_dim"].append(key)
else:
raise RuntimeError(f"Unsupported obs type: {obs_type}")
config = get_robomimic_config(
algo_name="bc_rnn",
hdf5_type="image",
task_name="square",
dataset_type="ph",
)
with config.unlocked():
config.observation.modalities.obs = obs_config
if crop_shape is None:
for _, modality in config.observation.encoder.items():
if modality.obs_randomizer_class == "CropRandomizer":
modality["obs_randomizer_class"] = None
else:
crop_h, crop_w = crop_shape
for _, modality in config.observation.encoder.items():
if modality.obs_randomizer_class == "CropRandomizer":
modality.obs_randomizer_kwargs.crop_height = crop_h
modality.obs_randomizer_kwargs.crop_width = crop_w
ObsUtils.initialize_obs_utils_with_config(config)
policy: PolicyAlgo = algo_factory(
algo_name=config.algo_name,
config=config,
obs_key_shapes=obs_key_shapes,
ac_dim=action_dim,
device="cpu",
)
obs_encoder = policy.nets["policy"].nets["encoder"].nets["obs"]
if obs_encoder_group_norm:
replace_submodules(
root_module=obs_encoder,
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(
num_groups=x.num_features // 16,
num_channels=x.num_features,
),
)
if eval_fixed_crop:
replace_submodules(
root_module=obs_encoder,
predicate=lambda x: isinstance(x, rmbn.CropRandomizer),
func=lambda x: dmvc.CropRandomizer(
input_shape=x.input_shape,
crop_height=x.crop_height,
crop_width=x.crop_width,
num_crops=x.num_crops,
pos_enc=x.pos_enc,
),
)
obs_feature_dim = obs_encoder.output_shape()[0]
input_dim = action_dim if obs_as_cond else (obs_feature_dim + action_dim)
cond_dim = obs_feature_dim if obs_as_cond else 0
self.obs_encoder = obs_encoder
self.model = PMFTransformerForDiffusion(
input_dim=input_dim,
output_dim=input_dim,
horizon=horizon if not pred_action_steps_only else n_action_steps,
n_obs_steps=n_obs_steps,
cond_dim=cond_dim,
n_layer=n_layer,
n_head=n_head,
n_emb=n_emb,
p_drop_emb=p_drop_emb,
p_drop_attn=p_drop_attn,
causal_attn=causal_attn,
obs_as_cond=obs_as_cond,
n_cond_layers=n_cond_layers,
n_time_tokens=n_time_tokens,
n_head_layers=n_head_layers,
)
self.noise_scheduler = noise_scheduler
self.mask_generator = LowdimMaskGenerator(
action_dim=action_dim,
obs_dim=0 if obs_as_cond else obs_feature_dim,
max_n_obs_steps=n_obs_steps,
fix_obs_steps=True,
action_visible=False,
)
self.normalizer = LinearNormalizer()
self.horizon = horizon
self.obs_feature_dim = obs_feature_dim
self.action_dim = action_dim
self.n_action_steps = n_action_steps
self.n_obs_steps = n_obs_steps
self.obs_as_cond = obs_as_cond
self.pred_action_steps_only = pred_action_steps_only
self.min_time = min_time
self.du_dt_epsilon = du_dt_epsilon
self.pmf_u_loss_weight = pmf_u_loss_weight
self.pmf_v_loss_weight = pmf_v_loss_weight
self.noise_scale = noise_scale
self.adatloss_eps = adatloss_eps
self.p_mean = p_mean
self.p_std = p_std
self.tr_uniform = tr_uniform
self.tr_uniform_prob = tr_uniform_prob
self.data_proportion = data_proportion
self.kwargs = kwargs
if num_inference_steps is None:
num_inference_steps = noise_scheduler.config.num_train_timesteps
self.num_inference_steps = num_inference_steps
def _encode_obs(self, nobs: Dict[str, torch.Tensor], n_steps: int) -> torch.Tensor:
flat_nobs = dict_apply(nobs, lambda x: x[:, :n_steps, ...].reshape(-1, *x.shape[2:]))
nobs_features = self.obs_encoder(flat_nobs)
return nobs_features.reshape(next(iter(nobs.values())).shape[0], n_steps, -1)
def _time_view(self, value: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
return value.reshape(value.shape[0], *([1] * (ref.ndim - 1)))
def _adatloss(self, loss: torch.Tensor) -> torch.Tensor:
denom = loss.detach() + self.adatloss_eps
return loss / denom
def _sample_logit_normal(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
normal = torch.randn(batch_size, device=device, dtype=dtype)
return torch.sigmoid(normal * self.p_std + self.p_mean)
def _sample_tr(self, batch_size: int, device: torch.device, dtype: torch.dtype):
t = self._sample_logit_normal(batch_size, device, dtype)
r = self._sample_logit_normal(batch_size, device, dtype)
if self.tr_uniform:
uniform_mask = torch.rand(batch_size, device=device) < self.tr_uniform_prob
uniform_t = torch.rand(batch_size, device=device, dtype=dtype)
uniform_r = torch.rand(batch_size, device=device, dtype=dtype)
t = torch.where(uniform_mask, uniform_t, t)
r = torch.where(uniform_mask, uniform_r, r)
data_size = int(batch_size * self.data_proportion)
fm_mask = torch.arange(batch_size, device=device) < data_size
r = torch.where(fm_mask, t, r)
t_final = torch.maximum(t, r)
r_final = torch.minimum(t, r)
return t_final, r_final
def _trajectory_inputs(
self,
nobs: Dict[str, torch.Tensor],
nactions: torch.Tensor,
):
batch_size = nactions.shape[0]
horizon = nactions.shape[1]
cond = None
trajectory = nactions
if self.obs_as_cond:
cond = self._encode_obs(nobs, self.n_obs_steps)
if self.pred_action_steps_only:
start = self.n_obs_steps - 1
end = start + self.n_action_steps
trajectory = nactions[:, start:end]
else:
nobs_features = self._encode_obs(nobs, horizon)
trajectory = torch.cat([nactions, nobs_features], dim=-1).detach()
if self.pred_action_steps_only:
condition_mask = torch.zeros_like(trajectory, dtype=torch.bool)
else:
condition_mask = self.mask_generator(trajectory.shape)
return batch_size, trajectory, cond, condition_mask
def _apply_conditioning(
self,
sample: torch.Tensor,
condition_data: torch.Tensor,
condition_mask: torch.Tensor,
) -> torch.Tensor:
if not condition_mask.any():
return sample
return torch.where(condition_mask, condition_data, sample)
def _compute_u_v(
self,
sample: torch.Tensor,
t: torch.Tensor,
r: torch.Tensor,
cond: torch.Tensor,
):
x_hat_u, x_hat_v = self.model(sample, t, r, cond)
denom = self._time_view(t, sample)
u = (sample - x_hat_u) / denom
v = (sample - x_hat_v) / denom
return u, v
def _compute_du_dt(
self,
sample: torch.Tensor,
t: torch.Tensor,
r: torch.Tensor,
cond: torch.Tensor,
condition_data: torch.Tensor,
condition_mask: torch.Tensor,
tangent_v: torch.Tensor,
) -> torch.Tensor:
tangent_sample = tangent_v.detach()
tangent_r = torch.zeros_like(r)
tangent_t = torch.ones_like(t)
def u_fn(sample_input, r_input, t_input):
conditioned_sample = self._apply_conditioning(
sample_input, condition_data, condition_mask
)
u_value, _ = self._compute_u_v(conditioned_sample, t_input, r_input, cond)
return u_value
primals = (sample, r, t)
tangents = (tangent_sample, tangent_r, tangent_t)
try:
_, du_dt = torch.func.jvp(u_fn, primals, tangents)
except (AttributeError, NotImplementedError, RuntimeError):
_, du_dt = torch.autograd.functional.jvp(
u_fn,
primals,
tangents,
create_graph=False,
strict=False,
)
return du_dt
# ========= inference ============
def conditional_sample(
self,
condition_data,
condition_mask,
cond=None,
generator=None,
**kwargs,
):
del kwargs
trajectory = torch.randn(
size=condition_data.shape,
dtype=condition_data.dtype,
device=condition_data.device,
generator=generator,
) * self.noise_scale
time_steps = torch.linspace(
1.0,
0.0,
self.num_inference_steps + 1,
dtype=trajectory.dtype,
device=trajectory.device,
)
for step_idx in range(self.num_inference_steps):
trajectory = self._apply_conditioning(trajectory, condition_data, condition_mask)
t = time_steps[step_idx].expand(trajectory.shape[0])
r = time_steps[step_idx + 1].expand(trajectory.shape[0])
u, _ = self._compute_u_v(trajectory, t, r, cond)
delta = self._time_view(t - r, trajectory)
trajectory = trajectory - delta * u
trajectory = self._apply_conditioning(trajectory, condition_data, condition_mask)
return trajectory
def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
assert "past_action" not in obs_dict
nobs = self.normalizer.normalize(obs_dict)
value = next(iter(nobs.values()))
batch_size, to_steps = value.shape[:2]
horizon = self.horizon
action_dim = self.action_dim
device = self.device
dtype = self.dtype
cond = None
if self.obs_as_cond:
cond = self._encode_obs(nobs, self.n_obs_steps)
shape = (batch_size, horizon, action_dim)
if self.pred_action_steps_only:
shape = (batch_size, self.n_action_steps, action_dim)
cond_data = torch.zeros(size=shape, device=device, dtype=dtype)
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
else:
nobs_features = self._encode_obs(nobs, self.n_obs_steps)
shape = (batch_size, horizon, action_dim + self.obs_feature_dim)
cond_data = torch.zeros(size=shape, device=device, dtype=dtype)
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
cond_data[:, : self.n_obs_steps, action_dim:] = nobs_features
cond_mask[:, : self.n_obs_steps, action_dim:] = True
nsample = self.conditional_sample(
cond_data,
cond_mask,
cond=cond,
**self.kwargs,
)
naction_pred = nsample[..., :action_dim]
action_pred = self.normalizer["action"].unnormalize(naction_pred)
if self.pred_action_steps_only:
action = action_pred
else:
start = to_steps - 1
end = start + self.n_action_steps
action = action_pred[:, start:end]
return {
"action": action,
"action_pred": action_pred,
}
# ========= training ============
def set_normalizer(self, normalizer: LinearNormalizer):
self.normalizer.load_state_dict(normalizer.state_dict())
def get_optimizer(
self,
transformer_weight_decay: float,
obs_encoder_weight_decay: float,
learning_rate: float,
betas: Tuple[float, float],
) -> torch.optim.Optimizer:
optim_groups = self.model.get_optim_groups(weight_decay=transformer_weight_decay)
optim_groups.append(
{
"params": self.obs_encoder.parameters(),
"weight_decay": obs_encoder_weight_decay,
}
)
return torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
def compute_loss(self, batch):
assert "valid_mask" not in batch
nobs = self.normalizer.normalize(batch["obs"])
nactions = self.normalizer["action"].normalize(batch["action"])
_, trajectory, cond, condition_mask = self._trajectory_inputs(nobs, nactions)
noise = torch.randn_like(trajectory) * self.noise_scale
batch_size = trajectory.shape[0]
t, r = self._sample_tr(
batch_size, device=trajectory.device, dtype=trajectory.dtype
)
z_t = (1 - self._time_view(t, trajectory)) * trajectory + self._time_view(t, trajectory) * noise
z_t = self._apply_conditioning(z_t, trajectory, condition_mask)
loss_mask = ~condition_mask
target_v = noise - trajectory
u, v = self._compute_u_v(z_t, t, r, cond)
du_dt = self._compute_du_dt(
sample=z_t,
t=t,
r=r,
cond=cond,
condition_data=trajectory,
condition_mask=condition_mask,
tangent_v=v,
)
pmf_velocity = u + self._time_view(t - r, trajectory) * du_dt.detach()
loss_u = F.mse_loss(pmf_velocity, target_v, reduction="none")
loss_v = F.mse_loss(v, target_v, reduction="none")
loss_u = loss_u * loss_mask.type(loss_u.dtype)
loss_v = loss_v * loss_mask.type(loss_v.dtype)
loss_u = reduce(loss_u, "b ... -> b (...)", "mean").mean()
loss_v = reduce(loss_v, "b ... -> b (...)", "mean").mean()
loss_u = self._adatloss(loss_u)
loss_v = self._adatloss(loss_v)
return self.pmf_u_loss_weight * loss_u + self.pmf_v_loss_weight * loss_v

64
eval.py Normal file
View File

@@ -0,0 +1,64 @@
"""
Usage:
python eval.py --checkpoint data/image/pusht/diffusion_policy_cnn/train_0/checkpoints/latest.ckpt -o data/pusht_eval_output
"""
import sys
# use line-buffering for both stdout and stderr
sys.stdout = open(sys.stdout.fileno(), mode='w', buffering=1)
sys.stderr = open(sys.stderr.fileno(), mode='w', buffering=1)
import os
import pathlib
import click
import hydra
import torch
import dill
import wandb
import json
from diffusion_policy.workspace.base_workspace import BaseWorkspace
@click.command()
@click.option('-c', '--checkpoint', required=True)
@click.option('-o', '--output_dir', required=True)
@click.option('-d', '--device', default='cuda:0')
def main(checkpoint, output_dir, device):
if os.path.exists(output_dir):
click.confirm(f"Output path {output_dir} already exists! Overwrite?", abort=True)
pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)
# load checkpoint
payload = torch.load(open(checkpoint, 'rb'), pickle_module=dill)
cfg = payload['cfg']
cls = hydra.utils.get_class(cfg._target_)
workspace = cls(cfg, output_dir=output_dir)
workspace: BaseWorkspace
workspace.load_payload(payload, exclude_keys=None, include_keys=None)
# get policy from workspace
policy = workspace.model
if cfg.training.use_ema:
policy = workspace.ema_model
device = torch.device(device)
policy.to(device)
policy.eval()
# run eval
env_runner = hydra.utils.instantiate(
cfg.task.env_runner,
output_dir=output_dir)
runner_log = env_runner.run(policy)
# dump log to json
json_log = dict()
for key, value in runner_log.items():
if isinstance(value, wandb.sdk.data_types.video.Video):
json_log[key] = value._path
else:
json_log[key] = value
out_path = os.path.join(output_dir, 'eval_log.json')
json.dump(json_log, open(out_path, 'w'), indent=2, sort_keys=True)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,182 @@
_target_: diffusion_policy.workspace.train_diffusion_unet_hybrid_workspace.TrainDiffusionUnetHybridWorkspace
checkpoint:
save_last_ckpt: true
save_last_snapshot: false
topk:
format_str: epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt
k: 5
mode: max
monitor_key: test_mean_score
dataloader:
batch_size: 64
num_workers: 8
persistent_workers: false
pin_memory: true
shuffle: true
dataset_obs_steps: 2
ema:
_target_: diffusion_policy.model.diffusion.ema_model.EMAModel
inv_gamma: 1.0
max_value: 0.9999
min_value: 0.0
power: 0.75
update_after_step: 0
exp_name: default
horizon: 16
keypoint_visible_rate: 1.0
logging:
group: null
id: null
mode: online
name: 2023.01.16-20.20.06_train_diffusion_unet_hybrid_pusht_image
project: diffusion_policy_debug
resume: true
tags:
- train_diffusion_unet_hybrid
- pusht_image
- default
multi_run:
run_dir: data/outputs/2023.01.16/20.20.06_train_diffusion_unet_hybrid_pusht_image
wandb_name_base: 2023.01.16-20.20.06_train_diffusion_unet_hybrid_pusht_image
n_action_steps: 8
n_latency_steps: 0
n_obs_steps: 2
name: train_diffusion_unet_hybrid
obs_as_global_cond: true
optimizer:
_target_: torch.optim.AdamW
betas:
- 0.95
- 0.999
eps: 1.0e-08
lr: 0.0001
weight_decay: 1.0e-06
past_action_visible: false
policy:
_target_: diffusion_policy.policy.diffusion_unet_hybrid_image_policy.DiffusionUnetHybridImagePolicy
cond_predict_scale: true
crop_shape:
- 84
- 84
diffusion_step_embed_dim: 128
down_dims:
- 512
- 1024
- 2048
eval_fixed_crop: true
horizon: 16
kernel_size: 5
n_action_steps: 8
n_groups: 8
n_obs_steps: 2
noise_scheduler:
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
beta_end: 0.02
beta_schedule: squaredcos_cap_v2
beta_start: 0.0001
clip_sample: true
num_train_timesteps: 100
prediction_type: epsilon
variance_type: fixed_small
num_inference_steps: 100
obs_as_global_cond: true
obs_encoder_group_norm: true
shape_meta:
action:
shape:
- 2
obs:
agent_pos:
shape:
- 2
type: low_dim
image:
shape:
- 3
- 96
- 96
type: rgb
shape_meta:
action:
shape:
- 2
obs:
agent_pos:
shape:
- 2
type: low_dim
image:
shape:
- 3
- 96
- 96
type: rgb
task:
dataset:
_target_: diffusion_policy.dataset.pusht_image_dataset.PushTImageDataset
horizon: 16
max_train_episodes: 90
pad_after: 7
pad_before: 1
seed: 42
val_ratio: 0.02
zarr_path: data/pusht/pusht_cchi_v7_replay.zarr
env_runner:
_target_: diffusion_policy.env_runner.pusht_image_runner.PushTImageRunner
fps: 10
legacy_test: true
max_steps: 300
n_action_steps: 8
n_envs: null
n_obs_steps: 2
n_test: 50
n_test_vis: 4
n_train: 6
n_train_vis: 2
past_action: false
test_start_seed: 100000
train_start_seed: 0
image_shape:
- 3
- 96
- 96
name: pusht_image
shape_meta:
action:
shape:
- 2
obs:
agent_pos:
shape:
- 2
type: low_dim
image:
shape:
- 3
- 96
- 96
type: rgb
task_name: pusht_image
training:
checkpoint_every: 50
debug: false
device: cuda:0
gradient_accumulate_every: 1
lr_scheduler: cosine
lr_warmup_steps: 500
max_train_steps: null
max_val_steps: null
num_epochs: 3050
resume: true
rollout_every: 50
sample_every: 5
seed: 42
tqdm_interval_sec: 1.0
use_ema: true
val_every: 1
val_dataloader:
batch_size: 64
num_workers: 8
persistent_workers: false
pin_memory: true
shuffle: false

View File

@@ -0,0 +1,190 @@
_target_: diffusion_policy.workspace.train_diffusion_transformer_hybrid_workspace.TrainDiffusionTransformerHybridWorkspace
checkpoint:
save_last_ckpt: true
save_last_snapshot: false
topk:
format_str: epoch={epoch:04d}-train_loss={train_loss:.3f}.ckpt
k: 5
mode: min
monitor_key: train_loss
dataloader:
batch_size: 64
num_workers: 8
persistent_workers: false
pin_memory: true
shuffle: true
dataset_obs_steps: 2
ema:
_target_: diffusion_policy.model.diffusion.ema_model.EMAModel
inv_gamma: 1.0
max_value: 0.9999
min_value: 0.0
power: 0.75
update_after_step: 0
exp_name: default
horizon: 16
keypoint_visible_rate: 1.0
logging:
group: null
id: null
mode: online
name: ${now:%Y.%m.%d-%H.%M.%S}_train_diffusion_transformer_hybrid_pmf_pusht_image
project: diffusion_policy_debug
resume: true
tags:
- train_diffusion_transformer_hybrid_pmf
- pusht_image
- default
multi_run:
run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_train_diffusion_transformer_hybrid_pmf_pusht_image
wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_train_diffusion_transformer_hybrid_pmf_pusht_image
n_action_steps: 8
n_latency_steps: 0
n_obs_steps: 2
name: train_diffusion_transformer_hybrid_pmf
obs_as_cond: true
optimizer:
betas:
- 0.9
- 0.95
learning_rate: 0.0001
obs_encoder_weight_decay: 1.0e-06
transformer_weight_decay: 0.001
past_action_visible: false
policy:
_target_: diffusion_policy.policy.pmf_transformer_hybrid_image_policy.PMFTransformerHybridImagePolicy
crop_shape:
- 84
- 84
eval_fixed_crop: true
horizon: 16
n_action_steps: 8
n_cond_layers: 0
n_emb: 256
n_head: 4
n_layer: 12
n_head_layers: 4
n_obs_steps: 2
n_time_tokens: 4
noise_scale: 1.0
adatloss_eps: 0.01
p_mean: -0.4
p_std: 1.0
noise_scheduler:
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
beta_end: 0.02
beta_schedule: squaredcos_cap_v2
beta_start: 0.0001
clip_sample: true
num_train_timesteps: 100
prediction_type: sample
variance_type: fixed_small
num_inference_steps: 1
obs_as_cond: true
obs_encoder_group_norm: true
p_drop_attn: 0.0
p_drop_emb: 0.0
pmf_u_loss_weight: 1.0
pmf_v_loss_weight: 1.0
tr_uniform: true
tr_uniform_prob: 0.1
data_proportion: 0.5
shape_meta:
action:
shape:
- 2
obs:
agent_pos:
shape:
- 2
type: low_dim
image:
shape:
- 3
- 96
- 96
type: rgb
shape_meta:
action:
shape:
- 2
obs:
agent_pos:
shape:
- 2
type: low_dim
image:
shape:
- 3
- 96
- 96
type: rgb
task:
dataset:
_target_: diffusion_policy.dataset.pusht_image_dataset.PushTImageDataset
horizon: 16
max_train_episodes: 90
pad_after: 7
pad_before: 1
seed: 42
val_ratio: 0.02
zarr_path: data/pusht/pusht_cchi_v7_replay.zarr
env_runner:
_target_: diffusion_policy.env_runner.pusht_image_runner.PushTImageRunner
fps: 10
legacy_test: true
max_steps: 300
n_action_steps: 8
n_envs: null
n_obs_steps: 2
n_test: 50
n_test_vis: 4
n_train: 6
n_train_vis: 2
past_action: false
test_start_seed: 100000
train_start_seed: 0
image_shape:
- 3
- 96
- 96
name: pusht_image
shape_meta:
action:
shape:
- 2
obs:
agent_pos:
shape:
- 2
type: low_dim
image:
shape:
- 3
- 96
- 96
type: rgb
task_name: pusht_image
training:
checkpoint_every: 50
debug: false
device: cuda:0
gradient_accumulate_every: 1
lr_scheduler: cosine
lr_warmup_steps: 500
max_train_steps: null
max_val_steps: null
num_epochs: 600
resume: true
rollout_every: 50
sample_every: 5
seed: 42
tqdm_interval_sec: 1.0
use_ema: true
val_every: 1
val_dataloader:
batch_size: 64
num_workers: 8
persistent_workers: false
pin_memory: true
shuffle: false

View File

@@ -0,0 +1,39 @@
# Direct package pins for the canonical PushT image workflow on host 5090.
# Torch/TorchVision/Torchaudio are installed separately from the cu128 index in setup_uv_pusht_5090.sh.
numpy==1.26.4
scipy==1.11.4
numba==0.59.1
llvmlite==0.42.0
cffi==1.15.1
cython==0.29.32
h5py==3.8.0
pandas==2.2.3
zarr==2.12.0
numcodecs==0.10.2
hydra-core==1.2.0
einops==0.4.1
tqdm==4.64.1
dill==0.3.5.1
scikit-video==1.1.11
scikit-image==0.19.3
gym==0.23.1
pymunk==6.2.1
wandb==0.13.3
threadpoolctl==3.1.0
shapely==1.8.5.post1
matplotlib==3.6.1
imageio==2.22.0
imageio-ffmpeg==0.4.7
termcolor==2.0.1
tensorboard==2.10.1
tensorboardx==2.5.1
psutil==7.2.2
click==8.1.8
boto3==1.24.96
diffusers==0.11.1
huggingface-hub==0.10.1
av==14.0.1
pygame==2.5.2
robomimic==0.2.0
opencv-python-headless==4.10.0.84

20
setup_uv_pusht_5090.sh Executable file
View File

@@ -0,0 +1,20 @@
#!/usr/bin/env bash
set -euo pipefail
ROOT_DIR="$(cd "$(dirname "$0")" && pwd)"
cd "$ROOT_DIR"
export UV_CACHE_DIR="${UV_CACHE_DIR:-$ROOT_DIR/.uv-cache}"
export UV_PYTHON_INSTALL_DIR="${UV_PYTHON_INSTALL_DIR:-$ROOT_DIR/.uv-python}"
uv venv --python 3.9 .venv
source .venv/bin/activate
uv pip install --upgrade pip wheel setuptools==80.9.0
uv pip install --python .venv/bin/python \
--index-url https://download.pytorch.org/whl/cu128 \
torch==2.8.0+cu128 torchvision==0.23.0+cu128 torchaudio==2.8.0+cu128
uv pip install --python .venv/bin/python -r requirements-pusht-5090.txt
uv pip install --python .venv/bin/python -e .
echo "uv environment ready at $ROOT_DIR/.venv"