import gradio as gr import numpy as np import cv2 from PIL import Image import torch import torch.nn as nn import torchvision.transforms as T from skimage import color from sklearn.cluster import KMeans from cloth_segmentation.networks.u2net import U2NET model = U2NET(3, 1) model.load_state_dict(torch.load("cloth_segmentation/networks/u2net.pth", map_location=torch.device('cpu'))) model.eval() # Preprocessing transform = T.Compose([ T.Resize((320, 320)), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Segmentation mask @torch.no_grad() def get_dress_mask(image_pil): img = transform(image_pil).unsqueeze(0) pred = model(img)[0] pred = pred.squeeze().cpu().numpy() mask = (pred > 0.5).astype(np.uint8) mask = cv2.resize(mask, image_pil.size[::-1]) return mask # Color parsing (extract target color from prompt) def extract_target_color(prompt): # Basic keyword matching (can be replaced with NLP-based color detection) import re colors = ["red", "blue", "green", "yellow", "pink", "black", "white", "sky blue", "purple"] for c in colors: if re.search(c, prompt.lower()): return c return "red" # default fallback # Recoloring function def recolor_dress(image_pil, prompt): image_np = np.array(image_pil.convert("RGB")) / 255.0 lab = color.rgb2lab(image_np) mask = get_dress_mask(image_pil) # Get mean a, b values in masked region a_mean = lab[:, :, 1][mask == 1].mean() b_mean = lab[:, :, 2][mask == 1].mean() # Target a, b (from a small predefined palette) target_color_map = { "red": [60, 40], "blue": [20, -60], "green": [-60, 60], "yellow": [10, 70], "pink": [50, 10], "purple": [40, -40], "black": [0, 0], "white": [0, 0], "sky blue": [0, -50], } target = extract_target_color(prompt) target_a, target_b = target_color_map.get(target, [60, 40]) # Apply color shift only to dress region lab_new = lab.copy() delta_a = target_a - a_mean delta_b = target_b - b_mean lab_new[:, :, 1][mask == 1] += delta_a lab_new[:, :, 2][mask == 1] += delta_b rgb_new = color.lab2rgb(lab_new) rgb_new = (rgb_new * 255).astype(np.uint8) return Image.fromarray(rgb_new) # Gradio UI def interface_fn(image, prompt): return recolor_dress(image, prompt) interface = gr.Interface( fn=interface_fn, inputs=[ gr.Image(label="Upload Image", type="pil"), gr.Textbox(label="Prompt", placeholder="Describe what to edit") ], outputs=gr.Image(label="Edited Image"), title="Image Editor", description="Uses Hugging Face model for real image editing based on prompt." ) interface.launch(show_error = True)