feat: 更新框架,新增数据及定义和backbone

This commit is contained in:
gouhanke
2026-02-05 01:37:55 +08:00
parent 92660562fb
commit dd2749cb12
10 changed files with 224 additions and 134 deletions

View File

@@ -0,0 +1,37 @@
from transformers import SiglipVisionModel
from roboimi.vla.core.interfaces import VLABackbone
from torchvision import transforms
class SigLIP2(VLABackbone):
def __init__(
self,
model_name = "google/siglip2-base-patch16-384",
freeze: bool = True,
):
super().__init__()
self.vision_model = SiglipVisionModel.from_pretrained(model_name)
self.transform = transforms.Compose([
transforms.Resize((384, 384), antialias=True),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
if freeze:
self._freeze_parameters()
def _freeze_parameters(self):
print("❄️ Freezing Vision Backbone parameters")
for param in self.vision_model.parameters():
param.requires_grad = False
self.vision_model.eval()
def forward(
self,
images
):
# images: (B, C, H, W), 归一化到 [0, 1]
images = self.transform(images) # 归一化到 [-1, 1]
outputs = self.vision_model(pixel_values=images)
return outputs.last_hidden_state