From 5ba07ac6661db573af695b419a7947ecb704690f Mon Sep 17 00:00:00 2001 From: Yihuai Gao Date: Tue, 24 Dec 2024 11:47:59 -0800 Subject: [PATCH] Done adapting mujoco image dataset --- diffusion_policy/common/replay_buffer.py | 2 + .../dataset/mujoco_image_dataset.py | 109 +++++++++++ image_pusht_diffusion_policy_cnn.yaml | 182 ++++++++++++++++++ 3 files changed, 293 insertions(+) create mode 100644 diffusion_policy/dataset/mujoco_image_dataset.py create mode 100644 image_pusht_diffusion_policy_cnn.yaml diff --git a/diffusion_policy/common/replay_buffer.py b/diffusion_policy/common/replay_buffer.py index 022a704..57b5aca 100644 --- a/diffusion_policy/common/replay_buffer.py +++ b/diffusion_policy/common/replay_buffer.py @@ -158,6 +158,8 @@ class ReplayBuffer: # numpy backend meta = dict() for key, value in src_root['meta'].items(): + if isinstance(value, zarr.Group): + continue if len(value.shape) == 0: meta[key] = np.array(value) else: diff --git a/diffusion_policy/dataset/mujoco_image_dataset.py b/diffusion_policy/dataset/mujoco_image_dataset.py new file mode 100644 index 0000000..f51aa9d --- /dev/null +++ b/diffusion_policy/dataset/mujoco_image_dataset.py @@ -0,0 +1,109 @@ +from typing import Dict +import torch +import numpy as np +import copy +from diffusion_policy.common.pytorch_util import dict_apply +from diffusion_policy.common.replay_buffer import ReplayBuffer +from diffusion_policy.common.sampler import ( + SequenceSampler, get_val_mask, downsample_mask) +from diffusion_policy.model.common.normalizer import LinearNormalizer +from diffusion_policy.dataset.base_dataset import BaseImageDataset +from diffusion_policy.common.normalize_util import get_image_range_normalizer + +class MujocoImageDataset(BaseImageDataset): + def __init__(self, + zarr_path, + horizon=1, + pad_before=0, + pad_after=0, + seed=42, + val_ratio=0.0, + max_train_episodes=None + ): + + super().__init__() + self.replay_buffer = ReplayBuffer.copy_from_path( + # zarr_path, keys=['img', 'state', 'action']) + zarr_path, keys=['robot_0_camera_images', 'robot_0_tcp_xyz_wxyz', 'robot_0_gripper_width', 'action_0_tcp_xyz_wxyz', 'action_0_gripper_width']) + val_mask = get_val_mask( + n_episodes=self.replay_buffer.n_episodes, + val_ratio=val_ratio, + seed=seed) + train_mask = ~val_mask + train_mask = downsample_mask( + mask=train_mask, + max_n=max_train_episodes, + seed=seed) + + self.sampler = SequenceSampler( + replay_buffer=self.replay_buffer, + sequence_length=horizon, + pad_before=pad_before, + pad_after=pad_after, + episode_mask=train_mask) + self.train_mask = train_mask + self.horizon = horizon + self.pad_before = pad_before + self.pad_after = pad_after + + def get_validation_dataset(self): + val_set = copy.copy(self) + val_set.sampler = SequenceSampler( + replay_buffer=self.replay_buffer, + sequence_length=self.horizon, + pad_before=self.pad_before, + pad_after=self.pad_after, + episode_mask=~self.train_mask + ) + val_set.train_mask = ~self.train_mask + return val_set + + def get_normalizer(self, mode='limits', **kwargs): + data = { + 'action': np.concatenate([self.replay_buffer['action_0_tcp_xyz_wxyz'], self.replay_buffer['action_0_gripper_width']], axis=-1), + 'agent_pos': np.concatenate([self.replay_buffer['robot_0_tcp_xyz_wxyz'], self.replay_buffer['robot_0_tcp_xyz_wxyz']], axis=-1) + } + normalizer = LinearNormalizer() + normalizer.fit(data=data, last_n_dims=1, mode=mode, **kwargs) + normalizer['image'] = get_image_range_normalizer() + return normalizer + + def __len__(self) -> int: + return len(self.sampler) + + def _sample_to_data(self, sample): + # agent_pos = sample['state'][:,:2].astype(np.float32) # (agent_posx2, block_posex3) + agent_pos = np.concatenate([sample['robot_0_tcp_xyz_wxyz'], sample['robot_0_gripper_width']], axis=-1).astype(np.float32) + agent_action = np.concatenate([sample['action_0_tcp_xyz_wxyz'], sample['action_0_gripper_width']], axis=-1).astype(np.float32) + # image = np.moveaxis(sample['img'],-1,1)/255 + image = np.moveaxis(sample['robot_0_camera_images'].astype(np.float32).squeeze(1),-1,1)/255 + + data = { + 'obs': { + 'image': image, # T, 3, 224, 224 + 'agent_pos': agent_pos, # T, 8 (x,y,z,qx,qy,qz,qw,gripper_width) + }, + 'action': agent_action # T, 8 (x,y,z,qx,qy,qz,qw,gripper_width) + } + return data + + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + sample = self.sampler.sample_sequence(idx) + data = self._sample_to_data(sample) + torch_data = dict_apply(data, torch.from_numpy) + return torch_data + + +def test(): + import os + zarr_path = os.path.expanduser('/home/yihuai/robotics/repositories/mujoco/mujoco-env/data/collect_heuristic_data/2024-12-24_11-36-15_100episodes/merged_data.zarr') + dataset = MujocoImageDataset(zarr_path, horizon=16) + print(dataset[0]) + # from matplotlib import pyplot as plt + # normalizer = dataset.get_normalizer() + # nactions = normalizer['action'].normalize(dataset.replay_buffer['action']) + # diff = np.diff(nactions, axis=0) + # dists = np.linalg.norm(np.diff(nactions, axis=0), axis=-1) + +if __name__ == '__main__': + test() \ No newline at end of file diff --git a/image_pusht_diffusion_policy_cnn.yaml b/image_pusht_diffusion_policy_cnn.yaml new file mode 100644 index 0000000..dd88408 --- /dev/null +++ b/image_pusht_diffusion_policy_cnn.yaml @@ -0,0 +1,182 @@ +_target_: diffusion_policy.workspace.train_diffusion_unet_hybrid_workspace.TrainDiffusionUnetHybridWorkspace +checkpoint: + save_last_ckpt: true + save_last_snapshot: false + topk: + format_str: epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt + k: 5 + mode: max + monitor_key: test_mean_score +dataloader: + batch_size: 64 + num_workers: 8 + persistent_workers: false + pin_memory: true + shuffle: true +dataset_obs_steps: 2 +ema: + _target_: diffusion_policy.model.diffusion.ema_model.EMAModel + inv_gamma: 1.0 + max_value: 0.9999 + min_value: 0.0 + power: 0.75 + update_after_step: 0 +exp_name: default +horizon: 16 +keypoint_visible_rate: 1.0 +logging: + group: null + id: null + mode: online + name: 2023.01.16-20.20.06_train_diffusion_unet_hybrid_pusht_image + project: diffusion_policy_debug + resume: true + tags: + - train_diffusion_unet_hybrid + - pusht_image + - default +multi_run: + run_dir: data/outputs/2023.01.16/20.20.06_train_diffusion_unet_hybrid_pusht_image + wandb_name_base: 2023.01.16-20.20.06_train_diffusion_unet_hybrid_pusht_image +n_action_steps: 8 +n_latency_steps: 0 +n_obs_steps: 2 +name: train_diffusion_unet_hybrid +obs_as_global_cond: true +optimizer: + _target_: torch.optim.AdamW + betas: + - 0.95 + - 0.999 + eps: 1.0e-08 + lr: 0.0001 + weight_decay: 1.0e-06 +past_action_visible: false +policy: + _target_: diffusion_policy.policy.diffusion_unet_hybrid_image_policy.DiffusionUnetHybridImagePolicy + cond_predict_scale: true + crop_shape: + - 84 + - 84 + diffusion_step_embed_dim: 128 + down_dims: + - 512 + - 1024 + - 2048 + eval_fixed_crop: true + horizon: 16 + kernel_size: 5 + n_action_steps: 8 + n_groups: 8 + n_obs_steps: 2 + noise_scheduler: + _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler + beta_end: 0.02 + beta_schedule: squaredcos_cap_v2 + beta_start: 0.0001 + clip_sample: true + num_train_timesteps: 100 + prediction_type: epsilon + variance_type: fixed_small + num_inference_steps: 100 + obs_as_global_cond: true + obs_encoder_group_norm: true + shape_meta: + action: + shape: + - 2 + obs: + agent_pos: + shape: + - 2 + type: low_dim + image: + shape: + - 3 + - 96 + - 96 + type: rgb +shape_meta: + action: + shape: + - 2 + obs: + agent_pos: + shape: + - 2 + type: low_dim + image: + shape: + - 3 + - 96 + - 96 + type: rgb +task: + dataset: + _target_: diffusion_policy.dataset.pusht_image_dataset.PushTImageDataset + horizon: 16 + max_train_episodes: 90 + pad_after: 7 + pad_before: 1 + seed: 42 + val_ratio: 0.02 + zarr_path: data/pusht/pusht_cchi_v7_replay.zarr + env_runner: + _target_: diffusion_policy.env_runner.pusht_image_runner.PushTImageRunner + fps: 10 + legacy_test: true + max_steps: 300 + n_action_steps: 8 + n_envs: null + n_obs_steps: 2 + n_test: 50 + n_test_vis: 4 + n_train: 6 + n_train_vis: 2 + past_action: false + test_start_seed: 100000 + train_start_seed: 0 + image_shape: + - 3 + - 96 + - 96 + name: pusht_image + shape_meta: + action: + shape: + - 2 + obs: + agent_pos: + shape: + - 2 + type: low_dim + image: + shape: + - 3 + - 96 + - 96 + type: rgb +task_name: pusht_image +training: + checkpoint_every: 50 + debug: false + device: cuda:0 + gradient_accumulate_every: 1 + lr_scheduler: cosine + lr_warmup_steps: 500 + max_train_steps: null + max_val_steps: null + num_epochs: 3050 + resume: true + rollout_every: 50 + sample_every: 5 + seed: 42 + tqdm_interval_sec: 1.0 + use_ema: true + val_every: 1 +val_dataloader: + batch_size: 64 + num_workers: 8 + persistent_workers: false + pin_memory: true + shuffle: false