chore: 删除多余脚本
This commit is contained in:
@@ -1,10 +1,4 @@
|
||||
# Backbone models
|
||||
from .siglip import SigLIPBackbone
|
||||
from .resnet import ResNetBackbone
|
||||
# from .clip import CLIPBackbone
|
||||
# from .dinov2 import DinoV2Backbone
|
||||
|
||||
__all__ = ["SigLIPBackbone", "ResNetBackbone"]
|
||||
|
||||
# from .debug import DebugBackbone
|
||||
# __all__ = ["DebugBackbone"]
|
||||
__all__ = ["ResNetBackbone"]
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
# SigLIP Backbone 实现
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import AutoModel, AutoProcessor, SiglipVisionModel
|
||||
from typing import Dict, Optional
|
||||
from roboimi.vla.core.interfaces import VLABackbone
|
||||
|
||||
class SigLIPBackbone(VLABackbone):
|
||||
"""
|
||||
Wraps Google's SigLIP Vision Encoder.
|
||||
HuggingFace ID example: "google/siglip-so400m-patch14-384"
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "google/siglip-so400m-patch14-384",
|
||||
freeze: bool = True,
|
||||
embed_dim: Optional[int] = None
|
||||
):
|
||||
super().__init__()
|
||||
print(f"Loading SigLIP: {model_name} ...")
|
||||
|
||||
# 加载视觉部分 (Vision Model only)
|
||||
# 我们不需要 Text Tower,因为 SigLIP 是对齐好的,只用 Vision Tower 抽特征即可
|
||||
self.vision_model = SiglipVisionModel.from_pretrained(model_name)
|
||||
|
||||
# 优先使用配置传入的 embed_dim,否则自动获取
|
||||
if embed_dim is not None:
|
||||
self._embed_dim = embed_dim
|
||||
print(f"✓ Using configured embed_dim: {embed_dim}")
|
||||
else:
|
||||
# 自动获取维度 (SigLIP so400m 通常是 1152)
|
||||
self._embed_dim = self.vision_model.config.hidden_size
|
||||
print(f"✓ Auto-detected embed_dim: {self._embed_dim}")
|
||||
|
||||
if freeze:
|
||||
self._freeze_parameters()
|
||||
|
||||
def _freeze_parameters(self):
|
||||
print("❄️ Freezing Vision Backbone parameters")
|
||||
for param in self.vision_model.parameters():
|
||||
param.requires_grad = False
|
||||
self.vision_model.eval()
|
||||
|
||||
def forward(self, obs: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
obs['image']: (B, C, H, W) normalized tensor
|
||||
Returns:
|
||||
features: (B, Seq_Len, Embed_Dim)
|
||||
"""
|
||||
images = obs['image']
|
||||
|
||||
# SigLIP 期望输入是 (B, C, H, W)
|
||||
# HuggingFace 的 VisionModel 输出是一个 BaseModelOutputWithPooling
|
||||
# last_hidden_state shape: (B, Num_Patches, Embed_Dim)
|
||||
outputs = self.vision_model(pixel_values=images)
|
||||
|
||||
return outputs.last_hidden_state
|
||||
|
||||
@property
|
||||
def embed_dim(self) -> int:
|
||||
return self._embed_dim
|
||||
@@ -1,8 +1,4 @@
|
||||
# # Action Head models
|
||||
from .diffusion import ConditionalUnet1D
|
||||
# from .act import ACTHead
|
||||
|
||||
__all__ = ["ConditionalUnet1D"]
|
||||
|
||||
# from .debug import DebugHead
|
||||
# __all__ = ["DebugHead"]
|
||||
Reference in New Issue
Block a user