import pathlib import tempfile from collections import OrderedDict from typing import Tuple, Union import logging import os import numpy as np import torch import torch.nn.functional as F from torch import nn from timm.models.layers import DropPath, trunc_normal_ from .image_encoder import build_image_encoder from .text_encoder import build_text_encoder from .text_encoder import build_tokenizer from .templates import DEFAULT_TEMPLATES logger = logging.getLogger(__name__) class UniCLModel(nn.Module): def __init__(self, config: dict,): super().__init__() self.conf_lang_encoder = config['MODEL']['TEXT_ENCODER'] self.tokenizer = build_tokenizer(self.conf_lang_encoder) self.text_encoder = build_text_encoder(self.conf_lang_encoder, self.tokenizer, config['VERBOSE']) dim_projection = config['MODEL']['DIM_PROJECTION'] if hasattr(self.text_encoder, 'dim_out'): dim_out = self.text_encoder.dim_out else: with torch.no_grad(): dim_out = self.text_encoder( torch.zeros(1,1).type(torch.LongTensor) )['last_hidden_state'].size(2) self.text_projection = nn.Parameter(torch.empty(dim_out, dim_projection)) self.conf_image_encoder = config['MODEL']['IMAGE_ENCODER'] self.image_encoder = build_image_encoder(self.conf_image_encoder) self.image_projection = nn.Parameter( torch.empty(self.image_encoder.dim_out, dim_projection) ) self.logit_scale = nn.Parameter(torch.ones([])) trunc_normal_(self.text_projection, std=.02) trunc_normal_(self.image_projection, std=.02) def _convert_old_weights(self, model_dict): model_dict_updated = {} for k, v in model_dict.items(): if k.startswith('visual.'): model_dict_updated['image_encoder.'+k[7:]] = v elif k.startswith('text.'): model_dict_updated['lang_encoder.'+k[5:]] = v elif k == 'vision_projection': model_dict_updated['image_projection'] = v elif k == 'text_projection': model_dict_updated['text_projection'] = v else: model_dict_updated[k] = v return model_dict_updated def from_pretrained(self, pretrained='', pretrained_layers=[], verbose=True): if not os.path.isfile(pretrained): logger.warning(f'=> Pretrained model ({pretrained}) is not a file, skip init weight') return pretrained_dict = torch.load(pretrained, map_location='cpu') logger.info(f'=> Loading pretrained model {pretrained}') pretrained_dict = self._convert_old_weights(pretrained_dict) model_dict = self.state_dict() pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict.keys() } need_init_state_dict = {} image_encoder_state_dict = {} for k, v in pretrained_dict.items(): need_init = ( k.split('.')[0] in pretrained_layers or pretrained_layers[0] == '*' ) if need_init: if k.startswith('image_encoder.'): image_encoder_state_dict[k] = v else: if verbose: logger.info(f'=> init {k} from {pretrained}') need_init_state_dict[k] = v self.image_encoder.from_state_dict(image_encoder_state_dict, ['*'], verbose) self.load_state_dict(need_init_state_dict, strict=False) @torch.jit.ignore def no_weight_decay(self): no_weight_decay = {'logit_scale'} if hasattr(self.text_encoder, 'no_weight_decay'): for k in self.text_encoder.no_weight_decay(): no_weight_decay.add('lang_encoder.'+k) if hasattr(self.image_encoder, 'no_weight_decay'): for k in self.image_encoder.no_weight_decay(): no_weight_decay.add('image_encoder.'+k) return no_weight_decay @property def dtype(self): return self.logit_scale.dtype def get_imnet_embeddings(self): templates = IMAGENET_DEFAULT_TEMPLATES[:1] clss_embeddings = [] for clss in IMAGENET_CLASSES: txts = [template.format(clss) for template in templates] tokens = self.tokenizer( txts, padding='max_length', truncation=True, max_length=77, return_tensors='pt' ) tokens = {key:(val.cuda() if next(self.parameters()).is_cuda else val) for key,val in tokens.items()} clss_embedding = self.encode_text(tokens) clss_embedding = clss_embedding.mean(dim=0) clss_embedding /= clss_embedding.norm() clss_embeddings.append(clss_embedding) imnet_text_embeddings = torch.stack(clss_embeddings, dim=0) return imnet_text_embeddings def get_text_embeddings(self, texts): templates = DEFAULT_TEMPLATES[:1] clss_embeddings = [] for clss in texts: txts = [template.format(clss) for template in templates] tokens = self.tokenizer( txts, padding='max_length', truncation=True, max_length=77, return_tensors='pt' ) tokens = {key:(val.cuda() if next(self.parameters()).is_cuda else val) for key,val in tokens.items()} clss_embedding = self.encode_text(tokens) clss_embedding = clss_embedding.mean(dim=0) clss_embedding /= clss_embedding.norm() clss_embeddings.append(clss_embedding) imnet_text_embeddings = torch.stack(clss_embeddings, dim=0) return imnet_text_embeddings def encode_image(self, image, norm=True): x = self.image_encoder.forward_features(image) x = x @ self.image_projection if norm: x = x / x.norm(dim=-1, keepdim=True) return x def encode_text(self, text, norm=True): x = self.text_encoder(**text) x = x['last_hidden_state'] if self.conf_lang_encoder['TOKENIZER'] == 'clip': x = x[torch.arange(x.size(0)), text['input_ids'].argmax(dim=-1)] else: x = x[:, 0] x = x @ self.text_projection if norm: x = x / x.norm(dim=-1, keepdim=True) return x def forward(self, image, text): features_image = self.encode_image(image) features_text = self.encode_text(text) # cosine similarity as logits T = self.logit_scale.exp() return features_image, features_text, T def build_unicl_model(config, **kwargs): model = UniCLModel(config) if config['MODEL']['PRETRAINED'] != '': pretrained_path = config['MODEL']['PRETRAINED'] from ..Utils.Utils import is_valid_url, download_file if is_valid_url(pretrained_path): with tempfile.TemporaryDirectory() as tmp_path: file_local_path = pathlib.Path(tmp_path) / 'base_model.pt' download_file(pretrained_path, file_local_path) model.from_pretrained(str(file_local_path), config['MODEL']['PRETRAINED_LAYERS'], config['VERBOSE']) else: model.from_pretrained(pretrained_path, config['MODEL']['PRETRAINED_LAYERS'], config['VERBOSE']) return model