Done adapting mujoco image dataset
This commit is contained in:
182
image_pusht_diffusion_policy_cnn.yaml
Normal file
182
image_pusht_diffusion_policy_cnn.yaml
Normal file
@@ -0,0 +1,182 @@
|
||||
_target_: diffusion_policy.workspace.train_diffusion_unet_hybrid_workspace.TrainDiffusionUnetHybridWorkspace
|
||||
checkpoint:
|
||||
save_last_ckpt: true
|
||||
save_last_snapshot: false
|
||||
topk:
|
||||
format_str: epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt
|
||||
k: 5
|
||||
mode: max
|
||||
monitor_key: test_mean_score
|
||||
dataloader:
|
||||
batch_size: 64
|
||||
num_workers: 8
|
||||
persistent_workers: false
|
||||
pin_memory: true
|
||||
shuffle: true
|
||||
dataset_obs_steps: 2
|
||||
ema:
|
||||
_target_: diffusion_policy.model.diffusion.ema_model.EMAModel
|
||||
inv_gamma: 1.0
|
||||
max_value: 0.9999
|
||||
min_value: 0.0
|
||||
power: 0.75
|
||||
update_after_step: 0
|
||||
exp_name: default
|
||||
horizon: 16
|
||||
keypoint_visible_rate: 1.0
|
||||
logging:
|
||||
group: null
|
||||
id: null
|
||||
mode: online
|
||||
name: 2023.01.16-20.20.06_train_diffusion_unet_hybrid_pusht_image
|
||||
project: diffusion_policy_debug
|
||||
resume: true
|
||||
tags:
|
||||
- train_diffusion_unet_hybrid
|
||||
- pusht_image
|
||||
- default
|
||||
multi_run:
|
||||
run_dir: data/outputs/2023.01.16/20.20.06_train_diffusion_unet_hybrid_pusht_image
|
||||
wandb_name_base: 2023.01.16-20.20.06_train_diffusion_unet_hybrid_pusht_image
|
||||
n_action_steps: 8
|
||||
n_latency_steps: 0
|
||||
n_obs_steps: 2
|
||||
name: train_diffusion_unet_hybrid
|
||||
obs_as_global_cond: true
|
||||
optimizer:
|
||||
_target_: torch.optim.AdamW
|
||||
betas:
|
||||
- 0.95
|
||||
- 0.999
|
||||
eps: 1.0e-08
|
||||
lr: 0.0001
|
||||
weight_decay: 1.0e-06
|
||||
past_action_visible: false
|
||||
policy:
|
||||
_target_: diffusion_policy.policy.diffusion_unet_hybrid_image_policy.DiffusionUnetHybridImagePolicy
|
||||
cond_predict_scale: true
|
||||
crop_shape:
|
||||
- 84
|
||||
- 84
|
||||
diffusion_step_embed_dim: 128
|
||||
down_dims:
|
||||
- 512
|
||||
- 1024
|
||||
- 2048
|
||||
eval_fixed_crop: true
|
||||
horizon: 16
|
||||
kernel_size: 5
|
||||
n_action_steps: 8
|
||||
n_groups: 8
|
||||
n_obs_steps: 2
|
||||
noise_scheduler:
|
||||
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
||||
beta_end: 0.02
|
||||
beta_schedule: squaredcos_cap_v2
|
||||
beta_start: 0.0001
|
||||
clip_sample: true
|
||||
num_train_timesteps: 100
|
||||
prediction_type: epsilon
|
||||
variance_type: fixed_small
|
||||
num_inference_steps: 100
|
||||
obs_as_global_cond: true
|
||||
obs_encoder_group_norm: true
|
||||
shape_meta:
|
||||
action:
|
||||
shape:
|
||||
- 2
|
||||
obs:
|
||||
agent_pos:
|
||||
shape:
|
||||
- 2
|
||||
type: low_dim
|
||||
image:
|
||||
shape:
|
||||
- 3
|
||||
- 96
|
||||
- 96
|
||||
type: rgb
|
||||
shape_meta:
|
||||
action:
|
||||
shape:
|
||||
- 2
|
||||
obs:
|
||||
agent_pos:
|
||||
shape:
|
||||
- 2
|
||||
type: low_dim
|
||||
image:
|
||||
shape:
|
||||
- 3
|
||||
- 96
|
||||
- 96
|
||||
type: rgb
|
||||
task:
|
||||
dataset:
|
||||
_target_: diffusion_policy.dataset.pusht_image_dataset.PushTImageDataset
|
||||
horizon: 16
|
||||
max_train_episodes: 90
|
||||
pad_after: 7
|
||||
pad_before: 1
|
||||
seed: 42
|
||||
val_ratio: 0.02
|
||||
zarr_path: data/pusht/pusht_cchi_v7_replay.zarr
|
||||
env_runner:
|
||||
_target_: diffusion_policy.env_runner.pusht_image_runner.PushTImageRunner
|
||||
fps: 10
|
||||
legacy_test: true
|
||||
max_steps: 300
|
||||
n_action_steps: 8
|
||||
n_envs: null
|
||||
n_obs_steps: 2
|
||||
n_test: 50
|
||||
n_test_vis: 4
|
||||
n_train: 6
|
||||
n_train_vis: 2
|
||||
past_action: false
|
||||
test_start_seed: 100000
|
||||
train_start_seed: 0
|
||||
image_shape:
|
||||
- 3
|
||||
- 96
|
||||
- 96
|
||||
name: pusht_image
|
||||
shape_meta:
|
||||
action:
|
||||
shape:
|
||||
- 2
|
||||
obs:
|
||||
agent_pos:
|
||||
shape:
|
||||
- 2
|
||||
type: low_dim
|
||||
image:
|
||||
shape:
|
||||
- 3
|
||||
- 96
|
||||
- 96
|
||||
type: rgb
|
||||
task_name: pusht_image
|
||||
training:
|
||||
checkpoint_every: 50
|
||||
debug: false
|
||||
device: cuda:0
|
||||
gradient_accumulate_every: 1
|
||||
lr_scheduler: cosine
|
||||
lr_warmup_steps: 500
|
||||
max_train_steps: null
|
||||
max_val_steps: null
|
||||
num_epochs: 3050
|
||||
resume: true
|
||||
rollout_every: 50
|
||||
sample_every: 5
|
||||
seed: 42
|
||||
tqdm_interval_sec: 1.0
|
||||
use_ema: true
|
||||
val_every: 1
|
||||
val_dataloader:
|
||||
batch_size: 64
|
||||
num_workers: 8
|
||||
persistent_workers: false
|
||||
pin_memory: true
|
||||
shuffle: false
|
||||
Reference in New Issue
Block a user