shrey14 commited on
Commit
ba7630e
·
verified ·
1 Parent(s): ad011c4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -0
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from inference_sdk import InferenceHTTPClient
3
+ from PIL import Image, ImageDraw, ImageFont
4
+ import os
5
+ from collections import defaultdict
6
+
7
+ API_KEY = os.getenv("ROBOFLOW_API_KEY")
8
+
9
+ CLIENT = InferenceHTTPClient(
10
+ api_url="https://detect.roboflow.com",
11
+ api_key=API_KEY
12
+ )
13
+
14
+
15
+ # Model settings
16
+ MODEL_ID = "hvacsym/5"
17
+ CONFIDENCE_THRESHOLD = 20 # Confidence threshold for filtering predictions
18
+ GRID_SIZE = (4, 4) # 4x4 segmentation
19
+
20
+ def process_image(image_path):
21
+ """Processes an uploaded image and returns the final image with bounding boxes & symbol counts."""
22
+ original_image = Image.open(image_path)
23
+ width, height = original_image.size
24
+ seg_w, seg_h = width // GRID_SIZE[1], height // GRID_SIZE[0]
25
+
26
+ # Create a copy of the image for bounding boxes
27
+ final_image = original_image.copy()
28
+ draw_final = ImageDraw.Draw(final_image)
29
+
30
+ # Load font
31
+ try:
32
+ font = ImageFont.truetype("arial.ttf", 10)
33
+ except:
34
+ font = ImageFont.load_default()
35
+
36
+ # Dictionary for total counts
37
+ total_counts = defaultdict(int)
38
+
39
+ # Process each segment
40
+ for row in range(GRID_SIZE[0]):
41
+ for col in range(GRID_SIZE[1]):
42
+ x1, y1 = col * seg_w, row * seg_h
43
+ x2, y2 = (col + 1) * seg_w, (row + 1) * seg_h
44
+
45
+ segment = original_image.crop((x1, y1, x2, y2))
46
+ segment_path = f"segment_{row}_{col}.png"
47
+ segment.save(segment_path)
48
+
49
+ # Run inference
50
+ result = CLIENT.infer(segment_path, model_id=MODEL_ID)
51
+ filtered_predictions = [
52
+ pred for pred in result["predictions"] if pred["confidence"] * 100 >= CONFIDENCE_THRESHOLD
53
+ ]
54
+
55
+ # Draw bounding boxes and update counts
56
+ for obj in filtered_predictions:
57
+ sx, sy, sw, sh = obj["x"], obj["y"], obj["width"], obj["height"]
58
+ class_name = obj["class"]
59
+ confidence = obj["confidence"]
60
+ total_counts[class_name] += 1
61
+
62
+ # Adjust coordinates for final image
63
+ x_min, y_min = x1 + (sx - sw // 2), y1 + (sy - sh // 2)
64
+ x_max, y_max = x1 + (sx + sw // 2), y1 + (sy + sh // 2)
65
+
66
+ # Draw bounding box
67
+ draw_final.rectangle([x_min, y_min, x_max, y_max], outline="green", width=2)
68
+
69
+ # Draw label
70
+ text = f"{class_name} {confidence:.2f}"
71
+ draw_final.text((x_min + 2, y_min - 10), text, fill="white", font=font)
72
+
73
+ # Save final image with bounding boxes
74
+ final_image_path = "final_detected_image.png"
75
+ final_image.save(final_image_path)
76
+
77
+ return final_image_path, total_counts
78
+
79
+ def process_uploaded_image(image):
80
+ """Gradio wrapper function that calls `process_image` and formats the output."""
81
+ final_image_path, total_counts = process_image(image)
82
+
83
+ # Convert count dictionary to readable text
84
+ count_text = "\n".join([f"{label}: {count}" for label, count in total_counts.items()])
85
+
86
+ return final_image_path, count_text
87
+
88
+ # Deploy with Gradio
89
+ iface = gr.Interface(
90
+ fn=process_uploaded_image,
91
+ inputs=gr.Image(type="filepath"), # Gradio input expects a file path
92
+ outputs=[gr.Image(type="filepath"), gr.Text()],
93
+ title="HVAC Symbol Detector",
94
+ description="Upload an HVAC blueprint image. The model will detect symbols and return the final image with bounding boxes along with symbol counts."
95
+ )
96
+
97
+ # Launch the Gradio app
98
+ iface.launch()