diff --git a/diffusion_policy/model/diffusion/pmf_transformer_for_diffusion.py b/diffusion_policy/model/diffusion/pmf_transformer_for_diffusion.py new file mode 100644 index 0000000..96d16d2 --- /dev/null +++ b/diffusion_policy/model/diffusion/pmf_transformer_for_diffusion.py @@ -0,0 +1,265 @@ +from typing import Optional, Tuple, Union +import logging +import torch +import torch.nn as nn + +from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin +from diffusion_policy.model.diffusion.positional_embedding import SinusoidalPosEmb + + +logger = logging.getLogger(__name__) + + +class PMFTransformerForDiffusion(ModuleAttrMixin): + def __init__( + self, + input_dim: int, + output_dim: int, + horizon: int, + n_obs_steps: Optional[int] = None, + cond_dim: int = 0, + n_layer: int = 12, + n_head: int = 12, + n_emb: int = 768, + p_drop_emb: float = 0.1, + p_drop_attn: float = 0.1, + causal_attn: bool = False, + obs_as_cond: bool = False, + n_cond_layers: int = 0, + n_time_tokens: int = 4, + ) -> None: + super().__init__() + + if n_obs_steps is None: + n_obs_steps = horizon + if n_time_tokens < 1: + raise ValueError("n_time_tokens must be >= 1") + + obs_as_cond = cond_dim > 0 + T = horizon + n_global_cond_tokens = 2 * n_time_tokens + T_cond = n_global_cond_tokens + (n_obs_steps if obs_as_cond else 0) + + self.input_emb = nn.Linear(input_dim, n_emb) + self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb)) + self.drop = nn.Dropout(p_drop_emb) + + self.t_emb = SinusoidalPosEmb(n_emb) + self.r_emb = SinusoidalPosEmb(n_emb) + self.t_tokens = nn.Parameter(torch.zeros(1, n_time_tokens, n_emb)) + self.r_tokens = nn.Parameter(torch.zeros(1, n_time_tokens, n_emb)) + self.cond_obs_emb = nn.Linear(cond_dim, n_emb) if obs_as_cond else None + self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb)) + + if n_cond_layers > 0: + encoder_layer = nn.TransformerEncoderLayer( + d_model=n_emb, + nhead=n_head, + dim_feedforward=4 * n_emb, + dropout=p_drop_attn, + activation="gelu", + batch_first=True, + norm_first=True, + ) + self.encoder = nn.TransformerEncoder( + encoder_layer=encoder_layer, + num_layers=n_cond_layers, + ) + else: + self.encoder = nn.Sequential( + nn.Linear(n_emb, 4 * n_emb), + nn.Mish(), + nn.Linear(4 * n_emb, n_emb), + ) + + decoder_layer = nn.TransformerDecoderLayer( + d_model=n_emb, + nhead=n_head, + dim_feedforward=4 * n_emb, + dropout=p_drop_attn, + activation="gelu", + batch_first=True, + norm_first=True, + ) + self.decoder = nn.TransformerDecoder( + decoder_layer=decoder_layer, + num_layers=n_layer, + ) + + if causal_attn: + sz = T + mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) + mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0)) + self.register_buffer("mask", mask) + + if obs_as_cond: + q_idx, c_idx = torch.meshgrid( + torch.arange(T), + torch.arange(T_cond), + indexing="ij", + ) + obs_offset = n_global_cond_tokens + visible = c_idx < obs_offset + visible = visible | (q_idx >= (c_idx - obs_offset)) + memory_mask = visible.float().masked_fill(~visible, float("-inf")).masked_fill(visible, float(0.0)) + self.register_buffer("memory_mask", memory_mask) + else: + self.memory_mask = None + else: + self.mask = None + self.memory_mask = None + + self.ln_f = nn.LayerNorm(n_emb) + self.head_u = nn.Linear(n_emb, output_dim) + self.head_v = nn.Linear(n_emb, output_dim) + + self.T = T + self.T_cond = T_cond + self.horizon = horizon + self.n_obs_steps = n_obs_steps + self.obs_as_cond = obs_as_cond + self.n_global_cond_tokens = n_global_cond_tokens + self.n_time_tokens = n_time_tokens + + self.apply(self._init_weights) + logger.info( + "number of parameters: %e", sum(p.numel() for p in self.parameters()) + ) + + def _init_weights(self, module): + ignore_types = ( + nn.Dropout, + SinusoidalPosEmb, + nn.TransformerEncoderLayer, + nn.TransformerDecoderLayer, + nn.TransformerEncoder, + nn.TransformerDecoder, + nn.ModuleList, + nn.Mish, + nn.Sequential, + ) + if isinstance(module, (nn.Linear, nn.Embedding)): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.MultiheadAttention): + for name in ("in_proj_weight", "q_proj_weight", "k_proj_weight", "v_proj_weight"): + weight = getattr(module, name) + if weight is not None: + torch.nn.init.normal_(weight, mean=0.0, std=0.02) + for name in ("in_proj_bias", "bias_k", "bias_v"): + bias = getattr(module, name) + if bias is not None: + torch.nn.init.zeros_(bias) + elif isinstance(module, nn.LayerNorm): + torch.nn.init.zeros_(module.bias) + torch.nn.init.ones_(module.weight) + elif isinstance(module, PMFTransformerForDiffusion): + torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02) + torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02) + torch.nn.init.normal_(module.t_tokens, mean=0.0, std=0.02) + torch.nn.init.normal_(module.r_tokens, mean=0.0, std=0.02) + elif isinstance(module, ignore_types): + pass + else: + raise RuntimeError("Unaccounted module {}".format(module)) + + def get_optim_groups(self, weight_decay: float = 1e-3): + decay = set() + no_decay = set() + whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention) + blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) + for mn, m in self.named_modules(): + for pn, _ in m.named_parameters(): + fpn = "%s.%s" % (mn, pn) if mn else pn + if pn.endswith("bias"): + no_decay.add(fpn) + elif pn.startswith("bias"): + no_decay.add(fpn) + elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules): + decay.add(fpn) + elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules): + no_decay.add(fpn) + + no_decay.update( + { + "pos_emb", + "cond_pos_emb", + "t_tokens", + "r_tokens", + "_dummy_variable", + } + ) + + param_dict = {pn: p for pn, p in self.named_parameters()} + inter_params = decay & no_decay + union_params = decay | no_decay + assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),) + assert len(param_dict.keys() - union_params) == 0, ( + "parameters %s were not separated into either decay/no_decay set!" % (str(param_dict.keys() - union_params),) + ) + + return [ + { + "params": [param_dict[pn] for pn in sorted(list(decay))], + "weight_decay": weight_decay, + }, + { + "params": [param_dict[pn] for pn in sorted(list(no_decay))], + "weight_decay": 0.0, + }, + ] + + def configure_optimizers( + self, + learning_rate: float = 1e-4, + weight_decay: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.95), + ): + optim_groups = self.get_optim_groups(weight_decay=weight_decay) + return torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas) + + def _broadcast_time(self, value: Union[torch.Tensor, float, int], batch_size: int, device: torch.device): + if not torch.is_tensor(value): + value = torch.tensor([value], dtype=torch.float32, device=device) + elif value.ndim == 0: + value = value[None].to(device=device, dtype=torch.float32) + else: + value = value.to(device=device, dtype=torch.float32) + return value.expand(batch_size) + + def forward( + self, + sample: torch.Tensor, + t: Union[torch.Tensor, float, int], + r: Union[torch.Tensor, float, int], + cond: Optional[torch.Tensor] = None, + ): + batch_size = sample.shape[0] + device = sample.device + t = self._broadcast_time(t, batch_size, device) + r = self._broadcast_time(r, batch_size, device) + + input_emb = self.input_emb(sample) + + t_cond = self.t_tokens + self.t_emb(t).unsqueeze(1) + r_cond = self.r_tokens + self.r_emb(r).unsqueeze(1) + cond_embeddings = [t_cond, r_cond] + if self.obs_as_cond: + cond_embeddings.append(self.cond_obs_emb(cond)) + cond_embeddings = torch.cat(cond_embeddings, dim=1) + + cond_pos = self.cond_pos_emb[:, : cond_embeddings.shape[1], :] + memory = self.drop(cond_embeddings + cond_pos) + memory = self.encoder(memory) + + token_pos = self.pos_emb[:, : input_emb.shape[1], :] + x = self.drop(input_emb + token_pos) + x = self.decoder( + tgt=x, + memory=memory, + tgt_mask=self.mask, + memory_mask=self.memory_mask, + ) + x = self.ln_f(x) + return self.head_u(x), self.head_v(x) diff --git a/diffusion_policy/policy/pmf_transformer_hybrid_image_policy.py b/diffusion_policy/policy/pmf_transformer_hybrid_image_policy.py new file mode 100644 index 0000000..20e7a1a --- /dev/null +++ b/diffusion_policy/policy/pmf_transformer_hybrid_image_policy.py @@ -0,0 +1,398 @@ +from typing import Dict, Tuple +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import reduce +from diffusers.schedulers.scheduling_ddpm import DDPMScheduler + +import diffusion_policy.model.vision.crop_randomizer as dmvc +import robomimic.models.base_nets as rmbn +import robomimic.utils.obs_utils as ObsUtils +from diffusion_policy.common.pytorch_util import dict_apply, replace_submodules +from diffusion_policy.common.robomimic_config_util import get_robomimic_config +from diffusion_policy.model.common.normalizer import LinearNormalizer +from diffusion_policy.model.diffusion.mask_generator import LowdimMaskGenerator +from diffusion_policy.model.diffusion.pmf_transformer_for_diffusion import ( + PMFTransformerForDiffusion, +) +from diffusion_policy.policy.base_image_policy import BaseImagePolicy +from robomimic.algo import algo_factory +from robomimic.algo.algo import PolicyAlgo + + +class PMFTransformerHybridImagePolicy(BaseImagePolicy): + def __init__( + self, + shape_meta: dict, + noise_scheduler: DDPMScheduler, + horizon, + n_action_steps, + n_obs_steps, + num_inference_steps=None, + crop_shape=(76, 76), + obs_encoder_group_norm=False, + eval_fixed_crop=False, + n_layer=8, + n_cond_layers=0, + n_head=4, + n_emb=256, + p_drop_emb=0.0, + p_drop_attn=0.0, + causal_attn=True, + obs_as_cond=True, + pred_action_steps_only=False, + n_time_tokens=4, + min_time=0.05, + du_dt_epsilon=1.0e-3, + pmf_u_loss_weight=1.0, + pmf_v_loss_weight=1.0, + noise_scale=1.0, + **kwargs, + ): + super().__init__() + + action_shape = shape_meta["action"]["shape"] + assert len(action_shape) == 1 + action_dim = action_shape[0] + obs_shape_meta = shape_meta["obs"] + obs_config = { + "low_dim": [], + "rgb": [], + "depth": [], + "scan": [], + } + obs_key_shapes = dict() + for key, attr in obs_shape_meta.items(): + shape = attr["shape"] + obs_key_shapes[key] = list(shape) + + obs_type = attr.get("type", "low_dim") + if obs_type == "rgb": + obs_config["rgb"].append(key) + elif obs_type == "low_dim": + obs_config["low_dim"].append(key) + else: + raise RuntimeError(f"Unsupported obs type: {obs_type}") + + config = get_robomimic_config( + algo_name="bc_rnn", + hdf5_type="image", + task_name="square", + dataset_type="ph", + ) + + with config.unlocked(): + config.observation.modalities.obs = obs_config + + if crop_shape is None: + for _, modality in config.observation.encoder.items(): + if modality.obs_randomizer_class == "CropRandomizer": + modality["obs_randomizer_class"] = None + else: + crop_h, crop_w = crop_shape + for _, modality in config.observation.encoder.items(): + if modality.obs_randomizer_class == "CropRandomizer": + modality.obs_randomizer_kwargs.crop_height = crop_h + modality.obs_randomizer_kwargs.crop_width = crop_w + + ObsUtils.initialize_obs_utils_with_config(config) + + policy: PolicyAlgo = algo_factory( + algo_name=config.algo_name, + config=config, + obs_key_shapes=obs_key_shapes, + ac_dim=action_dim, + device="cpu", + ) + + obs_encoder = policy.nets["policy"].nets["encoder"].nets["obs"] + if obs_encoder_group_norm: + replace_submodules( + root_module=obs_encoder, + predicate=lambda x: isinstance(x, nn.BatchNorm2d), + func=lambda x: nn.GroupNorm( + num_groups=x.num_features // 16, + num_channels=x.num_features, + ), + ) + + if eval_fixed_crop: + replace_submodules( + root_module=obs_encoder, + predicate=lambda x: isinstance(x, rmbn.CropRandomizer), + func=lambda x: dmvc.CropRandomizer( + input_shape=x.input_shape, + crop_height=x.crop_height, + crop_width=x.crop_width, + num_crops=x.num_crops, + pos_enc=x.pos_enc, + ), + ) + + obs_feature_dim = obs_encoder.output_shape()[0] + input_dim = action_dim if obs_as_cond else (obs_feature_dim + action_dim) + cond_dim = obs_feature_dim if obs_as_cond else 0 + + self.obs_encoder = obs_encoder + self.model = PMFTransformerForDiffusion( + input_dim=input_dim, + output_dim=input_dim, + horizon=horizon if not pred_action_steps_only else n_action_steps, + n_obs_steps=n_obs_steps, + cond_dim=cond_dim, + n_layer=n_layer, + n_head=n_head, + n_emb=n_emb, + p_drop_emb=p_drop_emb, + p_drop_attn=p_drop_attn, + causal_attn=causal_attn, + obs_as_cond=obs_as_cond, + n_cond_layers=n_cond_layers, + n_time_tokens=n_time_tokens, + ) + self.noise_scheduler = noise_scheduler + self.mask_generator = LowdimMaskGenerator( + action_dim=action_dim, + obs_dim=0 if obs_as_cond else obs_feature_dim, + max_n_obs_steps=n_obs_steps, + fix_obs_steps=True, + action_visible=False, + ) + self.normalizer = LinearNormalizer() + self.horizon = horizon + self.obs_feature_dim = obs_feature_dim + self.action_dim = action_dim + self.n_action_steps = n_action_steps + self.n_obs_steps = n_obs_steps + self.obs_as_cond = obs_as_cond + self.pred_action_steps_only = pred_action_steps_only + self.min_time = min_time + self.du_dt_epsilon = du_dt_epsilon + self.pmf_u_loss_weight = pmf_u_loss_weight + self.pmf_v_loss_weight = pmf_v_loss_weight + self.noise_scale = noise_scale + self.kwargs = kwargs + + if num_inference_steps is None: + num_inference_steps = noise_scheduler.config.num_train_timesteps + self.num_inference_steps = num_inference_steps + + def _encode_obs(self, nobs: Dict[str, torch.Tensor], n_steps: int) -> torch.Tensor: + flat_nobs = dict_apply(nobs, lambda x: x[:, :n_steps, ...].reshape(-1, *x.shape[2:])) + nobs_features = self.obs_encoder(flat_nobs) + return nobs_features.reshape(next(iter(nobs.values())).shape[0], n_steps, -1) + + def _time_view(self, value: torch.Tensor, ref: torch.Tensor) -> torch.Tensor: + return value.reshape(value.shape[0], *([1] * (ref.ndim - 1))) + + def _trajectory_inputs( + self, + nobs: Dict[str, torch.Tensor], + nactions: torch.Tensor, + ): + batch_size = nactions.shape[0] + horizon = nactions.shape[1] + cond = None + trajectory = nactions + if self.obs_as_cond: + cond = self._encode_obs(nobs, self.n_obs_steps) + if self.pred_action_steps_only: + start = self.n_obs_steps - 1 + end = start + self.n_action_steps + trajectory = nactions[:, start:end] + else: + nobs_features = self._encode_obs(nobs, horizon) + trajectory = torch.cat([nactions, nobs_features], dim=-1).detach() + + if self.pred_action_steps_only: + condition_mask = torch.zeros_like(trajectory, dtype=torch.bool) + else: + condition_mask = self.mask_generator(trajectory.shape) + + return batch_size, trajectory, cond, condition_mask + + def _apply_conditioning( + self, + sample: torch.Tensor, + condition_data: torch.Tensor, + condition_mask: torch.Tensor, + ) -> torch.Tensor: + if condition_mask.any(): + sample = sample.clone() + sample[condition_mask] = condition_data[condition_mask] + return sample + + def _compute_u_v( + self, + sample: torch.Tensor, + t: torch.Tensor, + r: torch.Tensor, + cond: torch.Tensor, + ): + x_hat_u, x_hat_v = self.model(sample, t, r, cond) + denom = self._time_view(t.clamp_min(self.min_time), sample) + u = (sample - x_hat_u) / denom + v = (sample - x_hat_v) / denom + return u, v + + def _compute_du_dt( + self, + sample: torch.Tensor, + t: torch.Tensor, + r: torch.Tensor, + cond: torch.Tensor, + condition_data: torch.Tensor, + condition_mask: torch.Tensor, + tangent_v: torch.Tensor, + ) -> torch.Tensor: + step = torch.full_like(t, self.du_dt_epsilon) + sample_base = self._apply_conditioning(sample, condition_data, condition_mask) + u_base, _ = self._compute_u_v(sample_base, t, r, cond) + + t_next = (t + step).clamp(max=1.0) + sample_next = sample + self._time_view(step, sample) * tangent_v.detach() + sample_next = self._apply_conditioning(sample_next, condition_data, condition_mask) + u_next, _ = self._compute_u_v(sample_next, t_next, r, cond) + + return (u_next - u_base) / self._time_view(step, sample) + + # ========= inference ============ + def conditional_sample( + self, + condition_data, + condition_mask, + cond=None, + generator=None, + **kwargs, + ): + del kwargs + + trajectory = torch.randn( + size=condition_data.shape, + dtype=condition_data.dtype, + device=condition_data.device, + generator=generator, + ) * self.noise_scale + + time_steps = torch.linspace( + 1.0, + 0.0, + self.num_inference_steps + 1, + dtype=trajectory.dtype, + device=trajectory.device, + ) + + for step_idx in range(self.num_inference_steps): + trajectory = self._apply_conditioning(trajectory, condition_data, condition_mask) + t = time_steps[step_idx].expand(trajectory.shape[0]) + r = time_steps[step_idx + 1].expand(trajectory.shape[0]) + u, _ = self._compute_u_v(trajectory, t, r, cond) + delta = self._time_view(t - r, trajectory) + trajectory = trajectory - delta * u + + trajectory = self._apply_conditioning(trajectory, condition_data, condition_mask) + return trajectory + + def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + assert "past_action" not in obs_dict + nobs = self.normalizer.normalize(obs_dict) + value = next(iter(nobs.values())) + batch_size, to_steps = value.shape[:2] + horizon = self.horizon + action_dim = self.action_dim + + device = self.device + dtype = self.dtype + cond = None + if self.obs_as_cond: + cond = self._encode_obs(nobs, self.n_obs_steps) + shape = (batch_size, horizon, action_dim) + if self.pred_action_steps_only: + shape = (batch_size, self.n_action_steps, action_dim) + cond_data = torch.zeros(size=shape, device=device, dtype=dtype) + cond_mask = torch.zeros_like(cond_data, dtype=torch.bool) + else: + nobs_features = self._encode_obs(nobs, self.n_obs_steps) + shape = (batch_size, horizon, action_dim + self.obs_feature_dim) + cond_data = torch.zeros(size=shape, device=device, dtype=dtype) + cond_mask = torch.zeros_like(cond_data, dtype=torch.bool) + cond_data[:, : self.n_obs_steps, action_dim:] = nobs_features + cond_mask[:, : self.n_obs_steps, action_dim:] = True + + nsample = self.conditional_sample( + cond_data, + cond_mask, + cond=cond, + **self.kwargs, + ) + + naction_pred = nsample[..., :action_dim] + action_pred = self.normalizer["action"].unnormalize(naction_pred) + if self.pred_action_steps_only: + action = action_pred + else: + start = to_steps - 1 + end = start + self.n_action_steps + action = action_pred[:, start:end] + return { + "action": action, + "action_pred": action_pred, + } + + # ========= training ============ + def set_normalizer(self, normalizer: LinearNormalizer): + self.normalizer.load_state_dict(normalizer.state_dict()) + + def get_optimizer( + self, + transformer_weight_decay: float, + obs_encoder_weight_decay: float, + learning_rate: float, + betas: Tuple[float, float], + ) -> torch.optim.Optimizer: + optim_groups = self.model.get_optim_groups(weight_decay=transformer_weight_decay) + optim_groups.append( + { + "params": self.obs_encoder.parameters(), + "weight_decay": obs_encoder_weight_decay, + } + ) + return torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas) + + def compute_loss(self, batch): + assert "valid_mask" not in batch + nobs = self.normalizer.normalize(batch["obs"]) + nactions = self.normalizer["action"].normalize(batch["action"]) + + _, trajectory, cond, condition_mask = self._trajectory_inputs(nobs, nactions) + noise = torch.randn_like(trajectory) * self.noise_scale + batch_size = trajectory.shape[0] + + t = torch.rand(batch_size, device=trajectory.device, dtype=trajectory.dtype) + r = torch.rand(batch_size, device=trajectory.device, dtype=trajectory.dtype) * t + z_t = (1 - self._time_view(t, trajectory)) * trajectory + self._time_view(t, trajectory) * noise + z_t = self._apply_conditioning(z_t, trajectory, condition_mask) + + loss_mask = ~condition_mask + denom = self._time_view(t.clamp_min(self.min_time), trajectory) + target_v = (z_t - trajectory) / denom + + u, v = self._compute_u_v(z_t, t, r, cond) + du_dt = self._compute_du_dt( + sample=z_t, + t=t, + r=r, + cond=cond, + condition_data=trajectory, + condition_mask=condition_mask, + tangent_v=v, + ) + pmf_velocity = u + self._time_view(t - r, trajectory) * du_dt.detach() + + loss_u = F.mse_loss(pmf_velocity, target_v, reduction="none") + loss_v = F.mse_loss(v, target_v, reduction="none") + loss_u = loss_u * loss_mask.type(loss_u.dtype) + loss_v = loss_v * loss_mask.type(loss_v.dtype) + loss_u = reduce(loss_u, "b ... -> b (...)", "mean").mean() + loss_v = reduce(loss_v, "b ... -> b (...)", "mean").mean() + return self.pmf_u_loss_weight * loss_u + self.pmf_v_loss_weight * loss_v diff --git a/image_pusht_diffusion_policy_dit_pmf.yaml b/image_pusht_diffusion_policy_dit_pmf.yaml new file mode 100644 index 0000000..3123932 --- /dev/null +++ b/image_pusht_diffusion_policy_dit_pmf.yaml @@ -0,0 +1,184 @@ +_target_: diffusion_policy.workspace.train_diffusion_transformer_hybrid_workspace.TrainDiffusionTransformerHybridWorkspace +checkpoint: + save_last_ckpt: true + save_last_snapshot: false + topk: + format_str: epoch={epoch:04d}-train_loss={train_loss:.3f}.ckpt + k: 5 + mode: min + monitor_key: train_loss +dataloader: + batch_size: 64 + num_workers: 8 + persistent_workers: false + pin_memory: true + shuffle: true +dataset_obs_steps: 2 +ema: + _target_: diffusion_policy.model.diffusion.ema_model.EMAModel + inv_gamma: 1.0 + max_value: 0.9999 + min_value: 0.0 + power: 0.75 + update_after_step: 0 +exp_name: default +horizon: 16 +keypoint_visible_rate: 1.0 +logging: + group: null + id: null + mode: online + name: ${now:%Y.%m.%d-%H.%M.%S}_train_diffusion_transformer_hybrid_pmf_pusht_image + project: diffusion_policy_debug + resume: true + tags: + - train_diffusion_transformer_hybrid_pmf + - pusht_image + - default +multi_run: + run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_train_diffusion_transformer_hybrid_pmf_pusht_image + wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_train_diffusion_transformer_hybrid_pmf_pusht_image +n_action_steps: 8 +n_latency_steps: 0 +n_obs_steps: 2 +name: train_diffusion_transformer_hybrid_pmf +obs_as_cond: true +optimizer: + betas: + - 0.9 + - 0.95 + learning_rate: 0.0001 + obs_encoder_weight_decay: 1.0e-06 + transformer_weight_decay: 0.001 +past_action_visible: false +policy: + _target_: diffusion_policy.policy.pmf_transformer_hybrid_image_policy.PMFTransformerHybridImagePolicy + crop_shape: + - 84 + - 84 + eval_fixed_crop: true + horizon: 16 + min_time: 0.05 + n_action_steps: 8 + n_cond_layers: 0 + n_emb: 256 + n_head: 4 + n_layer: 8 + n_obs_steps: 2 + n_time_tokens: 4 + noise_scale: 1.0 + noise_scheduler: + _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler + beta_end: 0.02 + beta_schedule: squaredcos_cap_v2 + beta_start: 0.0001 + clip_sample: true + num_train_timesteps: 100 + prediction_type: sample + variance_type: fixed_small + num_inference_steps: 32 + obs_as_cond: true + obs_encoder_group_norm: true + p_drop_attn: 0.0 + p_drop_emb: 0.0 + pmf_u_loss_weight: 1.0 + pmf_v_loss_weight: 1.0 + shape_meta: + action: + shape: + - 2 + obs: + agent_pos: + shape: + - 2 + type: low_dim + image: + shape: + - 3 + - 96 + - 96 + type: rgb +shape_meta: + action: + shape: + - 2 + obs: + agent_pos: + shape: + - 2 + type: low_dim + image: + shape: + - 3 + - 96 + - 96 + type: rgb +task: + dataset: + _target_: diffusion_policy.dataset.pusht_image_dataset.PushTImageDataset + horizon: 16 + max_train_episodes: 90 + pad_after: 7 + pad_before: 1 + seed: 42 + val_ratio: 0.02 + zarr_path: data/pusht/pusht_cchi_v7_replay.zarr + env_runner: + _target_: diffusion_policy.env_runner.pusht_image_runner.PushTImageRunner + fps: 10 + legacy_test: true + max_steps: 300 + n_action_steps: 8 + n_envs: null + n_obs_steps: 2 + n_test: 50 + n_test_vis: 4 + n_train: 6 + n_train_vis: 2 + past_action: false + test_start_seed: 100000 + train_start_seed: 0 + image_shape: + - 3 + - 96 + - 96 + name: pusht_image + shape_meta: + action: + shape: + - 2 + obs: + agent_pos: + shape: + - 2 + type: low_dim + image: + shape: + - 3 + - 96 + - 96 + type: rgb +task_name: pusht_image +training: + checkpoint_every: 50 + debug: false + device: cuda:0 + gradient_accumulate_every: 1 + lr_scheduler: cosine + lr_warmup_steps: 500 + max_train_steps: null + max_val_steps: null + num_epochs: 600 + resume: true + rollout_every: 50 + sample_every: 5 + seed: 42 + tqdm_interval_sec: 1.0 + use_ema: true + val_every: 1 +val_dataloader: + batch_size: 64 + num_workers: 8 + persistent_workers: false + pin_memory: true + shuffle: false