Flavio de Oliveira
Cosmetic changes
8235d5d
raw
history blame
8.56 kB
import gradio as gr
import os
from PIL import Image
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, AutoImageProcessor
# import utils
import base64
# from datasets import load_metric
import evaluate
import logging
# Only show log messages that are at the ERROR level or above, effectively filtering out any warnings
logging.getLogger('transformers').setLevel(logging.ERROR)
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
image_processor = AutoImageProcessor.from_pretrained("pstroe/bullinger-general-model")
model = VisionEncoderDecoderModel.from_pretrained("pstroe/bullinger-general-model")
# Create examples
# Get images and respective transcriptions from the examples directory
def get_example_data(folder_path="./examples/"):
example_data = []
# Get list of all files in the folder
all_files = os.listdir(folder_path)
# Loop through the file list
for file_name in all_files:
file_path = os.path.join(folder_path, file_name)
# Check if the file is an image (.png)
if file_name.endswith(".png"):
# Construct the corresponding .txt filename (same name)
corresponding_text_file_name = file_name.replace(".png", ".txt")
corresponding_text_file_path = os.path.join(folder_path, corresponding_text_file_name)
# Initialize to a default value
transcription = "Transcription not found."
# Try to read the content from the .txt file
try:
with open(corresponding_text_file_path, "r") as f:
transcription = f.read().strip()
except FileNotFoundError:
pass # If the corresponding .txt file is not found, leave the default value
example_data.append([file_path, transcription])
return example_data
# From pstroe's script
# def compute_metrics(pred):
# labels_ids = pred.label_ids
# pred_ids = pred.predictions
# pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
# labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
# label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)
# cer = cer_metric.compute(predictions=pred_str, references=label_str)
# return {"cer": cer}
def process_image(image, ground_truth):
cer = None
# prepare image
pixel_values = image_processor(image, return_tensors="pt").pixel_values
# generate (no beam search)
generated_ids = model.generate(pixel_values)
# decode
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
if ground_truth is not None and ground_truth.strip() != "":
# Debug: Print lengths before computing metric
print("Number of predictions:", len(generated_text))
print("Number of references:", len(ground_truth))
# Check if lengths match
if len(generated_text) != len(ground_truth):
print("Mismatch in number of predictions and references.")
print("Predictions:", generated_text)
print("References:", ground_truth)
print("\n")
cer = cer_metric.compute(predictions=[generated_text], references=[ground_truth])
# cer = f"{cer:.3f}"
else:
cer = "Ground truth not provided"
return generated_text, cer
# One way to use .svg files
# logo_url = "https://www.bullinger-digital.ch/bullinger-digital.svg"
# logo_url = "https://www.cl.uzh.ch/docroot/logos/uzh_logo_e_pos.svg"
# header_html = "<img src='data:image/png;base64,{}' class='img-fluid' width='180px'>".format(
# utils.img_to_bytes(".uzh_logo_e_pos.svg")
# )
# Encode images
with open("assets/uzh_logo_mod.png", "rb") as img_file:
logo_html = base64.b64encode(img_file.read()).decode('utf-8')
# with open("assets/bullinger-digital.png", "rb") as img_file:
with open("assets/bullinger_logo.png", "rb") as img_file:
footer_html = base64.b64encode(img_file.read()).decode('utf-8')
# App header
title = """
<h1 style='text-align: center'> TrOCR: Bullinger Dataset</p>
"""
description = """
Use of Microsoft's [TrOCR](https://arxiv.org/abs/2109.10282), an encoder-decoder model consisting of an \
image Transformer encoder and a text Transformer decoder for state-of-the-art optical character recognition \
(OCR) and handwritten text recognition (HTR) on text line images. \
This particular model was fine-tuned on [Bullinger Dataset](https://github.com/pstroe/bullinger-htr) \
as part of the project [Bullinger Digital](https://www.bullinger-digital.ch)
([References](https://www.cl.uzh.ch/de/people/team/compling/pstroebel.html#Publications)).
* HF `model card`: [pstroe/bullinger-general-model](https://huggingface.co/pstroe/bullinger-general-model) | \
[Flexible Techniques for Automatic Text Recognition of Historical Documents](https://doi.org/10.5167/uzh-234886)
"""
# articles = """
# <p style='text-align: center'><a href='https://arxiv.org/abs/2109.10282'>TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models</a><br>
# <a href='https://doi.org/10.5167/uzh-234886'>Flexible Techniques for Automatic Text Recognition of Historical Documents</a><br>
# <a href='https://zenodo.org/record/7715357'>Bullingers Briefwechsel zugänglich machen: Stand der Handschriftenerkennung</a></p>
# """
# Read .png and the respective .txt files
examples = get_example_data()
# load_metric() is deprecated
# cer_metric = load_metric("cer")
# pip install jiwer
# pip install evaluate
cer_metric = evaluate.load("cer")
with gr.Blocks(
theme=gr.themes.Soft(),
title="TrOCR Bullinger",
) as demo:
gr.HTML(
f"""
<div style='display: flex; justify-content: right; width: 100%;'>
<img src='data:image/png;base64,{logo_html}' class='img-fluid' width='200px'>
</div>
"""
)
#174x60
title = gr.HTML(title)
description = gr.Markdown(description)
with gr.Row():
with gr.Column(variant="panel"):
input = gr.components.Image(type="pil", label="Input image:")
with gr.Row():
btn_clear = gr.Button(value="Clear")
button = gr.Button(value="Submit")
with gr.Column(variant="panel"):
output = gr.components.Textbox(label="Generated text:")
ground_truth = gr.components.Textbox(value="", placeholder="Provide the ground truth, if available.", label="Ground truth:")
cer_output = gr.components.Textbox(label="CER:")
with gr.Row():
with gr.Accordion(label="Choose an example from test set:", open=False):
gr.Examples(
examples=examples,
inputs = [input, ground_truth],
label=None,
)
with gr.Row():
# gr.HTML(
# f"""
# <div style="display: flex; align-items: center; justify-content: center">
# <img src="data:image/png;base64,{footer_html}" style="width: 150px; height: 60px; object-fit: contain; margin-right: 5px; margin-bottom: 5px">
# <p style="font-size: 13px">
# Bullinger Digital | Institut für Computerlinguistik, Universität Zürich, 2023
# </p>
# </div>
# """
# )
gr.HTML(
f"""
<div style="display: flex; align-items: center; justify-content: center">
<img src="data:image/png;base64,{footer_html}" style="height: 40px; object-fit: contain; margin-right: 5px; margin-bottom: 5px">
<p style="font-size: 13px">
<strong>Bullinger</strong><u>Digital</u> | Institut für Computerlinguistik, Universität Zürich, 2023
</p>
</div>
"""
)
#383x85
button.click(process_image, inputs=[input, ground_truth], outputs=[output, cer_output])
btn_clear.click(lambda: [None, "", "", ""], outputs=[input, output, ground_truth, cer_output])
# # Try to force light mode
# js = """
# function () {
# gradioURL = window.location.href
# if (!gradioURL.endsWith('?__theme=light')) {
# window.location.replace(gradioURL + '?__theme=light');
# }
# }"""
# demo.load(_js=js)
if __name__ == "__main__":
demo.launch(favicon_path="icon.png")