import streamlit as st from transformers import AutoTokenizer, T5ForConditionalGeneration import post_ocr # Load model @st.cache_resource def load_model(): return T5ForConditionalGeneration.from_pretrained('viklofg/swedish-ocr-correction') model = load_model() # Load tokenizer @st.cache_resource def load_tokenizer(): return AutoTokenizer.from_pretrained('google/byt5-small') tokenizer = load_tokenizer() # Set model and tokenizer post_ocr.set_model(model, tokenizer) # Title st.title(':memo: Swedish OCR correction') # Input and output areas tab1, tab2 = st.tabs(["Text input", "From file"]) def clean_inputs(): st.session_state.inputs = {'tab1': None, 'tab2': None} if 'inputs' not in st.session_state: clean_inputs() def clean_outputs(): st.session_state.outputs = {'tab1': None, 'tab2': None} if 'outputs' not in st.session_state: clean_outputs() # Sidebar (settings and stuff) with st.sidebar: st.header('Settings') n_candidates = st.number_input('Overlap', help='A higher value may lead to better quality, but takes longer time', value=1, min_value=1, max_value=7, step=2, on_change=clean_inputs) st.header('Output') show_changes = st.toggle('Show changes') def handle_input(input_, id_): with st.container(border=True): st.caption('Output') # Only update the output if the input has been updated if input_ and st.session_state.inputs[id_] != input_: st.session_state.inputs[id_] = input_ with st.spinner('Generating...'): output = post_ocr.process(input_, n_candidates) st.session_state.outputs[id_] = output # Display output output = st.session_state.outputs[id_] if output is not None: st.write(post_ocr.diff(input_, output) if show_changes else output) # Manual entry tab with tab1: typed_input = st.text_area('Input OCR', placeholder='Enter OCR generated text', label_visibility='collapsed') handle_input(typed_input, 'tab1') # File upload tab with tab2: uploaded_file = st.file_uploader('Choose a file', type='.txt') if uploaded_file is not None: text = uploaded_file.getvalue().decode('utf-8') # Display file content with st.container(border=True): st.caption(f'File content') st.write(text) handle_input(text, 'tab2')