Files changed (1) hide show
  1. app.py +116 -83
app.py CHANGED
@@ -16,75 +16,99 @@ models: Dict[str, YOLO] = {}
16
 
17
  # Load all models
18
  for name, model_file in MODEL_OPTIONS.items():
19
- model_path = hf_hub_download(
20
- repo_id="johnlockejrr/medieval-manuscript-yolov11-seg",
21
- filename=model_file
22
- )
23
- models[name] = YOLO(model_path)
 
 
 
24
 
25
  # Create annotators
26
  LABEL_ANNOTATOR = sv.LabelAnnotator(text_color=sv.Color.BLACK)
27
  MASK_ANNOTATOR = sv.MaskAnnotator()
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def detect_and_annotate(
30
  image: np.ndarray,
31
  model_name: str,
32
  conf_threshold: float,
33
  iou_threshold: float
34
  ) -> np.ndarray:
35
- # Get the selected model
36
- model = models[model_name]
37
-
38
- # Perform inference
39
- results = model.predict(
40
- image,
41
- conf=conf_threshold,
42
- iou=iou_threshold
43
- )[0]
44
-
45
- # Convert results to supervision Detections
46
- boxes = results.boxes.xyxy.cpu().numpy()
47
- confidence = results.boxes.conf.cpu().numpy()
48
- class_ids = results.boxes.cls.cpu().numpy().astype(int)
49
-
50
- # Handle masks if they exist
51
- masks = None
52
- if results.masks is not None:
53
- masks = results.masks.data.cpu().numpy()
54
- # Reshape masks to (num_masks, H, W)
55
- masks = np.transpose(masks, (1, 2, 0)) # From (H, W, num_masks) to (num_masks, H, W)
56
 
57
- # Resize masks to match original image dimensions
58
- h, w = image.shape[:2]
59
- resized_masks = []
60
- for mask in masks:
61
- resized_mask = cv2.resize(mask.astype(float), (w, h), interpolation=cv2.INTER_LINEAR)
62
- resized_masks.append(resized_mask)
63
- masks = np.array(resized_masks)
64
- masks = masks.astype(bool)
65
-
66
- # Create Detections object
67
- detections = sv.Detections(
68
- xyxy=boxes,
69
- confidence=confidence,
70
- class_id=class_ids,
71
- mask=masks
72
- )
73
-
74
- # Create labels with confidence scores
75
- labels = [
76
- f"{results.names[class_id]} ({conf:.2f})"
77
- for class_id, conf
78
- in zip(class_ids, confidence)
79
- ]
 
 
 
 
 
 
 
 
 
80
 
81
- # Annotate image
82
- annotated_image = image.copy()
83
- if masks is not None:
84
- annotated_image = MASK_ANNOTATOR.annotate(scene=annotated_image, detections=detections)
85
- annotated_image = LABEL_ANNOTATOR.annotate(scene=annotated_image, detections=detections, labels=labels)
86
-
87
- return annotated_image
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  # Create Gradio interface
90
  with gr.Blocks() as demo:
@@ -97,37 +121,37 @@ with gr.Blocks() as demo:
97
  model_selector = gr.Dropdown(
98
  choices=list(MODEL_OPTIONS.keys()),
99
  value=list(MODEL_OPTIONS.keys())[0],
100
- label="Model",
101
- info="Select YOLO model variant"
102
  )
103
- with gr.Row():
104
- conf_threshold = gr.Slider(
105
- label="Confidence Threshold",
106
- minimum=0.0,
107
- maximum=1.0,
108
- step=0.05,
109
- value=0.25,
110
- )
111
- iou_threshold = gr.Slider(
112
- label="IoU Threshold",
113
- minimum=0.0,
114
- maximum=1.0,
115
- step=0.05,
116
- value=0.45,
117
- info="Decrease for stricter detection, increase for more overlapping boxes"
118
- )
119
- with gr.Row():
120
- clear_btn = gr.Button("Clear")
121
- detect_btn = gr.Button("Detect", variant="primary")
122
 
123
  with gr.Column():
124
  output_image = gr.Image(label="Segmentation Result", type='numpy')
125
 
126
  def process_image(image, model_name, conf_threshold, iou_threshold):
127
- if image is None:
128
- return None, None
129
- annotated_image = detect_and_annotate(image, model_name, conf_threshold, iou_threshold)
130
- return image, annotated_image
 
 
 
 
131
 
132
  def clear():
133
  return None, None
@@ -137,7 +161,16 @@ with gr.Blocks() as demo:
137
  inputs=[input_image, model_selector, conf_threshold, iou_threshold],
138
  outputs=[input_image, output_image]
139
  )
140
- clear_btn.click(clear, inputs=None, outputs=[input_image, output_image])
 
 
 
 
141
 
142
  if __name__ == "__main__":
143
- demo.launch(debug=True, show_error=True)
 
 
 
 
 
 
16
 
17
  # Load all models
18
  for name, model_file in MODEL_OPTIONS.items():
19
+ try:
20
+ model_path = hf_hub_download(
21
+ repo_id="johnlockejrr/medieval-manuscript-yolov11-seg",
22
+ filename=model_file
23
+ )
24
+ models[name] = YOLO(model_path)
25
+ except Exception as e:
26
+ print(f"Error loading model {name}: {str(e)}")
27
 
