Files
diffusion_policy/image_pusht_diffusion_policy_dit_pmf.yaml

190 lines
3.9 KiB
YAML

_target_: diffusion_policy.workspace.train_diffusion_transformer_hybrid_workspace.TrainDiffusionTransformerHybridWorkspace
checkpoint:
save_last_ckpt: true
save_last_snapshot: false
topk:
format_str: epoch={epoch:04d}-train_loss={train_loss:.3f}.ckpt
k: 5
mode: min
monitor_key: train_loss
dataloader:
batch_size: 64
num_workers: 8
persistent_workers: false
pin_memory: true
shuffle: true
dataset_obs_steps: 2
ema:
_target_: diffusion_policy.model.diffusion.ema_model.EMAModel
inv_gamma: 1.0
max_value: 0.9999
min_value: 0.0
power: 0.75
update_after_step: 0
exp_name: default
horizon: 16
keypoint_visible_rate: 1.0
logging:
group: null
id: null
mode: online
name: ${now:%Y.%m.%d-%H.%M.%S}_train_diffusion_transformer_hybrid_pmf_pusht_image
project: diffusion_policy_debug
resume: true
tags:
- train_diffusion_transformer_hybrid_pmf
- pusht_image
- default
multi_run:
run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_train_diffusion_transformer_hybrid_pmf_pusht_image
wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_train_diffusion_transformer_hybrid_pmf_pusht_image
n_action_steps: 8
n_latency_steps: 0
n_obs_steps: 2
name: train_diffusion_transformer_hybrid_pmf
obs_as_cond: true
optimizer:
betas:
- 0.9
- 0.95
learning_rate: 0.0001
obs_encoder_weight_decay: 1.0e-06
transformer_weight_decay: 0.001
past_action_visible: false
policy:
_target_: diffusion_policy.policy.pmf_transformer_hybrid_image_policy.PMFTransformerHybridImagePolicy
crop_shape:
- 84
- 84
eval_fixed_crop: true
horizon: 16
n_action_steps: 8
n_cond_layers: 0
n_emb: 256
n_head: 4
n_layer: 12
n_obs_steps: 2
n_time_tokens: 4
noise_scale: 1.0
adatloss_eps: 0.01
p_mean: -0.4
p_std: 1.0
noise_scheduler:
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
beta_end: 0.02
beta_schedule: squaredcos_cap_v2
beta_start: 0.0001
clip_sample: true
num_train_timesteps: 100
prediction_type: sample
variance_type: fixed_small
num_inference_steps: 1
obs_as_cond: true
obs_encoder_group_norm: true
p_drop_attn: 0.0
p_drop_emb: 0.0
pmf_u_loss_weight: 1.0
pmf_v_loss_weight: 1.0
tr_uniform: true
tr_uniform_prob: 0.1
data_proportion: 0.5
shape_meta:
action:
shape:
- 2
obs:
agent_pos:
shape:
- 2
type: low_dim
image:
shape:
- 3
- 96
- 96
type: rgb
shape_meta:
action:
shape:
- 2
obs:
agent_pos:
shape:
- 2
type: low_dim
image:
shape:
- 3
- 96
- 96
type: rgb
task:
dataset:
_target_: diffusion_policy.dataset.pusht_image_dataset.PushTImageDataset
horizon: 16
max_train_episodes: 90
pad_after: 7
pad_before: 1
seed: 42
val_ratio: 0.02
zarr_path: data/pusht/pusht_cchi_v7_replay.zarr
env_runner:
_target_: diffusion_policy.env_runner.pusht_image_runner.PushTImageRunner
fps: 10
legacy_test: true
max_steps: 300
n_action_steps: 8
n_envs: null
n_obs_steps: 2
n_test: 50
n_test_vis: 4
n_train: 6
n_train_vis: 2
past_action: false
test_start_seed: 100000
train_start_seed: 0
image_shape:
- 3
- 96
- 96
name: pusht_image
shape_meta:
action:
shape:
- 2
obs:
agent_pos:
shape:
- 2
type: low_dim
image:
shape:
- 3
- 96
- 96
type: rgb
task_name: pusht_image
training:
checkpoint_every: 50
debug: false
device: cuda:0
gradient_accumulate_every: 1
lr_scheduler: cosine
lr_warmup_steps: 500
max_train_steps: null
max_val_steps: null
num_epochs: 600
resume: true
rollout_every: 50
sample_every: 5
seed: 42
tqdm_interval_sec: 1.0
use_ema: true
val_every: 1
val_dataloader:
batch_size: 64
num_workers: 8
persistent_workers: false
pin_memory: true
shuffle: false