arrafaqat commited on
Commit
c917731
·
verified ·
1 Parent(s): a49b272

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -99
app.py CHANGED
@@ -6,8 +6,6 @@ import torch
6
  from PIL import Image
7
  from tqdm import tqdm
8
  import gradio as gr
9
- import base64
10
- import io
11
 
12
  from safetensors.torch import save_file
13
  from src.pipeline import FluxPipeline
@@ -19,9 +17,8 @@ base_path = "black-forest-labs/FLUX.1-dev"
19
  lora_base_path = "./models"
20
 
21
  # Environment variable for API token (set this in your Hugging Face space settings)
22
- API_TOKEN = os.environ.get("HF_API_TOKEN")
23
 
24
- # Initialize the pipeline
25
  pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16)
26
  transformer = FluxTransformer2DModel.from_pretrained(base_path, subfolder="transformer", torch_dtype=torch.bfloat16)
27
  pipe.transformer = transformer
@@ -31,100 +28,52 @@ def clear_cache(transformer):
31
  for name, attn_processor in transformer.attn_processors.items():
32
  attn_processor.bank_kv.clear()
33
 
34
- # Token verification function
35
- def verify_token(token):
36
- """Verify if the provided token matches the API token"""
37
- return API_TOKEN and token == API_TOKEN
38
-
39
  # Define the Gradio interface with token verification
40
  @spaces.GPU()
41
- def single_condition_generate_image(prompt, spatial_img, height, width, seed, control_type, api_token="", request=None):
42
- # Check authentication
43
- if not verify_token(api_token):
44
- return "Unauthorized: Please enter a valid API token"
45
 
46
- try:
47
- # Set the control type
48
- if control_type == "Ghibli":
49
- lora_path = os.path.join(lora_base_path, "Ghibli.safetensors")
50
- set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512)
51
-
52
- # Process the image
53
- spatial_imgs = [spatial_img] if spatial_img else []
54
- image = pipe(
55
- prompt,
56
- height=int(height),
57
- width=int(width),
58
- guidance_scale=3.5,
59
- num_inference_steps=25,
60
- max_sequence_length=512,
61
- generator=torch.Generator("cpu").manual_seed(seed),
62
- subject_images=[],
63
- spatial_images=spatial_imgs,
64
- cond_size=512,
65
- ).images[0]
66
- clear_cache(pipe.transformer)
67
-
68
- # We'll always return the PIL image for UI
69
- # The API will extract base64 from the returned image
70
- return image
71
 
72
- except Exception as e:
73
- error_msg = f"Error during image generation: {str(e)}"
74
- print(error_msg)
75
- return None
76
-
77
- # Define an API endpoint that uses the main function but returns proper JSON
78
- @spaces.GPU()
79
- def api_generate_image(prompt, spatial_img, height, width, seed, control_type, api_token=""):
80
- # Verify the API token
81
- if not verify_token(api_token):
82
- return {"error": "Unauthorized access. Invalid token."}
83
-
84
- try:
85
- # Use the same function as the UI
86
- image = single_condition_generate_image(
87
- prompt, spatial_img, height, width, seed, control_type, api_token
88
- )
89
-
90
- if image is None or isinstance(image, str):
91
- # Error occurred
92
- error_msg = image if isinstance(image, str) else "Image generation failed"
93
- return {"error": error_msg}
94
-
95
- # Return the image directly instead of converting to base64
96
- return image
97
-
98
- except Exception as e:
99
- error_msg = f"API error: {str(e)}"
100
- print(error_msg)
101
- return {"error": error_msg}
102
 
103
  # Define the Gradio interface components
104
  control_types = ["Ghibli"]
105
 
106
- # Example data - add the API token for convenience (assuming only you can see the examples)
107
- single_examples = [
108
- ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/00.png"), 680, 1024, 5, "Ghibli", API_TOKEN],
109
- ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/02.png"), 560, 1024, 42, "Ghibli", API_TOKEN],
110
- ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/03.png"), 568, 1024, 1, "Ghibli", API_TOKEN],
111
- ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/04.png"), 768, 672, 1, "Ghibli", API_TOKEN],
112
- ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/06.png"), 896, 1024, 1, "Ghibli", API_TOKEN],
113
- ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/07.png"), 528, 800, 1, "Ghibli", API_TOKEN],
114
- ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/08.png"), 696, 1024, 1, "Ghibli", API_TOKEN],
115
- ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/09.png"), 896, 1024, 1, "Ghibli", API_TOKEN],
116
- ]
117
-
118
  # Create the Gradio Blocks interface
