odulcy-mindee commited on
Commit
93d0893
·
verified ·
1 Parent(s): 92af12f

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. README.md +10 -2
  2. app.py +33 -4
  3. backend/pytorch.py +12 -0
  4. backend/tensorflow.py +101 -0
  5. packages.txt +1 -1
README.md CHANGED
@@ -4,13 +4,13 @@ emoji: 📑
4
  colorFrom: purple
5
  colorTo: pink
6
  sdk: streamlit
7
- sdk_version: 1.30.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
  ---
12
 
13
- # Configuration
14
 
15
  `title`: _string_
16
  Display title for the Space
@@ -37,3 +37,11 @@ Path is relative to the root of the repository.
37
 
38
  `pinned`: _boolean_
39
  Whether the Space stays on top of your list.
 
 
 
 
 
 
 
 
 
4
  colorFrom: purple
5
  colorTo: pink
6
  sdk: streamlit
7
+ sdk_version: 1.39.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
  ---
12
 
13
+ ## Configuration
14
 
15
  `title`: _string_
16
  Display title for the Space
 
37
 
38
  `pinned`: _boolean_
39
  Whether the Space stays on top of your list.
40
+
41
+ ## Run the demo locally
42
+
43
+ ```bash
44
+ cd demo
45
+ pip install -r pt-requirements.txt
46
+ streamlit run app.py
47
+ ```
app.py CHANGED
@@ -7,14 +7,25 @@ import cv2
7
  import matplotlib.pyplot as plt
8
  import numpy as np
9
  import streamlit as st
10
- import torch
11
 
 
12
  from doctr.io import DocumentFile
13
  from doctr.utils.visualization import visualize_page
14
 
15
- from backend.pytorch import DET_ARCHS, RECO_ARCHS, forward_image, load_predictor
 
 
16
 
17
- forward_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  def main(det_archs, reco_archs):
@@ -51,6 +62,7 @@ def main(det_archs, reco_archs):
51
 
52
  # Model selection
53
  st.sidebar.title("Model selection")
 
54
  det_arch = st.sidebar.selectbox("Text detection model", det_archs)
55
  reco_arch = st.sidebar.selectbox("Text recognition model", reco_archs)
56
 
@@ -60,12 +72,21 @@ def main(det_archs, reco_archs):
60
  st.sidebar.title("Parameters")
61
  assume_straight_pages = st.sidebar.checkbox("Assume straight pages", value=True)
62
  st.sidebar.write("\n")
 
 
 
 
 
 
63
  # Straighten pages
64
  straighten_pages = st.sidebar.checkbox("Straighten pages", value=False)
65
  st.sidebar.write("\n")
66
  # Binarization threshold
67
  bin_thresh = st.sidebar.slider("Binarization threshold", min_value=0.1, max_value=0.9, value=0.3, step=0.1)
68
  st.sidebar.write("\n")
 
 
 
69
 
70
  if st.sidebar.button("Analyze page"):
71
  if uploaded_file is None:
@@ -74,7 +95,15 @@ def main(det_archs, reco_archs):
74
  else:
75
  with st.spinner("Loading model..."):
76
  predictor = load_predictor(
77
- det_arch, reco_arch, assume_straight_pages, straighten_pages, bin_thresh, forward_device
 
 
 
 
 
 
 
 
78
  )
79
 
80
  with st.spinner("Analyzing..."):
 
7
  import matplotlib.pyplot as plt
8
  import numpy as np
9
  import streamlit as st
 
10
 
11
+ from doctr.file_utils import is_tf_available
12
  from doctr.io import DocumentFile
13
  from doctr.utils.visualization import visualize_page
14
 
15
+ if is_tf_available():
16
+ import tensorflow as tf
17
+ from backend.tensorflow import DET_ARCHS, RECO_ARCHS, forward_image, load_predictor
18
 
19
+ if any(tf.config.experimental.list_physical_devices("gpu")):
20
+ forward_device = tf.device("/gpu:0")
21
+ else:
22
+ forward_device = tf.device("/cpu:0")
23
+
24
+ else:
25
+ import torch
26
+ from backend.pytorch import DET_ARCHS, RECO_ARCHS, forward_image, load_predictor
27
+
28
+ forward_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
29
 
30
 
31
  def main(det_archs, reco_archs):
 
62
 
