From eeb07cad15e282dc8445d2c06030353780580b57 Mon Sep 17 00:00:00 2001 From: gouhanke <12219217+gouhanke@user.noreply.gitee.com> Date: Wed, 11 Feb 2026 20:11:25 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=86=BB=E7=BB=93resnet?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- roboimi/vla/agent.py | 21 +++++++++++++++---- .../vla/conf/backbone/resnet_diffusion.yaml | 7 ++++++- .../vla/models/backbones/resnet_diffusion.py | 9 ++++++++ .../vla/models/heads/conditional_unet1d.py | 14 +------------ 4 files changed, 33 insertions(+), 18 deletions(-) diff --git a/roboimi/vla/agent.py b/roboimi/vla/agent.py index eba3caf..1172f9e 100644 --- a/roboimi/vla/agent.py +++ b/roboimi/vla/agent.py @@ -116,12 +116,19 @@ class VLAAgent(nn.Module): action_features, noise, timesteps ) + # 拼接全局条件并展平 + # visual_features: (B, obs_horizon, vision_dim) + # state_features: (B, obs_horizon, obs_dim) + # 拼接后展平为 (B, obs_horizon * (vision_dim + obs_dim)) + global_cond = torch.cat([visual_features, state_features], dim=-1) + global_cond = global_cond.flatten(start_dim=1) + + # 5. 网络预测噪声 pred_noise = self.noise_pred_net( sample=noisy_actions, timestep=timesteps, - visual_features=visual_features, - proprioception=state_features + global_cond=global_cond ) # 6. 计算 Loss (MSE) @@ -314,12 +321,18 @@ class VLAAgent(nn.Module): for t in self.infer_scheduler.timesteps: model_input = current_actions + # 拼接全局条件并展平 + # visual_features: (B, obs_horizon, vision_dim) + # state_features: (B, obs_horizon, obs_dim) + # 拼接后展平为 (B, obs_horizon * (vision_dim + obs_dim)) + global_cond = torch.cat([visual_features, state_features], dim=-1) + global_cond = global_cond.flatten(start_dim=1) + # 预测噪声 noise_pred = self.noise_pred_net( sample=model_input, timestep=t, - visual_features=visual_features, - proprioception=state_features + global_cond=global_cond ) # 移除噪声,更新 current_actions diff --git a/roboimi/vla/conf/backbone/resnet_diffusion.yaml b/roboimi/vla/conf/backbone/resnet_diffusion.yaml index 0b985d1..2055ca7 100644 --- a/roboimi/vla/conf/backbone/resnet_diffusion.yaml +++ b/roboimi/vla/conf/backbone/resnet_diffusion.yaml @@ -4,7 +4,12 @@ _target_: roboimi.vla.models.backbones.resnet_diffusion.ResNetDiffusionBackbone # 骨干网络选择 # ==================== vision_backbone: "resnet18" # torchvision 模型名称: resnet18, resnet34, resnet50 -pretrained_backbone_weights: null # 预训练权重路径或 null(ImageNet 权重) +pretrained_backbone_weights: "IMAGENET1K_V1" # 使用ImageNet预训练权重(torchvision>=0.13) + +# ==================== +# 冻结设置 +# ==================== +freeze_backbone: true # 冻结ResNet参数,只训练后面的pool和out层(推荐:true) # ==================== # 输入配置 diff --git a/roboimi/vla/models/backbones/resnet_diffusion.py b/roboimi/vla/models/backbones/resnet_diffusion.py index 7416fec..695496d 100644 --- a/roboimi/vla/models/backbones/resnet_diffusion.py +++ b/roboimi/vla/models/backbones/resnet_diffusion.py @@ -102,6 +102,7 @@ class _SingleRgbEncoder(nn.Module): crop_is_random: bool, use_group_norm: bool, spatial_softmax_num_keypoints: int, + freeze_backbone: bool = True, # 新增:是否冻结backbone ): super().__init__() @@ -133,6 +134,11 @@ class _SingleRgbEncoder(nn.Module): func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features), ) + # 冻结backbone参数(可选) + if freeze_backbone: + for param in self.backbone.parameters(): + param.requires_grad = False + # 设置池化和最终层 # 使用试运行来获取特征图形状 dummy_shape = (1, input_shape[0], *crop_shape) @@ -164,6 +170,7 @@ class ResNetDiffusionBackbone(VLABackbone): spatial_softmax_num_keypoints: int = 32, use_separate_rgb_encoder_per_camera: bool = False, # 新增:是否为每个摄像头使用独立编码器 num_cameras: int = 1, # 新增:摄像头数量(仅在独立编码器模式下使用) + freeze_backbone: bool = True, # 新增:是否冻结ResNet backbone(推荐True) ): super().__init__() @@ -181,6 +188,7 @@ class ResNetDiffusionBackbone(VLABackbone): crop_is_random=crop_is_random, use_group_norm=use_group_norm, spatial_softmax_num_keypoints=spatial_softmax_num_keypoints, + freeze_backbone=freeze_backbone, ) for _ in range(num_cameras) ] @@ -197,6 +205,7 @@ class ResNetDiffusionBackbone(VLABackbone): crop_is_random=crop_is_random, use_group_norm=use_group_norm, spatial_softmax_num_keypoints=spatial_softmax_num_keypoints, + freeze_backbone=freeze_backbone, ) self.feature_dim = self.rgb_encoder.feature_dim diff --git a/roboimi/vla/models/heads/conditional_unet1d.py b/roboimi/vla/models/heads/conditional_unet1d.py index f468120..dae7eb8 100644 --- a/roboimi/vla/models/heads/conditional_unet1d.py +++ b/roboimi/vla/models/heads/conditional_unet1d.py @@ -225,27 +225,15 @@ class ConditionalUnet1D(nn.Module): def forward(self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], - local_cond=None, global_cond=None, - visual_features=None, proprioception=None, + local_cond=None, global_cond=None, **kwargs): """ x: (B,T,input_dim) timestep: (B,) or int, diffusion step local_cond: (B,T,local_cond_dim) global_cond: (B,global_cond_dim) - visual_features: (B, T_obs, D_vis) - proprioception: (B, T_obs, D_prop) output: (B,T,input_dim) """ - if global_cond is None: - conds = [] - if visual_features is not None: - conds.append(visual_features.flatten(start_dim=1)) - if proprioception is not None: - conds.append(proprioception.flatten(start_dim=1)) - if len(conds) > 0: - global_cond = torch.cat(conds, dim=-1) - sample = einops.rearrange(sample, 'b h t -> b t h') # 1. time