y22ma commited on
Commit
e55fddf
·
1 Parent(s): 6ba2ab1

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +129 -0
handler.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import base64
3
+ from PIL import Image
4
+ from io import BytesIO
5
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
6
+ import torch
7
+
8
+
9
+ import numpy as np
10
+ import cv2
11
+ import controlnet_hinter
12
+
13
+ # set device
14
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
+ if device.type != 'cuda':
16
+ raise ValueError("need to run on GPU")
17
+ # set mixed precision dtype
18
+ dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
19
+
20
+ # controlnet mapping for controlnet id and control hinter
21
+ CONTROLNET_MAPPING = {
22
+ "canny_edge": {
23
+ "model_id": "lllyasviel/sd-controlnet-canny",
24
+ "hinter": controlnet_hinter.hint_canny
25
+ },
26
+ "pose": {
27
+ "model_id": "lllyasviel/sd-controlnet-openpose",
28
+ "hinter": controlnet_hinter.hint_openpose
29
+ },
30
+ "depth": {
31
+ "model_id": "lllyasviel/sd-controlnet-depth",
32
+ "hinter": controlnet_hinter.hint_depth
33
+ },
34
+ "scribble": {
35
+ "model_id": "lllyasviel/sd-controlnet-scribble",
36
+ "hinter": controlnet_hinter.hint_scribble,
37
+ },
38
+ "segmentation": {
39
+ "model_id": "lllyasviel/sd-controlnet-seg",
40
+ "hinter": controlnet_hinter.hint_segmentation,
41
+ },
42
+ "normal": {
43
+ "model_id": "lllyasviel/sd-controlnet-normal",
44
+ "hinter": controlnet_hinter.hint_normal,
45
+ },
46
+ "hed": {
47
+ "model_id": "lllyasviel/sd-controlnet-hed",
48
+ "hinter": controlnet_hinter.hint_hed,
49
+ },
50
+ "hough": {
51
+ "model_id": "lllyasviel/sd-controlnet-mlsd",
52
+ "hinter": controlnet_hinter.hint_hough,
53
+ }
54
+ }
55
+
56
+
57
+ class EndpointHandler():
58
+ def __init__(self, path=""):
59
+ # define default controlnet id and load controlnet
60
+ self.control_type = "normal"
61
+ self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"],torch_dtype=dtype).to(device)
62
+
63
+ # Load StableDiffusionControlNetPipeline
64
+ self.stable_diffusion_id = "runwayml/stable-diffusion-v1-5"
65
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id,
66
+ controlnet=self.controlnet,
67
+ torch_dtype=dtype,
68
+ safety_checker=None).to(device)
69
+ # Define Generator with seed
70
+ self.generator = torch.Generator(device="cpu").manual_seed(3)
71
+
72
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
73
+ """
74
+ :param data: A dictionary contains `inputs` and optional `image` field.
75
+ :return: A dictionary with `image` field contains image in base64.
76
+ """
77
+ prompt = data.pop("inputs", None)
78
+ image = data.pop("image", None)
79
+ controlnet_type = data.pop("controlnet_type", None)
80
+
81
+ # Check if neither prompt nor image is provided
82
+ if prompt is None and image is None:
83
+ return {"error": "Please provide a prompt and base64 encoded image."}
84
+
85
+ # Check if a new controlnet is provided
86
+ if controlnet_type is not None and controlnet_type != self.control_type:
87
+ print(f"changing controlnet from {self.control_type} to {controlnet_type} using {CONTROLNET_MAPPING[controlnet_type]['model_id']} model")
88
+ self.control_type = controlnet_type
89
+ self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"],
90
+ torch_dtype=dtype).to(device)
91
+ self.pipe.controlnet = self.controlnet
92
+
93
+
94
+ # hyperparamters
95
+ num_inference_steps = data.pop("num_inference_steps", 30)
96
+ guidance_scale = data.pop("guidance_scale", 7.5)
97
+ negative_prompt = data.pop("negative_prompt", None)
98
+ height = data.pop("height", None)
99
+ width = data.pop("width", None)
100
+ controlnet_conditioning_scale = data.pop("controlnet_conditioning_scale", 1.0)
101
+
102
+ # process image
103
+ image = self.decode_base64_image(image)
104
+ control_image = CONTROLNET_MAPPING[self.control_type]["hinter"](image)
105
+
106
+ # run inference pipeline
107
+ out = self.pipe(
108
+ prompt=prompt,
109
+ negative_prompt=negative_prompt,
110
+ image=control_image,
111
+ num_inference_steps=num_inference_steps,
112
+ guidance_scale=guidance_scale,
113
+ num_images_per_prompt=1,
114
+ height=height,
115
+ width=width,
116
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
117
+ generator=self.generator
118
+ )
119
+
120
+
121
+ # return first generate PIL image
122
+ return out.images[0]
123
+
124
+ # helper to decode input image
125
+ def decode_base64_image(self, image_string):
126
+ base64_image = base64.b64decode(image_string)
127
+ buffer = BytesIO(base64_image)
128
+ image = Image.open(buffer)
129
+ return image