refactor: 重构resnet

This commit is contained in:
gouhanke
2026-02-10 15:26:10 +08:00
parent 88b9c10a75
commit 3c27d6d793
3 changed files with 300 additions and 3 deletions

View File

@@ -1,6 +1,7 @@
# @package agent # @package agent
defaults: defaults:
- /backbone@vision_backbone: resnet # - /backbone@vision_backbone: resnet
- /backbone@vision_backbone: resnet_diffusion
- /modules@state_encoder: identity_state_encoder - /modules@state_encoder: identity_state_encoder
- /modules@action_encoder: identity_action_encoder - /modules@action_encoder: identity_action_encoder
- /head: conditional_unet1d - /head: conditional_unet1d
@@ -16,8 +17,6 @@ obs_dim: 16
pred_horizon: 16 pred_horizon: 16
obs_horizon: 2 obs_horizon: 2
# Diffusion Parameters
# diffusion_steps: 100 (这些参数应该移到 head 配置中,或者通过变量传递)
# Camera Configuration # Camera Configuration
num_cams: ${len:${data.camera_names}} # 自动从 data.camera_names 列表长度获取 num_cams: ${len:${data.camera_names}} # 自动从 data.camera_names 列表长度获取

View File

@@ -0,0 +1,9 @@
_target_: roboimi.vla.models.backbones.resnet_diffusion.ResNetDiffusionBackbone
vision_backbone: "resnet18"
pretrained_backbone_weights: null
input_shape: [3, 96, 96]
crop_shape: [84, 84]
crop_is_random: true
use_group_norm: true
spatial_softmax_num_keypoints: 32
use_separate_rgb_encoder_per_camera: true

View File

