remove hardcode dino encoder logic

This commit is contained in:
wangshuai6
2025-04-11 11:07:59 +08:00
parent 093b0f47f8
commit 485818abf6
2 changed files with 2 additions and 12 deletions

View File

@@ -25,12 +25,7 @@ def constant(alpha, sigma):
class DINOv2(nn.Module): class DINOv2(nn.Module):
def __init__(self, weight_path:str): def __init__(self, weight_path:str):
super(DINOv2, self).__init__() super(DINOv2, self).__init__()
self.encoder = torch.hub.load( self.encoder = torch.hub.load('facebookresearch/dinov2', weight_path)
'/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
weight_path,
source="local",
skip_validation=True
)
self.pos_embed = copy.deepcopy(self.encoder.pos_embed) self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
self.encoder.head = torch.nn.Identity() self.encoder.head = torch.nn.Identity()
self.patch_size = self.encoder.patch_embed.patch_size self.patch_size = self.encoder.patch_embed.patch_size

View File

@@ -25,12 +25,7 @@ def constant(alpha, sigma):
class DINOv2(nn.Module): class DINOv2(nn.Module):
def __init__(self, weight_path:str): def __init__(self, weight_path:str):
super(DINOv2, self).__init__() super(DINOv2, self).__init__()
self.encoder = torch.hub.load( self.encoder = torch.hub.load('facebookresearch/dinov2', weight_path)
'/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
weight_path,
source="local",
skip_validation=True
)
self.pos_embed = copy.deepcopy(self.encoder.pos_embed) self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
self.encoder.head = torch.nn.Identity() self.encoder.head = torch.nn.Identity()
self.patch_size = self.encoder.patch_embed.patch_size self.patch_size = self.encoder.patch_embed.patch_size