chore(pusht): add 5090 repro docs and uv setup

This commit is contained in:
Logic
2026-03-14 12:25:44 +08:00
parent 5ba07ac666
commit 08c1950c6d
6 changed files with 270 additions and 8 deletions

View File

@@ -8,8 +8,7 @@ import dill
import math
import wandb.sdk.data_types.video as wv
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
from diffusion_policy.gym_util.async_vector_env import AsyncVectorEnv
# from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
@@ -121,7 +120,9 @@ class PushTImageRunner(BaseImageRunner):
env_prefixs.append('test/')
env_init_fn_dills.append(dill.dumps(init_fn))
env = AsyncVectorEnv(env_fns)
# This environment can run without multiprocessing, which avoids
# shared-memory and semaphore restrictions on some machines.
env = SyncVectorEnv(env_fns)
# test env
# env.reset(seed=env_seeds)

View File

@@ -60,17 +60,44 @@ class SyncVectorEnv(VectorEnv):
for env, seed in zip(self.envs, seeds):
env.seed(seed)
def reset_wait(self):
def reset_async(self, seed=None, return_info=False, options=None):
if seed is None:
seeds = [None for _ in range(self.num_envs)]
elif isinstance(seed, int):
seeds = [seed + i for i in range(self.num_envs)]
else:
seeds = list(seed)
assert len(seeds) == self.num_envs
self._reset_seeds = seeds
self._reset_return_info = return_info
self._reset_options = options
def reset_wait(self, seed=None, return_info=False, options=None):
seeds = getattr(self, '_reset_seeds', None)
if seeds is None:
if seed is None:
seeds = [None for _ in range(self.num_envs)]
elif isinstance(seed, int):
seeds = [seed + i for i in range(self.num_envs)]
else:
seeds = list(seed)
self._dones[:] = False
observations = []
for env in self.envs:
infos = []
for env, seed_i in zip(self.envs, seeds):
if seed_i is not None:
env.seed(seed_i)
observation = env.reset()
observations.append(observation)
infos.append({})
self.observations = concatenate(
observations, self.observations, self.single_observation_space
self.single_observation_space, observations, self.observations
)
return deepcopy(self.observations) if self.copy else self.observations
obs = deepcopy(self.observations) if self.copy else self.observations
if return_info:
return obs, infos
return obs
def step_async(self, actions):
self._actions = actions
@@ -84,7 +111,7 @@ class SyncVectorEnv(VectorEnv):
observations.append(observation)
infos.append(info)
self.observations = concatenate(
observations, self.observations, self.single_observation_space
self.single_observation_space, observations, self.observations
)
return (