Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import spaces | |
| import torch | |
| from PIL import Image | |
| import requests | |
| from transformers import DetrImageProcessor | |
| from transformers import DetrForObjectDetection | |
| from random import choice | |
| import matplotlib.pyplot as plt | |
| import io | |
| processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") | |
| model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") | |
| COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125], | |
| [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]] | |
| def get_output_figure(pil_img, scores, labels, boxes): | |
| plt.figure(figsize=(16, 10)) | |
| plt.imshow(pil_img) | |
| ax = plt.gca() | |
| colors = COLORS * 100 | |
| for score, label, (xmin, ymin, xmax, ymax), c in zip(scores.tolist(), labels.tolist(), boxes.tolist(), colors): | |
| ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3)) | |
| text = f'{model.config.id2label[label]}: {score:0.2f}' | |
| ax.text(xmin, ymin, text, fontsize=15, | |
| bbox=dict(facecolor='yellow', alpha=0.5)) | |
| plt.axis('off') | |
| return plt.gcf() | |
| def get_output_attn_figure(image, encoding, results, outputs): | |
| # keep only predictions of queries with +0.9 condifence (excluding no-object class) | |
| probas = outputs.logits.softmax(-1)[0, :, :-1] | |
| keep = probas.max(-1).values > 0.9 | |
| bboxes_scaled = results['boxes'] | |
| # use lists to store the outputs vis up-values | |
| conv_features = [] | |
| hooks = [ | |
| model.model.backbone.conv_encoder.register_forward_hook( | |
| lambda self, input, output: conv_features.append(output) | |
| ) | |
| ] | |
| # propagate through the model | |
| outputs = model(**encoding, output_attentions=True) | |
| for hook in hooks: | |
| hook.remove() | |
| # don't need the list anymore | |
| conv_features = conv_features[0] | |
| # get cross-attentions weights of last decoder layer - which is of shape (batch_size, num_heads, num_queries, width*height) | |
| dec_attn_weights = outputs.cross_attentions[-1] | |
| #average them over the 8 heads and detach from graph | |
| dec_attn_weights = torch.mean(dec_attn_weights, dim=1).detach() | |
| # get the feature map shape | |
| h, w = conv_features[-1][0].shape[-2:] | |
| fig, axs = plt.subplots(ncols=len(bboxes_scaled), nrows=2, figsize=(22, 7)) | |
| colors = COLORS * 100 | |
| for idx, ax_i, box in zip(keep.nonzero(), axs.T, bboxes_scaled): | |
| xmin, ymin, xmax, ymax = box.detach().numpy() | |
| ax = ax_i[0] | |
| ax.imshow(dec_attn_weights[0, idx].view(h, w)) | |
| ax.axis('off') | |
| ax.set_title(f'query id: {idx.item()}') | |
| ax = ax_i[1] | |
| ax.imshow(image) | |
| ax.add_patch(plt.Rectangle((xmin, ymin), xmax-xmin, ymax - ymin, fill=False, | |
| color='blue', linewidth=3)) | |
| ax.axis('off') | |
| ax.set_title(model.config.id2label[probas[idx].argmax().item()]) | |
| fig.tight_layout() | |
| return plt.gcf() | |
| def detect(image): | |
| encoding = processor(image, return_tensors='pt') | |
| print(encoding.keys()) | |
| with torch.no_grad(): | |
| outputs = model(**encoding) | |
| width, height = image.size | |
| postprocessed_outputs = processor.post_process_object_detection(outputs, target_sizes=[(height, width)], threshold=0.9) | |
| results = postprocessed_outputs[0] | |
| output_figure = get_output_figure(image, results['scores'], results['labels'], results['boxes']) | |
| buf = io.BytesIO() | |
| output_figure.savefig(buf, bbox_inches='tight') | |
| buf.seek(0) | |
| output_pil_img = Image.open(buf) | |
| output_figure_attn = get_output_attn_figure(image, encoding, results, outputs) | |
| buf = io.BytesIO() | |
| output_figure_attn.savefig(buf, bbox_inches='tight') | |
| buf.seek(0) | |
| output_pil_img_attn = Image.open(buf) | |
| return output_pil_img, output_pil_img_attn | |
| with gr.Blocks() as demo: | |
| ''' | |
| gr.Markdown("# Object detection with DETR") | |
| gr.Markdown( | |
| """ | |
| This applciation uses DETR (DEtection TRansformers) to detect objects on images. | |
| You can load an image and see the predictions for the objects detected along with the attention weights. | |
| """ | |
| ) | |
| ''' | |
| gr.Interface( | |
| fn=detect, | |
| inputs=gr.Image(label="Input image", type="pil"), | |
| outputs=[ | |
| gr.Image(label="Output prediction", type="pil"), | |
| gr.Image(label="Attention weights", type="pil") | |
| ] | |
| )#.launch() | |
| demo.launch() | |