from io import BytesIO import gradio as gr from PIL import Image import httpx from gradio_toggle import Toggle from pathlib import Path import numpy as np import os api_server = os.environ["NXN_API_SERVER"] tryon_endpoint = os.environ["NXN_TRYON_ENDPOINT"] tryoff_endpoint = os.environ["NXN_TRYOFF_ENDPOINT"] MAX_DIM = 2048 MIN_DIM = 500 def encode_bytes(image: Image.Image, format="PNG"): buffered = BytesIO() image.save(buffered, format=format) buffered.seek(0) return buffered # str to int def garment_type_to_int(garment_type: str): garment_dict = {"Upper": 0, "Lower": 1, "Full": 2} if garment_dict[garment_type] is None: raise gr.Error("Unexpected garment condition error") else: return garment_dict[garment_type] def extract_image_from_input(image_data): if isinstance(image_data, dict) and "background" in image_data: return image_data["background"].convert("RGB") else: return image_data.convert("RGB") def resize_image_if_needed(image: Image.Image): if image is None: return None, False original_width, original_height = image.size if original_width > MAX_DIM or original_height > MAX_DIM: gr.Warning("A provided image is too large and has been resized") scale_factor = min(MAX_DIM / original_width, MAX_DIM / original_height) new_width = int(original_width * scale_factor) new_height = int(original_height * scale_factor) return image.resize((new_width, new_height), Image.Resampling.LANCZOS), True elif original_width < MIN_DIM or original_height < MIN_DIM: gr.Warning("A provided image is too small and has been resized") scale_factor = max(MIN_DIM / original_width, MIN_DIM / original_height) new_width = int(original_width * scale_factor) new_height = int(original_height * scale_factor) return image.resize((new_width, new_height), Image.Resampling.LANCZOS), True return image, False # API Helpers async def _call_api(url: str, files: dict, data: dict): try: async with httpx.AsyncClient(timeout=3600) as client: response = await client.post(url, data=data, files=files) response.raise_for_status() return Image.open(BytesIO(response.content)) except httpx.RequestError as e: print(f"API request failed: {e}") raise gr.Error("Network error: Could not connect to the model API. Please try again later.") except Exception as e: print(f"An unexpected error occurred: {e}") raise raise gr.Error("An unexpected error occurred. The model may have failed to process the images.") async def call_tryon_api(model_image: Image.Image, garment_image: Image.Image, garment_type: int, mask: Image.Image=None, seed: int=1234): files = [ ("images", ("target.png", encode_bytes(model_image), "image/png")), ("images", ("garment.png", encode_bytes(garment_image), "image/png")) ] if mask: files.append(("images", ("mask.png", encode_bytes(mask, format="PNG"), "image/png"))) data = {'garment_type': garment_type, 'seed': seed} return await _call_api(f"{api_server}/{tryon_endpoint}", files=files, data=data) async def call_tryoff_api(model_image: Image.Image, garment_type: int, seed: int=1234): files = [ ("images", ("target.png", encode_bytes(model_image), "image/png")) ] data = {'garment_type': garment_type, 'seed': seed} return await _call_api(f"{api_server}/{tryoff_endpoint}", files=files, data=data) async def api_helper(model_image_dict: dict, garment_image: Image.Image, garment_type: str, is_tryoff: bool, seed: int): if model_image_dict is None or model_image_dict["background"] is None: raise gr.Error("Missing model image") elif not is_tryoff and garment_image is None: raise gr.Error("Missing garment image for Try-On") # Because Gradio ImageEditor can return a dict model_image = extract_image_from_input(model_image_dict) model_image, model_resized = resize_image_if_needed(model_image) garment_image, _ = resize_image_if_needed(garment_image) garment_type_int = garment_type_to_int(garment_type) if is_tryoff: return await call_tryoff_api(model_image, garment_type_int, seed) else: mask_image = None if isinstance(model_image_dict, dict) and model_image_dict.get("layers"): mask = model_image_dict["layers"][0] mask_array = np.array(mask) if not np.all(mask_array < 10): is_black = np.all(mask_array < 10, axis=2) mask_image = Image.fromarray(((~is_black) * 255).astype(np.uint8)) if model_resized: mask_image = mask_image.resize(model_image.size, Image.Resampling.NEAREST) else: gr.Info("No mask provided, using auto-generated mask") return await call_tryon_api(model_image, garment_image, garment_type_int, mask=mask_image, seed=seed) # Event handler functions def handle_toggle(toggle_value): """Handle toggle state changes - controls garment input visibility""" toggle_label = gr.update(value=toggle_value, label="Try-Off") if toggle_value else gr.update(value=toggle_value, label="Try-On") submit_btn_label = gr.update(value="Run Try-Off", elem_id="tryoff-color") if toggle_value else gr.update(value="Run Try-On", elem_id="tryon-color") if toggle_value: # Clear the image and disable the component return gr.update(value=None, elem_classes=["disabled-image"], interactive=False), toggle_label, submit_btn_label else: # Re-enable the component without clearing the image return gr.update(elem_classes=[], interactive=True), toggle_label, submit_btn_label def set_tryon(garment_img, model_img, output_img, garment_condition): garment_update, toggle_label, submit_btn_label = handle_toggle(False) return garment_update, toggle_label, submit_btn_label def set_tryoff(model_img, output_img, garment_condition): garment_update, toggle_label, submit_btn_label = handle_toggle(True) return garment_update, toggle_label, submit_btn_label def garment_sort_key(filename): if filename.startswith("upper_"): return (0, filename) elif filename.startswith("lower_"): return (1, filename) elif filename.startswith("full_"): return (2, filename) else: return (3, filename) # Get images for examples images_path = os.path.join(os.path.dirname(__file__),'images') garment_list = os.listdir(os.path.join(images_path, "garments")) garment_list_path = [ os.path.join(images_path, "garments", cloth) for cloth in sorted(garment_list, key=garment_sort_key) ] people_list = os.listdir(os.path.join(images_path, "persons")) people_list_path = [os.path.join(images_path, "persons", human) for human in sorted(people_list)] gr.set_static_paths(paths=[Path.cwd().absolute()/"images"]) # Create the Gradio interface with gr.Blocks(css_paths="styles.css", theme=gr.themes.Ocean(), title="Voost: Virtual Try-On/Off") as demo: with gr.Row(): gr.HTML("""
⚠️ Note: Errors may occur due to high concurrent requests or NSFW content detection. Please try again if needed.