diff --git a/roboimi/vla/conf/agent/resnet_diffusion.yaml b/roboimi/vla/conf/agent/resnet_diffusion.yaml index 2874672..b9ab4e4 100644 --- a/roboimi/vla/conf/agent/resnet_diffusion.yaml +++ b/roboimi/vla/conf/agent/resnet_diffusion.yaml @@ -1,6 +1,7 @@ # @package agent defaults: - - /backbone@vision_backbone: resnet + # - /backbone@vision_backbone: resnet + - /backbone@vision_backbone: resnet_diffusion - /modules@state_encoder: identity_state_encoder - /modules@action_encoder: identity_action_encoder - /head: conditional_unet1d @@ -16,8 +17,6 @@ obs_dim: 16 pred_horizon: 16 obs_horizon: 2 -# Diffusion Parameters -# diffusion_steps: 100 (这些参数应该移到 head 配置中,或者通过变量传递) # Camera Configuration num_cams: ${len:${data.camera_names}} # 自动从 data.camera_names 列表长度获取 \ No newline at end of file diff --git a/roboimi/vla/conf/backbone/resnet_diffusion.yaml b/roboimi/vla/conf/backbone/resnet_diffusion.yaml new file mode 100644 index 0000000..d8fd5b2 --- /dev/null +++ b/roboimi/vla/conf/backbone/resnet_diffusion.yaml @@ -0,0 +1,9 @@ +_target_: roboimi.vla.models.backbones.resnet_diffusion.ResNetDiffusionBackbone +vision_backbone: "resnet18" +pretrained_backbone_weights: null +input_shape: [3, 96, 96] +crop_shape: [84, 84] +crop_is_random: true +use_group_norm: true +spatial_softmax_num_keypoints: 32 +use_separate_rgb_encoder_per_camera: true \ No newline at end of file diff --git a/roboimi/vla/models/backbones/resnet_diffusion.py b/roboimi/vla/models/backbones/resnet_diffusion.py new file mode 100644 index 0000000..afb7c65 --- /dev/null +++ b/roboimi/vla/models/backbones/resnet_diffusion.py @@ -0,0 +1,289 @@ +from roboimi.vla.core.interfaces import VLABackbone +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision +import numpy as np +from typing import Callable, Optional, Tuple, Union + +def _replace_submodules( + root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module] +) -> nn.Module: + """ + Args: + root_module: 需要替换子模块的根模块 + predicate: 接受一个模块作为参数,如果该模块需要被替换则返回 True。 + func: 接受一个模块作为参数,并返回一个新的模块来替换它。 + Returns: + 子模块已被替换的根模块。 + """ + if predicate(root_module): + return func(root_module) + + replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)] + for *parents, k in replace_list: + parent_module = root_module + if len(parents) > 0: + parent_module = root_module.get_submodule(".".join(parents)) + if isinstance(parent_module, nn.Sequential): + src_module = parent_module[int(k)] + else: + src_module = getattr(parent_module, k) + tgt_module = func(src_module) + if isinstance(parent_module, nn.Sequential): + parent_module[int(k)] = tgt_module + else: + setattr(parent_module, k, tgt_module) + # 验证所有 BN 是否已被替换 + assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)) + return root_module + +class SpatialSoftmax(nn.Module): + """ + Finn 等人在 "Deep Spatial Autoencoders for Visuomotor Learning" 中描述的空间软 Argmax 操作 + (https://huggingface.co/papers/1509.06113)。这是 robomimic 实现的一个最小移植版本。 + """ + + def __init__(self, input_shape, num_kp=None): + """ + Args: + input_shape (list): (C, H, W) 输入特征图形状。 + num_kp (int): 输出中的关键点数量。如果为 None,输出将具有与输入相同的通道数。 + """ + super().__init__() + + assert len(input_shape) == 3 + self._in_c, self._in_h, self._in_w = input_shape + + if num_kp is not None: + self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1) + self._out_c = num_kp + else: + self.nets = None + self._out_c = self._in_c + + # 我们可以直接使用 torch.linspace,但这似乎与 numpy 的行为略有不同 + # 并且会导致预训练模型的 pc_success 略有下降。 + pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)) + pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float() + pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float() + # 注册为 buffer,以便将其移动到正确的设备。 + self.register_buffer("pos_grid", torch.cat([pos_x, pos_y], dim=1)) + + def forward(self, features: torch.Tensor) -> torch.Tensor: + """ + Args: + features: (B, C, H, W) 输入特征图。 + Returns: + (B, K, 2) 关键点的图像空间坐标。 + """ + if self.nets is not None: + features = self.nets(features) + + # [B, K, H, W] -> [B * K, H * W],其中 K 是关键点数量 + features = features.reshape(-1, self._in_h * self._in_w) + # 2d softmax 归一化 + attention = F.softmax(features, dim=-1) + # [B * K, H * W] x [H * W, 2] -> [B * K, 2] 用于 x 和 y 维度的空间坐标均值 + expected_xy = attention @ self.pos_grid + # 重塑为 [B, K, 2] + feature_keypoints = expected_xy.view(-1, self._out_c, 2) + + return feature_keypoints + +class ResNetDiffusionBackbone(VLABackbone): + def __init__( + self, + vision_backbone: str = "resnet18", + pretrained_backbone_weights: str | None = None, + input_shape: Tuple[int, int, int] = (3, 84, 84), # (C, H, W) + crop_shape: Optional[Tuple[int, int]] = None, + crop_is_random: bool = True, + use_group_norm: bool = True, + spatial_softmax_num_keypoints: int = 32, + use_separate_rgb_encoder_per_camera: bool = True, + ): + super().__init__() + + # 保存所有参数作为实例变量 + self.vision_backbone = vision_backbone + self.pretrained_backbone_weights = pretrained_backbone_weights + self.input_shape = input_shape + self.crop_shape = crop_shape + self.crop_is_random = crop_is_random + self.use_group_norm = use_group_norm + self.spatial_softmax_num_keypoints = spatial_softmax_num_keypoints + self.use_separate_rgb_encoder_per_camera = use_separate_rgb_encoder_per_camera + + # 设置可选的预处理。 + if crop_shape is not None: + self.do_crop = True + # 评估时始终使用中心裁剪 + self.center_crop = torchvision.transforms.CenterCrop(crop_shape) + if crop_is_random: + self.maybe_random_crop = torchvision.transforms.RandomCrop(crop_shape) + else: + self.maybe_random_crop = self.center_crop + else: + self.do_crop = False + self.crop_shape = input_shape[1:] + + # 创建骨干网络的内部函数 + def _create_backbone(): + backbone_model = getattr(torchvision.models, vision_backbone)( + weights=pretrained_backbone_weights + ) + # 移除 AvgPool 和 FC (假设 layer4 是 children()[-3]) + backbone = nn.Sequential(*(list(backbone_model.children())[:-2])) + if use_group_norm: + backbone = _replace_submodules( + root_module=backbone, + predicate=lambda x: isinstance(x, nn.BatchNorm2d), + func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features), + ) + return backbone + + # 创建池化和最终层的内部函数 + def _create_head(feature_map_shape): + pool = SpatialSoftmax(feature_map_shape, num_kp=spatial_softmax_num_keypoints) + feature_dim = spatial_softmax_num_keypoints * 2 + out = nn.Linear(spatial_softmax_num_keypoints * 2, feature_dim) + relu = nn.ReLU() + return pool, feature_dim, out, relu + + # 使用试运行来获取特征图形状 + dummy_shape = (1, input_shape[0], *self.crop_shape) + + if self.use_separate_rgb_encoder_per_camera: + # 每个相机使用独立的编码器,我们先创建一个临时骨干网络来获取特征图形状 + temp_backbone = _create_backbone() + with torch.no_grad(): + dummy_out = temp_backbone(torch.zeros(dummy_shape)) + feature_map_shape = dummy_out.shape[1:] # (C, H, W) + del temp_backbone + + # 注意:我们在 forward 方法中动态创建编码器,或者在知道相机数量时创建 + # 这里我们先不创建具体的编码器实例,而是在 forward 时根据需要创建 + # 或者,我们可以要求用户提供相机数量参数 + self.camera_encoders = None + self.feature_dim = spatial_softmax_num_keypoints * 2 + else: + # 所有相机共享同一个编码器 + self.backbone = _create_backbone() + with torch.no_grad(): + dummy_out = self.backbone(torch.zeros(dummy_shape)) + feature_map_shape = dummy_out.shape[1:] # (C, H, W) + self.pool, self.feature_dim, self.out, self.relu = _create_head(feature_map_shape) + + def _create_single_encoder(self): + """内部方法:创建单个编码器(骨干网络 + 池化 + 输出层)""" + # 创建骨干网络 + backbone_model = getattr(torchvision.models, self.vision_backbone)( + weights=self.pretrained_backbone_weights + ) + backbone = nn.Sequential(*(list(backbone_model.children())[:-2])) + + if self.use_group_norm: + backbone = _replace_submodules( + root_module=backbone, + predicate=lambda x: isinstance(x, nn.BatchNorm2d), + func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features), + ) + + # 获取特征图形状 + dummy_shape = (1, self.input_shape[0], *self.crop_shape) + with torch.no_grad(): + dummy_out = backbone(torch.zeros(dummy_shape)) + feature_map_shape = dummy_out.shape[1:] + + # 创建池化和输出层 + pool = SpatialSoftmax(feature_map_shape, num_kp=self.spatial_softmax_num_keypoints) + out = nn.Linear(self.spatial_softmax_num_keypoints * 2, self.feature_dim) + relu = nn.ReLU() + + return nn.ModuleList([backbone, pool, out, relu]) + + def forward_single_image(self, x: torch.Tensor, encoder: nn.ModuleList = None) -> torch.Tensor: + if self.do_crop: + x = self.maybe_random_crop(x) if self.training else self.center_crop(x) + + if self.use_separate_rgb_encoder_per_camera: + # 使用独立编码器 + backbone, pool, out, relu = encoder + x = relu(out(torch.flatten(pool(backbone(x)), start_dim=1))) + else: + # 使用共享编码器 + x = self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1))) + return x + + def forward(self, images): + any_tensor = next(iter(images.values())) + B, T = any_tensor.shape[:2] + features_all = [] + + # 检查是否需要初始化独立编码器 + if self.use_separate_rgb_encoder_per_camera and self.camera_encoders is None: + self.camera_encoders = nn.ModuleDict() + for cam_name in sorted(images.keys()): + self.camera_encoders[cam_name] = self._create_single_encoder() + + for cam_name in sorted(images.keys()): + img = images[cam_name] + if self.use_separate_rgb_encoder_per_camera: + # 使用该相机对应的独立编码器 + features = self.forward_single_image( + img.view(B * T, *img.shape[2:]), + self.camera_encoders[cam_name] + ) + else: + # 使用共享编码器 + features = self.forward_single_image(img.view(B * T, *img.shape[2:])) + features_all.append(features) + + return torch.cat(features_all, dim=1).view(B, T, -1) + + @property + def output_dim(self): + return self.feature_dim + +if __name__ == "__main__": + print("🚀 Testing ResNetDiffusionBackbone...") + + # Configuration + B, T = 2, 5 + C, H, W = 3, 96, 96 + crop_h, crop_w = 84, 84 + num_keypoints = 32 + feature_dim_per_cam = num_keypoints * 2 + + # Instantiate model + backbone = ResNetDiffusionBackbone( + vision_backbone="resnet18", + pretrained_backbone_weights=None, # Speed up test + input_shape=(C, H, W), + crop_shape=(crop_h, crop_w), + crop_is_random=True, + use_group_norm=True, + spatial_softmax_num_keypoints=num_keypoints + ) + + print(f"✅ Model instantiated. Output dim per camera: {backbone.output_dim}") + + # Create dummy input + images = { + "cam_high": torch.randn(B, T, C, H, W), + "cam_wrist": torch.randn(B, T, C, H, W) + } + + # Forward pass + print("🔄 Running forward pass...") + output = backbone(images) + + print(f"Input shapes: {[v.shape for v in images.values()]}") + print(f"Output shape: {output.shape}") + + # Verification + expected_dim = len(images) * feature_dim_per_cam + assert output.shape == (B, T, expected_dim), f"Expected shape {(B, T, expected_dim)}, got {output.shape}" + + print("✨ Test passed!") \ No newline at end of file