Spaces:
Build error
Build error
| import os | |
| import time | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from rembg import remove | |
| from segment_anything import SamPredictor, sam_model_registry | |
| import urllib.request | |
| from tqdm import tqdm | |
| def sam_init(sam_checkpoint, device_id=0): | |
| model_type = "vit_h" | |
| device = "cuda:{}".format(device_id) if torch.cuda.is_available() else "cpu" | |
| sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=device) | |
| predictor = SamPredictor(sam) | |
| return predictor | |
| def sam_out_nosave(predictor, input_image, *bbox_sliders): | |
| bbox = np.array(bbox_sliders) | |
| image = np.asarray(input_image) | |
| predictor.set_image(image) | |
| masks_bbox, scores_bbox, logits_bbox = predictor.predict( | |
| box=bbox, multimask_output=True | |
| ) | |
| out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8) | |
| out_image[:, :, :3] = image | |
| out_image_bbox = out_image.copy() | |
| out_image_bbox[:, :, 3] = ( | |
| masks_bbox[-1].astype(np.uint8) * 255 | |
| ) # np.argmax(scores_bbox) | |
| torch.cuda.empty_cache() | |
| return Image.fromarray(out_image_bbox, mode="RGBA") | |
| # contrast correction, rescale and recenter | |
| def image_preprocess(input_image, save_path, lower_contrast=True, rescale=True): | |
| image_arr = np.array(input_image) | |
| in_w, in_h = image_arr.shape[:2] | |
| if lower_contrast: | |
| alpha = 0.8 # Contrast control (1.0-3.0) | |
| beta = 0 # Brightness control (0-100) | |
| # Apply the contrast adjustment | |
| image_arr = cv2.convertScaleAbs(image_arr, alpha=alpha, beta=beta) | |
| image_arr[image_arr[..., -1] > 200, -1] = 255 | |
| ret, mask = cv2.threshold( | |
| np.array(input_image.split()[-1]), 0, 255, cv2.THRESH_BINARY | |
| ) | |
| x, y, w, h = cv2.boundingRect(mask) | |
| max_size = max(w, h) | |
| ratio = 0.75 | |
| if rescale: | |
| side_len = int(max_size / ratio) | |
| else: | |
| side_len = in_w | |
| padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8) | |
| center = side_len // 2 | |
| padded_image[ | |
| center - h // 2 : center - h // 2 + h, center - w // 2 : center - w // 2 + w | |
| ] = image_arr[y : y + h, x : x + w] | |
| rgba = Image.fromarray(padded_image).resize((256, 256), Image.LANCZOS) | |
| rgba.save(save_path) | |
| def pred_bbox(image): | |
| image_nobg = remove(image.convert("RGBA"), alpha_matting=True) | |
| alpha = np.asarray(image_nobg)[:, :, -1] | |
| x_nonzero = np.nonzero(alpha.sum(axis=0)) | |
| y_nonzero = np.nonzero(alpha.sum(axis=1)) | |
| x_min = int(x_nonzero[0].min()) | |
| y_min = int(y_nonzero[0].min()) | |
| x_max = int(x_nonzero[0].max()) | |
| y_max = int(y_nonzero[0].max()) | |
| return x_min, y_min, x_max, y_max | |
| # convert a function into recursive style to handle nested dict/list/tuple variables | |
| def make_recursive_func(func): | |
| def wrapper(vars, *args, **kwargs): | |
| if isinstance(vars, list): | |
| return [wrapper(x, *args, **kwargs) for x in vars] | |
| elif isinstance(vars, tuple): | |
| return tuple([wrapper(x, *args, **kwargs) for x in vars]) | |
| elif isinstance(vars, dict): | |
| return {k: wrapper(v, *args, **kwargs) for k, v in vars.items()} | |
| else: | |
| return func(vars, *args, **kwargs) | |
| return wrapper | |
| def todevice(vars, device="cuda"): | |
| if isinstance(vars, torch.Tensor): | |
| return vars.to(device) | |
| elif isinstance(vars, str): | |
| return vars | |
| elif isinstance(vars, bool): | |
| return vars | |
| elif isinstance(vars, float): | |
| return vars | |
| elif isinstance(vars, int): | |
| return vars | |
| else: | |
| raise NotImplementedError("invalid input type {} for tensor2numpy".format(type(vars))) | |