refactor: 重构resnet
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
# @package agent
|
# @package agent
|
||||||
defaults:
|
defaults:
|
||||||
- /backbone@vision_backbone: resnet
|
# - /backbone@vision_backbone: resnet
|
||||||
|
- /backbone@vision_backbone: resnet_diffusion
|
||||||
- /modules@state_encoder: identity_state_encoder
|
- /modules@state_encoder: identity_state_encoder
|
||||||
- /modules@action_encoder: identity_action_encoder
|
- /modules@action_encoder: identity_action_encoder
|
||||||
- /head: conditional_unet1d
|
- /head: conditional_unet1d
|
||||||
@@ -16,8 +17,6 @@ obs_dim: 16
|
|||||||
pred_horizon: 16
|
pred_horizon: 16
|
||||||
obs_horizon: 2
|
obs_horizon: 2
|
||||||
|
|
||||||
# Diffusion Parameters
|
|
||||||
# diffusion_steps: 100 (这些参数应该移到 head 配置中,或者通过变量传递)
|
|
||||||
|
|
||||||
# Camera Configuration
|
# Camera Configuration
|
||||||
num_cams: ${len:${data.camera_names}} # 自动从 data.camera_names 列表长度获取
|
num_cams: ${len:${data.camera_names}} # 自动从 data.camera_names 列表长度获取
|
||||||
9
roboimi/vla/conf/backbone/resnet_diffusion.yaml
Normal file
9
roboimi/vla/conf/backbone/resnet_diffusion.yaml
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
_target_: roboimi.vla.models.backbones.resnet_diffusion.ResNetDiffusionBackbone
|
||||||
|
vision_backbone: "resnet18"
|
||||||
|
pretrained_backbone_weights: null
|
||||||
|
input_shape: [3, 96, 96]
|
||||||
|
crop_shape: [84, 84]
|
||||||
|
crop_is_random: true
|
||||||
|
use_group_norm: true
|
||||||
|
spatial_softmax_num_keypoints: 32
|
||||||
|
use_separate_rgb_encoder_per_camera: true
|
||||||
289
roboimi/vla/models/backbones/resnet_diffusion.py
Normal file
289
roboimi/vla/models/backbones/resnet_diffusion.py
Normal file
@@ -0,0 +1,289 @@
|
|||||||
|
from roboimi.vla.core.interfaces import VLABackbone
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torchvision
|
||||||
|
import numpy as np
|
||||||
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
|
def _replace_submodules(
|
||||||
|
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
|
||||||
|
) -> nn.Module:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
root_module: 需要替换子模块的根模块
|
||||||
|
predicate: 接受一个模块作为参数,如果该模块需要被替换则返回 True。
|
||||||
|
func: 接受一个模块作为参数,并返回一个新的模块来替换它。
|
||||||
|
Returns:
|
||||||
|
子模块已被替换的根模块。
|
||||||
|
"""
|
||||||
|
if predicate(root_module):
|
||||||
|
return func(root_module)
|
||||||
|
|
||||||
|
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
|
||||||
|
for *parents, k in replace_list:
|
||||||
|
parent_module = root_module
|
||||||
|
if len(parents) > 0:
|
||||||
|
parent_module = root_module.get_submodule(".".join(parents))
|
||||||
|
if isinstance(parent_module, nn.Sequential):
|
||||||
|
src_module = parent_module[int(k)]
|
||||||
|
else:
|
||||||
|
src_module = getattr(parent_module, k)
|
||||||
|
tgt_module = func(src_module)
|
||||||
|
if isinstance(parent_module, nn.Sequential):
|
||||||
|
parent_module[int(k)] = tgt_module
|
||||||
|
else:
|
||||||
|
setattr(parent_module, k, tgt_module)
|
||||||
|
# 验证所有 BN 是否已被替换
|
||||||
|
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
|
||||||
|
return root_module
|
||||||
|
|
||||||
|
class SpatialSoftmax(nn.Module):
|
||||||
|
"""
|
||||||
|
Finn 等人在 "Deep Spatial Autoencoders for Visuomotor Learning" 中描述的空间软 Argmax 操作
|
||||||
|
(https://huggingface.co/papers/1509.06113)。这是 robomimic 实现的一个最小移植版本。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, input_shape, num_kp=None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
input_shape (list): (C, H, W) 输入特征图形状。
|
||||||
|
num_kp (int): 输出中的关键点数量。如果为 None,输出将具有与输入相同的通道数。
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert len(input_shape) == 3
|
||||||
|
self._in_c, self._in_h, self._in_w = input_shape
|
||||||
|
|
||||||
|
if num_kp is not None:
|
||||||
|
self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
|
||||||
|
self._out_c = num_kp
|
||||||
|
else:
|
||||||
|
self.nets = None
|
||||||
|
self._out_c = self._in_c
|
||||||
|
|
||||||
|
# 我们可以直接使用 torch.linspace,但这似乎与 numpy 的行为略有不同
|
||||||
|
# 并且会导致预训练模型的 pc_success 略有下降。
|
||||||
|
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
|
||||||
|
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
|
||||||
|
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
|
||||||
|
# 注册为 buffer,以便将其移动到正确的设备。
|
||||||
|
self.register_buffer("pos_grid", torch.cat([pos_x, pos_y], dim=1))
|
||||||
|
|
||||||
|
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
features: (B, C, H, W) 输入特征图。
|
||||||
|
Returns:
|
||||||
|
(B, K, 2) 关键点的图像空间坐标。
|
||||||
|
"""
|
||||||
|
if self.nets is not None:
|
||||||
|
features = self.nets(features)
|
||||||
|
|
||||||
|
# [B, K, H, W] -> [B * K, H * W],其中 K 是关键点数量
|
||||||
|
features = features.reshape(-1, self._in_h * self._in_w)
|
||||||
|
# 2d softmax 归一化
|
||||||
|
attention = F.softmax(features, dim=-1)
|
||||||
|
# [B * K, H * W] x [H * W, 2] -> [B * K, 2] 用于 x 和 y 维度的空间坐标均值
|
||||||
|
expected_xy = attention @ self.pos_grid
|
||||||
|
# 重塑为 [B, K, 2]
|
||||||
|
feature_keypoints = expected_xy.view(-1, self._out_c, 2)
|
||||||
|
|
||||||
|
return feature_keypoints
|
||||||
|
|
||||||
|
class ResNetDiffusionBackbone(VLABackbone):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vision_backbone: str = "resnet18",
|
||||||
|
pretrained_backbone_weights: str | None = None,
|
||||||
|
input_shape: Tuple[int, int, int] = (3, 84, 84), # (C, H, W)
|
||||||
|
crop_shape: Optional[Tuple[int, int]] = None,
|
||||||
|
crop_is_random: bool = True,
|
||||||
|
use_group_norm: bool = True,
|
||||||
|
spatial_softmax_num_keypoints: int = 32,
|
||||||
|
use_separate_rgb_encoder_per_camera: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# 保存所有参数作为实例变量
|
||||||
|
self.vision_backbone = vision_backbone
|
||||||
|
self.pretrained_backbone_weights = pretrained_backbone_weights
|
||||||
|
self.input_shape = input_shape
|
||||||
|
self.crop_shape = crop_shape
|
||||||
|
self.crop_is_random = crop_is_random
|
||||||
|
self.use_group_norm = use_group_norm
|
||||||
|
self.spatial_softmax_num_keypoints = spatial_softmax_num_keypoints
|
||||||
|
self.use_separate_rgb_encoder_per_camera = use_separate_rgb_encoder_per_camera
|
||||||
|
|
||||||
|
# 设置可选的预处理。
|
||||||
|
if crop_shape is not None:
|
||||||
|
self.do_crop = True
|
||||||
|
# 评估时始终使用中心裁剪
|
||||||
|
self.center_crop = torchvision.transforms.CenterCrop(crop_shape)
|
||||||
|
if crop_is_random:
|
||||||
|
self.maybe_random_crop = torchvision.transforms.RandomCrop(crop_shape)
|
||||||
|
else:
|
||||||
|
self.maybe_random_crop = self.center_crop
|
||||||
|
else:
|
||||||
|
self.do_crop = False
|
||||||
|
self.crop_shape = input_shape[1:]
|
||||||
|
|
||||||
|
# 创建骨干网络的内部函数
|
||||||
|
def _create_backbone():
|
||||||
|
backbone_model = getattr(torchvision.models, vision_backbone)(
|
||||||
|
weights=pretrained_backbone_weights
|
||||||
|
)
|
||||||
|
# 移除 AvgPool 和 FC (假设 layer4 是 children()[-3])
|
||||||
|
backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
|
||||||
|
if use_group_norm:
|
||||||
|
backbone = _replace_submodules(
|
||||||
|
root_module=backbone,
|
||||||
|
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
||||||
|
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
|
||||||
|
)
|
||||||
|
return backbone
|
||||||
|
|
||||||
|
# 创建池化和最终层的内部函数
|
||||||
|
def _create_head(feature_map_shape):
|
||||||
|
pool = SpatialSoftmax(feature_map_shape, num_kp=spatial_softmax_num_keypoints)
|
||||||
|
feature_dim = spatial_softmax_num_keypoints * 2
|
||||||
|
out = nn.Linear(spatial_softmax_num_keypoints * 2, feature_dim)
|
||||||
|
relu = nn.ReLU()
|
||||||
|
return pool, feature_dim, out, relu
|
||||||
|
|
||||||
|
# 使用试运行来获取特征图形状
|
||||||
|
dummy_shape = (1, input_shape[0], *self.crop_shape)
|
||||||
|
|
||||||
|
if self.use_separate_rgb_encoder_per_camera:
|
||||||
|
# 每个相机使用独立的编码器,我们先创建一个临时骨干网络来获取特征图形状
|
||||||
|
temp_backbone = _create_backbone()
|
||||||
|
with torch.no_grad():
|
||||||
|
dummy_out = temp_backbone(torch.zeros(dummy_shape))
|
||||||
|
feature_map_shape = dummy_out.shape[1:] # (C, H, W)
|
||||||
|
del temp_backbone
|
||||||
|
|
||||||
|
# 注意:我们在 forward 方法中动态创建编码器,或者在知道相机数量时创建
|
||||||
|
# 这里我们先不创建具体的编码器实例,而是在 forward 时根据需要创建
|
||||||
|
# 或者,我们可以要求用户提供相机数量参数
|
||||||
|
self.camera_encoders = None
|
||||||
|
self.feature_dim = spatial_softmax_num_keypoints * 2
|
||||||
|
else:
|
||||||
|
# 所有相机共享同一个编码器
|
||||||
|
self.backbone = _create_backbone()
|
||||||
|
with torch.no_grad():
|
||||||
|
dummy_out = self.backbone(torch.zeros(dummy_shape))
|
||||||
|
feature_map_shape = dummy_out.shape[1:] # (C, H, W)
|
||||||
|
self.pool, self.feature_dim, self.out, self.relu = _create_head(feature_map_shape)
|
||||||
|
|
||||||
|
def _create_single_encoder(self):
|
||||||
|
"""内部方法:创建单个编码器(骨干网络 + 池化 + 输出层)"""
|
||||||
|
# 创建骨干网络
|
||||||
|
backbone_model = getattr(torchvision.models, self.vision_backbone)(
|
||||||
|
weights=self.pretrained_backbone_weights
|
||||||
|
)
|
||||||
|
backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
|
||||||
|
|
||||||
|
if self.use_group_norm:
|
||||||
|
backbone = _replace_submodules(
|
||||||
|
root_module=backbone,
|
||||||
|
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
||||||
|
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取特征图形状
|
||||||
|
dummy_shape = (1, self.input_shape[0], *self.crop_shape)
|
||||||
|
with torch.no_grad():
|
||||||
|
dummy_out = backbone(torch.zeros(dummy_shape))
|
||||||
|
feature_map_shape = dummy_out.shape[1:]
|
||||||
|
|
||||||
|
# 创建池化和输出层
|
||||||
|
pool = SpatialSoftmax(feature_map_shape, num_kp=self.spatial_softmax_num_keypoints)
|
||||||
|
out = nn.Linear(self.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||||
|
relu = nn.ReLU()
|
||||||
|
|
||||||
|
return nn.ModuleList([backbone, pool, out, relu])
|
||||||
|
|
||||||
|
def forward_single_image(self, x: torch.Tensor, encoder: nn.ModuleList = None) -> torch.Tensor:
|
||||||
|
if self.do_crop:
|
||||||
|
x = self.maybe_random_crop(x) if self.training else self.center_crop(x)
|
||||||
|
|
||||||
|
if self.use_separate_rgb_encoder_per_camera:
|
||||||
|
# 使用独立编码器
|
||||||
|
backbone, pool, out, relu = encoder
|
||||||
|
x = relu(out(torch.flatten(pool(backbone(x)), start_dim=1)))
|
||||||
|
else:
|
||||||
|
# 使用共享编码器
|
||||||
|
x = self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)))
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, images):
|
||||||
|
any_tensor = next(iter(images.values()))
|
||||||
|
B, T = any_tensor.shape[:2]
|
||||||
|
features_all = []
|
||||||
|
|
||||||
|
# 检查是否需要初始化独立编码器
|
||||||
|
if self.use_separate_rgb_encoder_per_camera and self.camera_encoders is None:
|
||||||
|
self.camera_encoders = nn.ModuleDict()
|
||||||
|
for cam_name in sorted(images.keys()):
|
||||||
|
self.camera_encoders[cam_name] = self._create_single_encoder()
|
||||||
|
|
||||||
|
for cam_name in sorted(images.keys()):
|
||||||
|
img = images[cam_name]
|
||||||
|
if self.use_separate_rgb_encoder_per_camera:
|
||||||
|
# 使用该相机对应的独立编码器
|
||||||
|
features = self.forward_single_image(
|
||||||
|
img.view(B * T, *img.shape[2:]),
|
||||||
|
self.camera_encoders[cam_name]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 使用共享编码器
|
||||||
|
features = self.forward_single_image(img.view(B * T, *img.shape[2:]))
|
||||||
|
features_all.append(features)
|
||||||
|
|
||||||
|
return torch.cat(features_all, dim=1).view(B, T, -1)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_dim(self):
|
||||||
|
return self.feature_dim
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("🚀 Testing ResNetDiffusionBackbone...")
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
B, T = 2, 5
|
||||||
|
C, H, W = 3, 96, 96
|
||||||
|
crop_h, crop_w = 84, 84
|
||||||
|
num_keypoints = 32
|
||||||
|
feature_dim_per_cam = num_keypoints * 2
|
||||||
|
|
||||||
|
# Instantiate model
|
||||||
|
backbone = ResNetDiffusionBackbone(
|
||||||
|
vision_backbone="resnet18",
|
||||||
|
pretrained_backbone_weights=None, # Speed up test
|
||||||
|
input_shape=(C, H, W),
|
||||||
|
crop_shape=(crop_h, crop_w),
|
||||||
|
crop_is_random=True,
|
||||||
|
use_group_norm=True,
|
||||||
|
spatial_softmax_num_keypoints=num_keypoints
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"✅ Model instantiated. Output dim per camera: {backbone.output_dim}")
|
||||||
|
|
||||||
|
# Create dummy input
|
||||||
|
images = {
|
||||||
|
"cam_high": torch.randn(B, T, C, H, W),
|
||||||
|
"cam_wrist": torch.randn(B, T, C, H, W)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
print("🔄 Running forward pass...")
|
||||||
|
output = backbone(images)
|
||||||
|
|
||||||
|
print(f"Input shapes: {[v.shape for v in images.values()]}")
|
||||||
|
print(f"Output shape: {output.shape}")
|
||||||
|
|
||||||
|
# Verification
|
||||||
|
expected_dim = len(images) * feature_dim_per_cam
|
||||||
|
assert output.shape == (B, T, expected_dim), f"Expected shape {(B, T, expected_dim)}, got {output.shape}"
|
||||||
|
|
||||||
|
print("✨ Test passed!")
|
||||||
Reference in New Issue
Block a user