From ce93cc61a8a52436ba546cae27733c096982298f Mon Sep 17 00:00:00 2001 From: wangshuai6 Date: Sun, 13 Apr 2025 09:14:35 +0800 Subject: [PATCH] fix bugs(rm encoder.py) --- src/models/encoder.py | 132 ------------------------------------------ 1 file changed, 132 deletions(-) delete mode 100644 src/models/encoder.py diff --git a/src/models/encoder.py b/src/models/encoder.py deleted file mode 100644 index 8b7f96a..0000000 --- a/src/models/encoder.py +++ /dev/null @@ -1,132 +0,0 @@ -import torch -import copy -import os -import timm -import transformers -import torch.nn as nn -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.data import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD -from torchvision.transforms import Normalize - -class RandViT(nn.Module): - def __init__(self, model_id, weight_path:str=None): - super(RandViT, self).__init__() - self.encoder = timm.create_model( - model_id, - num_classes=0, - ) - self.pos_embed = copy.deepcopy(self.encoder.pos_embed) - self.encoder.head = torch.nn.Identity() - self.patch_size = self.encoder.patch_embed.patch_size - self.shifts = nn.Parameter(torch.tensor([0.0 - ]), requires_grad=False) - self.scales = nn.Parameter(torch.tensor([1.0 - ]), requires_grad=False) - - def forward(self, x): - x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) - x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic') - b, c, h, w = x.shape - patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] - feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:] - feature = feature.transpose(1, 2) - feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous() - feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1) - return feature - -class MAE(nn.Module): - def __init__(self, model_id, weight_path:str): - super(MAE, self).__init__() - if os.path.isdir(weight_path): - weight_path = os.path.join(weight_path, "pytorch_model.bin") - self.encoder = timm.create_model( - model_id, - checkpoint_path=weight_path, - num_classes=0, - ) - self.pos_embed = copy.deepcopy(self.encoder.pos_embed) - self.encoder.head = torch.nn.Identity() - self.patch_size = self.encoder.patch_embed.patch_size - self.shifts = nn.Parameter(torch.tensor([0.0 - ]), requires_grad=False) - self.scales = nn.Parameter(torch.tensor([1.0 - ]), requires_grad=False) - - def forward(self, x): - x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) - x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic') - b, c, h, w = x.shape - patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] - feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:] - feature = feature.transpose(1, 2) - feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous() - feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1) - return feature - -class DINO(nn.Module): - def __init__(self, model_id, weight_path:str): - super(DINO, self).__init__() - if os.path.isdir(weight_path): - weight_path = os.path.join(weight_path, "pytorch_model.bin") - self.encoder = timm.create_model( - model_id, - checkpoint_path=weight_path, - num_classes=0, - ) - self.pos_embed = copy.deepcopy(self.encoder.pos_embed) - self.encoder.head = torch.nn.Identity() - self.patch_size = self.encoder.patch_embed.patch_size - self.shifts = nn.Parameter(torch.tensor([ 0.0, - ]), requires_grad=False) - self.scales = nn.Parameter(torch.tensor([ 1.0, - ]), requires_grad=False) - - def forward(self, x): - x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) - x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic') - b, c, h, w = x.shape - patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] - feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:] - feature = feature.transpose(1, 2) - feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous() - feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1) - return feature - -class CLIP(nn.Module): - def __init__(self, model_id, weight_path:str): - super(CLIP, self).__init__() - self.encoder = transformers.CLIPVisionModel.from_pretrained(weight_path) - self.patch_size = self.encoder.vision_model.embeddings.patch_embedding.kernel_size - self.shifts = nn.Parameter(torch.tensor([0.0, - ]), requires_grad=False) - self.scales = nn.Parameter(torch.tensor([1.0, - ]), requires_grad=False) - - def forward(self, x): - x = Normalize(OPENAI_CLIP_MEAN, OPENAI_CLIP_STD)(x) - x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic') - b, c, h, w = x.shape - patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] - feature = self.encoder(x)['last_hidden_state'][:, 1:] - feature = feature.transpose(1, 2) - feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous() - feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1) - return feature - - - -class DINOv2(nn.Module): - def __init__(self, model_id, weight_path:str): - super(DINOv2, self).__init__() - self.encoder = transformers.Dinov2Model.from_pretrained(weight_path) - self.patch_size = self.encoder.embeddings.patch_embeddings.projection.kernel_size - - def forward(self, x): - x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) - x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic') - b, c, h, w = x.shape - patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] - feature = self.encoder.forward(x)['last_hidden_state'][:, 1:] - feature = feature.transpose(1, 2) - feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous() - return feature \ No newline at end of file