Spaces:
Running
Running
File size: 3,008 Bytes
93d9ee0 4eb7c20 93d0893 4eb7c20 62e4e64 93d0893 4eb7c20 93d0893 4eb7c20 62e4e64 93d0893 4eb7c20 93d0893 4eb7c20 62e4e64 4eb7c20 93d0893 4eb7c20 93d0893 4eb7c20 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
# Copyright (C) 2021-2025, Mindee.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
import numpy as np
import torch
from doctr.models import ocr_predictor
from doctr.models.predictor import OCRPredictor
DET_ARCHS = [
"fast_base",
"fast_small",
"fast_tiny",
"db_resnet50",
"db_resnet34",
"db_mobilenet_v3_large",
"linknet_resnet18",
"linknet_resnet34",
"linknet_resnet50",
]
RECO_ARCHS = [
"crnn_vgg16_bn",
"crnn_mobilenet_v3_small",
"crnn_mobilenet_v3_large",
"master",
"sar_resnet31",
"vitstr_small",
"vitstr_base",
"parseq",
]
def load_predictor(
det_arch: str,
reco_arch: str,
assume_straight_pages: bool,
straighten_pages: bool,
export_as_straight_boxes: bool,
disable_page_orientation: bool,
disable_crop_orientation: bool,
bin_thresh: float,
box_thresh: float,
device: torch.device,
) -> OCRPredictor:
"""Load a predictor from doctr.models
Args:
det_arch: detection architecture
reco_arch: recognition architecture
assume_straight_pages: whether to assume straight pages or not
straighten_pages: whether to straighten rotated pages or not
export_as_straight_boxes: whether to export boxes as straight or not
disable_page_orientation: whether to disable page orientation or not
disable_crop_orientation: whether to disable crop orientation or not
bin_thresh: binarization threshold for the segmentation map
box_thresh: minimal objectness score to consider a box
device: torch.device, the device to load the predictor on
Returns:
instance of OCRPredictor
"""
predictor = ocr_predictor(
det_arch,
reco_arch,
pretrained=True,
assume_straight_pages=assume_straight_pages,
straighten_pages=straighten_pages,
export_as_straight_boxes=export_as_straight_boxes,
detect_orientation=not assume_straight_pages,
disable_page_orientation=disable_page_orientation,
disable_crop_orientation=disable_crop_orientation,
).to(device)
predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh
predictor.det_predictor.model.postprocessor.box_thresh = box_thresh
return predictor
def forward_image(predictor: OCRPredictor, image: np.ndarray, device: torch.device) -> np.ndarray:
"""Forward an image through the predictor
Args:
predictor: instance of OCRPredictor
image: image to process
device: torch.device, the device to process the image on
Returns:
segmentation map
"""
with torch.no_grad():
processed_batches = predictor.det_predictor.pre_processor([image])
out = predictor.det_predictor.model(processed_batches[0].to(device), return_model_output=True)
seg_map = out["out_map"].to("cpu").numpy()
return seg_map
|