Spaces:
Build error
Build error
| import argparse | |
| import requests | |
| import gradio as gr | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| import torch.nn as nn | |
| from PIL import Image | |
| from pathlib import Path | |
| from torchvision import transforms | |
| from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |
| from timm.data import create_transform | |
| from config import get_config | |
| from model import build_model | |
| # Download human-readable labels for ImageNet. | |
| response = requests.get("https://git.io/JJkYN") | |
| labels = response.text.split("\n") | |
| def parse_option(): | |
| parser = argparse.ArgumentParser('UniCL demo script', add_help=False) | |
| parser.add_argument('--cfg', type=str, default="configs/unicl_swin_base.yaml", metavar="FILE", help='path to config file', ) | |
| args, unparsed = parser.parse_known_args() | |
| config = get_config(args) | |
| return args, config | |
| def build_transforms(img_size, center_crop=True): | |
| t = [transforms.ToPILImage()] | |
| if center_crop: | |
| size = int((256 / 224) * img_size) | |
| t.append( | |
| transforms.Resize(size) | |
| ) | |
| t.append( | |
| transforms.CenterCrop(img_size) | |
| ) | |
| else: | |
| t.append( | |
| transforms.Resize(img_size) | |
| ) | |
| t.append(transforms.ToTensor()) | |
| t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) | |
| return transforms.Compose(t) | |
| def build_transforms4display(img_size, center_crop=True): | |
| t = [transforms.ToPILImage()] | |
| if center_crop: | |
| size = int((256 / 224) * img_size) | |
| t.append( | |
| transforms.Resize(size) | |
| ) | |
| t.append( | |
| transforms.CenterCrop(img_size) | |
| ) | |
| else: | |
| t.append( | |
| transforms.Resize(img_size) | |
| ) | |
| t.append(transforms.ToTensor()) | |
| return transforms.Compose(t) | |
| args, config = parse_option() | |
| ''' | |
| build model | |
| ''' | |
| model = build_model(config) | |
| url = './in21k_yfcc14m_gcc15m_swin_base.pth' | |
| checkpoint = torch.load(url, map_location="cpu") | |
| model.load_state_dict(checkpoint["model"]) | |
| model.eval() | |
| ''' | |
| build data transform | |
| ''' | |
| eval_transforms = build_transforms(224, center_crop=True) | |
| display_transforms = build_transforms4display(224, center_crop=True) | |
| ''' | |
| build upsampler | |
| ''' | |
| # upsampler = nn.Upsample(scale_factor=16, mode='bilinear') | |
| ''' | |
| borrow code from here: https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/image.py | |
| ''' | |
| def show_cam_on_image(img: np.ndarray, | |
| mask: np.ndarray, | |
| use_rgb: bool = False, | |
| colormap: int = cv2.COLORMAP_JET) -> np.ndarray: | |
| """ This function overlays the cam mask on the image as an heatmap. | |
| By default the heatmap is in BGR format. | |
| :param img: The base image in RGB or BGR format. | |
| :param mask: The cam mask. | |
| :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format. | |
| :param colormap: The OpenCV colormap to be used. | |
| :returns: The default image with the cam overlay. | |
| """ | |
| heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap) | |
| if use_rgb: | |
| heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) | |
| heatmap = np.float32(heatmap) / 255 | |
| if np.max(img) > 1: | |
| raise Exception( | |
| "The input image should np.float32 in the range [0, 1]") | |
| cam = 0.7*heatmap + 0.3*img | |
| # cam = cam / np.max(cam) | |
| return np.uint8(255 * cam) | |
| def recognize_image(image, texts): | |
| img_t = eval_transforms(image) | |
| img_d = display_transforms(image).permute(1, 2, 0).numpy() | |
| text_embeddings = model.get_text_embeddings(texts.split(';')) | |
| # compute output | |
| feat_img, feat_map, H, W = model.encode_image(img_t.unsqueeze(0), output_map=True) | |
| output = model.logit_scale.exp() * feat_img @ text_embeddings.t() | |
| prediction = output.softmax(-1).flatten() | |
| # generate feat map given the top matched texts | |
| output_map = (feat_map * text_embeddings[prediction.argmax()].unsqueeze(-1)).sum(1).softmax(-1) | |
| output_map = output_map.view(1, 1, H, W) | |
| output_map = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(output_map) | |
| output_map = output_map.squeeze(1).detach().permute(1, 2, 0).numpy() | |
| output_map = (output_map - output_map.min()) / (output_map.max() - output_map.min()) | |
| heatmap = show_cam_on_image(img_d, output_map, use_rgb=True) | |
| show_img = np.concatenate((np.uint8(255 * img_d), heatmap), 1) | |
| return {texts.split(';')[i]: float(prediction[i]) for i in range(len(texts.split(';')))}, Image.fromarray(show_img) | |
| image = gr.inputs.Image() | |
| label = gr.outputs.Label(num_top_classes=100) | |
| description = "UniCL for Zero-shot Image Recognition. Given an image, our model maps it to an arbitary text in a candidate pool." | |
| gr.Interface( | |
| description=description, | |
| fn=recognize_image, | |
| inputs=["image", "text"], | |
| outputs=[ | |
| label, | |
| gr.outputs.Image( | |
| type="pil", | |
| label="crop input/heat map"), | |
| ], | |
| examples=[ | |
| ["./elephants.png", "an elephant; an elephant walking in the river; four elephants walking in the river"], | |
| ["./apple_with_ipod.jpg", "an ipod; an apple with a write note 'ipod'; an apple"], | |
| ["./crowd2.jpg", "a street; a street with a woman walking in the middle; a street with a man walking in the middle"], | |
| ], | |
| article=Path("docs/intro.md").read_text() | |
| ).launch() | |