Gemini899 commited on
Commit
7bdaae1
·
verified ·
1 Parent(s): 03ea29b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -51
app.py CHANGED
@@ -1,22 +1,28 @@
1
- import spaces
2
- import gradio as gr
3
- import re
4
- from PIL import Image
5
- import io
6
- import base64
7
  import os
 
8
  import json
 
 
 
9
  import numpy as np
10
  import torch
 
 
11
  from diffusers import FluxImg2ImgPipeline
12
  from cryptography.fernet import Fernet
13
  from cryptography.hazmat.primitives import hashes
14
  from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
15
 
 
 
 
16
  dtype = torch.bfloat16
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
 
19
- pipe = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(device)
 
 
 
20
 
21
  def generate_key(password, salt=None):
22
  if salt is None:
@@ -30,16 +36,19 @@ def generate_key(password, salt=None):
30
  key = base64.urlsafe_b64encode(kdf.derive(password.encode()))
31
  return key, salt
32
 
33
- def encrypt_image(image):
 
 
 
 
34
  # Convert PIL Image to bytes
35
  img_byte_arr = io.BytesIO()
36
  image.save(img_byte_arr, format='PNG')
37
  img_byte_arr = img_byte_arr.getvalue()
38
 
39
- # Generate key for encryption using the secret from the environment
40
- key, salt = generate_key(ENCRYPTION_KEY)
41
  cipher = Fernet(key)
42
-
43
  encrypted_data = cipher.encrypt(img_byte_arr)
44
 
45
  return {
@@ -49,15 +58,17 @@ def encrypt_image(image):
49
  'original_height': image.height
50
  }
51
 
52
- def decrypt_image(encrypted_data_dict):
 
 
 
53
  # Extract the encrypted data and salt
54
  encrypted_data = base64.b64decode(encrypted_data_dict['encrypted_data'])
55
  salt = base64.b64decode(encrypted_data_dict['salt'])
56
 
57
- # Regenerate the key using the same password and salt
58
- key, _ = generate_key(ENCRYPTION_KEY, salt)
59
  cipher = Fernet(key)
60
-
61
  decrypted_data = cipher.decrypt(encrypted_data)
62
  image = Image.open(io.BytesIO(decrypted_data))
63
  return image
@@ -72,7 +83,7 @@ def convert_to_fit_size(original_width_and_height, maximum_size=2048):
72
  width, height = original_width_and_height
73
  if width <= maximum_size and height <= maximum_size:
74
  return width, height
75
-
76
  if width > height:
77
  scaling_factor = maximum_size / width
78
  else:
@@ -88,19 +99,18 @@ def adjust_to_multiple_of_32(width: int, height: int):
88
  return width, height
89
 
90
  @spaces.GPU(duration=120)
91
- def process_images(image, prompt="a girl", strength=0.75, seed=0, inference_step=4,
92
- encrypt_password="default_password", progress=gr.Progress(track_tqdm=True)):
93
  progress(0, desc="Starting")
94
 
95
  def process_img2img(image, prompt="a person", strength=0.75, seed=0, num_inference_steps=4):
96
  if image is None:
97
- print("empty input image returned")
98
  return None
99
  generator = torch.Generator(device).manual_seed(seed)
100
  fit_width, fit_height = convert_to_fit_size(image.size)
101
  width, height = adjust_to_multiple_of_32(fit_width, fit_height)
102
  image = image.resize((width, height), Image.LANCZOS)
103
-
104
  output = pipe(
105
  prompt=prompt,
106
  image=image,
@@ -112,28 +122,26 @@ def process_images(image, prompt="a girl", strength=0.75, seed=0, inference_step
112
  num_inference_steps=num_inference_steps,
113
  max_sequence_length=256
114
  )
115
-
116
  pil_image = output.images[0]
117
  new_width, new_height = pil_image.size
118
-
119
  if (new_width != fit_width) or (new_height != fit_height):
120
  resized_image = pil_image.resize((fit_width, fit_height), Image.LANCZOS)
121
  return resized_image
122
  return pil_image
123
-
124
  output = process_img2img(image, prompt, strength, seed, inference_step)
125
-
126
- # Encrypt the output image
127
  if output is not None:
128
- encrypted_output = encrypt_image(output, encrypt_password)
129
- # Instead of returning a gray placeholder, show the real pipeline result:
130
  return {
131
- "display_image": output, # <--- Use the real pipeline output
132
  "encrypted_data": encrypted_output
133
  }
134
  return None
135
 
136
-
137
  def save_encrypted_image(encrypted_data, filename="encrypted_image.enc"):
138
  with open(filename, 'w') as f:
139
  json.dump(encrypted_data, f)
@@ -157,7 +165,7 @@ css = """
157
  display: flex;
158
  align-items: center;
159
  justify-content: center;
160
- gap:10px
161
  }
162
 
163
  .image {
@@ -182,7 +190,7 @@ css = """
182
  with gr.Blocks(css=css, elem_id="demo-container") as demo:
183
  # Store encrypted data in a state variable
184
  encrypted_output_state = gr.State(None)
185
-
186
  with gr.Column():
187
  gr.HTML(read_file("demo_header.html"))
188
  gr.HTML(read_file("demo_tools.html"))
@@ -204,9 +212,7 @@ with gr.Blocks(css=css, elem_id="demo-container") as demo:
204
  placeholder="Your prompt (what you want in place of what is erased)",
205
  elem_id="prompt"
206
  )
