Spaces:
Runtime error
Runtime error
Ankur Goyal
commited on
Commit
·
225fcc2
1
Parent(s):
588673f
Support Donut
Browse files- app.py +43 -26
- requirements.txt +1 -0
app.py
CHANGED
|
@@ -19,16 +19,23 @@ def ensure_list(x):
|
|
| 19 |
return [x]
|
| 20 |
|
| 21 |
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 25 |
-
ret = get_pipeline(device=device)
|
| 26 |
return ret
|
| 27 |
|
| 28 |
|
| 29 |
-
@st.cache
|
| 30 |
-
def run_pipeline(question, document, top_k):
|
| 31 |
-
|
|
|
|
| 32 |
|
| 33 |
|
| 34 |
# TODO: Move into docquery
|
|
@@ -56,13 +63,14 @@ st.markdown("# DocQuery: Query Documents w/ NLP")
|
|
| 56 |
if "document" not in st.session_state:
|
| 57 |
st.session_state["document"] = None
|
| 58 |
|
| 59 |
-
input_col, model_col = st.columns([2,1])
|
| 60 |
|
| 61 |
with input_col:
|
| 62 |
input_type = st.radio("Pick an input type", ["Upload", "URL"], horizontal=True)
|
| 63 |
|
| 64 |
with model_col:
|
| 65 |
-
model_type = st.radio("Pick a model",
|
|
|
|
| 66 |
|
| 67 |
def load_file_cb():
|
| 68 |
if st.session_state.file_input is None:
|
|
@@ -109,30 +117,39 @@ if document is not None:
|
|
| 109 |
|
| 110 |
colors = ["blue", "red", "green"]
|
| 111 |
if document is not None and question is not None and len(question) > 0:
|
| 112 |
-
col2.header("Answers")
|
| 113 |
with col2:
|
| 114 |
answers_placeholder = st.empty()
|
| 115 |
answers_loading_placeholder = st.empty()
|
| 116 |
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
)
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
if document is not None:
|
| 135 |
-
col1.image(image, use_column_width=
|
| 136 |
|
| 137 |
"DocQuery uses LayoutLMv1 fine-tuned on DocVQA, a document visual question answering dataset, as well as SQuAD, which boosts its English-language comprehension. To use it, simply upload an image or PDF, type a question, and click 'submit', or click one of the examples to load them."
|
| 138 |
|
|
|
|
| 19 |
return [x]
|
| 20 |
|
| 21 |
|
| 22 |
+
CHECKPOINTS = {
|
| 23 |
+
"LayoutLMv1 🦉": "impira/layoutlm-document-qa",
|
| 24 |
+
"Donut 🍩": "naver-clova-ix/donut-base-finetuned-docvqa",
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@st.experimental_singleton(show_spinner=False)
|
| 29 |
+
def construct_pipeline(model):
|
| 30 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 31 |
+
ret = get_pipeline(checkpoint=CHECKPOINTS[model], device=device)
|
| 32 |
return ret
|
| 33 |
|
| 34 |
|
| 35 |
+
@st.cache(show_spinner=False)
|
| 36 |
+
def run_pipeline(model, question, document, top_k):
|
| 37 |
+
pipeline = construct_pipeline(model)
|
| 38 |
+
return pipeline(question=question, **document.context, top_k=top_k)
|
| 39 |
|
| 40 |
|
| 41 |
# TODO: Move into docquery
|
|
|
|
| 63 |
if "document" not in st.session_state:
|
| 64 |
st.session_state["document"] = None
|
| 65 |
|
| 66 |
+
input_col, model_col = st.columns([2, 1])
|
| 67 |
|
| 68 |
with input_col:
|
| 69 |
input_type = st.radio("Pick an input type", ["Upload", "URL"], horizontal=True)
|
| 70 |
|
| 71 |
with model_col:
|
| 72 |
+
model_type = st.radio("Pick a model", list(CHECKPOINTS.keys()), horizontal=True)
|
| 73 |
+
|
| 74 |
|
| 75 |
def load_file_cb():
|
| 76 |
if st.session_state.file_input is None:
|
|
|
|
| 117 |
|
| 118 |
colors = ["blue", "red", "green"]
|
| 119 |
if document is not None and question is not None and len(question) > 0:
|
| 120 |
+
col2.header(f"Answers ({model_type})")
|
| 121 |
with col2:
|
| 122 |
answers_placeholder = st.empty()
|
| 123 |
answers_loading_placeholder = st.empty()
|
| 124 |
|
| 125 |
+
with answers_loading_placeholder:
|
| 126 |
+
# Run this (one-time) expensive operation outside of the processing
|
| 127 |
+
# question placeholder
|
| 128 |
+
with st.spinner("Constructing pipeline..."):
|
| 129 |
+
construct_pipeline(model_type)
|
| 130 |
+
|
| 131 |
+
with st.spinner("Processing question..."):
|
| 132 |
+
predictions = run_pipeline(
|
| 133 |
+
model=model_type, question=question, document=document, top_k=1
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
with answers_placeholder:
|
| 137 |
+
image = image.copy()
|
| 138 |
+
draw = ImageDraw.Draw(image)
|
| 139 |
+
for i, p in enumerate(ensure_list(predictions)):
|
| 140 |
+
col2.markdown(f"#### { p['answer'] }: ({round(p['score'] * 100, 1)}%)")
|
| 141 |
+
if "start" in p and "end" in p:
|
| 142 |
+
x1, y1, x2, y2 = normalize_bbox(
|
| 143 |
+
expand_bbox(
|
| 144 |
+
lift_word_boxes(document)[p["start"] : p["end"] + 1]
|
| 145 |
+
),
|
| 146 |
+
image.width,
|
| 147 |
+
image.height,
|
| 148 |
+
)
|
| 149 |
+
draw.rectangle(((x1, y1), (x2, y2)), outline=colors[i], width=3)
|
| 150 |
|
| 151 |
if document is not None:
|
| 152 |
+
col1.image(image, use_column_width="auto")
|
| 153 |
|
| 154 |
"DocQuery uses LayoutLMv1 fine-tuned on DocVQA, a document visual question answering dataset, as well as SQuAD, which boosts its English-language comprehension. To use it, simply upload an image or PDF, type a question, and click 'submit', or click one of the examples to load them."
|
| 155 |
|
requirements.txt
CHANGED
|
@@ -1,3 +1,4 @@
|
|
| 1 |
torch
|
| 2 |
git+https://github.com/huggingface/transformers.git@21f6f58721dd9154357576be6de54eefef1f1818
|
| 3 |
git+https://github.com/impira/docquery.git@43683e0dae72cadf8e8b4927191978109153458c
|
|
|
|
|
|
| 1 |
torch
|
| 2 |
git+https://github.com/huggingface/transformers.git@21f6f58721dd9154357576be6de54eefef1f1818
|
| 3 |
git+https://github.com/impira/docquery.git@43683e0dae72cadf8e8b4927191978109153458c
|
| 4 |
+
sentencepiece
|