pluniak commited on
Commit
92cd9d3
·
verified ·
1 Parent(s): 111e72b

Update helpers.py

Browse files
Files changed (1) hide show
  1. helpers.py +594 -594
helpers.py CHANGED
@@ -1,595 +1,595 @@
1
- import cv2
2
- import numpy as np
3
- import json
4
- from PIL import Image, ImageDraw, ImageFont
5
- from transformers import pipeline
6
- from huggingface_hub import from_pretrained_keras
7
- import imageio
8
-
9
-
10
- def resize_image(img_in,input_height,input_width):
11
- return cv2.resize( img_in, ( input_width,input_height) ,interpolation=cv2.INTER_NEAREST)
12
-
13
- def write_dict_to_json(dictionary, save_path, indent=4):
14
- with open(save_path, "w") as outfile:
15
- json.dump(dictionary, outfile, indent=indent)
16
-
17
- def load_json_to_dict(load_path):
18
- with open(load_path) as json_file:
19
- return json.load(json_file)
20
-
21
-
22
- class OCRD:
23
- """
24
- Optical Character Recognition and Document processing class that provides functionalities
25
- to preprocess images, detect text lines, perform OCR, and visualize the results.
26
-
27
- The class utilizes deep learning models for various tasks such as binarization and text
28
- line segmentation. It provides comprehensive methods to handle image scaling, prediction,
29
- text extraction, and overlaying recognized text on images.
30
-
31
- Attributes:
32
- image (ndarray): The image loaded into memory from the specified path. This image
33
- is used across various methods within the class.
34
-
35
- Methods:
36
- __init__(img_path: str):
37
- Initializes the OCRD class by loading an image from the specified file path.
38
-
39
- scale_image(img: ndarray) -> ndarray:
40
- Scales an image while maintaining its aspect ratio based on predefined width thresholds.
41
-
42
- predict(model, img: ndarray) -> ndarray:
43
- Uses a specified model to make predictions on the image. This function handles
44
- image resizing and segmenting for model input.
45
-
46
- binarize_image(img: ndarray, binarize_mode: str) -> ndarray:
47
- Applies binarization to the image based on the specified mode ('detailed', 'fast', or 'no').
48
-
49
- segment_textlines(img: ndarray) -> ndarray:
50
- Segments text lines from the binarized image using a pretrained model.
51
-
52
- extract_filter_and_deskew_textlines(img: ndarray, textline_mask: ndarray, min_pixel_sum: int, median_bounds: tuple) -> (dict, ndarray):
53
- Processes an image to extract and correct orientation of text lines based on the provided mask.
54
-
55
- ocr_on_textlines(textline_images: dict) -> dict:
56
- Performs OCR on the extracted text lines and returns the recognized text.
57
-
58
- create_text_overlay_image(textline_images: dict, textline_preds: dict, img_shape: tuple, font_size: int) -> Image:
59
- Creates an image overlay with the recognized text annotations.
60
-
61
- visualize_model_output(prediction: ndarray, img: ndarray) -> ndarray:
62
- Visualizes the model's prediction by overlaying it onto the original image with distinct colors.
63
- """
64
-
65
- def __init__(self, img_path):
66
- self.image = np.array(Image.open(img_path))
67
-
68
- def scale_image(self, img):
69
- """
70
- Scales an image to have dimensions suitable for neural network inference. Scaling is based on the
71
- input width parameter. The new width and height of the image are calculated to maintain the aspect
72
- ratio of the original image.
73
-
74
- Parameters:
75
- - img (ndarray): The image to be scaled, expected to be in the form of a numpy array where
76
- img.shape[0] is the height and img.shape[1] is the width.
77
-
78
- Behavior:
79
- - If image width is less than 1100, the new width is set to 2000 pixels. The height is adjusted
80
- to maintain the aspect ratio.
81
- - If image width is between 1100 (inclusive) and 2500 (exclusive), the width remains unchanged
82
- and the height is adjusted to maintain the aspect ratio.
83
- - If image width is 2500 or more, the width is set to 2000 pixels and the height is similarly
84
- adjusted to maintain the aspect ratio.
85
-
86
- Returns:
87
- - img_new (ndarray): A new image array that has been resized according to the specified rules.
88
- The aspect ratio of the original image is preserved.
89
-
90
- Note:
91
- - This function assumes that a function `resize_image(img, height, width)` is available and is
92
- used to resize the image where `img` is the original image array, `height` is the new height,
93
- and `width` is the new width.
94
- """
95
-
96
- width_early = img.shape[1]
97
-
98
- if width_early < 1100:
99
- img_w_new = 2000
100
- img_h_new = int(img.shape[0] / float(img.shape[1]) * 2000)
101
- elif width_early >= 1100 and width_early < 2500:
102
- img_w_new = width_early
103
- img_h_new = int(img.shape[0] / float(img.shape[1]) * width_early)
104
- else:
105
- img_w_new = 2000
106
- img_h_new = int(img.shape[0] / float(img.shape[1]) * 2000)
107
-
108
- img_new = resize_image(img, img_h_new, img_w_new)
109
-
110
- return img_new
111
-
112
- def predict(self, model, img):
113
- """
114
- Processes an image to predict segmentation outputs using a given model. The function handles image resizing
115
- to match the model's input dimensions and ensures that the entire image is processed by segmenting it into patches
116
- that the model can handle. The prediction from these patches is then reassembled into a single output image.
117
-
118
- Parameters:
119
- - model (keras.Model): The neural network model used for predicting the image segmentation. The model should have
120
- predefined input dimensions (height and width).
121
- - img (ndarray): The image to be processed, represented as a numpy array.
122
-
123
- Returns:
124
- - prediction_true (ndarray): An image of the same size as the input image, containing the segmentation prediction
125
- with each pixel labeled according to the model's output.
126
-
127
- Details:
128
- - The function first scales the input image according to the model's required input dimensions. If the scaled image
129
- is smaller than the model's height or width, it is resized to match exactly.
130
- - The function processes the image in overlapping patches to ensure smooth transitions between the segments. These
131
- patches are then processed individually through the model.
132
- - Predictions from these patches are then stitched together to form a complete output image, ensuring that edge
133
- artifacts are minimized by carefully blending the overlapping areas.
134
- - This method assumes the availability of `resize_image` function for scaling and resizing
135
- operations, respectively.
136
- - The output is converted to an 8-bit image before returning, suitable for display or further processing.
137
- """
138
-
139
- # bitmap output
140
- img_height_model=model.layers[len(model.layers)-1].output_shape[1]
141
- img_width_model=model.layers[len(model.layers)-1].output_shape[2]
142
-
143
- img = self.scale_image(img)
144
-
145
- if img.shape[0] < img_height_model:
146
- img = resize_image(img, img_height_model, img.shape[1])
147
-
148
- if img.shape[1] < img_width_model:
149
- img = resize_image(img, img.shape[0], img_width_model)
150
-
151
- marginal_of_patch_percent = 0.1
152
- margin = int(marginal_of_patch_percent * img_height_model)
153
- width_mid = img_width_model - 2 * margin
154
- height_mid = img_height_model - 2 * margin
155
- img = img / float(255.0)
156
- img = img.astype(np.float16)
157
- img_h = img.shape[0]
158
- img_w = img.shape[1]
159
- prediction_true = np.zeros((img_h, img_w, 3))
160
- nxf = img_w / float(width_mid)
161
- nyf = img_h / float(height_mid)
162
- nxf = int(nxf) + 1 if nxf > int(nxf) else int(nxf)
163
- nyf = int(nyf) + 1 if nyf > int(nyf) else int(nyf)
164
-
165
- for i in range(nxf):
166
- for j in range(nyf):
167
- if i == 0:
168
- index_x_d = i * width_mid
169
- index_x_u = index_x_d + img_width_model
170
- else:
171
- index_x_d = i * width_mid
172
- index_x_u = index_x_d + img_width_model
173
- if j == 0:
174
- index_y_d = j * height_mid
175
- index_y_u = index_y_d + img_height_model
176
- else:
177
- index_y_d = j * height_mid
178
- index_y_u = index_y_d + img_height_model
179
- if index_x_u > img_w:
180
- index_x_u = img_w
181
- index_x_d = img_w - img_width_model
182
- if index_y_u > img_h:
183
- index_y_u = img_h
184
- index_y_d = img_h - img_height_model
185
-
186
- img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :]
187
- label_p_pred = model.predict(img_patch.reshape(1, img_patch.shape[0], img_patch.shape[1], img_patch.shape[2]),
188
- verbose=0)
189
-
190
- seg = np.argmax(label_p_pred, axis=3)[0]
191
- seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
192
-
193
- if i == 0 and j == 0:
194
- seg_color = seg_color[0 : seg_color.shape[0] - margin, 0 : seg_color.shape[1] - margin, :]
195
- prediction_true[index_y_d + 0 : index_y_u - margin, index_x_d + 0 : index_x_u - margin, :] = seg_color
196
- elif i == nxf - 1 and j == nyf - 1:
197
- seg_color = seg_color[margin : seg_color.shape[0] - 0, margin : seg_color.shape[1] - 0, :]
198
- prediction_true[index_y_d + margin : index_y_u - 0, index_x_d + margin : index_x_u - 0, :] = seg_color
199
- elif i == 0 and j == nyf - 1:
200
- seg_color = seg_color[margin : seg_color.shape[0] - 0, 0 : seg_color.shape[1] - margin, :]
201
- prediction_true[index_y_d + margin : index_y_u - 0, index_x_d + 0 : index_x_u - margin, :] = seg_color
202
- elif i == nxf - 1 and j == 0:
203
- seg_color = seg_color[0 : seg_color.shape[0] - margin, margin : seg_color.shape[1] - 0, :]
204
- prediction_true[index_y_d + 0 : index_y_u - margin, index_x_d + margin : index_x_u - 0, :] = seg_color
205
- elif i == 0 and j != 0 and j != nyf - 1:
206
- seg_color = seg_color[margin : seg_color.shape[0] - margin, 0 : seg_color.shape[1] - margin, :]
207
- prediction_true[index_y_d + margin : index_y_u - margin, index_x_d + 0 : index_x_u - margin, :] = seg_color
208
- elif i == nxf - 1 and j != 0 and j != nyf - 1:
209
- seg_color = seg_color[margin : seg_color.shape[0] - margin, margin : seg_color.shape[1] - 0, :]
210
- prediction_true[index_y_d + margin : index_y_u - margin, index_x_d + margin : index_x_u - 0, :] = seg_color
211
- elif i != 0 and i != nxf - 1 and j == 0:
212
- seg_color = seg_color[0 : seg_color.shape[0] - margin, margin : seg_color.shape[1] - margin, :]
213
- prediction_true[index_y_d + 0 : index_y_u - margin, index_x_d + margin : index_x_u - margin, :] = seg_color
214
- elif i != 0 and i != nxf - 1 and j == nyf - 1:
215
- seg_color = seg_color[margin : seg_color.shape[0] - 0, margin : seg_color.shape[1] - margin, :]
216
- prediction_true[index_y_d + margin : index_y_u - 0, index_x_d + margin : index_x_u - margin, :] = seg_color
217
- else:
218
- seg_color = seg_color[margin : seg_color.shape[0] - margin, margin : seg_color.shape[1] - margin, :]
219
- prediction_true[index_y_d + margin : index_y_u - margin, index_x_d + margin : index_x_u - margin, :] = seg_color
220
-
221
- prediction_true = prediction_true.astype(np.uint8)
222
-
223
- return prediction_true
224
-
225
- def binarize_image(self, img, binarize_mode='detailed'):
226
- """
227
- Binarizes an image according to the specified mode.
228
-
229
- Parameters:
230
- - img (ndarray): The input image to be binarized.
231
- - binarize_mode (str): The mode of binarization. Can be 'detailed', 'fast', or 'no'.
232
- - 'detailed': Uses a pre-trained deep learning model for binarization.
233
- - 'fast': Uses OpenCV for a quicker, threshold-based binarization.
234
- - 'no': Returns a copy of the original image.
235
-
236
- Returns:
237
- - ndarray: The binarized image.
238
-
239
- Raises:
240
- - ValueError: If an invalid binarize_mode is provided.
241
-
242
- Description:
243
- Depending on the 'binarize_mode', the function processes the image differently:
244
- - For 'detailed' mode, it loads a specific model and performs prediction to binarize the image.
245
- - For 'fast' mode, it quickly converts the image to grayscale and applies a threshold.
246
- - For 'no' mode, it simply returns the original image unchanged.
247
- If an unsupported mode is provided, the function raises a ValueError.
248
-
249
- Note:
250
- - The 'detailed' mode requires a pre-trained model from huggingface_hub.
251
- - This function depends on OpenCV (cv2) for image processing in 'fast' mode.
252
- """
253
-
254
- if binarize_mode == 'detailed':
255
- model_name = "SBB/eynollah-binarization"
256
- model = from_pretrained_keras(model_name)
257
- binarized = self.predict(model, img)
258
-
259
- # Convert from mask to image (letters black)
260
- binarized = binarized.astype(np.int8)
261
- binarized = -binarized + 1
262
- binarized = (binarized * 255).astype(np.uint8)
263
-
264
- elif binarize_mode == 'fast':
265
- binarized = self.scale_image(img, self.image)
266
- binarized = cv2.cvtColor(binarized, cv2.COLOR_BGR2GRAY)
267
- _, binarized = cv2.threshold(binarized, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
268
- binarized = np.repeat(binarized[:, :, np.newaxis], 3, axis=2)
269
-
270
- elif binarize_mode == 'no':
271
- binarized = img.copy()
272
-
273
- else:
274
- accepted_values = ['detailed', 'fast', 'no']
275
- raise ValueError(f"Invalid value provided: {binarize_mode}. Accepted values are: {accepted_values}")
276
-
277
- binarized = binarized.astype(np.uint8)
278
-
279
- return binarized
280
-
281
-
282
- def segment_textlines(self, img):
283
- '''
284
- ADD DOCUMENTATION!
285
- '''
286
- model_name = "SBB/eynollah-textline"
287
- model = from_pretrained_keras(model_name)
288
- textline_segments = self.predict(model, img)
289
-
290
- return textline_segments
291
-
292
-
293
- def extract_filter_and_deskew_textlines(self, img, textline_mask, min_pixel_sum=20, median_bounds=(.5, 20)):
294
-
295
- """
296
- Extracts and deskews text lines from an image based on a provided textline mask. This function identifies
297
- text lines, filters out those that do not meet size criteria, calculates their minimum area rectangles,
298
- performs perspective transformations to deskew each text line, and handles potential rotations to ensure
299
- text lines are presented horizontally.
300
-
301
- Parameters:
302
- - img (numpy.ndarray): The original image from which to extract and deskew text lines. It should be a 3D array.
303
- - textline_mask (numpy.ndarray): A binary mask where text lines have been segmented. It should be a 2D array.
304
- - min_pixel_sum (int, optional): The minimum number of pixels (area) a connected component must have to be considered
305
- a valid text line. If None, no filtering is applied.
306
- - median_bounds (tuple, optional): A tuple representing the lower and upper bounds as multipliers for filtering
307
- text lines based on the median size of identified text lines. If None, no filtering is applied.
308
-
309
- Returns:
310
- - tuple:
311
- - dict: A dictionary containing lists of the extracted and deskewed text line images along with their
312
- metadata (center, left side, height, width, and rotation angle of the bounding box).
313
- - numpy.ndarray: An image visualization of the filtered text line mask for debugging or analysis.
314
-
315
- Description:
316
- The function first uses connected components to identify potential text lines from the mask. It filters these
317
- based on absolute size (min_pixel_sum) and relative size (median_bounds). For each valid text line, it computes
318
- a minimum area rectangle, extracts and deskews the bounded region. This includes rotating the text line if it
319
- is detected as vertical (taller than wide). Finally, it aggregates the results and provides an image for
320
- visualization of the text lines retained after filtering.
321
-
322
- Notes:
323
- - This function assumes the textline_mask is properly segmented and binary (0s for background, 255 for text lines).
324
- - Errors in perspective transformation due to incorrect contour extraction or bounding box calculations are handled
325
- gracefully, reporting the error but continuing with other text lines.
326
- """
327
-
328
- num_labels, labels_im = cv2.connectedComponents(textline_mask)
329
-
330
- # Thresholds for filtering
331
- MIN_PIXEL_SUM = min_pixel_sum # absolute filtering
332
- MEDIAN_LOWER_BOUND = median_bounds[0] # relative filtering
333
- MEDIAN_UPPER_BOUND = median_bounds[1] # relative filtering
334
-
335
- # Gather masks and their sizes
336
- cc_sizes = []
337
- masks = []
338
- labels_im_filtered = labels_im > 0 # for visualizing filtering result
339
- for label in range(1, num_labels): # ignore background class
340
- mask = np.where(labels_im == label, True, False)
341
- if MIN_PIXEL_SUM is None:
342
- is_above_min_pixel_sum = True
343
- else:
344
- is_above_min_pixel_sum = mask.sum() > MIN_PIXEL_SUM
345
- if is_above_min_pixel_sum: # dismiss mini segmentations to avoid skewing of median
346
- cc_sizes.append(mask.sum())
347
- masks.append(mask)
348
-
349
- # filter masks by size in relation to median; then calculate contours and min area bounding box for remaining ones
350
- rectangles = []
351
- median = np.median(cc_sizes)
352
- for mask in masks:
353
- mask_sum = mask.sum()
354
- if MEDIAN_LOWER_BOUND is None:
355
- is_above_lower_media_bound = True
356
- else:
357
- is_above_lower_media_bound = mask_sum > median*MEDIAN_LOWER_BOUND
358
- if MEDIAN_UPPER_BOUND is None:
359
- is_below_upper_median_bound = True
360
- else:
361
- is_below_upper_median_bound = mask_sum < median*MEDIAN_UPPER_BOUND
362
- if is_above_lower_media_bound and is_below_upper_median_bound:
363
- labels_im_filtered[mask > 0] = False
364
- mask = (mask*255).astype(np.uint8)
365
- contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
366
- rect = cv2.minAreaRect(contours[0])
367
- if np.prod(rect[1]) > 0: # filter out if height or width = 0
368
- rectangles.append(rect)
369
-
370
- # Transform (rotated) bounding boxes to horizontal; store together with rotation angle for downstream process re-transform
371
- if rectangles:
372
- # Filter rectangles and de-skew images
373
- textline_images = []
374
- for rect in rectangles:
375
- width, height = rect[1]
376
- rotation_angle = rect[2] # clarify how to interpret and use rotation angle!
377
-
378
- # Convert dimensions to integer and ensure they are > 0
379
- width = int(width)
380
- height = int(height)
381
-
382
- # get source and destination points for image transform
383
- box = cv2.boxPoints(rect)
384
- box = np.intp(box)
385
- src_pts = box.astype("float32")
386
- dst_pts = np.array([[0, height-1],
387
- [0, 0],
388
- [width-1, 0],
389
- [width-1, height-1]], dtype="float32")
390
-
391
- try:
392
- M = cv2.getPerspectiveTransform(src_pts, dst_pts)
393
- warped = cv2.warpPerspective(img, M, (width, height))
394
- # Check and rotate if the text line is taller than wide
395
- if height > width:
396
- warped = cv2.rotate(warped, cv2.ROTATE_90_CLOCKWISE)
397
- temp = height
398
- height = width
399
- width = temp
400
- rotation_angle = 90-rotation_angle
401
- center = rect[0]
402
- left = center[0] - width//2
403
- textline_images.append((warped, center, left, height, width, rotation_angle))
404
- except cv2.error as e:
405
- print(f"Error with warpPerspective: {e}")
406
-
407
- # cast to dict
408
- keys = ['array', 'center', 'left', 'height', 'width', 'rotation_angle']
409
- textline_images = {key: [tup[i] for tup in textline_images] for i, key in enumerate(keys)}
410
- num_labels_filtered = len(textline_images['array'])
411
- labels_im_filtered = np.repeat(labels_im_filtered[:, :, np.newaxis], 3, axis=2).astype(np.uint8) # 3 color channels for plotting
412
- print(f'Kept {num_labels_filtered} of {num_labels} text segments after filtering.')
413
- print(f'All segments deleted smaller than {MIN_PIXEL_SUM} pixels (absolute min size).')
414
- if MEDIAN_LOWER_BOUND is not None:
415
- print(f'All segments deleted smaller than {median*MEDIAN_LOWER_BOUND} pixels (lower median bound).')
416
- if MEDIAN_UPPER_BOUND is not None:
417
- print(f'All segments deleted bigger than {median*MEDIAN_UPPER_BOUND} pixels (upper median bound).')
418
- if MEDIAN_LOWER_BOUND is not None or MEDIAN_UPPER_BOUND is not None:
419
- print(f'Median segment size (pixel sum) used for filtering: {int(median)}.')
420
-
421
- return textline_images, labels_im_filtered
422
-
423
-
424
- def ocr_on_textlines(self, textline_images, model_name="microsoft/trocr-base-handwritten"):
425
- """
426
- Processes a list of image arrays using a pre-trained OCR model to extract text.
427
-
428
- Parameters:
429
- - textline_images (dict): A dictionary with a key 'array' that contains a list of image arrays.
430
- Each image array represents a line of text that will be processed by the OCR model.
431
- - model_name (str): A huggingface model trained for OCR on single text lines
432
-
433
- Returns:
434
- - dict: A dictionary containing a list of extracted text under the key 'preds'.
435
-
436
- Description:
437
- The function initializes the OCR model 'microsoft/trocr-base-handwritten' using Hugging Face's
438
- `pipeline` API for image-to-text conversion. Each image in the input list is converted from an
439
- array format to a PIL Image, processed by the model, and the text prediction is collected.
440
- The progress of image processing is printed every 10 images. The final result is a dictionary
441
- with the key 'preds' that holds all text predictions as a list.
442
-
443
- Note:
444
- - This function requires the `transformers` library from Hugging Face and PIL library to run.
445
- - Ensure that the model 'microsoft/trocr-base-handwritten' is correctly loaded and the
446
- `transformers` library is updated to use the pipeline.
447
- """
448
-
449
- pipe = pipeline("image-to-text", model=model_name)
450
-
451
- # Model inference
452
- textline_preds = []
453
- len_array = len(textline_images['array'])
454
- for i, textline in enumerate(textline_images['array'][:]):
455
- if i % 10 == 1:
456
- print(f'Processing textline no. {i} of {len_array}')
457
- textline = Image.fromarray(textline)
458
- textline_preds.append(pipe(textline))
459
-
460
- # Convert to dict
461
- preds = [pred[0]['generated_text'] for pred in textline_preds]
462
- textline_preds_dict = {'preds': preds}
463
-
464
- return textline_preds_dict
465
-
466
-
467
- def adjust_font_size(self, draw, text, box_width):
468
- """
469
- Adjusts the font size to ensure the text fits within a specified width.
470
-
471
- Parameters:
472
- - draw (ImageDraw.Draw): An instance of ImageDraw.Draw used to render the text.
473
- - text (str): The text string to be rendered.
474
- - box_width (int): The maximum width in pixels that the text should occupy.
475
-
476
- Returns:
477
- - ImageFont: A font object with a size adjusted to fit the text within the specified width.
478
- """
479
-
480
- for font_size in range(1, 200): # Adjust the range as needed
481
- font = ImageFont.load_default(font_size)
482
- text_width = draw.textlength(text, font=font)
483
- if text_width > box_width:
484
- font_size = int(font_size - 10)
485
- return ImageFont.load_default(font_size) # Return the last fitting size
486
- return font # Return max size if none exceeded the box
487
-
488
-
489
- def create_text_overlay_image(self, textline_images, textline_preds, img_shape, font_size=-1):
490
- """
491
- Creates an image overlay with text annotations based on provided bounding box information and predictions.
492
-
493
- Parameters:
494
- - textline_images (dict): A dictionary containing the bounding box data for each text segment.
495
- It should have keys 'left', 'center', 'width', and optionally 'height'. Each key should have
496
- a list of values corresponding to each text segment's properties.
497
- - textline_preds (dict): A dictionary containing the predicted text segments. It should have
498
- a key 'preds' which holds a list of text predictions corresponding to the bounding boxes in
499
- textline_images.
500
- - img_shape (tuple): A tuple representing the shape of the image where the text is to be drawn.
501
- The format should be (height, width).
502
- - font_size (int, optional): Specifies the font size for the text. If set to -1 (default), the font size
503
- is dynamically adjusted to fit the text within its bounding box width using the `adjust_font_size`
504
- function. If a specific integer is provided, it uses that size for all text segments.
505
-
506
- Returns:
507
- - Image: An image object with text drawn over a blank white background.
508
-
509
- Raises:
510
- - AssertionError: If the lengths of the lists in `textline_images` and `textline_preds['preds']`
511
- do not correspond, indicating a mismatch in the number of bounding boxes and text predictions.
512
- """
513
-
514
- for key in textline_images.keys():
515
- assert len(textline_images[key]) == len(textline_preds['preds']), f'Length of {key} and preds doesnt correspond'
516
-
517
- # Create a blank white image
518
- img_gen = Image.new('RGB', (img_shape[1], img_shape[0]), color=(255, 255, 255))
519
- draw = ImageDraw.Draw(img_gen)
520
-
521
- # Draw each text segment within its bounding box
522
- for i in range(len(textline_preds['preds'])):
523
- left_x = textline_images['left'][i]
524
- center_y = textline_images['center'][i][1]
525
- #height = textline_images['height'][i]
526
- width = textline_images['width'][i]
527
- text = textline_preds['preds'][i]
528
-
529
- # dynamic or static text size
530
- if font_size==-1:
531
- font = self.adjust_font_size(draw, text, width)
532
- else:
533
- font = ImageFont.load_default(font_size)
534
- draw.text((left_x, center_y), text, fill=(0, 0, 0), font=font, align='left')
535
-
536
- return img_gen
537
-
538
-
539
- def visualize_model_output(self, prediction, img):
540
- """
541
- Visualizes the output of a model prediction by overlaying predicted classes with distinct colors onto the original image.
542
-
543
- Parameters:
544
- - prediction (ndarray): A 3D array where the first channel holds the class predictions.
545
- - img (ndarray): The original image to overlay predictions onto. This should be in the same dimensions or resized accordingly.
546
-
547
- Returns:
548
- - ndarray: An image where the model's predictions are overlaid on the original image using a predefined color map.
549
-
550
- Description:
551
- The function first identifies unique classes present in the prediction's first channel. Each class is assigned a specific color from a predefined dictionary `rgb_colors`. The function then creates an output image where each pixel's color corresponds to the class predicted at that location.
552
-
553
- The function resizes the original image to match the dimensions of the prediction if necessary. It then blends the original image and the colored prediction output using OpenCV's `addWeighted` method to produce a final image that highlights the model's predictions with transparency.
554
-
555
- Note:
556
- - This function relies on `numpy` for array manipulations and `cv2` for image processing.
557
- - Ensure the `rgb_colors` dictionary contains enough colors for all classes your model can predict.
558
- - The function assumes `prediction` array's shape is compatible with `img`.
559
- """
560
-
561
- unique_classes = np.unique(prediction[:,:,0])
562
- rgb_colors = {'0' : [255, 255, 255],
563
- '1' : [255, 0, 0],
564
- '2' : [255, 125, 0],
565
- '3' : [255, 0, 125],
566
- '4' : [125, 125, 125],
567
- '5' : [125, 125, 0],
568
- '6' : [0, 125, 255],
569
- '7' : [0, 125, 0],
570
- '8' : [125, 125, 125],
571
- '9' : [0, 125, 255],
572
- '10' : [125, 0, 125],
573
- '11' : [0, 255, 0],
574
- '12' : [0, 0, 255],
575
- '13' : [0, 255, 255],
576
- '14' : [255, 125, 125],
577
- '15' : [255, 0, 255]}
578
-
579
- output = np.zeros(prediction.shape)
580
-
581
- for unq_class in unique_classes:
582
- rgb_class_unique = rgb_colors[str(int(unq_class))]
583
- output[:,:,0][prediction[:,:,0]==unq_class] = rgb_class_unique[0]
584
- output[:,:,1][prediction[:,:,0]==unq_class] = rgb_class_unique[1]
585
- output[:,:,2][prediction[:,:,0]==unq_class] = rgb_class_unique[2]
586
-
587
- img = resize_image(img, output.shape[0], output.shape[1])
588
-
589
- output = output.astype(np.int32)
590
- img = img.astype(np.int32)
591
-
592
- #added_image = cv2.addWeighted(img,0.5,output,0.1,0) # orig by eynollah (gives dark image output)
593
- added_image = cv2.addWeighted(img,0.8,output,0.2,10)
594
-
595
  return added_image
 
1
+ import cv2
2
+ import numpy as np
3
+ import json
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ from transformers import pipeline
6
+ from huggingface_hub import from_pretrained_keras
7
+ import imageio
8
+
9
+
10
+ def resize_image(img_in,input_height,input_width):
11
+ return cv2.resize( img_in, ( input_width,input_height) ,interpolation=cv2.INTER_NEAREST)
12
+
13
+ def write_dict_to_json(dictionary, save_path, indent=4):
14
+ with open(save_path, "w") as outfile:
15
+ json.dump(dictionary, outfile, indent=indent)
16
+
17
+ def load_json_to_dict(load_path):
18
+ with open(load_path) as json_file:
19
+ return json.load(json_file)
20
+
21
+
22
+ class OCRD:
23
+ """
24
+ Optical Character Recognition and Document processing class that provides functionalities
25
+ to preprocess images, detect text lines, perform OCR, and visualize the results.
26
+
27
+ The class utilizes deep learning models for various tasks such as binarization and text
28
+ line segmentation. It provides comprehensive methods to handle image scaling, prediction,
29
+ text extraction, and overlaying recognized text on images.
30
+
31
+ Attributes:
32
+ image (ndarray): The image loaded into memory from the specified path. This image
33
+ is used across various methods within the class.
34
+
35
+ Methods:
36
+ __init__(img_path: str):
37
+ Initializes the OCRD class by loading an image from the specified file path.
38
+
39
+ scale_image(img: ndarray) -> ndarray:
40
+ Scales an image while maintaining its aspect ratio based on predefined width thresholds.
41
+
42
+ predict(model, img: ndarray) -> ndarray:
43
+ Uses a specified model to make predictions on the image. This function handles
44
+ image resizing and segmenting for model input.
45
+
46
+ binarize_image(img: ndarray, binarize_mode: str) -> ndarray:
47
+ Applies binarization to the image based on the specified mode ('detailed', 'fast', or 'no').
48
+
49
+ segment_textlines(img: ndarray) -> ndarray:
50
+ Segments text lines from the binarized image using a pretrained model.
51
+
52
+ extract_filter_and_deskew_textlines(img: ndarray, textline_mask: ndarray, min_pixel_sum: int, median_bounds: tuple) -> (dict, ndarray):
53
+ Processes an image to extract and correct orientation of text lines based on the provided mask.
54
+
55
+ ocr_on_textlines(textline_images: dict) -> dict:
56
+ Performs OCR on the extracted text lines and returns the recognized text.
57
+
58
+ create_text_overlay_image(textline_images: dict, textline_preds: dict, img_shape: tuple, font_size: int) -> Image:
59
+ Creates an image overlay with the recognized text annotations.
60
+
61
+ visualize_model_output(prediction: ndarray, img: ndarray) -> ndarray:
62
+ Visualizes the model's prediction by overlaying it onto the original image with distinct colors.
63
+ """
64
+
65
+ def __init__(self, img_path):
66
+ self.image = np.array(Image.open(img_path))
67
+
68
+ def scale_image(self, img):
69
+ """
70
+ Scales an image to have dimensions suitable for neural network inference. Scaling is based on the
71
+ input width parameter. The new width and height of the image are calculated to maintain the aspect
72
+ ratio of the original image.
73
+
74
+ Parameters:
75
+ - img (ndarray): The image to be scaled, expected to be in the form of a numpy array where
76
+ img.shape[0] is the height and img.shape[1] is the width.
77
+
78
+ Behavior:
79
+ - If image width is less than 1100, the new width is set to 2000 pixels. The height is adjusted
80
+ to maintain the aspect ratio.
81
+ - If image width is between 1100 (inclusive) and 2500 (exclusive), the width remains unchanged
82
+ and the height is adjusted to maintain the aspect ratio.
83
+ - If image width is 2500 or more, the width is set to 2000 pixels and the height is similarly
84
+ adjusted to maintain the aspect ratio.
85
+
86
+ Returns:
87
+ - img_new (ndarray): A new image array that has been resized according to the specified rules.
88
+ The aspect ratio of the original image is preserved.
89
+
90
+ Note:
91
+ - This function assumes that a function `resize_image(img, height, width)` is available and is
92
+ used to resize the image where `img` is the original image array, `height` is the new height,
93
+ and `width` is the new width.
94
+ """
95
+
96
+ width_early = img.shape[1]
97
+
98
+ if width_early < 1100:
99
+ img_w_new = 2000
100
+ img_h_new = int(img.shape[0] / float(img.shape[1]) * 2000)
101
+ elif width_early >= 1100 and width_early < 2500:
102
+ img_w_new = width_early
103
+ img_h_new = int(img.shape[0] / float(img.shape[1]) * width_early)
104
+ else:
105
+ img_w_new = 2000
106
+ img_h_new = int(img.shape[0] / float(img.shape[1]) * 2000)
107
+
108
+ img_new = resize_image(img, img_h_new, img_w_new)
109
+
110
+ return img_new
111
+
112
+ def predict(self, model, img):
113
+ """
114
+ Processes an image to predict segmentation outputs using a given model. The function handles image resizing
115
+ to match the model's input dimensions and ensures that the entire image is processed by segmenting it into patches
116
+ that the model can handle. The prediction from these patches is then reassembled into a single output image.
117
+
118
+ Parameters:
119
+ - model (keras.Model): The neural network model used for predicting the image segmentation. The model should have
120
+ predefined input dimensions (height and width).
121
+ - img (ndarray): The image to be processed, represented as a numpy array.
122
+
123
+ Returns:
124
+ - prediction_true (ndarray): An image of the same size as the input image, containing the segmentation prediction
125
+ with each pixel labeled according to the model's output.
126
+
127
+ Details:
128
+ - The function first scales the input image according to the model's required input dimensions. If the scaled image
129
+ is smaller than the model's height or width, it is resized to match exactly.
130
+ - The function processes the image in overlapping patches to ensure smooth transitions between the segments. These
131
+ patches are then processed individually through the model.
132
+ - Predictions from these patches are then stitched together to form a complete output image, ensuring that edge
133
+ artifacts are minimized by carefully blending the overlapping areas.
134
+ - This method assumes the availability of `resize_image` function for scaling and resizing
135
+ operations, respectively.
136
+ - The output is converted to an 8-bit image before returning, suitable for display or further processing.
137
+ """
138
+
139
+ # bitmap output
140
+ img_height_model=model.layers[len(model.layers)-1].output_shape[1]
141
+ img_width_model=model.layers[len(model.layers)-1].output_shape[2]
142
+
143
+ img = self.scale_image(img)
144
+
145
+ if img.shape[0] < img_height_model:
146
+ img = resize_image(img, img_height_model, img.shape[1])
147
+
148
+ if img.shape[1] < img_width_model:
149
+ img = resize_image(img, img.shape[0], img_width_model)
150
+
151
+ marginal_of_patch_percent = 0.1
152
+ margin = int(marginal_of_patch_percent * img_height_model)
153
+ width_mid = img_width_model - 2 * margin
154
+ height_mid = img_height_model - 2 * margin
155
+ img = img / float(255.0)
156
+ img = img.astype(np.float16)
157
+ img_h = img.shape[0]
158
+ img_w = img.shape[1]
159
+ prediction_true = np.zeros((img_h, img_w, 3))
160
+ nxf = img_w / float(width_mid)
161
+ nyf = img_h / float(height_mid)
162
+ nxf = int(nxf) + 1 if nxf > int(nxf) else int(nxf)
163
+ nyf = int(nyf) + 1 if nyf > int(nyf) else int(nyf)
164
+
165
+ for i in range(nxf):
166
+ for j in range(nyf):
167
+ if i == 0:
168
+ index_x_d = i * width_mid
169
+ index_x_u = index_x_d + img_width_model
170
+ else:
171
+ index_x_d = i * width_mid
172
+ index_x_u = index_x_d + img_width_model
173
+ if j == 0:
174
+ index_y_d = j * height_mid
175
+ index_y_u = index_y_d + img_height_model
176
+ else:
177
+ index_y_d = j * height_mid
178
+ index_y_u = index_y_d + img_height_model
179
+ if index_x_u > img_w:
180
+ index_x_u = img_w
181
+ index_x_d = img_w - img_width_model
182
+ if index_y_u > img_h:
183
+ index_y_u = img_h
184
+ index_y_d = img_h - img_height_model
185
+
186
+ img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :]
187
+ label_p_pred = model.predict(img_patch.reshape(1, img_patch.shape[0], img_patch.shape[1], img_patch.shape[2]),
188
+ verbose=0)
189
+
190
+ seg = np.argmax(label_p_pred, axis=3)[0]
191
+ seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
192
+
193
+ if i == 0 and j == 0:
194
+ seg_color = seg_color[0 : seg_color.shape[0] - margin, 0 : seg_color.shape[1] - margin, :]
195
+ prediction_true[index_y_d + 0 : index_y_u - margin, index_x_d + 0 : index_x_u - margin, :] = seg_color
196
+ elif i == nxf - 1 and j == nyf - 1:
197
+ seg_color = seg_color[margin : seg_color.shape[0] - 0, margin : seg_color.shape[1] - 0, :]
198
+ prediction_true[index_y_d + margin : index_y_u - 0, index_x_d + margin : index_x_u - 0, :] = seg_color
199
+ elif i == 0 and j == nyf - 1:
200
+ seg_color = seg_color[margin : seg_color.shape[0] - 0, 0 : seg_color.shape[1] - margin, :]
201
+ prediction_true[index_y_d + margin : index_y_u - 0, index_x_d + 0 : index_x_u - margin, :] = seg_color
202
+ elif i == nxf - 1 and j == 0:
203
+ seg_color = seg_color[0 : seg_color.shape[0] - margin, margin : seg_color.shape[1] - 0, :]
204
+ prediction_true[index_y_d + 0 : index_y_u - margin, index_x_d + margin : index_x_u - 0, :] = seg_color
205
+ elif i == 0 and j != 0 and j != nyf - 1:
206
+ seg_color = seg_color[margin : seg_color.shape[0] - margin, 0 : seg_color.shape[1] - margin, :]
207
+ prediction_true[index_y_d + margin : index_y_u - margin, index_x_d + 0 : index_x_u - margin, :] = seg_color
208
+ elif i == nxf - 1 and j != 0 and j != nyf - 1:
209
+ seg_color = seg_color[margin : seg_color.shape[0] - margin, margin : seg_color.shape[1] - 0, :]
210
+ prediction_true[index_y_d + margin : index_y_u - margin, index_x_d + margin : index_x_u - 0, :] = seg_color
211
+ elif i != 0 and i != nxf - 1 and j == 0:
212
+ seg_color = seg_color[0 : seg_color.shape[0] - margin, margin : seg_color.shape[1] - margin, :]
213
+ prediction_true[index_y_d + 0 : index_y_u - margin, index_x_d + margin : index_x_u - margin, :] = seg_color
214
+ elif i != 0 and i != nxf - 1 and j == nyf - 1:
215
+ seg_color = seg_color[margin : seg_color.shape[0] - 0, margin : seg_color.shape[1] - margin, :]
216
+ prediction_true[index_y_d + margin : index_y_u - 0, index_x_d + margin : index_x_u - margin, :] = seg_color
217
+ else:
218
+ seg_color = seg_color[margin : seg_color.shape[0] - margin, margin : seg_color.shape[1] - margin, :]
219
+ prediction_true[index_y_d + margin : index_y_u - margin, index_x_d + margin : index_x_u - margin, :] = seg_color
220
+
221
+ prediction_true = prediction_true.astype(np.uint8)
222
+
223
+ return prediction_true
224
+
225
+ def binarize_image(self, img, binarize_mode='detailed'):
226
+ """
227
+ Binarizes an image according to the specified mode.
228
+
229
+ Parameters:
230
+ - img (ndarray): The input image to be binarized.
231
+ - binarize_mode (str): The mode of binarization. Can be 'detailed', 'fast', or 'no'.
232
+ - 'detailed': Uses a pre-trained deep learning model for binarization.
233
+ - 'fast': Uses OpenCV for a quicker, threshold-based binarization.
234
+ - 'no': Returns a copy of the original image.
235
+
236
+ Returns:
237
+ - ndarray: The binarized image.
238
+
239
+ Raises:
240
+ - ValueError: If an invalid binarize_mode is provided.
241
+
242
+ Description:
243
+ Depending on the 'binarize_mode', the function processes the image differently:
244
+ - For 'detailed' mode, it loads a specific model and performs prediction to binarize the image.
245
+ - For 'fast' mode, it quickly converts the image to grayscale and applies a threshold.
246
+ - For 'no' mode, it simply returns the original image unchanged.
247
+ If an unsupported mode is provided, the function raises a ValueError.
248
+
249
+ Note:
250
+ - The 'detailed' mode requires a pre-trained model from huggingface_hub.
251
+ - This function depends on OpenCV (cv2) for image processing in 'fast' mode.
252
+ """
253
+
254
+ if binarize_mode == 'detailed':
255
+ model_name = "SBB/eynollah-binarization"
256
+ model = from_pretrained_keras(model_name)
257
+ binarized = self.predict(model, img)
258
+
259
+ # Convert from mask to image (letters black)
260
+ binarized = binarized.astype(np.int8)
261
+ binarized = -binarized + 1
262
+ binarized = (binarized * 255).astype(np.uint8)
263
+
264
+ elif binarize_mode == 'fast':
265
+ binarized = self.scale_image(img, self.image)
266
+ binarized = cv2.cvtColor(binarized, cv2.COLOR_BGR2GRAY)
267
+ _, binarized = cv2.threshold(binarized, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
268
+ binarized = np.repeat(binarized[:, :, np.newaxis], 3, axis=2)
269
+
270
+ elif binarize_mode == 'no':
271
+ binarized = img.copy()
272
+
273
+ else:
274
+ accepted_values = ['detailed', 'fast', 'no']
275
+ raise ValueError(f"Invalid value provided: {binarize_mode}. Accepted values are: {accepted_values}")
276
+
277
+ binarized = binarized.astype(np.uint8)
278
+
279
+ return binarized
280
+
281
+
282
+ def segment_textlines(self, img):
283
+ '''
284
+ ADD DOCUMENTATION!
285
+ '''
286
+ model_name = "SBB/eynollah-textline"
287
+ model = from_pretrained_keras(model_name)
288
+ textline_segments = self.predict(model, img)
289
+
290
+ return textline_segments
291
+
292
+
293
+ def extract_filter_and_deskew_textlines(self, img, textline_mask, min_pixel_sum=20, median_bounds=(.5, 20)):
294
+
295
+ """
296
+ Extracts and deskews text lines from an image based on a provided textline mask. This function identifies
297
+ text lines, filters out those that do not meet size criteria, calculates their minimum area rectangles,
298
+ performs perspective transformations to deskew each text line, and handles potential rotations to ensure
299
+ text lines are presented horizontally.
300
+
301
+ Parameters:
302
+ - img (numpy.ndarray): The original image from which to extract and deskew text lines. It should be a 3D array.
303
+ - textline_mask (numpy.ndarray): A binary mask where text lines have been segmented. It should be a 2D array.
304
+ - min_pixel_sum (int, optional): The minimum number of pixels (area) a connected component must have to be considered
305
+ a valid text line. If None, no filtering is applied.
306
+ - median_bounds (tuple, optional): A tuple representing the lower and upper bounds as multipliers for filtering
307
+ text lines based on the median size of identified text lines. If None, no filtering is applied.
308
+
309
+ Returns:
310
+ - tuple:
311
+ - dict: A dictionary containing lists of the extracted and deskewed text line images along with their
312
+ metadata (center, left side, height, width, and rotation angle of the bounding box).
313
+ - numpy.ndarray: An image visualization of the filtered text line mask for debugging or analysis.
314
+
315
+ Description:
316
+ The function first uses connected components to identify potential text lines from the mask. It filters these
317
+ based on absolute size (min_pixel_sum) and relative size (median_bounds). For each valid text line, it computes
318
+ a minimum area rectangle, extracts and deskews the bounded region. This includes rotating the text line if it
319
+ is detected as vertical (taller than wide). Finally, it aggregates the results and provides an image for
320
+ visualization of the text lines retained after filtering.
321
+
322
+ Notes:
323
+ - This function assumes the textline_mask is properly segmented and binary (0s for background, 255 for text lines).
324
+ - Errors in perspective transformation due to incorrect contour extraction or bounding box calculations are handled
325
+ gracefully, reporting the error but continuing with other text lines.
326
+ """
327
+
328
+ num_labels, labels_im = cv2.connectedComponents(textline_mask)
329
+
330
+ # Thresholds for filtering
331
+ MIN_PIXEL_SUM = min_pixel_sum # absolute filtering
332
+ MEDIAN_LOWER_BOUND = median_bounds[0] # relative filtering
333
+ MEDIAN_UPPER_BOUND = median_bounds[1] # relative filtering
334
+
335
+ # Gather masks and their sizes
336
+ cc_sizes = []
337
+ masks = []
338
+ labels_im_filtered = labels_im > 0 # for visualizing filtering result
339
+ for label in range(1, num_labels): # ignore background class
340
+ mask = np.where(labels_im == label, True, False)
341
+ if MIN_PIXEL_SUM is None:
342
+ is_above_min_pixel_sum = True
343
+ else:
344
+ is_above_min_pixel_sum = mask.sum() > MIN_PIXEL_SUM
345
+ if is_above_min_pixel_sum: # dismiss mini segmentations to avoid skewing of median
346
+ cc_sizes.append(mask.sum())
347
+ masks.append(mask)
348
+
349
+ # filter masks by size in relation to median; then calculate contours and min area bounding box for remaining ones
350
+ rectangles = []
351
+ median = np.median(cc_sizes)
352
+ for mask in masks:
353
+ mask_sum = mask.sum()
354
+ if MEDIAN_LOWER_BOUND is None:
355
+ is_above_lower_media_bound = True
356
+ else:
357
+ is_above_lower_media_bound = mask_sum > median*MEDIAN_LOWER_BOUND
358
+ if MEDIAN_UPPER_BOUND is None:
359
+ is_below_upper_median_bound = True
360
+ else:
361
+ is_below_upper_median_bound = mask_sum < median*MEDIAN_UPPER_BOUND
362
+ if is_above_lower_media_bound and is_below_upper_median_bound:
363
+ labels_im_filtered[mask > 0] = False
364
+ mask = (mask*255).astype(np.uint8)
365
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
366
+ rect = cv2.minAreaRect(contours[0])
367
+ if np.prod(rect[1]) > 0: # filter out if height or width = 0
368
+ rectangles.append(rect)
369
+
370
+ # Transform (rotated) bounding boxes to horizontal; store together with rotation angle for downstream process re-transform
371
+ if rectangles:
372
+ # Filter rectangles and de-skew images
373
+ textline_images = []
374
+ for rect in rectangles:
375
+ width, height = rect[1]
376
+ rotation_angle = rect[2] # clarify how to interpret and use rotation angle!
377
+
378
+ # Convert dimensions to integer and ensure they are > 0
379
+ width = int(width)
380
+ height = int(height)
381
+
382
+ # get source and destination points for image transform
383
+ box = cv2.boxPoints(rect)
384
+ box = np.intp(box)
385
+ src_pts = box.astype("float32")
386
+ dst_pts = np.array([[0, height-1],
387
+ [0, 0],
388
+ [width-1, 0],
389
+ [width-1, height-1]], dtype="float32")
390
+
391
+ try:
392
+ M = cv2.getPerspectiveTransform(src_pts, dst_pts)
393
+ warped = cv2.warpPerspective(img, M, (width, height))
394
+ # Check and rotate if the text line is taller than wide
395
+ if height > width:
396
+ warped = cv2.rotate(warped, cv2.ROTATE_90_CLOCKWISE)
397
+ temp = height
398
+ height = width
399
+ width = temp
400
+ rotation_angle = 90-rotation_angle
401
+ center = rect[0]
402
+ left = center[0] - width//2
403
+ textline_images.append((warped, center, left, height, width, rotation_angle))
404
+ except cv2.error as e:
405
+ print(f"Error with warpPerspective: {e}")
406
+
407
+ # cast to dict
408
+ keys = ['array', 'center', 'left', 'height', 'width', 'rotation_angle']
409
+ textline_images = {key: [tup[i] for tup in textline_images] for i, key in enumerate(keys)}
410
+ num_labels_filtered = len(textline_images['array'])
411
+ labels_im_filtered = np.repeat(labels_im_filtered[:, :, np.newaxis], 3, axis=2).astype(np.uint8) # 3 color channels for plotting
412
+ print(f'Kept {num_labels_filtered} of {num_labels} text segments after filtering.')
413
+ print(f'All segments deleted smaller than {MIN_PIXEL_SUM} pixels (absolute min size).')
414
+ if MEDIAN_LOWER_BOUND is not None:
415
+ print(f'All segments deleted smaller than {median*MEDIAN_LOWER_BOUND} pixels (lower median bound).')
416
+ if MEDIAN_UPPER_BOUND is not None:
417
+ print(f'All segments deleted bigger than {median*MEDIAN_UPPER_BOUND} pixels (upper median bound).')
418
+ if MEDIAN_LOWER_BOUND is not None or MEDIAN_UPPER_BOUND is not None:
419
+ print(f'Median segment size (pixel sum) used for filtering: {int(median)}.')
420
+
421
+ return textline_images, labels_im_filtered
422
+
423
+
424
+ def ocr_on_textlines(self, textline_images, model_name="microsoft/trocr-base-handwritten"):
425
+ """
426
+ Processes a list of image arrays using a pre-trained OCR model to extract text.
427
+
428
+ Parameters:
429
+ - textline_images (dict): A dictionary with a key 'array' that contains a list of image arrays.
430
+ Each image array represents a line of text that will be processed by the OCR model.
431
+ - model_name (str): A huggingface model trained for OCR on single text lines
432
+
433
+ Returns:
434
+ - dict: A dictionary containing a list of extracted text under the key 'preds'.
435
+
436
+ Description:
437
+ The function initializes the OCR model 'microsoft/trocr-base-handwritten' using Hugging Face's
438
+ `pipeline` API for image-to-text conversion. Each image in the input list is converted from an
439
+ array format to a PIL Image, processed by the model, and the text prediction is collected.
440
+ The progress of image processing is printed every 10 images. The final result is a dictionary
441
+ with the key 'preds' that holds all text predictions as a list.
442
+
443
+ Note:
444
+ - This function requires the `transformers` library from Hugging Face and PIL library to run.
445
+ - Ensure that the model 'microsoft/trocr-base-handwritten' is correctly loaded and the
446
+ `transformers` library is updated to use the pipeline.
447
+ """
448
+
449
+ pipe = pipeline("image-to-text", model=model_name)
450
+
451
+ # Model inference
452
+ textline_preds = []
453
+ len_array = len(textline_images['array'])
454
+ for i, textline in enumerate(textline_images['array'][:]):
455
+ if i % 10 == 1:
456
+ print(f'Processing textline no. {i} of {len_array}')
457
+ textline = Image.fromarray(textline)
458
+ textline_preds.append(pipe(textline))
459
+
460
+ # Convert to dict
461
+ preds = [pred[0]['generated_text'] for pred in textline_preds]
462
+ textline_preds_dict = {'preds': preds}
463
+
464
+ return textline_preds_dict
465
+
466
+
467
+ def adjust_font_size(self, draw, text, box_width):
468
+ """
469
+ Adjusts the font size to ensure the text fits within a specified width.
470
+
471
+ Parameters:
472
+ - draw (ImageDraw.Draw): An instance of ImageDraw.Draw used to render the text.
473
+ - text (str): The text string to be rendered.
474
+ - box_width (int): The maximum width in pixels that the text should occupy.
475
+
476
+ Returns:
477
+ - ImageFont: A font object with a size adjusted to fit the text within the specified width.
478
+ """
479
+
480
+ for font_size in range(1, 200): # Adjust the range as needed
481
+ font = ImageFont.load_default(font_size)
482
+ text_width = draw.textlength(text, font=font)
483
+ if text_width > box_width:
484
+ font_size = max(5, int(font_size - 10)) # min font size of 5
485
+ return ImageFont.load_default(font_size) # Return the last fitting size
486
+ return font # Return max size if none exceeded the box
487
+
488
+
489
+ def create_text_overlay_image(self, textline_images, textline_preds, img_shape, font_size=-1):
490
+ """
491
+ Creates an image overlay with text annotations based on provided bounding box information and predictions.
492
+
493
+ Parameters:
494
+ - textline_images (dict): A dictionary containing the bounding box data for each text segment.
495
+ It should have keys 'left', 'center', 'width', and optionally 'height'. Each key should have
496
+ a list of values corresponding to each text segment's properties.
497
+ - textline_preds (dict): A dictionary containing the predicted text segments. It should have
498
+ a key 'preds' which holds a list of text predictions corresponding to the bounding boxes in
499
+ textline_images.
500
+ - img_shape (tuple): A tuple representing the shape of the image where the text is to be drawn.
501
+ The format should be (height, width).
502
+ - font_size (int, optional): Specifies the font size for the text. If set to -1 (default), the font size
503
+ is dynamically adjusted to fit the text within its bounding box width using the `adjust_font_size`
504
+ function. If a specific integer is provided, it uses that size for all text segments.
505
+
506
+ Returns:
507
+ - Image: An image object with text drawn over a blank white background.
508
+
509
+ Raises:
510
+ - AssertionError: If the lengths of the lists in `textline_images` and `textline_preds['preds']`
511
+ do not correspond, indicating a mismatch in the number of bounding boxes and text predictions.
512
+ """
513
+
514
+ for key in textline_images.keys():
515
+ assert len(textline_images[key]) == len(textline_preds['preds']), f'Length of {key} and preds doesnt correspond'
516
+
517
+ # Create a blank white image
518
+ img_gen = Image.new('RGB', (img_shape[1], img_shape[0]), color=(255, 255, 255))
519
+ draw = ImageDraw.Draw(img_gen)
520
+
521
+ # Draw each text segment within its bounding box
522
+ for i in range(len(textline_preds['preds'])):
523
+ left_x = textline_images['left'][i]
524
+ center_y = textline_images['center'][i][1]
525
+ #height = textline_images['height'][i]
526
+ width = textline_images['width'][i]
527
+ text = textline_preds['preds'][i]
528
+
529
+ # dynamic or static text size
530
+ if font_size==-1:
531
+ font = self.adjust_font_size(draw, text, width)
532
+ else:
533
+ font = ImageFont.load_default(font_size)
534
+ draw.text((left_x, center_y), text, fill=(0, 0, 0), font=font, align='left')
535
+
536
+ return img_gen
537
+
538
+
539
+ def visualize_model_output(self, prediction, img):
540
+ """
541
+ Visualizes the output of a model prediction by overlaying predicted classes with distinct colors onto the original image.
542
+
543
+ Parameters:
544
+ - prediction (ndarray): A 3D array where the first channel holds the class predictions.
545
+ - img (ndarray): The original image to overlay predictions onto. This should be in the same dimensions or resized accordingly.
546
+
547
+ Returns:
548
+ - ndarray: An image where the model's predictions are overlaid on the original image using a predefined color map.
549
+
550
+ Description:
551
+ The function first identifies unique classes present in the prediction's first channel. Each class is assigned a specific color from a predefined dictionary `rgb_colors`. The function then creates an output image where each pixel's color corresponds to the class predicted at that location.
552
+
553
+ The function resizes the original image to match the dimensions of the prediction if necessary. It then blends the original image and the colored prediction output using OpenCV's `addWeighted` method to produce a final image that highlights the model's predictions with transparency.
554
+
555
+ Note:
556
+ - This function relies on `numpy` for array manipulations and `cv2` for image processing.
557
+ - Ensure the `rgb_colors` dictionary contains enough colors for all classes your model can predict.
558
+ - The function assumes `prediction` array's shape is compatible with `img`.
559
+ """
560
+
561
+ unique_classes = np.unique(prediction[:,:,0])
562
+ rgb_colors = {'0' : [255, 255, 255],
563
+ '1' : [255, 0, 0],
564
+ '2' : [255, 125, 0],
565
+ '3' : [255, 0, 125],
566
+ '4' : [125, 125, 125],
567
+ '5' : [125, 125, 0],
568
+ '6' : [0, 125, 255],
569
+ '7' : [0, 125, 0],
570
+ '8' : [125, 125, 125],
571
+ '9' : [0, 125, 255],
572
+ '10' : [125, 0, 125],
573
+ '11' : [0, 255, 0],
574
+ '12' : [0, 0, 255],
575
+ '13' : [0, 255, 255],
576
+ '14' : [255, 125, 125],
577
+ '15' : [255, 0, 255]}
578
+
579
+ output = np.zeros(prediction.shape)
580
+
581
+ for unq_class in unique_classes:
582
+ rgb_class_unique = rgb_colors[str(int(unq_class))]
583
+ output[:,:,0][prediction[:,:,0]==unq_class] = rgb_class_unique[0]
584
+ output[:,:,1][prediction[:,:,0]==unq_class] = rgb_class_unique[1]
585
+ output[:,:,2][prediction[:,:,0]==unq_class] = rgb_class_unique[2]
586
+
587
+ img = resize_image(img, output.shape[0], output.shape[1])
588
+
589
+ output = output.astype(np.int32)
590
+ img = img.astype(np.int32)
591
+
592
+ #added_image = cv2.addWeighted(img,0.5,output,0.1,0) # orig by eynollah (gives dark image output)
593
+ added_image = cv2.addWeighted(img,0.8,output,0.2,10)
594
+
595
  return added_image