feat: add pusht image imf transformer
This commit is contained in:
@@ -0,0 +1,298 @@
|
|||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin
|
||||||
|
from diffusion_policy.model.diffusion.positional_embedding import SinusoidalPosEmb
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class IMFTransformerForDiffusion(ModuleAttrMixin):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim: int,
|
||||||
|
output_dim: int,
|
||||||
|
horizon: int,
|
||||||
|
n_obs_steps: int = None,
|
||||||
|
cond_dim: int = 0,
|
||||||
|
n_layer: int = 12,
|
||||||
|
n_head: int = 1,
|
||||||
|
n_emb: int = 768,
|
||||||
|
p_drop_emb: float = 0.1,
|
||||||
|
p_drop_attn: float = 0.1,
|
||||||
|
causal_attn: bool = False,
|
||||||
|
time_as_cond: bool = True,
|
||||||
|
obs_as_cond: bool = False,
|
||||||
|
n_cond_layers: int = 0,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert n_head == 1, 'IMFTransformerForDiffusion currently supports single-head attention only.'
|
||||||
|
|
||||||
|
if n_obs_steps is None:
|
||||||
|
n_obs_steps = horizon
|
||||||
|
|
||||||
|
T = horizon
|
||||||
|
T_cond = 2
|
||||||
|
if not time_as_cond:
|
||||||
|
T += 2
|
||||||
|
T_cond -= 2
|
||||||
|
obs_as_cond = cond_dim > 0
|
||||||
|
if obs_as_cond:
|
||||||
|
assert time_as_cond
|
||||||
|
T_cond += n_obs_steps
|
||||||
|
|
||||||
|
self.input_emb = nn.Linear(input_dim, n_emb)
|
||||||
|
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
|
||||||
|
self.drop = nn.Dropout(p_drop_emb)
|
||||||
|
|
||||||
|
self.time_emb = SinusoidalPosEmb(n_emb)
|
||||||
|
self.cond_obs_emb = None
|
||||||
|
if obs_as_cond:
|
||||||
|
self.cond_obs_emb = nn.Linear(cond_dim, n_emb)
|
||||||
|
|
||||||
|
self.cond_pos_emb = None
|
||||||
|
self.encoder = None
|
||||||
|
self.decoder = None
|
||||||
|
encoder_only = False
|
||||||
|
if T_cond > 0:
|
||||||
|
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
|
||||||
|
if n_cond_layers > 0:
|
||||||
|
encoder_layer = nn.TransformerEncoderLayer(
|
||||||
|
d_model=n_emb,
|
||||||
|
nhead=n_head,
|
||||||
|
dim_feedforward=4 * n_emb,
|
||||||
|
dropout=p_drop_attn,
|
||||||
|
activation='gelu',
|
||||||
|
batch_first=True,
|
||||||
|
norm_first=True,
|
||||||
|
)
|
||||||
|
self.encoder = nn.TransformerEncoder(
|
||||||
|
encoder_layer=encoder_layer,
|
||||||
|
num_layers=n_cond_layers,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.encoder = nn.Sequential(
|
||||||
|
nn.Linear(n_emb, 4 * n_emb),
|
||||||
|
nn.Mish(),
|
||||||
|
nn.Linear(4 * n_emb, n_emb),
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_layer = nn.TransformerDecoderLayer(
|
||||||
|
d_model=n_emb,
|
||||||
|
nhead=n_head,
|
||||||
|
dim_feedforward=4 * n_emb,
|
||||||
|
dropout=p_drop_attn,
|
||||||
|
activation='gelu',
|
||||||
|
batch_first=True,
|
||||||
|
norm_first=True,
|
||||||
|
)
|
||||||
|
self.decoder = nn.TransformerDecoder(
|
||||||
|
decoder_layer=decoder_layer,
|
||||||
|
num_layers=n_layer,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
encoder_only = True
|
||||||
|
encoder_layer = nn.TransformerEncoderLayer(
|
||||||
|
d_model=n_emb,
|
||||||
|
nhead=n_head,
|
||||||
|
dim_feedforward=4 * n_emb,
|
||||||
|
dropout=p_drop_attn,
|
||||||
|
activation='gelu',
|
||||||
|
batch_first=True,
|
||||||
|
norm_first=True,
|
||||||
|
)
|
||||||
|
self.encoder = nn.TransformerEncoder(
|
||||||
|
encoder_layer=encoder_layer,
|
||||||
|
num_layers=n_layer,
|
||||||
|
)
|
||||||
|
|
||||||
|
if causal_attn:
|
||||||
|
sz = T
|
||||||
|
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
||||||
|
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||||
|
self.register_buffer('mask', mask)
|
||||||
|
|
||||||
|
if time_as_cond and obs_as_cond:
|
||||||
|
S = T_cond
|
||||||
|
t_idx, s_idx = torch.meshgrid(
|
||||||
|
torch.arange(T),
|
||||||
|
torch.arange(S),
|
||||||
|
indexing='ij',
|
||||||
|
)
|
||||||
|
mask = t_idx >= (s_idx - 2)
|
||||||
|
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||||
|
self.register_buffer('memory_mask', mask)
|
||||||
|
else:
|
||||||
|
self.memory_mask = None
|
||||||
|
else:
|
||||||
|
self.mask = None
|
||||||
|
self.memory_mask = None
|
||||||
|
|
||||||
|
self.ln_f = nn.LayerNorm(n_emb)
|
||||||
|
self.head = nn.Linear(n_emb, output_dim)
|
||||||
|
|
||||||
|
self.T = T
|
||||||
|
self.T_cond = T_cond
|
||||||
|
self.horizon = horizon
|
||||||
|
self.time_as_cond = time_as_cond
|
||||||
|
self.obs_as_cond = obs_as_cond
|
||||||
|
self.encoder_only = encoder_only
|
||||||
|
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
logger.info(
|
||||||
|
'number of parameters: %e',
|
||||||
|
sum(p.numel() for p in self.parameters()),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
ignore_types = (
|
||||||
|
nn.Dropout,
|
||||||
|
SinusoidalPosEmb,
|
||||||
|
nn.TransformerEncoderLayer,
|
||||||
|
nn.TransformerDecoderLayer,
|
||||||
|
nn.TransformerEncoder,
|
||||||
|
nn.TransformerDecoder,
|
||||||
|
nn.ModuleList,
|
||||||
|
nn.Mish,
|
||||||
|
nn.Sequential,
|
||||||
|
)
|
||||||
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||||
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||||
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||||
|
torch.nn.init.zeros_(module.bias)
|
||||||
|
elif isinstance(module, nn.MultiheadAttention):
|
||||||
|
weight_names = ['in_proj_weight', 'q_proj_weight', 'k_proj_weight', 'v_proj_weight']
|
||||||
|
for name in weight_names:
|
||||||
|
weight = getattr(module, name)
|
||||||
|
if weight is not None:
|
||||||
|
torch.nn.init.normal_(weight, mean=0.0, std=0.02)
|
||||||
|
|
||||||
|
bias_names = ['in_proj_bias', 'bias_k', 'bias_v']
|
||||||
|
for name in bias_names:
|
||||||
|
bias = getattr(module, name)
|
||||||
|
if bias is not None:
|
||||||
|
torch.nn.init.zeros_(bias)
|
||||||
|
elif isinstance(module, nn.LayerNorm):
|
||||||
|
torch.nn.init.zeros_(module.bias)
|
||||||
|
torch.nn.init.ones_(module.weight)
|
||||||
|
elif isinstance(module, IMFTransformerForDiffusion):
|
||||||
|
torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
|
||||||
|
if module.cond_obs_emb is not None:
|
||||||
|
torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02)
|
||||||
|
elif isinstance(module, ignore_types):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f'Unaccounted module {module}')
|
||||||
|
|
||||||
|
def get_optim_groups(self, weight_decay: float = 1e-3):
|
||||||
|
decay = set()
|
||||||
|
no_decay = set()
|
||||||
|
whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention)
|
||||||
|
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
||||||
|
for mn, m in self.named_modules():
|
||||||
|
for pn, _ in m.named_parameters():
|
||||||
|
fpn = '%s.%s' % (mn, pn) if mn else pn
|
||||||
|
|
||||||
|
if pn.endswith('bias'):
|
||||||
|
no_decay.add(fpn)
|
||||||
|
elif pn.startswith('bias'):
|
||||||
|
no_decay.add(fpn)
|
||||||
|
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
|
||||||
|
decay.add(fpn)
|
||||||
|
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
|
||||||
|
no_decay.add(fpn)
|
||||||
|
|
||||||
|
no_decay.add('pos_emb')
|
||||||
|
no_decay.add('_dummy_variable')
|
||||||
|
if self.cond_pos_emb is not None:
|
||||||
|
no_decay.add('cond_pos_emb')
|
||||||
|
|
||||||
|
param_dict = {pn: p for pn, p in self.named_parameters()}
|
||||||
|
inter_params = decay & no_decay
|
||||||
|
union_params = decay | no_decay
|
||||||
|
assert len(inter_params) == 0, f'parameters {inter_params} made it into both decay/no_decay sets!'
|
||||||
|
assert len(param_dict.keys() - union_params) == 0, (
|
||||||
|
f'parameters {param_dict.keys() - union_params} were not separated into either decay/no_decay sets!'
|
||||||
|
)
|
||||||
|
|
||||||
|
optim_groups = [
|
||||||
|
{
|
||||||
|
'params': [param_dict[pn] for pn in sorted(list(decay))],
|
||||||
|
'weight_decay': weight_decay,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'params': [param_dict[pn] for pn in sorted(list(no_decay))],
|
||||||
|
'weight_decay': 0.0,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
return optim_groups
|
||||||
|
|
||||||
|
def configure_optimizers(
|
||||||
|
self,
|
||||||
|
learning_rate: float = 1e-4,
|
||||||
|
weight_decay: float = 1e-3,
|
||||||
|
betas: Tuple[float, float] = (0.9, 0.95),
|
||||||
|
):
|
||||||
|
optim_groups = self.get_optim_groups(weight_decay=weight_decay)
|
||||||
|
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
def _prepare_time_input(self, value: Union[torch.Tensor, float, int], sample: torch.Tensor) -> torch.Tensor:
|
||||||
|
if not torch.is_tensor(value):
|
||||||
|
value = torch.tensor([value], dtype=sample.dtype, device=sample.device)
|
||||||
|
elif value.ndim == 0:
|
||||||
|
value = value[None].to(device=sample.device, dtype=sample.dtype)
|
||||||
|
else:
|
||||||
|
value = value.to(device=sample.device, dtype=sample.dtype)
|
||||||
|
return value.expand(sample.shape[0])
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
sample: torch.Tensor,
|
||||||
|
r: Union[torch.Tensor, float, int],
|
||||||
|
t: Union[torch.Tensor, float, int],
|
||||||
|
cond: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
r = self._prepare_time_input(r, sample)
|
||||||
|
t = self._prepare_time_input(t, sample)
|
||||||
|
r_emb = self.time_emb(r).unsqueeze(1)
|
||||||
|
t_emb = self.time_emb(t).unsqueeze(1)
|
||||||
|
|
||||||
|
input_emb = self.input_emb(sample)
|
||||||
|
|
||||||
|
if self.encoder_only:
|
||||||
|
token_embeddings = torch.cat([r_emb, t_emb, input_emb], dim=1)
|
||||||
|
token_count = token_embeddings.shape[1]
|
||||||
|
position_embeddings = self.pos_emb[:, :token_count, :]
|
||||||
|
x = self.drop(token_embeddings + position_embeddings)
|
||||||
|
x = self.encoder(src=x, mask=self.mask)
|
||||||
|
x = x[:, 2:, :]
|
||||||
|
else:
|
||||||
|
cond_embeddings = torch.cat([r_emb, t_emb], dim=1)
|
||||||
|
if self.obs_as_cond:
|
||||||
|
cond_obs_emb = self.cond_obs_emb(cond)
|
||||||
|
cond_embeddings = torch.cat([cond_embeddings, cond_obs_emb], dim=1)
|
||||||
|
token_count = cond_embeddings.shape[1]
|
||||||
|
position_embeddings = self.cond_pos_emb[:, :token_count, :]
|
||||||
|
x = self.drop(cond_embeddings + position_embeddings)
|
||||||
|
x = self.encoder(x)
|
||||||
|
memory = x
|
||||||
|
|
||||||
|
token_embeddings = input_emb
|
||||||
|
token_count = token_embeddings.shape[1]
|
||||||
|
position_embeddings = self.pos_emb[:, :token_count, :]
|
||||||
|
x = self.drop(token_embeddings + position_embeddings)
|
||||||
|
x = self.decoder(
|
||||||
|
tgt=x,
|
||||||
|
memory=memory,
|
||||||
|
tgt_mask=self.mask,
|
||||||
|
memory_mask=self.memory_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
x = self.ln_f(x)
|
||||||
|
x = self.head(x)
|
||||||
|
return x
|
||||||
273
diffusion_policy/policy/imf_transformer_hybrid_image_policy.py
Normal file
273
diffusion_policy/policy/imf_transformer_hybrid_image_policy.py
Normal file
@@ -0,0 +1,273 @@
|
|||||||
|
from contextlib import nullcontext
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import reduce
|
||||||
|
|
||||||
|
from diffusion_policy.common.pytorch_util import dict_apply
|
||||||
|
from diffusion_policy.model.diffusion.imf_transformer_for_diffusion import IMFTransformerForDiffusion
|
||||||
|
from diffusion_policy.policy.diffusion_transformer_hybrid_image_policy import (
|
||||||
|
DiffusionTransformerHybridImagePolicy,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torch.func import jvp as TORCH_FUNC_JVP
|
||||||
|
except ImportError: # pragma: no cover - depends on torch version
|
||||||
|
TORCH_FUNC_JVP = None
|
||||||
|
|
||||||
|
|
||||||
|
class IMFTransformerHybridImagePolicy(DiffusionTransformerHybridImagePolicy):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
shape_meta: dict,
|
||||||
|
noise_scheduler,
|
||||||
|
horizon,
|
||||||
|
n_action_steps,
|
||||||
|
n_obs_steps,
|
||||||
|
num_inference_steps=None,
|
||||||
|
crop_shape=(76, 76),
|
||||||
|
obs_encoder_group_norm=False,
|
||||||
|
eval_fixed_crop=False,
|
||||||
|
n_layer=8,
|
||||||
|
n_cond_layers=0,
|
||||||
|
n_head=1,
|
||||||
|
n_emb=256,
|
||||||
|
p_drop_emb=0.0,
|
||||||
|
p_drop_attn=0.3,
|
||||||
|
causal_attn=True,
|
||||||
|
time_as_cond=True,
|
||||||
|
obs_as_cond=True,
|
||||||
|
pred_action_steps_only=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if num_inference_steps is None:
|
||||||
|
num_inference_steps = 1
|
||||||
|
elif num_inference_steps != 1:
|
||||||
|
raise ValueError(
|
||||||
|
'IMFTransformerHybridImagePolicy only supports one-step inference; '
|
||||||
|
f'num_inference_steps must be 1, got {num_inference_steps}.'
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
shape_meta=shape_meta,
|
||||||
|
noise_scheduler=noise_scheduler,
|
||||||
|
horizon=horizon,
|
||||||
|
n_action_steps=n_action_steps,
|
||||||
|
n_obs_steps=n_obs_steps,
|
||||||
|
num_inference_steps=num_inference_steps,
|
||||||
|
crop_shape=crop_shape,
|
||||||
|
obs_encoder_group_norm=obs_encoder_group_norm,
|
||||||
|
eval_fixed_crop=eval_fixed_crop,
|
||||||
|
n_layer=n_layer,
|
||||||
|
n_cond_layers=n_cond_layers,
|
||||||
|
n_head=n_head,
|
||||||
|
n_emb=n_emb,
|
||||||
|
p_drop_emb=p_drop_emb,
|
||||||
|
p_drop_attn=p_drop_attn,
|
||||||
|
causal_attn=causal_attn,
|
||||||
|
time_as_cond=time_as_cond,
|
||||||
|
obs_as_cond=obs_as_cond,
|
||||||
|
pred_action_steps_only=pred_action_steps_only,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
input_dim = self.action_dim if self.obs_as_cond else (self.obs_feature_dim + self.action_dim)
|
||||||
|
cond_dim = self.obs_feature_dim if self.obs_as_cond else 0
|
||||||
|
model_horizon = self.n_action_steps if self.pred_action_steps_only else horizon
|
||||||
|
self.model = IMFTransformerForDiffusion(
|
||||||
|
input_dim=input_dim,
|
||||||
|
output_dim=input_dim,
|
||||||
|
horizon=model_horizon,
|
||||||
|
n_obs_steps=n_obs_steps,
|
||||||
|
cond_dim=cond_dim,
|
||||||
|
n_layer=n_layer,
|
||||||
|
n_head=n_head,
|
||||||
|
n_emb=n_emb,
|
||||||
|
p_drop_emb=p_drop_emb,
|
||||||
|
p_drop_attn=p_drop_attn,
|
||||||
|
causal_attn=causal_attn,
|
||||||
|
time_as_cond=time_as_cond,
|
||||||
|
obs_as_cond=obs_as_cond,
|
||||||
|
n_cond_layers=n_cond_layers,
|
||||||
|
)
|
||||||
|
self.num_inference_steps = 1
|
||||||
|
|
||||||
|
def fn(self, z: torch.Tensor, r: torch.Tensor, t: torch.Tensor, cond=None) -> torch.Tensor:
|
||||||
|
return self.model(z, r, t, cond=cond)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _broadcast_batch_time(value: torch.Tensor, reference: torch.Tensor) -> torch.Tensor:
|
||||||
|
while value.ndim < reference.ndim:
|
||||||
|
value = value.unsqueeze(-1)
|
||||||
|
return value
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _apply_conditioning(
|
||||||
|
trajectory: torch.Tensor,
|
||||||
|
condition_data: Optional[torch.Tensor] = None,
|
||||||
|
condition_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if condition_data is None or condition_mask is None:
|
||||||
|
return trajectory
|
||||||
|
conditioned = trajectory.clone()
|
||||||
|
conditioned[condition_mask] = condition_data[condition_mask]
|
||||||
|
return conditioned
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _jvp_math_sdp_context(z_t: torch.Tensor):
|
||||||
|
if z_t.is_cuda:
|
||||||
|
return torch.backends.cuda.sdp_kernel(
|
||||||
|
enable_flash=False,
|
||||||
|
enable_math=True,
|
||||||
|
enable_mem_efficient=False,
|
||||||
|
enable_cudnn=False,
|
||||||
|
)
|
||||||
|
return nullcontext()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _jvp_tangents(v: torch.Tensor, r: torch.Tensor, t: torch.Tensor):
|
||||||
|
return v.detach(), torch.zeros_like(r), torch.ones_like(t)
|
||||||
|
|
||||||
|
def _compute_u_and_du_dt(
|
||||||
|
self,
|
||||||
|
z_t: torch.Tensor,
|
||||||
|
r: torch.Tensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
cond,
|
||||||
|
v: torch.Tensor,
|
||||||
|
condition_data: Optional[torch.Tensor] = None,
|
||||||
|
condition_mask: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
tangents = self._jvp_tangents(v, r, t)
|
||||||
|
|
||||||
|
def g(z, r_value, t_value):
|
||||||
|
conditioned_z = self._apply_conditioning(z, condition_data, condition_mask)
|
||||||
|
return self.fn(conditioned_z, r_value, t_value, cond=cond)
|
||||||
|
|
||||||
|
with self._jvp_math_sdp_context(z_t):
|
||||||
|
if TORCH_FUNC_JVP is not None:
|
||||||
|
try:
|
||||||
|
return TORCH_FUNC_JVP(g, (z_t, r, t), tangents)
|
||||||
|
except (RuntimeError, TypeError, NotImplementedError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
u = g(z_t, r, t)
|
||||||
|
_, du_dt = torch.autograd.functional.jvp(
|
||||||
|
g,
|
||||||
|
(z_t, r, t),
|
||||||
|
tangents,
|
||||||
|
create_graph=False,
|
||||||
|
strict=False,
|
||||||
|
)
|
||||||
|
return u, du_dt
|
||||||
|
|
||||||
|
def _compound_velocity(
|
||||||
|
self,
|
||||||
|
u: torch.Tensor,
|
||||||
|
du_dt: torch.Tensor,
|
||||||
|
r: torch.Tensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
delta = self._broadcast_batch_time(t - r, u)
|
||||||
|
return u + delta * du_dt.detach()
|
||||||
|
|
||||||
|
def _sample_one_step(
|
||||||
|
self,
|
||||||
|
z_t: torch.Tensor,
|
||||||
|
r: torch.Tensor = None,
|
||||||
|
t: torch.Tensor = None,
|
||||||
|
cond=None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
batch_size = z_t.shape[0]
|
||||||
|
if t is None:
|
||||||
|
t = torch.ones(batch_size, device=z_t.device, dtype=z_t.dtype)
|
||||||
|
if r is None:
|
||||||
|
r = torch.zeros(batch_size, device=z_t.device, dtype=z_t.dtype)
|
||||||
|
u = self.fn(z_t, r, t, cond=cond)
|
||||||
|
delta = self._broadcast_batch_time(t - r, z_t)
|
||||||
|
return z_t - delta * u
|
||||||
|
|
||||||
|
def conditional_sample(
|
||||||
|
self,
|
||||||
|
condition_data,
|
||||||
|
condition_mask,
|
||||||
|
cond=None,
|
||||||
|
generator=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
trajectory = torch.randn(
|
||||||
|
size=condition_data.shape,
|
||||||
|
dtype=condition_data.dtype,
|
||||||
|
device=condition_data.device,
|
||||||
|
generator=generator,
|
||||||
|
)
|
||||||
|
trajectory = self._apply_conditioning(trajectory, condition_data, condition_mask)
|
||||||
|
trajectory = self._sample_one_step(trajectory, cond=cond)
|
||||||
|
trajectory = self._apply_conditioning(trajectory, condition_data, condition_mask)
|
||||||
|
return trajectory
|
||||||
|
|
||||||
|
def compute_loss(self, batch):
|
||||||
|
assert 'valid_mask' not in batch
|
||||||
|
nobs = self.normalizer.normalize(batch['obs'])
|
||||||
|
nactions = self.normalizer['action'].normalize(batch['action'])
|
||||||
|
batch_size = nactions.shape[0]
|
||||||
|
horizon = nactions.shape[1]
|
||||||
|
To = self.n_obs_steps
|
||||||
|
|
||||||
|
cond = None
|
||||||
|
trajectory = nactions
|
||||||
|
if self.obs_as_cond:
|
||||||
|
this_nobs = dict_apply(
|
||||||
|
nobs,
|
||||||
|
lambda x: x[:, :To, ...].reshape(-1, *x.shape[2:]),
|
||||||
|
)
|
||||||
|
nobs_features = self.obs_encoder(this_nobs)
|
||||||
|
cond = nobs_features.reshape(batch_size, To, -1)
|
||||||
|
if self.pred_action_steps_only:
|
||||||
|
start = To - 1
|
||||||
|
end = start + self.n_action_steps
|
||||||
|
trajectory = nactions[:, start:end]
|
||||||
|
else:
|
||||||
|
this_nobs = dict_apply(nobs, lambda x: x.reshape(-1, *x.shape[2:]))
|
||||||
|
nobs_features = self.obs_encoder(this_nobs)
|
||||||
|
nobs_features = nobs_features.reshape(batch_size, horizon, -1)
|
||||||
|
trajectory = torch.cat([nactions, nobs_features], dim=-1).detach()
|
||||||
|
|
||||||
|
if self.pred_action_steps_only:
|
||||||
|
condition_mask = torch.zeros_like(trajectory, dtype=torch.bool)
|
||||||
|
else:
|
||||||
|
condition_mask = self.mask_generator(trajectory.shape)
|
||||||
|
|
||||||
|
loss_mask = torch.zeros_like(trajectory, dtype=torch.bool)
|
||||||
|
loss_mask[..., : self.action_dim] = True
|
||||||
|
loss_mask = loss_mask & ~condition_mask
|
||||||
|
|
||||||
|
x = trajectory
|
||||||
|
e = torch.randn_like(x)
|
||||||
|
t = torch.rand(batch_size, device=x.device, dtype=x.dtype)
|
||||||
|
r = torch.rand(batch_size, device=x.device, dtype=x.dtype)
|
||||||
|
t, r = torch.maximum(t, r), torch.minimum(t, r)
|
||||||
|
|
||||||
|
t_broadcast = self._broadcast_batch_time(t, x)
|
||||||
|
z_t = (1 - t_broadcast) * x + t_broadcast * e
|
||||||
|
z_t = self._apply_conditioning(z_t, x, condition_mask)
|
||||||
|
|
||||||
|
v = self.fn(z_t, t, t, cond=cond)
|
||||||
|
u, du_dt = self._compute_u_and_du_dt(
|
||||||
|
z_t,
|
||||||
|
r,
|
||||||
|
t,
|
||||||
|
cond=cond,
|
||||||
|
v=v,
|
||||||
|
condition_data=x,
|
||||||
|
condition_mask=condition_mask,
|
||||||
|
)
|
||||||
|
V = self._compound_velocity(u, du_dt, r, t)
|
||||||
|
target = e - x
|
||||||
|
|
||||||
|
loss = F.mse_loss(V, target, reduction='none')
|
||||||
|
loss = loss * loss_mask.type(loss.dtype)
|
||||||
|
loss = reduce(loss, 'b ... -> b (...)', 'mean')
|
||||||
|
loss = loss.mean()
|
||||||
|
return loss
|
||||||
30
image_pusht_diffusion_policy_dit_imf.yaml
Normal file
30
image_pusht_diffusion_policy_dit_imf.yaml
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
defaults:
|
||||||
|
- diffusion_policy/config/train_diffusion_transformer_hybrid_workspace@_here_
|
||||||
|
- override /diffusion_policy/config/task@task: pusht_image
|
||||||
|
- _self_
|
||||||
|
|
||||||
|
exp_name: pusht_image_dit_imf
|
||||||
|
|
||||||
|
policy:
|
||||||
|
_target_: diffusion_policy.policy.imf_transformer_hybrid_image_policy.IMFTransformerHybridImagePolicy
|
||||||
|
num_inference_steps: 1
|
||||||
|
n_head: 1
|
||||||
|
|
||||||
|
logging:
|
||||||
|
backend: swanlab
|
||||||
|
mode: online
|
||||||
|
tags: ["${name}", "${task_name}", "${exp_name}", "swanlab"]
|
||||||
|
id: ${now:%Y%m%d%H%M%S}_${name}_${task_name}
|
||||||
|
group: ${exp_name}
|
||||||
|
|
||||||
|
dataloader:
|
||||||
|
num_workers: 0
|
||||||
|
|
||||||
|
val_dataloader:
|
||||||
|
num_workers: 0
|
||||||
|
|
||||||
|
task:
|
||||||
|
env_runner:
|
||||||
|
n_envs: 1
|
||||||
|
n_test_vis: 0
|
||||||
|
n_train_vis: 0
|
||||||
46
tests/test_imf_transformer_for_diffusion.py
Normal file
46
tests/test_imf_transformer_for_diffusion.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
import inspect
|
||||||
|
import pathlib
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
|
||||||
|
if str(ROOT_DIR) not in sys.path:
|
||||||
|
sys.path.append(str(ROOT_DIR))
|
||||||
|
|
||||||
|
from diffusion_policy.model.diffusion.imf_transformer_for_diffusion import ( # noqa: E402
|
||||||
|
IMFTransformerForDiffusion,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_imf_transformer_forward_signature_and_shape_single_head():
|
||||||
|
signature = inspect.signature(IMFTransformerForDiffusion.forward)
|
||||||
|
assert list(signature.parameters)[:5] == ['self', 'sample', 'r', 't', 'cond']
|
||||||
|
assert signature.parameters['cond'].default is None
|
||||||
|
|
||||||
|
model = IMFTransformerForDiffusion(
|
||||||
|
input_dim=3,
|
||||||
|
output_dim=3,
|
||||||
|
horizon=5,
|
||||||
|
n_obs_steps=2,
|
||||||
|
cond_dim=4,
|
||||||
|
n_layer=1,
|
||||||
|
n_head=1,
|
||||||
|
n_emb=16,
|
||||||
|
p_drop_emb=0.0,
|
||||||
|
p_drop_attn=0.0,
|
||||||
|
causal_attn=True,
|
||||||
|
time_as_cond=True,
|
||||||
|
obs_as_cond=True,
|
||||||
|
n_cond_layers=0,
|
||||||
|
)
|
||||||
|
model.configure_optimizers()
|
||||||
|
|
||||||
|
sample = torch.randn(2, 5, 3)
|
||||||
|
r = torch.rand(2)
|
||||||
|
t = torch.rand(2)
|
||||||
|
cond = torch.randn(2, 2, 4)
|
||||||
|
|
||||||
|
pred_u = model(sample, r, t, cond=cond)
|
||||||
|
|
||||||
|
assert pred_u.shape == sample.shape
|
||||||
313
tests/test_imf_transformer_hybrid_image_policy.py
Normal file
313
tests/test_imf_transformer_hybrid_image_policy.py
Normal file
@@ -0,0 +1,313 @@
|
|||||||
|
import pathlib
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
|
||||||
|
if str(ROOT_DIR) not in sys.path:
|
||||||
|
sys.path.append(str(ROOT_DIR))
|
||||||
|
|
||||||
|
import diffusion_policy.policy.imf_transformer_hybrid_image_policy as policy_module # noqa: E402
|
||||||
|
from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin # noqa: E402
|
||||||
|
from diffusion_policy.policy.imf_transformer_hybrid_image_policy import ( # noqa: E402
|
||||||
|
IMFTransformerHybridImagePolicy,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ConstantModel(nn.Module):
|
||||||
|
def __init__(self, value):
|
||||||
|
super().__init__()
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def forward(self, sample, r, t, cond=None):
|
||||||
|
return torch.full_like(sample, self.value)
|
||||||
|
|
||||||
|
|
||||||
|
class AffineModel(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.tensor(2.0))
|
||||||
|
|
||||||
|
def forward(self, sample, r, t, cond=None):
|
||||||
|
return sample * self.weight + (r + t).view(-1, 1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class SumMixModel(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.tensor(2.0))
|
||||||
|
|
||||||
|
def forward(self, sample, r, t, cond=None):
|
||||||
|
mixed = sample.sum(dim=-1, keepdim=True).expand_as(sample)
|
||||||
|
return mixed * self.weight + t.view(-1, 1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class TrackingContext:
|
||||||
|
def __init__(self):
|
||||||
|
self.active = False
|
||||||
|
self.enter_count = 0
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.active = True
|
||||||
|
self.enter_count += 1
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc, tb):
|
||||||
|
self.active = False
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def make_policy(model):
|
||||||
|
policy = IMFTransformerHybridImagePolicy.__new__(IMFTransformerHybridImagePolicy)
|
||||||
|
ModuleAttrMixin.__init__(policy)
|
||||||
|
policy.model = model
|
||||||
|
return policy
|
||||||
|
|
||||||
|
|
||||||
|
def fake_parent_init(
|
||||||
|
self,
|
||||||
|
shape_meta,
|
||||||
|
noise_scheduler,
|
||||||
|
horizon,
|
||||||
|
n_action_steps,
|
||||||
|
n_obs_steps,
|
||||||
|
num_inference_steps=None,
|
||||||
|
crop_shape=(76, 76),
|
||||||
|
obs_encoder_group_norm=False,
|
||||||
|
eval_fixed_crop=False,
|
||||||
|
n_layer=8,
|
||||||
|
n_cond_layers=0,
|
||||||
|
n_head=1,
|
||||||
|
n_emb=256,
|
||||||
|
p_drop_emb=0.0,
|
||||||
|
p_drop_attn=0.3,
|
||||||
|
causal_attn=True,
|
||||||
|
time_as_cond=True,
|
||||||
|
obs_as_cond=True,
|
||||||
|
pred_action_steps_only=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
ModuleAttrMixin.__init__(self)
|
||||||
|
self.action_dim = shape_meta['action']['shape'][0]
|
||||||
|
self.obs_feature_dim = 4
|
||||||
|
self.obs_as_cond = obs_as_cond
|
||||||
|
self.pred_action_steps_only = pred_action_steps_only
|
||||||
|
self.n_action_steps = n_action_steps
|
||||||
|
self.n_obs_steps = n_obs_steps
|
||||||
|
self.horizon = horizon
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def shape_meta():
|
||||||
|
return {
|
||||||
|
'action': {'shape': [2]},
|
||||||
|
'obs': {},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_sample_one_step_uses_imf_update_formula():
|
||||||
|
policy = make_policy(ConstantModel(0.25))
|
||||||
|
z_1 = torch.tensor([
|
||||||
|
[[1.0, -1.0], [0.5, 0.0]],
|
||||||
|
[[2.0, 3.0], [-2.0, 4.0]],
|
||||||
|
])
|
||||||
|
r = torch.zeros(z_1.shape[0])
|
||||||
|
t = torch.ones(z_1.shape[0])
|
||||||
|
|
||||||
|
x_hat = policy._sample_one_step(z_1, r=r, t=t, cond=None)
|
||||||
|
|
||||||
|
expected = z_1 - (t - r).view(-1, 1, 1) * 0.25
|
||||||
|
assert torch.allclose(x_hat, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_compound_velocity_uses_detached_du_dt_term():
|
||||||
|
policy = make_policy(ConstantModel(0.0))
|
||||||
|
u = torch.tensor([[[1.0], [2.0]]], requires_grad=True)
|
||||||
|
du_dt = torch.tensor([[[3.0], [4.0]]], requires_grad=True)
|
||||||
|
r = torch.tensor([0.2])
|
||||||
|
t = torch.tensor([0.8])
|
||||||
|
|
||||||
|
compound = policy._compound_velocity(u, du_dt, r, t)
|
||||||
|
expected = u + (t - r).view(-1, 1, 1) * du_dt.detach()
|
||||||
|
|
||||||
|
assert torch.allclose(compound, expected)
|
||||||
|
|
||||||
|
compound.sum().backward()
|
||||||
|
assert u.grad is not None
|
||||||
|
assert du_dt.grad is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_compute_u_and_du_dt_uses_math_sdpa_context_for_torch_func_jvp(monkeypatch):
|
||||||
|
tracker = TrackingContext()
|
||||||
|
|
||||||
|
def fake_jvp(fn, primals, tangents):
|
||||||
|
assert tracker.active is True
|
||||||
|
return fn(*primals), torch.zeros_like(primals[0])
|
||||||
|
|
||||||
|
monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', fake_jvp)
|
||||||
|
|
||||||
|
policy = make_policy(ConstantModel(0.5))
|
||||||
|
policy._jvp_math_sdp_context = lambda tensor: tracker
|
||||||
|
z_t = torch.randn(2, 3, 4)
|
||||||
|
r = torch.rand(2, requires_grad=True)
|
||||||
|
t = torch.rand(2, requires_grad=True)
|
||||||
|
v = torch.randn_like(z_t, requires_grad=True)
|
||||||
|
|
||||||
|
policy._compute_u_and_du_dt(z_t, r, t, cond=None, v=v)
|
||||||
|
|
||||||
|
assert tracker.enter_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_compute_u_and_du_dt_uses_math_sdpa_context_for_autograd_fallback(monkeypatch):
|
||||||
|
tracker = TrackingContext()
|
||||||
|
|
||||||
|
def fake_autograd_jvp(fn, primals, tangents, create_graph=False, strict=False):
|
||||||
|
assert tracker.active is True
|
||||||
|
return fn(*primals), torch.zeros_like(primals[0])
|
||||||
|
|
||||||
|
monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', None)
|
||||||
|
monkeypatch.setattr(policy_module.torch.autograd.functional, 'jvp', fake_autograd_jvp)
|
||||||
|
|
||||||
|
policy = make_policy(ConstantModel(0.5))
|
||||||
|
policy._jvp_math_sdp_context = lambda tensor: tracker
|
||||||
|
z_t = torch.randn(2, 3, 4)
|
||||||
|
r = torch.rand(2, requires_grad=True)
|
||||||
|
t = torch.rand(2, requires_grad=True)
|
||||||
|
v = torch.randn_like(z_t, requires_grad=True)
|
||||||
|
|
||||||
|
policy._compute_u_and_du_dt(z_t, r, t, cond=None, v=v)
|
||||||
|
|
||||||
|
assert tracker.enter_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_compute_u_and_du_dt_uses_detached_v_zero_r_unit_t_and_reapplies_conditioning(monkeypatch):
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def fake_jvp(fn, primals, tangents):
|
||||||
|
captured['tangents'] = tangents
|
||||||
|
captured['primal_output'] = fn(*primals)
|
||||||
|
return captured['primal_output'], torch.zeros_like(primals[0])
|
||||||
|
|
||||||
|
monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', fake_jvp)
|
||||||
|
|
||||||
|
policy = make_policy(SumMixModel())
|
||||||
|
z_t = torch.tensor([[[1.0, 2.0, 3.0]]])
|
||||||
|
r = torch.rand(1, requires_grad=True)
|
||||||
|
t = torch.rand(1, requires_grad=True)
|
||||||
|
v = torch.tensor([[[10.0, 20.0, 30.0]]], requires_grad=True)
|
||||||
|
condition_mask = torch.tensor([[[False, True, False]]])
|
||||||
|
condition_data = torch.tensor([[[0.0, 7.0, 0.0]]])
|
||||||
|
|
||||||
|
policy._compute_u_and_du_dt(
|
||||||
|
z_t,
|
||||||
|
r,
|
||||||
|
t,
|
||||||
|
cond=None,
|
||||||
|
v=v,
|
||||||
|
condition_data=condition_data,
|
||||||
|
condition_mask=condition_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
tangent_v, tangent_r, tangent_t = captured['tangents']
|
||||||
|
assert torch.equal(tangent_v, v.detach())
|
||||||
|
assert tangent_v.requires_grad is False
|
||||||
|
assert torch.equal(tangent_r, torch.zeros_like(r))
|
||||||
|
assert torch.equal(tangent_t, torch.ones_like(t))
|
||||||
|
|
||||||
|
conditioned = z_t.clone()
|
||||||
|
conditioned[condition_mask] = condition_data[condition_mask]
|
||||||
|
expected_primal = policy.model(conditioned, r, t, cond=None)
|
||||||
|
assert torch.allclose(captured['primal_output'], expected_primal)
|
||||||
|
|
||||||
|
|
||||||
|
def test_compute_u_and_du_dt_fallback_blocks_conditioned_tangent_leakage_and_keeps_primal_gradients(monkeypatch):
|
||||||
|
monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', None)
|
||||||
|
|
||||||
|
policy = make_policy(SumMixModel())
|
||||||
|
z_t = torch.tensor([[[1.0, 2.0, 3.0]]], requires_grad=True)
|
||||||
|
r = torch.rand(1, requires_grad=True)
|
||||||
|
t = torch.rand(1, requires_grad=True)
|
||||||
|
v = torch.tensor([[[1.0, 10.0, 100.0]]], requires_grad=True)
|
||||||
|
condition_mask = torch.tensor([[[False, True, False]]])
|
||||||
|
condition_data = torch.tensor([[[0.0, 7.0, 0.0]]])
|
||||||
|
|
||||||
|
u, du_dt = policy._compute_u_and_du_dt(
|
||||||
|
z_t,
|
||||||
|
r,
|
||||||
|
t,
|
||||||
|
cond=None,
|
||||||
|
v=v,
|
||||||
|
condition_data=condition_data,
|
||||||
|
condition_mask=condition_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
conditioned = z_t.detach().clone()
|
||||||
|
conditioned[condition_mask] = condition_data[condition_mask]
|
||||||
|
expected_u = policy.model(conditioned, r, t, cond=None)
|
||||||
|
expected_du_dt_scalar = policy.model.weight.detach() * torch.tensor(101.0) + 1.0
|
||||||
|
expected_du_dt = torch.full_like(z_t, expected_du_dt_scalar)
|
||||||
|
|
||||||
|
assert u.shape == z_t.shape
|
||||||
|
assert du_dt.shape == z_t.shape
|
||||||
|
assert torch.allclose(u, expected_u)
|
||||||
|
assert torch.allclose(du_dt, expected_du_dt)
|
||||||
|
|
||||||
|
u.sum().backward()
|
||||||
|
assert policy.model.weight.grad is not None
|
||||||
|
assert torch.count_nonzero(policy.model.weight.grad) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_uses_action_step_horizon_when_pred_action_steps_only(monkeypatch, shape_meta):
|
||||||
|
monkeypatch.setattr(
|
||||||
|
policy_module.DiffusionTransformerHybridImagePolicy,
|
||||||
|
'__init__',
|
||||||
|
fake_parent_init,
|
||||||
|
)
|
||||||
|
|
||||||
|
policy = IMFTransformerHybridImagePolicy(
|
||||||
|
shape_meta=shape_meta,
|
||||||
|
noise_scheduler=None,
|
||||||
|
horizon=10,
|
||||||
|
n_action_steps=4,
|
||||||
|
n_obs_steps=2,
|
||||||
|
num_inference_steps=1,
|
||||||
|
n_layer=1,
|
||||||
|
n_head=1,
|
||||||
|
n_emb=16,
|
||||||
|
p_drop_emb=0.0,
|
||||||
|
p_drop_attn=0.0,
|
||||||
|
causal_attn=True,
|
||||||
|
obs_as_cond=True,
|
||||||
|
pred_action_steps_only=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert policy.model.horizon == 4
|
||||||
|
assert policy.num_inference_steps == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_rejects_non_one_step_inference(monkeypatch, shape_meta):
|
||||||
|
monkeypatch.setattr(
|
||||||
|
policy_module.DiffusionTransformerHybridImagePolicy,
|
||||||
|
'__init__',
|
||||||
|
fake_parent_init,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match='num_inference_steps'):
|
||||||
|
IMFTransformerHybridImagePolicy(
|
||||||
|
shape_meta=shape_meta,
|
||||||
|
noise_scheduler=None,
|
||||||
|
horizon=10,
|
||||||
|
n_action_steps=4,
|
||||||
|
n_obs_steps=2,
|
||||||
|
num_inference_steps=2,
|
||||||
|
n_layer=1,
|
||||||
|
n_head=1,
|
||||||
|
n_emb=16,
|
||||||
|
p_drop_emb=0.0,
|
||||||
|
p_drop_attn=0.0,
|
||||||
|
causal_attn=True,
|
||||||
|
obs_as_cond=True,
|
||||||
|
pred_action_steps_only=False,
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user