Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -3,6 +3,12 @@ from transformers import AutoTokenizer, VisionEncoderDecoderModel, AutoImageProc
|
|
3 |
from PIL import Image
|
4 |
from torchvision.transforms.functional import crop
|
5 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
# Global variables for models
|
8 |
object_detection_model = None
|
@@ -17,8 +23,29 @@ def init():
|
|
17 |
# Step 1: Load the YOLOv5 model from Hugging Face
|
18 |
try:
|
19 |
print("Loading YOLOv5 model...")
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
except Exception as e:
|
23 |
print(f"Error loading YOLOv5 model: {e}")
|
24 |
object_detection_model = None
|
@@ -38,7 +65,8 @@ def init():
|
|
38 |
def crop_objects(image, boxes):
|
39 |
cropped_images = []
|
40 |
for box in boxes:
|
41 |
-
|
|
|
42 |
cropped_images.append(cropped_image)
|
43 |
return cropped_images
|
44 |
|
@@ -83,8 +111,11 @@ def process_image(image):
|
|
83 |
"confidence_score": float(score) # Convert to float
|
84 |
})
|
85 |
|
|
|
|
|
|
|
86 |
# Return the image with detections and the caption
|
87 |
-
return
|
88 |
|
89 |
except Exception as e:
|
90 |
return None, {"error": str(e)}, None
|
@@ -96,9 +127,11 @@ init()
|
|
96 |
interface = gr.Interface(
|
97 |
fn=process_image, # Function to run
|
98 |
inputs=gr.Image(type="pil"), # Input: Image upload
|
99 |
-
outputs=[
|
100 |
-
|
101 |
-
|
|
|
|
|
102 |
live=True
|
103 |
)
|
104 |
|
|
|
3 |
from PIL import Image
|
4 |
from torchvision.transforms.functional import crop
|
5 |
import gradio as gr
|
6 |
+
import json
|
7 |
+
import base64
|
8 |
+
import io
|
9 |
+
from huggingface_hub import hf_hub_download
|
10 |
+
import zipfile
|
11 |
+
import os
|
12 |
|
13 |
# Global variables for models
|
14 |
object_detection_model = None
|
|
|
23 |
# Step 1: Load the YOLOv5 model from Hugging Face
|
24 |
try:
|
25 |
print("Loading YOLOv5 model...")
|
26 |
+
# Get Hugging Face auth token from environment variable
|
27 |
+
auth_token = os.getenv("HF_AUTH_TOKEN")
|
28 |
+
if not auth_token:
|
29 |
+
print("Error: HF_AUTH_TOKEN environment variable not set.")
|
30 |
+
object_detection_model = None
|
31 |
+
else:
|
32 |
+
# Download the zip file from Hugging Face
|
33 |
+
zip_path = hf_hub_download(repo_id='Mexbow/Yolov5_object_detection', filename='yolov5.zip', use_auth_token=auth_token)
|
34 |
+
|
35 |
+
# Extract the YOLOv5 model
|
36 |
+
extract_path = './yolov5_model' # Specify extraction path
|
37 |
+
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
38 |
+
os.makedirs(extract_path, exist_ok=True)
|
39 |
+
zip_ref.extractall(extract_path)
|
40 |
+
|
41 |
+
# Load the YOLOv5 model
|
42 |
+
model_path = os.path.join(extract_path, 'yolov5/weights/best14.pt')
|
43 |
+
if not os.path.exists(model_path):
|
44 |
+
print(f"Error: YOLOv5 model file not found at {model_path}")
|
45 |
+
object_detection_model = None
|
46 |
+
else:
|
47 |
+
object_detection_model = torch.hub.load('ultralytics/yolov5', 'custom', path=model_path, trust_repo=True)
|
48 |
+
print("YOLOv5 model loaded successfully.")
|
49 |
except Exception as e:
|
50 |
print(f"Error loading YOLOv5 model: {e}")
|
51 |
object_detection_model = None
|
|
|
65 |
def crop_objects(image, boxes):
|
66 |
cropped_images = []
|
67 |
for box in boxes:
|
68 |
+
left, top, right, bottom = box
|
69 |
+
cropped_image = image.crop((left, top, right, bottom))
|
70 |
cropped_images.append(cropped_image)
|
71 |
return cropped_images
|
72 |
|
|
|
111 |
"confidence_score": float(score) # Convert to float
|
112 |
})
|
113 |
|
114 |
+
# Render image with bounding boxes
|
115 |
+
result_image = results.render()[0]
|
116 |
+
|
117 |
# Return the image with detections and the caption
|
118 |
+
return result_image, detection_results, original_caption
|
119 |
|
120 |
except Exception as e:
|
121 |
return None, {"error": str(e)}, None
|
|
|
127 |
interface = gr.Interface(
|
128 |
fn=process_image, # Function to run
|
129 |
inputs=gr.Image(type="pil"), # Input: Image upload
|
130 |
+
outputs=[
|
131 |
+
gr.Image(type="pil", label="Detected Objects"), # Output 1: Image with bounding boxes
|
132 |
+
gr.JSON(label="Object Captions & Bounding Boxes"), # Output 2: JSON results for each object
|
133 |
+
gr.Textbox(label="Whole Image Caption") # Output 3: Caption for the whole image
|
134 |
+
],
|
135 |
live=True
|
136 |
)
|
137 |
|