detecting_dress / app.py
gaur3009's picture
Update app.py
eb85959 verified
import gradio as gr
import numpy as np
import cv2
from PIL import Image
import torch
import torch.nn as nn
import torchvision.transforms as T
from skimage import color
from sklearn.cluster import KMeans
from cloth_segmentation.networks.u2net import U2NET
model = U2NET(3, 1)
model.load_state_dict(torch.load("cloth_segmentation/networks/u2net.pth", map_location=torch.device('cpu')))
model.eval()
# Preprocessing
transform = T.Compose([
T.Resize((320, 320)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Segmentation mask
@torch.no_grad()
def get_dress_mask(image_pil):
img = transform(image_pil).unsqueeze(0)
pred = model(img)[0]
pred = pred.squeeze().cpu().numpy()
mask = (pred > 0.5).astype(np.uint8)
mask = cv2.resize(mask, image_pil.size[::-1])
return mask
# Color parsing (extract target color from prompt)
def extract_target_color(prompt):
# Basic keyword matching (can be replaced with NLP-based color detection)
import re
colors = ["red", "blue", "green", "yellow", "pink", "black", "white", "sky blue", "purple"]
for c in colors:
if re.search(c, prompt.lower()):
return c
return "red" # default fallback
# Recoloring function
def recolor_dress(image_pil, prompt):
image_np = np.array(image_pil.convert("RGB")) / 255.0
lab = color.rgb2lab(image_np)
mask = get_dress_mask(image_pil)
# Get mean a, b values in masked region
a_mean = lab[:, :, 1][mask == 1].mean()
b_mean = lab[:, :, 2][mask == 1].mean()
# Target a, b (from a small predefined palette)
target_color_map = {
"red": [60, 40],
"blue": [20, -60],
"green": [-60, 60],
"yellow": [10, 70],
"pink": [50, 10],
"purple": [40, -40],
"black": [0, 0],
"white": [0, 0],
"sky blue": [0, -50],
}
target = extract_target_color(prompt)
target_a, target_b = target_color_map.get(target, [60, 40])
# Apply color shift only to dress region
lab_new = lab.copy()
delta_a = target_a - a_mean
delta_b = target_b - b_mean
lab_new[:, :, 1][mask == 1] += delta_a
lab_new[:, :, 2][mask == 1] += delta_b
rgb_new = color.lab2rgb(lab_new)
rgb_new = (rgb_new * 255).astype(np.uint8)
return Image.fromarray(rgb_new)
# Gradio UI
def interface_fn(image, prompt):
return recolor_dress(image, prompt)
interface = gr.Interface(
fn=interface_fn,
inputs=[
gr.Image(label="Upload Image", type="pil"),
gr.Textbox(label="Prompt", placeholder="Describe what to edit")
],
outputs=gr.Image(label="Edited Image"),
title="Image Editor",
description="Uses Hugging Face model for real image editing based on prompt."
)
interface.launch(show_error = True)