207
-
208
  btn = gr.Button("Img2Img", elem_id="run_button", variant="primary")
209
-
210
  with gr.Accordion(label="Advanced Settings", open=False):
211
  with gr.Row(equal_height=True):
212
  strength = gr.Number(
@@ -218,15 +224,8 @@ with gr.Blocks(css=css, elem_id="demo-container") as demo:
218
  inference_step = gr.Number(
219
  value=4, minimum=1, step=4, label="Inference Steps"
220
  )
221
- encrypt_password = gr.Textbox(
222
- label="Encryption Password",
223
- value="default_password",
224
- type="password"
225
- )
226
  id_input = gr.Text(label="Name", visible=False)
227
-
228
  with gr.Column():
229
- # Display placeholder image
230
  image_out = gr.Image(
231
  height=800,
232
  sources=[],
@@ -252,31 +251,29 @@ with gr.Blocks(css=css, elem_id="demo-container") as demo:
252
  ],
253
  inputs=[image, image_out, prompt],
254
  )
255
-
256
  gr.HTML(read_file("demo_footer.html"))
257
-
258
  # Process images and encrypt outputs
259
- def handle_image_generation(image, prompt, strength, seed, inference_step, encrypt_password):
260
- result = process_images(image, prompt, strength, seed, inference_step, encrypt_password)
261
  if result:
262
  return result["display_image"], result["encrypted_data"]
263
  return None, None
264
 
265
- # >>>> CHANGED: Use .click() and .submit() with api_name
266
  btn.click(
267
  fn=handle_image_generation,
268
- inputs=[image, prompt, strength, seed, inference_step, encrypt_password],
269
  outputs=[image_out, encrypted_output_state],
270
- api_name="/process_images" # Exposes handle_image_generation as /process_images
271
  )
272
 
273
  prompt.submit(
274
  fn=handle_image_generation,
275
- inputs=[image, prompt, strength, seed, inference_step, encrypt_password],
276
  outputs=[image_out, encrypted_output_state],
277
- api_name="/process_images" # Same endpoint
278
  )
279
- # <<<< END CHANGE
280
 
281
  def handle_save_encrypted(encrypted_data):
282
  if encrypted_data:
@@ -286,7 +283,7 @@ with gr.Blocks(css=css, elem_id="demo-container") as demo:
286
  json.dump(encrypted_data, f)
287
  return f"Encrypted image saved to {path}"
288
  return "No encrypted image to save"
