toby007
commited on
Commit
·
3c71e14
1
Parent(s):
d74247d
update handler
Browse files- handler.py +55 -33
- local_test.py +38 -0
- model_index.json +29 -11
- requirements.txt +2 -0
handler.py
CHANGED
@@ -1,56 +1,78 @@
|
|
1 |
-
from diffusers import DiffusionPipeline
|
2 |
-
from diffusers.utils import load_image
|
3 |
-
from PIL import Image
|
4 |
-
import torch
|
5 |
import base64
|
6 |
from io import BytesIO
|
7 |
-
import
|
8 |
-
import os
|
9 |
-
from pathlib import Path
|
10 |
-
|
11 |
-
model_dir = Path(__file__).parent.resolve() # 获取handler.py所在目录的绝对路径
|
12 |
|
|
|
|
|
|
|
13 |
|
14 |
-
# 关键:注册 FluxFillPipeline 类(自动从 model_index.json 解析)
|
15 |
-
pipe = DiffusionPipeline.from_pretrained(
|
16 |
-
str(model_dir),
|
17 |
-
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
|
18 |
-
).to("cuda" if torch.cuda.is_available() else "cpu")
|
19 |
|
20 |
def decode_image(b64_string):
|
21 |
image_data = base64.b64decode(b64_string)
|
22 |
return Image.open(BytesIO(image_data)).convert("RGB")
|
23 |
|
|
|
24 |
def encode_image(image):
|
25 |
buffer = BytesIO()
|
26 |
image.save(buffer, format="PNG")
|
27 |
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
28 |
|
29 |
-
def handler(data):
|
30 |
-
try:
|
31 |
-
inputs = data.get("inputs", {})
|
32 |
-
prompt = inputs.get("prompt", "写实风格形象照")
|
33 |
-
image_b64 = inputs.get("image")
|
34 |
-
mask_b64 = inputs.get("mask")
|
35 |
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
38 |
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
prompt=prompt,
|
44 |
image=image,
|
45 |
mask_image=mask,
|
46 |
-
|
47 |
-
width=int(inputs.get("width", 1024)),
|
48 |
-
guidance_scale=float(inputs.get("guidance_scale", 7.5)),
|
49 |
-
num_inference_steps=int(inputs.get("steps", 30))
|
50 |
).images[0]
|
51 |
|
52 |
-
return {"image": encode_image(
|
53 |
-
|
54 |
-
except Exception as e:
|
55 |
-
return {"error": str(e), "status": "failed"}
|
56 |
|
|
|
|
|
|
|
|
|
|
|
1 |
import base64
|
2 |
from io import BytesIO
|
3 |
+
from typing import Any, Dict
|
|
|
|
|
|
|
|
|
4 |
|
5 |
+
import torch
|
6 |
+
from diffusers import FluxFillPipeline
|
7 |
+
from PIL import Image
|
8 |
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
def decode_image(b64_string):
|
11 |
image_data = base64.b64decode(b64_string)
|
12 |
return Image.open(BytesIO(image_data)).convert("RGB")
|
13 |
|
14 |
+
|
15 |
def encode_image(image):
|
16 |
buffer = BytesIO()
|
17 |
image.save(buffer, format="PNG")
|
18 |
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
+
class EndpointHandler:
|
22 |
+
def __init__(self, path="shangguanyanyan/flux1-fill-dev-custom"):
|
23 |
+
self.pipe = FluxFillPipeline.from_pretrained(
|
24 |
+
path, torch_dtype=torch.bfloat16
|
25 |
+
).to("cuda" if torch.cuda.is_available() else "cpu")
|
26 |
|
27 |
+
self.parameters = {
|
28 |
+
"height": 1632,
|
29 |
+
"width": 1232,
|
30 |
+
"guidance_scale": 30,
|
31 |
+
"num_inference_steps": 50,
|
32 |
+
"max_sequence_length": 512,
|
33 |
+
"generator": torch.Generator("cpu").manual_seed(0),
|
34 |
+
}
|
35 |
|
36 |
+
def __call__(self, data: Any) -> Dict[str, Any]:
|
37 |
+
"""
|
38 |
+
data: {
|
39 |
+
"inputs": {
|
40 |
+
"image": base64_image,
|
41 |
+
"mask": base64_mask,
|
42 |
+
"prompt": prompt
|
43 |
+
},
|
44 |
+
"parameters": {
|
45 |
+
"height": 1632,
|
46 |
+
"width": 1232,
|
47 |
+
"guidance_scale": 30,
|
48 |
+
"num_inference_steps": 50,
|
49 |
+
"max_sequence_length": 512,
|
50 |
+
}
|
51 |
+
}
|
52 |
+
"""
|
53 |
+
inputs = data.pop("inputs", data)
|
54 |
+
parameters = data.pop("parameters", {})
|
55 |
+
|
56 |
+
parameters.update(self.parameters)
|
57 |
+
base64_image = inputs.pop("image", "")
|
58 |
+
base64_mask = inputs.pop("mask", "")
|
59 |
+
prompt = inputs.pop("prompt", "")
|
60 |
+
|
61 |
+
if not base64_image or not base64_mask or not prompt:
|
62 |
+
return {
|
63 |
+
"error": "Please provide image, mask and prompt",
|
64 |
+
"status": "failed",
|
65 |
+
}
|
66 |
+
|
67 |
+
image = decode_image(base64_image)
|
68 |
+
mask = decode_image(base64_mask)
|
69 |
+
|
70 |
+
image = self.pipe(
|
71 |
prompt=prompt,
|
72 |
image=image,
|
73 |
mask_image=mask,
|
74 |
+
**parameters,
|
|
|
|
|
|
|
75 |
).images[0]
|
76 |
|
77 |
+
return {"image": encode_image(image), "status": "success"}
|
|
|
|
|
|
|
78 |
|
local_test.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers.utils import load_image
|
2 |
+
|
3 |
+
from handler import EndpointHandler, encode_image
|
4 |
+
|
5 |
+
# init handler
|
6 |
+
my_handler = EndpointHandler(path=".")
|
7 |
+
# my_handler = EndpointHandler(path="black-forest-labs/FLUX.1-Fill-dev")
|
8 |
+
|
9 |
+
image = load_image(
|
10 |
+
"./cup.png"
|
11 |
+
)
|
12 |
+
mask = load_image(
|
13 |
+
"./cup_mask.png"
|
14 |
+
)
|
15 |
+
prompt = "a white paper cup"
|
16 |
+
|
17 |
+
data = {
|
18 |
+
"inputs": {
|
19 |
+
"image": encode_image(image),
|
20 |
+
"mask": encode_image(mask),
|
21 |
+
"prompt": prompt,
|
22 |
+
},
|
23 |
+
"parameters": {
|
24 |
+
"height": 1632,
|
25 |
+
"width": 1232,
|
26 |
+
"guidance_scale": 30,
|
27 |
+
"num_inference_steps": 50,
|
28 |
+
"max_sequence_length": 512,
|
29 |
+
},
|
30 |
+
}
|
31 |
+
# test the handler
|
32 |
+
print("out")
|
33 |
+
exit()
|
34 |
+
result = my_handler(data=data)
|
35 |
+
|
36 |
+
|
37 |
+
print("result:", result)
|
38 |
+
|
model_index.json
CHANGED
@@ -1,14 +1,32 @@
|
|
1 |
{
|
2 |
"_class_name": "FluxFillPipeline",
|
3 |
"_diffusers_version": "0.32.0.dev0",
|
4 |
-
"scheduler": [
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
"
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
"
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
{
|
2 |
"_class_name": "FluxFillPipeline",
|
3 |
"_diffusers_version": "0.32.0.dev0",
|
4 |
+
"scheduler": [
|
5 |
+
"diffusers",
|
6 |
+
"FlowMatchEulerDiscreteScheduler"
|
7 |
+
],
|
8 |
+
"text_encoder": [
|
9 |
+
"transformers",
|
10 |
+
"CLIPTextModel"
|
11 |
+
],
|
12 |
+
"text_encoder_2": [
|
13 |
+
"transformers",
|
14 |
+
"T5EncoderModel"
|
15 |
+
],
|
16 |
+
"tokenizer": [
|
17 |
+
"transformers",
|
18 |
+
"CLIPTokenizer"
|
19 |
+
],
|
20 |
+
"tokenizer_2": [
|
21 |
+
"transformers",
|
22 |
+
"T5TokenizerFast"
|
23 |
+
],
|
24 |
+
"transformer": [
|
25 |
+
"diffusers",
|
26 |
+
"FluxTransformer2DModel"
|
27 |
+
],
|
28 |
+
"vae": [
|
29 |
+
"diffusers",
|
30 |
+
"AutoencoderKL"
|
31 |
+
]
|
32 |
+
}
|
requirements.txt
CHANGED
@@ -3,3 +3,5 @@ transformers
|
|
3 |
torch
|
4 |
accelerate
|
5 |
safetensors
|
|
|
|
|
|
3 |
torch
|
4 |
accelerate
|
5 |
safetensors
|
6 |
+
protobuf
|
7 |
+
sentencepiece
|