viklofg's picture
Upload 3 files
93d3903
raw
history blame
2.41 kB
import streamlit as st
from transformers import AutoTokenizer, T5ForConditionalGeneration
import post_ocr
# Load model
@st.cache_resource
def load_model():
return T5ForConditionalGeneration.from_pretrained('viklofg/swedish-ocr-correction')
model = load_model()
# Load tokenizer
@st.cache_resource
def load_tokenizer():
return AutoTokenizer.from_pretrained('google/byt5-small')
tokenizer = load_tokenizer()
# Set model and tokenizer
post_ocr.set_model(model, tokenizer)
# Title
st.title(':memo: Swedish OCR correction')
# Input and output areas
tab1, tab2 = st.tabs(["Text input", "From file"])
def clean_inputs():
st.session_state.inputs = {'tab1': None, 'tab2': None}
if 'inputs' not in st.session_state:
clean_inputs()
def clean_outputs():
st.session_state.outputs = {'tab1': None, 'tab2': None}
if 'outputs' not in st.session_state:
clean_outputs()
# Sidebar (settings and stuff)
with st.sidebar:
st.header('Settings')
n_candidates = st.number_input('Overlap', help='A higher value may lead to better quality, but takes longer time', value=1, min_value=1, max_value=7, step=2, on_change=clean_inputs)
st.header('Output')
show_changes = st.toggle('Show changes')
def handle_input(input_, id_):
with st.container(border=True):
st.caption('Output')
# Only update the output if the input has been updated
if input_ and st.session_state.inputs[id_] != input_:
st.session_state.inputs[id_] = input_
with st.spinner('Generating...'):
output = post_ocr.process(input_, n_candidates)
st.session_state.outputs[id_] = output
# Display output
output = st.session_state.outputs[id_]
if output is not None:
st.write(post_ocr.diff(input_, output) if show_changes else output)
# Manual entry tab
with tab1:
typed_input = st.text_area('Input OCR', placeholder='Enter OCR generated text', label_visibility='collapsed')
handle_input(typed_input, 'tab1')
# File upload tab
with tab2:
uploaded_file = st.file_uploader('Choose a file', type='.txt')
if uploaded_file is not None:
text = uploaded_file.getvalue().decode('utf-8')
# Display file content
with st.container(border=True):
st.caption(f'File content')
st.write(text)
handle_input(text, 'tab2')