Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import base64 | |
from random import randint | |
from all_models import models | |
from io import BytesIO | |
from PIL import Image | |
from fastapi import FastAPI, Request | |
from deep_translator import GoogleTranslator | |
import tempfile | |
import uuid | |
css_code = os.getenv("DazDinGo_CSS") | |
# Initialize translator | |
translator = GoogleTranslator(source='auto', target='en') | |
# Load models | |
models_load = {} | |
for model in models: | |
try: | |
models_load[model] = gr.load(f'models/{model}') | |
except Exception as error: | |
models_load[model] = gr.Interface(lambda txt: None, ['text'], ['image']) | |
app = FastAPI() | |
def convert_to_png(image): | |
"""Convert any image format to true PNG format""" | |
png_buffer = BytesIO() | |
if image.mode == 'RGBA': | |
image.save(png_buffer, format='PNG', optimize=True) | |
else: | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
image.save(png_buffer, format='PNG', optimize=True) | |
png_buffer.seek(0) | |
return Image.open(png_buffer) | |
def gen_image(model_str, prompt): | |
if model_str == 'NA': | |
return None, None | |
translated_prompt = translator.translate(prompt) | |
seed = randint(0, 4294967296) | |
noise = str(seed) | |
klir = '| ultra detail, ultra elaboration, ultra quality, perfect' | |
# Generate image | |
generated_image = models_load[model_str](f'{translated_prompt} {noise} {klir}') | |
if generated_image is None: | |
return None, None | |
# Convert to PIL Image | |
if isinstance(generated_image, str): | |
generated_image = Image.open(generated_image) | |
elif not isinstance(generated_image, Image.Image): | |
generated_image = Image.fromarray(generated_image) | |
# Create temp directory | |
temp_dir = os.path.join(tempfile.gettempdir(), "gradio_images") | |
os.makedirs(temp_dir, exist_ok=True) | |
# Save to temporary file | |
temp_path = os.path.join(temp_dir, f"{uuid.uuid4()}.png") | |
generated_image.save(temp_path, format="PNG") | |
return temp_path, seed | |
# Gradio Interface | |
def make_me(): | |
with gr.Row(): | |
with gr.Column(scale=3): | |
txt_input = gr.Textbox( | |
label='Your prompt:', | |
lines=4, | |
container=False, | |
elem_id="custom_textbox", | |
placeholder="Enter your prompt here..." | |
) | |
with gr.Row(): | |
gen_button = gr.Button('Generate', variant='primary', elem_id="custom_gen_button") | |
stop_button = gr.Button('Stop', variant='secondary', interactive=False, | |
elem_id="custom_stop_button") | |
def on_generate_click(): | |
return gr.Button('Generating...', interactive=False), gr.Button('Stop', interactive=True) | |
def on_stop_click(): | |
return gr.Button('Generate', interactive=True), gr.Button('Stop', interactive=False) | |
gen_button.click(on_generate_click, None, [gen_button, stop_button]) | |
stop_button.click(on_stop_click, None, [gen_button, stop_button]) | |
model_dropdown = gr.Dropdown( | |
models, | |
label="Select Model", | |
value=models[0] if models else None | |
) | |
with gr.Column(scale=2): | |
output_gallery = gr.Gallery( | |
label="Generated PNG Images", | |
columns=2, | |
height="auto", | |
elem_id="gallery" | |
) | |
seed_output = gr.Textbox( | |
label="Seed used", | |
interactive=False | |
) | |
def generate_wrapper(model_str, prompt): | |
image_path, seed = gen_image(model_str, prompt) | |
if image_path is None: | |
return None, "" | |
return [image_path], str(seed) | |
gen_event = gen_button.click( | |
generate_wrapper, | |
[model_dropdown, txt_input], | |
[output_gallery, seed_output] | |
) | |
stop_button.click( | |
on_stop_click, | |
None, | |
[gen_button, stop_button], | |
cancels=[gen_event] | |
) | |
# Create Gradio app | |
with gr.Blocks(css=css_code, title="Image Generator") as demo: | |
gr.Markdown("# Image Generation Tool") | |
gr.Markdown("Enter your prompt and select a model to generate an image") | |
make_me() | |
# Enable queue before mounting | |
demo.queue() | |
# Mount Gradio app to FastAPI | |
app = gr.mount_gradio_app(app, demo, path="/") | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) |