63
  # Model selection
64
  st.sidebar.title("Model selection")
65
+ st.sidebar.markdown("**Backend**: " + ("TensorFlow" if is_tf_available() else "PyTorch"))
66
  det_arch = st.sidebar.selectbox("Text detection model", det_archs)
67
  reco_arch = st.sidebar.selectbox("Text recognition model", reco_archs)
68
 
 
72
  st.sidebar.title("Parameters")
73
  assume_straight_pages = st.sidebar.checkbox("Assume straight pages", value=True)
74
  st.sidebar.write("\n")
75
+ # Disable page orientation detection
76
+ disable_page_orientation = st.sidebar.checkbox("Disable page orientation detection", value=False)
77
+ st.sidebar.write("\n")
78
+ # Disable crop orientation detection
79
+ disable_crop_orientation = st.sidebar.checkbox("Disable crop orientation detection", value=False)
80
+ st.sidebar.write("\n")
81
  # Straighten pages
82
  straighten_pages = st.sidebar.checkbox("Straighten pages", value=False)
83
  st.sidebar.write("\n")
84
  # Binarization threshold
85
  bin_thresh = st.sidebar.slider("Binarization threshold", min_value=0.1, max_value=0.9, value=0.3, step=0.1)
86
  st.sidebar.write("\n")
87
+ # Box threshold
88
+ box_thresh = st.sidebar.slider("Box threshold", min_value=0.1, max_value=0.9, value=0.1, step=0.1)
89
+ st.sidebar.write("\n")
90
 
91
  if st.sidebar.button("Analyze page"):
92
  if uploaded_file is None:
 
95
  else:
96
  with st.spinner("Loading model..."):
97
  predictor = load_predictor(
98
+ det_arch,
99
+ reco_arch,
100
+ assume_straight_pages,
101
+ straighten_pages,
102
+ disable_page_orientation,
103
+ disable_crop_orientation,
104
+ bin_thresh,
105
+ box_thresh,
106
+ forward_device,
107
  )
108
 
109
  with st.spinner("Analyzing..."):
backend/pytorch.py CHANGED
@@ -10,6 +10,9 @@ from doctr.models import ocr_predictor
10
  from doctr.models.predictor import OCRPredictor
11
 
12
  DET_ARCHS = [
 
 
 
13
  "db_resnet50",
14
  "db_resnet34",
15
  "db_mobilenet_v3_large",
@@ -34,7 +37,10 @@ def load_predictor(
34
  reco_arch: str,
35
  assume_straight_pages: bool,
36
  straighten_pages: bool,
 
 
37
  bin_thresh: float,
 
38
  device: torch.device,
39
  ) -> OCRPredictor:
40
  """Load a predictor from doctr.models
@@ -45,7 +51,10 @@ def load_predictor(
45
  reco_arch: recognition architecture
46
  assume_straight_pages: whether to assume straight pages or not
47
  straighten_pages: whether to straighten rotated pages or not
 
 
48
  bin_thresh: binarization threshold for the segmentation map
 
49
  device: torch.device, the device to load the predictor on
50
 
51
  Returns:
@@ -60,8 +69,11 @@ def load_predictor(
60
  straighten_pages=straighten_pages,
61
  export_as_straight_boxes=straighten_pages,
62
  detect_orientation=not assume_straight_pages,
 
 
63
  ).to(device)
64
  predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh
 
65
  return predictor
66
 
67
 
 
10
  from doctr.models.predictor import OCRPredictor
11
 
12
  DET_ARCHS = [
13
+ "fast_base",
14
+ "fast_small",
15
+ "fast_tiny",
16
  "db_resnet50",
17
  "db_resnet34",
18
  "db_mobilenet_v3_large",
 
37
  reco_arch: str,
38
  assume_straight_pages: bool,
39
  straighten_pages: bool,
40
+ disable_page_orientation: bool,
41
+ disable_crop_orientation: bool,
42
  bin_thresh: float,
43
+ box_thresh: float,
44
  device: torch.device,
45
  ) -> OCRPredictor:
46
  """Load a predictor from doctr.models
 
51
  reco_arch: recognition architecture
52
  assume_straight_pages: whether to assume straight pages or not
53
  straighten_pages: whether to straighten rotated pages or not
54
+ disable_page_orientation: whether to disable page orientation or not
55
+ disable_crop_orientation: whether to disable crop orientation or not
56
  bin_thresh: binarization threshold for the segmentation map
57
+ box_thresh: minimal objectness score to consider a box
58
  device: torch.device, the device to load the predictor on
59
 
60
  Returns:
 
69
  straighten_pages=straighten_pages,
70
  export_as_straight_boxes=straighten_pages,
71
  detect_orientation=not assume_straight_pages,
72
+ disable_page_orientation=disable_page_orientation,
73
+ disable_crop_orientation=disable_crop_orientation,
74
  ).to(device)
