Spaces:
Build error
Build error
| # -------------------------------------------------------- | |
| # X-Decoder -- Generalized Decoding for Pixel, Image, and Language | |
| # Copyright (c) 2022 Microsoft | |
| # Licensed under The MIT License [see LICENSE for details] | |
| # Written by Xueyan Zou ([email protected]) | |
| # -------------------------------------------------------- | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from torchvision import transforms | |
| from utils.visualizer import Visualizer | |
| from detectron2.utils.colormap import random_color | |
| from detectron2.data import MetadataCatalog | |
| t = [] | |
| t.append(transforms.Resize(512, interpolation=Image.BICUBIC)) | |
| transform = transforms.Compose(t) | |
| metadata = MetadataCatalog.get('ade20k_panoptic_train') | |
| def referring_segmentation(model, image, texts, inpainting_text, *args, **kwargs): | |
| model.model.metadata = metadata | |
| texts = texts.strip() | |
| texts = [[text.strip() if text.endswith('.') else (text + '.')] for text in texts.split(',')] | |
| image_ori = transform(image) | |
| with torch.no_grad(): | |
| width = image_ori.size[0] | |
| height = image_ori.size[1] | |
| image = np.asarray(image_ori) | |
| image_ori_np = np.asarray(image_ori) | |
| images = torch.from_numpy(image.copy()).permute(2,0,1).cuda() | |
| batch_inputs = [{'image': images, 'height': height, 'width': width, 'groundings': {'texts': texts}}] | |
| outputs = model.model.evaluate_grounding(batch_inputs, None) | |
| visual = Visualizer(image_ori_np, metadata=metadata) | |
| grd_mask = (outputs[0]['grounding_mask'] > 0).float().cpu().numpy() | |
| for idx, mask in enumerate(grd_mask): | |
| color = random_color(rgb=True, maximum=1).astype(np.int32).tolist() | |
| demo = visual.draw_binary_mask(mask, color=color, text=texts[idx]) | |
| res = demo.get_image() | |
| torch.cuda.empty_cache() | |
| return Image.fromarray(res), '', None |