Spaces:
Runtime error
Runtime error
| import spaces | |
| import gradio as gr | |
| import numpy as np | |
| import random | |
| import torch | |
| from diffusers import ( | |
| DiffusionPipeline, StableDiffusion3Pipeline, FluxPipeline, PixArtSigmaPipeline, | |
| AuraFlowPipeline, Kandinsky3Pipeline, HunyuanDiTPipeline, | |
| LuminaText2ImgPipeline, SanaPipeline,AutoPipelineForText2Image | |
| ) | |
| import gc | |
| import os | |
| import psutil | |
| import threading | |
| from pathlib import Path | |
| import shutil | |
| import time | |
| import glob | |
| from datetime import datetime | |
| from PIL import Image | |
| #import os | |
| #cache_dir = '/workspace/hf_cache' | |
| # Constants | |
| MAX_SEED = np.iinfo(np.int32).max | |
| MAX_IMAGE_SIZE = 1024 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| TORCH_DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
| OUTPUT_DIR = "generated_images" | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| # Model configurations | |
| MODEL_CONFIGS = { | |
| "FLUX": { | |
| "repo_id": "black-forest-labs/FLUX.1-dev", | |
| "pipeline_class": FluxPipeline, | |
| # "cache_dir" : cache_dir | |
| }, | |
| "Stable Diffusion 3.5": { | |
| "repo_id": "stabilityai/stable-diffusion-3.5-large", | |
| "pipeline_class": StableDiffusion3Pipeline, | |
| #"cache_dir" : cache_dir | |
| }, | |
| "PixArt": { | |
| "repo_id": "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", | |
| "pipeline_class": PixArtSigmaPipeline, | |
| #"cache_dir" : cache_dir | |
| }, | |
| "SANA": { | |
| "repo_id": "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", | |
| "pipeline_class": SanaPipeline, | |
| # "cache_dir" : cache_dir | |
| }, | |
| "AuraFlow": { | |
| "repo_id": "fal/AuraFlow", | |
| "pipeline_class": AuraFlowPipeline, | |
| # "cache_dir" : cache_dir | |
| }, | |
| "Kandinsky": { | |
| "repo_id": "kandinsky-community/kandinsky-3", | |
| "pipeline_class": Kandinsky3Pipeline, | |
| #"cache_dir" : cache_dir | |
| }, | |
| "Hunyuan": { | |
| "repo_id": "Tencent-Hunyuan/HunyuanDiT-Diffusers", | |
| "pipeline_class": HunyuanDiTPipeline, | |
| #"cache_dir" : cache_dir | |
| }, | |
| "Lumina": { | |
| "repo_id": "Alpha-VLLM/Lumina-Next-SFT-diffusers", | |
| "pipeline_class": LuminaText2ImgPipeline, | |
| #"cache_dir" : cache_dir | |
| }, | |
| } | |
| # Dictionary to store model pipelines | |
| pipes = {} | |
| model_locks = {model_name: threading.Lock() for model_name in MODEL_CONFIGS.keys()} | |
| def get_process_memory(): | |
| """Get memory usage of current process in GB""" | |
| process = psutil.Process(os.getpid()) | |
| return process.memory_info().rss / 1024 / 1024 / 1024 | |
| def clear_torch_cache(): | |
| """Clear PyTorch's CUDA cache""" | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.ipc_collect() | |
| def remove_cache_dir(model_name): | |
| """Remove the model's cache directory""" | |
| cache_dir = Path.home() / '.cache' / 'huggingface' / 'diffusers' / MODEL_CONFIGS[model_name]['repo_id'].replace('/', | |
| '--') | |
| if cache_dir.exists(): | |
| shutil.rmtree(cache_dir, ignore_errors=True) | |
| def deep_cleanup(model_name, pipe): | |
| """Perform deep cleanup of model resources""" | |
| try: | |
| # 1. Move model to CPU first (helps prevent CUDA memory fragmentation) | |
| if hasattr(pipe, 'to'): | |
| pipe.to('cpu') | |
| # 2. Delete all model components explicitly | |
| for attr_name in list(pipe.__dict__.keys()): | |
| if hasattr(pipe, attr_name): | |
| delattr(pipe, attr_name) | |
| # 3. Remove from pipes dictionary | |
| if model_name in pipes: | |
| del pipes[model_name] | |
| # 4. Clear CUDA cache | |
| clear_torch_cache() | |
| # 5. Run garbage collection multiple times | |
| for _ in range(3): | |
| gc.collect() | |
| # 6. Remove cached files | |
| remove_cache_dir(model_name) | |
| # 7. Additional CUDA cleanup if available | |
| if torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| # 8. Wait a small amount of time to ensure cleanup | |
| time.sleep(1) | |
| except Exception as e: | |
| print(f"Error during cleanup of {model_name}: {str(e)}") | |
| finally: | |
| # Final garbage collection | |
| gc.collect() | |
| clear_torch_cache() | |
| def load_pipeline(model_name): | |
| """Load model pipeline with memory tracking""" | |
| initial_memory = get_process_memory() | |
| config = MODEL_CONFIGS[model_name] | |
| pipe = None | |
| if model_name == "Kandinsky": | |
| print("Kandinsky Special") | |
| pipe = AutoPipelineForText2Image.from_pretrained( | |
| "kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16 | |
| ) | |
| else: | |
| pipe = config["pipeline_class"].from_pretrained( | |
| config["repo_id"], | |
| torch_dtype=TORCH_DTYPE, | |
| # cache_dir=cache_dir | |
| ) | |
| pipe = pipe.to(DEVICE) | |
| if hasattr(pipe, 'enable_model_cpu_offload'): | |
| pipe.enable_model_cpu_offload() | |
| final_memory = get_process_memory() | |
| print(f"Memory used by {model_name}: {final_memory - initial_memory:.2f} GB") | |
| return pipe | |
| def save_generated_image(image, model_name, prompt): | |
| """Save generated image with timestamp and model name""" | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| # Create sanitized filename from prompt (first 30 chars) | |
| prompt_part = "".join(c for c in prompt[:30] if c.isalnum() or c in (' ', '-', '_')).strip() | |
| filename = f"{timestamp}_{model_name}_{prompt_part}.png" | |
| filepath = os.path.join(OUTPUT_DIR, filename) | |
| image.save(filepath) | |
| return filepath | |
| def get_generated_images(): | |
| """Get list of generated images with their details""" | |
| files = glob.glob(os.path.join(OUTPUT_DIR, "*.png")) | |
| files.sort(key=os.path.getctime, reverse=True) # Sort by creation time | |
| return [ | |
| { | |
| "path": f, | |
| "name": os.path.basename(f), | |
| "date": datetime.fromtimestamp(os.path.getctime(f)).strftime("%Y-%m-%d %H:%M:%S"), | |
| "size": f"{os.path.getsize(f) / 1024:.1f} KB" | |
| } | |
| for f in files | |
| ] | |
| def generate_image( | |
| model_name, | |
| prompt, | |
| negative_prompt="", | |
| seed=42, | |
| randomize_seed=False, | |
| width=1024, | |
| height=1024, | |
| guidance_scale=4.5, | |
| num_inference_steps=40, | |
| progress=gr.Progress(track_tqdm=True) | |
| ): | |
| with model_locks[model_name]: | |
| try: | |
| # progress(0, desc=f"Loading {model_name} model...") | |
| if model_name not in pipes: | |
| pipes[model_name] = load_pipeline(model_name) | |
| pipe = pipes[model_name] | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| generator = torch.Generator(DEVICE).manual_seed(seed) | |
| print(f"Generating image with {model_name}...") | |
| # progress(0.3, desc=f"Generating image with {model_name}...") | |
| if model_name == "OneDiffusion": | |
| prompt = "[[text2image]] " + prompt | |
| image = pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| width=width, | |
| height=height, | |
| generator=generator, | |
| ).images[0] | |
| filepath = save_generated_image(image, model_name, prompt) | |
| print(f"Saved image to: {filepath}") | |
| # progress(0.9, desc=f"Cleaning up {model_name} resources...") | |
| # deep_cleanup(model_name, pipe) | |
| # progress(1.0, desc=f"Generation complete with {model_name}") | |
| return image, seed | |
| except Exception as e: | |
| print(f"Error with {model_name}: {str(e)}") | |
| if model_name in pipes: | |
| deep_cleanup(model_name, pipes[model_name]) | |
| raise e | |
| # Gradio Interface | |
| css = """ | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 1024px; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.Markdown("# Multi-Model Image Generation") | |
| with gr.Row(): | |
| prompt = gr.Text( | |
| label="Prompt", | |
| show_label=False, | |
| max_lines=1, | |
| placeholder="Enter your prompt", | |
| container=False, | |
| ) | |
| run_button = gr.Button("Generate", scale=0, variant="primary") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| negative_prompt = gr.Text( | |
| label="Negative prompt", | |
| max_lines=1, | |
| placeholder="Enter a negative prompt", | |
| ) | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=MAX_SEED, | |
| step=1, | |
| value=0, | |
| ) | |
| randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
| with gr.Row(): | |
| width = gr.Slider( | |
| label="Width", | |
| minimum=512, | |
| maximum=MAX_IMAGE_SIZE, | |
| step=32, | |
| value=1024, | |
| ) | |
| height = gr.Slider( | |
| label="Height", | |
| minimum=512, | |
| maximum=MAX_IMAGE_SIZE, | |
| step=32, | |
| value=1024, | |
| ) | |
| with gr.Row(): | |
| guidance_scale = gr.Slider( | |
| label="Guidance scale", | |
| minimum=0.0, | |
| maximum=7.5, | |
| step=0.1, | |
| value=4.5, | |
| ) | |
| num_inference_steps = gr.Slider( | |
| label="Number of inference steps", | |
| minimum=1, | |
| maximum=50, | |
| step=1, | |
| value=40, | |
| ) | |
| memory_indicator = gr.Markdown("Current memory usage: 0 GB") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| with gr.Tabs() as tabs: | |
| results = {} | |
| seeds = {} | |
| for model_name in MODEL_CONFIGS.keys(): | |
| with gr.Tab(model_name): | |
| results[model_name] = gr.Image(label=f"{model_name} Result") | |
| seeds[model_name] = gr.Number(label="Seed used", visible=True) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Generated Images") | |
| file_gallery = gr.Gallery( | |
| label="Generated Images", | |
| show_label=False, | |
| elem_id="file_gallery", | |
| columns=3, | |
| height=800, | |
| visible=True | |
| ) | |
| refresh_button = gr.Button("Refresh Gallery") | |
| def update_gallery(): | |
| """Update the file gallery""" | |
| files = get_generated_images() | |
| return [ | |
| (f["path"], f"{f['name']}\n{f['date']}") | |
| for f in files | |
| ] | |
| def generate_all(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, | |
| progress=gr.Progress()): | |
| outputs = [None] * (len(MODEL_CONFIGS) * 2) | |
| for idx, model_name in enumerate(MODEL_CONFIGS.keys()): | |
| try: | |
| # Display progress for the specific model | |
| # progress(0, desc=f"Starting generation for {model_name}...") | |
| print(f"IMAGE GENERATING {model_name} ") | |
| image, used_seed = generate_image( | |
| model_name, prompt, negative_prompt, seed, | |
| randomize_seed, width, height, guidance_scale, | |
| num_inference_steps, progress | |
| ) | |
| print(f"IMAGE GENERATIED {model_name} ") | |
| # Update the respective model's tab with the generated image | |
| # results[model_name].update(image) | |
| # seeds[model_name].update(used_seed) | |
| outputs[idx * 2] = image # Image slot | |
| outputs[idx * 2 + 1] = seed # Seed slot | |
| # outputs.extend([image, used_seed]) | |
| # Add intermediate results to progress * (len(all_outputs) - len(all_outputs)) | |
| print("YELID") | |
| yield outputs + [None] | |
| except Exception as e: | |
| print(f"Error generating with {model_name}: {str(e)}") | |
| outputs[idx * 2] = None | |
| outputs[idx * 2 + 1] = None | |
| # Update the gallery after generation | |
| gallery_images = update_gallery() | |
| # file_gallery.update(value=gallery_images) | |
| return outputs | |
| output_components = [] | |
| for model_name in MODEL_CONFIGS.keys(): | |
| output_components.extend([results[model_name], seeds[model_name]]) | |
| run_button.click( | |
| fn=generate_all, | |
| inputs=[ | |
| prompt, | |
| negative_prompt, | |
| seed, | |
| randomize_seed, | |
| width, | |
| height, | |
| guidance_scale, | |
| num_inference_steps, | |
| ], | |
| outputs=output_components, | |
| ) | |
| refresh_button.click( | |
| fn=update_gallery, | |
| inputs=[], | |
| outputs=[file_gallery], | |
| ) | |
| demo.load( | |
| fn=update_gallery, | |
| inputs=[], | |
| outputs=[file_gallery], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name='0.0.0.0') |