Done adapting mujoco image dataset
This commit is contained in:
@@ -158,6 +158,8 @@ class ReplayBuffer:
|
|||||||
# numpy backend
|
# numpy backend
|
||||||
meta = dict()
|
meta = dict()
|
||||||
for key, value in src_root['meta'].items():
|
for key, value in src_root['meta'].items():
|
||||||
|
if isinstance(value, zarr.Group):
|
||||||
|
continue
|
||||||
if len(value.shape) == 0:
|
if len(value.shape) == 0:
|
||||||
meta[key] = np.array(value)
|
meta[key] = np.array(value)
|
||||||
else:
|
else:
|
||||||
|
|||||||
109
diffusion_policy/dataset/mujoco_image_dataset.py
Normal file
109
diffusion_policy/dataset/mujoco_image_dataset.py
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
from typing import Dict
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import copy
|
||||||
|
from diffusion_policy.common.pytorch_util import dict_apply
|
||||||
|
from diffusion_policy.common.replay_buffer import ReplayBuffer
|
||||||
|
from diffusion_policy.common.sampler import (
|
||||||
|
SequenceSampler, get_val_mask, downsample_mask)
|
||||||
|
from diffusion_policy.model.common.normalizer import LinearNormalizer
|
||||||
|
from diffusion_policy.dataset.base_dataset import BaseImageDataset
|
||||||
|
from diffusion_policy.common.normalize_util import get_image_range_normalizer
|
||||||
|
|
||||||
|
class MujocoImageDataset(BaseImageDataset):
|
||||||
|
def __init__(self,
|
||||||
|
zarr_path,
|
||||||
|
horizon=1,
|
||||||
|
pad_before=0,
|
||||||
|
pad_after=0,
|
||||||
|
seed=42,
|
||||||
|
val_ratio=0.0,
|
||||||
|
max_train_episodes=None
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.replay_buffer = ReplayBuffer.copy_from_path(
|
||||||
|
# zarr_path, keys=['img', 'state', 'action'])
|
||||||
|
zarr_path, keys=['robot_0_camera_images', 'robot_0_tcp_xyz_wxyz', 'robot_0_gripper_width', 'action_0_tcp_xyz_wxyz', 'action_0_gripper_width'])
|
||||||
|
val_mask = get_val_mask(
|
||||||
|
n_episodes=self.replay_buffer.n_episodes,
|
||||||
|
val_ratio=val_ratio,
|
||||||
|
seed=seed)
|
||||||
|
train_mask = ~val_mask
|
||||||
|
train_mask = downsample_mask(
|
||||||
|
mask=train_mask,
|
||||||
|
max_n=max_train_episodes,
|
||||||
|
seed=seed)
|
||||||
|
|
||||||
|
self.sampler = SequenceSampler(
|
||||||
|
replay_buffer=self.replay_buffer,
|
||||||
|
sequence_length=horizon,
|
||||||
|
pad_before=pad_before,
|
||||||
|
pad_after=pad_after,
|
||||||
|
episode_mask=train_mask)
|
||||||
|
self.train_mask = train_mask
|
||||||
|
self.horizon = horizon
|
||||||
|
self.pad_before = pad_before
|
||||||
|
self.pad_after = pad_after
|
||||||
|
|
||||||
|
def get_validation_dataset(self):
|
||||||
|
val_set = copy.copy(self)
|
||||||
|
val_set.sampler = SequenceSampler(
|
||||||
|
replay_buffer=self.replay_buffer,
|
||||||
|
sequence_length=self.horizon,
|
||||||
|
pad_before=self.pad_before,
|
||||||
|
pad_after=self.pad_after,
|
||||||
|
episode_mask=~self.train_mask
|
||||||
|
)
|
||||||
|
val_set.train_mask = ~self.train_mask
|
||||||
|
return val_set
|
||||||
|
|
||||||
|
def get_normalizer(self, mode='limits', **kwargs):
|
||||||
|
data = {
|
||||||
|
'action': np.concatenate([self.replay_buffer['action_0_tcp_xyz_wxyz'], self.replay_buffer['action_0_gripper_width']], axis=-1),
|
||||||
|
'agent_pos': np.concatenate([self.replay_buffer['robot_0_tcp_xyz_wxyz'], self.replay_buffer['robot_0_tcp_xyz_wxyz']], axis=-1)
|
||||||
|
}
|
||||||
|
normalizer = LinearNormalizer()
|
||||||
|
normalizer.fit(data=data, last_n_dims=1, mode=mode, **kwargs)
|
||||||
|
normalizer['image'] = get_image_range_normalizer()
|
||||||
|
return normalizer
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.sampler)
|
||||||
|
|
||||||
|
def _sample_to_data(self, sample):
|
||||||
|
# agent_pos = sample['state'][:,:2].astype(np.float32) # (agent_posx2, block_posex3)
|
||||||
|
agent_pos = np.concatenate([sample['robot_0_tcp_xyz_wxyz'], sample['robot_0_gripper_width']], axis=-1).astype(np.float32)
|
||||||
|
agent_action = np.concatenate([sample['action_0_tcp_xyz_wxyz'], sample['action_0_gripper_width']], axis=-1).astype(np.float32)
|
||||||
|
# image = np.moveaxis(sample['img'],-1,1)/255
|
||||||
|
image = np.moveaxis(sample['robot_0_camera_images'].astype(np.float32).squeeze(1),-1,1)/255
|
||||||
|
|
||||||
|
data = {
|
||||||
|
'obs': {
|
||||||
|
'image': image, # T, 3, 224, 224
|
||||||
|
'agent_pos': agent_pos, # T, 8 (x,y,z,qx,qy,qz,qw,gripper_width)
|
||||||
|
},
|
||||||
|
'action': agent_action # T, 8 (x,y,z,qx,qy,qz,qw,gripper_width)
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
||||||
|
sample = self.sampler.sample_sequence(idx)
|
||||||
|
data = self._sample_to_data(sample)
|
||||||
|
torch_data = dict_apply(data, torch.from_numpy)
|
||||||
|
return torch_data
|
||||||
|
|
||||||
|
|
||||||
|
def test():
|
||||||
|
import os
|
||||||
|
zarr_path = os.path.expanduser('/home/yihuai/robotics/repositories/mujoco/mujoco-env/data/collect_heuristic_data/2024-12-24_11-36-15_100episodes/merged_data.zarr')
|
||||||
|
dataset = MujocoImageDataset(zarr_path, horizon=16)
|
||||||
|
print(dataset[0])
|
||||||
|
# from matplotlib import pyplot as plt
|
||||||
|
# normalizer = dataset.get_normalizer()
|
||||||
|
# nactions = normalizer['action'].normalize(dataset.replay_buffer['action'])
|
||||||
|
# diff = np.diff(nactions, axis=0)
|
||||||
|
# dists = np.linalg.norm(np.diff(nactions, axis=0), axis=-1)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test()
|
||||||
182
image_pusht_diffusion_policy_cnn.yaml
Normal file
182
image_pusht_diffusion_policy_cnn.yaml
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
_target_: diffusion_policy.workspace.train_diffusion_unet_hybrid_workspace.TrainDiffusionUnetHybridWorkspace
|
||||||
|
checkpoint:
|
||||||
|
save_last_ckpt: true
|
||||||
|
save_last_snapshot: false
|
||||||
|
topk:
|
||||||
|
format_str: epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt
|
||||||
|
k: 5
|
||||||
|
mode: max
|
||||||
|
monitor_key: test_mean_score
|
||||||
|
dataloader:
|
||||||
|
batch_size: 64
|
||||||
|
num_workers: 8
|
||||||
|
persistent_workers: false
|
||||||
|
pin_memory: true
|
||||||
|
shuffle: true
|
||||||
|
dataset_obs_steps: 2
|
||||||
|
ema:
|
||||||
|
_target_: diffusion_policy.model.diffusion.ema_model.EMAModel
|
||||||
|
inv_gamma: 1.0
|
||||||
|
max_value: 0.9999
|
||||||
|
min_value: 0.0
|
||||||
|
power: 0.75
|
||||||
|
update_after_step: 0
|
||||||
|
exp_name: default
|
||||||
|
horizon: 16
|
||||||
|
keypoint_visible_rate: 1.0
|
||||||
|
logging:
|
||||||
|
group: null
|
||||||
|
id: null
|
||||||
|
mode: online
|
||||||
|
name: 2023.01.16-20.20.06_train_diffusion_unet_hybrid_pusht_image
|
||||||
|
project: diffusion_policy_debug
|
||||||
|
resume: true
|
||||||
|
tags:
|
||||||
|
- train_diffusion_unet_hybrid
|
||||||
|
- pusht_image
|
||||||
|
- default
|
||||||
|
multi_run:
|
||||||
|
run_dir: data/outputs/2023.01.16/20.20.06_train_diffusion_unet_hybrid_pusht_image
|
||||||
|
wandb_name_base: 2023.01.16-20.20.06_train_diffusion_unet_hybrid_pusht_image
|
||||||
|
n_action_steps: 8
|
||||||
|
n_latency_steps: 0
|
||||||
|
n_obs_steps: 2
|
||||||
|
name: train_diffusion_unet_hybrid
|
||||||
|
obs_as_global_cond: true
|
||||||
|
optimizer:
|
||||||
|
_target_: torch.optim.AdamW
|
||||||
|
betas:
|
||||||
|
- 0.95
|
||||||
|
- 0.999
|
||||||
|
eps: 1.0e-08
|
||||||
|
lr: 0.0001
|
||||||
|
weight_decay: 1.0e-06
|
||||||
|
past_action_visible: false
|
||||||
|
policy:
|
||||||
|
_target_: diffusion_policy.policy.diffusion_unet_hybrid_image_policy.DiffusionUnetHybridImagePolicy
|
||||||
|
cond_predict_scale: true
|
||||||
|
crop_shape:
|
||||||
|
- 84
|
||||||
|
- 84
|
||||||
|
diffusion_step_embed_dim: 128
|
||||||
|
down_dims:
|
||||||
|
- 512
|
||||||
|
- 1024
|
||||||
|
- 2048
|
||||||
|
eval_fixed_crop: true
|
||||||
|
horizon: 16
|
||||||
|
kernel_size: 5
|
||||||
|
n_action_steps: 8
|
||||||
|
n_groups: 8
|
||||||
|
n_obs_steps: 2
|
||||||
|
noise_scheduler:
|
||||||
|
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
||||||
|
beta_end: 0.02
|
||||||
|
beta_schedule: squaredcos_cap_v2
|
||||||
|
beta_start: 0.0001
|
||||||
|
clip_sample: true
|
||||||
|
num_train_timesteps: 100
|
||||||
|
prediction_type: epsilon
|
||||||
|
variance_type: fixed_small
|
||||||
|
num_inference_steps: 100
|
||||||
|
obs_as_global_cond: true
|
||||||
|
obs_encoder_group_norm: true
|
||||||
|
shape_meta:
|
||||||
|
action:
|
||||||
|
shape:
|
||||||
|
- 2
|
||||||
|
obs:
|
||||||
|
agent_pos:
|
||||||
|
shape:
|
||||||
|
- 2
|
||||||
|
type: low_dim
|
||||||
|
image:
|
||||||
|
shape:
|
||||||
|
- 3
|
||||||
|
- 96
|
||||||
|
- 96
|
||||||
|
type: rgb
|
||||||
|
shape_meta:
|
||||||
|
action:
|
||||||
|
shape:
|
||||||
|
- 2
|
||||||
|
obs:
|
||||||
|
agent_pos:
|
||||||
|
shape:
|
||||||
|
- 2
|
||||||
|
type: low_dim
|
||||||
|
image:
|
||||||
|
shape:
|
||||||
|
- 3
|
||||||
|
- 96
|
||||||
|
- 96
|
||||||
|
type: rgb
|
||||||
|
task:
|
||||||
|
dataset:
|
||||||
|
_target_: diffusion_policy.dataset.pusht_image_dataset.PushTImageDataset
|
||||||
|
horizon: 16
|
||||||
|
max_train_episodes: 90
|
||||||
|
pad_after: 7
|
||||||
|
pad_before: 1
|
||||||
|
seed: 42
|
||||||
|
val_ratio: 0.02
|
||||||
|
zarr_path: data/pusht/pusht_cchi_v7_replay.zarr
|
||||||
|
env_runner:
|
||||||
|
_target_: diffusion_policy.env_runner.pusht_image_runner.PushTImageRunner
|
||||||
|
fps: 10
|
||||||
|
legacy_test: true
|
||||||
|
max_steps: 300
|
||||||
|
n_action_steps: 8
|
||||||
|
n_envs: null
|
||||||
|
n_obs_steps: 2
|
||||||
|
n_test: 50
|
||||||
|
n_test_vis: 4
|
||||||
|
n_train: 6
|
||||||
|
n_train_vis: 2
|
||||||
|
past_action: false
|
||||||
|
test_start_seed: 100000
|
||||||
|
train_start_seed: 0
|
||||||
|
image_shape:
|
||||||
|
- 3
|
||||||
|
- 96
|
||||||
|
- 96
|
||||||
|
name: pusht_image
|
||||||
|
shape_meta:
|
||||||
|
action:
|
||||||
|
shape:
|
||||||
|
- 2
|
||||||
|
obs:
|
||||||
|
agent_pos:
|
||||||
|
shape:
|
||||||
|
- 2
|
||||||
|
type: low_dim
|
||||||
|
image:
|
||||||
|
shape:
|
||||||
|
- 3
|
||||||
|
- 96
|
||||||
|
- 96
|
||||||
|
type: rgb
|
||||||
|
task_name: pusht_image
|
||||||
|
training:
|
||||||
|
checkpoint_every: 50
|
||||||
|
debug: false
|
||||||
|
device: cuda:0
|
||||||
|
gradient_accumulate_every: 1
|
||||||
|
lr_scheduler: cosine
|
||||||
|
lr_warmup_steps: 500
|
||||||
|
max_train_steps: null
|
||||||
|
max_val_steps: null
|
||||||
|
num_epochs: 3050
|
||||||
|
resume: true
|
||||||
|
rollout_every: 50
|
||||||
|
sample_every: 5
|
||||||
|
seed: 42
|
||||||
|
tqdm_interval_sec: 1.0
|
||||||
|
use_ema: true
|
||||||
|
val_every: 1
|
||||||
|
val_dataloader:
|
||||||
|
batch_size: 64
|
||||||
|
num_workers: 8
|
||||||
|
persistent_workers: false
|
||||||
|
pin_memory: true
|
||||||
|
shuffle: false
|
||||||
Reference in New Issue
Block a user