HimanshuGoyal2004 commited on
Commit
3c8962b
·
verified ·
1 Parent(s): 81a03c8

Uploading Trashify box detection model app.py

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ trashify_examples/trashify_example_2.jpeg filter=lfs diff=lfs merge=lfs -text
37
+ trashify_examples/trashify_example_3.jpeg filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,36 @@
1
  ---
2
- title: Trashify Demo V4
3
- emoji: 📚
4
- colorFrom: yellow
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 5.35.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Trashify Demo V4 🚮
3
+ emoji: 🗑️
4
+ colorFrom: purple
5
+ colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 5.34.0
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
  ---
12
 
13
+ # 🚮 Trashify Object Detector V4
14
+
15
+ Object detection demo to detect `trash`, `bin`, `hand`, `trash_arm`, `not_trash`, `not_bin`, `not_hand`.
16
+
17
+ Used as example for encouraging people to cleanup their local area.
18
+
19
+ If `trash`, `hand`, `bin` all detected = +1 point.
20
+
21
+ ## Dataset
22
+
23
+ All Trashify models are trained on a custom hand-labelled dataset of people picking up trash and placing it in a bin.
24
+
25
+ The dataset can be found on Hugging Face as [`HimanshuGoyal2004/trashify_manual_labelled_images`](https://huggingface.co/datasets/HimanshuGoyal2004/trashify_manual_labelled_images).
26
+
27
+ ## Demos
28
+
29
+ * [V1](https://huggingface.co/spaces/HimanshuGoyal2004/trashify_demo_v1) = Fine-tuned [Conditional DETR](https://huggingface.co/docs/transformers/en/model_doc/conditional_detr) model trained *without* data augmentation.
30
+ * [V2](https://huggingface.co/spaces/HimanshuGoyal2004/trashify_demo_v2) = Fine-tuned Conditional DETR model trained *with* data augmentation.
31
+ * [V3](https://huggingface.co/spaces/HimanshuGoyal2004/trashify_demo_v3) = Fine-tuned Conditional DETR model trained *with* data augmentation (same as V2) with an NMS (Non Maximum Suppression) post-processing step.
32
+ * [V4](https://huggingface.co/spaces/HimanshuGoyal2004/trashify_demo_v4) = Fine-tuned [RT-DETRv2](https://huggingface.co/docs/transformers/main/en/model_doc/rt_detr_v2) model trained *without* data augmentation or NMS post-processing (current best mAP).
33
+
34
+ ## Learn more
35
+
36
+ See the full end-to-end code of how this demo was built at [learnhuggingface.com](https://www.learnhuggingface.com/notebooks/hugging_face_object_detection_tutorial).
app.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # 1. Import the required libraries and packages
3
+ import gradio as gr
4
+ import torch
5
+ from PIL import Image, ImageDraw, ImageFont # could also use torch utilities for drawing
6
+
7
+ from transformers import AutoImageProcessor
8
+ from transformers import AutoModelForObjectDetection
9
+
10
+ ### 2. Setup preprocessing and helper functions ###
11
+
12
+ # Setup target model path to load
13
+ # Note: Can load from Hugging Face or can load from local
14
+ model_save_path = "HimanshuGoyal2004/rt_detrv2_finetuned_trashify_box_detector_v1"
15
+
16
+ # Load the model and preprocessor
17
+ # Because this app.py file is running directly on Hugging Face Spaces, the model will be loaded from the Hugging Face Hub
18
+ image_processor = AutoImageProcessor.from_pretrained(model_save_path)
19
+ model = AutoModelForObjectDetection.from_pretrained(model_save_path)
20
+
21
+ # Set the target device (use CUDA/GPU if it is available)
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ model = model.to(device)
24
+
25
+ # Get the id2label dictionary from the model
26
+ id2label = model.config.id2label
27
+
28
+ # Set up a colour dictionary for plotting boxes with different colours
29
+ color_dict = {
30
+ "bin": "green",
31
+ "trash": "blue",
32
+ "hand": "purple",
33
+ "trash_arm": "yellow",
34
+ "not_trash": "red",
35
+ "not_bin": "red",
36
+ "not_hand": "red",
37
+ }
38
+
39
+ # Create helper functions for seeing if items from one list are in another
40
+ def any_in_list(list_a, list_b):
41
+ "Returns True if *any* item from list_a is in list_b, otherwise False."
42
+ return any(item in list_b for item in list_a)
43
+
44
+ def all_in_list(list_a, list_b):
45
+ "Returns True if *all* items from list_a are in list_b, otherwise False."
46
+ return all(item in list_b for item in list_a)
47
+
48
+ ### 3. Create function to predict on a given image with a given confidence threshold ###
49
+ def predict_on_image(image, conf_threshold):
50
+ # Make sure model is in eval mode
51
+ model.eval()
52
+
53
+ # Make a prediction on target image
54
+ with torch.no_grad():
55
+ inputs = image_processor(images=[image], return_tensors="pt")
56
+ model_outputs = model(**inputs.to(device))
57
+
58
+ target_sizes = torch.tensor([[image.size[1], image.size[0]]]) # -> [batch_size, height, width]
59
+
60
+ # Post process the raw outputs from the model
61
+ results = image_processor.post_process_object_detection(model_outputs,
62
+ threshold=conf_threshold,
63
+ target_sizes=target_sizes)[0]
64
+
65
+ # Return all items in results to CPU (we'll want this for displaying outputs with matplotlib)
66
+ for key, value in results.items():
67
+ try:
68
+ results[key] = value.item().cpu() # can't get scalar as .item() so add try/except block
69
+ except:
70
+ results[key] = value.cpu()
71
+
72
+ ### 4. Draw the predictions on the target image ###
73
+
74
+ # Can return results as plotted on a PIL image (then display the image)
75
+ draw = ImageDraw.Draw(image)
76
+
77
+ # Get a font from ImageFont
78
+ font = ImageFont.load_default(size=20)
79
+
80
+ # Get class names as text for print out
81
+ class_name_text_labels = []
82
+
83
+ # Iterate through the predictions of the model and draw them on the target image
84
+ for box, score, label in zip(results["boxes"], results["scores"], results["labels"]):
85
+ # Create coordinates
86
+ x, y, x2, y2 = tuple(box.tolist())
87
+
88
+ # Get label_name
89
+ label_name = id2label[label.item()]
90
+ targ_color = color_dict[label_name]
91
+ class_name_text_labels.append(label_name)
92
+
93
+ # Draw the rectangle
94
+ draw.rectangle(xy=(x, y, x2, y2),
95
+ outline=targ_color,
96
+ width=3)
97
+
98
+ # Create a text string to display
99
+ text_string_to_show = f"{label_name} ({round(score.item(), 3)})"
100
+
101
+ # Draw the text on the image
102
+ draw.text(xy=(x, y),
103
+ text=text_string_to_show,
104
+ fill="white",
105
+ font=font)
106
+
107
+ # Remove the draw each time
108
+ del draw
109
+
110
+ # Setup blank string to print out
111
+ return_string = ""
112
+
113
+ # Setup list of target items to discover
114
+ target_items = ["trash", "bin", "hand"]
115
+
116
+ ### 5. Create logic for outputting information message ###
117
+
118
+ # If no items detected or trash, bin, hand not in list, return notification
119
+ if (len(class_name_text_labels) == 0) or not (any_in_list(list_a=target_items, list_b=class_name_text_labels)):
120
+ return_string = f"No trash, bin or hand detected at confidence threshold {conf_threshold}. Try another image or lowering the confidence threshold."
121
+ return image, return_string
122
+
123
+ # If there are some missing, print the ones which are missing
124
+ elif not all_in_list(list_a=target_items, list_b=class_name_text_labels):
125
+ missing_items = []
126
+ for item in target_items:
127
+ if item not in class_name_text_labels:
128
+ missing_items.append(item)
129
+ return_string = f"Detected the following items: {class_name_text_labels}. But missing the following in order to get +1: {missing_items}. If this is an error, try another image or altering the confidence threshold. Otherwise, the model may need to be updated with better data."
130
+
131
+ # If all 3 trash, bin, hand occur = + 1
132
+ if all_in_list(list_a=target_items, list_b=class_name_text_labels):
133
+ return_string = f"+1! Found the following items: {class_name_text_labels}, thank you for cleaning up the area!"
134
+
135
+ print(return_string)
136
+
137
+ return image, return_string
138
+
139
+ ### 6. Setup the demo application to take in image, make a prediction with our model, return the image with drawn predicitons ###
140
+
141
+ # Write description for our demo application
142
+ description = """
143
+ Help clean up your local area! Upload an image and get +1 if there is all of the following items detected: trash, bin, hand.
144
+
145
+ Model is a fine-tuned version of [RT-DETRv2](https://huggingface.co/docs/transformers/main/en/model_doc/rt_detr_v2#transformers.RTDetrV2Config) on the [Trashify dataset](https://huggingface.co/datasets/HimanshuGoyal2004/trashify_manual_labelled_images).
146
+
147
+ See the full data loading and training code on [learnhuggingface.com](https://www.learnhuggingface.com/notebooks/hugging_face_object_detection_tutorial).
148
+
149
+ This version is v4 because the first three versions were using a different model and did not perform as well, see the [README](https://huggingface.co/spaces/HimanshuGoyal2004/trashify_demo_v4/blob/main/README.md) for more.
150
+ """
151
+
152
+ # Create the Gradio interface to accept an image and confidence threshold and return an image with drawn prediction boxes
153
+ demo = gr.Interface(
154
+ fn=predict_on_image,
155
+ inputs=[
156
+ gr.Image(type="pil", label="Target Image"),
157
+ gr.Slider(minimum=0, maximum=1, value=0.3, label="Confidence Threshold")
158
+ ],
159
+ outputs=[
160
+ gr.Image(type="pil", label="Image Output"),
161
+ gr.Text(label="Text Output")
162
+ ],
163
+ title="🚮 Trashify Object Detection Demo V4",
164
+ description=description,
165
+ # Examples come in the form of a list of lists, where each inner list contains elements to prefill the `inputs` parameter with
166
+ # See where the examples originate from here: https://huggingface.co/datasets/HimanshuGoyal2004/trashify_examples/
167
+ examples=[
168
+ ["trashify_examples/trashify_example_1.jpeg", 0.3],
169
+ ["trashify_examples/trashify_example_2.jpeg", 0.3],
170
+ ["trashify_examples/trashify_example_3.jpeg", 0.3],
171
+ ],
172
+ cache_examples=True
173
+ )
174
+
175
+ # Launch the demo
176
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ timm gradio torch transformers
trashify_examples/trashify_example_1.jpeg ADDED
trashify_examples/trashify_example_2.jpeg ADDED

Git LFS Details

  • SHA256: e1c170311bdc358d5158049f42aa38fba3794c91bcb2d11578f7eb92d924c55c
  • Pointer size: 131 Bytes
  • Size of remote file: 361 kB
trashify_examples/trashify_example_3.jpeg ADDED

Git LFS Details

  • SHA256: 666068a4e4e92384bce54c5f9fa533ccef96da46df065e8760f03d49a04e3fd3
  • Pointer size: 131 Bytes
  • Size of remote file: 278 kB