football-minimap-generator / pose_estimator.py
RamziBm's picture
init
bdb955e
import torch
import numpy as np
import cv2
from PIL import Image
from transformers import AutoProcessor, RTDetrForObjectDetection, VitPoseForPoseEstimation
from pathlib import Path
# --- Global variables for models and processor (lazy loading) ---
person_processor = None
person_model = None
pose_processor = None
pose_model = None
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Pose Estimator: Using device: {device}")
# --- Constantes pour la couleur et le dessin ---
# Utilisation de tuples BGR pour les couleurs
DEFAULT_MARKER_COLOR = (255, 255, 255) # Blanc
MIN_PIXELS_FOR_COLOR = 20 # Nombre minimum de pixels valides dans la ROI pour tenter de calculer la couleur
CONFIDENCE_THRESHOLD_KEYPOINTS = 0.3 # Seuil pour considérer un keypoint fiable pour la ROI et le dessin
SKELETON_THICKNESS = 2
# Définition des segments du squelette (indices COCO 0-16)
# 0:Nose, 1:L_Eye, 2:R_Eye, 3:L_Ear, 4:R_Ear, 5:L_Shoulder, 6:R_Shoulder,
# 7:L_Elbow, 8:R_Elbow, 9:L_Wrist, 10:R_Wrist, 11:L_Hip, 12:R_Hip,
# 13:L_Knee, 14:R_Knee, 15:L_Ankle, 16:R_Ankle
SKELETON_EDGES = [
# Tête
(0, 1), (0, 2), (1, 3), (2, 4),
# Torse
(5, 6), (5, 11), (6, 12), (11, 12),
# Bras Gauche
(5, 7), (7, 9),
# Bras Droit
(6, 8), (8, 10),
# Jambe Gauche
(11, 13), (13, 15),
# Jambe Droite
(12, 14), (14, 16)
]
# Indices des keypoints pour le torse et les chevilles
TORSO_KP_INDICES = [5, 6, 11, 12] # Épaules, Hanches
LEFT_ANKLE_KP_INDEX = 15
RIGHT_ANKLE_KP_INDEX = 16
def _load_models():
"""Loads the models if they haven't been loaded yet."""
global person_processor, person_model, pose_processor, pose_model
if person_processor is None:
print("Loading RTDetr person detector model...")
person_processor = AutoProcessor.from_pretrained("PekingU/rtdetr_r50vd_coco_o365")
person_model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd_coco_o365", device_map=device)
print("✓ RTDetr loaded.")
if pose_processor is None:
print("Loading ViTPose pose estimation model...")
pose_processor = AutoProcessor.from_pretrained("usyd-community/vitpose-base-simple")
pose_model = VitPoseForPoseEstimation.from_pretrained("usyd-community/vitpose-base-simple", device_map=device)
print("✓ ViTPose loaded.")
def _is_color_greenish(bgr_pixel, threshold=10):
b, g, r = bgr_pixel
return g > b + threshold and g > r + threshold
def _is_color_grayscale(bgr_pixel, tolerance=30):
b, g, r = bgr_pixel
min_val, max_val = min(b, g, r), max(b, g, r)
is_dark = max_val < 50
is_light = min_val > 200
is_low_saturation = (max_val - min_val) < tolerance
return is_dark or is_light or is_low_saturation
def _get_average_color(roi_bgr):
"""Calcule la couleur moyenne d'une ROI après filtrage."""
if roi_bgr is None or roi_bgr.size == 0:
return DEFAULT_MARKER_COLOR
try:
pixels = roi_bgr.reshape(-1, 3)
valid_pixels = []
for pixel in pixels:
if not _is_color_greenish(pixel) and not _is_color_grayscale(pixel):
valid_pixels.append(pixel)
if len(valid_pixels) < MIN_PIXELS_FOR_COLOR:
return DEFAULT_MARKER_COLOR
avg_color = np.mean(valid_pixels, axis=0)
return tuple(map(int, avg_color))
except Exception as e:
print(f" Erreur calcul couleur moyenne: {e}. Utilisation couleur défaut.")
return DEFAULT_MARKER_COLOR
def get_player_data(image_bgr: np.ndarray) -> list:
"""
Detects persons, estimates pose, calculates average torso color,
and returns a list of data for each player.
Args:
image_bgr: The input image in BGR format (NumPy array).
Returns:
A list of dictionaries, each containing:
{
'keypoints': np.ndarray (17, 2),
'scores': np.ndarray (17,),
'bbox': np.ndarray (4,) [x1, y1, x2, y2],
'avg_color': tuple (b, g, r)
}
Returns an empty list if no persons are detected or an error occurs.
"""
_load_models()
player_list = []
height, width = image_bgr.shape[:2]
try:
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
image_pil = Image.fromarray(image_rgb)
# --- Stage 1: Detect humans ---
inputs_det = person_processor(images=image_pil, return_tensors="pt").to(device)
with torch.no_grad():
outputs_det = person_model(**inputs_det)
results_det = person_processor.post_process_object_detection(
outputs_det, target_sizes=torch.tensor([(height, width)]), threshold=0.5
)
result_det = results_det[0]
person_boxes = result_det["boxes"][result_det["labels"] == 0].cpu().numpy()
if len(person_boxes) == 0:
print("No persons detected.")
return player_list
person_boxes_coco = person_boxes.copy()
person_boxes_coco[:, 2] = person_boxes_coco[:, 2] - person_boxes_coco[:, 0]
person_boxes_coco[:, 3] = person_boxes_coco[:, 3] - person_boxes_coco[:, 1]
# --- Stage 2: Detect keypoints ---
inputs_pose = pose_processor(image_pil, boxes=[person_boxes_coco], return_tensors="pt").to(device)
with torch.no_grad():
outputs_pose = pose_model(**inputs_pose)
pose_results = pose_processor.post_process_pose_estimation(outputs_pose, boxes=[person_boxes_coco])
image_pose_result = pose_results[0]
if not image_pose_result:
print("Pose estimation did not return results.")
return player_list
# --- Stage 3: Process each person ---
for i, person_box_xyxy in enumerate(person_boxes):
if i >= len(image_pose_result): continue
pose_result = image_pose_result[i]
xy = pose_result['keypoints'].cpu().numpy()
scores = pose_result['scores'].cpu().numpy()
# Ensure xy shape is correct before proceeding
if xy.shape != (17, 2):
print(f"Person {i}: Unexpected keypoints shape {xy.shape}, skipping.")
continue
# -- Define Torso ROI --
reliable_torso_keypoints = xy[TORSO_KP_INDICES][scores[TORSO_KP_INDICES] > CONFIDENCE_THRESHOLD_KEYPOINTS]
x1_box, y1_box, x2_box, y2_box = map(int, person_box_xyxy)
box_h = y2_box - y1_box
box_w = x2_box - x1_box
if len(reliable_torso_keypoints) >= 3:
min_x_kp = int(np.min(reliable_torso_keypoints[:, 0]))
max_x_kp = int(np.max(reliable_torso_keypoints[:, 0]))
min_y_kp = int(np.min(reliable_torso_keypoints[:, 1]))
max_y_kp = int(np.max(reliable_torso_keypoints[:, 1]))
roi_x1 = max(x1_box, min_x_kp - 5); roi_y1 = max(y1_box, min_y_kp - 5)
roi_x2 = min(x2_box, max_x_kp + 5); roi_y2 = min(y2_box, max_y_kp + 5)
else:
roi_x1 = x1_box; roi_y1 = y1_box + int(0.1 * box_h)
roi_x2 = x2_box; roi_y2 = y1_box + int(0.6 * box_h)
roi_x1 = max(0, roi_x1); roi_y1 = max(0, roi_y1)
roi_x2 = min(width, roi_x2); roi_y2 = min(height, roi_y2)
# -- Extract Average Color --
avg_color = DEFAULT_MARKER_COLOR
if roi_y2 > roi_y1 and roi_x2 > roi_x1:
torso_roi = image_bgr[roi_y1:roi_y2, roi_x1:roi_x2]
avg_color = _get_average_color(torso_roi)
# else: # Pas besoin de message si ROI invalide, couleur par défaut suffit
# print(f"Person {i}: Invalid ROI, using default color.")
# -- Store player data --
player_data = {
'keypoints': xy,
'scores': scores,
'bbox': person_box_xyxy, # Utiliser la bbox originale xyxy
'avg_color': avg_color
}
player_list.append(player_data)
except Exception as e:
print(f"Error during player data extraction: {e}")
import traceback
traceback.print_exc()
# Retourner une liste vide en cas d'erreur majeure
return []
return player_list
# Example usage (optional, for testing the module directly)
if __name__ == '__main__':
test_image_path = 'img3.png'
if not Path(test_image_path).exists():
print(f"Test image not found: {test_image_path}")
else:
print(f"Testing player data extraction with image: {test_image_path}")
input_img = cv2.imread(test_image_path)
if input_img is None:
print(f"Failed to load test image: {test_image_path}")
else:
print("Getting player data...")
players = get_player_data(input_img)
print(f"✓ Found data for {len(players)} players.")
# --- Draw markers and info on original image for testing ---
output_img_test = input_img.copy()
for idx, p_data in enumerate(players):
kps = p_data['keypoints']
scores = p_data['scores']
bbox = p_data['bbox']
color = p_data['avg_color']
# Determine reference point (ankles or bbox bottom mid)
l_ankle_pt = kps[LEFT_ANKLE_KP_INDEX]
r_ankle_pt = kps[RIGHT_ANKLE_KP_INDEX]
l_ankle_score = scores[LEFT_ANKLE_KP_INDEX]
r_ankle_score = scores[RIGHT_ANKLE_KP_INDEX]
ref_point = None
if l_ankle_score > CONFIDENCE_THRESHOLD_KEYPOINTS and r_ankle_score > CONFIDENCE_THRESHOLD_KEYPOINTS:
ref_point = tuple(map(int, (l_ankle_pt + r_ankle_pt) / 2))
elif l_ankle_score > CONFIDENCE_THRESHOLD_KEYPOINTS:
ref_point = tuple(map(int, l_ankle_pt))
elif r_ankle_score > CONFIDENCE_THRESHOLD_KEYPOINTS:
ref_point = tuple(map(int, r_ankle_pt))
else:
x1, y1, x2, y2 = map(int, bbox)
ref_point = (int((x1 + x2) / 2), y2)
# Draw marker at reference point
if ref_point:
cv2.circle(output_img_test, ref_point, 8, color, -1, cv2.LINE_AA)
cv2.circle(output_img_test, ref_point, 8, (0,0,0), 1, cv2.LINE_AA) # Black outline
# Draw player index
cv2.putText(output_img_test, str(idx), (ref_point[0]+5, ref_point[1]-5),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,0,0), 2, cv2.LINE_AA)
cv2.putText(output_img_test, str(idx), (ref_point[0]+5, ref_point[1]-5),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 1, cv2.LINE_AA)
cv2.imshow("Original Image", input_img)
cv2.imshow("Player Markers Test", output_img_test)
print("Displaying test results. Press any key to exit.")
cv2.waitKey(0)
cv2.destroyAllWindows()