Compare commits
4 Commits
4cd5085b33
...
dualhead_u
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9169e4d7e0 | ||
|
|
42dc29a2cb | ||
|
|
79f31940c4 | ||
|
|
2aa06c8917 |
@@ -8,8 +8,7 @@ import dill
|
|||||||
import math
|
import math
|
||||||
import wandb.sdk.data_types.video as wv
|
import wandb.sdk.data_types.video as wv
|
||||||
from diffusion_policy.env.pusht.pusht_keypoints_env import PushTKeypointsEnv
|
from diffusion_policy.env.pusht.pusht_keypoints_env import PushTKeypointsEnv
|
||||||
from diffusion_policy.gym_util.async_vector_env import AsyncVectorEnv
|
from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
|
||||||
# from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
|
|
||||||
from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
|
from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
|
||||||
from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
|
from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
|
||||||
|
|
||||||
@@ -133,7 +132,7 @@ class PushTKeypointsRunner(BaseLowdimRunner):
|
|||||||
env_prefixs.append('test/')
|
env_prefixs.append('test/')
|
||||||
env_init_fn_dills.append(dill.dumps(init_fn))
|
env_init_fn_dills.append(dill.dumps(init_fn))
|
||||||
|
|
||||||
env = AsyncVectorEnv(env_fns)
|
env = SyncVectorEnv(env_fns)
|
||||||
|
|
||||||
# test env
|
# test env
|
||||||
# env.reset(seed=env_seeds)
|
# env.reset(seed=env_seeds)
|
||||||
|
|||||||
@@ -0,0 +1,302 @@
|
|||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin
|
||||||
|
from diffusion_policy.model.diffusion.positional_embedding import SinusoidalPosEmb
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PMFTransformerForDiffusion(ModuleAttrMixin):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim: int,
|
||||||
|
output_dim: int,
|
||||||
|
horizon: int,
|
||||||
|
n_obs_steps: Optional[int] = None,
|
||||||
|
cond_dim: int = 0,
|
||||||
|
n_layer: int = 12,
|
||||||
|
n_head: int = 12,
|
||||||
|
n_emb: int = 768,
|
||||||
|
p_drop_emb: float = 0.1,
|
||||||
|
p_drop_attn: float = 0.1,
|
||||||
|
causal_attn: bool = False,
|
||||||
|
obs_as_cond: bool = False,
|
||||||
|
n_cond_layers: int = 0,
|
||||||
|
n_time_tokens: int = 4,
|
||||||
|
n_head_layers: int = 4,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if n_obs_steps is None:
|
||||||
|
n_obs_steps = horizon
|
||||||
|
if n_time_tokens < 1:
|
||||||
|
raise ValueError("n_time_tokens must be >= 1")
|
||||||
|
if n_head_layers < 0:
|
||||||
|
raise ValueError("n_head_layers must be >= 0")
|
||||||
|
if n_head_layers >= n_layer:
|
||||||
|
raise ValueError(
|
||||||
|
"n_head_layers must be smaller than n_layer so shared trunk depth stays positive"
|
||||||
|
)
|
||||||
|
|
||||||
|
obs_as_cond = cond_dim > 0
|
||||||
|
T = horizon
|
||||||
|
n_global_cond_tokens = 2 * n_time_tokens
|
||||||
|
T_cond = n_global_cond_tokens + (n_obs_steps if obs_as_cond else 0)
|
||||||
|
n_shared_layers = n_layer - n_head_layers
|
||||||
|
|
||||||
|
self.input_emb = nn.Linear(input_dim, n_emb)
|
||||||
|
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
|
||||||
|
self.drop = nn.Dropout(p_drop_emb)
|
||||||
|
|
||||||
|
self.t_emb = SinusoidalPosEmb(n_emb)
|
||||||
|
self.r_emb = SinusoidalPosEmb(n_emb)
|
||||||
|
self.t_tokens = nn.Parameter(torch.zeros(1, n_time_tokens, n_emb))
|
||||||
|
self.r_tokens = nn.Parameter(torch.zeros(1, n_time_tokens, n_emb))
|
||||||
|
self.cond_obs_emb = nn.Linear(cond_dim, n_emb) if obs_as_cond else None
|
||||||
|
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
|
||||||
|
|
||||||
|
if n_cond_layers > 0:
|
||||||
|
encoder_layer = nn.TransformerEncoderLayer(
|
||||||
|
d_model=n_emb,
|
||||||
|
nhead=n_head,
|
||||||
|
dim_feedforward=4 * n_emb,
|
||||||
|
dropout=p_drop_attn,
|
||||||
|
activation="gelu",
|
||||||
|
batch_first=True,
|
||||||
|
norm_first=True,
|
||||||
|
)
|
||||||
|
self.encoder = nn.TransformerEncoder(
|
||||||
|
encoder_layer=encoder_layer,
|
||||||
|
num_layers=n_cond_layers,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.encoder = nn.Sequential(
|
||||||
|
nn.Linear(n_emb, 4 * n_emb),
|
||||||
|
nn.Mish(),
|
||||||
|
nn.Linear(4 * n_emb, n_emb),
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_layer = nn.TransformerDecoderLayer(
|
||||||
|
d_model=n_emb,
|
||||||
|
nhead=n_head,
|
||||||
|
dim_feedforward=4 * n_emb,
|
||||||
|
dropout=p_drop_attn,
|
||||||
|
activation="gelu",
|
||||||
|
batch_first=True,
|
||||||
|
norm_first=True,
|
||||||
|
)
|
||||||
|
self.shared_decoder = nn.TransformerDecoder(
|
||||||
|
decoder_layer=decoder_layer,
|
||||||
|
num_layers=n_shared_layers,
|
||||||
|
)
|
||||||
|
self.u_decoder = nn.TransformerDecoder(
|
||||||
|
decoder_layer=decoder_layer,
|
||||||
|
num_layers=n_head_layers,
|
||||||
|
)
|
||||||
|
self.v_decoder = nn.TransformerDecoder(
|
||||||
|
decoder_layer=decoder_layer,
|
||||||
|
num_layers=n_head_layers,
|
||||||
|
)
|
||||||
|
|
||||||
|
if causal_attn:
|
||||||
|
sz = T
|
||||||
|
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
||||||
|
mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0))
|
||||||
|
self.register_buffer("mask", mask)
|
||||||
|
|
||||||
|
if obs_as_cond:
|
||||||
|
q_idx, c_idx = torch.meshgrid(
|
||||||
|
torch.arange(T),
|
||||||
|
torch.arange(T_cond),
|
||||||
|
indexing="ij",
|
||||||
|
)
|
||||||
|
obs_offset = n_global_cond_tokens
|
||||||
|
visible = c_idx < obs_offset
|
||||||
|
visible = visible | (q_idx >= (c_idx - obs_offset))
|
||||||
|
memory_mask = visible.float().masked_fill(~visible, float("-inf")).masked_fill(visible, float(0.0))
|
||||||
|
self.register_buffer("memory_mask", memory_mask)
|
||||||
|
else:
|
||||||
|
self.memory_mask = None
|
||||||
|
else:
|
||||||
|
self.mask = None
|
||||||
|
self.memory_mask = None
|
||||||
|
|
||||||
|
self.ln_u = nn.LayerNorm(n_emb)
|
||||||
|
self.ln_v = nn.LayerNorm(n_emb)
|
||||||
|
self.head_u = nn.Linear(n_emb, output_dim)
|
||||||
|
self.head_v = nn.Linear(n_emb, output_dim)
|
||||||
|
|
||||||
|
self.T = T
|
||||||
|
self.T_cond = T_cond
|
||||||
|
self.horizon = horizon
|
||||||
|
self.n_obs_steps = n_obs_steps
|
||||||
|
self.obs_as_cond = obs_as_cond
|
||||||
|
self.n_global_cond_tokens = n_global_cond_tokens
|
||||||
|
self.n_time_tokens = n_time_tokens
|
||||||
|
self.n_layer = n_layer
|
||||||
|
self.n_head_layers = n_head_layers
|
||||||
|
self.n_shared_layers = n_shared_layers
|
||||||
|
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
logger.info(
|
||||||
|
"number of parameters: %e", sum(p.numel() for p in self.parameters())
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"PMFTransformerForDiffusion layers: shared=%d u_head=%d v_head=%d",
|
||||||
|
self.n_shared_layers,
|
||||||
|
self.n_head_layers,
|
||||||
|
self.n_head_layers,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
ignore_types = (
|
||||||
|
nn.Dropout,
|
||||||
|
SinusoidalPosEmb,
|
||||||
|
nn.TransformerEncoderLayer,
|
||||||
|
nn.TransformerDecoderLayer,
|
||||||
|
nn.TransformerEncoder,
|
||||||
|
nn.TransformerDecoder,
|
||||||
|
nn.ModuleList,
|
||||||
|
nn.Mish,
|
||||||
|
nn.Sequential,
|
||||||
|
)
|
||||||
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||||
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||||
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||||
|
torch.nn.init.zeros_(module.bias)
|
||||||
|
elif isinstance(module, nn.MultiheadAttention):
|
||||||
|
for name in ("in_proj_weight", "q_proj_weight", "k_proj_weight", "v_proj_weight"):
|
||||||
|
weight = getattr(module, name)
|
||||||
|
if weight is not None:
|
||||||
|
torch.nn.init.normal_(weight, mean=0.0, std=0.02)
|
||||||
|
for name in ("in_proj_bias", "bias_k", "bias_v"):
|
||||||
|
bias = getattr(module, name)
|
||||||
|
if bias is not None:
|
||||||
|
torch.nn.init.zeros_(bias)
|
||||||
|
elif isinstance(module, nn.LayerNorm):
|
||||||
|
torch.nn.init.zeros_(module.bias)
|
||||||
|
torch.nn.init.ones_(module.weight)
|
||||||
|
elif isinstance(module, PMFTransformerForDiffusion):
|
||||||
|
torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
|
||||||
|
torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02)
|
||||||
|
torch.nn.init.normal_(module.t_tokens, mean=0.0, std=0.02)
|
||||||
|
torch.nn.init.normal_(module.r_tokens, mean=0.0, std=0.02)
|
||||||
|
elif isinstance(module, ignore_types):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Unaccounted module {}".format(module))
|
||||||
|
|
||||||
|
def get_optim_groups(self, weight_decay: float = 1e-3):
|
||||||
|
decay = set()
|
||||||
|
no_decay = set()
|
||||||
|
whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention)
|
||||||
|
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
||||||
|
for mn, m in self.named_modules():
|
||||||
|
for pn, _ in m.named_parameters():
|
||||||
|
fpn = "%s.%s" % (mn, pn) if mn else pn
|
||||||
|
if pn.endswith("bias"):
|
||||||
|
no_decay.add(fpn)
|
||||||
|
elif pn.startswith("bias"):
|
||||||
|
no_decay.add(fpn)
|
||||||
|
elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules):
|
||||||
|
decay.add(fpn)
|
||||||
|
elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules):
|
||||||
|
no_decay.add(fpn)
|
||||||
|
|
||||||
|
no_decay.update(
|
||||||
|
{
|
||||||
|
"pos_emb",
|
||||||
|
"cond_pos_emb",
|
||||||
|
"t_tokens",
|
||||||
|
"r_tokens",
|
||||||
|
"_dummy_variable",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
param_dict = {pn: p for pn, p in self.named_parameters()}
|
||||||
|
inter_params = decay & no_decay
|
||||||
|
union_params = decay | no_decay
|
||||||
|
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),)
|
||||||
|
assert len(param_dict.keys() - union_params) == 0, (
|
||||||
|
"parameters %s were not separated into either decay/no_decay set!" % (str(param_dict.keys() - union_params),)
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"params": [param_dict[pn] for pn in sorted(list(decay))],
|
||||||
|
"weight_decay": weight_decay,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": [param_dict[pn] for pn in sorted(list(no_decay))],
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
def configure_optimizers(
|
||||||
|
self,
|
||||||
|
learning_rate: float = 1e-4,
|
||||||
|
weight_decay: float = 1e-3,
|
||||||
|
betas: Tuple[float, float] = (0.9, 0.95),
|
||||||
|
):
|
||||||
|
optim_groups = self.get_optim_groups(weight_decay=weight_decay)
|
||||||
|
return torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
|
||||||
|
|
||||||
|
def _broadcast_time(self, value: Union[torch.Tensor, float, int], batch_size: int, device: torch.device):
|
||||||
|
if not torch.is_tensor(value):
|
||||||
|
value = torch.tensor([value], dtype=torch.float32, device=device)
|
||||||
|
elif value.ndim == 0:
|
||||||
|
value = value[None].to(device=device, dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
value = value.to(device=device, dtype=torch.float32)
|
||||||
|
return value.expand(batch_size)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
sample: torch.Tensor,
|
||||||
|
t: Union[torch.Tensor, float, int],
|
||||||
|
r: Union[torch.Tensor, float, int],
|
||||||
|
cond: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
batch_size = sample.shape[0]
|
||||||
|
device = sample.device
|
||||||
|
t = self._broadcast_time(t, batch_size, device)
|
||||||
|
r = self._broadcast_time(r, batch_size, device)
|
||||||
|
|
||||||
|
input_emb = self.input_emb(sample)
|
||||||
|
|
||||||
|
t_cond = self.t_tokens + self.t_emb(t).unsqueeze(1)
|
||||||
|
r_cond = self.r_tokens + self.r_emb(r).unsqueeze(1)
|
||||||
|
cond_embeddings = [t_cond, r_cond]
|
||||||
|
if self.obs_as_cond:
|
||||||
|
cond_embeddings.append(self.cond_obs_emb(cond))
|
||||||
|
cond_embeddings = torch.cat(cond_embeddings, dim=1)
|
||||||
|
|
||||||
|
cond_pos = self.cond_pos_emb[:, : cond_embeddings.shape[1], :]
|
||||||
|
memory = self.drop(cond_embeddings + cond_pos)
|
||||||
|
memory = self.encoder(memory)
|
||||||
|
|
||||||
|
token_pos = self.pos_emb[:, : input_emb.shape[1], :]
|
||||||
|
x = self.drop(input_emb + token_pos)
|
||||||
|
shared_x = self.shared_decoder(
|
||||||
|
tgt=x,
|
||||||
|
memory=memory,
|
||||||
|
tgt_mask=self.mask,
|
||||||
|
memory_mask=self.memory_mask,
|
||||||
|
)
|
||||||
|
u_x = self.u_decoder(
|
||||||
|
tgt=shared_x,
|
||||||
|
memory=memory,
|
||||||
|
tgt_mask=self.mask,
|
||||||
|
memory_mask=self.memory_mask,
|
||||||
|
)
|
||||||
|
v_x = self.v_decoder(
|
||||||
|
tgt=shared_x,
|
||||||
|
memory=memory,
|
||||||
|
tgt_mask=self.mask,
|
||||||
|
memory_mask=self.memory_mask,
|
||||||
|
)
|
||||||
|
return self.head_u(self.ln_u(u_x)), self.head_v(self.ln_v(v_x))
|
||||||
455
diffusion_policy/policy/pmf_transformer_hybrid_image_policy.py
Normal file
455
diffusion_policy/policy/pmf_transformer_hybrid_image_policy.py
Normal file
@@ -0,0 +1,455 @@
|
|||||||
|
from typing import Dict, Tuple
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import reduce
|
||||||
|
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||||
|
|
||||||
|
import diffusion_policy.model.vision.crop_randomizer as dmvc
|
||||||
|
import robomimic.models.base_nets as rmbn
|
||||||
|
import robomimic.utils.obs_utils as ObsUtils
|
||||||
|
from diffusion_policy.common.pytorch_util import dict_apply, replace_submodules
|
||||||
|
from diffusion_policy.common.robomimic_config_util import get_robomimic_config
|
||||||
|
from diffusion_policy.model.common.normalizer import LinearNormalizer
|
||||||
|
from diffusion_policy.model.diffusion.mask_generator import LowdimMaskGenerator
|
||||||
|
from diffusion_policy.model.diffusion.pmf_transformer_for_diffusion import (
|
||||||
|
PMFTransformerForDiffusion,
|
||||||
|
)
|
||||||
|
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
|
||||||
|
from robomimic.algo import algo_factory
|
||||||
|
from robomimic.algo.algo import PolicyAlgo
|
||||||
|
|
||||||
|
|
||||||
|
class PMFTransformerHybridImagePolicy(BaseImagePolicy):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
shape_meta: dict,
|
||||||
|
noise_scheduler: DDPMScheduler,
|
||||||
|
horizon,
|
||||||
|
n_action_steps,
|
||||||
|
n_obs_steps,
|
||||||
|
num_inference_steps=None,
|
||||||
|
crop_shape=(76, 76),
|
||||||
|
obs_encoder_group_norm=False,
|
||||||
|
eval_fixed_crop=False,
|
||||||
|
n_layer=8,
|
||||||
|
n_cond_layers=0,
|
||||||
|
n_head=4,
|
||||||
|
n_emb=256,
|
||||||
|
p_drop_emb=0.0,
|
||||||
|
p_drop_attn=0.0,
|
||||||
|
causal_attn=True,
|
||||||
|
obs_as_cond=True,
|
||||||
|
pred_action_steps_only=False,
|
||||||
|
n_time_tokens=4,
|
||||||
|
n_head_layers=4,
|
||||||
|
min_time=0.05,
|
||||||
|
du_dt_epsilon=1.0e-3,
|
||||||
|
pmf_u_loss_weight=1.0,
|
||||||
|
pmf_v_loss_weight=1.0,
|
||||||
|
noise_scale=1.0,
|
||||||
|
adatloss_eps=0.01,
|
||||||
|
p_mean=-0.4,
|
||||||
|
p_std=1.0,
|
||||||
|
tr_uniform=True,
|
||||||
|
tr_uniform_prob=0.1,
|
||||||
|
data_proportion=0.5,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
action_shape = shape_meta["action"]["shape"]
|
||||||
|
assert len(action_shape) == 1
|
||||||
|
action_dim = action_shape[0]
|
||||||
|
obs_shape_meta = shape_meta["obs"]
|
||||||
|
obs_config = {
|
||||||
|
"low_dim": [],
|
||||||
|
"rgb": [],
|
||||||
|
"depth": [],
|
||||||
|
"scan": [],
|
||||||
|
}
|
||||||
|
obs_key_shapes = dict()
|
||||||
|
for key, attr in obs_shape_meta.items():
|
||||||
|
shape = attr["shape"]
|
||||||
|
obs_key_shapes[key] = list(shape)
|
||||||
|
|
||||||
|
obs_type = attr.get("type", "low_dim")
|
||||||
|
if obs_type == "rgb":
|
||||||
|
obs_config["rgb"].append(key)
|
||||||
|
elif obs_type == "low_dim":
|
||||||
|
obs_config["low_dim"].append(key)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unsupported obs type: {obs_type}")
|
||||||
|
|
||||||
|
config = get_robomimic_config(
|
||||||
|
algo_name="bc_rnn",
|
||||||
|
hdf5_type="image",
|
||||||
|
task_name="square",
|
||||||
|
dataset_type="ph",
|
||||||
|
)
|
||||||
|
|
||||||
|
with config.unlocked():
|
||||||
|
config.observation.modalities.obs = obs_config
|
||||||
|
|
||||||
|
if crop_shape is None:
|
||||||
|
for _, modality in config.observation.encoder.items():
|
||||||
|
if modality.obs_randomizer_class == "CropRandomizer":
|
||||||
|
modality["obs_randomizer_class"] = None
|
||||||
|
else:
|
||||||
|
crop_h, crop_w = crop_shape
|
||||||
|
for _, modality in config.observation.encoder.items():
|
||||||
|
if modality.obs_randomizer_class == "CropRandomizer":
|
||||||
|
modality.obs_randomizer_kwargs.crop_height = crop_h
|
||||||
|
modality.obs_randomizer_kwargs.crop_width = crop_w
|
||||||
|
|
||||||
|
ObsUtils.initialize_obs_utils_with_config(config)
|
||||||
|
|
||||||
|
policy: PolicyAlgo = algo_factory(
|
||||||
|
algo_name=config.algo_name,
|
||||||
|
config=config,
|
||||||
|
obs_key_shapes=obs_key_shapes,
|
||||||
|
ac_dim=action_dim,
|
||||||
|
device="cpu",
|
||||||
|
)
|
||||||
|
|
||||||
|
obs_encoder = policy.nets["policy"].nets["encoder"].nets["obs"]
|
||||||
|
if obs_encoder_group_norm:
|
||||||
|
replace_submodules(
|
||||||
|
root_module=obs_encoder,
|
||||||
|
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
||||||
|
func=lambda x: nn.GroupNorm(
|
||||||
|
num_groups=x.num_features // 16,
|
||||||
|
num_channels=x.num_features,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if eval_fixed_crop:
|
||||||
|
replace_submodules(
|
||||||
|
root_module=obs_encoder,
|
||||||
|
predicate=lambda x: isinstance(x, rmbn.CropRandomizer),
|
||||||
|
func=lambda x: dmvc.CropRandomizer(
|
||||||
|
input_shape=x.input_shape,
|
||||||
|
crop_height=x.crop_height,
|
||||||
|
crop_width=x.crop_width,
|
||||||
|
num_crops=x.num_crops,
|
||||||
|
pos_enc=x.pos_enc,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
obs_feature_dim = obs_encoder.output_shape()[0]
|
||||||
|
input_dim = action_dim if obs_as_cond else (obs_feature_dim + action_dim)
|
||||||
|
cond_dim = obs_feature_dim if obs_as_cond else 0
|
||||||
|
|
||||||
|
self.obs_encoder = obs_encoder
|
||||||
|
self.model = PMFTransformerForDiffusion(
|
||||||
|
input_dim=input_dim,
|
||||||
|
output_dim=input_dim,
|
||||||
|
horizon=horizon if not pred_action_steps_only else n_action_steps,
|
||||||
|
n_obs_steps=n_obs_steps,
|
||||||
|
cond_dim=cond_dim,
|
||||||
|
n_layer=n_layer,
|
||||||
|
n_head=n_head,
|
||||||
|
n_emb=n_emb,
|
||||||
|
p_drop_emb=p_drop_emb,
|
||||||
|
p_drop_attn=p_drop_attn,
|
||||||
|
causal_attn=causal_attn,
|
||||||
|
obs_as_cond=obs_as_cond,
|
||||||
|
n_cond_layers=n_cond_layers,
|
||||||
|
n_time_tokens=n_time_tokens,
|
||||||
|
n_head_layers=n_head_layers,
|
||||||
|
)
|
||||||
|
self.noise_scheduler = noise_scheduler
|
||||||
|
self.mask_generator = LowdimMaskGenerator(
|
||||||
|
action_dim=action_dim,
|
||||||
|
obs_dim=0 if obs_as_cond else obs_feature_dim,
|
||||||
|
max_n_obs_steps=n_obs_steps,
|
||||||
|
fix_obs_steps=True,
|
||||||
|
action_visible=False,
|
||||||
|
)
|
||||||
|
self.normalizer = LinearNormalizer()
|
||||||
|
self.horizon = horizon
|
||||||
|
self.obs_feature_dim = obs_feature_dim
|
||||||
|
self.action_dim = action_dim
|
||||||
|
self.n_action_steps = n_action_steps
|
||||||
|
self.n_obs_steps = n_obs_steps
|
||||||
|
self.obs_as_cond = obs_as_cond
|
||||||
|
self.pred_action_steps_only = pred_action_steps_only
|
||||||
|
self.min_time = min_time
|
||||||
|
self.du_dt_epsilon = du_dt_epsilon
|
||||||
|
self.pmf_u_loss_weight = pmf_u_loss_weight
|
||||||
|
self.pmf_v_loss_weight = pmf_v_loss_weight
|
||||||
|
self.noise_scale = noise_scale
|
||||||
|
self.adatloss_eps = adatloss_eps
|
||||||
|
self.p_mean = p_mean
|
||||||
|
self.p_std = p_std
|
||||||
|
self.tr_uniform = tr_uniform
|
||||||
|
self.tr_uniform_prob = tr_uniform_prob
|
||||||
|
self.data_proportion = data_proportion
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
if num_inference_steps is None:
|
||||||
|
num_inference_steps = noise_scheduler.config.num_train_timesteps
|
||||||
|
self.num_inference_steps = num_inference_steps
|
||||||
|
|
||||||
|
def _encode_obs(self, nobs: Dict[str, torch.Tensor], n_steps: int) -> torch.Tensor:
|
||||||
|
flat_nobs = dict_apply(nobs, lambda x: x[:, :n_steps, ...].reshape(-1, *x.shape[2:]))
|
||||||
|
nobs_features = self.obs_encoder(flat_nobs)
|
||||||
|
return nobs_features.reshape(next(iter(nobs.values())).shape[0], n_steps, -1)
|
||||||
|
|
||||||
|
def _time_view(self, value: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
|
||||||
|
return value.reshape(value.shape[0], *([1] * (ref.ndim - 1)))
|
||||||
|
|
||||||
|
def _adatloss(self, loss: torch.Tensor) -> torch.Tensor:
|
||||||
|
denom = loss.detach() + self.adatloss_eps
|
||||||
|
return loss / denom
|
||||||
|
|
||||||
|
def _sample_logit_normal(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||||
|
normal = torch.randn(batch_size, device=device, dtype=dtype)
|
||||||
|
return torch.sigmoid(normal * self.p_std + self.p_mean)
|
||||||
|
|
||||||
|
def _sample_tr(self, batch_size: int, device: torch.device, dtype: torch.dtype):
|
||||||
|
t = self._sample_logit_normal(batch_size, device, dtype)
|
||||||
|
r = self._sample_logit_normal(batch_size, device, dtype)
|
||||||
|
|
||||||
|
if self.tr_uniform:
|
||||||
|
uniform_mask = torch.rand(batch_size, device=device) < self.tr_uniform_prob
|
||||||
|
uniform_t = torch.rand(batch_size, device=device, dtype=dtype)
|
||||||
|
uniform_r = torch.rand(batch_size, device=device, dtype=dtype)
|
||||||
|
t = torch.where(uniform_mask, uniform_t, t)
|
||||||
|
r = torch.where(uniform_mask, uniform_r, r)
|
||||||
|
|
||||||
|
data_size = int(batch_size * self.data_proportion)
|
||||||
|
fm_mask = torch.arange(batch_size, device=device) < data_size
|
||||||
|
r = torch.where(fm_mask, t, r)
|
||||||
|
|
||||||
|
t_final = torch.maximum(t, r)
|
||||||
|
r_final = torch.minimum(t, r)
|
||||||
|
return t_final, r_final
|
||||||
|
|
||||||
|
def _trajectory_inputs(
|
||||||
|
self,
|
||||||
|
nobs: Dict[str, torch.Tensor],
|
||||||
|
nactions: torch.Tensor,
|
||||||
|
):
|
||||||
|
batch_size = nactions.shape[0]
|
||||||
|
horizon = nactions.shape[1]
|
||||||
|
cond = None
|
||||||
|
trajectory = nactions
|
||||||
|
if self.obs_as_cond:
|
||||||
|
cond = self._encode_obs(nobs, self.n_obs_steps)
|
||||||
|
if self.pred_action_steps_only:
|
||||||
|
start = self.n_obs_steps - 1
|
||||||
|
end = start + self.n_action_steps
|
||||||
|
trajectory = nactions[:, start:end]
|
||||||
|
else:
|
||||||
|
nobs_features = self._encode_obs(nobs, horizon)
|
||||||
|
trajectory = torch.cat([nactions, nobs_features], dim=-1).detach()
|
||||||
|
|
||||||
|
if self.pred_action_steps_only:
|
||||||
|
condition_mask = torch.zeros_like(trajectory, dtype=torch.bool)
|
||||||
|
else:
|
||||||
|
condition_mask = self.mask_generator(trajectory.shape)
|
||||||
|
|
||||||
|
return batch_size, trajectory, cond, condition_mask
|
||||||
|
|
||||||
|
def _apply_conditioning(
|
||||||
|
self,
|
||||||
|
sample: torch.Tensor,
|
||||||
|
condition_data: torch.Tensor,
|
||||||
|
condition_mask: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if not condition_mask.any():
|
||||||
|
return sample
|
||||||
|
return torch.where(condition_mask, condition_data, sample)
|
||||||
|
|
||||||
|
def _compute_u_v(
|
||||||
|
self,
|
||||||
|
sample: torch.Tensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
r: torch.Tensor,
|
||||||
|
cond: torch.Tensor,
|
||||||
|
):
|
||||||
|
x_hat_u, x_hat_v = self.model(sample, t, r, cond)
|
||||||
|
denom = self._time_view(t, sample)
|
||||||
|
u = (sample - x_hat_u) / denom
|
||||||
|
v = (sample - x_hat_v) / denom
|
||||||
|
return u, v
|
||||||
|
|
||||||
|
def _compute_du_dt(
|
||||||
|
self,
|
||||||
|
sample: torch.Tensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
r: torch.Tensor,
|
||||||
|
cond: torch.Tensor,
|
||||||
|
condition_data: torch.Tensor,
|
||||||
|
condition_mask: torch.Tensor,
|
||||||
|
tangent_v: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
tangent_sample = tangent_v.detach()
|
||||||
|
tangent_r = torch.zeros_like(r)
|
||||||
|
tangent_t = torch.ones_like(t)
|
||||||
|
|
||||||
|
def u_fn(sample_input, r_input, t_input):
|
||||||
|
conditioned_sample = self._apply_conditioning(
|
||||||
|
sample_input, condition_data, condition_mask
|
||||||
|
)
|
||||||
|
u_value, _ = self._compute_u_v(conditioned_sample, t_input, r_input, cond)
|
||||||
|
return u_value
|
||||||
|
|
||||||
|
primals = (sample, r, t)
|
||||||
|
tangents = (tangent_sample, tangent_r, tangent_t)
|
||||||
|
try:
|
||||||
|
_, du_dt = torch.func.jvp(u_fn, primals, tangents)
|
||||||
|
except (AttributeError, NotImplementedError, RuntimeError):
|
||||||
|
_, du_dt = torch.autograd.functional.jvp(
|
||||||
|
u_fn,
|
||||||
|
primals,
|
||||||
|
tangents,
|
||||||
|
create_graph=False,
|
||||||
|
strict=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return du_dt
|
||||||
|
|
||||||
|
# ========= inference ============
|
||||||
|
def conditional_sample(
|
||||||
|
self,
|
||||||
|
condition_data,
|
||||||
|
condition_mask,
|
||||||
|
cond=None,
|
||||||
|
generator=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
del kwargs
|
||||||
|
|
||||||
|
trajectory = torch.randn(
|
||||||
|
size=condition_data.shape,
|
||||||
|
dtype=condition_data.dtype,
|
||||||
|
device=condition_data.device,
|
||||||
|
generator=generator,
|
||||||
|
) * self.noise_scale
|
||||||
|
|
||||||
|
time_steps = torch.linspace(
|
||||||
|
1.0,
|
||||||
|
0.0,
|
||||||
|
self.num_inference_steps + 1,
|
||||||
|
dtype=trajectory.dtype,
|
||||||
|
device=trajectory.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
for step_idx in range(self.num_inference_steps):
|
||||||
|
trajectory = self._apply_conditioning(trajectory, condition_data, condition_mask)
|
||||||
|
t = time_steps[step_idx].expand(trajectory.shape[0])
|
||||||
|
r = time_steps[step_idx + 1].expand(trajectory.shape[0])
|
||||||
|
u, _ = self._compute_u_v(trajectory, t, r, cond)
|
||||||
|
delta = self._time_view(t - r, trajectory)
|
||||||
|
trajectory = trajectory - delta * u
|
||||||
|
|
||||||
|
trajectory = self._apply_conditioning(trajectory, condition_data, condition_mask)
|
||||||
|
return trajectory
|
||||||
|
|
||||||
|
def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||||
|
assert "past_action" not in obs_dict
|
||||||
|
nobs = self.normalizer.normalize(obs_dict)
|
||||||
|
value = next(iter(nobs.values()))
|
||||||
|
batch_size, to_steps = value.shape[:2]
|
||||||
|
horizon = self.horizon
|
||||||
|
action_dim = self.action_dim
|
||||||
|
|
||||||
|
device = self.device
|
||||||
|
dtype = self.dtype
|
||||||
|
cond = None
|
||||||
|
if self.obs_as_cond:
|
||||||
|
cond = self._encode_obs(nobs, self.n_obs_steps)
|
||||||
|
shape = (batch_size, horizon, action_dim)
|
||||||
|
if self.pred_action_steps_only:
|
||||||
|
shape = (batch_size, self.n_action_steps, action_dim)
|
||||||
|
cond_data = torch.zeros(size=shape, device=device, dtype=dtype)
|
||||||
|
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
|
||||||
|
else:
|
||||||
|
nobs_features = self._encode_obs(nobs, self.n_obs_steps)
|
||||||
|
shape = (batch_size, horizon, action_dim + self.obs_feature_dim)
|
||||||
|
cond_data = torch.zeros(size=shape, device=device, dtype=dtype)
|
||||||
|
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
|
||||||
|
cond_data[:, : self.n_obs_steps, action_dim:] = nobs_features
|
||||||
|
cond_mask[:, : self.n_obs_steps, action_dim:] = True
|
||||||
|
|
||||||
|
nsample = self.conditional_sample(
|
||||||
|
cond_data,
|
||||||
|
cond_mask,
|
||||||
|
cond=cond,
|
||||||
|
**self.kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
naction_pred = nsample[..., :action_dim]
|
||||||
|
action_pred = self.normalizer["action"].unnormalize(naction_pred)
|
||||||
|
if self.pred_action_steps_only:
|
||||||
|
action = action_pred
|
||||||
|
else:
|
||||||
|
start = to_steps - 1
|
||||||
|
end = start + self.n_action_steps
|
||||||
|
action = action_pred[:, start:end]
|
||||||
|
return {
|
||||||
|
"action": action,
|
||||||
|
"action_pred": action_pred,
|
||||||
|
}
|
||||||
|
|
||||||
|
# ========= training ============
|
||||||
|
def set_normalizer(self, normalizer: LinearNormalizer):
|
||||||
|
self.normalizer.load_state_dict(normalizer.state_dict())
|
||||||
|
|
||||||
|
def get_optimizer(
|
||||||
|
self,
|
||||||
|
transformer_weight_decay: float,
|
||||||
|
obs_encoder_weight_decay: float,
|
||||||
|
learning_rate: float,
|
||||||
|
betas: Tuple[float, float],
|
||||||
|
) -> torch.optim.Optimizer:
|
||||||
|
optim_groups = self.model.get_optim_groups(weight_decay=transformer_weight_decay)
|
||||||
|
optim_groups.append(
|
||||||
|
{
|
||||||
|
"params": self.obs_encoder.parameters(),
|
||||||
|
"weight_decay": obs_encoder_weight_decay,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
|
||||||
|
|
||||||
|
def compute_loss(self, batch):
|
||||||
|
assert "valid_mask" not in batch
|
||||||
|
nobs = self.normalizer.normalize(batch["obs"])
|
||||||
|
nactions = self.normalizer["action"].normalize(batch["action"])
|
||||||
|
|
||||||
|
_, trajectory, cond, condition_mask = self._trajectory_inputs(nobs, nactions)
|
||||||
|
noise = torch.randn_like(trajectory) * self.noise_scale
|
||||||
|
batch_size = trajectory.shape[0]
|
||||||
|
|
||||||
|
t, r = self._sample_tr(
|
||||||
|
batch_size, device=trajectory.device, dtype=trajectory.dtype
|
||||||
|
)
|
||||||
|
z_t = (1 - self._time_view(t, trajectory)) * trajectory + self._time_view(t, trajectory) * noise
|
||||||
|
z_t = self._apply_conditioning(z_t, trajectory, condition_mask)
|
||||||
|
|
||||||
|
loss_mask = ~condition_mask
|
||||||
|
target_v = noise - trajectory
|
||||||
|
|
||||||
|
u, v = self._compute_u_v(z_t, t, r, cond)
|
||||||
|
du_dt = self._compute_du_dt(
|
||||||
|
sample=z_t,
|
||||||
|
t=t,
|
||||||
|
r=r,
|
||||||
|
cond=cond,
|
||||||
|
condition_data=trajectory,
|
||||||
|
condition_mask=condition_mask,
|
||||||
|
tangent_v=v,
|
||||||
|
)
|
||||||
|
pmf_velocity = u + self._time_view(t - r, trajectory) * du_dt.detach()
|
||||||
|
|
||||||
|
loss_u = F.mse_loss(pmf_velocity, target_v, reduction="none")
|
||||||
|
loss_v = F.mse_loss(v, target_v, reduction="none")
|
||||||
|
loss_u = loss_u * loss_mask.type(loss_u.dtype)
|
||||||
|
loss_v = loss_v * loss_mask.type(loss_v.dtype)
|
||||||
|
loss_u = reduce(loss_u, "b ... -> b (...)", "mean").mean()
|
||||||
|
loss_v = reduce(loss_v, "b ... -> b (...)", "mean").mean()
|
||||||
|
loss_u = self._adatloss(loss_u)
|
||||||
|
loss_v = self._adatloss(loss_v)
|
||||||
|
return self.pmf_u_loss_weight * loss_u + self.pmf_v_loss_weight * loss_v
|
||||||
190
image_pusht_diffusion_policy_dit_pmf.yaml
Normal file
190
image_pusht_diffusion_policy_dit_pmf.yaml
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
_target_: diffusion_policy.workspace.train_diffusion_transformer_hybrid_workspace.TrainDiffusionTransformerHybridWorkspace
|
||||||
|
checkpoint:
|
||||||
|
save_last_ckpt: true
|
||||||
|
save_last_snapshot: false
|
||||||
|
topk:
|
||||||
|
format_str: epoch={epoch:04d}-train_loss={train_loss:.3f}.ckpt
|
||||||
|
k: 5
|
||||||
|
mode: min
|
||||||
|
monitor_key: train_loss
|
||||||
|
dataloader:
|
||||||
|
batch_size: 64
|
||||||
|
num_workers: 8
|
||||||
|
persistent_workers: false
|
||||||
|
pin_memory: true
|
||||||
|
shuffle: true
|
||||||
|
dataset_obs_steps: 2
|
||||||
|
ema:
|
||||||
|
_target_: diffusion_policy.model.diffusion.ema_model.EMAModel
|
||||||
|
inv_gamma: 1.0
|
||||||
|
max_value: 0.9999
|
||||||
|
min_value: 0.0
|
||||||
|
power: 0.75
|
||||||
|
update_after_step: 0
|
||||||
|
exp_name: default
|
||||||
|
horizon: 16
|
||||||
|
keypoint_visible_rate: 1.0
|
||||||
|
logging:
|
||||||
|
group: null
|
||||||
|
id: null
|
||||||
|
mode: online
|
||||||
|
name: ${now:%Y.%m.%d-%H.%M.%S}_train_diffusion_transformer_hybrid_pmf_pusht_image
|
||||||
|
project: diffusion_policy_debug
|
||||||
|
resume: true
|
||||||
|
tags:
|
||||||
|
- train_diffusion_transformer_hybrid_pmf
|
||||||
|
- pusht_image
|
||||||
|
- default
|
||||||
|
multi_run:
|
||||||
|
run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_train_diffusion_transformer_hybrid_pmf_pusht_image
|
||||||
|
wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_train_diffusion_transformer_hybrid_pmf_pusht_image
|
||||||
|
n_action_steps: 8
|
||||||
|
n_latency_steps: 0
|
||||||
|
n_obs_steps: 2
|
||||||
|
name: train_diffusion_transformer_hybrid_pmf
|
||||||
|
obs_as_cond: true
|
||||||
|
optimizer:
|
||||||
|
betas:
|
||||||
|
- 0.9
|
||||||
|
- 0.95
|
||||||
|
learning_rate: 0.0001
|
||||||
|
obs_encoder_weight_decay: 1.0e-06
|
||||||
|
transformer_weight_decay: 0.001
|
||||||
|
past_action_visible: false
|
||||||
|
policy:
|
||||||
|
_target_: diffusion_policy.policy.pmf_transformer_hybrid_image_policy.PMFTransformerHybridImagePolicy
|
||||||
|
crop_shape:
|
||||||
|
- 84
|
||||||
|
- 84
|
||||||
|
eval_fixed_crop: true
|
||||||
|
horizon: 16
|
||||||
|
n_action_steps: 8
|
||||||
|
n_cond_layers: 0
|
||||||
|
n_emb: 256
|
||||||
|
n_head: 4
|
||||||
|
n_layer: 12
|
||||||
|
n_head_layers: 4
|
||||||
|
n_obs_steps: 2
|
||||||
|
n_time_tokens: 4
|
||||||
|
noise_scale: 1.0
|
||||||
|
adatloss_eps: 0.01
|
||||||
|
p_mean: -0.4
|
||||||
|
p_std: 1.0
|
||||||
|
noise_scheduler:
|
||||||
|
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
||||||
|
beta_end: 0.02
|
||||||
|
beta_schedule: squaredcos_cap_v2
|
||||||
|
beta_start: 0.0001
|
||||||
|
clip_sample: true
|
||||||
|
num_train_timesteps: 100
|
||||||
|
prediction_type: sample
|
||||||
|
variance_type: fixed_small
|
||||||
|
num_inference_steps: 1
|
||||||
|
obs_as_cond: true
|
||||||
|
obs_encoder_group_norm: true
|
||||||
|
p_drop_attn: 0.0
|
||||||
|
p_drop_emb: 0.0
|
||||||
|
pmf_u_loss_weight: 1.0
|
||||||
|
pmf_v_loss_weight: 1.0
|
||||||
|
tr_uniform: true
|
||||||
|
tr_uniform_prob: 0.1
|
||||||
|
data_proportion: 0.5
|
||||||
|
shape_meta:
|
||||||
|
action:
|
||||||
|
shape:
|
||||||
|
- 2
|
||||||
|
obs:
|
||||||
|
agent_pos:
|
||||||
|
shape:
|
||||||
|
- 2
|
||||||
|
type: low_dim
|
||||||
|
image:
|
||||||
|
shape:
|
||||||
|
- 3
|
||||||
|
- 96
|
||||||
|
- 96
|
||||||
|
type: rgb
|
||||||
|
shape_meta:
|
||||||
|
action:
|
||||||
|
shape:
|
||||||
|
- 2
|
||||||
|
obs:
|
||||||
|
agent_pos:
|
||||||
|
shape:
|
||||||
|
- 2
|
||||||
|
type: low_dim
|
||||||
|
image:
|
||||||
|
shape:
|
||||||
|
- 3
|
||||||
|
- 96
|
||||||
|
- 96
|
||||||
|
type: rgb
|
||||||
|
task:
|
||||||
|
dataset:
|
||||||
|
_target_: diffusion_policy.dataset.pusht_image_dataset.PushTImageDataset
|
||||||
|
horizon: 16
|
||||||
|
max_train_episodes: 90
|
||||||
|
pad_after: 7
|
||||||
|
pad_before: 1
|
||||||
|
seed: 42
|
||||||
|
val_ratio: 0.02
|
||||||
|
zarr_path: data/pusht/pusht_cchi_v7_replay.zarr
|
||||||
|
env_runner:
|
||||||
|
_target_: diffusion_policy.env_runner.pusht_image_runner.PushTImageRunner
|
||||||
|
fps: 10
|
||||||
|
legacy_test: true
|
||||||
|
max_steps: 300
|
||||||
|
n_action_steps: 8
|
||||||
|
n_envs: null
|
||||||
|
n_obs_steps: 2
|
||||||
|
n_test: 50
|
||||||
|
n_test_vis: 4
|
||||||
|
n_train: 6
|
||||||
|
n_train_vis: 2
|
||||||
|
past_action: false
|
||||||
|
test_start_seed: 100000
|
||||||
|
train_start_seed: 0
|
||||||
|
image_shape:
|
||||||
|
- 3
|
||||||
|
- 96
|
||||||
|
- 96
|
||||||
|
name: pusht_image
|
||||||
|
shape_meta:
|
||||||
|
action:
|
||||||
|
shape:
|
||||||
|
- 2
|
||||||
|
obs:
|
||||||
|
agent_pos:
|
||||||
|
shape:
|
||||||
|
- 2
|
||||||
|
type: low_dim
|
||||||
|
image:
|
||||||
|
shape:
|
||||||
|
- 3
|
||||||
|
- 96
|
||||||
|
- 96
|
||||||
|
type: rgb
|
||||||
|
task_name: pusht_image
|
||||||
|
training:
|
||||||
|
checkpoint_every: 50
|
||||||
|
debug: false
|
||||||
|
device: cuda:0
|
||||||
|
gradient_accumulate_every: 1
|
||||||
|
lr_scheduler: cosine
|
||||||
|
lr_warmup_steps: 500
|
||||||
|
max_train_steps: null
|
||||||
|
max_val_steps: null
|
||||||
|
num_epochs: 600
|
||||||
|
resume: true
|
||||||
|
rollout_every: 50
|
||||||
|
sample_every: 5
|
||||||
|
seed: 42
|
||||||
|
tqdm_interval_sec: 1.0
|
||||||
|
use_ema: true
|
||||||
|
val_every: 1
|
||||||
|
val_dataloader:
|
||||||
|
batch_size: 64
|
||||||
|
num_workers: 8
|
||||||
|
persistent_workers: false
|
||||||
|
pin_memory: true
|
||||||
|
shuffle: false
|
||||||
@@ -22,6 +22,7 @@ pymunk==6.2.1
|
|||||||
wandb==0.13.3
|
wandb==0.13.3
|
||||||
threadpoolctl==3.1.0
|
threadpoolctl==3.1.0
|
||||||
shapely==1.8.5.post1
|
shapely==1.8.5.post1
|
||||||
|
matplotlib==3.6.1
|
||||||
imageio==2.22.0
|
imageio==2.22.0
|
||||||
imageio-ffmpeg==0.4.7
|
imageio-ffmpeg==0.4.7
|
||||||
termcolor==2.0.1
|
termcolor==2.0.1
|
||||||
|
|||||||
Reference in New Issue
Block a user