Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
import numpy as np | |
import random | |
import torch | |
from diffusers import ( | |
DiffusionPipeline, StableDiffusion3Pipeline, FluxPipeline, PixArtSigmaPipeline, | |
AuraFlowPipeline, Kandinsky3Pipeline, HunyuanDiTPipeline, | |
LuminaText2ImgPipeline, SanaPipeline,AutoPipelineForText2Image | |
) | |
import gc | |
import os | |
import psutil | |
import threading | |
from pathlib import Path | |
import shutil | |
import time | |
import glob | |
from datetime import datetime | |
from PIL import Image | |
#import os | |
#cache_dir = '/workspace/hf_cache' | |
# Constants | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 1024 | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
TORCH_DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
OUTPUT_DIR = "generated_images" | |
os.makedirs(OUTPUT_DIR, exist_ok=True) | |
# Model configurations | |
MODEL_CONFIGS = { | |
"FLUX": { | |
"repo_id": "black-forest-labs/FLUX.1-dev", | |
"pipeline_class": FluxPipeline, | |
# "cache_dir" : cache_dir | |
}, | |
"Stable Diffusion 3.5": { | |
"repo_id": "stabilityai/stable-diffusion-3.5-large", | |
"pipeline_class": StableDiffusion3Pipeline, | |
#"cache_dir" : cache_dir | |
}, | |
"PixArt": { | |
"repo_id": "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", | |
"pipeline_class": PixArtSigmaPipeline, | |
#"cache_dir" : cache_dir | |
}, | |
"SANA": { | |
"repo_id": "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", | |
"pipeline_class": SanaPipeline, | |
# "cache_dir" : cache_dir | |
}, | |
"AuraFlow": { | |
"repo_id": "fal/AuraFlow", | |
"pipeline_class": AuraFlowPipeline, | |
# "cache_dir" : cache_dir | |
}, | |
"Kandinsky": { | |
"repo_id": "kandinsky-community/kandinsky-3", | |
"pipeline_class": Kandinsky3Pipeline, | |
#"cache_dir" : cache_dir | |
}, | |
"Hunyuan": { | |
"repo_id": "Tencent-Hunyuan/HunyuanDiT-Diffusers", | |
"pipeline_class": HunyuanDiTPipeline, | |
#"cache_dir" : cache_dir | |
}, | |
"Lumina": { | |
"repo_id": "Alpha-VLLM/Lumina-Next-SFT-diffusers", | |
"pipeline_class": LuminaText2ImgPipeline, | |
#"cache_dir" : cache_dir | |
}, | |
} | |
# Dictionary to store model pipelines | |
pipes = {} | |
model_locks = {model_name: threading.Lock() for model_name in MODEL_CONFIGS.keys()} | |
def get_process_memory(): | |
"""Get memory usage of current process in GB""" | |
process = psutil.Process(os.getpid()) | |
return process.memory_info().rss / 1024 / 1024 / 1024 | |
def clear_torch_cache(): | |
"""Clear PyTorch's CUDA cache""" | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.ipc_collect() | |
def remove_cache_dir(model_name): | |
"""Remove the model's cache directory""" | |
cache_dir = Path.home() / '.cache' / 'huggingface' / 'diffusers' / MODEL_CONFIGS[model_name]['repo_id'].replace('/', | |
'--') | |
if cache_dir.exists(): | |
shutil.rmtree(cache_dir, ignore_errors=True) | |
def deep_cleanup(model_name, pipe): | |
"""Perform deep cleanup of model resources""" | |
try: | |
# 1. Move model to CPU first (helps prevent CUDA memory fragmentation) | |
if hasattr(pipe, 'to'): | |
pipe.to('cpu') | |
# 2. Delete all model components explicitly | |
for attr_name in list(pipe.__dict__.keys()): | |
if hasattr(pipe, attr_name): | |
delattr(pipe, attr_name) | |
# 3. Remove from pipes dictionary | |
if model_name in pipes: | |
del pipes[model_name] | |
# 4. Clear CUDA cache | |
clear_torch_cache() | |
# 5. Run garbage collection multiple times | |
for _ in range(3): | |
gc.collect() | |
# 6. Remove cached files | |
remove_cache_dir(model_name) | |
# 7. Additional CUDA cleanup if available | |
if torch.cuda.is_available(): | |
torch.cuda.synchronize() | |
# 8. Wait a small amount of time to ensure cleanup | |
time.sleep(1) | |
except Exception as e: | |
print(f"Error during cleanup of {model_name}: {str(e)}") | |
finally: | |
# Final garbage collection | |
gc.collect() | |
clear_torch_cache() | |
def load_pipeline(model_name): | |
"""Load model pipeline with memory tracking""" | |
initial_memory = get_process_memory() | |
config = MODEL_CONFIGS[model_name] | |
pipe = None | |
if model_name == "Kandinsky": | |
print("Kandinsky Special") | |
pipe = AutoPipelineForText2Image.from_pretrained( | |
"kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16 | |
) | |
else: | |
pipe = config["pipeline_class"].from_pretrained( | |
config["repo_id"], | |
torch_dtype=TORCH_DTYPE, | |
# cache_dir=cache_dir | |
) | |
pipe = pipe.to(DEVICE) | |
if hasattr(pipe, 'enable_model_cpu_offload'): | |
pipe.enable_model_cpu_offload() | |
final_memory = get_process_memory() | |
print(f"Memory used by {model_name}: {final_memory - initial_memory:.2f} GB") | |
return pipe | |
def save_generated_image(image, model_name, prompt): | |
"""Save generated image with timestamp and model name""" | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
# Create sanitized filename from prompt (first 30 chars) | |
prompt_part = "".join(c for c in prompt[:30] if c.isalnum() or c in (' ', '-', '_')).strip() | |
filename = f"{timestamp}_{model_name}_{prompt_part}.png" | |
filepath = os.path.join(OUTPUT_DIR, filename) | |
image.save(filepath) | |
return filepath | |
def get_generated_images(): | |
"""Get list of generated images with their details""" | |
files = glob.glob(os.path.join(OUTPUT_DIR, "*.png")) | |
files.sort(key=os.path.getctime, reverse=True) # Sort by creation time | |
return [ | |
{ | |
"path": f, | |
"name": os.path.basename(f), | |
"date": datetime.fromtimestamp(os.path.getctime(f)).strftime("%Y-%m-%d %H:%M:%S"), | |
"size": f"{os.path.getsize(f) / 1024:.1f} KB" | |
} | |
for f in files | |
] | |
def generate_image( | |
model_name, | |
prompt, | |
negative_prompt="", | |
seed=42, | |
randomize_seed=False, | |
width=1024, | |
height=1024, | |
guidance_scale=4.5, | |
num_inference_steps=40, | |
progress=gr.Progress(track_tqdm=True) | |
): | |
with model_locks[model_name]: | |
try: | |
# progress(0, desc=f"Loading {model_name} model...") | |
if model_name not in pipes: | |
pipes[model_name] = load_pipeline(model_name) | |
pipe = pipes[model_name] | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
generator = torch.Generator(DEVICE).manual_seed(seed) | |
print(f"Generating image with {model_name}...") | |
# progress(0.3, desc=f"Generating image with {model_name}...") | |
if model_name == "OneDiffusion": | |
prompt = "[[text2image]] " + prompt | |
image = pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
width=width, | |
height=height, | |
generator=generator, | |
).images[0] | |
filepath = save_generated_image(image, model_name, prompt) | |
print(f"Saved image to: {filepath}") | |
# progress(0.9, desc=f"Cleaning up {model_name} resources...") | |
# deep_cleanup(model_name, pipe) | |
# progress(1.0, desc=f"Generation complete with {model_name}") | |
return image, seed | |
except Exception as e: | |
print(f"Error with {model_name}: {str(e)}") | |
if model_name in pipes: | |
deep_cleanup(model_name, pipes[model_name]) | |
raise e | |
# Gradio Interface | |
css = """ | |
#col-container { | |
margin: 0 auto; | |
max-width: 1024px; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown("# Multi-Model Image Generation") | |
with gr.Row(): | |
prompt = gr.Text( | |
label="Prompt", | |
show_label=False, | |
max_lines=1, | |
placeholder="Enter your prompt", | |
container=False, | |
) | |
run_button = gr.Button("Generate", scale=0, variant="primary") | |
with gr.Accordion("Advanced Settings", open=False): | |
negative_prompt = gr.Text( | |
label="Negative prompt", | |
max_lines=1, | |
placeholder="Enter a negative prompt", | |
) | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=0, | |
) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
with gr.Row(): | |
width = gr.Slider( | |
label="Width", | |
minimum=512, | |
maximum=MAX_IMAGE_SIZE, | |
step=32, | |
value=1024, | |
) | |
height = gr.Slider( | |
label="Height", | |
minimum=512, | |
maximum=MAX_IMAGE_SIZE, | |
step=32, | |
value=1024, | |
) | |
with gr.Row(): | |
guidance_scale = gr.Slider( | |
label="Guidance scale", | |
minimum=0.0, | |
maximum=7.5, | |
step=0.1, | |
value=4.5, | |
) | |
num_inference_steps = gr.Slider( | |
label="Number of inference steps", | |
minimum=1, | |
maximum=50, | |
step=1, | |
value=40, | |
) | |
memory_indicator = gr.Markdown("Current memory usage: 0 GB") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
with gr.Tabs() as tabs: | |
results = {} | |
seeds = {} | |
for model_name in MODEL_CONFIGS.keys(): | |
with gr.Tab(model_name): | |
results[model_name] = gr.Image(label=f"{model_name} Result") | |
seeds[model_name] = gr.Number(label="Seed used", visible=True) | |
with gr.Column(scale=1): | |
gr.Markdown("### Generated Images") | |
file_gallery = gr.Gallery( | |
label="Generated Images", | |
show_label=False, | |
elem_id="file_gallery", | |
columns=3, | |
height=800, | |
visible=True | |
) | |
refresh_button = gr.Button("Refresh Gallery") | |
def update_gallery(): | |
"""Update the file gallery""" | |
files = get_generated_images() | |
return [ | |
(f["path"], f"{f['name']}\n{f['date']}") | |
for f in files | |
] | |
def generate_all(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, | |
progress=gr.Progress()): | |
outputs = [None] * (len(MODEL_CONFIGS) * 2) | |
for idx, model_name in enumerate(MODEL_CONFIGS.keys()): | |
try: | |
# Display progress for the specific model | |
# progress(0, desc=f"Starting generation for {model_name}...") | |
print(f"IMAGE GENERATING {model_name} ") | |
image, used_seed = generate_image( | |
model_name, prompt, negative_prompt, seed, | |
randomize_seed, width, height, guidance_scale, | |
num_inference_steps, progress | |
) | |
print(f"IMAGE GENERATIED {model_name} ") | |
# Update the respective model's tab with the generated image | |
# results[model_name].update(image) | |
# seeds[model_name].update(used_seed) | |
outputs[idx * 2] = image # Image slot | |
outputs[idx * 2 + 1] = seed # Seed slot | |
# outputs.extend([image, used_seed]) | |
# Add intermediate results to progress * (len(all_outputs) - len(all_outputs)) | |
print("YELID") | |
yield outputs + [None] | |
except Exception as e: | |
print(f"Error generating with {model_name}: {str(e)}") | |
outputs[idx * 2] = None | |
outputs[idx * 2 + 1] = None | |
# Update the gallery after generation | |
gallery_images = update_gallery() | |
# file_gallery.update(value=gallery_images) | |
return outputs | |
output_components = [] | |
for model_name in MODEL_CONFIGS.keys(): | |
output_components.extend([results[model_name], seeds[model_name]]) | |
run_button.click( | |
fn=generate_all, | |
inputs=[ | |
prompt, | |
negative_prompt, | |
seed, | |
randomize_seed, | |
width, | |
height, | |
guidance_scale, | |
num_inference_steps, | |
], | |
outputs=output_components, | |
) | |
refresh_button.click( | |
fn=update_gallery, | |
inputs=[], | |
outputs=[file_gallery], | |
) | |
demo.load( | |
fn=update_gallery, | |
inputs=[], | |
outputs=[file_gallery], | |
) | |
if __name__ == "__main__": | |
demo.launch(server_name='0.0.0.0') |