import spaces import gradio as gr import numpy as np import random import torch from diffusers import ( DiffusionPipeline, StableDiffusion3Pipeline, FluxPipeline, PixArtSigmaPipeline, AuraFlowPipeline, Kandinsky3Pipeline, HunyuanDiTPipeline, LuminaText2ImgPipeline, SanaPipeline,AutoPipelineForText2Image ) import gc import os import psutil import threading from pathlib import Path import shutil import time import glob from datetime import datetime from PIL import Image #import os #cache_dir = '/workspace/hf_cache' # Constants MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 1024 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" TORCH_DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 OUTPUT_DIR = "generated_images" os.makedirs(OUTPUT_DIR, exist_ok=True) # Model configurations MODEL_CONFIGS = { "FLUX": { "repo_id": "black-forest-labs/FLUX.1-dev", "pipeline_class": FluxPipeline, # "cache_dir" : cache_dir }, "Stable Diffusion 3.5": { "repo_id": "stabilityai/stable-diffusion-3.5-large", "pipeline_class": StableDiffusion3Pipeline, #"cache_dir" : cache_dir }, "PixArt": { "repo_id": "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", "pipeline_class": PixArtSigmaPipeline, #"cache_dir" : cache_dir }, "SANA": { "repo_id": "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", "pipeline_class": SanaPipeline, # "cache_dir" : cache_dir }, "AuraFlow": { "repo_id": "fal/AuraFlow", "pipeline_class": AuraFlowPipeline, # "cache_dir" : cache_dir }, "Kandinsky": { "repo_id": "kandinsky-community/kandinsky-3", "pipeline_class": Kandinsky3Pipeline, #"cache_dir" : cache_dir }, "Hunyuan": { "repo_id": "Tencent-Hunyuan/HunyuanDiT-Diffusers", "pipeline_class": HunyuanDiTPipeline, #"cache_dir" : cache_dir }, "Lumina": { "repo_id": "Alpha-VLLM/Lumina-Next-SFT-diffusers", "pipeline_class": LuminaText2ImgPipeline, #"cache_dir" : cache_dir }, } # Dictionary to store model pipelines pipes = {} model_locks = {model_name: threading.Lock() for model_name in MODEL_CONFIGS.keys()} def get_process_memory(): """Get memory usage of current process in GB""" process = psutil.Process(os.getpid()) return process.memory_info().rss / 1024 / 1024 / 1024 def clear_torch_cache(): """Clear PyTorch's CUDA cache""" if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() def remove_cache_dir(model_name): """Remove the model's cache directory""" cache_dir = Path.home() / '.cache' / 'huggingface' / 'diffusers' / MODEL_CONFIGS[model_name]['repo_id'].replace('/', '--') if cache_dir.exists(): shutil.rmtree(cache_dir, ignore_errors=True) def deep_cleanup(model_name, pipe): """Perform deep cleanup of model resources""" try: # 1. Move model to CPU first (helps prevent CUDA memory fragmentation) if hasattr(pipe, 'to'): pipe.to('cpu') # 2. Delete all model components explicitly for attr_name in list(pipe.__dict__.keys()): if hasattr(pipe, attr_name): delattr(pipe, attr_name) # 3. Remove from pipes dictionary if model_name in pipes: del pipes[model_name] # 4. Clear CUDA cache clear_torch_cache() # 5. Run garbage collection multiple times for _ in range(3): gc.collect() # 6. Remove cached files remove_cache_dir(model_name) # 7. Additional CUDA cleanup if available if torch.cuda.is_available(): torch.cuda.synchronize() # 8. Wait a small amount of time to ensure cleanup time.sleep(1) except Exception as e: print(f"Error during cleanup of {model_name}: {str(e)}") finally: # Final garbage collection gc.collect() clear_torch_cache() def load_pipeline(model_name): """Load model pipeline with memory tracking""" initial_memory = get_process_memory() config = MODEL_CONFIGS[model_name] pipe = None if model_name == "Kandinsky": print("Kandinsky Special") pipe = AutoPipelineForText2Image.from_pretrained( "kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16 ) else: pipe = config["pipeline_class"].from_pretrained( config["repo_id"], torch_dtype=TORCH_DTYPE, # cache_dir=cache_dir ) pipe = pipe.to(DEVICE) if hasattr(pipe, 'enable_model_cpu_offload'): pipe.enable_model_cpu_offload() final_memory = get_process_memory() print(f"Memory used by {model_name}: {final_memory - initial_memory:.2f} GB") return pipe def save_generated_image(image, model_name, prompt): """Save generated image with timestamp and model name""" timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") # Create sanitized filename from prompt (first 30 chars) prompt_part = "".join(c for c in prompt[:30] if c.isalnum() or c in (' ', '-', '_')).strip() filename = f"{timestamp}_{model_name}_{prompt_part}.png" filepath = os.path.join(OUTPUT_DIR, filename) image.save(filepath) return filepath def get_generated_images(): """Get list of generated images with their details""" files = glob.glob(os.path.join(OUTPUT_DIR, "*.png")) files.sort(key=os.path.getctime, reverse=True) # Sort by creation time return [ { "path": f, "name": os.path.basename(f), "date": datetime.fromtimestamp(os.path.getctime(f)).strftime("%Y-%m-%d %H:%M:%S"), "size": f"{os.path.getsize(f) / 1024:.1f} KB" } for f in files ] def generate_image( model_name, prompt, negative_prompt="", seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=4.5, num_inference_steps=40, progress=gr.Progress(track_tqdm=True) ): with model_locks[model_name]: try: # progress(0, desc=f"Loading {model_name} model...") if model_name not in pipes: pipes[model_name] = load_pipeline(model_name) pipe = pipes[model_name] if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator(DEVICE).manual_seed(seed) print(f"Generating image with {model_name}...") # progress(0.3, desc=f"Generating image with {model_name}...") if model_name == "OneDiffusion": prompt = "[[text2image]] " + prompt image = pipe( prompt=prompt, negative_prompt=negative_prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, width=width, height=height, generator=generator, ).images[0] filepath = save_generated_image(image, model_name, prompt) print(f"Saved image to: {filepath}") # progress(0.9, desc=f"Cleaning up {model_name} resources...") # deep_cleanup(model_name, pipe) # progress(1.0, desc=f"Generation complete with {model_name}") return image, seed except Exception as e: print(f"Error with {model_name}: {str(e)}") if model_name in pipes: deep_cleanup(model_name, pipes[model_name]) raise e # Gradio Interface css = """ #col-container { margin: 0 auto; max-width: 1024px; } """ with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown("# Multi-Model Image Generation") with gr.Row(): prompt = gr.Text( label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False, ) run_button = gr.Button("Generate", scale=0, variant="primary") with gr.Accordion("Advanced Settings", open=False): negative_prompt = gr.Text( label="Negative prompt", max_lines=1, placeholder="Enter a negative prompt", ) seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) with gr.Row(): width = gr.Slider( label="Width", minimum=512, maximum=MAX_IMAGE_SIZE, step=32, value=1024, ) height = gr.Slider( label="Height", minimum=512, maximum=MAX_IMAGE_SIZE, step=32, value=1024, ) with gr.Row(): guidance_scale = gr.Slider( label="Guidance scale", minimum=0.0, maximum=7.5, step=0.1, value=4.5, ) num_inference_steps = gr.Slider( label="Number of inference steps", minimum=1, maximum=50, step=1, value=40, ) memory_indicator = gr.Markdown("Current memory usage: 0 GB") with gr.Row(): with gr.Column(scale=2): with gr.Tabs() as tabs: results = {} seeds = {} for model_name in MODEL_CONFIGS.keys(): with gr.Tab(model_name): results[model_name] = gr.Image(label=f"{model_name} Result") seeds[model_name] = gr.Number(label="Seed used", visible=True) with gr.Column(scale=1): gr.Markdown("### Generated Images") file_gallery = gr.Gallery( label="Generated Images", show_label=False, elem_id="file_gallery", columns=3, height=800, visible=True ) refresh_button = gr.Button("Refresh Gallery") def update_gallery(): """Update the file gallery""" files = get_generated_images() return [ (f["path"], f"{f['name']}\n{f['date']}") for f in files ] @spaces.GPU(duration=400) def generate_all(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress()): outputs = [None] * (len(MODEL_CONFIGS) * 2) for idx, model_name in enumerate(MODEL_CONFIGS.keys()): try: # Display progress for the specific model # progress(0, desc=f"Starting generation for {model_name}...") print(f"IMAGE GENERATING {model_name} ") image, used_seed = generate_image( model_name, prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress ) print(f"IMAGE GENERATIED {model_name} ") # Update the respective model's tab with the generated image # results[model_name].update(image) # seeds[model_name].update(used_seed) outputs[idx * 2] = image # Image slot outputs[idx * 2 + 1] = seed # Seed slot # outputs.extend([image, used_seed]) # Add intermediate results to progress * (len(all_outputs) - len(all_outputs)) print("YELID") yield outputs + [None] except Exception as e: print(f"Error generating with {model_name}: {str(e)}") outputs[idx * 2] = None outputs[idx * 2 + 1] = None # Update the gallery after generation gallery_images = update_gallery() # file_gallery.update(value=gallery_images) return outputs output_components = [] for model_name in MODEL_CONFIGS.keys(): output_components.extend([results[model_name], seeds[model_name]]) run_button.click( fn=generate_all, inputs=[ prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, ], outputs=output_components, ) refresh_button.click( fn=update_gallery, inputs=[], outputs=[file_gallery], ) demo.load( fn=update_gallery, inputs=[], outputs=[file_gallery], ) if __name__ == "__main__": demo.launch(server_name='0.0.0.0')