|
from typing import Dict, Any
|
|
import torch
|
|
import base64
|
|
import io
|
|
from PIL import Image
|
|
from diffusers import AutoPipelineForImage2Image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EndpointHandler:
|
|
def __init__(self, path=""):
|
|
"""Initialize the model from the given path."""
|
|
self.pipeline = AutoPipelineForImage2Image.from_pretrained(
|
|
"cjwalch/kandinsky-endpoint",
|
|
torch_dtype=torch.float16,
|
|
use_safetensors=True
|
|
)
|
|
self.pipeline.enable_model_cpu_offload()
|
|
if torch.cuda.is_available():
|
|
self.pipeline.to("cuda")
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Run inference on the input image and return a base64-encoded result."""
|
|
try:
|
|
|
|
prompt = data.get("inputs", "")
|
|
strength = float(data.get("strength", 0.6))
|
|
guidance_scale = float(data.get("guidance_scale", 7.0))
|
|
negative_prompt = data.get("negative_prompt", "blurry, ugly, deformed")
|
|
|
|
|
|
init_image_b64 = data.get("init_image", None)
|
|
if not init_image_b64:
|
|
return {"error": "Missing 'init_image' in input data"}
|
|
|
|
image_bytes = base64.b64decode(init_image_b64)
|
|
init_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
|
|
|
|
|
output_image = self.pipeline(
|
|
prompt=prompt,
|
|
image=init_image,
|
|
strength=strength,
|
|
guidance_scale=guidance_scale,
|
|
negative_prompt=negative_prompt
|
|
).images[0]
|
|
|
|
|
|
buffered = io.BytesIO()
|
|
output_image.save(buffered, format="PNG")
|
|
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
|
|
|
|
|
torch.cuda.empty_cache()
|
|
del output_image
|
|
del init_image
|
|
|
|
return {"generated_image": img_str}
|
|
|
|
except Exception as e:
|
|
return {"error": str(e)}
|
|
|