fix(pusht): stabilize DiT pusht training on current stack
This commit is contained in:
@@ -8,8 +8,7 @@ import dill
|
|||||||
import math
|
import math
|
||||||
import wandb.sdk.data_types.video as wv
|
import wandb.sdk.data_types.video as wv
|
||||||
from diffusion_policy.env.pusht.pusht_keypoints_env import PushTKeypointsEnv
|
from diffusion_policy.env.pusht.pusht_keypoints_env import PushTKeypointsEnv
|
||||||
from diffusion_policy.gym_util.async_vector_env import AsyncVectorEnv
|
from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
|
||||||
# from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
|
|
||||||
from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
|
from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
|
||||||
from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
|
from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
|
||||||
|
|
||||||
@@ -133,7 +132,7 @@ class PushTKeypointsRunner(BaseLowdimRunner):
|
|||||||
env_prefixs.append('test/')
|
env_prefixs.append('test/')
|
||||||
env_init_fn_dills.append(dill.dumps(init_fn))
|
env_init_fn_dills.append(dill.dumps(init_fn))
|
||||||
|
|
||||||
env = AsyncVectorEnv(env_fns)
|
env = SyncVectorEnv(env_fns)
|
||||||
|
|
||||||
# test env
|
# test env
|
||||||
# env.reset(seed=env_seeds)
|
# env.reset(seed=env_seeds)
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ pymunk==6.2.1
|
|||||||
wandb==0.13.3
|
wandb==0.13.3
|
||||||
threadpoolctl==3.1.0
|
threadpoolctl==3.1.0
|
||||||
shapely==1.8.5.post1
|
shapely==1.8.5.post1
|
||||||
|
matplotlib==3.6.1
|
||||||
imageio==2.22.0
|
imageio==2.22.0
|
||||||
imageio-ffmpeg==0.4.7
|
imageio-ffmpeg==0.4.7
|
||||||
termcolor==2.0.1
|
termcolor==2.0.1
|
||||||
|
|||||||
Reference in New Issue
Block a user