@@ -0,0 +1,289 @@
from roboimi.vla.core.interfaces import VLABackbone
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import numpy as np
from typing import Callable, Optional, Tuple, Union
def _replace_submodules(
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
) -> nn.Module:
"""
Args:
root_module: 需要替换子模块的根模块
predicate: 接受一个模块作为参数,如果该模块需要被替换则返回 True。
func: 接受一个模块作为参数,并返回一个新的模块来替换它。
Returns:
子模块已被替换的根模块。
"""
if predicate(root_module):
return func(root_module)
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
for *parents, k in replace_list:
parent_module = root_module
if len(parents) > 0:
parent_module = root_module.get_submodule(".".join(parents))
if isinstance(parent_module, nn.Sequential):
src_module = parent_module[int(k)]
else:
src_module = getattr(parent_module, k)
tgt_module = func(src_module)
if isinstance(parent_module, nn.Sequential):
parent_module[int(k)] = tgt_module
else:
setattr(parent_module, k, tgt_module)
# 验证所有 BN 是否已被替换
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
return root_module
class SpatialSoftmax(nn.Module):
"""
Finn 等人在 "Deep Spatial Autoencoders for Visuomotor Learning" 中描述的空间软 Argmax 操作
(https://huggingface.co/papers/1509.06113)。这是 robomimic 实现的一个最小移植版本。
"""
def __init__(self, input_shape, num_kp=None):
"""
Args:
input_shape (list): (C, H, W) 输入特征图形状。
num_kp (int): 输出中的关键点数量。如果为 None输出将具有与输入相同的通道数。
"""
super().__init__()
assert len(input_shape) == 3
self._in_c, self._in_h, self._in_w = input_shape
if num_kp is not None:
self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
self._out_c = num_kp
else:
self.nets = None
self._out_c = self._in_c
# 我们可以直接使用 torch.linspace但这似乎与 numpy 的行为略有不同
# 并且会导致预训练模型的 pc_success 略有下降。
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
# 注册为 buffer以便将其移动到正确的设备。
self.register_buffer("pos_grid", torch.cat([pos_x, pos_y], dim=1))
def forward(self, features: torch.Tensor) -> torch.Tensor:
"""
Args:
features: (B, C, H, W) 输入特征图。
Returns:
(B, K, 2) 关键点的图像空间坐标。
"""
if self.nets is not None:
features = self.nets(features)
# [B, K, H, W] -> [B * K, H * W],其中 K 是关键点数量
features = features.reshape(-1, self._in_h * self._in_w)
# 2d softmax 归一化
attention = F.softmax(features, dim=-1)
# [B * K, H * W] x [H * W, 2] -> [B * K, 2] 用于 x 和 y 维度的空间坐标均值
expected_xy = attention @ self.pos_grid
# 重塑为 [B, K, 2]
feature_keypoints = expected_xy.view(-1, self._out_c, 2)
return feature_keypoints
class ResNetDiffusionBackbone(VLABackbone):
def __init__(
self,
vision_backbone: str = "resnet18",
pretrained_backbone_weights: str | None = None,
input_shape: Tuple[int, int, int] = (3, 84, 84), # (C, H, W)
crop_shape: Optional[Tuple[int, int]] = None,
crop_is_random: bool = True,
use_group_norm: bool = True,
spatial_softmax_num_keypoints: int = 32,
use_separate_rgb_encoder_per_camera: bool = True,
):
super().__init__()
# 保存所有参数作为实例变量
self.vision_backbone = vision_backbone
self.pretrained_backbone_weights = pretrained_backbone_weights
self.input_shape = input_shape
self.crop_shape = crop_shape
self.crop_is_random = crop_is_random
self.use_group_norm = use_group_norm
self.spatial_softmax_num_keypoints = spatial_softmax_num_keypoints
self.use_separate_rgb_encoder_per_camera = use_separate_rgb_encoder_per_camera
# 设置可选的预处理。
if crop_shape is not None:
self.do_crop = True
# 评估时始终使用中心裁剪
self.center_crop = torchvision.transforms.CenterCrop(crop_shape)
if crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(crop_shape)
else:
self.maybe_random_crop = self.center_crop
else:
self.do_crop = False
self.crop_shape = input_shape[1:]
# 创建骨干网络的内部函数
def _create_backbone():
backbone_model = getattr(torchvision.models, vision_backbone)(
weights=pretrained_backbone_weights
)
# 移除 AvgPool 和 FC (假设 layer4 是 children()[-3])
backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
if use_group_norm:
backbone = _replace_submodules(
root_module=backbone,
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
)
return backbone
# 创建池化和最终层的内部函数
def _create_head(feature_map_shape):
pool = SpatialSoftmax(feature_map_shape, num_kp=spatial_softmax_num_keypoints)
feature_dim = spatial_softmax_num_keypoints * 2
out = nn.Linear(spatial_softmax_num_keypoints * 2, feature_dim)
relu = nn.ReLU()
return pool, feature_dim, out, relu
# 使用试运行来获取特征图形状
dummy_shape = (1, input_shape[0], *self.crop_shape)
if self.use_separate_rgb_encoder_per_camera:
# 每个相机使用独立的编码器,我们先创建一个临时骨干网络来获取特征图形状
temp_backbone = _create_backbone()
with torch.no_grad():
dummy_out = temp_backbone(torch.zeros(dummy_shape))
feature_map_shape = dummy_out.shape[1:] # (C, H, W)
del temp_backbone
# 注意:我们在 forward 方法中动态创建编码器,或者在知道相机数量时创建
# 这里我们先不创建具体的编码器实例,而是在 forward 时根据需要创建
# 或者,我们可以要求用户提供相机数量参数
self.camera_encoders = None
self.feature_dim = spatial_softmax_num_keypoints * 2
else:
# 所有相机共享同一个编码器
self.backbone = _create_backbone()
with torch.no_grad():
dummy_out = self.backbone(torch.zeros(dummy_shape))
feature_map_shape = dummy_out.shape[1:] # (C, H, W)
self.pool, self.feature_dim, self.out, self.relu = _create_head(feature_map_shape)
def _create_single_encoder(self):
"""内部方法:创建单个编码器(骨干网络 + 池化 + 输出层)"""
# 创建骨干网络
backbone_model = getattr(torchvision.models, self.vision_backbone)(
weights=self.pretrained_backbone_weights
)
backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
if self.use_group_norm:
backbone = _replace_submodules(
root_module=backbone,
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
)
# 获取特征图形状
dummy_shape = (1, self.input_shape[0], *self.crop_shape)
with torch.no_grad():
dummy_out = backbone(torch.zeros(dummy_shape))
feature_map_shape = dummy_out.shape[1:]
# 创建池化和输出层
pool = SpatialSoftmax(feature_map_shape, num_kp=self.spatial_softmax_num_keypoints)
out = nn.Linear(self.spatial_softmax_num_keypoints * 2, self.feature_dim)
relu = nn.ReLU()
return nn.ModuleList([backbone, pool, out, relu])
def forward_single_image(self, x: torch.Tensor, encoder: nn.ModuleList = None) -> torch.Tensor:
if self.do_crop:
x = self.maybe_random_crop(x) if self.training else self.center_crop(x)
if self.use_separate_rgb_encoder_per_camera:
# 使用独立编码器
backbone, pool, out, relu = encoder
x = relu(out(torch.flatten(pool(backbone(x)), start_dim=1)))
else:
# 使用共享编码器
x = self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)))
return x
def forward(self, images):
any_tensor = next(iter(images.values()))
B, T = any_tensor.shape[:2]
features_all = []
# 检查是否需要初始化独立编码器
if self.use_separate_rgb_encoder_per_camera and self.camera_encoders is None:
self.camera_encoders = nn.ModuleDict()
for cam_name in sorted(images.keys()):
self.camera_encoders[cam_name] = self._create_single_encoder()
for cam_name in sorted(images.keys()):
img = images[cam_name]
if self.use_separate_rgb_encoder_per_camera:
# 使用该相机对应的独立编码器
features = self.forward_single_image(
img.view(B * T, *img.shape[2:]),
self.camera_encoders[cam_name]
)
else:
# 使用共享编码器
features = self.forward_single_image(img.view(B * T, *img.shape[2:]))
features_all.append(features)
return torch.cat(features_all, dim=1).view(B, T, -1)
@property
def output_dim(self):
return self.feature_dim
if __name__ == "__main__":
print("🚀 Testing ResNetDiffusionBackbone...")
# Configuration
B, T = 2, 5
C, H, W = 3, 96, 96
crop_h, crop_w = 84, 84
num_keypoints = 32
feature_dim_per_cam = num_keypoints * 2
# Instantiate model
backbone = ResNetDiffusionBackbone(
vision_backbone="resnet18",
pretrained_backbone_weights=None, # Speed up test
input_shape=(C, H, W),
crop_shape=(crop_h, crop_w),
crop_is_random=True,
use_group_norm=True,
spatial_softmax_num_keypoints=num_keypoints
)
print(f"✅ Model instantiated. Output dim per camera: {backbone.output_dim}")
# Create dummy input
images = {
"cam_high": torch.randn(B, T, C, H, W),
"cam_wrist": torch.randn(B, T, C, H, W)
}
# Forward pass
print("🔄 Running forward pass...")
output = backbone(images)
print(f"Input shapes: {[v.shape for v in images.values()]}")
print(f"Output shape: {output.shape}")
# Verification
expected_dim = len(images) * feature_dim_per_cam
assert output.shape == (B, T, expected_dim), f"Expected shape {(B, T, expected_dim)}, got {output.shape}"
print("✨ Test passed!")