die_demo / app.py
gabar92's picture
update model path parameter in argument parser
8430bc8
"""
Small demo application to explore Gradio.
"""
import argparse
import os
from functools import partial
import gradio as gr
from PIL import Image
from huggingface_hub import hf_hub_download
from die_model import UNetDIEModel
from utils import resize_image, make_image_square, cast_pil_image_to_torch_tensor_with_4_channel_dim, \
remove_square_padding
def die_inference(
image_raw,
num_of_die_iterations,
die_model,
device
):
"""
Function to run the DIE model.
:param image_raw: raw image
:param num_of_die_iterations: number of DIE iterations
:param die_model: DIE model
:param device: device
:return: cleaned image
"""
# preprocess
image_raw_resized = resize_image(image_raw, 1500)
image_raw_resized_square = make_image_square(image_raw_resized)
image_raw_resized_square_tensor = cast_pil_image_to_torch_tensor_with_4_channel_dim(image_raw_resized_square)
image_raw_resized_square_tensor = image_raw_resized_square_tensor.to(device)
# convert string to int
num_of_die_iterations = int(num_of_die_iterations)
# inference
image_die = die_model.enhance_document_image(
image_raw_list=[image_raw_resized_square_tensor],
num_of_die_iterations=num_of_die_iterations
)[0]
# postprocess
image_die_resized = remove_square_padding(
original_image=image_raw,
square_image=image_die,
resize_back_to_original=True
)
return image_die_resized
def main():
"""
Main function to run the Gradio demo.
:return:
"""
args = parse_arguments()
description = "Welcome to the Document Image Enhancement (DIE) model demo on Hugging Face!\n\n" \
"" \
"This interactive application showcases a specialized AI model developed by " \
"the [Artificial Intelligence group](https://ai.renyi.hu) at the [Alfréd Rényi Institute of Mathematics](https://renyi.hu).\n\n" \
"" \
"Our DIE model is designed to enhance and restore archival and aged document images " \
"by removing various types of degradation, thereby making historical documents more legible " \
"and suitable for Optical Character Recognition (OCR) processing.\n\n" \
"" \
"The model effectively tackles 20-30 types of domain-specific noise found in historical records, " \
"such as scribbles, bleed-through text, faded or worn text, blurriness, textured noise, " \
"and unwanted background elements. " \
"By applying deep learning techniques, specifically a U-Net-based architecture, " \
"the model accurately cleans and clarifies text while preserving original details. " \
"This improved clarity dramatically boosts OCR accuracy, making it an ideal " \
"pre-processing tool in digitization workflows.\n\n" \
"" \
"If you’re interested in learning more about the model’s capabilities or potential applications, " \
"please contact us at: [email protected].\n\n"
# TODO: Add a description for the Number of DIE iterations parameter!
num_of_die_iterations_list = [1, 2, 3]
die_token = os.getenv("DIE_TOKEN")
# Provide images alone for example display
example_image_list = [
[Image.open(os.path.join(args.example_image_path, image_path))]
for image_path in os.listdir(args.example_image_path)
]
# Load DIE model
args.die_model_path = hf_hub_download(
repo_id="gabar92/die",
filename=args.die_model_path,
use_auth_token=die_token
)
die_model = UNetDIEModel(args=args)
# Partially apply the model and device arguments to die_inference
partial_die_inference = partial(die_inference, device=args.device, die_model=die_model)
demo = gr.Interface(
fn=partial_die_inference,
inputs=[
gr.Image(type="pil", label="Degraded Document Image"),
gr.Dropdown(num_of_die_iterations_list, label="Number of DIE iterations", value=1),
],
outputs=gr.Image(type="pil", label="Clean Document Image"),
title="Document Image Enhancement (DIE) model",
description=description,
examples=example_image_list
)
demo.launch(server_name="0.0.0.0", server_port=7860)
def parse_arguments():
"""
Parse arguments.
:return: argument namespace
"""
parser = argparse.ArgumentParser()
parser.add_argument("--die_model_path", default="2024_08_09_model_epoch_89.pt")
parser.add_argument("--device", default="cpu")
parser.add_argument("--example_image_path", default="example_images")
return parser.parse_args()
if __name__ == "__main__":
main()