wjm55 commited on
Commit
052c825
·
1 Parent(s): d5d8604

Add YOLOv11 model integration and Gradio interface for text detection

Browse files
Files changed (2) hide show
  1. app.py +158 -95
  2. requirements.txt +2 -0
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import streamlit as st
2
  import warnings
3
  warnings.simplefilter("ignore", UserWarning)
4
 
@@ -18,13 +17,39 @@ import cv2
18
  import numpy as np
19
  import pandas as pd
20
  import logging
21
- from typing import List, Optional
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  # Configure logging
24
  logging.getLogger("lightning.pytorch").setLevel(logging.ERROR)
25
 
26
  # Load YOLOv8 model
27
- model = YOLO('model.pt')
28
  images = Path(mkdtemp())
29
  DEFAULT_HEIGHT = 128
30
  TEXT_DIRECTION = "LTR"
@@ -36,6 +61,13 @@ CONFIDENCE_PATTERN = r"(?P<confidence>[0-9.]+)" # For line
36
  TEXT_PATTERN = r"\s*(?P<text>.*)\s*"
37
  LINE_PREDICTION = re.compile(rf"{IMAGE_ID_PATTERN} {CONFIDENCE_PATTERN} {TEXT_PATTERN}")
38
 
 
 
 
 
 
 
 
39
  def get_width(image, height=DEFAULT_HEIGHT):
40
  aspect_ratio = image.width / image.height
41
  return height * aspect_ratio
@@ -65,7 +97,8 @@ def simplify_polygons(polygons: List[np.ndarray], approx_level: float = 0.01) ->
65
  result.append(approx.squeeze())
66
  return result
67
 
68
- def predict(model_name, input_img):
 
69
  model_dir = 'catmus-medieval'
70
  temperature = 2.0
71
  batch_size = 1
@@ -121,96 +154,126 @@ def predict(model_name, input_img):
121
  predictions = Path(pred_stdout.name).read_text().strip().splitlines()
122
 
123
  _, score, text = LINE_PREDICTION.match(predictions[0]).groups()
124
- if TEXT_DIRECTION == "RTL":
125
- return input_img, {"text": get_display(text), "score": score}
126
- else:
127
- return input_img, {"text": text, "score": score}
128
-
129
- def process_image(image):
130
- # Perform inference on an image, select textline only
131
- results = model(image, classes=0)
132
-
133
- img_cv2 = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
134
- masks = results[0].masks
135
- polygons = []
136
- texts = []
137
-
138
- if masks is not None:
139
- # Get masks data and original image dimensions
140
- masks = masks.data.cpu().numpy()
141
- img_height, img_width = img_cv2.shape[:2]
142
-
143
- # Get bounding boxes in xyxy format
144
- boxes = results[0].boxes.xyxy.cpu().numpy()
145
-
146
- # Sort by y-coordinate of the top-left corner
147
- sorted_indices = np.argsort(boxes[:, 1])
148
- masks = masks[sorted_indices]
149
- boxes = boxes[sorted_indices]
150
-
151
- for i, (mask, box) in enumerate(zip(masks, boxes)):
152
- # Scale the mask to original image size
153
- mask = cv2.resize(mask.squeeze(), (img_width, img_height), interpolation=cv2.INTER_LINEAR)
154
- mask = (mask > 0.5).astype(np.uint8) * 255 # Apply threshold
155
-
156
- # Convert mask to polygon
157
- contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
158
-
159
- if contours:
160
- # Get the largest contour
161
- largest_contour = max(contours, key=cv2.contourArea)
162
- simplified_polygon = simplify_polygons([largest_contour])[0]
163
-
164
- if simplified_polygon is not None:
165
- # Crop the image using the bounding box for text recognition
166
- x1, y1, x2, y2 = map(int, box)
167
- crop_img = img_cv2[y1:y2, x1:x2]
168
- crop_pil = Image.fromarray(cv2.cvtColor(crop_img, cv2.COLOR_BGR2RGB))
169
-
170
- # Recognize text using PyLaia model
171
- predicted = predict('pylaia-samaritan_v1', crop_pil)
172
- texts.append(predicted[1]["text"])
173
-
174
- # Convert polygon to list of points for display
175
- poly_points = simplified_polygon.reshape(-1, 2).astype(int).tolist()
176
- polygons.append(f"Line {i+1}: {poly_points}")
177
-
178
- # Draw polygon on the image
179
- cv2.polylines(img_cv2, [simplified_polygon.reshape(-1, 1, 2).astype(int)],
180
- True, (0, 255, 0), 2)
181
-
182
- # Convert image back to RGB for display in Streamlit
183
- img_result = cv2.cvtColor(img_cv2, cv2.COLOR_BGR2RGB)
184
-
185
- # Combine polygons and texts into a DataFrame for table display
186
- table_data = pd.DataFrame({"Polygons": polygons, "Recognized Text": texts})
187
- return Image.fromarray(img_result), table_data
188
-
189
- def segment_and_recognize(image):
190
- segmented_image, table_data = process_image(image)
191
- return segmented_image, table_data
192
-
193
- # Streamlit app layout
194
- st.set_page_config(layout="wide") # Use full page width
195
- st.title("YOLOv11 Text Line Segmentation & PyLaia Text Recognition on CATMuS/medieval")
196
-
197
- # File uploader
198
- uploaded_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
199
-
200
- # Process the image if uploaded
201
- if uploaded_image is not None:
202
- image = Image.open(uploaded_image)
203
-
204
- if st.button("Segment and Recognize"):
205
- # Perform segmentation and recognition
206
- segmented_image, table_data = segment_and_recognize(image)
207
-
208
- # Layout: Image on the left, Table on the right
209
- col1, col2 = st.columns([2, 3]) # Adjust the ratio if needed
210
-
211
- with col1:
212
- st.image(segmented_image, caption="Segmented Image with Polygon Masks", use_container_width=True)
213
 
