greyeye124 commited on
Commit
3617214
1 Parent(s): 8d380e0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -0
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import gradio as gr
3
+ import torch
4
+ from PIL import Image
5
+
6
+ from donut import DonutModel
7
+
8
+
9
+ def demo_process_vqa(input_img, question):
10
+ global pretrained_model, task_prompt, task_name
11
+ # input_img = Image.fromarray(input_img)
12
+ user_prompt = task_prompt.replace("{user_input}", question)
13
+ output = pretrained_model.inference(input_img, prompt=user_prompt)["predictions"][0]
14
+ return output
15
+
16
+
17
+ def demo_process(input_img):
18
+ global pretrained_model, task_prompt, task_name,security_layer
19
+ input_img = Image.fromarray(input_img)
20
+ sec = security_layer.inference(image=input_img,prompt="<s_rvlcdip>")['predictions'][0]
21
+ print(sec)
22
+ if sec['class']=="invoice":
23
+ output = pretrained_model.inference(image=input_img, prompt="<s_cord-v2>")["predictions"][0]
24
+ return output
25
+ return sec
26
+
27
+ task_name="cord-v2"
28
+ if "docvqa" == task_name:
29
+ task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
30
+ else: # rvlcdip, cord, ...
31
+ task_prompt = f"<s_{task_name}>"
32
+
33
+ security_layer = DonutModel.from_pretrained("naver-clova-ix/donut-base-finetuned-rvlcdip")
34
+
35
+ pretrained_model = DonutModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
36
+
37
+
38
+ if torch.cuda.is_available():
39
+ pretrained_model.half()
40
+ security_layer.half()
41
+ device = torch.device("cuda")
42
+ pretrained_model.to(device)
43
+ security_layer.to(device)
44
+ else:
45
+ pretrained_model.encoder.to(torch.bfloat16)
46
+ security_layer.encoder.to(torch.bfloat16)
47
+
48
+ pretrained_model.eval()
49
+ security_layer.eval()
50
+
51
+
52
+ demo = gr.Interface(
53
+ fn=demo_process_vqa if task_name == "docvqa" else demo_process,
54
+ inputs=["image", "text"] if task_name == "docvqa" else "image",
55
+ outputs="json",
56
+ title=f"Donut 🍩 demonstration for `{task_name}` task",
57
+ concurrency_limit=10,
58
+ description="Get invoice details if invoice"
59
+ )
60
+
61
+ demo.queue(default_concurrency_limit=2,max_size=5)
62
+ demo.launch(debug=True,share=True, inline=False)