Qwen-Qwen-Image / app.py
hari7261's picture
Update app.py
b9acbf8 verified
raw
history blame
12.1 kB
import gradio as gr
import numpy as np
import random
from diffusers import AutoPipelineForText2Image
import torch
from time import time
device = "cuda" if torch.cuda.is_available() else "cpu"
model_repo_id = "stabilityai/sdxl-turbo"
# Load model with optimizations
if torch.cuda.is_available():
torch_dtype = torch.float16
pipe = AutoPipelineForText2Image.from_pretrained(
model_repo_id,
torch_dtype=torch_dtype,
variant="fp16",
use_safetensors=True
)
pipe.enable_model_cpu_offload() # For better memory management
pipe.enable_xformers_memory_efficient_attention() # Faster attention
else:
torch_dtype = torch.float32
pipe = AutoPipelineForText2Image.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
def infer(
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
progress=gr.Progress(track_tqdm=True),
):
start_time = time()
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device=device).manual_seed(seed)
# Generate image with progress updates
with torch.no_grad():
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0]
gen_time = time() - start_time
return image, seed, f"Generated in {gen_time:.2f}s"
examples = [
["Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", 1024, 1024],
["An astronaut riding a green horse", 768, 768],
["A delicious ceviche cheesecake slice", 896, 896],
]
css = """
:root {
--primary-50: #f0f0ff;
--primary-100: #e0e0ff;
--primary-200: #c7c7fe;
--primary-300: #a5a5fc;
--primary-400: #8181f8;
--primary-500: #6e6af0;
--primary-600: #5a56e4;
--primary-700: #4a46c9;
--primary-800: #3e3ba3;
--primary-900: #383682;
--primary-950: #211f4d;
--surface-0: 255 255 255;
--surface-50: 248 250 252;
--surface-100: 241 245 249;
--surface-200: 226 232 240;
--surface-300: 203 213 225;
--surface-400: 148 163 184;
--surface-500: 100 116 139;
--surface-600: 71 85 105;
--surface-700: 45 55 72;
--surface-800: 30 41 59;
--surface-900: 15 23 42;
--surface-950: 3 6 23;
--text-primary: rgb(var(--surface-900));
--text-secondary: rgb(var(--surface-600));
}
.dark {
--primary-50: #211f4d;
--primary-100: #383682;
--primary-200: #3e3ba3;
--primary-300: #4a46c9;
--primary-400: #5a56e4;
--primary-500: #6e6af0;
--primary-600: #8181f8;
--primary-700: #a5a5fc;
--primary-800: #c7c7fe;
--primary-900: #e0e0ff;
--primary-950: #f0f0ff;
--text-primary: rgb(var(--surface-100));
--text-secondary: rgb(var(--surface-300));
}
#col-container {
margin: 0 auto;
max-width: 800px;
padding: 20px;
}
.header {
text-align: center;
margin-bottom: 20px;
}
.header h1 {
font-size: 2.5rem;
font-weight: 700;
color: var(--primary-500);
margin-bottom: 10px;
}
.header p {
color: var(--text-secondary);
}
.prompt-container, .result-container, .advanced-settings {
background-color: rgb(var(--surface-50));
border-radius: 12px;
padding: 20px;
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
margin-bottom: 20px;
border: 1px solid rgb(var(--surface-200));
}
.dark .prompt-container,
.dark .result-container,
.dark .advanced-settings {
background-color: rgb(var(--surface-800));
border-color: rgb(var(--surface-700));
}
.advanced-settings .form {
display: grid;
grid-template-columns: 1fr 1fr;
gap: 16px;
}
.advanced-settings .form > * {
margin-bottom: 0 !important;
}
.control-row {
display: flex;
gap: 10px;
align-items: center;
}
.btn-primary {
background: var(--primary-500) !important;
border: none !important;
color: white !important;
}
.btn-primary:hover {
background: var(--primary-600) !important;
}
.examples {
display: grid;
grid-template-columns: repeat(auto-fill, minmax(250px, 1fr));
gap: 12px;
margin-top: 20px;
}
.example-prompt {
background: rgb(var(--surface-100));
padding: 12px;
border-radius: 8px;
cursor: pointer;
transition: all 0.2s;
border: 1px solid rgb(var(--surface-200));
}
.dark .example-prompt {
background: rgb(var(--surface-700));
border-color: rgb(var(--surface-600));
}
.example-prompt:hover {
background: var(--primary-100);
transform: translateY(-2px);
border-color: var(--primary-300);
}
.dark .example-prompt:hover {
background: var(--primary-800);
border-color: var(--primary-600);
}
.example-img {
width: 100%;
height: 120px;
object-fit: cover;
border-radius: 6px;
margin-bottom: 8px;
}
/* Theme toggle button */
.theme-toggle {
position: absolute;
top: 20px;
right: 20px;
background: var(--primary-100);
border: none;
border-radius: 50%;
width: 40px;
height: 40px;
display: flex;
align-items: center;
justify-content: center;
cursor: pointer;
transition: all 0.2s;
}
.theme-toggle:hover {
background: var(--primary-200);
}
.dark .theme-toggle {
background: var(--primary-800);
}
.dark .theme-toggle:hover {
background: var(--primary-700);
}
@media (max-width: 768px) {
.advanced-settings .form {
grid-template-columns: 1fr;
}
.theme-toggle {
top: 10px;
right: 10px;
}
}
/* Loading animation */
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
.loading-spinner {
display: inline-block;
width: 20px;
height: 20px;
border: 3px solid rgba(255, 255, 255, 0.3);
border-radius: 50%;
border-top-color: white;
animation: spin 1s ease-in-out infinite;
margin-right: 8px;
}
"""
js = """
function toggleTheme() {
const body = document.body;
body.classList.toggle('dark');
localStorage.setItem('gradio-theme', body.classList.contains('dark') ? 'dark' : 'light');
}
document.addEventListener('DOMContentLoaded', () => {
const savedTheme = localStorage.getItem('gradio-theme') || 'light';
if (savedTheme === 'dark') {
document.body.classList.add('dark');
}
const themeToggle = document.createElement('button');
themeToggle.className = 'theme-toggle';
themeToggle.innerHTML = savedTheme === 'dark' ? '☀️' : '🌙';
themeToggle.onclick = toggleTheme;
document.body.appendChild(themeToggle);
// Update icon when theme changes
document.body.addEventListener('click', (e) => {
if (e.target === themeToggle) {
themeToggle.innerHTML = document.body.classList.contains('dark') ? '☀️' : '🌙';
}
});
});
"""
with gr.Blocks(css=css, js=js, theme=gr.themes.Soft()) as demo:
with gr.Column(elem_id="col-container"):
with gr.Column(visible=True) as header:
gr.Markdown(
"""
<div class="header">
<h1>✨ AI Image Generator</h1>
<p>Transform your text into stunning images with SDXL Turbo</p>
</div>
""",
elem_classes="header"
)
with gr.Column(elem_classes="prompt-container"):
with gr.Row():
prompt = gr.Textbox(
label="",
show_label=False,
max_lines=2,
placeholder="Describe the image you want to generate...",
container=False,
scale=5
)
run_button = gr.Button(
"Generate",
scale=1,
variant="primary",
elem_classes="btn-primary"
)
with gr.Column(elem_classes="result-container"):
result = gr.Image(
label="Generated Image",
show_label=False,
height=500,
elem_id="output-image"
)
with gr.Row():
seed_info = gr.Textbox(
label="Seed",
interactive=False
)
time_info = gr.Textbox(
label="Generation Time",
interactive=False
)
with gr.Accordion("🛠️ Advanced Settings", open=False, elem_classes="advanced-settings"):
with gr.Column(elem_classes="form"):
with gr.Row():
negative_prompt = gr.Textbox(
label="Negative Prompt",
max_lines=1,
placeholder="What you don't want to see in the image",
)
with gr.Row():
with gr.Column():
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(
label="Randomize seed",
value=True,
interactive=True
)
with gr.Column():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=0.0,
)
num_inference_steps = gr.Slider(
label="Inference Steps",
minimum=1,
maximum=50,
step=1,
value=2,
)
gr.Markdown("### Example Prompts")
with gr.Row(elem_classes="examples"):
for example in examples:
with gr.Column(min_width=200):
gr.Examples(
examples=[[example[0], example[1], example[2]]],
inputs=[prompt, width, height],
label="",
examples_per_page=20,
elem_id=f"example-{example[0][:10]}"
)
run_button.click(
fn=infer,
inputs=[
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
],
outputs=[result, seed_info, time_info],
)
prompt.submit(
fn=infer,
inputs=[
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
],
outputs=[result, seed_info, time_info],
)
if __name__ == "__main__":
demo.queue(api_open=False).launch()