This commit is contained in:
gouhanke
2026-02-09 14:41:35 +08:00
parent f833c6d9f1
commit 8b700b6d99
10 changed files with 76 additions and 39 deletions

View File

@@ -1,4 +1,4 @@
# # Action Head models
from .diffusion import ConditionalUnet1D
from .conditional_unet1d import ConditionalUnet1D
__all__ = ["ConditionalUnet1D"]

View File

@@ -225,14 +225,27 @@ class ConditionalUnet1D(nn.Module):
def forward(self,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
local_cond=None, global_cond=None, **kwargs):
local_cond=None, global_cond=None,
visual_features=None, proprioception=None,
**kwargs):
"""
x: (B,T,input_dim)
timestep: (B,) or int, diffusion step
local_cond: (B,T,local_cond_dim)
global_cond: (B,global_cond_dim)
visual_features: (B, T_obs, D_vis)
proprioception: (B, T_obs, D_prop)
output: (B,T,input_dim)
"""
if global_cond is None:
conds = []
if visual_features is not None:
conds.append(visual_features.flatten(start_dim=1))
if proprioception is not None:
conds.append(proprioception.flatten(start_dim=1))
if len(conds) > 0:
global_cond = torch.cat(conds, dim=-1)
sample = einops.rearrange(sample, 'b h t -> b t h')
# 1. time
@@ -291,4 +304,3 @@ class ConditionalUnet1D(nn.Module):
x = einops.rearrange(x, 'b t h -> b h t')
return x