Spaces:
Running
Running
from io import BytesIO | |
import gradio as gr | |
from PIL import Image | |
import httpx | |
from gradio_toggle import Toggle | |
from pathlib import Path | |
import numpy as np | |
import os | |
api_server = os.environ["NXN_API_SERVER"] | |
tryon_endpoint = os.environ["NXN_TRYON_ENDPOINT"] | |
tryoff_endpoint = os.environ["NXN_TRYOFF_ENDPOINT"] | |
MAX_DIM = 2048 | |
MIN_DIM = 500 | |
def encode_bytes(image: Image.Image, format="PNG"): | |
buffered = BytesIO() | |
image.save(buffered, format=format) | |
buffered.seek(0) | |
return buffered | |
# str to int | |
def garment_type_to_int(garment_type: str): | |
garment_dict = {"Upper": 0, "Lower": 1, "Full": 2} | |
if garment_dict[garment_type] is None: | |
raise gr.Error("Unexpected garment condition error") | |
else: | |
return garment_dict[garment_type] | |
def extract_image_from_input(image_data): | |
if isinstance(image_data, dict) and "background" in image_data: | |
return image_data["background"].convert("RGB") | |
else: | |
return image_data.convert("RGB") | |
def resize_image_if_needed(image: Image.Image): | |
if image is None: | |
return None, False | |
original_width, original_height = image.size | |
if original_width > MAX_DIM or original_height > MAX_DIM: | |
gr.Warning("A provided image is too large and has been resized") | |
scale_factor = min(MAX_DIM / original_width, MAX_DIM / original_height) | |
new_width = int(original_width * scale_factor) | |
new_height = int(original_height * scale_factor) | |
return image.resize((new_width, new_height), Image.Resampling.LANCZOS), True | |
elif original_width < MIN_DIM or original_height < MIN_DIM: | |
gr.Warning("A provided image is too small and has been resized") | |
scale_factor = max(MIN_DIM / original_width, MIN_DIM / original_height) | |
new_width = int(original_width * scale_factor) | |
new_height = int(original_height * scale_factor) | |
return image.resize((new_width, new_height), Image.Resampling.LANCZOS), True | |
return image, False | |
# API Helpers | |
async def _call_api(url: str, files: dict, data: dict): | |
try: | |
async with httpx.AsyncClient(timeout=3600) as client: | |
response = await client.post(url, data=data, files=files) | |
response.raise_for_status() | |
return Image.open(BytesIO(response.content)) | |
except httpx.RequestError as e: | |
print(f"API request failed: {e}") | |
raise gr.Error("Network error: Could not connect to the model API. Please try again later.") | |
except Exception as e: | |
print(f"An unexpected error occurred: {e}") | |
raise | |
raise gr.Error("An unexpected error occurred. The model may have failed to process the images.") | |
async def call_tryon_api(model_image: Image.Image, garment_image: Image.Image, garment_type: int, mask: Image.Image=None, seed: int=1234): | |
files = [ | |
("images", ("target.png", encode_bytes(model_image), "image/png")), | |
("images", ("garment.png", encode_bytes(garment_image), "image/png")) | |
] | |
if mask: | |
files.append(("images", ("mask.png", encode_bytes(mask, format="PNG"), "image/png"))) | |
data = {'garment_type': garment_type, 'seed': seed} | |
return await _call_api(f"{api_server}/{tryon_endpoint}", files=files, data=data) | |
async def call_tryoff_api(model_image: Image.Image, garment_type: int, seed: int=1234): | |
files = [ ("images", ("target.png", encode_bytes(model_image), "image/png")) ] | |
data = {'garment_type': garment_type, 'seed': seed} | |
return await _call_api(f"{api_server}/{tryoff_endpoint}", files=files, data=data) | |
async def api_helper(model_image_dict: dict, garment_image: Image.Image, garment_type: str, is_tryoff: bool, seed: int): | |
if model_image_dict is None or model_image_dict["background"] is None: | |
raise gr.Error("Missing model image") | |
elif not is_tryoff and garment_image is None: | |
raise gr.Error("Missing garment image for Try-On") | |
# Because Gradio ImageEditor can return a dict | |
model_image = extract_image_from_input(model_image_dict) | |
model_image, model_resized = resize_image_if_needed(model_image) | |
garment_image, _ = resize_image_if_needed(garment_image) | |
garment_type_int = garment_type_to_int(garment_type) | |
if is_tryoff: | |
return await call_tryoff_api(model_image, garment_type_int, seed) | |
else: | |
mask_image = None | |
if isinstance(model_image_dict, dict) and model_image_dict.get("layers"): | |
mask = model_image_dict["layers"][0] | |
mask_array = np.array(mask) | |
if not np.all(mask_array < 10): | |
is_black = np.all(mask_array < 10, axis=2) | |
mask_image = Image.fromarray(((~is_black) * 255).astype(np.uint8)) | |
if model_resized: | |
mask_image = mask_image.resize(model_image.size, Image.Resampling.NEAREST) | |
else: | |
gr.Info("No mask provided, using auto-generated mask") | |
return await call_tryon_api(model_image, garment_image, garment_type_int, mask=mask_image, seed=seed) | |
# Event handler functions | |
def handle_toggle(toggle_value): | |
"""Handle toggle state changes - controls garment input visibility""" | |
toggle_label = gr.update(value=toggle_value, label="Try-Off") if toggle_value else gr.update(value=toggle_value, label="Try-On") | |
submit_btn_label = gr.update(value="Run Try-Off", elem_id="tryoff-color") if toggle_value else gr.update(value="Run Try-On", elem_id="tryon-color") | |
if toggle_value: | |
# Clear the image and disable the component | |
return gr.update(value=None, elem_classes=["disabled-image"], interactive=False), toggle_label, submit_btn_label | |
else: | |
# Re-enable the component without clearing the image | |
return gr.update(elem_classes=[], interactive=True), toggle_label, submit_btn_label | |
def set_tryon(garment_img, model_img, output_img, garment_condition): | |
garment_update, toggle_label, submit_btn_label = handle_toggle(False) | |
return garment_update, toggle_label, submit_btn_label | |
def set_tryoff(model_img, output_img, garment_condition): | |
garment_update, toggle_label, submit_btn_label = handle_toggle(True) | |
return garment_update, toggle_label, submit_btn_label | |
def garment_sort_key(filename): | |
if filename.startswith("upper_"): | |
return (0, filename) | |
elif filename.startswith("lower_"): | |
return (1, filename) | |
elif filename.startswith("full_"): | |
return (2, filename) | |
else: | |
return (3, filename) | |
# Get images for examples | |
images_path = os.path.join(os.path.dirname(__file__),'images') | |
garment_list = os.listdir(os.path.join(images_path, "garments")) | |
garment_list_path = [ | |
os.path.join(images_path, "garments", cloth) | |
for cloth in sorted(garment_list, key=garment_sort_key) | |
] | |
people_list = os.listdir(os.path.join(images_path, "persons")) | |
people_list_path = [os.path.join(images_path, "persons", human) for human in sorted(people_list)] | |
gr.set_static_paths(paths=[Path.cwd().absolute()/"images"]) | |
# Create the Gradio interface | |
with gr.Blocks(css_paths="styles.css", theme=gr.themes.Ocean(), title="Voost: Virtual Try-On/Off") as demo: | |
with gr.Row(): | |
gr.HTML(""" | |
<div class="header-container"> | |
<div class="logo-container"> | |
<a href="https://nxn.ai/"> | |
<picture> | |
<source media="(prefers-color-scheme: dark)" srcset="/gradio_api/file=images/dark_mode_logo.png"/> | |
<img src='/gradio_api/file=images/nxn_logo_transparent.png' style="height: 120px; width: 150px;"/> | |
</picture> | |
</a> | |
</div> | |
<div style="display: flex; flex-direction: column; align-items: center; text-align: center;"> | |
<div style="font-size: 45px; margin-bottom: 10px;"> | |
<b>Voost: Virtual Try-On/Off</b> | |
</div> | |
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> | |
<a href="https://arxiv.org/abs/2508.04825"> | |
<img src='https://img.shields.io/badge/arXiv-2508.04825-red?style=flat&logo=arXiv&logoColor=red' alt='arxiv'> | |
</a>   | |
<a href='https://nxnai.github.io/Voost/'> | |
<img src='https://img.shields.io/badge/Webpage-Project-silver?style=flat&logo=&logoColor=orange' alt='webpage'> | |
</a>   | |
<a href="https://github.com/nxnai/Voost"> | |
<img src='https://img.shields.io/badge/GitHub-Repo-blue?style=flat&logo=GitHub' alt='GitHub'> | |
</a>   | |
<a href="https://github.com/nxnai/Voost/blob/main/LICENSE"> | |
<img src='https://img.shields.io/badge/License-CC BY--NC--SA--4.0-lightgreen?style=flat&logo=Lisence' alt='License'> | |
</a> | |
</div> | |
<div style="font-size: 14px; color: #666; margin-top: 5px;"> | |
Website: <a href="https://nxn.ai" target="_blank">https://nxn.ai</a> Inquiries: <a href="mailto:[email protected]">[email protected]</a> | |
</div> | |
</div> | |
</div> | |
""") | |
gr.Markdown("---") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.HTML("<center><h4>Step 1: Select <em>Try-On</em> or <em>Try-Off</em> mode. </h4></center>") | |
input_toggle = Toggle( | |
label="Try-On", | |
value=False, | |
interactive=True, | |
elem_classes=["button-container"], | |
color="rgba(177, 162, 239, .5)", | |
elem_id="toggle-modify" | |
) | |
with gr.Column(scale=1): | |
gr.HTML("<center><h4>Step 2: Select your desired garment type.</h4></center>") | |
garment_condition = gr.Radio( | |
choices=["Upper", "Lower", "Full"], | |
value="Upper", | |
interactive=True, | |
elem_classes=["center-item"], | |
show_label=False, | |
label="Garment Type" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1, elem_id="col-left"): | |
gr.HTML("<center><h4>Step 3: Upload a model image. <br> (Optional) Use the draw tool to create the mask. ⬇️</h4></center>") | |
model_image = gr.ImageEditor( | |
label="Model Image", | |
type="pil", | |
height=450, | |
width=600, | |
interactive=True, | |
brush=gr.Brush( | |
default_color=f"rgba(255, 255, 255, 0.5)", | |
colors=["rgb(255, 255, 255)"] | |
), | |
eraser=gr.Eraser(), | |
placeholder="Upload an image\n or\n select the draw tool on the left\n to start editing mask" | |
) | |
model_examples = gr.Examples( | |
examples=people_list_path, | |
inputs=[model_image], | |
label="Model Examples", | |
examples_per_page=12, | |
) | |
with gr.Column(scale=1, elem_id="col-mid"): | |
gr.HTML("<center><h4>Step 4: Upload a garment image. ⬇️ <br><br></h4></center>") | |
garment_input = gr.Image( | |
label="Garment Image", | |
type="pil", | |
height=450, | |
width=350, | |
visible=True, | |
interactive=True, | |
) | |
garment_examples = gr.Examples( | |
examples=garment_list_path, | |
inputs=[garment_input], | |
label="Garment Examples", | |
examples_per_page=12 | |
) | |
with gr.Column(scale=1, elem_id="col-right"): | |
gr.HTML("<center><h4>Step 5: Click the button below to run the model! ⬇️ <br><br></h4></center>") | |
output_image = gr.Image( | |
format="png", | |
label="Output Image", | |
type="pil", | |
height=450, | |
width=550, | |
interactive=False, | |
) | |
submit_btn = gr.Button( | |
value="Run Try-On", | |
elem_id="tryon-color" | |
) | |
seed_input = gr.Slider( | |
label="Seed", | |
value=1234, | |
minimum=0, | |
maximum=2**16 - 1, # 2**32 - 1 | |
step=1, | |
interactive=True, | |
elem_id="seed-input", | |
) | |
gr.HTML(""" | |
<div style="margin-top: 15px; padding: 10px; background-color: #f8f9fa; border-radius: 8px; border-left: 4px solid #ffc107;"> | |
<p style="margin: 0; font-size: 16px; color: #856404;"> | |
<strong>⚠️ Note:</strong> Errors may occur due to high concurrent requests or NSFW content detection. Please try again if needed. | |
</p> | |
</div> | |
""") | |
gr.Markdown("---") | |
with gr.Row(): | |
tryon_examples = gr.Examples( | |
examples=[ | |
["Upper", "images/examples/tryon/persons/1.jpg", "images/examples/tryon/garments/1.jpg", "images/examples/tryon/outputs/1.webp"], | |
["Lower", "images/examples/tryon/persons/2.jpg", "images/examples/tryon/garments/2.jpg", "images/examples/tryon/outputs/2.webp"], | |
["Full", "images/examples/tryon/persons/3.jpg", "images/examples/tryon/garments/3.jpg", "images/examples/tryon/outputs/3.webp"], | |
], | |
inputs=[garment_condition, model_image, garment_input, output_image], | |
fn=set_tryon, | |
outputs=[garment_input, input_toggle, submit_btn], | |
label="Try-on Examples", | |
run_on_click=True | |
) | |
tryoff_examples = gr.Examples( | |
examples=[ | |
["Upper", "images/examples/tryoff/persons/1.jpg", "images/examples/tryoff/outputs/1.webp"], | |
["Lower", "images/examples/tryoff/persons/2.jpg", "images/examples/tryoff/outputs/2.webp"], | |
["Full", "images/examples/tryoff/persons/3.jpg", "images/examples/tryoff/outputs/3.webp"], | |
], | |
inputs=[garment_condition, model_image, output_image], | |
fn=set_tryoff, | |
outputs=[garment_input, input_toggle, submit_btn], | |
label="Try-Off Examples", | |
run_on_click=True | |
) | |
gr.Markdown("---") | |
gr.HTML(""" | |
<div class="footer-container"> | |
<div class="footer-col footer-logo"> | |
</div> | |
<div class="footer-col footer-main"> | |
<h3>AI Studio Shaping the New Architecture of Fashion Imagery</h3> | |
<p>We’re a team of researchers from <b>Stanford</b>, <b>NYU</b>, <b>Seoul National University</b>, and <b>KAIST</b>. At <b>NXN Labs</b>, we’re developing an <b>image-to-image virtual try-on/try-off diffusion model</b>, designed to push the boundaries of digital production in the fashion industry. | |
This demo is <b>not the full version</b> of our model - it is based on our recent research work, <a href="https://arxiv.org/abs/2508.04825">Voost</a> - but it reflects the underlying research direction. | |
We’re headquartered in <b>San Francisco</b> and <b>Seoul</b>. If you’re a <b>brand or retailer</b> interested in using our full model API, please sign up at <a href="https://nxn.ai" target="_blank">https://nxn.ai</a> with your business name, and we’ll get back to you within 1–2 business days. | |
For part-time or full-time research roles, contact <a href="mailto:[email protected]">[email protected]</a>. | |
</p> | |
<p>©2025 NXN Labs ——— Copyright.</p> | |
</div> | |
<div class="footer-col footer-credits"> | |
<h3>Special Thanks to NXN Labs Summer Interns:</h3> | |
<p> | |
<a href="https://www.linkedin.com/in/james-fu-74a16524b/" target="_blank">James Fu</a>, | |
<a href="https://www.linkedin.com/in/wing-lai-7a8987271/" target="_blank">Wing Lai</a>, | |
<a href="https://www.linkedin.com/in/stephen-park-53640332b/" target="_blank">Stephen Park</a> | |
<br><small>for their valuable contributions to this demo space</small> | |
</p> | |
</div> | |
</div> | |
""") | |
# Connect toggle to control garment input visibility | |
input_toggle.change( | |
fn=handle_toggle, | |
inputs=[input_toggle], | |
outputs=[garment_input, input_toggle, submit_btn], | |
api_name=False | |
) | |
submit_btn.click( | |
fn=api_helper, | |
inputs=[model_image, garment_input, garment_condition, input_toggle, seed_input], | |
outputs=[output_image], | |
concurrency_limit=7, | |
api_name=False | |
) | |
if __name__ == "__main__": | |
demo.launch(allowed_paths=["/gradio_api/images/examples"], share=True) |