gabar92 commited on
Commit
a9d81c5
·
1 Parent(s): bc329bc

add implementation scripts

Browse files
Files changed (4) hide show
  1. app.py +135 -0
  2. die_model.py +264 -0
  3. requirements.txt +4 -0
  4. utils.py +139 -0
app.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Small demo application to explore Gradio.
3
+ """
4
+
5
+ import argparse
6
+ import os
7
+ from functools import partial
8
+
9
+ import gradio as gr
10
+ from PIL import Image
11
+
12
+ from die_model import UNetDIEModel
13
+ from utils import resize_image, make_image_square, cast_pil_image_to_torch_tensor_with_4_channel_dim, \
14
+ remove_square_padding
15
+
16
+
17
+ def die_inference(
18
+ image_raw,
19
+ num_of_die_iterations,
20
+ die_model,
21
+ device
22
+ ):
23
+ """
24
+ Function to run the DIE model.
25
+ :param image_raw: raw image
26
+ :param num_of_die_iterations: number of DIE iterations
27
+ :param die_model: DIE model
28
+ :param device: device
29
+ :return: cleaned image
30
+ """
31
+
32
+ # preprocess
33
+ image_raw_resized = resize_image(image_raw, 1500)
34
+ image_raw_resized_square = make_image_square(image_raw_resized)
35
+ image_raw_resized_square_tensor = cast_pil_image_to_torch_tensor_with_4_channel_dim(image_raw_resized_square)
36
+ image_raw_resized_square_tensor = image_raw_resized_square_tensor.to(device)
37
+
38
+ # convert string to int
39
+ num_of_die_iterations = int(num_of_die_iterations)
40
+
41
+ # inference
42
+ image_die = die_model.enhance_document_image(
43
+ image_raw_list=[image_raw_resized_square_tensor],
44
+ num_of_die_iterations=num_of_die_iterations
45
+ )[0]
46
+
47
+ # postprocess
48
+ image_die_resized = remove_square_padding(
49
+ original_image=image_raw,
50
+ square_image=image_die,
51
+ resize_back_to_original=True
52
+ )
53
+
54
+
55
+ return image_die_resized
56
+
57
+
58
+ def main():
59
+ """
60
+ Main function to run the Gradio demo.
61
+ :return:
62
+ """
63
+
64
+ args = parse_arguments()
65
+
66
+ description = "Welcome to the Document Image Enhancement (DIE) model demo on Hugging Face!\n\n" \
67
+ "" \
68
+ "This interactive application showcases a specialized AI model developed by " \
69
+ "the [Artificial Intelligence group](https://ai.renyi.hu) at the [Alfréd Rényi Institute of Mathematics](https://renyi.hu).\n\n" \
70
+ "" \
71
+ "Our DIE model is designed to enhance and restore archival and aged document images " \
72
+ "by removing various types of degradation, thereby making historical documents more legible " \
73
+ "and suitable for Optical Character Recognition (OCR) processing.\n\n" \
74
+ "" \
75
+ "The model effectively tackles 20-30 types of domain-specific noise found in historical records, " \
76
+ "such as scribbles, bleed-through text, faded or worn text, blurriness, textured noise, " \
77
+ "and unwanted background elements. " \
78
+ "By applying deep learning techniques, specifically a U-Net-based architecture, " \
79
+ "the model accurately cleans and clarifies text while preserving original details. " \
80
+ "This improved clarity dramatically boosts OCR accuracy, making it an ideal " \
81
+ "pre-processing tool in digitization workflows.\n\n" \
82
+ "" \
83
+ "If you’re interested in learning more about the model’s capabilities or potential applications, " \
84
+ "please contact us at: [email protected].\n\n"
85
+
86
+ # TODO: Add a description for the Number of DIE iterations parameter!
87
+
88
+ num_of_die_iterations_list = [1, 2, 3]
89
+
90
+ # Provide images alone for example display
91
+ example_image_list = [
92
+ [Image.open(os.path.join(args.example_image_path, image_path))]
93
+ for image_path in os.listdir(args.example_image_path)
94
+ ]
95
+
96
+ # Load DIE model
97
+ die_model = UNetDIEModel(args=args)
98
+
99
+ # Partially apply the model and device arguments to die_inference
100
+ partial_die_inference = partial(die_inference, device=args.device, die_model=die_model)
101
+
102
+ demo = gr.Interface(
103
+ fn=partial_die_inference,
104
+ inputs=[
105
+ gr.Image(type="pil", label="Degraded Document Image"),
106
+ gr.Dropdown(num_of_die_iterations_list, label="Number of DIE iterations", value=1),
107
+ ],
108
+ outputs=gr.Image(type="pil", label="Clean Document Image"),
109
+ title="Document Image Enhancement (DIE) model",
110
+ description=description,
111
+ examples=example_image_list
112
+ )
113
+
114
+ demo.launch(server_name="0.0.0.0", server_port=7860)
115
+
116
+
117
+ def parse_arguments():
118
+ """
119
+ Parse arguments.
120
+ :return: argument namespace
121
+ """
122
+
123
+ parser = argparse.ArgumentParser()
124
+
125
+ parser.add_argument("--die_model_path", default="./2024_08_09_model_epoch_89.pt")
126
+ parser.add_argument("--device", default="cpu")
127
+
128
+ parser.add_argument("--example_image_path", default="./example_images")
129
+
130
+ return parser.parse_args()
131
+
132
+
133
+ if __name__ == "__main__":
134
+
135
+ main()
die_model.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ U-Net based DIE model for cleaning document.
3
+ """
4
+
5
+ import os
6
+ from typing import Callable
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torchvision.transforms as T
12
+ from PIL import Image
13
+
14
+
15
+ class DoubleConv(nn.Module):
16
+ """(convolution => [BN] => ReLU) * 2"""
17
+
18
+ def __init__(self, in_channels, out_channels, mid_channels=None):
19
+ super().__init__()
20
+ if not mid_channels:
21
+ mid_channels = out_channels
22
+ self.double_conv = nn.Sequential(
23
+ nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
24
+ nn.BatchNorm2d(mid_channels),
25
+ nn.ReLU(inplace=True),
26
+ nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
27
+ nn.BatchNorm2d(out_channels),
28
+ nn.ReLU(inplace=True)
29
+ )
30
+
31
+ def forward(self, x):
32
+ return self.double_conv(x)
33
+
34
+
35
+ class Down(nn.Module):
36
+ """Downscaling with maxpool then double conv"""
37
+
38
+ def __init__(self, in_channels, out_channels):
39
+ super().__init__()
40
+ self.maxpool_conv = nn.Sequential(
41
+ nn.MaxPool2d(2),
42
+ DoubleConv(in_channels, out_channels)
43
+ )
44
+
45
+ def forward(self, x):
46
+ return self.maxpool_conv(x)
47
+
48
+
49
+ class Up(nn.Module):
50
+ """Upscaling then double conv"""
51
+
52
+ def __init__(self, in_channels, out_channels, bilinear=True):
53
+ super().__init__()
54
+
55
+ # if bilinear, use the normal convolutions to reduce the number of channels
56
+ if bilinear:
57
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
58
+ self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
59
+ else:
60
+ self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
61
+ self.conv = DoubleConv(in_channels, out_channels)
62
+
63
+ def forward(self, x1, x2):
64
+ x1 = self.up(x1)
65
+ # input is CHW
66
+ diffY = x2.size()[2] - x1.size()[2]
67
+ diffX = x2.size()[3] - x1.size()[3]
68
+
69
+ x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
70
+ diffY // 2, diffY - diffY // 2])
71
+ # if you have padding issues, see
72
+ # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
73
+ # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
74
+ x = torch.cat([x2, x1], dim=1)
75
+ return self.conv(x)
76
+
77
+
78
+ class OutConv(nn.Module):
79
+ def __init__(self, in_channels, out_channels):
80
+ super(OutConv, self).__init__()
81
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
82
+
83
+ def forward(self, x):
84
+ x = self.conv(x)
85
+ x = torch.sigmoid(x)
86
+ return x
87
+
88
+
89
+ class UNet(nn.Module):
90
+ def __init__(self, n_channels, output_channel_dim=1, bilinear=False):
91
+ super(UNet, self).__init__()
92
+ self.n_channels = n_channels
93
+ self.n_classes = output_channel_dim
94
+ self.bilinear = bilinear
95
+
96
+ self.inc = DoubleConv(n_channels, 64)
97
+ self.down1 = Down(64, 128)
98
+ self.down2 = Down(128, 256)
99
+ self.down3 = Down(256, 512)
100
+ factor = 2 if bilinear else 1
101
+ self.down4 = Down(512, 1024 // factor)
102
+ self.up1 = Up(1024, 512 // factor, bilinear)
103
+ self.up2 = Up(512, 256 // factor, bilinear)
104
+ self.up3 = Up(256, 128 // factor, bilinear)
105
+ self.up4 = Up(128, 64, bilinear)
106
+ self.outc = OutConv(64, output_channel_dim)
107
+
108
+ def forward(self, x):
109
+ x1 = self.inc(x)
110
+ x2 = self.down1(x1)
111
+ x3 = self.down2(x2)
112
+ x4 = self.down3(x3)
113
+ x5 = self.down4(x4)
114
+ x = self.up1(x5, x4)
115
+ x = self.up2(x, x3)
116
+ x = self.up3(x, x2)
117
+ x = self.up4(x, x1)
118
+ logits = self.outc(x)
119
+ return logits
120
+
121
+
122
+ def add_gaussian_noise(
123
+ data: torch.Tensor
124
+ ) -> torch.Tensor:
125
+ """
126
+ Adding gaussian noise to torch tensor.
127
+ :param data: torch tensor
128
+ :return: noise perturbed tensor
129
+ """
130
+
131
+ data_with_noise = data.clone()
132
+ data_with_noise += torch.normal(mean=0, std=0.05, size=data_with_noise.shape).to(data_with_noise.device)
133
+ data_with_noise = data_with_noise.clip(min=0, max=1)
134
+
135
+ return data_with_noise
136
+
137
+
138
+ def inference_model(
139
+ model: Callable,
140
+ model_input: torch.Tensor,
141
+ device: str | torch.device,
142
+ num_of_iterations: int = 1
143
+ ) -> list[torch.Tensor, ...]:
144
+ """
145
+ Performing model inference.
146
+ :param model: image pre-processing model
147
+ :param model_input: data to model
148
+ :param device: cuda device
149
+ :param num_of_iterations: defines how many times feed the network (recursively)
150
+ :return: predictions
151
+ """
152
+
153
+ # inference model
154
+ with torch.no_grad():
155
+
156
+ prediction_list = []
157
+
158
+ model_input = model_input.to(device)
159
+
160
+ if len(model_input.shape) == 3:
161
+ model_input = model_input.unsqueeze(dim=0)
162
+
163
+ model_input_original_part = model_input[:, 0:3, ...]
164
+
165
+ for i in range(num_of_iterations):
166
+
167
+ if i == 0:
168
+ model_input = add_gaussian_noise(model_input)
169
+ prediction = model(model_input)
170
+ prediction_list.append(prediction)
171
+ model_input_new = torch.cat((model_input_original_part, prediction.detach()), dim=1)
172
+ else:
173
+ model_input_perturbed = add_gaussian_noise(model_input_new)
174
+ prediction = model(model_input_perturbed)
175
+ prediction_list.append(prediction)
176
+ model_input_new = torch.cat((model_input_original_part, prediction.detach()), dim=1)
177
+
178
+ return prediction_list
179
+
180
+
181
+ def load_unet(
182
+ model_path: str,
183
+ device: str = 'cpu',
184
+ eval_mode: bool = False,
185
+ n_channels: int = 4,
186
+ bilinear: bool = False,
187
+ output_channel_dim: int = 1
188
+ ):
189
+
190
+ print("Loading UNet model...")
191
+
192
+ # image preprocessing model
193
+ model = UNet(
194
+ n_channels=n_channels,
195
+ bilinear=bilinear,
196
+ output_channel_dim=output_channel_dim
197
+ )
198
+
199
+ # this hack is required due to distributed data parallel training
200
+ state_dict = torch.load(os.path.join(model_path), map_location=device)
201
+ new_state_dict = {key.replace('module.', ''): value for key, value in state_dict.items()}
202
+ model.load_state_dict(new_state_dict)
203
+ model.to(device)
204
+
205
+ if eval_mode:
206
+ model.eval()
207
+
208
+ return model
209
+
210
+
211
+ class UNetDIEModel:
212
+ """
213
+ Class for Document Image Enhancement with U-Net.
214
+ """
215
+
216
+ def __init__(
217
+ self,
218
+ *args,
219
+ **kwargs
220
+ ):
221
+ """
222
+ Initialization.
223
+ """
224
+
225
+ self.args = kwargs['args']
226
+
227
+ # loading text detector model
228
+ self.die = load_unet(
229
+ model_path=self.args.die_model_path,
230
+ device=self.args.device,
231
+ eval_mode=True,
232
+ )
233
+
234
+ def enhance_document_image(
235
+ self,
236
+ image_raw_list: list[Image.Image],
237
+ num_of_die_iterations: int = 1,
238
+ ) -> list[Image.Image]:
239
+ """"
240
+ Enhance document image by removing noise.
241
+ :param image_raw_list: original document page to process
242
+ :param num_of_die_iterations: number of DIE iterations
243
+ :return: cleaned document page to process
244
+ """
245
+
246
+ with torch.no_grad():
247
+
248
+ # image_die = torch.stack(image_die_list, dim=0)
249
+ image_die = torch.stack(image_raw_list, dim=0)
250
+
251
+ # document image enhancement
252
+ prediction_list = inference_model(
253
+ model=self.die,
254
+ model_input=image_die,
255
+ num_of_iterations=num_of_die_iterations,
256
+ device=self.args.device
257
+ )
258
+
259
+ # transform DIE model output to image and apply post-processing
260
+ last_prediction = prediction_list[-1]
261
+ batch_size = last_prediction.size(0)
262
+ image_die_list = [T.ToPILImage()(last_prediction[idx, ...]).convert('RGB') for idx in range(batch_size)]
263
+
264
+ return image_die_list
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ pillow
3
+ torch
4
+ torchvision
utils.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for the DIE demo.
3
+ """
4
+
5
+
6
+ import torch
7
+ from PIL import Image
8
+ from torch import Tensor
9
+ from torchvision import transforms
10
+
11
+
12
+ def resize_image(
13
+ image: Image.Image,
14
+ max_size: int = 1024
15
+ ) -> Image.Image:
16
+ """
17
+ Resizing images by keeping the ratios
18
+ :param image: PIL image
19
+ :param max_size: size of the new image larger side
20
+ :return: the resized PIL image
21
+ """
22
+
23
+ # extracting size
24
+ width, height = image.size
25
+
26
+ # checking which side is larger
27
+ height_larger = True if height >= width else False
28
+
29
+ # reshaping based on the larger side
30
+ if height_larger:
31
+ height_new = max_size
32
+ width_new = round((height_new / height) * width)
33
+ else:
34
+ width_new = max_size
35
+ height_new = round((width_new / width) * height)
36
+
37
+ return image.resize((width_new, height_new))
38
+
39
+
40
+ def make_image_square(
41
+ image: Image.Image,
42
+ image_size: int = 1024
43
+ ) -> Image.Image:
44
+ """
45
+ Making the input image a square
46
+ :param image: PIL image
47
+ :param image_size: defines the size of the square image
48
+ :return: the square-sized PIL image
49
+ """
50
+
51
+ if max(image.size) > image_size:
52
+ image_size = max(image.size)
53
+ # creating a new square image
54
+ if image.mode == 'L':
55
+ image_square = Image.new(image.mode, (image_size, image_size), (255,))
56
+ elif image.mode == 'RGB':
57
+ image_square = Image.new(image.mode, (image_size, image_size), (255, 255, 255))
58
+ else:
59
+ raise NotImplementedError("Not implemented image mode.")
60
+ # copying the original content onto the blank image
61
+ image_square.paste(image, (0, 0))
62
+
63
+ return image_square
64
+
65
+
66
+ def cast_pil_image_to_torch_tensor_with_4_channel_dim(
67
+ image: Image.Image,
68
+ device: str | None = None
69
+ ) -> Tensor:
70
+ """
71
+ Casting PIL image to torch tensor.
72
+ Adding the grayscale image of the original RGB image as a 4th channel dimension.
73
+ :param image: input image
74
+ :param device: cuda device
75
+ :return: torch tensor (4 channel dim)
76
+ """
77
+
78
+ # PIL image to torch tensor transformation
79
+ transform = transforms.Compose([transforms.PILToTensor()])
80
+
81
+ # creating gray image
82
+ image_gray = image.convert('L')
83
+
84
+ # casting PIL images to torch tensor with normalization
85
+ image_tensor = transform(image.convert('RGB')).to(torch.float32) / 255.0
86
+ image_gray_tensor = transform(image_gray).to(torch.float32) / 255.0
87
+
88
+ # concatenating gray channel to RGB channel
89
+ final_image_tensor = torch.cat((image_tensor, image_gray_tensor), dim=0)
90
+
91
+ # moving tensor to gpu if required
92
+ if device is not None:
93
+ final_image_tensor = final_image_tensor.to(device)
94
+
95
+ return final_image_tensor
96
+
97
+
98
+ def remove_square_padding(
99
+ original_image: Image.Image | Tensor,
100
+ square_image: Image.Image | Tensor,
101
+ resize_back_to_original: bool = False
102
+ ):
103
+ """
104
+ Removing the square padding added to the original image to make square.
105
+ :param original_image: the image with the original size
106
+ :param square_image: the image with the square size
107
+ :param resize_back_to_original: defines if we want to resize the square image back to the original size
108
+ :return: square image with the original size ratio
109
+ """
110
+
111
+ if isinstance(original_image, Image.Image):
112
+ original_width, original_height = original_image.size
113
+ else:
114
+ original_height, original_width = original_image.shape[:2]
115
+
116
+ if isinstance(square_image, Image.Image):
117
+ square_width, square_height = square_image.size
118
+ else:
119
+ square_height, square_width = square_image.shape[:2]
120
+
121
+ if original_width > original_height:
122
+ ratio = square_width / original_width
123
+ new_width = square_width
124
+ new_height = int(ratio * original_height)
125
+ else:
126
+ ratio = square_height / original_height
127
+ new_height = square_height
128
+ new_width = int(ratio * original_width)
129
+
130
+ # cutting size of the square image to the original ratio
131
+ if isinstance(square_image, Image.Image):
132
+ square_image_with_original_ratio = square_image.crop((0, 0, new_width, new_height))
133
+ else:
134
+ square_image_with_original_ratio = square_image[:new_height, :new_width]
135
+
136
+ if resize_back_to_original:
137
+ square_image_with_original_ratio = square_image_with_original_ratio.resize((original_width, original_height))
138
+
139
+ return square_image_with_original_ratio