""" Small demo application to explore Gradio. """ import argparse import os from functools import partial import gradio as gr from PIL import Image from huggingface_hub import hf_hub_download from die_model import UNetDIEModel from utils import resize_image, make_image_square, cast_pil_image_to_torch_tensor_with_4_channel_dim, \ remove_square_padding def die_inference( image_raw, num_of_die_iterations, die_model, device ): """ Function to run the DIE model. :param image_raw: raw image :param num_of_die_iterations: number of DIE iterations :param die_model: DIE model :param device: device :return: cleaned image """ # preprocess image_raw_resized = resize_image(image_raw, 1500) image_raw_resized_square = make_image_square(image_raw_resized) image_raw_resized_square_tensor = cast_pil_image_to_torch_tensor_with_4_channel_dim(image_raw_resized_square) image_raw_resized_square_tensor = image_raw_resized_square_tensor.to(device) # convert string to int num_of_die_iterations = int(num_of_die_iterations) # inference image_die = die_model.enhance_document_image( image_raw_list=[image_raw_resized_square_tensor], num_of_die_iterations=num_of_die_iterations )[0] # postprocess image_die_resized = remove_square_padding( original_image=image_raw, square_image=image_die, resize_back_to_original=True ) return image_die_resized def main(): """ Main function to run the Gradio demo. :return: """ args = parse_arguments() description = "Welcome to the Document Image Enhancement (DIE) model demo on Hugging Face!\n\n" \ "" \ "This interactive application showcases a specialized AI model developed by " \ "the [Artificial Intelligence group](https://ai.renyi.hu) at the [Alfréd Rényi Institute of Mathematics](https://renyi.hu).\n\n" \ "" \ "Our DIE model is designed to enhance and restore archival and aged document images " \ "by removing various types of degradation, thereby making historical documents more legible " \ "and suitable for Optical Character Recognition (OCR) processing.\n\n" \ "" \ "The model effectively tackles 20-30 types of domain-specific noise found in historical records, " \ "such as scribbles, bleed-through text, faded or worn text, blurriness, textured noise, " \ "and unwanted background elements. " \ "By applying deep learning techniques, specifically a U-Net-based architecture, " \ "the model accurately cleans and clarifies text while preserving original details. " \ "This improved clarity dramatically boosts OCR accuracy, making it an ideal " \ "pre-processing tool in digitization workflows.\n\n" \ "" \ "If you’re interested in learning more about the model’s capabilities or potential applications, " \ "please contact us at: gabar92@renyi.hu.\n\n" # TODO: Add a description for the Number of DIE iterations parameter! num_of_die_iterations_list = [1, 2, 3] die_token = os.getenv("DIE_TOKEN") # Load example images from the private repository example_image_list = [] example_images = [ "_002.jpg", "_004.jpg", "_005.jpg", "_006.jpg", "_007.jpg", "_010.jpg", "_097.jpg", ] for image_filename in example_images: image_path = hf_hub_download( repo_id="gabar92/DIE", filename=f"{args.example_image_path}/{image_filename}", use_auth_token=die_token ) example_image_list.append([Image.open(image_path)]) # Load DIE model # model_path = hf_hub_download( # repo_id="gabar92/DIE", # filename=args.die_model_path, # use_auth_token=die_token #) model_path = "./DIE/2024_08_09_model_epoch_89.pt" die_model = UNetDIEModel(args=args, model_path=model_path) # Partially apply the model and device arguments to die_inference partial_die_inference = partial(die_inference, device=args.device, die_model=die_model) demo = gr.Interface( fn=partial_die_inference, inputs=[ gr.Image(type="pil", label="Degraded Document Image"), gr.Dropdown(num_of_die_iterations_list, label="Number of DIE iterations", value=1), ], outputs=gr.Image(type="pil", label="Clean Document Image"), title="Document Image Enhancement (DIE) model", description=description, examples=example_image_list ) demo.launch(server_name="0.0.0.0", server_port=7860) def parse_arguments(): """ Parse arguments. :return: argument namespace """ parser = argparse.ArgumentParser() parser.add_argument("--die_model_path", default="2024_08_09_model_epoch_89.pt") parser.add_argument("--device", default="cpu") parser.add_argument("--example_image_path", default="example_images") return parser.parse_args() if __name__ == "__main__": main()