This commit is contained in:
gouhanke
2026-02-10 15:56:05 +08:00
parent 3c27d6d793
commit 1e95d40bf9
4 changed files with 790 additions and 109 deletions

View File

@@ -101,20 +101,9 @@ class ResNetDiffusionBackbone(VLABackbone):
crop_is_random: bool = True,
use_group_norm: bool = True,
spatial_softmax_num_keypoints: int = 32,
use_separate_rgb_encoder_per_camera: bool = True,
):
super().__init__()
# 保存所有参数作为实例变量
self.vision_backbone = vision_backbone
self.pretrained_backbone_weights = pretrained_backbone_weights
self.input_shape = input_shape
self.crop_shape = crop_shape
self.crop_is_random = crop_is_random
self.use_group_norm = use_group_norm
self.spatial_softmax_num_keypoints = spatial_softmax_num_keypoints
self.use_separate_rgb_encoder_per_camera = use_separate_rgb_encoder_per_camera
# 设置可选的预处理。
if crop_shape is not None:
self.do_crop = True
@@ -126,120 +115,49 @@ class ResNetDiffusionBackbone(VLABackbone):
self.maybe_random_crop = self.center_crop
else:
self.do_crop = False
self.crop_shape = input_shape[1:]
crop_shape = input_shape[1:]
# 创建骨干网络的内部函数
def _create_backbone():
backbone_model = getattr(torchvision.models, vision_backbone)(
weights=pretrained_backbone_weights
)
# 移除 AvgPool 和 FC (假设 layer4 是 children()[-3])
backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
if use_group_norm:
backbone = _replace_submodules(
root_module=backbone,
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
)
return backbone
# 创建池化和最终层的内部函数
def _create_head(feature_map_shape):
pool = SpatialSoftmax(feature_map_shape, num_kp=spatial_softmax_num_keypoints)
feature_dim = spatial_softmax_num_keypoints * 2
out = nn.Linear(spatial_softmax_num_keypoints * 2, feature_dim)
relu = nn.ReLU()
return pool, feature_dim, out, relu
# 使用试运行来获取特征图形状
dummy_shape = (1, input_shape[0], *self.crop_shape)
if self.use_separate_rgb_encoder_per_camera:
# 每个相机使用独立的编码器,我们先创建一个临时骨干网络来获取特征图形状
temp_backbone = _create_backbone()
with torch.no_grad():
dummy_out = temp_backbone(torch.zeros(dummy_shape))
feature_map_shape = dummy_out.shape[1:] # (C, H, W)
del temp_backbone
# 注意:我们在 forward 方法中动态创建编码器,或者在知道相机数量时创建
# 这里我们先不创建具体的编码器实例,而是在 forward 时根据需要创建
# 或者,我们可以要求用户提供相机数量参数
self.camera_encoders = None
self.feature_dim = spatial_softmax_num_keypoints * 2
else:
# 所有相机共享同一个编码器
self.backbone = _create_backbone()
with torch.no_grad():
dummy_out = self.backbone(torch.zeros(dummy_shape))
feature_map_shape = dummy_out.shape[1:] # (C, H, W)
self.pool, self.feature_dim, self.out, self.relu = _create_head(feature_map_shape)
def _create_single_encoder(self):
"""内部方法:创建单个编码器(骨干网络 + 池化 + 输出层)"""
# 创建骨干网络
backbone_model = getattr(torchvision.models, self.vision_backbone)(
weights=self.pretrained_backbone_weights
# 设置骨干网络
backbone_model = getattr(torchvision.models, vision_backbone)(
weights=pretrained_backbone_weights
)
backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
if self.use_group_norm:
backbone = _replace_submodules(
root_module=backbone,
# 移除 AvgPool 和 FC (假设 layer4 是 children()[-3])
self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
if use_group_norm:
self.backbone = _replace_submodules(
root_module=self.backbone,
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
)
# 获取特征图形状
dummy_shape = (1, self.input_shape[0], *self.crop_shape)
# 设置池化和最终层。
# 使用试运行来获取特征图形状。
dummy_shape = (1, input_shape[0], *crop_shape)
with torch.no_grad():
dummy_out = backbone(torch.zeros(dummy_shape))
feature_map_shape = dummy_out.shape[1:]
dummy_out = self.backbone(torch.zeros(dummy_shape))
feature_map_shape = dummy_out.shape[1:] # (C, H, W)
# 创建池化和输出层
pool = SpatialSoftmax(feature_map_shape, num_kp=self.spatial_softmax_num_keypoints)
out = nn.Linear(self.spatial_softmax_num_keypoints * 2, self.feature_dim)
relu = nn.ReLU()
self.pool = SpatialSoftmax(feature_map_shape, num_kp=spatial_softmax_num_keypoints)
self.feature_dim = spatial_softmax_num_keypoints * 2
self.out = nn.Linear(spatial_softmax_num_keypoints * 2, self.feature_dim)
self.relu = nn.ReLU()
return nn.ModuleList([backbone, pool, out, relu])
def forward_single_image(self, x: torch.Tensor, encoder: nn.ModuleList = None) -> torch.Tensor:
def forward_single_image(self, x: torch.Tensor) -> torch.Tensor:
if self.do_crop:
x = self.maybe_random_crop(x) if self.training else self.center_crop(x)
if self.use_separate_rgb_encoder_per_camera:
# 使用独立编码器
backbone, pool, out, relu = encoder
x = relu(out(torch.flatten(pool(backbone(x)), start_dim=1)))
else:
# 使用共享编码器
x = self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)))
x = self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)))
return x
def forward(self, images):
any_tensor = next(iter(images.values()))
B, T = any_tensor.shape[:2]
features_all = []
# 检查是否需要初始化独立编码器
if self.use_separate_rgb_encoder_per_camera and self.camera_encoders is None:
self.camera_encoders = nn.ModuleDict()
for cam_name in sorted(images.keys()):
self.camera_encoders[cam_name] = self._create_single_encoder()
for cam_name in sorted(images.keys()):
img = images[cam_name]
if self.use_separate_rgb_encoder_per_camera:
# 使用该相机对应的独立编码器
features = self.forward_single_image(
img.view(B * T, *img.shape[2:]),
self.camera_encoders[cam_name]
)
else:
# 使用共享编码器
features = self.forward_single_image(img.view(B * T, *img.shape[2:]))
features = self.forward_single_image(img.view(B * T, *img.shape[2:]))
features_all.append(features)
return torch.cat(features_all, dim=1).view(B, T, -1)
@property