Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import base64 | |
| from random import randint | |
| from all_models import models | |
| from io import BytesIO | |
| from PIL import Image | |
| from fastapi import FastAPI, Request | |
| from deep_translator import GoogleTranslator | |
| css_code = """ | |
| #custom_textbox { | |
| width: 100%; | |
| min-height: 150px; | |
| } | |
| #custom_gen_button { | |
| background: #4CAF50 !important; | |
| color: white !important; | |
| } | |
| #custom_stop_button { | |
| background: #F44336 !important; | |
| color: white !important; | |
| } | |
| #custom_image { | |
| width: 100%; | |
| max-height: 768px; | |
| } | |
| """ | |
| # Initialize translator | |
| translator = GoogleTranslator(source='auto', target='en') | |
| # Load models | |
| models_load = {} | |
| for model in models: | |
| try: | |
| models_load[model] = gr.load(f'models/{model}') | |
| except Exception as error: | |
| models_load[model] = gr.Interface(lambda txt: None, ['text'], ['image']) | |
| app = FastAPI() | |
| def gen_image(model_str, prompt): | |
| if model_str == 'NA': | |
| return None | |
| # Translate prompt to English {noise} {klir} | |
| translated_prompt = translator.translate(prompt) | |
| noise = str(randint(0, 4294967296)) | |
| klir = '| ultra detail, ultra elaboration, ultra quality, perfect' | |
| return models_load[model_str](f'{translated_prompt}') | |
| def image_to_base64(image): | |
| buffered = BytesIO() | |
| if isinstance(image, str): # if it's a file path | |
| img = Image.open(image) | |
| img.save(buffered, format="JPEG") | |
| else: # if it's a PIL Image | |
| image.save(buffered, format="JPEG") | |
| return base64.b64encode(buffered.getvalue()).decode() | |
| # API endpoint | |
| async def api_generate(request: Request): | |
| data = await request.json() | |
| model = data.get('model', models[0]) | |
| prompt = data.get('prompt', '') | |
| if model not in models: | |
| return {"error": "Model not found"} | |
| # Translate prompt to English for API endpoint too | |
| translated_prompt = translator.translate(prompt) | |
| image = gen_image(model, translated_prompt) | |
| if image is None: | |
| return {"error": "Image generation failed"} | |
| base64_str = image_to_base64(image) | |
| return { | |
| "status": "success", | |
| "model": model, | |
| "original_prompt": prompt, | |
| "translated_prompt": translated_prompt, | |
| "image_base64": base64_str, | |
| "image_format": "jpeg" | |
| } | |
| # Gradio Interface | |
| def make_me(): | |
| with gr.Row(): | |
| # Left Column (50% width) | |
| with gr.Column(scale=1, min_width=400): | |
| txt_input = gr.Textbox( | |
| label='Your prompt:', | |
| lines=4, | |
| container=False, | |
| elem_id="custom_textbox", | |
| placeholder="Enter your prompt here..." | |
| ) | |
| model_dropdown = gr.Dropdown( | |
| models, | |
| label="Select LoRA Model", | |
| value=models[0] if models else None, | |
| container=False | |
| ) | |
| with gr.Row(): | |
| gen_button = gr.Button( | |
| 'Generate Image', | |
| elem_id="custom_gen_button", | |
| variant='primary' | |
| ) | |
| stop_button = gr.Button( | |
| 'Stop', | |
| variant='stop', | |
| elem_id="custom_stop_button", | |
| interactive=False | |
| ) | |
| # Right Column (50% width) | |
| with gr.Column(scale=1, min_width=400): | |
| output_image = gr.Image( | |
| label="Generated Image", | |
| elem_id="custom_image", | |
| show_label=True, | |
| interactive=False | |
| ) | |
| json_output = gr.JSON( | |
| label="API Response", | |
| container=False | |
| ) | |
| # Functionality remains the same | |
| def generate_wrapper(model_str, prompt): | |
| # Translate prompt to English | |
| translated_prompt = translator.translate(prompt) | |
| image = gen_image(model_str, translated_prompt) | |
| if image is None: | |
| return None, {"error": "Generation failed"} | |
| base64_str = image_to_base64(image) | |
| response = { | |
| "status": "success", | |
| "model": model_str, | |
| "original_prompt": prompt, | |
| "translated_prompt": translated_prompt, | |
| "image_base64": base64_str, | |
| "image_format": "jpeg" | |
| } | |
| return image, response | |
| def on_generate_click(): | |
| return gr.Button(interactive=False), gr.Button(interactive=True) | |
| def on_stop_click(): | |
| return gr.Button(interactive=True), gr.Button(interactive=False) | |
| gen_event = gen_button.click( | |
| on_generate_click, | |
| inputs=None, | |
| outputs=[gen_button, stop_button] | |
| ).then( | |
| generate_wrapper, | |
| [model_dropdown, txt_input], | |
| [output_image, json_output] | |
| ).then( | |
| on_stop_click, | |
| inputs=None, | |
| outputs=[gen_button, stop_button] | |
| ) | |
| stop_button.click( | |
| on_stop_click, | |
| inputs=None, | |
| outputs=[gen_button, stop_button], | |
| cancels=[gen_event] | |
| ) | |
| # Create Gradio app | |
| with gr.Blocks(css=css_code, title="Image Generation App") as demo: | |
| gr.Markdown("# Image Generation Tool") | |
| make_me() | |
| # Enable queue before mounting | |
| demo.queue() | |
| # Mount Gradio app to FastAPI | |
| app = gr.mount_gradio_app(app, demo, path="/") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |