toby007
commited on
Commit
·
5eb1745
1
Parent(s):
92fc27b
update test script and ignore file
Browse files- .gitignore +1 -0
- huggingface_inference.py +152 -0
.gitignore
CHANGED
@@ -1 +1,2 @@
|
|
1 |
venv
|
|
|
|
1 |
venv
|
2 |
+
__pycache__
|
huggingface_inference.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import base64
|
3 |
+
import os
|
4 |
+
from PIL import Image
|
5 |
+
import io
|
6 |
+
import time
|
7 |
+
import datetime
|
8 |
+
|
9 |
+
def encode_image_to_base64(image_path):
|
10 |
+
"""Convert an image file to base64 string."""
|
11 |
+
with open(image_path, "rb") as image_file:
|
12 |
+
return base64.b64encode(image_file.read()).decode('utf-8')
|
13 |
+
|
14 |
+
def process_image(
|
15 |
+
image_path,
|
16 |
+
mask_path=None,
|
17 |
+
prompt="",
|
18 |
+
height=1632,
|
19 |
+
width=1232,
|
20 |
+
guidance_scale=30,
|
21 |
+
num_inference_steps=50,
|
22 |
+
max_sequence_length=512,
|
23 |
+
api_token=None,
|
24 |
+
output_path="output_image.jpg"
|
25 |
+
):
|
26 |
+
"""
|
27 |
+
Send a request to the Hugging Face Inference Endpoint.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
image_path (str): Path to the input image
|
31 |
+
mask_path (str, optional): Path to the mask image
|
32 |
+
prompt (str): Text prompt to guide the model
|
33 |
+
height (int): Output image height
|
34 |
+
width (int): Output image width
|
35 |
+
guidance_scale (float): Guidance scale for the model
|
36 |
+
num_inference_steps (int): Number of inference steps
|
37 |
+
max_sequence_length (int): Maximum sequence length
|
38 |
+
api_token (str): Hugging Face API token
|
39 |
+
output_path (str): Path to save the output image
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
The response from the API or the path to the saved image
|
43 |
+
"""
|
44 |
+
# Log start time
|
45 |
+
start_time = time.time()
|
46 |
+
start_timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
|
47 |
+
print(f"Request started at: {start_timestamp}")
|
48 |
+
|
49 |
+
# Endpoint URL
|
50 |
+
url = "https://pewtn8bsankvfriv.us-east-1.aws.endpoints.huggingface.cloud"
|
51 |
+
|
52 |
+
# Get API token (from environment variable or parameter)
|
53 |
+
if api_token is None:
|
54 |
+
api_token = os.environ.get("HF_API_TOKEN")
|
55 |
+
if api_token is None:
|
56 |
+
raise ValueError("API token not provided. Please set HF_API_TOKEN environment variable or pass it as a parameter.")
|
57 |
+
|
58 |
+
# Encode image to base64
|
59 |
+
base64_image = encode_image_to_base64(image_path)
|
60 |
+
|
61 |
+
# Encode mask to base64 if provided
|
62 |
+
base64_mask = None
|
63 |
+
if mask_path:
|
64 |
+
base64_mask = encode_image_to_base64(mask_path)
|
65 |
+
|
66 |
+
# Prepare payload
|
67 |
+
payload = {
|
68 |
+
"inputs": {
|
69 |
+
"image": base64_image,
|
70 |
+
"prompt": prompt
|
71 |
+
},
|
72 |
+
"parameters": {
|
73 |
+
"height": height,
|
74 |
+
"width": width,
|
75 |
+
"guidance_scale": guidance_scale,
|
76 |
+
"num_inference_steps": num_inference_steps,
|
77 |
+
"max_sequence_length": max_sequence_length,
|
78 |
+
}
|
79 |
+
}
|
80 |
+
|
81 |
+
# Add mask if provided
|
82 |
+
if base64_mask:
|
83 |
+
payload["inputs"]["mask"] = base64_mask
|
84 |
+
|
85 |
+
# Set up headers
|
86 |
+
headers = {
|
87 |
+
"Authorization": f"Bearer {api_token}",
|
88 |
+
"Content-Type": "application/json"
|
89 |
+
}
|
90 |
+
|
91 |
+
# Make the request
|
92 |
+
print(f"Sending request to {url}...")
|
93 |
+
response = requests.post(url, json=payload, headers=headers)
|
94 |
+
|
95 |
+
# Log end time
|
96 |
+
end_time = time.time()
|
97 |
+
end_timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
|
98 |
+
elapsed_time = end_time - start_time
|
99 |
+
print(f"Request completed at: {end_timestamp}")
|
100 |
+
print(f"Total request time: {elapsed_time:.2f} seconds ({elapsed_time/60:.2f} minutes)")
|
101 |
+
|
102 |
+
# Handle the response
|
103 |
+
if response.status_code == 200:
|
104 |
+
try:
|
105 |
+
response_data = response.json()
|
106 |
+
|
107 |
+
# Check if response has the expected format
|
108 |
+
if "image" in response_data and "status" in response_data and response_data["status"] == "success":
|
109 |
+
# Decode the base64 image
|
110 |
+
image_data = base64.b64decode(response_data["image"])
|
111 |
+
|
112 |
+
# Convert to PIL Image
|
113 |
+
image = Image.open(io.BytesIO(image_data))
|
114 |
+
|
115 |
+
# Save the image to the specified output path
|
116 |
+
image.save(output_path)
|
117 |
+
print(f"Image successfully saved to {output_path}")
|
118 |
+
|
119 |
+
return output_path
|
120 |
+
else:
|
121 |
+
print("Unexpected response format:", response_data)
|
122 |
+
return response_data
|
123 |
+
except Exception as e:
|
124 |
+
print(f"Error processing response: {e}")
|
125 |
+
return response.json()
|
126 |
+
else:
|
127 |
+
print(f"Error: {response.status_code}")
|
128 |
+
print(response.text)
|
129 |
+
return None
|
130 |
+
|
131 |
+
if __name__ == "__main__":
|
132 |
+
# Example usage
|
133 |
+
output_file = "generated_image1.jpg"
|
134 |
+
|
135 |
+
# 从环境变量获取API令牌,或者提示用户输入
|
136 |
+
api_token = os.environ.get("HF_API_TOKEN")
|
137 |
+
if not api_token:
|
138 |
+
api_token = input("请输入您的Hugging Face API令牌: ")
|
139 |
+
|
140 |
+
print("Starting image processing...")
|
141 |
+
result = process_image(
|
142 |
+
image_path="cup.png",
|
143 |
+
mask_path="cup_mask.png", # Optional
|
144 |
+
prompt="a blue paper cup",
|
145 |
+
api_token=api_token,
|
146 |
+
output_path=output_file
|
147 |
+
)
|
148 |
+
|
149 |
+
if result == output_file:
|
150 |
+
print(f"Processing completed successfully. Image saved to {output_file}")
|
151 |
+
else:
|
152 |
+
print("Processing completed with unexpected result:", result)
|