From abb4f501e38e84ffc87b984cb67938c4833c7840 Mon Sep 17 00:00:00 2001 From: gouhanke <12219217+gouhanke@user.noreply.gitee.com> Date: Sat, 28 Feb 2026 10:42:16 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E5=88=A0=E9=99=A4unet=E9=87=8C?= =?UTF-8?q?=E7=9A=84local=5Fcond(=E6=9C=AA=E4=BD=BF=E7=94=A8)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../vla/models/heads/conditional_unet1d.py | 50 +++---------------- 1 file changed, 6 insertions(+), 44 deletions(-) diff --git a/roboimi/vla/models/heads/conditional_unet1d.py b/roboimi/vla/models/heads/conditional_unet1d.py index dae7eb8..b9cc11e 100644 --- a/roboimi/vla/models/heads/conditional_unet1d.py +++ b/roboimi/vla/models/heads/conditional_unet1d.py @@ -122,9 +122,8 @@ class ConditionalResidualBlock1D(nn.Module): class ConditionalUnet1D(nn.Module): - def __init__(self, + def __init__(self, input_dim, - local_cond_dim=None, global_cond_dim=None, diffusion_step_embed_dim=256, down_dims=[256,512,1024], @@ -149,23 +148,6 @@ class ConditionalUnet1D(nn.Module): in_out = list(zip(all_dims[:-1], all_dims[1:])) - local_cond_encoder = None - if local_cond_dim is not None: - _, dim_out = in_out[0] - dim_in = local_cond_dim - local_cond_encoder = nn.ModuleList([ - # down encoder - ConditionalResidualBlock1D( - dim_in, dim_out, cond_dim=cond_dim, - kernel_size=kernel_size, n_groups=n_groups, - cond_predict_scale=cond_predict_scale), - # up encoder - ConditionalResidualBlock1D( - dim_in, dim_out, cond_dim=cond_dim, - kernel_size=kernel_size, n_groups=n_groups, - cond_predict_scale=cond_predict_scale) - ]) - mid_dim = all_dims[-1] self.mid_modules = nn.ModuleList([ ConditionalResidualBlock1D( @@ -216,21 +198,19 @@ class ConditionalUnet1D(nn.Module): ) self.diffusion_step_encoder = diffusion_step_encoder - self.local_cond_encoder = local_cond_encoder self.up_modules = up_modules self.down_modules = down_modules self.final_conv = final_conv - def forward(self, - sample: torch.Tensor, - timestep: Union[torch.Tensor, float, int], - local_cond=None, global_cond=None, + def forward(self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + global_cond=None, **kwargs): """ x: (B,T,input_dim) timestep: (B,) or int, diffusion step - local_cond: (B,T,local_cond_dim) global_cond: (B,global_cond_dim) output: (B,T,input_dim) """ @@ -252,23 +232,11 @@ class ConditionalUnet1D(nn.Module): global_feature = torch.cat([ global_feature, global_cond ], axis=-1) - - # encode local features - h_local = list() - if local_cond is not None: - local_cond = einops.rearrange(local_cond, 'b h t -> b t h') - resnet, resnet2 = self.local_cond_encoder - x = resnet(local_cond, global_feature) - h_local.append(x) - x = resnet2(local_cond, global_feature) - h_local.append(x) - + x = sample h = [] for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules): x = resnet(x, global_feature) - if idx == 0 and len(h_local) > 0: - x = x + h_local[0] x = resnet2(x, global_feature) h.append(x) x = downsample(x) @@ -279,12 +247,6 @@ class ConditionalUnet1D(nn.Module): for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules): x = torch.cat((x, h.pop()), dim=1) x = resnet(x, global_feature) - # The correct condition should be: - # if idx == (len(self.up_modules)-1) and len(h_local) > 0: - # However this change will break compatibility with published checkpoints. - # Therefore it is left as a comment. - if idx == len(self.up_modules) and len(h_local) > 0: - x = x + h_local[1] x = resnet2(x, global_feature) x = upsample(x)