gaur3009 commited on
Commit
d00e30a
·
verified ·
1 Parent(s): 9bbf95a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -33
app.py CHANGED
@@ -4,18 +4,38 @@ import torch
4
  import cv2
5
  from PIL import Image
6
  from torchvision import transforms
7
- from cloth_segmentation.networks.u2net import U2NET
8
 
9
- # Load U²-Net Model
10
  model_path = "cloth_segmentation/networks/u2net.pth"
11
  model = U2NET(3, 1)
12
  state_dict = torch.load(model_path, map_location=torch.device('cpu'))
13
- state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
14
  model.load_state_dict(state_dict)
15
  model.eval()
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def segment_dress(image_np):
18
- """Segment the dress using U²-Net and GrabCut."""
19
  transform_pipeline = transforms.Compose([
20
  transforms.ToTensor(),
21
  transforms.Resize((320, 320))
@@ -26,42 +46,36 @@ def segment_dress(image_np):
26
 
27
  with torch.no_grad():
28
  output = model(input_tensor)[0][0].squeeze().cpu().numpy()
 
 
 
 
 
 
 
 
29
 
30
- u2net_mask = (output > 0.5).astype(np.uint8) * 255
31
- u2net_mask = cv2.resize(u2net_mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_NEAREST)
32
-
33
- # Apply GrabCut to refine the mask
34
- mask = np.zeros(image_np.shape[:2], np.uint8)
35
- mask[u2net_mask > 128] = cv2.GC_FGD
36
- mask[u2net_mask <= 128] = cv2.GC_BGD
37
- bg_model = np.zeros((1, 65), np.float64)
38
- fg_model = np.zeros((1, 65), np.float64)
39
-
40
- cv2.grabCut(image_np, mask, None, bg_model, fg_model, 5, cv2.GC_INIT_WITH_MASK)
41
- mask = np.where((mask == 2) | (mask == 0), 0, 255).astype(np.uint8)
42
-
43
- return mask
44
 
45
- def recolor_dress(image_np, mask, target_color):
46
- """Recolor the dress while keeping texture, shadows, and designs."""
47
 
48
- # Convert to LAB color space
49
  img_lab = cv2.cvtColor(image_np, cv2.COLOR_RGB2LAB)
50
-
51
- # Target color in LAB
52
  target_color_lab = cv2.cvtColor(np.uint8([[target_color]]), cv2.COLOR_BGR2LAB)[0][0]
53
-
54
- # Preserve lightness (L) and change only chromatic channels (A & B)
 
 
 
55
  blend_factor = 0.8
56
- img_lab[..., 1] = np.where(mask > 128, img_lab[..., 1] * (1 - blend_factor) + target_color_lab[1] * blend_factor, img_lab[..., 1])
57
- img_lab[..., 2] = np.where(mask > 128, img_lab[..., 2] * (1 - blend_factor) + target_color_lab[2] * blend_factor, img_lab[..., 2])
58
 
59
- # Convert back to RGB
60
  img_recolored = cv2.cvtColor(img_lab, cv2.COLOR_LAB2RGB)
61
  return img_recolored
62
 
63
  def change_dress_color(image_path, color):
64
- """Change the dress color while preserving texture and design details."""
65
  if image_path is None:
66
  return None
67
 
@@ -69,11 +83,14 @@ def change_dress_color(image_path, color):
69
  img_np = np.array(img)
70
 
71
  # Get dress segmentation mask
72
- mask = segment_dress(img_np)
73
 
74
- if mask is None:
75
  return img # No dress detected
76
-
 
 
 
77
  # Convert the selected color to BGR
78
  color_map = {
79
  "Red": (0, 0, 255), "Blue": (255, 0, 0), "Green": (0, 255, 0), "Yellow": (0, 255, 255),
@@ -83,7 +100,7 @@ def change_dress_color(image_path, color):
83
  new_color_bgr = np.array(color_map.get(color, (0, 0, 255)), dtype=np.uint8) # Default to Red
84
 
85
  # Apply recoloring logic
86
- img_recolored = recolor_dress(img_np, mask, new_color_bgr)
87
 
88
  return Image.fromarray(img_recolored)
89
 
 
4
  import cv2
5
  from PIL import Image
6
  from torchvision import transforms
7
+ from cloth_segmentation.networks.u2net import U2NET # Import U²-Net
8
 
9
+ # Load U²-Net model
10
  model_path = "cloth_segmentation/networks/u2net.pth"
11
  model = U2NET(3, 1)
12
  state_dict = torch.load(model_path, map_location=torch.device('cpu'))
13
+ state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} # Remove 'module.' prefix
14
  model.load_state_dict(state_dict)
15
  model.eval()
16
 
17
+ def detect_design(image_np):
18
+ """Detects the design on the dress using edge detection and adaptive thresholding."""
19
+ gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
20
+
21
+ # Use adaptive thresholding to segment the design
22
+ adaptive_thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
23
+ cv2.THRESH_BINARY_INV, 11, 2)
24
+
25
+ # Detect edges using Canny
26
+ edges = cv2.Canny(gray, 50, 150)
27
+
28
+ # Combine both masks
29
+ design_mask = cv2.bitwise_or(adaptive_thresh, edges)
30
+
31
+ # Morphological operations to remove noise
32
+ kernel = np.ones((3, 3), np.uint8)
33
+ design_mask = cv2.morphologyEx(design_mask, cv2.MORPH_CLOSE, kernel)
34
+
35
+ return design_mask
36
+
37
  def segment_dress(image_np):
38
+ """Segment the dress using U²-Net"""
39
  transform_pipeline = transforms.Compose([
40
  transforms.ToTensor(),
41
  transforms.Resize((320, 320))
 
46
 
47
  with torch.no_grad():
48
  output = model(input_tensor)[0][0].squeeze().cpu().numpy()
49
+
50
+ # Convert output to mask
51
+ dress_mask = (output > 0.5).astype(np.uint8) * 255
52
+ dress_mask = cv2.resize(dress_mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_NEAREST)
53
+
54
+ # Morphological operations for smoothness
55
+ kernel = np.ones((5, 5), np.uint8)
56
+ dress_mask = cv2.morphologyEx(dress_mask, cv2.MORPH_CLOSE, kernel)
57
 
58
+ return dress_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ def recolor_dress(image_np, dress_mask, design_mask, target_color):
61
+ """Change dress color while preserving designs"""
62
 
 
63
  img_lab = cv2.cvtColor(image_np, cv2.COLOR_RGB2LAB)
 
 
64
  target_color_lab = cv2.cvtColor(np.uint8([[target_color]]), cv2.COLOR_BGR2LAB)[0][0]
65
+
66
+ # Ensure the design areas are NOT recolored
67
+ recolor_mask = cv2.bitwise_and(dress_mask, cv2.bitwise_not(design_mask))
68
+
69
+ # Apply color change only to the non-design dress areas
70
  blend_factor = 0.8
71
+ img_lab[..., 1] = np.where(recolor_mask > 128, img_lab[..., 1] * (1 - blend_factor) + target_color_lab[1] * blend_factor, img_lab[..., 1])
72
+ img_lab[..., 2] = np.where(recolor_mask > 128, img_lab[..., 2] * (1 - blend_factor) + target_color_lab[2] * blend_factor, img_lab[..., 2])
73
 
 
74
  img_recolored = cv2.cvtColor(img_lab, cv2.COLOR_LAB2RGB)
75
  return img_recolored
76
 
77
  def change_dress_color(image_path, color):
78
+ """Change the dress color naturally while keeping designs intact."""
79
  if image_path is None:
80
  return None
81
 
 
83
  img_np = np.array(img)
84
 
85
  # Get dress segmentation mask
86
+ dress_mask = segment_dress(img_np)
87
 
88
+ if dress_mask is None:
89
  return img # No dress detected
90
+
91
+ # Detect design on the dress
92
+ design_mask = detect_design(img_np)
93
+
94
  # Convert the selected color to BGR
95
  color_map = {
96
  "Red": (0, 0, 255), "Blue": (255, 0, 0), "Green": (0, 255, 0), "Yellow": (0, 255, 255),
 
100
  new_color_bgr = np.array(color_map.get(color, (0, 0, 255)), dtype=np.uint8) # Default to Red
101
 
102
  # Apply recoloring logic
103
+ img_recolored = recolor_dress(img_np, dress_mask, design_mask, new_color_bgr)
104
 
105
  return Image.fromarray(img_recolored)
106