Spaces:
Running
on
Zero
Running
on
Zero
# 1. Import the required dependencies | |
import gradio as gr | |
import torch | |
import spaces # for GPU usage | |
from PIL import Image, ImageDraw, ImageFont | |
from transformers import AutoImageProcessor, AutoModelForObjectDetection | |
# Model path = mrdbourke/rt_detrv2_finetuned_trashify_box_detector_v1 | |
# 2. Setup preprocessing and model functions - mrdbourke/rt_detrv2_finetuned_trashify_box_detector_v1 | |
model_save_path = "mrdbourke/rt_detrv2_finetuned_trashify_box_detector_v1" | |
# Load image processor | |
image_processor = AutoImageProcessor.from_pretrained(model_save_path) | |
# Default size to 640x640 for simplicity, also handles strange shaped images | |
image_processor.size = {"height": 640, | |
"width": 640} | |
# Load the model | |
model = AutoModelForObjectDetection.from_pretrained(model_save_path) | |
# Setup the target device (use GPU if it's accessible) | |
# Note: if you want to use a GPU in your Space, you can use ZeroGPU: https://huggingface.co/docs/hub/en/spaces-zerogpu | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = model.to(device) | |
# Get the id2label dictionary from the model | |
id2label = model.config.id2label | |
# Setup a color dictionary for pretty drawings | |
color_dict = { | |
"bin": "green", | |
"trash": "blue", | |
"hand": "purple", | |
"trash_arm": "yellow", | |
"not_trash": "red", | |
"not_bin": "red", | |
"not_hand": "red" | |
} | |
# 3. Create a function predict_on_image | |
# Use a GPU on a target function | |
def predict_on_image(image, conf_threshold): | |
model.eval() | |
# Make a prediction on target image | |
with torch.no_grad(): | |
inputs = image_processor(images=[image], | |
return_tensors="pt") | |
model_outputs = model(**inputs.to(device)) | |
# Get original size of image | |
# PIL.Image.size = width, height | |
# But post_process_object_detection requires height, width | |
target_sizes = torch.tensor([[image.size[1], image.size[0]]]) # -> [batch_size, height, width] | |
print(target_sizes) | |
# Post process the raw outputs from the model | |
results = image_processor.post_process_object_detection(model_outputs, | |
threshold=conf_threshold, | |
target_sizes=target_sizes)[0] | |
# Return all data items/objects to the CPU if they aren't already there | |
for key, value in results.items(): | |
try: | |
results[key] = value.item().cpu() # can't get scalars as .item() so add try/except block | |
except: | |
results[key] = value.cpu() | |
### 4. Draw the predictions on the target image ### | |
draw = ImageDraw.Draw(image) | |
# Get a font to write on our image | |
font = ImageFont.load_default(size=20) | |
# Get a list of the detect class names | |
detected_class_names_text_labels = [] | |
# Iterate through the predictions of the model and draw them on the target image | |
for box, score, label in zip(results["boxes"], results["scores"], results["labels"]): | |
# Create the coordinates | |
x, y, x2, y2 = tuple(box.tolist()) # XYXY | |
# Get the text-based label | |
label_name = id2label[label.item()] | |
targ_color = color_dict[label_name] | |
detected_class_names_text_labels.append(label_name) | |
# Draw the rectangle | |
draw.rectangle(xy=(x, y, x2, y2), | |
outline=targ_color, | |
width=3) | |
# Create the text to display on the box | |
text_string_to_show = f"{label_name} ({round(score.item(), 4)})" | |
# Draw the text on the image | |
draw.text(xy=(x, y), | |
text=text_string_to_show, | |
fill="white", | |
font=font) | |
# Remove the draw each time to make sure it doesn't get caught in memory | |
del draw | |
### 5. Create logic for outputting information message | |
# Setup set of target items to discover | |
target_items = {"trash", "bin", "hand"} | |
detected_items = set(detected_class_names_text_labels) | |
# If no items detected or bin, trash, hand not in detected items, return notification | |
if not detected_items & target_items: | |
return_string = ( | |
f"No trash, bin or hand detected at confidence threshold {conf_threshold}. " | |
"Try another image or lowering the confidence threshold." | |
) | |
print(return_string) | |
return image, return_string | |
# If there are items missing, output what's missing for +1 point | |
missing_items = target_items - detected_items | |
if missing_items: | |
return_string = ( | |
f"Detected the following items: {sorted(detected_items & target_items)}. " | |
f"Missing the following: {missing_items}. " | |
"In order to get +1 point, all target items must be detected." | |
) | |
print(return_string) | |
return image, return_string | |
# Final case, all items are detected | |
return_string = f"+1! Found the following items: {sorted(detected_items)}, thank you for cleaning up your local area!" | |
print(return_string) | |
return image, return_string | |
### 6. Setup the demo application to take in image/conf threshold, pass it through our function, show the output image/text | |
description = """ | |
Help clean up your local area! Upload an image and get +1 if there is all of the following items detected: trash, bin, hand. | |
Model is a fine-tuned version of [RT-DETRv2](https://huggingface.co/docs/transformers/main/en/model_doc/rt_detr_v2#transformers.RTDetrV2Config) on the [Trashify dataset](https://huggingface.co/datasets/mrdbourke/trashify_manual_labelled_images). | |
See the full data loading and training code on [learnhuggingface.com](https://www.learnhuggingface.com/notebooks/hugging_face_object_detection_tutorial). | |
This version is v4 because the first three versions were using a different model and did not perform as well, see the [README](https://huggingface.co/spaces/mrdbourke/trashify_demo_v4/blob/main/README.md) for more. | |
""" | |
# Create the Gradio interface | |
demo = gr.Interface( | |
fn=predict_on_image, | |
inputs=[ | |
gr.Image(type="pil", label="Target Input Image"), | |
gr.Slider(minimum=0, maximum=1, value=0.3, label="Confidence Threshold (set higher for more confident boxes)") | |
], | |
outputs=[ | |
gr.Image(type="pil", label="Target Image Output"), | |
gr.Text(label="Text Output") | |
], | |
description=description, | |
title="🚮 Trashify Object Detection Demo V4 - Video", | |
examples=[ | |
["trashify_examples/trashify_example_1.jpeg", 0.3], | |
["trashify_examples/trashify_example_2.jpeg", 0.3], | |
["trashify_examples/trashify_example_3.jpeg", 0.3], | |
], | |
cache_examples=True | |
) | |
# Launch demo | |
# demo.launch(debug=True) # run with debug=True to see errors in Google Colab | |
demo.launch() | |