|
|
|
|
|
import gradio as gr |
|
import torch |
|
from PIL import Image, ImageDraw, ImageFont |
|
|
|
from transformers import AutoImageProcessor |
|
from transformers import AutoModelForObjectDetection |
|
|
|
|
|
|
|
|
|
|
|
model_save_path = "HimanshuGoyal2004/rt_detrv2_finetuned_trashify_box_detector_v1" |
|
|
|
|
|
|
|
image_processor = AutoImageProcessor.from_pretrained(model_save_path) |
|
model = AutoModelForObjectDetection.from_pretrained(model_save_path) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model = model.to(device) |
|
|
|
|
|
id2label = model.config.id2label |
|
|
|
|
|
color_dict = { |
|
"bin": "green", |
|
"trash": "blue", |
|
"hand": "purple", |
|
"trash_arm": "yellow", |
|
"not_trash": "red", |
|
"not_bin": "red", |
|
"not_hand": "red", |
|
} |
|
|
|
|
|
def any_in_list(list_a, list_b): |
|
"Returns True if *any* item from list_a is in list_b, otherwise False." |
|
return any(item in list_b for item in list_a) |
|
|
|
def all_in_list(list_a, list_b): |
|
"Returns True if *all* items from list_a are in list_b, otherwise False." |
|
return all(item in list_b for item in list_a) |
|
|
|
|
|
def predict_on_image(image, conf_threshold): |
|
|
|
model.eval() |
|
|
|
|
|
with torch.no_grad(): |
|
inputs = image_processor(images=[image], return_tensors="pt") |
|
model_outputs = model(**inputs.to(device)) |
|
|
|
target_sizes = torch.tensor([[image.size[1], image.size[0]]]) |
|
|
|
|
|
results = image_processor.post_process_object_detection(model_outputs, |
|
threshold=conf_threshold, |
|
target_sizes=target_sizes)[0] |
|
|
|
|
|
for key, value in results.items(): |
|
try: |
|
results[key] = value.item().cpu() |
|
except: |
|
results[key] = value.cpu() |
|
|
|
|
|
|
|
|
|
draw = ImageDraw.Draw(image) |
|
|
|
|
|
font = ImageFont.load_default(size=20) |
|
|
|
|
|
class_name_text_labels = [] |
|
|
|
|
|
for box, score, label in zip(results["boxes"], results["scores"], results["labels"]): |
|
|
|
x, y, x2, y2 = tuple(box.tolist()) |
|
|
|
|
|
label_name = id2label[label.item()] |
|
targ_color = color_dict[label_name] |
|
class_name_text_labels.append(label_name) |
|
|
|
|
|
draw.rectangle(xy=(x, y, x2, y2), |
|
outline=targ_color, |
|
width=3) |
|
|
|
|
|
text_string_to_show = f"{label_name} ({round(score.item(), 3)})" |
|
|
|
|
|
draw.text(xy=(x, y), |
|
text=text_string_to_show, |
|
fill="white", |
|
font=font) |
|
|
|
|
|
del draw |
|
|
|
|
|
return_string = "" |
|
|
|
|
|
target_items = ["trash", "bin", "hand"] |
|
|
|
|
|
|
|
|
|
if (len(class_name_text_labels) == 0) or not (any_in_list(list_a=target_items, list_b=class_name_text_labels)): |
|
return_string = f"No trash, bin or hand detected at confidence threshold {conf_threshold}. Try another image or lowering the confidence threshold." |
|
return image, return_string |
|
|
|
|
|
elif not all_in_list(list_a=target_items, list_b=class_name_text_labels): |
|
missing_items = [] |
|
for item in target_items: |
|
if item not in class_name_text_labels: |
|
missing_items.append(item) |
|
return_string = f"Detected the following items: {class_name_text_labels}. But missing the following in order to get +1: {missing_items}. If this is an error, try another image or altering the confidence threshold. Otherwise, the model may need to be updated with better data." |
|
|
|
|
|
if all_in_list(list_a=target_items, list_b=class_name_text_labels): |
|
return_string = f"+1! Found the following items: {class_name_text_labels}, thank you for cleaning up the area!" |
|
|
|
print(return_string) |
|
|
|
return image, return_string |
|
|
|
|
|
|
|
|
|
description = """ |
|
Help clean up your local area! Upload an image and get +1 if there is all of the following items detected: trash, bin, hand. |
|
|
|
Model is a fine-tuned version of [RT-DETRv2](https://huggingface.co/docs/transformers/main/en/model_doc/rt_detr_v2#transformers.RTDetrV2Config) on the [Trashify dataset](https://huggingface.co/datasets/mrdbourke/trashify_manual_labelled_images). |
|
|
|
See the full data loading and training code on this [google colab notebook](https://colab.research.google.com/drive/1BBMETl2eSEhcj0oTvuOq4mg4Doocibth?usp=sharing). |
|
|
|
See the [README](https://huggingface.co/spaces/HimanshuGoyal2004/trashify_demo_v4/blob/main/README.md) for more. |
|
""" |
|
|
|
|
|
demo = gr.Interface( |
|
fn=predict_on_image, |
|
inputs=[ |
|
gr.Image(type="pil", label="Target Image"), |
|
gr.Slider(minimum=0, maximum=1, value=0.3, label="Confidence Threshold") |
|
], |
|
outputs=[ |
|
gr.Image(type="pil", label="Image Output"), |
|
gr.Text(label="Text Output") |
|
], |
|
title="🚮 Trashify Object Detection Demo", |
|
description=description, |
|
|
|
|
|
examples=[ |
|
["trashify_examples/trashify_example_1.jpeg", 0.3], |
|
["trashify_examples/trashify_example_2.jpeg", 0.3], |
|
["trashify_examples/trashify_example_3.jpeg", 0.3], |
|
], |
|
cache_examples=True |
|
) |
|
|
|
|
|
demo.launch() |
|
|