UMO_UNO / app.py
cb1cyf's picture
feat:update bibtex
9108f2b
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
sys.path.append(os.path.abspath("projects/UNO"))
import dataclasses
import gradio as gr
import spaces
import json
import torch
from huggingface_hub import hf_hub_download
from pathlib import Path
from uno.flux.pipeline import UNOPipeline
def get_examples(examples_dir: str = "assets/examples") -> list:
examples = Path(examples_dir)
ans = []
for example in examples.iterdir():
if not example.is_dir():
continue
with open(example / "config.json") as f:
example_dict = json.load(f)
example_list = []
example_list.append(example_dict["prompt"]) # prompt
for key in ["image_ref1", "image_ref2"]:
if key in example_dict:
example_list.append(str(example / example_dict[key]))
else:
example_list.append(None)
example_list.append(example_dict.get("width", 768))
example_list.append(example_dict.get("height", 768))
example_list.append(example_dict["seed"])
example_list.append(str(example / example_dict["image_result_uno"]))
example_list.append(str(example / example_dict["image_result"]))
ans.append(example_list)
return ans
with open("assets/logo.svg", "r", encoding="utf-8") as svg_file:
logo_content = svg_file.read()
title = f"""
<div style="display: flex; align-items: center; justify-content: center;">
<span style="transform: scale(0.7);margin-right: -5px;">{logo_content}</span>
<span style="font-size: 1.8em;margin-left: -10px;font-weight: bold; font-family: Gill Sans;">UMO (based on UNO) by UXO Team</span>
</div>
""".strip()
badges_text = r"""
<div style="text-align: center; display: flex; justify-content: center; gap: 5px;">
<a href="https://github.com/bytedance/UMO"><img alt="Build" src="https://img.shields.io/github/stars/bytedance/UMO"></a>
<a href="https://bytedance.github.io/UMO/"><img alt="Build" src="https://img.shields.io/badge/Project%20Page-UMO-blue"></a>
<a href="https://huggingface.co/bytedance-research/UMO"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Hugging%20Face&message=Model&color=green"></a>
<a href="https://arxiv.org/abs/2509.06818"><img alt="Build" src="https://img.shields.io/badge/arXiv%20paper-UMO-b31b1b.svg"></a>
<a href="https://huggingface.co/spaces/bytedance-research/UMO_UNO"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Demo&message=UMO-UNO&color=orange"></a>
<a href="https://huggingface.co/spaces/bytedance-research/UMO_OmniGen2"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Demo&message=UMO-OmniGen2&color=orange"></a>
</div>
""".strip()
tips = """
📌 ***UMO*** is a **U**nified **M**ulti-identity **O**ptimization framework to *boost the multi-ID fidelity and mitigate confusion* for image customization model, and the latest addition to the UXO family (<a href='https://github.com/bytedance/UMO' target='_blank'> UMO</a>, <a href='https://github.com/bytedance/USO' target='_blank'> USO</a> and <a href='https://github.com/bytedance/UNO' target='_blank'> UNO</a>).
🎨 UMO in the demo is trained based on <a href='https://github.com/bytedance/UNO' target='_blank'> UNO</a>.
💡 We provide step-by-step instructions in our <a href='https://github.com/bytedance/UMO' target='_blank'> Github Repo</a>. Additionally, try the examples and comparison provided below the demo to quickly get familiar with UMO and spark your creativity!
⚡️ ***Tips for UMO based on UNO***
- Using description prompt instead of instruction one.
- Using resolution 768~1024 instead of 512.
- When reference identities are more than 2, the based model UNO becomes unstable.
""".strip()
article = """
```bibtex
@article{cheng2025umo,
title={UMO: Scaling Multi-Identity Consistency for Image Customization via Matching Reward},
author={Cheng, Yufeng and Wu, Wenxu and Wu, Shaojin and Huang, Mengqi and Ding, Fei and He, Qian},
journal={arXiv preprint arXiv:2509.06818},
year={2025}
}
```
""".strip()
star = f"""
If UMO is helpful, please help to ⭐ our <a href='https://github.com/bytedance/UMO' target='_blank'> Github Repo</a> or cite our paper. Thanks a lot!
{article}
"""
def create_demo(
model_type: str,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
offload: bool = False,
lora_path: str = None,
):
# create pipeline
pipeline = UNOPipeline(model_type, device, offload, only_lora=True, lora_rank=512)
lora_path = hf_hub_download("bytedance-research/UMO", "UMO_UNO.safetensors") if lora_path is None else lora_path
pipeline.load_ckpt(lora_path)
pipeline.model.to(device)
pipeline.gradio_generate = spaces.GPU(duratioin=120)(pipeline.gradio_generate)
# gradio
with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(badges_text)
gr.Markdown(tips)
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", value="")
with gr.Row():
image_prompt1 = gr.Image(label="Ref Img1", visible=True, interactive=True, type="pil")
image_prompt2 = gr.Image(label="Ref Img2", visible=True, interactive=True, type="pil")
image_prompt3 = gr.Image(label="Ref Img3", visible=True, interactive=True, type="pil")
image_prompt4 = gr.Image(label="Ref Img4", visible=True, interactive=True, type="pil")
with gr.Row():
with gr.Column():
width = gr.Slider(512, 2048, 768, step=16, label="Gneration Width")
height = gr.Slider(512, 2048, 768, step=16, label="Gneration Height")
with gr.Accordion("Advanced Options", open=False):
with gr.Row():
num_steps = gr.Slider(1, 50, 25, step=1, label="Number of steps")
guidance = gr.Slider(1.0, 5.0, 4.0, step=0.1, label="Guidance", interactive=True)
seed = gr.Number(-1, label="Seed (-1 for random)")
generate_btn = gr.Button("Generate")
with gr.Column():
output_image = gr.Image(label="Generated Image")
with gr.Accordion("Examples Comparison with UNO", open=False):
output_image_uno = gr.Image(label="Generated Image (UNO)")
download_btn = gr.File(label="Download full-resolution", type="filepath", interactive=False)
inputs = [
prompt, width, height, guidance, num_steps,
seed, image_prompt1, image_prompt2, image_prompt3, image_prompt4
]
generate_btn.click(
fn=pipeline.gradio_generate,
inputs=inputs,
outputs=[output_image, download_btn],
)
gr.Markdown(star)
examples = get_examples("./assets/examples/UNO")
gr.Examples(
examples=examples,
inputs=[
prompt,
image_prompt1, image_prompt2,
width, height,
seed, output_image_uno, output_image
],
label="We provide examples for academic research. The vast majority of images used in this demo are either generated or from open-source datasets. If you have any concerns, please contact us, and we will promptly remove any inappropriate content.",
)
return demo
if __name__ == "__main__":
from typing import Literal
from transformers import HfArgumentParser
@dataclasses.dataclass
class AppArgs:
name: Literal["flux-dev", "flux-dev-fp8", "flux-schnell"] = "flux-dev"
device: Literal["cuda", "cpu"] = (
"cuda" if torch.cuda.is_available() \
else "mps" if torch.backends.mps.is_available() \
else "cpu"
)
offload: bool = dataclasses.field(
default=False,
metadata={"help": "If True, sequantial offload the models(ae, dit, text encoder) to CPU if not used."}
)
port: int = 7860
server_name: str | None = None
lora_path: str | None = None
parser = HfArgumentParser([AppArgs])
args_tuple = parser.parse_args_into_dataclasses() # type: tuple[AppArgs]
args = args_tuple[0]
demo = create_demo(args.name, args.device, args.offload, args.lora_path)
demo.launch(server_port=args.port, server_name=args.server_name, ssr_mode=False)