from io import BytesIO import gradio as gr from PIL import Image import httpx from gradio_toggle import Toggle from pathlib import Path import numpy as np import os api_server = os.environ["NXN_API_SERVER"] tryon_endpoint = os.environ["NXN_TRYON_ENDPOINT"] tryoff_endpoint = os.environ["NXN_TRYOFF_ENDPOINT"] MAX_DIM = 2048 MIN_DIM = 500 def encode_bytes(image: Image.Image, format="PNG"): buffered = BytesIO() image.save(buffered, format=format) buffered.seek(0) return buffered # str to int def garment_type_to_int(garment_type: str): garment_dict = {"Upper": 0, "Lower": 1, "Full": 2} if garment_dict[garment_type] is None: raise gr.Error("Unexpected garment condition error") else: return garment_dict[garment_type] def extract_image_from_input(image_data): if isinstance(image_data, dict) and "background" in image_data: return image_data["background"].convert("RGB") else: return image_data.convert("RGB") def resize_image_if_needed(image: Image.Image): if image is None: return None, False original_width, original_height = image.size if original_width > MAX_DIM or original_height > MAX_DIM: gr.Warning("A provided image is too large and has been resized") scale_factor = min(MAX_DIM / original_width, MAX_DIM / original_height) new_width = int(original_width * scale_factor) new_height = int(original_height * scale_factor) return image.resize((new_width, new_height), Image.Resampling.LANCZOS), True elif original_width < MIN_DIM or original_height < MIN_DIM: gr.Warning("A provided image is too small and has been resized") scale_factor = max(MIN_DIM / original_width, MIN_DIM / original_height) new_width = int(original_width * scale_factor) new_height = int(original_height * scale_factor) return image.resize((new_width, new_height), Image.Resampling.LANCZOS), True return image, False # API Helpers async def _call_api(url: str, files: dict, data: dict): try: async with httpx.AsyncClient(timeout=3600) as client: response = await client.post(url, data=data, files=files) response.raise_for_status() return Image.open(BytesIO(response.content)) except httpx.RequestError as e: print(f"API request failed: {e}") raise gr.Error("Network error: Could not connect to the model API. Please try again later.") except Exception as e: print(f"An unexpected error occurred: {e}") raise raise gr.Error("An unexpected error occurred. The model may have failed to process the images.") async def call_tryon_api(model_image: Image.Image, garment_image: Image.Image, garment_type: int, mask: Image.Image=None, seed: int=1234): files = [ ("images", ("target.png", encode_bytes(model_image), "image/png")), ("images", ("garment.png", encode_bytes(garment_image), "image/png")) ] if mask: files.append(("images", ("mask.png", encode_bytes(mask, format="PNG"), "image/png"))) data = {'garment_type': garment_type, 'seed': seed} return await _call_api(f"{api_server}/{tryon_endpoint}", files=files, data=data) async def call_tryoff_api(model_image: Image.Image, garment_type: int, seed: int=1234): files = [ ("images", ("target.png", encode_bytes(model_image), "image/png")) ] data = {'garment_type': garment_type, 'seed': seed} return await _call_api(f"{api_server}/{tryoff_endpoint}", files=files, data=data) async def api_helper(model_image_dict: dict, garment_image: Image.Image, garment_type: str, is_tryoff: bool, seed: int): if model_image_dict is None or model_image_dict["background"] is None: raise gr.Error("Missing model image") elif not is_tryoff and garment_image is None: raise gr.Error("Missing garment image for Try-On") # Because Gradio ImageEditor can return a dict model_image = extract_image_from_input(model_image_dict) model_image, model_resized = resize_image_if_needed(model_image) garment_image, _ = resize_image_if_needed(garment_image) garment_type_int = garment_type_to_int(garment_type) if is_tryoff: return await call_tryoff_api(model_image, garment_type_int, seed) else: mask_image = None if isinstance(model_image_dict, dict) and model_image_dict.get("layers"): mask = model_image_dict["layers"][0] mask_array = np.array(mask) if not np.all(mask_array < 10): is_black = np.all(mask_array < 10, axis=2) mask_image = Image.fromarray(((~is_black) * 255).astype(np.uint8)) if model_resized: mask_image = mask_image.resize(model_image.size, Image.Resampling.NEAREST) else: gr.Info("No mask provided, using auto-generated mask") return await call_tryon_api(model_image, garment_image, garment_type_int, mask=mask_image, seed=seed) # Event handler functions def handle_toggle(toggle_value): """Handle toggle state changes - controls garment input visibility""" toggle_label = gr.update(value=toggle_value, label="Try-Off") if toggle_value else gr.update(value=toggle_value, label="Try-On") submit_btn_label = gr.update(value="Run Try-Off", elem_id="tryoff-color") if toggle_value else gr.update(value="Run Try-On", elem_id="tryon-color") if toggle_value: # Clear the image and disable the component return gr.update(value=None, elem_classes=["disabled-image"], interactive=False), toggle_label, submit_btn_label else: # Re-enable the component without clearing the image return gr.update(elem_classes=[], interactive=True), toggle_label, submit_btn_label def set_tryon(garment_img, model_img, output_img, garment_condition): garment_update, toggle_label, submit_btn_label = handle_toggle(False) return garment_update, toggle_label, submit_btn_label def set_tryoff(model_img, output_img, garment_condition): garment_update, toggle_label, submit_btn_label = handle_toggle(True) return garment_update, toggle_label, submit_btn_label def garment_sort_key(filename): if filename.startswith("upper_"): return (0, filename) elif filename.startswith("lower_"): return (1, filename) elif filename.startswith("full_"): return (2, filename) else: return (3, filename) # Get images for examples images_path = os.path.join(os.path.dirname(__file__),'images') garment_list = os.listdir(os.path.join(images_path, "garments")) garment_list_path = [ os.path.join(images_path, "garments", cloth) for cloth in sorted(garment_list, key=garment_sort_key) ] people_list = os.listdir(os.path.join(images_path, "persons")) people_list_path = [os.path.join(images_path, "persons", human) for human in sorted(people_list)] gr.set_static_paths(paths=[Path.cwd().absolute()/"images"]) # Create the Gradio interface with gr.Blocks(css_paths="styles.css", theme=gr.themes.Ocean(), title="Voost: Virtual Try-On/Off") as demo: with gr.Row(): gr.HTML("""
Voost: Virtual Try-On/Off
arxiv webpage GitHub License
Website: https://nxn.ai       Inquiries: hello@nxn.ai
""") gr.Markdown("---") with gr.Row(): with gr.Column(scale=1): gr.HTML("

Step 1: Select Try-On or Try-Off mode.

") input_toggle = Toggle( label="Try-On", value=False, interactive=True, elem_classes=["button-container"], color="rgba(177, 162, 239, .5)", elem_id="toggle-modify" ) with gr.Column(scale=1): gr.HTML("

Step 2: Select your desired garment type.

") garment_condition = gr.Radio( choices=["Upper", "Lower", "Full"], value="Upper", interactive=True, elem_classes=["center-item"], show_label=False, label="Garment Type" ) with gr.Row(): with gr.Column(scale=1, elem_id="col-left"): gr.HTML("

Step 3: Upload a model image.
(Optional) Use the draw tool to create the mask. ⬇️

") model_image = gr.ImageEditor( label="Model Image", type="pil", height=450, width=600, interactive=True, brush=gr.Brush( default_color=f"rgba(255, 255, 255, 0.5)", colors=["rgb(255, 255, 255)"] ), eraser=gr.Eraser(), placeholder="Upload an image\n or\n select the draw tool on the left\n to start editing mask" ) model_examples = gr.Examples( examples=people_list_path, inputs=[model_image], label="Model Examples", examples_per_page=12, ) with gr.Column(scale=1, elem_id="col-mid"): gr.HTML("

Step 4: Upload a garment image. ⬇️

") garment_input = gr.Image( label="Garment Image", type="pil", height=450, width=350, visible=True, interactive=True, ) garment_examples = gr.Examples( examples=garment_list_path, inputs=[garment_input], label="Garment Examples", examples_per_page=12 ) with gr.Column(scale=1, elem_id="col-right"): gr.HTML("

Step 5: Click the button below to run the model! ⬇️

") output_image = gr.Image( format="png", label="Output Image", type="pil", height=450, width=550, interactive=False, ) submit_btn = gr.Button( value="Run Try-On", elem_id="tryon-color" ) seed_input = gr.Slider( label="Seed", value=1234, minimum=0, maximum=2**16 - 1, # 2**32 - 1 step=1, interactive=True, elem_id="seed-input", ) gr.HTML("""

