gaur3009 commited on
Commit
0bda8ba
·
verified ·
1 Parent(s): 7da70bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -60
app.py CHANGED
@@ -4,10 +4,10 @@ import torch
4
  import cv2
5
  from PIL import Image
6
  from torchvision import transforms
7
- from cloth_segmentation.networks.u2net import U2NET
8
 
9
  # Load U²-Net model
10
- model_path = "cloth_segmentation/networks/u2net.pth"
11
  model = U2NET(3, 1)
12
  state_dict = torch.load(model_path, map_location=torch.device('cpu'))
13
  state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
@@ -15,14 +15,25 @@ model.load_state_dict(state_dict)
15
  model.eval()
16
 
17
  def refine_mask(mask):
18
- """Refines mask using morphological closing followed by Gaussian blur"""
19
- kernel = np.ones((7, 7), np.uint8)
20
- mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) # Close holes inside dress
21
- mask = cv2.GaussianBlur(mask, (7, 7), 1.5)
 
 
 
 
 
 
 
 
 
 
 
22
  return mask
23
 
24
  def segment_dress(image_np):
25
- """Segment dress using U²-Net"""
26
  transform_pipeline = transforms.Compose([
27
  transforms.ToTensor(),
28
  transforms.Resize((320, 320))
@@ -34,70 +45,67 @@ def segment_dress(image_np):
34
  with torch.no_grad():
35
  output = model(input_tensor)[0][0].squeeze().cpu().numpy()
36
 
37
- output = (output - output.min()) / (output.max() - output.min() + 1e-8) # Normalize to [0, 1]
38
- dress_mask = (output > 0.5).astype(np.uint8) * 255
39
- dress_mask = cv2.resize(dress_mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_LINEAR)
 
 
 
 
 
40
 
41
  return refine_mask(dress_mask)
42
 
43
  def apply_grabcut(image_np, dress_mask):
44
- """Refines the mask using GrabCut to avoid color bleeding"""
45
  bgd_model = np.zeros((1, 65), np.float64)
46
  fgd_model = np.zeros((1, 65), np.float64)
47
 
48
  mask = np.where(dress_mask > 0, cv2.GC_PR_FGD, cv2.GC_BGD).astype('uint8')
49
-
50
- # Get bounding box of the mask
51
  coords = cv2.findNonZero(dress_mask)
52
- x, y, w, h = cv2.boundingRect(coords)
53
- rect = (x, y, w, h)
54
-
55
- cv2.grabCut(image_np, mask, rect, bgd_model, fgd_model, 5, cv2.GC_INIT_WITH_MASK)
56
-
57
  refined_mask = np.where((mask == cv2.GC_FGD) | (mask == cv2.GC_PR_FGD), 255, 0).astype("uint8")
58
  return refine_mask(refined_mask)
59
 
60
  def recolor_dress(image_np, dress_mask, target_color):
61
- """Changes dress color while keeping texture & lighting intact"""
62
-
63
- # Convert target color to LAB
64
  target_color_lab = cv2.cvtColor(np.uint8([[target_color]]), cv2.COLOR_BGR2LAB)[0][0]
65
-
66
- # Convert image to LAB
67
  img_lab = cv2.cvtColor(image_np, cv2.COLOR_RGB2LAB)
68
 
69
- # Compute mean LAB values of dress pixels
70
  dress_pixels = img_lab[dress_mask > 0]
71
  if len(dress_pixels) == 0:
72
- return image_np # No dress detected
73
 
74
  mean_L, mean_A, mean_B = np.mean(dress_pixels, axis=0)
75
-
76
- # Apply LAB shift
77
  a_shift = target_color_lab[1] - mean_A
78
  b_shift = target_color_lab[2] - mean_B
 
 
79
  img_lab[..., 1] = np.clip(img_lab[..., 1] + (dress_mask / 255.0) * a_shift, 0, 255)
80
  img_lab[..., 2] = np.clip(img_lab[..., 2] + (dress_mask / 255.0) * b_shift, 0, 255)
81
 
82
- # Convert back to RGB
83
  img_recolored = cv2.cvtColor(img_lab.astype(np.uint8), cv2.COLOR_LAB2RGB)
84
-
85
- # Create feathered mask for smooth blending
86
- lightness_mask = (img_lab[..., 0] / 255.0)
87
- feathered_mask = cv2.GaussianBlur(dress_mask, (15, 15), 5)
88
  adaptive_feather = (feathered_mask * lightness_mask).astype(np.uint8)
89
 
90
- # Blend the recolored dress with the original image
91
- img_final = (image_np * (1 - adaptive_feather[..., None] / 255) + img_recolored * (adaptive_feather[..., None] / 255)).astype(np.uint8)
92
-
93
- return img_final
94
 
95
  def change_dress_color(img, color):
96
- """Main function to change dress color naturally"""
97
  if img is None:
98
  return None
99
 
100
- # Convert color name to BGR using a safer method
101
  color_map = {
102
  "Red": (0, 0, 255), "Blue": (255, 0, 0), "Green": (0, 255, 0),
103
  "Yellow": (0, 255, 255), "Purple": (128, 0, 128), "Orange": (0, 165, 255),
@@ -105,45 +113,42 @@ def change_dress_color(img, color):
105
  "Black": (0, 0, 0)
106
  }
107
 
108
- # Safely get color with fallback to red
109
  new_color_bgr = color_map.get(color, (0, 0, 255))
110
-
111
  img_np = np.array(img)
112
 
113
- # Get dress segmentation mask
114
- dress_mask = segment_dress(img_np)
115
-
116
- if dress_mask is None or np.sum(dress_mask) == 0:
117
- return img # Return original if no mask found
118
-
119
- # Further refine mask with GrabCut
120
- dress_mask = apply_grabcut(img_np, dress_mask)
121
-
122
- # Apply recoloring with blending
123
- img_recolored = recolor_dress(img_np, dress_mask, new_color_bgr)
124
-
125
- return Image.fromarray(img_recolored)
126
 
127
- # Create Gradio Blocks interface instead of simple Interface
128
  with gr.Blocks() as demo:
129
- gr.Markdown("# AI-Powered Dress Color Changer")
130
- gr.Markdown("Upload an image of a dress and select a new color. The AI will change the dress color naturally while keeping the fabric texture.")
131
 
132
  with gr.Row():
133
  with gr.Column():
134
- input_image = gr.Image(type="pil", label="Upload Dress Image")
135
  color_choice = gr.Dropdown(
136
  choices=["Red", "Blue", "Green", "Yellow", "Purple",
137
  "Orange", "Cyan", "Magenta", "White", "Black"],
138
  value="Red",
139
- label="Choose New Dress Color"
140
  )
141
- submit_btn = gr.Button("Change Color")
142
 
143
  with gr.Column():
144
  output_image = gr.Image(type="pil", label="Result")
145
 
146
- submit_btn.click(
147
  fn=change_dress_color,
148
  inputs=[input_image, color_choice],
149
  outputs=output_image
 
4
  import cv2
5
  from PIL import Image
6
  from torchvision import transforms
7
+ from cloth_segmentation.networks.u2net import U2NET
8
 
9
  # Load U²-Net model
10
+ model_path = "/kaggle/input/tygsgg/cloth_segmentation/networks/u2net.pth"
11
  model = U2NET(3, 1)
12
  state_dict = torch.load(model_path, map_location=torch.device('cpu'))
13
  state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
 
15
  model.eval()
16
 
17
  def refine_mask(mask):
18
+ """Enhanced mask refinement with erosion and morphological operations"""
19
+ # First closing to fill small holes
20
+ close_kernel = np.ones((5, 5), np.uint8)
21
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, close_kernel)
22
+
23
+ # Erosion to remove small protrusions and extra areas
24
+ erode_kernel = np.ones((3, 3), np.uint8)
25
+ mask = cv2.erode(mask, erode_kernel, iterations=1)
26
+
27
+ # Second closing to refine edges after erosion
28
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, close_kernel)
29
+
30
+ # Final blur to smooth edges while preserving shape
31
+ mask = cv2.GaussianBlur(mask, (5, 5), 1.5)
32
+
33
  return mask
34
 
35
  def segment_dress(image_np):
36
+ """Improved dress segmentation with adaptive thresholding"""
37
  transform_pipeline = transforms.Compose([
38
  transforms.ToTensor(),
39
  transforms.Resize((320, 320))
 
45
  with torch.no_grad():
46
  output = model(input_tensor)[0][0].squeeze().cpu().numpy()
47
 
48
+ # Adaptive threshold calculation
49
+ output = (output - output.min()) / (output.max() - output.min() + 1e-8)
50
+ adaptive_thresh = np.mean(output) + 0.2 # Increased threshold for tighter mask
51
+ dress_mask = (output > adaptive_thresh).astype(np.uint8) * 255
52
+
53
+ # Preserve hard edges during resize
54
+ dress_mask = cv2.resize(dress_mask, (image_np.shape[1], image_np.shape[0]),
55
+ interpolation=cv2.INTER_NEAREST)
56
 
57
  return refine_mask(dress_mask)
58
 
59
  def apply_grabcut(image_np, dress_mask):
60
+ """Mask refinement using GrabCut"""
61
  bgd_model = np.zeros((1, 65), np.float64)
62
  fgd_model = np.zeros((1, 65), np.float64)
63
 
64
  mask = np.where(dress_mask > 0, cv2.GC_PR_FGD, cv2.GC_BGD).astype('uint8')
65
+
66
+ # Get bounding box coordinates
67
  coords = cv2.findNonZero(dress_mask)
68
+ if coords is not None:
69
+ x, y, w, h = cv2.boundingRect(coords)
70
+ rect = (x, y, w, h)
71
+ cv2.grabCut(image_np, mask, rect, bgd_model, fgd_model, 3, cv2.GC_INIT_WITH_MASK)
72
+
73
  refined_mask = np.where((mask == cv2.GC_FGD) | (mask == cv2.GC_PR_FGD), 255, 0).astype("uint8")
74
  return refine_mask(refined_mask)
75
 
76
  def recolor_dress(image_np, dress_mask, target_color):
77
+ """Color transformation with improved blending"""
78
+ # Convert colors to LAB space
 
79
  target_color_lab = cv2.cvtColor(np.uint8([[target_color]]), cv2.COLOR_BGR2LAB)[0][0]
 
 
80
  img_lab = cv2.cvtColor(image_np, cv2.COLOR_RGB2LAB)
81
 
82
+ # Calculate color shifts
83
  dress_pixels = img_lab[dress_mask > 0]
84
  if len(dress_pixels) == 0:
85
+ return image_np
86
 
87
  mean_L, mean_A, mean_B = np.mean(dress_pixels, axis=0)
 
 
88
  a_shift = target_color_lab[1] - mean_A
89
  b_shift = target_color_lab[2] - mean_B
90
+
91
+ # Apply color transformation
92
  img_lab[..., 1] = np.clip(img_lab[..., 1] + (dress_mask / 255.0) * a_shift, 0, 255)
93
  img_lab[..., 2] = np.clip(img_lab[..., 2] + (dress_mask / 255.0) * b_shift, 0, 255)
94
 
95
+ # Create adaptive blending mask
96
  img_recolored = cv2.cvtColor(img_lab.astype(np.uint8), cv2.COLOR_LAB2RGB)
97
+ feathered_mask = cv2.GaussianBlur(dress_mask, (21, 21), 7)
98
+ lightness_mask = (img_lab[..., 0] / 255.0) ** 0.7
 
 
99
  adaptive_feather = (feathered_mask * lightness_mask).astype(np.uint8)
100
 
101
+ # Smooth blending
102
+ return (image_np * (1 - adaptive_feather[..., None]/255) + img_recolored * (adaptive_feather[..., None]/255)).astype(np.uint8)
 
 
103
 
104
  def change_dress_color(img, color):
105
+ """Main processing function with error handling"""
106
  if img is None:
107
  return None
108
 
 
109
  color_map = {
110
  "Red": (0, 0, 255), "Blue": (255, 0, 0), "Green": (0, 255, 0),
111
  "Yellow": (0, 255, 255), "Purple": (128, 0, 128), "Orange": (0, 165, 255),
 
113
  "Black": (0, 0, 0)
114
  }
115
 
 
116
  new_color_bgr = color_map.get(color, (0, 0, 255))
 
117
  img_np = np.array(img)
118
 
119
+ try:
120
+ dress_mask = segment_dress(img_np)
121
+ if np.sum(dress_mask) < 1000: # Minimum mask area threshold
122
+ return img
123
+
124
+ dress_mask = apply_grabcut(img_np, dress_mask)
125
+ img_recolored = recolor_dress(img_np, dress_mask, new_color_bgr)
126
+ return Image.fromarray(img_recolored)
127
+
128
+ except Exception as e:
129
+ print(f"Error processing image: {str(e)}")
130
+ return img
 
131
 
132
+ # Gradio Interface
133
  with gr.Blocks() as demo:
134
+ gr.Markdown("# AI Dress Color Changer")
135
+ gr.Markdown("Upload a dress image and select a new color for realistic recoloring")
136
 
137
  with gr.Row():
138
  with gr.Column():
139
+ input_image = gr.Image(type="pil", label="Input Image")
140
  color_choice = gr.Dropdown(
141
  choices=["Red", "Blue", "Green", "Yellow", "Purple",
142
  "Orange", "Cyan", "Magenta", "White", "Black"],
143
  value="Red",
144
+ label="Select New Color"
145
  )
146
+ process_btn = gr.Button("Recolor Dress")
147
 
148
  with gr.Column():
149
  output_image = gr.Image(type="pil", label="Result")
150
 
151
+ process_btn.click(
152
  fn=change_dress_color,
153
  inputs=[input_image, color_choice],
154
  outputs=output_image