detecting_dress / app.py
gaur3009's picture
Update app.py
eb85959 verified
raw
history blame
2.8 kB
import gradio as gr
import numpy as np
import cv2
from PIL import Image
import torch
import torch.nn as nn
import torchvision.transforms as T
from skimage import color
from sklearn.cluster import KMeans
from cloth_segmentation.networks.u2net import U2NET
model = U2NET(3, 1)
model.load_state_dict(torch.load("cloth_segmentation/networks/u2net.pth", map_location=torch.device('cpu')))
model.eval()
# Preprocessing
transform = T.Compose([
T.Resize((320, 320)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Segmentation mask
@torch.no_grad()
def get_dress_mask(image_pil):
img = transform(image_pil).unsqueeze(0)
pred = model(img)[0]
pred = pred.squeeze().cpu().numpy()
mask = (pred > 0.5).astype(np.uint8)
mask = cv2.resize(mask, image_pil.size[::-1])
return mask
# Color parsing (extract target color from prompt)
def extract_target_color(prompt):
# Basic keyword matching (can be replaced with NLP-based color detection)
import re
colors = ["red", "blue", "green", "yellow", "pink", "black", "white", "sky blue", "purple"]
for c in colors:
if re.search(c, prompt.lower()):
return c
return "red" # default fallback
# Recoloring function
def recolor_dress(image_pil, prompt):
image_np = np.array(image_pil.convert("RGB")) / 255.0
lab = color.rgb2lab(image_np)
mask = get_dress_mask(image_pil)
# Get mean a, b values in masked region
a_mean = lab[:, :, 1][mask == 1].mean()
b_mean = lab[:, :, 2][mask == 1].mean()
# Target a, b (from a small predefined palette)
target_color_map = {
"red": [60, 40],
"blue": [20, -60],
"green": [-60, 60],
"yellow": [10, 70],
"pink": [50, 10],
"purple": [40, -40],
"black": [0, 0],
"white": [0, 0],
"sky blue": [0, -50],
}
target = extract_target_color(prompt)
target_a, target_b = target_color_map.get(target, [60, 40])
# Apply color shift only to dress region
lab_new = lab.copy()
delta_a = target_a - a_mean
delta_b = target_b - b_mean
lab_new[:, :, 1][mask == 1] += delta_a
lab_new[:, :, 2][mask == 1] += delta_b
rgb_new = color.lab2rgb(lab_new)
rgb_new = (rgb_new * 255).astype(np.uint8)
return Image.fromarray(rgb_new)
# Gradio UI
def interface_fn(image, prompt):
return recolor_dress(image, prompt)
interface = gr.Interface(
fn=interface_fn,
inputs=[
gr.Image(label="Upload Image", type="pil"),
gr.Textbox(label="Prompt", placeholder="Describe what to edit")
],
outputs=gr.Image(label="Edited Image"),
title="Image Editor",
description="Uses Hugging Face model for real image editing based on prompt."
)
interface.launch(show_error = True)