cjwalch commited on
Commit
71b16c4
·
verified ·
1 Parent(s): c7cbf75

Upload 2 files

Browse files

added files for deploying to hf endpoints

Files changed (2) hide show
  1. handler.py +54 -0
  2. requirements.txt +5 -0
handler.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any
2
+ import torch
3
+ import base64
4
+ import io
5
+ from PIL import Image
6
+ from diffusers import AutoPipelineForImage2Image
7
+
8
+ class EndpointHandler:
9
+ def __init__(self, path=""):
10
+ """Initialize the model from the given path."""
11
+ self.pipeline = AutoPipelineForImage2Image.from_pretrained(
12
+ "cjwalch/kandinsky-endpoint",
13
+ torch_dtype=torch.float16,
14
+ use_safetensors=True
15
+ )
16
+ self.pipeline.enable_model_cpu_offload()
17
+ if torch.cuda.is_available():
18
+ self.pipeline.to("cuda")
19
+
20
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
21
+ """Run inference on the input image and return a base64-encoded result."""
22
+ try:
23
+ # Extract input parameters
24
+ prompt = data.get("inputs", "")
25
+ strength = float(data.get("strength", 0.6))
26
+ guidance_scale = float(data.get("guidance_scale", 7.0))
27
+ negative_prompt = data.get("negative_prompt", "blurry, ugly, deformed")
28
+
29
+ # Decode base64 image
30
+ init_image_b64 = data.get("init_image", None)
31
+ if not init_image_b64:
32
+ return {"error": "Missing 'init_image' in input data"}
33
+
34
+ image_bytes = base64.b64decode(init_image_b64)
35
+ init_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
36
+
37
+ # Generate output image
38
+ output_image = self.pipeline(
39
+ prompt=prompt,
40
+ image=init_image,
41
+ strength=strength,
42
+ guidance_scale=guidance_scale,
43
+ negative_prompt=negative_prompt
44
+ ).images[0]
45
+
46
+ # Convert to base64
47
+ buffered = io.BytesIO()
48
+ output_image.save(buffered, format="PNG")
49
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
50
+
51
+ return {"generated_image": img_str}
52
+
53
+ except Exception as e:
54
+ return {"error": str(e)}
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch==2.5.1+cu124
2
+ torchvision==0.18.0+cu124
3
+ torchaudio==2.5.1+cu124
4
+ diffusers==0.17.0.dev0
5
+ Pillow==10.0.0