""" Source url: https://github.com/OPHoperHPO/image-background-remove-tool Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. License: Apache License 2.0 """ import pathlib from typing import Union, List, Tuple import PIL import cv2 import numpy as np import torch from PIL import Image from carvekit.ml.arch.fba_matting.models import FBA from carvekit.ml.arch.fba_matting.transforms import ( trimap_transform, groupnorm_normalise_image, ) from carvekit.ml.files.models_loc import fba_pretrained from carvekit.utils.image_utils import convert_image, load_image from carvekit.utils.models_utils import get_precision_autocast, cast_network from carvekit.utils.pool_utils import batch_generator, thread_pool_processing __all__ = ["FBAMatting"] class FBAMatting(FBA): """ FBA Matting Neural Network to improve edges on image. """ def __init__( self, device="cpu", input_tensor_size: Union[List[int], int] = 2048, batch_size: int = 2, encoder="resnet50_GN_WS", load_pretrained: bool = True, fp16: bool = False, ): """ Initialize the FBAMatting model Args: device: processing device input_tensor_size: input image size batch_size: the number of images that the neural network processes in one run encoder: neural network encoder head load_pretrained: loading pretrained model fp16: use half precision """ super(FBAMatting, self).__init__(encoder=encoder) self.fp16 = fp16 self.device = device self.batch_size = batch_size if isinstance(input_tensor_size, list): self.input_image_size = input_tensor_size[:2] else: self.input_image_size = (input_tensor_size, input_tensor_size) self.to(device) if load_pretrained: self.load_state_dict(torch.load(fba_pretrained(), map_location=self.device)) self.eval() def data_preprocessing( self, data: Union[PIL.Image.Image, np.ndarray] ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: """ Transform input image to suitable data format for neural network Args: data: input image Returns: input for neural network """ resized = data.copy() if self.batch_size == 1: resized.thumbnail(self.input_image_size, resample=3) else: resized = resized.resize(self.input_image_size, resample=3) # noinspection PyTypeChecker image = np.array(resized, dtype=np.float64) image = image / 255.0 # Normalize image to [0, 1] values range if resized.mode == "RGB": image = image[:, :, ::-1] elif resized.mode == "L": image2 = np.copy(image) h, w = image2.shape image = np.zeros((h, w, 2)) # Transform trimap to binary data format image[image2 == 1, 1] = 1 image[image2 == 0, 0] = 1 else: raise ValueError("Incorrect color mode for image") h, w = image.shape[:2] # Scale input mlt to 8 h1 = int(np.ceil(1.0 * h / 8) * 8) w1 = int(np.ceil(1.0 * w / 8) * 8) x_scale = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_LANCZOS4) image_tensor = torch.from_numpy(x_scale).permute(2, 0, 1)[None, :, :, :].float() if resized.mode == "RGB": return image_tensor, groupnorm_normalise_image( image_tensor.clone(), format="nchw" ) else: return ( image_tensor, torch.from_numpy(trimap_transform(x_scale)) .permute(2, 0, 1)[None, :, :, :] .float(), ) @staticmethod def data_postprocessing( data: torch.tensor, trimap: PIL.Image.Image ) -> PIL.Image.Image: """ Transforms output data from neural network to suitable data format for using with other components of this framework. Args: data: output data from neural network trimap: Map with the area we need to refine Returns: Segmentation mask as PIL Image instance """ if trimap.mode != "L": raise ValueError("Incorrect color mode for trimap") pred = data.numpy().transpose((1, 2, 0)) pred = cv2.resize(pred, trimap.size, cv2.INTER_LANCZOS4)[:, :, 0] # noinspection PyTypeChecker # Clean mask by removing all false predictions outside trimap and already known area trimap_arr = np.array(trimap.copy()) pred[trimap_arr[:, :] == 0] = 0 # pred[trimap_arr[:, :] == 255] = 1 pred[pred < 0.3] = 0 return Image.fromarray(pred * 255).convert("L") def __call__( self, images: List[Union[str, pathlib.Path, PIL.Image.Image]], trimaps: List[Union[str, pathlib.Path, PIL.Image.Image]], ) -> List[PIL.Image.Image]: """ Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances Args: images: input images trimaps: Maps with the areas we need to refine Returns: segmentation masks as for input images, as PIL.Image.Image instances """ if len(images) != len(trimaps): raise ValueError( "Len of specified arrays of images and trimaps should be equal!" ) collect_masks = [] autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16) with autocast: cast_network(self, dtype) for idx_batch in batch_generator(range(len(images)), self.batch_size): inpt_images = thread_pool_processing( lambda x: convert_image(load_image(images[x])), idx_batch ) inpt_trimaps = thread_pool_processing( lambda x: convert_image(load_image(trimaps[x]), mode="L"), idx_batch ) inpt_img_batches = thread_pool_processing( self.data_preprocessing, inpt_images ) inpt_trimaps_batches = thread_pool_processing( self.data_preprocessing, inpt_trimaps ) inpt_img_batches_transformed = torch.vstack( [i[1] for i in inpt_img_batches] ) inpt_img_batches = torch.vstack([i[0] for i in inpt_img_batches]) inpt_trimaps_transformed = torch.vstack( [i[1] for i in inpt_trimaps_batches] ) inpt_trimaps_batches = torch.vstack( [i[0] for i in inpt_trimaps_batches] ) with torch.no_grad(): inpt_img_batches = inpt_img_batches.to(self.device) inpt_trimaps_batches = inpt_trimaps_batches.to(self.device) inpt_img_batches_transformed = inpt_img_batches_transformed.to( self.device ) inpt_trimaps_transformed = inpt_trimaps_transformed.to(self.device) output = super(FBAMatting, self).__call__( inpt_img_batches, inpt_trimaps_batches, inpt_img_batches_transformed, inpt_trimaps_transformed, ) output_cpu = output.cpu() del ( inpt_img_batches, inpt_trimaps_batches, inpt_img_batches_transformed, inpt_trimaps_transformed, output, ) masks = thread_pool_processing( lambda x: self.data_postprocessing(output_cpu[x], inpt_trimaps[x]), range(len(inpt_images)), ) collect_masks += masks return collect_masks