|
import streamlit as st |
|
from transformers import AutoTokenizer, T5ForConditionalGeneration |
|
import post_ocr |
|
|
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
return T5ForConditionalGeneration.from_pretrained('viklofg/swedish-ocr-correction') |
|
model = load_model() |
|
|
|
|
|
|
|
@st.cache_resource |
|
def load_tokenizer(): |
|
return AutoTokenizer.from_pretrained('google/byt5-small') |
|
tokenizer = load_tokenizer() |
|
|
|
|
|
|
|
post_ocr.set_model(model, tokenizer) |
|
|
|
|
|
|
|
st.title(':memo: Swedish OCR correction') |
|
|
|
tab1, tab2 = st.tabs(["Text input", "From file"]) |
|
|
|
|
|
def clean_inputs(): |
|
st.session_state.inputs = {'tab1': None, 'tab2': None} |
|
|
|
if 'inputs' not in st.session_state: |
|
clean_inputs() |
|
|
|
|
|
def clean_outputs(): |
|
st.session_state.outputs = {'tab1': None, 'tab2': None} |
|
|
|
if 'outputs' not in st.session_state: |
|
clean_outputs() |
|
|
|
|
|
|
|
with st.sidebar: |
|
st.header('Settings') |
|
n_candidates = st.number_input('Overlap', help='A higher value may lead to better quality, but takes longer time', value=1, min_value=1, max_value=7, step=2, on_change=clean_inputs) |
|
|
|
st.header('Output') |
|
show_changes = st.toggle('Show changes') |
|
|
|
|
|
def handle_input(input_, id_): |
|
|
|
with st.container(border=True): |
|
st.caption('Output') |
|
|
|
|
|
if input_ and st.session_state.inputs[id_] != input_: |
|
st.session_state.inputs[id_] = input_ |
|
with st.spinner('Generating...'): |
|
output = post_ocr.process(input_, n_candidates) |
|
st.session_state.outputs[id_] = output |
|
|
|
|
|
output = st.session_state.outputs[id_] |
|
if output is not None: |
|
st.write(post_ocr.diff(input_, output) if show_changes else output) |
|
|
|
|
|
|
|
with tab1: |
|
typed_input = st.text_area('Input OCR', placeholder='Enter OCR generated text', label_visibility='collapsed') |
|
handle_input(typed_input, 'tab1') |
|
|
|
|
|
|
|
with tab2: |
|
uploaded_file = st.file_uploader('Choose a file', type='.txt') |
|
|
|
if uploaded_file is not None: |
|
text = uploaded_file.getvalue().decode('utf-8') |
|
|
|
|
|
with st.container(border=True): |
|
st.caption(f'File content') |
|
st.write(text) |
|
|
|
handle_input(text, 'tab2') |
|
|