Spaces:
Runtime error
Runtime error
""" | |
Source url: https://github.com/OPHoperHPO/image-background-remove-tool | |
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. | |
License: Apache License 2.0 | |
""" | |
from pathlib import Path | |
from typing import Union, List, Optional | |
from PIL import Image | |
from carvekit.ml.wrap.basnet import BASNET | |
from carvekit.ml.wrap.deeplab_v3 import DeepLabV3 | |
from carvekit.ml.wrap.u2net import U2NET | |
from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7 | |
from carvekit.pipelines.preprocessing import PreprocessingStub | |
from carvekit.pipelines.postprocessing import MattingMethod | |
from carvekit.utils.image_utils import load_image | |
from carvekit.utils.mask_utils import apply_mask | |
from carvekit.utils.pool_utils import thread_pool_processing | |
class Interface: | |
def __init__( | |
self, | |
seg_pipe: Union[U2NET, BASNET, DeepLabV3, TracerUniversalB7], | |
pre_pipe: Optional[Union[PreprocessingStub]] = None, | |
post_pipe: Optional[Union[MattingMethod]] = None, | |
device="cpu", | |
): | |
""" | |
Initializes an object for interacting with pipelines and other components of the CarveKit framework. | |
Args: | |
pre_pipe: Initialized pre-processing pipeline object | |
seg_pipe: Initialized segmentation network object | |
post_pipe: Initialized postprocessing pipeline object | |
device: The processing device that will be used to apply the masks to the images. | |
""" | |
self.device = device | |
self.preprocessing_pipeline = pre_pipe | |
self.segmentation_pipeline = seg_pipe | |
self.postprocessing_pipeline = post_pipe | |
def __call__( | |
self, images: List[Union[str, Path, Image.Image]] | |
) -> List[Image.Image]: | |
""" | |
Removes the background from the specified images. | |
Args: | |
images: list of input images | |
Returns: | |
List of images without background as PIL.Image.Image instances | |
""" | |
images = thread_pool_processing(load_image, images) | |
if self.preprocessing_pipeline is not None: | |
masks: List[Image.Image] = self.preprocessing_pipeline( | |
interface=self, images=images | |
) | |
else: | |
masks: List[Image.Image] = self.segmentation_pipeline(images=images) | |
if self.postprocessing_pipeline is not None: | |
images: List[Image.Image] = self.postprocessing_pipeline( | |
images=images, masks=masks | |
) | |
else: | |
images = list( | |
map( | |
lambda x: apply_mask( | |
image=images[x], mask=masks[x], device=self.device | |
), | |
range(len(images)), | |
) | |
) | |
return images | |