debug
This commit is contained in:
@@ -101,20 +101,9 @@ class ResNetDiffusionBackbone(VLABackbone):
|
||||
crop_is_random: bool = True,
|
||||
use_group_norm: bool = True,
|
||||
spatial_softmax_num_keypoints: int = 32,
|
||||
use_separate_rgb_encoder_per_camera: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# 保存所有参数作为实例变量
|
||||
self.vision_backbone = vision_backbone
|
||||
self.pretrained_backbone_weights = pretrained_backbone_weights
|
||||
self.input_shape = input_shape
|
||||
self.crop_shape = crop_shape
|
||||
self.crop_is_random = crop_is_random
|
||||
self.use_group_norm = use_group_norm
|
||||
self.spatial_softmax_num_keypoints = spatial_softmax_num_keypoints
|
||||
self.use_separate_rgb_encoder_per_camera = use_separate_rgb_encoder_per_camera
|
||||
|
||||
|
||||
# 设置可选的预处理。
|
||||
if crop_shape is not None:
|
||||
self.do_crop = True
|
||||
@@ -126,120 +115,49 @@ class ResNetDiffusionBackbone(VLABackbone):
|
||||
self.maybe_random_crop = self.center_crop
|
||||
else:
|
||||
self.do_crop = False
|
||||
self.crop_shape = input_shape[1:]
|
||||
crop_shape = input_shape[1:]
|
||||
|
||||
# 创建骨干网络的内部函数
|
||||
def _create_backbone():
|
||||
backbone_model = getattr(torchvision.models, vision_backbone)(
|
||||
weights=pretrained_backbone_weights
|
||||
)
|
||||
# 移除 AvgPool 和 FC (假设 layer4 是 children()[-3])
|
||||
backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
|
||||
if use_group_norm:
|
||||
backbone = _replace_submodules(
|
||||
root_module=backbone,
|
||||
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
||||
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
|
||||
)
|
||||
return backbone
|
||||
|
||||
# 创建池化和最终层的内部函数
|
||||
def _create_head(feature_map_shape):
|
||||
pool = SpatialSoftmax(feature_map_shape, num_kp=spatial_softmax_num_keypoints)
|
||||
feature_dim = spatial_softmax_num_keypoints * 2
|
||||
out = nn.Linear(spatial_softmax_num_keypoints * 2, feature_dim)
|
||||
relu = nn.ReLU()
|
||||
return pool, feature_dim, out, relu
|
||||
|
||||
# 使用试运行来获取特征图形状
|
||||
dummy_shape = (1, input_shape[0], *self.crop_shape)
|
||||
|
||||
if self.use_separate_rgb_encoder_per_camera:
|
||||
# 每个相机使用独立的编码器,我们先创建一个临时骨干网络来获取特征图形状
|
||||
temp_backbone = _create_backbone()
|
||||
with torch.no_grad():
|
||||
dummy_out = temp_backbone(torch.zeros(dummy_shape))
|
||||
feature_map_shape = dummy_out.shape[1:] # (C, H, W)
|
||||
del temp_backbone
|
||||
|
||||
# 注意:我们在 forward 方法中动态创建编码器,或者在知道相机数量时创建
|
||||
# 这里我们先不创建具体的编码器实例,而是在 forward 时根据需要创建
|
||||
# 或者,我们可以要求用户提供相机数量参数
|
||||
self.camera_encoders = None
|
||||
self.feature_dim = spatial_softmax_num_keypoints * 2
|
||||
else:
|
||||
# 所有相机共享同一个编码器
|
||||
self.backbone = _create_backbone()
|
||||
with torch.no_grad():
|
||||
dummy_out = self.backbone(torch.zeros(dummy_shape))
|
||||
feature_map_shape = dummy_out.shape[1:] # (C, H, W)
|
||||
self.pool, self.feature_dim, self.out, self.relu = _create_head(feature_map_shape)
|
||||
|
||||
def _create_single_encoder(self):
|
||||
"""内部方法:创建单个编码器(骨干网络 + 池化 + 输出层)"""
|
||||
# 创建骨干网络
|
||||
backbone_model = getattr(torchvision.models, self.vision_backbone)(
|
||||
weights=self.pretrained_backbone_weights
|
||||
# 设置骨干网络。
|
||||
backbone_model = getattr(torchvision.models, vision_backbone)(
|
||||
weights=pretrained_backbone_weights
|
||||
)
|
||||
backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
|
||||
|
||||
if self.use_group_norm:
|
||||
backbone = _replace_submodules(
|
||||
root_module=backbone,
|
||||
|
||||
# 移除 AvgPool 和 FC (假设 layer4 是 children()[-3])
|
||||
self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
|
||||
|
||||
if use_group_norm:
|
||||
self.backbone = _replace_submodules(
|
||||
root_module=self.backbone,
|
||||
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
||||
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
|
||||
)
|
||||
|
||||
# 获取特征图形状
|
||||
dummy_shape = (1, self.input_shape[0], *self.crop_shape)
|
||||
# 设置池化和最终层。
|
||||
# 使用试运行来获取特征图形状。
|
||||
dummy_shape = (1, input_shape[0], *crop_shape)
|
||||
with torch.no_grad():
|
||||
dummy_out = backbone(torch.zeros(dummy_shape))
|
||||
feature_map_shape = dummy_out.shape[1:]
|
||||
dummy_out = self.backbone(torch.zeros(dummy_shape))
|
||||
feature_map_shape = dummy_out.shape[1:] # (C, H, W)
|
||||
|
||||
# 创建池化和输出层
|
||||
pool = SpatialSoftmax(feature_map_shape, num_kp=self.spatial_softmax_num_keypoints)
|
||||
out = nn.Linear(self.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||
relu = nn.ReLU()
|
||||
self.pool = SpatialSoftmax(feature_map_shape, num_kp=spatial_softmax_num_keypoints)
|
||||
self.feature_dim = spatial_softmax_num_keypoints * 2
|
||||
self.out = nn.Linear(spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
return nn.ModuleList([backbone, pool, out, relu])
|
||||
|
||||
def forward_single_image(self, x: torch.Tensor, encoder: nn.ModuleList = None) -> torch.Tensor:
|
||||
def forward_single_image(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.do_crop:
|
||||
x = self.maybe_random_crop(x) if self.training else self.center_crop(x)
|
||||
|
||||
if self.use_separate_rgb_encoder_per_camera:
|
||||
# 使用独立编码器
|
||||
backbone, pool, out, relu = encoder
|
||||
x = relu(out(torch.flatten(pool(backbone(x)), start_dim=1)))
|
||||
else:
|
||||
# 使用共享编码器
|
||||
x = self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)))
|
||||
x = self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)))
|
||||
return x
|
||||
|
||||
def forward(self, images):
|
||||
any_tensor = next(iter(images.values()))
|
||||
B, T = any_tensor.shape[:2]
|
||||
features_all = []
|
||||
|
||||
# 检查是否需要初始化独立编码器
|
||||
if self.use_separate_rgb_encoder_per_camera and self.camera_encoders is None:
|
||||
self.camera_encoders = nn.ModuleDict()
|
||||
for cam_name in sorted(images.keys()):
|
||||
self.camera_encoders[cam_name] = self._create_single_encoder()
|
||||
|
||||
for cam_name in sorted(images.keys()):
|
||||
img = images[cam_name]
|
||||
if self.use_separate_rgb_encoder_per_camera:
|
||||
# 使用该相机对应的独立编码器
|
||||
features = self.forward_single_image(
|
||||
img.view(B * T, *img.shape[2:]),
|
||||
self.camera_encoders[cam_name]
|
||||
)
|
||||
else:
|
||||
# 使用共享编码器
|
||||
features = self.forward_single_image(img.view(B * T, *img.shape[2:]))
|
||||
features = self.forward_single_image(img.view(B * T, *img.shape[2:]))
|
||||
features_all.append(features)
|
||||
|
||||
return torch.cat(features_all, dim=1).view(B, T, -1)
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user