flux1-fill-dev-custom / handler.py
toby007
update handler
3c71e14
import base64
from io import BytesIO
from typing import Any, Dict
import torch
from diffusers import FluxFillPipeline
from PIL import Image
def decode_image(b64_string):
image_data = base64.b64decode(b64_string)
return Image.open(BytesIO(image_data)).convert("RGB")
def encode_image(image):
buffer = BytesIO()
image.save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")
class EndpointHandler:
def __init__(self, path="shangguanyanyan/flux1-fill-dev-custom"):
self.pipe = FluxFillPipeline.from_pretrained(
path, torch_dtype=torch.bfloat16
).to("cuda" if torch.cuda.is_available() else "cpu")
self.parameters = {
"height": 1632,
"width": 1232,
"guidance_scale": 30,
"num_inference_steps": 50,
"max_sequence_length": 512,
"generator": torch.Generator("cpu").manual_seed(0),
}
def __call__(self, data: Any) -> Dict[str, Any]:
"""
data: {
"inputs": {
"image": base64_image,
"mask": base64_mask,
"prompt": prompt
},
"parameters": {
"height": 1632,
"width": 1232,
"guidance_scale": 30,
"num_inference_steps": 50,
"max_sequence_length": 512,
}
}
"""
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", {})
parameters.update(self.parameters)
base64_image = inputs.pop("image", "")
base64_mask = inputs.pop("mask", "")
prompt = inputs.pop("prompt", "")
if not base64_image or not base64_mask or not prompt:
return {
"error": "Please provide image, mask and prompt",
"status": "failed",
}
image = decode_image(base64_image)
mask = decode_image(base64_mask)
image = self.pipe(
prompt=prompt,
image=image,
mask_image=mask,
**parameters,
).images[0]
return {"image": encode_image(image), "status": "success"}