phitran commited on
Commit
37ae657
·
1 Parent(s): 609a8ea

edit UI, add logic to detect keypoints

Browse files
Files changed (3) hide show
  1. app.py +79 -39
  2. example/cloth/test.txt +0 -0
  3. example/human/test.txt +0 -0
app.py CHANGED
@@ -2,6 +2,14 @@ import gradio as gr
2
  import cv2
3
  import numpy as np
4
  import mediapipe as mp
 
 
 
 
 
 
 
 
5
 
6
  # Initialize MediaPipe Pose
7
  mp_pose = mp.solutions.pose
@@ -9,52 +17,84 @@ pose = mp_pose.Pose(static_image_mode=True)
9
  mp_drawing = mp.solutions.drawing_utils
10
  mp_pose_landmark = mp_pose.PoseLandmark
11
 
12
- def detect_pose(image):
13
- # Convert to RGB
14
- image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
15
 
16
- # Run pose detection
 
17
  result = pose.process(image_rgb)
18
-
19
  keypoints = {}
20
 
21
  if result.pose_landmarks:
22
- # Draw landmarks on image
23
- mp_drawing.draw_landmarks(image, result.pose_landmarks, mp_pose.POSE_CONNECTIONS)
24
-
25
- # Get image dimensions
26
- height, width, _ = image.shape
27
 
28
- # Extract specific landmarks
29
- landmark_indices = {
30
  'left_shoulder': mp_pose_landmark.LEFT_SHOULDER,
31
  'right_shoulder': mp_pose_landmark.RIGHT_SHOULDER,
32
- 'left_hip': mp_pose_landmark.LEFT_HIP,
33
- 'right_hip': mp_pose_landmark.RIGHT_HIP
34
  }
35
 
36
- for name, index in landmark_indices.items():
37
- lm = result.pose_landmarks.landmark[index]
38
- x, y = int(lm.x * width), int(lm.y * height)
39
- keypoints[name] = (x, y)
40
-
41
- # Draw a circle + label for debug
42
- cv2.circle(image, (x, y), 5, (0, 255, 0), -1)
43
- cv2.putText(image, name, (x + 5, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
44
-
45
- return image, keypoints
46
-
47
- # Gradio interface
48
- iface = gr.Interface(
49
- fn=detect_pose,
50
- inputs=gr.Image(type="numpy", label="Upload Full-Body Image"),
51
- outputs=[
52
- gr.Image(type="numpy", label="Pose Visualization"),
53
- gr.JSON(label="Extracted Keypoints")
54
- ],
55
- title="Virtual Try-On - Pose Detection",
56
- description="Detects body keypoints using MediaPipe Pose and visualizes them. Shoulders and hips are labeled."
57
- )
58
-
59
- if __name__ == "__main__":
60
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import cv2
3
  import numpy as np
4
  import mediapipe as mp
5
+ import os
6
+
7
+ example_path = os.path.join(os.path.dirname(__file__), 'example')
8
+ garm_list = os.listdir(os.path.join(example_path, "cloth"))
9
+ garm_list_path = [os.path.join(example_path, "cloth", garm) for garm in garm_list]
10
+
11
+ human_list = os.listdir(os.path.join(example_path, "cloth"))
12
+ human_list_path = [os.path.join(example_path, "cloth", garm) for garm in garm_list]
13
 
14
  # Initialize MediaPipe Pose
15
  mp_pose = mp.solutions.pose
 
17
  mp_drawing = mp.solutions.drawing_utils
18
  mp_pose_landmark = mp_pose.PoseLandmark
19
 
 
 
 
20
 
21
+ def align_clothing(body_img, clothing_img):
22
+ image_rgb = cv2.cvtColor(body_img, cv2.COLOR_BGR2RGB)
23
  result = pose.process(image_rgb)
24
+ output = body_img.copy()
25
  keypoints = {}
26
 
27
  if result.pose_landmarks:
28
+ height, width, _ = output.shape
 
 
 
 
29
 
30
+ # Extract body keypoints
31
+ points = {
32
  'left_shoulder': mp_pose_landmark.LEFT_SHOULDER,
33
  'right_shoulder': mp_pose_landmark.RIGHT_SHOULDER,
34
+ 'left_hip': mp_pose_landmark.LEFT_HIP
 
35
  }
36
 
37
+ for name, idx in points.items():
38
+ lm = result.pose_landmarks.landmark[idx]
39
+ keypoints[name] = (int(lm.x * width), int(lm.y * height))
40
+
41
+ # Draw for debug
42
+ for name, (x, y) in keypoints.items():
43
+ cv2.circle(output, (x, y), 5, (0, 255, 0), -1)
44
+ cv2.putText(output, name, (x + 5, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
45
+
46
+ # Affine Transform
47
+ if all(k in keypoints for k in ['left_shoulder', 'right_shoulder', 'left_hip']):
48
+ src_tri = np.array([
49
+ [0, 0],
50
+ [clothing_img.shape[1], 0],
51
+ [0, clothing_img.shape[0]]
52
+ ], dtype=np.float32)
53
+
54
+ dst_tri = np.array([
55
+ keypoints['left_shoulder'],
56
+ keypoints['right_shoulder'],
57
+ keypoints['left_hip']
58
+ ], dtype=np.float32)
59
+
60
+ # Compute warp matrix and apply it
61
+ warp_mat = cv2.getAffineTransform(src_tri, dst_tri)
62
+ warped_clothing = cv2.warpAffine(clothing_img, warp_mat, (width, height), flags=cv2.INTER_LINEAR,
63
+ borderMode=cv2.BORDER_TRANSPARENT)
64
+
65
+ # Blend clothing over body
66
+ if clothing_img.shape[2] == 4: # has alpha
67
+ alpha = warped_clothing[:, :, 3] / 255.0
68
+ for c in range(3):
69
+ output[:, :, c] = (1 - alpha) * output[:, :, c] + alpha * warped_clothing[:, :, c]
70
+ else:
71
+ output = cv2.addWeighted(output, 0.8, warped_clothing, 0.5, 0)
72
+
73
+ return output
74
+
75
+
76
+ image_blocks = gr.Blocks(theme="Nymbo/Alyx_Theme").queue()
77
+ with image_blocks as demo:
78
+ gr.HTML("<center><h1>Virtual Try-On</h1></center>")
79
+ gr.HTML("<center><p>Upload an image of a person and an image of a garment ✨</p></center>")
80
+ with gr.Row():
81
+ with gr.Column():
82
+ imgs = gr.Image(type="pil", label='Human', interactive=True)
83
+ example = gr.Examples(
84
+ inputs=imgs,
85
+ examples_per_page=10,
86
+ examples=human_list_path
87
+ )
88
+
89
+ with gr.Column():
90
+ garm_img = gr.Image(label="Garment", type="pil",interactive=True)
91
+ example = gr.Examples(
92
+ inputs=garm_img,
93
+ examples_per_page=8,
94
+ examples=garm_list_path)
95
+ with gr.Column():
96
+ image_out = gr.Image(label="Processed image", type="pil")
97
+
98
+ with gr.Row():
99
+ try_button = gr.Button(value="Try-on")
100
+ image_blocks.launch()
example/cloth/test.txt ADDED
File without changes
example/human/test.txt ADDED
File without changes