remove hardcode dino encoder logic
This commit is contained in:
@@ -25,12 +25,7 @@ def constant(alpha, sigma):
|
|||||||
class DINOv2(nn.Module):
|
class DINOv2(nn.Module):
|
||||||
def __init__(self, weight_path:str):
|
def __init__(self, weight_path:str):
|
||||||
super(DINOv2, self).__init__()
|
super(DINOv2, self).__init__()
|
||||||
self.encoder = torch.hub.load(
|
self.encoder = torch.hub.load('facebookresearch/dinov2', weight_path)
|
||||||
'/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
|
|
||||||
weight_path,
|
|
||||||
source="local",
|
|
||||||
skip_validation=True
|
|
||||||
)
|
|
||||||
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
||||||
self.encoder.head = torch.nn.Identity()
|
self.encoder.head = torch.nn.Identity()
|
||||||
self.patch_size = self.encoder.patch_embed.patch_size
|
self.patch_size = self.encoder.patch_embed.patch_size
|
||||||
|
|||||||
@@ -25,12 +25,7 @@ def constant(alpha, sigma):
|
|||||||
class DINOv2(nn.Module):
|
class DINOv2(nn.Module):
|
||||||
def __init__(self, weight_path:str):
|
def __init__(self, weight_path:str):
|
||||||
super(DINOv2, self).__init__()
|
super(DINOv2, self).__init__()
|
||||||
self.encoder = torch.hub.load(
|
self.encoder = torch.hub.load('facebookresearch/dinov2', weight_path)
|
||||||
'/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
|
|
||||||
weight_path,
|
|
||||||
source="local",
|
|
||||||
skip_validation=True
|
|
||||||
)
|
|
||||||
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
||||||
self.encoder.head = torch.nn.Identity()
|
self.encoder.head = torch.nn.Identity()
|
||||||
self.patch_size = self.encoder.patch_embed.patch_size
|
self.patch_size = self.encoder.patch_embed.patch_size
|
||||||
|
|||||||
Reference in New Issue
Block a user