chore: 删除多余文件
This commit is contained in:
@@ -1,37 +0,0 @@
|
||||
from transformers import SiglipVisionModel
|
||||
from roboimi.vla.core.interfaces import VLABackbone
|
||||
from torchvision import transforms
|
||||
|
||||
class SigLIP2(VLABackbone):
|
||||
def __init__(
|
||||
self,
|
||||
model_name = "google/siglip2-base-patch16-384",
|
||||
freeze: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.vision_model = SiglipVisionModel.from_pretrained(model_name)
|
||||
self.transform = transforms.Compose([
|
||||
transforms.Resize((384, 384), antialias=True),
|
||||
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
])
|
||||
|
||||
if freeze:
|
||||
self._freeze_parameters()
|
||||
|
||||
def _freeze_parameters(self):
|
||||
print("❄️ Freezing Vision Backbone parameters")
|
||||
for param in self.vision_model.parameters():
|
||||
param.requires_grad = False
|
||||
self.vision_model.eval()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
images
|
||||
):
|
||||
# images: (B, C, H, W), 归一化到 [0, 1]
|
||||
images = self.transform(images) # 归一化到 [-1, 1]
|
||||
|
||||
outputs = self.vision_model(pixel_values=images)
|
||||
|
||||
return outputs.last_hidden_state
|
||||
@@ -1,9 +0,0 @@
|
||||
# Projector models
|
||||
# from .mlp import MLPProjector
|
||||
# from .perceiver import PerceiverResampler
|
||||
|
||||
# __all__ = ["MLPProjector", "PerceiverResampler"]
|
||||
|
||||
from .mlp import MLPProjector
|
||||
|
||||
__all__ = ["MLPProjector"]
|
||||
@@ -1,19 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from roboimi.vla.core.interfaces import VLAProjector
|
||||
|
||||
class MLPProjector(VLAProjector):
|
||||
"""
|
||||
A simple Linear Projection layer.
|
||||
First-class citizen: Adapts Backbone dim -> Head dim.
|
||||
"""
|
||||
def __init__(self, input_dim: int, output_dim: int):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(input_dim, output_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(output_dim, output_dim)
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.net(x)
|
||||
@@ -1 +0,0 @@
|
||||
# Perceiver Resampler 实现
|
||||
Reference in New Issue
Block a user