Spaces:
No application file
No application file
| from typing import List, Tuple, Dict | |
| import os | |
| import time | |
| from tqdm import tqdm | |
| import torch | |
| import numpy as np | |
| from numpy import ndarray | |
| from PIL import Image | |
| from transformers import BertForSequenceClassification, BertTokenizer, CLIPProcessor, CLIPModel | |
| class TextFeatureExtractor(object): | |
| def __init__(self, language_model_path: str, local_file: bool=True, device: str='cpu'): | |
| if device: | |
| self.device = device | |
| else: | |
| self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| language_model_path = "Taiyi-CLIP-Roberta-large-326M-Chinese" if local_file else "IDEA-CCNL/Taiyi-CLIP-Roberta-large-326M-Chinese" | |
| self.text_tokenizer = BertTokenizer.from_pretrained(language_model_path, local_files_only=local_file) | |
| self.text_encoder = BertForSequenceClassification.from_pretrained(language_model_path, local_files_only=local_file).eval().to(self.device) | |
| def text(self, query_texts: List[str]) -> ndarray: | |
| text = self.text_tokenizer(query_texts, return_tensors='pt', padding=True, truncation=True, max_length=self.text_encoder.config.max_length)['input_ids'] | |
| text = text.to(self.device) | |
| with torch.no_grad(): | |
| text_features = self.text_encoder(text).logits | |
| text_features = text_features / text_features.norm(dim=1, keepdim=True) | |
| text_features = text_features.squeeze | |
| return text_features.detach().cpu().numpy() | |
| class TaiyiFeatureExtractor(TextFeatureExtractor): | |
| def __init__(self, language_model_path: str="Taiyi-CLIP-Roberta-large-326M-Chinese", local_file: bool = True, device: str = 'cpu'): | |
| """_summary_ | |
| Args: | |
| language_model_path (str, optional): Taiyi-CLIP-Roberta-large-326M-Chinese or IDEA-CCNL/Taiyi-CLIP-Roberta-large-326M-Chinese. Defaults to "Taiyi-CLIP-Roberta-large-326M-Chinese". | |
| local_file (bool, optional): _description_. Defaults to True. | |
| device (str, optional): _description_. Defaults to 'cpu'. | |
| """ | |
| super().__init__(language_model_path, local_file, device) | |