kandinsky-endpoint / handler.py
cjwalch's picture
Upload handler.py
0ab38e4 verified
from typing import Dict, Any
import torch
import base64
import io
from PIL import Image
from diffusers import AutoPipelineForImage2Image
# torch==2.5.1+cu124
# torchvision==0.18.0+cu124
# torchaudio==2.5.1+cu124
# diffusers==0.17.0.dev0
# Pillow==10.0.0
# fastapi
# pydantic
# uvicorn
#torchvision==0.18.0
# torchaudio==2.5.1
class EndpointHandler:
def __init__(self, path=""):
"""Initialize the model from the given path."""
self.pipeline = AutoPipelineForImage2Image.from_pretrained(
"cjwalch/kandinsky-endpoint",
torch_dtype=torch.float16,
use_safetensors=True
)
self.pipeline.enable_model_cpu_offload()
if torch.cuda.is_available():
self.pipeline.to("cuda")
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Run inference on the input image and return a base64-encoded result."""
try:
# Extract input parameters
prompt = data.get("inputs", "")
strength = float(data.get("strength", 0.6))
guidance_scale = float(data.get("guidance_scale", 7.0))
negative_prompt = data.get("negative_prompt", "blurry, ugly, deformed")
# Decode base64 image
init_image_b64 = data.get("init_image", None)
if not init_image_b64:
return {"error": "Missing 'init_image' in input data"}
image_bytes = base64.b64decode(init_image_b64)
init_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# Generate output image
output_image = self.pipeline(
prompt=prompt,
image=init_image,
strength=strength,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt
).images[0]
# Convert to base64
buffered = io.BytesIO()
output_image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
# Clear any cache and memory used by the model after inference
torch.cuda.empty_cache() # Clears GPU memory
del output_image # Delete the output image from memory
del init_image # Delete the input image from memory
return {"generated_image": img_str}
except Exception as e:
return {"error": str(e)}