RageshAntony's picture
fixed callback
99d1063 verified
raw
history blame
3.22 kB
import spaces
import torch
from diffusers import FluxPipeline, StableDiffusion3Pipeline
from PIL import Image
from io import BytesIO
import gradio as gr
# Function to generate images with progress
def generate_image_with_progress(pipe, prompt, num_steps, guidance_scale=None, seed=None, progress=gr.Progress()):
generator = None
if seed is not None:
generator = torch.Generator("cpu").manual_seed(seed)
print("Start generating")
# Wrapper to track progress
def callback(pipe, step_index, timestep, callback_kwargs): # pipe, step_index, timestep, callback_kwargs
cur_prg = step_index / num_steps
print(f"Progressing {cur_prg} Step {step_index}/{num_steps}")
progress(cur_prg, desc=f"Step {step_index}/{num_steps}")
if isinstance(pipe, StableDiffusion3Pipeline):
image = pipe(
prompt,
num_inference_steps=num_steps,
guidance_scale=guidance_scale,
callback_on_step_end=callback,
).images[0]
elif isinstance(pipe, FluxPipeline):
image = pipe(
prompt,
num_inference_steps=num_steps,
generator=generator,
output_type="pil",
callback_on_step_end=callback,
).images[0]
return image
# Gradio application
def main():
@spaces.GPU(duration=170)
def tab1_logic(prompt_text):
progress = gr.Progress()
num_steps = 30
seed = 42
print(f"Start tab {prompt_text}")
flux_pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
).to("cuda")
image = generate_image_with_progress(
flux_pipe, prompt_text, num_steps=num_steps, seed=seed, progress=progress
)
return f"Seed: {seed}", image
@spaces.GPU(duration=170)
def tab2_logic(prompt_text):
progress = gr.Progress()
num_steps = 28
guidance_scale = 3.5
print(f"Start tab {prompt_text}")
# Initialize pipelines
stable_diffusion_pipe = StableDiffusion3Pipeline.from_pretrained(
"stabilityai/stable-diffusion-3.5-large", torch_dtype=torch.bfloat16
).to("cuda")
image = generate_image_with_progress(
stable_diffusion_pipe, prompt_text, num_steps=num_steps, guidance_scale=guidance_scale, progress=progress
)
return "Seed: None", image
with gr.Blocks() as app:
gr.Markdown("# Multiple Model Image Generation with Progress Bar")
prompt_text = gr.Textbox(label="Enter prompt")
with gr.Tab("FLUX"):
button_1 = gr.Button("Run FLUX")
output_1 = gr.Textbox(label="Status")
img_1 = gr.Image(label="FLUX", height=300)
button_1.click(fn=tab1_logic, inputs=[prompt_text], outputs=[output_1, img_1])
with gr.Tab("StableDiffusion3"):
button_2 = gr.Button("Run StableDiffusion3")
output_2 = gr.Textbox(label="Status")
img_2 = gr.Image(label="StableDiffusion3", height=300)
button_2.click(fn=tab2_logic, inputs=[prompt_text], outputs=[output_2, img_2])
app.launch()
if __name__ == "__main__":
main()