190 lines
3.9 KiB
YAML
190 lines
3.9 KiB
YAML
_target_: diffusion_policy.workspace.train_diffusion_transformer_hybrid_workspace.TrainDiffusionTransformerHybridWorkspace
|
|
checkpoint:
|
|
save_last_ckpt: true
|
|
save_last_snapshot: false
|
|
topk:
|
|
format_str: epoch={epoch:04d}-train_loss={train_loss:.3f}.ckpt
|
|
k: 5
|
|
mode: min
|
|
monitor_key: train_loss
|
|
dataloader:
|
|
batch_size: 64
|
|
num_workers: 8
|
|
persistent_workers: false
|
|
pin_memory: true
|
|
shuffle: true
|
|
dataset_obs_steps: 2
|
|
ema:
|
|
_target_: diffusion_policy.model.diffusion.ema_model.EMAModel
|
|
inv_gamma: 1.0
|
|
max_value: 0.9999
|
|
min_value: 0.0
|
|
power: 0.75
|
|
update_after_step: 0
|
|
exp_name: default
|
|
horizon: 16
|
|
keypoint_visible_rate: 1.0
|
|
logging:
|
|
group: null
|
|
id: null
|
|
mode: online
|
|
name: ${now:%Y.%m.%d-%H.%M.%S}_train_diffusion_transformer_hybrid_pmf_pusht_image
|
|
project: diffusion_policy_debug
|
|
resume: true
|
|
tags:
|
|
- train_diffusion_transformer_hybrid_pmf
|
|
- pusht_image
|
|
- default
|
|
multi_run:
|
|
run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_train_diffusion_transformer_hybrid_pmf_pusht_image
|
|
wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_train_diffusion_transformer_hybrid_pmf_pusht_image
|
|
n_action_steps: 8
|
|
n_latency_steps: 0
|
|
n_obs_steps: 2
|
|
name: train_diffusion_transformer_hybrid_pmf
|
|
obs_as_cond: true
|
|
optimizer:
|
|
betas:
|
|
- 0.9
|
|
- 0.95
|
|
learning_rate: 0.0001
|
|
obs_encoder_weight_decay: 1.0e-06
|
|
transformer_weight_decay: 0.001
|
|
past_action_visible: false
|
|
policy:
|
|
_target_: diffusion_policy.policy.pmf_transformer_hybrid_image_policy.PMFTransformerHybridImagePolicy
|
|
crop_shape:
|
|
- 84
|
|
- 84
|
|
eval_fixed_crop: true
|
|
horizon: 16
|
|
n_action_steps: 8
|
|
n_cond_layers: 0
|
|
n_emb: 256
|
|
n_head: 4
|
|
n_layer: 12
|
|
n_obs_steps: 2
|
|
n_time_tokens: 4
|
|
noise_scale: 1.0
|
|
adatloss_eps: 0.01
|
|
p_mean: -0.4
|
|
p_std: 1.0
|
|
noise_scheduler:
|
|
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
|
beta_end: 0.02
|
|
beta_schedule: squaredcos_cap_v2
|
|
beta_start: 0.0001
|
|
clip_sample: true
|
|
num_train_timesteps: 100
|
|
prediction_type: sample
|
|
variance_type: fixed_small
|
|
num_inference_steps: 1
|
|
obs_as_cond: true
|
|
obs_encoder_group_norm: true
|
|
p_drop_attn: 0.0
|
|
p_drop_emb: 0.0
|
|
pmf_u_loss_weight: 1.0
|
|
pmf_v_loss_weight: 1.0
|
|
tr_uniform: true
|
|
tr_uniform_prob: 0.1
|
|
data_proportion: 0.5
|
|
shape_meta:
|
|
action:
|
|
shape:
|
|
- 2
|
|
obs:
|
|
agent_pos:
|
|
shape:
|
|
- 2
|
|
type: low_dim
|
|
image:
|
|
shape:
|
|
- 3
|
|
- 96
|
|
- 96
|
|
type: rgb
|
|
shape_meta:
|
|
action:
|
|
shape:
|
|
- 2
|
|
obs:
|
|
agent_pos:
|
|
shape:
|
|
- 2
|
|
type: low_dim
|
|
image:
|
|
shape:
|
|
- 3
|
|
- 96
|
|
- 96
|
|
type: rgb
|
|
task:
|
|
dataset:
|
|
_target_: diffusion_policy.dataset.pusht_image_dataset.PushTImageDataset
|
|
horizon: 16
|
|
max_train_episodes: 90
|
|
pad_after: 7
|
|
pad_before: 1
|
|
seed: 42
|
|
val_ratio: 0.02
|
|
zarr_path: data/pusht/pusht_cchi_v7_replay.zarr
|
|
env_runner:
|
|
_target_: diffusion_policy.env_runner.pusht_image_runner.PushTImageRunner
|
|
fps: 10
|
|
legacy_test: true
|
|
max_steps: 300
|
|
n_action_steps: 8
|
|
n_envs: null
|
|
n_obs_steps: 2
|
|
n_test: 50
|
|
n_test_vis: 4
|
|
n_train: 6
|
|
n_train_vis: 2
|
|
past_action: false
|
|
test_start_seed: 100000
|
|
train_start_seed: 0
|
|
image_shape:
|
|
- 3
|
|
- 96
|
|
- 96
|
|
name: pusht_image
|
|
shape_meta:
|
|
action:
|
|
shape:
|
|
- 2
|
|
obs:
|
|
agent_pos:
|
|
shape:
|
|
- 2
|
|
type: low_dim
|
|
image:
|
|
shape:
|
|
- 3
|
|
- 96
|
|
- 96
|
|
type: rgb
|
|
task_name: pusht_image
|
|
training:
|
|
checkpoint_every: 50
|
|
debug: false
|
|
device: cuda:0
|
|
gradient_accumulate_every: 1
|
|
lr_scheduler: cosine
|
|
lr_warmup_steps: 500
|
|
max_train_steps: null
|
|
max_val_steps: null
|
|
num_epochs: 600
|
|
resume: true
|
|
rollout_every: 50
|
|
sample_every: 5
|
|
seed: 42
|
|
tqdm_interval_sec: 1.0
|
|
use_ema: true
|
|
val_every: 1
|
|
val_dataloader:
|
|
batch_size: 64
|
|
num_workers: 8
|
|
persistent_workers: false
|
|
pin_memory: true
|
|
shuffle: false
|