289
-
290
  save_btn.click(
291
  fn=handle_save_encrypted,
292
  inputs=[encrypted_output_state],
 
 
 
 
 
 
 
1
  import os
2
+ import io
3
  import json
4
+ import base64
5
+ import re
6
+ from PIL import Image
7
  import numpy as np
8
  import torch
9
+ import gradio as gr
10
+ import spaces
11
  from diffusers import FluxImg2ImgPipeline
12
  from cryptography.fernet import Fernet
13
  from cryptography.hazmat.primitives import hashes
14
  from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
15
 
16
+ # Retrieve the encryption key from the environment (set in Hugging Face Secrets Manager)
17
+ ENCRYPTION_KEY = os.environ.get("key", "FAKEFALLBACKKEY_FOR_LOCAL_TESTING")
18
+
19
  dtype = torch.bfloat16
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
21
 
22
+ pipe = FluxImg2ImgPipeline.from_pretrained(
23
+ "black-forest-labs/FLUX.1-schnell",
24
+ torch_dtype=torch.bfloat16
25
+ ).to(device)
26
 
27
  def generate_key(password, salt=None):
28
  if salt is None:
 
36
  key = base64.urlsafe_b64encode(kdf.derive(password.encode()))
37
  return key, salt
38
 
39
+ def encrypt_image(image, password=None):
40
+ # Use the secure key if no override is provided
41
+ if password is None:
42
+ password = ENCRYPTION_KEY
43
+
44
  # Convert PIL Image to bytes
45
  img_byte_arr = io.BytesIO()
46
  image.save(img_byte_arr, format='PNG')
47
  img_byte_arr = img_byte_arr.getvalue()
48
 
49
+ # Generate key for encryption using the secure password
50
+ key, salt = generate_key(password)
51
  cipher = Fernet(key)
 
52
  encrypted_data = cipher.encrypt(img_byte_arr)
53
 
54
  return {
 
58
  'original_height': image.height
59
  }
60
 
61
+ def decrypt_image(encrypted_data_dict, password=None):
62
+ if password is None:
63
+ password = ENCRYPTION_KEY
64
+
65
  # Extract the encrypted data and salt
66
  encrypted_data = base64.b64decode(encrypted_data_dict['encrypted_data'])
67
  salt = base64.b64decode(encrypted_data_dict['salt'])
68
 
69
+ # Regenerate the key using the secure password and salt
70
+ key, _ = generate_key(password, salt)
71
  cipher = Fernet(key)
 
72
  decrypted_data = cipher.decrypt(encrypted_data)
73
  image = Image.open(io.BytesIO(decrypted_data))
74
  return image
 
83
  width, height = original_width_and_height
84
  if width <= maximum_size and height <= maximum_size:
85
  return width, height
86
+
87
  if width > height:
88
  scaling_factor = maximum_size / width
89
  else:
 
99
  return width, height
100
 
101
  @spaces.GPU(duration=120)
102
+ def process_images(image, prompt="a girl", strength=0.75, seed=0, inference_step=4, progress=gr.Progress(track_tqdm=True)):
 
103
  progress(0, desc="Starting")
104
 
105
  def process_img2img(image, prompt="a person", strength=0.75, seed=0, num_inference_steps=4):
106
  if image is None:
107
+ print("Empty input image returned")
108
  return None
109
  generator = torch.Generator(device).manual_seed(seed)
110
  fit_width, fit_height = convert_to_fit_size(image.size)
111
  width, height = adjust_to_multiple_of_32(fit_width, fit_height)
112
  image = image.resize((width, height), Image.LANCZOS)
113
+
114
  output = pipe(
115
  prompt=prompt,
116
  image=image,
 
122
  num_inference_steps=num_inference_steps,
123
  max_sequence_length=256
124
  )
125
+
126
  pil_image = output.images[0]
127
  new_width, new_height = pil_image.size
128
+
129
  if (new_width != fit_width) or (new_height != fit_height):
130
  resized_image = pil_image.resize((fit_width, fit_height), Image.LANCZOS)
131
  return resized_image
132
  return pil_image
133
+
134
  output = process_img2img(image, prompt, strength, seed, inference_step)
135
+
136
+ # Encrypt the output image using the secure key
137
  if output is not None:
138
+ encrypted_output = encrypt_image(output)
 
139
  return {
140
+ "display_image": output,
141
  "encrypted_data": encrypted_output
142
  }
143
  return None
144
 
 
145
  def save_encrypted_image(encrypted_data, filename="encrypted_image.enc"):
146
  with open(filename, 'w') as f:
147
  json.dump(encrypted_data, f)
 
165
  display: flex;
166
  align-items: center;
167
  justify-content: center;
168
+ gap:10px;
169
  }
170
 
171
  .image {
 
190
  with gr.Blocks(css=css, elem_id="demo-container") as demo:
191
  # Store encrypted data in a state variable
192
  encrypted_output_state = gr.State(None)
193
+
194
  with gr.Column():
195
  gr.HTML(read_file("demo_header.html"))
196
  gr.HTML(read_file("demo_tools.html"))
 
212
  placeholder="Your prompt (what you want in place of what is erased)",
213
  elem_id="prompt"
214
  )
 
215
  btn = gr.Button("Img2Img", elem_id="run_button", variant="primary")
 
216
  with gr.Accordion(label="Advanced Settings", open=False):
217
  with gr.Row(equal_height=True):
218
  strength = gr.Number(
 
224
  inference_step = gr.Number(
225
  value=4, minimum=1, step=4, label="Inference Steps"
226
  )
 
 
 
 
 
227
  id_input = gr.Text(label="Name", visible=False)
 
228
  with gr.Column():
 
229
  image_out = gr.Image(
230
  height=800,
231
  sources=[],
 
251
  ],
252
  inputs=[image, image_out, prompt],
253
  )
254
+
255
  gr.HTML(read_file("demo_footer.html"))
256
+
257
  # Process images and encrypt outputs
258
+ def handle_image_generation(image, prompt, strength, seed, inference_step):
259
+ result = process_images(image, prompt, strength, seed, inference_step)
260
  if result:
261
  return result["display_image"], result["encrypted_data"]
262
  return None, None
263
 
 
264
  btn.click(
265
  fn=handle_image_generation,
266
+ inputs=[image, prompt, strength, seed, inference_step],
267
  outputs=[image_out, encrypted_output_state],
268
+ api_name="/process_images"
269
  )
270
 
271
  prompt.submit(
272
  fn=handle_image_generation,
273
+ inputs=[image, prompt, strength, seed, inference_step],
274
  outputs=[image_out, encrypted_output_state],
275
+ api_name="/process_images"
276
  )
 
277
 
278
  def handle_save_encrypted(encrypted_data):
279
  if encrypted_data:
 
283
  json.dump(encrypted_data, f)
284
  return f"Encrypted image saved to {path}"
285
  return "No encrypted image to save"
286
+
287
  save_btn.click(
288
  fn=handle_save_encrypted,
289
  inputs=[encrypted_output_state],