Spaces:
Running
Running
import gradio as gr | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
import torch | |
import torch.nn as nn | |
import torchvision.transforms as T | |
from skimage import color | |
from sklearn.cluster import KMeans | |
from cloth_segmentation.networks.u2net import U2NET | |
model = U2NET(3, 1) | |
model.load_state_dict(torch.load("cloth_segmentation/networks/u2net.pth", map_location=torch.device('cpu'))) | |
model.eval() | |
# Preprocessing | |
transform = T.Compose([ | |
T.Resize((320, 320)), | |
T.ToTensor(), | |
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
# Segmentation mask | |
def get_dress_mask(image_pil): | |
img = transform(image_pil).unsqueeze(0) | |
pred = model(img)[0] | |
pred = pred.squeeze().cpu().numpy() | |
mask = (pred > 0.5).astype(np.uint8) | |
mask = cv2.resize(mask, image_pil.size[::-1]) | |
return mask | |
# Color parsing (extract target color from prompt) | |
def extract_target_color(prompt): | |
# Basic keyword matching (can be replaced with NLP-based color detection) | |
import re | |
colors = ["red", "blue", "green", "yellow", "pink", "black", "white", "sky blue", "purple"] | |
for c in colors: | |
if re.search(c, prompt.lower()): | |
return c | |
return "red" # default fallback | |
# Recoloring function | |
def recolor_dress(image_pil, prompt): | |
image_np = np.array(image_pil.convert("RGB")) / 255.0 | |
lab = color.rgb2lab(image_np) | |
mask = get_dress_mask(image_pil) | |
# Get mean a, b values in masked region | |
a_mean = lab[:, :, 1][mask == 1].mean() | |
b_mean = lab[:, :, 2][mask == 1].mean() | |
# Target a, b (from a small predefined palette) | |
target_color_map = { | |
"red": [60, 40], | |
"blue": [20, -60], | |
"green": [-60, 60], | |
"yellow": [10, 70], | |
"pink": [50, 10], | |
"purple": [40, -40], | |
"black": [0, 0], | |
"white": [0, 0], | |
"sky blue": [0, -50], | |
} | |
target = extract_target_color(prompt) | |
target_a, target_b = target_color_map.get(target, [60, 40]) | |
# Apply color shift only to dress region | |
lab_new = lab.copy() | |
delta_a = target_a - a_mean | |
delta_b = target_b - b_mean | |
lab_new[:, :, 1][mask == 1] += delta_a | |
lab_new[:, :, 2][mask == 1] += delta_b | |
rgb_new = color.lab2rgb(lab_new) | |
rgb_new = (rgb_new * 255).astype(np.uint8) | |
return Image.fromarray(rgb_new) | |
# Gradio UI | |
def interface_fn(image, prompt): | |
return recolor_dress(image, prompt) | |
interface = gr.Interface( | |
fn=interface_fn, | |
inputs=[ | |
gr.Image(label="Upload Image", type="pil"), | |
gr.Textbox(label="Prompt", placeholder="Describe what to edit") | |
], | |
outputs=gr.Image(label="Edited Image"), | |
title="Image Editor", | |
description="Uses Hugging Face model for real image editing based on prompt." | |
) | |
interface.launch(show_error = True) |