code release
This commit is contained in:
51
diffusion_policy/dataset/base_dataset.py
Normal file
51
diffusion_policy/dataset/base_dataset.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.nn
|
||||
from diffusion_policy.model.common.normalizer import LinearNormalizer
|
||||
|
||||
class BaseLowdimDataset(torch.utils.data.Dataset):
|
||||
def get_validation_dataset(self) -> 'BaseLowdimDataset':
|
||||
# return an empty dataset by default
|
||||
return BaseLowdimDataset()
|
||||
|
||||
def get_normalizer(self, **kwargs) -> LinearNormalizer:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_all_actions(self) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
def __len__(self) -> int:
|
||||
return 0
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
output:
|
||||
obs: T, Do
|
||||
action: T, Da
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class BaseImageDataset(torch.utils.data.Dataset):
|
||||
def get_validation_dataset(self) -> 'BaseLowdimDataset':
|
||||
# return an empty dataset by default
|
||||
return BaseImageDataset()
|
||||
|
||||
def get_normalizer(self, **kwargs) -> LinearNormalizer:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_all_actions(self) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
def __len__(self) -> int:
|
||||
return 0
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
output:
|
||||
obs:
|
||||
key: T, *
|
||||
action: T, Da
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
126
diffusion_policy/dataset/blockpush_lowdim_dataset.py
Normal file
126
diffusion_policy/dataset/blockpush_lowdim_dataset.py
Normal file
@@ -0,0 +1,126 @@
|
||||
from typing import Dict
|
||||
import torch
|
||||
import numpy as np
|
||||
import copy
|
||||
from diffusion_policy.common.pytorch_util import dict_apply
|
||||
from diffusion_policy.common.replay_buffer import ReplayBuffer
|
||||
from diffusion_policy.common.sampler import SequenceSampler, get_val_mask
|
||||
from diffusion_policy.model.common.normalizer import LinearNormalizer, SingleFieldLinearNormalizer
|
||||
from diffusion_policy.dataset.base_dataset import BaseLowdimDataset
|
||||
|
||||
class BlockPushLowdimDataset(BaseLowdimDataset):
|
||||
def __init__(self,
|
||||
zarr_path,
|
||||
horizon=1,
|
||||
pad_before=0,
|
||||
pad_after=0,
|
||||
obs_key='obs',
|
||||
action_key='action',
|
||||
obs_eef_target=True,
|
||||
use_manual_normalizer=False,
|
||||
seed=42,
|
||||
val_ratio=0.0
|
||||
):
|
||||
super().__init__()
|
||||
self.replay_buffer = ReplayBuffer.copy_from_path(
|
||||
zarr_path, keys=[obs_key, action_key])
|
||||
|
||||
val_mask = get_val_mask(
|
||||
n_episodes=self.replay_buffer.n_episodes,
|
||||
val_ratio=val_ratio,
|
||||
seed=seed)
|
||||
train_mask = ~val_mask
|
||||
self.sampler = SequenceSampler(
|
||||
replay_buffer=self.replay_buffer,
|
||||
sequence_length=horizon,
|
||||
pad_before=pad_before,
|
||||
pad_after=pad_after,
|
||||
episode_mask=train_mask)
|
||||
self.obs_key = obs_key
|
||||
self.action_key = action_key
|
||||
self.obs_eef_target = obs_eef_target
|
||||
self.use_manual_normalizer = use_manual_normalizer
|
||||
self.train_mask = train_mask
|
||||
self.horizon = horizon
|
||||
self.pad_before = pad_before
|
||||
self.pad_after = pad_after
|
||||
|
||||
def get_validation_dataset(self):
|
||||
val_set = copy.copy(self)
|
||||
val_set.sampler = SequenceSampler(
|
||||
replay_buffer=self.replay_buffer,
|
||||
sequence_length=self.horizon,
|
||||
pad_before=self.pad_before,
|
||||
pad_after=self.pad_after,
|
||||
episode_mask=~self.train_mask
|
||||
)
|
||||
val_set.train_mask = ~self.train_mask
|
||||
return val_set
|
||||
|
||||
def get_normalizer(self, mode='limits', **kwargs):
|
||||
data = self._sample_to_data(self.replay_buffer)
|
||||
|
||||
normalizer = LinearNormalizer()
|
||||
if not self.use_manual_normalizer:
|
||||
normalizer.fit(data=data, last_n_dims=1, mode=mode, **kwargs)
|
||||
else:
|
||||
x = data['obs']
|
||||
stat = {
|
||||
'max': np.max(x, axis=0),
|
||||
'min': np.min(x, axis=0),
|
||||
'mean': np.mean(x, axis=0),
|
||||
'std': np.std(x, axis=0)
|
||||
}
|
||||
|
||||
is_x = np.zeros(stat['max'].shape, dtype=bool)
|
||||
is_y = np.zeros_like(is_x)
|
||||
is_x[[0,3,6,8,10,13]] = True
|
||||
is_y[[1,4,7,9,11,14]] = True
|
||||
is_rot = ~(is_x|is_y)
|
||||
|
||||
def normalizer_with_masks(stat, masks):
|
||||
global_scale = np.ones_like(stat['max'])
|
||||
global_offset = np.zeros_like(stat['max'])
|
||||
for mask in masks:
|
||||
output_max = 1
|
||||
output_min = -1
|
||||
input_max = stat['max'][mask].max()
|
||||
input_min = stat['min'][mask].min()
|
||||
input_range = input_max - input_min
|
||||
scale = (output_max - output_min) / input_range
|
||||
offset = output_min - scale * input_min
|
||||
global_scale[mask] = scale
|
||||
global_offset[mask] = offset
|
||||
return SingleFieldLinearNormalizer.create_manual(
|
||||
scale=global_scale,
|
||||
offset=global_offset,
|
||||
input_stats_dict=stat
|
||||
)
|
||||
|
||||
normalizer['obs'] = normalizer_with_masks(stat, [is_x, is_y, is_rot])
|
||||
normalizer['action'] = SingleFieldLinearNormalizer.create_fit(
|
||||
data['action'], last_n_dims=1, mode=mode, **kwargs)
|
||||
return normalizer
|
||||
|
||||
def get_all_actions(self) -> torch.Tensor:
|
||||
return torch.from_numpy(self.replay_buffer['action'])
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.sampler)
|
||||
|
||||
def _sample_to_data(self, sample):
|
||||
obs = sample[self.obs_key] # T, D_o
|
||||
if not self.obs_eef_target:
|
||||
obs[:,8:10] = 0
|
||||
data = {
|
||||
'obs': obs,
|
||||
'action': sample[self.action_key], # T, D_a
|
||||
}
|
||||
return data
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
||||
sample = self.sampler.sample_sequence(idx)
|
||||
data = self._sample_to_data(sample)
|
||||
|
||||
torch_data = dict_apply(data, torch.from_numpy)
|
||||
return torch_data
|
||||
91
diffusion_policy/dataset/kitchen_lowdim_dataset.py
Normal file
91
diffusion_policy/dataset/kitchen_lowdim_dataset.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from typing import Dict
|
||||
import torch
|
||||
import numpy as np
|
||||
import copy
|
||||
import pathlib
|
||||
from diffusion_policy.common.pytorch_util import dict_apply
|
||||
from diffusion_policy.common.replay_buffer import ReplayBuffer
|
||||
from diffusion_policy.common.sampler import SequenceSampler, get_val_mask
|
||||
from diffusion_policy.model.common.normalizer import LinearNormalizer, SingleFieldLinearNormalizer
|
||||
from diffusion_policy.dataset.base_dataset import BaseLowdimDataset
|
||||
|
||||
class KitchenLowdimDataset(BaseLowdimDataset):
|
||||
def __init__(self,
|
||||
dataset_dir,
|
||||
horizon=1,
|
||||
pad_before=0,
|
||||
pad_after=0,
|
||||
seed=42,
|
||||
val_ratio=0.0
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
data_directory = pathlib.Path(dataset_dir)
|
||||
observations = np.load(data_directory / "observations_seq.npy")
|
||||
actions = np.load(data_directory / "actions_seq.npy")
|
||||
masks = np.load(data_directory / "existence_mask.npy")
|
||||
|
||||
self.replay_buffer = ReplayBuffer.create_empty_numpy()
|
||||
for i in range(len(masks)):
|
||||
eps_len = int(masks[i].sum())
|
||||
obs = observations[i,:eps_len].astype(np.float32)
|
||||
action = actions[i,:eps_len].astype(np.float32)
|
||||
data = {
|
||||
'obs': obs,
|
||||
'action': action
|
||||
}
|
||||
self.replay_buffer.add_episode(data)
|
||||
|
||||
val_mask = get_val_mask(
|
||||
n_episodes=self.replay_buffer.n_episodes,
|
||||
val_ratio=val_ratio,
|
||||
seed=seed)
|
||||
train_mask = ~val_mask
|
||||
self.sampler = SequenceSampler(
|
||||
replay_buffer=self.replay_buffer,
|
||||
sequence_length=horizon,
|
||||
pad_before=pad_before,
|
||||
pad_after=pad_after,
|
||||
episode_mask=train_mask)
|
||||
|
||||
self.train_mask = train_mask
|
||||
self.horizon = horizon
|
||||
self.pad_before = pad_before
|
||||
self.pad_after = pad_after
|
||||
|
||||
def get_validation_dataset(self):
|
||||
val_set = copy.copy(self)
|
||||
val_set.sampler = SequenceSampler(
|
||||
replay_buffer=self.replay_buffer,
|
||||
sequence_length=self.horizon,
|
||||
pad_before=self.pad_before,
|
||||
pad_after=self.pad_after,
|
||||
episode_mask=~self.train_mask
|
||||
)
|
||||
val_set.train_mask = ~self.train_mask
|
||||
return val_set
|
||||
|
||||
def get_normalizer(self, mode='limits', **kwargs):
|
||||
data = {
|
||||
'obs': self.replay_buffer['obs'],
|
||||
'action': self.replay_buffer['action']
|
||||
}
|
||||
if 'range_eps' not in kwargs:
|
||||
# to prevent blowing up dims that barely change
|
||||
kwargs['range_eps'] = 5e-2
|
||||
normalizer = LinearNormalizer()
|
||||
normalizer.fit(data=data, last_n_dims=1, mode=mode, **kwargs)
|
||||
return normalizer
|
||||
|
||||
def get_all_actions(self) -> torch.Tensor:
|
||||
return torch.from_numpy(self.replay_buffer['action'])
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.sampler)
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
||||
sample = self.sampler.sample_sequence(idx)
|
||||
data = sample
|
||||
|
||||
torch_data = dict_apply(data, torch.from_numpy)
|
||||
return torch_data
|
||||
112
diffusion_policy/dataset/kitchen_mjl_lowdim_dataset.py
Normal file
112
diffusion_policy/dataset/kitchen_mjl_lowdim_dataset.py
Normal file
@@ -0,0 +1,112 @@
|
||||
from typing import Dict
|
||||
import torch
|
||||
import numpy as np
|
||||
import copy
|
||||
import pathlib
|
||||
from tqdm import tqdm
|
||||
from diffusion_policy.common.pytorch_util import dict_apply
|
||||
from diffusion_policy.common.replay_buffer import ReplayBuffer
|
||||
from diffusion_policy.common.sampler import SequenceSampler, get_val_mask
|
||||
from diffusion_policy.model.common.normalizer import LinearNormalizer, SingleFieldLinearNormalizer
|
||||
from diffusion_policy.dataset.base_dataset import BaseLowdimDataset
|
||||
from diffusion_policy.env.kitchen.kitchen_util import parse_mjl_logs
|
||||
|
||||
class KitchenMjlLowdimDataset(BaseLowdimDataset):
|
||||
def __init__(self,
|
||||
dataset_dir,
|
||||
horizon=1,
|
||||
pad_before=0,
|
||||
pad_after=0,
|
||||
abs_action=True,
|
||||
robot_noise_ratio=0.0,
|
||||
seed=42,
|
||||
val_ratio=0.0
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if not abs_action:
|
||||
raise NotImplementedError()
|
||||
|
||||
robot_pos_noise_amp = np.array([0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.1 ,
|
||||
0.1 , 0.005 , 0.005 , 0.0005, 0.0005, 0.0005, 0.0005, 0.0005,
|
||||
0.0005, 0.005 , 0.005 , 0.005 , 0.1 , 0.1 , 0.1 , 0.005 ,
|
||||
0.005 , 0.005 , 0.1 , 0.1 , 0.1 , 0.005 ], dtype=np.float32)
|
||||
rng = np.random.default_rng(seed=seed)
|
||||
|
||||
data_directory = pathlib.Path(dataset_dir)
|
||||
self.replay_buffer = ReplayBuffer.create_empty_numpy()
|
||||
for i, mjl_path in enumerate(tqdm(list(data_directory.glob('*/*.mjl')))):
|
||||
try:
|
||||
data = parse_mjl_logs(str(mjl_path.absolute()), skipamount=40)
|
||||
qpos = data['qpos'].astype(np.float32)
|
||||
obs = np.concatenate([
|
||||
qpos[:,:9],
|
||||
qpos[:,-21:],
|
||||
np.zeros((len(qpos),30),dtype=np.float32)
|
||||
], axis=-1)
|
||||
if robot_noise_ratio > 0:
|
||||
# add observation noise to match real robot
|
||||
noise = robot_noise_ratio * robot_pos_noise_amp * rng.uniform(
|
||||
low=-1., high=1., size=(obs.shape[0], 30))
|
||||
obs[:,:30] += noise
|
||||
episode = {
|
||||
'obs': obs,
|
||||
'action': data['ctrl'].astype(np.float32)
|
||||
}
|
||||
self.replay_buffer.add_episode(episode)
|
||||
except Exception as e:
|
||||
print(i, e)
|
||||
|
||||
val_mask = get_val_mask(
|
||||
n_episodes=self.replay_buffer.n_episodes,
|
||||
val_ratio=val_ratio,
|
||||
seed=seed)
|
||||
train_mask = ~val_mask
|
||||
self.sampler = SequenceSampler(
|
||||
replay_buffer=self.replay_buffer,
|
||||
sequence_length=horizon,
|
||||
pad_before=pad_before,
|
||||
pad_after=pad_after,
|
||||
episode_mask=train_mask)
|
||||
|
||||
self.train_mask = train_mask
|
||||
self.horizon = horizon
|
||||
self.pad_before = pad_before
|
||||
self.pad_after = pad_after
|
||||
|
||||
def get_validation_dataset(self):
|
||||
val_set = copy.copy(self)
|
||||
val_set.sampler = SequenceSampler(
|
||||
replay_buffer=self.replay_buffer,
|
||||
sequence_length=self.horizon,
|
||||
pad_before=self.pad_before,
|
||||
pad_after=self.pad_after,
|
||||
episode_mask=~self.train_mask
|
||||
)
|
||||
val_set.train_mask = ~self.train_mask
|
||||
return val_set
|
||||
|
||||
def get_normalizer(self, mode='limits', **kwargs):
|
||||
data = {
|
||||
'obs': self.replay_buffer['obs'],
|
||||
'action': self.replay_buffer['action']
|
||||
}
|
||||
if 'range_eps' not in kwargs:
|
||||
# to prevent blowing up dims that barely change
|
||||
kwargs['range_eps'] = 5e-2
|
||||
normalizer = LinearNormalizer()
|
||||
normalizer.fit(data=data, last_n_dims=1, mode=mode, **kwargs)
|
||||
return normalizer
|
||||
|
||||
def get_all_actions(self) -> torch.Tensor:
|
||||
return torch.from_numpy(self.replay_buffer['action'])
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.sampler)
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
||||
sample = self.sampler.sample_sequence(idx)
|
||||
data = sample
|
||||
|
||||
torch_data = dict_apply(data, torch.from_numpy)
|
||||
return torch_data
|
||||
97
diffusion_policy/dataset/pusht_dataset.py
Normal file
97
diffusion_policy/dataset/pusht_dataset.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from typing import Dict
|
||||
import torch
|
||||
import numpy as np
|
||||
import copy
|
||||
from diffusion_policy.common.pytorch_util import dict_apply
|
||||
from diffusion_policy.common.replay_buffer import ReplayBuffer
|
||||
from diffusion_policy.common.sampler import (
|
||||
SequenceSampler, get_val_mask, downsample_mask)
|
||||
from diffusion_policy.model.common.normalizer import LinearNormalizer
|
||||
from diffusion_policy.dataset.base_dataset import BaseLowdimDataset
|
||||
|
||||
class PushTLowdimDataset(BaseLowdimDataset):
|
||||
def __init__(self,
|
||||
zarr_path,
|
||||
horizon=1,
|
||||
pad_before=0,
|
||||
pad_after=0,
|
||||
obs_key='keypoint',
|
||||
state_key='state',
|
||||
action_key='action',
|
||||
seed=42,
|
||||
val_ratio=0.0,
|
||||
max_train_episodes=None
|
||||
):
|
||||
super().__init__()
|
||||
self.replay_buffer = ReplayBuffer.copy_from_path(
|
||||
zarr_path, keys=[obs_key, state_key, action_key])
|
||||
|
||||
val_mask = get_val_mask(
|
||||
n_episodes=self.replay_buffer.n_episodes,
|
||||
val_ratio=val_ratio,
|
||||
seed=seed)
|
||||
train_mask = ~val_mask
|
||||
train_mask = downsample_mask(
|
||||
mask=train_mask,
|
||||
max_n=max_train_episodes,
|
||||
seed=seed)
|
||||
|
||||
self.sampler = SequenceSampler(
|
||||
replay_buffer=self.replay_buffer,
|
||||
sequence_length=horizon,
|
||||
pad_before=pad_before,
|
||||
pad_after=pad_after,
|
||||
episode_mask=train_mask
|
||||
)
|
||||
self.obs_key = obs_key
|
||||
self.state_key = state_key
|
||||
self.action_key = action_key
|
||||
self.train_mask = train_mask
|
||||
self.horizon = horizon
|
||||
self.pad_before = pad_before
|
||||
self.pad_after = pad_after
|
||||
|
||||
def get_validation_dataset(self):
|
||||
val_set = copy.copy(self)
|
||||
val_set.sampler = SequenceSampler(
|
||||
replay_buffer=self.replay_buffer,
|
||||
sequence_length=self.horizon,
|
||||
pad_before=self.pad_before,
|
||||
pad_after=self.pad_after,
|
||||
episode_mask=~self.train_mask
|
||||
)
|
||||
val_set.train_mask = ~self.train_mask
|
||||
return val_set
|
||||
|
||||
def get_normalizer(self, mode='limits', **kwargs):
|
||||
data = self._sample_to_data(self.replay_buffer)
|
||||
normalizer = LinearNormalizer()
|
||||
normalizer.fit(data=data, last_n_dims=1, mode=mode, **kwargs)
|
||||
return normalizer
|
||||
|
||||
def get_all_actions(self) -> torch.Tensor:
|
||||
return torch.from_numpy(self.replay_buffer[self.action_key])
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.sampler)
|
||||
|
||||
def _sample_to_data(self, sample):
|
||||
keypoint = sample[self.obs_key]
|
||||
state = sample[self.state_key]
|
||||
agent_pos = state[:,:2]
|
||||
obs = np.concatenate([
|
||||
keypoint.reshape(keypoint.shape[0], -1),
|
||||
agent_pos], axis=-1)
|
||||
|
||||
data = {
|
||||
'obs': obs, # T, D_o
|
||||
'action': sample[self.action_key], # T, D_a
|
||||
}
|
||||
return data
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
||||
sample = self.sampler.sample_sequence(idx)
|
||||
data = self._sample_to_data(sample)
|
||||
|
||||
torch_data = dict_apply(data, torch.from_numpy)
|
||||
return torch_data
|
||||
102
diffusion_policy/dataset/pusht_image_dataset.py
Normal file
102
diffusion_policy/dataset/pusht_image_dataset.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from typing import Dict
|
||||
import torch
|
||||
import numpy as np
|
||||
import copy
|
||||
from diffusion_policy.common.pytorch_util import dict_apply
|
||||
from diffusion_policy.common.replay_buffer import ReplayBuffer
|
||||
from diffusion_policy.common.sampler import (
|
||||
SequenceSampler, get_val_mask, downsample_mask)
|
||||
from diffusion_policy.model.common.normalizer import LinearNormalizer
|
||||
from diffusion_policy.dataset.base_dataset import BaseImageDataset
|
||||
from diffusion_policy.common.normalize_util import get_image_range_normalizer
|
||||
|
||||
class PushTImageDataset(BaseImageDataset):
|
||||
def __init__(self,
|
||||
zarr_path,
|
||||
horizon=1,
|
||||
pad_before=0,
|
||||
pad_after=0,
|
||||
seed=42,
|
||||
val_ratio=0.0,
|
||||
max_train_episodes=None
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
self.replay_buffer = ReplayBuffer.copy_from_path(
|
||||
zarr_path, keys=['img', 'state', 'action'])
|
||||
val_mask = get_val_mask(
|
||||
n_episodes=self.replay_buffer.n_episodes,
|
||||
val_ratio=val_ratio,
|
||||
seed=seed)
|
||||
train_mask = ~val_mask
|
||||
train_mask = downsample_mask(
|
||||
mask=train_mask,
|
||||
max_n=max_train_episodes,
|
||||
seed=seed)
|
||||
|
||||
self.sampler = SequenceSampler(
|
||||
replay_buffer=self.replay_buffer,
|
||||
sequence_length=horizon,
|
||||
pad_before=pad_before,
|
||||
pad_after=pad_after,
|
||||
episode_mask=train_mask)
|
||||
self.train_mask = train_mask
|
||||
self.horizon = horizon
|
||||
self.pad_before = pad_before
|
||||
self.pad_after = pad_after
|
||||
|
||||
def get_validation_dataset(self):
|
||||
val_set = copy.copy(self)
|
||||
val_set.sampler = SequenceSampler(
|
||||
replay_buffer=self.replay_buffer,
|
||||
sequence_length=self.horizon,
|
||||
pad_before=self.pad_before,
|
||||
pad_after=self.pad_after,
|
||||
episode_mask=~self.train_mask
|
||||
)
|
||||
val_set.train_mask = ~self.train_mask
|
||||
return val_set
|
||||
|
||||
def get_normalizer(self, mode='limits', **kwargs):
|
||||
data = {
|
||||
'action': self.replay_buffer['action'],
|
||||
'agent_pos': self.replay_buffer['state'][...,:2]
|
||||
}
|
||||
normalizer = LinearNormalizer()
|
||||
normalizer.fit(data=data, last_n_dims=1, mode=mode, **kwargs)
|
||||
normalizer['image'] = get_image_range_normalizer()
|
||||
return normalizer
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.sampler)
|
||||
|
||||
def _sample_to_data(self, sample):
|
||||
agent_pos = sample['state'][:,:2].astype(np.float32) # (agent_posx2, block_posex3)
|
||||
image = np.moveaxis(sample['img'],-1,1)/255
|
||||
|
||||
data = {
|
||||
'obs': {
|
||||
'image': image, # T, 3, 96, 96
|
||||
'agent_pos': agent_pos, # T, 2
|
||||
},
|
||||
'action': sample['action'].astype(np.float32) # T, 2
|
||||
}
|
||||
return data
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
||||
sample = self.sampler.sample_sequence(idx)
|
||||
data = self._sample_to_data(sample)
|
||||
torch_data = dict_apply(data, torch.from_numpy)
|
||||
return torch_data
|
||||
|
||||
|
||||
def test():
|
||||
import os
|
||||
zarr_path = os.path.expanduser('~/dev/diffusion_policy/data/pusht/pusht_cchi_v7_replay.zarr')
|
||||
dataset = PushTImageDataset(zarr_path, horizon=16)
|
||||
|
||||
# from matplotlib import pyplot as plt
|
||||
# normalizer = dataset.get_normalizer()
|
||||
# nactions = normalizer['action'].normalize(dataset.replay_buffer['action'])
|
||||
# diff = np.diff(nactions, axis=0)
|
||||
# dists = np.linalg.norm(np.diff(nactions, axis=0), axis=-1)
|
||||
291
diffusion_policy/dataset/real_pusht_image_dataset.py
Normal file
291
diffusion_policy/dataset/real_pusht_image_dataset.py
Normal file
@@ -0,0 +1,291 @@
|
||||
from typing import Dict, List
|
||||
import torch
|
||||
import numpy as np
|
||||
import zarr
|
||||
import os
|
||||
import shutil
|
||||
from filelock import FileLock
|
||||
from threadpoolctl import threadpool_limits
|
||||
from omegaconf import OmegaConf
|
||||
import cv2
|
||||
import json
|
||||
import hashlib
|
||||
import copy
|
||||
from diffusion_policy.common.pytorch_util import dict_apply
|
||||
from diffusion_policy.dataset.base_dataset import BaseImageDataset
|
||||
from diffusion_policy.model.common.normalizer import LinearNormalizer, SingleFieldLinearNormalizer
|
||||
from diffusion_policy.common.replay_buffer import ReplayBuffer
|
||||
from diffusion_policy.common.sampler import (
|
||||
SequenceSampler, get_val_mask, downsample_mask)
|
||||
from diffusion_policy.real_world.real_data_conversion import real_data_to_replay_buffer
|
||||
from diffusion_policy.common.normalize_util import (
|
||||
get_range_normalizer_from_stat,
|
||||
get_image_range_normalizer,
|
||||
get_identity_normalizer_from_stat,
|
||||
array_to_stats
|
||||
)
|
||||
|
||||
class RealPushTImageDataset(BaseImageDataset):
|
||||
def __init__(self,
|
||||
shape_meta: dict,
|
||||
dataset_path: str,
|
||||
horizon=1,
|
||||
pad_before=0,
|
||||
pad_after=0,
|
||||
n_obs_steps=None,
|
||||
n_latency_steps=0,
|
||||
use_cache=False,
|
||||
seed=42,
|
||||
val_ratio=0.0,
|
||||
max_train_episodes=None,
|
||||
delta_action=False,
|
||||
):
|
||||
assert os.path.isdir(dataset_path)
|
||||
|
||||
replay_buffer = None
|
||||
if use_cache:
|
||||
# fingerprint shape_meta
|
||||
shape_meta_json = json.dumps(OmegaConf.to_container(shape_meta), sort_keys=True)
|
||||
shape_meta_hash = hashlib.md5(shape_meta_json.encode('utf-8')).hexdigest()
|
||||
cache_zarr_path = os.path.join(dataset_path, shape_meta_hash + '.zarr.zip')
|
||||
cache_lock_path = cache_zarr_path + '.lock'
|
||||
print('Acquiring lock on cache.')
|
||||
with FileLock(cache_lock_path):
|
||||
if not os.path.exists(cache_zarr_path):
|
||||
# cache does not exists
|
||||
try:
|
||||
print('Cache does not exist. Creating!')
|
||||
replay_buffer = _get_replay_buffer(
|
||||
dataset_path=dataset_path,
|
||||
shape_meta=shape_meta,
|
||||
store=zarr.MemoryStore()
|
||||
)
|
||||
print('Saving cache to disk.')
|
||||
with zarr.ZipStore(cache_zarr_path) as zip_store:
|
||||
replay_buffer.save_to_store(
|
||||
store=zip_store
|
||||
)
|
||||
except Exception as e:
|
||||
shutil.rmtree(cache_zarr_path)
|
||||
raise e
|
||||
else:
|
||||
print('Loading cached ReplayBuffer from Disk.')
|
||||
with zarr.ZipStore(cache_zarr_path, mode='r') as zip_store:
|
||||
replay_buffer = ReplayBuffer.copy_from_store(
|
||||
src_store=zip_store, store=zarr.MemoryStore())
|
||||
print('Loaded!')
|
||||
else:
|
||||
replay_buffer = _get_replay_buffer(
|
||||
dataset_path=dataset_path,
|
||||
shape_meta=shape_meta,
|
||||
store=zarr.MemoryStore()
|
||||
)
|
||||
|
||||
if delta_action:
|
||||
# replace action as relative to previous frame
|
||||
actions = replay_buffer['action'][:]
|
||||
# suport positions only at this time
|
||||
assert actions.shape[1] <= 3
|
||||
actions_diff = np.zeros_like(actions)
|
||||
episode_ends = replay_buffer.episode_ends[:]
|
||||
for i in range(len(episode_ends)):
|
||||
start = 0
|
||||
if i > 0:
|
||||
start = episode_ends[i-1]
|
||||
end = episode_ends[i]
|
||||
# delta action is the difference between previous desired postion and the current
|
||||
# it should be scheduled at the previous timestep for the curren timestep
|
||||
# to ensure consistency with positional mode
|
||||
actions_diff[start+1:end] = np.diff(actions[start:end], axis=0)
|
||||
replay_buffer['action'][:] = actions_diff
|
||||
|
||||
rgb_keys = list()
|
||||
lowdim_keys = list()
|
||||
obs_shape_meta = shape_meta['obs']
|
||||
for key, attr in obs_shape_meta.items():
|
||||
type = attr.get('type', 'low_dim')
|
||||
if type == 'rgb':
|
||||
rgb_keys.append(key)
|
||||
elif type == 'low_dim':
|
||||
lowdim_keys.append(key)
|
||||
|
||||
key_first_k = dict()
|
||||
if n_obs_steps is not None:
|
||||
# only take first k obs from images
|
||||
for key in rgb_keys + lowdim_keys:
|
||||
key_first_k[key] = n_obs_steps
|
||||
|
||||
val_mask = get_val_mask(
|
||||
n_episodes=replay_buffer.n_episodes,
|
||||
val_ratio=val_ratio,
|
||||
seed=seed)
|
||||
train_mask = ~val_mask
|
||||
train_mask = downsample_mask(
|
||||
mask=train_mask,
|
||||
max_n=max_train_episodes,
|
||||
seed=seed)
|
||||
|
||||
sampler = SequenceSampler(
|
||||
replay_buffer=replay_buffer,
|
||||
sequence_length=horizon+n_latency_steps,
|
||||
pad_before=pad_before,
|
||||
pad_after=pad_after,
|
||||
episode_mask=train_mask,
|
||||
key_first_k=key_first_k)
|
||||
|
||||
self.replay_buffer = replay_buffer
|
||||
self.sampler = sampler
|
||||
self.shape_meta = shape_meta
|
||||
self.rgb_keys = rgb_keys
|
||||
self.lowdim_keys = lowdim_keys
|
||||
self.n_obs_steps = n_obs_steps
|
||||
self.val_mask = val_mask
|
||||
self.horizon = horizon
|
||||
self.n_latency_steps = n_latency_steps
|
||||
self.pad_before = pad_before
|
||||
self.pad_after = pad_after
|
||||
|
||||
def get_validation_dataset(self):
|
||||
val_set = copy.copy(self)
|
||||
val_set.sampler = SequenceSampler(
|
||||
replay_buffer=self.replay_buffer,
|
||||
sequence_length=self.horizon+self.n_latency_steps,
|
||||
pad_before=self.pad_before,
|
||||
pad_after=self.pad_after,
|
||||
episode_mask=self.val_mask
|
||||
)
|
||||
val_set.val_mask = ~self.val_mask
|
||||
return val_set
|
||||
|
||||
def get_normalizer(self, **kwargs) -> LinearNormalizer:
|
||||
normalizer = LinearNormalizer()
|
||||
|
||||
# action
|
||||
normalizer['action'] = SingleFieldLinearNormalizer.create_fit(
|
||||
self.replay_buffer['action'])
|
||||
|
||||
# obs
|
||||
for key in self.lowdim_keys:
|
||||
normalizer[key] = SingleFieldLinearNormalizer.create_fit(
|
||||
self.replay_buffer[key])
|
||||
|
||||
# image
|
||||
for key in self.rgb_keys:
|
||||
normalizer[key] = get_image_range_normalizer()
|
||||
return normalizer
|
||||
|
||||
def get_all_actions(self) -> torch.Tensor:
|
||||
return torch.from_numpy(self.replay_buffer['action'])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sampler)
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
||||
threadpool_limits(1)
|
||||
data = self.sampler.sample_sequence(idx)
|
||||
|
||||
# to save RAM, only return first n_obs_steps of OBS
|
||||
# since the rest will be discarded anyway.
|
||||
# when self.n_obs_steps is None
|
||||
# this slice does nothing (takes all)
|
||||
T_slice = slice(self.n_obs_steps)
|
||||
|
||||
obs_dict = dict()
|
||||
for key in self.rgb_keys:
|
||||
# move channel last to channel first
|
||||
# T,H,W,C
|
||||
# convert uint8 image to float32
|
||||
obs_dict[key] = np.moveaxis(data[key][T_slice],-1,1
|
||||
).astype(np.float32) / 255.
|
||||
# T,C,H,W
|
||||
# save ram
|
||||
del data[key]
|
||||
for key in self.lowdim_keys:
|
||||
obs_dict[key] = data[key][T_slice].astype(np.float32)
|
||||
# save ram
|
||||
del data[key]
|
||||
|
||||
action = data['action'].astype(np.float32)
|
||||
# handle latency by dropping first n_latency_steps action
|
||||
# observations are already taken care of by T_slice
|
||||
if self.n_latency_steps > 0:
|
||||
action = action[self.n_latency_steps:]
|
||||
|
||||
torch_data = {
|
||||
'obs': dict_apply(obs_dict, torch.from_numpy),
|
||||
'action': torch.from_numpy(action)
|
||||
}
|
||||
return torch_data
|
||||
|
||||
def zarr_resize_index_last_dim(zarr_arr, idxs):
|
||||
actions = zarr_arr[:]
|
||||
actions = actions[...,idxs]
|
||||
zarr_arr.resize(zarr_arr.shape[:-1] + (len(idxs),))
|
||||
zarr_arr[:] = actions
|
||||
return zarr_arr
|
||||
|
||||
def _get_replay_buffer(dataset_path, shape_meta, store):
|
||||
# parse shape meta
|
||||
rgb_keys = list()
|
||||
lowdim_keys = list()
|
||||
out_resolutions = dict()
|
||||
lowdim_shapes = dict()
|
||||
obs_shape_meta = shape_meta['obs']
|
||||
for key, attr in obs_shape_meta.items():
|
||||
type = attr.get('type', 'low_dim')
|
||||
shape = tuple(attr.get('shape'))
|
||||
if type == 'rgb':
|
||||
rgb_keys.append(key)
|
||||
c,h,w = shape
|
||||
out_resolutions[key] = (w,h)
|
||||
elif type == 'low_dim':
|
||||
lowdim_keys.append(key)
|
||||
lowdim_shapes[key] = tuple(shape)
|
||||
if 'pose' in key:
|
||||
assert tuple(shape) in [(2,),(6,)]
|
||||
|
||||
action_shape = tuple(shape_meta['action']['shape'])
|
||||
assert action_shape in [(2,),(6,)]
|
||||
|
||||
# load data
|
||||
cv2.setNumThreads(1)
|
||||
with threadpool_limits(1):
|
||||
replay_buffer = real_data_to_replay_buffer(
|
||||
dataset_path=dataset_path,
|
||||
out_store=store,
|
||||
out_resolutions=out_resolutions,
|
||||
lowdim_keys=lowdim_keys + ['action'],
|
||||
image_keys=rgb_keys
|
||||
)
|
||||
|
||||
# transform lowdim dimensions
|
||||
if action_shape == (2,):
|
||||
# 2D action space, only controls X and Y
|
||||
zarr_arr = replay_buffer['action']
|
||||
zarr_resize_index_last_dim(zarr_arr, idxs=[0,1])
|
||||
|
||||
for key, shape in lowdim_shapes.items():
|
||||
if 'pose' in key and shape == (2,):
|
||||
# only take X and Y
|
||||
zarr_arr = replay_buffer[key]
|
||||
zarr_resize_index_last_dim(zarr_arr, idxs=[0,1])
|
||||
|
||||
return replay_buffer
|
||||
|
||||
|
||||
def test():
|
||||
import hydra
|
||||
from omegaconf import OmegaConf
|
||||
OmegaConf.register_new_resolver("eval", eval, replace=True)
|
||||
|
||||
with hydra.initialize('../diffusion_policy/config'):
|
||||
cfg = hydra.compose('train_robomimic_real_image_workspace')
|
||||
OmegaConf.resolve(cfg)
|
||||
dataset = hydra.utils.instantiate(cfg.task.dataset)
|
||||
|
||||
from matplotlib import pyplot as plt
|
||||
normalizer = dataset.get_normalizer()
|
||||
nactions = normalizer['action'].normalize(dataset.replay_buffer['action'][:])
|
||||
diff = np.diff(nactions, axis=0)
|
||||
dists = np.linalg.norm(np.diff(nactions, axis=0), axis=-1)
|
||||
_ = plt.hist(dists, bins=100); plt.title('real action velocity')
|
||||
373
diffusion_policy/dataset/robomimic_replay_image_dataset.py
Normal file
373
diffusion_policy/dataset/robomimic_replay_image_dataset.py
Normal file
@@ -0,0 +1,373 @@
|
||||
from typing import Dict, List
|
||||
import torch
|
||||
import numpy as np
|
||||
import h5py
|
||||
from tqdm import tqdm
|
||||
import zarr
|
||||
import os
|
||||
import shutil
|
||||
import copy
|
||||
import json
|
||||
import hashlib
|
||||
from filelock import FileLock
|
||||
from threadpoolctl import threadpool_limits
|
||||
import concurrent.futures
|
||||
import multiprocessing
|
||||
from omegaconf import OmegaConf
|
||||
from diffusion_policy.common.pytorch_util import dict_apply
|
||||
from diffusion_policy.dataset.base_dataset import BaseImageDataset, LinearNormalizer
|
||||
from diffusion_policy.model.common.normalizer import LinearNormalizer, SingleFieldLinearNormalizer
|
||||
from diffusion_policy.model.common.rotation_transformer import RotationTransformer
|
||||
from diffusion_policy.codecs.imagecodecs_numcodecs import register_codecs, Jpeg2k
|
||||
from diffusion_policy.common.replay_buffer import ReplayBuffer
|
||||
from diffusion_policy.common.sampler import SequenceSampler, get_val_mask
|
||||
from diffusion_policy.common.normalize_util import (
|
||||
robomimic_abs_action_only_normalizer_from_stat,
|
||||
robomimic_abs_action_only_dual_arm_normalizer_from_stat,
|
||||
get_range_normalizer_from_stat,
|
||||
get_image_range_normalizer,
|
||||
get_identity_normalizer_from_stat,
|
||||
array_to_stats
|
||||
)
|
||||
register_codecs()
|
||||
|
||||
class RobomimicReplayImageDataset(BaseImageDataset):
|
||||
def __init__(self,
|
||||
shape_meta: dict,
|
||||
dataset_path: str,
|
||||
horizon=1,
|
||||
pad_before=0,
|
||||
pad_after=0,
|
||||
n_obs_steps=None,
|
||||
abs_action=False,
|
||||
rotation_rep='rotation_6d', # ignored when abs_action=False
|
||||
use_legacy_normalizer=False,
|
||||
use_cache=False,
|
||||
seed=42,
|
||||
val_ratio=0.0
|
||||
):
|
||||
rotation_transformer = RotationTransformer(
|
||||
from_rep='axis_angle', to_rep=rotation_rep)
|
||||
|
||||
replay_buffer = None
|
||||
if use_cache:
|
||||
cache_zarr_path = dataset_path + '.zarr.zip'
|
||||
cache_lock_path = cache_zarr_path + '.lock'
|
||||
print('Acquiring lock on cache.')
|
||||
with FileLock(cache_lock_path):
|
||||
if not os.path.exists(cache_zarr_path):
|
||||
# cache does not exists
|
||||
try:
|
||||
print('Cache does not exist. Creating!')
|
||||
# store = zarr.DirectoryStore(cache_zarr_path)
|
||||
replay_buffer = _convert_robomimic_to_replay(
|
||||
store=zarr.MemoryStore(),
|
||||
shape_meta=shape_meta,
|
||||
dataset_path=dataset_path,
|
||||
abs_action=abs_action,
|
||||
rotation_transformer=rotation_transformer)
|
||||
print('Saving cache to disk.')
|
||||
with zarr.ZipStore(cache_zarr_path) as zip_store:
|
||||
replay_buffer.save_to_store(
|
||||
store=zip_store
|
||||
)
|
||||
except Exception as e:
|
||||
shutil.rmtree(cache_zarr_path)
|
||||
raise e
|
||||
else:
|
||||
print('Loading cached ReplayBuffer from Disk.')
|
||||
with zarr.ZipStore(cache_zarr_path, mode='r') as zip_store:
|
||||
replay_buffer = ReplayBuffer.copy_from_store(
|
||||
src_store=zip_store, store=zarr.MemoryStore())
|
||||
print('Loaded!')
|
||||
else:
|
||||
replay_buffer = _convert_robomimic_to_replay(
|
||||
store=zarr.MemoryStore(),
|
||||
shape_meta=shape_meta,
|
||||
dataset_path=dataset_path,
|
||||
abs_action=abs_action,
|
||||
rotation_transformer=rotation_transformer)
|
||||
|
||||
rgb_keys = list()
|
||||
lowdim_keys = list()
|
||||
obs_shape_meta = shape_meta['obs']
|
||||
for key, attr in obs_shape_meta.items():
|
||||
type = attr.get('type', 'low_dim')
|
||||
if type == 'rgb':
|
||||
rgb_keys.append(key)
|
||||
elif type == 'low_dim':
|
||||
lowdim_keys.append(key)
|
||||
|
||||
# for key in rgb_keys:
|
||||
# replay_buffer[key].compressor.numthreads=1
|
||||
|
||||
key_first_k = dict()
|
||||
if n_obs_steps is not None:
|
||||
# only take first k obs from images
|
||||
for key in rgb_keys + lowdim_keys:
|
||||
key_first_k[key] = n_obs_steps
|
||||
|
||||
val_mask = get_val_mask(
|
||||
n_episodes=replay_buffer.n_episodes,
|
||||
val_ratio=val_ratio,
|
||||
seed=seed)
|
||||
train_mask = ~val_mask
|
||||
sampler = SequenceSampler(
|
||||
replay_buffer=replay_buffer,
|
||||
sequence_length=horizon,
|
||||
pad_before=pad_before,
|
||||
pad_after=pad_after,
|
||||
episode_mask=train_mask,
|
||||
key_first_k=key_first_k)
|
||||
|
||||
self.replay_buffer = replay_buffer
|
||||
self.sampler = sampler
|
||||
self.shape_meta = shape_meta
|
||||
self.rgb_keys = rgb_keys
|
||||
self.lowdim_keys = lowdim_keys
|
||||
self.abs_action = abs_action
|
||||
self.n_obs_steps = n_obs_steps
|
||||
self.train_mask = train_mask
|
||||
self.horizon = horizon
|
||||
self.pad_before = pad_before
|
||||
self.pad_after = pad_after
|
||||
self.use_legacy_normalizer = use_legacy_normalizer
|
||||
|
||||
def get_validation_dataset(self):
|
||||
val_set = copy.copy(self)
|
||||
val_set.sampler = SequenceSampler(
|
||||
replay_buffer=self.replay_buffer,
|
||||
sequence_length=self.horizon,
|
||||
pad_before=self.pad_before,
|
||||
pad_after=self.pad_after,
|
||||
episode_mask=~self.train_mask
|
||||
)
|
||||
val_set.train_mask = ~self.train_mask
|
||||
return val_set
|
||||
|
||||
def get_normalizer(self, **kwargs) -> LinearNormalizer:
|
||||
normalizer = LinearNormalizer()
|
||||
|
||||
# action
|
||||
stat = array_to_stats(self.replay_buffer['action'])
|
||||
if self.abs_action:
|
||||
if stat['mean'].shape[-1] > 10:
|
||||
# dual arm
|
||||
this_normalizer = robomimic_abs_action_only_dual_arm_normalizer_from_stat(stat)
|
||||
else:
|
||||
this_normalizer = robomimic_abs_action_only_normalizer_from_stat(stat)
|
||||
|
||||
if self.use_legacy_normalizer:
|
||||
this_normalizer = normalizer_from_stat(stat)
|
||||
else:
|
||||
# already normalized
|
||||
this_normalizer = get_identity_normalizer_from_stat(stat)
|
||||
normalizer['action'] = this_normalizer
|
||||
|
||||
# obs
|
||||
for key in self.lowdim_keys:
|
||||
stat = array_to_stats(self.replay_buffer[key])
|
||||
|
||||
if key.endswith('pos'):
|
||||
this_normalizer = get_range_normalizer_from_stat(stat)
|
||||
elif key.endswith('quat'):
|
||||
# quaternion is in [-1,1] already
|
||||
this_normalizer = get_identity_normalizer_from_stat(stat)
|
||||
elif key.endswith('qpos'):
|
||||
this_normalizer = get_range_normalizer_from_stat(stat)
|
||||
else:
|
||||
raise RuntimeError('unsupported')
|
||||
normalizer[key] = this_normalizer
|
||||
|
||||
# image
|
||||
for key in self.rgb_keys:
|
||||
normalizer[key] = get_image_range_normalizer()
|
||||
return normalizer
|
||||
|
||||
def get_all_actions(self) -> torch.Tensor:
|
||||
return torch.from_numpy(self.replay_buffer['action'])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sampler)
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
||||
threadpool_limits(1)
|
||||
data = self.sampler.sample_sequence(idx)
|
||||
|
||||
# to save RAM, only return first n_obs_steps of OBS
|
||||
# since the rest will be discarded anyway.
|
||||
# when self.n_obs_steps is None
|
||||
# this slice does nothing (takes all)
|
||||
T_slice = slice(self.n_obs_steps)
|
||||
|
||||
obs_dict = dict()
|
||||
for key in self.rgb_keys:
|
||||
# move channel last to channel first
|
||||
# T,H,W,C
|
||||
# convert uint8 image to float32
|
||||
obs_dict[key] = np.moveaxis(data[key][T_slice],-1,1
|
||||
).astype(np.float32) / 255.
|
||||
# T,C,H,W
|
||||
del data[key]
|
||||
for key in self.lowdim_keys:
|
||||
obs_dict[key] = data[key][T_slice].astype(np.float32)
|
||||
del data[key]
|
||||
|
||||
torch_data = {
|
||||
'obs': dict_apply(obs_dict, torch.from_numpy),
|
||||
'action': torch.from_numpy(data['action'].astype(np.float32))
|
||||
}
|
||||
return torch_data
|
||||
|
||||
|
||||
def _convert_actions(raw_actions, abs_action, rotation_transformer):
|
||||
actions = raw_actions
|
||||
if abs_action:
|
||||
is_dual_arm = False
|
||||
if raw_actions.shape[-1] == 14:
|
||||
# dual arm
|
||||
raw_actions = raw_actions.reshape(-1,2,7)
|
||||
is_dual_arm = True
|
||||
|
||||
pos = raw_actions[...,:3]
|
||||
rot = raw_actions[...,3:6]
|
||||
gripper = raw_actions[...,6:]
|
||||
rot = rotation_transformer.forward(rot)
|
||||
raw_actions = np.concatenate([
|
||||
pos, rot, gripper
|
||||
], axis=-1).astype(np.float32)
|
||||
|
||||
if is_dual_arm:
|
||||
raw_actions = raw_actions.reshape(-1,20)
|
||||
actions = raw_actions
|
||||
return actions
|
||||
|
||||
|
||||
def _convert_robomimic_to_replay(store, shape_meta, dataset_path, abs_action, rotation_transformer,
|
||||
n_workers=None, max_inflight_tasks=None):
|
||||
if n_workers is None:
|
||||
n_workers = multiprocessing.cpu_count()
|
||||
if max_inflight_tasks is None:
|
||||
max_inflight_tasks = n_workers * 5
|
||||
|
||||
# parse shape_meta
|
||||
rgb_keys = list()
|
||||
lowdim_keys = list()
|
||||
# construct compressors and chunks
|
||||
obs_shape_meta = shape_meta['obs']
|
||||
for key, attr in obs_shape_meta.items():
|
||||
shape = attr['shape']
|
||||
type = attr.get('type', 'low_dim')
|
||||
if type == 'rgb':
|
||||
rgb_keys.append(key)
|
||||
elif type == 'low_dim':
|
||||
lowdim_keys.append(key)
|
||||
|
||||
root = zarr.group(store)
|
||||
data_group = root.require_group('data', overwrite=True)
|
||||
meta_group = root.require_group('meta', overwrite=True)
|
||||
|
||||
with h5py.File(dataset_path) as file:
|
||||
# count total steps
|
||||
demos = file['data']
|
||||
episode_ends = list()
|
||||
prev_end = 0
|
||||
for i in range(len(demos)):
|
||||
demo = demos[f'demo_{i}']
|
||||
episode_length = demo['actions'].shape[0]
|
||||
episode_end = prev_end + episode_length
|
||||
prev_end = episode_end
|
||||
episode_ends.append(episode_end)
|
||||
n_steps = episode_ends[-1]
|
||||
episode_starts = [0] + episode_ends[:-1]
|
||||
_ = meta_group.array('episode_ends', episode_ends,
|
||||
dtype=np.int64, compressor=None, overwrite=True)
|
||||
|
||||
# save lowdim data
|
||||
for key in tqdm(lowdim_keys + ['action'], desc="Loading lowdim data"):
|
||||
data_key = 'obs/' + key
|
||||
if key == 'action':
|
||||
data_key = 'actions'
|
||||
this_data = list()
|
||||
for i in range(len(demos)):
|
||||
demo = demos[f'demo_{i}']
|
||||
this_data.append(demo[data_key][:].astype(np.float32))
|
||||
this_data = np.concatenate(this_data, axis=0)
|
||||
if key == 'action':
|
||||
this_data = _convert_actions(
|
||||
raw_actions=this_data,
|
||||
abs_action=abs_action,
|
||||
rotation_transformer=rotation_transformer
|
||||
)
|
||||
assert this_data.shape == (n_steps,) + tuple(shape_meta['action']['shape'])
|
||||
else:
|
||||
assert this_data.shape == (n_steps,) + tuple(shape_meta['obs'][key]['shape'])
|
||||
_ = data_group.array(
|
||||
name=key,
|
||||
data=this_data,
|
||||
shape=this_data.shape,
|
||||
chunks=this_data.shape,
|
||||
compressor=None,
|
||||
dtype=this_data.dtype
|
||||
)
|
||||
|
||||
def img_copy(zarr_arr, zarr_idx, hdf5_arr, hdf5_idx):
|
||||
try:
|
||||
zarr_arr[zarr_idx] = hdf5_arr[hdf5_idx]
|
||||
# make sure we can successfully decode
|
||||
_ = zarr_arr[zarr_idx]
|
||||
return True
|
||||
except Exception as e:
|
||||
return False
|
||||
|
||||
with tqdm(total=n_steps*len(rgb_keys), desc="Loading image data", mininterval=1.0) as pbar:
|
||||
# one chunk per thread, therefore no synchronization needed
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as executor:
|
||||
futures = set()
|
||||
for key in rgb_keys:
|
||||
data_key = 'obs/' + key
|
||||
shape = tuple(shape_meta['obs'][key]['shape'])
|
||||
c,h,w = shape
|
||||
this_compressor = Jpeg2k(level=50)
|
||||
img_arr = data_group.require_dataset(
|
||||
name=key,
|
||||
shape=(n_steps,h,w,c),
|
||||
chunks=(1,h,w,c),
|
||||
compressor=this_compressor,
|
||||
dtype=np.uint8
|
||||
)
|
||||
for episode_idx in range(len(demos)):
|
||||
demo = demos[f'demo_{episode_idx}']
|
||||
hdf5_arr = demo['obs'][key]
|
||||
for hdf5_idx in range(hdf5_arr.shape[0]):
|
||||
if len(futures) >= max_inflight_tasks:
|
||||
# limit number of inflight tasks
|
||||
completed, futures = concurrent.futures.wait(futures,
|
||||
return_when=concurrent.futures.FIRST_COMPLETED)
|
||||
for f in completed:
|
||||
if not f.result():
|
||||
raise RuntimeError('Failed to encode image!')
|
||||
pbar.update(len(completed))
|
||||
|
||||
zarr_idx = episode_starts[episode_idx] + hdf5_idx
|
||||
futures.add(
|
||||
executor.submit(img_copy,
|
||||
img_arr, zarr_idx, hdf5_arr, hdf5_idx))
|
||||
completed, futures = concurrent.futures.wait(futures)
|
||||
for f in completed:
|
||||
if not f.result():
|
||||
raise RuntimeError('Failed to encode image!')
|
||||
pbar.update(len(completed))
|
||||
|
||||
replay_buffer = ReplayBuffer(root)
|
||||
return replay_buffer
|
||||
|
||||
def normalizer_from_stat(stat):
|
||||
max_abs = np.maximum(stat['max'].max(), np.abs(stat['min']).max())
|
||||
scale = np.full_like(stat['max'], fill_value=1/max_abs)
|
||||
offset = np.zeros_like(stat['max'])
|
||||
return SingleFieldLinearNormalizer.create_manual(
|
||||
scale=scale,
|
||||
offset=offset,
|
||||
input_stats_dict=stat
|
||||
)
|
||||
168
diffusion_policy/dataset/robomimic_replay_lowdim_dataset.py
Normal file
168
diffusion_policy/dataset/robomimic_replay_lowdim_dataset.py
Normal file
@@ -0,0 +1,168 @@
|
||||
from typing import Dict, List
|
||||
import torch
|
||||
import numpy as np
|
||||
import h5py
|
||||
from tqdm import tqdm
|
||||
import copy
|
||||
from diffusion_policy.common.pytorch_util import dict_apply
|
||||
from diffusion_policy.dataset.base_dataset import BaseLowdimDataset, LinearNormalizer
|
||||
from diffusion_policy.model.common.normalizer import LinearNormalizer, SingleFieldLinearNormalizer
|
||||
from diffusion_policy.model.common.rotation_transformer import RotationTransformer
|
||||
from diffusion_policy.common.replay_buffer import ReplayBuffer
|
||||
from diffusion_policy.common.sampler import (
|
||||
SequenceSampler, get_val_mask, downsample_mask)
|
||||
from diffusion_policy.common.normalize_util import (
|
||||
robomimic_abs_action_only_normalizer_from_stat,
|
||||
robomimic_abs_action_only_dual_arm_normalizer_from_stat,
|
||||
get_identity_normalizer_from_stat,
|
||||
array_to_stats
|
||||
)
|
||||
|
||||
class RobomimicReplayLowdimDataset(BaseLowdimDataset):
|
||||
def __init__(self,
|
||||
dataset_path: str,
|
||||
horizon=1,
|
||||
pad_before=0,
|
||||
pad_after=0,
|
||||
obs_keys: List[str]=[
|
||||
'object',
|
||||
'robot0_eef_pos',
|
||||
'robot0_eef_quat',
|
||||
'robot0_gripper_qpos'],
|
||||
abs_action=False,
|
||||
rotation_rep='rotation_6d',
|
||||
use_legacy_normalizer=False,
|
||||
seed=42,
|
||||
val_ratio=0.0,
|
||||
max_train_episodes=None
|
||||
):
|
||||
obs_keys = list(obs_keys)
|
||||
rotation_transformer = RotationTransformer(
|
||||
from_rep='axis_angle', to_rep=rotation_rep)
|
||||
|
||||
replay_buffer = ReplayBuffer.create_empty_numpy()
|
||||
with h5py.File(dataset_path) as file:
|
||||
demos = file['data']
|
||||
for i in tqdm(range(len(demos)), desc="Loading hdf5 to ReplayBuffer"):
|
||||
demo = demos[f'demo_{i}']
|
||||
episode = _data_to_obs(
|
||||
raw_obs=demo['obs'],
|
||||
raw_actions=demo['actions'][:].astype(np.float32),
|
||||
obs_keys=obs_keys,
|
||||
abs_action=abs_action,
|
||||
rotation_transformer=rotation_transformer)
|
||||
replay_buffer.add_episode(episode)
|
||||
|
||||
val_mask = get_val_mask(
|
||||
n_episodes=replay_buffer.n_episodes,
|
||||
val_ratio=val_ratio,
|
||||
seed=seed)
|
||||
train_mask = ~val_mask
|
||||
train_mask = downsample_mask(
|
||||
mask=train_mask,
|
||||
max_n=max_train_episodes,
|
||||
seed=seed)
|
||||
|
||||
sampler = SequenceSampler(
|
||||
replay_buffer=replay_buffer,
|
||||
sequence_length=horizon,
|
||||
pad_before=pad_before,
|
||||
pad_after=pad_after,
|
||||
episode_mask=train_mask)
|
||||
|
||||
self.replay_buffer = replay_buffer
|
||||
self.sampler = sampler
|
||||
self.abs_action = abs_action
|
||||
self.train_mask = train_mask
|
||||
self.horizon = horizon
|
||||
self.pad_before = pad_before
|
||||
self.pad_after = pad_after
|
||||
self.use_legacy_normalizer = use_legacy_normalizer
|
||||
|
||||
def get_validation_dataset(self):
|
||||
val_set = copy.copy(self)
|
||||
val_set.sampler = SequenceSampler(
|
||||
replay_buffer=self.replay_buffer,
|
||||
sequence_length=self.horizon,
|
||||
pad_before=self.pad_before,
|
||||
pad_after=self.pad_after,
|
||||
episode_mask=~self.train_mask
|
||||
)
|
||||
val_set.train_mask = ~self.train_mask
|
||||
return val_set
|
||||
|
||||
def get_normalizer(self, **kwargs) -> LinearNormalizer:
|
||||
normalizer = LinearNormalizer()
|
||||
|
||||
# action
|
||||
stat = array_to_stats(self.replay_buffer['action'])
|
||||
if self.abs_action:
|
||||
if stat['mean'].shape[-1] > 10:
|
||||
# dual arm
|
||||
this_normalizer = robomimic_abs_action_only_dual_arm_normalizer_from_stat(stat)
|
||||
else:
|
||||
this_normalizer = robomimic_abs_action_only_normalizer_from_stat(stat)
|
||||
|
||||
if self.use_legacy_normalizer:
|
||||
this_normalizer = normalizer_from_stat(stat)
|
||||
else:
|
||||
# already normalized
|
||||
this_normalizer = get_identity_normalizer_from_stat(stat)
|
||||
normalizer['action'] = this_normalizer
|
||||
|
||||
# aggregate obs stats
|
||||
obs_stat = array_to_stats(self.replay_buffer['obs'])
|
||||
|
||||
|
||||
normalizer['obs'] = normalizer_from_stat(obs_stat)
|
||||
return normalizer
|
||||
|
||||
def get_all_actions(self) -> torch.Tensor:
|
||||
return torch.from_numpy(self.replay_buffer['action'])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sampler)
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
||||
data = self.sampler.sample_sequence(idx)
|
||||
torch_data = dict_apply(data, torch.from_numpy)
|
||||
return torch_data
|
||||
|
||||
def normalizer_from_stat(stat):
|
||||
max_abs = np.maximum(stat['max'].max(), np.abs(stat['min']).max())
|
||||
scale = np.full_like(stat['max'], fill_value=1/max_abs)
|
||||
offset = np.zeros_like(stat['max'])
|
||||
return SingleFieldLinearNormalizer.create_manual(
|
||||
scale=scale,
|
||||
offset=offset,
|
||||
input_stats_dict=stat
|
||||
)
|
||||
|
||||
def _data_to_obs(raw_obs, raw_actions, obs_keys, abs_action, rotation_transformer):
|
||||
obs = np.concatenate([
|
||||
raw_obs[key] for key in obs_keys
|
||||
], axis=-1).astype(np.float32)
|
||||
|
||||
if abs_action:
|
||||
is_dual_arm = False
|
||||
if raw_actions.shape[-1] == 14:
|
||||
# dual arm
|
||||
raw_actions = raw_actions.reshape(-1,2,7)
|
||||
is_dual_arm = True
|
||||
|
||||
pos = raw_actions[...,:3]
|
||||
rot = raw_actions[...,3:6]
|
||||
gripper = raw_actions[...,6:]
|
||||
rot = rotation_transformer.forward(rot)
|
||||
raw_actions = np.concatenate([
|
||||
pos, rot, gripper
|
||||
], axis=-1).astype(np.float32)
|
||||
|
||||
if is_dual_arm:
|
||||
raw_actions = raw_actions.reshape(-1,20)
|
||||
|
||||
data = {
|
||||
'obs': obs,
|
||||
'action': raw_actions
|
||||
}
|
||||
return data
|
||||
Reference in New Issue
Block a user