Qwen-Qwen-Image / app.py
hari7261's picture
Update app.py
a7d86f5 verified
raw
history blame
9.43 kB
import gradio as gr
import numpy as np
import random
from diffusers import DiffusionPipeline
import torch
from time import time
device = "cuda" if torch.cuda.is_available() else "cpu"
model_repo_id = "stabilityai/sdxl-turbo"
# Simplified model loading
try:
if torch.cuda.is_available():
torch_dtype = torch.float16
pipe = DiffusionPipeline.from_pretrained(
model_repo_id,
torch_dtype=torch_dtype,
variant="fp16",
use_safetensors=True
).to(device)
else:
torch_dtype = torch.float32
pipe = DiffusionPipeline.from_pretrained(
model_repo_id,
torch_dtype=torch_dtype
).to(device)
except Exception as e:
print(f"Error loading model: {e}")
raise
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
def infer(
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
progress=gr.Progress(track_tqdm=True),
):
try:
start_time = time()
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device=device).manual_seed(seed)
# Generate image
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0]
gen_time = time() - start_time
return image, seed, f"Generated in {gen_time:.2f}s"
except Exception as e:
print(f"Error during inference: {e}")
return None, seed, f"Error: {str(e)}"
examples = [
["Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", 1024, 1024],
["An astronaut riding a green horse", 768, 768],
["A delicious ceviche cheesecake slice", 896, 896],
]
css = """
:root {
--primary: #6e6af0;
--secondary: #f5f5f7;
--accent: #f5f5f7;
--text: #1e1e1e;
--shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
}
.dark {
--primary: #a5a5fc;
--secondary: #2d3748;
--accent: #4a5568;
--text: #f7fafc;
--shadow: 0 4px 12px rgba(0, 0, 0, 0.3);
}
#col-container {
margin: 0 auto;
max-width: 800px;
padding: 20px;
}
.header {
text-align: center;
margin-bottom: 20px;
}
.header h1 {
font-size: 2.5rem;
font-weight: 700;
color: var(--primary);
margin-bottom: 10px;
}
.header p {
color: var(--text);
opacity: 0.8;
}
.prompt-container, .result-container, .advanced-settings {
background: var(--secondary);
border-radius: 12px;
padding: 20px;
box-shadow: var(--shadow);
margin-bottom: 20px;
}
.advanced-settings .form {
display: grid;
grid-template-columns: 1fr 1fr;
gap: 16px;
}
.control-row {
display: flex;
gap: 10px;
align-items: center;
}
.btn-primary {
background: var(--primary) !important;
border: none !important;
color: white !important;
}
.btn-primary:hover {
opacity: 0.9 !important;
}
.examples {
display: grid;
grid-template-columns: repeat(auto-fill, minmax(250px, 1fr));
gap: 12px;
margin-top: 20px;
}
.example-prompt {
background: var(--secondary);
padding: 12px;
border-radius: 8px;
cursor: pointer;
transition: all 0.2s;
}
.example-prompt:hover {
transform: translateY(-2px);
box-shadow: var(--shadow);
}
.theme-toggle {
position: absolute;
top: 20px;
right: 20px;
background: var(--secondary);
border: none;
border-radius: 50%;
width: 40px;
height: 40px;
display: flex;
align-items: center;
justify-content: center;
cursor: pointer;
}
@media (max-width: 768px) {
.advanced-settings .form {
grid-template-columns: 1fr;
}
}
"""
js = """
function toggleTheme() {
const body = document.body;
body.classList.toggle('dark');
localStorage.setItem('gradio-theme', body.classList.contains('dark') ? 'dark' : 'light');
}
document.addEventListener('DOMContentLoaded', () => {
const savedTheme = localStorage.getItem('gradio-theme') || 'light';
if (savedTheme === 'dark') {
document.body.classList.add('dark');
}
const themeToggle = document.createElement('button');
themeToggle.className = 'theme-toggle';
themeToggle.innerHTML = savedTheme === 'dark' ? '☀️' : '🌙';
themeToggle.onclick = toggleTheme;
document.body.appendChild(themeToggle);
});
"""
with gr.Blocks(css=css, js=js, theme=gr.themes.Soft()) as demo:
with gr.Column(elem_id="col-container"):
with gr.Column(visible=True) as header:
gr.Markdown(
"""
<div class="header">
<h1>✨ AI Image Generator</h1>
<p>Transform your text into stunning images with SDXL Turbo</p>
</div>
""",
elem_classes="header"
)
with gr.Column(elem_classes="prompt-container"):
with gr.Row():
prompt = gr.Textbox(
label="",
show_label=False,
max_lines=2,
placeholder="Describe the image you want to generate...",
container=False,
scale=5
)
run_button = gr.Button(
"Generate",
scale=1,
variant="primary",
elem_classes="btn-primary"
)
with gr.Column(elem_classes="result-container"):
result = gr.Image(
label="Generated Image",
show_label=False,
height=500
)
with gr.Row():
seed_info = gr.Textbox(
label="Seed",
interactive=False
)
time_info = gr.Textbox(
label="Generation Time",
interactive=False
)
with gr.Accordion("🛠️ Advanced Settings", open=False, elem_classes="advanced-settings"):
with gr.Column(elem_classes="form"):
with gr.Row():
negative_prompt = gr.Textbox(
label="Negative Prompt",
max_lines=1,
placeholder="What you don't want to see in the image",
)
with gr.Row():
with gr.Column():
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(
label="Randomize seed",
value=True,
)
with gr.Column():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=0.0,
)
num_inference_steps = gr.Slider(
label="Inference Steps",
minimum=1,
maximum=50,
step=1,
value=2,
)
gr.Markdown("### Example Prompts")
with gr.Row(elem_classes="examples"):
for example in examples:
with gr.Column(min_width=200):
gr.Examples(
examples=[[example[0], example[1], example[2]]],
inputs=[prompt, width, height],
label="",
examples_per_page=20
)
run_button.click(
fn=infer,
inputs=[
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
],
outputs=[result, seed_info, time_info],
)
prompt.submit(
fn=infer,
inputs=[
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
],
outputs=[result, seed_info, time_info],
)
if __name__ == "__main__":
demo.queue(api_open=False).launch()