File size: 6,617 Bytes
2fc8f92
 
 
 
34e9d62
2fc8f92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34e9d62
2fc8f92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179

# 1. Import the required dependencies
import gradio as gr
import torch 
import spaces # for GPU usage

from PIL import Image, ImageDraw, ImageFont 

from transformers import AutoImageProcessor, AutoModelForObjectDetection

# Model path = mrdbourke/rt_detrv2_finetuned_trashify_box_detector_v1 

# 2. Setup preprocessing and model functions - mrdbourke/rt_detrv2_finetuned_trashify_box_detector_v1 
model_save_path = "mrdbourke/rt_detrv2_finetuned_trashify_box_detector_v1"

# Load image processor
image_processor = AutoImageProcessor.from_pretrained(model_save_path)

# Default size to 640x640 for simplicity, also handles strange shaped images 
image_processor.size = {"height": 640, 
                        "width": 640}

# Load the model
model = AutoModelForObjectDetection.from_pretrained(model_save_path)

# Setup the target device (use GPU if it's accessible)
# Note: if you want to use a GPU in your Space, you can use ZeroGPU: https://huggingface.co/docs/hub/en/spaces-zerogpu 
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

# Get the id2label dictionary from the model
id2label = model.config.id2label

# Setup a color dictionary for pretty drawings
color_dict = {
    "bin": "green",
    "trash": "blue",
    "hand": "purple",
    "trash_arm": "yellow",
    "not_trash": "red",
    "not_bin": "red",
    "not_hand": "red"
}

# 3. Create a function predict_on_image
# Use a GPU on a target function
@spaces.GPU
def predict_on_image(image, conf_threshold):
  model.eval()

  # Make a prediction on target image
  with torch.no_grad():
    inputs = image_processor(images=[image], 
                             return_tensors="pt")
    model_outputs = model(**inputs.to(device))

    # Get original size of image
    # PIL.Image.size = width, height
    # But post_process_object_detection requires height, width
    target_sizes = torch.tensor([[image.size[1], image.size[0]]]) # -> [batch_size, height, width]

    print(target_sizes)

    # Post process the raw outputs from the model
    results = image_processor.post_process_object_detection(model_outputs,
                                                            threshold=conf_threshold,
                                                            target_sizes=target_sizes)[0]
  
  # Return all data items/objects to the CPU if they aren't already there
  for key, value in results.items():
    try:
      results[key] = value.item().cpu() # can't get scalars as .item() so add try/except block
    except:
      results[key] = value.cpu()
  
  ### 4. Draw the predictions on the target image ###
  draw = ImageDraw.Draw(image)

  # Get a font to write on our image
  font = ImageFont.load_default(size=20)

  # Get a list of the detect class names
  detected_class_names_text_labels = []

  # Iterate through the predictions of the model and draw them on the target image
  for box, score, label in zip(results["boxes"], results["scores"], results["labels"]):
    # Create the coordinates
    x, y, x2, y2 = tuple(box.tolist()) # XYXY 

    # Get the text-based label
    label_name = id2label[label.item()]
    targ_color = color_dict[label_name]
    detected_class_names_text_labels.append(label_name)

    # Draw the rectangle
    draw.rectangle(xy=(x, y, x2, y2),
                   outline=targ_color,
                   width=3)
    
    # Create the text to display on the box
    text_string_to_show = f"{label_name} ({round(score.item(), 4)})"

    # Draw the text on the image
    draw.text(xy=(x, y),
              text=text_string_to_show,
              fill="white",
              font=font)

  # Remove the draw each time to make sure it doesn't get caught in memory
  del draw

  ### 5. Create logic for outputting information message
  
  # Setup set of target items to discover
  target_items = {"trash", "bin", "hand"}
  detected_items = set(detected_class_names_text_labels)

  # If no items detected or bin, trash, hand not in detected items, return notification
  if not detected_items & target_items:
    return_string = (
        f"No trash, bin or hand detected at confidence threshold {conf_threshold}. "
        "Try another image or lowering the confidence threshold."
    )
    print(return_string)
    return image, return_string

  # If there are items missing, output what's missing for +1 point
  missing_items = target_items - detected_items
  if missing_items:
    return_string =  (
        f"Detected the following items: {sorted(detected_items & target_items)}. "
        f"Missing the following: {missing_items}. "
        "In order to get +1 point, all target items must be detected."
    )
    print(return_string)
    return image, return_string
  
  # Final case, all items are detected
  return_string = f"+1! Found the following items: {sorted(detected_items)}, thank you for cleaning up your local area!"
  print(return_string)
  return image, return_string

### 6. Setup the demo application to take in image/conf threshold, pass it through our function, show the output image/text

description = """
Help clean up your local area! Upload an image and get +1 if there is all of the following items detected: trash, bin, hand.

Model is a fine-tuned version of [RT-DETRv2](https://huggingface.co/docs/transformers/main/en/model_doc/rt_detr_v2#transformers.RTDetrV2Config) on the [Trashify dataset](https://huggingface.co/datasets/mrdbourke/trashify_manual_labelled_images).

See the full data loading and training code on [learnhuggingface.com](https://www.learnhuggingface.com/notebooks/hugging_face_object_detection_tutorial).

This version is v4 because the first three versions were using a different model and did not perform as well, see the [README](https://huggingface.co/spaces/mrdbourke/trashify_demo_v4/blob/main/README.md) for more.
"""

# Create the Gradio interface
demo = gr.Interface(
    fn=predict_on_image,
    inputs=[
        gr.Image(type="pil", label="Target Input Image"),
        gr.Slider(minimum=0, maximum=1, value=0.3, label="Confidence Threshold (set higher for more confident boxes)")
    ],
    outputs=[
        gr.Image(type="pil", label="Target Image Output"),
        gr.Text(label="Text Output")
    ],
    description=description,
    title="🚮 Trashify Object Detection Demo V4 - Video", 
    examples=[
        ["trashify_examples/trashify_example_1.jpeg", 0.3],
        ["trashify_examples/trashify_example_2.jpeg", 0.3],
        ["trashify_examples/trashify_example_3.jpeg", 0.3],
    ],
    cache_examples=True
)

# Launch demo
# demo.launch(debug=True) # run with debug=True to see errors in Google Colab
demo.launch()