RageshAntony's picture
rem one diff
ff3026d verified
import spaces
import gradio as gr
import numpy as np
import random
import torch
from diffusers import (
DiffusionPipeline, StableDiffusion3Pipeline, FluxPipeline, PixArtSigmaPipeline,
AuraFlowPipeline, Kandinsky3Pipeline, HunyuanDiTPipeline,
LuminaText2ImgPipeline, SanaPipeline,AutoPipelineForText2Image
)
import gc
import os
import psutil
import threading
from pathlib import Path
import shutil
import time
import glob
from datetime import datetime
from PIL import Image
#import os
#cache_dir = '/workspace/hf_cache'
# Constants
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TORCH_DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
OUTPUT_DIR = "generated_images"
os.makedirs(OUTPUT_DIR, exist_ok=True)
# Model configurations
MODEL_CONFIGS = {
"FLUX": {
"repo_id": "black-forest-labs/FLUX.1-dev",
"pipeline_class": FluxPipeline,
# "cache_dir" : cache_dir
},
"Stable Diffusion 3.5": {
"repo_id": "stabilityai/stable-diffusion-3.5-large",
"pipeline_class": StableDiffusion3Pipeline,
#"cache_dir" : cache_dir
},
"PixArt": {
"repo_id": "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
"pipeline_class": PixArtSigmaPipeline,
#"cache_dir" : cache_dir
},
"SANA": {
"repo_id": "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
"pipeline_class": SanaPipeline,
# "cache_dir" : cache_dir
},
"AuraFlow": {
"repo_id": "fal/AuraFlow",
"pipeline_class": AuraFlowPipeline,
# "cache_dir" : cache_dir
},
"Kandinsky": {
"repo_id": "kandinsky-community/kandinsky-3",
"pipeline_class": Kandinsky3Pipeline,
#"cache_dir" : cache_dir
},
"Hunyuan": {
"repo_id": "Tencent-Hunyuan/HunyuanDiT-Diffusers",
"pipeline_class": HunyuanDiTPipeline,
#"cache_dir" : cache_dir
},
"Lumina": {
"repo_id": "Alpha-VLLM/Lumina-Next-SFT-diffusers",
"pipeline_class": LuminaText2ImgPipeline,
#"cache_dir" : cache_dir
},
}
# Dictionary to store model pipelines
pipes = {}
model_locks = {model_name: threading.Lock() for model_name in MODEL_CONFIGS.keys()}
def get_process_memory():
"""Get memory usage of current process in GB"""
process = psutil.Process(os.getpid())
return process.memory_info().rss / 1024 / 1024 / 1024
def clear_torch_cache():
"""Clear PyTorch's CUDA cache"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def remove_cache_dir(model_name):
"""Remove the model's cache directory"""
cache_dir = Path.home() / '.cache' / 'huggingface' / 'diffusers' / MODEL_CONFIGS[model_name]['repo_id'].replace('/',
'--')
if cache_dir.exists():
shutil.rmtree(cache_dir, ignore_errors=True)
def deep_cleanup(model_name, pipe):
"""Perform deep cleanup of model resources"""
try:
# 1. Move model to CPU first (helps prevent CUDA memory fragmentation)
if hasattr(pipe, 'to'):
pipe.to('cpu')
# 2. Delete all model components explicitly
for attr_name in list(pipe.__dict__.keys()):
if hasattr(pipe, attr_name):
delattr(pipe, attr_name)
# 3. Remove from pipes dictionary
if model_name in pipes:
del pipes[model_name]
# 4. Clear CUDA cache
clear_torch_cache()
# 5. Run garbage collection multiple times
for _ in range(3):
gc.collect()
# 6. Remove cached files
remove_cache_dir(model_name)
# 7. Additional CUDA cleanup if available
if torch.cuda.is_available():
torch.cuda.synchronize()
# 8. Wait a small amount of time to ensure cleanup
time.sleep(1)
except Exception as e:
print(f"Error during cleanup of {model_name}: {str(e)}")
finally:
# Final garbage collection
gc.collect()
clear_torch_cache()
def load_pipeline(model_name):
"""Load model pipeline with memory tracking"""
initial_memory = get_process_memory()
config = MODEL_CONFIGS[model_name]
pipe = None
if model_name == "Kandinsky":
print("Kandinsky Special")
pipe = AutoPipelineForText2Image.from_pretrained(
"kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16
)
else:
pipe = config["pipeline_class"].from_pretrained(
config["repo_id"],
torch_dtype=TORCH_DTYPE,
# cache_dir=cache_dir
)
pipe = pipe.to(DEVICE)
if hasattr(pipe, 'enable_model_cpu_offload'):
pipe.enable_model_cpu_offload()
final_memory = get_process_memory()
print(f"Memory used by {model_name}: {final_memory - initial_memory:.2f} GB")
return pipe
def save_generated_image(image, model_name, prompt):
"""Save generated image with timestamp and model name"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Create sanitized filename from prompt (first 30 chars)
prompt_part = "".join(c for c in prompt[:30] if c.isalnum() or c in (' ', '-', '_')).strip()
filename = f"{timestamp}_{model_name}_{prompt_part}.png"
filepath = os.path.join(OUTPUT_DIR, filename)
image.save(filepath)
return filepath
def get_generated_images():
"""Get list of generated images with their details"""
files = glob.glob(os.path.join(OUTPUT_DIR, "*.png"))
files.sort(key=os.path.getctime, reverse=True) # Sort by creation time
return [
{
"path": f,
"name": os.path.basename(f),
"date": datetime.fromtimestamp(os.path.getctime(f)).strftime("%Y-%m-%d %H:%M:%S"),
"size": f"{os.path.getsize(f) / 1024:.1f} KB"
}
for f in files
]
def generate_image(
model_name,
prompt,
negative_prompt="",
seed=42,
randomize_seed=False,
width=1024,
height=1024,
guidance_scale=4.5,
num_inference_steps=40,
progress=gr.Progress(track_tqdm=True)
):
with model_locks[model_name]:
try:
# progress(0, desc=f"Loading {model_name} model...")
if model_name not in pipes:
pipes[model_name] = load_pipeline(model_name)
pipe = pipes[model_name]
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(DEVICE).manual_seed(seed)
print(f"Generating image with {model_name}...")
# progress(0.3, desc=f"Generating image with {model_name}...")
if model_name == "OneDiffusion":
prompt = "[[text2image]] " + prompt
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0]
filepath = save_generated_image(image, model_name, prompt)
print(f"Saved image to: {filepath}")
# progress(0.9, desc=f"Cleaning up {model_name} resources...")
# deep_cleanup(model_name, pipe)
# progress(1.0, desc=f"Generation complete with {model_name}")
return image, seed
except Exception as e:
print(f"Error with {model_name}: {str(e)}")
if model_name in pipes:
deep_cleanup(model_name, pipes[model_name])
raise e
# Gradio Interface
css = """
#col-container {
margin: 0 auto;
max-width: 1024px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# Multi-Model Image Generation")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Generate", scale=0, variant="primary")
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=512,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=512,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=7.5,
step=0.1,
value=4.5,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=40,
)
memory_indicator = gr.Markdown("Current memory usage: 0 GB")
with gr.Row():
with gr.Column(scale=2):
with gr.Tabs() as tabs:
results = {}
seeds = {}
for model_name in MODEL_CONFIGS.keys():
with gr.Tab(model_name):
results[model_name] = gr.Image(label=f"{model_name} Result")
seeds[model_name] = gr.Number(label="Seed used", visible=True)
with gr.Column(scale=1):
gr.Markdown("### Generated Images")
file_gallery = gr.Gallery(
label="Generated Images",
show_label=False,
elem_id="file_gallery",
columns=3,
height=800,
visible=True
)
refresh_button = gr.Button("Refresh Gallery")
def update_gallery():
"""Update the file gallery"""
files = get_generated_images()
return [
(f["path"], f"{f['name']}\n{f['date']}")
for f in files
]
@spaces.GPU(duration=400)
def generate_all(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
progress=gr.Progress()):
outputs = [None] * (len(MODEL_CONFIGS) * 2)
for idx, model_name in enumerate(MODEL_CONFIGS.keys()):
try:
# Display progress for the specific model
# progress(0, desc=f"Starting generation for {model_name}...")
print(f"IMAGE GENERATING {model_name} ")
image, used_seed = generate_image(
model_name, prompt, negative_prompt, seed,
randomize_seed, width, height, guidance_scale,
num_inference_steps, progress
)
print(f"IMAGE GENERATIED {model_name} ")
# Update the respective model's tab with the generated image
# results[model_name].update(image)
# seeds[model_name].update(used_seed)
outputs[idx * 2] = image # Image slot
outputs[idx * 2 + 1] = seed # Seed slot
# outputs.extend([image, used_seed])
# Add intermediate results to progress * (len(all_outputs) - len(all_outputs))
print("YELID")
yield outputs + [None]
except Exception as e:
print(f"Error generating with {model_name}: {str(e)}")
outputs[idx * 2] = None
outputs[idx * 2 + 1] = None
# Update the gallery after generation
gallery_images = update_gallery()
# file_gallery.update(value=gallery_images)
return outputs
output_components = []
for model_name in MODEL_CONFIGS.keys():
output_components.extend([results[model_name], seeds[model_name]])
run_button.click(
fn=generate_all,
inputs=[
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
],
outputs=output_components,
)
refresh_button.click(
fn=update_gallery,
inputs=[],
outputs=[file_gallery],
)
demo.load(
fn=update_gallery,
inputs=[],
outputs=[file_gallery],
)
if __name__ == "__main__":
demo.launch(server_name='0.0.0.0')