import gradio as gr import cv2 import numpy as np import torch from pathlib import Path import time import traceback # Import YOLO from ultralytics import YOLO # Importer les éléments nécessaires depuis les autres modules du projet try: from tvcalib.infer.module import TvCalibInferModule # On essaie d'importer la fonction de pré-traitement depuis main.py # Si main.py n'est pas conçu pour être importé, il faudra peut-être copier/coller cette fonction ici # Importer aussi les constantes YOLO depuis main.py (ou les redéfinir ici) from main import preprocess_image_tvcalib, IMAGE_SHAPE, SEGMENTATION_MODEL_PATH, YOLO_MODEL_PATH, BALL_CLASS_INDEX from visualizer import ( create_minimap_view, create_minimap_with_offset_skeletons, DYNAMIC_SCALE_MIN_MODULATION, DYNAMIC_SCALE_MAX_MODULATION ) from pose_estimator import get_player_data except ImportError as e: print(f"Erreur d'importation : {e}") print("Assurez-vous que les modules tvcalib, main, visualizer, pose_estimator sont accessibles.") # On pourrait mettre des stubs ou lever une exception ici pour Gradio raise e # --- Configuration Globale (Modèle, etc.) --- # Essayer de charger le modèle une seule fois globalement peut améliorer les performances # mais attention à la gestion de l'état dans les environnements multi-utilisateurs/threads de Spaces # Pour l'instant, on le chargera dans la fonction de traitement. # MODEL = None # Optionnel: Charger ici # DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"Utilisation du device : {DEVICE}") if not SEGMENTATION_MODEL_PATH.exists(): print(f"AVERTISSEMENT : Modèle de segmentation introuvable : {SEGMENTATION_MODEL_PATH}") print("L'application risque de ne pas fonctionner. Assurez-vous que le fichier est présent.") # Gradio peut quand même démarrer, mais le traitement échouera. # Vérifier si le modèle YOLO existe aussi if not YOLO_MODEL_PATH.exists(): print(f"AVERTISSEMENT : Modèle YOLO introuvable : {YOLO_MODEL_PATH}") print("L'application risque de ne pas fonctionner. Assurez-vous que le fichier est présent.") # --- Fonction Principale de Traitement --- def process_image_and_generate_minimaps(input_image_bgr, optim_steps, target_avg_scale): """ Prend une image BGR (NumPy), les étapes d'optimisation et l'échelle cible, retourne les deux minimaps (NumPy BGR). """ global DEVICE # Utiliser le device défini globalement print("\n--- Nouvelle requête ---") print(f"Paramètres: optim_steps={optim_steps}, target_avg_scale={target_avg_scale}") # Vérifier si le modèle de segmentation existe (important car on ne peut pas l'afficher dans l'UI facilement) if not SEGMENTATION_MODEL_PATH.exists(): # Retourner des images noires ou des messages d'erreur error_msg = f"Erreur: Modèle {SEGMENTATION_MODEL_PATH.name} introuvable." print(error_msg) # Créer un placeholder plus informatif placeholder = np.zeros((300, 500, 3), dtype=np.uint8) cv2.putText(placeholder, "Model Error:", (10, 130), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 1, cv2.LINE_AA) cv2.putText(placeholder, error_msg, (10, 155), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA) return placeholder, placeholder.copy() # Retourner deux placeholders # Vérifier aussi le modèle YOLO if not YOLO_MODEL_PATH.exists(): error_msg = f"Erreur: Modèle {YOLO_MODEL_PATH.name} introuvable." print(error_msg) placeholder = np.zeros((300, 500, 3), dtype=np.uint8) cv2.putText(placeholder, "Model Error:", (10, 130), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 1, cv2.LINE_AA) cv2.putText(placeholder, error_msg, (10, 155), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA) return placeholder, placeholder.copy() try: # 1. Initialisation du modèle TvCalib (peut être lent si fait à chaque fois) # Pourrait être optimisé en chargeant globalement (voir commentaire plus haut) print("Initialisation de TvCalibInferModule...") start_init = time.time() model = TvCalibInferModule( segmentation_checkpoint=SEGMENTATION_MODEL_PATH, image_shape=IMAGE_SHAPE, # Utilise la constante importée optim_steps=int(optim_steps), # Assurer que c'est un entier lens_dist=False ) # Déplacer le modèle sur le bon device ici explicitement si nécessaire # model.to(DEVICE) # TvCalibInferModule devrait gérer ça en interne ? A vérifier. print(f"✓ Modèle chargé sur {next(model.model_calib.parameters()).device} en {time.time() - start_init:.3f}s") model_device = next(model.model_calib.parameters()).device # Vérifier le device réel # 2. Prétraitement de l'image print("Prétraitement de l'image...") start_preprocess = time.time() # preprocess_image_tvcalib attend BGR, Gradio fournit BGR par défaut avec type="numpy" # Assurez-vous que preprocess_image_tvcalib déplace bien le tenseur sur le bon device image_tensor, image_bgr_resized, image_rgb_resized = preprocess_image_tvcalib(input_image_bgr) # Vérifier/forcer le device du tenseur image_tensor = image_tensor.to(model_device) print(f"Temps de prétraitement TvCalib : {time.time() - start_preprocess:.3f}s") # --- Détection du ballon avec YOLO --- print("Chargement du modèle YOLO et détection du ballon...") start_yolo = time.time() ball_ref_point_img = None # Point de référence du ballon sur l'image originale redimensionnée try: # Charger le modèle YOLO (pourrait être chargé globalement pour la perf, mais attention) yolo_model = YOLO(YOLO_MODEL_PATH) # Utiliser l'image BGR redimensionnée pour YOLO results = yolo_model.predict(image_bgr_resized, classes=[BALL_CLASS_INDEX], verbose=False) if results and len(results[0].boxes) > 0: # Prendre la détection avec la plus haute confiance best_ball_box = results[0].boxes[results[0].boxes.conf.argmax()] x1, y1, x2, y2 = map(int, best_ball_box.xyxy[0].tolist()) conf = best_ball_box.conf[0].item() # Calculer le point de référence (centre bas de la bbox) ball_ref_point_img = np.array([(x1 + x2) / 2, y2], dtype=np.float32) print(f" ✓ Ballon trouvé (conf: {conf:.2f}) à la bbox [{x1},{y1},{x2},{y2}]. Point réf: {ball_ref_point_img}") else: print(" Aucun ballon détecté.") except Exception as e_yolo: print(f" Erreur pendant la détection YOLO : {e_yolo}") print(f"Temps de détection YOLO : {time.time() - start_yolo:.3f}s") # 3. Exécuter la calibration (Segmentation + Optimisation) print("Exécution de la segmentation...") start_segment = time.time() with torch.no_grad(): keypoints = model._segment(image_tensor) print(f"Temps de segmentation : {time.time() - start_segment:.3f}s") print("Exécution de la calibration (optimisation)...") start_calibrate = time.time() homography = model._calibrate(keypoints) print(f"Temps de calibration : {time.time() - start_calibrate:.3f}s") if homography is None: print("Aucune homographie n'a pu être calculée.") # Retourner des placeholders avec message placeholder = np.zeros((300, 500, 3), dtype=np.uint8) cv2.putText(placeholder, "Homographie non calculee", (10, 150), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA) return placeholder, placeholder.copy() if isinstance(homography, torch.Tensor): homography_np = homography.detach().cpu().numpy() else: homography_np = np.array(homography) # Assurer que c'est un NumPy array print("✓ Homographie calculée.") # 4. Extraction des données joueurs print("Extraction des données joueurs (pose+couleur)...") start_pose = time.time() # get_player_data attend une image BGR player_list = get_player_data(image_bgr_resized) print(f"Temps d'extraction données joueurs : {time.time() - start_pose:.3f}s ({len(player_list)} joueurs trouvés)") # 5. Calcul de l'échelle de base print("Calcul de l'échelle de base...") # Reprend la logique de main.py pour estimer l'échelle de base avg_modulation_expected = DYNAMIC_SCALE_MIN_MODULATION + \ (DYNAMIC_SCALE_MAX_MODULATION - DYNAMIC_SCALE_MIN_MODULATION) * (1.0 - 0.5) estimated_base_scale = target_avg_scale if avg_modulation_expected != 0: estimated_base_scale = target_avg_scale / avg_modulation_expected print(f" Échelle de base interne estimée pour cible {target_avg_scale:.3f} : {estimated_base_scale:.3f}") # 6. Génération des minimaps print("Génération des minimaps...") start_viz = time.time() # Minimap avec projection (image RGB attendue par la fonction) minimap_original = create_minimap_view(image_rgb_resized, homography_np) # Minimap avec squelettes ET LE BALLON (utilise l'échelle estimée) minimap_offset_skeletons, actual_avg_scale = create_minimap_with_offset_skeletons( player_list, homography_np, base_skeleton_scale=estimated_base_scale, ball_ref_point_img=ball_ref_point_img # Passer le point de référence du ballon ) print(f"Temps de génération des minimaps : {time.time() - start_viz:.3f}s") if actual_avg_scale is not None: print(f"Échelle moyenne CIBLE demandée : {target_avg_scale:.3f}") print(f"Échelle moyenne FINALE RÉELLEMENT appliquée : {actual_avg_scale:.3f}") # Vérifier si les minimaps ont été créées (peuvent être None en cas d'erreur interne) if minimap_original is None: print("Erreur: La minimap originale n'a pas pu être générée.") minimap_original = np.zeros((300, 500, 3), dtype=np.uint8) cv2.putText(minimap_original, "Erreur Minimap Originale", (10, 150), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA) if minimap_offset_skeletons is None: print("Erreur: La minimap squelettes n'a pas pu être générée.") minimap_offset_skeletons = np.zeros((300, 500, 3), dtype=np.uint8) cv2.putText(minimap_offset_skeletons, "Erreur Minimap Squelettes", (10, 150), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA) # Gradio attend des images RGB pour l'affichage, nos fonctions retournent probablement BGR (via OpenCV) # Conversion BGR -> RGB si nécessaire if minimap_original.shape[2] == 3: # Assurer que c'est une image couleur minimap_original = cv2.cvtColor(minimap_original, cv2.COLOR_BGR2RGB) if minimap_offset_skeletons.shape[2] == 3: minimap_offset_skeletons = cv2.cvtColor(minimap_offset_skeletons, cv2.COLOR_BGR2RGB) print("✓ Traitement terminé.") return minimap_original, minimap_offset_skeletons except Exception as e: print(f"Erreur majeure lors du traitement : {e}") traceback.print_exc() # Retourner des placeholders avec message d'erreur général placeholder = np.zeros((300, 500, 3), dtype=np.uint8) cv2.putText(placeholder, f"Erreur: {e}", (10, 150), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 255), 1, cv2.LINE_AA) return placeholder, placeholder.copy() # --- Interface Gradio --- with gr.Blocks() as demo: gr.Markdown("# Foot Calib Pos Image Processor - Minimap Generator") gr.Markdown( "Upload a football pitch image to compute homography (TvCalib), " "detect players (RT-DETR/ViTPose), and generate two minimap visualizations." ) with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(type="numpy", label="Input Image (.jpg, .png)") optim_steps_slider = gr.Slider( minimum=100, maximum=2000, step=50, value=500, label="TvCalib Optimization Steps", info="Number of iterations to refine homography." ) target_scale_slider = gr.Slider( minimum=0.1, maximum=2.5, step=0.05, value=1, label="Target Average Skeleton Scale", info="Adjusts the desired average size of skeletons on the minimap." ) submit_button = gr.Button("Generate Minimaps", variant="primary") with gr.Column(scale=2): output_minimap_orig = gr.Image(type="numpy", label="Minimap with Original Projection", interactive=False) output_minimap_skel = gr.Image(type="numpy", label="Minimap with Offset Skeletons", interactive=False) # Connecter le bouton à la fonction de traitement submit_button.click( fn=process_image_and_generate_minimaps, inputs=[input_image, optim_steps_slider, target_scale_slider], outputs=[output_minimap_orig, output_minimap_skel] ) # Ajouter des exemples (optionnel mais utile pour Spaces) gr.Examples( examples=[ ["data/img1.png", 500, 1.35], ["data/img2.png", 1000, 1.5], ["data/img3.png", 500, 0.8], ["data/7.jpg", 500, 1], # Add .jpg examples ["data/15.jpg", 800, 1.35], ], inputs=[input_image, optim_steps_slider, target_scale_slider], outputs=[output_minimap_orig, output_minimap_skel], # Outputs won't be pre-calculated here, just to populate inputs fn=process_image_and_generate_minimaps, # Function will be called if example is clicked cache_examples=False # Important if processing is long or depends on external models ) # --- Lancement de l'application --- if __name__ == "__main__": # share=True creates a temporary public link (useful for testing outside localhost) # debug=True shows more Gradio logs in the console demo.launch(debug=True)