Spaces:
Sleeping
Sleeping
| from PIL import Image | |
| import requests | |
| import torch | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import io | |
| import gradio as gr | |
| from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation | |
| processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") | |
| model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") | |
| def visualize_segmentation(image, prompts, preds): | |
| fig, ax = plt.subplots(1, len(prompts) + 1, figsize=(3*(len(prompts) + 1), 4)) | |
| [a.axis('off') for a in ax.flatten()] | |
| ax[0].imshow(image) | |
| [ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(len(prompts))]; | |
| [ax[i+1].text(0, -15, prompt) for i, prompt in enumerate(prompts)]; | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format='png') | |
| plt.close(fig) | |
| return Image.open(buf) | |
| def segment(img, clases): | |
| image = Image.fromarray(img, 'RGB') | |
| prompts = clases.split(',') | |
| inputs = processor(text=prompts, images=[image] * len(prompts), padding="max_length", return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| preds = outputs.logits.unsqueeze(1) | |
| return visualize_segmentation(image, prompts, preds) | |
| demo = gr.Interface(fn = segment, | |
| inputs = ["image", gr.Textbox(label = 'Enter classes separated by ","')], | |
| outputs = "image", | |
| examples = [['desayuno.jpg', 'cutlery, pancakes, blueberries, orange juice']] | |
| ) | |
| demo.launch() | |