import spaces import torch import random import numpy as np from inspect import signature from diffusers import ( FluxPipeline, StableDiffusion3Pipeline, PixArtSigmaPipeline, SanaPipeline, AuraFlowPipeline, Kandinsky3Pipeline, HunyuanDiTPipeline, LuminaText2ImgPipeline,AutoPipelineForText2Image ) import gradio as gr from diffusers.pipelines.pipeline_utils import DiffusionPipeline from pathlib import Path import time import os from datetime import datetime MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 1024 class ProgressPipeline(DiffusionPipeline): def __init__(self, original_pipeline): super().__init__() self.original_pipeline = original_pipeline # Register all components from the original pipeline for attr_name, attr_value in vars(original_pipeline).items(): setattr(self, attr_name, attr_value) @torch.no_grad() def __call__( self, prompt, num_inference_steps=30, generator=None, guidance_scale=7.5, callback=None, callback_steps=1, **kwargs ): # Initialize the progress tracking self._num_inference_steps = num_inference_steps self._step = 0 def progress_callback(step_index, timestep, callback_kwargs): if callback and step_index % callback_steps == 0: # Pass self (the pipeline) to the callback callback(self, step_index, timestep, callback_kwargs) return callback_kwargs # Monkey patch the original pipeline's progress tracking original_step = self.original_pipeline.scheduler.step def wrapped_step(*args, **kwargs): self._step += 1 progress_callback(self._step, None, {}) return original_step(*args, **kwargs) self.original_pipeline.scheduler.step = wrapped_step try: # Call the original pipeline result = self.original_pipeline( prompt=prompt, num_inference_steps=num_inference_steps, generator=generator, guidance_scale=guidance_scale, **kwargs ) return result finally: # Restore the original step function self.original_pipeline.scheduler.step = original_step cache_dir = '/workspace/hf_cache' MODEL_CONFIGS = { "FLUX": { "repo_id": "black-forest-labs/FLUX.1-dev", "pipeline_class": FluxPipeline, }, "Stable Diffusion 3.5": { "repo_id": "stabilityai/stable-diffusion-3.5-large", "pipeline_class": StableDiffusion3Pipeline, }, "PixArt": { "repo_id": "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", "pipeline_class": PixArtSigmaPipeline, }, "SANA": { "repo_id": "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", "pipeline_class": SanaPipeline, }, "AuraFlow": { "repo_id": "fal/AuraFlow", "pipeline_class": AuraFlowPipeline, }, "Kandinsky": { "repo_id": "kandinsky-community/kandinsky-3", "pipeline_class": Kandinsky3Pipeline, }, "Hunyuan": { "repo_id": "Tencent-Hunyuan/HunyuanDiT-Diffusers", "pipeline_class": HunyuanDiTPipeline, }, "Lumina": { "repo_id": "Alpha-VLLM/Lumina-Next-SFT-diffusers", "pipeline_class": LuminaText2ImgPipeline, } } def generate_image_with_progress(model_name,pipe, prompt, num_steps, guidance_scale=3.5, seed=None,negative_prompt=None, randomize_seed=None, width=1024, height=1024, num_inference_steps=40, progress=gr.Progress(track_tqdm=False)): generator = None if randomize_seed: seed = random.randint(0, MAX_SEED) if seed is not None: generator = torch.Generator("cuda").manual_seed(seed) else: generator = torch.Generator("cuda") def callback(pipe, step_index, timestep, callback_kwargs): print(f" callback => {step_index}, {timestep}") if step_index is None: step_index = 0 cur_prg = step_index / num_steps progress(cur_prg, desc=f"Step {step_index}/{num_steps}") return callback_kwargs print(f"START GENR ") # Get the signature of the pipe pipe_signature = signature(pipe) # Check for the presence of "guidance_scale" and "callback_on_step_end" in the signature has_guidance_scale = "guidance_scale" in pipe_signature.parameters has_callback_on_step_end = "callback_on_step_end" in pipe_signature.parameters # Define common arguments common_args = { "prompt": prompt, "num_inference_steps": num_steps, "negative_prompt": negative_prompt, "width": width, "height": height, "generator": generator, } if has_guidance_scale: common_args["guidance_scale"] = guidance_scale if has_callback_on_step_end: print("has callback_on_step_end and", "has guidance_scale" if has_guidance_scale else "NO guidance_scale") common_args["callback_on_step_end"] = callback else: print("NO callback_on_step_end and", "has guidance_scale" if has_guidance_scale else "NO guidance_scale") common_args["callback"] = callback common_args["callback_steps"] = 1 # Generate image image = pipe(**common_args).images[0] filepath = save_generated_image(image, model_name, prompt) # Then, reload the gallery images, load_message = load_images_from_directory(model_name) print(f"Saved image to: {filepath}") return seed, image, images @spaces.GPU(duration=170) def create_pipeline_logic(prompt_text, model_name, negative_prompt="", seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=4.5, num_inference_steps=40,): print(f"starting {model_name}") progress = gr.Progress(track_tqdm=False) config = MODEL_CONFIGS[model_name] pipe_class = config["pipeline_class"] pipe = None b_pipe = AutoPipelineForText2Image.from_pretrained( config["repo_id"], #variant="fp16", #cache_dir=config["cache_dir"], torch_dtype=torch.bfloat16 ).to("cuda") pipe_signature = signature(b_pipe) # Check for the presence of "callback_on_step_end" in the signature has_callback_on_step_end = "callback_on_step_end" in pipe_signature.parameters if not has_callback_on_step_end: pipe = ProgressPipeline(b_pipe) print("ProgressPipeline specal") else: pipe = b_pipe gen_seed,image, images = generate_image_with_progress( model_name,pipe, prompt_text, num_steps=num_inference_steps, guidance_scale=guidance_scale, seed=seed,negative_prompt = negative_prompt, randomize_seed = randomize_seed, width = width, height = height, progress=progress ) return f"Seed: {gen_seed}", image, images def main(): with gr.Blocks() as app: gr.Markdown("# Dynamic Multiple Model Image Generation") prompt_text = gr.Textbox(label="Enter prompt") with gr.Accordion("Advanced Settings", open=False): negative_prompt = gr.Text( label="Negative prompt", max_lines=1, placeholder="Enter a negative prompt", ) seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=100, value=0, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) with gr.Row(): width = gr.Slider( label="Width", minimum=512, maximum=MAX_IMAGE_SIZE, step=32, value=1024, ) height = gr.Slider( label="Height", minimum=512, maximum=MAX_IMAGE_SIZE, step=32, value=1024, ) with gr.Row(): guidance_scale = gr.Slider( label="Guidance scale", minimum=0.0, maximum=7.5, step=0.1, value=4.5, ) num_inference_steps = gr.Slider( label="Number of inference steps", minimum=1, maximum=50, step=1, value=40, ) for model_name, config in MODEL_CONFIGS.items(): #global gallery with gr.Tab(model_name) as tab_model: button = gr.Button(f"Run {model_name}") output = gr.Textbox(label="Status") img = gr.Image(label=model_name, height=300) gallery = gr.Gallery( label="Image Gallery", show_label=True, columns=4, rows=3, height=600, object_fit="contain" ) tab_model.select( fn=load_images_from_directory, inputs=[gr.Text(value= model_name,visible=False)], outputs=[gallery], ) button.click(fn=create_pipeline_logic, inputs=[prompt_text, gr.Text(value= model_name,visible=False), negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps], outputs=[output, img, gallery]) app.launch() def save_generated_image(image, model_name, prompt): """Save generated image with timestamp and model name""" timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") # Create sanitized filename from prompt (first 30 chars) prompt_part = "".join(c for c in prompt[:30] if c.isalnum() or c in (' ', '-', '_')).strip() filename = f"{timestamp}_{model_name}_{prompt_part}.png" path = Path(model_name) path.mkdir(parents=True, exist_ok=True) filepath = os.path.join(model_name, filename) image.save(filepath) return filepath def load_images_from_directory(directory_path): """ Load all images from the specified directory. Returns a list of image file paths. """ print(f"Loading images {directory_path}") image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp'} directory = Path(directory_path) if not directory.exists(): print(f"NO Direc {directory_path} ") return [], f"Error: Directory '{directory_path}' does not exist" image_files = [ str(f) for f in directory.iterdir() if f.suffix.lower() in image_extensions and f.is_file() ] if not image_files: print(f"NO images {directory_path} ") return [], f"No images found in directory '{directory_path}'" print(f"has images {directory_path} {len(image_files)}") return image_files, f"Found {len(image_files)} images" if __name__ == "__main__": main()