import argparse import os from functools import partial import torch import gradio as gr from PIL import Image from huggingface_hub import hf_hub_download from die_model import UNetDIEModel from utils import resize_image, make_image_square, cast_pil_image_to_torch_tensor_with_4_channel_dim, remove_square_padding def die_inference(image_raw, num_of_die_iterations, die_model, device): """ Applies the DIE model for document enhancement on a provided image. """ # preprocess image_raw_resized = resize_image(image_raw, 1500) image_raw_resized_square = make_image_square(image_raw_resized) image_raw_resized_square_tensor = cast_pil_image_to_torch_tensor_with_4_channel_dim(image_raw_resized_square).to(device) # convert string to int num_of_die_iterations = int(num_of_die_iterations) # inference image_die = die_model.enhance_document_image( image_raw_list=[image_raw_resized_square_tensor], num_of_die_iterations=num_of_die_iterations )[0] # postprocess return remove_square_padding( original_image=image_raw, square_image=image_die, resize_back_to_original=True ) def main(): """ Main function to set up and run the Gradio demo. """ args = parse_arguments() args.device = 'cuda' if torch.cuda.is_available() else 'cpu' # Set up model die_token = os.getenv("DIE_TOKEN") args.die_model_path = hf_hub_download( repo_id="gabar92/die", filename=args.die_model_path, use_auth_token=die_token ) die_model = UNetDIEModel(args=args) # Prepare example images example_image_list = [ [Image.open(os.path.join(args.example_image_path, image_path))] for image_path in os.listdir(args.example_image_path) ] description = "Welcome to the Document Image Enhancement (DIE) model demo on Hugging Face!\n\n" \ "" \ "This interactive application showcases a specialized AI model developed by " \ "the [Artificial Intelligence group](https://ai.renyi.hu) at the [Alfréd Rényi Institute of Mathematics](https://renyi.hu).\n\n" \ "" \ "Our DIE model is designed to enhance and restore archival and aged document images " \ "by removing various types of degradation, thereby making historical documents more legible " \ "and suitable for Optical Character Recognition (OCR) processing.\n\n" \ "" \ "The model effectively tackles 20-30 types of domain-specific noise found in historical records, " \ "such as scribbles, bleed-through text, faded or worn text, blurriness, textured noise, " \ "and unwanted background elements. " \ "By applying deep learning techniques, specifically a U-Net-based architecture, " \ "the model accurately cleans and clarifies text while preserving original details. " \ "This improved clarity dramatically boosts OCR accuracy, making it an ideal " \ "pre-processing tool in digitization workflows.\n\n" \ "" \ "If you’re interested in learning more about the model’s capabilities or potential applications, " \ "please contact us at: gabar92@renyi.hu.\n" # Partial function for inference with model and device arguments partial_die_inference = partial(die_inference, die_model=die_model, device=args.device) with gr.Blocks() as demo: with gr.Row(): with gr.Column(): gr.Markdown("## Document Image Enhancement (DIE) model") with gr.Row(): with gr.Column(): gr.Markdown(description) with gr.Column(): # Display QR code as an image in Gradio gr.Image(value=Image.open("logo/qr-code.png").resize((400, 400)), label="QR Code") with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil", label="Upload Degraded Document Image") num_iterations = gr.Dropdown([1, 2, 3], label="Number of DIE Iterations", value=1) run_button = gr.Button("Enhance Image") with gr.Column(): output_image = gr.Image(type="pil", label="Enhanced Document Image") # Display example images gr.Examples( examples=example_image_list, inputs=[input_image], label="Example Images - Source: National Archives of Hungary and Budapest City Archives", ) # Button trigger for inference run_button.click(partial_die_inference, [input_image, num_iterations], output_image) demo.launch() def parse_arguments(): """ Parses command-line arguments. :return: argument namespace """ parser = argparse.ArgumentParser() parser.add_argument("--die_model_path", default="2024_08_09_model_epoch_89.pt") parser.add_argument("--example_image_path", default="example_images") return parser.parse_args() if __name__ == "__main__": main()