feat: 冻结resnet

This commit is contained in:
gouhanke
2026-02-11 20:11:25 +08:00
parent 83d11ab640
commit eeb07cad15
4 changed files with 33 additions and 18 deletions

View File

@@ -102,6 +102,7 @@ class _SingleRgbEncoder(nn.Module):
crop_is_random: bool,
use_group_norm: bool,
spatial_softmax_num_keypoints: int,
freeze_backbone: bool = True, # 新增是否冻结backbone
):
super().__init__()
@@ -133,6 +134,11 @@ class _SingleRgbEncoder(nn.Module):
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
)
# 冻结backbone参数可选
if freeze_backbone:
for param in self.backbone.parameters():
param.requires_grad = False
# 设置池化和最终层
# 使用试运行来获取特征图形状
dummy_shape = (1, input_shape[0], *crop_shape)
@@ -164,6 +170,7 @@ class ResNetDiffusionBackbone(VLABackbone):
spatial_softmax_num_keypoints: int = 32,
use_separate_rgb_encoder_per_camera: bool = False, # 新增:是否为每个摄像头使用独立编码器
num_cameras: int = 1, # 新增:摄像头数量(仅在独立编码器模式下使用)
freeze_backbone: bool = True, # 新增是否冻结ResNet backbone推荐True
):
super().__init__()
@@ -181,6 +188,7 @@ class ResNetDiffusionBackbone(VLABackbone):
crop_is_random=crop_is_random,
use_group_norm=use_group_norm,
spatial_softmax_num_keypoints=spatial_softmax_num_keypoints,
freeze_backbone=freeze_backbone,
)
for _ in range(num_cameras)
]
@@ -197,6 +205,7 @@ class ResNetDiffusionBackbone(VLABackbone):
crop_is_random=crop_is_random,
use_group_norm=use_group_norm,
spatial_softmax_num_keypoints=spatial_softmax_num_keypoints,
freeze_backbone=freeze_backbone,
)
self.feature_dim = self.rgb_encoder.feature_dim

View File

@@ -225,27 +225,15 @@ class ConditionalUnet1D(nn.Module):
def forward(self,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
local_cond=None, global_cond=None,
visual_features=None, proprioception=None,
local_cond=None, global_cond=None,
**kwargs):
"""
x: (B,T,input_dim)
timestep: (B,) or int, diffusion step
local_cond: (B,T,local_cond_dim)
global_cond: (B,global_cond_dim)
visual_features: (B, T_obs, D_vis)
proprioception: (B, T_obs, D_prop)
output: (B,T,input_dim)
"""
if global_cond is None:
conds = []
if visual_features is not None:
conds.append(visual_features.flatten(start_dim=1))
if proprioception is not None:
conds.append(proprioception.flatten(start_dim=1))
if len(conds) > 0:
global_cond = torch.cat(conds, dim=-1)
sample = einops.rearrange(sample, 'b h t -> b t h')
# 1. time