Upload 2 files
Browse filesadded files for deploying to hf endpoints
- handler.py +54 -0
- requirements.txt +5 -0
handler.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Any
|
2 |
+
import torch
|
3 |
+
import base64
|
4 |
+
import io
|
5 |
+
from PIL import Image
|
6 |
+
from diffusers import AutoPipelineForImage2Image
|
7 |
+
|
8 |
+
class EndpointHandler:
|
9 |
+
def __init__(self, path=""):
|
10 |
+
"""Initialize the model from the given path."""
|
11 |
+
self.pipeline = AutoPipelineForImage2Image.from_pretrained(
|
12 |
+
"cjwalch/kandinsky-endpoint",
|
13 |
+
torch_dtype=torch.float16,
|
14 |
+
use_safetensors=True
|
15 |
+
)
|
16 |
+
self.pipeline.enable_model_cpu_offload()
|
17 |
+
if torch.cuda.is_available():
|
18 |
+
self.pipeline.to("cuda")
|
19 |
+
|
20 |
+
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
21 |
+
"""Run inference on the input image and return a base64-encoded result."""
|
22 |
+
try:
|
23 |
+
# Extract input parameters
|
24 |
+
prompt = data.get("inputs", "")
|
25 |
+
strength = float(data.get("strength", 0.6))
|
26 |
+
guidance_scale = float(data.get("guidance_scale", 7.0))
|
27 |
+
negative_prompt = data.get("negative_prompt", "blurry, ugly, deformed")
|
28 |
+
|
29 |
+
# Decode base64 image
|
30 |
+
init_image_b64 = data.get("init_image", None)
|
31 |
+
if not init_image_b64:
|
32 |
+
return {"error": "Missing 'init_image' in input data"}
|
33 |
+
|
34 |
+
image_bytes = base64.b64decode(init_image_b64)
|
35 |
+
init_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
36 |
+
|
37 |
+
# Generate output image
|
38 |
+
output_image = self.pipeline(
|
39 |
+
prompt=prompt,
|
40 |
+
image=init_image,
|
41 |
+
strength=strength,
|
42 |
+
guidance_scale=guidance_scale,
|
43 |
+
negative_prompt=negative_prompt
|
44 |
+
).images[0]
|
45 |
+
|
46 |
+
# Convert to base64
|
47 |
+
buffered = io.BytesIO()
|
48 |
+
output_image.save(buffered, format="PNG")
|
49 |
+
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
50 |
+
|
51 |
+
return {"generated_image": img_str}
|
52 |
+
|
53 |
+
except Exception as e:
|
54 |
+
return {"error": str(e)}
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.5.1+cu124
|
2 |
+
torchvision==0.18.0+cu124
|
3 |
+
torchaudio==2.5.1+cu124
|
4 |
+
diffusers==0.17.0.dev0
|
5 |
+
Pillow==10.0.0
|