From 485818abf66d9a1cb4faf6be5c90df70352dace2 Mon Sep 17 00:00:00 2001 From: wangshuai6 Date: Fri, 11 Apr 2025 11:07:59 +0800 Subject: [PATCH] remove hardcode dino encoder logic --- src/diffusion/flow_matching/training_repa.py | 7 +------ src/diffusion/stateful_flow_matching/training_repa.py | 7 +------ 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/src/diffusion/flow_matching/training_repa.py b/src/diffusion/flow_matching/training_repa.py index e9a6788..40d80a6 100644 --- a/src/diffusion/flow_matching/training_repa.py +++ b/src/diffusion/flow_matching/training_repa.py @@ -25,12 +25,7 @@ def constant(alpha, sigma): class DINOv2(nn.Module): def __init__(self, weight_path:str): super(DINOv2, self).__init__() - self.encoder = torch.hub.load( - '/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main', - weight_path, - source="local", - skip_validation=True - ) + self.encoder = torch.hub.load('facebookresearch/dinov2', weight_path) self.pos_embed = copy.deepcopy(self.encoder.pos_embed) self.encoder.head = torch.nn.Identity() self.patch_size = self.encoder.patch_embed.patch_size diff --git a/src/diffusion/stateful_flow_matching/training_repa.py b/src/diffusion/stateful_flow_matching/training_repa.py index a5a28db..4846d5d 100644 --- a/src/diffusion/stateful_flow_matching/training_repa.py +++ b/src/diffusion/stateful_flow_matching/training_repa.py @@ -25,12 +25,7 @@ def constant(alpha, sigma): class DINOv2(nn.Module): def __init__(self, weight_path:str): super(DINOv2, self).__init__() - self.encoder = torch.hub.load( - '/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main', - weight_path, - source="local", - skip_validation=True - ) + self.encoder = torch.hub.load('facebookresearch/dinov2', weight_path) self.pos_embed = copy.deepcopy(self.encoder.pos_embed) self.encoder.head = torch.nn.Identity() self.patch_size = self.encoder.patch_embed.patch_size