RamziBm commited on
Commit
bdb955e
·
0 Parent(s):
Files changed (46) hide show
  1. .gitattributes +2 -0
  2. .gitignore +12 -0
  3. README.md +107 -0
  4. app.py +234 -0
  5. common/data/augmentation.py +55 -0
  6. common/data/calib.py +116 -0
  7. common/data/transforms.py +23 -0
  8. common/data/utils.py +72 -0
  9. common/infer/base.py +43 -0
  10. common/infer/module.py +24 -0
  11. common/infer/sink.py +42 -0
  12. common/loggers/homography_previewer.py +103 -0
  13. common/loggers/image_preview.py +93 -0
  14. main.py +199 -0
  15. pose_estimator.py +265 -0
  16. requirements.txt +22 -0
  17. tvcalib/cam_distr/tv_main_behind.py +77 -0
  18. tvcalib/cam_distr/tv_main_center.py +78 -0
  19. tvcalib/cam_distr/tv_main_left.py +77 -0
  20. tvcalib/cam_distr/tv_main_right.py +77 -0
  21. tvcalib/cam_distr/tv_main_tribune.py +77 -0
  22. tvcalib/cam_modules.py +583 -0
  23. tvcalib/data/dataset.py +142 -0
  24. tvcalib/data/utils.py +166 -0
  25. tvcalib/infer/module.py +518 -0
  26. tvcalib/models/segmentation.py +22 -0
  27. tvcalib/sn_segmentation/resources/mean.npy +0 -0
  28. tvcalib/sn_segmentation/resources/std.npy +0 -0
  29. tvcalib/sn_segmentation/src/baseline_extremities.py +311 -0
  30. tvcalib/sn_segmentation/src/custom_extremities.py +322 -0
  31. tvcalib/sn_segmentation/src/dataloader.py +122 -0
  32. tvcalib/sn_segmentation/src/evaluate_extremities.py +270 -0
  33. tvcalib/sn_segmentation/src/masks_gt2chen.py +217 -0
  34. tvcalib/sn_segmentation/src/masks_pred2chen.py +150 -0
  35. tvcalib/sn_segmentation/src/segmentation/README.md +23 -0
  36. tvcalib/sn_segmentation/src/segmentation/coco_utils.py +108 -0
  37. tvcalib/sn_segmentation/src/segmentation/presets.py +39 -0
  38. tvcalib/sn_segmentation/src/segmentation/soccerdata.py +164 -0
  39. tvcalib/sn_segmentation/src/segmentation/train.py +341 -0
  40. tvcalib/sn_segmentation/src/segmentation/transforms.py +100 -0
  41. tvcalib/sn_segmentation/src/segmentation/utils.py +304 -0
  42. tvcalib/utils/data_distr.py +44 -0
  43. tvcalib/utils/io.py +44 -0
  44. tvcalib/utils/linalg.py +106 -0
  45. tvcalib/utils/objects_3d.py +1674 -0
  46. visualizer.py +298 -0
.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.jpg filter=lfs diff=lfs merge=lfs -text
2
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ venv
2
+ *.pyc
3
+ *.jpg
4
+ *.png
5
+ *.jpeg
6
+ *.gif
7
+ *.bmp
8
+ *.tiff
9
+ *.ico
10
+
11
+
12
+ *.pt
README.md ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Foot Calib Pos Image Processor
2
+
3
+ This project uses the TvCalib library to calculate the homography matrix for a football pitch image. This matrix allows mapping image points onto a standard 2D representation of the pitch (minimap).
4
+
5
+ The project also includes a pose estimation step (ViTPose) to detect players and calculate the average color of their torso.
6
+
7
+ **The main result is a minimap where each player is represented by their colored skeleton, drawn at a *dynamically* reduced scale around their projected position on the pitch.**
8
+ The position is determined by projecting a reference point (feet/bottom of bbox) using the homography. The skeleton is then drawn using its relative coordinates (original image), scaled, and translated. The scale used depends on the **player's vertical position (Y) on the minimap** (higher on the minimap = smaller) and a base factor adjustable via the `--target_avg_scale` option.
9
+
10
+ It is also possible to visualize a minimap with the projection of the original image for comparison.
11
+
12
+ ## Features
13
+
14
+ * Homography calculation from a single image via TvCalib.
15
+ * Person detection (RT-DETR) and pose estimation (ViTPose).
16
+ * Calculation of the filtered average torso color for each player.
17
+ * Projection of each player's reference point (feet/bbox) onto the minimap.
18
+ * Generation of a minimap with the **original skeletons (colored, dynamically scaled based on projected Y position, and offset) drawn around the projected point**.
19
+ * (Optional) Generation of a minimap with the projected original image.
20
+ * Possibility to save the calculated homography matrix.
21
+
22
+ ## Project Structure
23
+
24
+ ```
25
+ .
26
+ ├── .git/ # Git metadata
27
+ ├── .venv/ # Python virtual environment (recommended)
28
+ ├── common/ # Common Python modules (potentially)
29
+ ├── data/ # Data (input images, etc.)
30
+ ├── models/
31
+ │ └── segmentation/
32
+ │ └── train_59.pt # Pre-trained segmentation model (TO DOWNLOAD)
33
+ ├── tvcalib/ # Source code of the TvCalib library (or a fork/adaptation)
34
+ │ └── infer/
35
+ │ └── module.py # Main module for TvCalib inference
36
+ ├── .gitignore # Files ignored by Git
37
+ ├── main.py # Main script entry point
38
+ ├── requirements.txt # Python dependencies file
39
+ ├── visualizer.py # Module for generating visualization minimaps
40
+ ├── pose_estimator.py # Module for pose estimation and player data extraction
41
+ └── README.md # This file
42
+ ```
43
+
44
+ ## Installation
45
+
46
+ 1. **Clone the repository:**
47
+ ```powershell
48
+ git clone <repository-url>
49
+ cd Foot_calib_pos_image_processor
50
+ ```
51
+
52
+ 2. **Create a virtual environment (recommended):**
53
+ ```powershell
54
+ python -m venv venv
55
+ .\venv\Scripts\Activate.ps1
56
+ ```
57
+
58
+ 3. **Install dependencies:**
59
+ ```powershell
60
+ pip install -r requirements.txt
61
+ ```
62
+ *(Make sure to install PyTorch with appropriate CUDA support if needed.)*
63
+
64
+ 4. **Download the segmentation model:**
65
+ Place `train_59.pt` in `models/segmentation/`.
66
+
67
+ 5. **(Automatic) Download detection/pose models:**
68
+ The RT-DETR and ViTPose models will be downloaded automatically.
69
+
70
+ ## Usage
71
+
72
+ Run the `main.py` script providing the path to the image:
73
+
74
+ ```powershell
75
+ python main.py path/to/your/image.jpg [OPTIONS]
76
+ ```
77
+
78
+ **Options:**
79
+
80
+ * `image_path`: Path to the input image (required).
81
+ * `--output_homography PATH.npy`: Saves the calculated homography matrix.
82
+ * `--optim_steps NUMBER`: Number of optimization steps for calibration (default: 500, was 1000 in original README example).
83
+ * `--target_avg_scale FLOAT`: **Target average** scale factor for drawing skeletons (default: 0.35). The script attempts to adjust the internal base scale so that the resulting average scale (after inverse dynamic modulation) is close to this value.
84
+
85
+ **Example:**
86
+
87
+ ```powershell
88
+ # Simple usage (target average size 0.35)
89
+ python main.py data/img3.png
90
+
91
+ # Aim for larger skeletons on average (target 0.5)
92
+ python main.py data/img2.png --target_avg_scale 0.5
93
+ ```
94
+
95
+ The script will display:
96
+ * Time taken and homography matrix.
97
+ * Estimated internal base scale.
98
+ * Requested TARGET average scale.
99
+ * ACTUALLY applied FINAL average scale.
100
+ * Window: **Minimap with Original Projection**.
101
+ * Window: **Minimap with Offset Skeletons** (dynamically scaled inversely, targeting the average scale).
102
+ * Press any key to close.
103
+
104
+ ## Key Dependencies
105
+
106
+ * PyTorch, OpenCV, NumPy, PyTorch Lightning
107
+ * SoccerNet, Kornia, Hugging Face Transformers, Pillow
app.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ from pathlib import Path
6
+ import time
7
+ import traceback
8
+
9
+ # Importer les éléments nécessaires depuis les autres modules du projet
10
+ try:
11
+ from tvcalib.infer.module import TvCalibInferModule
12
+ # On essaie d'importer la fonction de pré-traitement depuis main.py
13
+ # Si main.py n'est pas conçu pour être importé, il faudra peut-être copier/coller cette fonction ici
14
+ from main import preprocess_image_tvcalib, IMAGE_SHAPE, SEGMENTATION_MODEL_PATH
15
+ from visualizer import (
16
+ create_minimap_view,
17
+ create_minimap_with_offset_skeletons,
18
+ DYNAMIC_SCALE_MIN_MODULATION,
19
+ DYNAMIC_SCALE_MAX_MODULATION
20
+ )
21
+ from pose_estimator import get_player_data
22
+ except ImportError as e:
23
+ print(f"Erreur d'importation : {e}")
24
+ print("Assurez-vous que les modules tvcalib, main, visualizer, pose_estimator sont accessibles.")
25
+ # On pourrait mettre des stubs ou lever une exception ici pour Gradio
26
+ raise e
27
+
28
+ # --- Configuration Globale (Modèle, etc.) ---
29
+ # Essayer de charger le modèle une seule fois globalement peut améliorer les performances
30
+ # mais attention à la gestion de l'état dans les environnements multi-utilisateurs/threads de Spaces
31
+ # Pour l'instant, on le chargera dans la fonction de traitement.
32
+ # MODEL = None # Optionnel: Charger ici
33
+ # DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
34
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
35
+
36
+ print(f"Utilisation du device : {DEVICE}")
37
+
38
+ if not SEGMENTATION_MODEL_PATH.exists():
39
+ print(f"AVERTISSEMENT : Modèle de segmentation introuvable : {SEGMENTATION_MODEL_PATH}")
40
+ print("L'application risque de ne pas fonctionner. Assurez-vous que le fichier est présent.")
41
+ # Gradio peut quand même démarrer, mais le traitement échouera.
42
+
43
+ # --- Fonction Principale de Traitement ---
44
+ def process_image_and_generate_minimaps(input_image_bgr, optim_steps, target_avg_scale):
45
+ """
46
+ Prend une image BGR (NumPy), les étapes d'optimisation et l'échelle cible,
47
+ retourne les deux minimaps (NumPy BGR).
48
+ """
49
+ global DEVICE # Utiliser le device défini globalement
50
+
51
+ print("\n--- Nouvelle requête ---")
52
+ print(f"Paramètres: optim_steps={optim_steps}, target_avg_scale={target_avg_scale}")
53
+
54
+ # Vérifier si le modèle de segmentation existe (important car on ne peut pas l'afficher dans l'UI facilement)
55
+ if not SEGMENTATION_MODEL_PATH.exists():
56
+ # Retourner des images noires ou des messages d'erreur
57
+ error_msg = f"Erreur: Modèle {SEGMENTATION_MODEL_PATH} introuvable."
58
+ print(error_msg)
59
+ placeholder = np.zeros((300, 500, 3), dtype=np.uint8) # Placeholder noir
60
+ cv2.putText(placeholder, error_msg, (10, 150), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA)
61
+ return placeholder, placeholder.copy() # Retourner deux placeholders
62
+
63
+ try:
64
+ # 1. Initialisation du modèle TvCalib (peut être lent si fait à chaque fois)
65
+ # Pourrait être optimisé en chargeant globalement (voir commentaire plus haut)
66
+ print("Initialisation de TvCalibInferModule...")
67
+ start_init = time.time()
68
+ model = TvCalibInferModule(
69
+ segmentation_checkpoint=SEGMENTATION_MODEL_PATH,
70
+ image_shape=IMAGE_SHAPE, # Utilise la constante importée
71
+ optim_steps=int(optim_steps), # Assurer que c'est un entier
72
+ lens_dist=False
73
+ )
74
+ # Déplacer le modèle sur le bon device ici explicitement si nécessaire
75
+ # model.to(DEVICE) # TvCalibInferModule devrait gérer ça en interne ? A vérifier.
76
+ print(f"✓ Modèle chargé sur {next(model.model_calib.parameters()).device} en {time.time() - start_init:.3f}s")
77
+ model_device = next(model.model_calib.parameters()).device # Vérifier le device réel
78
+
79
+ # 2. Prétraitement de l'image
80
+ print("Prétraitement de l'image...")
81
+ start_preprocess = time.time()
82
+ # preprocess_image_tvcalib attend BGR, Gradio fournit BGR par défaut avec type="numpy"
83
+ # Assurez-vous que preprocess_image_tvcalib déplace bien le tenseur sur le bon device
84
+ image_tensor, image_bgr_resized, image_rgb_resized = preprocess_image_tvcalib(input_image_bgr)
85
+ # Vérifier/forcer le device du tenseur
86
+ image_tensor = image_tensor.to(model_device)
87
+ print(f"Temps de prétraitement TvCalib : {time.time() - start_preprocess:.3f}s")
88
+
89
+
90
+ # 3. Exécuter la calibration (Segmentation + Optimisation)
91
+ print("Exécution de la segmentation...")
92
+ start_segment = time.time()
93
+ with torch.no_grad():
94
+ keypoints = model._segment(image_tensor)
95
+ print(f"Temps de segmentation : {time.time() - start_segment:.3f}s")
96
+
97
+ print("Exécution de la calibration (optimisation)...")
98
+ start_calibrate = time.time()
99
+ homography = model._calibrate(keypoints)
100
+ print(f"Temps de calibration : {time.time() - start_calibrate:.3f}s")
101
+
102
+ if homography is None:
103
+ print("Aucune homographie n'a pu être calculée.")
104
+ # Retourner des placeholders avec message
105
+ placeholder = np.zeros((300, 500, 3), dtype=np.uint8)
106
+ cv2.putText(placeholder, "Homographie non calculee", (10, 150), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA)
107
+ return placeholder, placeholder.copy()
108
+
109
+ if isinstance(homography, torch.Tensor):
110
+ homography_np = homography.detach().cpu().numpy()
111
+ else:
112
+ homography_np = np.array(homography) # Assurer que c'est un NumPy array
113
+ print("✓ Homographie calculée.")
114
+
115
+
116
+ # 4. Extraction des données joueurs
117
+ print("Extraction des données joueurs (pose+couleur)...")
118
+ start_pose = time.time()
119
+ # get_player_data attend une image BGR
120
+ player_list = get_player_data(image_bgr_resized)
121
+ print(f"Temps d'extraction données joueurs : {time.time() - start_pose:.3f}s ({len(player_list)} joueurs trouvés)")
122
+
123
+ # 5. Calcul de l'échelle de base
124
+ print("Calcul de l'échelle de base...")
125
+ # Reprend la logique de main.py pour estimer l'échelle de base
126
+ avg_modulation_expected = DYNAMIC_SCALE_MIN_MODULATION + \
127
+ (DYNAMIC_SCALE_MAX_MODULATION - DYNAMIC_SCALE_MIN_MODULATION) * (1.0 - 0.5)
128
+ estimated_base_scale = target_avg_scale
129
+ if avg_modulation_expected != 0:
130
+ estimated_base_scale = target_avg_scale / avg_modulation_expected
131
+ print(f" Échelle de base interne estimée pour cible {target_avg_scale:.3f} : {estimated_base_scale:.3f}")
132
+
133
+ # 6. Génération des minimaps
134
+ print("Génération des minimaps...")
135
+ start_viz = time.time()
136
+ # Minimap avec projection (image RGB attendue par la fonction)
137
+ minimap_original = create_minimap_view(image_rgb_resized, homography_np)
138
+
139
+ # Minimap avec squelettes (utilise l'échelle estimée)
140
+ minimap_offset_skeletons, actual_avg_scale = create_minimap_with_offset_skeletons(
141
+ player_list,
142
+ homography_np,
143
+ base_skeleton_scale=estimated_base_scale
144
+ )
145
+ print(f"Temps de génération des minimaps : {time.time() - start_viz:.3f}s")
146
+ if actual_avg_scale is not None:
147
+ print(f"Échelle moyenne CIBLE demandée : {target_avg_scale:.3f}")
148
+ print(f"Échelle moyenne FINALE RÉELLEMENT appliquée : {actual_avg_scale:.3f}")
149
+
150
+
151
+ # Vérifier si les minimaps ont été créées (peuvent être None en cas d'erreur interne)
152
+ if minimap_original is None:
153
+ print("Erreur: La minimap originale n'a pas pu être générée.")
154
+ minimap_original = np.zeros((300, 500, 3), dtype=np.uint8)
155
+ cv2.putText(minimap_original, "Erreur Minimap Originale", (10, 150), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA)
156
+
157
+ if minimap_offset_skeletons is None:
158
+ print("Erreur: La minimap squelettes n'a pas pu être générée.")
159
+ minimap_offset_skeletons = np.zeros((300, 500, 3), dtype=np.uint8)
160
+ cv2.putText(minimap_offset_skeletons, "Erreur Minimap Squelettes", (10, 150), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA)
161
+
162
+ # Gradio attend des images RGB pour l'affichage, nos fonctions retournent probablement BGR (via OpenCV)
163
+ # Conversion BGR -> RGB si nécessaire
164
+ if minimap_original.shape[2] == 3: # Assurer que c'est une image couleur
165
+ minimap_original = cv2.cvtColor(minimap_original, cv2.COLOR_BGR2RGB)
166
+ if minimap_offset_skeletons.shape[2] == 3:
167
+ minimap_offset_skeletons = cv2.cvtColor(minimap_offset_skeletons, cv2.COLOR_BGR2RGB)
168
+
169
+ print("✓ Traitement terminé.")
170
+ return minimap_original, minimap_offset_skeletons
171
+
172
+ except Exception as e:
173
+ print(f"Erreur majeure lors du traitement : {e}")
174
+ traceback.print_exc()
175
+ # Retourner des placeholders avec message d'erreur général
176
+ placeholder = np.zeros((300, 500, 3), dtype=np.uint8)
177
+ cv2.putText(placeholder, f"Erreur: {e}", (10, 150), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 255), 1, cv2.LINE_AA)
178
+ return placeholder, placeholder.copy()
179
+
180
+ # --- Interface Gradio ---
181
+ with gr.Blocks() as demo:
182
+ gr.Markdown("# Foot Calib Pos Image Processor - Minimap Generator")
183
+ gr.Markdown(
184
+ "Upload a football pitch image to compute homography (TvCalib), "
185
+ "detect players (RT-DETR/ViTPose), and generate two minimap visualizations."
186
+ )
187
+
188
+ with gr.Row():
189
+ with gr.Column(scale=1):
190
+ input_image = gr.Image(type="numpy", label="Input Image (.jpg, .png)")
191
+ optim_steps_slider = gr.Slider(
192
+ minimum=100, maximum=2000, step=50, value=500,
193
+ label="TvCalib Optimization Steps",
194
+ info="Number of iterations to refine homography."
195
+ )
196
+ target_scale_slider = gr.Slider(
197
+ minimum=0.1, maximum=2.5, step=0.05, value=0.35,
198
+ label="Target Average Skeleton Scale",
199
+ info="Adjusts the desired average size of skeletons on the minimap."
200
+ )
201
+ submit_button = gr.Button("Generate Minimaps", variant="primary")
202
+
203
+ with gr.Column(scale=2):
204
+ output_minimap_orig = gr.Image(type="numpy", label="Minimap with Original Projection", interactive=False)
205
+ output_minimap_skel = gr.Image(type="numpy", label="Minimap with Offset Skeletons", interactive=False)
206
+
207
+ # Connecter le bouton à la fonction de traitement
208
+ submit_button.click(
209
+ fn=process_image_and_generate_minimaps,
210
+ inputs=[input_image, optim_steps_slider, target_scale_slider],
211
+ outputs=[output_minimap_orig, output_minimap_skel]
212
+ )
213
+
214
+ # Ajouter des exemples (optionnel mais utile pour Spaces)
215
+ gr.Examples(
216
+ examples=[
217
+ ["data/img1.png", 500, 1.35],
218
+ ["data/img2.png", 1000, 1.5],
219
+ ["data/img3.png", 500, 0.8],
220
+ ["data/7.jpg", 500, 1], # Add .jpg examples
221
+ ["data/15.jpg", 800, 1.35],
222
+ ],
223
+ inputs=[input_image, optim_steps_slider, target_scale_slider],
224
+ outputs=[output_minimap_orig, output_minimap_skel], # Outputs won't be pre-calculated here, just to populate inputs
225
+ fn=process_image_and_generate_minimaps, # Function will be called if example is clicked
226
+ cache_examples=False # Important if processing is long or depends on external models
227
+ )
228
+
229
+
230
+ # --- Lancement de l'application ---
231
+ if __name__ == "__main__":
232
+ # share=True creates a temporary public link (useful for testing outside localhost)
233
+ # debug=True shows more Gradio logs in the console
234
+ demo.launch(debug=True)
common/data/augmentation.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ from typing import Callable, Any
4
+ import methods.common.data.utils as utils
5
+ import random
6
+
7
+
8
+
9
+ class WarpAugmentation(Callable):
10
+ def __init__(
11
+ self,
12
+ warp_function: Any,
13
+ mode: str="train",
14
+ noise_translate: float=0.0,
15
+ noise_rotate: float=0.0
16
+ ):
17
+ # Template
18
+ self.template_grid = utils.gen_template_grid()
19
+ self.warp_fn = warp_function
20
+ self.mode = mode
21
+ self.noise_translate = noise_translate
22
+ self.noise_rotate = noise_rotate
23
+
24
+ def __call__(
25
+ self,
26
+ image: np.ndarray,
27
+ homography: np.ndarray,
28
+ frame_idx: int
29
+ ):
30
+ warp_image, warp_grid, warp_homography = self.warp_fn(
31
+ mode=self.mode,
32
+ frame=image,
33
+ f_idx=frame_idx,
34
+ gt_homo=homography,
35
+ template=self.template_grid,
36
+ noise_trans=self.noise_translate,
37
+ noise_rotate=self.noise_rotate,
38
+ index=-1 # not really used ...
39
+ )
40
+ return warp_image, warp_grid, warp_homography
41
+
42
+
43
+
44
+
45
+ class LeftRightFlipAugmentation(Callable):
46
+ def __init__(self, enabled: bool=False):
47
+ self.enabled = enabled
48
+
49
+ def __call__(self, image, grid):
50
+ if (self.enabled):
51
+ if (random.random() < 0.5):
52
+ image, grid = utils.put_lrflip_augmentation(image, grid)
53
+ return image, grid, True
54
+
55
+ return image, grid, False
common/data/calib.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ import methods.common.data.utils as utils
6
+
7
+
8
+
9
+ class Calibrator:
10
+ def __init__(
11
+ self,
12
+ num_classes: int=92,
13
+ nms_threshold: float=0.995
14
+ ):
15
+ self.template_grid = utils.gen_template_grid()
16
+ self.num_classes = num_classes
17
+ self.nms_threshold = nms_threshold
18
+
19
+
20
+ def find_homography(
21
+ self,
22
+ heatmap_logits: torch.Tensor
23
+ ):
24
+ """ Extract keypoints from heatmap, and find homography matrix
25
+
26
+ heatmap_logits: torch.Tensor for individual frame (not a mini-batch!)
27
+ """
28
+
29
+ pred_rgb, pred_keypoints, scores_heatmap = self.decode_heatmap(heatmap_logits)
30
+ homography = None
31
+
32
+ # We need at least 4 point correspondences
33
+ if (pred_rgb.shape[0] >= 4):
34
+ src_pts, dst_pts = self.get_class_mapping(pred_rgb)
35
+
36
+ # Find homography from point correspondences
37
+ homography, _ = cv2.findHomography(
38
+ src_pts.reshape(-1, 1, 2),
39
+ dst_pts.reshape(-1, 1, 2),
40
+ cv2.RANSAC,
41
+ 10
42
+ )
43
+
44
+ return homography, pred_keypoints, scores_heatmap
45
+
46
+
47
+ def decode_heatmap(self, heatmap_logits: torch.Tensor):
48
+ """ Decode heatmap info keypoint set using non-maximum suppression
49
+ heatmap_logits: torc.Tensor with shape <NUM_CLASSES; H; W>
50
+ """
51
+
52
+ pred_heatmap = torch.softmax(heatmap_logits, dim=0)
53
+ arg = torch.argmax(pred_heatmap, dim=0).detach().cpu().numpy()
54
+ scores, pred_heatmap = torch.max(pred_heatmap, dim=0)
55
+
56
+ # Convert to Numpy & get keypoints locations
57
+ scores = scores.detach().cpu().numpy()
58
+ pred_heatmap = pred_heatmap.detach().cpu().numpy()
59
+ pred_class_dict = self.get_class_dict(scores, pred_heatmap)
60
+
61
+ # Colorize
62
+ num_classes = heatmap_logits.shape[0]
63
+ np_scores = np.clip(arg * 255.0 / num_classes, 0, 255).astype(np.uint8)
64
+ scores_heatmap = cv2.applyColorMap(np_scores, cv2.COLORMAP_HOT)
65
+ scores_heatmap = cv2.cvtColor(scores_heatmap, cv2.COLOR_BGR2RGB)
66
+
67
+ # Produce image with keypoints
68
+ pred_keypoints = np.zeros_like(pred_heatmap, dtype=np.uint8)
69
+ pred_rgb = []
70
+ for _, (pk, pv) in enumerate(pred_class_dict.items()):
71
+ if (pv):
72
+ pred_keypoints[pv[1][0], pv[1][1]] = pk # (H,W)
73
+ # camera view point sets (x, y, label) in rgb domain not heatmap domain
74
+ pred_rgb.append([pv[1][1] * 4, pv[1][0] * 4, pk])
75
+ pred_rgb = np.asarray(pred_rgb, dtype=np.float32) # (?, 3)
76
+
77
+ # Return list of point locations, and image of keypoints
78
+ return pred_rgb, pred_keypoints, scores_heatmap
79
+
80
+
81
+
82
+ def get_class_mapping(self, rgb):
83
+ src_pts = rgb.copy()
84
+ cls_map_pts = []
85
+
86
+ for ind, elem in enumerate(src_pts):
87
+ coords = np.where(elem[2] == self.template_grid[:, 2])[0] # find correspondence
88
+ cls_map_pts.append(self.template_grid[coords[0]])
89
+ dst_pts = np.array(cls_map_pts, dtype=np.float32)
90
+
91
+ return src_pts[:, :2], dst_pts[:, :2]
92
+
93
+
94
+ def get_class_dict(self, scores, pred):
95
+ # Decode
96
+ pred_cls_dict = {k: [] for k in range(1, self.num_classes)}
97
+ for cls in range(1, self.num_classes):
98
+ pred_inds = (pred == cls)
99
+
100
+ # implies the current class does not appear in this heatmaps
101
+ if not np.any(pred_inds):
102
+ continue
103
+
104
+ values = scores[pred_inds]
105
+ max_score = values.max()
106
+ max_index = values.argmax()
107
+
108
+ indices = np.where(pred_inds)
109
+ coords = list(zip(indices[0], indices[1]))
110
+
111
+ # the only keypoint with max confidence is greater than threshold or not
112
+ if max_score >= self.nms_threshold:
113
+ pred_cls_dict[cls].append(max_score)
114
+ pred_cls_dict[cls].append(coords[max_index])
115
+
116
+ return pred_cls_dict
common/data/transforms.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+ import torch
3
+ import numpy as np
4
+
5
+
6
+
7
+ class UnNormalize(object):
8
+ def __init__(self, mean, std):
9
+ self.mean = mean
10
+ self.std = std
11
+
12
+ def __call__(self, tensor):
13
+ """
14
+ Args:
15
+ tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
16
+ Returns:
17
+ Tensor: Normalized image.
18
+ """
19
+ for t, m, s in zip(tensor, self.mean, self.std):
20
+ t.mul_(s).add_(m)
21
+ # The normalize code -> t.sub_(m).div_(s)
22
+ return tensor
23
+
common/data/utils.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import random
4
+ import torch
5
+
6
+
7
+
8
+ def yards(x):
9
+ return x * 1.0936132983
10
+
11
+
12
+
13
+ def to_torch(ndarray):
14
+ if type(ndarray).__module__ == 'numpy':
15
+ return torch.from_numpy(ndarray.copy())
16
+ elif not torch.is_tensor(ndarray):
17
+ raise ValueError("Cannot convert {} to torch tensor"
18
+ .format(type(ndarray)))
19
+ return ndarray
20
+
21
+
22
+
23
+ def gen_template_grid():
24
+ # === set uniform grid ===
25
+ # field_dim_x, field_dim_y = 105.000552, 68.003928 # in meter
26
+ field_dim_x, field_dim_y = 114.83, 74.37 # in yard
27
+ # field_dim_x, field_dim_y = 115, 74 # in yard
28
+ nx, ny = (13, 7)
29
+ x = np.linspace(0, field_dim_x, nx)
30
+ y = np.linspace(0, field_dim_y, ny)
31
+ xv, yv = np.meshgrid(x, y, indexing='ij')
32
+ uniform_grid = np.stack((xv, yv), axis=2).reshape(-1, 2)
33
+ uniform_grid = np.concatenate((uniform_grid, np.ones(
34
+ (uniform_grid.shape[0], 1))), axis=1) # top2bottom, left2right
35
+ # TODO: class label in template, each keypoints is (x, y, c), c is label that starts from 1
36
+ for idx, pts in enumerate(uniform_grid):
37
+ pts[2] = idx + 1 # keypoints label
38
+ return uniform_grid
39
+
40
+
41
+ def put_lrflip_augmentation(frame, unigrid):
42
+
43
+ frame_h, frame_w = frame.shape[0], frame.shape[1]
44
+ flipped_img = np.fliplr(frame)
45
+
46
+ # TODO: grid flipping and re-assign pixels class label, 1-91
47
+ for ind, pts in enumerate(unigrid):
48
+ pts[0] = frame_w - pts[0]
49
+ col = (pts[2] - 1) // 7 # get each column of uniform grid
50
+ pts[2] = pts[2] - (col - 6) * 2 * 7 # keypoints label
51
+
52
+ return flipped_img, unigrid
53
+
54
+
55
+
56
+
57
+ def make_grid(images, nrow=4):
58
+ num_images = len(images)
59
+ ih, iw = images[0].shape[0], images[0].shape[1]
60
+ rows = min(num_images, nrow)
61
+ cols = (num_images + nrow-1) // nrow
62
+
63
+ result = np.zeros(shape=(cols*ih, rows*iw, 3), dtype=np.uint8)
64
+ for i in range(num_images):
65
+ cell_x = i%nrow
66
+ cell_y = i//nrow
67
+ result[
68
+ (cell_y+0)*ih:(cell_y+1)*ih,
69
+ (cell_x+0)*iw:(cell_x+1)*iw
70
+ ] = images[i]
71
+
72
+ return result
common/infer/base.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Any, Dict
3
+ from torch.utils.data import Dataset
4
+
5
+ import numpy as np
6
+
7
+
8
+
9
+
10
+ class InferDataModule:
11
+ def __init__(self):
12
+ pass
13
+
14
+
15
+ def get_inference_dataset(self) -> Dataset:
16
+ """ Return the dataset to run inference on """
17
+ pass
18
+
19
+
20
+
21
+ class InferModule:
22
+ def __init__(self):
23
+ pass
24
+
25
+
26
+ def setup(
27
+ self,
28
+ datamodule: InferDataModule
29
+ ):
30
+ """ Initialize inference proces with the given datamodule """
31
+ pass
32
+
33
+
34
+ def predict(
35
+ self,
36
+ x: Any
37
+ ) -> Dict:
38
+ """ Predict the calibration information for the given dataset sample """
39
+ return None
40
+
41
+
42
+
43
+
common/infer/module.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
2
+ from methods.common.infer.base import *
3
+
4
+
5
+
6
+ class LabelInferModule(InferModule):
7
+ def __init__(self):
8
+ pass
9
+
10
+ def setup(self, datamodule: InferDataModule):
11
+ pass
12
+
13
+
14
+ def predict(self, x: Any) -> Dict:
15
+ """
16
+ x - sample from dataset (including label)
17
+ """
18
+
19
+ # Extract homography matrix
20
+ result = {
21
+ "homography": x["homography"]
22
+ }
23
+
24
+ return result
common/infer/sink.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from pathlib import Path
3
+ from collections import defaultdict
4
+ import pandas as pd
5
+ from typing import Dict
6
+ import cv2
7
+
8
+
9
+ class PredictionsSink:
10
+ def __init__(
11
+ self,
12
+ target_filepath: Path
13
+ ):
14
+ self.target_filepath = target_filepath
15
+ def new_item():
16
+ return []
17
+ self.data = defaultdict(new_item)
18
+ self.count = 0
19
+
20
+
21
+ def write(self, item: Dict):
22
+ self.data["item"].append(self.count)
23
+ for k,v in item.items():
24
+ if ("image" in k):
25
+ self.write_image(self.count, k, v)
26
+ else:
27
+ self.data[k].append(v)
28
+
29
+ self.count += 1
30
+
31
+
32
+ def flush(self):
33
+ df = pd.DataFrame(self.data)
34
+ self.target_filepath.parent.mkdir(parents=True, exist_ok=True)
35
+ df.to_csv(self.target_filepath)
36
+
37
+ def write_image(self, idx, name, image):
38
+ folder = self.target_filepath.parent / "images" / self.target_filepath.stem / name
39
+ filepath = folder / f"{idx:06d}.png"
40
+ filepath.parent.mkdir(parents=True, exist_ok=True)
41
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
42
+ cv2.imwrite(filepath.as_posix(), image)
common/loggers/homography_previewer.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ import skimage.segmentation as ss
4
+ import cv2
5
+
6
+ from torchvision import transforms
7
+
8
+ from project.data.transforms import TensorToNumpy
9
+
10
+
11
+ from methods.common.loggers.image_preview import ImagePreviewLogger
12
+ from methods.common.data.transforms import UnNormalize
13
+ from methods.common.data.calib import Calibrator
14
+
15
+
16
+ class HomographyPreviewerLogger(ImagePreviewLogger):
17
+ def __init__(
18
+ self,
19
+ experiment,
20
+ num_rows: int=3,
21
+ ):
22
+ super().__init__(experiment, num_rows)
23
+
24
+ self.calib = Calibrator()
25
+ self.to_image = transforms.Compose([
26
+ UnNormalize(
27
+ mean=[0.485, 0.456, 0.406],
28
+ std=[0.229, 0.224, 0.225]
29
+ ),
30
+ TensorToNumpy()
31
+ ])
32
+
33
+
34
+
35
+
36
+ def draw_keypoints(self, image, image_keypoints, color=(255,0,0)):
37
+ """ Upscales keypoints map into image resolution and
38
+ overlays it over the image
39
+ """
40
+
41
+ # Get keypoints image in target image resolution
42
+ a = ss.expand_labels(image_keypoints, distance=1)
43
+ a = cv2.resize(
44
+ a,
45
+ dsize=(image.shape[1], image.shape[0]),
46
+ interpolation=cv2.INTER_NEAREST
47
+ )
48
+ a = np.expand_dims(a, axis=2)
49
+
50
+ # Alpha of keypoints image
51
+ a = (a > 0)*1.0
52
+
53
+ # Color of keypoints image
54
+ c = np.concatenate([a*color[0], a*color[1], a*color[2]], axis=2)
55
+
56
+ # Superimpose the keypoints
57
+ result = (1.0-a)*image + a*c
58
+ result = np.clip(result, 0, 255).astype(np.uint8)
59
+ return result
60
+
61
+
62
+ def draw_playfield(
63
+ self,
64
+ image,
65
+ image_playfield,
66
+ homography,
67
+ color=(255,0,0),
68
+ alpha=1.0,
69
+ flip=False
70
+ ):
71
+ """ Draws the playfield image under the homography matrix
72
+ over the target image
73
+ """
74
+ if (homography is None):
75
+ return image
76
+
77
+ # Warp the playfield image
78
+ warp_field = cv2.warpPerspective(
79
+ image_playfield,
80
+ homography,
81
+ (image.shape[1], image.shape[0]),
82
+ cv2.INTER_LINEAR,
83
+ borderMode=cv2.BORDER_CONSTANT,
84
+ borderValue=(0)
85
+ )
86
+
87
+ if (flip):
88
+ warp_field = np.fliplr(warp_field)
89
+
90
+ # Get the alpha
91
+ a = np.expand_dims((warp_field / 255.0), axis=2)
92
+
93
+ # Color of playfield
94
+ c = np.concatenate([a*color[0], a*color[1], a*color[2]], axis=2)
95
+
96
+ # Draw with specified alpha
97
+ a = a * alpha
98
+
99
+ # Superimpose the playing field image
100
+ result = (1.0-a)*image + a*c
101
+ result = np.clip(result, 0, 255).astype(np.uint8)
102
+ return result
103
+
common/loggers/image_preview.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from project.base.logging import Logger
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+
7
+ import methods.common.data.utils as utils
8
+
9
+
10
+ class ImagePreviewLogger(Logger):
11
+ def __init__(
12
+ self,
13
+ experiment,
14
+ num_rows: int=3
15
+ ):
16
+ self.experiment = experiment
17
+ if (experiment is not None):
18
+ self.tracker = experiment.tracker
19
+ self.num_rows = num_rows
20
+
21
+ # Will be sampled later
22
+ self.samples = None
23
+ self.images = None
24
+
25
+
26
+ def on_training_start(self):
27
+
28
+ datamodule = self.experiment.datamodule
29
+ module = self.experiment.module
30
+
31
+ # Get the images
32
+ self.samples, images = self.sample_images(datamodule, module)
33
+ if (len(images) > 0):
34
+ self.images = torch.concatenate(
35
+ images,
36
+ dim=0
37
+ ).to(module.device)
38
+
39
+
40
+ def on_epoch_end(self, epoch, stats):
41
+ self.make_preview()
42
+
43
+
44
+ def make_preview(self):
45
+
46
+ # Get the model
47
+ model = self.experiment.module.model
48
+ if (isinstance(model, nn.DataParallel)):
49
+ model = model.module
50
+
51
+ # Get the preview results
52
+ model.eval()
53
+ items = self.process_images(model, self.images)
54
+ model.train()
55
+
56
+ # Arange into grid, and send to tracker
57
+ log_images = {}
58
+ # Unpack the list of dicts
59
+ for item in items:
60
+ for key,image in item.items():
61
+ if (not key in log_images):
62
+ log_images[key] = []
63
+ log_images[key].append(image)
64
+
65
+ # Arange images into grids
66
+ result = {}
67
+ for key, images in log_images.items():
68
+ result[key] = utils.make_grid(images, nrow=self.num_rows)
69
+
70
+ # Send to tracker
71
+ self.tracker.write_images(result)
72
+
73
+
74
+ def sample_dataset(self, dataset, num_images):
75
+ idx = np.random.choice(
76
+ len(dataset),
77
+ size=(num_images,),
78
+ replace=False
79
+ )
80
+ samples = [ dataset[i] for i in idx ]
81
+ return samples
82
+
83
+
84
+
85
+
86
+ def sample_images(self, datamodule, module):
87
+ """ Sample our datasets """
88
+ return [], []
89
+
90
+
91
+ def process_images(self, model, images):
92
+ """ Returns list of dict[key,image] items """
93
+ return []
main.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ from pathlib import Path
6
+ import time
7
+ import traceback
8
+
9
+ # Assurez-vous que le répertoire tvcalib est dans le PYTHONPATH
10
+ # ou exécutez depuis le répertoire tvcalib_image_processor
11
+ from tvcalib.infer.module import TvCalibInferModule
12
+ # Importer les fonctions de visualisation et les constantes de modulation
13
+ from visualizer import (
14
+ create_minimap_view,
15
+ create_minimap_with_offset_skeletons,
16
+ DYNAMIC_SCALE_MIN_MODULATION, # Importer les constantes
17
+ DYNAMIC_SCALE_MAX_MODULATION
18
+ )
19
+ # Importer la fonction d'extraction des données joueurs
20
+ from pose_estimator import get_player_data
21
+
22
+ # Constantes
23
+ IMAGE_SHAPE = (720, 1280) # Hauteur, Largeur
24
+ SEGMENTATION_MODEL_PATH = Path("models/segmentation/train_59.pt")
25
+
26
+ def preprocess_image_tvcalib(image_bgr):
27
+ """Prétraite l'image BGR pour TvCalib et retourne le tenseur et l'image RGB redimensionnée."""
28
+ if image_bgr is None:
29
+ raise ValueError("Impossible de charger l'image")
30
+
31
+ # 1. Redimensionner en 720p si nécessaire
32
+ h, w = image_bgr.shape[:2]
33
+ if h != IMAGE_SHAPE[0] or w != IMAGE_SHAPE[1]:
34
+ print(f"Redimensionnement de l'image vers {IMAGE_SHAPE[1]}x{IMAGE_SHAPE[0]}")
35
+ image_bgr_resized = cv2.resize(image_bgr, (IMAGE_SHAPE[1], IMAGE_SHAPE[0]), interpolation=cv2.INTER_LINEAR)
36
+ else:
37
+ image_bgr_resized = image_bgr
38
+
39
+ # 2. Convertir en RGB (pour TvCalib ET pour la visualisation originale)
40
+ image_rgb_resized = cv2.cvtColor(image_bgr_resized, cv2.COLOR_BGR2RGB)
41
+
42
+ # 3. Normalisation spécifique pour le modèle pré-entraîné (pour TvCalib)
43
+ image_tensor = torch.from_numpy(image_rgb_resized).float()
44
+ image_tensor = image_tensor.permute(2, 0, 1) # HWC -> CHW
45
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
46
+ std = torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
47
+ image_tensor = (image_tensor / 255.0 - mean) / std
48
+
49
+ # 4. Déplacer sur le bon device
50
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
51
+ image_tensor = image_tensor.to(device)
52
+
53
+ # Retourner le tenseur pour TvCalib, l'image BGR et RGB redimensionnée
54
+ return image_tensor, image_bgr_resized, image_rgb_resized
55
+
56
+ def main():
57
+ parser = argparse.ArgumentParser(description="Exécute la méthode TvCalib sur une seule image.")
58
+ parser.add_argument("image_path", type=str, help="Chemin vers l'image à traiter.")
59
+ parser.add_argument("--output_homography", type=str, default=None, help="Chemin optionnel pour sauvegarder la matrice d'homographie (.npy).")
60
+ parser.add_argument("--optim_steps", type=int, default=500, help="Nombre d'étapes d'optimisation pour la calibration (l'arrêt anticipé est désactivé).")
61
+ parser.add_argument("--target_avg_scale", type=float, default=1,
62
+ help="Facteur d'échelle MOYEN CIBLE pour dessiner les squelettes sur la minimap (défaut: 0.35). Le script ajuste l'échelle de base pour tenter d'atteindre cette moyenne.")
63
+
64
+ args = parser.parse_args()
65
+
66
+ if not Path(args.image_path).exists():
67
+ print(f"Erreur : Fichier image introuvable : {args.image_path}")
68
+ return
69
+
70
+ if not SEGMENTATION_MODEL_PATH.exists():
71
+ print(f"Erreur : Modèle de segmentation introuvable : {SEGMENTATION_MODEL_PATH}")
72
+ print("Assurez-vous d'avoir copié train_59.pt dans le dossier models/segmentation/")
73
+ return
74
+
75
+ print("Initialisation de TvCalibInferModule...")
76
+ try:
77
+ model = TvCalibInferModule(
78
+ segmentation_checkpoint=SEGMENTATION_MODEL_PATH,
79
+ image_shape=IMAGE_SHAPE,
80
+ optim_steps=args.optim_steps,
81
+ lens_dist=False # Gardons cela simple pour l'instant
82
+ )
83
+ print(f"✓ Modèle chargé sur {next(model.model_calib.parameters()).device}")
84
+ except Exception as e:
85
+ print(f"Erreur lors de l'initialisation du modèle : {e}")
86
+ return
87
+
88
+ print(f"Traitement de l'image : {args.image_path}")
89
+ try:
90
+ # Charger l'image (en BGR par défaut avec OpenCV)
91
+ image_bgr_orig = cv2.imread(args.image_path)
92
+ if image_bgr_orig is None:
93
+ raise FileNotFoundError(f"Impossible de lire le fichier image: {args.image_path}")
94
+
95
+ # Prétraiter l'image
96
+ start_preprocess = time.time()
97
+ image_tensor, image_bgr_resized, image_rgb_resized = preprocess_image_tvcalib(image_bgr_orig)
98
+ print(f"Temps de prétraitement TvCalib : {time.time() - start_preprocess:.3f}s")
99
+
100
+ # Exécuter la segmentation
101
+ print("Exécution de la segmentation...")
102
+ start_segment = time.time()
103
+ with torch.no_grad():
104
+ keypoints = model._segment(image_tensor)
105
+ print(f"Temps de segmentation : {time.time() - start_segment:.3f}s")
106
+
107
+ # Exécuter la calibration (optimisation)
108
+ print("Exécution de la calibration (optimisation)...")
109
+ start_calibrate = time.time()
110
+ homography = model._calibrate(keypoints)
111
+ print(f"Temps de calibration : {time.time() - start_calibrate:.3f}s")
112
+
113
+ if homography is not None:
114
+ print("\n--- Homographie Calculée ---")
115
+ if isinstance(homography, torch.Tensor):
116
+ homography_np = homography.detach().cpu().numpy()
117
+ else:
118
+ homography_np = homography
119
+ print(homography_np)
120
+
121
+ if args.output_homography:
122
+ try:
123
+ np.save(args.output_homography, homography_np)
124
+ print(f"\nHomographie sauvegardée dans : {args.output_homography}")
125
+ except Exception as e:
126
+ print(f"Erreur lors de la sauvegarde de l'homographie : {e}")
127
+
128
+ # --- Extraction des données joueurs ---
129
+ print("\nExtraction des données joueurs (pose+couleur)...")
130
+ start_pose = time.time()
131
+ player_list = get_player_data(image_bgr_resized)
132
+ print(f"Temps d'extraction données joueurs : {time.time() - start_pose:.3f}s ({len(player_list)} joueurs trouvés)")
133
+
134
+ # --- Calcul de l'échelle de base estimée ---
135
+ print("\nCalcul de l'échelle de base pour atteindre la cible...")
136
+ target_average_scale = args.target_avg_scale
137
+
138
+ # Calculer la modulation moyenne attendue (hypothèse: joueur moyen au centre Y=0.5)
139
+ # Logique inversée actuelle : MIN + (MAX - MIN) * (1.0 - norm_y)
140
+ avg_modulation_expected = DYNAMIC_SCALE_MIN_MODULATION + \
141
+ (DYNAMIC_SCALE_MAX_MODULATION - DYNAMIC_SCALE_MIN_MODULATION) * (1.0 - 0.5)
142
+
143
+ estimated_base_scale = target_average_scale # Valeur par défaut si modulation = 0
144
+ if avg_modulation_expected != 0:
145
+ estimated_base_scale = target_average_scale / avg_modulation_expected
146
+ else:
147
+ print("Avertissement : Modulation moyenne attendue nulle, impossible d'estimer l'échelle de base.")
148
+
149
+ print(f" Modulation dynamique moyenne attendue (pour Y=0.5) : {avg_modulation_expected:.3f}")
150
+ print(f" Échelle de base interne estimée pour cible {target_average_scale:.3f} : {estimated_base_scale:.3f}")
151
+
152
+ # --- Génération des DEUX minimaps ---
153
+ print("\nGénération des minimaps (Originale et Squelettes Décalés)...")
154
+
155
+ # 1. Minimap avec l'image originale (RGB)
156
+ minimap_original = create_minimap_view(image_rgb_resized, homography_np)
157
+
158
+ # 2. Minimap avec les squelettes
159
+ # Utiliser l'échelle de base ESTIMÉE
160
+ minimap_offset_skeletons, actual_avg_scale = create_minimap_with_offset_skeletons(
161
+ player_list,
162
+ homography_np,
163
+ base_skeleton_scale=estimated_base_scale # Utiliser l'estimation
164
+ )
165
+
166
+ # Afficher la cible et le résultat réel
167
+ if actual_avg_scale is not None:
168
+ print(f"\nÉchelle moyenne CIBLE demandée (--target_avg_scale) : {target_average_scale:.3f}")
169
+ print(f"Échelle moyenne FINALE RÉELLEMENT appliquée (basée sur joueurs réels) : {actual_avg_scale:.3f}")
170
+
171
+ # --- Affichage des résultats ---
172
+ print("\nAffichage des résultats. Appuyez sur une touche pour quitter.")
173
+
174
+ # Afficher la minimap originale
175
+ if minimap_original is not None:
176
+ cv2.imshow("Minimap avec Projection Originale", minimap_original)
177
+ else:
178
+ print("N'a pas pu générer la minimap originale.")
179
+
180
+ # Afficher la minimap avec les squelettes décalés
181
+ if minimap_offset_skeletons is not None:
182
+ cv2.imshow("Minimap avec Squelettes Decales", minimap_offset_skeletons)
183
+ else:
184
+ print("N'a pas pu générer la minimap squelettes décalés.")
185
+
186
+ cv2.waitKey(0) # Attend qu'une touche soit pressée
187
+
188
+ else:
189
+ print("\nAucune homographie n'a pu être calculée.")
190
+
191
+ except Exception as e:
192
+ print(f"Erreur lors du traitement de l'image : {e}")
193
+ traceback.print_exc()
194
+ finally:
195
+ print("Fermeture des fenêtres OpenCV.")
196
+ cv2.destroyAllWindows()
197
+
198
+ if __name__ == "__main__":
199
+ main()
pose_estimator.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import cv2
4
+ from PIL import Image
5
+ from transformers import AutoProcessor, RTDetrForObjectDetection, VitPoseForPoseEstimation
6
+ from pathlib import Path
7
+
8
+ # --- Global variables for models and processor (lazy loading) ---
9
+ person_processor = None
10
+ person_model = None
11
+ pose_processor = None
12
+ pose_model = None
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ print(f"Pose Estimator: Using device: {device}")
15
+
16
+ # --- Constantes pour la couleur et le dessin ---
17
+ # Utilisation de tuples BGR pour les couleurs
18
+ DEFAULT_MARKER_COLOR = (255, 255, 255) # Blanc
19
+ MIN_PIXELS_FOR_COLOR = 20 # Nombre minimum de pixels valides dans la ROI pour tenter de calculer la couleur
20
+ CONFIDENCE_THRESHOLD_KEYPOINTS = 0.3 # Seuil pour considérer un keypoint fiable pour la ROI et le dessin
21
+ SKELETON_THICKNESS = 2
22
+
23
+ # Définition des segments du squelette (indices COCO 0-16)
24
+ # 0:Nose, 1:L_Eye, 2:R_Eye, 3:L_Ear, 4:R_Ear, 5:L_Shoulder, 6:R_Shoulder,
25
+ # 7:L_Elbow, 8:R_Elbow, 9:L_Wrist, 10:R_Wrist, 11:L_Hip, 12:R_Hip,
26
+ # 13:L_Knee, 14:R_Knee, 15:L_Ankle, 16:R_Ankle
27
+ SKELETON_EDGES = [
28
+ # Tête
29
+ (0, 1), (0, 2), (1, 3), (2, 4),
30
+ # Torse
31
+ (5, 6), (5, 11), (6, 12), (11, 12),
32
+ # Bras Gauche
33
+ (5, 7), (7, 9),
34
+ # Bras Droit
35
+ (6, 8), (8, 10),
36
+ # Jambe Gauche
37
+ (11, 13), (13, 15),
38
+ # Jambe Droite
39
+ (12, 14), (14, 16)
40
+ ]
41
+
42
+ # Indices des keypoints pour le torse et les chevilles
43
+ TORSO_KP_INDICES = [5, 6, 11, 12] # Épaules, Hanches
44
+ LEFT_ANKLE_KP_INDEX = 15
45
+ RIGHT_ANKLE_KP_INDEX = 16
46
+
47
+ def _load_models():
48
+ """Loads the models if they haven't been loaded yet."""
49
+ global person_processor, person_model, pose_processor, pose_model
50
+
51
+ if person_processor is None:
52
+ print("Loading RTDetr person detector model...")
53
+ person_processor = AutoProcessor.from_pretrained("PekingU/rtdetr_r50vd_coco_o365")
54
+ person_model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd_coco_o365", device_map=device)
55
+ print("✓ RTDetr loaded.")
56
+
57
+ if pose_processor is None:
58
+ print("Loading ViTPose pose estimation model...")
59
+ pose_processor = AutoProcessor.from_pretrained("usyd-community/vitpose-base-simple")
60
+ pose_model = VitPoseForPoseEstimation.from_pretrained("usyd-community/vitpose-base-simple", device_map=device)
61
+ print("✓ ViTPose loaded.")
62
+
63
+ def _is_color_greenish(bgr_pixel, threshold=10):
64
+ b, g, r = bgr_pixel
65
+ return g > b + threshold and g > r + threshold
66
+
67
+ def _is_color_grayscale(bgr_pixel, tolerance=30):
68
+ b, g, r = bgr_pixel
69
+ min_val, max_val = min(b, g, r), max(b, g, r)
70
+ is_dark = max_val < 50
71
+ is_light = min_val > 200
72
+ is_low_saturation = (max_val - min_val) < tolerance
73
+ return is_dark or is_light or is_low_saturation
74
+
75
+ def _get_average_color(roi_bgr):
76
+ """Calcule la couleur moyenne d'une ROI après filtrage."""
77
+ if roi_bgr is None or roi_bgr.size == 0:
78
+ return DEFAULT_MARKER_COLOR
79
+
80
+ try:
81
+ pixels = roi_bgr.reshape(-1, 3)
82
+ valid_pixels = []
83
+ for pixel in pixels:
84
+ if not _is_color_greenish(pixel) and not _is_color_grayscale(pixel):
85
+ valid_pixels.append(pixel)
86
+
87
+ if len(valid_pixels) < MIN_PIXELS_FOR_COLOR:
88
+ return DEFAULT_MARKER_COLOR
89
+
90
+ avg_color = np.mean(valid_pixels, axis=0)
91
+ return tuple(map(int, avg_color))
92
+
93
+ except Exception as e:
94
+ print(f" Erreur calcul couleur moyenne: {e}. Utilisation couleur défaut.")
95
+ return DEFAULT_MARKER_COLOR
96
+
97
+ def get_player_data(image_bgr: np.ndarray) -> list:
98
+ """
99
+ Detects persons, estimates pose, calculates average torso color,
100
+ and returns a list of data for each player.
101
+
102
+ Args:
103
+ image_bgr: The input image in BGR format (NumPy array).
104
+
105
+ Returns:
106
+ A list of dictionaries, each containing:
107
+ {
108
+ 'keypoints': np.ndarray (17, 2),
109
+ 'scores': np.ndarray (17,),
110
+ 'bbox': np.ndarray (4,) [x1, y1, x2, y2],
111
+ 'avg_color': tuple (b, g, r)
112
+ }
113
+ Returns an empty list if no persons are detected or an error occurs.
114
+ """
115
+ _load_models()
116
+ player_list = []
117
+ height, width = image_bgr.shape[:2]
118
+
119
+ try:
120
+ image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
121
+ image_pil = Image.fromarray(image_rgb)
122
+
123
+ # --- Stage 1: Detect humans ---
124
+ inputs_det = person_processor(images=image_pil, return_tensors="pt").to(device)
125
+ with torch.no_grad():
126
+ outputs_det = person_model(**inputs_det)
127
+ results_det = person_processor.post_process_object_detection(
128
+ outputs_det, target_sizes=torch.tensor([(height, width)]), threshold=0.5
129
+ )
130
+ result_det = results_det[0]
131
+ person_boxes = result_det["boxes"][result_det["labels"] == 0].cpu().numpy()
132
+
133
+ if len(person_boxes) == 0:
134
+ print("No persons detected.")
135
+ return player_list
136
+
137
+ person_boxes_coco = person_boxes.copy()
138
+ person_boxes_coco[:, 2] = person_boxes_coco[:, 2] - person_boxes_coco[:, 0]
139
+ person_boxes_coco[:, 3] = person_boxes_coco[:, 3] - person_boxes_coco[:, 1]
140
+
141
+ # --- Stage 2: Detect keypoints ---
142
+ inputs_pose = pose_processor(image_pil, boxes=[person_boxes_coco], return_tensors="pt").to(device)
143
+ with torch.no_grad():
144
+ outputs_pose = pose_model(**inputs_pose)
145
+ pose_results = pose_processor.post_process_pose_estimation(outputs_pose, boxes=[person_boxes_coco])
146
+ image_pose_result = pose_results[0]
147
+
148
+ if not image_pose_result:
149
+ print("Pose estimation did not return results.")
150
+ return player_list
151
+
152
+ # --- Stage 3: Process each person ---
153
+ for i, person_box_xyxy in enumerate(person_boxes):
154
+ if i >= len(image_pose_result): continue
155
+
156
+ pose_result = image_pose_result[i]
157
+ xy = pose_result['keypoints'].cpu().numpy()
158
+ scores = pose_result['scores'].cpu().numpy()
159
+
160
+ # Ensure xy shape is correct before proceeding
161
+ if xy.shape != (17, 2):
162
+ print(f"Person {i}: Unexpected keypoints shape {xy.shape}, skipping.")
163
+ continue
164
+
165
+ # -- Define Torso ROI --
166
+ reliable_torso_keypoints = xy[TORSO_KP_INDICES][scores[TORSO_KP_INDICES] > CONFIDENCE_THRESHOLD_KEYPOINTS]
167
+ x1_box, y1_box, x2_box, y2_box = map(int, person_box_xyxy)
168
+ box_h = y2_box - y1_box
169
+ box_w = x2_box - x1_box
170
+ if len(reliable_torso_keypoints) >= 3:
171
+ min_x_kp = int(np.min(reliable_torso_keypoints[:, 0]))
172
+ max_x_kp = int(np.max(reliable_torso_keypoints[:, 0]))
173
+ min_y_kp = int(np.min(reliable_torso_keypoints[:, 1]))
174
+ max_y_kp = int(np.max(reliable_torso_keypoints[:, 1]))
175
+ roi_x1 = max(x1_box, min_x_kp - 5); roi_y1 = max(y1_box, min_y_kp - 5)
176
+ roi_x2 = min(x2_box, max_x_kp + 5); roi_y2 = min(y2_box, max_y_kp + 5)
177
+ else:
178
+ roi_x1 = x1_box; roi_y1 = y1_box + int(0.1 * box_h)
179
+ roi_x2 = x2_box; roi_y2 = y1_box + int(0.6 * box_h)
180
+ roi_x1 = max(0, roi_x1); roi_y1 = max(0, roi_y1)
181
+ roi_x2 = min(width, roi_x2); roi_y2 = min(height, roi_y2)
182
+
183
+ # -- Extract Average Color --
184
+ avg_color = DEFAULT_MARKER_COLOR
185
+ if roi_y2 > roi_y1 and roi_x2 > roi_x1:
186
+ torso_roi = image_bgr[roi_y1:roi_y2, roi_x1:roi_x2]
187
+ avg_color = _get_average_color(torso_roi)
188
+ # else: # Pas besoin de message si ROI invalide, couleur par défaut suffit
189
+ # print(f"Person {i}: Invalid ROI, using default color.")
190
+
191
+ # -- Store player data --
192
+ player_data = {
193
+ 'keypoints': xy,
194
+ 'scores': scores,
195
+ 'bbox': person_box_xyxy, # Utiliser la bbox originale xyxy
196
+ 'avg_color': avg_color
197
+ }
198
+ player_list.append(player_data)
199
+
200
+ except Exception as e:
201
+ print(f"Error during player data extraction: {e}")
202
+ import traceback
203
+ traceback.print_exc()
204
+ # Retourner une liste vide en cas d'erreur majeure
205
+ return []
206
+
207
+ return player_list
208
+
209
+ # Example usage (optional, for testing the module directly)
210
+ if __name__ == '__main__':
211
+ test_image_path = 'img3.png'
212
+
213
+ if not Path(test_image_path).exists():
214
+ print(f"Test image not found: {test_image_path}")
215
+ else:
216
+ print(f"Testing player data extraction with image: {test_image_path}")
217
+ input_img = cv2.imread(test_image_path)
218
+
219
+ if input_img is None:
220
+ print(f"Failed to load test image: {test_image_path}")
221
+ else:
222
+ print("Getting player data...")
223
+ players = get_player_data(input_img)
224
+ print(f"✓ Found data for {len(players)} players.")
225
+
226
+ # --- Draw markers and info on original image for testing ---
227
+ output_img_test = input_img.copy()
228
+ for idx, p_data in enumerate(players):
229
+ kps = p_data['keypoints']
230
+ scores = p_data['scores']
231
+ bbox = p_data['bbox']
232
+ color = p_data['avg_color']
233
+
234
+ # Determine reference point (ankles or bbox bottom mid)
235
+ l_ankle_pt = kps[LEFT_ANKLE_KP_INDEX]
236
+ r_ankle_pt = kps[RIGHT_ANKLE_KP_INDEX]
237
+ l_ankle_score = scores[LEFT_ANKLE_KP_INDEX]
238
+ r_ankle_score = scores[RIGHT_ANKLE_KP_INDEX]
239
+
240
+ ref_point = None
241
+ if l_ankle_score > CONFIDENCE_THRESHOLD_KEYPOINTS and r_ankle_score > CONFIDENCE_THRESHOLD_KEYPOINTS:
242
+ ref_point = tuple(map(int, (l_ankle_pt + r_ankle_pt) / 2))
243
+ elif l_ankle_score > CONFIDENCE_THRESHOLD_KEYPOINTS:
244
+ ref_point = tuple(map(int, l_ankle_pt))
245
+ elif r_ankle_score > CONFIDENCE_THRESHOLD_KEYPOINTS:
246
+ ref_point = tuple(map(int, r_ankle_pt))
247
+ else:
248
+ x1, y1, x2, y2 = map(int, bbox)
249
+ ref_point = (int((x1 + x2) / 2), y2)
250
+
251
+ # Draw marker at reference point
252
+ if ref_point:
253
+ cv2.circle(output_img_test, ref_point, 8, color, -1, cv2.LINE_AA)
254
+ cv2.circle(output_img_test, ref_point, 8, (0,0,0), 1, cv2.LINE_AA) # Black outline
255
+ # Draw player index
256
+ cv2.putText(output_img_test, str(idx), (ref_point[0]+5, ref_point[1]-5),
257
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,0,0), 2, cv2.LINE_AA)
258
+ cv2.putText(output_img_test, str(idx), (ref_point[0]+5, ref_point[1]-5),
259
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 1, cv2.LINE_AA)
260
+
261
+ cv2.imshow("Original Image", input_img)
262
+ cv2.imshow("Player Markers Test", output_img_test)
263
+ print("Displaying test results. Press any key to exit.")
264
+ cv2.waitKey(0)
265
+ cv2.destroyAllWindows()
requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Installez PyTorch en utilisant la commande spécifique si nécessaire :
2
+ # pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121
3
+
4
+ # Dépendances principales
5
+ torch
6
+ torchvision
7
+ torchaudio
8
+ numpy
9
+ opencv-python
10
+ pytorch-lightning
11
+ soccernet
12
+ kornia
13
+
14
+ # Ajouté car requis par sn_segmentation
15
+
16
+ # Dépendances pour l'estimation de pose (ViTPose)
17
+ transformers
18
+ supervision
19
+ Pillow # Souvent une dépendance de transformers/supervision, mais explicite ici
20
+ accelerate
21
+ # scikit-learn # Retiré car K-Means n'est plus utilisé
22
+ gradio
tvcalib/cam_distr/tv_main_behind.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import pi
2
+ from tvcalib.utils.data_distr import mean_std_with_confidence_interval
3
+
4
+
5
+ def get_cam_distr(sigma_scale: float, batch_dim: int, temporal_dim: int):
6
+ cam_distr = {
7
+ "pan": {
8
+ "minmax": (pi / 4, 3 * pi / 4), # in deg 45°, 135°
9
+ "dimension": (
10
+ batch_dim,
11
+ temporal_dim,
12
+ ),
13
+ },
14
+ "tilt": {
15
+ "minmax": (pi / 16, pi / 2), # in deg 11.25°, 90°
16
+ "dimension": (
17
+ batch_dim,
18
+ temporal_dim,
19
+ ),
20
+ },
21
+ "roll": {
22
+ "minmax": (-pi / 32, pi / 32), # in deg -5.6°, 5.6°
23
+ "dimension": (
24
+ batch_dim,
25
+ temporal_dim,
26
+ ),
27
+ },
28
+ "aov": {
29
+ "minmax": (pi / 22, pi / 2), # (8.2°, 90°)
30
+ "dimension": (
31
+ batch_dim,
32
+ temporal_dim,
33
+ ),
34
+ },
35
+ "c_x": {
36
+ "minmax": (-32.5, -52.5),
37
+ "dimension": (
38
+ batch_dim,
39
+ 1,
40
+ ),
41
+ },
42
+ "c_y": {
43
+ "minmax": (-5.0, 5.0),
44
+ "dimension": (
45
+ batch_dim,
46
+ 1,
47
+ ),
48
+ },
49
+ "c_z": {
50
+ "minmax": (-35.0, -1.0),
51
+ "dimension": (
52
+ batch_dim,
53
+ 1,
54
+ ),
55
+ },
56
+ }
57
+
58
+ for k, params in cam_distr.items():
59
+ cam_distr[k]["mean_std"] = mean_std_with_confidence_interval(
60
+ *params["minmax"], sigma_scale=sigma_scale
61
+ )
62
+ return cam_distr
63
+
64
+
65
+ def get_dist_distr(batch_dim: int, temporal_dim: int, _sigma_scale: float = 2.57):
66
+ return {
67
+ "k1": {
68
+ "minmax": [0.0, 0.5], # we clip min(0.0, x)
69
+ "mean_std": (0.0, _sigma_scale * 0.5),
70
+ "dimension": (batch_dim, temporal_dim),
71
+ },
72
+ "k2": {
73
+ "minmax": [-0.1, 0.1],
74
+ "mean_std": (0.0, _sigma_scale * 0.1),
75
+ "dimension": (batch_dim, temporal_dim),
76
+ },
77
+ }
tvcalib/cam_distr/tv_main_center.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from math import pi
3
+ from ..utils.data_distr import mean_std_with_confidence_interval
4
+
5
+
6
+ def get_cam_distr(sigma_scale: float, batch_dim: int, temporal_dim: int):
7
+ cam_distr = {
8
+ "pan": {
9
+ "minmax": (-pi / 4, pi / 4), # in deg -45°, 45°
10
+ "dimension": (
11
+ batch_dim,
12
+ temporal_dim,
13
+ ),
14
+ },
15
+ "tilt": {
16
+ "minmax": (pi / 4, pi / 2), # in deg 45°, 90°
17
+ "dimension": (
18
+ batch_dim,
19
+ temporal_dim,
20
+ ),
21
+ },
22
+ "roll": {
23
+ "minmax": (-pi / 18, pi / 18), # in deg -10°, 10°
24
+ "dimension": (
25
+ batch_dim,
26
+ temporal_dim,
27
+ ),
28
+ },
29
+ "aov": {
30
+ "minmax": (pi / 22, pi / 2), # (8.2°, 90°)
31
+ "dimension": (
32
+ batch_dim,
33
+ temporal_dim,
34
+ ),
35
+ },
36
+ "c_x": {
37
+ "minmax": (-12.0, 12.0), # (-36, 36) for entire main tribune
38
+ "dimension": (
39
+ batch_dim,
40
+ 1,
41
+ ),
42
+ },
43
+ "c_y": {
44
+ "minmax": (40.0, 110.0),
45
+ "dimension": (
46
+ batch_dim,
47
+ 1,
48
+ ),
49
+ },
50
+ "c_z": {
51
+ "minmax": (-40.0, -5.0),
52
+ "dimension": (
53
+ batch_dim,
54
+ 1,
55
+ ),
56
+ },
57
+ }
58
+
59
+ for k, params in cam_distr.items():
60
+ cam_distr[k]["mean_std"] = mean_std_with_confidence_interval(
61
+ *params["minmax"], sigma_scale=sigma_scale
62
+ )
63
+ return cam_distr
64
+
65
+
66
+ def get_dist_distr(batch_dim: int, temporal_dim: int, _sigma_scale: float = 2.57):
67
+ return {
68
+ "k1": {
69
+ "minmax": [0.0, 0.5], # we clip min(0.0, x)
70
+ "mean_std": (0.0, _sigma_scale * 0.5),
71
+ "dimension": (batch_dim, temporal_dim),
72
+ },
73
+ "k2": {
74
+ "minmax": [-0.1, 0.1],
75
+ "mean_std": (0.0, _sigma_scale * 0.1),
76
+ "dimension": (batch_dim, temporal_dim),
77
+ },
78
+ }
tvcalib/cam_distr/tv_main_left.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import pi
2
+ from tvcalib.utils.data_distr import mean_std_with_confidence_interval
3
+
4
+
5
+ def get_cam_distr(sigma_scale: float, batch_dim: int, temporal_dim: int):
6
+ cam_distr = {
7
+ "pan": {
8
+ "minmax": (-pi / 4, pi / 4), # in deg -45°, 45°
9
+ "dimension": (
10
+ batch_dim,
11
+ temporal_dim,
12
+ ),
13
+ },
14
+ "tilt": {
15
+ "minmax": (pi / 4, pi / 2), # in deg 45°, 90°
16
+ "dimension": (
17
+ batch_dim,
18
+ temporal_dim,
19
+ ),
20
+ },
21
+ "roll": {
22
+ "minmax": (-pi / 18, pi / 18), # in deg -10°, 10°
23
+ "dimension": (
24
+ batch_dim,
25
+ temporal_dim,
26
+ ),
27
+ },
28
+ "aov": {
29
+ "minmax": (pi / 22, pi / 2), # (8.2°, 90°)
30
+ "dimension": (
31
+ batch_dim,
32
+ temporal_dim,
33
+ ),
34
+ },
35
+ "c_x": {
36
+ "minmax": (-36 - 16.5, -36 + 16.5),
37
+ "dimension": (
38
+ batch_dim,
39
+ 1,
40
+ ),
41
+ },
42
+ "c_y": {
43
+ "minmax": (40.0, 110.0),
44
+ "dimension": (
45
+ batch_dim,
46
+ 1,
47
+ ),
48
+ },
49
+ "c_z": {
50
+ "minmax": (-40.0, -5.0),
51
+ "dimension": (
52
+ batch_dim,
53
+ 1,
54
+ ),
55
+ },
56
+ }
57
+
58
+ for k, params in cam_distr.items():
59
+ cam_distr[k]["mean_std"] = mean_std_with_confidence_interval(
60
+ *params["minmax"], sigma_scale=sigma_scale
61
+ )
62
+ return cam_distr
63
+
64
+
65
+ def get_dist_distr(batch_dim: int, temporal_dim: int, _sigma_scale: float = 2.57):
66
+ return {
67
+ "k1": {
68
+ "minmax": [0.0, 0.5], # we clip min(0.0, x)
69
+ "mean_std": (0.0, _sigma_scale * 0.5),
70
+ "dimension": (batch_dim, temporal_dim),
71
+ },
72
+ "k2": {
73
+ "minmax": [-0.1, 0.1],
74
+ "mean_std": (0.0, _sigma_scale * 0.1),
75
+ "dimension": (batch_dim, temporal_dim),
76
+ },
77
+ }
tvcalib/cam_distr/tv_main_right.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import pi
2
+ from tvcalib.utils.data_distr import mean_std_with_confidence_interval
3
+
4
+
5
+ def get_cam_distr(sigma_scale: float, batch_dim: int, temporal_dim: int):
6
+ cam_distr = {
7
+ "pan": {
8
+ "minmax": (-pi / 4, pi / 4), # in deg -45°, 45°
9
+ "dimension": (
10
+ batch_dim,
11
+ temporal_dim,
12
+ ),
13
+ },
14
+ "tilt": {
15
+ "minmax": (pi / 4, pi / 2), # in deg 45°, 90°
16
+ "dimension": (
17
+ batch_dim,
18
+ temporal_dim,
19
+ ),
20
+ },
21
+ "roll": {
22
+ "minmax": (-pi / 18, pi / 18), # in deg -10°, 10°
23
+ "dimension": (
24
+ batch_dim,
25
+ temporal_dim,
26
+ ),
27
+ },
28
+ "aov": {
29
+ "minmax": (pi / 22, pi / 2), # (8.2°, 90°)
30
+ "dimension": (
31
+ batch_dim,
32
+ temporal_dim,
33
+ ),
34
+ },
35
+ "c_x": {
36
+ "minmax": (36 - 16.5, 36 + 16.5),
37
+ "dimension": (
38
+ batch_dim,
39
+ 1,
40
+ ),
41
+ },
42
+ "c_y": {
43
+ "minmax": (40.0, 110.0),
44
+ "dimension": (
45
+ batch_dim,
46
+ 1,
47
+ ),
48
+ },
49
+ "c_z": {
50
+ "minmax": (-40.0, -5.0),
51
+ "dimension": (
52
+ batch_dim,
53
+ 1,
54
+ ),
55
+ },
56
+ }
57
+
58
+ for k, params in cam_distr.items():
59
+ cam_distr[k]["mean_std"] = mean_std_with_confidence_interval(
60
+ *params["minmax"], sigma_scale=sigma_scale
61
+ )
62
+ return cam_distr
63
+
64
+
65
+ def get_dist_distr(batch_dim: int, temporal_dim: int, _sigma_scale: float = 2.57):
66
+ return {
67
+ "k1": {
68
+ "minmax": [0.0, 0.5], # we clip min(0.0, x)
69
+ "mean_std": (0.0, _sigma_scale * 0.5),
70
+ "dimension": (batch_dim, temporal_dim),
71
+ },
72
+ "k2": {
73
+ "minmax": [-0.1, 0.1],
74
+ "mean_std": (0.0, _sigma_scale * 0.1),
75
+ "dimension": (batch_dim, temporal_dim),
76
+ },
77
+ }
tvcalib/cam_distr/tv_main_tribune.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import pi
2
+ from tvcalib.utils.data_distr import mean_std_with_confidence_interval
3
+
4
+
5
+ def get_cam_distr(sigma_scale: float, batch_dim: int, temporal_dim: int):
6
+ cam_distr = {
7
+ "pan": {
8
+ "minmax": (-pi / 4, pi / 4), # in deg -45°, 45°
9
+ "dimension": (
10
+ batch_dim,
11
+ temporal_dim,
12
+ ),
13
+ },
14
+ "tilt": {
15
+ "minmax": (pi / 4, pi / 2), # in deg 45°, 90°
16
+ "dimension": (
17
+ batch_dim,
18
+ temporal_dim,
19
+ ),
20
+ },
21
+ "roll": {
22
+ "minmax": (-pi / 18, pi / 18), # in deg -10°, 10°
23
+ "dimension": (
24
+ batch_dim,
25
+ temporal_dim,
26
+ ),
27
+ },
28
+ "aov": {
29
+ "minmax": (pi / 22, pi / 2), # (8.2°, 90°)
30
+ "dimension": (
31
+ batch_dim,
32
+ temporal_dim,
33
+ ),
34
+ },
35
+ "c_x": {
36
+ "minmax": (-40.0, 40.0), # entire main tribune
37
+ "dimension": (
38
+ batch_dim,
39
+ 1,
40
+ ),
41
+ },
42
+ "c_y": {
43
+ "minmax": (40.0, 110.0),
44
+ "dimension": (
45
+ batch_dim,
46
+ 1,
47
+ ),
48
+ },
49
+ "c_z": {
50
+ "minmax": (-40.0, -5.0),
51
+ "dimension": (
52
+ batch_dim,
53
+ 1,
54
+ ),
55
+ },
56
+ }
57
+
58
+ for k, params in cam_distr.items():
59
+ cam_distr[k]["mean_std"] = mean_std_with_confidence_interval(
60
+ *params["minmax"], sigma_scale=sigma_scale
61
+ )
62
+ return cam_distr
63
+
64
+
65
+ def get_dist_distr(batch_dim: int, temporal_dim: int, _sigma_scale: float = 2.57):
66
+ return {
67
+ "k1": {
68
+ "minmax": [0.0, 0.5], # we clip min(0.0, x)
69
+ "mean_std": (0.0, _sigma_scale * 0.5),
70
+ "dimension": (batch_dim, temporal_dim),
71
+ },
72
+ "k2": {
73
+ "minmax": [-0.1, 0.1],
74
+ "mean_std": (0.0, _sigma_scale * 0.1),
75
+ "dimension": (batch_dim, temporal_dim),
76
+ },
77
+ }
tvcalib/cam_modules.py ADDED
@@ -0,0 +1,583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Dict, Union
2
+
3
+ from pytorch_lightning import LightningModule
4
+ import torch
5
+ import kornia
6
+ import torch.nn as nn
7
+ from .utils.data_distr import FeatureScalerZScore
8
+
9
+
10
+ class CameraParameterWLensDistDictZScore(LightningModule):
11
+ """Holds individual camera parameters including lens distortion parameters as nn.Modul"""
12
+
13
+ def __init__(self, cam_distr, dist_distr, device="cpu"):
14
+ super(CameraParameterWLensDistDictZScore, self).__init__()
15
+
16
+ self.cam_distr = cam_distr
17
+ self._device = device
18
+
19
+ # phi raw
20
+ self.param_dict = torch.nn.ParameterDict(
21
+ {
22
+ k: torch.nn.parameter.Parameter(
23
+ torch.zeros(
24
+ *cam_distr[k]["dimension"],
25
+ device=device,
26
+ ),
27
+ requires_grad=False
28
+ if ("no_grad" in cam_distr[k]) and (cam_distr[k]["no_grad"] == True)
29
+ else True,
30
+ )
31
+ for k in cam_distr.keys()
32
+ }
33
+ )
34
+
35
+ # denormalization module to get phi_target
36
+ self.feature_scaler = torch.nn.ModuleDict(
37
+ {k: FeatureScalerZScore(*cam_distr[k]["mean_std"]) for k in cam_distr.keys()}
38
+ )
39
+
40
+ self.dist_distr = dist_distr
41
+ if self.dist_distr is not None:
42
+ self.param_dict_dist = torch.nn.ParameterDict(
43
+ {
44
+ k: torch.nn.Parameter(torch.zeros(*dist_distr[k]["dimension"], device=device))
45
+ for k in dist_distr.keys()
46
+ }
47
+ )
48
+ # TODO: modify later to dynamically cunstruct a tensor of shape (k_1,k_2,p_1,p_2[,k_3[,k_4,k_5,k_6[,s_1,s_2,s_3,s_4[,\tau_x,\tau_y]]]])
49
+ #
50
+
51
+ self.feature_scaler_dist_coeff = torch.nn.ModuleDict(
52
+ {k: FeatureScalerZScore(*dist_distr[k]["mean_std"]) for k in dist_distr.keys()}
53
+ )
54
+
55
+ def initialize(
56
+ self,
57
+ update_dict_cam: Union[Dict[str, Union[float, torch.tensor]], None],
58
+ update_dict_dist=None,
59
+ ):
60
+ """Initializes all camera parameters with zeros and replace specific values with provided values
61
+
62
+ Args:
63
+ update_dict_cam (Dict[str, Union[float, torch.tensor]]): Parameters to be updated
64
+ """
65
+
66
+ for k in self.param_dict.keys():
67
+ self.param_dict[k].data = torch.zeros(
68
+ *self.cam_distr[k]["dimension"], device=self._device
69
+ )
70
+ if self.dist_distr is not None:
71
+ for k in self.dist_distr.keys():
72
+ self.param_dict_dist[k].data = torch.zeros(
73
+ *self.dist_distr[k]["dimension"], device=self._device
74
+ )
75
+
76
+ if update_dict_cam is not None and len(update_dict_cam) > 0:
77
+ for k, v in update_dict_cam.items():
78
+ self.param_dict[k].data = (
79
+ torch.zeros(*self.cam_distr[k]["dimension"], device=self._device) + v
80
+ )
81
+ if update_dict_dist is not None:
82
+ raise NotImplementedError
83
+
84
+ def forward(self):
85
+ phi_dict = {}
86
+ for k, param in self.param_dict.items():
87
+ phi_dict[k] = self.feature_scaler[k](param)
88
+
89
+ if self.dist_distr is None:
90
+ return phi_dict, None
91
+
92
+ # This is a vector with 4, 5, 8, 12 or 14 elements with shape :math:`(*, n)` depending on the provided dict of coefficients
93
+ # assumes dict is ordered according (k_1,k_2,p_1,p_2[,k_3[,k_4,k_5,k_6[,s_1,s_2,s_3,s_4[,\tau_x,\tau_y]]]])
94
+ psi = torch.stack(
95
+ [
96
+ torch.clamp(
97
+ self.feature_scaler_dist_coeff[k](param),
98
+ min=self.dist_distr[k]["minmax"][0],
99
+ max=self.dist_distr[k]["minmax"][1],
100
+ )
101
+ for k, param in self.param_dict_dist.items()
102
+ ],
103
+ dim=-1, # stack individual features and not arbirary leading dimensions
104
+ )
105
+
106
+ return phi_dict, psi
107
+
108
+
109
+ class SNProjectiveCamera:
110
+ def __init__(
111
+ self,
112
+ phi_dict: Dict[str, torch.tensor],
113
+ psi: torch.tensor,
114
+ principal_point: Tuple[float, float],
115
+ image_width: int,
116
+ image_height: int,
117
+ device: str = "cpu",
118
+ nan_check=True,
119
+ ) -> None:
120
+ """Projective camera defined as K @ R [I|-t] with lens distortion module and batch dimensions B,T.
121
+
122
+ Following Euler angles convention, we use a ZXZ succession of intrinsic rotations in order to describe
123
+ the orientation of the camera. Starting from the world reference axis system, we first apply a rotation
124
+ around the Z axis to pan the camera. Then the obtained axis system is rotated around its x axis in order to tilt the camera.
125
+ Then the last rotation around the z axis of the new axis system alows to roll the camera. Note that this z axis is the principal axis of the camera.
126
+
127
+ As T is not provided for camra location and lens distortion, these parameters are assumed to be fixed accross T.
128
+ phi_dict is a dict of parameters containing:
129
+ {
130
+ 'aov_x, torch.Size([B, T])',
131
+ 'pan, torch.Size([B, T])',
132
+ 'tilt, torch.Size([B, T])',
133
+ 'roll, torch.Size([B, T])',
134
+ 'c_x, torch.Size([B, 1])',
135
+ 'c_y, torch.Size([B, 1])',
136
+ 'c_z, torch.Size([B, 1])',
137
+ }
138
+
139
+ Internally fuses B and T dimension to pseudo batch dimension.
140
+ {
141
+ 'aov_x, torch.Size([B*T])',
142
+ 'pan, torch.Size([B*T])',
143
+ 'tilt, torch.Size([B*T])'
144
+ 'roll, torch.Size([B*T])',
145
+ 'c_x, torch.Size([B])',
146
+ 'c_y, torch.Size([B])',
147
+ 'c_z, torch.Size([B])',
148
+ }
149
+
150
+ aov_x, pan, tilt, roll are assumed in radian.
151
+
152
+ Note on lens distortion:
153
+ Lens distortion coefficients are independent from image resolution!
154
+ We I(dist_points(K_ndc, dist_coeff, points2d_ndc)) == I(dist_points(K_raster, dist_coeff, points2d_raster))
155
+
156
+ Args:
157
+ phi_dict (Dict[str, torch.tensor]): See example above
158
+ psi (Union[None, torch.Tensor]): distortion coefficients as concatinated vector according to https://kornia.readthedocs.io/en/latest/geometry.calibration.html of shape (B, T, {2, 4, 5,8,12, 14})
159
+ principal_point (Tuple[float, float]): Principal point assumed to be fixed across all samples (B,T,)
160
+ image_width (int): assumed to be fixed across all samples (B,T,)
161
+ image_height (int): assumed to be fixed across all samples (B,T,)
162
+ """
163
+
164
+ # fuse B and T dimension
165
+ phi_dict_flat = {}
166
+ for k, v in phi_dict.items():
167
+ if len(v.shape) == 2:
168
+ phi_dict_flat[k] = v.view(v.shape[0] * v.shape[1])
169
+ elif len(v.shape) == 3:
170
+ phi_dict_flat[k] = v.view(v.shape[0] * v.shape[1], v.shape[-1])
171
+
172
+ self.batch_dim, self.temporal_dim = phi_dict["pan"].shape
173
+ self.pseudo_batch_size = phi_dict_flat["pan"].shape[0]
174
+ self.phi_dict_flat = phi_dict_flat
175
+
176
+ self.principal_point = principal_point
177
+ self.image_width = image_width
178
+ self.image_height = image_height
179
+ self.device = device
180
+
181
+ self.psi = psi
182
+ if self.psi is not None:
183
+ if self.psi.shape[-1] != 2:
184
+ raise NotImplementedError
185
+
186
+ # :math:`(k_1,k_2,p_1,p_2[,k_3[,k_4,k_5,k_6[,s_1,s_2,s_3,s_4[,\tau_x,\tau_y]]]])`.
187
+ # psi is a vector with 2, 4, 5, 8, 12 or 14 elements with shape :math:`(*, n)`.
188
+ if self.psi.shape[-1] == 2:
189
+ # assume zero tangential coefficients
190
+ psi_ext = torch.zeros(*list(self.psi.shape[:-1]), 4)
191
+ psi_ext[..., :2] = self.psi
192
+ self.psi = psi_ext
193
+ self.lens_dist_coeff = self.psi.view(self.pseudo_batch_size, self.psi.shape[-1]).to(
194
+ self.device
195
+ )
196
+
197
+ self.intrinsics_ndc = self.construct_intrinsics_ndc()
198
+ self.intrinsics_raster = self.construct_intrinsics_raster()
199
+
200
+ self.rotation = self.rotation_from_euler_angles(
201
+ *[phi_dict_flat[k] for k in ["pan", "tilt", "roll"]]
202
+ )
203
+ self.position = torch.stack([phi_dict_flat[k] for k in ["c_x", "c_y", "c_z"]], dim=-1)
204
+ self.position = self.position.repeat_interleave(
205
+ int(self.pseudo_batch_size / self.batch_dim), dim=0
206
+ ) # (B, 3) # TODO: probably needs modification if B > 0?
207
+ self.P_ndc = self.construct_projection_matrix(self.intrinsics_ndc)
208
+ self.P_raster = self.construct_projection_matrix(self.intrinsics_raster)
209
+ self.phi_dict = phi_dict
210
+
211
+ self.nan_check = nan_check
212
+ super().__init__()
213
+
214
+ def construct_projection_matrix(self, intrinsics):
215
+ It = torch.eye(4, device=self.device)[:-1].repeat(self.pseudo_batch_size, 1, 1)
216
+ It[:, :, -1] = -self.position # (B, 3, 4)
217
+ self.It = It
218
+ return intrinsics @ self.rotation @ It # # (B, 3, 4)
219
+
220
+ def construct_intrinsics_ndc(self):
221
+ # assume that the principal point is (0,0)
222
+ K = torch.eye(3, requires_grad=False, device=self.device)
223
+ K = K.reshape((1, 3, 3)).repeat(self.pseudo_batch_size, 1, 1)
224
+ K[:, 0, 0] = self.get_fl_from_aov_rad(self.phi_dict_flat["aov"], d=2)
225
+ K[:, 1, 1] = self.get_fl_from_aov_rad(
226
+ self.phi_dict_flat["aov"], d=2 * self.image_width / self.image_height
227
+ )
228
+ return K
229
+
230
+ def construct_intrinsics_raster(self):
231
+ # assume that the principal point is (W/2,H/2)
232
+ K = torch.eye(3, requires_grad=False, device=self.device)
233
+ K = K.reshape((1, 3, 3)).repeat(self.pseudo_batch_size, 1, 1)
234
+ K[:, 0, 0] = self.get_fl_from_aov_rad(self.phi_dict_flat["aov"], d=self.image_width)
235
+ K[:, 1, 1] = self.get_fl_from_aov_rad(self.phi_dict_flat["aov"], d=self.image_width)
236
+ K[:, 0, 2] = self.principal_point[0]
237
+ K[:, 1, 2] = self.principal_point[1]
238
+ return K
239
+
240
+ def __str__(self) -> str:
241
+ return f"aov_deg={torch.rad2deg(self.phi_dict['aov'])}, t={torch.stack([self.phi_dict[k] for k in ['c_x', 'c_y', 'c_z']], dim=-1)}, pan_deg={torch.rad2deg(self.phi_dict['pan'])} tilt_deg={torch.rad2deg(self.phi_dict['tilt'])} roll_deg={torch.rad2deg(self.phi_dict['roll'])}"
242
+
243
+ def str_pan_tilt_roll_fl(self, b, t):
244
+ r = f"FOV={torch.rad2deg(self.phi_dict['aov'][b, t]):.1f}°, pan={torch.rad2deg(self.phi_dict['pan'][b, t]):.1f}° tilt={torch.rad2deg(self.phi_dict['tilt'][b, t]):.1f}° roll={torch.rad2deg(self.phi_dict['roll'][b, t]):.1f}°"
245
+ return r
246
+
247
+ def str_lens_distortion_coeff(self, b):
248
+ # TODO: T! also need indivudual lens_dist_coeff for each t in T
249
+ # print(self.lens_dist_coeff.shape)
250
+ return f"lens dist coeff=" + " ".join(
251
+ [f"{x:.2f}" for x in self.lens_dist_coeff[b, :2]]
252
+ ) # print only radial lens dist. coeff
253
+
254
+ def __repr__(self) -> str:
255
+ return f"{self.__class__}:" + self.__str__()
256
+
257
+ def __len__(self):
258
+ return self.pseudo_batch_size # e.g. self.intrinsics.shape[0]
259
+
260
+ def project_point2pixel(self, points3d: torch.tensor, lens_distortion: bool) -> torch.tensor:
261
+ """Project world coordinates to pixel coordinates.
262
+
263
+ Args:
264
+ points3d (torch.tensor): of shape (N, 3) or (1, N, 3)
265
+
266
+ Returns:
267
+ torch.tensor: projected points of shape (B, T, N, 2)
268
+ """
269
+ position = self.position.view(self.pseudo_batch_size, 1, 3)
270
+ point = points3d - position
271
+ rotated_point = self.rotation @ point.transpose(1, 2) # (pseudo_batch_size, 3, N)
272
+ dist_point2cam = rotated_point[:, 2] # (B, N) distance pixel to world point
273
+ dist_point2cam = dist_point2cam.view(self.pseudo_batch_size, 1, rotated_point.shape[-1])
274
+ rotated_point = rotated_point / dist_point2cam # (B, 3, N) / (B, 1, N) -> (B, 3, N)
275
+
276
+ projected_points = self.intrinsics_raster @ rotated_point # (B, 3, N)
277
+ # transpose vs view? here
278
+ projected_points = projected_points.transpose(-1, -2) # cannot use view()
279
+ projected_points = kornia.geometry.convert_points_from_homogeneous(projected_points)
280
+ if lens_distortion:
281
+ if self.psi is None:
282
+ raise RuntimeError("Lens distortion requested, but deactivated in module")
283
+ projected_points = self.distort_points(projected_points, self.intrinsics_raster)
284
+
285
+ # reshape back from (pseudo_batch_size, N, 2) to (B, T, N, 2)
286
+ projected_points = projected_points.view(
287
+ self.batch_dim, self.temporal_dim, projected_points.shape[-2], 2
288
+ )
289
+ if self.nan_check:
290
+ if torch.isnan(projected_points).any().item():
291
+ print(self.phi_dict_flat)
292
+ print(projected_points)
293
+ raise RuntimeWarning("NaN in project_point2pixel")
294
+ return projected_points
295
+
296
+ def project_point2ndc(self, points3d: torch.tensor, lens_distortion: bool) -> torch.tensor:
297
+ """Project world coordinates to pixel coordinates.
298
+
299
+ Args:
300
+ points3d (torch.tensor): of shape (N, 3) or (1, N, 3)
301
+
302
+ Returns:
303
+ torch.tensor: projected points of shape (B, T, N, 2)
304
+ """
305
+ position = self.position.view(self.pseudo_batch_size, 1, 3)
306
+ point = points3d - position
307
+ rotated_point = self.rotation @ point.transpose(1, 2) # (pseudo_batch_size, 3, N)
308
+ dist_point2cam = rotated_point[:, 2] # (B, N) distance pixel to world point
309
+ dist_point2cam = dist_point2cam.view(self.pseudo_batch_size, 1, rotated_point.shape[-1])
310
+ rotated_point = rotated_point / dist_point2cam # (B, 3, N) / (B, 1, N) -> (B, 3, N)
311
+
312
+ projected_points = self.intrinsics_ndc @ rotated_point # (B, 3, N)
313
+ # transpose vs view? here
314
+ projected_points = projected_points.transpose(-1, -2) # cannot use view()
315
+ projected_points = kornia.geometry.convert_points_from_homogeneous(projected_points)
316
+ if self.nan_check:
317
+ if torch.isnan(projected_points).any().item():
318
+ print(projected_points)
319
+ print(self.phi_dict_flat)
320
+ print("lens distortion", self.lens_dist_coeff)
321
+
322
+ raise RuntimeWarning("NaN in project_point2ndc before distort")
323
+ if lens_distortion:
324
+ if self.psi is None:
325
+ raise RuntimeError("Lens distortion requested, but deactivated in module")
326
+ projected_points = self.distort_points(projected_points, self.intrinsics_ndc)
327
+
328
+ # reshape back from (pseudo_batch_size, N, 2) to (B, T, N, 2)
329
+ projected_points = projected_points.view(
330
+ self.batch_dim, self.temporal_dim, projected_points.shape[-2], 2
331
+ )
332
+ if self.nan_check:
333
+ if torch.isnan(projected_points).any().item():
334
+ print(self.phi_dict_flat)
335
+ print(projected_points)
336
+ raise RuntimeWarning("NaN in project_point2ndc after distort")
337
+ return projected_points
338
+
339
+ def project_point2pixel_from_P(
340
+ self, points3d: torch.tensor, lens_distortion: bool
341
+ ) -> torch.tensor:
342
+ """Project world coordinates to pixel coordinates from the projection matrix.
343
+
344
+ Args:
345
+ points3d (torch.tensor): of shape (1, N, 3)
346
+
347
+ Returns:
348
+ torch.tensor: projected points of shape (B, T, N, 2)
349
+ """
350
+
351
+ points3d = kornia.geometry.conversions.convert_points_to_homogeneous(points3d).transpose(
352
+ 1, 2
353
+ ) # (B, 4, N)
354
+ projected_points = torch.bmm(self.P_raster, points3d.repeat(self.pseudo_batch_size, 1, 1))
355
+ normalize_by = projected_points[:, -1].view(
356
+ self.pseudo_batch_size, 1, projected_points.shape[-1]
357
+ )
358
+ projected_points /= normalize_by
359
+ projected_points = projected_points.transpose(-1, -2) # cannot use view()
360
+ projected_points = kornia.geometry.convert_points_from_homogeneous(projected_points)
361
+ if lens_distortion:
362
+ if self.psi is None:
363
+ raise RuntimeError("Lens distortion requested, but deactivated in module")
364
+ projected_points = self.distort_points(projected_points, self.intrinsics_raster)
365
+ # reshape back from (pseudo_batch_size, N, 2) to (B, T, N, 2)
366
+ projected_points = projected_points.view(
367
+ self.batch_dim, self.temporal_dim, projected_points.shape[-2], 2
368
+ )
369
+ return projected_points # (B, T, N, 2)
370
+
371
+ def project_point2ndc_from_P(
372
+ self, points3d: torch.tensor, lens_distortion: bool
373
+ ) -> torch.tensor:
374
+ """Project world coordinates to pixel coordinates from the projection matrix.
375
+
376
+ Args:
377
+ points3d (torch.tensor): of shape (1, N, 3)
378
+
379
+ Returns:
380
+ torch.tensor: projected points of shape (B, T, N, 2)
381
+ """
382
+
383
+ points3d = kornia.geometry.conversions.convert_points_to_homogeneous(points3d).transpose(
384
+ 1, 2
385
+ ) # (B, 4, N)
386
+ projected_points = torch.bmm(self.P_ndc, points3d.repeat(self.pseudo_batch_size, 1, 1))
387
+ normalize_by = projected_points[:, -1].view(
388
+ self.pseudo_batch_size, 1, projected_points.shape[-1]
389
+ )
390
+ projected_points /= normalize_by
391
+ projected_points = projected_points.transpose(-1, -2) # cannot use view()
392
+ projected_points = kornia.geometry.convert_points_from_homogeneous(projected_points)
393
+ if lens_distortion:
394
+ if self.psi is None:
395
+ raise RuntimeError("Lens distortion requested, but deactivated in module")
396
+ projected_points = self.distort_points(projected_points, self.intrinsics_ndc)
397
+ # reshape back from (pseudo_batch_size, N, 2) to (B, T, N, 2)
398
+ projected_points = projected_points.view(
399
+ self.batch_dim, self.temporal_dim, projected_points.shape[-2], 2
400
+ )
401
+ return projected_points # (B, T, N, 2)
402
+
403
+ def rotation_from_euler_angles(self, pan, tilt, roll):
404
+ # rotation matrices from a batch of pan tilt roll [rad] vectors of shape (?, )
405
+
406
+ mask = (
407
+ torch.eye(3, requires_grad=False, device=self.device)
408
+ .reshape((1, 3, 3))
409
+ .repeat(pan.shape[0], 1, 1)
410
+ )
411
+ mask[:, 0, 0] = -torch.sin(pan) * torch.sin(roll) * torch.cos(tilt) + torch.cos(
412
+ pan
413
+ ) * torch.cos(roll)
414
+ mask[:, 0, 1] = torch.sin(pan) * torch.cos(roll) + torch.sin(roll) * torch.cos(
415
+ pan
416
+ ) * torch.cos(tilt)
417
+ mask[:, 0, 2] = torch.sin(roll) * torch.sin(tilt)
418
+
419
+ mask[:, 1, 0] = -torch.sin(pan) * torch.cos(roll) * torch.cos(tilt) - torch.sin(
420
+ roll
421
+ ) * torch.cos(pan)
422
+ mask[:, 1, 1] = -torch.sin(pan) * torch.sin(roll) + torch.cos(pan) * torch.cos(
423
+ roll
424
+ ) * torch.cos(tilt)
425
+ mask[:, 1, 2] = torch.sin(tilt) * torch.cos(roll)
426
+
427
+ mask[:, 2, 0] = torch.sin(pan) * torch.sin(tilt)
428
+ mask[:, 2, 1] = -torch.sin(tilt) * torch.cos(pan)
429
+ mask[:, 2, 2] = torch.cos(tilt)
430
+
431
+ return mask
432
+
433
+ def get_homography_raster(self):
434
+ return self.P_raster[:, :, [0, 1, 3]].inverse()
435
+
436
+ def get_rays_world(self, x):
437
+ """_summary_
438
+
439
+ Args:
440
+ x (_type_): x of shape (B, 3, N)
441
+
442
+ Returns:
443
+ LineCollection: _description_
444
+ """
445
+ raise NotImplementedError
446
+ # TODO: verify
447
+ # ray_cam_trans = torch.bmm(self.rotation.inverse(), torch.bmm(self.intrinsics.inverse(), x))
448
+ # # unnormalized direction vector in euclidean points (x,y,z) based on camera origin (0,0,0)
449
+ # ray_cam_trans = torch.nn.functional.normalize(ray_cam_trans, p=2, dim=1) # (B, 3, N)
450
+
451
+ # # shift support vector to origin in world space, i.e. the translation vector
452
+ # support = self.position.unsqueeze(-1).repeat(
453
+ # ray_cam_trans.shape[0], 1, ray_cam_trans.shape[2]
454
+ # ) # (B, 3, N)
455
+ # return LineCollection(support=support, direction_norm=ray_cam_trans)
456
+
457
+ @staticmethod
458
+ def get_aov_rad(d: float, fl: torch.tensor):
459
+ # https://en.wikipedia.org/wiki/Angle_of_view#Calculating_a_camera's_angle_of_view
460
+ return 2 * torch.arctan(d / (2 * fl)) # in range [0.0, PI]
461
+
462
+ @staticmethod
463
+ def get_fl_from_aov_rad(aov_rad: torch.tensor, d: float):
464
+ return 0.5 * d * (1 / torch.tan(0.5 * aov_rad))
465
+
466
+ def undistort_points(self, points_pixel: torch.tensor, intrinsics, num_iters=5) -> torch.tensor:
467
+ """Compensate for lens distortion a set of 2D image points.
468
+
469
+ Wrapper for kornia.geometry.undistort_points()
470
+
471
+ Args:
472
+ points_pixel (torch.tensor): tensor of shape (B, N, 2)
473
+
474
+ Returns:
475
+ torch.tensor: undistorted points of shape (B, N, 2)
476
+ """
477
+ # print(points_pixel.shape, intrinsics.shape, self.lens_dist_coeff.shape)
478
+ batch_dim, temporal_dim, N, _ = points_pixel.shape
479
+ points_pixel = points_pixel.view(batch_dim * temporal_dim, N, 2)
480
+ true_batch_size = batch_dim
481
+
482
+ lens_dist_coeff = self.lens_dist_coeff
483
+ if true_batch_size < self.batch_dim:
484
+ intrinsics = intrinsics[:true_batch_size]
485
+ lens_dist_coeff = lens_dist_coeff[:true_batch_size]
486
+
487
+ return kornia.geometry.undistort_points(
488
+ points_pixel, intrinsics, dist=lens_dist_coeff, num_iters=num_iters
489
+ ).view(batch_dim, temporal_dim, N, 2)
490
+
491
+ def distort_points(self, points_pixel: torch.tensor, intrinsics) -> torch.tensor:
492
+ """Distortion of a set of 2D points based on the lens distortion model.
493
+
494
+ Wrapper for kornia.geometry.distort_points()
495
+
496
+ Args:
497
+ points_pixel (torch.tensor): tensor of shape (B, N, 2)
498
+
499
+ Returns:
500
+ torch.tensor: distorted points of shape (B, N, 2)
501
+ """
502
+ return kornia.geometry.distort_points(points_pixel, intrinsics, dist=self.lens_dist_coeff)
503
+
504
+ def undistort_images(self, images):
505
+ # images of shape (B, T, C, H, W)
506
+ true_batch_size, T = images.shape[:2]
507
+ images = images.view(true_batch_size * T, 3, self.image_height, self.image_width).to(
508
+ self.device
509
+ )
510
+ intrinsics = self.intrinsics_raster
511
+ lens_dist_coeff = self.lens_dist_coeff
512
+ if true_batch_size < self.batch_dim:
513
+ intrinsics = intrinsics[:true_batch_size]
514
+ lens_dist_coeff = lens_dist_coeff[:true_batch_size]
515
+
516
+ return kornia.geometry.calibration.undistort_image(
517
+ images, intrinsics, lens_dist_coeff
518
+ ).view(true_batch_size, self.temporal_dim, 3, self.image_height, self.image_width)
519
+
520
+ def get_parameters(self, true_batch_size=None):
521
+ """
522
+ Get dict of relevant camera parameters and homography matrix
523
+ :return: The dictionary
524
+ """
525
+ out_dict = {
526
+ "pan_degrees": torch.rad2deg(self.phi_dict["pan"]),
527
+ "tilt_degrees": torch.rad2deg(self.phi_dict["tilt"]),
528
+ "roll_degrees": torch.rad2deg(self.phi_dict["roll"]),
529
+ "position_meters": torch.stack([self.phi_dict[k] for k in ["c_x", "c_y", "c_z"]], dim=1)
530
+ .squeeze(-1)
531
+ .unsqueeze(-2)
532
+ .repeat(1, self.temporal_dim, 1),
533
+ "aov_radian": self.phi_dict["aov"],
534
+ "aov_degrees": torch.rad2deg(self.phi_dict["aov"]),
535
+ "x_focal_length": self.get_fl_from_aov_rad(self.phi_dict["aov"], d=self.image_width),
536
+ "y_focal_length": self.get_fl_from_aov_rad(self.phi_dict["aov"], d=self.image_width),
537
+ "principal_point": torch.tensor(
538
+ [[self.principal_point] * self.temporal_dim] * self.batch_dim
539
+ ),
540
+ }
541
+ out_dict["homography"] = self.get_homography_raster().unsqueeze(1) # (B, 1, 3, 3)
542
+
543
+ # expected for SN evaluation
544
+ out_dict["radial_distortion"] = torch.zeros(self.batch_dim, self.temporal_dim, 6)
545
+ out_dict["tangential_distortion"] = torch.zeros(self.batch_dim, self.temporal_dim, 2)
546
+ out_dict["thin_prism_distortion"] = torch.zeros(self.batch_dim, self.temporal_dim, 4)
547
+
548
+ if self.psi is not None:
549
+ # in case only k1 and k2 are provided
550
+ out_dict["radial_distortion"][..., :2] = self.psi[..., :2]
551
+
552
+ if true_batch_size is None or true_batch_size == self.batch_dim:
553
+ return out_dict
554
+
555
+ for k in out_dict.keys():
556
+ out_dict[k] = out_dict[k][:true_batch_size]
557
+
558
+ return out_dict
559
+
560
+ @staticmethod
561
+ def static_undistort_points(points, cam):
562
+
563
+ intrinsics = cam.intrinsics_raster
564
+ lens_dist_coeff = cam.lens_dist_coeff
565
+
566
+ true_batch_size = points.shape[0]
567
+ if true_batch_size < cam.batch_dim:
568
+ intrinsics = intrinsics[:true_batch_size]
569
+ lens_dist_coeff = lens_dist_coeff[:true_batch_size]
570
+ # points in homogenous coordinates
571
+ # (B, T, 3, S, N) -> (T, 3, S*N) -> (T, S*N, 3)
572
+ batch_size, T, _, S, N = points.shape
573
+ points = points.view(batch_size, T, 3, S * N).transpose(2, 3)
574
+ points[..., :2] = kornia.geometry.undistort_points(
575
+ points[..., :2].view(batch_size * T, S * N, 2),
576
+ intrinsics,
577
+ dist=lens_dist_coeff,
578
+ num_iters=1,
579
+ ).view(batch_size, T, S * N, 2)
580
+
581
+ # (T, S*N, 3) -> (T, 3, S*N) -> (B, T, 3, S, N)
582
+ points = points.transpose(2, 3).view(batch_size, T, 3, S, N)
583
+ return points
tvcalib/data/dataset.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import kornia
2
+ import torch
3
+ import random
4
+ import numpy as np
5
+ from kornia.geometry.transform import resize
6
+ from .utils import split_circle_central
7
+
8
+
9
+
10
+ class InferenceDatasetCalibration(torch.utils.data.Dataset):
11
+ def __init__(self, keypoints_raw, image_width_source, image_height_source, object3d) -> None:
12
+ super().__init__()
13
+ self.keypoints_raw = keypoints_raw
14
+ self.w = image_width_source
15
+ self.h = image_height_source
16
+ self.object3d = object3d
17
+ self.split_circle_central = True
18
+
19
+ def __getitem__(self, idx):
20
+
21
+ keypoints_dict = self.keypoints_raw[idx]
22
+ if self.split_circle_central:
23
+ keypoints_dict = split_circle_central(keypoints_dict)
24
+ # add empty entries for non-visible segments
25
+ for l in self.object3d.segment_names:
26
+ if l not in keypoints_dict:
27
+ keypoints_dict[l] = []
28
+
29
+ per_sample_output = self.prepare_per_sample(
30
+ keypoints_dict, self.object3d, 4, 8, self.w, self.h, pad_pixel_position_xy=0.0
31
+ )
32
+ for k in per_sample_output.keys():
33
+ per_sample_output[k] = per_sample_output[k].unsqueeze(0)
34
+
35
+ return per_sample_output
36
+
37
+ def __len__(self):
38
+ return len(self.keypoints_raw)
39
+
40
+ @staticmethod
41
+ def prepare_per_sample(
42
+ keypoints_raw: dict,
43
+ model3d,
44
+ num_points_on_line_segments: int,
45
+ num_points_on_circle_segments: int,
46
+ image_width_source: int,
47
+ image_height_source: int,
48
+ pad_pixel_position_xy=0.0,
49
+ ):
50
+ r = {}
51
+ pixel_stacked = {}
52
+ for label, points in keypoints_raw.items():
53
+
54
+ num_points_selection = num_points_on_line_segments
55
+ if "Circle" in label:
56
+ num_points_selection = num_points_on_circle_segments
57
+
58
+ # rand select num_points_selection
59
+ if num_points_selection > len(points):
60
+ points_sel = points
61
+ else:
62
+ # random sample without replacement
63
+ points_sel = random.sample(points, k=num_points_selection)
64
+
65
+ if len(points_sel) > 0:
66
+ xx = torch.tensor([a["x"] for a in points_sel])
67
+ yy = torch.tensor([a["y"] for a in points_sel])
68
+ pixel_stacked[label] = torch.stack([xx, yy], dim=-1) # (?, 2)
69
+ # scale pixel annotations from [0, 1] range to source image resolution
70
+ # as this ranges from [1, {image_height, image_width}] shift pixel one left
71
+ pixel_stacked[label][:, 0] = pixel_stacked[label][:, 0] * (image_width_source - 1)
72
+ pixel_stacked[label][:, 1] = pixel_stacked[label][:, 1] * (image_height_source - 1)
73
+
74
+ for segment_type, num_segments, segment_names in [
75
+ ("lines", model3d.line_segments.shape[1], model3d.line_segments_names),
76
+ ("circles", model3d.circle_segments.shape[1], model3d.circle_segments_names),
77
+ ]:
78
+
79
+ num_points_selection = num_points_on_line_segments
80
+ if segment_type == "circles":
81
+ num_points_selection = num_points_on_circle_segments
82
+ px_projected_selection = (
83
+ torch.zeros((num_segments, num_points_selection, 2)) + pad_pixel_position_xy
84
+ )
85
+ for segment_index, label in enumerate(segment_names):
86
+ if label in pixel_stacked:
87
+ # set annotations to first positions
88
+ px_projected_selection[
89
+ segment_index, : pixel_stacked[label].shape[0], :
90
+ ] = pixel_stacked[label]
91
+
92
+ randperm = torch.randperm(num_points_selection)
93
+ px_projected_selection_shuffled = px_projected_selection.clone()
94
+ px_projected_selection_shuffled[:, :, 0] = px_projected_selection_shuffled[
95
+ :, randperm, 0
96
+ ]
97
+ px_projected_selection_shuffled[:, :, 1] = px_projected_selection_shuffled[
98
+ :, randperm, 1
99
+ ]
100
+
101
+ is_keypoint_mask = (
102
+ (0.0 <= px_projected_selection_shuffled[:, :, 0])
103
+ & (px_projected_selection_shuffled[:, :, 0] < image_width_source)
104
+ ) & (
105
+ (0 < px_projected_selection_shuffled[:, :, 1])
106
+ & (px_projected_selection_shuffled[:, :, 1] < image_height_source)
107
+ )
108
+
109
+ r[f"{segment_type}__is_keypoint_mask"] = is_keypoint_mask.unsqueeze(0)
110
+
111
+ # reshape from (num_segments, num_points_selection, 2) to (3, num_segments, num_points_selection)
112
+ px_projected_selection_shuffled = (
113
+ kornia.geometry.conversions.convert_points_to_homogeneous(
114
+ px_projected_selection_shuffled
115
+ )
116
+ )
117
+ px_projected_selection_shuffled = px_projected_selection_shuffled.view(
118
+ num_segments * num_points_selection, 3
119
+ )
120
+ px_projected_selection_shuffled = px_projected_selection_shuffled.transpose(0, 1)
121
+ px_projected_selection_shuffled = px_projected_selection_shuffled.view(
122
+ 3, num_segments, num_points_selection
123
+ )
124
+ # (3, num_segments, num_points_selection)
125
+ r[f"{segment_type}__px_projected_selection_shuffled"] = px_projected_selection_shuffled
126
+
127
+ ndc_projected_selection_shuffled = px_projected_selection_shuffled.clone()
128
+ ndc_projected_selection_shuffled[0] = (
129
+ ndc_projected_selection_shuffled[0] / image_width_source
130
+ )
131
+ ndc_projected_selection_shuffled[1] = (
132
+ ndc_projected_selection_shuffled[1] / image_height_source
133
+ )
134
+ ndc_projected_selection_shuffled[1] = ndc_projected_selection_shuffled[1] * 2.0 - 1
135
+ ndc_projected_selection_shuffled[0] = ndc_projected_selection_shuffled[0] * 2.0 - 1
136
+ r[
137
+ f"{segment_type}__ndc_projected_selection_shuffled"
138
+ ] = ndc_projected_selection_shuffled
139
+
140
+ return r
141
+
142
+
tvcalib/data/utils.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from operator import itemgetter
3
+ import torch
4
+ import re
5
+ import collections
6
+
7
+
8
+ string_classes=str
9
+
10
+
11
+ def split_circle_central(keypoints_dict):
12
+ # split "circle central" in "circle central left" and "circle central right"
13
+
14
+ # assume main camera --> TODO behind the goal camera
15
+ if "Circle central" in keypoints_dict:
16
+ points_circle_central_left = []
17
+ points_circle_central_right = []
18
+
19
+ if "Middle line" in keypoints_dict:
20
+ p_index_ymin, _ = min(
21
+ enumerate([p["y"] for p in keypoints_dict["Middle line"]]),
22
+ key=itemgetter(1),
23
+ )
24
+ p_index_ymax, _ = max(
25
+ enumerate([p["y"] for p in keypoints_dict["Middle line"]]),
26
+ key=itemgetter(1),
27
+ )
28
+ p_ymin = keypoints_dict["Middle line"][p_index_ymin]
29
+ p_ymax = keypoints_dict["Middle line"][p_index_ymax]
30
+ p_xmean = (p_ymin["x"] + p_ymax["x"]) / 2
31
+
32
+ points_circle_central = keypoints_dict["Circle central"]
33
+ for p in points_circle_central:
34
+ if p["x"] < p_xmean:
35
+ points_circle_central_left.append(p)
36
+ else:
37
+ points_circle_central_right.append(p)
38
+ else:
39
+ # circle is partly shown on the left or right side of the image
40
+ # mean position is shown on the left part of the image --> label right
41
+ circle_x = [p["x"] for p in keypoints_dict["Circle central"]]
42
+ mean_x_circle = sum(circle_x) / len(circle_x)
43
+ if mean_x_circle < 0.5:
44
+ points_circle_central_right = keypoints_dict["Circle central"]
45
+ else:
46
+ points_circle_central_left = keypoints_dict["Circle central"]
47
+
48
+ if len(points_circle_central_left) > 0:
49
+ keypoints_dict["Circle central left"] = points_circle_central_left
50
+ if len(points_circle_central_right) > 0:
51
+ keypoints_dict["Circle central right"] = points_circle_central_right
52
+ if len(points_circle_central_left) == 0 and len(points_circle_central_right) == 0:
53
+ raise RuntimeError
54
+ del keypoints_dict["Circle central"]
55
+ return keypoints_dict
56
+
57
+
58
+ def custom_list_collate(batch):
59
+ r"""
60
+ Function that takes in a batch of data and puts the elements within the batch
61
+ into a tensor with an additional outer dimension - batch size. The exact output type can be
62
+ a :class:`torch.Tensor`, a `Sequence` of :class:`torch.Tensor`, a
63
+ Collection of :class:`torch.Tensor`, or left unchanged, depending on the input type.
64
+ This is used as the default function for collation when
65
+ `batch_size` or `batch_sampler` is defined in :class:`~torch.utils.data.DataLoader`.
66
+ Here is the general input type (based on the type of the element within the batch) to output type mapping:
67
+ * :class:`torch.Tensor` -> :class:`torch.Tensor` (with an added outer dimension batch size)
68
+ * NumPy Arrays -> :class:`torch.Tensor`
69
+ * `float` -> :class:`torch.Tensor`
70
+ * `int` -> :class:`torch.Tensor`
71
+ * `str` -> `str` (unchanged)
72
+ * `bytes` -> `bytes` (unchanged)
73
+ * `Mapping[K, V_i]` -> `Mapping[K, default_collate([V_1, V_2, ...])]`
74
+ * `NamedTuple[V1_i, V2_i, ...]` -> `NamedTuple[default_collate([V1_1, V1_2, ...]), default_collate([V2_1, V2_2, ...]), ...]`
75
+ * `Sequence[V1_i, V2_i, ...]` -> `Sequence[default_collate([V1_1, V1_2, ...]), default_collate([V2_1, V2_2, ...]), ...]`
76
+ Args:
77
+ batch: a single batch to be collated
78
+ Examples:
79
+ >>> # Example with a batch of `int`s:
80
+ >>> default_collate([0, 1, 2, 3])
81
+ tensor([0, 1, 2, 3])
82
+ >>> # Example with a batch of `str`s:
83
+ >>> default_collate(['a', 'b', 'c'])
84
+ ['a', 'b', 'c']
85
+ >>> # Example with `Map` inside the batch:
86
+ >>> default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}])
87
+ {'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])}
88
+ >>> # Example with `NamedTuple` inside the batch:
89
+ >>> Point = namedtuple('Point', ['x', 'y'])
90
+ >>> default_collate([Point(0, 0), Point(1, 1)])
91
+ Point(x=tensor([0, 1]), y=tensor([0, 1]))
92
+ >>> # Example with `Tuple` inside the batch:
93
+ >>> default_collate([(0, 1), (2, 3)])
94
+ [tensor([0, 2]), tensor([1, 3])]
95
+
96
+ >>> # modification
97
+ >>> # Example with `List` inside the batch:
98
+ >>> default_collate([[0, 1, 2], [2, 3, 4]])
99
+ >>> [[0, 1, 2], [2, 3, 4]]
100
+ >>> # original behavior
101
+ >>> [[0, 2], [1, 3], [2, 4]]
102
+ """
103
+
104
+ np_str_obj_array_pattern = re.compile(r"[SaUO]")
105
+ default_collate_err_msg_format = "default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found {}"
106
+
107
+ elem = batch[0]
108
+ elem_type = type(elem)
109
+ if isinstance(elem, torch.Tensor):
110
+ out = None
111
+ if torch.utils.data.get_worker_info() is not None:
112
+ # If we're in a background process, concatenate directly into a
113
+ # shared memory tensor to avoid an extra copy
114
+ numel = sum(x.numel() for x in batch)
115
+ storage = elem.storage()._new_shared(numel)
116
+ out = elem.new(storage).resize_(len(batch), *list(elem.size()))
117
+ return torch.stack(batch, 0, out=out)
118
+ elif (
119
+ elem_type.__module__ == "numpy"
120
+ and elem_type.__name__ != "str_"
121
+ and elem_type.__name__ != "string_"
122
+ ):
123
+ if elem_type.__name__ == "ndarray" or elem_type.__name__ == "memmap":
124
+ # array of string classes and object
125
+ if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
126
+ raise TypeError(default_collate_err_msg_format.format(elem.dtype))
127
+
128
+ return [torch.as_tensor(b) for b in batch]
129
+ elif elem.shape == (): # scalars
130
+ return torch.as_tensor(batch)
131
+ elif isinstance(elem, float):
132
+ return torch.tensor(batch, dtype=torch.float64)
133
+ elif isinstance(elem, int):
134
+ return torch.tensor(batch)
135
+ elif isinstance(elem, string_classes):
136
+ return batch
137
+ elif isinstance(elem, collections.abc.Mapping):
138
+ try:
139
+ return elem_type({key: custom_list_collate([d[key] for d in batch]) for key in elem})
140
+ except TypeError:
141
+ # The mapping type may not support `__init__(iterable)`.
142
+ return {key: custom_list_collate([d[key] for d in batch]) for key in elem}
143
+ elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple
144
+ return elem_type(*(custom_list_collate(samples) for samples in zip(*batch)))
145
+ elif isinstance(elem, collections.abc.Sequence):
146
+ # check to make sure that the elements in batch have consistent size
147
+ it = iter(batch)
148
+ elem_size = len(next(it))
149
+ if not all(len(elem) == elem_size for elem in it):
150
+ raise RuntimeError("each element in list of batch should be of equal size")
151
+ # transposed = list(zip(*batch)) # It may be accessed twice, so we use a list.
152
+
153
+ return batch
154
+
155
+ # if isinstance(elem, tuple):
156
+ # return [
157
+ # custom_list_collate(samples) for samples in transposed
158
+ # ] # Backwards compatibility.
159
+ # else:
160
+ # try:
161
+ # return elem_type([custom_list_collate(samples) for samples in transposed])
162
+ # except TypeError:
163
+ # # The sequence type may not support `__init__(iterable)` (e.g., `range`).
164
+ # return [custom_list_collate(samples) for samples in transposed]
165
+
166
+ raise TypeError(default_collate_err_msg_format.format(elem_type))
tvcalib/infer/module.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from functools import partial
4
+
5
+ from pathlib import Path
6
+ from typing import Any, Dict, Tuple
7
+ # Imports depuis le package common (supposé être au même niveau que tvcalib)
8
+ from common.infer.base import *
9
+ # from common.registry import Registry # Toujours commenté car source inconnue
10
+ # from common.utils import to_cuda # Toujours commenté car source inconnue
11
+ # import project as p # Supprimé car probablement lié au projet complet
12
+
13
+ import torchvision.transforms as T
14
+ # Imports relatifs à l'intérieur de tvcalib (restent relatifs)
15
+ from ..sn_segmentation.src.custom_extremities import (
16
+ generate_class_synthesis, get_line_extremities
17
+ )
18
+ from ..models.segmentation import InferenceSegmentationModel
19
+ from ..data.dataset import InferenceDatasetCalibration
20
+ from ..data.utils import custom_list_collate
21
+ from ..cam_modules import CameraParameterWLensDistDictZScore, SNProjectiveCamera
22
+ from ..utils.linalg import distance_line_pointcloud_3d, distance_point_pointcloud
23
+ from ..utils.objects_3d import SoccerPitchLineCircleSegments, SoccerPitchSNCircleCentralSplit
24
+ from ..cam_distr.tv_main_center import get_cam_distr, get_dist_distr
25
+ from ..utils.io import detach_dict, tensor2list
26
+ # Import depuis le package common
27
+ from common.data.utils import yards
28
+
29
+ from kornia.geometry.conversions import convert_points_to_homogeneous
30
+ from tqdm.auto import tqdm
31
+
32
+ # Commenté car lié à la méthode 'robust' et peut introduire des dépendances
33
+ # from methods.robust.loggers.preview import RobustPreviewLogger
34
+
35
+ import numpy as np
36
+
37
+
38
+
39
+ class TvCalibInferModule(InferModule):
40
+ def __init__(
41
+ self,
42
+ segmentation_checkpoint: Path,
43
+ image_shape=(720,1280),
44
+ optim_steps=2000,
45
+ lens_dist: bool=False,
46
+ playfield_size=(105, 68),
47
+ make_images: bool=False
48
+
49
+ ):
50
+ self.image_shape = image_shape
51
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
52
+ self.make_images = make_images
53
+
54
+ # We use the logger to draw visualizations
55
+ # Commenté car la classe RobustPreviewLogger est commentée
56
+ # self.previewer = RobustPreviewLogger(
57
+ # None, num_images=1
58
+ # )
59
+
60
+ self.fn_generate_class_synthesis = partial(
61
+ generate_class_synthesis,
62
+ radius=4
63
+ )
64
+ self.fn_get_line_extremities = partial(
65
+ get_line_extremities,
66
+ maxdist=30,
67
+ width=455,
68
+ height=256,
69
+ num_points_lines=4,
70
+ num_points_circles=8
71
+ )
72
+
73
+ # Segmentation model
74
+ self.model_seg = InferenceSegmentationModel(
75
+ segmentation_checkpoint,
76
+ self.device
77
+ )
78
+
79
+ self.object3d = SoccerPitchLineCircleSegments(
80
+ device=self.device,
81
+ base_field=SoccerPitchSNCircleCentralSplit()
82
+ )
83
+ self.object3dcpu = SoccerPitchLineCircleSegments(
84
+ device="cpu",
85
+ base_field=SoccerPitchSNCircleCentralSplit()
86
+ )
87
+
88
+ # Calibration module
89
+ batch_size_calib = 1
90
+ self.model_calib = TVCalibModule(
91
+ self.object3d,
92
+ get_cam_distr(1.96, batch_size_calib, 1),
93
+ get_dist_distr(batch_size_calib, 1) if lens_dist else None,
94
+ (image_shape[0], image_shape[1]),
95
+ optim_steps,
96
+ self.device,
97
+ log_per_step=False,
98
+ tqdm_kwqargs=None,
99
+ )
100
+ self.resize = T.Compose([
101
+ T.Resize(size=(256,455))
102
+ ])
103
+ self.offset = np.array([
104
+ [1, 0, playfield_size[0]/2.0 ],
105
+ [0, 1, playfield_size[1]/2.0 ],
106
+ [0, 0, 1]
107
+ ])
108
+
109
+
110
+
111
+ def setup(self, datamodule: InferDataModule):
112
+ pass
113
+
114
+
115
+ def predict(self, x: Any) -> Dict:
116
+
117
+ """
118
+ 1. Run segmentation & Pick keypoints
119
+ 2. Calibrate based on selected points
120
+ """
121
+
122
+ # Segment
123
+ image = x["image"]
124
+ keypoints = self._segment(x["image"])
125
+
126
+ # Calibrate
127
+ homo = self._calibrate(keypoints)
128
+
129
+ # Rescale to 720p
130
+ image_720p = self.previewer.to_image(image.clone().detach().cpu())
131
+
132
+ # Draw predicted playing field
133
+ if (homo is not None):
134
+ # to yards
135
+ # Commenté car previewer est commenté
136
+ # to_yards = np.array([
137
+ # [ yards(1.0), 0, 0 ],
138
+ # [ 0, yards(1.0), 0 ],
139
+ # [ 0, 0, 1]
140
+ # ])
141
+ #homo = to_yards @ homo
142
+
143
+ # Commenté car previewer est commenté
144
+ # try:
145
+ # inv_homo = np.linalg.inv(homo) @ self.previewer.scale
146
+ # image_720p = self.previewer.draw_playfield(
147
+ # image_720p,
148
+ # self.previewer.image_playfield,
149
+ # inv_homo,
150
+ # color=(255,0,0), alpha=1.0,
151
+ # flip=False
152
+ # )
153
+ # except:
154
+ # # Homography might
155
+ # pass
156
+ pass # Placeholder si l'homographie existe mais previewer est commenté
157
+
158
+ result = {
159
+ "homography": homo
160
+ }
161
+
162
+ if (self.make_images):
163
+ # result["image_720p"] = image_720p # Commenté car image_720p n'est pas modifié sans previewer
164
+ pass # Placeholder si make_images est True
165
+
166
+ return result
167
+
168
+
169
+ def _segment(self, image):
170
+
171
+ # Image -> <1;3;256;455>
172
+ image = self.resize(image)
173
+ with torch.no_grad():
174
+ sem_lines = self.model_seg.inference(
175
+ image.unsqueeze(0).to(self.device)
176
+ )
177
+ # <B;256;455>
178
+ sem_lines = sem_lines.detach().cpu().numpy().astype(np.uint8)
179
+
180
+ # Point selection
181
+ skeletons_batch = self.fn_generate_class_synthesis(sem_lines[0])
182
+ keypoints_raw_batch = self.fn_get_line_extremities(skeletons_batch)
183
+
184
+ # Return the keypoints
185
+ return keypoints_raw_batch
186
+
187
+
188
+ def _calibrate(self, keypoints):
189
+
190
+ # Just wrap around the keypoints
191
+ ds = InferenceDatasetCalibration(
192
+ [keypoints],
193
+ self.image_shape[1], self.image_shape[0],
194
+ self.object3d
195
+ )
196
+
197
+ # Get the first item and optimize it
198
+ _batch_size = 1
199
+ x_dict = custom_list_collate([ds[0]])
200
+ try:
201
+ # La gestion de previous_params est faite dans self_optim_batch
202
+ per_sample_loss, cam, _ = self.model_calib.self_optim_batch(x_dict)
203
+ output_dict = tensor2list(
204
+ detach_dict({**cam.get_parameters(_batch_size), **per_sample_loss})
205
+ )
206
+
207
+ homo = output_dict["homography"][0]
208
+ if (len(homo) > 0):
209
+ homo = np.array(homo[0])
210
+
211
+ to_yards = np.array([
212
+ [ yards(1), 0, 0 ],
213
+ [ 0, yards(1), 0 ],
214
+ [ 0, 0, 1]
215
+ ])
216
+
217
+ # Shift the homography by half the playing field
218
+ homo = to_yards @ self.offset @ homo
219
+
220
+ else:
221
+ homo = None
222
+ except Exception as e:
223
+ print(f"Erreur lors de la calibration: {str(e)}")
224
+ homo = None
225
+
226
+ return homo
227
+
228
+
229
+
230
+
231
+
232
+
233
+
234
+ class TVCalibModule(torch.nn.Module):
235
+ def __init__(
236
+ self,
237
+ model3d,
238
+ cam_distr,
239
+ dist_distr,
240
+ image_dim: Tuple[int, int],
241
+ optim_steps: int,
242
+ device="cpu",
243
+ tqdm_kwqargs=None,
244
+ log_per_step=False,
245
+ *args,
246
+ **kwargs,
247
+ ) -> None:
248
+ super().__init__(*args, **kwargs)
249
+ self.image_height, self.image_width = image_dim
250
+ self.principal_point = (self.image_width / 2, self.image_height / 2)
251
+ self.model3d = model3d
252
+ self.cam_param_dict = CameraParameterWLensDistDictZScore(
253
+ cam_distr, dist_distr, device=device
254
+ )
255
+
256
+ self.lens_distortion_active = False if dist_distr is None else True
257
+ self.optim_steps = optim_steps
258
+ self._device = device
259
+
260
+ # Ajouter l'attribut pour stocker les paramètres précédents
261
+ self.previous_params = None
262
+
263
+ self.optim = torch.optim.AdamW(
264
+ self.cam_param_dict.param_dict.parameters(), lr=0.1, weight_decay=0.01
265
+ )
266
+ self.Scheduler = partial(
267
+ torch.optim.lr_scheduler.OneCycleLR,
268
+ max_lr=0.05,
269
+ total_steps=self.optim_steps,
270
+ pct_start=0.5,
271
+ )
272
+
273
+ if self.lens_distortion_active:
274
+ self.optim_lens_distortion = torch.optim.AdamW(
275
+ self.cam_param_dict.param_dict_dist.parameters(), lr=1e-3, weight_decay=0.01
276
+ )
277
+ self.Scheduler_lens_distortion = partial(
278
+ torch.optim.lr_scheduler.OneCycleLR,
279
+ max_lr=1e-3,
280
+ total_steps=self.optim_steps,
281
+ pct_start=0.33,
282
+ optimizer=self.optim_lens_distortion,
283
+ )
284
+
285
+ self.tqdm_kwqargs = tqdm_kwqargs
286
+ if tqdm_kwqargs is None:
287
+ self.tqdm_kwqargs = {}
288
+
289
+ self.hparams = {"optim": str(self.optim), "scheduler": str(self.Scheduler)}
290
+ self.log_per_step = log_per_step
291
+
292
+ def forward(self, x):
293
+
294
+ # individual camera parameters & distortion parameters
295
+ phi_hat, psi_hat = self.cam_param_dict()
296
+
297
+ cam = SNProjectiveCamera(
298
+ phi_hat,
299
+ psi_hat,
300
+ self.principal_point,
301
+ self.image_width,
302
+ self.image_height,
303
+ device=self._device,
304
+ nan_check=False,
305
+ )
306
+
307
+ # (batch_size, num_views_per_cam, 3, num_segments, num_points)
308
+ points_px_lines_true = x["lines__ndc_projected_selection_shuffled"].to(self._device)
309
+ batch_size, T_l, _, S_l, N_l = points_px_lines_true.shape
310
+
311
+ # project circle points
312
+ points_px_circles_true = x["circles__ndc_projected_selection_shuffled"].to(self._device)
313
+ _, T_c, _, S_c, N_c = points_px_circles_true.shape
314
+ assert T_c == T_l
315
+
316
+ #################### line-to-point distance at pixel space ####################
317
+ # start and end point (in world coordinates) for each line segment
318
+ points3d_lines_keypoints = self.model3d.line_segments # (3, S_l, 2) to (S_l * 2, 3)
319
+ points3d_lines_keypoints = points3d_lines_keypoints.reshape(3, S_l * 2).transpose(0, 1)
320
+ points_px_lines_keypoints = convert_points_to_homogeneous(
321
+ cam.project_point2ndc(points3d_lines_keypoints, lens_distortion=False)
322
+ ) # (batch_size, t_l, S_l*2, 3)
323
+
324
+ if batch_size < cam.batch_dim: # actual batch_size smaller than expected, i.e. last batch
325
+ points_px_lines_keypoints = points_px_lines_keypoints[:batch_size]
326
+
327
+ points_px_lines_keypoints = points_px_lines_keypoints.view(batch_size, T_l, S_l, 2, 3)
328
+
329
+ lp1 = points_px_lines_keypoints[..., 0, :].unsqueeze(-2) # -> (batch_size, T_l, 1, S_l, 3)
330
+ lp2 = points_px_lines_keypoints[..., 1, :].unsqueeze(-2) # -> (batch_size, T_l, 1, S_l, 3)
331
+ # (batch_size, T, 3, S, N) -> (batch_size, T, 3, S*N) -> (batch_size, T, S*N, 3) -> (batch_size, T, S, N, 3)
332
+ pc = (
333
+ points_px_lines_true.view(batch_size, T_l, 3, S_l * N_l)
334
+ .transpose(2, 3)
335
+ .view(batch_size, T_l, S_l, N_l, 3)
336
+ )
337
+
338
+ if self.lens_distortion_active:
339
+ # undistort given points
340
+ pc = pc.view(batch_size, T_l, S_l * N_l, 3)
341
+ pc = pc.detach().clone()
342
+ pc[..., :2] = cam.undistort_points(
343
+ pc[..., :2], cam.intrinsics_ndc, num_iters=1
344
+ ) # num_iters=1 might be enough for a good approximation
345
+ pc = pc.view(batch_size, T_l, S_l, N_l, 3)
346
+
347
+ distances_px_lines_raw = distance_line_pointcloud_3d(
348
+ e1=lp2 - lp1, r1=lp1, pc=pc, reduce=None
349
+ ) # (batch_size, T_l, S_l, N_l)
350
+ distances_px_lines_raw = distances_px_lines_raw.unsqueeze(-3)
351
+ # (..., 1, S_l, N_l,), i.e. (batch_size, T, 1, S_l, N_l)
352
+ #################### circle-to-point distance at pixel space ####################
353
+
354
+ # circle segments are approximated as point clouds of size N_c_star
355
+ points3d_circles_pc = self.model3d.circle_segments
356
+ _, S_c, N_c_star = points3d_circles_pc.shape
357
+ points3d_circles_pc = points3d_circles_pc.reshape(3, S_c * N_c_star).transpose(0, 1)
358
+ points_px_circles_pc = cam.project_point2ndc(points3d_circles_pc, lens_distortion=False)
359
+
360
+ if batch_size < cam.batch_dim: # actual batch_size smaller than expected, i.e. last batch
361
+ points_px_circles_pc = points_px_circles_pc[:batch_size]
362
+
363
+ if self.lens_distortion_active:
364
+ # (batch_size, T_c, _, S_c, N_c)
365
+ points_px_circles_true = points_px_circles_true.view(
366
+ batch_size, T_c, 3, S_c * N_c
367
+ ).transpose(2, 3)
368
+ points_px_circles_true = points_px_circles_true.detach().clone()
369
+ points_px_circles_true[..., :2] = cam.undistort_points(
370
+ points_px_circles_true[..., :2], cam.intrinsics_ndc, num_iters=1
371
+ )
372
+ points_px_circles_true = points_px_circles_true.transpose(2, 3).view(
373
+ batch_size, T_c, 3, S_c, N_c
374
+ )
375
+
376
+ distances_px_circles_raw = distance_point_pointcloud(
377
+ points_px_circles_true, points_px_circles_pc.view(batch_size, T_c, S_c, N_c_star, 2)
378
+ )
379
+
380
+ distances_dict = {
381
+ "loss_ndc_lines": distances_px_lines_raw, # (batch_size, T_l, 1, S_l, N_l)
382
+ "loss_ndc_circles": distances_px_circles_raw, # (batch_size, T_c, 1, S_c, N_c)
383
+ }
384
+ return distances_dict, cam
385
+
386
+ def self_optim_batch(self, x, *args, **kwargs):
387
+
388
+ scheduler = self.Scheduler(self.optim) # re-initialize lr scheduler for every batch
389
+ if self.lens_distortion_active:
390
+ scheduler_lens_distortion = self.Scheduler_lens_distortion()
391
+
392
+ # Initialiser avec les paramètres précédents si disponibles
393
+ if self.previous_params is not None:
394
+ print("Utilisation des paramètres précédents pour l'initialisation")
395
+ update_dict = {}
396
+ for k, v in self.previous_params.items():
397
+ update_dict[k] = v.detach().clone()
398
+ self.cam_param_dict.initialize(update_dict)
399
+ else:
400
+ print("Première frame : initialisation à zéro")
401
+ self.cam_param_dict.initialize(None)
402
+
403
+ self.optim.zero_grad()
404
+ if self.lens_distortion_active:
405
+ self.optim_lens_distortion.zero_grad()
406
+
407
+ keypoint_masks = {
408
+ "loss_ndc_lines": x["lines__is_keypoint_mask"].to(self._device),
409
+ "loss_ndc_circles": x["circles__is_keypoint_mask"].to(self._device),
410
+ }
411
+ num_actual_points = {
412
+ "loss_ndc_circles": keypoint_masks["loss_ndc_circles"].sum(dim=(-1, -2)),
413
+ "loss_ndc_lines": keypoint_masks["loss_ndc_lines"].sum(dim=(-1, -2)),
414
+ }
415
+
416
+ per_sample_loss = {}
417
+ per_sample_loss["mask_lines"] = keypoint_masks["loss_ndc_lines"]
418
+ per_sample_loss["mask_circles"] = keypoint_masks["loss_ndc_circles"]
419
+
420
+ per_step_info = {"loss": [], "lr": []}
421
+
422
+ # Paramètres pour les critères d'arrêt
423
+ loss_target = 0.001 # Réduit pour une meilleure précision potentielle
424
+ loss_patience = 10 # Nombre d'itérations pour vérifier la stagnation
425
+ loss_tolerance = 1e-4 # Tolérance pour la variation relative de loss
426
+ loss_history = [] # Historique des valeurs de loss
427
+ best_loss = float('inf') # Meilleure loss obtenue
428
+ steps_without_improvement = 0 # Compteur d'itérations sans amélioration
429
+
430
+ # with torch.autograd.detect_anomaly():
431
+ with tqdm(range(self.optim_steps), **self.tqdm_kwqargs) as pbar:
432
+ for step in pbar:
433
+ self.optim.zero_grad()
434
+ if self.lens_distortion_active:
435
+ self.optim_lens_distortion.zero_grad()
436
+
437
+ # forward pass
438
+ distances_dict, cam = self(x)
439
+
440
+ # distance calculate with masked input and output
441
+ losses = {}
442
+ for key_dist, distances in distances_dict.items():
443
+ distances[~keypoint_masks[key_dist]] = 0.0
444
+ per_sample_loss[f"{key_dist}_distances_raw"] = distances
445
+ distances_reduced = distances.sum(dim=(-1, -2))
446
+ distances_reduced = distances_reduced / num_actual_points[key_dist]
447
+ distances_reduced[num_actual_points[key_dist] == 0] = 0.0
448
+ distances_reduced = distances_reduced.squeeze(-1)
449
+ per_sample_loss[key_dist] = distances_reduced
450
+ loss = distances_reduced.mean(dim=-1)
451
+ loss = loss.sum()
452
+ losses[key_dist] = loss
453
+
454
+ loss_total_dist = losses["loss_ndc_lines"] + losses["loss_ndc_circles"]
455
+ loss_total = loss_total_dist
456
+ current_loss = loss_total.item()
457
+
458
+ # Mettre à jour l'historique des loss
459
+ loss_history.append(current_loss)
460
+
461
+ # Vérifier si on a une meilleure loss
462
+ if current_loss < best_loss:
463
+ best_loss = current_loss
464
+ steps_without_improvement = 0
465
+ else:
466
+ steps_without_improvement += 1
467
+
468
+ # Critères d'arrêt (commentés pour forcer le nombre total d'étapes)
469
+ # if len(loss_history) >= loss_patience:
470
+ # # Calculer la variation relative moyenne sur les dernières itérations
471
+ # recent_losses = loss_history[-loss_patience:]
472
+ # # Gérer le cas où toutes les pertes récentes sont nulles ou proches de zéro
473
+ # max_recent_loss = max(max(recent_losses), 1e-9) # Evite division par zéro
474
+ # loss_variation = abs(max(recent_losses) - min(recent_losses)) / max_recent_loss
475
+ #
476
+ # # Conditions d'arrêt
477
+ # if (current_loss <= loss_target or # On a atteint la valeur cible
478
+ # loss_variation < loss_tolerance or # La loss ne varie plus significativement
479
+ # steps_without_improvement >= loss_patience): # Pas d'amélioration depuis un moment
480
+ # print(f"\nArrêt anticipé à l'itération {step+1}:")
481
+ # print(f"Loss finale: {current_loss:.5f}")
482
+ # print(f"Meilleure loss: {best_loss:.5f}")
483
+ # print(f"Variation relative: {loss_variation:.6f}")
484
+ # break
485
+
486
+ if self.log_per_step:
487
+ per_step_info["lr"].append(scheduler.get_last_lr())
488
+ per_step_info["loss"].append(distances_reduced)
489
+ if step % 50 == 0:
490
+ pbar.set_postfix(
491
+ loss=f"{loss_total_dist.detach().cpu().tolist():.5f}",
492
+ loss_lines=f'{losses["loss_ndc_lines"].detach().cpu().tolist():.3f}',
493
+ loss_circles=f'{losses["loss_ndc_circles"].detach().cpu().tolist():.3f}',
494
+ )
495
+
496
+ loss_total.backward()
497
+ self.optim.step()
498
+ scheduler.step()
499
+ if self.lens_distortion_active:
500
+ self.optim_lens_distortion.step()
501
+ scheduler_lens_distortion.step()
502
+
503
+ # Sauvegarder les paramètres optimisés pour la prochaine frame
504
+ self.previous_params = {}
505
+ for k, v in self.cam_param_dict.param_dict.items():
506
+ self.previous_params[k] = v.detach().clone()
507
+
508
+ per_sample_loss["loss_ndc_total"] = torch.sum(
509
+ torch.stack([per_sample_loss[key_dist] for key_dist in distances_dict.keys()], dim=0),
510
+ dim=0,
511
+ )
512
+
513
+ if self.log_per_step:
514
+ per_step_info["loss"] = torch.stack(
515
+ per_step_info["loss"], dim=-1
516
+ )
517
+ per_step_info["lr"] = torch.tensor(per_step_info["lr"])
518
+ return per_sample_loss, cam, per_step_info
tvcalib/models/segmentation.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+ from pathlib import Path
3
+ import torch
4
+
5
+ from torchvision.models.segmentation import deeplabv3_resnet101
6
+ from SoccerNet.Evaluation.utils_calibration import SoccerPitch
7
+
8
+
9
+
10
+ class InferenceSegmentationModel:
11
+ def __init__(self, checkpoint: Union[str, Path], device) -> None:
12
+ self.device = device
13
+ self.model = deeplabv3_resnet101(
14
+ num_classes=len(SoccerPitch.lines_classes) + 1, aux_loss=True
15
+ )
16
+ checkpoint_data = torch.load(checkpoint, map_location=self.device, weights_only=False)
17
+ self.model.load_state_dict(checkpoint_data["model"], strict=False)
18
+ self.model.to(self.device)
19
+ self.model.eval()
20
+
21
+ def inference(self, img_batch):
22
+ return self.model(img_batch)["out"].argmax(1)
tvcalib/sn_segmentation/resources/mean.npy ADDED
Binary file (152 Bytes). View file
 
tvcalib/sn_segmentation/resources/std.npy ADDED
Binary file (152 Bytes). View file
 
tvcalib/sn_segmentation/src/baseline_extremities.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import copy
3
+ import json
4
+ import os.path
5
+ import random
6
+ from collections import deque
7
+ from pathlib import Path
8
+
9
+ import cv2 as cv
10
+ import numpy as np
11
+ import torch
12
+ import torch.backends.cudnn
13
+ import torch.nn as nn
14
+ from PIL import Image
15
+ from torchvision.models.segmentation import deeplabv3_resnet50
16
+ from tqdm import tqdm
17
+
18
+ from SoccerNet.Evaluation.utils_calibration import SoccerPitch
19
+
20
+
21
+ def generate_class_synthesis(semantic_mask, radius):
22
+ """
23
+ This function selects for each class present in the semantic mask, a set of circles that cover most of the semantic
24
+ class blobs.
25
+ :param semantic_mask: a image containing the segmentation predictions
26
+ :param radius: circle radius
27
+ :return: a dictionary which associates with each class detected a list of points ( the circles centers)
28
+ """
29
+ buckets = dict()
30
+ kernel = np.ones((5, 5), np.uint8)
31
+ semantic_mask = cv.erode(semantic_mask, kernel, iterations=1)
32
+ for k, class_name in enumerate(SoccerPitch.lines_classes):
33
+ mask = semantic_mask == k + 1
34
+ if mask.sum() > 0:
35
+ disk_list = synthesize_mask(mask, radius)
36
+ if len(disk_list):
37
+ buckets[class_name] = disk_list
38
+
39
+ return buckets
40
+
41
+
42
+ def join_points(point_list, maxdist):
43
+ """
44
+ Given a list of points that were extracted from the blobs belonging to a same semantic class, this function creates
45
+ polylines by linking close points together if their distance is below the maxdist threshold.
46
+ :param point_list: List of points of the same line class
47
+ :param maxdist: minimal distance between two polylines.
48
+ :return: a list of polylines
49
+ """
50
+ polylines = []
51
+
52
+ if not len(point_list):
53
+ return polylines
54
+ head = point_list[0]
55
+ tail = point_list[0]
56
+ polyline = deque()
57
+ polyline.append(point_list[0])
58
+ remaining_points = copy.deepcopy(point_list[1:])
59
+
60
+ while len(remaining_points) > 0:
61
+ min_dist_tail = 1000
62
+ min_dist_head = 1000
63
+ best_head = -1
64
+ best_tail = -1
65
+ for j, point in enumerate(remaining_points):
66
+ dist_tail = np.sqrt(np.sum(np.square(point - tail)))
67
+ dist_head = np.sqrt(np.sum(np.square(point - head)))
68
+ if dist_tail < min_dist_tail:
69
+ min_dist_tail = dist_tail
70
+ best_tail = j
71
+ if dist_head < min_dist_head:
72
+ min_dist_head = dist_head
73
+ best_head = j
74
+
75
+ if min_dist_head <= min_dist_tail and min_dist_head < maxdist:
76
+ polyline.appendleft(remaining_points[best_head])
77
+ head = polyline[0]
78
+ remaining_points.pop(best_head)
79
+ elif min_dist_tail < min_dist_head and min_dist_tail < maxdist:
80
+ polyline.append(remaining_points[best_tail])
81
+ tail = polyline[-1]
82
+ remaining_points.pop(best_tail)
83
+ else:
84
+ polylines.append(list(polyline.copy()))
85
+ head = remaining_points[0]
86
+ tail = remaining_points[0]
87
+ polyline = deque()
88
+ polyline.append(head)
89
+ remaining_points.pop(0)
90
+ polylines.append(list(polyline))
91
+ return polylines
92
+
93
+
94
+ def get_line_extremities(buckets, maxdist, width, height):
95
+ """
96
+ Given the dictionary {lines_class: points}, finds plausible extremities of each line, i.e the extremities
97
+ of the longest polyline that can be built on the class blobs, and normalize its coordinates
98
+ by the image size.
99
+ :param buckets: The dictionary associating line classes to the set of circle centers that covers best the class
100
+ prediction blobs in the segmentation mask
101
+ :param maxdist: the maximal distance between two circle centers belonging to the same blob (heuristic)
102
+ :param width: image width
103
+ :param height: image height
104
+ :return: a dictionary associating to each class its extremities
105
+ """
106
+ extremities = dict()
107
+ for class_name, disks_list in buckets.items():
108
+ polyline_list = join_points(disks_list, maxdist)
109
+ max_len = 0
110
+ longest_polyline = []
111
+ for polyline in polyline_list:
112
+ if len(polyline) > max_len:
113
+ max_len = len(polyline)
114
+ longest_polyline = polyline
115
+ extremities[class_name] = [
116
+ {'x': longest_polyline[0][1] / width, 'y': longest_polyline[0][0] / height},
117
+ {'x': longest_polyline[-1][1] / width, 'y': longest_polyline[-1][0] / height}
118
+ ]
119
+ return extremities
120
+
121
+
122
+ def get_support_center(mask, start, disk_radius, min_support=0.1):
123
+ """
124
+ Returns the barycenter of the True pixels under the area of the mask delimited by the circle of center start and
125
+ radius of disk_radius pixels.
126
+ :param mask: Boolean mask
127
+ :param start: A point located on a true pixel of the mask
128
+ :param disk_radius: the radius of the circles
129
+ :param min_support: proportion of the area under the circle area that should be True in order to get enough support
130
+ :return: A boolean indicating if there is enough support in the circle area, the barycenter of the True pixels under
131
+ the circle
132
+ """
133
+ x = int(start[0])
134
+ y = int(start[1])
135
+ support_pixels = 1
136
+ result = [x, y]
137
+ xstart = x - disk_radius
138
+ if xstart < 0:
139
+ xstart = 0
140
+ xend = x + disk_radius
141
+ if xend > mask.shape[0]:
142
+ xend = mask.shape[0] - 1
143
+
144
+ ystart = y - disk_radius
145
+ if ystart < 0:
146
+ ystart = 0
147
+ yend = y + disk_radius
148
+ if yend > mask.shape[1]:
149
+ yend = mask.shape[1] - 1
150
+
151
+ for i in range(xstart, xend + 1):
152
+ for j in range(ystart, yend + 1):
153
+ dist = np.sqrt(np.square(x - i) + np.square(y - j))
154
+ if dist < disk_radius and mask[i, j] > 0:
155
+ support_pixels += 1
156
+ result[0] += i
157
+ result[1] += j
158
+ support = True
159
+ if support_pixels < min_support * np.square(disk_radius) * np.pi:
160
+ support = False
161
+
162
+ result = np.array(result)
163
+ result = np.true_divide(result, support_pixels)
164
+
165
+ return support, result
166
+
167
+
168
+ def synthesize_mask(semantic_mask, disk_radius):
169
+ """
170
+ Fits circles on the True pixels of the mask and returns those which have enough support : meaning that the
171
+ proportion of the area of the circle covering True pixels is higher that a certain threshold in order to avoid
172
+ fitting circles on alone pixels.
173
+ :param semantic_mask: boolean mask
174
+ :param disk_radius: radius of the circles
175
+ :return: a list of disk centers, that have enough support
176
+ """
177
+ mask = semantic_mask.copy().astype(np.uint8)
178
+ points = np.transpose(np.nonzero(mask))
179
+ disks = []
180
+ while len(points):
181
+
182
+ start = random.choice(points)
183
+ dist = 10.
184
+ success = True
185
+ while dist > 1.:
186
+ enough_support, center = get_support_center(mask, start, disk_radius)
187
+ if not enough_support:
188
+ bad_point = np.round(center).astype(np.int32)
189
+ cv.circle(mask, (bad_point[1], bad_point[0]), disk_radius, (0), -1)
190
+ success = False
191
+ dist = np.sqrt(np.sum(np.square(center - start)))
192
+ start = center
193
+ if success:
194
+ disks.append(np.round(start).astype(np.int32))
195
+ cv.circle(mask, (disks[-1][1], disks[-1][0]), disk_radius, 0, -1)
196
+ points = np.transpose(np.nonzero(mask))
197
+
198
+ return disks
199
+
200
+
201
+ class SegmentationNetwork:
202
+ def __init__(self, model_file, mean_file, std_file, num_classes=29, width=640, height=360):
203
+ file_path = Path(model_file).resolve()
204
+ model = nn.DataParallel(deeplabv3_resnet50(pretrained=False, num_classes=num_classes))
205
+ self.init_weight(model, nn.init.kaiming_normal_,
206
+ nn.BatchNorm2d, 1e-3, 0.1,
207
+ mode='fan_in')
208
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
209
+ checkpoint = torch.load(str(file_path), map_location=self.device)
210
+ model.load_state_dict(checkpoint["model"])
211
+ model.eval()
212
+ self.model = model.to(self.device)
213
+ file_path = Path(mean_file).resolve()
214
+ self.mean = np.load(str(file_path))
215
+ file_path = Path(std_file).resolve()
216
+ self.std = np.load(str(file_path))
217
+ self.width = width
218
+ self.height = height
219
+
220
+ def init_weight(self, feature, conv_init, norm_layer, bn_eps, bn_momentum,
221
+ **kwargs):
222
+ for name, m in feature.named_modules():
223
+ if isinstance(m, (nn.Conv2d, nn.Conv3d)):
224
+ conv_init(m.weight, **kwargs)
225
+ elif isinstance(m, norm_layer):
226
+ m.eps = bn_eps
227
+ m.momentum = bn_momentum
228
+ nn.init.constant_(m.weight, 1)
229
+ nn.init.constant_(m.bias, 0)
230
+
231
+ def analyse_image(self, image):
232
+ """
233
+ Process image and perform inference, returns mask of detected classes
234
+ :param image: BGR image
235
+ :return: predicted classes mask
236
+ """
237
+ img = cv.resize(image, (self.width, self.height), interpolation=cv.INTER_LINEAR)
238
+ img = np.asarray(img, np.float32) / 255.
239
+ img = (img - self.mean) / self.std
240
+ img = img.transpose((2, 0, 1))
241
+ img = torch.from_numpy(img).to(self.device).unsqueeze(0)
242
+
243
+ cuda_result = self.model.forward(img.float())
244
+ output = cuda_result['out'].data[0].cpu().numpy()
245
+ output = output.transpose(1, 2, 0)
246
+ output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
247
+
248
+ return output
249
+
250
+
251
+ if __name__ == "__main__":
252
+ parser = argparse.ArgumentParser(description='Test')
253
+
254
+ parser.add_argument('-s', '--soccernet', default="./annotations/", type=str,
255
+ help='Path to the SoccerNet-V3 dataset folder')
256
+ parser.add_argument('-p', '--prediction', default="./results_bis", required=False, type=str,
257
+ help="Path to the prediction folder")
258
+ parser.add_argument('--split', required=False, type=str, default="test", help='Select the split of data')
259
+ parser.add_argument('--masks', required=False, type=bool, default=False, help='Save masks in prediction directory')
260
+ parser.add_argument('--resolution_width', required=False, type=int, default=455,
261
+ help='width resolution of the images')
262
+ parser.add_argument('--resolution_height', required=False, type=int, default=256,
263
+ help='height resolution of the images')
264
+ parser.add_argument('--checkpoint_dir', default="resources")
265
+ args = parser.parse_args()
266
+
267
+ lines_palette = [0, 0, 0]
268
+ for line_class in SoccerPitch.lines_classes:
269
+ lines_palette.extend(SoccerPitch.palette[line_class])
270
+
271
+ calib_net = SegmentationNetwork(
272
+ os.path.join(args.checkpoint_dir, "soccer_pitch_segmentation.pth"),
273
+ os.path.join(args.checkpoint_dir, "mean.npy"),
274
+ os.path.join(args.checkpoint_dir, "std.npy")
275
+ )
276
+
277
+ dataset_dir = os.path.join(args.soccernet, args.split)
278
+ if not os.path.exists(dataset_dir):
279
+ print("Invalid dataset path !")
280
+ exit(-1)
281
+
282
+ frames = [f for f in os.listdir(dataset_dir) if ".jpg" in f]
283
+ with tqdm(enumerate(frames), total=len(frames), ncols=160) as t:
284
+ for i, frame in t:
285
+
286
+ output_prediction_folder = os.path.join(args.prediction, args.split)
287
+ if not os.path.exists(output_prediction_folder):
288
+ os.makedirs(output_prediction_folder)
289
+ prediction = dict()
290
+ count = 0
291
+
292
+ frame_path = os.path.join(dataset_dir, frame)
293
+
294
+ frame_index = frame.split(".")[0]
295
+
296
+ image = cv.imread(frame_path)
297
+ semlines = calib_net.analyse_image(image)
298
+ if args.masks:
299
+ mask = Image.fromarray(semlines.astype(np.uint8)).convert('P')
300
+ mask.putpalette(lines_palette)
301
+ mask_file = os.path.join(output_prediction_folder, frame)
302
+ mask.convert("RGB").save(mask_file)
303
+ skeletons = generate_class_synthesis(semlines, 6)
304
+ extremities = get_line_extremities(skeletons, 40, args.resolution_width, args.resolution_height)
305
+
306
+ prediction = extremities
307
+ count += 1
308
+
309
+ prediction_file = os.path.join(output_prediction_folder, f"extremities_{frame_index}.json")
310
+ with open(prediction_file, "w") as f:
311
+ json.dump(prediction, f, indent=4)
tvcalib/sn_segmentation/src/custom_extremities.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import copy
3
+ import itertools
4
+ import json
5
+ import os.path
6
+ import random
7
+ from collections import deque
8
+ from pathlib import Path
9
+
10
+ from pytorch_lightning import seed_everything
11
+
12
+ seed_everything(seed=10, workers=True)
13
+
14
+ import cv2 as cv
15
+ import numpy as np
16
+ import torch
17
+ import torch.backends.cudnn
18
+ import torch.nn as nn
19
+ import torchvision.transforms as T
20
+
21
+ from PIL import Image
22
+ from torchvision.models.segmentation import deeplabv3_resnet101
23
+ from tqdm import tqdm
24
+
25
+ from SoccerNet.Evaluation.utils_calibration import SoccerPitch
26
+
27
+
28
+ def generate_class_synthesis(semantic_mask, radius):
29
+ """
30
+ This function selects for each class present in the semantic mask, a set of circles that cover most of the semantic
31
+ class blobs.
32
+ :param semantic_mask: a image containing the segmentation predictions
33
+ :param radius: circle radius
34
+ :return: a dictionary which associates with each class detected a list of points ( the circles centers)
35
+ """
36
+ buckets = dict()
37
+ kernel = np.ones((5, 5), np.uint8)
38
+ semantic_mask = cv.erode(semantic_mask, kernel, iterations=1)
39
+ for k, class_name in enumerate(SoccerPitch.lines_classes):
40
+ mask = semantic_mask == k + 1
41
+ if mask.sum() > 0:
42
+ disk_list = synthesize_mask(mask, radius)
43
+ if len(disk_list):
44
+ buckets[class_name] = disk_list
45
+
46
+ return buckets
47
+
48
+
49
+ def join_points(point_list, maxdist):
50
+ """
51
+ Given a list of points that were extracted from the blobs belonging to a same semantic class, this function creates
52
+ polylines by linking close points together if their distance is below the maxdist threshold.
53
+ :param point_list: List of points of the same line class
54
+ :param maxdist: minimal distance between two polylines.
55
+ :return: a list of polylines
56
+ """
57
+ polylines = []
58
+
59
+ if not len(point_list):
60
+ return polylines
61
+ head = point_list[0]
62
+ tail = point_list[0]
63
+ polyline = deque()
64
+ polyline.append(point_list[0])
65
+ remaining_points = copy.deepcopy(point_list[1:])
66
+
67
+ while len(remaining_points) > 0:
68
+ min_dist_tail = 1000
69
+ min_dist_head = 1000
70
+ best_head = -1
71
+ best_tail = -1
72
+ for j, point in enumerate(remaining_points):
73
+ dist_tail = np.sqrt(np.sum(np.square(point - tail)))
74
+ dist_head = np.sqrt(np.sum(np.square(point - head)))
75
+ if dist_tail < min_dist_tail:
76
+ min_dist_tail = dist_tail
77
+ best_tail = j
78
+ if dist_head < min_dist_head:
79
+ min_dist_head = dist_head
80
+ best_head = j
81
+
82
+ if min_dist_head <= min_dist_tail and min_dist_head < maxdist:
83
+ polyline.appendleft(remaining_points[best_head])
84
+ head = polyline[0]
85
+ remaining_points.pop(best_head)
86
+ elif min_dist_tail < min_dist_head and min_dist_tail < maxdist:
87
+ polyline.append(remaining_points[best_tail])
88
+ tail = polyline[-1]
89
+ remaining_points.pop(best_tail)
90
+ else:
91
+ polylines.append(list(polyline.copy()))
92
+ head = remaining_points[0]
93
+ tail = remaining_points[0]
94
+ polyline = deque()
95
+ polyline.append(head)
96
+ remaining_points.pop(0)
97
+ polylines.append(list(polyline))
98
+ return polylines
99
+
100
+
101
+ def get_line_extremities(buckets, maxdist, width, height, num_points_lines, num_points_circles):
102
+ """
103
+ Given the dictionary {lines_class: points}, finds plausible extremities of each line, i.e the extremities
104
+ of the longest polyline that can be built on the class blobs, and normalize its coordinates
105
+ by the image size.
106
+ :param buckets: The dictionary associating line classes to the set of circle centers that covers best the class
107
+ prediction blobs in the segmentation mask
108
+ :param maxdist: the maximal distance between two circle centers belonging to the same blob (heuristic)
109
+ :param width: image width
110
+ :param height: image height
111
+ :return: a dictionary associating to each class its extremities
112
+ """
113
+ extremities = dict()
114
+ for class_name, disks_list in buckets.items():
115
+ polyline_list = join_points(disks_list, maxdist)
116
+ max_len = 0
117
+ longest_polyline = []
118
+ for polyline in polyline_list:
119
+ if len(polyline) > max_len:
120
+ max_len = len(polyline)
121
+ longest_polyline = polyline
122
+ extremities[class_name] = [
123
+ {'x': longest_polyline[0][1] / width, 'y': longest_polyline[0][0] / height},
124
+ {'x': longest_polyline[-1][1] / width, 'y': longest_polyline[-1][0] / height},
125
+
126
+ ]
127
+ num_points = num_points_lines
128
+ if "Circle" in class_name:
129
+ num_points = num_points_circles
130
+ if num_points > 2:
131
+ # equally spaced points along the longest polyline
132
+ # skip first and last as they already exist
133
+ for i in range(1, num_points - 1):
134
+ extremities[class_name].insert(
135
+ len(extremities[class_name]) - 1,
136
+ {'x': longest_polyline[i * int(len(longest_polyline) / num_points)][1] / width, 'y': longest_polyline[i * int(len(longest_polyline) / num_points)][0] / height}
137
+ )
138
+
139
+ return extremities
140
+
141
+
142
+ def get_support_center(mask, start, disk_radius, min_support=0.1):
143
+ """
144
+ Returns the barycenter of the True pixels under the area of the mask delimited by the circle of center start and
145
+ radius of disk_radius pixels.
146
+ :param mask: Boolean mask
147
+ :param start: A point located on a true pixel of the mask
148
+ :param disk_radius: the radius of the circles
149
+ :param min_support: proportion of the area under the circle area that should be True in order to get enough support
150
+ :return: A boolean indicating if there is enough support in the circle area, the barycenter of the True pixels under
151
+ the circle
152
+ """
153
+ x = int(start[0])
154
+ y = int(start[1])
155
+ support_pixels = 1
156
+ result = [x, y]
157
+ xstart = x - disk_radius
158
+ if xstart < 0:
159
+ xstart = 0
160
+ xend = x + disk_radius
161
+ if xend > mask.shape[0]:
162
+ xend = mask.shape[0] - 1
163
+
164
+ ystart = y - disk_radius
165
+ if ystart < 0:
166
+ ystart = 0
167
+ yend = y + disk_radius
168
+ if yend > mask.shape[1]:
169
+ yend = mask.shape[1] - 1
170
+
171
+ for i in range(xstart, xend + 1):
172
+ for j in range(ystart, yend + 1):
173
+ dist = np.sqrt(np.square(x - i) + np.square(y - j))
174
+ if dist < disk_radius and mask[i, j] > 0:
175
+ support_pixels += 1
176
+ result[0] += i
177
+ result[1] += j
178
+ support = True
179
+ if support_pixels < min_support * np.square(disk_radius) * np.pi:
180
+ support = False
181
+
182
+ result = np.array(result)
183
+ result = np.true_divide(result, support_pixels)
184
+
185
+ return support, result
186
+
187
+
188
+ def synthesize_mask(semantic_mask, disk_radius):
189
+ """
190
+ Fits circles on the True pixels of the mask and returns those which have enough support : meaning that the
191
+ proportion of the area of the circle covering True pixels is higher that a certain threshold in order to avoid
192
+ fitting circles on alone pixels.
193
+ :param semantic_mask: boolean mask
194
+ :param disk_radius: radius of the circles
195
+ :return: a list of disk centers, that have enough support
196
+ """
197
+ mask = semantic_mask.copy().astype(np.uint8)
198
+ points = np.transpose(np.nonzero(mask))
199
+ disks = []
200
+ while len(points):
201
+
202
+ start = random.choice(points)
203
+ dist = 10.
204
+ success = True
205
+ while dist > 1.:
206
+ enough_support, center = get_support_center(mask, start, disk_radius)
207
+ if not enough_support:
208
+ bad_point = np.round(center).astype(np.int32)
209
+ cv.circle(mask, (bad_point[1], bad_point[0]), disk_radius, (0), -1)
210
+ success = False
211
+ dist = np.sqrt(np.sum(np.square(center - start)))
212
+ start = center
213
+ if success:
214
+ disks.append(np.round(start).astype(np.int32))
215
+ cv.circle(mask, (disks[-1][1], disks[-1][0]), disk_radius, 0, -1)
216
+ points = np.transpose(np.nonzero(mask))
217
+
218
+ return disks
219
+
220
+ class CustomNetwork:
221
+
222
+ def __init__(self, checkpoint):
223
+ print("Loading model" + checkpoint)
224
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
225
+ self.model = deeplabv3_resnet101(num_classes=len(SoccerPitch.lines_classes) + 1, aux_loss=True)
226
+ self.model.load_state_dict(torch.load(checkpoint)["model"], strict=False)
227
+ self.model.to(self.device)
228
+ self.model.eval()
229
+ print("using", self.device)
230
+
231
+ def forward(self, img):
232
+ trf = T.Compose(
233
+ [
234
+ T.Resize(256),
235
+ #T.CenterCrop(224),
236
+ T.ToTensor(),
237
+ T.Normalize(
238
+ mean = [0.485, 0.456, 0.406],
239
+ std = [0.229, 0.224, 0.225]
240
+ )
241
+ ]
242
+ )
243
+ img = trf(img).unsqueeze(0).to(self.device)
244
+ result = self.model(img)["out"].detach().squeeze(0).argmax(0)
245
+ result = result.cpu().numpy().astype(np.uint8)
246
+ #print(result)
247
+ return result
248
+
249
+ if __name__ == "__main__":
250
+ parser = argparse.ArgumentParser(description='Test')
251
+
252
+ parser.add_argument('-s', '--soccernet', default="/nfs/data/soccernet/calibration/", type=str,
253
+ help='Path to the SoccerNet-V3 dataset folder')
254
+ parser.add_argument('-p', '--prediction', default="sn-calib-test_endpoints", required=False, type=str,
255
+ help="Path to the prediction folder")
256
+ parser.add_argument('--split', required=False, type=str, default="challenge", help='Select the split of data')
257
+ parser.add_argument('--masks', required=False, type=bool, default=False, help='Save masks in prediction directory')
258
+ parser.add_argument('--resolution_width', required=False, type=int, default=455,
259
+ help='width resolution of the images')
260
+ parser.add_argument('--resolution_height', required=False, type=int, default=256,
261
+ help='height resolution of the images')
262
+ parser.add_argument('--checkpoint', required=False, type=str, help="Path to the custom model checkpoint.")
263
+ parser.add_argument('--pp_radius', required=False, type=int, default=4,
264
+ help='Post processing: Radius of circles that cover each segment.')
265
+ parser.add_argument('--pp_maxdists', required=False, type=int, default=30,
266
+ help='Post processing: Maximum distance of circles that are allowed within one segment.')
267
+ parser.add_argument('--num_points_lines', required=False, type=int, default=2, choices=range(2,10),
268
+ help='Post processing: Number of keypoints that represent a line segment')
269
+ parser.add_argument('--num_points_circles', required=False, type=int, default=2, choices=range(2,10),
270
+ help='Post processing: Number of keypoints that represent a circle segment')
271
+ args = parser.parse_args()
272
+
273
+ lines_palette = [0, 0, 0]
274
+ for line_class in SoccerPitch.lines_classes:
275
+ lines_palette.extend(SoccerPitch.palette[line_class])
276
+
277
+ model = CustomNetwork(args.checkpoint)
278
+
279
+ dataset_dir = os.path.join(args.soccernet, args.split)
280
+ if not os.path.exists(dataset_dir):
281
+ print("Invalid dataset path !")
282
+ exit(-1)
283
+
284
+ radius = args.pp_radius
285
+ maxdists = args.pp_maxdists
286
+
287
+ frames = [f for f in os.listdir(dataset_dir) if ".jpg" in f]
288
+ with tqdm(enumerate(frames), total=len(frames), ncols=160) as t:
289
+ for i, frame in t:
290
+
291
+ output_prediction_folder = os.path.join(str(args.prediction), f"np{args.num_points_lines}_nc{args.num_points_circles}_r{radius}_md{maxdists}", args.split)
292
+ if not os.path.exists(output_prediction_folder):
293
+ os.makedirs(output_prediction_folder)
294
+ prediction = dict()
295
+ count = 0
296
+
297
+ frame_path = os.path.join(dataset_dir, frame)
298
+
299
+ frame_index = frame.split(".")[0]
300
+
301
+ image = Image.open(frame_path)
302
+
303
+ semlines = model.forward(image)
304
+ #print(semlines.shape)
305
+ # print("\nsemlines", type(semlines), semlines.shape)
306
+ if args.masks:
307
+ mask = Image.fromarray(semlines.astype(np.uint8)).convert('P')
308
+ mask.putpalette(lines_palette)
309
+ mask_file = os.path.join(output_prediction_folder, frame)
310
+ mask.convert("RGB").save(mask_file)
311
+ skeletons = generate_class_synthesis(semlines, radius)
312
+
313
+ extremities = get_line_extremities(skeletons, maxdists, args.resolution_width, args.resolution_height, args.num_points_lines, args.num_points_circles)
314
+
315
+
316
+ prediction = extremities
317
+ count += 1
318
+
319
+ prediction_file = os.path.join(output_prediction_folder, f"extremities_{frame_index}.json")
320
+ with open(prediction_file, "w") as f:
321
+ json.dump(prediction, f, indent=4)
322
+
tvcalib/sn_segmentation/src/dataloader.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DataLoader used to train the segmentation network used for the prediction of extremities.
3
+ """
4
+
5
+ import json
6
+ import os
7
+ import time
8
+ from argparse import ArgumentParser
9
+
10
+ import cv2 as cv
11
+ import numpy as np
12
+ from torch.utils.data import Dataset
13
+ from tqdm import tqdm
14
+
15
+ from SoccerNet.Evaluation.utils_calibration import SoccerPitch
16
+
17
+
18
+ class SoccerNetDataset(Dataset):
19
+ def __init__(self,
20
+ datasetpath,
21
+ split="test",
22
+ width=640,
23
+ height=360,
24
+ mean="../resources/mean.npy",
25
+ std="../resources/std.npy"):
26
+ self.mean = np.load(mean)
27
+ self.std = np.load(std)
28
+ self.width = width
29
+ self.height = height
30
+
31
+ dataset_dir = os.path.join(datasetpath, split)
32
+ if not os.path.exists(dataset_dir):
33
+ print("Invalid dataset path !")
34
+ exit(-1)
35
+
36
+ frames = [f for f in os.listdir(dataset_dir) if ".jpg" in f]
37
+
38
+ self.data = []
39
+ self.n_samples = 0
40
+ for frame in frames:
41
+
42
+ frame_index = frame.split(".")[0]
43
+ annotation_file = os.path.join(dataset_dir, f"{frame_index}.json")
44
+ if not os.path.exists(annotation_file):
45
+ continue
46
+ with open(annotation_file, "r") as f:
47
+ groundtruth_lines = json.load(f)
48
+ img_path = os.path.join(dataset_dir, frame)
49
+ if groundtruth_lines:
50
+ self.data.append({
51
+ "image_path": img_path,
52
+ "annotations": groundtruth_lines,
53
+ })
54
+
55
+ def __len__(self):
56
+ return len(self.data)
57
+
58
+ def __getitem__(self, index):
59
+ item = self.data[index]
60
+
61
+ img = cv.imread(item["image_path"])
62
+ img = cv.resize(img, (self.width, self.height), interpolation=cv.INTER_LINEAR)
63
+
64
+ mask = np.zeros(img.shape[:-1], dtype=np.uint8)
65
+ img = np.asarray(img, np.float32) / 255.
66
+ img -= self.mean
67
+ img /= self.std
68
+ img = img.transpose((2, 0, 1))
69
+ for class_number, class_ in enumerate(SoccerPitch.lines_classes):
70
+ if class_ in item["annotations"].keys():
71
+ key = class_
72
+ line = item["annotations"][key]
73
+ prev_point = line[0]
74
+ for i in range(1, len(line)):
75
+ next_point = line[i]
76
+ cv.line(mask,
77
+ (int(prev_point["x"] * mask.shape[1]), int(prev_point["y"] * mask.shape[0])),
78
+ (int(next_point["x"] * mask.shape[1]), int(next_point["y"] * mask.shape[0])),
79
+ class_number + 1,
80
+ 2)
81
+ prev_point = next_point
82
+ return img, mask
83
+
84
+
85
+ if __name__ == "__main__":
86
+
87
+ # Load the arguments
88
+ parser = ArgumentParser(description='dataloader')
89
+
90
+ parser.add_argument('--SoccerNet_path', default="./annotations/", type=str,
91
+ help='Path to the SoccerNet-V3 dataset folder')
92
+ parser.add_argument('--tiny', required=False, type=int, default=None, help='Select a subset of x games')
93
+ parser.add_argument('--split', required=False, type=str, default="test", help='Select the split of data')
94
+ parser.add_argument('--num_workers', required=False, type=int, default=4,
95
+ help='number of workers for the dataloader')
96
+ parser.add_argument('--resolution_width', required=False, type=int, default=1920,
97
+ help='width resolution of the images')
98
+ parser.add_argument('--resolution_height', required=False, type=int, default=1080,
99
+ help='height resolution of the images')
100
+ parser.add_argument('--preload_images', action='store_true',
101
+ help="Preload the images when constructing the dataset")
102
+ parser.add_argument('--zipped_images', action='store_true', help="Read images from zipped folder")
103
+
104
+ args = parser.parse_args()
105
+
106
+ start_time = time.time()
107
+ soccernet = SoccerNetDataset(args.SoccerNet_path, split=args.split)
108
+ with tqdm(enumerate(soccernet), total=len(soccernet), ncols=160) as t:
109
+ for i, data in t:
110
+ img = soccernet[i][0].astype(np.uint8).transpose((1, 2, 0))
111
+ print(img.shape)
112
+ print(img.dtype)
113
+ cv.imshow("Normalized image", img)
114
+ cv.waitKey(0)
115
+ cv.destroyAllWindows()
116
+ print(data[1].shape)
117
+ cv.imshow("Mask", soccernet[i][1].astype(np.uint8))
118
+ cv.waitKey(0)
119
+ cv.destroyAllWindows()
120
+ continue
121
+ end_time = time.time()
122
+ print(end_time - start_time)
tvcalib/sn_segmentation/src/evaluate_extremities.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+
9
+ from SoccerNet.Evaluation.utils_calibration import SoccerPitch
10
+
11
+
12
+ def distance(point1, point2):
13
+ """
14
+ Computes euclidian distance between 2D points
15
+ :param point1
16
+ :param point2
17
+ :return: euclidian distance between point1 and point2
18
+ """
19
+ diff = np.array([point1['x'], point1['y']]) - np.array([point2['x'], point2['y']])
20
+ sq_dist = np.square(diff)
21
+ return np.sqrt(sq_dist.sum())
22
+
23
+
24
+ def mirror_labels(lines_dict):
25
+ """
26
+ Replace each line class key of the dictionary with its opposite element according to a central projection by the
27
+ soccer pitch center
28
+ :param lines_dict: dictionary whose keys will be mirrored
29
+ :return: Dictionary with mirrored keys and same values
30
+ """
31
+ mirrored_dict = dict()
32
+ for line_class, value in lines_dict.items():
33
+ mirrored_dict[SoccerPitch.symetric_classes[line_class]] = value
34
+ return mirrored_dict
35
+
36
+
37
+ def evaluate_detection_prediction(detected_lines, groundtruth_lines, threshold=2.):
38
+ """
39
+ Evaluates the prediction of extremities. The extremities associated to a class are unordered. The extremities of the
40
+ "Circle central" element is not well-defined for this task, thus this class is ignored.
41
+ Computes confusion matrices for a level of precision specified by the threshold.
42
+ A groundtruth extremity point is correctly classified if it lies at less than threshold pixels from the
43
+ corresponding extremity point of the prediction of the same class.
44
+ Computes also the euclidian distance between each predicted extremity and its closest groundtruth extremity, when
45
+ both the groundtruth and the prediction contain the element class.
46
+
47
+ :param detected_lines: dictionary of detected lines classes as keys and associated predicted extremities as values
48
+ :param groundtruth_lines: dictionary of annotated lines classes as keys and associated annotated points as values
49
+ :param threshold: distance in pixels that distinguishes good matches from bad ones
50
+ :return: confusion matrix, per class confusion matrix & per class localization errors
51
+ """
52
+ confusion_mat = np.zeros((2, 2), dtype=np.float32)
53
+ per_class_confusion = {}
54
+ errors_dict = {}
55
+ detected_classes = set(detected_lines.keys())
56
+ groundtruth_classes = set(groundtruth_lines.keys())
57
+
58
+ if "Circle central" in groundtruth_classes:
59
+ groundtruth_classes.remove("Circle central")
60
+ if "Circle central" in detected_classes:
61
+ detected_classes.remove("Circle central")
62
+
63
+ false_positives_classes = detected_classes - groundtruth_classes
64
+ for false_positive_class in false_positives_classes:
65
+ false_positives = len(detected_lines[false_positive_class])
66
+ confusion_mat[0, 1] += false_positives
67
+ per_class_confusion[false_positive_class] = np.array([[0., false_positives], [0., 0.]])
68
+
69
+ false_negatives_classes = groundtruth_classes - detected_classes
70
+ for false_negatives_class in false_negatives_classes:
71
+ false_negatives = len(groundtruth_lines[false_negatives_class])
72
+ confusion_mat[1, 0] += false_negatives
73
+ per_class_confusion[false_negatives_class] = np.array([[0., 0.], [false_negatives, 0.]])
74
+
75
+ common_classes = detected_classes - false_positives_classes
76
+
77
+ for detected_class in common_classes:
78
+
79
+ detected_points = detected_lines[detected_class]
80
+
81
+ groundtruth_points = groundtruth_lines[detected_class]
82
+
83
+ groundtruth_extremities = [groundtruth_points[0], groundtruth_points[-1]]
84
+ predicted_extremities = [detected_points[0], detected_points[-1]]
85
+ per_class_confusion[detected_class] = np.zeros((2, 2))
86
+
87
+ dist1 = distance(groundtruth_extremities[0], predicted_extremities[0])
88
+ dist1rev = distance(groundtruth_extremities[1], predicted_extremities[0])
89
+
90
+ dist2 = distance(groundtruth_extremities[1], predicted_extremities[1])
91
+ dist2rev = distance(groundtruth_extremities[0], predicted_extremities[1])
92
+ if dist1rev <= dist1 and dist2rev <= dist2:
93
+ # reverse order
94
+ dist1 = dist1rev
95
+ dist2 = dist2rev
96
+
97
+ errors_dict[detected_class] = [dist1, dist2]
98
+
99
+ if dist1 < threshold:
100
+ confusion_mat[0, 0] += 1
101
+ per_class_confusion[detected_class][0, 0] += 1
102
+ else:
103
+ # treat too far detections as false positives
104
+ confusion_mat[0, 1] += 1
105
+ per_class_confusion[detected_class][0, 1] += 1
106
+
107
+ if dist2 < threshold:
108
+ confusion_mat[0, 0] += 1
109
+ per_class_confusion[detected_class][0, 0] += 1
110
+
111
+ else:
112
+ # treat too far detections as false positives
113
+ confusion_mat[0, 1] += 1
114
+ per_class_confusion[detected_class][0, 1] += 1
115
+
116
+ return confusion_mat, per_class_confusion, errors_dict
117
+
118
+
119
+ def scale_points(points_dict, s_width, s_height):
120
+ """
121
+ Scale points by s_width and s_height factors
122
+ :param points_dict: dictionary of annotations/predictions with normalized point values
123
+ :param s_width: width scaling factor
124
+ :param s_height: height scaling factor
125
+ :return: dictionary with scaled points
126
+ """
127
+ line_dict = {}
128
+ for line_class, points in points_dict.items():
129
+ scaled_points = []
130
+ for point in points:
131
+ new_point = {'x': point['x'] * (s_width-1), 'y': point['y'] * (s_height-1)}
132
+ scaled_points.append(new_point)
133
+ if len(scaled_points):
134
+ line_dict[line_class] = scaled_points
135
+ return line_dict
136
+
137
+
138
+ if __name__ == "__main__":
139
+
140
+ parser = argparse.ArgumentParser(description='Test')
141
+
142
+ parser.add_argument('-s', '--soccernet', default="./annotations", type=str,
143
+ help='Path to the SoccerNet-V3 dataset folder')
144
+ parser.add_argument('-p', '--prediction', default="./results_bis",
145
+ required=False, type=str,
146
+ help="Path to the prediction folder")
147
+ parser.add_argument('-t', '--threshold', default=10, required=False, type=int,
148
+ help="Accuracy threshold in pixels")
149
+ parser.add_argument('--split', required=False, type=str, default="test", help='Select the split of data')
150
+ parser.add_argument('--resolution_width', required=False, type=int, default=960,
151
+ help='width resolution of the images')
152
+ parser.add_argument('--resolution_height', required=False, type=int, default=540,
153
+ help='height resolution of the images')
154
+ args = parser.parse_args()
155
+
156
+ accuracies = []
157
+ precisions = []
158
+ recalls = []
159
+ dict_errors = {}
160
+ per_class_confusion_dict = {}
161
+
162
+ dataset_dir = os.path.join(args.soccernet, args.split)
163
+ if not os.path.exists(dataset_dir):
164
+ print("Invalid dataset path !")
165
+ exit(-1)
166
+
167
+ annotation_files = [f for f in os.listdir(dataset_dir) if ".json" in f]
168
+
169
+ with tqdm(enumerate(annotation_files), total=len(annotation_files), ncols=160) as t:
170
+ for i, annotation_file in t:
171
+ frame_index = annotation_file.split(".")[0]
172
+ annotation_file = os.path.join(args.soccernet, args.split, annotation_file)
173
+ prediction_file = os.path.join(args.prediction, args.split, f"extremities_{frame_index}.json")
174
+
175
+ if not os.path.exists(prediction_file):
176
+ accuracies.append(0.)
177
+ precisions.append(0.)
178
+ recalls.append(0.)
179
+ continue
180
+
181
+ with open(annotation_file, 'r') as f:
182
+ line_annotations = json.load(f)
183
+
184
+ with open(prediction_file, 'r') as f:
185
+ predictions = json.load(f)
186
+
187
+ predictions = scale_points(predictions, args.resolution_width, args.resolution_height)
188
+ line_annotations = scale_points(line_annotations, args.resolution_width, args.resolution_height)
189
+
190
+ img_prediction = predictions
191
+ img_groundtruth = line_annotations
192
+ confusion1, per_class_conf1, reproj_errors1 = evaluate_detection_prediction(img_prediction,
193
+ img_groundtruth,
194
+ args.threshold)
195
+ confusion2, per_class_conf2, reproj_errors2 = evaluate_detection_prediction(img_prediction,
196
+ mirror_labels(
197
+ img_groundtruth),
198
+ args.threshold)
199
+
200
+ accuracy1, accuracy2 = 0., 0.
201
+ if confusion1.sum() > 0:
202
+ accuracy1 = confusion1[0, 0] / confusion1.sum()
203
+
204
+ if confusion2.sum() > 0:
205
+ accuracy2 = confusion2[0, 0] / confusion2.sum()
206
+
207
+ if accuracy1 > accuracy2:
208
+ accuracy = accuracy1
209
+ confusion = confusion1
210
+ per_class_conf = per_class_conf1
211
+ reproj_errors = reproj_errors1
212
+ else:
213
+ accuracy = accuracy2
214
+ confusion = confusion2
215
+ per_class_conf = per_class_conf2
216
+ reproj_errors = reproj_errors2
217
+
218
+ accuracies.append(accuracy)
219
+ if confusion[0, :].sum() > 0:
220
+ precision = confusion[0, 0] / (confusion[0, :].sum())
221
+ precisions.append(precision)
222
+ if (confusion[0, 0] + confusion[1, 0]) > 0:
223
+ recall = confusion[0, 0] / (confusion[0, 0] + confusion[1, 0])
224
+ recalls.append(recall)
225
+
226
+ for line_class, errors in reproj_errors.items():
227
+ if line_class in dict_errors.keys():
228
+ dict_errors[line_class].extend(errors)
229
+ else:
230
+ dict_errors[line_class] = errors
231
+
232
+ for line_class, confusion_mat in per_class_conf.items():
233
+ if line_class in per_class_confusion_dict.keys():
234
+ per_class_confusion_dict[line_class] += confusion_mat
235
+ else:
236
+ per_class_confusion_dict[line_class] = confusion_mat
237
+
238
+ mRecall = np.mean(recalls)
239
+ sRecall = np.std(recalls)
240
+ medianRecall = np.median(recalls)
241
+ print(
242
+ f" On SoccerNet {args.split} set, recall mean value : {mRecall * 100:2.2f}% with standard deviation of {sRecall * 100:2.2f}% and median of {medianRecall * 100:2.2f}%")
243
+
244
+ mPrecision = np.mean(precisions)
245
+ sPrecision = np.std(precisions)
246
+ medianPrecision = np.median(precisions)
247
+ print(
248
+ f" On SoccerNet {args.split} set, precision mean value : {mPrecision * 100:2.2f}% with standard deviation of {sPrecision * 100:2.2f}% and median of {medianPrecision * 100:2.2f}%")
249
+
250
+ mAccuracy = np.mean(accuracies)
251
+ sAccuracy = np.std(accuracies)
252
+ medianAccuracy = np.median(accuracies)
253
+ print(
254
+ f" On SoccerNet {args.split} set, accuracy mean value : {mAccuracy * 100:2.2f}% with standard deviation of {sAccuracy * 100:2.2f}% and median of {medianAccuracy * 100:2.2f}%")
255
+
256
+ for line_class, confusion_mat in per_class_confusion_dict.items():
257
+ class_accuracy = confusion_mat[0, 0] / confusion_mat.sum()
258
+ class_recall = confusion_mat[0, 0] / (confusion_mat[0, 0] + confusion_mat[1, 0])
259
+ class_precision = confusion_mat[0, 0] / (confusion_mat[0, 0] + confusion_mat[0, 1])
260
+ print(
261
+ f"For class {line_class}, accuracy of {class_accuracy * 100:2.2f}%, precision of {class_precision * 100:2.2f}% and recall of {class_recall * 100:2.2f}%")
262
+
263
+ for k, v in dict_errors.items():
264
+ fig, ax1 = plt.subplots(figsize=(11, 8))
265
+ ax1.hist(v, bins=30, range=(0, 60))
266
+ ax1.set_title(k)
267
+ ax1.set_xlabel("Errors in pixel")
268
+ os.makedirs(f"./results/", exist_ok=True)
269
+ plt.savefig(f"./results/{k}_detection_error.png")
270
+ plt.close(fig)
tvcalib/sn_segmentation/src/masks_gt2chen.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import torch
3
+ from torch.utils.data import Dataset
4
+ import torchvision
5
+ import torchvision.transforms as T
6
+ from PIL import Image
7
+ import cv2
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ import pandas as pd
11
+ import h5py
12
+ from tqdm.auto import tqdm
13
+ from collections import defaultdict
14
+ from pathlib import Path
15
+ import json
16
+ import os
17
+ from argparse import ArgumentParser
18
+
19
+ lines_classes = [
20
+ "Big rect. left bottom",
21
+ "Big rect. left main",
22
+ "Big rect. left top",
23
+ "Big rect. right bottom",
24
+ "Big rect. right main",
25
+ "Big rect. right top",
26
+ "Circle central",
27
+ "Circle left",
28
+ "Circle right",
29
+ # "Goal left crossbar",
30
+ # "Goal left post left ",
31
+ # "Goal left post right",
32
+ # "Goal right crossbar",
33
+ # "Goal right post left",
34
+ # "Goal right post right",
35
+ "Goal unknown",
36
+ "Line unknown",
37
+ "Middle line",
38
+ "Side line bottom",
39
+ "Side line left",
40
+ "Side line right",
41
+ "Side line top",
42
+ "Small rect. left bottom",
43
+ "Small rect. left main",
44
+ "Small rect. left top",
45
+ "Small rect. right bottom",
46
+ "Small rect. right main",
47
+ "Small rect. right top",
48
+ ]
49
+
50
+ # RGB values
51
+ palette = {
52
+ "Big rect. left bottom": (127, 0, 0),
53
+ "Big rect. left main": (102, 102, 102),
54
+ "Big rect. left top": (0, 0, 127),
55
+ "Big rect. right bottom": (86, 32, 39),
56
+ "Big rect. right main": (48, 77, 0),
57
+ "Big rect. right top": (14, 97, 100),
58
+ "Circle central": (0, 0, 255),
59
+ "Circle left": (255, 127, 0),
60
+ "Circle right": (0, 255, 255),
61
+ # "Goal left crossbar": (255, 255, 200),
62
+ # "Goal left post left ": (165, 255, 0),
63
+ # "Goal left post right": (155, 119, 45),
64
+ # "Goal right crossbar": (86, 32, 139),
65
+ # "Goal right post left": (196, 120, 153),
66
+ # "Goal right post right": (166, 36, 52),
67
+ "Goal unknown": (0, 0, 0),
68
+ "Line unknown": (0, 0, 0),
69
+ "Middle line": (255, 255, 0),
70
+ "Side line bottom": (255, 0, 255),
71
+ "Side line left": (0, 255, 150),
72
+ "Side line right": (0, 230, 0),
73
+ "Side line top": (230, 0, 0),
74
+ "Small rect. left bottom": (0, 150, 255),
75
+ "Small rect. left main": (254, 173, 225),
76
+ "Small rect. left top": (87, 72, 39),
77
+ "Small rect. right bottom": (122, 0, 255),
78
+ "Small rect. right main": (255, 255, 255),
79
+ "Small rect. right top": (153, 23, 153),
80
+ }
81
+
82
+
83
+ def create_target_from_annotation(width, height, annotation, classes, linewidth=4):
84
+ """Draw one-hot encoded segments according to the annotation.
85
+ Creates target that matches image size ([C+1]xHxW).
86
+ """
87
+ annotation_abs = defaultdict(list)
88
+ # unnormalize every point in every class k
89
+ for k in annotation.keys():
90
+ if k not in lines_classes:
91
+ continue
92
+ start = annotation[k][0].copy()
93
+ end = annotation[k][-1].copy()
94
+ for annotation_point in annotation[k]:
95
+ tup = annotation_point.copy()
96
+ tup["x"] *= width
97
+ tup["x"] = int(tup["x"])
98
+ tup["y"] *= height
99
+ tup["y"] = int(tup["y"])
100
+ annotation_abs[k].append(tup)
101
+
102
+ # draw lines between annotated points for each segment
103
+ # offset class +1 such that no classes detected will end in argmax 0
104
+ # otherwise argmax 0 will be another class
105
+ classes_segments = np.zeros(shape=(len(classes) + 1, height, width))
106
+ for cls, points in annotation_abs.items():
107
+ class_segments = np.zeros(shape=(height, width, 3))
108
+ for start, end in zip(points, points[1:]):
109
+ startxy = (start["x"], start["y"])
110
+ endxy = [end["x"], end["y"]]
111
+ class_segments = cv2.line(
112
+ class_segments, startxy, endxy, (1, 1, 1), linewidth
113
+ )
114
+ classes_segments[classes.index(cls) + 1] = class_segments[:, :, 1]
115
+
116
+ classes_segments = torch.Tensor(classes_segments)
117
+ return classes_segments
118
+
119
+
120
+ class ExtremitiesDataset(Dataset):
121
+ def __init__(
122
+ self, root, split, annotations, filter_cam=None, extremities_prefix="", classes=lines_classes, palette=palette
123
+ ):
124
+ self.data_root = Path(root)
125
+ self.split = split
126
+
127
+ self.annotations_path = annotations
128
+
129
+ if filter_cam is None:
130
+ files = os.listdir(self.data_root / self.split)
131
+ self.annotations = sorted([fn for fn in files if fn.endswith("json")])
132
+ self.images = sorted([fn for fn in files if fn.endswith("jpg")])
133
+ else:
134
+ df = pd.read_json(self.data_root / self.split / "match_info_cam_gt.json").T
135
+ df = df.loc[df.camera == filter_cam]
136
+ assert len(df.index) > 0
137
+ df["image_file"] = df.index
138
+ df = df.sort_values(by=["image_file"])
139
+ df["annotation_file"] = df["image_file"].apply(
140
+ lambda s: extremities_prefix + s.split(".jpg")[0] + ".json"
141
+ )
142
+ self.annotations = df["annotation_file"].tolist()
143
+ self.images = df["image_file"].tolist()
144
+
145
+ self.classes = classes
146
+
147
+ def __len__(self):
148
+ return len(self.images)
149
+
150
+ def __getitem__(self, idx):
151
+ # see https://learnopencv.com/pytorch-for-beginners-semantic-segmentation-using-torchvision/
152
+
153
+ impath = self.data_root / self.split / self.images[idx]
154
+ annotation_path = self.annotations_path / self.annotations[idx]
155
+ with open(annotation_path, "r") as f:
156
+ annotation = json.load(f)
157
+
158
+ img = Image.open(impath) # .resize((1280, 720))
159
+ trf = T.Compose(
160
+ [
161
+ T.ToTensor(),
162
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
163
+ ]
164
+ )
165
+ # prepare batches
166
+ img = trf(img)
167
+
168
+ # see https://git.tib.eu/vid2pos/sccvsd/-/blob/master/utils/synthetic_util.py
169
+ # draw lines (linewidth=4 for 720p) -> hence we rescale first
170
+ target = create_target_from_annotation(1280, 720, annotation, self.classes)
171
+ target = target.long().argmax(dim=0).unsqueeze(0)
172
+ # to binary mask
173
+ target = target.bool().float()
174
+ # rescale target equivalent to cv2.resize() with default args (interpolation bilinear) -> same as in torchvision
175
+ # bilinear -> [0, 0.25, 0.5, 0.7, 1.0] are entries
176
+ target = torchvision.transforms.Resize((180, 320))(target)
177
+ # print(torch.unique(target))
178
+ # to uint8
179
+ target = (target * 255.0).to(torch.uint8)
180
+
181
+ return img, target, impath.name
182
+
183
+
184
+ if __name__ == "__main__":
185
+
186
+ args = ArgumentParser()
187
+ args.add_argument("--data_dir", type=Path)
188
+ args.add_argument("--annotations", type=Path)
189
+ args.add_argument("--output_dir", type=Path)
190
+ args.add_argument("--extremities_prefix", type=str, default="")
191
+ args = args.parse_args()
192
+
193
+ data_dir = args.data_dir.parent
194
+ split = args.data_dir.name
195
+ output_dir = args.output_dir
196
+ if not output_dir.exists():
197
+ raise FileNotFoundError
198
+
199
+ dataset = ExtremitiesDataset(data_dir, split, args.annotations, filter_cam="Main camera center", extremities_prefix=args.extremities_prefix)
200
+
201
+ # img, edge_map, img_id = dataset[0]
202
+ # edge_map = edge_map.squeeze(0).numpy()
203
+ # Image.fromarray(edge_map).show()
204
+
205
+ image_src = []
206
+ edge_maps = np.zeros((len(dataset), 1, 180, 320), dtype=np.uint8)
207
+ for i, (_, edge_map, img_id) in enumerate(tqdm(dataset)):
208
+ edge_map = edge_map.numpy()
209
+ # Image.fromarray(edge_map).show()
210
+ edge_maps[i] = edge_map
211
+ image_src.append(img_id)
212
+
213
+ with h5py.File(output_dir / "seg_edge_maps.h5", "w") as f:
214
+ f.create_dataset("edge_map", data=edge_maps)
215
+
216
+ with open(output_dir / "seg_image_paths.pkl", "wb") as f:
217
+ pickle.dump(image_src, f)
tvcalib/sn_segmentation/src/masks_pred2chen.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os.path
3
+ import pickle
4
+
5
+ import h5py
6
+ import numpy as np
7
+ import pandas as pd
8
+ from PIL import Image
9
+ from tqdm import tqdm
10
+ import cv2
11
+
12
+ from SoccerNet.Evaluation.utils_calibration import SoccerPitch
13
+
14
+ from custom_extremities import CustomNetwork
15
+
16
+ if __name__ == "__main__":
17
+ parser = argparse.ArgumentParser(description="Test")
18
+
19
+ parser.add_argument(
20
+ "-s",
21
+ "--soccernet",
22
+ default="/nfs/data/soccernet/calibration/",
23
+ type=str,
24
+ help="Path to the SoccerNet-V3 dataset folder",
25
+ )
26
+ parser.add_argument(
27
+ "-p",
28
+ "--prediction",
29
+ default="/nfs/home/rhotertj/datasets/sn-calib-test_endpoints",
30
+ required=False,
31
+ type=str,
32
+ help="Path to the prediction folder",
33
+ )
34
+ parser.add_argument(
35
+ "--split",
36
+ required=False,
37
+ type=str,
38
+ default="challenge",
39
+ help="Select the split of data",
40
+ )
41
+ parser.add_argument(
42
+ "--resolution_width",
43
+ required=False,
44
+ type=int,
45
+ default=455,
46
+ help="width resolution of the images",
47
+ )
48
+ parser.add_argument(
49
+ "--resolution_height",
50
+ required=False,
51
+ type=int,
52
+ default=256,
53
+ help="height resolution of the images",
54
+ )
55
+ parser.add_argument(
56
+ "--checkpoint",
57
+ required=False,
58
+ type=str,
59
+ help="Path to the custom model checkpoint.",
60
+ )
61
+ parser.add_argument("--filter_cam", type=str, required=False)
62
+ args = parser.parse_args()
63
+
64
+ lines_palette = [0, 0, 0]
65
+ for line_class in SoccerPitch.lines_classes:
66
+ print(line_class, SoccerPitch.palette[line_class])
67
+ lines_palette.extend(SoccerPitch.palette[line_class])
68
+
69
+ print(lines_palette)
70
+
71
+ # exit(0)
72
+
73
+ dataset_dir = os.path.join(args.soccernet, args.split)
74
+ if not os.path.exists(dataset_dir):
75
+ print("Invalid dataset path !")
76
+ exit(-1)
77
+
78
+ match_info_file = os.path.join(args.soccernet, args.split, "match_info_cam_gt.json")
79
+ print(match_info_file)
80
+ if not os.path.exists(match_info_file):
81
+ exit(-1)
82
+ df = pd.read_json(match_info_file).T
83
+ if args.filter_cam:
84
+ df = df.loc[df.camera == args.filter_cam]
85
+ df["image_file"] = df.index
86
+ df = df.sort_values(by=["image_file"])
87
+ print(df)
88
+
89
+ frames = df["image_file"].tolist()
90
+
91
+ model = CustomNetwork(args.checkpoint)
92
+
93
+ image_src = []
94
+ edge_maps = np.zeros((len(frames), 1, 180, 320), dtype=np.uint8)
95
+
96
+ kernel = np.ones((4, 4), np.uint8)
97
+
98
+ with tqdm(enumerate(frames), total=len(frames), ncols=100) as t:
99
+ for i, frame in t:
100
+
101
+ output_prediction_folder = args.prediction
102
+ if not os.path.exists(output_prediction_folder):
103
+ os.makedirs(output_prediction_folder)
104
+
105
+ frame_path = os.path.join(dataset_dir, frame)
106
+
107
+ frame_index = frame.split(".")[0]
108
+
109
+ image = Image.open(frame_path)
110
+
111
+ semlines = model.forward(image)
112
+
113
+ # print(semlines.shape, np.unique(semlines))
114
+ # set class 9-15 (goal parts) to background
115
+ mask_goal = (semlines >= 9) & (semlines <= 15)
116
+ semlines[mask_goal] = 0
117
+
118
+ mask = Image.fromarray(semlines.astype(np.uint8)).convert("P")
119
+ mask.putpalette(lines_palette)
120
+
121
+ # to binary edge map
122
+ mask = np.asarray(mask.convert("L"))
123
+ mask[mask > 0] = 255
124
+
125
+ mask = Image.fromarray(mask)
126
+ mask = mask.resize((320, 180), resample=Image.NEAREST)
127
+ # expected linewith @ 720p resulution -> 4px
128
+
129
+ mask = np.asarray(mask)
130
+ # print(mask.shape)
131
+
132
+ mask = cv2.erode(mask, kernel, iterations=1)
133
+
134
+ # assert len(np.unique(mask)) == 2 # [0, 255]
135
+
136
+ # mask_file = os.path.join(output_prediction_folder, frame)
137
+ # mask.save(mask_file)
138
+ # print(mask)
139
+ # exit(0)
140
+
141
+ edge_maps[i] = mask
142
+ image_src.append(frame)
143
+
144
+ with h5py.File(
145
+ os.path.join(output_prediction_folder, "seg_edge_maps.h5"), "w"
146
+ ) as f:
147
+ f.create_dataset("edge_map", data=edge_maps)
148
+
149
+ with open(os.path.join(output_prediction_folder, "seg_image_paths.pkl"), "wb") as f:
150
+ pickle.dump(image_src, f)
tvcalib/sn_segmentation/src/segmentation/README.md ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Semantic segmentation reference training scripts
2
+
3
+ This directory is an edited copy of the [torchvision](https://github.com/pytorch/vision/tree/main/references/segmentation) reference scripts.
4
+
5
+ # Usage
6
+
7
+ Train from scratch:
8
+
9
+ ```
10
+ python train.py -b 8 --model deeplabv3_resnet101 --aux-loss --weights-backbone ResNet101_Weights.IMAGENET1K_V1 --epochs 30 --output-dir "./checkpoints" --split train
11
+ ```
12
+
13
+ Resume training:
14
+
15
+ ```
16
+ python train.py -b 8 --model deeplabv3_resnet101 --aux-loss --weights-backbone ResNet101_Weights.IMAGENET1K_V1 --epochs 60 --start-epoch 30 --weights /path/to/checkpoints.pt
17
+ ```
18
+
19
+ Evaluate checkpoint:
20
+
21
+ ```
22
+ python train.py -b 4 --model deeplabv3_resnet101 --aux-loss --weights-backbone ResNet101_Weights.IMAGENET1K_V1 --test-only --weights /path/to/checkpoints.pt
23
+ ```
tvcalib/sn_segmentation/src/segmentation/coco_utils.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+
4
+ import torch
5
+ import torch.utils.data
6
+ import torchvision
7
+ from PIL import Image
8
+ from pycocotools import mask as coco_mask
9
+ from transforms import Compose
10
+
11
+
12
+ class FilterAndRemapCocoCategories:
13
+ def __init__(self, categories, remap=True):
14
+ self.categories = categories
15
+ self.remap = remap
16
+
17
+ def __call__(self, image, anno):
18
+ anno = [obj for obj in anno if obj["category_id"] in self.categories]
19
+ if not self.remap:
20
+ return image, anno
21
+ anno = copy.deepcopy(anno)
22
+ for obj in anno:
23
+ obj["category_id"] = self.categories.index(obj["category_id"])
24
+ return image, anno
25
+
26
+
27
+ def convert_coco_poly_to_mask(segmentations, height, width):
28
+ masks = []
29
+ for polygons in segmentations:
30
+ rles = coco_mask.frPyObjects(polygons, height, width)
31
+ mask = coco_mask.decode(rles)
32
+ if len(mask.shape) < 3:
33
+ mask = mask[..., None]
34
+ mask = torch.as_tensor(mask, dtype=torch.uint8)
35
+ mask = mask.any(dim=2)
36
+ masks.append(mask)
37
+ if masks:
38
+ masks = torch.stack(masks, dim=0)
39
+ else:
40
+ masks = torch.zeros((0, height, width), dtype=torch.uint8)
41
+ return masks
42
+
43
+
44
+ class ConvertCocoPolysToMask:
45
+ def __call__(self, image, anno):
46
+ w, h = image.size
47
+ segmentations = [obj["segmentation"] for obj in anno]
48
+ cats = [obj["category_id"] for obj in anno]
49
+ if segmentations:
50
+ masks = convert_coco_poly_to_mask(segmentations, h, w)
51
+ cats = torch.as_tensor(cats, dtype=masks.dtype)
52
+ # merge all instance masks into a single segmentation map
53
+ # with its corresponding categories
54
+ target, _ = (masks * cats[:, None, None]).max(dim=0)
55
+ # discard overlapping instances
56
+ target[masks.sum(0) > 1] = 255
57
+ else:
58
+ target = torch.zeros((h, w), dtype=torch.uint8)
59
+ target = Image.fromarray(target.numpy())
60
+ return image, target
61
+
62
+
63
+ def _coco_remove_images_without_annotations(dataset, cat_list=None):
64
+ def _has_valid_annotation(anno):
65
+ # if it's empty, there is no annotation
66
+ if len(anno) == 0:
67
+ return False
68
+ # if more than 1k pixels occupied in the image
69
+ return sum(obj["area"] for obj in anno) > 1000
70
+
71
+ if not isinstance(dataset, torchvision.datasets.CocoDetection):
72
+ raise TypeError(
73
+ f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}"
74
+ )
75
+
76
+ ids = []
77
+ for ds_idx, img_id in enumerate(dataset.ids):
78
+ ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
79
+ anno = dataset.coco.loadAnns(ann_ids)
80
+ if cat_list:
81
+ anno = [obj for obj in anno if obj["category_id"] in cat_list]
82
+ if _has_valid_annotation(anno):
83
+ ids.append(ds_idx)
84
+
85
+ dataset = torch.utils.data.Subset(dataset, ids)
86
+ return dataset
87
+
88
+
89
+ def get_coco(root, image_set, transforms):
90
+ PATHS = {
91
+ "train": ("train2017", os.path.join("annotations", "instances_train2017.json")),
92
+ "val": ("val2017", os.path.join("annotations", "instances_val2017.json")),
93
+ # "train": ("val2017", os.path.join("annotations", "instances_val2017.json"))
94
+ }
95
+ CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72]
96
+
97
+ transforms = Compose([FilterAndRemapCocoCategories(CAT_LIST, remap=True), ConvertCocoPolysToMask(), transforms])
98
+
99
+ img_folder, ann_file = PATHS[image_set]
100
+ img_folder = os.path.join(root, img_folder)
101
+ ann_file = os.path.join(root, ann_file)
102
+
103
+ dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
104
+
105
+ if image_set == "train":
106
+ dataset = _coco_remove_images_without_annotations(dataset, CAT_LIST)
107
+
108
+ return dataset
tvcalib/sn_segmentation/src/segmentation/presets.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import transforms as T
3
+
4
+
5
+ class SegmentationPresetTrain:
6
+ def __init__(self, *, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
7
+ min_size = int(0.5 * base_size)
8
+ max_size = int(2.0 * base_size)
9
+
10
+ trans = [T.RandomResize(min_size, max_size)]
11
+ if hflip_prob > 0:
12
+ trans.append(T.RandomHorizontalFlip(hflip_prob))
13
+ trans.extend(
14
+ [
15
+ T.RandomCrop(crop_size),
16
+ T.PILToTensor(),
17
+ T.ConvertImageDtype(torch.float),
18
+ T.Normalize(mean=mean, std=std),
19
+ ]
20
+ )
21
+ self.transforms = T.Compose(trans)
22
+
23
+ def __call__(self, img, target):
24
+ return self.transforms(img, target)
25
+
26
+
27
+ class SegmentationPresetEval:
28
+ def __init__(self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
29
+ self.transforms = T.Compose(
30
+ [
31
+ T.RandomResize(base_size, base_size),
32
+ T.PILToTensor(),
33
+ T.ConvertImageDtype(torch.float),
34
+ T.Normalize(mean=mean, std=std),
35
+ ]
36
+ )
37
+
38
+ def __call__(self, img, target):
39
+ return self.transforms(img, target)
tvcalib/sn_segmentation/src/segmentation/soccerdata.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ import torchvision
4
+ import torchvision.transforms as T
5
+ from PIL import Image
6
+ import cv2
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+
10
+ from collections import defaultdict
11
+ from pathlib import Path
12
+ import json
13
+ import os
14
+
15
+ lines_classes = [
16
+ 'Big rect. left bottom',
17
+ 'Big rect. left main',
18
+ 'Big rect. left top',
19
+ 'Big rect. right bottom',
20
+ 'Big rect. right main',
21
+ 'Big rect. right top',
22
+ 'Circle central',
23
+ 'Circle left',
24
+ 'Circle right',
25
+ 'Goal left crossbar',
26
+ 'Goal left post left ',
27
+ 'Goal left post right',
28
+ 'Goal right crossbar',
29
+ 'Goal right post left',
30
+ 'Goal right post right',
31
+ 'Goal unknown',
32
+ 'Line unknown',
33
+ 'Middle line',
34
+ 'Side line bottom',
35
+ 'Side line left',
36
+ 'Side line right',
37
+ 'Side line top',
38
+ 'Small rect. left bottom',
39
+ 'Small rect. left main',
40
+ 'Small rect. left top',
41
+ 'Small rect. right bottom',
42
+ 'Small rect. right main',
43
+ 'Small rect. right top'
44
+ ]
45
+
46
+ # RGB values
47
+ palette = {
48
+ 'Big rect. left bottom': (127, 0, 0),
49
+ 'Big rect. left main': (102, 102, 102),
50
+ 'Big rect. left top': (0, 0, 127),
51
+ 'Big rect. right bottom': (86, 32, 39),
52
+ 'Big rect. right main': (48, 77, 0),
53
+ 'Big rect. right top': (14, 97, 100),
54
+ 'Circle central': (0, 0, 255),
55
+ 'Circle left': (255, 127, 0),
56
+ 'Circle right': (0, 255, 255),
57
+ 'Goal left crossbar': (255, 255, 200),
58
+ 'Goal left post left ': (165, 255, 0),
59
+ 'Goal left post right': (155, 119, 45),
60
+ 'Goal right crossbar': (86, 32, 139),
61
+ 'Goal right post left': (196, 120, 153),
62
+ 'Goal right post right': (166, 36, 52),
63
+ 'Goal unknown': (0, 0, 0),
64
+ 'Line unknown': (0, 0, 0),
65
+ 'Middle line': (255, 255, 0),
66
+ 'Side line bottom': (255, 0, 255),
67
+ 'Side line left': (0, 255, 150),
68
+ 'Side line right': (0, 230, 0),
69
+ 'Side line top': (230, 0, 0),
70
+ 'Small rect. left bottom': (0, 150, 255),
71
+ 'Small rect. left main': (254, 173, 225),
72
+ 'Small rect. left top': (87, 72, 39),
73
+ 'Small rect. right bottom': (122, 0, 255),
74
+ 'Small rect. right main': (255, 255, 255),
75
+ 'Small rect. right top': (153, 23, 153)
76
+ }
77
+
78
+ data_dir = Path("data/datasets")
79
+
80
+ def create_target_from_annotation(width, height, annotation, classes):
81
+ """Draw one-hot encoded segments according to the annotation.
82
+ Creates target that matches image size ([C+1]xHxW).
83
+ """
84
+ annotation_abs = defaultdict(list)
85
+ # unnormalize every point in every class k
86
+ for k in annotation.keys():
87
+ start = annotation[k][0].copy()
88
+ end = annotation[k][-1].copy()
89
+ for annotation_point in annotation[k]:
90
+ tup = annotation_point.copy()
91
+ tup["x"] *= width
92
+ tup["x"] = int(tup["x"])
93
+ tup["y"] *= height
94
+ tup["y"] = int(tup["y"])
95
+ annotation_abs[k].append(tup)
96
+
97
+ # draw lines between annotated points for each segment
98
+ # offset class +1 such that no classes detected will end in argmax 0
99
+ # otherwise argmax 0 will be another class
100
+ classes_segments = np.zeros(shape=(len(classes) + 1, height, width))
101
+ for cls, points in annotation_abs.items():
102
+ class_segments = np.zeros(shape=(height, width, 3))
103
+ for start, end in zip(points, points[1:]):
104
+ startxy = (start["x"], start["y"])
105
+ endxy = [end["x"], end["y"]]
106
+ class_segments = cv2.line(class_segments, startxy, endxy, (1,1,1), 5)
107
+ classes_segments[classes.index(cls) + 1] = class_segments[:,:,1]
108
+
109
+ classes_segments = torch.Tensor(classes_segments)
110
+ return classes_segments
111
+
112
+ class ExtremitiesDataset(Dataset):
113
+
114
+ def __init__(self, root, split, classes=lines_classes, palette=palette):
115
+ self.data_root = Path(root)
116
+ self.split = split
117
+ files = os.listdir(self.data_root / self.split)
118
+ self.annotations = sorted([fn for fn in files if fn.endswith("json")])
119
+ self.images = sorted([fn for fn in files if fn.endswith("jpg")])
120
+ #self.height, self.width = 224, 224
121
+ self.classes = classes
122
+
123
+ def __len__(self):
124
+ return len(self.images)
125
+
126
+ def __getitem__(self, idx):
127
+ # see https://learnopencv.com/pytorch-for-beginners-semantic-segmentation-using-torchvision/
128
+
129
+ impath = self.data_root / self.split / self.images[idx]
130
+ annotation_path = self.data_root / self.split / self.annotations[idx]
131
+ #print(impath)
132
+ #print(annotation_path)
133
+ with open(annotation_path, "r") as f:
134
+ annotation = json.load(f)
135
+
136
+ # setup image, cast to device later in training
137
+ img = Image.open(impath)
138
+ trf = T.Compose(
139
+ [
140
+ T.Resize(256),
141
+ #T.CenterCrop(224),
142
+ T.ToTensor(),
143
+ T.Normalize(
144
+ mean = [0.485, 0.456, 0.406],
145
+ std = [0.229, 0.224, 0.225]
146
+ )
147
+ ]
148
+ )
149
+ # prepare batches
150
+ img = trf(img)#.unsqueeze(0)
151
+ new_height, new_width = img.shape[-2], img.shape[-1]
152
+
153
+
154
+ target = create_target_from_annotation(new_width, new_height, annotation, self.classes)
155
+ #target = torchvision.transforms.functional.center_crop(target, 224)
156
+ target = target.long().argmax(dim=0)
157
+
158
+ return img, target
159
+
160
+ if __name__ == "__main__":
161
+ data = ExtremitiesDataset(root=data_dir, split="test")
162
+ print(data[0][1])
163
+ target = data[0][1].unsqueeze(0).permute(1,2,0)
164
+ plt.imshow(target)
tvcalib/sn_segmentation/src/segmentation/train.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import os
3
+ import time
4
+ import warnings
5
+
6
+ import numpy as np
7
+
8
+ import presets
9
+ import torch
10
+ import torch.utils.data
11
+ import torchvision
12
+ import utils
13
+ from coco_utils import get_coco
14
+ from torch import nn
15
+ from torchvision.transforms import functional as F, InterpolationMode
16
+ import wandb
17
+
18
+ from soccerdata import ExtremitiesDataset
19
+
20
+
21
+ def get_dataset(dir_path, name, image_set, transform):
22
+ def sbd(*args, **kwargs):
23
+ kwargs["root"] = "."
24
+ print(kwargs)
25
+ return torchvision.datasets.SBDataset(*args, mode="segmentation", download=True, **kwargs)
26
+
27
+ paths = {
28
+ "voc": (dir_path, torchvision.datasets.VOCSegmentation, 21),
29
+ "voc_aug": (dir_path, sbd, 21),
30
+ "coco": (dir_path, get_coco, 21),
31
+ }
32
+ p, ds_fn, num_classes = paths[name]
33
+
34
+ ds = ds_fn(p, image_set=image_set, transforms=transform)
35
+ return ds, num_classes
36
+
37
+
38
+ def get_transform(train, args):
39
+ if train:
40
+ return presets.SegmentationPresetTrain(base_size=520, crop_size=480)
41
+ elif args.weights and args.test_only:
42
+ weights = torchvision.models.get_weight(args.weights)
43
+ trans = weights.transforms()
44
+
45
+ def preprocessing(img, target):
46
+ img = trans(img)
47
+ size = F.get_dimensions(img)[1:]
48
+ target = F.resize(target, size, interpolation=InterpolationMode.NEAREST)
49
+ return img, F.pil_to_tensor(target)
50
+
51
+ return preprocessing
52
+ else:
53
+ return presets.SegmentationPresetEval(base_size=520)
54
+
55
+
56
+ def criterion(inputs, target):
57
+ losses = {}
58
+ for name, x in inputs.items():
59
+ losses[name] = nn.functional.cross_entropy(x, target, ignore_index=255)
60
+
61
+ if len(losses) == 1:
62
+ return losses["out"]
63
+
64
+ return losses["out"] + 0.5 * losses["aux"]
65
+
66
+
67
+ def evaluate(model, data_loader, device, num_classes):
68
+ model.eval()
69
+ confmat = utils.ConfusionMatrix(num_classes)
70
+ metric_logger = utils.MetricLogger(delimiter=" ")
71
+ header = "Test:"
72
+ num_processed_samples = 0
73
+ losses = []
74
+ with torch.inference_mode():
75
+ for image, target in metric_logger.log_every(data_loader, 100, header):
76
+ image, target = image.to(device), target.to(device)
77
+ output = model(image)
78
+ loss = criterion(output, target).unsqueeze(0).detach().cpu()
79
+ losses.append(loss)
80
+ output = output["out"]
81
+
82
+ # 1xCx224x224
83
+
84
+ confmat.update(target.flatten(), output.argmax(1).flatten())
85
+ # FIXME need to take into account that the datasets
86
+ # could have been padded in distributed setup
87
+ num_processed_samples += image.shape[0]
88
+
89
+ confmat.reduce_from_all_processes()
90
+
91
+ print(losses[0])
92
+ #print(losses)
93
+ losses = torch.cat(losses)
94
+ loss_argsort = torch.argsort(losses, descending=True)
95
+ loss_argsort.numpy()
96
+ np.save("losses_argsort.npy", loss_argsort)
97
+
98
+ num_processed_samples = utils.reduce_across_processes(num_processed_samples)
99
+ if (
100
+ hasattr(data_loader.dataset, "__len__")
101
+ and len(data_loader.dataset) != num_processed_samples
102
+ and torch.distributed.get_rank() == 0
103
+ ):
104
+ # See FIXME above
105
+ warnings.warn(
106
+ f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
107
+ "samples were used for the validation, which might bias the results. "
108
+ "Try adjusting the batch size and / or the world size. "
109
+ "Setting the world size to 1 is always a safe bet."
110
+ )
111
+
112
+ return confmat
113
+
114
+
115
+ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, print_freq, scaler=None):
116
+ model.train()
117
+ metric_logger = utils.MetricLogger(delimiter=" ")
118
+ metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
119
+ header = f"Epoch: [{epoch}]"
120
+ for image, target in metric_logger.log_every(data_loader, print_freq, header):
121
+ image, target = image.to(device), target.to(device)
122
+ with torch.cuda.amp.autocast(enabled=scaler is not None):
123
+ output = model(image)
124
+ loss = criterion(output, target)
125
+ wandb.log({"loss": loss})
126
+
127
+ optimizer.zero_grad()
128
+ if scaler is not None:
129
+ scaler.scale(loss).backward()
130
+ scaler.step(optimizer)
131
+ scaler.update()
132
+ else:
133
+ loss.backward()
134
+ optimizer.step()
135
+
136
+ lr_scheduler.step()
137
+
138
+ metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
139
+
140
+
141
+ def main(args):
142
+ wandb.init(project="sn-calibration", entity="rhotertj")
143
+ if args.output_dir:
144
+ utils.mkdir(args.output_dir)
145
+
146
+ utils.init_distributed_mode(args)
147
+ print(args)
148
+
149
+ device = torch.device(args.device)
150
+
151
+ if args.use_deterministic_algorithms:
152
+ torch.backends.cudnn.benchmark = False
153
+ torch.use_deterministic_algorithms(True)
154
+ else:
155
+ torch.backends.cudnn.benchmark = True
156
+
157
+ #dataset, num_classes = get_dataset(args.data_path, args.dataset, "train", get_transform(True, args))
158
+ dataset = ExtremitiesDataset(root="/nfs/data/soccernet/calibration", split="train")
159
+ num_classes = len(dataset.classes) + 1
160
+
161
+ dataset_test = ExtremitiesDataset(root="/nfs/data/soccernet/calibration", split="test")
162
+ #dataset_test, _ = get_dataset(args.data_path, args.dataset, "val", get_transform(False, args))
163
+
164
+ if args.distributed:
165
+ train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
166
+ test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
167
+ else:
168
+ train_sampler = torch.utils.data.RandomSampler(dataset)
169
+ test_sampler = torch.utils.data.SequentialSampler(dataset_test)
170
+
171
+ data_loader = torch.utils.data.DataLoader(
172
+ dataset,
173
+ batch_size=args.batch_size,
174
+ sampler=train_sampler,
175
+ num_workers=args.workers,
176
+ collate_fn=utils.collate_fn,
177
+ drop_last=True,
178
+ )
179
+
180
+ data_loader_test = torch.utils.data.DataLoader(
181
+ dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
182
+ )
183
+
184
+
185
+ model = torchvision.models.segmentation.deeplabv3_resnet101(num_classes=num_classes, aux_loss=args.aux_loss)
186
+ if args.test_only or args.resume:
187
+ model.load_state_dict(torch.load(args.weights)["model"], strict=False)
188
+
189
+ #model = torchvision.models.segmentation.__dict__[args.model](
190
+ # weights=args.weights, num_classes=num_classes, aux_loss=args.aux_loss
191
+ #)
192
+ model.to(device)
193
+ if args.distributed:
194
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
195
+
196
+ model_without_ddp = model
197
+ if args.distributed:
198
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
199
+ model_without_ddp = model.module
200
+
201
+ params_to_optimize = [
202
+ {"params": [p for p in model_without_ddp.backbone.parameters() if p.requires_grad]},
203
+ {"params": [p for p in model_without_ddp.classifier.parameters() if p.requires_grad]},
204
+ ]
205
+ if args.aux_loss:
206
+ params = [p for p in model_without_ddp.aux_classifier.parameters() if p.requires_grad]
207
+ params_to_optimize.append({"params": params, "lr": args.lr * 10})
208
+ optimizer = torch.optim.SGD(params_to_optimize, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
209
+
210
+ scaler = torch.cuda.amp.GradScaler() if args.amp else None
211
+
212
+ iters_per_epoch = len(data_loader)
213
+ main_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
214
+ optimizer, lambda x: (1 - x / (iters_per_epoch * (args.epochs - args.lr_warmup_epochs))) ** 0.9
215
+ )
216
+
217
+ if args.lr_warmup_epochs > 0:
218
+ warmup_iters = iters_per_epoch * args.lr_warmup_epochs
219
+ args.lr_warmup_method = args.lr_warmup_method.lower()
220
+ if args.lr_warmup_method == "linear":
221
+ warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
222
+ optimizer, start_factor=args.lr_warmup_decay, total_iters=warmup_iters
223
+ )
224
+ elif args.lr_warmup_method == "constant":
225
+ warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
226
+ optimizer, factor=args.lr_warmup_decay, total_iters=warmup_iters
227
+ )
228
+ else:
229
+ raise RuntimeError(
230
+ f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
231
+ )
232
+ lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
233
+ optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[warmup_iters]
234
+ )
235
+ else:
236
+ lr_scheduler = main_lr_scheduler
237
+
238
+ if args.resume:
239
+ #checkpoint = torch.load(args.resume, map_location="cpu")
240
+ #model_without_ddp.load_state_dict(checkpoint["model"], strict=not args.test_only)
241
+ if not args.test_only:
242
+ optimizer.load_state_dict(checkpoint["optimizer"])
243
+ lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
244
+ args.start_epoch = checkpoint["epoch"] + 1
245
+ if args.amp:
246
+ scaler.load_state_dict(checkpoint["scaler"])
247
+
248
+ if args.test_only:
249
+ # We disable the cudnn benchmarking because it can noticeably affect the accuracy
250
+ torch.backends.cudnn.benchmark = False
251
+ torch.backends.cudnn.deterministic = True
252
+ confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
253
+ print(confmat)
254
+ return
255
+
256
+ start_time = time.time()
257
+ for epoch in range(args.start_epoch, args.epochs):
258
+ if args.distributed:
259
+ train_sampler.set_epoch(epoch)
260
+ train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq, scaler)
261
+ confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
262
+ print(confmat)
263
+ checkpoint = {
264
+ "model": model_without_ddp.state_dict(),
265
+ "optimizer": optimizer.state_dict(),
266
+ "lr_scheduler": lr_scheduler.state_dict(),
267
+ "epoch": epoch,
268
+ "args": args,
269
+ }
270
+ if args.amp:
271
+ checkpoint["scaler"] = scaler.state_dict()
272
+ utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
273
+ utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
274
+
275
+ total_time = time.time() - start_time
276
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
277
+ print(f"Training time {total_time_str}")
278
+
279
+
280
+ def get_args_parser(add_help=True):
281
+ import argparse
282
+
283
+ parser = argparse.ArgumentParser(description="PyTorch Segmentation Training", add_help=add_help)
284
+
285
+ parser.add_argument("--data-path", default="/datasets01/COCO/022719/", type=str, help="dataset path")
286
+ parser.add_argument("--dataset", default="coco", type=str, help="dataset name")
287
+ parser.add_argument("--model", default="fcn_resnet101", type=str, help="model name")
288
+ parser.add_argument("--aux-loss", action="store_true", help="auxiliar loss")
289
+ parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
290
+ parser.add_argument(
291
+ "-b", "--batch-size", default=8, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
292
+ )
293
+ parser.add_argument("--epochs", default=30, type=int, metavar="N", help="number of total epochs to run")
294
+
295
+ parser.add_argument(
296
+ "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
297
+ )
298
+ parser.add_argument("--lr", default=0.01, type=float, help="initial learning rate")
299
+ parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
300
+ parser.add_argument(
301
+ "--wd",
302
+ "--weight-decay",
303
+ default=1e-4,
304
+ type=float,
305
+ metavar="W",
306
+ help="weight decay (default: 1e-4)",
307
+ dest="weight_decay",
308
+ )
309
+ parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
310
+ parser.add_argument("--lr-warmup-method", default="linear", type=str, help="the warmup method (default: linear)")
311
+ parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
312
+ parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
313
+ parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
314
+ parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
315
+ parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
316
+ parser.add_argument(
317
+ "--test-only",
318
+ dest="test_only",
319
+ help="Only test the model",
320
+ action="store_true",
321
+ )
322
+ parser.add_argument(
323
+ "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
324
+ )
325
+ # distributed training parameters
326
+ parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
327
+ parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
328
+
329
+ parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
330
+ parser.add_argument("--weights-backbone", default=None, type=str, help="the backbone weights enum name to load")
331
+
332
+ # Mixed precision training parameters
333
+ parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
334
+ parser.add_argument("--split", default="train", type=str, help="Dataset split to be used for training")
335
+
336
+ return parser
337
+
338
+
339
+ if __name__ == "__main__":
340
+ args = get_args_parser().parse_args()
341
+ main(args)
tvcalib/sn_segmentation/src/segmentation/transforms.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torchvision import transforms as T
6
+ from torchvision.transforms import functional as F
7
+
8
+
9
+ def pad_if_smaller(img, size, fill=0):
10
+ min_size = min(img.size)
11
+ if min_size < size:
12
+ ow, oh = img.size
13
+ padh = size - oh if oh < size else 0
14
+ padw = size - ow if ow < size else 0
15
+ img = F.pad(img, (0, 0, padw, padh), fill=fill)
16
+ return img
17
+
18
+
19
+ class Compose:
20
+ def __init__(self, transforms):
21
+ self.transforms = transforms
22
+
23
+ def __call__(self, image, target):
24
+ for t in self.transforms:
25
+ image, target = t(image, target)
26
+ return image, target
27
+
28
+
29
+ class RandomResize:
30
+ def __init__(self, min_size, max_size=None):
31
+ self.min_size = min_size
32
+ if max_size is None:
33
+ max_size = min_size
34
+ self.max_size = max_size
35
+
36
+ def __call__(self, image, target):
37
+ size = random.randint(self.min_size, self.max_size)
38
+ image = F.resize(image, size)
39
+ target = F.resize(target, size, interpolation=T.InterpolationMode.NEAREST)
40
+ return image, target
41
+
42
+
43
+ class RandomHorizontalFlip:
44
+ def __init__(self, flip_prob):
45
+ self.flip_prob = flip_prob
46
+
47
+ def __call__(self, image, target):
48
+ if random.random() < self.flip_prob:
49
+ image = F.hflip(image)
50
+ target = F.hflip(target)
51
+ return image, target
52
+
53
+
54
+ class RandomCrop:
55
+ def __init__(self, size):
56
+ self.size = size
57
+
58
+ def __call__(self, image, target):
59
+ image = pad_if_smaller(image, self.size)
60
+ target = pad_if_smaller(target, self.size, fill=255)
61
+ crop_params = T.RandomCrop.get_params(image, (self.size, self.size))
62
+ image = F.crop(image, *crop_params)
63
+ target = F.crop(target, *crop_params)
64
+ return image, target
65
+
66
+
67
+ class CenterCrop:
68
+ def __init__(self, size):
69
+ self.size = size
70
+
71
+ def __call__(self, image, target):
72
+ image = F.center_crop(image, self.size)
73
+ target = F.center_crop(target, self.size)
74
+ return image, target
75
+
76
+
77
+ class PILToTensor:
78
+ def __call__(self, image, target):
79
+ image = F.pil_to_tensor(image)
80
+ target = torch.as_tensor(np.array(target), dtype=torch.int64)
81
+ return image, target
82
+
83
+
84
+ class ConvertImageDtype:
85
+ def __init__(self, dtype):
86
+ self.dtype = dtype
87
+
88
+ def __call__(self, image, target):
89
+ image = F.convert_image_dtype(image, self.dtype)
90
+ return image, target
91
+
92
+
93
+ class Normalize:
94
+ def __init__(self, mean, std):
95
+ self.mean = mean
96
+ self.std = std
97
+
98
+ def __call__(self, image, target):
99
+ image = F.normalize(image, mean=self.mean, std=self.std)
100
+ return image, target
tvcalib/sn_segmentation/src/segmentation/utils.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import errno
3
+ import os
4
+ import time
5
+ from collections import defaultdict, deque
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+
10
+
11
+ class SmoothedValue:
12
+ """Track a series of values and provide access to smoothed values over a
13
+ window or the global series average.
14
+ """
15
+
16
+ def __init__(self, window_size=20, fmt=None):
17
+ if fmt is None:
18
+ fmt = "{median:.4f} ({global_avg:.4f})"
19
+ self.deque = deque(maxlen=window_size)
20
+ self.total = 0.0
21
+ self.count = 0
22
+ self.fmt = fmt
23
+
24
+ def update(self, value, n=1):
25
+ self.deque.append(value)
26
+ self.count += n
27
+ self.total += value * n
28
+
29
+ def synchronize_between_processes(self):
30
+ """
31
+ Warning: does not synchronize the deque!
32
+ """
33
+ t = reduce_across_processes([self.count, self.total])
34
+ t = t.tolist()
35
+ self.count = int(t[0])
36
+ self.total = t[1]
37
+
38
+ @property
39
+ def median(self):
40
+ d = torch.tensor(list(self.deque))
41
+ return d.median().item()
42
+
43
+ @property
44
+ def avg(self):
45
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
46
+ return d.mean().item()
47
+
48
+ @property
49
+ def global_avg(self):
50
+ return self.total / self.count
51
+
52
+ @property
53
+ def max(self):
54
+ return max(self.deque)
55
+
56
+ @property
57
+ def value(self):
58
+ return self.deque[-1]
59
+
60
+ def __str__(self):
61
+ return self.fmt.format(
62
+ median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
63
+ )
64
+
65
+
66
+ class ConfusionMatrix:
67
+ def __init__(self, num_classes):
68
+ self.num_classes = num_classes
69
+ self.mat = None
70
+
71
+ def update(self, a, b):
72
+ n = self.num_classes
73
+ if self.mat is None:
74
+ self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
75
+ with torch.inference_mode():
76
+ k = (a >= 0) & (a < n)
77
+ inds = n * a[k].to(torch.int64) + b[k]
78
+ self.mat += torch.bincount(inds, minlength=n ** 2).reshape(n, n)
79
+
80
+ def reset(self):
81
+ self.mat.zero_()
82
+
83
+ def compute(self):
84
+ h = self.mat.float()
85
+ acc_global = torch.diag(h).sum() / h.sum()
86
+ acc = torch.diag(h) / h.sum(1)
87
+ iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
88
+ return acc_global, acc, iu
89
+
90
+ def reduce_from_all_processes(self):
91
+ reduce_across_processes(self.mat)
92
+
93
+ def __str__(self):
94
+ acc_global, acc, iu = self.compute()
95
+ return ("global correct: {:.1f}\naverage row correct: {}\nIoU: {}\nmean IoU: {:.1f}").format(
96
+ acc_global.item() * 100,
97
+ [f"{i:.1f}" for i in (acc * 100).tolist()],
98
+ [f"{i:.1f}" for i in (iu * 100).tolist()],
99
+ iu.mean().item() * 100,
100
+ )
101
+
102
+
103
+ class MetricLogger:
104
+ def __init__(self, delimiter="\t"):
105
+ self.meters = defaultdict(SmoothedValue)
106
+ self.delimiter = delimiter
107
+
108
+ def update(self, **kwargs):
109
+ for k, v in kwargs.items():
110
+ if isinstance(v, torch.Tensor):
111
+ v = v.item()
112
+ if not isinstance(v, (float, int)):
113
+ raise TypeError(
114
+ f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}"
115
+ )
116
+ self.meters[k].update(v)
117
+
118
+ def __getattr__(self, attr):
119
+ if attr in self.meters:
120
+ return self.meters[attr]
121
+ if attr in self.__dict__:
122
+ return self.__dict__[attr]
123
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
124
+
125
+ def __str__(self):
126
+ loss_str = []
127
+ for name, meter in self.meters.items():
128
+ loss_str.append(f"{name}: {str(meter)}")
129
+ return self.delimiter.join(loss_str)
130
+
131
+ def synchronize_between_processes(self):
132
+ for meter in self.meters.values():
133
+ meter.synchronize_between_processes()
134
+
135
+ def add_meter(self, name, meter):
136
+ self.meters[name] = meter
137
+
138
+ def log_every(self, iterable, print_freq, header=None):
139
+ i = 0
140
+ if not header:
141
+ header = ""
142
+ start_time = time.time()
143
+ end = time.time()
144
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
145
+ data_time = SmoothedValue(fmt="{avg:.4f}")
146
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
147
+ if torch.cuda.is_available():
148
+ log_msg = self.delimiter.join(
149
+ [
150
+ header,
151
+ "[{0" + space_fmt + "}/{1}]",
152
+ "eta: {eta}",
153
+ "{meters}",
154
+ "time: {time}",
155
+ "data: {data}",
156
+ "max mem: {memory:.0f}",
157
+ ]
158
+ )
159
+ else:
160
+ log_msg = self.delimiter.join(
161
+ [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
162
+ )
163
+ MB = 1024.0 * 1024.0
164
+ for obj in iterable:
165
+ data_time.update(time.time() - end)
166
+ yield obj
167
+ iter_time.update(time.time() - end)
168
+ if i % print_freq == 0:
169
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
170
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
171
+ if torch.cuda.is_available():
172
+ print(
173
+ log_msg.format(
174
+ i,
175
+ len(iterable),
176
+ eta=eta_string,
177
+ meters=str(self),
178
+ time=str(iter_time),
179
+ data=str(data_time),
180
+ memory=torch.cuda.max_memory_allocated() / MB,
181
+ )
182
+ )
183
+ else:
184
+ print(
185
+ log_msg.format(
186
+ i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
187
+ )
188
+ )
189
+ i += 1
190
+ end = time.time()
191
+ total_time = time.time() - start_time
192
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
193
+ print(f"{header} Total time: {total_time_str}")
194
+
195
+
196
+ def cat_list(images, fill_value=0):
197
+ max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
198
+ batch_shape = (len(images),) + max_size
199
+ batched_imgs = images[0].new(*batch_shape).fill_(fill_value)
200
+ for img, pad_img in zip(images, batched_imgs):
201
+ pad_img[..., : img.shape[-2], : img.shape[-1]].copy_(img)
202
+ return batched_imgs
203
+
204
+
205
+ def collate_fn(batch):
206
+ images, targets = list(zip(*batch))
207
+ batched_imgs = cat_list(images, fill_value=0)
208
+ batched_targets = cat_list(targets, fill_value=255)
209
+ return batched_imgs, batched_targets
210
+
211
+
212
+ def mkdir(path):
213
+ try:
214
+ os.makedirs(path)
215
+ except OSError as e:
216
+ if e.errno != errno.EEXIST:
217
+ raise
218
+
219
+
220
+ def setup_for_distributed(is_master):
221
+ """
222
+ This function disables printing when not in master process
223
+ """
224
+ import builtins as __builtin__
225
+
226
+ builtin_print = __builtin__.print
227
+
228
+ def print(*args, **kwargs):
229
+ force = kwargs.pop("force", False)
230
+ if is_master or force:
231
+ builtin_print(*args, **kwargs)
232
+
233
+ __builtin__.print = print
234
+
235
+
236
+ def is_dist_avail_and_initialized():
237
+ if not dist.is_available():
238
+ return False
239
+ if not dist.is_initialized():
240
+ return False
241
+ return True
242
+
243
+
244
+ def get_world_size():
245
+ if not is_dist_avail_and_initialized():
246
+ return 1
247
+ return dist.get_world_size()
248
+
249
+
250
+ def get_rank():
251
+ if not is_dist_avail_and_initialized():
252
+ return 0
253
+ return dist.get_rank()
254
+
255
+
256
+ def is_main_process():
257
+ return get_rank() == 0
258
+
259
+
260
+ def save_on_master(*args, **kwargs):
261
+ if is_main_process():
262
+ torch.save(*args, **kwargs)
263
+
264
+
265
+ def init_distributed_mode(args):
266
+ print("Not using distributed mode")
267
+ args.distributed = False
268
+ return
269
+
270
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
271
+ args.rank = int(os.environ["RANK"])
272
+ args.world_size = int(os.environ["WORLD_SIZE"])
273
+ args.gpu = int(os.environ["LOCAL_RANK"])
274
+ elif "SLURM_PROCID" in os.environ:
275
+ args.rank = int(os.environ["SLURM_PROCID"])
276
+ args.gpu = args.rank % torch.cuda.device_count()
277
+ elif hasattr(args, "rank"):
278
+ pass
279
+ else:
280
+ print("Not using distributed mode")
281
+ args.distributed = False
282
+ return
283
+
284
+ args.distributed = True
285
+
286
+ torch.cuda.set_device(args.gpu)
287
+ args.dist_backend = "nccl"
288
+ print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
289
+ torch.distributed.init_process_group(
290
+ backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
291
+ )
292
+ torch.distributed.barrier()
293
+ setup_for_distributed(args.rank == 0)
294
+
295
+
296
+ def reduce_across_processes(val):
297
+ if not is_dist_avail_and_initialized():
298
+ # nothing to sync, but we still convert to tensor for consistency with the distributed case.
299
+ return torch.tensor(val)
300
+
301
+ t = torch.tensor(val, device="cuda")
302
+ dist.barrier()
303
+ dist.all_reduce(t)
304
+ return t
tvcalib/utils/data_distr.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def mean_std_with_confidence_interval(
5
+ vmin, vmax, sigma_scale: float, _steps=1000, round_decimals=4
6
+ ):
7
+ """Computes mean and std given min,max values with respect a confidence interval (sigma_scale).
8
+
9
+ sigma_scale = 1.65 -> 90% of samples are in range [vmin, vmax]
10
+ sigma_scale = 1.96 -> 95% of samples are in range [vmin, vmax]
11
+ sigma_scale = 2.58 -> 99% of samples are in range [vmin, vmax]
12
+ """
13
+
14
+ # sample from uniform distribution
15
+ x = torch.linspace(vmin, vmax, _steps)
16
+ mu = x.mean(dim=-1)
17
+ sigma = x.std(dim=-1)
18
+ return (round(mu.item(), round_decimals), round((sigma * sigma_scale).item(), round_decimals))
19
+
20
+
21
+ class FeatureScalerZScore(torch.nn.Module):
22
+ def __init__(self, loc: float, scale: float) -> None:
23
+ # Transforms data from distribution parameterized by loc (mean) and scale (=sigma*scaling factor).
24
+ super(FeatureScalerZScore, self).__init__()
25
+
26
+ self.loc = loc
27
+ self.scale = scale
28
+
29
+ def forward(self, z):
30
+ """
31
+ Args:
32
+ z (Tensor): tensor of size (B, *) to be denormalized.
33
+ Returns:
34
+ x: tensor.
35
+ """
36
+ return self.denormalize(z)
37
+
38
+ def denormalize(self, z):
39
+ x = z * self.scale + self.loc
40
+ return x
41
+
42
+ def normalize(self, x):
43
+ z = (x - self.loc) / self.scale
44
+ return z
tvcalib/utils/io.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ import json
3
+ from typing import List
4
+ import torch
5
+
6
+
7
+ def tensor2list(d: dict):
8
+ tensor2list_lambda = lambda x: x.detach().cpu().numpy().tolist()
9
+ for k in d.keys():
10
+ if isinstance(d[k], torch.Tensor):
11
+ d[k] = tensor2list_lambda(d[k])
12
+ if isinstance(d[k], List):
13
+ if isinstance(d[k][0], torch.Tensor):
14
+ d[k] = [tensor2list_lambda(x) for x in d[k]]
15
+ return d
16
+
17
+
18
+ def write_json(json_serializable_dict, fout, indent=2):
19
+ with open(fout, "w") as fw:
20
+ json.dump(json_serializable_dict, fw, indent=indent)
21
+
22
+
23
+ def write_yaml(json_serializable_dict, fout):
24
+ with open(fout, "w") as fw:
25
+ yaml.dump(json_serializable_dict, fw, default_flow_style=False)
26
+
27
+
28
+ def detach_dict(x_dict):
29
+ with torch.no_grad():
30
+ for k in x_dict.keys():
31
+ if isinstance(x_dict[k], torch.Tensor):
32
+ x_dict[k] = x_dict[k].detach().cpu()
33
+ elif isinstance(x_dict[k], dict):
34
+ x_dict[k] = detach_dict(x_dict[k])
35
+ return x_dict
36
+
37
+
38
+ def tensor2list(xdict):
39
+ for k in xdict.keys():
40
+ if isinstance(xdict[k], torch.Tensor):
41
+ xdict[k] = xdict[k].numpy().tolist()
42
+ elif isinstance(xdict[k], dict):
43
+ xdict[k] = tensor2list(xdict[k])
44
+ return xdict
tvcalib/utils/linalg.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+ import torch
3
+ from kornia.geometry.conversions import convert_points_from_homogeneous
4
+
5
+
6
+ class LineCollection:
7
+ def __init__(
8
+ self,
9
+ support: torch.tensor,
10
+ direction_norm: torch.tensor,
11
+ direction: Optional[torch.tensor] = None,
12
+ ):
13
+ """Wrapper class to represent lines by support and direction vectors.
14
+
15
+ Args:
16
+ support (torch.tensor): with shape (*, {2,3})
17
+ direction_norm (torch.tensor): with shape (*, {2,3})
18
+ direction (Optional[torch.tensor], optional): Unnormalized direction vector. Defaults to None.
19
+ """
20
+ self.support = support
21
+ self.direction_norm = direction_norm
22
+ self.direction = direction
23
+
24
+ def __copy__(self):
25
+ return LineCollection(
26
+ self.support.clone(),
27
+ self.direction_norm.clone(),
28
+ self.direction.clone() if self.direction is not None else None,
29
+ )
30
+
31
+ def copy(self):
32
+ return self.__copy__()
33
+
34
+ def shape(self):
35
+ return f"support={self.support.shape} direction_norm={self.direction_norm.shape} direction={self.direction.shape if self.direction else None}"
36
+
37
+ def __repr__(self) -> str:
38
+ return f"{self.__class__} " + self.shape()
39
+
40
+
41
+ def distance_line_pointcloud_3d(
42
+ e1: torch.Tensor,
43
+ r1: torch.Tensor,
44
+ pc: torch.Tensor,
45
+ reduce: Union[None, str] = None,
46
+ ) -> torch.Tensor:
47
+ """
48
+ Line to point cloud distance with arbitrary leading dimensions.
49
+
50
+ TODO. if cross = (0.0.0) -> distance=0 otherwise NaNs are returned
51
+
52
+ https://mathworld.wolfram.com/Point-LineDistance2-Dimensional.html
53
+ Args:
54
+ e1 (torch.Tensor): direction vector of shape (*, B, 1, 3)
55
+ r1 (torch.Tensor): support vector of shape (*, B, 1, 3)
56
+ pc (torch.Tensor): point cloud of shape (*, B, A, 3)
57
+ reduce (Union[None, str]): reduce distance for all points to one using 'mean' or 'min'
58
+ Returns:
59
+ distance of an infinite line to given points, (*, B, ) using reduce='mean' or reduce='min' or (*, B, A) if reduce=False
60
+ """
61
+
62
+ num_points = pc.shape[-2]
63
+ _sub = r1 - pc # (*, B, A, 3)
64
+
65
+ cross = torch.cross(e1.repeat_interleave(num_points, dim=-2), _sub, dim=-1) # (*, B, A, 3)
66
+
67
+ e1_norm = torch.linalg.norm(e1, dim=-1)
68
+ cross_norm = torch.linalg.norm(cross, dim=-1)
69
+
70
+ d = cross_norm / e1_norm
71
+ if reduce == "mean":
72
+ return d.mean(dim=-1) # (*, B, )
73
+ elif reduce == "min":
74
+ return d.min(dim=-1)[0] # (*, B, )
75
+
76
+ return d # (B, A)
77
+
78
+
79
+ def distance_point_pointcloud(points: torch.Tensor, pointcloud: torch.Tensor) -> torch.Tensor:
80
+ """Batched version for point-pointcloud distance calculation
81
+ Args:
82
+ points (torch.Tensor): N points in homogenous coordinates; shape (B, T, 3, S, N)
83
+ pointcloud (torch.Tensor): N_star points for each pointcloud; shape (B, T, S, N_star, 2)
84
+
85
+ Returns:
86
+ torch.Tensor: Minimum distance for each point N to pointcloud; shape (B, T, 1, S, N)
87
+ """
88
+
89
+ batch_size, T, _, S, N = points.shape
90
+ batch_size, T, S, N_star, _ = pointcloud.shape
91
+
92
+ pointcloud = pointcloud.reshape(batch_size * T * S, N_star, 2)
93
+
94
+ points = convert_points_from_homogeneous(
95
+ points.permute(0, 1, 3, 4, 2).reshape(batch_size * T * S, N, 3)
96
+ )
97
+
98
+ # cdist signature: (B, P, M), (B, R, M) -> (B, P, R)
99
+ distances = torch.cdist(points, pointcloud, p=2) # (B*T*S, N, N_star)
100
+
101
+ distances = distances.view(batch_size, T, S, N, N_star)
102
+ distances = distances.unsqueeze(-4)
103
+
104
+ # distance to nearest point from point cloud (batch_size, T, 1, S, N, N_star)
105
+ distances = distances.min(dim=-1)[0]
106
+ return distances
tvcalib/utils/objects_3d.py ADDED
@@ -0,0 +1,1674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABCMeta
2
+ from typing import List
3
+ import kornia
4
+ import torch
5
+ import numpy as np
6
+ import random
7
+ from .linalg import LineCollection
8
+ from torch.utils.data import Dataset, DataLoader
9
+ from pytorch_lightning import LightningDataModule
10
+
11
+
12
+ class SoccerPitchSN:
13
+ """Static class variables that are specified by the rules of the game"""
14
+
15
+ GOAL_LINE_TO_PENALTY_MARK = 11.0
16
+ PENALTY_AREA_WIDTH = 40.32
17
+ PENALTY_AREA_LENGTH = 16.5
18
+ GOAL_AREA_WIDTH = 18.32
19
+ GOAL_AREA_LENGTH = 5.5
20
+ CENTER_CIRCLE_RADIUS = 9.15
21
+ GOAL_HEIGHT = 2.44
22
+ GOAL_LENGTH = 7.32
23
+
24
+ lines_classes = [
25
+ "Big rect. left bottom",
26
+ "Big rect. left main",
27
+ "Big rect. left top",
28
+ "Big rect. right bottom",
29
+ "Big rect. right main",
30
+ "Big rect. right top",
31
+ "Circle central",
32
+ "Circle left",
33
+ "Circle right",
34
+ "Goal left crossbar",
35
+ "Goal left post left ",
36
+ "Goal left post right",
37
+ "Goal right crossbar",
38
+ "Goal right post left",
39
+ "Goal right post right",
40
+ "Goal unknown",
41
+ "Line unknown",
42
+ "Middle line",
43
+ "Side line bottom",
44
+ "Side line left",
45
+ "Side line right",
46
+ "Side line top",
47
+ "Small rect. left bottom",
48
+ "Small rect. left main",
49
+ "Small rect. left top",
50
+ "Small rect. right bottom",
51
+ "Small rect. right main",
52
+ "Small rect. right top",
53
+ ]
54
+
55
+ symetric_classes = {
56
+ "Side line top": "Side line bottom",
57
+ "Side line bottom": "Side line top",
58
+ "Side line left": "Side line right",
59
+ "Middle line": "Middle line",
60
+ "Side line right": "Side line left",
61
+ "Big rect. left top": "Big rect. right bottom",
62
+ "Big rect. left bottom": "Big rect. right top",
63
+ "Big rect. left main": "Big rect. right main",
64
+ "Big rect. right top": "Big rect. left bottom",
65
+ "Big rect. right bottom": "Big rect. left top",
66
+ "Big rect. right main": "Big rect. left main",
67
+ "Small rect. left top": "Small rect. right bottom",
68
+ "Small rect. left bottom": "Small rect. right top",
69
+ "Small rect. left main": "Small rect. right main",
70
+ "Small rect. right top": "Small rect. left bottom",
71
+ "Small rect. right bottom": "Small rect. left top",
72
+ "Small rect. right main": "Small rect. left main",
73
+ "Circle left": "Circle right",
74
+ "Circle central": "Circle central",
75
+ "Circle right": "Circle left",
76
+ "Goal left crossbar": "Goal right crossbar",
77
+ "Goal left post left ": "Goal right post right",
78
+ "Goal left post right": "Goal right post left",
79
+ "Goal right crossbar": "Goal left crossbar",
80
+ "Goal right post left": "Goal left post right",
81
+ "Goal right post right": "Goal left post left ",
82
+ "Goal unknown": "Goal unknown",
83
+ "Line unknown": "Line unknown",
84
+ }
85
+
86
+ # RGB values
87
+ palette = {
88
+ "Big rect. left bottom": (127, 0, 0),
89
+ "Big rect. left main": (102, 102, 102),
90
+ "Big rect. left top": (0, 0, 127),
91
+ "Big rect. right bottom": (86, 32, 39),
92
+ "Big rect. right main": (48, 77, 0),
93
+ "Big rect. right top": (14, 97, 100),
94
+ "Circle central": (0, 0, 255),
95
+ "Circle left": (255, 127, 0),
96
+ "Circle right": (0, 255, 255),
97
+ "Goal left crossbar": (255, 255, 200),
98
+ "Goal left post left ": (165, 255, 0),
99
+ "Goal left post right": (155, 119, 45),
100
+ "Goal right crossbar": (86, 32, 139),
101
+ "Goal right post left": (196, 120, 153),
102
+ "Goal right post right": (166, 36, 52),
103
+ "Goal unknown": (0, 0, 0),
104
+ "Line unknown": (0, 0, 0),
105
+ "Middle line": (255, 255, 0),
106
+ "Side line bottom": (255, 0, 255),
107
+ "Side line left": (0, 255, 150),
108
+ "Side line right": (0, 230, 0),
109
+ "Side line top": (230, 0, 0),
110
+ "Small rect. left bottom": (0, 150, 255),
111
+ "Small rect. left main": (254, 173, 225),
112
+ "Small rect. left top": (87, 72, 39),
113
+ "Small rect. right bottom": (122, 0, 255),
114
+ "Small rect. right main": (128, 128, 128), # (255, 255, 255)
115
+ "Small rect. right top": (153, 23, 153),
116
+ }
117
+
118
+ def __init__(self, pitch_length=105.0, pitch_width=68.0):
119
+ """
120
+ Initialize 3D coordinates of all elements of the soccer pitch.
121
+ :param pitch_length: According to FIFA rules, length belong to [90,120] meters
122
+ :param pitch_width: According to FIFA rules, length belong to [45,90] meters
123
+ """
124
+ self.PITCH_LENGTH = pitch_length
125
+ self.PITCH_WIDTH = pitch_width
126
+
127
+ self.center_mark = np.array([0, 0, 0], dtype="float")
128
+ self.halfway_and_bottom_touch_line_mark = np.array([0, pitch_width / 2.0, 0], dtype="float")
129
+ self.halfway_and_top_touch_line_mark = np.array([0, -pitch_width / 2.0, 0], dtype="float")
130
+ self.halfway_line_and_center_circle_top_mark = np.array(
131
+ [0, -SoccerPitchSN.CENTER_CIRCLE_RADIUS, 0], dtype="float"
132
+ )
133
+ self.halfway_line_and_center_circle_bottom_mark = np.array(
134
+ [0, SoccerPitchSN.CENTER_CIRCLE_RADIUS, 0], dtype="float"
135
+ )
136
+ self.bottom_right_corner = np.array(
137
+ [pitch_length / 2.0, pitch_width / 2.0, 0], dtype="float"
138
+ )
139
+ self.bottom_left_corner = np.array(
140
+ [-pitch_length / 2.0, pitch_width / 2.0, 0], dtype="float"
141
+ )
142
+ self.top_right_corner = np.array([pitch_length / 2.0, -pitch_width / 2.0, 0], dtype="float")
143
+ self.top_left_corner = np.array([-pitch_length / 2.0, -34, 0], dtype="float")
144
+
145
+ self.left_goal_bottom_left_post = np.array(
146
+ [-pitch_length / 2.0, SoccerPitchSN.GOAL_LENGTH / 2.0, 0.0], dtype="float"
147
+ )
148
+ self.left_goal_top_left_post = np.array(
149
+ [-pitch_length / 2.0, SoccerPitchSN.GOAL_LENGTH / 2.0, -SoccerPitchSN.GOAL_HEIGHT],
150
+ dtype="float",
151
+ )
152
+ self.left_goal_bottom_right_post = np.array(
153
+ [-pitch_length / 2.0, -SoccerPitchSN.GOAL_LENGTH / 2.0, 0.0], dtype="float"
154
+ )
155
+ self.left_goal_top_right_post = np.array(
156
+ [-pitch_length / 2.0, -SoccerPitchSN.GOAL_LENGTH / 2.0, -SoccerPitchSN.GOAL_HEIGHT],
157
+ dtype="float",
158
+ )
159
+
160
+ self.right_goal_bottom_left_post = np.array(
161
+ [pitch_length / 2.0, -SoccerPitchSN.GOAL_LENGTH / 2.0, 0.0], dtype="float"
162
+ )
163
+ self.right_goal_top_left_post = np.array(
164
+ [pitch_length / 2.0, -SoccerPitchSN.GOAL_LENGTH / 2.0, -SoccerPitchSN.GOAL_HEIGHT],
165
+ dtype="float",
166
+ )
167
+ self.right_goal_bottom_right_post = np.array(
168
+ [pitch_length / 2.0, SoccerPitchSN.GOAL_LENGTH / 2.0, 0.0], dtype="float"
169
+ )
170
+ self.right_goal_top_right_post = np.array(
171
+ [pitch_length / 2.0, SoccerPitchSN.GOAL_LENGTH / 2.0, -SoccerPitchSN.GOAL_HEIGHT],
172
+ dtype="float",
173
+ )
174
+
175
+ self.left_penalty_mark = np.array(
176
+ [-pitch_length / 2.0 + SoccerPitchSN.GOAL_LINE_TO_PENALTY_MARK, 0, 0], dtype="float"
177
+ )
178
+ self.right_penalty_mark = np.array(
179
+ [pitch_length / 2.0 - SoccerPitchSN.GOAL_LINE_TO_PENALTY_MARK, 0, 0], dtype="float"
180
+ )
181
+
182
+ self.left_penalty_area_top_right_corner = np.array(
183
+ [
184
+ -pitch_length / 2.0 + SoccerPitchSN.PENALTY_AREA_LENGTH,
185
+ -SoccerPitchSN.PENALTY_AREA_WIDTH / 2.0,
186
+ 0,
187
+ ],
188
+ dtype="float",
189
+ )
190
+ self.left_penalty_area_top_left_corner = np.array(
191
+ [-pitch_length / 2.0, -SoccerPitchSN.PENALTY_AREA_WIDTH / 2.0, 0], dtype="float"
192
+ )
193
+ self.left_penalty_area_bottom_right_corner = np.array(
194
+ [
195
+ -pitch_length / 2.0 + SoccerPitchSN.PENALTY_AREA_LENGTH,
196
+ SoccerPitchSN.PENALTY_AREA_WIDTH / 2.0,
197
+ 0,
198
+ ],
199
+ dtype="float",
200
+ )
201
+ self.left_penalty_area_bottom_left_corner = np.array(
202
+ [-pitch_length / 2.0, SoccerPitchSN.PENALTY_AREA_WIDTH / 2.0, 0], dtype="float"
203
+ )
204
+ self.right_penalty_area_top_right_corner = np.array(
205
+ [pitch_length / 2.0, -SoccerPitchSN.PENALTY_AREA_WIDTH / 2.0, 0], dtype="float"
206
+ )
207
+ self.right_penalty_area_top_left_corner = np.array(
208
+ [
209
+ pitch_length / 2.0 - SoccerPitchSN.PENALTY_AREA_LENGTH,
210
+ -SoccerPitchSN.PENALTY_AREA_WIDTH / 2.0,
211
+ 0,
212
+ ],
213
+ dtype="float",
214
+ )
215
+ self.right_penalty_area_bottom_right_corner = np.array(
216
+ [pitch_length / 2.0, SoccerPitchSN.PENALTY_AREA_WIDTH / 2.0, 0], dtype="float"
217
+ )
218
+ self.right_penalty_area_bottom_left_corner = np.array(
219
+ [
220
+ pitch_length / 2.0 - SoccerPitchSN.PENALTY_AREA_LENGTH,
221
+ SoccerPitchSN.PENALTY_AREA_WIDTH / 2.0,
222
+ 0,
223
+ ],
224
+ dtype="float",
225
+ )
226
+
227
+ self.left_goal_area_top_right_corner = np.array(
228
+ [
229
+ -pitch_length / 2.0 + SoccerPitchSN.GOAL_AREA_LENGTH,
230
+ -SoccerPitchSN.GOAL_AREA_WIDTH / 2.0,
231
+ 0,
232
+ ],
233
+ dtype="float",
234
+ )
235
+ self.left_goal_area_top_left_corner = np.array(
236
+ [-pitch_length / 2.0, -SoccerPitchSN.GOAL_AREA_WIDTH / 2.0, 0], dtype="float"
237
+ )
238
+ self.left_goal_area_bottom_right_corner = np.array(
239
+ [
240
+ -pitch_length / 2.0 + SoccerPitchSN.GOAL_AREA_LENGTH,
241
+ SoccerPitchSN.GOAL_AREA_WIDTH / 2.0,
242
+ 0,
243
+ ],
244
+ dtype="float",
245
+ )
246
+ self.left_goal_area_bottom_left_corner = np.array(
247
+ [-pitch_length / 2.0, SoccerPitchSN.GOAL_AREA_WIDTH / 2.0, 0], dtype="float"
248
+ )
249
+ self.right_goal_area_top_right_corner = np.array(
250
+ [pitch_length / 2.0, -SoccerPitchSN.GOAL_AREA_WIDTH / 2.0, 0], dtype="float"
251
+ )
252
+ self.right_goal_area_top_left_corner = np.array(
253
+ [
254
+ pitch_length / 2.0 - SoccerPitchSN.GOAL_AREA_LENGTH,
255
+ -SoccerPitchSN.GOAL_AREA_WIDTH / 2.0,
256
+ 0,
257
+ ],
258
+ dtype="float",
259
+ )
260
+ self.right_goal_area_bottom_right_corner = np.array(
261
+ [pitch_length / 2.0, SoccerPitchSN.GOAL_AREA_WIDTH / 2.0, 0], dtype="float"
262
+ )
263
+ self.right_goal_area_bottom_left_corner = np.array(
264
+ [
265
+ pitch_length / 2.0 - SoccerPitchSN.GOAL_AREA_LENGTH,
266
+ SoccerPitchSN.GOAL_AREA_WIDTH / 2.0,
267
+ 0,
268
+ ],
269
+ dtype="float",
270
+ )
271
+
272
+ x = -pitch_length / 2.0 + SoccerPitchSN.PENALTY_AREA_LENGTH
273
+ dx = SoccerPitchSN.PENALTY_AREA_LENGTH - SoccerPitchSN.GOAL_LINE_TO_PENALTY_MARK
274
+ y = -np.sqrt(
275
+ SoccerPitchSN.CENTER_CIRCLE_RADIUS * SoccerPitchSN.CENTER_CIRCLE_RADIUS - dx * dx
276
+ )
277
+ self.top_left_16M_penalty_arc_mark = np.array([x, y, 0], dtype="float")
278
+
279
+ x = pitch_length / 2.0 - SoccerPitchSN.PENALTY_AREA_LENGTH
280
+ dx = SoccerPitchSN.PENALTY_AREA_LENGTH - SoccerPitchSN.GOAL_LINE_TO_PENALTY_MARK
281
+ y = -np.sqrt(
282
+ SoccerPitchSN.CENTER_CIRCLE_RADIUS * SoccerPitchSN.CENTER_CIRCLE_RADIUS - dx * dx
283
+ )
284
+ self.top_right_16M_penalty_arc_mark = np.array([x, y, 0], dtype="float")
285
+
286
+ x = -pitch_length / 2.0 + SoccerPitchSN.PENALTY_AREA_LENGTH
287
+ dx = SoccerPitchSN.PENALTY_AREA_LENGTH - SoccerPitchSN.GOAL_LINE_TO_PENALTY_MARK
288
+ y = np.sqrt(
289
+ SoccerPitchSN.CENTER_CIRCLE_RADIUS * SoccerPitchSN.CENTER_CIRCLE_RADIUS - dx * dx
290
+ )
291
+ self.bottom_left_16M_penalty_arc_mark = np.array([x, y, 0], dtype="float")
292
+
293
+ x = pitch_length / 2.0 - SoccerPitchSN.PENALTY_AREA_LENGTH
294
+ dx = SoccerPitchSN.PENALTY_AREA_LENGTH - SoccerPitchSN.GOAL_LINE_TO_PENALTY_MARK
295
+ y = np.sqrt(
296
+ SoccerPitchSN.CENTER_CIRCLE_RADIUS * SoccerPitchSN.CENTER_CIRCLE_RADIUS - dx * dx
297
+ )
298
+ self.bottom_right_16M_penalty_arc_mark = np.array([x, y, 0], dtype="float")
299
+
300
+ # self.set_elevations(elevation)
301
+
302
+ self.point_dict = {}
303
+ self.point_dict["CENTER_MARK"] = self.center_mark
304
+ self.point_dict["L_PENALTY_MARK"] = self.left_penalty_mark
305
+ self.point_dict["R_PENALTY_MARK"] = self.right_penalty_mark
306
+ self.point_dict["TL_PITCH_CORNER"] = self.top_left_corner
307
+ self.point_dict["BL_PITCH_CORNER"] = self.bottom_left_corner
308
+ self.point_dict["TR_PITCH_CORNER"] = self.top_right_corner
309
+ self.point_dict["BR_PITCH_CORNER"] = self.bottom_right_corner
310
+ self.point_dict["L_PENALTY_AREA_TL_CORNER"] = self.left_penalty_area_top_left_corner
311
+ self.point_dict["L_PENALTY_AREA_TR_CORNER"] = self.left_penalty_area_top_right_corner
312
+ self.point_dict["L_PENALTY_AREA_BL_CORNER"] = self.left_penalty_area_bottom_left_corner
313
+ self.point_dict["L_PENALTY_AREA_BR_CORNER"] = self.left_penalty_area_bottom_right_corner
314
+
315
+ self.point_dict["R_PENALTY_AREA_TL_CORNER"] = self.right_penalty_area_top_left_corner
316
+ self.point_dict["R_PENALTY_AREA_TR_CORNER"] = self.right_penalty_area_top_right_corner
317
+ self.point_dict["R_PENALTY_AREA_BL_CORNER"] = self.right_penalty_area_bottom_left_corner
318
+ self.point_dict["R_PENALTY_AREA_BR_CORNER"] = self.right_penalty_area_bottom_right_corner
319
+
320
+ self.point_dict["L_GOAL_AREA_TL_CORNER"] = self.left_goal_area_top_left_corner
321
+ self.point_dict["L_GOAL_AREA_TR_CORNER"] = self.left_goal_area_top_right_corner
322
+ self.point_dict["L_GOAL_AREA_BL_CORNER"] = self.left_goal_area_bottom_left_corner
323
+ self.point_dict["L_GOAL_AREA_BR_CORNER"] = self.left_goal_area_bottom_right_corner
324
+
325
+ self.point_dict["R_GOAL_AREA_TL_CORNER"] = self.right_goal_area_top_left_corner
326
+ self.point_dict["R_GOAL_AREA_TR_CORNER"] = self.right_goal_area_top_right_corner
327
+ self.point_dict["R_GOAL_AREA_BL_CORNER"] = self.right_goal_area_bottom_left_corner
328
+ self.point_dict["R_GOAL_AREA_BR_CORNER"] = self.right_goal_area_bottom_right_corner
329
+
330
+ self.point_dict["L_GOAL_TL_POST"] = self.left_goal_top_left_post
331
+ self.point_dict["L_GOAL_TR_POST"] = self.left_goal_top_right_post
332
+ self.point_dict["L_GOAL_BL_POST"] = self.left_goal_bottom_left_post
333
+ self.point_dict["L_GOAL_BR_POST"] = self.left_goal_bottom_right_post
334
+
335
+ self.point_dict["R_GOAL_TL_POST"] = self.right_goal_top_left_post
336
+ self.point_dict["R_GOAL_TR_POST"] = self.right_goal_top_right_post
337
+ self.point_dict["R_GOAL_BL_POST"] = self.right_goal_bottom_left_post
338
+ self.point_dict["R_GOAL_BR_POST"] = self.right_goal_bottom_right_post
339
+
340
+ self.point_dict[
341
+ "T_TOUCH_AND_HALFWAY_LINES_INTERSECTION"
342
+ ] = self.halfway_and_top_touch_line_mark
343
+ self.point_dict[
344
+ "B_TOUCH_AND_HALFWAY_LINES_INTERSECTION"
345
+ ] = self.halfway_and_bottom_touch_line_mark
346
+ self.point_dict[
347
+ "T_HALFWAY_LINE_AND_CENTER_CIRCLE_INTERSECTION"
348
+ ] = self.halfway_line_and_center_circle_top_mark
349
+ self.point_dict[
350
+ "B_HALFWAY_LINE_AND_CENTER_CIRCLE_INTERSECTION"
351
+ ] = self.halfway_line_and_center_circle_bottom_mark
352
+ self.point_dict[
353
+ "TL_16M_LINE_AND_PENALTY_ARC_INTERSECTION"
354
+ ] = self.top_left_16M_penalty_arc_mark
355
+ self.point_dict[
356
+ "BL_16M_LINE_AND_PENALTY_ARC_INTERSECTION"
357
+ ] = self.bottom_left_16M_penalty_arc_mark
358
+ self.point_dict[
359
+ "TR_16M_LINE_AND_PENALTY_ARC_INTERSECTION"
360
+ ] = self.top_right_16M_penalty_arc_mark
361
+ self.point_dict[
362
+ "BR_16M_LINE_AND_PENALTY_ARC_INTERSECTION"
363
+ ] = self.bottom_right_16M_penalty_arc_mark
364
+
365
+ self.line_extremities = dict()
366
+ self.line_extremities["Big rect. left bottom"] = (
367
+ self.point_dict["L_PENALTY_AREA_BL_CORNER"],
368
+ self.point_dict["L_PENALTY_AREA_BR_CORNER"],
369
+ )
370
+ self.line_extremities["Big rect. left top"] = (
371
+ self.point_dict["L_PENALTY_AREA_TL_CORNER"],
372
+ self.point_dict["L_PENALTY_AREA_TR_CORNER"],
373
+ )
374
+ self.line_extremities["Big rect. left main"] = (
375
+ self.point_dict["L_PENALTY_AREA_TR_CORNER"],
376
+ self.point_dict["L_PENALTY_AREA_BR_CORNER"],
377
+ )
378
+ self.line_extremities["Big rect. right bottom"] = (
379
+ self.point_dict["R_PENALTY_AREA_BL_CORNER"],
380
+ self.point_dict["R_PENALTY_AREA_BR_CORNER"],
381
+ )
382
+ self.line_extremities["Big rect. right top"] = (
383
+ self.point_dict["R_PENALTY_AREA_TL_CORNER"],
384
+ self.point_dict["R_PENALTY_AREA_TR_CORNER"],
385
+ )
386
+ self.line_extremities["Big rect. right main"] = (
387
+ self.point_dict["R_PENALTY_AREA_TL_CORNER"],
388
+ self.point_dict["R_PENALTY_AREA_BL_CORNER"],
389
+ )
390
+
391
+ self.line_extremities["Small rect. left bottom"] = (
392
+ self.point_dict["L_GOAL_AREA_BL_CORNER"],
393
+ self.point_dict["L_GOAL_AREA_BR_CORNER"],
394
+ )
395
+ self.line_extremities["Small rect. left top"] = (
396
+ self.point_dict["L_GOAL_AREA_TL_CORNER"],
397
+ self.point_dict["L_GOAL_AREA_TR_CORNER"],
398
+ )
399
+ self.line_extremities["Small rect. left main"] = (
400
+ self.point_dict["L_GOAL_AREA_TR_CORNER"],
401
+ self.point_dict["L_GOAL_AREA_BR_CORNER"],
402
+ )
403
+ self.line_extremities["Small rect. right bottom"] = (
404
+ self.point_dict["R_GOAL_AREA_BL_CORNER"],
405
+ self.point_dict["R_GOAL_AREA_BR_CORNER"],
406
+ )
407
+ self.line_extremities["Small rect. right top"] = (
408
+ self.point_dict["R_GOAL_AREA_TL_CORNER"],
409
+ self.point_dict["R_GOAL_AREA_TR_CORNER"],
410
+ )
411
+ self.line_extremities["Small rect. right main"] = (
412
+ self.point_dict["R_GOAL_AREA_TL_CORNER"],
413
+ self.point_dict["R_GOAL_AREA_BL_CORNER"],
414
+ )
415
+
416
+ self.line_extremities["Side line top"] = (
417
+ self.point_dict["TL_PITCH_CORNER"],
418
+ self.point_dict["TR_PITCH_CORNER"],
419
+ )
420
+ self.line_extremities["Side line bottom"] = (
421
+ self.point_dict["BL_PITCH_CORNER"],
422
+ self.point_dict["BR_PITCH_CORNER"],
423
+ )
424
+ self.line_extremities["Side line left"] = (
425
+ self.point_dict["TL_PITCH_CORNER"],
426
+ self.point_dict["BL_PITCH_CORNER"],
427
+ )
428
+ self.line_extremities["Side line right"] = (
429
+ self.point_dict["TR_PITCH_CORNER"],
430
+ self.point_dict["BR_PITCH_CORNER"],
431
+ )
432
+ self.line_extremities["Middle line"] = (
433
+ self.point_dict["T_TOUCH_AND_HALFWAY_LINES_INTERSECTION"],
434
+ self.point_dict["B_TOUCH_AND_HALFWAY_LINES_INTERSECTION"],
435
+ )
436
+
437
+ self.line_extremities["Goal left crossbar"] = (
438
+ self.point_dict["L_GOAL_TR_POST"],
439
+ self.point_dict["L_GOAL_TL_POST"],
440
+ )
441
+ self.line_extremities["Goal left post left "] = (
442
+ self.point_dict["L_GOAL_TL_POST"],
443
+ self.point_dict["L_GOAL_BL_POST"],
444
+ )
445
+ self.line_extremities["Goal left post right"] = (
446
+ self.point_dict["L_GOAL_TR_POST"],
447
+ self.point_dict["L_GOAL_BR_POST"],
448
+ )
449
+
450
+ self.line_extremities["Goal right crossbar"] = (
451
+ self.point_dict["R_GOAL_TL_POST"],
452
+ self.point_dict["R_GOAL_TR_POST"],
453
+ )
454
+ self.line_extremities["Goal right post left"] = (
455
+ self.point_dict["R_GOAL_TL_POST"],
456
+ self.point_dict["R_GOAL_BL_POST"],
457
+ )
458
+ self.line_extremities["Goal right post right"] = (
459
+ self.point_dict["R_GOAL_TR_POST"],
460
+ self.point_dict["R_GOAL_BR_POST"],
461
+ )
462
+ self.line_extremities["Circle right"] = (
463
+ self.point_dict["TR_16M_LINE_AND_PENALTY_ARC_INTERSECTION"],
464
+ self.point_dict["BR_16M_LINE_AND_PENALTY_ARC_INTERSECTION"],
465
+ )
466
+ self.line_extremities["Circle left"] = (
467
+ self.point_dict["TL_16M_LINE_AND_PENALTY_ARC_INTERSECTION"],
468
+ self.point_dict["BL_16M_LINE_AND_PENALTY_ARC_INTERSECTION"],
469
+ )
470
+
471
+ self.line_extremities_keys = dict()
472
+ self.line_extremities_keys["Big rect. left bottom"] = (
473
+ "L_PENALTY_AREA_BL_CORNER",
474
+ "L_PENALTY_AREA_BR_CORNER",
475
+ )
476
+ self.line_extremities_keys["Big rect. left top"] = (
477
+ "L_PENALTY_AREA_TL_CORNER",
478
+ "L_PENALTY_AREA_TR_CORNER",
479
+ )
480
+ self.line_extremities_keys["Big rect. left main"] = (
481
+ "L_PENALTY_AREA_TR_CORNER",
482
+ "L_PENALTY_AREA_BR_CORNER",
483
+ )
484
+ self.line_extremities_keys["Big rect. right bottom"] = (
485
+ "R_PENALTY_AREA_BL_CORNER",
486
+ "R_PENALTY_AREA_BR_CORNER",
487
+ )
488
+ self.line_extremities_keys["Big rect. right top"] = (
489
+ "R_PENALTY_AREA_TL_CORNER",
490
+ "R_PENALTY_AREA_TR_CORNER",
491
+ )
492
+ self.line_extremities_keys["Big rect. right main"] = (
493
+ "R_PENALTY_AREA_TL_CORNER",
494
+ "R_PENALTY_AREA_BL_CORNER",
495
+ )
496
+
497
+ self.line_extremities_keys["Small rect. left bottom"] = (
498
+ "L_GOAL_AREA_BL_CORNER",
499
+ "L_GOAL_AREA_BR_CORNER",
500
+ )
501
+ self.line_extremities_keys["Small rect. left top"] = (
502
+ "L_GOAL_AREA_TL_CORNER",
503
+ "L_GOAL_AREA_TR_CORNER",
504
+ )
505
+ self.line_extremities_keys["Small rect. left main"] = (
506
+ "L_GOAL_AREA_TR_CORNER",
507
+ "L_GOAL_AREA_BR_CORNER",
508
+ )
509
+ self.line_extremities_keys["Small rect. right bottom"] = (
510
+ "R_GOAL_AREA_BL_CORNER",
511
+ "R_GOAL_AREA_BR_CORNER",
512
+ )
513
+ self.line_extremities_keys["Small rect. right top"] = (
514
+ "R_GOAL_AREA_TL_CORNER",
515
+ "R_GOAL_AREA_TR_CORNER",
516
+ )
517
+ self.line_extremities_keys["Small rect. right main"] = (
518
+ "R_GOAL_AREA_TL_CORNER",
519
+ "R_GOAL_AREA_BL_CORNER",
520
+ )
521
+
522
+ self.line_extremities_keys["Side line top"] = ("TL_PITCH_CORNER", "TR_PITCH_CORNER")
523
+ self.line_extremities_keys["Side line bottom"] = ("BL_PITCH_CORNER", "BR_PITCH_CORNER")
524
+ self.line_extremities_keys["Side line left"] = ("TL_PITCH_CORNER", "BL_PITCH_CORNER")
525
+ self.line_extremities_keys["Side line right"] = ("TR_PITCH_CORNER", "BR_PITCH_CORNER")
526
+ self.line_extremities_keys["Middle line"] = (
527
+ "T_TOUCH_AND_HALFWAY_LINES_INTERSECTION",
528
+ "B_TOUCH_AND_HALFWAY_LINES_INTERSECTION",
529
+ )
530
+
531
+ self.line_extremities_keys["Goal left crossbar"] = ("L_GOAL_TR_POST", "L_GOAL_TL_POST")
532
+ self.line_extremities_keys["Goal left post left "] = ("L_GOAL_TL_POST", "L_GOAL_BL_POST")
533
+ self.line_extremities_keys["Goal left post right"] = ("L_GOAL_TR_POST", "L_GOAL_BR_POST")
534
+
535
+ self.line_extremities_keys["Goal right crossbar"] = ("R_GOAL_TL_POST", "R_GOAL_TR_POST")
536
+ self.line_extremities_keys["Goal right post left"] = ("R_GOAL_TL_POST", "R_GOAL_BL_POST")
537
+ self.line_extremities_keys["Goal right post right"] = ("R_GOAL_TR_POST", "R_GOAL_BR_POST")
538
+ self.line_extremities_keys["Circle right"] = (
539
+ "TR_16M_LINE_AND_PENALTY_ARC_INTERSECTION",
540
+ "BR_16M_LINE_AND_PENALTY_ARC_INTERSECTION",
541
+ )
542
+ self.line_extremities_keys["Circle left"] = (
543
+ "TL_16M_LINE_AND_PENALTY_ARC_INTERSECTION",
544
+ "BL_16M_LINE_AND_PENALTY_ARC_INTERSECTION",
545
+ )
546
+
547
+ def points(self):
548
+ return [
549
+ self.center_mark,
550
+ self.halfway_and_bottom_touch_line_mark,
551
+ self.halfway_and_top_touch_line_mark,
552
+ self.halfway_line_and_center_circle_top_mark,
553
+ self.halfway_line_and_center_circle_bottom_mark,
554
+ self.bottom_right_corner,
555
+ self.bottom_left_corner,
556
+ self.top_right_corner,
557
+ self.top_left_corner,
558
+ self.left_penalty_mark,
559
+ self.right_penalty_mark,
560
+ self.left_penalty_area_top_right_corner,
561
+ self.left_penalty_area_top_left_corner,
562
+ self.left_penalty_area_bottom_right_corner,
563
+ self.left_penalty_area_bottom_left_corner,
564
+ self.right_penalty_area_top_right_corner,
565
+ self.right_penalty_area_top_left_corner,
566
+ self.right_penalty_area_bottom_right_corner,
567
+ self.right_penalty_area_bottom_left_corner,
568
+ self.left_goal_area_top_right_corner,
569
+ self.left_goal_area_top_left_corner,
570
+ self.left_goal_area_bottom_right_corner,
571
+ self.left_goal_area_bottom_left_corner,
572
+ self.right_goal_area_top_right_corner,
573
+ self.right_goal_area_top_left_corner,
574
+ self.right_goal_area_bottom_right_corner,
575
+ self.right_goal_area_bottom_left_corner,
576
+ self.top_left_16M_penalty_arc_mark,
577
+ self.top_right_16M_penalty_arc_mark,
578
+ self.bottom_left_16M_penalty_arc_mark,
579
+ self.bottom_right_16M_penalty_arc_mark,
580
+ self.left_goal_top_left_post,
581
+ self.left_goal_top_right_post,
582
+ self.left_goal_bottom_left_post,
583
+ self.left_goal_bottom_right_post,
584
+ self.right_goal_top_left_post,
585
+ self.right_goal_top_right_post,
586
+ self.right_goal_bottom_left_post,
587
+ self.right_goal_bottom_right_post,
588
+ ]
589
+
590
+ def sample_field_points(self, dist=0.1, dist_circles=0.2):
591
+ """
592
+ Samples each pitch element every dist meters, returns a dictionary associating the class of the element with a list of points sampled along this element.
593
+ :param dist: the distance in meters between each point sampled
594
+ :param dist_circles: the distance in meters between each point sampled on circles
595
+ :return: a dictionary associating the class of the element with a list of points sampled along this element.
596
+ """
597
+ polylines = dict()
598
+ center = self.point_dict["CENTER_MARK"]
599
+ fromAngle = 0.0
600
+ toAngle = 2 * np.pi
601
+
602
+ if toAngle < fromAngle:
603
+ toAngle += 2 * np.pi
604
+ x1 = center[0] + np.cos(fromAngle) * SoccerPitchSN.CENTER_CIRCLE_RADIUS
605
+ y1 = center[1] + np.sin(fromAngle) * SoccerPitchSN.CENTER_CIRCLE_RADIUS
606
+ z1 = 0.0
607
+ point = np.array((x1, y1, z1))
608
+ polyline = [point]
609
+ length = SoccerPitchSN.CENTER_CIRCLE_RADIUS * (toAngle - fromAngle)
610
+ nb_pts = int(length / dist_circles)
611
+ dangle = dist_circles / SoccerPitchSN.CENTER_CIRCLE_RADIUS
612
+ for i in range(1, nb_pts):
613
+ angle = fromAngle + i * dangle
614
+ x = center[0] + np.cos(angle) * SoccerPitchSN.CENTER_CIRCLE_RADIUS
615
+ y = center[1] + np.sin(angle) * SoccerPitchSN.CENTER_CIRCLE_RADIUS
616
+ z = 0
617
+ point = np.array((x, y, z))
618
+ polyline.append(point)
619
+ polylines["Circle central"] = polyline
620
+ for key, line in self.line_extremities.items():
621
+
622
+ if "Circle" in key:
623
+ if key == "Circle right":
624
+ top = self.point_dict["TR_16M_LINE_AND_PENALTY_ARC_INTERSECTION"]
625
+ bottom = self.point_dict["BR_16M_LINE_AND_PENALTY_ARC_INTERSECTION"]
626
+ center = self.point_dict["R_PENALTY_MARK"]
627
+ toAngle = np.arctan2(top[1] - center[1], top[0] - center[0]) + 2 * np.pi
628
+ fromAngle = np.arctan2(bottom[1] - center[1], bottom[0] - center[0]) + 2 * np.pi
629
+ elif key == "Circle left":
630
+ top = self.point_dict["TL_16M_LINE_AND_PENALTY_ARC_INTERSECTION"]
631
+ bottom = self.point_dict["BL_16M_LINE_AND_PENALTY_ARC_INTERSECTION"]
632
+ center = self.point_dict["L_PENALTY_MARK"]
633
+ fromAngle = np.arctan2(top[1] - center[1], top[0] - center[0]) + 2 * np.pi
634
+ toAngle = np.arctan2(bottom[1] - center[1], bottom[0] - center[0]) + 2 * np.pi
635
+ if toAngle < fromAngle:
636
+ toAngle += 2 * np.pi
637
+ x1 = center[0] + np.cos(fromAngle) * SoccerPitchSN.CENTER_CIRCLE_RADIUS
638
+ y1 = center[1] + np.sin(fromAngle) * SoccerPitchSN.CENTER_CIRCLE_RADIUS
639
+ z1 = 0.0
640
+ xn = center[0] + np.cos(toAngle) * SoccerPitchSN.CENTER_CIRCLE_RADIUS
641
+ yn = center[1] + np.sin(toAngle) * SoccerPitchSN.CENTER_CIRCLE_RADIUS
642
+ zn = 0.0
643
+ start = np.array((x1, y1, z1))
644
+ end = np.array((xn, yn, zn))
645
+ polyline = [start]
646
+ length = SoccerPitchSN.CENTER_CIRCLE_RADIUS * (toAngle - fromAngle)
647
+ nb_pts = int(length / dist_circles)
648
+ dangle = dist_circles / SoccerPitchSN.CENTER_CIRCLE_RADIUS
649
+ for i in range(1, nb_pts + 1):
650
+ angle = fromAngle + i * dangle
651
+ x = center[0] + np.cos(angle) * SoccerPitchSN.CENTER_CIRCLE_RADIUS
652
+ y = center[1] + np.sin(angle) * SoccerPitchSN.CENTER_CIRCLE_RADIUS
653
+ z = 0
654
+ point = np.array((x, y, z))
655
+ polyline.append(point)
656
+ polyline.append(end)
657
+ polylines[key] = polyline
658
+ else:
659
+ start = line[0]
660
+ end = line[1]
661
+
662
+ polyline = [start]
663
+
664
+ total_dist = np.sqrt(np.sum(np.square(start - end)))
665
+ nb_pts = int(total_dist / dist - 1)
666
+
667
+ v = end - start
668
+ v /= np.linalg.norm(v)
669
+ prev_pt = start
670
+ for i in range(nb_pts):
671
+ pt = prev_pt + dist * v
672
+ prev_pt = pt
673
+ polyline.append(pt)
674
+ polyline.append(end)
675
+ polylines[key] = polyline
676
+ return polylines
677
+
678
+ def get_2d_homogeneous_line(self, line_name):
679
+ """
680
+ For lines belonging to the pitch lawn plane returns its 2D homogenous equation coefficients
681
+ :param line_name
682
+ :return: an array containing the three coefficients of the line
683
+ """
684
+ # ensure line in football pitch plane
685
+ if (
686
+ line_name in self.line_extremities.keys()
687
+ and "post" not in line_name
688
+ and "crossbar" not in line_name
689
+ and "Circle" not in line_name
690
+ ):
691
+ extremities = self.line_extremities[line_name]
692
+ p1 = np.array([extremities[0][0], extremities[0][1], 1], dtype="float")
693
+ p2 = np.array([extremities[1][0], extremities[1][1], 1], dtype="float")
694
+ line = np.cross(p1, p2)
695
+
696
+ return line
697
+ return None
698
+
699
+
700
+ class SoccerPitchSNCircleCentralSplit:
701
+ """Static class variables that are specified by the rules of the game"""
702
+
703
+ GOAL_LINE_TO_PENALTY_MARK = 11.0
704
+ PENALTY_AREA_WIDTH = 40.32
705
+ PENALTY_AREA_LENGTH = 16.5
706
+ GOAL_AREA_WIDTH = 18.32
707
+ GOAL_AREA_LENGTH = 5.5
708
+ CENTER_CIRCLE_RADIUS = 9.15
709
+ GOAL_HEIGHT = 2.44
710
+ GOAL_LENGTH = 7.32
711
+
712
+ lines_classes = [
713
+ "Big rect. left bottom",
714
+ "Big rect. left main",
715
+ "Big rect. left top",
716
+ "Big rect. right bottom",
717
+ "Big rect. right main",
718
+ "Big rect. right top",
719
+ "Circle central left",
720
+ "Circle central right",
721
+ "Circle left",
722
+ "Circle right",
723
+ "Goal left crossbar",
724
+ "Goal left post left ",
725
+ "Goal left post right",
726
+ "Goal right crossbar",
727
+ "Goal right post left",
728
+ "Goal right post right",
729
+ "Goal unknown",
730
+ "Line unknown",
731
+ "Middle line",
732
+ "Side line bottom",
733
+ "Side line left",
734
+ "Side line right",
735
+ "Side line top",
736
+ "Small rect. left bottom",
737
+ "Small rect. left main",
738
+ "Small rect. left top",
739
+ "Small rect. right bottom",
740
+ "Small rect. right main",
741
+ "Small rect. right top",
742
+ ]
743
+
744
+ symetric_classes = {
745
+ "Side line top": "Side line bottom",
746
+ "Side line bottom": "Side line top",
747
+ "Side line left": "Side line right",
748
+ "Middle line": "Middle line",
749
+ "Side line right": "Side line left",
750
+ "Big rect. left top": "Big rect. right bottom",
751
+ "Big rect. left bottom": "Big rect. right top",
752
+ "Big rect. left main": "Big rect. right main",
753
+ "Big rect. right top": "Big rect. left bottom",
754
+ "Big rect. right bottom": "Big rect. left top",
755
+ "Big rect. right main": "Big rect. left main",
756
+ "Small rect. left top": "Small rect. right bottom",
757
+ "Small rect. left bottom": "Small rect. right top",
758
+ "Small rect. left main": "Small rect. right main",
759
+ "Small rect. right top": "Small rect. left bottom",
760
+ "Small rect. right bottom": "Small rect. left top",
761
+ "Small rect. right main": "Small rect. left main",
762
+ "Circle left": "Circle right",
763
+ "Circle central left": "Circle central right",
764
+ "Circle central right": "Circle central left",
765
+ "Circle right": "Circle left",
766
+ "Goal left crossbar": "Goal right crossbar",
767
+ "Goal left post left ": "Goal right post right",
768
+ "Goal left post right": "Goal right post left",
769
+ "Goal right crossbar": "Goal left crossbar",
770
+ "Goal right post left": "Goal left post right",
771
+ "Goal right post right": "Goal left post left ",
772
+ "Goal unknown": "Goal unknown",
773
+ "Line unknown": "Line unknown",
774
+ }
775
+
776
+ # RGB values
777
+ palette = {
778
+ "Big rect. left bottom": (127, 0, 0),
779
+ "Big rect. left main": (102, 102, 102),
780
+ "Big rect. left top": (0, 0, 127),
781
+ "Big rect. right bottom": (86, 32, 39),
782
+ "Big rect. right main": (48, 77, 0),
783
+ "Big rect. right top": (14, 97, 100),
784
+ "Circle central left": (0, 0, 255),
785
+ "Circle central right": (0, 255, 0),
786
+ "Circle left": (255, 127, 0),
787
+ "Circle right": (0, 255, 255),
788
+ "Goal left crossbar": (255, 255, 200),
789
+ "Goal left post left ": (165, 255, 0),
790
+ "Goal left post right": (155, 119, 45),
791
+ "Goal right crossbar": (86, 32, 139),
792
+ "Goal right post left": (196, 120, 153),
793
+ "Goal right post right": (166, 36, 52),
794
+ "Goal unknown": (0, 0, 0),
795
+ "Line unknown": (0, 0, 0),
796
+ "Middle line": (255, 255, 0),
797
+ "Side line bottom": (255, 0, 255),
798
+ "Side line left": (0, 255, 150),
799
+ "Side line right": (0, 230, 0),
800
+ "Side line top": (230, 0, 0),
801
+ "Small rect. left bottom": (0, 150, 255),
802
+ "Small rect. left main": (254, 173, 225),
803
+ "Small rect. left top": (87, 72, 39),
804
+ "Small rect. right bottom": (122, 0, 255),
805
+ "Small rect. right main": (255, 255, 255),
806
+ "Small rect. right top": (153, 23, 153),
807
+ }
808
+
809
+ def __init__(self, pitch_length=105.0, pitch_width=68.0):
810
+ """
811
+ Initialize 3D coordinates of all elements of the soccer pitch.
812
+ :param pitch_length: According to FIFA rules, length belong to [90,120] meters
813
+ :param pitch_width: According to FIFA rules, length belong to [45,90] meters
814
+ """
815
+ self.PITCH_LENGTH = pitch_length
816
+ self.PITCH_WIDTH = pitch_width
817
+
818
+ self.center_mark = np.array([0, 0, 0], dtype="float")
819
+ self.halfway_and_bottom_touch_line_mark = np.array([0, pitch_width / 2.0, 0], dtype="float")
820
+ self.halfway_and_top_touch_line_mark = np.array([0, -pitch_width / 2.0, 0], dtype="float")
821
+ self.halfway_line_and_center_circle_top_mark = np.array(
822
+ [0, -SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS, 0], dtype="float"
823
+ )
824
+ self.halfway_line_and_center_circle_bottom_mark = np.array(
825
+ [0, SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS, 0], dtype="float"
826
+ )
827
+ self.bottom_right_corner = np.array(
828
+ [pitch_length / 2.0, pitch_width / 2.0, 0], dtype="float"
829
+ )
830
+ self.bottom_left_corner = np.array(
831
+ [-pitch_length / 2.0, pitch_width / 2.0, 0], dtype="float"
832
+ )
833
+ self.top_right_corner = np.array([pitch_length / 2.0, -pitch_width / 2.0, 0], dtype="float")
834
+ self.top_left_corner = np.array([-pitch_length / 2.0, -34, 0], dtype="float")
835
+
836
+ self.left_goal_bottom_left_post = np.array(
837
+ [-pitch_length / 2.0, SoccerPitchSNCircleCentralSplit.GOAL_LENGTH / 2.0, 0.0],
838
+ dtype="float",
839
+ )
840
+ self.left_goal_top_left_post = np.array(
841
+ [
842
+ -pitch_length / 2.0,
843
+ SoccerPitchSNCircleCentralSplit.GOAL_LENGTH / 2.0,
844
+ -SoccerPitchSNCircleCentralSplit.GOAL_HEIGHT,
845
+ ],
846
+ dtype="float",
847
+ )
848
+ self.left_goal_bottom_right_post = np.array(
849
+ [-pitch_length / 2.0, -SoccerPitchSNCircleCentralSplit.GOAL_LENGTH / 2.0, 0.0],
850
+ dtype="float",
851
+ )
852
+ self.left_goal_top_right_post = np.array(
853
+ [
854
+ -pitch_length / 2.0,
855
+ -SoccerPitchSNCircleCentralSplit.GOAL_LENGTH / 2.0,
856
+ -SoccerPitchSNCircleCentralSplit.GOAL_HEIGHT,
857
+ ],
858
+ dtype="float",
859
+ )
860
+
861
+ self.right_goal_bottom_left_post = np.array(
862
+ [pitch_length / 2.0, -SoccerPitchSNCircleCentralSplit.GOAL_LENGTH / 2.0, 0.0],
863
+ dtype="float",
864
+ )
865
+ self.right_goal_top_left_post = np.array(
866
+ [
867
+ pitch_length / 2.0,
868
+ -SoccerPitchSNCircleCentralSplit.GOAL_LENGTH / 2.0,
869
+ -SoccerPitchSNCircleCentralSplit.GOAL_HEIGHT,
870
+ ],
871
+ dtype="float",
872
+ )
873
+ self.right_goal_bottom_right_post = np.array(
874
+ [pitch_length / 2.0, SoccerPitchSNCircleCentralSplit.GOAL_LENGTH / 2.0, 0.0],
875
+ dtype="float",
876
+ )
877
+ self.right_goal_top_right_post = np.array(
878
+ [
879
+ pitch_length / 2.0,
880
+ SoccerPitchSNCircleCentralSplit.GOAL_LENGTH / 2.0,
881
+ -SoccerPitchSNCircleCentralSplit.GOAL_HEIGHT,
882
+ ],
883
+ dtype="float",
884
+ )
885
+
886
+ self.left_penalty_mark = np.array(
887
+ [-pitch_length / 2.0 + SoccerPitchSNCircleCentralSplit.GOAL_LINE_TO_PENALTY_MARK, 0, 0],
888
+ dtype="float",
889
+ )
890
+ self.right_penalty_mark = np.array(
891
+ [pitch_length / 2.0 - SoccerPitchSNCircleCentralSplit.GOAL_LINE_TO_PENALTY_MARK, 0, 0],
892
+ dtype="float",
893
+ )
894
+
895
+ self.left_penalty_area_top_right_corner = np.array(
896
+ [
897
+ -pitch_length / 2.0 + SoccerPitchSNCircleCentralSplit.PENALTY_AREA_LENGTH,
898
+ -SoccerPitchSNCircleCentralSplit.PENALTY_AREA_WIDTH / 2.0,
899
+ 0,
900
+ ],
901
+ dtype="float",
902
+ )
903
+ self.left_penalty_area_top_left_corner = np.array(
904
+ [-pitch_length / 2.0, -SoccerPitchSNCircleCentralSplit.PENALTY_AREA_WIDTH / 2.0, 0],
905
+ dtype="float",
906
+ )
907
+ self.left_penalty_area_bottom_right_corner = np.array(
908
+ [
909
+ -pitch_length / 2.0 + SoccerPitchSNCircleCentralSplit.PENALTY_AREA_LENGTH,
910
+ SoccerPitchSNCircleCentralSplit.PENALTY_AREA_WIDTH / 2.0,
911
+ 0,
912
+ ],
913
+ dtype="float",
914
+ )
915
+ self.left_penalty_area_bottom_left_corner = np.array(
916
+ [-pitch_length / 2.0, SoccerPitchSNCircleCentralSplit.PENALTY_AREA_WIDTH / 2.0, 0],
917
+ dtype="float",
918
+ )
919
+ self.right_penalty_area_top_right_corner = np.array(
920
+ [pitch_length / 2.0, -SoccerPitchSNCircleCentralSplit.PENALTY_AREA_WIDTH / 2.0, 0],
921
+ dtype="float",
922
+ )
923
+ self.right_penalty_area_top_left_corner = np.array(
924
+ [
925
+ pitch_length / 2.0 - SoccerPitchSNCircleCentralSplit.PENALTY_AREA_LENGTH,
926
+ -SoccerPitchSNCircleCentralSplit.PENALTY_AREA_WIDTH / 2.0,
927
+ 0,
928
+ ],
929
+ dtype="float",
930
+ )
931
+ self.right_penalty_area_bottom_right_corner = np.array(
932
+ [pitch_length / 2.0, SoccerPitchSNCircleCentralSplit.PENALTY_AREA_WIDTH / 2.0, 0],
933
+ dtype="float",
934
+ )
935
+ self.right_penalty_area_bottom_left_corner = np.array(
936
+ [
937
+ pitch_length / 2.0 - SoccerPitchSNCircleCentralSplit.PENALTY_AREA_LENGTH,
938
+ SoccerPitchSNCircleCentralSplit.PENALTY_AREA_WIDTH / 2.0,
939
+ 0,
940
+ ],
941
+ dtype="float",
942
+ )
943
+
944
+ self.left_goal_area_top_right_corner = np.array(
945
+ [
946
+ -pitch_length / 2.0 + SoccerPitchSNCircleCentralSplit.GOAL_AREA_LENGTH,
947
+ -SoccerPitchSNCircleCentralSplit.GOAL_AREA_WIDTH / 2.0,
948
+ 0,
949
+ ],
950
+ dtype="float",
951
+ )
952
+ self.left_goal_area_top_left_corner = np.array(
953
+ [-pitch_length / 2.0, -SoccerPitchSNCircleCentralSplit.GOAL_AREA_WIDTH / 2.0, 0],
954
+ dtype="float",
955
+ )
956
+ self.left_goal_area_bottom_right_corner = np.array(
957
+ [
958
+ -pitch_length / 2.0 + SoccerPitchSNCircleCentralSplit.GOAL_AREA_LENGTH,
959
+ SoccerPitchSNCircleCentralSplit.GOAL_AREA_WIDTH / 2.0,
960
+ 0,
961
+ ],
962
+ dtype="float",
963
+ )
964
+ self.left_goal_area_bottom_left_corner = np.array(
965
+ [-pitch_length / 2.0, SoccerPitchSNCircleCentralSplit.GOAL_AREA_WIDTH / 2.0, 0],
966
+ dtype="float",
967
+ )
968
+ self.right_goal_area_top_right_corner = np.array(
969
+ [pitch_length / 2.0, -SoccerPitchSNCircleCentralSplit.GOAL_AREA_WIDTH / 2.0, 0],
970
+ dtype="float",
971
+ )
972
+ self.right_goal_area_top_left_corner = np.array(
973
+ [
974
+ pitch_length / 2.0 - SoccerPitchSNCircleCentralSplit.GOAL_AREA_LENGTH,
975
+ -SoccerPitchSNCircleCentralSplit.GOAL_AREA_WIDTH / 2.0,
976
+ 0,
977
+ ],
978
+ dtype="float",
979
+ )
980
+ self.right_goal_area_bottom_right_corner = np.array(
981
+ [pitch_length / 2.0, SoccerPitchSNCircleCentralSplit.GOAL_AREA_WIDTH / 2.0, 0],
982
+ dtype="float",
983
+ )
984
+ self.right_goal_area_bottom_left_corner = np.array(
985
+ [
986
+ pitch_length / 2.0 - SoccerPitchSNCircleCentralSplit.GOAL_AREA_LENGTH,
987
+ SoccerPitchSNCircleCentralSplit.GOAL_AREA_WIDTH / 2.0,
988
+ 0,
989
+ ],
990
+ dtype="float",
991
+ )
992
+
993
+ x = -pitch_length / 2.0 + SoccerPitchSNCircleCentralSplit.PENALTY_AREA_LENGTH
994
+ dx = (
995
+ SoccerPitchSNCircleCentralSplit.PENALTY_AREA_LENGTH
996
+ - SoccerPitchSNCircleCentralSplit.GOAL_LINE_TO_PENALTY_MARK
997
+ )
998
+ y = -np.sqrt(
999
+ SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
1000
+ * SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
1001
+ - dx * dx
1002
+ )
1003
+ self.top_left_16M_penalty_arc_mark = np.array([x, y, 0], dtype="float")
1004
+
1005
+ x = pitch_length / 2.0 - SoccerPitchSNCircleCentralSplit.PENALTY_AREA_LENGTH
1006
+ dx = (
1007
+ SoccerPitchSNCircleCentralSplit.PENALTY_AREA_LENGTH
1008
+ - SoccerPitchSNCircleCentralSplit.GOAL_LINE_TO_PENALTY_MARK
1009
+ )
1010
+ y = -np.sqrt(
1011
+ SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
1012
+ * SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
1013
+ - dx * dx
1014
+ )
1015
+ self.top_right_16M_penalty_arc_mark = np.array([x, y, 0], dtype="float")
1016
+
1017
+ x = -pitch_length / 2.0 + SoccerPitchSNCircleCentralSplit.PENALTY_AREA_LENGTH
1018
+ dx = (
1019
+ SoccerPitchSNCircleCentralSplit.PENALTY_AREA_LENGTH
1020
+ - SoccerPitchSNCircleCentralSplit.GOAL_LINE_TO_PENALTY_MARK
1021
+ )
1022
+ y = np.sqrt(
1023
+ SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
1024
+ * SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
1025
+ - dx * dx
1026
+ )
1027
+ self.bottom_left_16M_penalty_arc_mark = np.array([x, y, 0], dtype="float")
1028
+
1029
+ x = pitch_length / 2.0 - SoccerPitchSNCircleCentralSplit.PENALTY_AREA_LENGTH
1030
+ dx = (
1031
+ SoccerPitchSNCircleCentralSplit.PENALTY_AREA_LENGTH
1032
+ - SoccerPitchSNCircleCentralSplit.GOAL_LINE_TO_PENALTY_MARK
1033
+ )
1034
+ y = np.sqrt(
1035
+ SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
1036
+ * SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
1037
+ - dx * dx
1038
+ )
1039
+ self.bottom_right_16M_penalty_arc_mark = np.array([x, y, 0], dtype="float")
1040
+
1041
+ # self.set_elevations(elevation)
1042
+
1043
+ self.point_dict = {}
1044
+ self.point_dict["CENTER_MARK"] = self.center_mark
1045
+ self.point_dict["L_PENALTY_MARK"] = self.left_penalty_mark
1046
+ self.point_dict["R_PENALTY_MARK"] = self.right_penalty_mark
1047
+ self.point_dict["TL_PITCH_CORNER"] = self.top_left_corner
1048
+ self.point_dict["BL_PITCH_CORNER"] = self.bottom_left_corner
1049
+ self.point_dict["TR_PITCH_CORNER"] = self.top_right_corner
1050
+ self.point_dict["BR_PITCH_CORNER"] = self.bottom_right_corner
1051
+ self.point_dict["L_PENALTY_AREA_TL_CORNER"] = self.left_penalty_area_top_left_corner
1052
+ self.point_dict["L_PENALTY_AREA_TR_CORNER"] = self.left_penalty_area_top_right_corner
1053
+ self.point_dict["L_PENALTY_AREA_BL_CORNER"] = self.left_penalty_area_bottom_left_corner
1054
+ self.point_dict["L_PENALTY_AREA_BR_CORNER"] = self.left_penalty_area_bottom_right_corner
1055
+
1056
+ self.point_dict["R_PENALTY_AREA_TL_CORNER"] = self.right_penalty_area_top_left_corner
1057
+ self.point_dict["R_PENALTY_AREA_TR_CORNER"] = self.right_penalty_area_top_right_corner
1058
+ self.point_dict["R_PENALTY_AREA_BL_CORNER"] = self.right_penalty_area_bottom_left_corner
1059
+ self.point_dict["R_PENALTY_AREA_BR_CORNER"] = self.right_penalty_area_bottom_right_corner
1060
+
1061
+ self.point_dict["L_GOAL_AREA_TL_CORNER"] = self.left_goal_area_top_left_corner
1062
+ self.point_dict["L_GOAL_AREA_TR_CORNER"] = self.left_goal_area_top_right_corner
1063
+ self.point_dict["L_GOAL_AREA_BL_CORNER"] = self.left_goal_area_bottom_left_corner
1064
+ self.point_dict["L_GOAL_AREA_BR_CORNER"] = self.left_goal_area_bottom_right_corner
1065
+
1066
+ self.point_dict["R_GOAL_AREA_TL_CORNER"] = self.right_goal_area_top_left_corner
1067
+ self.point_dict["R_GOAL_AREA_TR_CORNER"] = self.right_goal_area_top_right_corner
1068
+ self.point_dict["R_GOAL_AREA_BL_CORNER"] = self.right_goal_area_bottom_left_corner
1069
+ self.point_dict["R_GOAL_AREA_BR_CORNER"] = self.right_goal_area_bottom_right_corner
1070
+
1071
+ self.point_dict["L_GOAL_TL_POST"] = self.left_goal_top_left_post
1072
+ self.point_dict["L_GOAL_TR_POST"] = self.left_goal_top_right_post
1073
+ self.point_dict["L_GOAL_BL_POST"] = self.left_goal_bottom_left_post
1074
+ self.point_dict["L_GOAL_BR_POST"] = self.left_goal_bottom_right_post
1075
+
1076
+ self.point_dict["R_GOAL_TL_POST"] = self.right_goal_top_left_post
1077
+ self.point_dict["R_GOAL_TR_POST"] = self.right_goal_top_right_post
1078
+ self.point_dict["R_GOAL_BL_POST"] = self.right_goal_bottom_left_post
1079
+ self.point_dict["R_GOAL_BR_POST"] = self.right_goal_bottom_right_post
1080
+
1081
+ self.point_dict[
1082
+ "T_TOUCH_AND_HALFWAY_LINES_INTERSECTION"
1083
+ ] = self.halfway_and_top_touch_line_mark
1084
+ self.point_dict[
1085
+ "B_TOUCH_AND_HALFWAY_LINES_INTERSECTION"
1086
+ ] = self.halfway_and_bottom_touch_line_mark
1087
+ self.point_dict[
1088
+ "T_HALFWAY_LINE_AND_CENTER_CIRCLE_INTERSECTION"
1089
+ ] = self.halfway_line_and_center_circle_top_mark
1090
+ self.point_dict[
1091
+ "B_HALFWAY_LINE_AND_CENTER_CIRCLE_INTERSECTION"
1092
+ ] = self.halfway_line_and_center_circle_bottom_mark
1093
+ self.point_dict[
1094
+ "TL_16M_LINE_AND_PENALTY_ARC_INTERSECTION"
1095
+ ] = self.top_left_16M_penalty_arc_mark
1096
+ self.point_dict[
1097
+ "BL_16M_LINE_AND_PENALTY_ARC_INTERSECTION"
1098
+ ] = self.bottom_left_16M_penalty_arc_mark
1099
+ self.point_dict[
1100
+ "TR_16M_LINE_AND_PENALTY_ARC_INTERSECTION"
1101
+ ] = self.top_right_16M_penalty_arc_mark
1102
+ self.point_dict[
1103
+ "BR_16M_LINE_AND_PENALTY_ARC_INTERSECTION"
1104
+ ] = self.bottom_right_16M_penalty_arc_mark
1105
+
1106
+ self.line_extremities = dict()
1107
+ self.line_extremities["Big rect. left bottom"] = (
1108
+ self.point_dict["L_PENALTY_AREA_BL_CORNER"],
1109
+ self.point_dict["L_PENALTY_AREA_BR_CORNER"],
1110
+ )
1111
+ self.line_extremities["Big rect. left top"] = (
1112
+ self.point_dict["L_PENALTY_AREA_TL_CORNER"],
1113
+ self.point_dict["L_PENALTY_AREA_TR_CORNER"],
1114
+ )
1115
+ self.line_extremities["Big rect. left main"] = (
1116
+ self.point_dict["L_PENALTY_AREA_TR_CORNER"],
1117
+ self.point_dict["L_PENALTY_AREA_BR_CORNER"],
1118
+ )
1119
+ self.line_extremities["Big rect. right bottom"] = (
1120
+ self.point_dict["R_PENALTY_AREA_BL_CORNER"],
1121
+ self.point_dict["R_PENALTY_AREA_BR_CORNER"],
1122
+ )
1123
+ self.line_extremities["Big rect. right top"] = (
1124
+ self.point_dict["R_PENALTY_AREA_TL_CORNER"],
1125
+ self.point_dict["R_PENALTY_AREA_TR_CORNER"],
1126
+ )
1127
+ self.line_extremities["Big rect. right main"] = (
1128
+ self.point_dict["R_PENALTY_AREA_TL_CORNER"],
1129
+ self.point_dict["R_PENALTY_AREA_BL_CORNER"],
1130
+ )
1131
+
1132
+ self.line_extremities["Small rect. left bottom"] = (
1133
+ self.point_dict["L_GOAL_AREA_BL_CORNER"],
1134
+ self.point_dict["L_GOAL_AREA_BR_CORNER"],
1135
+ )
1136
+ self.line_extremities["Small rect. left top"] = (
1137
+ self.point_dict["L_GOAL_AREA_TL_CORNER"],
1138
+ self.point_dict["L_GOAL_AREA_TR_CORNER"],
1139
+ )
1140
+ self.line_extremities["Small rect. left main"] = (
1141
+ self.point_dict["L_GOAL_AREA_TR_CORNER"],
1142
+ self.point_dict["L_GOAL_AREA_BR_CORNER"],
1143
+ )
1144
+ self.line_extremities["Small rect. right bottom"] = (
1145
+ self.point_dict["R_GOAL_AREA_BL_CORNER"],
1146
+ self.point_dict["R_GOAL_AREA_BR_CORNER"],
1147
+ )
1148
+ self.line_extremities["Small rect. right top"] = (
1149
+ self.point_dict["R_GOAL_AREA_TL_CORNER"],
1150
+ self.point_dict["R_GOAL_AREA_TR_CORNER"],
1151
+ )
1152
+ self.line_extremities["Small rect. right main"] = (
1153
+ self.point_dict["R_GOAL_AREA_TL_CORNER"],
1154
+ self.point_dict["R_GOAL_AREA_BL_CORNER"],
1155
+ )
1156
+
1157
+ self.line_extremities["Side line top"] = (
1158
+ self.point_dict["TL_PITCH_CORNER"],
1159
+ self.point_dict["TR_PITCH_CORNER"],
1160
+ )
1161
+ self.line_extremities["Side line bottom"] = (
1162
+ self.point_dict["BL_PITCH_CORNER"],
1163
+ self.point_dict["BR_PITCH_CORNER"],
1164
+ )
1165
+ self.line_extremities["Side line left"] = (
1166
+ self.point_dict["TL_PITCH_CORNER"],
1167
+ self.point_dict["BL_PITCH_CORNER"],
1168
+ )
1169
+ self.line_extremities["Side line right"] = (
1170
+ self.point_dict["TR_PITCH_CORNER"],
1171
+ self.point_dict["BR_PITCH_CORNER"],
1172
+ )
1173
+ self.line_extremities["Middle line"] = (
1174
+ self.point_dict["T_TOUCH_AND_HALFWAY_LINES_INTERSECTION"],
1175
+ self.point_dict["B_TOUCH_AND_HALFWAY_LINES_INTERSECTION"],
1176
+ )
1177
+
1178
+ self.line_extremities["Goal left crossbar"] = (
1179
+ self.point_dict["L_GOAL_TR_POST"],
1180
+ self.point_dict["L_GOAL_TL_POST"],
1181
+ )
1182
+ self.line_extremities["Goal left post left "] = (
1183
+ self.point_dict["L_GOAL_TL_POST"],
1184
+ self.point_dict["L_GOAL_BL_POST"],
1185
+ )
1186
+ self.line_extremities["Goal left post right"] = (
1187
+ self.point_dict["L_GOAL_TR_POST"],
1188
+ self.point_dict["L_GOAL_BR_POST"],
1189
+ )
1190
+
1191
+ self.line_extremities["Goal right crossbar"] = (
1192
+ self.point_dict["R_GOAL_TL_POST"],
1193
+ self.point_dict["R_GOAL_TR_POST"],
1194
+ )
1195
+ self.line_extremities["Goal right post left"] = (
1196
+ self.point_dict["R_GOAL_TL_POST"],
1197
+ self.point_dict["R_GOAL_BL_POST"],
1198
+ )
1199
+ self.line_extremities["Goal right post right"] = (
1200
+ self.point_dict["R_GOAL_TR_POST"],
1201
+ self.point_dict["R_GOAL_BR_POST"],
1202
+ )
1203
+ self.line_extremities["Circle right"] = (
1204
+ self.point_dict["TR_16M_LINE_AND_PENALTY_ARC_INTERSECTION"],
1205
+ self.point_dict["BR_16M_LINE_AND_PENALTY_ARC_INTERSECTION"],
1206
+ )
1207
+ self.line_extremities["Circle left"] = (
1208
+ self.point_dict["TL_16M_LINE_AND_PENALTY_ARC_INTERSECTION"],
1209
+ self.point_dict["BL_16M_LINE_AND_PENALTY_ARC_INTERSECTION"],
1210
+ )
1211
+
1212
+ self.line_extremities_keys = dict()
1213
+ self.line_extremities_keys["Big rect. left bottom"] = (
1214
+ "L_PENALTY_AREA_BL_CORNER",
1215
+ "L_PENALTY_AREA_BR_CORNER",
1216
+ )
1217
+ self.line_extremities_keys["Big rect. left top"] = (
1218
+ "L_PENALTY_AREA_TL_CORNER",
1219
+ "L_PENALTY_AREA_TR_CORNER",
1220
+ )
1221
+ self.line_extremities_keys["Big rect. left main"] = (
1222
+ "L_PENALTY_AREA_TR_CORNER",
1223
+ "L_PENALTY_AREA_BR_CORNER",
1224
+ )
1225
+ self.line_extremities_keys["Big rect. right bottom"] = (
1226
+ "R_PENALTY_AREA_BL_CORNER",
1227
+ "R_PENALTY_AREA_BR_CORNER",
1228
+ )
1229
+ self.line_extremities_keys["Big rect. right top"] = (
1230
+ "R_PENALTY_AREA_TL_CORNER",
1231
+ "R_PENALTY_AREA_TR_CORNER",
1232
+ )
1233
+ self.line_extremities_keys["Big rect. right main"] = (
1234
+ "R_PENALTY_AREA_TL_CORNER",
1235
+ "R_PENALTY_AREA_BL_CORNER",
1236
+ )
1237
+
1238
+ self.line_extremities_keys["Small rect. left bottom"] = (
1239
+ "L_GOAL_AREA_BL_CORNER",
1240
+ "L_GOAL_AREA_BR_CORNER",
1241
+ )
1242
+ self.line_extremities_keys["Small rect. left top"] = (
1243
+ "L_GOAL_AREA_TL_CORNER",
1244
+ "L_GOAL_AREA_TR_CORNER",
1245
+ )
1246
+ self.line_extremities_keys["Small rect. left main"] = (
1247
+ "L_GOAL_AREA_TR_CORNER",
1248
+ "L_GOAL_AREA_BR_CORNER",
1249
+ )
1250
+ self.line_extremities_keys["Small rect. right bottom"] = (
1251
+ "R_GOAL_AREA_BL_CORNER",
1252
+ "R_GOAL_AREA_BR_CORNER",
1253
+ )
1254
+ self.line_extremities_keys["Small rect. right top"] = (
1255
+ "R_GOAL_AREA_TL_CORNER",
1256
+ "R_GOAL_AREA_TR_CORNER",
1257
+ )
1258
+ self.line_extremities_keys["Small rect. right main"] = (
1259
+ "R_GOAL_AREA_TL_CORNER",
1260
+ "R_GOAL_AREA_BL_CORNER",
1261
+ )
1262
+
1263
+ self.line_extremities_keys["Side line top"] = ("TL_PITCH_CORNER", "TR_PITCH_CORNER")
1264
+ self.line_extremities_keys["Side line bottom"] = ("BL_PITCH_CORNER", "BR_PITCH_CORNER")
1265
+ self.line_extremities_keys["Side line left"] = ("TL_PITCH_CORNER", "BL_PITCH_CORNER")
1266
+ self.line_extremities_keys["Side line right"] = ("TR_PITCH_CORNER", "BR_PITCH_CORNER")
1267
+ self.line_extremities_keys["Middle line"] = (
1268
+ "T_TOUCH_AND_HALFWAY_LINES_INTERSECTION",
1269
+ "B_TOUCH_AND_HALFWAY_LINES_INTERSECTION",
1270
+ )
1271
+
1272
+ self.line_extremities_keys["Goal left crossbar"] = ("L_GOAL_TR_POST", "L_GOAL_TL_POST")
1273
+ self.line_extremities_keys["Goal left post left "] = ("L_GOAL_TL_POST", "L_GOAL_BL_POST")
1274
+ self.line_extremities_keys["Goal left post right"] = ("L_GOAL_TR_POST", "L_GOAL_BR_POST")
1275
+
1276
+ self.line_extremities_keys["Goal right crossbar"] = ("R_GOAL_TL_POST", "R_GOAL_TR_POST")
1277
+ self.line_extremities_keys["Goal right post left"] = ("R_GOAL_TL_POST", "R_GOAL_BL_POST")
1278
+ self.line_extremities_keys["Goal right post right"] = ("R_GOAL_TR_POST", "R_GOAL_BR_POST")
1279
+ self.line_extremities_keys["Circle right"] = (
1280
+ "TR_16M_LINE_AND_PENALTY_ARC_INTERSECTION",
1281
+ "BR_16M_LINE_AND_PENALTY_ARC_INTERSECTION",
1282
+ )
1283
+ self.line_extremities_keys["Circle left"] = (
1284
+ "TL_16M_LINE_AND_PENALTY_ARC_INTERSECTION",
1285
+ "BL_16M_LINE_AND_PENALTY_ARC_INTERSECTION",
1286
+ )
1287
+
1288
+ def points(self):
1289
+ return [
1290
+ self.center_mark,
1291
+ self.halfway_and_bottom_touch_line_mark,
1292
+ self.halfway_and_top_touch_line_mark,
1293
+ self.halfway_line_and_center_circle_top_mark,
1294
+ self.halfway_line_and_center_circle_bottom_mark,
1295
+ self.bottom_right_corner,
1296
+ self.bottom_left_corner,
1297
+ self.top_right_corner,
1298
+ self.top_left_corner,
1299
+ self.left_penalty_mark,
1300
+ self.right_penalty_mark,
1301
+ self.left_penalty_area_top_right_corner,
1302
+ self.left_penalty_area_top_left_corner,
1303
+ self.left_penalty_area_bottom_right_corner,
1304
+ self.left_penalty_area_bottom_left_corner,
1305
+ self.right_penalty_area_top_right_corner,
1306
+ self.right_penalty_area_top_left_corner,
1307
+ self.right_penalty_area_bottom_right_corner,
1308
+ self.right_penalty_area_bottom_left_corner,
1309
+ self.left_goal_area_top_right_corner,
1310
+ self.left_goal_area_top_left_corner,
1311
+ self.left_goal_area_bottom_right_corner,
1312
+ self.left_goal_area_bottom_left_corner,
1313
+ self.right_goal_area_top_right_corner,
1314
+ self.right_goal_area_top_left_corner,
1315
+ self.right_goal_area_bottom_right_corner,
1316
+ self.right_goal_area_bottom_left_corner,
1317
+ self.top_left_16M_penalty_arc_mark,
1318
+ self.top_right_16M_penalty_arc_mark,
1319
+ self.bottom_left_16M_penalty_arc_mark,
1320
+ self.bottom_right_16M_penalty_arc_mark,
1321
+ self.left_goal_top_left_post,
1322
+ self.left_goal_top_right_post,
1323
+ self.left_goal_bottom_left_post,
1324
+ self.left_goal_bottom_right_post,
1325
+ self.right_goal_top_left_post,
1326
+ self.right_goal_top_right_post,
1327
+ self.right_goal_bottom_left_post,
1328
+ self.right_goal_bottom_right_post,
1329
+ ]
1330
+
1331
+ def sample_field_points(self, dist=0.1, dist_circles=0.2):
1332
+ """
1333
+ Samples each pitch element every dist meters, returns a dictionary associating the class of the element with a list of points sampled along this element.
1334
+ :param dist: the distance in meters between each point sampled
1335
+ :param dist_circles: the distance in meters between each point sampled on circles
1336
+ :return: a dictionary associating the class of the element with a list of points sampled along this element.
1337
+ """
1338
+ polylines = dict()
1339
+ center = self.point_dict["CENTER_MARK"]
1340
+ fromAngle = 0.0
1341
+ toAngle = 2 * np.pi
1342
+
1343
+ if toAngle < fromAngle:
1344
+ toAngle += 2 * np.pi
1345
+ x1 = center[0] + np.cos(fromAngle) * SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
1346
+ y1 = center[1] + np.sin(fromAngle) * SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
1347
+ z1 = 0.0
1348
+ point = np.array((x1, y1, z1))
1349
+ polyline = [point]
1350
+ length = SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS * (toAngle - fromAngle)
1351
+ nb_pts = int(length / dist_circles)
1352
+ dangle = dist_circles / SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
1353
+ for i in range(1, nb_pts):
1354
+ angle = fromAngle + i * dangle
1355
+ x = center[0] + np.cos(angle) * SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
1356
+ y = center[1] + np.sin(angle) * SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
1357
+ z = 0
1358
+ point = np.array((x, y, z))
1359
+ polyline.append(point)
1360
+
1361
+ # split central circle in left and right
1362
+ polylines["Circle central left"] = [p for p in polyline if p[0] < 0.0]
1363
+ polylines["Circle central right"] = [p for p in polyline if p[0] >= 0.0]
1364
+ for key, line in self.line_extremities.items():
1365
+
1366
+ if "Circle" in key:
1367
+ if key == "Circle right":
1368
+ top = self.point_dict["TR_16M_LINE_AND_PENALTY_ARC_INTERSECTION"]
1369
+ bottom = self.point_dict["BR_16M_LINE_AND_PENALTY_ARC_INTERSECTION"]
1370
+ center = self.point_dict["R_PENALTY_MARK"]
1371
+ toAngle = np.arctan2(top[1] - center[1], top[0] - center[0]) + 2 * np.pi
1372
+ fromAngle = np.arctan2(bottom[1] - center[1], bottom[0] - center[0]) + 2 * np.pi
1373
+ elif key == "Circle left":
1374
+ top = self.point_dict["TL_16M_LINE_AND_PENALTY_ARC_INTERSECTION"]
1375
+ bottom = self.point_dict["BL_16M_LINE_AND_PENALTY_ARC_INTERSECTION"]
1376
+ center = self.point_dict["L_PENALTY_MARK"]
1377
+ fromAngle = np.arctan2(top[1] - center[1], top[0] - center[0]) + 2 * np.pi
1378
+ toAngle = np.arctan2(bottom[1] - center[1], bottom[0] - center[0]) + 2 * np.pi
1379
+ if toAngle < fromAngle:
1380
+ toAngle += 2 * np.pi
1381
+ x1 = (
1382
+ center[0]
1383
+ + np.cos(fromAngle) * SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
1384
+ )
1385
+ y1 = (
1386
+ center[1]
1387
+ + np.sin(fromAngle) * SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
1388
+ )
1389
+ z1 = 0.0
1390
+ xn = (
1391
+ center[0]
1392
+ + np.cos(toAngle) * SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
1393
+ )
1394
+ yn = (
1395
+ center[1]
1396
+ + np.sin(toAngle) * SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
1397
+ )
1398
+ zn = 0.0
1399
+ start = np.array((x1, y1, z1))
1400
+ end = np.array((xn, yn, zn))
1401
+ polyline = [start]
1402
+ length = SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS * (
1403
+ toAngle - fromAngle
1404
+ )
1405
+ nb_pts = int(length / dist_circles)
1406
+ dangle = dist_circles / SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
1407
+ for i in range(1, nb_pts + 1):
1408
+ angle = fromAngle + i * dangle
1409
+ x = (
1410
+ center[0]
1411
+ + np.cos(angle) * SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
1412
+ )
1413
+ y = (
1414
+ center[1]
1415
+ + np.sin(angle) * SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
1416
+ )
1417
+ z = 0
1418
+ point = np.array((x, y, z))
1419
+ polyline.append(point)
1420
+ polyline.append(end)
1421
+ polylines[key] = polyline
1422
+ else:
1423
+ start = line[0]
1424
+ end = line[1]
1425
+
1426
+ polyline = [start]
1427
+
1428
+ total_dist = np.sqrt(np.sum(np.square(start - end)))
1429
+ nb_pts = int(total_dist / dist - 1)
1430
+
1431
+ v = end - start
1432
+ v /= np.linalg.norm(v)
1433
+ prev_pt = start
1434
+ for i in range(nb_pts):
1435
+ pt = prev_pt + dist * v
1436
+ prev_pt = pt
1437
+ polyline.append(pt)
1438
+ polyline.append(end)
1439
+ polylines[key] = polyline
1440
+ return polylines
1441
+
1442
+ def get_2d_homogeneous_line(self, line_name):
1443
+ """
1444
+ For lines belonging to the pitch lawn plane returns its 2D homogenous equation coefficients
1445
+ :param line_name
1446
+ :return: an array containing the three coefficients of the line
1447
+ """
1448
+ # ensure line in football pitch plane
1449
+ if (
1450
+ line_name in self.line_extremities.keys()
1451
+ and "post" not in line_name
1452
+ and "crossbar" not in line_name
1453
+ and "Circle" not in line_name
1454
+ ):
1455
+ extremities = self.line_extremities[line_name]
1456
+ p1 = np.array([extremities[0][0], extremities[0][1], 1], dtype="float")
1457
+ p2 = np.array([extremities[1][0], extremities[1][1], 1], dtype="float")
1458
+ line = np.cross(p1, p2)
1459
+
1460
+ return line
1461
+ return None
1462
+
1463
+
1464
+ class Abstract3dModel(metaclass=ABCMeta):
1465
+ def __init__(self) -> None:
1466
+
1467
+ self.points = None # keypoints: tensor of shape (N, 3)
1468
+ self.points_sampled = (
1469
+ None # sampled points for each segment Dict[str: torch.tensor of shape (*, 3)]
1470
+ )
1471
+ self.points_sampled_palette = {}
1472
+ self.segment_names = set(self.points_sampled_palette.keys())
1473
+
1474
+ self.line_segments = [] # tensor of shape (3, S_l, 2) containing 2 points
1475
+ self.line_segments_names = [] # list of of respective names for each s in S_l
1476
+ self.line_palette = [] # list of RGB tuples
1477
+
1478
+ self.circle_segments = None # tensor of shape (3, S_c, num_points_per_circle)
1479
+ self.circle_segments_names = [] # list of of respective names for each s in S_c
1480
+ self.circle_palette = [] # list of RGB tuples
1481
+
1482
+
1483
+ class Meshgrid(Abstract3dModel):
1484
+ def __init__(self, height=68, width=105):
1485
+ self.points = kornia.utils.create_meshgrid(
1486
+ height=height + 1, width=width + 1, normalized_coordinates=True
1487
+ )
1488
+ self.points = self.points.flatten(start_dim=-3, end_dim=-2)
1489
+ self.points[:, :, 0] = self.points[:, :, 0] * width / 2
1490
+ self.points[:, :, 1] = self.points[:, :, 1] * height / 2
1491
+ self.points = kornia.geometry.conversions.convert_points_to_homogeneous(self.points)
1492
+ self.points[:, :, -1] = 0.0 # set z=0)
1493
+ self.points = self.points.squeeze(0)
1494
+ self.points_sampled = {"meshgrid": self.points}
1495
+
1496
+
1497
+ class SoccerPitchLineCircleSegments(Abstract3dModel):
1498
+ def __init__(
1499
+ self,
1500
+ base_field,
1501
+ device="cpu",
1502
+ N_cstar=128,
1503
+ sampling_factor_lines=0.2,
1504
+ sampling_factor_circles=0.8,
1505
+ ) -> None:
1506
+
1507
+ if not (
1508
+ isinstance(base_field, SoccerPitchSNCircleCentralSplit)
1509
+ or isinstance(base_field, SoccerPitchSN)
1510
+ ):
1511
+ raise NotImplementedError
1512
+
1513
+ self.sampling_factor_lines = sampling_factor_lines
1514
+ self.sampling_factor_circles = sampling_factor_circles
1515
+
1516
+ self._field_sncalib = base_field
1517
+
1518
+ self.device = device
1519
+
1520
+ # classical keypoints as single tensor
1521
+ self.points = torch.from_numpy(np.stack(self._field_sncalib.points())).float().to(device)
1522
+
1523
+ # sampled points for each segment Dict[str: torch.tensor of shape (*, 3)]
1524
+ self.points_sampled = self._field_sncalib.sample_field_points(
1525
+ self.sampling_factor_lines, self.sampling_factor_circles
1526
+ )
1527
+ self.points_sampled = {
1528
+ k: torch.from_numpy(np.stack(v)).float().to(device)
1529
+ for k, v in self.points_sampled.items()
1530
+ }
1531
+ self.points_sampled_palette = self._field_sncalib.palette
1532
+ self.segment_names = set(self.points_sampled_palette.keys())
1533
+ self.cmap_01 = {k: [c / 255.0 for c in v] for k, v in self.points_sampled_palette.items()}
1534
+
1535
+ self.line_collection: List[LineCollection] = []
1536
+ self.line_segments = [] # (3, S, 2)
1537
+ self.line_segments_names = []
1538
+
1539
+ for line_name, (p0, p1) in self._field_sncalib.line_extremities.items():
1540
+ if "Circle" not in line_name:
1541
+ p0 = torch.from_numpy(p0).float().to(device)
1542
+ p1 = torch.from_numpy(p1).float().to(device)
1543
+ direction = p1 - p0
1544
+ direction_norm = direction / torch.linalg.norm(direction)
1545
+ self.line_collection.append(
1546
+ LineCollection(
1547
+ support=p0,
1548
+ direction=direction,
1549
+ direction_norm=direction_norm,
1550
+ )
1551
+ )
1552
+ self.line_segments_names.append(line_name)
1553
+ self.line_segments.append(torch.stack([p0, p1], dim=1))
1554
+
1555
+ self.line_segments = torch.stack(self.line_segments, dim=-1).transpose(1, 2).to(device)
1556
+ self.line_palette = [
1557
+ self._field_sncalib.palette[self.line_segments_names[i]]
1558
+ for i in range(len(self.line_segments_names))
1559
+ ]
1560
+
1561
+ if isinstance(base_field, SoccerPitchSNCircleCentralSplit):
1562
+ self.circle_segments_names = [
1563
+ "Circle central left",
1564
+ "Circle central right",
1565
+ "Circle left",
1566
+ "Circle right",
1567
+ ]
1568
+ elif isinstance(base_field, SoccerPitchSN):
1569
+ self.circle_segments_names = [
1570
+ "Circle central",
1571
+ "Circle left",
1572
+ "Circle right",
1573
+ ]
1574
+ else:
1575
+ raise NotImplementedError
1576
+
1577
+ self.circle_segments = self._sample_points_from_circle_segments(
1578
+ m=N_cstar
1579
+ ) # (3, num_circles, num_points_per_circle)
1580
+
1581
+ self.circle_palette = [
1582
+ self._field_sncalib.palette[self.circle_segments_names[i]]
1583
+ for i in range(len(self.circle_segments_names))
1584
+ ]
1585
+
1586
+ def _sample_points_from_circle_segments(self, m: int):
1587
+
1588
+ sampled_points = self._field_sncalib.sample_field_points(dist=1.0, dist_circles=0.05)
1589
+ for key in self.circle_segments_names:
1590
+ assert len(sampled_points[key]) >= m
1591
+ return (
1592
+ torch.stack(
1593
+ [
1594
+ torch.from_numpy(np.stack(random.sample(sampled_points[key], k=m), axis=-1))
1595
+ for key in self.circle_segments_names
1596
+ ],
1597
+ dim=1,
1598
+ )
1599
+ .float()
1600
+ .to(self.device)
1601
+ ) # (3, S, m)
1602
+
1603
+
1604
+ if __name__ == "__main__":
1605
+ from matplotlib import pyplot as plt
1606
+ from mpl_toolkits.mplot3d.art3d import Poly3DCollection
1607
+
1608
+ model3d = SoccerPitchLineCircleSegments()
1609
+
1610
+ fig = plt.figure(figsize=(20, 20))
1611
+ ax = fig.add_subplot(projection="3d")
1612
+ for s in range(len(model3d.line_collection)):
1613
+ ax.quiver(
1614
+ model3d.line_collection[s].support[0],
1615
+ model3d.line_collection[s].support[1],
1616
+ model3d.line_collection[s].support[2],
1617
+ model3d.line_collection[s].direction[0],
1618
+ model3d.line_collection[s].direction[1],
1619
+ model3d.line_collection[s].direction[2],
1620
+ arrow_length_ratio=0.05,
1621
+ color=[x / 255.0 for x in model3d.line_palette[s]],
1622
+ zorder=2000,
1623
+ # length=68.0,
1624
+ linewidths=3,
1625
+ label=model3d.line_segments_names[s],
1626
+ alpha=0.5,
1627
+ )
1628
+
1629
+ plt.legend()
1630
+ ax.set_xlim([-105 / 2, 105 / 2])
1631
+ ax.set_ylim([-105 / 2, 105 / 2])
1632
+ ax.set_zlim([-105 / 2, 105 / 2])
1633
+ plt.show()
1634
+
1635
+ fig = plt.figure(figsize=(20, 20))
1636
+ ax = fig.add_subplot(projection="3d")
1637
+ for segment_name, sampled_points in model3d.points_sampled.items():
1638
+ ax.scatter(
1639
+ sampled_points[:, 0],
1640
+ sampled_points[:, 1],
1641
+ -sampled_points[:, 2],
1642
+ zorder=3000,
1643
+ color=[x / 255.0 for x in model3d.points_sampled_palette[segment_name]],
1644
+ marker="x",
1645
+ label=segment_name,
1646
+ )
1647
+
1648
+ plt.legend()
1649
+ ax.set_xlim([-105 / 2, 105 / 2])
1650
+ ax.set_ylim([-105 / 2, 105 / 2])
1651
+ ax.set_zlim([-105 / 2, 105 / 2])
1652
+ plt.show()
1653
+
1654
+ fig = plt.figure(figsize=(20, 20))
1655
+ ax = fig.add_subplot(projection="3d")
1656
+ for s in range(model3d.line_segments.shape[1]):
1657
+
1658
+ if "crossbar" in model3d.line_segments_names[s]:
1659
+ print(s, model3d.line_segments[:, s])
1660
+ ax.scatter(
1661
+ model3d.line_segments[0, s],
1662
+ model3d.line_segments[1, s],
1663
+ -model3d.line_segments[2, s],
1664
+ zorder=3000,
1665
+ color=[x / 255.0 for x in model3d.line_palette[s]],
1666
+ marker="x",
1667
+ label=model3d.line_segments_names[s],
1668
+ )
1669
+
1670
+ plt.legend()
1671
+ ax.set_xlim([-105 / 2, 105 / 2])
1672
+ ax.set_ylim([-105 / 2, 105 / 2])
1673
+ ax.set_zlim([-105 / 2, 105 / 2])
1674
+ plt.savefig("soccer_field_line_segments.pdf")
visualizer.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ # Dimensions du terrain en yards
5
+ FIELD_LENGTH_YARDS = 114.83
6
+ FIELD_WIDTH_YARDS = 74.37
7
+
8
+ # Constantes de taille d'image attendue
9
+ EXPECTED_H, EXPECTED_W = 720, 1280
10
+
11
+ # Import des constantes d'indices depuis pose_estimator si nécessaire (ou les redéfinir ici)
12
+ from pose_estimator import (LEFT_ANKLE_KP_INDEX, RIGHT_ANKLE_KP_INDEX,
13
+ CONFIDENCE_THRESHOLD_KEYPOINTS, DEFAULT_MARKER_COLOR, SKELETON_EDGES, SKELETON_THICKNESS)
14
+
15
+ # Constantes pour les marqueurs
16
+ MARKER_RADIUS = 6
17
+ MARKER_BORDER_THICKNESS = 1
18
+ MARKER_BORDER_COLOR = (0, 0, 0) # Noir
19
+
20
+ # Plage de modulation pour l'échelle dynamique inverseé
21
+ DYNAMIC_SCALE_MIN_MODULATION = 0.4 # Pour les joueurs les plus loin (haut de la minimap)
22
+ DYNAMIC_SCALE_MAX_MODULATION = 1.6 # Pour les joueurs les plus près (bas de la minimap)
23
+
24
+ def calculate_dynamic_scale(y_position, frame_height, min_scale=1.0, max_scale=2):
25
+ """Calcule le facteur d'échelle en fonction de la position verticale (non utilisé dans cette version simplifiée)."""
26
+ normalized_position = y_position / frame_height
27
+ return min_scale + (max_scale - min_scale) * normalized_position
28
+
29
+ def _prepare_minimap_base(minimap_size=(EXPECTED_W, EXPECTED_H)):
30
+ """Prépare le fond de la minimap (vert texturé verticalement dans la zone terrain) et calcule les métriques du terrain."""
31
+ minimap_h, minimap_w = minimap_size[1], minimap_size[0]
32
+
33
+ # Définir les couleurs et la largeur des bandes verticales
34
+ base_green = (0, 60, 0) # Vert foncé (fond)
35
+ stripe_green = (0, 70, 0) #
36
+ stripe_width = 5 # Largeur de chaque bande verticale (pixels)
37
+
38
+ # Initialiser TOUTE la minimap avec la couleur de base
39
+ minimap_bgr = np.full((minimap_h, minimap_w, 3), base_green, dtype=np.uint8)
40
+
41
+ # --- Calculer les métriques et les limites du terrain D'ABORD ---
42
+ scale_x = minimap_w / FIELD_LENGTH_YARDS
43
+ scale_y = minimap_h / FIELD_WIDTH_YARDS
44
+ scale = min(scale_x, scale_y) * 0.9 # Marge
45
+
46
+ field_width_px = int(FIELD_WIDTH_YARDS * scale)
47
+ field_length_px = int(FIELD_LENGTH_YARDS * scale)
48
+ offset_x = (minimap_w - field_length_px) // 2
49
+ offset_y = (minimap_h - field_width_px) // 2
50
+
51
+ # --- Dessiner les bandes VERTICALES alternées UNIQUEMENT dans la zone du terrain ---
52
+ for x in range(offset_x, offset_x + field_length_px, stripe_width * 2):
53
+ # Coordonnées du rectangle vertical pour la bande claire
54
+ start_x = x
55
+ end_x = min(x + stripe_width, offset_x + field_length_px) # Ne pas dépasser la limite droite
56
+ start_y = offset_y
57
+ end_y = offset_y + field_width_px
58
+
59
+ cv2.rectangle(minimap_bgr, (start_x, start_y), (end_x, end_y), stripe_green, thickness=-1)
60
+ # La bande foncée suivante est déjà là (couleur de base)
61
+
62
+ # --- Préparer la matrice S et les métriques à retourner ---
63
+ S = np.array([
64
+ [scale, 0, offset_x],
65
+ [0, scale, offset_y],
66
+ [0, 0, 1]
67
+ ], dtype=np.float32)
68
+
69
+ metrics = {
70
+ "scale": scale,
71
+ "offset_x": offset_x,
72
+ "offset_y": offset_y,
73
+ "field_width_px": field_width_px,
74
+ "field_length_px": field_length_px,
75
+ "S": S
76
+ }
77
+ return minimap_bgr, metrics
78
+
79
+ def _draw_field_lines(minimap_bgr, metrics):
80
+ """Dessine les lignes du terrain et les buts sur la minimap."""
81
+ scale = metrics["scale"]
82
+ offset_x = metrics["offset_x"]
83
+ offset_y = metrics["offset_y"]
84
+ field_width_px = metrics["field_width_px"]
85
+ field_length_px = metrics["field_length_px"]
86
+
87
+ line_color = (255, 255, 255) # Blanc
88
+ line_thickness = 1
89
+ goal_thickness = 1 # Épaisseur pour les poteaux de but
90
+ goal_width_yards = 8 # Largeur standard du but
91
+
92
+ center_x = offset_x + field_length_px // 2
93
+ center_y = offset_y + field_width_px // 2
94
+ penalty_area_width_px = int(SoccerPitchSN.PENALTY_AREA_WIDTH * scale)
95
+ penalty_area_length_px = int(SoccerPitchSN.PENALTY_AREA_LENGTH * scale)
96
+ goal_area_width_px = int(SoccerPitchSN.GOAL_AREA_WIDTH * scale)
97
+ goal_area_length_px = int(SoccerPitchSN.GOAL_AREA_LENGTH * scale)
98
+ center_circle_radius_px = int(SoccerPitchSN.CENTER_CIRCLE_RADIUS * scale)
99
+ goal_width_px = int(goal_width_yards * scale)
100
+
101
+ # Dessiner les lignes principales
102
+ cv2.rectangle(minimap_bgr, (offset_x, offset_y), (offset_x + field_length_px, offset_y + field_width_px), line_color, line_thickness)
103
+ cv2.line(minimap_bgr, (center_x, offset_y), (center_x, offset_y + field_width_px), line_color, line_thickness)
104
+ cv2.circle(minimap_bgr, (center_x, center_y), center_circle_radius_px, line_color, line_thickness)
105
+ cv2.circle(minimap_bgr, (center_x, center_y), 3, line_color, -1) # Point central
106
+ cv2.rectangle(minimap_bgr, (offset_x, center_y - penalty_area_width_px//2), (offset_x + penalty_area_length_px, center_y + penalty_area_width_px//2), line_color, line_thickness)
107
+ cv2.rectangle(minimap_bgr, (offset_x + field_length_px - penalty_area_length_px, center_y - penalty_area_width_px//2), (offset_x + field_length_px, center_y + penalty_area_width_px//2), line_color, line_thickness)
108
+ cv2.rectangle(minimap_bgr, (offset_x, center_y - goal_area_width_px//2), (offset_x + goal_area_length_px, center_y + goal_area_width_px//2), line_color, line_thickness)
109
+ cv2.rectangle(minimap_bgr, (offset_x + field_length_px - goal_area_length_px, center_y - goal_area_width_px//2), (offset_x + field_length_px, center_y + goal_area_width_px//2), line_color, line_thickness)
110
+
111
+ # Dessiner les buts (rectangles épais sur les lignes de but)
112
+ goal_y_top = center_y - goal_width_px // 2
113
+ goal_y_bottom = center_y + goal_width_px // 2
114
+ # But gauche
115
+ cv2.rectangle(minimap_bgr, (offset_x-6 - goal_thickness // 2, goal_y_top), (offset_x + goal_thickness // 2, goal_y_bottom), line_color, thickness=goal_thickness)
116
+ # But droit
117
+ cv2.rectangle(minimap_bgr, (offset_x + field_length_px - goal_thickness // 2, goal_y_top), (offset_x +6 + field_length_px + goal_thickness // 2, goal_y_bottom), line_color, thickness=goal_thickness)
118
+
119
+ def create_minimap_view(image_rgb, homography, minimap_size=(EXPECTED_W, EXPECTED_H)):
120
+ """Crée une vue minimap avec l'image RGB originale projetée et les lignes du terrain.
121
+
122
+ Args:
123
+ image_rgb: Image source en format RGB (720p attendu).
124
+ homography: Matrice d'homographie (numpy array) pour projeter l'image.
125
+ minimap_size: Taille de la minimap de sortie (largeur, hauteur).
126
+
127
+ Returns:
128
+ L'image de la minimap (numpy array BGR) ou None si l'homographie est invalide.
129
+ """
130
+ if homography is None:
131
+ print("Avertissement : Homographie invalide, impossible de créer la minimap (vue originale).")
132
+ return None
133
+
134
+ h, w = image_rgb.shape[:2]
135
+ if h != EXPECTED_H or w != EXPECTED_W:
136
+ print(f"Avertissement : L'image RGB d'entrée n'est pas en {EXPECTED_W}x{EXPECTED_H}, redimensionnement...")
137
+ image_rgb = cv2.resize(image_rgb, (EXPECTED_W, EXPECTED_H), interpolation=cv2.INTER_LINEAR)
138
+
139
+ minimap_bgr, metrics = _prepare_minimap_base(minimap_size)
140
+ S = metrics["S"]
141
+
142
+ try:
143
+ overlay = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
144
+ overlay = cv2.convertScaleAbs(overlay, alpha=1.2, beta=10)
145
+ overlay = cv2.addWeighted(overlay, 0.5, np.zeros_like(overlay), 0.5, 0)
146
+
147
+ H_minimap = S @ homography
148
+ warped = cv2.warpPerspective(overlay, H_minimap, minimap_size, flags=cv2.INTER_LINEAR)
149
+
150
+ mask = cv2.cvtColor(warped, cv2.COLOR_BGR2GRAY)
151
+ _, mask = cv2.threshold(mask, 1, 255, cv2.THRESH_BINARY)
152
+
153
+ minimap_bgr = np.where(mask[..., None] > 0, warped, minimap_bgr)
154
+
155
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
156
+ cv2.drawContours(minimap_bgr, contours, -1, (255, 255, 255), 2)
157
+
158
+ except Exception as e:
159
+ print(f"Erreur lors de la projection sur la mini-carte (vue originale) : {str(e)}")
160
+
161
+ _draw_field_lines(minimap_bgr, metrics)
162
+ return minimap_bgr
163
+
164
+ def create_minimap_with_offset_skeletons(player_data_list, homography,
165
+ base_skeleton_scale: float,
166
+ minimap_size=(EXPECTED_W, EXPECTED_H)) -> tuple[np.ndarray | None, float | None]:
167
+ """Crée une vue minimap en dessinant le squelette original (réduit/agrandi dynamiquement et inversé)
168
+ à la position projetée du joueur, trié par position Y.
169
+
170
+ Args:
171
+ player_data_list: Liste de dictionnaires retournée par get_player_data.
172
+ homography: Matrice d'homographie (numpy array).
173
+ base_skeleton_scale: Facteur d'échelle de base pour dessiner les squelettes.
174
+ minimap_size: Taille de la minimap de sortie (largeur, hauteur).
175
+
176
+ Returns:
177
+ Tuple: (L'image de la minimap (numpy array BGR) ou None, Échelle moyenne appliquée ou None)
178
+ """
179
+ if homography is None:
180
+ print("Avertissement : Homographie invalide, impossible de créer la minimap (squelettes décalés).")
181
+ return None, None # Retourner None pour l'image et l'échelle
182
+
183
+ minimap_bgr, metrics = _prepare_minimap_base(minimap_size)
184
+
185
+ # --- Dessiner les lignes du terrain D'ABORD ---
186
+ _draw_field_lines(minimap_bgr, metrics)
187
+
188
+ S = metrics["S"]
189
+ H_minimap = S @ homography
190
+
191
+ players_to_draw = [] # Liste pour stocker les joueurs valides avec leur position Y
192
+
193
+ # --- Étape 1 & 2: Calculer la position projetée pour tous les joueurs valides ---
194
+ for p_data in player_data_list:
195
+ kps_img = p_data['keypoints']
196
+ scores = p_data['scores']
197
+ bbox = p_data['bbox']
198
+ color = p_data['avg_color']
199
+
200
+ # -- Calculer le point de référence sur l'image --
201
+ l_ankle_pt = kps_img[LEFT_ANKLE_KP_INDEX]
202
+ r_ankle_pt = kps_img[RIGHT_ANKLE_KP_INDEX]
203
+ l_ankle_score = scores[LEFT_ANKLE_KP_INDEX]
204
+ r_ankle_score = scores[RIGHT_ANKLE_KP_INDEX]
205
+ ref_point_img = None
206
+ if l_ankle_score > CONFIDENCE_THRESHOLD_KEYPOINTS and r_ankle_score > CONFIDENCE_THRESHOLD_KEYPOINTS:
207
+ ref_point_img = (l_ankle_pt + r_ankle_pt) / 2
208
+ elif l_ankle_score > CONFIDENCE_THRESHOLD_KEYPOINTS:
209
+ ref_point_img = l_ankle_pt
210
+ elif r_ankle_score > CONFIDENCE_THRESHOLD_KEYPOINTS:
211
+ ref_point_img = r_ankle_pt
212
+ else:
213
+ x1, _, x2, y2 = bbox
214
+ ref_point_img = np.array([(x1 + x2) / 2, y2], dtype=np.float32)
215
+ if ref_point_img is None: continue
216
+
217
+ # -- Projeter ce point de référence sur la minimap --
218
+ try:
219
+ point_to_transform = np.array([[ref_point_img]], dtype=np.float32)
220
+ projected_point = cv2.perspectiveTransform(point_to_transform, H_minimap)
221
+ mx, my = map(int, projected_point[0, 0])
222
+ h_map, w_map = minimap_bgr.shape[:2]
223
+ if not (0 <= mx < w_map and 0 <= my < h_map):
224
+ continue # Ignorer si hors des limites de la minimap
225
+ except Exception as e:
226
+ # print(f"Erreur lors de la projection du point de référence {ref_point_img}: {e}") # Optionnel: décommenter pour debug
227
+ continue
228
+
229
+ # Stocker les données nécessaires pour le tri et le dessin
230
+ players_to_draw.append({
231
+ 'data': p_data,
232
+ 'mx': mx,
233
+ 'my': my,
234
+ 'ref_point': ref_point_img
235
+ })
236
+
237
+ # --- Étape 3: Trier les joueurs par position Y (ordre croissant) ---
238
+ # Ceux avec Y plus petit (plus haut) seront dessinés en premier
239
+ players_to_draw.sort(key=lambda p: p['my'])
240
+
241
+ # Variables pour calculer l'échelle moyenne appliquée
242
+ total_applied_scale = 0.0
243
+ drawn_players_count = 0
244
+
245
+ # --- Étape 4: Dessiner les joueurs dans l'ordre trié (MAINTENANT AU-DESSUS DES LIGNES) ---
246
+ for player_info in players_to_draw:
247
+ p_data = player_info['data']
248
+ mx = player_info['mx']
249
+ my = player_info['my']
250
+ ref_point_img = player_info['ref_point']
251
+ kps_img = p_data['keypoints']
252
+ scores = p_data['scores']
253
+ # color = p_data['avg_color'] # Ignorer la couleur calculée
254
+ drawing_color = (0, 0, 0) # Utiliser le noir pour tous les joueurs
255
+
256
+ # -- Calculer l'échelle dynamique INVERSÉE pour CE joueur --
257
+ minimap_height = minimap_bgr.shape[0]
258
+ if minimap_height == 0: continue
259
+ ref_y_normalized = my / minimap_height
260
+ dynamic_modulation = DYNAMIC_SCALE_MIN_MODULATION + \
261
+ (DYNAMIC_SCALE_MAX_MODULATION - DYNAMIC_SCALE_MIN_MODULATION) * (1.0 - ref_y_normalized)
262
+ dynamic_modulation = np.clip(dynamic_modulation, DYNAMIC_SCALE_MIN_MODULATION * 0.8, DYNAMIC_SCALE_MAX_MODULATION * 1.2)
263
+ final_draw_scale = base_skeleton_scale * dynamic_modulation
264
+
265
+ # Ajouter à la somme pour la moyenne
266
+ total_applied_scale += final_draw_scale
267
+ drawn_players_count += 1
268
+
269
+ # -- Dessiner le squelette --
270
+ kps_relative_to_ref = kps_img - ref_point_img
271
+ for kp_idx1, kp_idx2 in SKELETON_EDGES:
272
+ if scores[kp_idx1] > CONFIDENCE_THRESHOLD_KEYPOINTS and scores[kp_idx2] > CONFIDENCE_THRESHOLD_KEYPOINTS:
273
+ pt1_map_offset = (mx, my) + kps_relative_to_ref[kp_idx1] * final_draw_scale
274
+ pt2_map_offset = (mx, my) + kps_relative_to_ref[kp_idx2] * final_draw_scale
275
+ pt1_draw = tuple(map(int, pt1_map_offset))
276
+ pt2_draw = tuple(map(int, pt2_map_offset))
277
+ # Vérifier si les points sont dans les limites avant de dessiner (sécurité)
278
+ h_map, w_map = minimap_bgr.shape[:2]
279
+ if (0 <= pt1_draw[0] < w_map and 0 <= pt1_draw[1] < h_map and
280
+ 0 <= pt2_draw[0] < w_map and 0 <= pt2_draw[1] < h_map):
281
+ cv2.line(minimap_bgr, pt1_draw, pt2_draw, drawing_color, SKELETON_THICKNESS, cv2.LINE_AA) # Utiliser drawing_color (noir)
282
+
283
+ # Calculer l'échelle moyenne finale
284
+ average_draw_scale = base_skeleton_scale # Valeur par défaut si aucun joueur n'est dessiné
285
+ if drawn_players_count > 0:
286
+ average_draw_scale = total_applied_scale / drawn_players_count
287
+
288
+ return minimap_bgr, average_draw_scale # Retourner aussi l'échelle moyenne
289
+
290
+ # Définition simplifiée de SoccerPitchSN juste pour les constantes de dimension
291
+ # (pour éviter d'importer toute la classe complexe)
292
+ class SoccerPitchSN:
293
+ GOAL_LINE_TO_PENALTY_MARK = 11.0
294
+ PENALTY_AREA_WIDTH = 42
295
+ PENALTY_AREA_LENGTH = 19
296
+ GOAL_AREA_WIDTH = 18.32
297
+ GOAL_AREA_LENGTH = 5.5
298
+ CENTER_CIRCLE_RADIUS = 10