83 lines
3.0 KiB
Python
83 lines
3.0 KiB
Python
from roboimi.vla.core.interfaces import VLABackbone
|
|
from transformers import ResNetModel
|
|
from torchvision import transforms
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
class ResNetBackbone(VLABackbone):
|
|
def __init__(
|
|
self,
|
|
model_name = "microsoft/resnet-18",
|
|
freeze: bool = True,
|
|
):
|
|
super().__init__()
|
|
self.model = ResNetModel.from_pretrained(model_name)
|
|
self.out_channels = self.model.config.hidden_sizes[-1]
|
|
self.transform = transforms.Compose([
|
|
transforms.Resize((384, 384)),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
])
|
|
self.spatial_softmax = SpatialSoftmax(num_rows=12, num_cols=12)
|
|
if freeze:
|
|
self._freeze_parameters()
|
|
|
|
def _freeze_parameters(self):
|
|
print("❄️ Freezing ResNet Backbone parameters")
|
|
for param in self.model.parameters():
|
|
param.requires_grad = False
|
|
self.model.eval()
|
|
|
|
def forward_single_image(self, image):
|
|
B, T, C, H, W = image.shape
|
|
image = image.view(B * T, C, H, W)
|
|
image = self.transform(image)
|
|
feature_map = self.model(image).last_hidden_state # (B*T, D, H', W')
|
|
features = self.spatial_softmax(feature_map) # (B*T, D*2)
|
|
return features
|
|
|
|
def forward(self, images):
|
|
any_tensor = next(iter(images.values()))
|
|
B, T = any_tensor.shape[:2]
|
|
features_all = []
|
|
sorted_cam_names = sorted(images.keys())
|
|
for cam_name in sorted_cam_names:
|
|
img = images[cam_name]
|
|
features = self.forward_single_image(img) # (B*T, D*2)
|
|
features_all.append(features)
|
|
combined_features = torch.cat(features_all, dim=1) # (B*T, Num_Cams*D*2)
|
|
return combined_features.view(B, T, -1)
|
|
|
|
@property
|
|
def output_dim(self):
|
|
"""Output dimension after spatial softmax: out_channels * 2"""
|
|
return self.out_channels * 2
|
|
|
|
class SpatialSoftmax(nn.Module):
|
|
"""
|
|
将特征图 (N, C, H, W) 转换为坐标特征 (N, C*2)
|
|
"""
|
|
def __init__(self, num_rows, num_cols, temperature=None):
|
|
super().__init__()
|
|
self.temperature = nn.Parameter(torch.ones(1))
|
|
# 创建网格坐标
|
|
pos_x, pos_y = torch.meshgrid(
|
|
torch.linspace(-1, 1, num_rows),
|
|
torch.linspace(-1, 1, num_cols),
|
|
indexing='ij'
|
|
)
|
|
self.register_buffer('pos_x', pos_x.reshape(-1))
|
|
self.register_buffer('pos_y', pos_y.reshape(-1))
|
|
|
|
def forward(self, x):
|
|
N, C, H, W = x.shape
|
|
x = x.view(N, C, -1) # (N, C, H*W)
|
|
|
|
# 计算 Softmax 注意力图
|
|
softmax_attention = torch.nn.functional.softmax(x / self.temperature, dim=2)
|
|
|
|
# 计算期望坐标 (x, y)
|
|
expected_x = torch.sum(softmax_attention * self.pos_x, dim=2, keepdim=True)
|
|
expected_y = torch.sum(softmax_attention * self.pos_y, dim=2, keepdim=True)
|
|
|
|
# 拼接并展平 -> (N, C*2)
|
|
return torch.cat([expected_x, expected_y], dim=2).reshape(N, -1) |