wjbmattingly commited on
Commit
d39c80d
·
verified ·
1 Parent(s): b7fdefb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -0
app.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["GRADIO_TEMP_DIR"] = "./tmp"
3
+
4
+ import sys
5
+ import spaces
6
+ import torch
7
+ import torchvision
8
+ import gradio as gr
9
+ import numpy as np
10
+ from PIL import Image
11
+ from huggingface_hub import snapshot_download
12
+ from visualization import visualize_bbox
13
+
14
+ # == download weights ==
15
+ model_dir = snapshot_download('juliozhao/DocLayout-YOLO-DocStructBench', local_dir='./models/DocLayout-YOLO-DocStructBench')
16
+ # == select device ==
17
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
18
+
19
+ id_to_names = {
20
+ 0: 'title',
21
+ 1: 'plain text',
22
+ 2: 'abandon',
23
+ 3: 'figure',
24
+ 4: 'figure_caption',
25
+ 5: 'table',
26
+ 6: 'table_caption',
27
+ 7: 'table_footnote',
28
+ 8: 'isolate_formula',
29
+ 9: 'formula_caption'
30
+ }
31
+
32
+ @spaces.GPU
33
+ def recognize_image(input_img, conf_threshold, iou_threshold):
34
+ det_res = model.predict(
35
+ input_img,
36
+ imgsz=1024,
37
+ conf=conf_threshold,
38
+ device=device,
39
+ )[0]
40
+ boxes = det_res.__dict__['boxes'].xyxy
41
+ classes = det_res.__dict__['boxes'].cls
42
+ scores = det_res.__dict__['boxes'].conf
43
+
44
+ indices = torchvision.ops.nms(boxes=torch.Tensor(boxes), scores=torch.Tensor(scores),iou_threshold=iou_threshold)
45
+ boxes, scores, classes = boxes[indices], scores[indices], classes[indices]
46
+ if len(boxes.shape) == 1:
47
+ boxes = np.expand_dims(boxes, 0)
48
+ scores = np.expand_dims(scores, 0)
49
+ classes = np.expand_dims(classes, 0)
50
+
51
+ vis_result = visualize_bbox(input_img, boxes, classes, scores, id_to_names)
52
+ return vis_result
53
+
54
+ def gradio_reset():
55
+ return gr.update(value=None), gr.update(value=None)
56
+
57
+
58
+ if __name__ == "__main__":
59
+ root_path = os.path.abspath(os.getcwd())
60
+ # == load model ==
61
+ from doclayout_yolo import YOLOv10
62
+ print(f"Using device: {device}")
63
+ model = YOLOv10(os.path.join(os.path.dirname(__file__), "models", "DocLayout-YOLO-DocStructBench", "doclayout_yolo_docstructbench_imgsz1024.pt")) # load an official model
64
+
65
+ with open("header.html", "r") as file:
66
+ header = file.read()
67
+ with gr.Blocks() as demo:
68
+ gr.HTML(header)
69
+
70
+ with gr.Row():
71
+ with gr.Column():
72
+
73
+ input_img = gr.Image(label=" ", interactive=True)
74
+ with gr.Row():
75
+ clear = gr.Button(value="Clear")
76
+ predict = gr.Button(value="Detect", interactive=True, variant="primary")
77
+
78
+ with gr.Row():
79
+ conf_threshold = gr.Slider(
80
+ label="Confidence Threshold",
81
+ minimum=0.0,
82
+ maximum=1.0,
83
+ step=0.05,
84
+ value=0.25,
85
+ )
86
+
87
+ with gr.Row():
88
+ iou_threshold = gr.Slider(
89
+ label="NMS IOU Threshold",
90
+ minimum=0.0,
91
+ maximum=1.0,
92
+ step=0.05,
93
+ value=0.45,
94
+ )
95
+
96
+ with gr.Accordion("Examples:"):
97
+ example_root = os.path.join(os.path.dirname(__file__), "assets", "example")
98
+ gr.Examples(
99
+ examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if
100
+ _.endswith("jpg")],
101
+ inputs=[input_img],
102
+ )
103
+ with gr.Column():
104
+ gr.Button(value="Predict Result:", interactive=False)
105
+ output_img = gr.Image(label=" ", interactive=False)
106
+
107
+ clear.click(gradio_reset, inputs=None, outputs=[input_img, output_img])
108
+ predict.click(recognize_image, inputs=[input_img,conf_threshold,iou_threshold], outputs=[output_img])
109
+
110
+ demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)