phitran commited on
Commit
3286ec0
·
1 Parent(s): fd15f8d

implement UI and basic functions

Browse files
Files changed (1) hide show
  1. app.py +47 -53
app.py CHANGED
@@ -5,11 +5,12 @@ import mediapipe as mp
5
  import os
6
 
7
  example_path = os.path.join(os.path.dirname(__file__), 'example')
 
8
  garm_list = os.listdir(os.path.join(example_path, "cloth"))
9
  garm_list_path = [os.path.join(example_path, "cloth", garm) for garm in garm_list]
10
 
11
- human_list = os.listdir(os.path.join(example_path, "cloth"))
12
- human_list_path = [os.path.join(example_path, "cloth", garm) for garm in garm_list]
13
 
14
  # Initialize MediaPipe Pose
15
  mp_pose = mp.solutions.pose
@@ -18,76 +19,65 @@ mp_drawing = mp.solutions.drawing_utils
18
  mp_pose_landmark = mp_pose.PoseLandmark
19
 
20
 
21
- def align_clothing(body_img, clothing_img):
22
- image_rgb = cv2.cvtColor(body_img, cv2.COLOR_BGR2RGB)
 
 
 
23
  result = pose.process(image_rgb)
24
- output = body_img.copy()
25
  keypoints = {}
26
 
27
  if result.pose_landmarks:
28
- height, width, _ = output.shape
 
 
 
 
29
 
30
- # Extract body keypoints
31
- points = {
32
  'left_shoulder': mp_pose_landmark.LEFT_SHOULDER,
33
  'right_shoulder': mp_pose_landmark.RIGHT_SHOULDER,
34
- 'left_hip': mp_pose_landmark.LEFT_HIP
 
35
  }
36
 
