toby007 commited on
Commit
e51059c
·
1 Parent(s): 7dbdbc9

update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +50 -11
handler.py CHANGED
@@ -4,13 +4,18 @@ import torch
4
  import base64
5
  from io import BytesIO
6
 
 
7
  pipe = StableDiffusionInpaintPipeline.from_pretrained(
8
- "./", torch_dtype=torch.float16
 
9
  ).to("cuda" if torch.cuda.is_available() else "cpu")
10
 
11
  def decode_image(b64_string):
12
- image_data = base64.b64decode(b64_string)
13
- return Image.open(BytesIO(image_data)).convert("RGB")
 
 
 
14
 
15
  def encode_image(image):
16
  buffer = BytesIO()
@@ -18,13 +23,47 @@ def encode_image(image):
18
  return base64.b64encode(buffer.getvalue()).decode("utf-8")
19
 
20
  def handler(data):
21
- inputs = data.get("inputs", {})
22
- prompt = inputs.get("prompt", "商务风格形象照,高清写实,蓝色西装")
23
- image_b64 = inputs.get("image")
24
- mask_b64 = inputs.get("mask")
25
 
26
- image = decode_image(image_b64)
27
- mask = decode_image(mask_b64)
 
 
28
 
29
- result = pipe(prompt=prompt, image=image, mask_image=mask).images[0]
30
- return {"image": encode_image(result)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import base64
5
  from io import BytesIO
6
 
7
+ # ✅ 使用 huggingface repo id 形式加载(确保路径正确)
8
  pipe = StableDiffusionInpaintPipeline.from_pretrained(
9
+ "shangguanyanyan/flux1-fill-dev-custom", # 请确认仓库ID无误
10
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
11
  ).to("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
  def decode_image(b64_string):
14
+ try:
15
+ image_data = base64.b64decode(b64_string)
16
+ return Image.open(BytesIO(image_data)).convert("RGB")
17
+ except Exception as e:
18
+ raise ValueError(f"解码图像失败: {str(e)}")
19
 
20
  def encode_image(image):
21
  buffer = BytesIO()
 
23
  return base64.b64encode(buffer.getvalue()).decode("utf-8")
24
 
25
  def handler(data):
26
+ try:
27
+ inputs = data.get("inputs", {})
28
+ prompt = inputs.get("prompt", "高清写实风格人物形象照")
 
29
 
30
+ image_b64 = inputs.get("image")
31
+ mask_b64 = inputs.get("mask")
32
+ if not image_b64 or not mask_b64:
33
+ raise ValueError("缺少必要的 image 或 mask 参数")
34
 
35
+ image = decode_image(image_b64)
36
+ mask = decode_image(mask_b64)
37
+
38
+ # 默认参数(支持调整)
39
+ height = int(inputs.get("height", image.height))
40
+ width = int(inputs.get("width", image.width))
41
+ steps = int(inputs.get("num_inference_steps", 30))
42
+ cfg_scale = float(inputs.get("guidance_scale", 7.5))
43
+
44
+ image = image.resize((width, height))
45
+ mask = mask.resize((width, height))
46
+
47
+ result = pipe(
48
+ prompt=prompt,
49
+ image=image,
50
+ mask_image=mask,
51
+ height=height,
52
+ width=width,
53
+ num_inference_steps=steps,
54
+ guidance_scale=cfg_scale
55
+ ).images[0]
56
+
57
+ return {
58
+ "status": "success",
59
+ "image": encode_image(result),
60
+ "meta": {
61
+ "prompt": prompt,
62
+ "size": f"{width}x{height}",
63
+ "steps": steps,
64
+ "cfg_scale": cfg_scale
65
+ }
66
+ }
67
+
68
+ except Exception as e:
69
+ return {"status": "failed", "error": str(e)}