File size: 4,892 Bytes
a9d81c5
 
 
 
 
 
 
 
 
 
698149b
a9d81c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ea1aa1
 
04f8aab
 
 
 
d07ad97
 
a9d81c5
8430bc8
04f8aab
 
 
 
8430bc8
 
a9d81c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a3e894a
a9d81c5
 
d07ad97
a9d81c5
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""
Small demo application to explore Gradio.
"""

import argparse
import os
from functools import partial

import gradio as gr
from PIL import Image
from huggingface_hub import hf_hub_download

from die_model import UNetDIEModel
from utils import resize_image, make_image_square, cast_pil_image_to_torch_tensor_with_4_channel_dim, \
    remove_square_padding


def die_inference(
    image_raw,
    num_of_die_iterations,
    die_model,
    device
):
    """
    Function to run the DIE model.
    :param image_raw: raw image
    :param num_of_die_iterations: number of DIE iterations
    :param die_model: DIE model
    :param device: device
    :return: cleaned image
    """

    # preprocess
    image_raw_resized = resize_image(image_raw, 1500)
    image_raw_resized_square = make_image_square(image_raw_resized)
    image_raw_resized_square_tensor = cast_pil_image_to_torch_tensor_with_4_channel_dim(image_raw_resized_square)
    image_raw_resized_square_tensor = image_raw_resized_square_tensor.to(device)

    # convert string to int
    num_of_die_iterations = int(num_of_die_iterations)

    # inference
    image_die = die_model.enhance_document_image(
        image_raw_list=[image_raw_resized_square_tensor],
        num_of_die_iterations=num_of_die_iterations
    )[0]

    # postprocess
    image_die_resized = remove_square_padding(
        original_image=image_raw,
        square_image=image_die,
        resize_back_to_original=True
    )


    return image_die_resized


def main():
    """
    Main function to run the Gradio demo.
    :return:
    """

    args = parse_arguments()

    description = "Welcome to the Document Image Enhancement (DIE) model demo on Hugging Face!\n\n" \
                  "" \
                  "This interactive application showcases a specialized AI model developed by " \
                  "the [Artificial Intelligence group](https://ai.renyi.hu) at the [Alfréd Rényi Institute of Mathematics](https://renyi.hu).\n\n" \
                  "" \
                  "Our DIE model is designed to enhance and restore archival and aged document images " \
                  "by removing various types of degradation, thereby making historical documents more legible " \
                  "and suitable for Optical Character Recognition (OCR) processing.\n\n" \
                  "" \
                  "The model effectively tackles 20-30 types of domain-specific noise found in historical records, " \
                  "such as scribbles, bleed-through text, faded or worn text, blurriness, textured noise, " \
                  "and unwanted background elements. " \
                  "By applying deep learning techniques, specifically a U-Net-based architecture, " \
                  "the model accurately cleans and clarifies text while preserving original details. " \
                  "This improved clarity dramatically boosts OCR accuracy, making it an ideal " \
                  "pre-processing tool in digitization workflows.\n\n" \
                  "" \
                  "If you’re interested in learning more about the model’s capabilities or potential applications, " \
                  "please contact us at: [email protected].\n\n"

    # TODO: Add a description for the Number of DIE iterations parameter!

    num_of_die_iterations_list = [1, 2, 3]

    die_token = os.getenv("DIE_TOKEN")

    # Provide images alone for example display
    example_image_list = [
        [Image.open(os.path.join(args.example_image_path, image_path))]
        for image_path in os.listdir(args.example_image_path)
    ]

    # Load DIE model
    args.die_model_path = hf_hub_download(
        repo_id="gabar92/die",
        filename=args.die_model_path,
        use_auth_token=die_token
    )
    
    die_model = UNetDIEModel(args=args)

    # Partially apply the model and device arguments to die_inference
    partial_die_inference = partial(die_inference, device=args.device, die_model=die_model)

    demo = gr.Interface(
        fn=partial_die_inference,
        inputs=[
            gr.Image(type="pil", label="Degraded Document Image"),
            gr.Dropdown(num_of_die_iterations_list, label="Number of DIE iterations", value=1),
        ],
        outputs=gr.Image(type="pil", label="Clean Document Image"),
        title="Document Image Enhancement (DIE) model",
        description=description,
        examples=example_image_list
    )

    demo.launch(server_name="0.0.0.0", server_port=7860)


def parse_arguments():
    """
    Parse arguments.
    :return: argument namespace
    """

    parser = argparse.ArgumentParser()

    parser.add_argument("--die_model_path", default="2024_08_09_model_epoch_89.pt")
    parser.add_argument("--device", default="cpu")

    parser.add_argument("--example_image_path", default="example_images")

    return parser.parse_args()


if __name__ == "__main__":

    main()