mrcuddle commited on
Commit
66a2908
·
verified ·
1 Parent(s): 074f00a

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +37 -23
handler.py CHANGED
@@ -23,23 +23,32 @@ class EndpointHandler:
23
  def __call__(self, data: dict):
24
  """Custom call function for Hugging Face Inference Endpoints."""
25
  try:
26
- inputs = data.pop("inputs", data)
27
- encoded_image = data.pop("image", None)
28
- encoded_mask_image = data.pop("mask_image", None)
29
-
30
- num_inference_steps = data.pop("num_inference_steps", 25)
31
- guidance_scale = data.pop("guidance_scale", 7.5)
32
- negative_prompt = data.pop("negative_prompt", None)
33
- height = data.pop("height", None)
34
- width = data.pop("width", None)
35
-
36
- # Process images
37
- if encoded_image and encoded_mask_image:
38
- image = self.decode_base64_image(encoded_image)
39
- mask_image = self.decode_base64_image(encoded_mask_image)
40
- else:
41
- raise ValueError("Both image and mask_image are required")
42
-
 
 
 
 
 
 
 
 
 
43
  # Run inference
44
  output_image = self.pipeline(
45
  prompt=inputs,
@@ -52,17 +61,22 @@ class EndpointHandler:
52
  height=height,
53
  width=width
54
  ).images[0]
55
-
 
56
  return json.dumps({"output": self.encode_base64_image(output_image)})
 
57
  except Exception as e:
58
  return json.dumps({"error": str(e)})
59
 
60
  def decode_base64_image(self, image_string):
61
- """Decode base64 encoded image."""
62
- base64_image = base64.b64decode(image_string)
63
- buffer = io.BytesIO(base64_image)
64
- return Image.open(buffer).convert("RGB")
65
-
 
 
 
66
  def encode_base64_image(self, image):
67
  """Encode PIL image to base64."""
68
  buffered = io.BytesIO()
 
23
  def __call__(self, data: dict):
24
  """Custom call function for Hugging Face Inference Endpoints."""
25
  try:
26
+ # Extract inputs from JSON payload
27
+ inputs = data.get("inputs", "")
28
+ encoded_image = data.get("image", None)
29
+ encoded_mask_image = data.get("mask_image", None)
30
+
31
+ # Extract optional parameters with default values
32
+ num_inference_steps = data.get("num_inference_steps", 25)
33
+ guidance_scale = data.get("guidance_scale", 7.5)
34
+ negative_prompt = data.get("negative_prompt", None)
35
+ height = data.get("height", None)
36
+ width = data.get("width", None)
37
+
38
+ # Ensure both images are provided
39
+ if not encoded_image or not encoded_mask_image:
40
+ raise ValueError("Both 'image' and 'mask_image' are required in base64 format.")
41
+
42
+ # Decode base64 images
43
+ image = self.decode_base64_image(encoded_image)
44
+ mask_image = self.decode_base64_image(encoded_mask_image)
45
+
46
+ print("\n--- Running Inference ---")
47
+ print(f"Prompt: {inputs}")
48
+ print(f"Steps: {num_inference_steps}, Guidance Scale: {guidance_scale}")
49
+ print(f"Negative Prompt: {negative_prompt}")
50
+ print(f"Image Size: {image.size}, Mask Size: {mask_image.size}")
51
+
52
  # Run inference
53
  output_image = self.pipeline(
54
  prompt=inputs,
 
61
  height=height,
62
  width=width
63
  ).images[0]
64
+
65
+ # Return base64-encoded image
66
  return json.dumps({"output": self.encode_base64_image(output_image)})
67
+
68
  except Exception as e:
69
  return json.dumps({"error": str(e)})
70
 
71
  def decode_base64_image(self, image_string):
72
+ """Decode base64-encoded image to a PIL Image."""
73
+ try:
74
+ base64_image = base64.b64decode(image_string)
75
+ buffer = io.BytesIO(base64_image)
76
+ return Image.open(buffer).convert("RGB")
77
+ except Exception as e:
78
+ raise ValueError(f"Failed to decode base64 image: {e}")
79
+
80
  def encode_base64_image(self, image):
81
  """Encode PIL image to base64."""
82
  buffered = io.BytesIO()