import os import io import json import base64 import re from PIL import Image import numpy as np import torch import gradio as gr import spaces from diffusers import FluxImg2ImgPipeline from cryptography.fernet import Fernet from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC # Retrieve the encryption key from the environment (set in Hugging Face Secrets Manager) ENCRYPTION_KEY = os.environ.get("key", "FAKEFALLBACKKEY_FOR_LOCAL_TESTING") dtype = torch.bfloat16 device = "cuda" if torch.cuda.is_available() else "cpu" pipe = FluxImg2ImgPipeline.from_pretrained( "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16 ).to(device) def generate_key(password, salt=None): if salt is None: salt = os.urandom(16) kdf = PBKDF2HMAC( algorithm=hashes.SHA256(), length=32, salt=salt, iterations=100000, ) key = base64.urlsafe_b64encode(kdf.derive(password.encode())) return key, salt def encrypt_image(image, password=None): # Use the secure key if no override is provided if password is None: password = ENCRYPTION_KEY # Convert PIL Image to bytes img_byte_arr = io.BytesIO() image.save(img_byte_arr, format='PNG') img_byte_arr = img_byte_arr.getvalue() # Generate key for encryption using the secure password key, salt = generate_key(password) cipher = Fernet(key) encrypted_data = cipher.encrypt(img_byte_arr) return { 'encrypted_data': base64.b64encode(encrypted_data).decode('utf-8'), 'salt': base64.b64encode(salt).decode('utf-8'), 'original_width': image.width, 'original_height': image.height } def decrypt_image(encrypted_data_dict, password=None): if password is None: password = ENCRYPTION_KEY # Extract the encrypted data and salt encrypted_data = base64.b64decode(encrypted_data_dict['encrypted_data']) salt = base64.b64decode(encrypted_data_dict['salt']) # Regenerate the key using the secure password and salt key, _ = generate_key(password, salt) cipher = Fernet(key) decrypted_data = cipher.decrypt(encrypted_data) image = Image.open(io.BytesIO(decrypted_data)) return image def sanitize_prompt(prompt): # Allow only alphanumeric characters, spaces, and basic punctuation allowed_chars = re.compile(r"[^a-zA-Z0-9\s.,!?-]") sanitized_prompt = allowed_chars.sub("", prompt) return sanitized_prompt def convert_to_fit_size(original_width_and_height, maximum_size=2048): width, height = original_width_and_height if width <= maximum_size and height <= maximum_size: return width, height if width > height: scaling_factor = maximum_size / width else: scaling_factor = maximum_size / height new_width = int(width * scaling_factor) new_height = int(height * scaling_factor) return new_width, new_height def adjust_to_multiple_of_32(width: int, height: int): width = width - (width % 32) height = height - (height % 32) return width, height @spaces.GPU(duration=90) def process_images(image, prompt="a girl", strength=0.75, seed=0, inference_step=4, progress=gr.Progress(track_tqdm=True)): progress(0, desc="Starting") def process_img2img(image, prompt="a person", strength=0.75, seed=0, num_inference_steps=4): if image is None: print("Empty input image returned") return None generator = torch.Generator(device).manual_seed(seed) fit_width, fit_height = convert_to_fit_size(image.size) width, height = adjust_to_multiple_of_32(fit_width, fit_height) image = image.resize((width, height), Image.LANCZOS) output = pipe( prompt=prompt, image=image, generator=generator, strength=strength, width=width, height=height, guidance_scale=0, num_inference_steps=num_inference_steps, max_sequence_length=256 ) pil_image = output.images[0] new_width, new_height = pil_image.size if (new_width != fit_width) or (new_height != fit_height): resized_image = pil_image.resize((fit_width, fit_height), Image.LANCZOS) return resized_image return pil_image output = process_img2img(image, prompt, strength, seed, inference_step) # Encrypt the output image using the secure key if output is not None: encrypted_output = encrypt_image(output) return { "display_image": output, "encrypted_data": encrypted_output } return None def save_encrypted_image(encrypted_data, filename="encrypted_image.enc"): with open(filename, 'w') as f: json.dump(encrypted_data, f) return f"Encrypted image saved as {filename}" def read_file(path: str) -> str: with open(path, 'r', encoding='utf-8') as f: content = f.read() return content css = """ #col-left { margin: 0 auto; max-width: 640px; } #col-right { margin: 0 auto; max-width: 640px; } .grid-container { display: flex; align-items: center; justify-content: center; gap:10px; } .image { width: 128px; height: 128px; object-fit: cover; } .text { font-size: 16px; } .encryption-notice { background-color: #f0f0f0; padding: 15px; border-radius: 5px; margin-top: 10px; text-align: center; } """ with gr.Blocks(css=css, elem_id="demo-container") as demo: # Store encrypted data in a state variable encrypted_output_state = gr.State(None) with gr.Column(): gr.HTML(read_file("demo_header.html")) gr.HTML(read_file("demo_tools.html")) with gr.Row(): with gr.Column(): image = gr.Image( height=800, sources=['upload', 'clipboard'], image_mode='RGB', elem_id="image_upload", type="pil", label="Upload" ) with gr.Row(elem_id="prompt-container", equal_height=False): with gr.Row(): prompt = gr.Textbox( label="Prompt", value="a women", placeholder="Your prompt (what you want in place of what is erased)", elem_id="prompt" ) btn = gr.Button("Img2Img", elem_id="run_button", variant="primary") with gr.Accordion(label="Advanced Settings", open=False): with gr.Row(equal_height=True): strength = gr.Number( value=0.75, minimum=0, maximum=0.75, step=0.01, label="Strength" ) seed = gr.Number( value=100, minimum=0, step=1, label="Seed" ) inference_step = gr.Number( value=4, minimum=1, step=4, label="Inference Steps" ) id_input = gr.Text(label="Name", visible=False) with gr.Column(): image_out = gr.Image( height=800, sources=[], label="Output (Encrypted)", elem_id="output-img", format="jpg" ) encryption_notice = gr.HTML( '
' 'The output image is encrypted. Use the Save button to download the encrypted file.' '
' ) save_btn = gr.Button("Save Encrypted Image") save_result = gr.Text(label="Save Result") # Examples section gr.Examples( examples=[ ["examples/draw_input.jpg", "examples/draw_output.jpg", "a women, eyes closed, mouth opened"], ["examples/draw-gimp_input.jpg", "examples/draw-gimp_output.jpg", "a women, eyes closed, mouth opened"], ["examples/gimp_input.jpg", "examples/gimp_output.jpg", "a women, hand on neck"], ["examples/inpaint_input.jpg", "examples/inpaint_output.jpg", "a women, hand on neck"] ], inputs=[image, image_out, prompt], ) gr.HTML(read_file("demo_footer.html")) # Process images and encrypt outputs def handle_image_generation(image, prompt, strength, seed, inference_step): result = process_images(image, prompt, strength, seed, inference_step) if result: return result["display_image"], result["encrypted_data"] return None, None btn.click( fn=handle_image_generation, inputs=[image, prompt, strength, seed, inference_step], outputs=[image_out, encrypted_output_state], api_name="/process_images" ) prompt.submit( fn=handle_image_generation, inputs=[image, prompt, strength, seed, inference_step], outputs=[image_out, encrypted_output_state], api_name="/process_images" ) def handle_save_encrypted(encrypted_data): if encrypted_data: import tempfile fd, path = tempfile.mkstemp(suffix='.encimg') with os.fdopen(fd, 'w') as f: json.dump(encrypted_data, f) return f"Encrypted image saved to {path}" return "No encrypted image to save" save_btn.click( fn=handle_save_encrypted, inputs=[encrypted_output_state], outputs=[save_result] ) if __name__ == "__main__": demo.launch(share=True, show_error=True)