Spaces:
Runtime error
Runtime error
""" | |
Source url: https://github.com/OPHoperHPO/image-background-remove-tool | |
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. | |
License: Apache License 2.0 | |
""" | |
from PIL import Image | |
from carvekit.trimap.cv_gen import CV2TrimapGenerator | |
from carvekit.trimap.add_ops import prob_filter, prob_as_unknown_area, post_erosion | |
class TrimapGenerator(CV2TrimapGenerator): | |
def __init__( | |
self, prob_threshold: int = 231, kernel_size: int = 30, erosion_iters: int = 5 | |
): | |
""" | |
Initialize a TrimapGenerator instance | |
Args: | |
prob_threshold: Probability threshold at which the | |
prob_filter and prob_as_unknown_area operations will be applied | |
kernel_size: The size of the offset from the object mask | |
in pixels when an unknown area is detected in the trimap | |
erosion_iters: The number of iterations of erosion that | |
the object's mask will be subjected to before forming an unknown area | |
""" | |
super().__init__(kernel_size, erosion_iters=0) | |
self.prob_threshold = prob_threshold | |
self.__erosion_iters = erosion_iters | |
def __call__(self, original_image: Image.Image, mask: Image.Image) -> Image.Image: | |
""" | |
Generates trimap based on predicted object mask to refine object mask borders. | |
Based on cv2 erosion algorithm and additional prob. filters. | |
Args: | |
original_image: Original image | |
mask: Predicted object mask | |
Returns: | |
Generated trimap for image. | |
""" | |
filter_mask = prob_filter(mask=mask, prob_threshold=self.prob_threshold) | |
trimap = super(TrimapGenerator, self).__call__(original_image, filter_mask) | |
new_trimap = prob_as_unknown_area( | |
trimap=trimap, mask=mask, prob_threshold=self.prob_threshold | |
) | |
new_trimap = post_erosion(new_trimap, self.__erosion_iters) | |
return new_trimap | |