Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from inference_sdk import InferenceHTTPClient
|
3 |
+
from PIL import Image, ImageDraw, ImageFont
|
4 |
+
import os
|
5 |
+
from collections import defaultdict
|
6 |
+
|
7 |
+
API_KEY = os.getenv("ROBOFLOW_API_KEY")
|
8 |
+
|
9 |
+
CLIENT = InferenceHTTPClient(
|
10 |
+
api_url="https://detect.roboflow.com",
|
11 |
+
api_key=API_KEY
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
# Model settings
|
16 |
+
MODEL_ID = "hvacsym/5"
|
17 |
+
CONFIDENCE_THRESHOLD = 20 # Confidence threshold for filtering predictions
|
18 |
+
GRID_SIZE = (4, 4) # 4x4 segmentation
|
19 |
+
|
20 |
+
def process_image(image_path):
|
21 |
+
"""Processes an uploaded image and returns the final image with bounding boxes & symbol counts."""
|
22 |
+
original_image = Image.open(image_path)
|
23 |
+
width, height = original_image.size
|
24 |
+
seg_w, seg_h = width // GRID_SIZE[1], height // GRID_SIZE[0]
|
25 |
+
|
26 |
+
# Create a copy of the image for bounding boxes
|
27 |
+
final_image = original_image.copy()
|
28 |
+
draw_final = ImageDraw.Draw(final_image)
|
29 |
+
|
30 |
+
# Load font
|
31 |
+
try:
|
32 |
+
font = ImageFont.truetype("arial.ttf", 10)
|
33 |
+
except:
|
34 |
+
font = ImageFont.load_default()
|
35 |
+
|
36 |
+
# Dictionary for total counts
|
37 |
+
total_counts = defaultdict(int)
|
38 |
+
|
39 |
+
# Process each segment
|
40 |
+
for row in range(GRID_SIZE[0]):
|
41 |
+
for col in range(GRID_SIZE[1]):
|
42 |
+
x1, y1 = col * seg_w, row * seg_h
|
43 |
+
x2, y2 = (col + 1) * seg_w, (row + 1) * seg_h
|
44 |
+
|
45 |
+
segment = original_image.crop((x1, y1, x2, y2))
|
46 |
+
segment_path = f"segment_{row}_{col}.png"
|
47 |
+
segment.save(segment_path)
|
48 |
+
|
49 |
+
# Run inference
|
50 |
+
result = CLIENT.infer(segment_path, model_id=MODEL_ID)
|
51 |
+
filtered_predictions = [
|
52 |
+
pred for pred in result["predictions"] if pred["confidence"] * 100 >= CONFIDENCE_THRESHOLD
|
53 |
+
]
|
54 |
+
|
55 |
+
# Draw bounding boxes and update counts
|
56 |
+
for obj in filtered_predictions:
|
57 |
+
sx, sy, sw, sh = obj["x"], obj["y"], obj["width"], obj["height"]
|
58 |
+
class_name = obj["class"]
|
59 |
+
confidence = obj["confidence"]
|
60 |
+
total_counts[class_name] += 1
|
61 |
+
|
62 |
+
# Adjust coordinates for final image
|
63 |
+
x_min, y_min = x1 + (sx - sw // 2), y1 + (sy - sh // 2)
|
64 |
+
x_max, y_max = x1 + (sx + sw // 2), y1 + (sy + sh // 2)
|
65 |
+
|
66 |
+
# Draw bounding box
|
67 |
+
draw_final.rectangle([x_min, y_min, x_max, y_max], outline="green", width=2)
|
68 |
+
|
69 |
+
# Draw label
|
70 |
+
text = f"{class_name} {confidence:.2f}"
|
71 |
+
draw_final.text((x_min + 2, y_min - 10), text, fill="white", font=font)
|
72 |
+
|
73 |
+
# Save final image with bounding boxes
|
74 |
+
final_image_path = "final_detected_image.png"
|
75 |
+
final_image.save(final_image_path)
|
76 |
+
|
77 |
+
return final_image_path, total_counts
|
78 |
+
|
79 |
+
def process_uploaded_image(image):
|
80 |
+
"""Gradio wrapper function that calls `process_image` and formats the output."""
|
81 |
+
final_image_path, total_counts = process_image(image)
|
82 |
+
|
83 |
+
# Convert count dictionary to readable text
|
84 |
+
count_text = "\n".join([f"{label}: {count}" for label, count in total_counts.items()])
|
85 |
+
|
86 |
+
return final_image_path, count_text
|
87 |
+
|
88 |
+
# Deploy with Gradio
|
89 |
+
iface = gr.Interface(
|
90 |
+
fn=process_uploaded_image,
|
91 |
+
inputs=gr.Image(type="filepath"), # Gradio input expects a file path
|
92 |
+
outputs=[gr.Image(type="filepath"), gr.Text()],
|
93 |
+
title="HVAC Symbol Detector",
|
94 |
+
description="Upload an HVAC blueprint image. The model will detect symbols and return the final image with bounding boxes along with symbol counts."
|
95 |
+
)
|
96 |
+
|
97 |
+
# Launch the Gradio app
|
98 |
+
iface.launch()
|