暂存
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
# # Action Head models
|
||||
from .diffusion import ConditionalUnet1D
|
||||
from .conditional_unet1d import ConditionalUnet1D
|
||||
|
||||
__all__ = ["ConditionalUnet1D"]
|
||||
|
||||
@@ -225,14 +225,27 @@ class ConditionalUnet1D(nn.Module):
|
||||
def forward(self,
|
||||
sample: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
local_cond=None, global_cond=None, **kwargs):
|
||||
local_cond=None, global_cond=None,
|
||||
visual_features=None, proprioception=None,
|
||||
**kwargs):
|
||||
"""
|
||||
x: (B,T,input_dim)
|
||||
timestep: (B,) or int, diffusion step
|
||||
local_cond: (B,T,local_cond_dim)
|
||||
global_cond: (B,global_cond_dim)
|
||||
visual_features: (B, T_obs, D_vis)
|
||||
proprioception: (B, T_obs, D_prop)
|
||||
output: (B,T,input_dim)
|
||||
"""
|
||||
if global_cond is None:
|
||||
conds = []
|
||||
if visual_features is not None:
|
||||
conds.append(visual_features.flatten(start_dim=1))
|
||||
if proprioception is not None:
|
||||
conds.append(proprioception.flatten(start_dim=1))
|
||||
if len(conds) > 0:
|
||||
global_cond = torch.cat(conds, dim=-1)
|
||||
|
||||
sample = einops.rearrange(sample, 'b h t -> b t h')
|
||||
|
||||
# 1. time
|
||||
@@ -291,4 +304,3 @@ class ConditionalUnet1D(nn.Module):
|
||||
|
||||
x = einops.rearrange(x, 'b t h -> b h t')
|
||||
return x
|
||||
|
||||
Reference in New Issue
Block a user