refactor: 重构resnet
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# @package agent
|
||||
defaults:
|
||||
- /backbone@vision_backbone: resnet
|
||||
# - /backbone@vision_backbone: resnet
|
||||
- /backbone@vision_backbone: resnet_diffusion
|
||||
- /modules@state_encoder: identity_state_encoder
|
||||
- /modules@action_encoder: identity_action_encoder
|
||||
- /head: conditional_unet1d
|
||||
@@ -16,8 +17,6 @@ obs_dim: 16
|
||||
pred_horizon: 16
|
||||
obs_horizon: 2
|
||||
|
||||
# Diffusion Parameters
|
||||
# diffusion_steps: 100 (这些参数应该移到 head 配置中,或者通过变量传递)
|
||||
|
||||
# Camera Configuration
|
||||
num_cams: ${len:${data.camera_names}} # 自动从 data.camera_names 列表长度获取
|
||||
9
roboimi/vla/conf/backbone/resnet_diffusion.yaml
Normal file
9
roboimi/vla/conf/backbone/resnet_diffusion.yaml
Normal file
@@ -0,0 +1,9 @@
|
||||
_target_: roboimi.vla.models.backbones.resnet_diffusion.ResNetDiffusionBackbone
|
||||
vision_backbone: "resnet18"
|
||||
pretrained_backbone_weights: null
|
||||
input_shape: [3, 96, 96]
|
||||
crop_shape: [84, 84]
|
||||
crop_is_random: true
|
||||
use_group_norm: true
|
||||
spatial_softmax_num_keypoints: 32
|
||||
use_separate_rgb_encoder_per_camera: true
|
||||
289
roboimi/vla/models/backbones/resnet_diffusion.py
Normal file
289
roboimi/vla/models/backbones/resnet_diffusion.py
Normal file
@@ -0,0 +1,289 @@
|
||||
from roboimi.vla.core.interfaces import VLABackbone
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
import numpy as np
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
def _replace_submodules(
|
||||
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Args:
|
||||
root_module: 需要替换子模块的根模块
|
||||
predicate: 接受一个模块作为参数,如果该模块需要被替换则返回 True。
|
||||
func: 接受一个模块作为参数,并返回一个新的模块来替换它。
|
||||
Returns:
|
||||
子模块已被替换的根模块。
|
||||
"""
|
||||
if predicate(root_module):
|
||||
return func(root_module)
|
||||
|
||||
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
|
||||
for *parents, k in replace_list:
|
||||
parent_module = root_module
|
||||
if len(parents) > 0:
|
||||
parent_module = root_module.get_submodule(".".join(parents))
|
||||
if isinstance(parent_module, nn.Sequential):
|
||||
src_module = parent_module[int(k)]
|
||||
else:
|
||||
src_module = getattr(parent_module, k)
|
||||
tgt_module = func(src_module)
|
||||
if isinstance(parent_module, nn.Sequential):
|
||||
parent_module[int(k)] = tgt_module
|
||||
else:
|
||||
setattr(parent_module, k, tgt_module)
|
||||
# 验证所有 BN 是否已被替换
|
||||
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
|
||||
return root_module
|
||||
|
||||
class SpatialSoftmax(nn.Module):
|
||||
"""
|
||||
Finn 等人在 "Deep Spatial Autoencoders for Visuomotor Learning" 中描述的空间软 Argmax 操作
|
||||
(https://huggingface.co/papers/1509.06113)。这是 robomimic 实现的一个最小移植版本。
|
||||
"""
|
||||
|
||||
def __init__(self, input_shape, num_kp=None):
|
||||
"""
|
||||
Args:
|
||||
input_shape (list): (C, H, W) 输入特征图形状。
|
||||
num_kp (int): 输出中的关键点数量。如果为 None,输出将具有与输入相同的通道数。
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
assert len(input_shape) == 3
|
||||
self._in_c, self._in_h, self._in_w = input_shape
|
||||
|
||||
if num_kp is not None:
|
||||
self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
|
||||
self._out_c = num_kp
|
||||
else:
|
||||
self.nets = None
|
||||
self._out_c = self._in_c
|
||||
|
||||
# 我们可以直接使用 torch.linspace,但这似乎与 numpy 的行为略有不同
|
||||
# 并且会导致预训练模型的 pc_success 略有下降。
|
||||
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
|
||||
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
|
||||
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
|
||||
# 注册为 buffer,以便将其移动到正确的设备。
|
||||
self.register_buffer("pos_grid", torch.cat([pos_x, pos_y], dim=1))
|
||||
|
||||
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
features: (B, C, H, W) 输入特征图。
|
||||
Returns:
|
||||
(B, K, 2) 关键点的图像空间坐标。
|
||||
"""
|
||||
if self.nets is not None:
|
||||
features = self.nets(features)
|
||||
|
||||
# [B, K, H, W] -> [B * K, H * W],其中 K 是关键点数量
|
||||
features = features.reshape(-1, self._in_h * self._in_w)
|
||||
# 2d softmax 归一化
|
||||
attention = F.softmax(features, dim=-1)
|
||||
# [B * K, H * W] x [H * W, 2] -> [B * K, 2] 用于 x 和 y 维度的空间坐标均值
|
||||
expected_xy = attention @ self.pos_grid
|
||||
# 重塑为 [B, K, 2]
|
||||
feature_keypoints = expected_xy.view(-1, self._out_c, 2)
|
||||
|
||||
return feature_keypoints
|
||||
|
||||
class ResNetDiffusionBackbone(VLABackbone):
|
||||
def __init__(
|
||||
self,
|
||||
vision_backbone: str = "resnet18",
|
||||
pretrained_backbone_weights: str | None = None,
|
||||
input_shape: Tuple[int, int, int] = (3, 84, 84), # (C, H, W)
|
||||
crop_shape: Optional[Tuple[int, int]] = None,
|
||||
crop_is_random: bool = True,
|
||||
use_group_norm: bool = True,
|
||||
spatial_softmax_num_keypoints: int = 32,
|
||||
use_separate_rgb_encoder_per_camera: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# 保存所有参数作为实例变量
|
||||
self.vision_backbone = vision_backbone
|
||||
self.pretrained_backbone_weights = pretrained_backbone_weights
|
||||
self.input_shape = input_shape
|
||||
self.crop_shape = crop_shape
|
||||
self.crop_is_random = crop_is_random
|
||||
self.use_group_norm = use_group_norm
|
||||
self.spatial_softmax_num_keypoints = spatial_softmax_num_keypoints
|
||||
self.use_separate_rgb_encoder_per_camera = use_separate_rgb_encoder_per_camera
|
||||
|
||||
# 设置可选的预处理。
|
||||
if crop_shape is not None:
|
||||
self.do_crop = True
|
||||
# 评估时始终使用中心裁剪
|
||||
self.center_crop = torchvision.transforms.CenterCrop(crop_shape)
|
||||
if crop_is_random:
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(crop_shape)
|
||||
else:
|
||||
self.maybe_random_crop = self.center_crop
|
||||
else:
|
||||
self.do_crop = False
|
||||
self.crop_shape = input_shape[1:]
|
||||
|
||||
# 创建骨干网络的内部函数
|
||||
def _create_backbone():
|
||||
backbone_model = getattr(torchvision.models, vision_backbone)(
|
||||
weights=pretrained_backbone_weights
|
||||
)
|
||||
# 移除 AvgPool 和 FC (假设 layer4 是 children()[-3])
|
||||
backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
|
||||
if use_group_norm:
|
||||
backbone = _replace_submodules(
|
||||
root_module=backbone,
|
||||
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
||||
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
|
||||
)
|
||||
return backbone
|
||||
|
||||
# 创建池化和最终层的内部函数
|
||||
def _create_head(feature_map_shape):
|
||||
pool = SpatialSoftmax(feature_map_shape, num_kp=spatial_softmax_num_keypoints)
|
||||
feature_dim = spatial_softmax_num_keypoints * 2
|
||||
out = nn.Linear(spatial_softmax_num_keypoints * 2, feature_dim)
|
||||
relu = nn.ReLU()
|
||||
return pool, feature_dim, out, relu
|
||||
|
||||
# 使用试运行来获取特征图形状
|
||||
dummy_shape = (1, input_shape[0], *self.crop_shape)
|
||||
|
||||
if self.use_separate_rgb_encoder_per_camera:
|
||||
# 每个相机使用独立的编码器,我们先创建一个临时骨干网络来获取特征图形状
|
||||
temp_backbone = _create_backbone()
|
||||
with torch.no_grad():
|
||||
dummy_out = temp_backbone(torch.zeros(dummy_shape))
|
||||
feature_map_shape = dummy_out.shape[1:] # (C, H, W)
|
||||
del temp_backbone
|
||||
|
||||
# 注意:我们在 forward 方法中动态创建编码器,或者在知道相机数量时创建
|
||||
# 这里我们先不创建具体的编码器实例,而是在 forward 时根据需要创建
|
||||
# 或者,我们可以要求用户提供相机数量参数
|
||||
self.camera_encoders = None
|
||||
self.feature_dim = spatial_softmax_num_keypoints * 2
|
||||
else:
|
||||
# 所有相机共享同一个编码器
|
||||
self.backbone = _create_backbone()
|
||||
with torch.no_grad():
|
||||
dummy_out = self.backbone(torch.zeros(dummy_shape))
|
||||
feature_map_shape = dummy_out.shape[1:] # (C, H, W)
|
||||
self.pool, self.feature_dim, self.out, self.relu = _create_head(feature_map_shape)
|
||||
|
||||
def _create_single_encoder(self):
|
||||
"""内部方法:创建单个编码器(骨干网络 + 池化 + 输出层)"""
|
||||
# 创建骨干网络
|
||||
backbone_model = getattr(torchvision.models, self.vision_backbone)(
|
||||
weights=self.pretrained_backbone_weights
|
||||
)
|
||||
backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
|
||||
|
||||
if self.use_group_norm:
|
||||
backbone = _replace_submodules(
|
||||
root_module=backbone,
|
||||
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
||||
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
|
||||
)
|
||||
|
||||
# 获取特征图形状
|
||||
dummy_shape = (1, self.input_shape[0], *self.crop_shape)
|
||||
with torch.no_grad():
|
||||
dummy_out = backbone(torch.zeros(dummy_shape))
|
||||
feature_map_shape = dummy_out.shape[1:]
|
||||
|
||||
# 创建池化和输出层
|
||||
pool = SpatialSoftmax(feature_map_shape, num_kp=self.spatial_softmax_num_keypoints)
|
||||
out = nn.Linear(self.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||
relu = nn.ReLU()
|
||||
|
||||
return nn.ModuleList([backbone, pool, out, relu])
|
||||
|
||||
def forward_single_image(self, x: torch.Tensor, encoder: nn.ModuleList = None) -> torch.Tensor:
|
||||
if self.do_crop:
|
||||
x = self.maybe_random_crop(x) if self.training else self.center_crop(x)
|
||||
|
||||
if self.use_separate_rgb_encoder_per_camera:
|
||||
# 使用独立编码器
|
||||
backbone, pool, out, relu = encoder
|
||||
x = relu(out(torch.flatten(pool(backbone(x)), start_dim=1)))
|
||||
else:
|
||||
# 使用共享编码器
|
||||
x = self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)))
|
||||
return x
|
||||
|
||||
def forward(self, images):
|
||||
any_tensor = next(iter(images.values()))
|
||||
B, T = any_tensor.shape[:2]
|
||||
features_all = []
|
||||
|
||||
# 检查是否需要初始化独立编码器
|
||||
if self.use_separate_rgb_encoder_per_camera and self.camera_encoders is None:
|
||||
self.camera_encoders = nn.ModuleDict()
|
||||
for cam_name in sorted(images.keys()):
|
||||
self.camera_encoders[cam_name] = self._create_single_encoder()
|
||||
|
||||
for cam_name in sorted(images.keys()):
|
||||
img = images[cam_name]
|
||||
if self.use_separate_rgb_encoder_per_camera:
|
||||
# 使用该相机对应的独立编码器
|
||||
features = self.forward_single_image(
|
||||
img.view(B * T, *img.shape[2:]),
|
||||
self.camera_encoders[cam_name]
|
||||
)
|
||||
else:
|
||||
# 使用共享编码器
|
||||
features = self.forward_single_image(img.view(B * T, *img.shape[2:]))
|
||||
features_all.append(features)
|
||||
|
||||
return torch.cat(features_all, dim=1).view(B, T, -1)
|
||||
|
||||
@property
|
||||
def output_dim(self):
|
||||
return self.feature_dim
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("🚀 Testing ResNetDiffusionBackbone...")
|
||||
|
||||
# Configuration
|
||||
B, T = 2, 5
|
||||
C, H, W = 3, 96, 96
|
||||
crop_h, crop_w = 84, 84
|
||||
num_keypoints = 32
|
||||
feature_dim_per_cam = num_keypoints * 2
|
||||
|
||||
# Instantiate model
|
||||
backbone = ResNetDiffusionBackbone(
|
||||
vision_backbone="resnet18",
|
||||
pretrained_backbone_weights=None, # Speed up test
|
||||
input_shape=(C, H, W),
|
||||
crop_shape=(crop_h, crop_w),
|
||||
crop_is_random=True,
|
||||
use_group_norm=True,
|
||||
spatial_softmax_num_keypoints=num_keypoints
|
||||
)
|
||||
|
||||
print(f"✅ Model instantiated. Output dim per camera: {backbone.output_dim}")
|
||||
|
||||
# Create dummy input
|
||||
images = {
|
||||
"cam_high": torch.randn(B, T, C, H, W),
|
||||
"cam_wrist": torch.randn(B, T, C, H, W)
|
||||
}
|
||||
|
||||
# Forward pass
|
||||
print("🔄 Running forward pass...")
|
||||
output = backbone(images)
|
||||
|
||||
print(f"Input shapes: {[v.shape for v in images.values()]}")
|
||||
print(f"Output shape: {output.shape}")
|
||||
|
||||
# Verification
|
||||
expected_dim = len(images) * feature_dim_per_cam
|
||||
assert output.shape == (B, T, expected_dim), f"Expected shape {(B, T, expected_dim)}, got {output.shape}"
|
||||
|
||||
print("✨ Test passed!")
|
||||
Reference in New Issue
Block a user