File size: 4,053 Bytes
ca53d7c
 
 
 
 
 
 
 
88a5db7
ca53d7c
 
4792343
 
5bb196c
 
4792343
ec9e91a
4792343
5bb196c
 
 
 
 
ca53d7c
5d1e573
 
 
ca53d7c
 
 
 
 
88a5db7
ca53d7c
 
 
6c0128c
4792343
ca53d7c
ec9e91a
5d1e573
3e193b0
5bb196c
 
 
ca53d7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c0128c
 
4792343
ca53d7c
 
 
 
 
 
 
 
 
a84bd32
6bbcef4
c77986f
 
ca53d7c
c77986f
 
 
6bbcef4
ca53d7c
c77986f
ca53d7c
c77986f
c3949f9
c77986f
ca53d7c
c77986f
 
 
 
 
 
 
 
ca53d7c
 
c77986f
 
 
ca53d7c
 
5bb196c
ca53d7c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import numpy as np
import torch
from torch import nn
import streamlit as st
import os

from PIL import Image
from io import BytesIO
import transformers
from transformers import VisionEncoderDecoderModel, VisionEncoderDecoderConfig, DonutProcessor, DonutImageProcessor, AutoTokenizer

from logits_ngrams import NoRepeatNGramLogitsProcessor, get_table_token_ids

def run_prediction(sample, model, processor, mode):

    skip_tokens = get_table_token_ids(processor)
    no_repeat_ngram_size = 15

    if mode == "OCR":
        prompt = "<s><s_pretraining>"
    else:
        prompt = "<s><s_hierarchical>"


    print("prompt:", prompt)
    print("no_repeat_ngram_size:", no_repeat_ngram_size)

    pixel_values = processor(np.array(
                    sample,
                    np.float32,
                ), return_tensors="pt").pixel_values

    transformers.set_seed(42)
    with torch.no_grad():
        outputs = model.generate(
            pixel_values.to(device),
            decoder_input_ids=processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(device),
            logits_processor=[NoRepeatNGramLogitsProcessor(no_repeat_ngram_size, skip_tokens)],
            do_sample=True,
            top_p=0.92, 
            top_k=5,
            no_repeat_ngram_size=15,
            num_beams=3,
            output_attentions=False,
            output_hidden_states=False,
        )

    # process output
    prediction = processor.batch_decode(outputs)[0]
    print(prediction)
    
    return prediction
    

logo = Image.open("./rsz_unstructured_logo.png")
st.image(logo)

st.markdown('''
### Chipper
Chipper is an OCR-free Document Understanding Transformer. It was pre-trained with over 1M documents from public sources and fine-tuned on a large range of documents. 

At [Unstructured.io](https://github.com/Unstructured-IO/unstructured) we are on a mission to build custom preprocessing pipelines for labeling, training, or production ML-ready pipelines. 
Come and join us in our public repos and contribute! Each of your contributions and feedback holds great value and is very significant to the community.
''')

image_upload = None
photo = None
with st.sidebar:
    # file upload
    uploaded_file = st.file_uploader("Upload a document")
    if uploaded_file is not None:
        # To read file as bytes:
        image_bytes_data = uploaded_file.getvalue()
        image_upload = Image.open(BytesIO(image_bytes_data))

    mode = st.selectbox('Mode', ('OCR', 'Element annotation'), index=1)

if image_upload:
    image = image_upload
else:
    image = Image.open(f"./document.png")

st.image(image, caption='Your target document')

with st.spinner(f'Processing the document ...'):
        pre_trained_model = "unstructuredio/chipper-v3"
        processor = DonutProcessor.from_pretrained(pre_trained_model, token=os.environ['HF_TOKEN'])
        
        device = "cuda" if torch.cuda.is_available() else "cpu"

        if 'model' in st.session_state:
            model = st.session_state['model']
        else:
            model = VisionEncoderDecoderModel.from_pretrained(pre_trained_model, token=os.environ['HF_TOKEN'])

            from huggingface_hub import hf_hub_download

            lm_head_file = hf_hub_download(
                repo_id=pre_trained_model, filename="lm_head.pth", token=os.environ['HF_TOKEN']
            )

            rank = 128
            model.decoder.lm_head = nn.Sequential(
                nn.Linear(model.decoder.lm_head.weight.shape[1], rank, bias=False),
                nn.Linear(rank, rank, bias=False),
                nn.Linear(rank, model.decoder.lm_head.weight.shape[0], bias=True),
            )

            model.decoder.lm_head.load_state_dict(torch.load(lm_head_file))


            model.eval()
            model.to(device)
            st.session_state['model'] = model

st.info(f'Parsing document')
parsed_info = run_prediction(image.convert("RGB"), model, processor, mode)
st.text(f'\nDocument:')
st.text_area('Output text', value=parsed_info, height=800)