Gemini899 commited on
Commit
46d2fcf
·
verified ·
1 Parent(s): bd1ec8b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +246 -114
app.py CHANGED
@@ -2,52 +2,111 @@ import spaces
2
  import gradio as gr
3
  import re
4
  from PIL import Image
 
5
  import os
6
- import gc # Add garbage collection
7
- import psutil # Add for memory monitoring
8
-
9
- # Set memory optimization flags
10
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64" # Reduced value
11
- os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Ensure using single GPU
12
-
13
  import numpy as np
14
  import torch
15
  from diffusers import FluxImg2ImgPipeline
 
 
 
 
 
 
 
 
 
 
16
 
17
- # Use float16 instead of bfloat16 for T4 compatibility
18
- dtype = torch.float16
19
- device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
- # Initialize the pipe directly during startup
22
- print("Loading model during startup...")
23
- torch.cuda.empty_cache()
24
- gc.collect() # Force garbage collection
 
 
25
 
26
- pipe = FluxImg2ImgPipeline.from_pretrained(
27
- "black-forest-labs/FLUX.1-schnell",
28
- torch_dtype=torch.float16,
29
- low_cpu_mem_usage=True,
30
- use_safetensors=True
31
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- # Enable attention slicing to reduce memory footprint
34
- pipe.enable_attention_slicing(1)
35
 
36
- # Move to device immediately
37
- if torch.cuda.is_available():
38
- pipe = pipe.to("cuda:0")
39
- else:
40
- pipe = pipe.to("cpu")
41
 
42
- print("Model loaded successfully")
 
 
 
 
 
 
 
43
 
44
  def sanitize_prompt(prompt):
45
- # Allow only alphanumeric characters, spaces, and basic punctuation
46
- allowed_chars = re.compile(r"[^a-zA-Z0-9\s.,!?-]")
47
- sanitized_prompt = allowed_chars.sub("", prompt)
48
- return sanitized_prompt
49
 
50
- def convert_to_fit_size(original_width_and_height, maximum_size = 1024):
51
  width, height = original_width_and_height
52
  if width <= maximum_size and height <= maximum_size:
53
  return width, height
@@ -66,64 +125,72 @@ def adjust_to_multiple_of_32(width: int, height: int):
66
  height = height - (height % 32)
67
  return width, height
68
 
69
- def resize_image(image: Image.Image, max_dim: int = 384) -> Image.Image:
70
- """Resizes image to fit within max_dim while preserving aspect ratio"""
71
- w, h = image.size
72
- ratio = min(max_dim / w, max_dim / h)
73
- if ratio < 1.0:
74
- new_w = int(w * ratio)
75
- new_h = int(h * ratio)
76
- image = image.resize((new_w, new_h), Image.LANCZOS)
77
- return image
78
-
79
- # Increase the timeout to 4 minutes
80
- @spaces.GPU(duration=740)
81
- def process_images(image, prompt="a girl", strength=0.75, seed=0, inference_step=2, progress=gr.Progress(track_tqdm=True)):
82
- progress(0, desc="Starting and freeing memory")
83
- # Free memory before processing
84
- torch.cuda.empty_cache()
85
- gc.collect()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- progress(15, desc="Processing")
 
88
 
89
  def process_img2img(image, prompt="a person", strength=0.75, seed=0, num_inference_steps=4):
 
 
 
90
  if image is None:
91
- print("empty input image returned")
92
  return None
93
 
94
- # Convert webp to PNG if needed
95
- if hasattr(image, 'format') and image.format == 'WEBP':
96
- print("Converting WEBP to PNG")
97
- rgb_im = image.convert('RGB')
98
- image = rgb_im
99
-
100
- # Resize image to reduce memory usage
101
- image = resize_image(image, max_dim=512)
102
-
103
  generator = torch.Generator(device).manual_seed(seed)
104
- fit_width, fit_height = convert_to_fit_size(image.size, maximum_size=512)
105
  width, height = adjust_to_multiple_of_32(fit_width, fit_height)
106
  image = image.resize((width, height), Image.LANCZOS)
107
-
108
- progress(30, desc="Processing image")
109
-
110
- # Use autocast for better memory efficiency
111
- with torch.cuda.amp.autocast(dtype=torch.float16):
112
- with torch.no_grad():
113
- output = pipe(
114
- prompt=prompt,
115
- image=image,
116
- generator=generator,
117
- strength=strength,
118
- width=width,
119
- height=height,
120
- guidance_scale=0,
121
- num_inference_steps=num_inference_steps,
122
- max_sequence_length=256
123
- )
124
-
125
- progress(90, desc="Finalizing")
126
-
127
  pil_image = output.images[0]
128
  new_width, new_height = pil_image.size
129
 
@@ -132,9 +199,20 @@ def process_images(image, prompt="a girl", strength=0.75, seed=0, inference_step
132
  return resized_image
133
  return pil_image
134
 
 
135
  output = process_img2img(image, prompt, strength, seed, inference_step)
136
- progress(100, desc="Done")
137
- return output
 
 
 
 
 
 
 
 
 
 
138
 
139
  def read_file(path: str) -> str:
140
  with open(path, 'r', encoding='utf-8') as f:
@@ -168,48 +246,102 @@ css="""
168
  }
169
  """
170
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  with gr.Blocks(css=css, elem_id="demo-container") as demo:
172
  with gr.Column():
173
  gr.HTML(read_file("demo_header.html"))
174
  gr.HTML(read_file("demo_tools.html"))
175
  with gr.Row():
176
- with gr.Column():
177
- image = gr.Image(height=800,sources=['upload','clipboard'],image_mode='RGB', elem_id="image_upload", type="pil", label="Upload")
178
- with gr.Row(elem_id="prompt-container", equal_height=False):
179
- with gr.Row():
180
- prompt = gr.Textbox(label="Prompt",value="a women",placeholder="Your prompt (what you want in place of what is erased)", elem_id="prompt")
181
-
182
- btn = gr.Button("Img2Img", elem_id="run_button",variant="primary")
 
 
 
 
 
 
 
 
 
 
183
 
184
- with gr.Accordion(label="Advanced Settings", open=False):
185
- with gr.Row( equal_height=True):
186
- strength = gr.Number(value=0.65, minimum=0, maximum=0.75, step=0.01, label="strength")
187
- seed = gr.Number(value=100, minimum=0, step=1, label="seed")
188
- inference_step = gr.Number(value=2, minimum=1, maximum=4, step=1, label="inference_step")
189
- id_input=gr.Text(label="Name", visible=False)
190
-
191
- with gr.Column():
192
- image_out = gr.Image(height=800,sources=[],label="Output", elem_id="output-img",format="jpg")
193
-
 
 
 
 
 
 
 
 
194
  gr.Examples(
195
- examples=[
196
- ["examples/draw_input.jpg", "examples/draw_output.jpg","a women ,eyes closed,mouth opened"],
197
- ["examples/draw-gimp_input.jpg", "examples/draw-gimp_output.jpg","a women ,eyes closed,mouth opened"],
198
- ["examples/gimp_input.jpg", "examples/gimp_output.jpg","a women ,hand on neck"],
199
- ["examples/inpaint_input.jpg", "examples/inpaint_output.jpg","a women ,hand on neck"]
200
- ]
201
- ,
202
- inputs=[image,image_out,prompt],
203
  )
204
  gr.HTML(
205
- gr.HTML(read_file("demo_footer.html"))
206
  )
207
  gr.on(
208
  triggers=[btn.click, prompt.submit],
209
- fn = process_images,
210
- inputs = [image,prompt,strength,seed,inference_step],
211
- outputs = [image_out]
212
  )
213
 
 
 
 
 
214
  if __name__ == "__main__":
215
- demo.launch(share=True, show_error=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import gradio as gr
3
  import re
4
  from PIL import Image
5
+ import io
6
  import os
 
 
 
 
 
 
 
7
  import numpy as np
8
  import torch
9
  from diffusers import FluxImg2ImgPipeline
10
+ import tempfile
11
+ import secrets
12
+ import uuid
13
+ import shutil
14
+ import ssl
15
+ from cryptography.fernet import Fernet
16
+ import base64
17
+ import hashlib
18
+ import time
19
+ import threading
20
 
21
+ # Global encryption key for this session
22
+ ENCRYPTION_KEY = Fernet.generate_key()
23
+ cipher_suite = Fernet(ENCRYPTION_KEY)
24
 
25
+ # Configure SSL context for secure connections
26
+ ssl_context = ssl.create_default_context()
27
+ ssl_context.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 # Disable older TLS protocols
28
+ ssl_context.set_ciphers('ECDHE+AESGCM:ECDHE+CHACHA20:DHE+AESGCM:DHE+CHACHA20')
29
+ ssl_context.check_hostname = True
30
+ ssl_context.verify_mode = ssl.CERT_REQUIRED
31
 
32
+ # Secure temporary directory manager
33
+ class SecureTempManager:
34
+ def __init__(self):
35
+ self.temp_dir = tempfile.mkdtemp(prefix='flux_secure_')
36
+ self.cleanup_timeout = 30 # seconds
37
+ self.file_registry = {}
38
+
39
+ def get_secure_path(self, prefix="img"):
40
+ """Generate a secure random filename in the temp directory"""
41
+ filename = f"{prefix}_{uuid.uuid4().hex}_{secrets.token_hex(8)}.png"
42
+ filepath = os.path.join(self.temp_dir, filename)
43
+
44
+ # Register file for cleanup
45
+ self.file_registry[filepath] = time.time()
46
+
47
+ return filepath
48
+
49
+ def cleanup_old_files(self):
50
+ """Clean up files older than the timeout"""
51
+ current_time = time.time()
52
+ for filepath, created_time in list(self.file_registry.items()):
53
+ if current_time - created_time > self.cleanup_timeout:
54
+ try:
55
+ if os.path.exists(filepath):
56
+ # Securely delete by overwriting with random data
57
+ file_size = os.path.getsize(filepath)
58
+ with open(filepath, 'wb') as f:
59
+ f.write(os.urandom(file_size))
60
+ # Then delete
61
+ os.remove(filepath)
62
+ # Remove from registry
63
+ del self.file_registry[filepath]
64
+ except Exception as e:
65
+ print(f"Error cleaning up file {filepath}: {e}")
66
+
67
+ def cleanup_all(self):
68
+ """Clean up all files and remove temp directory"""
69
+ # First clean up individual files
70
+ for filepath in list(self.file_registry.keys()):
71
+ try:
72
+ if os.path.exists(filepath):
73
+ os.remove(filepath)
74
+ del self.file_registry[filepath]
75
+ except:
76
+ pass
77
+
78
+ # Then remove the directory
79
+ try:
80
+ if os.path.exists(self.temp_dir):
81
+ shutil.rmtree(self.temp_dir)
82
+ except:
83
+ pass
84
 
85
+ # Initialize secure temp manager
86
+ secure_temp = SecureTempManager()
87
 
88
+ # Start a thread to periodically clean up old files
89
+ def cleanup_thread_function():
90
+ while True:
91
+ secure_temp.cleanup_old_files()
92
+ time.sleep(5) # Check every 5 seconds
93
 
94
+ cleanup_thread = threading.Thread(target=cleanup_thread_function)
95
+ cleanup_thread.daemon = True # Thread will exit when main program exits
96
+ cleanup_thread.start()
97
+
98
+ # Initialize model with proper settings
99
+ dtype = torch.bfloat16
100
+ device = "cuda" if torch.cuda.is_available() else "cpu"
101
+ pipe = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(device)
102
 
103
  def sanitize_prompt(prompt):
104
+ # Allow only alphanumeric characters, spaces, and basic punctuation
105
+ allowed_chars = re.compile(r"[^a-zA-Z0-9\s.,!?-]")
106
+ sanitized_prompt = allowed_chars.sub("", prompt)
107
+ return sanitized_prompt
108
 
109
+ def convert_to_fit_size(original_width_and_height, maximum_size = 2048):
110
  width, height = original_width_and_height
111
  if width <= maximum_size and height <= maximum_size:
112
  return width, height
 
125
  height = height - (height % 32)
126
  return width, height
127
 
128
+ # Function to securely handle image data
129
+ def secure_image_handler(image):
130
+ """Process image securely without exposing it to the file system"""
131
+ if image is None:
132
+ return None
133
+
134
+ # If the image is already a PIL Image, use it directly
135
+ if isinstance(image, Image.Image):
136
+ return image
137
+
138
+ # Otherwise, assume it's a file path or binary data
139
+ try:
140
+ if isinstance(image, str) and os.path.exists(image):
141
+ # It's a file path, load it securely
142
+ with open(image, 'rb') as f:
143
+ img_data = f.read()
144
+
145
+ # Immediately delete the original file if it's in our temp directory
146
+ if image.startswith(secure_temp.temp_dir):
147
+ try:
148
+ os.remove(image)
149
+ except:
150
+ pass
151
+
152
+ # Create image from binary data
153
+ return Image.open(io.BytesIO(img_data))
154
+ elif isinstance(image, bytes):
155
+ # It's binary data
156
+ return Image.open(io.BytesIO(image))
157
+ except Exception as e:
158
+ print(f"Error processing image: {e}")
159
+ return None
160
+
161
+ @spaces.GPU(duration=120)
162
+ def process_images(image, prompt="a girl", strength=0.75, seed=0, inference_step=4, progress=gr.Progress(track_tqdm=True)):
163
+ progress(0, desc="Starting")
164
 
165
+ # Sanitize input
166
+ prompt = sanitize_prompt(prompt)
167
 
168
  def process_img2img(image, prompt="a person", strength=0.75, seed=0, num_inference_steps=4):
169
+ # Secure image handling
170
+ image = secure_image_handler(image)
171
+
172
  if image is None:
173
+ print("Empty input image returned")
174
  return None
175
 
 
 
 
 
 
 
 
 
 
176
  generator = torch.Generator(device).manual_seed(seed)
177
+ fit_width, fit_height = convert_to_fit_size(image.size)
178
  width, height = adjust_to_multiple_of_32(fit_width, fit_height)
179
  image = image.resize((width, height), Image.LANCZOS)
180
+
181
+ # Process the image
182
+ output = pipe(
183
+ prompt=prompt,
184
+ image=image,
185
+ generator=generator,
186
+ strength=strength,
187
+ width=width,
188
+ height=height,
189
+ guidance_scale=0,
190
+ num_inference_steps=num_inference_steps,
191
+ max_sequence_length=256
192
+ )
193
+
 
 
 
 
 
 
194
  pil_image = output.images[0]
195
  new_width, new_height = pil_image.size
196
 
 
199
  return resized_image
200
  return pil_image
201
 
202
+ # Process the image
203
  output = process_img2img(image, prompt, strength, seed, inference_step)
204
+
205
+ # Instead of returning image directly, save to secure temp location and return path
206
+ if output is not None:
207
+ # Convert output to secure in-memory format
208
+ output_buffer = io.BytesIO()
209
+ output.save(output_buffer, format="PNG")
210
+ output_buffer.seek(0)
211
+
212
+ # Return the image data directly (Gradio will handle it)
213
+ return output
214
+
215
+ return None
216
 
217
  def read_file(path: str) -> str:
218
  with open(path, 'r', encoding='utf-8') as f:
 
246
  }
247
  """
248
 
249
+ # Custom HTTP headers for security
250
+ custom_headers = {
251
+ "Strict-Transport-Security": "max-age=63072000; includeSubDomains; preload",
252
+ "X-Content-Type-Options": "nosniff",
253
+ "X-Frame-Options": "SAMEORIGIN",
254
+ "Content-Security-Policy": "default-src 'self'; img-src 'self' data:; style-src 'self' 'unsafe-inline';",
255
+ "Referrer-Policy": "strict-origin-when-cross-origin",
256
+ "Permissions-Policy": "camera=(), microphone=(), geolocation=()",
257
+ "Cache-Control": "no-store, max-age=0"
258
+ }
259
+
260
+ # Create Gradio app with enhanced security
261
  with gr.Blocks(css=css, elem_id="demo-container") as demo:
262
  with gr.Column():
263
  gr.HTML(read_file("demo_header.html"))
264
  gr.HTML(read_file("demo_tools.html"))
265
  with gr.Row():
266
+ with gr.Column():
267
+ image = gr.Image(
268
+ height=800,
269
+ sources=['upload','clipboard'],
270
+ image_mode='RGB',
271
+ elem_id="image_upload",
272
+ type="pil",
273
+ label="Upload"
274
+ )
275
+ with gr.Row(elem_id="prompt-container", equal_height=False):
276
+ with gr.Row():
277
+ prompt = gr.Textbox(
278
+ label="Prompt",
279
+ value="a women",
280
+ placeholder="Your prompt (what you want in place of what is erased)",
281
+ elem_id="prompt"
282
+ )
283
 
284
+ btn = gr.Button("Img2Img", elem_id="run_button", variant="primary")
285
+
286
+ with gr.Accordion(label="Advanced Settings", open=False):
287
+ with gr.Row(equal_height=True):
288
+ strength = gr.Number(value=0.75, minimum=0, maximum=0.75, step=0.01, label="strength")
289
+ seed = gr.Number(value=100, minimum=0, step=1, label="seed")
290
+ inference_step = gr.Number(value=4, minimum=1, step=4, label="inference_step")
291
+ id_input=gr.Text(label="Name", visible=False)
292
+
293
+ with gr.Column():
294
+ image_out = gr.Image(
295
+ height=800,
296
+ sources=[],
297
+ label="Output",
298
+ elem_id="output-img",
299
+ format="jpg"
300
+ )
301
+
302
  gr.Examples(
303
+ examples=[
304
+ ["examples/draw_input.jpg", "examples/draw_output.jpg", "a women ,eyes closed,mouth opened"],
305
+ ["examples/draw-gimp_input.jpg", "examples/draw-gimp_output.jpg", "a women ,eyes closed,mouth opened"],
306
+ ["examples/gimp_input.jpg", "examples/gimp_output.jpg", "a women ,hand on neck"],
307
+ ["examples/inpaint_input.jpg", "examples/inpaint_output.jpg", "a women ,hand on neck"]
308
+ ],
309
+ inputs=[image, image_out, prompt],
 
310
  )
311
  gr.HTML(
312
+ gr.HTML(read_file("demo_footer.html"))
313
  )
314
  gr.on(
315
  triggers=[btn.click, prompt.submit],
316
+ fn=process_images,
317
+ inputs=[image, prompt, strength, seed, inference_step],
318
+ outputs=[image_out]
319
  )
320
 
321
+ # Register shutdown handler to clean up
322
+ import atexit
323
+ atexit.register(secure_temp.cleanup_all)
324
+
325
  if __name__ == "__main__":
326
+ # Launch with security settings
327
+ demo.launch(
328
+ share=True,
329
+ show_error=True,
330
+ ssl_verify=True,
331
+ ssl_certfile="server.crt", # You'll need to generate these
332
+ ssl_keyfile="server.key", # for production use
333
+ ssl_keyfile_password=None,
334
+ ssl_client_cert_chain="chain.pem", # Optional for client cert verification
335
+ favicon_path=None,
336
+ server_name="0.0.0.0", # Listen on all interfaces
337
+ server_port=7860, # Default Gradio port
338
+ inbrowser=False,
339
+ debug=False, # Disable in production
340
+ quiet=True, # Less logging for security
341
+ height=900,
342
+ width=1600,
343
+ enable_queue=True,
344
+ max_threads=20,
345
+ auth=None, # Enable if you need authentication
346
+ root_path=""
347
+ )