28
  # Create annotators
29
  LABEL_ANNOTATOR = sv.LabelAnnotator(text_color=sv.Color.BLACK)
30
  MASK_ANNOTATOR = sv.MaskAnnotator()
31
 
32
+ def process_masks(masks: np.ndarray, target_shape: Tuple[int, int]) -> np.ndarray:
33
+ """Process and resize masks to target shape"""
34
+ if masks is None:
35
+ return None
36
+
37
+ processed_masks = []
38
+ h, w = target_shape
39
+ for mask in masks:
40
+ # Resize mask to target dimensions
41
+ resized_mask = cv2.resize(mask.astype(float), (w, h), interpolation=cv2.INTER_LINEAR)
42
+ # Threshold to create binary mask
43
+ processed_masks.append(resized_mask > 0.5)
44
+
45
+ return np.array(processed_masks)
46
+
47
  def detect_and_annotate(
48
  image: np.ndarray,
49
  model_name: str,
50
  conf_threshold: float,
51
  iou_threshold: float
52
  ) -> np.ndarray:
53
+ try:
54
+ if image is None:
55
+ return None
56
+
57
+ model = models.get(model_name)
58
+ if model is None:
59
+ raise ValueError(f"Model {model_name} not loaded")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
+ # Perform inference
62
+ results = model.predict(
63
+ image,
64
+ conf=conf_threshold,
65
+ iou=iou_threshold
66
+ )[0]
67
+
68
+ # Convert results to supervision Detections
69
+ boxes = results.boxes.xyxy.cpu().numpy()
70
+ confidence = results.boxes.conf.cpu().numpy()
71
+ class_ids = results.boxes.cls.cpu().numpy().astype(int)
72
+
73
+ # Process masks
74
+ masks = None
75
+ if results.masks is not None:
76
+ masks = results.masks.data.cpu().numpy()
77
+ masks = np.moveaxis(masks, 0, -1) # Change from (N,H,W) to (H,W,N)
78
+ masks = process_masks(masks, image.shape[:2])
79
+
80
+ # Create Detections object
81
+ detections = sv.Detections(
82
+ xyxy=boxes,
83
+ confidence=confidence,
84
+ class_id=class_ids,
85
+ mask=masks
86
+ )
87
+
88
+ # Create labels
89
+ labels = [
90
+ f"{results.names[class_id]} ({conf:.2f})"
91
+ for class_id, conf in zip(class_ids, confidence)
92
+ ]
93
 
94
+ # Annotate image
95
+ annotated_image = image.copy()
96
+ if masks is not None:
97
+ annotated_image = MASK_ANNOTATOR.annotate(
98
+ scene=annotated_image,
99
+ detections=detections
100
+ )
101
+ annotated_image = LABEL_ANNOTATOR.annotate(
102
+ scene=annotated_image,
103
+ detections=detections,
104
+ labels=labels
105
+ )
106
+
107
+ return annotated_image
108
+
109
+ except Exception as e:
110
+ print(f"Error during detection: {str(e)}")
111
+ return image # Return original image on error
112
 
113
  # Create Gradio interface
114
  with gr.Blocks() as demo:
 
121
  model_selector = gr.Dropdown(
122
  choices=list(MODEL_OPTIONS.keys()),
123
  value=list(MODEL_OPTIONS.keys())[0],
124
+ label="Model"
 
125
  )
126
+ conf_threshold = gr.Slider(
127
+ label="Confidence Threshold",
128
+ minimum=0.0,
129
+ maximum=1.0,
130
+ step=0.05,
131
+ value=0.25
132
+ )
133
+ iou_threshold = gr.Slider(
134
+ label="IoU Threshold",
135
+ minimum=0.0,
136
+ maximum=1.0,
137
+ step=0.05,
138
+ value=0.45
139
+ )
140
+ detect_btn = gr.Button("Detect", variant="primary")
141
+ clear_btn = gr.Button("Clear")
 
 
 
142
 
143
  with gr.Column():
144
  output_image = gr.Image(label="Segmentation Result", type='numpy')
145
 
146
  def process_image(image, model_name, conf_threshold, iou_threshold):
147
+ try:
148
+ if image is None:
149
+ return None, None
150
+ annotated_image = detect_and_annotate(image, model_name, conf_threshold, iou_threshold)
151
+ return image, annotated_image
152
+ except Exception as e:
153
+ print(f"Error in process_image: {str(e)}")
154
+ return image, image # Fallback to original image
155
 
156
  def clear():
157
  return None, None
 
161
  inputs=[input_image, model_selector, conf_threshold, iou_threshold],
162
  outputs=[input_image, output_image]
163
  )
164
+ clear_btn.click(
165
+ clear,
166
+ inputs=None,
167
+ outputs=[input_image, output_image]
168
+ )
169
 
170
  if __name__ == "__main__":
171
+ demo.launch(
172
+ server_name="0.0.0.0",
173
+ server_port=7860,
174
+ show_error=True,
175
+ debug=True
176
+ )