119
  with gr.Blocks() as demo:
120
  gr.Markdown("# Ghibli Studio Control Image Generation with EasyControl")
121
- gr.Markdown("⚠️ **AUTHENTICATION REQUIRED**: You must enter a valid API token to use this interface.")
 
 
 
 
 
 
 
122
  gr.Markdown("The model is trained on **only 100 real Asian faces** paired with **GPT-4o-generated Ghibli-style counterparts**, and it preserves facial features while applying the iconic anime aesthetic.")
123
  gr.Markdown("Generate images using EasyControl with Ghibli control LoRAs.(Due to hardware constraints, only low-resolution images can be generated. For high-resolution (1024+), please set up your own environment.)")
124
 
125
- # Authentication input - visible at the top of the interface
126
- api_token = gr.Textbox(label="API Token (Required)", type="password", value="")
127
-
128
  gr.Markdown("**[Attention!!]**:The recommended prompts for using Ghibli Control LoRA should include the trigger words: `Ghibli Studio style, Charming hand-drawn anime-style illustration`")
129
  gr.Markdown("😊😊If you like this demo, please give us a star (github: [EasyControl](https://github.com/Xiaojiu-z/EasyControl))")
130
 
@@ -141,31 +90,53 @@ with gr.Blocks() as demo:
141
  with gr.Column():
142
  single_output_image = gr.Image(label="Generated Image")
143
 
144
- # Add examples for Single Condition Generation (including the token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  gr.Examples(
146
- examples=single_examples,
147
- inputs=[prompt, spatial_img, height, width, seed, control_type, api_token],
148
  outputs=single_output_image,
149
  fn=single_condition_generate_image,
150
  cache_examples=False,
151
  label="Single Condition Examples"
152
  )
153
 
154
- # Link the buttons to the functions, including the token
 
 
 
 
155
  single_generate_btn.click(
156
  single_condition_generate_image,
157
- inputs=[prompt, spatial_img, height, width, seed, control_type, api_token],
158
  outputs=single_output_image
159
  )
160
 
161
- # Create an API endpoint that clients can use programmatically
162
- demo.queue()
163
-
164
- # Add the API endpoint
165
- demo.load(api_generate_image,
166
- inputs=[prompt, spatial_img, height, width, seed, control_type, api_token],
167
- outputs=gr.JSON(),
168
- api_name="generate")
169
-
170
  # Launch the Gradio app
171
- demo.launch()
 
6
  from PIL import Image
7
  from tqdm import tqdm
8
  import gradio as gr
 
 
9
 
10
  from safetensors.torch import save_file
11
  from src.pipeline import FluxPipeline
 
17
  lora_base_path = "./models"
18
 
19
  # Environment variable for API token (set this in your Hugging Face space settings)
20
+ API_TOKEN = os.environ.get("HF_TOKEN")
21
 
 
22
  pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16)
23
  transformer = FluxTransformer2DModel.from_pretrained(base_path, subfolder="transformer", torch_dtype=torch.bfloat16)
24
  pipe.transformer = transformer
 
28
  for name, attn_processor in transformer.attn_processors.items():
29
  attn_processor.bank_kv.clear()
30
 
 
 
 
 
 
31
  # Define the Gradio interface with token verification
32
  @spaces.GPU()
33
+ def single_condition_generate_image(prompt, spatial_img, height, width, seed, control_type, api_token=""):
34
+ # Check if API token is required and valid
35
+ if API_TOKEN and api_token != API_TOKEN:
36
+ return "ERROR: Invalid API token. Please provide a valid token to generate images."
37
 
38
+ # Set the control type
39
+ if control_type == "Ghibli":
40
+ lora_path = os.path.join(lora_base_path, "Ghibli.safetensors")
41
+ set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ # Process the image
44
+ spatial_imgs = [spatial_img] if spatial_img else []
45
+ image = pipe(
46
+ prompt,
47
+ height=int(height),
48
+ width=int(width),
49
+ guidance_scale=3.5,
50
+ num_inference_steps=25,
51
+ max_sequence_length=512,
52
+ generator=torch.Generator("cpu").manual_seed(seed),
53
+ subject_images=[],
54
+ spatial_images=spatial_imgs,
55
+ cond_size=512,
56
+ ).images[0]
57
+ clear_cache(pipe.transformer)
58
+ return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  # Define the Gradio interface components
61
  control_types = ["Ghibli"]
