|
import torch |
|
import numpy as np |
|
import cv2 |
|
from PIL import Image |
|
from transformers import AutoProcessor, RTDetrForObjectDetection, VitPoseForPoseEstimation |
|
from pathlib import Path |
|
|
|
|
|
person_processor = None |
|
person_model = None |
|
pose_processor = None |
|
pose_model = None |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
print(f"Pose Estimator: Using device: {device}") |
|
|
|
|
|
|
|
DEFAULT_MARKER_COLOR = (255, 255, 255) |
|
MIN_PIXELS_FOR_COLOR = 20 |
|
CONFIDENCE_THRESHOLD_KEYPOINTS = 0.3 |
|
SKELETON_THICKNESS = 2 |
|
|
|
|
|
|
|
|
|
|
|
SKELETON_EDGES = [ |
|
|
|
(0, 1), (0, 2), (1, 3), (2, 4), |
|
|
|
(5, 6), (5, 11), (6, 12), (11, 12), |
|
|
|
(5, 7), (7, 9), |
|
|
|
(6, 8), (8, 10), |
|
|
|
(11, 13), (13, 15), |
|
|
|
(12, 14), (14, 16) |
|
] |
|
|
|
|
|
TORSO_KP_INDICES = [5, 6, 11, 12] |
|
LEFT_ANKLE_KP_INDEX = 15 |
|
RIGHT_ANKLE_KP_INDEX = 16 |
|
|
|
def _load_models(): |
|
"""Loads the models if they haven't been loaded yet.""" |
|
global person_processor, person_model, pose_processor, pose_model |
|
|
|
if person_processor is None: |
|
print("Loading RTDetr person detector model...") |
|
person_processor = AutoProcessor.from_pretrained("PekingU/rtdetr_r50vd_coco_o365") |
|
person_model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd_coco_o365", device_map=device) |
|
print("✓ RTDetr loaded.") |
|
|
|
if pose_processor is None: |
|
print("Loading ViTPose pose estimation model...") |
|
pose_processor = AutoProcessor.from_pretrained("usyd-community/vitpose-base-simple") |
|
pose_model = VitPoseForPoseEstimation.from_pretrained("usyd-community/vitpose-base-simple", device_map=device) |
|
print("✓ ViTPose loaded.") |
|
|
|
def _is_color_greenish(bgr_pixel, threshold=10): |
|
b, g, r = bgr_pixel |
|
return g > b + threshold and g > r + threshold |
|
|
|
def _is_color_grayscale(bgr_pixel, tolerance=30): |
|
b, g, r = bgr_pixel |
|
min_val, max_val = min(b, g, r), max(b, g, r) |
|
is_dark = max_val < 50 |
|
is_light = min_val > 200 |
|
is_low_saturation = (max_val - min_val) < tolerance |
|
return is_dark or is_light or is_low_saturation |
|
|
|
def _get_average_color(roi_bgr): |
|
"""Calcule la couleur moyenne d'une ROI après filtrage.""" |
|
if roi_bgr is None or roi_bgr.size == 0: |
|
return DEFAULT_MARKER_COLOR |
|
|
|
try: |
|
pixels = roi_bgr.reshape(-1, 3) |
|
valid_pixels = [] |
|
for pixel in pixels: |
|
if not _is_color_greenish(pixel) and not _is_color_grayscale(pixel): |
|
valid_pixels.append(pixel) |
|
|
|
if len(valid_pixels) < MIN_PIXELS_FOR_COLOR: |
|
return DEFAULT_MARKER_COLOR |
|
|
|
avg_color = np.mean(valid_pixels, axis=0) |
|
return tuple(map(int, avg_color)) |
|
|
|
except Exception as e: |
|
print(f" Erreur calcul couleur moyenne: {e}. Utilisation couleur défaut.") |
|
return DEFAULT_MARKER_COLOR |
|
|
|
def get_player_data(image_bgr: np.ndarray) -> list: |
|
""" |
|
Detects persons, estimates pose, calculates average torso color, |
|
and returns a list of data for each player. |
|
|
|
Args: |
|
image_bgr: The input image in BGR format (NumPy array). |
|
|
|
Returns: |
|
A list of dictionaries, each containing: |
|
{ |
|
'keypoints': np.ndarray (17, 2), |
|
'scores': np.ndarray (17,), |
|
'bbox': np.ndarray (4,) [x1, y1, x2, y2], |
|
'avg_color': tuple (b, g, r) |
|
} |
|
Returns an empty list if no persons are detected or an error occurs. |
|
""" |
|
_load_models() |
|
player_list = [] |
|
height, width = image_bgr.shape[:2] |
|
|
|
try: |
|
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) |
|
image_pil = Image.fromarray(image_rgb) |
|
|
|
|
|
inputs_det = person_processor(images=image_pil, return_tensors="pt").to(device) |
|
with torch.no_grad(): |
|
outputs_det = person_model(**inputs_det) |
|
results_det = person_processor.post_process_object_detection( |
|
outputs_det, target_sizes=torch.tensor([(height, width)]), threshold=0.5 |
|
) |
|
result_det = results_det[0] |
|
person_boxes = result_det["boxes"][result_det["labels"] == 0].cpu().numpy() |
|
|
|
if len(person_boxes) == 0: |
|
print("No persons detected.") |
|
return player_list |
|
|
|
person_boxes_coco = person_boxes.copy() |
|
person_boxes_coco[:, 2] = person_boxes_coco[:, 2] - person_boxes_coco[:, 0] |
|
person_boxes_coco[:, 3] = person_boxes_coco[:, 3] - person_boxes_coco[:, 1] |
|
|
|
|
|
inputs_pose = pose_processor(image_pil, boxes=[person_boxes_coco], return_tensors="pt").to(device) |
|
with torch.no_grad(): |
|
outputs_pose = pose_model(**inputs_pose) |
|
pose_results = pose_processor.post_process_pose_estimation(outputs_pose, boxes=[person_boxes_coco]) |
|
image_pose_result = pose_results[0] |
|
|
|
if not image_pose_result: |
|
print("Pose estimation did not return results.") |
|
return player_list |
|
|
|
|
|
for i, person_box_xyxy in enumerate(person_boxes): |
|
if i >= len(image_pose_result): continue |
|
|
|
pose_result = image_pose_result[i] |
|
xy = pose_result['keypoints'].cpu().numpy() |
|
scores = pose_result['scores'].cpu().numpy() |
|
|
|
|
|
if xy.shape != (17, 2): |
|
print(f"Person {i}: Unexpected keypoints shape {xy.shape}, skipping.") |
|
continue |
|
|
|
|
|
reliable_torso_keypoints = xy[TORSO_KP_INDICES][scores[TORSO_KP_INDICES] > CONFIDENCE_THRESHOLD_KEYPOINTS] |
|
x1_box, y1_box, x2_box, y2_box = map(int, person_box_xyxy) |
|
box_h = y2_box - y1_box |
|
box_w = x2_box - x1_box |
|
if len(reliable_torso_keypoints) >= 3: |
|
min_x_kp = int(np.min(reliable_torso_keypoints[:, 0])) |
|
max_x_kp = int(np.max(reliable_torso_keypoints[:, 0])) |
|
min_y_kp = int(np.min(reliable_torso_keypoints[:, 1])) |
|
max_y_kp = int(np.max(reliable_torso_keypoints[:, 1])) |
|
roi_x1 = max(x1_box, min_x_kp - 5); roi_y1 = max(y1_box, min_y_kp - 5) |
|
roi_x2 = min(x2_box, max_x_kp + 5); roi_y2 = min(y2_box, max_y_kp + 5) |
|
else: |
|
roi_x1 = x1_box; roi_y1 = y1_box + int(0.1 * box_h) |
|
roi_x2 = x2_box; roi_y2 = y1_box + int(0.6 * box_h) |
|
roi_x1 = max(0, roi_x1); roi_y1 = max(0, roi_y1) |
|
roi_x2 = min(width, roi_x2); roi_y2 = min(height, roi_y2) |
|
|
|
|
|
avg_color = DEFAULT_MARKER_COLOR |
|
if roi_y2 > roi_y1 and roi_x2 > roi_x1: |
|
torso_roi = image_bgr[roi_y1:roi_y2, roi_x1:roi_x2] |
|
avg_color = _get_average_color(torso_roi) |
|
|
|
|
|
|
|
|
|
player_data = { |
|
'keypoints': xy, |
|
'scores': scores, |
|
'bbox': person_box_xyxy, |
|
'avg_color': avg_color |
|
} |
|
player_list.append(player_data) |
|
|
|
except Exception as e: |
|
print(f"Error during player data extraction: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
|
|
return [] |
|
|
|
return player_list |
|
|
|
|
|
if __name__ == '__main__': |
|
test_image_path = 'img3.png' |
|
|
|
if not Path(test_image_path).exists(): |
|
print(f"Test image not found: {test_image_path}") |
|
else: |
|
print(f"Testing player data extraction with image: {test_image_path}") |
|
input_img = cv2.imread(test_image_path) |
|
|
|
if input_img is None: |
|
print(f"Failed to load test image: {test_image_path}") |
|
else: |
|
print("Getting player data...") |
|
players = get_player_data(input_img) |
|
print(f"✓ Found data for {len(players)} players.") |
|
|
|
|
|
output_img_test = input_img.copy() |
|
for idx, p_data in enumerate(players): |
|
kps = p_data['keypoints'] |
|
scores = p_data['scores'] |
|
bbox = p_data['bbox'] |
|
color = p_data['avg_color'] |
|
|
|
|
|
l_ankle_pt = kps[LEFT_ANKLE_KP_INDEX] |
|
r_ankle_pt = kps[RIGHT_ANKLE_KP_INDEX] |
|
l_ankle_score = scores[LEFT_ANKLE_KP_INDEX] |
|
r_ankle_score = scores[RIGHT_ANKLE_KP_INDEX] |
|
|
|
ref_point = None |
|
if l_ankle_score > CONFIDENCE_THRESHOLD_KEYPOINTS and r_ankle_score > CONFIDENCE_THRESHOLD_KEYPOINTS: |
|
ref_point = tuple(map(int, (l_ankle_pt + r_ankle_pt) / 2)) |
|
elif l_ankle_score > CONFIDENCE_THRESHOLD_KEYPOINTS: |
|
ref_point = tuple(map(int, l_ankle_pt)) |
|
elif r_ankle_score > CONFIDENCE_THRESHOLD_KEYPOINTS: |
|
ref_point = tuple(map(int, r_ankle_pt)) |
|
else: |
|
x1, y1, x2, y2 = map(int, bbox) |
|
ref_point = (int((x1 + x2) / 2), y2) |
|
|
|
|
|
if ref_point: |
|
cv2.circle(output_img_test, ref_point, 8, color, -1, cv2.LINE_AA) |
|
cv2.circle(output_img_test, ref_point, 8, (0,0,0), 1, cv2.LINE_AA) |
|
|
|
cv2.putText(output_img_test, str(idx), (ref_point[0]+5, ref_point[1]-5), |
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,0,0), 2, cv2.LINE_AA) |
|
cv2.putText(output_img_test, str(idx), (ref_point[0]+5, ref_point[1]-5), |
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 1, cv2.LINE_AA) |
|
|
|
cv2.imshow("Original Image", input_img) |
|
cv2.imshow("Player Markers Test", output_img_test) |
|
print("Displaying test results. Press any key to exit.") |
|
cv2.waitKey(0) |
|
cv2.destroyAllWindows() |