diff --git a/roboimi/vla/models/heads/conditional_unet1d.py b/roboimi/vla/models/heads/conditional_unet1d.py index dae7eb8..b9cc11e 100644 --- a/roboimi/vla/models/heads/conditional_unet1d.py +++ b/roboimi/vla/models/heads/conditional_unet1d.py @@ -122,9 +122,8 @@ class ConditionalResidualBlock1D(nn.Module): class ConditionalUnet1D(nn.Module): - def __init__(self, + def __init__(self, input_dim, - local_cond_dim=None, global_cond_dim=None, diffusion_step_embed_dim=256, down_dims=[256,512,1024], @@ -149,23 +148,6 @@ class ConditionalUnet1D(nn.Module): in_out = list(zip(all_dims[:-1], all_dims[1:])) - local_cond_encoder = None - if local_cond_dim is not None: - _, dim_out = in_out[0] - dim_in = local_cond_dim - local_cond_encoder = nn.ModuleList([ - # down encoder - ConditionalResidualBlock1D( - dim_in, dim_out, cond_dim=cond_dim, - kernel_size=kernel_size, n_groups=n_groups, - cond_predict_scale=cond_predict_scale), - # up encoder - ConditionalResidualBlock1D( - dim_in, dim_out, cond_dim=cond_dim, - kernel_size=kernel_size, n_groups=n_groups, - cond_predict_scale=cond_predict_scale) - ]) - mid_dim = all_dims[-1] self.mid_modules = nn.ModuleList([ ConditionalResidualBlock1D( @@ -216,21 +198,19 @@ class ConditionalUnet1D(nn.Module): ) self.diffusion_step_encoder = diffusion_step_encoder - self.local_cond_encoder = local_cond_encoder self.up_modules = up_modules self.down_modules = down_modules self.final_conv = final_conv - def forward(self, - sample: torch.Tensor, - timestep: Union[torch.Tensor, float, int], - local_cond=None, global_cond=None, + def forward(self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + global_cond=None, **kwargs): """ x: (B,T,input_dim) timestep: (B,) or int, diffusion step - local_cond: (B,T,local_cond_dim) global_cond: (B,global_cond_dim) output: (B,T,input_dim) """ @@ -252,23 +232,11 @@ class ConditionalUnet1D(nn.Module): global_feature = torch.cat([ global_feature, global_cond ], axis=-1) - - # encode local features - h_local = list() - if local_cond is not None: - local_cond = einops.rearrange(local_cond, 'b h t -> b t h') - resnet, resnet2 = self.local_cond_encoder - x = resnet(local_cond, global_feature) - h_local.append(x) - x = resnet2(local_cond, global_feature) - h_local.append(x) - + x = sample h = [] for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules): x = resnet(x, global_feature) - if idx == 0 and len(h_local) > 0: - x = x + h_local[0] x = resnet2(x, global_feature) h.append(x) x = downsample(x) @@ -279,12 +247,6 @@ class ConditionalUnet1D(nn.Module): for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules): x = torch.cat((x, h.pop()), dim=1) x = resnet(x, global_feature) - # The correct condition should be: - # if idx == (len(self.up_modules)-1) and len(h_local) > 0: - # However this change will break compatibility with published checkpoints. - # Therefore it is left as a comment. - if idx == len(self.up_modules) and len(h_local) > 0: - x = x + h_local[1] x = resnet2(x, global_feature) x = upsample(x)