⚠️ Note: Errors may occur due to high concurrent requests or NSFW content detection. Please try again if needed.

""") gr.Markdown("---") with gr.Row(): tryon_examples = gr.Examples( examples=[ ["Upper", "images/examples/tryon/persons/1.jpg", "images/examples/tryon/garments/1.jpg", "images/examples/tryon/outputs/1.webp"], ["Lower", "images/examples/tryon/persons/2.jpg", "images/examples/tryon/garments/2.jpg", "images/examples/tryon/outputs/2.webp"], ["Full", "images/examples/tryon/persons/3.jpg", "images/examples/tryon/garments/3.jpg", "images/examples/tryon/outputs/3.webp"], ], inputs=[garment_condition, model_image, garment_input, output_image], fn=set_tryon, outputs=[garment_input, input_toggle, submit_btn], label="Try-on Examples", run_on_click=True ) tryoff_examples = gr.Examples( examples=[ ["Upper", "images/examples/tryoff/persons/1.jpg", "images/examples/tryoff/outputs/1.webp"], ["Lower", "images/examples/tryoff/persons/2.jpg", "images/examples/tryoff/outputs/2.webp"], ["Full", "images/examples/tryoff/persons/3.jpg", "images/examples/tryoff/outputs/3.webp"], ], inputs=[garment_condition, model_image, output_image], fn=set_tryoff, outputs=[garment_input, input_toggle, submit_btn], label="Try-Off Examples", run_on_click=True ) gr.Markdown("---") gr.HTML(""" """) # Connect toggle to control garment input visibility input_toggle.change( fn=handle_toggle, inputs=[input_toggle], outputs=[garment_input, input_toggle, submit_btn], api_name=False ) submit_btn.click( fn=api_helper, inputs=[model_image, garment_input, garment_condition, input_toggle, seed_input], outputs=[output_image], concurrency_limit=7, api_name=False ) if __name__ == "__main__": demo.launch(allowed_paths=["/gradio_api/images/examples"], share=True)