Update app.py
Browse files
app.py
CHANGED
@@ -8,7 +8,6 @@ from tqdm import tqdm
|
|
8 |
import gradio as gr
|
9 |
import base64
|
10 |
import io
|
11 |
-
from fastapi import Request
|
12 |
|
13 |
from safetensors.torch import save_file
|
14 |
from src.pipeline import FluxPipeline
|
@@ -20,7 +19,7 @@ base_path = "black-forest-labs/FLUX.1-dev"
|
|
20 |
lora_base_path = "./models"
|
21 |
|
22 |
# Environment variable for API token (set this in your Hugging Face space settings)
|
23 |
-
API_TOKEN = os.environ.get("
|
24 |
|
25 |
# Initialize the pipeline
|
26 |
pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16)
|
@@ -39,21 +38,10 @@ def verify_token(token):
|
|
39 |
|
40 |
# Define the Gradio interface with token verification
|
41 |
@spaces.GPU()
|
42 |
-
def single_condition_generate_image(prompt, spatial_img, height, width, seed, control_type, api_token="", request
|
43 |
-
# Check authentication
|
44 |
-
if not
|
45 |
-
|
46 |
-
auth_header = request.headers.get("Authorization")
|
47 |
-
if not auth_header or not auth_header.startswith("Bearer "):
|
48 |
-
return {"error": "Unauthorized access. Invalid or missing token in Authorization header."}
|
49 |
-
|
50 |
-
token = auth_header.replace("Bearer ", "")
|
51 |
-
if not verify_token(token):
|
52 |
-
return {"error": "Unauthorized access. Invalid token."}
|
53 |
-
else:
|
54 |
-
# For UI requests, check the token input
|
55 |
-
if not verify_token(api_token):
|
56 |
-
return "Unauthorized: Please enter a valid API token"
|
57 |
|
58 |
try:
|
59 |
# Set the control type
|
@@ -77,26 +65,45 @@ def single_condition_generate_image(prompt, spatial_img, height, width, seed, co
|
|
77 |
).images[0]
|
78 |
clear_cache(pipe.transformer)
|
79 |
|
80 |
-
#
|
81 |
-
|
82 |
-
buffered = io.BytesIO()
|
83 |
-
image.save(buffered, format="PNG")
|
84 |
-
img_str = base64.b64encode(buffered.getvalue()).decode()
|
85 |
-
return {"meta": {"format": "png", "base64": img_str}}
|
86 |
-
|
87 |
return image
|
88 |
|
89 |
except Exception as e:
|
90 |
error_msg = f"Error during image generation: {str(e)}"
|
91 |
print(error_msg)
|
92 |
-
if not request.is_from_ui:
|
93 |
-
return {"error": error_msg}
|
94 |
return None
|
95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
# Define the Gradio interface components
|
97 |
control_types = ["Ghibli"]
|
98 |
|
99 |
-
# Example data
|
100 |
single_examples = [
|
101 |
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/00.png"), 680, 1024, 5, "Ghibli", API_TOKEN],
|
102 |
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/02.png"), 560, 1024, 42, "Ghibli", API_TOKEN],
|
@@ -151,5 +158,14 @@ with gr.Blocks() as demo:
|
|
151 |
outputs=single_output_image
|
152 |
)
|
153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
# Launch the Gradio app
|
155 |
-
demo.
|
|
|
8 |
import gradio as gr
|
9 |
import base64
|
10 |
import io
|
|
|
11 |
|
12 |
from safetensors.torch import save_file
|
13 |
from src.pipeline import FluxPipeline
|
|
|
19 |
lora_base_path = "./models"
|
20 |
|
21 |
# Environment variable for API token (set this in your Hugging Face space settings)
|
22 |
+
API_TOKEN = os.environ.get("HF_API_TOKEN")
|
23 |
|
24 |
# Initialize the pipeline
|
25 |
pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16)
|
|
|
38 |
|
39 |
# Define the Gradio interface with token verification
|
40 |
@spaces.GPU()
|
41 |
+
def single_condition_generate_image(prompt, spatial_img, height, width, seed, control_type, api_token="", request=None):
|
42 |
+
# Check authentication
|
43 |
+
if not verify_token(api_token):
|
44 |
+
return "Unauthorized: Please enter a valid API token"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
try:
|
47 |
# Set the control type
|
|
|
65 |
).images[0]
|
66 |
clear_cache(pipe.transformer)
|
67 |
|
68 |
+
# We'll always return the PIL image for UI
|
69 |
+
# The API will extract base64 from the returned image
|
|
|
|
|
|
|
|
|
|
|
70 |
return image
|
71 |
|
72 |
except Exception as e:
|
73 |
error_msg = f"Error during image generation: {str(e)}"
|
74 |
print(error_msg)
|
|
|
|
|
75 |
return None
|
76 |
|
77 |
+
# Define an API endpoint that uses the main function but returns proper JSON
|
78 |
+
@spaces.GPU()
|
79 |
+
def api_generate_image(prompt, spatial_img, height, width, seed, control_type, api_token=""):
|
80 |
+
# Verify the API token
|
81 |
+
if not verify_token(api_token):
|
82 |
+
return {"error": "Unauthorized access. Invalid token."}
|
83 |
+
|
84 |
+
try:
|
85 |
+
# Use the same function as the UI
|
86 |
+
image = single_condition_generate_image(
|
87 |
+
prompt, spatial_img, height, width, seed, control_type, api_token
|
88 |
+
)
|
89 |
+
|
90 |
+
if image is None or isinstance(image, str):
|
91 |
+
# Error occurred
|
92 |
+
error_msg = image if isinstance(image, str) else "Image generation failed"
|
93 |
+
return {"error": error_msg}
|
94 |
+
|
95 |
+
# Return the image directly instead of converting to base64
|
96 |
+
return image
|
97 |
+
|
98 |
+
except Exception as e:
|
99 |
+
error_msg = f"API error: {str(e)}"
|
100 |
+
print(error_msg)
|
101 |
+
return {"error": error_msg}
|
102 |
+
|
103 |
# Define the Gradio interface components
|
104 |
control_types = ["Ghibli"]
|
105 |
|
106 |
+
# Example data - add the API token for convenience (assuming only you can see the examples)
|
107 |
single_examples = [
|
108 |
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/00.png"), 680, 1024, 5, "Ghibli", API_TOKEN],
|
109 |
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/02.png"), 560, 1024, 42, "Ghibli", API_TOKEN],
|
|
|
158 |
outputs=single_output_image
|
159 |
)
|
160 |
|
161 |
+
# Create an API endpoint that clients can use programmatically
|
162 |
+
demo.queue()
|
163 |
+
|
164 |
+
# Add the API endpoint
|
165 |
+
demo.load(api_generate_image,
|
166 |
+
inputs=[prompt, spatial_img, height, width, seed, control_type, api_token],
|
167 |
+
outputs=gr.JSON(),
|
168 |
+
api_name="generate")
|
169 |
+
|
170 |
# Launch the Gradio app
|
171 |
+
demo.launch()
|