img2img_test / app.py
Gemini899's picture
Update app.py
1d3fc1d verified
import os
import io
import json
import base64
import re
from PIL import Image
import numpy as np
import torch
import gradio as gr
import spaces
from diffusers import FluxImg2ImgPipeline
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
# Retrieve the encryption key from the environment (set in Hugging Face Secrets Manager)
ENCRYPTION_KEY = os.environ.get("key", "FAKEFALLBACKKEY_FOR_LOCAL_TESTING")
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = FluxImg2ImgPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16
).to(device)
def generate_key(password, salt=None):
if salt is None:
salt = os.urandom(16)
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=salt,
iterations=100000,
)
key = base64.urlsafe_b64encode(kdf.derive(password.encode()))
return key, salt
def encrypt_image(image, password=None):
# Use the secure key if no override is provided
if password is None:
password = ENCRYPTION_KEY
# Convert PIL Image to bytes
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
# Generate key for encryption using the secure password
key, salt = generate_key(password)
cipher = Fernet(key)
encrypted_data = cipher.encrypt(img_byte_arr)
return {
'encrypted_data': base64.b64encode(encrypted_data).decode('utf-8'),
'salt': base64.b64encode(salt).decode('utf-8'),
'original_width': image.width,
'original_height': image.height
}
def decrypt_image(encrypted_data_dict, password=None):
if password is None:
password = ENCRYPTION_KEY
# Extract the encrypted data and salt
encrypted_data = base64.b64decode(encrypted_data_dict['encrypted_data'])
salt = base64.b64decode(encrypted_data_dict['salt'])
# Regenerate the key using the secure password and salt
key, _ = generate_key(password, salt)
cipher = Fernet(key)
decrypted_data = cipher.decrypt(encrypted_data)
image = Image.open(io.BytesIO(decrypted_data))
return image
def sanitize_prompt(prompt):
# Allow only alphanumeric characters, spaces, and basic punctuation
allowed_chars = re.compile(r"[^a-zA-Z0-9\s.,!?-]")
sanitized_prompt = allowed_chars.sub("", prompt)
return sanitized_prompt
def convert_to_fit_size(original_width_and_height, maximum_size=2048):
width, height = original_width_and_height
if width <= maximum_size and height <= maximum_size:
return width, height
if width > height:
scaling_factor = maximum_size / width
else:
scaling_factor = maximum_size / height
new_width = int(width * scaling_factor)
new_height = int(height * scaling_factor)
return new_width, new_height
def adjust_to_multiple_of_32(width: int, height: int):
width = width - (width % 32)
height = height - (height % 32)
return width, height
@spaces.GPU(duration=90)
def process_images(image, prompt="a girl", strength=0.75, seed=0, inference_step=4, progress=gr.Progress(track_tqdm=True)):
progress(0, desc="Starting")
def process_img2img(image, prompt="a person", strength=0.75, seed=0, num_inference_steps=4):
if image is None:
print("Empty input image returned")
return None
generator = torch.Generator(device).manual_seed(seed)
fit_width, fit_height = convert_to_fit_size(image.size)
width, height = adjust_to_multiple_of_32(fit_width, fit_height)
image = image.resize((width, height), Image.LANCZOS)
output = pipe(
prompt=prompt,
image=image,
generator=generator,
strength=strength,
width=width,
height=height,
guidance_scale=0,
num_inference_steps=num_inference_steps,
max_sequence_length=256
)
pil_image = output.images[0]
new_width, new_height = pil_image.size
if (new_width != fit_width) or (new_height != fit_height):
resized_image = pil_image.resize((fit_width, fit_height), Image.LANCZOS)
return resized_image
return pil_image
output = process_img2img(image, prompt, strength, seed, inference_step)
# Encrypt the output image using the secure key
if output is not None:
encrypted_output = encrypt_image(output)
return {
"display_image": output,
"encrypted_data": encrypted_output
}
return None
def save_encrypted_image(encrypted_data, filename="encrypted_image.enc"):
with open(filename, 'w') as f:
json.dump(encrypted_data, f)
return f"Encrypted image saved as {filename}"
def read_file(path: str) -> str:
with open(path, 'r', encoding='utf-8') as f:
content = f.read()
return content
css = """
#col-left {
margin: 0 auto;
max-width: 640px;
}
#col-right {
margin: 0 auto;
max-width: 640px;
}
.grid-container {
display: flex;
align-items: center;
justify-content: center;
gap:10px;
}
.image {
width: 128px;
height: 128px;
object-fit: cover;
}
.text {
font-size: 16px;
}
.encryption-notice {
background-color: #f0f0f0;
padding: 15px;
border-radius: 5px;
margin-top: 10px;
text-align: center;
}
"""
with gr.Blocks(css=css, elem_id="demo-container") as demo:
# Store encrypted data in a state variable
encrypted_output_state = gr.State(None)
with gr.Column():
gr.HTML(read_file("demo_header.html"))
gr.HTML(read_file("demo_tools.html"))
with gr.Row():
with gr.Column():
image = gr.Image(
height=800,
sources=['upload', 'clipboard'],
image_mode='RGB',
elem_id="image_upload",
type="pil",
label="Upload"
)
with gr.Row(elem_id="prompt-container", equal_height=False):
with gr.Row():
prompt = gr.Textbox(
label="Prompt",
value="a women",
placeholder="Your prompt (what you want in place of what is erased)",
elem_id="prompt"
)
btn = gr.Button("Img2Img", elem_id="run_button", variant="primary")
with gr.Accordion(label="Advanced Settings", open=False):
with gr.Row(equal_height=True):
strength = gr.Number(
value=0.75, minimum=0, maximum=0.75, step=0.01, label="Strength"
)
seed = gr.Number(
value=100, minimum=0, step=1, label="Seed"
)
inference_step = gr.Number(
value=4, minimum=1, step=4, label="Inference Steps"
)
id_input = gr.Text(label="Name", visible=False)
with gr.Column():
image_out = gr.Image(
height=800,
sources=[],
label="Output (Encrypted)",
elem_id="output-img",
format="jpg"
)
encryption_notice = gr.HTML(
'<div class="encryption-notice">'
'The output image is encrypted. Use the Save button to download the encrypted file.'
'</div>'
)
save_btn = gr.Button("Save Encrypted Image")
save_result = gr.Text(label="Save Result")
# Examples section
gr.Examples(
examples=[
["examples/draw_input.jpg", "examples/draw_output.jpg", "a women, eyes closed, mouth opened"],
["examples/draw-gimp_input.jpg", "examples/draw-gimp_output.jpg", "a women, eyes closed, mouth opened"],
["examples/gimp_input.jpg", "examples/gimp_output.jpg", "a women, hand on neck"],
["examples/inpaint_input.jpg", "examples/inpaint_output.jpg", "a women, hand on neck"]
],
inputs=[image, image_out, prompt],
)
gr.HTML(read_file("demo_footer.html"))
# Process images and encrypt outputs
def handle_image_generation(image, prompt, strength, seed, inference_step):
result = process_images(image, prompt, strength, seed, inference_step)
if result:
return result["display_image"], result["encrypted_data"]
return None, None
btn.click(
fn=handle_image_generation,
inputs=[image, prompt, strength, seed, inference_step],
outputs=[image_out, encrypted_output_state],
api_name="/process_images"
)
prompt.submit(
fn=handle_image_generation,
inputs=[image, prompt, strength, seed, inference_step],
outputs=[image_out, encrypted_output_state],
api_name="/process_images"
)
def handle_save_encrypted(encrypted_data):
if encrypted_data:
import tempfile
fd, path = tempfile.mkstemp(suffix='.encimg')
with os.fdopen(fd, 'w') as f:
json.dump(encrypted_data, f)
return f"Encrypted image saved to {path}"
return "No encrypted image to save"
save_btn.click(
fn=handle_save_encrypted,
inputs=[encrypted_output_state],
outputs=[save_result]
)
if __name__ == "__main__":
demo.launch(share=True, show_error=True)