import spaces import gradio as gr import numpy as np import random import torch from diffusers import ( DiffusionPipeline, StableDiffusion3Pipeline, FluxPipeline, PixArtSigmaPipeline, AuraFlowPipeline, Kandinsky3Pipeline, HunyuanDiTPipeline, LuminaText2ImgPipeline, SanaPipeline,AutoPipelineForText2Image ) import gc import os import psutil import threading from pathlib import Path import shutil import time import glob from datetime import datetime from PIL import Image from onediffusion.diffusion.pipelines.onediffusion import OneDiffusionPipeline from onediffusion.models.denoiser.nextdit import NextDiT from onediffusion.dataset.utils import get_closest_ratio, ASPECT_RATIO_512 #import os #cache_dir = '/workspace/hf_cache' # Constants MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 1024 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" TORCH_DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 OUTPUT_DIR = "generated_images" os.makedirs(OUTPUT_DIR, exist_ok=True) # Model configurations MODEL_CONFIGS = { "OneDiffusion": { "repo_id": "lehduong/OneDiffusion", "pipeline_class": OneDiffusionPipeline, # "cache_dir" : cache_dir } } # Dictionary to store model pipelines pipes = {} model_locks = {model_name: threading.Lock() for model_name in MODEL_CONFIGS.keys()} def get_process_memory(): """Get memory usage of current process in GB""" process = psutil.Process(os.getpid()) return process.memory_info().rss / 1024 / 1024 / 1024 def clear_torch_cache(): """Clear PyTorch's CUDA cache""" if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() def remove_cache_dir(model_name): """Remove the model's cache directory""" cache_dir = Path.home() / '.cache' / 'huggingface' / 'diffusers' / MODEL_CONFIGS[model_name]['repo_id'].replace('/', '--') if cache_dir.exists(): shutil.rmtree(cache_dir, ignore_errors=True) def deep_cleanup(model_name, pipe): """Perform deep cleanup of model resources""" try: # 1. Move model to CPU first (helps prevent CUDA memory fragmentation) if hasattr(pipe, 'to'): pipe.to('cpu') # 2. Delete all model components explicitly for attr_name in list(pipe.__dict__.keys()): if hasattr(pipe, attr_name): delattr(pipe, attr_name) # 3. Remove from pipes dictionary if model_name in pipes: del pipes[model_name] # 4. Clear CUDA cache clear_torch_cache() # 5. Run garbage collection multiple times for _ in range(3): gc.collect() # 6. Remove cached files remove_cache_dir(model_name) # 7. Additional CUDA cleanup if available if torch.cuda.is_available(): torch.cuda.synchronize() # 8. Wait a small amount of time to ensure cleanup time.sleep(1) except Exception as e: print(f"Error during cleanup of {model_name}: {str(e)}") finally: # Final garbage collection gc.collect() clear_torch_cache() def load_pipeline(model_name): """Load model pipeline with memory tracking""" initial_memory = get_process_memory() config = MODEL_CONFIGS[model_name] pipe = None if model_name == "Kandinsky": print("Kandinsky Special") pipe = AutoPipelineForText2Image.from_pretrained( "kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16 ) else: pipe = config["pipeline_class"].from_pretrained( config["repo_id"], torch_dtype=TORCH_DTYPE, # cache_dir=cache_dir ) pipe = pipe.to(DEVICE) if hasattr(pipe, 'enable_model_cpu_offload'): pipe.enable_model_cpu_offload() final_memory = get_process_memory() print(f"Memory used by {model_name}: {final_memory - initial_memory:.2f} GB") return pipe def save_generated_image(image, model_name, prompt): """Save generated image with timestamp and model name""" timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") # Create sanitized filename from prompt (first 30 chars) prompt_part = "".join(c for c in prompt[:30] if c.isalnum() or c in (' ', '-', '_')).strip() filename = f"{timestamp}_{model_name}_{prompt_part}.png" filepath = os.path.join(OUTPUT_DIR, filename) image.save(filepath) return filepath def get_generated_images(): """Get list of generated images with their details""" files = glob.glob(os.path.join(OUTPUT_DIR, "*.png")) files.sort(key=os.path.getctime, reverse=True) # Sort by creation time return [ { "path": f, "name": os.path.basename(f), "date": datetime.fromtimestamp(os.path.getctime(f)).strftime("%Y-%m-%d %H:%M:%S"), "size": f"{os.path.getsize(f) / 1024:.1f} KB" } for f in files ] def generate_image( model_name, prompt, negative_prompt="", seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=4.5, num_inference_steps=40, progress=gr.Progress(track_tqdm=True) ): with model_locks[model_name]: try: # progress(0, desc=f"Loading {model_name} model...") if model_name not in pipes: pipes[model_name] = load_pipeline(model_name) pipe = pipes[model_name] if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator(DEVICE).manual_seed(seed) print(f"Generating image with {model_name}...") # progress(0.3, desc=f"Generating image with {model_name}...") if model_name == "OneDiffusion": prompt = "[[text2image]] " + prompt image = pipe( prompt=prompt, negative_prompt=negative_prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, width=width, height=height, generator=generator, ).images[0] filepath = save_generated_image(image, model_name, prompt) print(f"Saved image to: {filepath}") # progress(0.9, desc=f"Cleaning up {model_name} resources...") # deep_cleanup(model_name, pipe) # progress(1.0, desc=f"Generation complete with {model_name}") return image, seed except Exception as e: print(f"Error with {model_name}: {str(e)}") if model_name in pipes: deep_cleanup(model_name, pipes[model_name]) raise e # Gradio Interface css = """ #col-container { margin: 0 auto; max-width: 1024px; } """ with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown("# Multi-Model Image Generation") with gr.Row(): prompt = gr.Text( label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False, ) run_button = gr.Button("Generate", scale=0, variant="primary") with gr.Accordion("Advanced Settings", open=False): negative_prompt = gr.Text( label="Negative prompt", max_lines=1, placeholder="Enter a negative prompt", ) seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) with gr.Row(): width = gr.Slider( label="Width", minimum=512, maximum=MAX_IMAGE_SIZE, step=32, value=1024, ) height = gr.Slider( label="Height", minimum=512, maximum=MAX_IMAGE_SIZE, step=32, value=1024, ) with gr.Row(): guidance_scale = gr.Slider( label="Guidance scale", minimum=0.0, maximum=7.5, step=0.1, value=4.5, ) num_inference_steps = gr.Slider( label="Number of inference steps", minimum=1, maximum=50, step=1, value=40, ) memory_indicator = gr.Markdown("Current memory usage: 0 GB") with gr.Row(): with gr.Column(scale=2): with gr.Tabs() as tabs: results = {} seeds = {} for model_name in MODEL_CONFIGS.keys(): with gr.Tab(model_name): results[model_name] = gr.Image(label=f"{model_name} Result") seeds[model_name] = gr.Number(label="Seed used", visible=True) with gr.Column(scale=1): gr.Markdown("### Generated Images") file_gallery = gr.Gallery( label="Generated Images", show_label=False, elem_id="file_gallery", columns=3, height=800, visible=True ) refresh_button = gr.Button("Refresh Gallery") def update_gallery(): """Update the file gallery""" files = get_generated_images() return [ (f["path"], f"{f['name']}\n{f['date']}") for f in files ] @spaces.GPU(duration=400) def generate_all(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress()): outputs = [None] * (len(MODEL_CONFIGS) * 2) for idx, model_name in enumerate(MODEL_CONFIGS.keys()): try: # Display progress for the specific model # progress(0, desc=f"Starting generation for {model_name}...") print(f"IMAGE GENERATING {model_name} ") image, used_seed = generate_image( model_name, prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress ) print(f"IMAGE GENERATIED {model_name} ") # Update the respective model's tab with the generated image # results[model_name].update(image) # seeds[model_name].update(used_seed) outputs[idx * 2] = image # Image slot outputs[idx * 2 + 1] = seed # Seed slot # outputs.extend([image, used_seed]) # Add intermediate results to progress * (len(all_outputs) - len(all_outputs)) print("YELID") yield outputs + [None] except Exception as e: print(f"Error generating with {model_name}: {str(e)}") outputs[idx * 2] = None outputs[idx * 2 + 1] = None # Update the gallery after generation gallery_images = update_gallery() # file_gallery.update(value=gallery_images) return outputs output_components = [] for model_name in MODEL_CONFIGS.keys(): output_components.extend([results[model_name], seeds[model_name]]) run_button.click( fn=generate_all, inputs=[ prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, ], outputs=output_components, ) refresh_button.click( fn=update_gallery, inputs=[], outputs=[file_gallery], ) demo.load( fn=update_gallery, inputs=[], outputs=[file_gallery], ) if __name__ == "__main__": demo.launch(server_name='0.0.0.0')