Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import gradio as gr | |
| import time | |
| import clip | |
| import requests | |
| import csv | |
| import json | |
| import wget | |
| url_dict = {'clip_ViTL14_openimage_classifier_weights.pt': 'https://raw.githubusercontent.com/geonm/socratic-models-demo/master/prompts/clip_ViTL14_openimage_classifier_weights.pt', | |
| 'clip_ViTL14_place365_classifier_weights.pt': 'https://raw.githubusercontent.com/geonm/socratic-models-demo/master/prompts/clip_ViTL14_place365_classifier_weights.pt', | |
| 'clip_ViTL14_tencentml_classifier_weights.pt': 'https://raw.githubusercontent.com/geonm/socratic-models-demo/master/prompts/clip_ViTL14_tencentml_classifier_weights.pt'} | |
| os.makedirs('./prompts', exist_ok=True) | |
| for k, v in url_dict.items(): | |
| wget.download(v, out='./prompts') | |
| os.environ['CUDA_VISIBLE_DEVICES'] = '' | |
| API_URL = "https://api-inference.huggingface.co/models/bigscience/bloom" | |
| HF_TOKEN = os.environ["HF_TOKEN"] | |
| def load_openimage_classnames(csv_path): | |
| csv_data = open(csv_path) | |
| csv_reader = csv.reader(csv_data) | |
| classnames = {idx: row[-1] for idx, row in enumerate(csv_reader)} | |
| return classnames | |
| def load_tencentml_classnames(txt_path): | |
| txt_data = open(txt_path) | |
| lines = txt_data.readlines() | |
| classnames = {idx: line.strip() for idx, line in enumerate(lines)} | |
| return classnames | |
| def build_simple_classifier(clip_model, text_list, template, device): | |
| with torch.no_grad(): | |
| texts = [template(text) for text in text_list] | |
| text_inputs = clip.tokenize(texts).to(device) | |
| text_features = clip_model.encode_text(text_inputs) | |
| text_features /= text_features.norm(dim=-1, keepdim=True) | |
| return text_features, {idx: text for idx, text in enumerate(text_list)} | |
| def load_models(): | |
| # build model and tokenizer | |
| model_dict = {} | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print('\tLoading CLIP ViT-L/14') | |
| clip_model, clip_preprocess = clip.load("ViT-L/14", device=device) | |
| print('\tLoading precomputed zeroshot classifier') | |
| openimage_classifier_weights = torch.load('./prompts/clip_ViTL14_openimage_classifier_weights.pt', map_location=device).type(torch.FloatTensor) | |
| openimage_classnames = load_openimage_classnames('./prompts/openimage-classnames.csv') | |
| tencentml_classifier_weights = torch.load('./prompts/clip_ViTL14_tencentml_classifier_weights.pt', map_location=device).type(torch.FloatTensor) | |
| tencentml_classnames = load_tencentml_classnames('./prompts/tencent-ml-classnames.txt') | |
| place365_classifier_weights = torch.load('./prompts/clip_ViTL14_place365_classifier_weights.pt', map_location=device).type(torch.FloatTensor) | |
| place365_classnames = load_tencentml_classnames('./prompts/place365-classnames.txt') | |
| print('\tBuilding simple zeroshot classifier') | |
| img_types = ['photo', 'cartoon', 'sketch', 'painting'] | |
| ppl_texts = ['no people', 'people'] | |
| ifppl_texts = ['is one person', 'are two people', 'are three people', 'are several people', 'are many people'] | |
| imgtype_classifier_weights, imgtype_classnames = build_simple_classifier(clip_model, img_types, lambda c: f'This is a {c}.', device) | |
| ppl_classifier_weights, ppl_classnames = build_simple_classifier(clip_model, ppl_texts, lambda c: f'There are {c} in this photo.', device) | |
| ifppl_classifier_weights, ifppl_classnames = build_simple_classifier(clip_model, ifppl_texts, lambda c: f'There {c} in this photo.', device) | |
| model_dict['clip_model'] = clip_model | |
| model_dict['clip_preprocess'] = clip_preprocess | |
| model_dict['openimage_classifier_weights'] = openimage_classifier_weights | |
| model_dict['openimage_classnames'] = openimage_classnames | |
| model_dict['tencentml_classifier_weights'] = tencentml_classifier_weights | |
| model_dict['tencentml_classnames'] = tencentml_classnames | |
| model_dict['place365_classifier_weights'] = place365_classifier_weights | |
| model_dict['place365_classnames'] = place365_classnames | |
| model_dict['imgtype_classifier_weights'] = imgtype_classifier_weights | |
| model_dict['imgtype_classnames'] = imgtype_classnames | |
| model_dict['ppl_classifier_weights'] = ppl_classifier_weights | |
| model_dict['ppl_classnames'] = ppl_classnames | |
| model_dict['ifppl_classifier_weights'] = ifppl_classifier_weights | |
| model_dict['ifppl_classnames'] = ifppl_classnames | |
| model_dict['device'] = device | |
| return model_dict | |
| def drop_gpu(tensor): | |
| if torch.cuda.is_available(): | |
| return tensor.cpu().numpy() | |
| else: | |
| return tensor.numpy() | |
| def zeroshot_classifier(image): | |
| image_input = model_dict['clip_preprocess'](image).unsqueeze(0).to(model_dict['device']) | |
| with torch.no_grad(): | |
| image_features = model_dict['clip_model'].encode_image(image_input) | |
| image_features /= image_features.norm(dim=-1, keepdim=True) | |
| sim = (100.0 * image_features @ model_dict['openimage_classifier_weights'].T).softmax(dim=-1) | |
| openimage_scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(10)] | |
| openimage_classes = [model_dict['openimage_classnames'][idx] for idx in indices] | |
| sim = (100.0 * image_features @ model_dict['tencentml_classifier_weights'].T).softmax(dim=-1) | |
| tencentml_scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(10)] | |
| tencentml_classes = [model_dict['tencentml_classnames'][idx] for idx in indices] | |
| sim = (100.0 * image_features @ model_dict['place365_classifier_weights'].T).softmax(dim=-1) | |
| place365_scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(10)] | |
| place365_classes = [model_dict['place365_classnames'][idx] for idx in indices] | |
| sim = (100.0 * image_features @ model_dict['imgtype_classifier_weights'].T).softmax(dim=-1) | |
| imgtype_scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(len(model_dict['imgtype_classnames']))] | |
| imgtype_classes = [model_dict['imgtype_classnames'][idx] for idx in indices] | |
| sim = (100.0 * image_features @ model_dict['ppl_classifier_weights'].T).softmax(dim=-1) | |
| ppl_scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(len(model_dict['ppl_classnames']))] | |
| ppl_classes = [model_dict['ppl_classnames'][idx] for idx in indices] | |
| sim = (100.0 * image_features @ model_dict['ifppl_classifier_weights'].T).softmax(dim=-1) | |
| ifppl_scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(len(model_dict['ifppl_classnames']))] | |
| ifppl_classes = [model_dict['ifppl_classnames'][idx] for idx in indices] | |
| return image_features, openimage_scores, openimage_classes, tencentml_scores, tencentml_classes,\ | |
| place365_scores, place365_classes, imgtype_scores, imgtype_classes,\ | |
| ppl_scores, ppl_classes, ifppl_scores, ifppl_classes | |
| def generate_prompt(openimage_classes, tencentml_classes, place365_classes, imgtype_classes, ppl_classes, ifppl_classes): | |
| img_type = imgtype_classes[0] | |
| ppl_result = ppl_classes[0] | |
| if ppl_result == 'people': | |
| ppl_result = ifppl_classes[0] | |
| else: | |
| ppl_result = 'are %s' % ppl_result | |
| sorted_places = place365_classes | |
| object_list = '' | |
| for cls in tencentml_classes: | |
| object_list += f'{cls}, ' | |
| for cls in openimage_classes[:2]: | |
| object_list += f'{cls}, ' | |
| object_list = object_list[:-2] | |
| prompt_caption = f'''I am an intelligent image captioning bot. | |
| This image is a {img_type}. There {ppl_result}. | |
| I think this photo was taken at a {sorted_places[0]}, {sorted_places[1]}, or {sorted_places[2]}. | |
| I think there might be a {object_list} in this {img_type}. | |
| A creative short caption I can generate to describe this image is:''' | |
| #prompt_search = f'''Let's list keywords that include the following description. | |
| #This image is a {img_type}. There {ppl_result}. | |
| #I think this photo was taken at a {sorted_places[0]}, {sorted_places[1]}, or {sorted_places[2]}. | |
| #I think there might be a {object_list} in this {img_type}. | |
| #Relevant keywords which we can list and are seperated with comma are:''' | |
| return prompt_caption | |
| def generate_captions(prompt, num_captions=3): | |
| headers = {"Authorization": f"Bearer {HF_TOKEN}"} | |
| max_length = 16 | |
| seed = 42 | |
| sample_or_greedy = 'Greedy' | |
| input_sentence = prompt | |
| if sample_or_greedy == "Sample": | |
| parameters = { | |
| "max_new_tokens": max_length, | |
| "top_p": 0.7, | |
| "do_sample": True, | |
| "seed": seed, | |
| "early_stopping": False, | |
| "length_penalty": 0.0, | |
| "eos_token_id": None, | |
| } | |
| else: | |
| parameters = { | |
| "max_new_tokens": max_length, | |
| "do_sample": False, | |
| "seed": seed, | |
| "early_stopping": False, | |
| "length_penalty": 0.0, | |
| "eos_token_id": None, | |
| } | |
| payload = {"inputs": input_sentence, "parameters": parameters,"options" : {"use_cache": False}} | |
| bloom_results = [] | |
| for _ in range(num_captions): | |
| response = requests.post(API_URL, headers=headers, json=payload) | |
| output = response.json() | |
| generated_text = output[0]['generated_text'].replace(prompt, '').split('.')[0] + '.' | |
| bloom_results.append(generated_text) | |
| return bloom_results | |
| def sorting_texts(image_features, captions): | |
| with torch.no_grad(): | |
| text_inputs = clip.tokenize(captions).to(model_dict['device']) | |
| text_features = model_dict['clip_model'].encode_text(text_inputs) | |
| text_features /= text_features.norm(dim=-1, keepdim=True) | |
| sim = (100.0 * image_features @ text_features.T).softmax(dim=-1) | |
| scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(len(captions))] | |
| sorted_captions = [captions[idx] for idx in indices] | |
| return scores, sorted_captions | |
| def postprocess_results(scores, classes): | |
| scores = [float('%.4f' % float(val)) for val in scores] | |
| outputs = [] | |
| for score, cls in zip(scores, classes): | |
| outputs.append({'score': score, 'output': cls}) | |
| return outputs | |
| def image_captioning(image): | |
| start_time = time.time() | |
| image_features, openimage_scores, openimage_classes, tencentml_scores, tencentml_classes, place365_scores, place365_classes, imgtype_scores, imgtype_classes, ppl_scores, ppl_classes, ifppl_scores, ifppl_classes = zeroshot_classifier(image) | |
| end_zeroshot = time.time() | |
| prompt_caption = generate_prompt(openimage_classes, tencentml_classes, place365_classes, imgtype_classes, ppl_classes, ifppl_classes) | |
| generated_captions = generate_captions(prompt_caption, num_captions=1) | |
| end_bloom = time.time() | |
| caption_scores, sorted_captions = sorting_texts(image_features, generated_captions) | |
| output_dict = {} | |
| output_dict['inference_time'] = {'CLIP inference': end_zeroshot - start_time, | |
| 'BLOOM request': end_bloom - end_zeroshot} | |
| output_dict['generated_captions'] = postprocess_results(caption_scores, sorted_captions) | |
| output_dict['reasoning'] = {'openimage_results': postprocess_results(openimage_scores, openimage_classes), | |
| 'tencentml_results': postprocess_results(tencentml_scores, tencentml_classes), | |
| 'place365_results': postprocess_results(place365_scores, place365_classes), | |
| 'imgtype_results': postprocess_results(imgtype_scores, imgtype_classes), | |
| 'ppl_results': postprocess_results(ppl_scores, ppl_classes), | |
| 'ifppl_results': postprocess_results(ifppl_scores, ifppl_classes)} | |
| return output_dict | |
| if __name__ == '__main__': | |
| print('\tinit models') | |
| global model_dict | |
| model_dict = load_models() | |
| # define gradio demo | |
| inputs = [gr.inputs.Image(type="pil", label="Image") | |
| ] | |
| outputs = gr.outputs.JSON() | |
| title = "Socratic models for image captioning with BLOOM" | |
| description = """ | |
| ## Details | |
| **Without any fine-tuning**, we can do image captioning using Visual-Language models (e.g., CLIP, SLIP, ...) and Large language models (e.g., GPT, BLOOM, ...). | |
| In this demo, I choose BLOOM as the language model and CLIP ViT-L/14 as the visual-language model. | |
| The order of generating image caption is as follow: | |
| 1. Classify whether there are people, where the location is, and what objects are in the input image using the visual-language model. | |
| 2. Then, build a prompt using classified results. | |
| 3. Request BLOOM API with the prompt. | |
| This demo is slightly different with the original method proposed in the socratic model paper. | |
| I used not only tencent ml class names, but also OpenImage class names and I adopt BLOOM for the large language model | |
| If you want the demo using GPT3 from OpenAI, check https://github.com/geonm/socratic-models-demo. | |
| Demo is running on CPU. | |
| """ | |
| article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2204.00598'>Socratic Models: Composing Zero-Shot Multimodal Reasoning with Language</a></p>" | |
| examples = ['k21-1.jpg'] | |
| gr.Interface(image_captioning, | |
| inputs, | |
| outputs, | |
| title=title, | |
| description=description, | |
| article=article, | |
| examples=examples, | |
| #examples_per_page=50, | |
| ).launch() | |