""" Small demo application to explore Gradio. """ import argparse import os from functools import partial import gradio as gr from PIL import Image from huggingface_hub import hf_hub_download from die_model import UNetDIEModel from utils import resize_image, make_image_square, cast_pil_image_to_torch_tensor_with_4_channel_dim, \ remove_square_padding def die_inference( image_raw, num_of_die_iterations, die_model, device ): """ Function to run the DIE model. :param image_raw: raw image :param num_of_die_iterations: number of DIE iterations :param die_model: DIE model :param device: device :return: cleaned image """ # preprocess image_raw_resized = resize_image(image_raw, 1500) image_raw_resized_square = make_image_square(image_raw_resized) image_raw_resized_square_tensor = cast_pil_image_to_torch_tensor_with_4_channel_dim(image_raw_resized_square) image_raw_resized_square_tensor = image_raw_resized_square_tensor.to(device) # convert string to int num_of_die_iterations = int(num_of_die_iterations) # inference image_die = die_model.enhance_document_image( image_raw_list=[image_raw_resized_square_tensor], num_of_die_iterations=num_of_die_iterations )[0] # postprocess image_die_resized = remove_square_padding( original_image=image_raw, square_image=image_die, resize_back_to_original=True ) return image_die_resized def main(): """ Main function to run the Gradio demo. :return: """ args = parse_arguments() description = "Welcome to the Document Image Enhancement (DIE) model demo on Hugging Face!\n\n" \ "" \ "This interactive application showcases a specialized AI model developed by " \ "the [Artificial Intelligence group](https://ai.renyi.hu) at the [Alfréd Rényi Institute of Mathematics](https://renyi.hu).\n\n" \ "" \ "Our DIE model is designed to enhance and restore archival and aged document images " \ "by removing various types of degradation, thereby making historical documents more legible " \ "and suitable for Optical Character Recognition (OCR) processing.\n\n" \ "" \ "The model effectively tackles 20-30 types of domain-specific noise found in historical records, " \ "such as scribbles, bleed-through text, faded or worn text, blurriness, textured noise, " \ "and unwanted background elements. " \ "By applying deep learning techniques, specifically a U-Net-based architecture, " \ "the model accurately cleans and clarifies text while preserving original details. " \ "This improved clarity dramatically boosts OCR accuracy, making it an ideal " \ "pre-processing tool in digitization workflows.\n\n" \ "" \ "If you’re interested in learning more about the model’s capabilities or potential applications, " \ "please contact us at: gabar92@renyi.hu.\n\n" # TODO: Add a description for the Number of DIE iterations parameter! num_of_die_iterations_list = [1, 2, 3] die_token = os.getenv("DIE_TOKEN") # Provide images alone for example display example_image_list = [ [Image.open(os.path.join(args.example_image_path, image_path))] for image_path in os.listdir(args.example_image_path) ] # Load DIE model args.die_model_path = hf_hub_download( repo_id="gabar92/die", filename=args.die_model_path, use_auth_token=die_token ) die_model = UNetDIEModel(args=args) # Partially apply the model and device arguments to die_inference partial_die_inference = partial(die_inference, device=args.device, die_model=die_model) demo = gr.Interface( fn=partial_die_inference, inputs=[ gr.Image(type="pil", label="Degraded Document Image"), gr.Dropdown(num_of_die_iterations_list, label="Number of DIE iterations", value=1), ], outputs=gr.Image(type="pil", label="Clean Document Image"), title="Document Image Enhancement (DIE) model", description=description, examples=example_image_list ) demo.launch(server_name="0.0.0.0", server_port=7860) def parse_arguments(): """ Parse arguments. :return: argument namespace """ parser = argparse.ArgumentParser() parser.add_argument("--die_model_path", default="2024_08_09_model_epoch_89.pt") parser.add_argument("--device", default="cpu") parser.add_argument("--example_image_path", default="example_images") return parser.parse_args() if __name__ == "__main__": main()