feat(pusht): add pMF-style DiT image policy
This commit is contained in:
184
image_pusht_diffusion_policy_dit_pmf.yaml
Normal file
184
image_pusht_diffusion_policy_dit_pmf.yaml
Normal file
@@ -0,0 +1,184 @@
|
||||
_target_: diffusion_policy.workspace.train_diffusion_transformer_hybrid_workspace.TrainDiffusionTransformerHybridWorkspace
|
||||
checkpoint:
|
||||
save_last_ckpt: true
|
||||
save_last_snapshot: false
|
||||
topk:
|
||||
format_str: epoch={epoch:04d}-train_loss={train_loss:.3f}.ckpt
|
||||
k: 5
|
||||
mode: min
|
||||
monitor_key: train_loss
|
||||
dataloader:
|
||||
batch_size: 64
|
||||
num_workers: 8
|
||||
persistent_workers: false
|
||||
pin_memory: true
|
||||
shuffle: true
|
||||
dataset_obs_steps: 2
|
||||
ema:
|
||||
_target_: diffusion_policy.model.diffusion.ema_model.EMAModel
|
||||
inv_gamma: 1.0
|
||||
max_value: 0.9999
|
||||
min_value: 0.0
|
||||
power: 0.75
|
||||
update_after_step: 0
|
||||
exp_name: default
|
||||
horizon: 16
|
||||
keypoint_visible_rate: 1.0
|
||||
logging:
|
||||
group: null
|
||||
id: null
|
||||
mode: online
|
||||
name: ${now:%Y.%m.%d-%H.%M.%S}_train_diffusion_transformer_hybrid_pmf_pusht_image
|
||||
project: diffusion_policy_debug
|
||||
resume: true
|
||||
tags:
|
||||
- train_diffusion_transformer_hybrid_pmf
|
||||
- pusht_image
|
||||
- default
|
||||
multi_run:
|
||||
run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_train_diffusion_transformer_hybrid_pmf_pusht_image
|
||||
wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_train_diffusion_transformer_hybrid_pmf_pusht_image
|
||||
n_action_steps: 8
|
||||
n_latency_steps: 0
|
||||
n_obs_steps: 2
|
||||
name: train_diffusion_transformer_hybrid_pmf
|
||||
obs_as_cond: true
|
||||
optimizer:
|
||||
betas:
|
||||
- 0.9
|
||||
- 0.95
|
||||
learning_rate: 0.0001
|
||||
obs_encoder_weight_decay: 1.0e-06
|
||||
transformer_weight_decay: 0.001
|
||||
past_action_visible: false
|
||||
policy:
|
||||
_target_: diffusion_policy.policy.pmf_transformer_hybrid_image_policy.PMFTransformerHybridImagePolicy
|
||||
crop_shape:
|
||||
- 84
|
||||
- 84
|
||||
eval_fixed_crop: true
|
||||
horizon: 16
|
||||
min_time: 0.05
|
||||
n_action_steps: 8
|
||||
n_cond_layers: 0
|
||||
n_emb: 256
|
||||
n_head: 4
|
||||
n_layer: 8
|
||||
n_obs_steps: 2
|
||||
n_time_tokens: 4
|
||||
noise_scale: 1.0
|
||||
noise_scheduler:
|
||||
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
||||
beta_end: 0.02
|
||||
beta_schedule: squaredcos_cap_v2
|
||||
beta_start: 0.0001
|
||||
clip_sample: true
|
||||
num_train_timesteps: 100
|
||||
prediction_type: sample
|
||||
variance_type: fixed_small
|
||||
num_inference_steps: 32
|
||||
obs_as_cond: true
|
||||
obs_encoder_group_norm: true
|
||||
p_drop_attn: 0.0
|
||||
p_drop_emb: 0.0
|
||||
pmf_u_loss_weight: 1.0
|
||||
pmf_v_loss_weight: 1.0
|
||||
shape_meta:
|
||||
action:
|
||||
shape:
|
||||
- 2
|
||||
obs:
|
||||
agent_pos:
|
||||
shape:
|
||||
- 2
|
||||
type: low_dim
|
||||
image:
|
||||
shape:
|
||||
- 3
|
||||
- 96
|
||||
- 96
|
||||
type: rgb
|
||||
shape_meta:
|
||||
action:
|
||||
shape:
|
||||
- 2
|
||||
obs:
|
||||
agent_pos:
|
||||
shape:
|
||||
- 2
|
||||
type: low_dim
|
||||
image:
|
||||
shape:
|
||||
- 3
|
||||
- 96
|
||||
- 96
|
||||
type: rgb
|
||||
task:
|
||||
dataset:
|
||||
_target_: diffusion_policy.dataset.pusht_image_dataset.PushTImageDataset
|
||||
horizon: 16
|
||||
max_train_episodes: 90
|
||||
pad_after: 7
|
||||
pad_before: 1
|
||||
seed: 42
|
||||
val_ratio: 0.02
|
||||
zarr_path: data/pusht/pusht_cchi_v7_replay.zarr
|
||||
env_runner:
|
||||
_target_: diffusion_policy.env_runner.pusht_image_runner.PushTImageRunner
|
||||
fps: 10
|
||||
legacy_test: true
|
||||
max_steps: 300
|
||||
n_action_steps: 8
|
||||
n_envs: null
|
||||
n_obs_steps: 2
|
||||
n_test: 50
|
||||
n_test_vis: 4
|
||||
n_train: 6
|
||||
n_train_vis: 2
|
||||
past_action: false
|
||||
test_start_seed: 100000
|
||||
train_start_seed: 0
|
||||
image_shape:
|
||||
- 3
|
||||
- 96
|
||||
- 96
|
||||
name: pusht_image
|
||||
shape_meta:
|
||||
action:
|
||||
shape:
|
||||
- 2
|
||||
obs:
|
||||
agent_pos:
|
||||
shape:
|
||||
- 2
|
||||
type: low_dim
|
||||
image:
|
||||
shape:
|
||||
- 3
|
||||
- 96
|
||||
- 96
|
||||
type: rgb
|
||||
task_name: pusht_image
|
||||
training:
|
||||
checkpoint_every: 50
|
||||
debug: false
|
||||
device: cuda:0
|
||||
gradient_accumulate_every: 1
|
||||
lr_scheduler: cosine
|
||||
lr_warmup_steps: 500
|
||||
max_train_steps: null
|
||||
max_val_steps: null
|
||||
num_epochs: 600
|
||||
resume: true
|
||||
rollout_every: 50
|
||||
sample_every: 5
|
||||
seed: 42
|
||||
tqdm_interval_sec: 1.0
|
||||
use_ema: true
|
||||
val_every: 1
|
||||
val_dataloader:
|
||||
batch_size: 64
|
||||
num_workers: 8
|
||||
persistent_workers: false
|
||||
pin_memory: true
|
||||
shuffle: false
|
||||
Reference in New Issue
Block a user