File size: 4,892 Bytes
a9d81c5 698149b a9d81c5 7ea1aa1 04f8aab d07ad97 a9d81c5 8430bc8 04f8aab 8430bc8 a9d81c5 a3e894a a9d81c5 d07ad97 a9d81c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
"""
Small demo application to explore Gradio.
"""
import argparse
import os
from functools import partial
import gradio as gr
from PIL import Image
from huggingface_hub import hf_hub_download
from die_model import UNetDIEModel
from utils import resize_image, make_image_square, cast_pil_image_to_torch_tensor_with_4_channel_dim, \
remove_square_padding
def die_inference(
image_raw,
num_of_die_iterations,
die_model,
device
):
"""
Function to run the DIE model.
:param image_raw: raw image
:param num_of_die_iterations: number of DIE iterations
:param die_model: DIE model
:param device: device
:return: cleaned image
"""
# preprocess
image_raw_resized = resize_image(image_raw, 1500)
image_raw_resized_square = make_image_square(image_raw_resized)
image_raw_resized_square_tensor = cast_pil_image_to_torch_tensor_with_4_channel_dim(image_raw_resized_square)
image_raw_resized_square_tensor = image_raw_resized_square_tensor.to(device)
# convert string to int
num_of_die_iterations = int(num_of_die_iterations)
# inference
image_die = die_model.enhance_document_image(
image_raw_list=[image_raw_resized_square_tensor],
num_of_die_iterations=num_of_die_iterations
)[0]
# postprocess
image_die_resized = remove_square_padding(
original_image=image_raw,
square_image=image_die,
resize_back_to_original=True
)
return image_die_resized
def main():
"""
Main function to run the Gradio demo.
:return:
"""
args = parse_arguments()
description = "Welcome to the Document Image Enhancement (DIE) model demo on Hugging Face!\n\n" \
"" \
"This interactive application showcases a specialized AI model developed by " \
"the [Artificial Intelligence group](https://ai.renyi.hu) at the [Alfréd Rényi Institute of Mathematics](https://renyi.hu).\n\n" \
"" \
"Our DIE model is designed to enhance and restore archival and aged document images " \
"by removing various types of degradation, thereby making historical documents more legible " \
"and suitable for Optical Character Recognition (OCR) processing.\n\n" \
"" \
"The model effectively tackles 20-30 types of domain-specific noise found in historical records, " \
"such as scribbles, bleed-through text, faded or worn text, blurriness, textured noise, " \
"and unwanted background elements. " \
"By applying deep learning techniques, specifically a U-Net-based architecture, " \
"the model accurately cleans and clarifies text while preserving original details. " \
"This improved clarity dramatically boosts OCR accuracy, making it an ideal " \
"pre-processing tool in digitization workflows.\n\n" \
"" \
"If you’re interested in learning more about the model’s capabilities or potential applications, " \
"please contact us at: [email protected].\n\n"
# TODO: Add a description for the Number of DIE iterations parameter!
num_of_die_iterations_list = [1, 2, 3]
die_token = os.getenv("DIE_TOKEN")
# Provide images alone for example display
example_image_list = [
[Image.open(os.path.join(args.example_image_path, image_path))]
for image_path in os.listdir(args.example_image_path)
]
# Load DIE model
args.die_model_path = hf_hub_download(
repo_id="gabar92/die",
filename=args.die_model_path,
use_auth_token=die_token
)
die_model = UNetDIEModel(args=args)
# Partially apply the model and device arguments to die_inference
partial_die_inference = partial(die_inference, device=args.device, die_model=die_model)
demo = gr.Interface(
fn=partial_die_inference,
inputs=[
gr.Image(type="pil", label="Degraded Document Image"),
gr.Dropdown(num_of_die_iterations_list, label="Number of DIE iterations", value=1),
],
outputs=gr.Image(type="pil", label="Clean Document Image"),
title="Document Image Enhancement (DIE) model",
description=description,
examples=example_image_list
)
demo.launch(server_name="0.0.0.0", server_port=7860)
def parse_arguments():
"""
Parse arguments.
:return: argument namespace
"""
parser = argparse.ArgumentParser()
parser.add_argument("--die_model_path", default="2024_08_09_model_epoch_89.pt")
parser.add_argument("--device", default="cpu")
parser.add_argument("--example_image_path", default="example_images")
return parser.parse_args()
if __name__ == "__main__":
main()
|