OmniConsistency / app.py
yiren98's picture
small UI layout improvements (#3)
badf861 verified
import spaces
import os
import time
import torch
import gradio as gr
from PIL import Image
from huggingface_hub import hf_hub_download, list_repo_files
from src_inference.pipeline import FluxPipeline
from src_inference.lora_helper import set_single_lora
BASE_PATH = "black-forest-labs/FLUX.1-dev"
LOCAL_LORA_DIR = "./LoRAs"
CUSTOM_LORA_DIR = "./Custom_LoRAs"
os.makedirs(LOCAL_LORA_DIR, exist_ok=True)
os.makedirs(CUSTOM_LORA_DIR, exist_ok=True)
print("downloading OmniConsistency base LoRA …")
omni_consistency_path = hf_hub_download(
repo_id="showlab/OmniConsistency",
filename="OmniConsistency.safetensors",
local_dir="./Model"
)
print("loading base pipeline …")
pipe = FluxPipeline.from_pretrained(
BASE_PATH, torch_dtype=torch.bfloat16
).to("cuda")
set_single_lora(pipe.transformer, omni_consistency_path,
lora_weights=[1], cond_size=512)
def download_all_loras():
lora_names = [
"3D_Chibi", "American_Cartoon", "Chinese_Ink", "Clay_Toy",
"Fabric", "Ghibli", "Irasutoya", "Jojo", "LEGO", "Line",
"Macaron", "Oil_Painting", "Origami", "Paper_Cutting",
"Picasso", "Pixel", "Poly", "Pop_Art", "Rick_Morty",
"Snoopy", "Van_Gogh", "Vector"
]
for name in lora_names:
hf_hub_download(
repo_id="showlab/OmniConsistency",
filename=f"LoRAs/{name}_rank128_bf16.safetensors",
local_dir=LOCAL_LORA_DIR,
)
download_all_loras()
def clear_cache(transformer):
for _, attn_processor in transformer.attn_processors.items():
attn_processor.bank_kv.clear()
@spaces.GPU()
def generate_image(
lora_name,
custom_repo_id,
prompt,
uploaded_image,
width, height,
guidance_scale,
num_inference_steps,
seed
):
width, height = int(width), int(height)
generator = torch.Generator("cpu").manual_seed(seed)
if custom_repo_id and custom_repo_id.strip():
repo_id = custom_repo_id.strip()
try:
files = list_repo_files(repo_id)
print("using custom LoRA from:", repo_id)
safetensors_files = [f for f in files if f.endswith(".safetensors")]
print("found safetensors files:", safetensors_files)
if not safetensors_files:
raise ValueError("No .safetensors files were found in this repo")
fname = safetensors_files[0]
lora_path = hf_hub_download(
repo_id=repo_id,
filename=fname,
local_dir=CUSTOM_LORA_DIR,
)
except Exception as e:
raise gr.Error(f"Load custom LoRA failed: {e}")
else:
lora_path = os.path.join(
f"{LOCAL_LORA_DIR}/LoRAs", f"{lora_name}_rank128_bf16.safetensors"
)
pipe.unload_lora_weights()
try:
pipe.load_lora_weights(
os.path.dirname(lora_path),
weight_name=os.path.basename(lora_path)
)
except Exception as e:
raise gr.Error(f"Load LoRA failed: {e}")
spatial_image = [uploaded_image.convert("RGB")]
subject_images = []
start = time.time()
out_img = pipe(
prompt,
height=(height // 8) * 8,
width=(width // 8) * 8,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
max_sequence_length=512,
generator=generator,
spatial_images=spatial_image,
subject_images=subject_images,
cond_size=512,
).images[0]
print(f"inference time: {time.time()-start:.2f}s")
clear_cache(pipe.transformer)
return uploaded_image, out_img
# =============== Gradio UI ===============
def create_interface():
demo_lora_names = [
"3D_Chibi", "American_Cartoon", "Chinese_Ink", "Clay_Toy",
"Fabric", "Ghibli", "Irasutoya", "Jojo", "LEGO", "Line",
"Macaron", "Oil_Painting", "Origami", "Paper_Cutting",
"Picasso", "Pixel", "Poly", "Pop_Art", "Rick_Morty",
"Snoopy", "Van_Gogh", "Vector"
]
def update_trigger_word(lora_name, prompt):
for name in demo_lora_names:
trigger = " ".join(name.split("_")) + " style,"
prompt = prompt.replace(trigger, "")
new_trigger = " ".join(lora_name.split("_"))+ " style,"
return new_trigger + prompt
# Example data
examples = [
["3D_Chibi", "", "3D Chibi style, Two smiling colleagues enthusiastically high-five in front of a whiteboard filled with technical notes about multimodal learning, reflecting a moment of success and collaboration at OpenAI.",
Image.open("./test_imgs/00.png"), 680, 1024, 3.5, 24, 42],
["Clay_Toy", "", "Clay Toy style, Three team members from OpenAI are gathered around a laptop in a cozy, festive setting, with holiday decorations in the background; one waves cheerfully while the others engage in light conversation, reflecting a relaxed and collaborative atmosphere.",
Image.open("./test_imgs/01.png"), 560, 1024, 3.5, 24, 42],
["American_Cartoon", "", "American Cartoon style, In a dramatic and comedic moment from a classic Chinese film, an intense elder with a white beard and red hat grips a younger man, declaring something with fervor, while the subtitle at the bottom reads 'I want them all' — capturing both tension and humor.",
Image.open("./test_imgs/02.png"), 568, 1024, 3.5, 24, 42],
["Origami", "", "Origami style, A thrilled fan wearing a Portugal football kit poses energetically with a smiling Cristiano Ronaldo, who gives a thumbs-up, as they stand side by side in a casual, cheerful moment—capturing the excitement of meeting a football legend.",
Image.open("./test_imgs/03.png"), 768, 672, 3.5, 24, 42],
["Vector", "", "Vector style, A man glances admiringly at a passing woman, while his girlfriend looks at him in disbelief, perfectly capturing the theme of shifting attention and misplaced priorities in a humorous, relatable way.",
Image.open("./test_imgs/04.png"), 512, 1024, 3.5, 24, 42]
]
header = """
<div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
<a href="https://arxiv.org/abs/2505.18445"><img src="https://img.shields.io/badge/ariXv-2505.18445-A42C25.svg" alt="arXiv"></a>
<a href="https://huggingface.co/showlab/OmniConsistency"><img src="https://img.shields.io/badge/🤗_HuggingFace-Model-ffbd45.svg" alt="HuggingFace"></a>
<a href="https://github.com/showlab/OmniConsistency"><img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub"></a>
</div>
"""
with gr.Blocks() as demo:
gr.Markdown("# OmniConsistency LoRA Image Generation")
gr.Markdown("Select a LoRA, enter a prompt, and upload an image to generate a new image with OmniConsistency.")
gr.HTML(header)
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Upload Image")
prompt_box = gr.Textbox(label="Prompt",
value="3D Chibi style,",
info="Remember to include the necessary trigger words if you're using a custom LoRA."
)
lora_dropdown = gr.Dropdown(
demo_lora_names, label="Select built-in LoRA")
custom_repo_box = gr.Textbox(
label="Enter Custom LoRA",
placeholder="LoRA Hugging Face path (e.g., 'username/repo_name')",
info="If you want to use a custom LoRA, enter its Hugging Face repo ID here and built-in LoRA will be Overridden. Leave empty to use built-in LoRAs. [Check the list of FLUX LoRAs](https://huggingface.co/models?other=base_model:adapter:black-forest-labs/FLUX.1-dev)"
)
gen_btn = gr.Button("Generate")
with gr.Column(scale=1):
output_image = gr.ImageSlider(label="Generated Image")
with gr.Accordion("Advanced Options", open=False):
height_box = gr.Textbox(value="1024", label="Height")
width_box = gr.Textbox(value="1024", label="Width")
guidance_slider = gr.Slider(
0.1, 20, value=3.5, step=0.1, label="Guidance Scale")
steps_slider = gr.Slider(
1, 50, value=25, step=1, label="Inference Steps")
seed_slider = gr.Slider(
1, 2_147_483_647, value=42, step=1, label="Seed")
lora_dropdown.select(fn=update_trigger_word, inputs=[lora_dropdown,prompt_box],
outputs=prompt_box)
gr.Examples(
examples=examples,
inputs=[lora_dropdown, custom_repo_box, prompt_box, image_input,
height_box, width_box, guidance_slider, steps_slider, seed_slider],
outputs=output_image,
fn=generate_image,
cache_examples=False,
label="Examples"
)
gen_btn.click(
fn=generate_image,
inputs=[lora_dropdown, custom_repo_box, prompt_box, image_input,
width_box, height_box, guidance_slider, steps_slider, seed_slider],
outputs=output_image
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch(ssr_mode=False)