62
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  # Create the Gradio Blocks interface
64
  with gr.Blocks() as demo:
65
  gr.Markdown("# Ghibli Studio Control Image Generation with EasyControl")
66
+
67
+ # Only show token field if API token is required
68
+ if API_TOKEN:
69
+ gr.Markdown("⚠️ **AUTHENTICATION REQUIRED**: Please enter your API token to use this service.")
70
+ api_token = gr.Textbox(label="API Token", type="password", value="")
71
+ else:
72
+ api_token = gr.Textbox(visible=False, value="") # Hidden field with empty value
73
+
74
  gr.Markdown("The model is trained on **only 100 real Asian faces** paired with **GPT-4o-generated Ghibli-style counterparts**, and it preserves facial features while applying the iconic anime aesthetic.")
75
  gr.Markdown("Generate images using EasyControl with Ghibli control LoRAs.(Due to hardware constraints, only low-resolution images can be generated. For high-resolution (1024+), please set up your own environment.)")
76
 
 
 
 
77
  gr.Markdown("**[Attention!!]**:The recommended prompts for using Ghibli Control LoRA should include the trigger words: `Ghibli Studio style, Charming hand-drawn anime-style illustration`")
78
  gr.Markdown("😊😊If you like this demo, please give us a star (github: [EasyControl](https://github.com/Xiaojiu-z/EasyControl))")
79
 
 
90
  with gr.Column():
91
  single_output_image = gr.Image(label="Generated Image")
92
 
93
+ # Set up examples (with token automatically added if present)
94
+ example_inputs = [prompt, spatial_img, height, width, seed, control_type]
95
+ if API_TOKEN:
96
+ # Add token to examples for convenience
97
+ example_data = [
98
+ ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/00.png"), 680, 1024, 5, "Ghibli", API_TOKEN],
99
+ ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/02.png"), 560, 1024, 42, "Ghibli", API_TOKEN],
100
+ ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/03.png"), 568, 1024, 1, "Ghibli", API_TOKEN],
101
+ ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/04.png"), 768, 672, 1, "Ghibli", API_TOKEN],
102
+ ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/06.png"), 896, 1024, 1, "Ghibli", API_TOKEN],
103
+ ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/07.png"), 528, 800, 1, "Ghibli", API_TOKEN],
104
+ ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/08.png"), 696, 1024, 1, "Ghibli", API_TOKEN],
105
+ ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/09.png"), 896, 1024, 1, "Ghibli", API_TOKEN],
106
+ ]
107
+ example_inputs.append(api_token)
108
+ else:
109
+ # Use examples without token
110
+ example_data = [
111
+ ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/00.png"), 680, 1024, 5, "Ghibli"],
112
+ ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/02.png"), 560, 1024, 42, "Ghibli"],
113
+ ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/03.png"), 568, 1024, 1, "Ghibli"],
114
+ ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/04.png"), 768, 672, 1, "Ghibli"],
115
+ ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/06.png"), 896, 1024, 1, "Ghibli"],
116
+ ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/07.png"), 528, 800, 1, "Ghibli"],
117
+ ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/08.png"), 696, 1024, 1, "Ghibli"],
118
+ ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/09.png"), 896, 1024, 1, "Ghibli"],
119
+ ]
120
+
121
  gr.Examples(
122
+ examples=example_data,
123
+ inputs=example_inputs,
124
  outputs=single_output_image,
125
  fn=single_condition_generate_image,
126
  cache_examples=False,
127
  label="Single Condition Examples"
128
  )
129
 
130
+ # Link the buttons to the functions with API token included
131
+ inputs = [prompt, spatial_img, height, width, seed, control_type]
132
+ if API_TOKEN:
133
+ inputs.append(api_token)
134
+
135
  single_generate_btn.click(
136
  single_condition_generate_image,
137
+ inputs=inputs,
138
  outputs=single_output_image
139
  )
140
 
 
 
 
 
 
 
 
 
 
141
  # Launch the Gradio app
142
+ demo.queue().launch()