214
- with col2:
215
- st.table(table_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
 
 
1
  import warnings
2
  warnings.simplefilter("ignore", UserWarning)
3
 
 
17
  import numpy as np
18
  import pandas as pd
19
  import logging
20
+ from typing import List, Optional, Tuple, Dict
21
+ from huggingface_hub import hf_hub_download
22
+ import gradio as gr
23
+ import supervision as sv
24
+ import os
25
+ import spaces
26
+ import torch
27
+
28
+ # Define models
29
+ MODEL_OPTIONS = {
30
+ "YOLOv11-Nano": "medieval-yolov11n.pt",
31
+ "YOLOv11-Small": "medieval-yolov11s.pt",
32
+ "YOLOv11-Medium": "medieval-yolov11m.pt",
33
+ "YOLOv11-Large": "medieval-yolov11l.pt",
34
+ "YOLOv11-XLarge": "medieval-yolov11x.pt"
35
+ }
36
+
37
+ # Dictionary to store loaded models
38
+ models: Dict[str, YOLO] = {}
39
+
40
+ # Load all models
41
+ for name, model_file in MODEL_OPTIONS.items():
42
+ model_path = hf_hub_download(
43
+ repo_id="biglam/medieval-manuscript-yolov11",
44
+ filename=model_file
45
+ )
46
+ models[name] = YOLO(model_path)
47
 
48
  # Configure logging
49
  logging.getLogger("lightning.pytorch").setLevel(logging.ERROR)
50
 
51
  # Load YOLOv8 model
52
+ model = YOLO(model_path)
53
  images = Path(mkdtemp())
54
  DEFAULT_HEIGHT = 128
55
  TEXT_DIRECTION = "LTR"
 
61
  TEXT_PATTERN = r"\s*(?P<text>.*)\s*"
62
  LINE_PREDICTION = re.compile(rf"{IMAGE_ID_PATTERN} {CONFIDENCE_PATTERN} {TEXT_PATTERN}")
63
 
64
+ # Create annotators
65
+ LABEL_ANNOTATOR = sv.LabelAnnotator(text_color=sv.Color.BLACK)
66
+ BOX_ANNOTATOR = sv.BoxAnnotator()
67
+
68
+ # Select device
69
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
70
+
71
  def get_width(image, height=DEFAULT_HEIGHT):
72
  aspect_ratio = image.width / image.height
73
  return height * aspect_ratio
 
97
  result.append(approx.squeeze())
98
  return result
99
 
100
+ def predict_text(input_img):
101
+ """PyLaia text recognition function"""
102
  model_dir = 'catmus-medieval'
103
  temperature = 2.0
104
  batch_size = 1
 
154
  predictions = Path(pred_stdout.name).read_text().strip().splitlines()
155
 
156
  _, score, text = LINE_PREDICTION.match(predictions[0]).groups()
157
+ return text, float(score)
158
+
159
+ @spaces.GPU
160
+ def detect_and_recognize(image, model_name, conf_threshold, iou_threshold):
161
+ if image is None:
162
+ return None, ""
163
+
164
+ # Get model path
165
+ model_path = hf_hub_download(
166
+ repo_id="biglam/medieval-manuscript-yolov11",
167
+ filename=MODEL_OPTIONS[model_name]
168
+ )
169
+
170
+ # Load model
171
+ model = YOLO(model_path)
172
+
173
+ # Perform inference
174
+ results = model.predict(
175
+ image,
176
+ conf=conf_threshold,
177
+ iou=iou_threshold,
178
+ classes=0,
179
+ device=device
180
+ )[0]
181
+
182
+ # Convert results to supervision Detections
183
+ boxes = results.boxes.xyxy.cpu().numpy()
184
+ confidence = results.boxes.conf.cpu().numpy()
185
+ class_ids = results.boxes.cls.cpu().numpy().astype(int)
186
+
187
+ # Sort boxes by y-coordinate
188
+ sorted_indices = np.argsort(boxes[:, 1])
189
+ boxes = boxes[sorted_indices]
190
+ confidence = confidence[sorted_indices]
191
+
192
+ # Create Detections object
193
+ detections = sv.Detections(
194
+ xyxy=boxes,
195
+ confidence=confidence,
196
+ class_id=class_ids
197
+ )
198
+
199
+ # Create labels
200
+ labels = [
201
+ f"Line {i+1} ({conf:.2f})"
202
+ for i, conf in enumerate(confidence)
203
+ ]
204
+
205
+ # Annotate image
206
+ annotated_image = image.copy()
207
+ annotated_image = BOX_ANNOTATOR.annotate(scene=annotated_image, detections=detections)
208
+ annotated_image = LABEL_ANNOTATOR.annotate(scene=annotated_image, detections=detections, labels=labels)
209
+
210
+ # Create text summary
211
+ text_summary = "\n".join([f"Line {i+1}: Confidence {conf:.2f}" for i, conf in enumerate(confidence)])
212
+
213
+ return annotated_image, text_summary
214
+
215
+ def gradio_reset():
216
+ return None, None, ""
217
+
218
+ if __name__ == "__main__":
219
+ print(f"Using device: {device}")
220
+
221
+ with gr.Blocks() as demo:
222
+ gr.Markdown("# Medieval Manuscript Text Detection")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
+ with gr.Row():
225
+ with gr.Column():
226
+ input_image = gr.Image(
227
+ label="Input Image",
228
+ type="numpy"
229
+ )
230
+ with gr.Accordion("Detection Settings", open=True):
231
+ model_selector = gr.Dropdown(
232
+ choices=list(MODEL_OPTIONS.keys()),
233
+ value=list(MODEL_OPTIONS.keys())[0],
234
+ label="Model",
235
+ info="Select YOLO model variant"
236
+ )
237
+ with gr.Row():
238
+ conf_threshold = gr.Slider(
239
+ label="Confidence Threshold",
240
+ minimum=0.0,
241
+ maximum=1.0,
242
+ step=0.05,
243
+ value=0.25,
244
+ )
245
+ iou_threshold = gr.Slider(
246
+ label="IoU Threshold",
247
+ minimum=0.0,
248
+ maximum=1.0,
249
+ step=0.05,
250
+ value=0.45,
251
+ )
252
+ with gr.Row():
253
+ clear_btn = gr.Button("Clear")
254
+ detect_btn = gr.Button("Detect", variant="primary")
255
+
256
+ with gr.Column():
257
+ output_image = gr.Image(
258
+ label="Detection Result",
259
+ type="numpy"
260
+ )
261
+ text_output = gr.Textbox(
262
+ label="Detection Summary",
263
+ lines=10
264
+ )
265
+
266
+ # Connect buttons to functions
267
+ detect_btn.click(
268
+ detect_and_recognize,
269
+ inputs=[input_image, model_selector, conf_threshold, iou_threshold],
270
+ outputs=[output_image, text_output]
271
+ )
272
+ clear_btn.click(
273
+ gradio_reset,
274
+ inputs=None,
275
+ outputs=[input_image, output_image, text_output]
276
+ )
277
+
278
+ demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)
279
 
requirements.txt CHANGED
@@ -20,3 +20,5 @@ python-bidi==0.6.0
20
  streamlit==1.44.0
21
  transformers==4.50.3
22
  ultralytics==8.3.99
 
 
 
20
  streamlit==1.44.0
21
  transformers==4.50.3
22
  ultralytics==8.3.99
23
+ gradio
24
+ supervision