toby007 commited on
Commit
5eb1745
·
1 Parent(s): 92fc27b

update test script and ignore file

Browse files
Files changed (2) hide show
  1. .gitignore +1 -0
  2. huggingface_inference.py +152 -0
.gitignore CHANGED
@@ -1 +1,2 @@
1
  venv
 
 
1
  venv
2
+ __pycache__
huggingface_inference.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import base64
3
+ import os
4
+ from PIL import Image
5
+ import io
6
+ import time
7
+ import datetime
8
+
9
+ def encode_image_to_base64(image_path):
10
+ """Convert an image file to base64 string."""
11
+ with open(image_path, "rb") as image_file:
12
+ return base64.b64encode(image_file.read()).decode('utf-8')
13
+
14
+ def process_image(
15
+ image_path,
16
+ mask_path=None,
17
+ prompt="",
18
+ height=1632,
19
+ width=1232,
20
+ guidance_scale=30,
21
+ num_inference_steps=50,
22
+ max_sequence_length=512,
23
+ api_token=None,
24
+ output_path="output_image.jpg"
25
+ ):
26
+ """
27
+ Send a request to the Hugging Face Inference Endpoint.
28
+
29
+ Args:
30
+ image_path (str): Path to the input image
31
+ mask_path (str, optional): Path to the mask image
32
+ prompt (str): Text prompt to guide the model
33
+ height (int): Output image height
34
+ width (int): Output image width
35
+ guidance_scale (float): Guidance scale for the model
36
+ num_inference_steps (int): Number of inference steps
37
+ max_sequence_length (int): Maximum sequence length
38
+ api_token (str): Hugging Face API token
39
+ output_path (str): Path to save the output image
40
+
41
+ Returns:
42
+ The response from the API or the path to the saved image
43
+ """
44
+ # Log start time
45
+ start_time = time.time()
46
+ start_timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
47
+ print(f"Request started at: {start_timestamp}")
48
+
49
+ # Endpoint URL
50
+ url = "https://pewtn8bsankvfriv.us-east-1.aws.endpoints.huggingface.cloud"
51
+
52
+ # Get API token (from environment variable or parameter)
53
+ if api_token is None:
54
+ api_token = os.environ.get("HF_API_TOKEN")
55
+ if api_token is None:
56
+ raise ValueError("API token not provided. Please set HF_API_TOKEN environment variable or pass it as a parameter.")
57
+
58
+ # Encode image to base64
59
+ base64_image = encode_image_to_base64(image_path)
60
+
61
+ # Encode mask to base64 if provided
62
+ base64_mask = None
63
+ if mask_path:
64
+ base64_mask = encode_image_to_base64(mask_path)
65
+
66
+ # Prepare payload
67
+ payload = {
68
+ "inputs": {
69
+ "image": base64_image,
70
+ "prompt": prompt
71
+ },
72
+ "parameters": {
73
+ "height": height,
74
+ "width": width,
75
+ "guidance_scale": guidance_scale,
76
+ "num_inference_steps": num_inference_steps,
77
+ "max_sequence_length": max_sequence_length,
78
+ }
79
+ }
80
+
81
+ # Add mask if provided
82
+ if base64_mask:
83
+ payload["inputs"]["mask"] = base64_mask
84
+
85
+ # Set up headers
86
+ headers = {
87
+ "Authorization": f"Bearer {api_token}",
88
+ "Content-Type": "application/json"
89
+ }
90
+
91
+ # Make the request
92
+ print(f"Sending request to {url}...")
93
+ response = requests.post(url, json=payload, headers=headers)
94
+
95
+ # Log end time
96
+ end_time = time.time()
97
+ end_timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
98
+ elapsed_time = end_time - start_time
99
+ print(f"Request completed at: {end_timestamp}")
100
+ print(f"Total request time: {elapsed_time:.2f} seconds ({elapsed_time/60:.2f} minutes)")
101
+
102
+ # Handle the response
103
+ if response.status_code == 200:
104
+ try:
105
+ response_data = response.json()
106
+
107
+ # Check if response has the expected format
108
+ if "image" in response_data and "status" in response_data and response_data["status"] == "success":
109
+ # Decode the base64 image
110
+ image_data = base64.b64decode(response_data["image"])
111
+
112
+ # Convert to PIL Image
113
+ image = Image.open(io.BytesIO(image_data))
114
+
115
+ # Save the image to the specified output path
116
+ image.save(output_path)
117
+ print(f"Image successfully saved to {output_path}")
118
+
119
+ return output_path
120
+ else:
121
+ print("Unexpected response format:", response_data)
122
+ return response_data
123
+ except Exception as e:
124
+ print(f"Error processing response: {e}")
125
+ return response.json()
126
+ else:
127
+ print(f"Error: {response.status_code}")
128
+ print(response.text)
129
+ return None
130
+
131
+ if __name__ == "__main__":
132
+ # Example usage
133
+ output_file = "generated_image1.jpg"
134
+
135
+ # 从环境变量获取API令牌,或者提示用户输入
136
+ api_token = os.environ.get("HF_API_TOKEN")
137
+ if not api_token:
138
+ api_token = input("请输入您的Hugging Face API令牌: ")
139
+
140
+ print("Starting image processing...")
141
+ result = process_image(
142
+ image_path="cup.png",
143
+ mask_path="cup_mask.png", # Optional
144
+ prompt="a blue paper cup",
145
+ api_token=api_token,
146
+ output_path=output_file
147
+ )
148
+
149
+ if result == output_file:
150
+ print(f"Processing completed successfully. Image saved to {output_file}")
151
+ else:
152
+ print("Processing completed with unexpected result:", result)