UMO_OmniGen2 / app.py
cb1cyf's picture
feat: update bibtex
150fc3d
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
# Copyright (c) VectorSpaceLab and its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import dotenv
dotenv.load_dotenv(override=True)
import gradio as gr
import spaces
import argparse
import json
import random
from datetime import datetime
from glob import glob
from typing import Literal
import torch
from torchvision.transforms.functional import to_pil_image, to_tensor
from accelerate import Accelerator
from huggingface_hub import hf_hub_download
from peft import LoraConfig
from safetensors.torch import load_file
from omnigen2.pipelines.omnigen2.pipeline_omnigen2 import OmniGen2Pipeline
from omnigen2.models.transformers.transformer_omnigen2 import OmniGen2Transformer2DModel
from omnigen2.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from omnigen2.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
from omnigen2.utils.img_util import create_collage
NEGATIVE_PROMPT = "(((deformed))), blurry, over saturation, bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), fused fingers, messy drawing, broken legs censor, censored, censor_bar"
SAVE_DIR = "output/gradio"
pipeline = None
accelerator = None
save_images = False
enable_taylorseer = False
enable_teacache = False
def load_pipeline(accelerator, weight_dtype, args):
pipeline = OmniGen2Pipeline.from_pretrained(
args.model_path,
torch_dtype=weight_dtype,
trust_remote_code=True,
)
pipeline.transformer = OmniGen2Transformer2DModel.from_pretrained(
args.model_path,
subfolder="transformer",
torch_dtype=weight_dtype,
)
lora_path = hf_hub_download("bytedance-research/UMO", "UMO_OmniGen2.safetensors") if args.lora_path is None else args.lora_path
target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
lora_config = LoraConfig(
r=512,
lora_alpha=512,
lora_dropout=0,
init_lora_weights="gaussian",
target_modules=target_modules,
)
pipeline.transformer.add_adapter(lora_config)
lora_state_dict = load_file(lora_path, device=accelerator.device.__str__())
pipeline.transformer.load_state_dict(lora_state_dict, strict=False)
pipeline.transformer.fuse_lora(lora_scale=1, safe_fusing=False, adapter_names=["default"])
pipeline.transformer.unload_lora()
if args.enable_sequential_cpu_offload:
pipeline.enable_sequential_cpu_offload()
elif args.enable_model_cpu_offload:
pipeline.enable_model_cpu_offload()
else:
pipeline = pipeline.to(accelerator.device)
return pipeline
@spaces.GPU(duration=120)
def run(
instruction,
width_input,
height_input,
image_input_1,
image_input_2,
image_input_3,
scheduler: Literal["euler", "dpmsolver++"] = "euler",
num_inference_steps: int = 50,
negative_prompt: str = NEGATIVE_PROMPT,
guidance_scale_input: float = 5.0,
img_guidance_scale_input: float = 2.0,
cfg_range_start: float = 0.0,
cfg_range_end: float = 1.0,
num_images_per_prompt: int = 1,
max_input_image_side_length: int = 2048,
max_pixels: int = 1024 * 1024,
seed_input: int = -1,
align_res: bool = True,
):
if enable_taylorseer:
pipeline.enable_taylorseer = True
elif enable_teacache:
pipeline.transformer.enable_teacache = True
pipeline.transformer.teacache_rel_l1_thresh = 0.05
input_images = [image_input_1, image_input_2, image_input_3]
input_images = [img for img in input_images if img is not None]
if len(input_images) == 0:
input_images = None
if seed_input == -1:
seed_input = random.randint(0, 2**16 - 1)
generator = torch.Generator(device="cpu").manual_seed(seed_input) # set random to cpu to avoid different result on different GPU
if scheduler == 'euler' and not isinstance(pipeline.scheduler, FlowMatchEulerDiscreteScheduler):
pipeline.scheduler = FlowMatchEulerDiscreteScheduler()
elif scheduler == 'dpmsolver++' and not isinstance(pipeline.scheduler, DPMSolverMultistepScheduler):
pipeline.scheduler = DPMSolverMultistepScheduler(
algorithm_type="dpmsolver++",
solver_type="midpoint",
solver_order=2,
prediction_type="flow_prediction",
)
results = pipeline(
prompt=instruction,
input_images=input_images,
width=width_input,
height=height_input,
align_res=align_res,
max_input_image_side_length=max_input_image_side_length,
max_pixels=max_pixels,
num_inference_steps=num_inference_steps,
max_sequence_length=1024,
text_guidance_scale=guidance_scale_input,
image_guidance_scale=img_guidance_scale_input,
cfg_range=(cfg_range_start, cfg_range_end),
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
generator=generator,
output_type="pil",
)
vis_images = [to_tensor(image) * 2 - 1 for image in results.images]
output_image = create_collage(vis_images)
output_path = ""
if save_images:
# Create outputs directory if it doesn't exist
output_dir = SAVE_DIR
os.makedirs(output_dir, exist_ok=True)
# Generate unique filename with timestamp
timestamp = datetime.now().strftime("%Y_%m_%d-%H_%M_%S")
# Generate unique filename with timestamp
output_path = os.path.join(output_dir, f"{timestamp}_seed{seed_input}_{instruction[:20]}.png")
# Save the image
output_image.save(output_path)
# Save All Generated Images
if len(results.images) > 1:
for i, image in enumerate(results.images):
image_name, ext = os.path.splitext(output_path)
image.save(f"{image_name}_{i}{ext}")
return output_image, output_path
def get_examples(base_dir="assets/examples/OmniGen2"):
example_keys = ["instruction", "width_input", "height_input", "image_input_1", "image_input_2", "image_input_3", "seed_input", "align_res", "output_image", "output_image_OmniGen2"]
examples = []
example_configs = glob(os.path.join(base_dir, "*", "config.json"))
for config_path in example_configs:
with open(config_path, "r", encoding="utf-8") as f:
config = json.load(f)
_example = [config.get(k, None) for k in example_keys]
examples.append(_example)
return examples
with open("assets/logo.svg", "r", encoding="utf-8") as svg_file:
logo_content = svg_file.read()
title = f"""
<div style="display: flex; align-items: center; justify-content: center;">
<span style="transform: scale(0.7);margin-right: -5px;">{logo_content}</span>
<span style="font-size: 1.8em;margin-left: -10px;font-weight: bold; font-family: Gill Sans;">UMO (based on OmniGen2) by UXO Team</span>
</div>
""".strip()
badges_text = r"""
<div style="text-align: center; display: flex; justify-content: center; gap: 5px;">
<a href="https://github.com/bytedance/UMO"><img alt="Build" src="https://img.shields.io/github/stars/bytedance/UMO"></a>
<a href="https://bytedance.github.io/UMO/"><img alt="Build" src="https://img.shields.io/badge/Project%20Page-UMO-blue"></a>
<a href="https://huggingface.co/bytedance-research/UMO"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Hugging%20Face&message=Model&color=green"></a>
<a href="https://arxiv.org/abs/2509.06818"><img alt="Build" src="https://img.shields.io/badge/arXiv%20paper-UMO-b31b1b.svg"></a>
<a href="https://huggingface.co/spaces/bytedance-research/UMO_UNO"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Demo&message=UMO-UNO&color=orange"></a>
<a href="https://huggingface.co/spaces/bytedance-research/UMO_OmniGen2"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Demo&message=UMO-OmniGen2&color=orange"></a>
</div>
""".strip()
tips = """
📌 ***UMO*** is a **U**nified **M**ulti-identity **O**ptimization framework to *boost the multi-ID fidelity and mitigate confusion* for image customization model, and the latest addition to the UXO family (<a href='https://github.com/bytedance/UMO' target='_blank'> UMO</a>, <a href='https://github.com/bytedance/USO' target='_blank'> USO</a> and <a href='https://github.com/bytedance/UNO' target='_blank'> UNO</a>).
🎨 UMO in the demo is trained based on <a href='https://github.com/VectorSpaceLab/OmniGen2' target='_blank'> OmniGen2</a>.
💡 We provide step-by-step instructions in our <a href='https://github.com/bytedance/UMO' target='_blank'> Github Repo</a>. Additionally, try the examples and comparison provided below the demo to quickly get familiar with UMO and spark your creativity!
<details>
<summary style="cursor: pointer; color: #d34c0e; font-weight: 500;"> ⚡️ Tips from the based OmniGen2</summary>
- Image Quality: Use high-resolution images (**at least 512x512 recommended**).
- Be Specific: Instead of "Add bird to desk", try "Add the bird from image 1 to the desk in image 2".
- Use English: English prompts currently yield better results.
- Increase image_guidance_scale for better consistency with the reference image:
- Image Editing: 1.3 - 2.0
- In-context Generation: 2.0 - 3.0
- For in-context edit (edit based multiple images), we recommend using the following prompt format: "Edit the first image: add/replace (the [object] with) the [object] from the second image. [descripton for your target image]."
- For example: "Edit the first image: add the man from the second image. The man is talking with a woman in the kitchen"
""".strip()
article = """
```bibtex
@article{cheng2025umo,
title={UMO: Scaling Multi-Identity Consistency for Image Customization via Matching Reward},
author={Cheng, Yufeng and Wu, Wenxu and Wu, Shaojin and Huang, Mengqi and Ding, Fei and He, Qian},
journal={arXiv preprint arXiv:2509.06818},
year={2025}
}
```
""".strip()
star = f"""
If UMO is helpful, please help to ⭐ our <a href='https://github.com/bytedance/UMO' target='_blank'> Github Repo</a> or cite our paper. Thanks a lot!
{article}
"""
def main(args):
# Gradio
with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(badges_text)
gr.Markdown(tips)
with gr.Row():
with gr.Column():
# text prompt
instruction = gr.Textbox(
label='Enter your prompt',
info='Use "first/second image" or “第一张图/第二张图” as reference.',
placeholder="Type your prompt here...",
)
with gr.Row(equal_height=True):
# input images
image_input_1 = gr.Image(label="First Image", type="pil")
image_input_2 = gr.Image(label="Second Image", type="pil")
image_input_3 = gr.Image(label="Third Image", type="pil")
generate_button = gr.Button("Generate Image")
negative_prompt = gr.Textbox(
label="Enter your negative prompt",
placeholder="Type your negative prompt here...",
value=NEGATIVE_PROMPT,
)
# slider
with gr.Row(equal_height=True):
height_input = gr.Slider(
label="Height", minimum=256, maximum=2048, value=1024, step=128
)
width_input = gr.Slider(
label="Width", minimum=256, maximum=2048, value=1024, step=128
)
with gr.Accordion("Speed Up Options", open=True):
with gr.Row(equal_height=True):
global enable_taylorseer
global enable_teacache
enable_taylorseer = gr.Checkbox(label="Using TaylorSeer to speed up", value=True)
enable_teacache = gr.Checkbox(label="Using TeaCache to speed up", value=False)
with gr.Row(equal_height=True):
scheduler_input = gr.Dropdown(
label="Scheduler",
choices=["euler", "dpmsolver++"],
value="euler",
info="The scheduler to use for the model.",
)
num_inference_steps = gr.Slider(
label="Inference Steps", minimum=20, maximum=100, value=50, step=1
)
with gr.Accordion("Advanced Options", open=False):
with gr.Row(equal_height=True):
align_res = gr.Checkbox(
label="Align Resolution",
info="Align output's resolution with the first reference image. Only valid when there is only one reference image.",
value=True
)
with gr.Row(equal_height=True):
text_guidance_scale_input = gr.Slider(
label="Text Guidance Scale",
minimum=1.0,
maximum=8.0,
value=5.0,
step=0.1,
)
image_guidance_scale_input = gr.Slider(
label="Image Guidance Scale",
minimum=1.0,
maximum=3.0,
value=2.0,
step=0.1,
)
with gr.Row(equal_height=True):
cfg_range_start = gr.Slider(
label="CFG Range Start",
minimum=0.0,
maximum=1.0,
value=0.0,
step=0.1,
)
cfg_range_end = gr.Slider(
label="CFG Range End",
minimum=0.0,
maximum=1.0,
value=1.0,
step=0.1,
)
def adjust_end_slider(start_val, end_val):
return max(start_val, end_val)
def adjust_start_slider(end_val, start_val):
return min(end_val, start_val)
cfg_range_start.input(
fn=adjust_end_slider,
inputs=[cfg_range_start, cfg_range_end],
outputs=[cfg_range_end]
)
cfg_range_end.input(
fn=adjust_start_slider,
inputs=[cfg_range_end, cfg_range_start],
outputs=[cfg_range_start]
)
with gr.Row(equal_height=True):
num_images_per_prompt = gr.Slider(
label="Number of images per prompt",
minimum=1,
maximum=4,
value=1,
step=1,
)
seed_input = gr.Slider(
label="Seed", minimum=-1, maximum=2147483647, value=-1, step=1
)
with gr.Row(equal_height=True):
max_input_image_side_length = gr.Slider(
label="max_input_image_side_length",
minimum=256,
maximum=2048,
value=2048,
step=256,
)
max_pixels = gr.Slider(
label="max_pixels",
minimum=256 * 256,
maximum=1536 * 1536,
value=1024 * 1024,
step=256 * 256,
)
with gr.Column():
with gr.Column():
# output image
output_image = gr.Image(label="Output Image")
global save_images
# save_images = gr.Checkbox(label="Save generated images", value=True)
save_images = True
with gr.Accordion("Examples Comparison with OmniGen2", open=False):
output_image_omnigen2 = gr.Image(label="Generated Image (OmniGen2)")
download_btn = gr.File(label="Download full-resolution", type="filepath", interactive=False)
gr.Markdown(star)
global accelerator
global pipeline
bf16 = True
accelerator = Accelerator(mixed_precision="bf16" if bf16 else "no")
weight_dtype = torch.bfloat16 if bf16 else torch.float32
pipeline = load_pipeline(accelerator, weight_dtype, args)
# click
generate_button.click(
run,
inputs=[
instruction,
width_input,
height_input,
image_input_1,
image_input_2,
image_input_3,
scheduler_input,
num_inference_steps,
negative_prompt,
text_guidance_scale_input,
image_guidance_scale_input,
cfg_range_start,
cfg_range_end,
num_images_per_prompt,
max_input_image_side_length,
max_pixels,
seed_input,
align_res,
],
outputs=[output_image, download_btn],
)
gr.Examples(
examples=get_examples("assets/examples/OmniGen2"),
inputs=[
instruction,
width_input,
height_input,
image_input_1,
image_input_2,
image_input_3,
seed_input,
align_res,
output_image,
output_image_omnigen2,
],
label="We provide examples for academic research. The vast majority of images used in this demo are either generated or from open-source datasets. If you have any concerns, please contact us, and we will promptly remove any inappropriate content.",
examples_per_page=15
)
# launch
demo.launch(share=args.share, server_port=args.port, server_name=args.server_name, ssr_mode=False)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--share", action="store_true", help="Share the Gradio app")
parser.add_argument(
"--port", type=int, default=7860, help="Port to use for the Gradio app"
)
parser.add_argument(
"--server_name", type=str, default=None
)
parser.add_argument(
"--model_path",
type=str,
default="OmniGen2/OmniGen2",
help="Path or HuggingFace name of the model to load."
)
parser.add_argument(
"--enable_model_cpu_offload",
action="store_true",
help="Enable model CPU offload."
)
parser.add_argument(
"--enable_sequential_cpu_offload",
action="store_true",
help="Enable sequential CPU offload."
)
parser.add_argument(
"--lora_path",
type=str,
default=None,
help="Path to the LoRA checkpoint to load."
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
main(args)