chore(pusht): add 5090 repro docs and uv setup
This commit is contained in:
@@ -8,8 +8,7 @@ import dill
|
||||
import math
|
||||
import wandb.sdk.data_types.video as wv
|
||||
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
|
||||
from diffusion_policy.gym_util.async_vector_env import AsyncVectorEnv
|
||||
# from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
|
||||
from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
|
||||
from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
|
||||
from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
|
||||
|
||||
@@ -121,7 +120,9 @@ class PushTImageRunner(BaseImageRunner):
|
||||
env_prefixs.append('test/')
|
||||
env_init_fn_dills.append(dill.dumps(init_fn))
|
||||
|
||||
env = AsyncVectorEnv(env_fns)
|
||||
# This environment can run without multiprocessing, which avoids
|
||||
# shared-memory and semaphore restrictions on some machines.
|
||||
env = SyncVectorEnv(env_fns)
|
||||
|
||||
# test env
|
||||
# env.reset(seed=env_seeds)
|
||||
|
||||
@@ -60,17 +60,44 @@ class SyncVectorEnv(VectorEnv):
|
||||
for env, seed in zip(self.envs, seeds):
|
||||
env.seed(seed)
|
||||
|
||||
def reset_wait(self):
|
||||
def reset_async(self, seed=None, return_info=False, options=None):
|
||||
if seed is None:
|
||||
seeds = [None for _ in range(self.num_envs)]
|
||||
elif isinstance(seed, int):
|
||||
seeds = [seed + i for i in range(self.num_envs)]
|
||||
else:
|
||||
seeds = list(seed)
|
||||
assert len(seeds) == self.num_envs
|
||||
self._reset_seeds = seeds
|
||||
self._reset_return_info = return_info
|
||||
self._reset_options = options
|
||||
|
||||
def reset_wait(self, seed=None, return_info=False, options=None):
|
||||
seeds = getattr(self, '_reset_seeds', None)
|
||||
if seeds is None:
|
||||
if seed is None:
|
||||
seeds = [None for _ in range(self.num_envs)]
|
||||
elif isinstance(seed, int):
|
||||
seeds = [seed + i for i in range(self.num_envs)]
|
||||
else:
|
||||
seeds = list(seed)
|
||||
self._dones[:] = False
|
||||
observations = []
|
||||
for env in self.envs:
|
||||
infos = []
|
||||
for env, seed_i in zip(self.envs, seeds):
|
||||
if seed_i is not None:
|
||||
env.seed(seed_i)
|
||||
observation = env.reset()
|
||||
observations.append(observation)
|
||||
infos.append({})
|
||||
self.observations = concatenate(
|
||||
observations, self.observations, self.single_observation_space
|
||||
self.single_observation_space, observations, self.observations
|
||||
)
|
||||
|
||||
return deepcopy(self.observations) if self.copy else self.observations
|
||||
obs = deepcopy(self.observations) if self.copy else self.observations
|
||||
if return_info:
|
||||
return obs, infos
|
||||
return obs
|
||||
|
||||
def step_async(self, actions):
|
||||
self._actions = actions
|
||||
@@ -84,7 +111,7 @@ class SyncVectorEnv(VectorEnv):
|
||||
observations.append(observation)
|
||||
infos.append(info)
|
||||
self.observations = concatenate(
|
||||
observations, self.observations, self.single_observation_space
|
||||
self.single_observation_space, observations, self.observations
|
||||
)
|
||||
|
||||
return (
|
||||
|
||||
Reference in New Issue
Block a user