fix(pusht): stabilize DiT pusht training on current stack

This commit is contained in:
Logic
2026-03-15 18:54:50 +08:00
parent 08c1950c6d
commit 2aa06c8917
2 changed files with 3 additions and 3 deletions

View File

@@ -8,8 +8,7 @@ import dill
import math import math
import wandb.sdk.data_types.video as wv import wandb.sdk.data_types.video as wv
from diffusion_policy.env.pusht.pusht_keypoints_env import PushTKeypointsEnv from diffusion_policy.env.pusht.pusht_keypoints_env import PushTKeypointsEnv
from diffusion_policy.gym_util.async_vector_env import AsyncVectorEnv from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
# from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
@@ -133,7 +132,7 @@ class PushTKeypointsRunner(BaseLowdimRunner):
env_prefixs.append('test/') env_prefixs.append('test/')
env_init_fn_dills.append(dill.dumps(init_fn)) env_init_fn_dills.append(dill.dumps(init_fn))
env = AsyncVectorEnv(env_fns) env = SyncVectorEnv(env_fns)
# test env # test env
# env.reset(seed=env_seeds) # env.reset(seed=env_seeds)

View File

@@ -22,6 +22,7 @@ pymunk==6.2.1
wandb==0.13.3 wandb==0.13.3
threadpoolctl==3.1.0 threadpoolctl==3.1.0
shapely==1.8.5.post1 shapely==1.8.5.post1
matplotlib==3.6.1
imageio==2.22.0 imageio==2.22.0
imageio-ffmpeg==0.4.7 imageio-ffmpeg==0.4.7
termcolor==2.0.1 termcolor==2.0.1