Spaces:
Sleeping
Sleeping
import requests | |
from typing import Optional, Dict, Any, Union | |
import os | |
import rembg | |
import numpy as np | |
from PIL import Image | |
import io | |
import base64 | |
import re | |
def remove_background( | |
image_input: Union[str, bytes, np.ndarray, Image.Image], | |
model_name: str = "u2net" | |
) -> Dict[str, Any]: | |
""" | |
Remove background from an image. | |
Args: | |
image_input: Can be one of: | |
- URL string | |
- Data URL string (base64 encoded) | |
- Image bytes | |
- NumPy array | |
- PIL Image | |
model_name: Background removal model to use | |
Returns: | |
Dictionary with result information and processed image data | |
""" | |
try: | |
# Initialize session | |
session = rembg.new_session(model_name=model_name) | |
# Handle different input types | |
if isinstance(image_input, str): | |
if image_input.startswith('http://') or image_input.startswith('https://'): | |
# If input is a URL, download the image | |
response = requests.get(image_input, timeout=30) | |
response.raise_for_status() | |
input_data = response.content | |
source_info = f"URL: {image_input}" | |
elif image_input.startswith('data:'): | |
# If input is a data URL (base64 encoded string) | |
# Extract the base64 part after the comma | |
base64_data = re.sub('^data:image/.+;base64,', '', image_input) | |
input_data = base64.b64decode(base64_data) | |
source_info = "data URL" | |
else: | |
return { | |
"success": False, | |
"error": f"Unsupported string input format: {image_input[:30]}...", | |
"image_data": None | |
} | |
elif isinstance(image_input, bytes): | |
# If input is bytes, use directly | |
input_data = image_input | |
source_info = "image bytes" | |
elif isinstance(image_input, np.ndarray): | |
# If input is numpy array, convert to bytes | |
pil_img = Image.fromarray(image_input) | |
buffer = io.BytesIO() | |
pil_img.save(buffer, format="PNG") | |
input_data = buffer.getvalue() | |
source_info = "numpy array" | |
elif isinstance(image_input, Image.Image): | |
# If input is PIL Image, convert to bytes | |
buffer = io.BytesIO() | |
image_input.save(buffer, format="PNG") | |
input_data = buffer.getvalue() | |
source_info = "PIL Image" | |
else: | |
return { | |
"success": False, | |
"error": f"Unsupported input type: {type(image_input)}", | |
"image_data": None | |
} | |
# Remove background | |
output_data = rembg.remove(input_data, session=session) | |
return { | |
"success": True, | |
"message": f"Background removed from {source_info} using {model_name} model", | |
"image_data": output_data, | |
"model_used": model_name | |
} | |
except requests.RequestException as e: | |
return { | |
"success": False, | |
"error": f"Failed to download image: {str(e)}", | |
"image_data": None | |
} | |
except Exception as e: | |
return { | |
"success": False, | |
"error": f"Failed to process image: {str(e)}", | |
"image_data": None | |
} |