arrafaqat commited on
Commit
a49b272
·
verified ·
1 Parent(s): 9300754

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -28
app.py CHANGED
@@ -8,7 +8,6 @@ from tqdm import tqdm
8
  import gradio as gr
9
  import base64
10
  import io
11
- from fastapi import Request
12
 
13
  from safetensors.torch import save_file
14
  from src.pipeline import FluxPipeline
@@ -20,7 +19,7 @@ base_path = "black-forest-labs/FLUX.1-dev"
20
  lora_base_path = "./models"
21
 
22
  # Environment variable for API token (set this in your Hugging Face space settings)
23
- API_TOKEN = os.environ.get("HF_TOKEN")
24
 
25
  # Initialize the pipeline
26
  pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16)
@@ -39,21 +38,10 @@ def verify_token(token):
39
 
40
  # Define the Gradio interface with token verification
41
  @spaces.GPU()
42
- def single_condition_generate_image(prompt, spatial_img, height, width, seed, control_type, api_token="", request: gr.Request = None):
43
- # Check authentication for API requests
44
- if not request.is_from_ui:
45
- # For API requests, check Authorization header
46
- auth_header = request.headers.get("Authorization")
47
- if not auth_header or not auth_header.startswith("Bearer "):
48
- return {"error": "Unauthorized access. Invalid or missing token in Authorization header."}
49
-
50
- token = auth_header.replace("Bearer ", "")
51
- if not verify_token(token):
52
- return {"error": "Unauthorized access. Invalid token."}
53
- else:
54
- # For UI requests, check the token input
55
- if not verify_token(api_token):
56
- return "Unauthorized: Please enter a valid API token"
57
 
58
  try:
59
  # Set the control type
@@ -77,26 +65,45 @@ def single_condition_generate_image(prompt, spatial_img, height, width, seed, co
77
  ).images[0]
78
  clear_cache(pipe.transformer)
79
 
80
- # Convert to base64 for API responses
81
- if not request.is_from_ui:
82
- buffered = io.BytesIO()
83
- image.save(buffered, format="PNG")
84
- img_str = base64.b64encode(buffered.getvalue()).decode()
85
- return {"meta": {"format": "png", "base64": img_str}}
86
-
87
  return image
88
 
89
  except Exception as e:
90
  error_msg = f"Error during image generation: {str(e)}"
91
  print(error_msg)
92
- if not request.is_from_ui:
93
- return {"error": error_msg}
94
  return None
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  # Define the Gradio interface components
97
  control_types = ["Ghibli"]
98
 
99
- # Example data
100
  single_examples = [
101
  ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/00.png"), 680, 1024, 5, "Ghibli", API_TOKEN],
102
  ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/02.png"), 560, 1024, 42, "Ghibli", API_TOKEN],
@@ -151,5 +158,14 @@ with gr.Blocks() as demo:
151
  outputs=single_output_image
152
  )
153
 
 
 
 
 
 
 
 
 
 
154
  # Launch the Gradio app
155
- demo.queue().launch()
 
8
  import gradio as gr
9
  import base64
10
  import io
 
11
 
12
  from safetensors.torch import save_file
13
  from src.pipeline import FluxPipeline
 
19
  lora_base_path = "./models"
20
 
21
  # Environment variable for API token (set this in your Hugging Face space settings)
22
+ API_TOKEN = os.environ.get("HF_API_TOKEN")
23
 
24
  # Initialize the pipeline
25
  pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16)
 
38
 
39
  # Define the Gradio interface with token verification
40
  @spaces.GPU()
41
+ def single_condition_generate_image(prompt, spatial_img, height, width, seed, control_type, api_token="", request=None):
42
+ # Check authentication
43
+ if not verify_token(api_token):
44
+ return "Unauthorized: Please enter a valid API token"
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  try:
47
  # Set the control type
 
65
  ).images[0]
66
  clear_cache(pipe.transformer)
67
 
68
+ # We'll always return the PIL image for UI
69
+ # The API will extract base64 from the returned image
 
 
 
 
 
70
  return image
71
 
72
  except Exception as e:
73
  error_msg = f"Error during image generation: {str(e)}"
74
  print(error_msg)
 
 
75
  return None
76
 
77
+ # Define an API endpoint that uses the main function but returns proper JSON
78
+ @spaces.GPU()
79
+ def api_generate_image(prompt, spatial_img, height, width, seed, control_type, api_token=""):
80
+ # Verify the API token
81
+ if not verify_token(api_token):
82
+ return {"error": "Unauthorized access. Invalid token."}
83
+
84
+ try:
85
+ # Use the same function as the UI
86
+ image = single_condition_generate_image(
87
+ prompt, spatial_img, height, width, seed, control_type, api_token
88
+ )
89
+
90
+ if image is None or isinstance(image, str):
91
+ # Error occurred
92
+ error_msg = image if isinstance(image, str) else "Image generation failed"
93
+ return {"error": error_msg}
94
+
95
+ # Return the image directly instead of converting to base64
96
+ return image
97
+
98
+ except Exception as e:
99
+ error_msg = f"API error: {str(e)}"
100
+ print(error_msg)
101
+ return {"error": error_msg}
102
+
103
  # Define the Gradio interface components
104
  control_types = ["Ghibli"]
105
 
106
+ # Example data - add the API token for convenience (assuming only you can see the examples)
107
  single_examples = [
108
  ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/00.png"), 680, 1024, 5, "Ghibli", API_TOKEN],
109
  ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/02.png"), 560, 1024, 42, "Ghibli", API_TOKEN],
 
158
  outputs=single_output_image
159
  )
160
 
161
+ # Create an API endpoint that clients can use programmatically
162
+ demo.queue()
163
+
164
+ # Add the API endpoint
165
+ demo.load(api_generate_image,
166
+ inputs=[prompt, spatial_img, height, width, seed, control_type, api_token],
167
+ outputs=gr.JSON(),
168
+ api_name="generate")
169
+
170
  # Launch the Gradio app
171
+ demo.launch()