# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import sys sys.path.append(os.path.abspath("projects/UNO")) import dataclasses import gradio as gr import spaces import json import torch from huggingface_hub import hf_hub_download from pathlib import Path from uno.flux.pipeline import UNOPipeline def get_examples(examples_dir: str = "assets/examples") -> list: examples = Path(examples_dir) ans = [] for example in examples.iterdir(): if not example.is_dir(): continue with open(example / "config.json") as f: example_dict = json.load(f) example_list = [] example_list.append(example_dict["prompt"]) # prompt for key in ["image_ref1", "image_ref2"]: if key in example_dict: example_list.append(str(example / example_dict[key])) else: example_list.append(None) example_list.append(example_dict.get("width", 768)) example_list.append(example_dict.get("height", 768)) example_list.append(example_dict["seed"]) example_list.append(str(example / example_dict["image_result_uno"])) example_list.append(str(example / example_dict["image_result"])) ans.append(example_list) return ans with open("assets/logo.svg", "r", encoding="utf-8") as svg_file: logo_content = svg_file.read() title = f"""
{logo_content} UMO (based on UNO) by UXO Team
""".strip() badges_text = r"""
Build Build Build
""".strip() tips = """ 📌 ***UMO*** is a **U**nified **M**ulti-identity **O**ptimization framework to *boost the multi-ID fidelity and mitigate confusion* for image customization model, and the latest addition to the UXO family ( UMO, USO and UNO). 🎨 UMO in the demo is trained based on UNO. 💡 We provide step-by-step instructions in our Github Repo. Additionally, try the examples and comparison provided below the demo to quickly get familiar with UMO and spark your creativity! ⚡️ ***Tips for UMO based on UNO*** - Using description prompt instead of instruction one. - Using resolution 768~1024 instead of 512. - When reference identities are more than 2, the based model UNO becomes unstable. """.strip() article = """ ```bibtex @article{cheng2025umo, title={UMO: Scaling Multi-Identity Consistency for Image Customization via Matching Reward}, author={Cheng, Yufeng and Wu, Wenxu and Wu, Shaojin and Huang, Mengqi and Ding, Fei and He, Qian}, journal={arXiv preprint arXiv:2509.06818}, year={2025} } ``` """.strip() star = f""" If UMO is helpful, please help to ⭐ our Github Repo or cite our paper. Thanks a lot! {article} """ def create_demo( model_type: str, device: str = "cuda" if torch.cuda.is_available() else "cpu", offload: bool = False, lora_path: str = None, ): # create pipeline pipeline = UNOPipeline(model_type, device, offload, only_lora=True, lora_rank=512) lora_path = hf_hub_download("bytedance-research/UMO", "UMO_UNO.safetensors") if lora_path is None else lora_path pipeline.load_ckpt(lora_path) pipeline.model.to(device) pipeline.gradio_generate = spaces.GPU(duratioin=120)(pipeline.gradio_generate) # gradio with gr.Blocks() as demo: gr.Markdown(title) gr.Markdown(badges_text) gr.Markdown(tips) with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="Prompt", value="") with gr.Row(): image_prompt1 = gr.Image(label="Ref Img1", visible=True, interactive=True, type="pil") image_prompt2 = gr.Image(label="Ref Img2", visible=True, interactive=True, type="pil") image_prompt3 = gr.Image(label="Ref Img3", visible=True, interactive=True, type="pil") image_prompt4 = gr.Image(label="Ref Img4", visible=True, interactive=True, type="pil") with gr.Row(): with gr.Column(): width = gr.Slider(512, 2048, 768, step=16, label="Gneration Width") height = gr.Slider(512, 2048, 768, step=16, label="Gneration Height") with gr.Accordion("Advanced Options", open=False): with gr.Row(): num_steps = gr.Slider(1, 50, 25, step=1, label="Number of steps") guidance = gr.Slider(1.0, 5.0, 4.0, step=0.1, label="Guidance", interactive=True) seed = gr.Number(-1, label="Seed (-1 for random)") generate_btn = gr.Button("Generate") with gr.Column(): output_image = gr.Image(label="Generated Image") with gr.Accordion("Examples Comparison with UNO", open=False): output_image_uno = gr.Image(label="Generated Image (UNO)") download_btn = gr.File(label="Download full-resolution", type="filepath", interactive=False) inputs = [ prompt, width, height, guidance, num_steps, seed, image_prompt1, image_prompt2, image_prompt3, image_prompt4 ] generate_btn.click( fn=pipeline.gradio_generate, inputs=inputs, outputs=[output_image, download_btn], ) gr.Markdown(star) examples = get_examples("./assets/examples/UNO") gr.Examples( examples=examples, inputs=[ prompt, image_prompt1, image_prompt2, width, height, seed, output_image_uno, output_image ], label="We provide examples for academic research. The vast majority of images used in this demo are either generated or from open-source datasets. If you have any concerns, please contact us, and we will promptly remove any inappropriate content.", ) return demo if __name__ == "__main__": from typing import Literal from transformers import HfArgumentParser @dataclasses.dataclass class AppArgs: name: Literal["flux-dev", "flux-dev-fp8", "flux-schnell"] = "flux-dev" device: Literal["cuda", "cpu"] = ( "cuda" if torch.cuda.is_available() \ else "mps" if torch.backends.mps.is_available() \ else "cpu" ) offload: bool = dataclasses.field( default=False, metadata={"help": "If True, sequantial offload the models(ae, dit, text encoder) to CPU if not used."} ) port: int = 7860 server_name: str | None = None lora_path: str | None = None parser = HfArgumentParser([AppArgs]) args_tuple = parser.parse_args_into_dataclasses() # type: tuple[AppArgs] args = args_tuple[0] demo = create_demo(args.name, args.device, args.offload, args.lora_path) demo.launch(server_port=args.port, server_name=args.server_name, ssr_mode=False)