Voost / app.py
RyanL22's picture
fix: gr.Markdown doesn't work properly
0615f24 verified
from io import BytesIO
import gradio as gr
from PIL import Image
import httpx
from gradio_toggle import Toggle
from pathlib import Path
import numpy as np
import os
api_server = os.environ["NXN_API_SERVER"]
tryon_endpoint = os.environ["NXN_TRYON_ENDPOINT"]
tryoff_endpoint = os.environ["NXN_TRYOFF_ENDPOINT"]
MAX_DIM = 2048
MIN_DIM = 500
def encode_bytes(image: Image.Image, format="PNG"):
buffered = BytesIO()
image.save(buffered, format=format)
buffered.seek(0)
return buffered
# str to int
def garment_type_to_int(garment_type: str):
garment_dict = {"Upper": 0, "Lower": 1, "Full": 2}
if garment_dict[garment_type] is None:
raise gr.Error("Unexpected garment condition error")
else:
return garment_dict[garment_type]
def extract_image_from_input(image_data):
if isinstance(image_data, dict) and "background" in image_data:
return image_data["background"].convert("RGB")
else:
return image_data.convert("RGB")
def resize_image_if_needed(image: Image.Image):
if image is None:
return None, False
original_width, original_height = image.size
if original_width > MAX_DIM or original_height > MAX_DIM:
gr.Warning("A provided image is too large and has been resized")
scale_factor = min(MAX_DIM / original_width, MAX_DIM / original_height)
new_width = int(original_width * scale_factor)
new_height = int(original_height * scale_factor)
return image.resize((new_width, new_height), Image.Resampling.LANCZOS), True
elif original_width < MIN_DIM or original_height < MIN_DIM:
gr.Warning("A provided image is too small and has been resized")
scale_factor = max(MIN_DIM / original_width, MIN_DIM / original_height)
new_width = int(original_width * scale_factor)
new_height = int(original_height * scale_factor)
return image.resize((new_width, new_height), Image.Resampling.LANCZOS), True
return image, False
# API Helpers
async def _call_api(url: str, files: dict, data: dict):
try:
async with httpx.AsyncClient(timeout=3600) as client:
response = await client.post(url, data=data, files=files)
response.raise_for_status()
return Image.open(BytesIO(response.content))
except httpx.RequestError as e:
print(f"API request failed: {e}")
raise gr.Error("Network error: Could not connect to the model API. Please try again later.")
except Exception as e:
print(f"An unexpected error occurred: {e}")
raise
raise gr.Error("An unexpected error occurred. The model may have failed to process the images.")
async def call_tryon_api(model_image: Image.Image, garment_image: Image.Image, garment_type: int, mask: Image.Image=None, seed: int=1234):
files = [
("images", ("target.png", encode_bytes(model_image), "image/png")),
("images", ("garment.png", encode_bytes(garment_image), "image/png"))
]
if mask:
files.append(("images", ("mask.png", encode_bytes(mask, format="PNG"), "image/png")))
data = {'garment_type': garment_type, 'seed': seed}
return await _call_api(f"{api_server}/{tryon_endpoint}", files=files, data=data)
async def call_tryoff_api(model_image: Image.Image, garment_type: int, seed: int=1234):
files = [ ("images", ("target.png", encode_bytes(model_image), "image/png")) ]
data = {'garment_type': garment_type, 'seed': seed}
return await _call_api(f"{api_server}/{tryoff_endpoint}", files=files, data=data)
async def api_helper(model_image_dict: dict, garment_image: Image.Image, garment_type: str, is_tryoff: bool, seed: int):
if model_image_dict is None or model_image_dict["background"] is None:
raise gr.Error("Missing model image")
elif not is_tryoff and garment_image is None:
raise gr.Error("Missing garment image for Try-On")
# Because Gradio ImageEditor can return a dict
model_image = extract_image_from_input(model_image_dict)
model_image, model_resized = resize_image_if_needed(model_image)
garment_image, _ = resize_image_if_needed(garment_image)
garment_type_int = garment_type_to_int(garment_type)
if is_tryoff:
return await call_tryoff_api(model_image, garment_type_int, seed)
else:
mask_image = None
if isinstance(model_image_dict, dict) and model_image_dict.get("layers"):
mask = model_image_dict["layers"][0]
mask_array = np.array(mask)
if not np.all(mask_array < 10):
is_black = np.all(mask_array < 10, axis=2)
mask_image = Image.fromarray(((~is_black) * 255).astype(np.uint8))
if model_resized:
mask_image = mask_image.resize(model_image.size, Image.Resampling.NEAREST)
else:
gr.Info("No mask provided, using auto-generated mask")
return await call_tryon_api(model_image, garment_image, garment_type_int, mask=mask_image, seed=seed)
# Event handler functions
def handle_toggle(toggle_value):
"""Handle toggle state changes - controls garment input visibility"""
toggle_label = gr.update(value=toggle_value, label="Try-Off") if toggle_value else gr.update(value=toggle_value, label="Try-On")
submit_btn_label = gr.update(value="Run Try-Off", elem_id="tryoff-color") if toggle_value else gr.update(value="Run Try-On", elem_id="tryon-color")
if toggle_value:
# Clear the image and disable the component
return gr.update(value=None, elem_classes=["disabled-image"], interactive=False), toggle_label, submit_btn_label
else:
# Re-enable the component without clearing the image
return gr.update(elem_classes=[], interactive=True), toggle_label, submit_btn_label
def set_tryon(garment_img, model_img, output_img, garment_condition):
garment_update, toggle_label, submit_btn_label = handle_toggle(False)
return garment_update, toggle_label, submit_btn_label
def set_tryoff(model_img, output_img, garment_condition):
garment_update, toggle_label, submit_btn_label = handle_toggle(True)
return garment_update, toggle_label, submit_btn_label
def garment_sort_key(filename):
if filename.startswith("upper_"):
return (0, filename)
elif filename.startswith("lower_"):
return (1, filename)
elif filename.startswith("full_"):
return (2, filename)
else:
return (3, filename)
# Get images for examples
images_path = os.path.join(os.path.dirname(__file__),'images')
garment_list = os.listdir(os.path.join(images_path, "garments"))
garment_list_path = [
os.path.join(images_path, "garments", cloth)
for cloth in sorted(garment_list, key=garment_sort_key)
]
people_list = os.listdir(os.path.join(images_path, "persons"))
people_list_path = [os.path.join(images_path, "persons", human) for human in sorted(people_list)]
gr.set_static_paths(paths=[Path.cwd().absolute()/"images"])
# Create the Gradio interface
with gr.Blocks(css_paths="styles.css", theme=gr.themes.Ocean(), title="Voost: Virtual Try-On/Off") as demo:
with gr.Row():
gr.HTML("""
<div class="header-container">
<div class="logo-container">
<a href="https://nxn.ai/">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="/gradio_api/file=images/dark_mode_logo.png"/>
<img src='/gradio_api/file=images/nxn_logo_transparent.png' style="height: 120px; width: 150px;"/>
</picture>
</a>
</div>
<div style="display: flex; flex-direction: column; align-items: center; text-align: center;">
<div style="font-size: 45px; margin-bottom: 10px;">
<b>Voost: Virtual Try-On/Off</b>
</div>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://arxiv.org/abs/2508.04825">
<img src='https://img.shields.io/badge/arXiv-2508.04825-red?style=flat&logo=arXiv&logoColor=red' alt='arxiv'>
</a> &ensp;
<a href='https://nxnai.github.io/Voost/'>
<img src='https://img.shields.io/badge/Webpage-Project-silver?style=flat&logo=&logoColor=orange' alt='webpage'>
</a> &ensp;
<a href="https://github.com/nxnai/Voost">
<img src='https://img.shields.io/badge/GitHub-Repo-blue?style=flat&logo=GitHub' alt='GitHub'>
</a> &ensp;
<a href="https://github.com/nxnai/Voost/blob/main/LICENSE">
<img src='https://img.shields.io/badge/License-CC BY--NC--SA--4.0-lightgreen?style=flat&logo=Lisence' alt='License'>
</a>
</div>
<div style="font-size: 14px; color: #666; margin-top: 5px;">
Website: <a href="https://nxn.ai" target="_blank">https://nxn.ai</a> &nbsp; &nbsp; &nbsp; Inquiries: <a href="mailto:[email protected]">[email protected]</a>
</div>
</div>
</div>
""")
gr.Markdown("---")
with gr.Row():
with gr.Column(scale=1):
gr.HTML("<center><h4>Step 1: Select <em>Try-On</em> or <em>Try-Off</em> mode. </h4></center>")
input_toggle = Toggle(
label="Try-On",
value=False,
interactive=True,
elem_classes=["button-container"],
color="rgba(177, 162, 239, .5)",
elem_id="toggle-modify"
)
with gr.Column(scale=1):
gr.HTML("<center><h4>Step 2: Select your desired garment type.</h4></center>")
garment_condition = gr.Radio(
choices=["Upper", "Lower", "Full"],
value="Upper",
interactive=True,
elem_classes=["center-item"],
show_label=False,
label="Garment Type"
)
with gr.Row():
with gr.Column(scale=1, elem_id="col-left"):
gr.HTML("<center><h4>Step 3: Upload a model image. <br> (Optional) Use the draw tool to create the mask. ⬇️</h4></center>")
model_image = gr.ImageEditor(
label="Model Image",
type="pil",
height=450,
width=600,
interactive=True,
brush=gr.Brush(
default_color=f"rgba(255, 255, 255, 0.5)",
colors=["rgb(255, 255, 255)"]
),
eraser=gr.Eraser(),
placeholder="Upload an image\n or\n select the draw tool on the left\n to start editing mask"
)
model_examples = gr.Examples(
examples=people_list_path,
inputs=[model_image],
label="Model Examples",
examples_per_page=12,
)
with gr.Column(scale=1, elem_id="col-mid"):
gr.HTML("<center><h4>Step 4: Upload a garment image. ⬇️ <br><br></h4></center>")
garment_input = gr.Image(
label="Garment Image",
type="pil",
height=450,
width=350,
visible=True,
interactive=True,
)
garment_examples = gr.Examples(
examples=garment_list_path,
inputs=[garment_input],
label="Garment Examples",
examples_per_page=12
)
with gr.Column(scale=1, elem_id="col-right"):
gr.HTML("<center><h4>Step 5: Click the button below to run the model! ⬇️ <br><br></h4></center>")
output_image = gr.Image(
format="png",
label="Output Image",
type="pil",
height=450,
width=550,
interactive=False,
)
submit_btn = gr.Button(
value="Run Try-On",
elem_id="tryon-color"
)
seed_input = gr.Slider(
label="Seed",
value=1234,
minimum=0,
maximum=2**16 - 1, # 2**32 - 1
step=1,
interactive=True,
elem_id="seed-input",
)
gr.HTML("""
<div style="margin-top: 15px; padding: 10px; background-color: #f8f9fa; border-radius: 8px; border-left: 4px solid #ffc107;">
<p style="margin: 0; font-size: 16px; color: #856404;">
<strong>⚠️ Note:</strong> Errors may occur due to high concurrent requests or NSFW content detection. Please try again if needed.
</p>
</div>
""")
gr.Markdown("---")
with gr.Row():
tryon_examples = gr.Examples(
examples=[
["Upper", "images/examples/tryon/persons/1.jpg", "images/examples/tryon/garments/1.jpg", "images/examples/tryon/outputs/1.webp"],
["Lower", "images/examples/tryon/persons/2.jpg", "images/examples/tryon/garments/2.jpg", "images/examples/tryon/outputs/2.webp"],
["Full", "images/examples/tryon/persons/3.jpg", "images/examples/tryon/garments/3.jpg", "images/examples/tryon/outputs/3.webp"],
],
inputs=[garment_condition, model_image, garment_input, output_image],
fn=set_tryon,
outputs=[garment_input, input_toggle, submit_btn],
label="Try-on Examples",
run_on_click=True
)
tryoff_examples = gr.Examples(
examples=[
["Upper", "images/examples/tryoff/persons/1.jpg", "images/examples/tryoff/outputs/1.webp"],
["Lower", "images/examples/tryoff/persons/2.jpg", "images/examples/tryoff/outputs/2.webp"],
["Full", "images/examples/tryoff/persons/3.jpg", "images/examples/tryoff/outputs/3.webp"],
],
inputs=[garment_condition, model_image, output_image],
fn=set_tryoff,
outputs=[garment_input, input_toggle, submit_btn],
label="Try-Off Examples",
run_on_click=True
)
gr.Markdown("---")
gr.HTML("""
<div class="footer-container">
<div class="footer-col footer-logo">
</div>
<div class="footer-col footer-main">
<h3>AI Studio Shaping the New Architecture of Fashion Imagery</h3>
<p>We’re a team of researchers from <b>Stanford</b>, <b>NYU</b>, <b>Seoul National University</b>, and <b>KAIST</b>. At <b>NXN Labs</b>, we’re developing an <b>image-to-image virtual try-on/try-off diffusion model</b>, designed to push the boundaries of digital production in the fashion industry.
This demo is <b>not the full version</b> of our model - it is based on our recent research work, <a href="https://arxiv.org/abs/2508.04825">Voost</a> - but it reflects the underlying research direction.
We’re headquartered in <b>San Francisco</b> and <b>Seoul</b>. If you’re a <b>brand or retailer</b> interested in using our full model API, please sign up at <a href="https://nxn.ai" target="_blank">https://nxn.ai</a> with your business name, and we’ll get back to you within 1–2 business days.
For part-time or full-time research roles, contact <a href="mailto:[email protected]">[email protected]</a>.
</p>
<p>©2025 NXN Labs ——— Copyright.</p>
</div>
<div class="footer-col footer-credits">
<h3>Special Thanks to NXN Labs Summer Interns:</h3>
<p>
<a href="https://www.linkedin.com/in/james-fu-74a16524b/" target="_blank">James Fu</a>,
<a href="https://www.linkedin.com/in/wing-lai-7a8987271/" target="_blank">Wing Lai</a>,
<a href="https://www.linkedin.com/in/stephen-park-53640332b/" target="_blank">Stephen Park</a>
<br><small>for their valuable contributions to this demo space</small>
</p>
</div>
</div>
""")
# Connect toggle to control garment input visibility
input_toggle.change(
fn=handle_toggle,
inputs=[input_toggle],
outputs=[garment_input, input_toggle, submit_btn],
api_name=False
)
submit_btn.click(
fn=api_helper,
inputs=[model_image, garment_input, garment_condition, input_toggle, seed_input],
outputs=[output_image],
concurrency_limit=7,
api_name=False
)
if __name__ == "__main__":
demo.launch(allowed_paths=["/gradio_api/images/examples"], share=True)