import gradio as gr import os from PIL import Image from transformers import TrOCRProcessor, VisionEncoderDecoderModel, AutoImageProcessor # import utils import base64 # from datasets import load_metric import evaluate import logging # Only show log messages that are at the ERROR level or above, effectively filtering out any warnings logging.getLogger('transformers').setLevel(logging.ERROR) processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") image_processor = AutoImageProcessor.from_pretrained("pstroe/bullinger-general-model") model = VisionEncoderDecoderModel.from_pretrained("pstroe/bullinger-general-model") # Create examples # Get images and respective transcriptions from the examples directory def get_example_data(folder_path="./examples/"): example_data = [] # Get list of all files in the folder all_files = os.listdir(folder_path) # Loop through the file list for file_name in all_files: file_path = os.path.join(folder_path, file_name) # Check if the file is an image (.png) if file_name.endswith(".png"): # Construct the corresponding .txt filename (same name) corresponding_text_file_name = file_name.replace(".png", ".txt") corresponding_text_file_path = os.path.join(folder_path, corresponding_text_file_name) # Initialize to a default value transcription = "Transcription not found." # Try to read the content from the .txt file try: with open(corresponding_text_file_path, "r") as f: transcription = f.read().strip() except FileNotFoundError: pass # If the corresponding .txt file is not found, leave the default value example_data.append([file_path, transcription]) return example_data # From pstroe's script # def compute_metrics(pred): # labels_ids = pred.label_ids # pred_ids = pred.predictions # pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True) # labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id # label_str = processor.batch_decode(labels_ids, skip_special_tokens=True) # cer = cer_metric.compute(predictions=pred_str, references=label_str) # return {"cer": cer} def process_image(image, ground_truth): cer = None # prepare image pixel_values = image_processor(image, return_tensors="pt").pixel_values # generate (no beam search) generated_ids = model.generate(pixel_values) # decode generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] if ground_truth is not None and ground_truth.strip() != "": # Debug: Print lengths before computing metric print("Number of predictions:", len(generated_text)) print("Number of references:", len(ground_truth)) # Check if lengths match if len(generated_text) != len(ground_truth): print("Mismatch in number of predictions and references.") print("Predictions:", generated_text) print("References:", ground_truth) print("\n") cer = cer_metric.compute(predictions=[generated_text], references=[ground_truth]) # cer = f"{cer:.3f}" else: cer = "Ground truth not provided" return generated_text, cer # One way to use .svg files # logo_url = "https://www.bullinger-digital.ch/bullinger-digital.svg" # logo_url = "https://www.cl.uzh.ch/docroot/logos/uzh_logo_e_pos.svg" # header_html = "".format( # utils.img_to_bytes(".uzh_logo_e_pos.svg") # ) # Encode images with open("assets/uzh_logo_mod.png", "rb") as img_file: logo_html = base64.b64encode(img_file.read()).decode('utf-8') # with open("assets/bullinger-digital.png", "rb") as img_file: with open("assets/bullinger_logo.png", "rb") as img_file: footer_html = base64.b64encode(img_file.read()).decode('utf-8') # App header title = """
TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models
# Flexible Techniques for Automatic Text Recognition of Historical Documents
# Bullingers Briefwechsel zugänglich machen: Stand der Handschriftenerkennung
# Bullinger Digital | Institut für Computerlinguistik, Universität Zürich, 2023 #
#BullingerDigital | Institut für Computerlinguistik, Universität Zürich, 2023