37
- for name, idx in points.items():
38
- lm = result.pose_landmarks.landmark[idx]
39
- keypoints[name] = (int(lm.x * width), int(lm.y * height))
40
-
41
- # Draw for debug
42
- for name, (x, y) in keypoints.items():
43
- cv2.circle(output, (x, y), 5, (0, 255, 0), -1)
44
- cv2.putText(output, name, (x + 5, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
45
-
46
- # Affine Transform
47
- if all(k in keypoints for k in ['left_shoulder', 'right_shoulder', 'left_hip']):
48
- src_tri = np.array([
49
- [0, 0],
50
- [clothing_img.shape[1], 0],
51
- [0, clothing_img.shape[0]]
52
- ], dtype=np.float32)
53
-
54
- dst_tri = np.array([
55
- keypoints['left_shoulder'],
56
- keypoints['right_shoulder'],
57
- keypoints['left_hip']
58
- ], dtype=np.float32)
59
-
60
- # Compute warp matrix and apply it
61
- warp_mat = cv2.getAffineTransform(src_tri, dst_tri)
62
- warped_clothing = cv2.warpAffine(clothing_img, warp_mat, (width, height), flags=cv2.INTER_LINEAR,
63
- borderMode=cv2.BORDER_TRANSPARENT)
64
-
65
- # Blend clothing over body
66
- if clothing_img.shape[2] == 4: # has alpha
67
- alpha = warped_clothing[:, :, 3] / 255.0
68
- for c in range(3):
69
- output[:, :, c] = (1 - alpha) * output[:, :, c] + alpha * warped_clothing[:, :, c]
70
- else:
71
- output = cv2.addWeighted(output, 0.8, warped_clothing, 0.5, 0)
72
-
73
- return output
74
-
75
-
76
- image_blocks = gr.Blocks(theme="Nymbo/Alyx_Theme").queue()
77
  with image_blocks as demo:
78
  gr.HTML("<center><h1>Virtual Try-On</h1></center>")
79
  gr.HTML("<center><p>Upload an image of a person and an image of a garment ✨</p></center>")
80
  with gr.Row():
81
  with gr.Column():
82
- imgs = gr.Image(type="pil", label='Human', interactive=True)
83
  example = gr.Examples(
84
- inputs=imgs,
85
  examples_per_page=10,
86
  examples=human_list_path
87
  )
88
 
89
  with gr.Column():
90
- garm_img = gr.Image(label="Garment", type="pil",interactive=True)
91
  example = gr.Examples(
92
  inputs=garm_img,
93
  examples_per_page=8,
@@ -96,5 +86,9 @@ with image_blocks as demo:
96
  image_out = gr.Image(label="Processed image", type="pil")
97
 
98
  with gr.Row():
99
- try_button = gr.Button(value="Try-on")
 
 
 
 
100
  image_blocks.launch()
 
5
  import os
6
 
7
  example_path = os.path.join(os.path.dirname(__file__), 'example')
8
+
9
  garm_list = os.listdir(os.path.join(example_path, "cloth"))
10
  garm_list_path = [os.path.join(example_path, "cloth", garm) for garm in garm_list]
11
 
12
+ human_list = os.listdir(os.path.join(example_path, "human"))
13
+ human_list_path = [os.path.join(example_path, "human", human) for human in human_list]
14
 
15
  # Initialize MediaPipe Pose
16
  mp_pose = mp.solutions.pose
 
19
  mp_pose_landmark = mp_pose.PoseLandmark
20
 
21
 
22
+ def detect_pose(image):
23
+ # Convert to RGB
24
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
25
+
26
+ # Run pose detection
27
  result = pose.process(image_rgb)
28
+
29
  keypoints = {}
30
 
31
  if result.pose_landmarks:
32
+ # Draw landmarks on image
33
+ mp_drawing.draw_landmarks(image, result.pose_landmarks, mp_pose.POSE_CONNECTIONS)
34
+
35
+ # Get image dimensions
36
+ height, width, _ = image.shape
37
 
38
+ # Extract specific landmarks
39
+ landmark_indices = {
40
  'left_shoulder': mp_pose_landmark.LEFT_SHOULDER,
41
  'right_shoulder': mp_pose_landmark.RIGHT_SHOULDER,
42
+ 'left_hip': mp_pose_landmark.LEFT_HIP,
43
+ 'right_hip': mp_pose_landmark.RIGHT_HIP
44
  }
45
 
46
+ for name, index in landmark_indices.items():
47
+ lm = result.pose_landmarks.landmark[index]
48
+ x, y = int(lm.x * width), int(lm.y * height)
49
+ keypoints[name] = (x, y)
50
+
51
+ # Draw a circle + label for debug
52
+ cv2.circle(image, (x, y), 5, (0, 255, 0), -1)
53
+ cv2.putText(image, name, (x + 5, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
54
+
55
+ return image
56
+
57
+
58
+ def process_image(human_img):
59
+ # Convert PIL image to NumPy array
60
+ human_img = np.array(human_img)
61
+
62
+ processed_image = detect_pose(human_img)
63
+ return processed_image
64
+
65
+
66
+ image_blocks = gr.Blocks().queue()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  with image_blocks as demo:
68
  gr.HTML("<center><h1>Virtual Try-On</h1></center>")
69
  gr.HTML("<center><p>Upload an image of a person and an image of a garment ✨</p></center>")
70
  with gr.Row():
71
  with gr.Column():
72
+ human_img = gr.Image(type="pil", label='Human', interactive=True)
73
  example = gr.Examples(
74
+ inputs=human_img,
75
  examples_per_page=10,
76
  examples=human_list_path
77
  )
78
 
79
  with gr.Column():
80
+ garm_img = gr.Image(label="Garment", type="pil", interactive=True)
81
  example = gr.Examples(
82
  inputs=garm_img,
83
  examples_per_page=8,
 
86
  image_out = gr.Image(label="Processed image", type="pil")
87
 
88
  with gr.Row():
89
+ try_button = gr.Button(value="Try-on", variant='primary')
90
+
91
+ # Linking the button to the processing function
92
+ try_button.click(fn=process_image, inputs=human_img, outputs=image_out)
93
+
94
  image_blocks.launch()