75
  predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh
76
+ predictor.det_predictor.model.postprocessor.box_thresh = box_thresh
77
  return predictor
78
 
79
 
backend/tensorflow.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ import numpy as np
7
+ import tensorflow as tf
8
+
9
+ from doctr.models import ocr_predictor
10
+ from doctr.models.predictor import OCRPredictor
11
+
12
+ DET_ARCHS = [
13
+ "fast_base",
14
+ "fast_small",
15
+ "fast_tiny",
16
+ "db_resnet50",
17
+ "db_mobilenet_v3_large",
18
+ "linknet_resnet18",
19
+ "linknet_resnet34",
20
+ "linknet_resnet50",
21
+ ]
22
+ RECO_ARCHS = [
23
+ "crnn_vgg16_bn",
24
+ "crnn_mobilenet_v3_small",
25
+ "crnn_mobilenet_v3_large",
26
+ "master",
27
+ "sar_resnet31",
28
+ "vitstr_small",
29
+ "vitstr_base",
30
+ "parseq",
31
+ ]
32
+
33
+
34
+ def load_predictor(
35
+ det_arch: str,
36
+ reco_arch: str,
37
+ assume_straight_pages: bool,
38
+ straighten_pages: bool,
39
+ disable_page_orientation: bool,
40
+ disable_crop_orientation: bool,
41
+ bin_thresh: float,
42
+ box_thresh: float,
43
+ device: tf.device,
44
+ ) -> OCRPredictor:
45
+ """Load a predictor from doctr.models
46
+
47
+ Args:
48
+ ----
49
+ det_arch: detection architecture
50
+ reco_arch: recognition architecture
51
+ assume_straight_pages: whether to assume straight pages or not
52
+ straighten_pages: whether to straighten rotated pages or not
53
+ disable_page_orientation: whether to disable page orientation or not
54
+ disable_crop_orientation: whether to disable crop orientation or not
55
+ bin_thresh: binarization threshold for the segmentation map
56
+ box_thresh: threshold for the detection boxes
57
+ device: tf.device, the device to load the predictor on
58
+
59
+ Returns:
60
+ -------
61
+ instance of OCRPredictor
62
+ """
63
+ with device:
64
+ predictor = ocr_predictor(
65
+ det_arch,
66
+ reco_arch,
67
+ pretrained=True,
68
+ assume_straight_pages=assume_straight_pages,
69
+ straighten_pages=straighten_pages,
70
+ export_as_straight_boxes=straighten_pages,
71
+ detect_orientation=not assume_straight_pages,
72
+ disable_page_orientation=disable_page_orientation,
73
+ disable_crop_orientation=disable_crop_orientation,
74
+ )
75
+ predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh
76
+ predictor.det_predictor.model.postprocessor.box_thresh = box_thresh
77
+ return predictor
78
+
79
+
80
+ def forward_image(predictor: OCRPredictor, image: np.ndarray, device: tf.device) -> np.ndarray:
81
+ """Forward an image through the predictor
82
+
83
+ Args:
84
+ ----
85
+ predictor: instance of OCRPredictor
86
+ image: image to process as numpy array
87
+ device: tf.device, the device to process the image on
88
+
89
+ Returns:
90
+ -------
91
+ segmentation map
92
+ """
93
+ with device:
94
+ processed_batches = predictor.det_predictor.pre_processor([image])
95
+ out = predictor.det_predictor.model(processed_batches[0], return_model_output=True)
96
+ seg_map = out["out_map"]
97
+
98
+ with tf.device("/cpu:0"):
99
+ seg_map = tf.identity(seg_map).numpy()
100
+
101
+ return seg_map
packages.txt CHANGED
@@ -1 +1 @@
1
- python3-opencv
 
1
+ python3-opencv