import os import io import json import base64 import re from PIL import Image import numpy as np import torch import gradio as gr import spaces from diffusers import FluxImg2ImgPipeline from cryptography.fernet import Fernet from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC # Retrieve the encryption key from the environment (set in Hugging Face Secrets Manager) ENCRYPTION_KEY = os.environ.get("key", "FAKEFALLBACKKEY_FOR_LOCAL_TESTING") dtype = torch.bfloat16 device = "cuda" if torch.cuda.is_available() else "cpu" pipe = FluxImg2ImgPipeline.from_pretrained( "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16 ).to(device) def generate_key(password, salt=None): if salt is None: salt = os.urandom(16) kdf = PBKDF2HMAC( algorithm=hashes.SHA256(), length=32, salt=salt, iterations=100000, ) key = base64.urlsafe_b64encode(kdf.derive(password.encode())) return key, salt def encrypt_image(image, password=None): # Use the secure key if no override is provided if password is None: password = ENCRYPTION_KEY # Convert PIL Image to bytes img_byte_arr = io.BytesIO() image.save(img_byte_arr, format='PNG') img_byte_arr = img_byte_arr.getvalue() # Generate key for encryption using the secure password key, salt = generate_key(password) cipher = Fernet(key) encrypted_data = cipher.encrypt(img_byte_arr) return { 'encrypted_data': base64.b64encode(encrypted_data).decode('utf-8'), 'salt': base64.b64encode(salt).decode('utf-8'), 'original_width': image.width, 'original_height': image.height } def decrypt_image(encrypted_data_dict, password=None): if password is None: password = ENCRYPTION_KEY # Extract the encrypted data and salt encrypted_data = base64.b64decode(encrypted_data_dict['encrypted_data']) salt = base64.b64decode(encrypted_data_dict['salt']) # Regenerate the key using the secure password and salt key, _ = generate_key(password, salt) cipher = Fernet(key) decrypted_data = cipher.decrypt(encrypted_data) image = Image.open(io.BytesIO(decrypted_data)) return image def sanitize_prompt(prompt): # Allow only alphanumeric characters, spaces, and basic punctuation allowed_chars = re.compile(r"[^a-zA-Z0-9\s.,!?-]") sanitized_prompt = allowed_chars.sub("", prompt) return sanitized_prompt def convert_to_fit_size(original_width_and_height, maximum_size=2048): width, height = original_width_and_height if width <= maximum_size and height <= maximum_size: return width, height if width > height: scaling_factor = maximum_size / width else: scaling_factor = maximum_size / height new_width = int(width * scaling_factor) new_height = int(height * scaling_factor) return new_width, new_height def adjust_to_multiple_of_32(width: int, height: int): width = width - (width % 32) height = height - (height % 32) return width, height @spaces.GPU(duration=90) def process_images(image, prompt="a girl", strength=0.75, seed=0, inference_step=4, progress=gr.Progress(track_tqdm=True)): progress(0, desc="Starting") def process_img2img(image, prompt="a person", strength=0.75, seed=0, num_inference_steps=4): if image is None: print("Empty input image returned") return None generator = torch.Generator(device).manual_seed(seed) fit_width, fit_height = convert_to_fit_size(image.size) width, height = adjust_to_multiple_of_32(fit_width, fit_height) image = image.resize((width, height), Image.LANCZOS) output = pipe( prompt=prompt, image=image, generator=generator, strength=strength, width=width, height=height, guidance_scale=0, num_inference_steps=num_inference_steps, max_sequence_length=256 ) pil_image = output.images[0] new_width, new_height = pil_image.size if (new_width != fit_width) or (new_height != fit_height): resized_image = pil_image.resize((fit_width, fit_height), Image.LANCZOS) return resized_image return pil_image output = process_img2img(image, prompt, strength, seed, inference_step) # Encrypt the output image using the secure key if output is not None: encrypted_output = encrypt_image(output) return { "display_image": output, "encrypted_data": encrypted_output } return None def save_encrypted_image(encrypted_data, filename="encrypted_image.enc"): with open(filename, 'w') as f: json.dump(encrypted_data, f) return f"Encrypted image saved as {filename}" def read_file(path: str) -> str: with open(path, 'r', encoding='utf-8') as f: content = f.read() return content css = """ #col-left { margin: 0 auto; max-width: 640px; } #col-right { margin: 0 auto; max-width: 640px; } .grid-container { display: flex; align-items: center; justify-content: center; gap:10px; } .image { width: 128px; height: 128px; object-fit: cover; } .text { font-size: 16px; } .encryption-notice { background-color: #f0f0f0; padding: 15px; border-radius: 5px; margin-top: 10px; text-align: center; } """ with gr.Blocks(css=css, elem_id="demo-container") as demo: # Store encrypted data in a state variable encrypted_output_state = gr.State(None) with gr.Column(): gr.HTML(read_file("demo_header.html")) gr.HTML(read_file("demo_tools.html")) with gr.Row(): with gr.Column(): image = gr.Image( height=800, sources=['upload', 'clipboard'], image_mode='RGB', elem_id="image_upload", type="pil", label="Upload" ) with gr.Row(elem_id="prompt-container", equal_height=False): with gr.Row(): prompt = gr.Textbox( label="Prompt", value="a women", placeholder="Your prompt (what you want in place of what is erased)", elem_id="prompt" ) btn = gr.Button("Img2Img", elem_id="run_button", variant="primary") with gr.Accordion(label="Advanced Settings", open=False): with gr.Row(equal_height=True): strength = gr.Number( value=0.75, minimum=0, maximum=0.75, step=0.01, label="Strength" ) seed = gr.Number( value=100, minimum=0, step=1, label="Seed" ) inference_step = gr.Number( value=4, minimum=1, step=4, label="Inference Steps" ) id_input = gr.Text(label="Name", visible=False) with gr.Column(): image_out = gr.Image( height=800, sources=[], label="Output (Encrypted)", elem_id="output-img", format="jpg" ) encryption_notice = gr.HTML( '