|
import gradio as gr |
|
import os |
|
from PIL import Image |
|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, AutoImageProcessor |
|
|
|
import base64 |
|
|
|
import evaluate |
|
import logging |
|
|
|
|
|
logging.getLogger('transformers').setLevel(logging.ERROR) |
|
|
|
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") |
|
image_processor = AutoImageProcessor.from_pretrained("pstroe/bullinger-general-model") |
|
model = VisionEncoderDecoderModel.from_pretrained("pstroe/bullinger-general-model") |
|
|
|
|
|
|
|
def get_example_data(folder_path="./examples/"): |
|
|
|
example_data = [] |
|
|
|
|
|
all_files = os.listdir(folder_path) |
|
|
|
|
|
for file_name in all_files: |
|
|
|
file_path = os.path.join(folder_path, file_name) |
|
|
|
|
|
if file_name.endswith(".png"): |
|
|
|
|
|
corresponding_text_file_name = file_name.replace(".png", ".txt") |
|
corresponding_text_file_path = os.path.join(folder_path, corresponding_text_file_name) |
|
|
|
|
|
transcription = "Transcription not found." |
|
|
|
|
|
try: |
|
with open(corresponding_text_file_path, "r") as f: |
|
transcription = f.read().strip() |
|
except FileNotFoundError: |
|
pass |
|
|
|
example_data.append([file_path, transcription]) |
|
|
|
return example_data |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_image(image, ground_truth): |
|
|
|
cer = None |
|
|
|
|
|
pixel_values = image_processor(image, return_tensors="pt").pixel_values |
|
|
|
|
|
generated_ids = model.generate(pixel_values) |
|
|
|
|
|
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
|
|
if ground_truth is not None and ground_truth.strip() != "": |
|
|
|
|
|
print("Number of predictions:", len(generated_text)) |
|
print("Number of references:", len(ground_truth)) |
|
|
|
|
|
if len(generated_text) != len(ground_truth): |
|
|
|
print("Mismatch in number of predictions and references.") |
|
print("Predictions:", generated_text) |
|
print("References:", ground_truth) |
|
print("\n") |
|
|
|
cer = cer_metric.compute(predictions=[generated_text], references=[ground_truth]) |
|
|
|
|
|
else: |
|
|
|
cer = "Ground truth not provided" |
|
|
|
return generated_text, cer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with open("assets/uzh_logo_mod.png", "rb") as img_file: |
|
logo_html = base64.b64encode(img_file.read()).decode('utf-8') |
|
|
|
|
|
with open("assets/bullinger_logo.png", "rb") as img_file: |
|
footer_html = base64.b64encode(img_file.read()).decode('utf-8') |
|
|
|
|
|
title = """ |
|
<h1 style='text-align: center'> TrOCR: Bullinger Dataset</p> |
|
""" |
|
|
|
description = """ |
|
Use of Microsoft's [TrOCR](https://arxiv.org/abs/2109.10282), an encoder-decoder model consisting of an \ |
|
image Transformer encoder and a text Transformer decoder for state-of-the-art optical character recognition \ |
|
(OCR) and handwritten text recognition (HTR) on text line images. \ |
|
This particular model was fine-tuned on [Bullinger Dataset](https://github.com/pstroe/bullinger-htr) \ |
|
as part of the project [Bullinger Digital](https://www.bullinger-digital.ch) |
|
([References](https://www.cl.uzh.ch/de/people/team/compling/pstroebel.html#Publications)). |
|
* HF `model card`: [pstroe/bullinger-general-model](https://huggingface.co/pstroe/bullinger-general-model) | \ |
|
[Flexible Techniques for Automatic Text Recognition of Historical Documents](https://doi.org/10.5167/uzh-234886) |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples = get_example_data() |
|
|
|
|
|
|
|
|
|
|
|
cer_metric = evaluate.load("cer") |
|
|
|
with gr.Blocks( |
|
theme=gr.themes.Soft(), |
|
title="TrOCR Bullinger", |
|
) as demo: |
|
|
|
gr.HTML( |
|
f""" |
|
<div style='display: flex; justify-content: right; width: 100%;'> |
|
<img src='data:image/png;base64,{logo_html}' class='img-fluid' width='200px'> |
|
</div> |
|
""" |
|
) |
|
|
|
|
|
|
|
title = gr.HTML(title) |
|
description = gr.Markdown(description) |
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(variant="panel"): |
|
|
|
input = gr.components.Image(type="pil", label="Input image:") |
|
|
|
with gr.Row(): |
|
|
|
btn_clear = gr.Button(value="Clear") |
|
button = gr.Button(value="Submit") |
|
|
|
with gr.Column(variant="panel"): |
|
|
|
output = gr.components.Textbox(label="Generated text:") |
|
ground_truth = gr.components.Textbox(value="", placeholder="Provide the ground truth, if available.", label="Ground truth:") |
|
cer_output = gr.components.Textbox(label="CER:") |
|
|
|
with gr.Row(): |
|
|
|
with gr.Accordion(label="Choose an example from test set:", open=False): |
|
|
|
gr.Examples( |
|
examples=examples, |
|
inputs = [input, ground_truth], |
|
label=None, |
|
) |
|
|
|
with gr.Row(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gr.HTML( |
|
f""" |
|
<div style="display: flex; align-items: center; justify-content: center"> |
|
<img src="data:image/png;base64,{footer_html}" style="height: 40px; object-fit: contain; margin-right: 5px; margin-bottom: 5px"> |
|
<p style="font-size: 13px"> |
|
<strong>Bullinger</strong><u>Digital</u> | Institut für Computerlinguistik, Universität Zürich, 2023 |
|
</p> |
|
</div> |
|
""" |
|
) |
|
|
|
|
|
|
|
button.click(process_image, inputs=[input, ground_truth], outputs=[output, cer_output]) |
|
btn_clear.click(lambda: [None, "", "", ""], outputs=[input, output, ground_truth, cer_output]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
demo.launch(favicon_path="icon.png") |
|
|