import base64 import gzip import numpy as np from io import BytesIO from typing import Dict, List, Any from PIL import Image import torch from transformers import SamModel, SamProcessor def pack_bits(boolean_tensor): # Flatten the tensor and add padding if necessary flat = boolean_tensor.flatten() if flat.size()[0] % 8 != 0: padding = np.zeros((8 - flat.size % 8,), dtype=bool) flat = np.concatenate([flat, padding]) # Reshape into bytes and pack into binary string packed = np.packbits(flat.reshape((-1, 8))) packed = packed.tobytes() return gzip.compress(packed) # json_str = json.dumps({"shape": boolean_tensor.shape, "data": binary_str}) class EndpointHandler(): def __init__(self, path=""): # Preload all the elements you are going to need at inference. self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(self.device) self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base") def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: `str` | `PIL.Image` | `np.array`) kwargs Return: A :obj:`list` | `dict`: will be serialized and returned """ inputs = data.pop("inputs", data) parameters = data.pop("parameters", {"mode": "image"}) # Decode base64 image to PIL image = Image.open(BytesIO(base64.b64decode(inputs['image']))).convert("RGB") input_points = [inputs['points']] # 2D localization of a window model_inputs = self.processor(image, input_points=input_points, return_tensors="pt").to(self.device) outputs = self.model(**model_inputs) masks = self.processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), model_inputs["original_sizes"].cpu(), model_inputs["reshaped_input_sizes"].cpu()) scores = outputs.iou_scores packed = [base64.b64encode(pack_bits(masks[0][0][i])).decode() for i in range(masks[0].shape[1])] shape = list(masks[0].shape)[2:] return {"masks": packed, "scores": scores[0][0].tolist(), "shape": shape}