""" U-Net based DIE model for cleaning document. """ import os from typing import Callable import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as T from PIL import Image class DoubleConv(nn.Module): """(convolution => [BN] => ReLU) * 2""" def __init__(self, in_channels, out_channels, mid_channels=None): super().__init__() if not mid_channels: mid_channels = out_channels self.double_conv = nn.Sequential( nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(mid_channels), nn.ReLU(inplace=True), nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): return self.double_conv(x) class Down(nn.Module): """Downscaling with maxpool then double conv""" def __init__(self, in_channels, out_channels): super().__init__() self.maxpool_conv = nn.Sequential( nn.MaxPool2d(2), DoubleConv(in_channels, out_channels) ) def forward(self, x): return self.maxpool_conv(x) class Up(nn.Module): """Upscaling then double conv""" def __init__(self, in_channels, out_channels, bilinear=True): super().__init__() # if bilinear, use the normal convolutions to reduce the number of channels if bilinear: self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) else: self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) self.conv = DoubleConv(in_channels, out_channels) def forward(self, x1, x2): x1 = self.up(x1) # input is CHW diffY = x2.size()[2] - x1.size()[2] diffX = x2.size()[3] - x1.size()[3] x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) # if you have padding issues, see # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd x = torch.cat([x2, x1], dim=1) return self.conv(x) class OutConv(nn.Module): def __init__(self, in_channels, out_channels): super(OutConv, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) def forward(self, x): x = self.conv(x) x = torch.sigmoid(x) return x class UNet(nn.Module): def __init__(self, n_channels, output_channel_dim=1, bilinear=False): super(UNet, self).__init__() self.n_channels = n_channels self.n_classes = output_channel_dim self.bilinear = bilinear self.inc = DoubleConv(n_channels, 64) self.down1 = Down(64, 128) self.down2 = Down(128, 256) self.down3 = Down(256, 512) factor = 2 if bilinear else 1 self.down4 = Down(512, 1024 // factor) self.up1 = Up(1024, 512 // factor, bilinear) self.up2 = Up(512, 256 // factor, bilinear) self.up3 = Up(256, 128 // factor, bilinear) self.up4 = Up(128, 64, bilinear) self.outc = OutConv(64, output_channel_dim) def forward(self, x): x1 = self.inc(x) x2 = self.down1(x1) x3 = self.down2(x2) x4 = self.down3(x3) x5 = self.down4(x4) x = self.up1(x5, x4) x = self.up2(x, x3) x = self.up3(x, x2) x = self.up4(x, x1) logits = self.outc(x) return logits def add_gaussian_noise( data: torch.Tensor ) -> torch.Tensor: """ Adding gaussian noise to torch tensor. :param data: torch tensor :return: noise perturbed tensor """ data_with_noise = data.clone() data_with_noise += torch.normal(mean=0, std=0.05, size=data_with_noise.shape).to(data_with_noise.device) data_with_noise = data_with_noise.clip(min=0, max=1) return data_with_noise def inference_model( model: Callable, model_input: torch.Tensor, device: str | torch.device, num_of_iterations: int = 1 ) -> list[torch.Tensor, ...]: """ Performing model inference. :param model: image pre-processing model :param model_input: data to model :param device: cuda device :param num_of_iterations: defines how many times feed the network (recursively) :return: predictions """ # inference model with torch.no_grad(): prediction_list = [] model_input = model_input.to(device) if len(model_input.shape) == 3: model_input = model_input.unsqueeze(dim=0) model_input_original_part = model_input[:, 0:3, ...] for i in range(num_of_iterations): if i == 0: model_input = add_gaussian_noise(model_input) prediction = model(model_input) prediction_list.append(prediction) model_input_new = torch.cat((model_input_original_part, prediction.detach()), dim=1) else: model_input_perturbed = add_gaussian_noise(model_input_new) prediction = model(model_input_perturbed) prediction_list.append(prediction) model_input_new = torch.cat((model_input_original_part, prediction.detach()), dim=1) return prediction_list def load_unet( model_path: str, device: str = 'cpu', eval_mode: bool = False, n_channels: int = 4, bilinear: bool = False, output_channel_dim: int = 1 ): print("Loading UNet model...") # image preprocessing model model = UNet( n_channels=n_channels, bilinear=bilinear, output_channel_dim=output_channel_dim ) # this hack is required due to distributed data parallel training state_dict = torch.load(os.path.join(model_path), map_location=device) new_state_dict = {key.replace('module.', ''): value for key, value in state_dict.items()} model.load_state_dict(new_state_dict) model.to(device) if eval_mode: model.eval() return model class UNetDIEModel: """ Class for Document Image Enhancement with U-Net. """ def __init__( self, *args, **kwargs ): """ Initialization. """ self.args = kwargs['args'] # loading text detector model self.die = load_unet( model_path=self.args.die_model_path, device=self.args.device, eval_mode=True, ) def enhance_document_image( self, image_raw_list: list[Image.Image], num_of_die_iterations: int = 1, ) -> list[Image.Image]: """" Enhance document image by removing noise. :param image_raw_list: original document page to process :param num_of_die_iterations: number of DIE iterations :return: cleaned document page to process """ with torch.no_grad(): # image_die = torch.stack(image_die_list, dim=0) image_die = torch.stack(image_raw_list, dim=0) # document image enhancement prediction_list = inference_model( model=self.die, model_input=image_die, num_of_iterations=num_of_die_iterations, device=self.args.device ) # transform DIE model output to image and apply post-processing last_prediction = prediction_list[-1] batch_size = last_prediction.size(0) image_die_list = [T.ToPILImage()(last_prediction[idx, ...]).convert('RGB') for idx in range(batch_size)] return image_die_list