Update app.py
Browse files
app.py
CHANGED
|
@@ -8,7 +8,6 @@ from tqdm import tqdm
|
|
| 8 |
import gradio as gr
|
| 9 |
import base64
|
| 10 |
import io
|
| 11 |
-
from fastapi import Request
|
| 12 |
|
| 13 |
from safetensors.torch import save_file
|
| 14 |
from src.pipeline import FluxPipeline
|
|
@@ -20,7 +19,7 @@ base_path = "black-forest-labs/FLUX.1-dev"
|
|
| 20 |
lora_base_path = "./models"
|
| 21 |
|
| 22 |
# Environment variable for API token (set this in your Hugging Face space settings)
|
| 23 |
-
API_TOKEN = os.environ.get("
|
| 24 |
|
| 25 |
# Initialize the pipeline
|
| 26 |
pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16)
|
|
@@ -39,21 +38,10 @@ def verify_token(token):
|
|
| 39 |
|
| 40 |
# Define the Gradio interface with token verification
|
| 41 |
@spaces.GPU()
|
| 42 |
-
def single_condition_generate_image(prompt, spatial_img, height, width, seed, control_type, api_token="", request
|
| 43 |
-
# Check authentication
|
| 44 |
-
if not
|
| 45 |
-
|
| 46 |
-
auth_header = request.headers.get("Authorization")
|
| 47 |
-
if not auth_header or not auth_header.startswith("Bearer "):
|
| 48 |
-
return {"error": "Unauthorized access. Invalid or missing token in Authorization header."}
|
| 49 |
-
|
| 50 |
-
token = auth_header.replace("Bearer ", "")
|
| 51 |
-
if not verify_token(token):
|
| 52 |
-
return {"error": "Unauthorized access. Invalid token."}
|
| 53 |
-
else:
|
| 54 |
-
# For UI requests, check the token input
|
| 55 |
-
if not verify_token(api_token):
|
| 56 |
-
return "Unauthorized: Please enter a valid API token"
|
| 57 |
|
| 58 |
try:
|
| 59 |
# Set the control type
|
|
@@ -77,26 +65,45 @@ def single_condition_generate_image(prompt, spatial_img, height, width, seed, co
|
|
| 77 |
).images[0]
|
| 78 |
clear_cache(pipe.transformer)
|
| 79 |
|
| 80 |
-
#
|
| 81 |
-
|
| 82 |
-
buffered = io.BytesIO()
|
| 83 |
-
image.save(buffered, format="PNG")
|
| 84 |
-
img_str = base64.b64encode(buffered.getvalue()).decode()
|
| 85 |
-
return {"meta": {"format": "png", "base64": img_str}}
|
| 86 |
-
|
| 87 |
return image
|
| 88 |
|
| 89 |
except Exception as e:
|
| 90 |
error_msg = f"Error during image generation: {str(e)}"
|
| 91 |
print(error_msg)
|
| 92 |
-
if not request.is_from_ui:
|
| 93 |
-
return {"error": error_msg}
|
| 94 |
return None
|
| 95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
# Define the Gradio interface components
|
| 97 |
control_types = ["Ghibli"]
|
| 98 |
|
| 99 |
-
# Example data
|
| 100 |
single_examples = [
|
| 101 |
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/00.png"), 680, 1024, 5, "Ghibli", API_TOKEN],
|
| 102 |
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/02.png"), 560, 1024, 42, "Ghibli", API_TOKEN],
|
|
@@ -151,5 +158,14 @@ with gr.Blocks() as demo:
|
|
| 151 |
outputs=single_output_image
|
| 152 |
)
|
| 153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
# Launch the Gradio app
|
| 155 |
-
demo.
|
|
|
|
| 8 |
import gradio as gr
|
| 9 |
import base64
|
| 10 |
import io
|
|
|
|
| 11 |
|
| 12 |
from safetensors.torch import save_file
|
| 13 |
from src.pipeline import FluxPipeline
|
|
|
|
| 19 |
lora_base_path = "./models"
|
| 20 |
|
| 21 |
# Environment variable for API token (set this in your Hugging Face space settings)
|
| 22 |
+
API_TOKEN = os.environ.get("HF_API_TOKEN")
|
| 23 |
|
| 24 |
# Initialize the pipeline
|
| 25 |
pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16)
|
|
|
|
| 38 |
|
| 39 |
# Define the Gradio interface with token verification
|
| 40 |
@spaces.GPU()
|
| 41 |
+
def single_condition_generate_image(prompt, spatial_img, height, width, seed, control_type, api_token="", request=None):
|
| 42 |
+
# Check authentication
|
| 43 |
+
if not verify_token(api_token):
|
| 44 |
+
return "Unauthorized: Please enter a valid API token"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
try:
|
| 47 |
# Set the control type
|
|
|
|
| 65 |
).images[0]
|
| 66 |
clear_cache(pipe.transformer)
|
| 67 |
|
| 68 |
+
# We'll always return the PIL image for UI
|
| 69 |
+
# The API will extract base64 from the returned image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
return image
|
| 71 |
|
| 72 |
except Exception as e:
|
| 73 |
error_msg = f"Error during image generation: {str(e)}"
|
| 74 |
print(error_msg)
|
|
|
|
|
|
|
| 75 |
return None
|
| 76 |
|
| 77 |
+
# Define an API endpoint that uses the main function but returns proper JSON
|
| 78 |
+
@spaces.GPU()
|
| 79 |
+
def api_generate_image(prompt, spatial_img, height, width, seed, control_type, api_token=""):
|
| 80 |
+
# Verify the API token
|
| 81 |
+
if not verify_token(api_token):
|
| 82 |
+
return {"error": "Unauthorized access. Invalid token."}
|
| 83 |
+
|
| 84 |
+
try:
|
| 85 |
+
# Use the same function as the UI
|
| 86 |
+
image = single_condition_generate_image(
|
| 87 |
+
prompt, spatial_img, height, width, seed, control_type, api_token
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
if image is None or isinstance(image, str):
|
| 91 |
+
# Error occurred
|
| 92 |
+
error_msg = image if isinstance(image, str) else "Image generation failed"
|
| 93 |
+
return {"error": error_msg}
|
| 94 |
+
|
| 95 |
+
# Return the image directly instead of converting to base64
|
| 96 |
+
return image
|
| 97 |
+
|
| 98 |
+
except Exception as e:
|
| 99 |
+
error_msg = f"API error: {str(e)}"
|
| 100 |
+
print(error_msg)
|
| 101 |
+
return {"error": error_msg}
|
| 102 |
+
|
| 103 |
# Define the Gradio interface components
|
| 104 |
control_types = ["Ghibli"]
|
| 105 |
|
| 106 |
+
# Example data - add the API token for convenience (assuming only you can see the examples)
|
| 107 |
single_examples = [
|
| 108 |
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/00.png"), 680, 1024, 5, "Ghibli", API_TOKEN],
|
| 109 |
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/02.png"), 560, 1024, 42, "Ghibli", API_TOKEN],
|
|
|
|
| 158 |
outputs=single_output_image
|
| 159 |
)
|
| 160 |
|
| 161 |
+
# Create an API endpoint that clients can use programmatically
|
| 162 |
+
demo.queue()
|
| 163 |
+
|
| 164 |
+
# Add the API endpoint
|
| 165 |
+
demo.load(api_generate_image,
|
| 166 |
+
inputs=[prompt, spatial_img, height, width, seed, control_type, api_token],
|
| 167 |
+
outputs=gr.JSON(),
|
| 168 |
+
api_name="generate")
|
| 169 |
+
|
| 170 |
# Launch the Gradio app
|
| 171 |
+
demo.launch()
|