Spaces:
Runtime error
Runtime error
from pathlib import Path | |
import click | |
import tqdm | |
from carvekit.utils.image_utils import ALLOWED_SUFFIXES | |
from carvekit.utils.pool_utils import batch_generator, thread_pool_processing | |
from carvekit.web.schemas.config import MLConfig | |
from carvekit.web.utils.init_utils import init_interface | |
from carvekit.utils.fs_utils import save_file | |
def removebg( | |
i: str, | |
o: str, | |
pre: str, | |
post: str, | |
net: str, | |
recursive: bool, | |
batch_size: int, | |
batch_size_seg: int, | |
batch_size_mat: int, | |
seg_mask_size: int, | |
matting_mask_size: int, | |
device: str, | |
fp16: bool, | |
trimap_dilation: int, | |
trimap_erosion: int, | |
trimap_prob_threshold: int, | |
): | |
out_path = Path(o) | |
input_path = Path(i) | |
if input_path.is_dir(): | |
if recursive: | |
all_images = input_path.rglob("*.*") | |
else: | |
all_images = input_path.glob("*.*") | |
all_images = [ | |
i | |
for i in all_images | |
if i.suffix.lower() in ALLOWED_SUFFIXES and "_bg_removed" not in i.name | |
] | |
else: | |
all_images = [input_path] | |
interface_config = MLConfig( | |
segmentation_network=net, | |
preprocessing_method=pre, | |
postprocessing_method=post, | |
device=device, | |
batch_size_seg=batch_size_seg, | |
batch_size_matting=batch_size_mat, | |
seg_mask_size=seg_mask_size, | |
matting_mask_size=matting_mask_size, | |
fp16=fp16, | |
trimap_dilation=trimap_dilation, | |
trimap_erosion=trimap_erosion, | |
trimap_prob_threshold=trimap_prob_threshold, | |
) | |
interface = init_interface(interface_config) | |
for image_batch in tqdm.tqdm( | |
batch_generator(all_images, n=batch_size), | |
total=int(len(all_images) / batch_size), | |
desc="Removing background", | |
unit=" image batch", | |
colour="blue", | |
): | |
images_without_background = interface(image_batch) # Remove background | |
thread_pool_processing( | |
lambda x: save_file(out_path, image_batch[x], images_without_background[x]), | |
range((len(image_batch))), | |
) # Drop images to fs | |
if __name__ == "__main__": | |
removebg() | |