|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch.nn as nn |
|
from transformers import ViTImageProcessor, ViTModel, AutoImageProcessor, AutoModel, Dinov2Model |
|
|
|
class DinoWrapper(nn.Module): |
|
""" |
|
Dino v1 wrapper using huggingface transformer implementation. |
|
""" |
|
def __init__(self, model_name: str, freeze: bool = True): |
|
super().__init__() |
|
self.model, self.processor = self._build_dino(model_name) |
|
if freeze: |
|
self._freeze() |
|
|
|
def forward(self, image): |
|
|
|
|
|
inputs = self.processor(images=image.float(), return_tensors="pt", do_rescale=False, do_resize=False).to(self.model.device) |
|
|
|
outputs = self.model(**inputs) |
|
last_hidden_states = outputs.last_hidden_state |
|
return last_hidden_states |
|
|
|
def _freeze(self): |
|
print(f"======== Freezing DinoWrapper ========") |
|
self.model.eval() |
|
for name, param in self.model.named_parameters(): |
|
param.requires_grad = False |
|
|
|
@staticmethod |
|
def _build_dino(model_name: str, proxy_error_retries: int = 3, proxy_error_cooldown: int = 5): |
|
import requests |
|
try: |
|
processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base') |
|
processor.do_center_crop = False |
|
model = AutoModel.from_pretrained('facebook/dinov2-base') |
|
return model, processor |
|
except requests.exceptions.ProxyError as err: |
|
if proxy_error_retries > 0: |
|
print(f"Huggingface ProxyError: Retrying in {proxy_error_cooldown} seconds...") |
|
import time |
|
time.sleep(proxy_error_cooldown) |
|
return DinoWrapper._build_dino(model_name, proxy_error_retries - 1, proxy_error_cooldown) |
|
else: |
|
raise err |