Commit
·
bdb955e
0
Parent(s):
init
Browse files- .gitattributes +2 -0
- .gitignore +12 -0
- README.md +107 -0
- app.py +234 -0
- common/data/augmentation.py +55 -0
- common/data/calib.py +116 -0
- common/data/transforms.py +23 -0
- common/data/utils.py +72 -0
- common/infer/base.py +43 -0
- common/infer/module.py +24 -0
- common/infer/sink.py +42 -0
- common/loggers/homography_previewer.py +103 -0
- common/loggers/image_preview.py +93 -0
- main.py +199 -0
- pose_estimator.py +265 -0
- requirements.txt +22 -0
- tvcalib/cam_distr/tv_main_behind.py +77 -0
- tvcalib/cam_distr/tv_main_center.py +78 -0
- tvcalib/cam_distr/tv_main_left.py +77 -0
- tvcalib/cam_distr/tv_main_right.py +77 -0
- tvcalib/cam_distr/tv_main_tribune.py +77 -0
- tvcalib/cam_modules.py +583 -0
- tvcalib/data/dataset.py +142 -0
- tvcalib/data/utils.py +166 -0
- tvcalib/infer/module.py +518 -0
- tvcalib/models/segmentation.py +22 -0
- tvcalib/sn_segmentation/resources/mean.npy +0 -0
- tvcalib/sn_segmentation/resources/std.npy +0 -0
- tvcalib/sn_segmentation/src/baseline_extremities.py +311 -0
- tvcalib/sn_segmentation/src/custom_extremities.py +322 -0
- tvcalib/sn_segmentation/src/dataloader.py +122 -0
- tvcalib/sn_segmentation/src/evaluate_extremities.py +270 -0
- tvcalib/sn_segmentation/src/masks_gt2chen.py +217 -0
- tvcalib/sn_segmentation/src/masks_pred2chen.py +150 -0
- tvcalib/sn_segmentation/src/segmentation/README.md +23 -0
- tvcalib/sn_segmentation/src/segmentation/coco_utils.py +108 -0
- tvcalib/sn_segmentation/src/segmentation/presets.py +39 -0
- tvcalib/sn_segmentation/src/segmentation/soccerdata.py +164 -0
- tvcalib/sn_segmentation/src/segmentation/train.py +341 -0
- tvcalib/sn_segmentation/src/segmentation/transforms.py +100 -0
- tvcalib/sn_segmentation/src/segmentation/utils.py +304 -0
- tvcalib/utils/data_distr.py +44 -0
- tvcalib/utils/io.py +44 -0
- tvcalib/utils/linalg.py +106 -0
- tvcalib/utils/objects_3d.py +1674 -0
- visualizer.py +298 -0
.gitattributes
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
venv
|
2 |
+
*.pyc
|
3 |
+
*.jpg
|
4 |
+
*.png
|
5 |
+
*.jpeg
|
6 |
+
*.gif
|
7 |
+
*.bmp
|
8 |
+
*.tiff
|
9 |
+
*.ico
|
10 |
+
|
11 |
+
|
12 |
+
*.pt
|
README.md
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Foot Calib Pos Image Processor
|
2 |
+
|
3 |
+
This project uses the TvCalib library to calculate the homography matrix for a football pitch image. This matrix allows mapping image points onto a standard 2D representation of the pitch (minimap).
|
4 |
+
|
5 |
+
The project also includes a pose estimation step (ViTPose) to detect players and calculate the average color of their torso.
|
6 |
+
|
7 |
+
**The main result is a minimap where each player is represented by their colored skeleton, drawn at a *dynamically* reduced scale around their projected position on the pitch.**
|
8 |
+
The position is determined by projecting a reference point (feet/bottom of bbox) using the homography. The skeleton is then drawn using its relative coordinates (original image), scaled, and translated. The scale used depends on the **player's vertical position (Y) on the minimap** (higher on the minimap = smaller) and a base factor adjustable via the `--target_avg_scale` option.
|
9 |
+
|
10 |
+
It is also possible to visualize a minimap with the projection of the original image for comparison.
|
11 |
+
|
12 |
+
## Features
|
13 |
+
|
14 |
+
* Homography calculation from a single image via TvCalib.
|
15 |
+
* Person detection (RT-DETR) and pose estimation (ViTPose).
|
16 |
+
* Calculation of the filtered average torso color for each player.
|
17 |
+
* Projection of each player's reference point (feet/bbox) onto the minimap.
|
18 |
+
* Generation of a minimap with the **original skeletons (colored, dynamically scaled based on projected Y position, and offset) drawn around the projected point**.
|
19 |
+
* (Optional) Generation of a minimap with the projected original image.
|
20 |
+
* Possibility to save the calculated homography matrix.
|
21 |
+
|
22 |
+
## Project Structure
|
23 |
+
|
24 |
+
```
|
25 |
+
.
|
26 |
+
├── .git/ # Git metadata
|
27 |
+
├── .venv/ # Python virtual environment (recommended)
|
28 |
+
├── common/ # Common Python modules (potentially)
|
29 |
+
├── data/ # Data (input images, etc.)
|
30 |
+
├── models/
|
31 |
+
│ └── segmentation/
|
32 |
+
│ └── train_59.pt # Pre-trained segmentation model (TO DOWNLOAD)
|
33 |
+
├── tvcalib/ # Source code of the TvCalib library (or a fork/adaptation)
|
34 |
+
│ └── infer/
|
35 |
+
│ └── module.py # Main module for TvCalib inference
|
36 |
+
├── .gitignore # Files ignored by Git
|
37 |
+
├── main.py # Main script entry point
|
38 |
+
├── requirements.txt # Python dependencies file
|
39 |
+
├── visualizer.py # Module for generating visualization minimaps
|
40 |
+
├── pose_estimator.py # Module for pose estimation and player data extraction
|
41 |
+
└── README.md # This file
|
42 |
+
```
|
43 |
+
|
44 |
+
## Installation
|
45 |
+
|
46 |
+
1. **Clone the repository:**
|
47 |
+
```powershell
|
48 |
+
git clone <repository-url>
|
49 |
+
cd Foot_calib_pos_image_processor
|
50 |
+
```
|
51 |
+
|
52 |
+
2. **Create a virtual environment (recommended):**
|
53 |
+
```powershell
|
54 |
+
python -m venv venv
|
55 |
+
.\venv\Scripts\Activate.ps1
|
56 |
+
```
|
57 |
+
|
58 |
+
3. **Install dependencies:**
|
59 |
+
```powershell
|
60 |
+
pip install -r requirements.txt
|
61 |
+
```
|
62 |
+
*(Make sure to install PyTorch with appropriate CUDA support if needed.)*
|
63 |
+
|
64 |
+
4. **Download the segmentation model:**
|
65 |
+
Place `train_59.pt` in `models/segmentation/`.
|
66 |
+
|
67 |
+
5. **(Automatic) Download detection/pose models:**
|
68 |
+
The RT-DETR and ViTPose models will be downloaded automatically.
|
69 |
+
|
70 |
+
## Usage
|
71 |
+
|
72 |
+
Run the `main.py` script providing the path to the image:
|
73 |
+
|
74 |
+
```powershell
|
75 |
+
python main.py path/to/your/image.jpg [OPTIONS]
|
76 |
+
```
|
77 |
+
|
78 |
+
**Options:**
|
79 |
+
|
80 |
+
* `image_path`: Path to the input image (required).
|
81 |
+
* `--output_homography PATH.npy`: Saves the calculated homography matrix.
|
82 |
+
* `--optim_steps NUMBER`: Number of optimization steps for calibration (default: 500, was 1000 in original README example).
|
83 |
+
* `--target_avg_scale FLOAT`: **Target average** scale factor for drawing skeletons (default: 0.35). The script attempts to adjust the internal base scale so that the resulting average scale (after inverse dynamic modulation) is close to this value.
|
84 |
+
|
85 |
+
**Example:**
|
86 |
+
|
87 |
+
```powershell
|
88 |
+
# Simple usage (target average size 0.35)
|
89 |
+
python main.py data/img3.png
|
90 |
+
|
91 |
+
# Aim for larger skeletons on average (target 0.5)
|
92 |
+
python main.py data/img2.png --target_avg_scale 0.5
|
93 |
+
```
|
94 |
+
|
95 |
+
The script will display:
|
96 |
+
* Time taken and homography matrix.
|
97 |
+
* Estimated internal base scale.
|
98 |
+
* Requested TARGET average scale.
|
99 |
+
* ACTUALLY applied FINAL average scale.
|
100 |
+
* Window: **Minimap with Original Projection**.
|
101 |
+
* Window: **Minimap with Offset Skeletons** (dynamically scaled inversely, targeting the average scale).
|
102 |
+
* Press any key to close.
|
103 |
+
|
104 |
+
## Key Dependencies
|
105 |
+
|
106 |
+
* PyTorch, OpenCV, NumPy, PyTorch Lightning
|
107 |
+
* SoccerNet, Kornia, Hugging Face Transformers, Pillow
|
app.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from pathlib import Path
|
6 |
+
import time
|
7 |
+
import traceback
|
8 |
+
|
9 |
+
# Importer les éléments nécessaires depuis les autres modules du projet
|
10 |
+
try:
|
11 |
+
from tvcalib.infer.module import TvCalibInferModule
|
12 |
+
# On essaie d'importer la fonction de pré-traitement depuis main.py
|
13 |
+
# Si main.py n'est pas conçu pour être importé, il faudra peut-être copier/coller cette fonction ici
|
14 |
+
from main import preprocess_image_tvcalib, IMAGE_SHAPE, SEGMENTATION_MODEL_PATH
|
15 |
+
from visualizer import (
|
16 |
+
create_minimap_view,
|
17 |
+
create_minimap_with_offset_skeletons,
|
18 |
+
DYNAMIC_SCALE_MIN_MODULATION,
|
19 |
+
DYNAMIC_SCALE_MAX_MODULATION
|
20 |
+
)
|
21 |
+
from pose_estimator import get_player_data
|
22 |
+
except ImportError as e:
|
23 |
+
print(f"Erreur d'importation : {e}")
|
24 |
+
print("Assurez-vous que les modules tvcalib, main, visualizer, pose_estimator sont accessibles.")
|
25 |
+
# On pourrait mettre des stubs ou lever une exception ici pour Gradio
|
26 |
+
raise e
|
27 |
+
|
28 |
+
# --- Configuration Globale (Modèle, etc.) ---
|
29 |
+
# Essayer de charger le modèle une seule fois globalement peut améliorer les performances
|
30 |
+
# mais attention à la gestion de l'état dans les environnements multi-utilisateurs/threads de Spaces
|
31 |
+
# Pour l'instant, on le chargera dans la fonction de traitement.
|
32 |
+
# MODEL = None # Optionnel: Charger ici
|
33 |
+
# DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
34 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
35 |
+
|
36 |
+
print(f"Utilisation du device : {DEVICE}")
|
37 |
+
|
38 |
+
if not SEGMENTATION_MODEL_PATH.exists():
|
39 |
+
print(f"AVERTISSEMENT : Modèle de segmentation introuvable : {SEGMENTATION_MODEL_PATH}")
|
40 |
+
print("L'application risque de ne pas fonctionner. Assurez-vous que le fichier est présent.")
|
41 |
+
# Gradio peut quand même démarrer, mais le traitement échouera.
|
42 |
+
|
43 |
+
# --- Fonction Principale de Traitement ---
|
44 |
+
def process_image_and_generate_minimaps(input_image_bgr, optim_steps, target_avg_scale):
|
45 |
+
"""
|
46 |
+
Prend une image BGR (NumPy), les étapes d'optimisation et l'échelle cible,
|
47 |
+
retourne les deux minimaps (NumPy BGR).
|
48 |
+
"""
|
49 |
+
global DEVICE # Utiliser le device défini globalement
|
50 |
+
|
51 |
+
print("\n--- Nouvelle requête ---")
|
52 |
+
print(f"Paramètres: optim_steps={optim_steps}, target_avg_scale={target_avg_scale}")
|
53 |
+
|
54 |
+
# Vérifier si le modèle de segmentation existe (important car on ne peut pas l'afficher dans l'UI facilement)
|
55 |
+
if not SEGMENTATION_MODEL_PATH.exists():
|
56 |
+
# Retourner des images noires ou des messages d'erreur
|
57 |
+
error_msg = f"Erreur: Modèle {SEGMENTATION_MODEL_PATH} introuvable."
|
58 |
+
print(error_msg)
|
59 |
+
placeholder = np.zeros((300, 500, 3), dtype=np.uint8) # Placeholder noir
|
60 |
+
cv2.putText(placeholder, error_msg, (10, 150), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA)
|
61 |
+
return placeholder, placeholder.copy() # Retourner deux placeholders
|
62 |
+
|
63 |
+
try:
|
64 |
+
# 1. Initialisation du modèle TvCalib (peut être lent si fait à chaque fois)
|
65 |
+
# Pourrait être optimisé en chargeant globalement (voir commentaire plus haut)
|
66 |
+
print("Initialisation de TvCalibInferModule...")
|
67 |
+
start_init = time.time()
|
68 |
+
model = TvCalibInferModule(
|
69 |
+
segmentation_checkpoint=SEGMENTATION_MODEL_PATH,
|
70 |
+
image_shape=IMAGE_SHAPE, # Utilise la constante importée
|
71 |
+
optim_steps=int(optim_steps), # Assurer que c'est un entier
|
72 |
+
lens_dist=False
|
73 |
+
)
|
74 |
+
# Déplacer le modèle sur le bon device ici explicitement si nécessaire
|
75 |
+
# model.to(DEVICE) # TvCalibInferModule devrait gérer ça en interne ? A vérifier.
|
76 |
+
print(f"✓ Modèle chargé sur {next(model.model_calib.parameters()).device} en {time.time() - start_init:.3f}s")
|
77 |
+
model_device = next(model.model_calib.parameters()).device # Vérifier le device réel
|
78 |
+
|
79 |
+
# 2. Prétraitement de l'image
|
80 |
+
print("Prétraitement de l'image...")
|
81 |
+
start_preprocess = time.time()
|
82 |
+
# preprocess_image_tvcalib attend BGR, Gradio fournit BGR par défaut avec type="numpy"
|
83 |
+
# Assurez-vous que preprocess_image_tvcalib déplace bien le tenseur sur le bon device
|
84 |
+
image_tensor, image_bgr_resized, image_rgb_resized = preprocess_image_tvcalib(input_image_bgr)
|
85 |
+
# Vérifier/forcer le device du tenseur
|
86 |
+
image_tensor = image_tensor.to(model_device)
|
87 |
+
print(f"Temps de prétraitement TvCalib : {time.time() - start_preprocess:.3f}s")
|
88 |
+
|
89 |
+
|
90 |
+
# 3. Exécuter la calibration (Segmentation + Optimisation)
|
91 |
+
print("Exécution de la segmentation...")
|
92 |
+
start_segment = time.time()
|
93 |
+
with torch.no_grad():
|
94 |
+
keypoints = model._segment(image_tensor)
|
95 |
+
print(f"Temps de segmentation : {time.time() - start_segment:.3f}s")
|
96 |
+
|
97 |
+
print("Exécution de la calibration (optimisation)...")
|
98 |
+
start_calibrate = time.time()
|
99 |
+
homography = model._calibrate(keypoints)
|
100 |
+
print(f"Temps de calibration : {time.time() - start_calibrate:.3f}s")
|
101 |
+
|
102 |
+
if homography is None:
|
103 |
+
print("Aucune homographie n'a pu être calculée.")
|
104 |
+
# Retourner des placeholders avec message
|
105 |
+
placeholder = np.zeros((300, 500, 3), dtype=np.uint8)
|
106 |
+
cv2.putText(placeholder, "Homographie non calculee", (10, 150), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA)
|
107 |
+
return placeholder, placeholder.copy()
|
108 |
+
|
109 |
+
if isinstance(homography, torch.Tensor):
|
110 |
+
homography_np = homography.detach().cpu().numpy()
|
111 |
+
else:
|
112 |
+
homography_np = np.array(homography) # Assurer que c'est un NumPy array
|
113 |
+
print("✓ Homographie calculée.")
|
114 |
+
|
115 |
+
|
116 |
+
# 4. Extraction des données joueurs
|
117 |
+
print("Extraction des données joueurs (pose+couleur)...")
|
118 |
+
start_pose = time.time()
|
119 |
+
# get_player_data attend une image BGR
|
120 |
+
player_list = get_player_data(image_bgr_resized)
|
121 |
+
print(f"Temps d'extraction données joueurs : {time.time() - start_pose:.3f}s ({len(player_list)} joueurs trouvés)")
|
122 |
+
|
123 |
+
# 5. Calcul de l'échelle de base
|
124 |
+
print("Calcul de l'échelle de base...")
|
125 |
+
# Reprend la logique de main.py pour estimer l'échelle de base
|
126 |
+
avg_modulation_expected = DYNAMIC_SCALE_MIN_MODULATION + \
|
127 |
+
(DYNAMIC_SCALE_MAX_MODULATION - DYNAMIC_SCALE_MIN_MODULATION) * (1.0 - 0.5)
|
128 |
+
estimated_base_scale = target_avg_scale
|
129 |
+
if avg_modulation_expected != 0:
|
130 |
+
estimated_base_scale = target_avg_scale / avg_modulation_expected
|
131 |
+
print(f" Échelle de base interne estimée pour cible {target_avg_scale:.3f} : {estimated_base_scale:.3f}")
|
132 |
+
|
133 |
+
# 6. Génération des minimaps
|
134 |
+
print("Génération des minimaps...")
|
135 |
+
start_viz = time.time()
|
136 |
+
# Minimap avec projection (image RGB attendue par la fonction)
|
137 |
+
minimap_original = create_minimap_view(image_rgb_resized, homography_np)
|
138 |
+
|
139 |
+
# Minimap avec squelettes (utilise l'échelle estimée)
|
140 |
+
minimap_offset_skeletons, actual_avg_scale = create_minimap_with_offset_skeletons(
|
141 |
+
player_list,
|
142 |
+
homography_np,
|
143 |
+
base_skeleton_scale=estimated_base_scale
|
144 |
+
)
|
145 |
+
print(f"Temps de génération des minimaps : {time.time() - start_viz:.3f}s")
|
146 |
+
if actual_avg_scale is not None:
|
147 |
+
print(f"Échelle moyenne CIBLE demandée : {target_avg_scale:.3f}")
|
148 |
+
print(f"Échelle moyenne FINALE RÉELLEMENT appliquée : {actual_avg_scale:.3f}")
|
149 |
+
|
150 |
+
|
151 |
+
# Vérifier si les minimaps ont été créées (peuvent être None en cas d'erreur interne)
|
152 |
+
if minimap_original is None:
|
153 |
+
print("Erreur: La minimap originale n'a pas pu être générée.")
|
154 |
+
minimap_original = np.zeros((300, 500, 3), dtype=np.uint8)
|
155 |
+
cv2.putText(minimap_original, "Erreur Minimap Originale", (10, 150), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA)
|
156 |
+
|
157 |
+
if minimap_offset_skeletons is None:
|
158 |
+
print("Erreur: La minimap squelettes n'a pas pu être générée.")
|
159 |
+
minimap_offset_skeletons = np.zeros((300, 500, 3), dtype=np.uint8)
|
160 |
+
cv2.putText(minimap_offset_skeletons, "Erreur Minimap Squelettes", (10, 150), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA)
|
161 |
+
|
162 |
+
# Gradio attend des images RGB pour l'affichage, nos fonctions retournent probablement BGR (via OpenCV)
|
163 |
+
# Conversion BGR -> RGB si nécessaire
|
164 |
+
if minimap_original.shape[2] == 3: # Assurer que c'est une image couleur
|
165 |
+
minimap_original = cv2.cvtColor(minimap_original, cv2.COLOR_BGR2RGB)
|
166 |
+
if minimap_offset_skeletons.shape[2] == 3:
|
167 |
+
minimap_offset_skeletons = cv2.cvtColor(minimap_offset_skeletons, cv2.COLOR_BGR2RGB)
|
168 |
+
|
169 |
+
print("✓ Traitement terminé.")
|
170 |
+
return minimap_original, minimap_offset_skeletons
|
171 |
+
|
172 |
+
except Exception as e:
|
173 |
+
print(f"Erreur majeure lors du traitement : {e}")
|
174 |
+
traceback.print_exc()
|
175 |
+
# Retourner des placeholders avec message d'erreur général
|
176 |
+
placeholder = np.zeros((300, 500, 3), dtype=np.uint8)
|
177 |
+
cv2.putText(placeholder, f"Erreur: {e}", (10, 150), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 255), 1, cv2.LINE_AA)
|
178 |
+
return placeholder, placeholder.copy()
|
179 |
+
|
180 |
+
# --- Interface Gradio ---
|
181 |
+
with gr.Blocks() as demo:
|
182 |
+
gr.Markdown("# Foot Calib Pos Image Processor - Minimap Generator")
|
183 |
+
gr.Markdown(
|
184 |
+
"Upload a football pitch image to compute homography (TvCalib), "
|
185 |
+
"detect players (RT-DETR/ViTPose), and generate two minimap visualizations."
|
186 |
+
)
|
187 |
+
|
188 |
+
with gr.Row():
|
189 |
+
with gr.Column(scale=1):
|
190 |
+
input_image = gr.Image(type="numpy", label="Input Image (.jpg, .png)")
|
191 |
+
optim_steps_slider = gr.Slider(
|
192 |
+
minimum=100, maximum=2000, step=50, value=500,
|
193 |
+
label="TvCalib Optimization Steps",
|
194 |
+
info="Number of iterations to refine homography."
|
195 |
+
)
|
196 |
+
target_scale_slider = gr.Slider(
|
197 |
+
minimum=0.1, maximum=2.5, step=0.05, value=0.35,
|
198 |
+
label="Target Average Skeleton Scale",
|
199 |
+
info="Adjusts the desired average size of skeletons on the minimap."
|
200 |
+
)
|
201 |
+
submit_button = gr.Button("Generate Minimaps", variant="primary")
|
202 |
+
|
203 |
+
with gr.Column(scale=2):
|
204 |
+
output_minimap_orig = gr.Image(type="numpy", label="Minimap with Original Projection", interactive=False)
|
205 |
+
output_minimap_skel = gr.Image(type="numpy", label="Minimap with Offset Skeletons", interactive=False)
|
206 |
+
|
207 |
+
# Connecter le bouton à la fonction de traitement
|
208 |
+
submit_button.click(
|
209 |
+
fn=process_image_and_generate_minimaps,
|
210 |
+
inputs=[input_image, optim_steps_slider, target_scale_slider],
|
211 |
+
outputs=[output_minimap_orig, output_minimap_skel]
|
212 |
+
)
|
213 |
+
|
214 |
+
# Ajouter des exemples (optionnel mais utile pour Spaces)
|
215 |
+
gr.Examples(
|
216 |
+
examples=[
|
217 |
+
["data/img1.png", 500, 1.35],
|
218 |
+
["data/img2.png", 1000, 1.5],
|
219 |
+
["data/img3.png", 500, 0.8],
|
220 |
+
["data/7.jpg", 500, 1], # Add .jpg examples
|
221 |
+
["data/15.jpg", 800, 1.35],
|
222 |
+
],
|
223 |
+
inputs=[input_image, optim_steps_slider, target_scale_slider],
|
224 |
+
outputs=[output_minimap_orig, output_minimap_skel], # Outputs won't be pre-calculated here, just to populate inputs
|
225 |
+
fn=process_image_and_generate_minimaps, # Function will be called if example is clicked
|
226 |
+
cache_examples=False # Important if processing is long or depends on external models
|
227 |
+
)
|
228 |
+
|
229 |
+
|
230 |
+
# --- Lancement de l'application ---
|
231 |
+
if __name__ == "__main__":
|
232 |
+
# share=True creates a temporary public link (useful for testing outside localhost)
|
233 |
+
# debug=True shows more Gradio logs in the console
|
234 |
+
demo.launch(debug=True)
|
common/data/augmentation.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import numpy as np
|
3 |
+
from typing import Callable, Any
|
4 |
+
import methods.common.data.utils as utils
|
5 |
+
import random
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
class WarpAugmentation(Callable):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
warp_function: Any,
|
13 |
+
mode: str="train",
|
14 |
+
noise_translate: float=0.0,
|
15 |
+
noise_rotate: float=0.0
|
16 |
+
):
|
17 |
+
# Template
|
18 |
+
self.template_grid = utils.gen_template_grid()
|
19 |
+
self.warp_fn = warp_function
|
20 |
+
self.mode = mode
|
21 |
+
self.noise_translate = noise_translate
|
22 |
+
self.noise_rotate = noise_rotate
|
23 |
+
|
24 |
+
def __call__(
|
25 |
+
self,
|
26 |
+
image: np.ndarray,
|
27 |
+
homography: np.ndarray,
|
28 |
+
frame_idx: int
|
29 |
+
):
|
30 |
+
warp_image, warp_grid, warp_homography = self.warp_fn(
|
31 |
+
mode=self.mode,
|
32 |
+
frame=image,
|
33 |
+
f_idx=frame_idx,
|
34 |
+
gt_homo=homography,
|
35 |
+
template=self.template_grid,
|
36 |
+
noise_trans=self.noise_translate,
|
37 |
+
noise_rotate=self.noise_rotate,
|
38 |
+
index=-1 # not really used ...
|
39 |
+
)
|
40 |
+
return warp_image, warp_grid, warp_homography
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
class LeftRightFlipAugmentation(Callable):
|
46 |
+
def __init__(self, enabled: bool=False):
|
47 |
+
self.enabled = enabled
|
48 |
+
|
49 |
+
def __call__(self, image, grid):
|
50 |
+
if (self.enabled):
|
51 |
+
if (random.random() < 0.5):
|
52 |
+
image, grid = utils.put_lrflip_augmentation(image, grid)
|
53 |
+
return image, grid, True
|
54 |
+
|
55 |
+
return image, grid, False
|
common/data/calib.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import methods.common.data.utils as utils
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
class Calibrator:
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
num_classes: int=92,
|
13 |
+
nms_threshold: float=0.995
|
14 |
+
):
|
15 |
+
self.template_grid = utils.gen_template_grid()
|
16 |
+
self.num_classes = num_classes
|
17 |
+
self.nms_threshold = nms_threshold
|
18 |
+
|
19 |
+
|
20 |
+
def find_homography(
|
21 |
+
self,
|
22 |
+
heatmap_logits: torch.Tensor
|
23 |
+
):
|
24 |
+
""" Extract keypoints from heatmap, and find homography matrix
|
25 |
+
|
26 |
+
heatmap_logits: torch.Tensor for individual frame (not a mini-batch!)
|
27 |
+
"""
|
28 |
+
|
29 |
+
pred_rgb, pred_keypoints, scores_heatmap = self.decode_heatmap(heatmap_logits)
|
30 |
+
homography = None
|
31 |
+
|
32 |
+
# We need at least 4 point correspondences
|
33 |
+
if (pred_rgb.shape[0] >= 4):
|
34 |
+
src_pts, dst_pts = self.get_class_mapping(pred_rgb)
|
35 |
+
|
36 |
+
# Find homography from point correspondences
|
37 |
+
homography, _ = cv2.findHomography(
|
38 |
+
src_pts.reshape(-1, 1, 2),
|
39 |
+
dst_pts.reshape(-1, 1, 2),
|
40 |
+
cv2.RANSAC,
|
41 |
+
10
|
42 |
+
)
|
43 |
+
|
44 |
+
return homography, pred_keypoints, scores_heatmap
|
45 |
+
|
46 |
+
|
47 |
+
def decode_heatmap(self, heatmap_logits: torch.Tensor):
|
48 |
+
""" Decode heatmap info keypoint set using non-maximum suppression
|
49 |
+
heatmap_logits: torc.Tensor with shape <NUM_CLASSES; H; W>
|
50 |
+
"""
|
51 |
+
|
52 |
+
pred_heatmap = torch.softmax(heatmap_logits, dim=0)
|
53 |
+
arg = torch.argmax(pred_heatmap, dim=0).detach().cpu().numpy()
|
54 |
+
scores, pred_heatmap = torch.max(pred_heatmap, dim=0)
|
55 |
+
|
56 |
+
# Convert to Numpy & get keypoints locations
|
57 |
+
scores = scores.detach().cpu().numpy()
|
58 |
+
pred_heatmap = pred_heatmap.detach().cpu().numpy()
|
59 |
+
pred_class_dict = self.get_class_dict(scores, pred_heatmap)
|
60 |
+
|
61 |
+
# Colorize
|
62 |
+
num_classes = heatmap_logits.shape[0]
|
63 |
+
np_scores = np.clip(arg * 255.0 / num_classes, 0, 255).astype(np.uint8)
|
64 |
+
scores_heatmap = cv2.applyColorMap(np_scores, cv2.COLORMAP_HOT)
|
65 |
+
scores_heatmap = cv2.cvtColor(scores_heatmap, cv2.COLOR_BGR2RGB)
|
66 |
+
|
67 |
+
# Produce image with keypoints
|
68 |
+
pred_keypoints = np.zeros_like(pred_heatmap, dtype=np.uint8)
|
69 |
+
pred_rgb = []
|
70 |
+
for _, (pk, pv) in enumerate(pred_class_dict.items()):
|
71 |
+
if (pv):
|
72 |
+
pred_keypoints[pv[1][0], pv[1][1]] = pk # (H,W)
|
73 |
+
# camera view point sets (x, y, label) in rgb domain not heatmap domain
|
74 |
+
pred_rgb.append([pv[1][1] * 4, pv[1][0] * 4, pk])
|
75 |
+
pred_rgb = np.asarray(pred_rgb, dtype=np.float32) # (?, 3)
|
76 |
+
|
77 |
+
# Return list of point locations, and image of keypoints
|
78 |
+
return pred_rgb, pred_keypoints, scores_heatmap
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
def get_class_mapping(self, rgb):
|
83 |
+
src_pts = rgb.copy()
|
84 |
+
cls_map_pts = []
|
85 |
+
|
86 |
+
for ind, elem in enumerate(src_pts):
|
87 |
+
coords = np.where(elem[2] == self.template_grid[:, 2])[0] # find correspondence
|
88 |
+
cls_map_pts.append(self.template_grid[coords[0]])
|
89 |
+
dst_pts = np.array(cls_map_pts, dtype=np.float32)
|
90 |
+
|
91 |
+
return src_pts[:, :2], dst_pts[:, :2]
|
92 |
+
|
93 |
+
|
94 |
+
def get_class_dict(self, scores, pred):
|
95 |
+
# Decode
|
96 |
+
pred_cls_dict = {k: [] for k in range(1, self.num_classes)}
|
97 |
+
for cls in range(1, self.num_classes):
|
98 |
+
pred_inds = (pred == cls)
|
99 |
+
|
100 |
+
# implies the current class does not appear in this heatmaps
|
101 |
+
if not np.any(pred_inds):
|
102 |
+
continue
|
103 |
+
|
104 |
+
values = scores[pred_inds]
|
105 |
+
max_score = values.max()
|
106 |
+
max_index = values.argmax()
|
107 |
+
|
108 |
+
indices = np.where(pred_inds)
|
109 |
+
coords = list(zip(indices[0], indices[1]))
|
110 |
+
|
111 |
+
# the only keypoint with max confidence is greater than threshold or not
|
112 |
+
if max_score >= self.nms_threshold:
|
113 |
+
pred_cls_dict[cls].append(max_score)
|
114 |
+
pred_cls_dict[cls].append(coords[max_index])
|
115 |
+
|
116 |
+
return pred_cls_dict
|
common/data/transforms.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
class UnNormalize(object):
|
8 |
+
def __init__(self, mean, std):
|
9 |
+
self.mean = mean
|
10 |
+
self.std = std
|
11 |
+
|
12 |
+
def __call__(self, tensor):
|
13 |
+
"""
|
14 |
+
Args:
|
15 |
+
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
|
16 |
+
Returns:
|
17 |
+
Tensor: Normalized image.
|
18 |
+
"""
|
19 |
+
for t, m, s in zip(tensor, self.mean, self.std):
|
20 |
+
t.mul_(s).add_(m)
|
21 |
+
# The normalize code -> t.sub_(m).div_(s)
|
22 |
+
return tensor
|
23 |
+
|
common/data/utils.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
def yards(x):
|
9 |
+
return x * 1.0936132983
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
def to_torch(ndarray):
|
14 |
+
if type(ndarray).__module__ == 'numpy':
|
15 |
+
return torch.from_numpy(ndarray.copy())
|
16 |
+
elif not torch.is_tensor(ndarray):
|
17 |
+
raise ValueError("Cannot convert {} to torch tensor"
|
18 |
+
.format(type(ndarray)))
|
19 |
+
return ndarray
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
def gen_template_grid():
|
24 |
+
# === set uniform grid ===
|
25 |
+
# field_dim_x, field_dim_y = 105.000552, 68.003928 # in meter
|
26 |
+
field_dim_x, field_dim_y = 114.83, 74.37 # in yard
|
27 |
+
# field_dim_x, field_dim_y = 115, 74 # in yard
|
28 |
+
nx, ny = (13, 7)
|
29 |
+
x = np.linspace(0, field_dim_x, nx)
|
30 |
+
y = np.linspace(0, field_dim_y, ny)
|
31 |
+
xv, yv = np.meshgrid(x, y, indexing='ij')
|
32 |
+
uniform_grid = np.stack((xv, yv), axis=2).reshape(-1, 2)
|
33 |
+
uniform_grid = np.concatenate((uniform_grid, np.ones(
|
34 |
+
(uniform_grid.shape[0], 1))), axis=1) # top2bottom, left2right
|
35 |
+
# TODO: class label in template, each keypoints is (x, y, c), c is label that starts from 1
|
36 |
+
for idx, pts in enumerate(uniform_grid):
|
37 |
+
pts[2] = idx + 1 # keypoints label
|
38 |
+
return uniform_grid
|
39 |
+
|
40 |
+
|
41 |
+
def put_lrflip_augmentation(frame, unigrid):
|
42 |
+
|
43 |
+
frame_h, frame_w = frame.shape[0], frame.shape[1]
|
44 |
+
flipped_img = np.fliplr(frame)
|
45 |
+
|
46 |
+
# TODO: grid flipping and re-assign pixels class label, 1-91
|
47 |
+
for ind, pts in enumerate(unigrid):
|
48 |
+
pts[0] = frame_w - pts[0]
|
49 |
+
col = (pts[2] - 1) // 7 # get each column of uniform grid
|
50 |
+
pts[2] = pts[2] - (col - 6) * 2 * 7 # keypoints label
|
51 |
+
|
52 |
+
return flipped_img, unigrid
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
def make_grid(images, nrow=4):
|
58 |
+
num_images = len(images)
|
59 |
+
ih, iw = images[0].shape[0], images[0].shape[1]
|
60 |
+
rows = min(num_images, nrow)
|
61 |
+
cols = (num_images + nrow-1) // nrow
|
62 |
+
|
63 |
+
result = np.zeros(shape=(cols*ih, rows*iw, 3), dtype=np.uint8)
|
64 |
+
for i in range(num_images):
|
65 |
+
cell_x = i%nrow
|
66 |
+
cell_y = i//nrow
|
67 |
+
result[
|
68 |
+
(cell_y+0)*ih:(cell_y+1)*ih,
|
69 |
+
(cell_x+0)*iw:(cell_x+1)*iw
|
70 |
+
] = images[i]
|
71 |
+
|
72 |
+
return result
|
common/infer/base.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from typing import Any, Dict
|
3 |
+
from torch.utils.data import Dataset
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
class InferDataModule:
|
11 |
+
def __init__(self):
|
12 |
+
pass
|
13 |
+
|
14 |
+
|
15 |
+
def get_inference_dataset(self) -> Dataset:
|
16 |
+
""" Return the dataset to run inference on """
|
17 |
+
pass
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
class InferModule:
|
22 |
+
def __init__(self):
|
23 |
+
pass
|
24 |
+
|
25 |
+
|
26 |
+
def setup(
|
27 |
+
self,
|
28 |
+
datamodule: InferDataModule
|
29 |
+
):
|
30 |
+
""" Initialize inference proces with the given datamodule """
|
31 |
+
pass
|
32 |
+
|
33 |
+
|
34 |
+
def predict(
|
35 |
+
self,
|
36 |
+
x: Any
|
37 |
+
) -> Dict:
|
38 |
+
""" Predict the calibration information for the given dataset sample """
|
39 |
+
return None
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
|
common/infer/module.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict
|
2 |
+
from methods.common.infer.base import *
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
class LabelInferModule(InferModule):
|
7 |
+
def __init__(self):
|
8 |
+
pass
|
9 |
+
|
10 |
+
def setup(self, datamodule: InferDataModule):
|
11 |
+
pass
|
12 |
+
|
13 |
+
|
14 |
+
def predict(self, x: Any) -> Dict:
|
15 |
+
"""
|
16 |
+
x - sample from dataset (including label)
|
17 |
+
"""
|
18 |
+
|
19 |
+
# Extract homography matrix
|
20 |
+
result = {
|
21 |
+
"homography": x["homography"]
|
22 |
+
}
|
23 |
+
|
24 |
+
return result
|
common/infer/sink.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from pathlib import Path
|
3 |
+
from collections import defaultdict
|
4 |
+
import pandas as pd
|
5 |
+
from typing import Dict
|
6 |
+
import cv2
|
7 |
+
|
8 |
+
|
9 |
+
class PredictionsSink:
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
target_filepath: Path
|
13 |
+
):
|
14 |
+
self.target_filepath = target_filepath
|
15 |
+
def new_item():
|
16 |
+
return []
|
17 |
+
self.data = defaultdict(new_item)
|
18 |
+
self.count = 0
|
19 |
+
|
20 |
+
|
21 |
+
def write(self, item: Dict):
|
22 |
+
self.data["item"].append(self.count)
|
23 |
+
for k,v in item.items():
|
24 |
+
if ("image" in k):
|
25 |
+
self.write_image(self.count, k, v)
|
26 |
+
else:
|
27 |
+
self.data[k].append(v)
|
28 |
+
|
29 |
+
self.count += 1
|
30 |
+
|
31 |
+
|
32 |
+
def flush(self):
|
33 |
+
df = pd.DataFrame(self.data)
|
34 |
+
self.target_filepath.parent.mkdir(parents=True, exist_ok=True)
|
35 |
+
df.to_csv(self.target_filepath)
|
36 |
+
|
37 |
+
def write_image(self, idx, name, image):
|
38 |
+
folder = self.target_filepath.parent / "images" / self.target_filepath.stem / name
|
39 |
+
filepath = folder / f"{idx:06d}.png"
|
40 |
+
filepath.parent.mkdir(parents=True, exist_ok=True)
|
41 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
42 |
+
cv2.imwrite(filepath.as_posix(), image)
|
common/loggers/homography_previewer.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import numpy as np
|
3 |
+
import skimage.segmentation as ss
|
4 |
+
import cv2
|
5 |
+
|
6 |
+
from torchvision import transforms
|
7 |
+
|
8 |
+
from project.data.transforms import TensorToNumpy
|
9 |
+
|
10 |
+
|
11 |
+
from methods.common.loggers.image_preview import ImagePreviewLogger
|
12 |
+
from methods.common.data.transforms import UnNormalize
|
13 |
+
from methods.common.data.calib import Calibrator
|
14 |
+
|
15 |
+
|
16 |
+
class HomographyPreviewerLogger(ImagePreviewLogger):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
experiment,
|
20 |
+
num_rows: int=3,
|
21 |
+
):
|
22 |
+
super().__init__(experiment, num_rows)
|
23 |
+
|
24 |
+
self.calib = Calibrator()
|
25 |
+
self.to_image = transforms.Compose([
|
26 |
+
UnNormalize(
|
27 |
+
mean=[0.485, 0.456, 0.406],
|
28 |
+
std=[0.229, 0.224, 0.225]
|
29 |
+
),
|
30 |
+
TensorToNumpy()
|
31 |
+
])
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
def draw_keypoints(self, image, image_keypoints, color=(255,0,0)):
|
37 |
+
""" Upscales keypoints map into image resolution and
|
38 |
+
overlays it over the image
|
39 |
+
"""
|
40 |
+
|
41 |
+
# Get keypoints image in target image resolution
|
42 |
+
a = ss.expand_labels(image_keypoints, distance=1)
|
43 |
+
a = cv2.resize(
|
44 |
+
a,
|
45 |
+
dsize=(image.shape[1], image.shape[0]),
|
46 |
+
interpolation=cv2.INTER_NEAREST
|
47 |
+
)
|
48 |
+
a = np.expand_dims(a, axis=2)
|
49 |
+
|
50 |
+
# Alpha of keypoints image
|
51 |
+
a = (a > 0)*1.0
|
52 |
+
|
53 |
+
# Color of keypoints image
|
54 |
+
c = np.concatenate([a*color[0], a*color[1], a*color[2]], axis=2)
|
55 |
+
|
56 |
+
# Superimpose the keypoints
|
57 |
+
result = (1.0-a)*image + a*c
|
58 |
+
result = np.clip(result, 0, 255).astype(np.uint8)
|
59 |
+
return result
|
60 |
+
|
61 |
+
|
62 |
+
def draw_playfield(
|
63 |
+
self,
|
64 |
+
image,
|
65 |
+
image_playfield,
|
66 |
+
homography,
|
67 |
+
color=(255,0,0),
|
68 |
+
alpha=1.0,
|
69 |
+
flip=False
|
70 |
+
):
|
71 |
+
""" Draws the playfield image under the homography matrix
|
72 |
+
over the target image
|
73 |
+
"""
|
74 |
+
if (homography is None):
|
75 |
+
return image
|
76 |
+
|
77 |
+
# Warp the playfield image
|
78 |
+
warp_field = cv2.warpPerspective(
|
79 |
+
image_playfield,
|
80 |
+
homography,
|
81 |
+
(image.shape[1], image.shape[0]),
|
82 |
+
cv2.INTER_LINEAR,
|
83 |
+
borderMode=cv2.BORDER_CONSTANT,
|
84 |
+
borderValue=(0)
|
85 |
+
)
|
86 |
+
|
87 |
+
if (flip):
|
88 |
+
warp_field = np.fliplr(warp_field)
|
89 |
+
|
90 |
+
# Get the alpha
|
91 |
+
a = np.expand_dims((warp_field / 255.0), axis=2)
|
92 |
+
|
93 |
+
# Color of playfield
|
94 |
+
c = np.concatenate([a*color[0], a*color[1], a*color[2]], axis=2)
|
95 |
+
|
96 |
+
# Draw with specified alpha
|
97 |
+
a = a * alpha
|
98 |
+
|
99 |
+
# Superimpose the playing field image
|
100 |
+
result = (1.0-a)*image + a*c
|
101 |
+
result = np.clip(result, 0, 255).astype(np.uint8)
|
102 |
+
return result
|
103 |
+
|
common/loggers/image_preview.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from project.base.logging import Logger
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
import methods.common.data.utils as utils
|
8 |
+
|
9 |
+
|
10 |
+
class ImagePreviewLogger(Logger):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
experiment,
|
14 |
+
num_rows: int=3
|
15 |
+
):
|
16 |
+
self.experiment = experiment
|
17 |
+
if (experiment is not None):
|
18 |
+
self.tracker = experiment.tracker
|
19 |
+
self.num_rows = num_rows
|
20 |
+
|
21 |
+
# Will be sampled later
|
22 |
+
self.samples = None
|
23 |
+
self.images = None
|
24 |
+
|
25 |
+
|
26 |
+
def on_training_start(self):
|
27 |
+
|
28 |
+
datamodule = self.experiment.datamodule
|
29 |
+
module = self.experiment.module
|
30 |
+
|
31 |
+
# Get the images
|
32 |
+
self.samples, images = self.sample_images(datamodule, module)
|
33 |
+
if (len(images) > 0):
|
34 |
+
self.images = torch.concatenate(
|
35 |
+
images,
|
36 |
+
dim=0
|
37 |
+
).to(module.device)
|
38 |
+
|
39 |
+
|
40 |
+
def on_epoch_end(self, epoch, stats):
|
41 |
+
self.make_preview()
|
42 |
+
|
43 |
+
|
44 |
+
def make_preview(self):
|
45 |
+
|
46 |
+
# Get the model
|
47 |
+
model = self.experiment.module.model
|
48 |
+
if (isinstance(model, nn.DataParallel)):
|
49 |
+
model = model.module
|
50 |
+
|
51 |
+
# Get the preview results
|
52 |
+
model.eval()
|
53 |
+
items = self.process_images(model, self.images)
|
54 |
+
model.train()
|
55 |
+
|
56 |
+
# Arange into grid, and send to tracker
|
57 |
+
log_images = {}
|
58 |
+
# Unpack the list of dicts
|
59 |
+
for item in items:
|
60 |
+
for key,image in item.items():
|
61 |
+
if (not key in log_images):
|
62 |
+
log_images[key] = []
|
63 |
+
log_images[key].append(image)
|
64 |
+
|
65 |
+
# Arange images into grids
|
66 |
+
result = {}
|
67 |
+
for key, images in log_images.items():
|
68 |
+
result[key] = utils.make_grid(images, nrow=self.num_rows)
|
69 |
+
|
70 |
+
# Send to tracker
|
71 |
+
self.tracker.write_images(result)
|
72 |
+
|
73 |
+
|
74 |
+
def sample_dataset(self, dataset, num_images):
|
75 |
+
idx = np.random.choice(
|
76 |
+
len(dataset),
|
77 |
+
size=(num_images,),
|
78 |
+
replace=False
|
79 |
+
)
|
80 |
+
samples = [ dataset[i] for i in idx ]
|
81 |
+
return samples
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
def sample_images(self, datamodule, module):
|
87 |
+
""" Sample our datasets """
|
88 |
+
return [], []
|
89 |
+
|
90 |
+
|
91 |
+
def process_images(self, model, images):
|
92 |
+
""" Returns list of dict[key,image] items """
|
93 |
+
return []
|
main.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from pathlib import Path
|
6 |
+
import time
|
7 |
+
import traceback
|
8 |
+
|
9 |
+
# Assurez-vous que le répertoire tvcalib est dans le PYTHONPATH
|
10 |
+
# ou exécutez depuis le répertoire tvcalib_image_processor
|
11 |
+
from tvcalib.infer.module import TvCalibInferModule
|
12 |
+
# Importer les fonctions de visualisation et les constantes de modulation
|
13 |
+
from visualizer import (
|
14 |
+
create_minimap_view,
|
15 |
+
create_minimap_with_offset_skeletons,
|
16 |
+
DYNAMIC_SCALE_MIN_MODULATION, # Importer les constantes
|
17 |
+
DYNAMIC_SCALE_MAX_MODULATION
|
18 |
+
)
|
19 |
+
# Importer la fonction d'extraction des données joueurs
|
20 |
+
from pose_estimator import get_player_data
|
21 |
+
|
22 |
+
# Constantes
|
23 |
+
IMAGE_SHAPE = (720, 1280) # Hauteur, Largeur
|
24 |
+
SEGMENTATION_MODEL_PATH = Path("models/segmentation/train_59.pt")
|
25 |
+
|
26 |
+
def preprocess_image_tvcalib(image_bgr):
|
27 |
+
"""Prétraite l'image BGR pour TvCalib et retourne le tenseur et l'image RGB redimensionnée."""
|
28 |
+
if image_bgr is None:
|
29 |
+
raise ValueError("Impossible de charger l'image")
|
30 |
+
|
31 |
+
# 1. Redimensionner en 720p si nécessaire
|
32 |
+
h, w = image_bgr.shape[:2]
|
33 |
+
if h != IMAGE_SHAPE[0] or w != IMAGE_SHAPE[1]:
|
34 |
+
print(f"Redimensionnement de l'image vers {IMAGE_SHAPE[1]}x{IMAGE_SHAPE[0]}")
|
35 |
+
image_bgr_resized = cv2.resize(image_bgr, (IMAGE_SHAPE[1], IMAGE_SHAPE[0]), interpolation=cv2.INTER_LINEAR)
|
36 |
+
else:
|
37 |
+
image_bgr_resized = image_bgr
|
38 |
+
|
39 |
+
# 2. Convertir en RGB (pour TvCalib ET pour la visualisation originale)
|
40 |
+
image_rgb_resized = cv2.cvtColor(image_bgr_resized, cv2.COLOR_BGR2RGB)
|
41 |
+
|
42 |
+
# 3. Normalisation spécifique pour le modèle pré-entraîné (pour TvCalib)
|
43 |
+
image_tensor = torch.from_numpy(image_rgb_resized).float()
|
44 |
+
image_tensor = image_tensor.permute(2, 0, 1) # HWC -> CHW
|
45 |
+
mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
|
46 |
+
std = torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
|
47 |
+
image_tensor = (image_tensor / 255.0 - mean) / std
|
48 |
+
|
49 |
+
# 4. Déplacer sur le bon device
|
50 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
51 |
+
image_tensor = image_tensor.to(device)
|
52 |
+
|
53 |
+
# Retourner le tenseur pour TvCalib, l'image BGR et RGB redimensionnée
|
54 |
+
return image_tensor, image_bgr_resized, image_rgb_resized
|
55 |
+
|
56 |
+
def main():
|
57 |
+
parser = argparse.ArgumentParser(description="Exécute la méthode TvCalib sur une seule image.")
|
58 |
+
parser.add_argument("image_path", type=str, help="Chemin vers l'image à traiter.")
|
59 |
+
parser.add_argument("--output_homography", type=str, default=None, help="Chemin optionnel pour sauvegarder la matrice d'homographie (.npy).")
|
60 |
+
parser.add_argument("--optim_steps", type=int, default=500, help="Nombre d'étapes d'optimisation pour la calibration (l'arrêt anticipé est désactivé).")
|
61 |
+
parser.add_argument("--target_avg_scale", type=float, default=1,
|
62 |
+
help="Facteur d'échelle MOYEN CIBLE pour dessiner les squelettes sur la minimap (défaut: 0.35). Le script ajuste l'échelle de base pour tenter d'atteindre cette moyenne.")
|
63 |
+
|
64 |
+
args = parser.parse_args()
|
65 |
+
|
66 |
+
if not Path(args.image_path).exists():
|
67 |
+
print(f"Erreur : Fichier image introuvable : {args.image_path}")
|
68 |
+
return
|
69 |
+
|
70 |
+
if not SEGMENTATION_MODEL_PATH.exists():
|
71 |
+
print(f"Erreur : Modèle de segmentation introuvable : {SEGMENTATION_MODEL_PATH}")
|
72 |
+
print("Assurez-vous d'avoir copié train_59.pt dans le dossier models/segmentation/")
|
73 |
+
return
|
74 |
+
|
75 |
+
print("Initialisation de TvCalibInferModule...")
|
76 |
+
try:
|
77 |
+
model = TvCalibInferModule(
|
78 |
+
segmentation_checkpoint=SEGMENTATION_MODEL_PATH,
|
79 |
+
image_shape=IMAGE_SHAPE,
|
80 |
+
optim_steps=args.optim_steps,
|
81 |
+
lens_dist=False # Gardons cela simple pour l'instant
|
82 |
+
)
|
83 |
+
print(f"✓ Modèle chargé sur {next(model.model_calib.parameters()).device}")
|
84 |
+
except Exception as e:
|
85 |
+
print(f"Erreur lors de l'initialisation du modèle : {e}")
|
86 |
+
return
|
87 |
+
|
88 |
+
print(f"Traitement de l'image : {args.image_path}")
|
89 |
+
try:
|
90 |
+
# Charger l'image (en BGR par défaut avec OpenCV)
|
91 |
+
image_bgr_orig = cv2.imread(args.image_path)
|
92 |
+
if image_bgr_orig is None:
|
93 |
+
raise FileNotFoundError(f"Impossible de lire le fichier image: {args.image_path}")
|
94 |
+
|
95 |
+
# Prétraiter l'image
|
96 |
+
start_preprocess = time.time()
|
97 |
+
image_tensor, image_bgr_resized, image_rgb_resized = preprocess_image_tvcalib(image_bgr_orig)
|
98 |
+
print(f"Temps de prétraitement TvCalib : {time.time() - start_preprocess:.3f}s")
|
99 |
+
|
100 |
+
# Exécuter la segmentation
|
101 |
+
print("Exécution de la segmentation...")
|
102 |
+
start_segment = time.time()
|
103 |
+
with torch.no_grad():
|
104 |
+
keypoints = model._segment(image_tensor)
|
105 |
+
print(f"Temps de segmentation : {time.time() - start_segment:.3f}s")
|
106 |
+
|
107 |
+
# Exécuter la calibration (optimisation)
|
108 |
+
print("Exécution de la calibration (optimisation)...")
|
109 |
+
start_calibrate = time.time()
|
110 |
+
homography = model._calibrate(keypoints)
|
111 |
+
print(f"Temps de calibration : {time.time() - start_calibrate:.3f}s")
|
112 |
+
|
113 |
+
if homography is not None:
|
114 |
+
print("\n--- Homographie Calculée ---")
|
115 |
+
if isinstance(homography, torch.Tensor):
|
116 |
+
homography_np = homography.detach().cpu().numpy()
|
117 |
+
else:
|
118 |
+
homography_np = homography
|
119 |
+
print(homography_np)
|
120 |
+
|
121 |
+
if args.output_homography:
|
122 |
+
try:
|
123 |
+
np.save(args.output_homography, homography_np)
|
124 |
+
print(f"\nHomographie sauvegardée dans : {args.output_homography}")
|
125 |
+
except Exception as e:
|
126 |
+
print(f"Erreur lors de la sauvegarde de l'homographie : {e}")
|
127 |
+
|
128 |
+
# --- Extraction des données joueurs ---
|
129 |
+
print("\nExtraction des données joueurs (pose+couleur)...")
|
130 |
+
start_pose = time.time()
|
131 |
+
player_list = get_player_data(image_bgr_resized)
|
132 |
+
print(f"Temps d'extraction données joueurs : {time.time() - start_pose:.3f}s ({len(player_list)} joueurs trouvés)")
|
133 |
+
|
134 |
+
# --- Calcul de l'échelle de base estimée ---
|
135 |
+
print("\nCalcul de l'échelle de base pour atteindre la cible...")
|
136 |
+
target_average_scale = args.target_avg_scale
|
137 |
+
|
138 |
+
# Calculer la modulation moyenne attendue (hypothèse: joueur moyen au centre Y=0.5)
|
139 |
+
# Logique inversée actuelle : MIN + (MAX - MIN) * (1.0 - norm_y)
|
140 |
+
avg_modulation_expected = DYNAMIC_SCALE_MIN_MODULATION + \
|
141 |
+
(DYNAMIC_SCALE_MAX_MODULATION - DYNAMIC_SCALE_MIN_MODULATION) * (1.0 - 0.5)
|
142 |
+
|
143 |
+
estimated_base_scale = target_average_scale # Valeur par défaut si modulation = 0
|
144 |
+
if avg_modulation_expected != 0:
|
145 |
+
estimated_base_scale = target_average_scale / avg_modulation_expected
|
146 |
+
else:
|
147 |
+
print("Avertissement : Modulation moyenne attendue nulle, impossible d'estimer l'échelle de base.")
|
148 |
+
|
149 |
+
print(f" Modulation dynamique moyenne attendue (pour Y=0.5) : {avg_modulation_expected:.3f}")
|
150 |
+
print(f" Échelle de base interne estimée pour cible {target_average_scale:.3f} : {estimated_base_scale:.3f}")
|
151 |
+
|
152 |
+
# --- Génération des DEUX minimaps ---
|
153 |
+
print("\nGénération des minimaps (Originale et Squelettes Décalés)...")
|
154 |
+
|
155 |
+
# 1. Minimap avec l'image originale (RGB)
|
156 |
+
minimap_original = create_minimap_view(image_rgb_resized, homography_np)
|
157 |
+
|
158 |
+
# 2. Minimap avec les squelettes
|
159 |
+
# Utiliser l'échelle de base ESTIMÉE
|
160 |
+
minimap_offset_skeletons, actual_avg_scale = create_minimap_with_offset_skeletons(
|
161 |
+
player_list,
|
162 |
+
homography_np,
|
163 |
+
base_skeleton_scale=estimated_base_scale # Utiliser l'estimation
|
164 |
+
)
|
165 |
+
|
166 |
+
# Afficher la cible et le résultat réel
|
167 |
+
if actual_avg_scale is not None:
|
168 |
+
print(f"\nÉchelle moyenne CIBLE demandée (--target_avg_scale) : {target_average_scale:.3f}")
|
169 |
+
print(f"Échelle moyenne FINALE RÉELLEMENT appliquée (basée sur joueurs réels) : {actual_avg_scale:.3f}")
|
170 |
+
|
171 |
+
# --- Affichage des résultats ---
|
172 |
+
print("\nAffichage des résultats. Appuyez sur une touche pour quitter.")
|
173 |
+
|
174 |
+
# Afficher la minimap originale
|
175 |
+
if minimap_original is not None:
|
176 |
+
cv2.imshow("Minimap avec Projection Originale", minimap_original)
|
177 |
+
else:
|
178 |
+
print("N'a pas pu générer la minimap originale.")
|
179 |
+
|
180 |
+
# Afficher la minimap avec les squelettes décalés
|
181 |
+
if minimap_offset_skeletons is not None:
|
182 |
+
cv2.imshow("Minimap avec Squelettes Decales", minimap_offset_skeletons)
|
183 |
+
else:
|
184 |
+
print("N'a pas pu générer la minimap squelettes décalés.")
|
185 |
+
|
186 |
+
cv2.waitKey(0) # Attend qu'une touche soit pressée
|
187 |
+
|
188 |
+
else:
|
189 |
+
print("\nAucune homographie n'a pu être calculée.")
|
190 |
+
|
191 |
+
except Exception as e:
|
192 |
+
print(f"Erreur lors du traitement de l'image : {e}")
|
193 |
+
traceback.print_exc()
|
194 |
+
finally:
|
195 |
+
print("Fermeture des fenêtres OpenCV.")
|
196 |
+
cv2.destroyAllWindows()
|
197 |
+
|
198 |
+
if __name__ == "__main__":
|
199 |
+
main()
|
pose_estimator.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
from PIL import Image
|
5 |
+
from transformers import AutoProcessor, RTDetrForObjectDetection, VitPoseForPoseEstimation
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
# --- Global variables for models and processor (lazy loading) ---
|
9 |
+
person_processor = None
|
10 |
+
person_model = None
|
11 |
+
pose_processor = None
|
12 |
+
pose_model = None
|
13 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
14 |
+
print(f"Pose Estimator: Using device: {device}")
|
15 |
+
|
16 |
+
# --- Constantes pour la couleur et le dessin ---
|
17 |
+
# Utilisation de tuples BGR pour les couleurs
|
18 |
+
DEFAULT_MARKER_COLOR = (255, 255, 255) # Blanc
|
19 |
+
MIN_PIXELS_FOR_COLOR = 20 # Nombre minimum de pixels valides dans la ROI pour tenter de calculer la couleur
|
20 |
+
CONFIDENCE_THRESHOLD_KEYPOINTS = 0.3 # Seuil pour considérer un keypoint fiable pour la ROI et le dessin
|
21 |
+
SKELETON_THICKNESS = 2
|
22 |
+
|
23 |
+
# Définition des segments du squelette (indices COCO 0-16)
|
24 |
+
# 0:Nose, 1:L_Eye, 2:R_Eye, 3:L_Ear, 4:R_Ear, 5:L_Shoulder, 6:R_Shoulder,
|
25 |
+
# 7:L_Elbow, 8:R_Elbow, 9:L_Wrist, 10:R_Wrist, 11:L_Hip, 12:R_Hip,
|
26 |
+
# 13:L_Knee, 14:R_Knee, 15:L_Ankle, 16:R_Ankle
|
27 |
+
SKELETON_EDGES = [
|
28 |
+
# Tête
|
29 |
+
(0, 1), (0, 2), (1, 3), (2, 4),
|
30 |
+
# Torse
|
31 |
+
(5, 6), (5, 11), (6, 12), (11, 12),
|
32 |
+
# Bras Gauche
|
33 |
+
(5, 7), (7, 9),
|
34 |
+
# Bras Droit
|
35 |
+
(6, 8), (8, 10),
|
36 |
+
# Jambe Gauche
|
37 |
+
(11, 13), (13, 15),
|
38 |
+
# Jambe Droite
|
39 |
+
(12, 14), (14, 16)
|
40 |
+
]
|
41 |
+
|
42 |
+
# Indices des keypoints pour le torse et les chevilles
|
43 |
+
TORSO_KP_INDICES = [5, 6, 11, 12] # Épaules, Hanches
|
44 |
+
LEFT_ANKLE_KP_INDEX = 15
|
45 |
+
RIGHT_ANKLE_KP_INDEX = 16
|
46 |
+
|
47 |
+
def _load_models():
|
48 |
+
"""Loads the models if they haven't been loaded yet."""
|
49 |
+
global person_processor, person_model, pose_processor, pose_model
|
50 |
+
|
51 |
+
if person_processor is None:
|
52 |
+
print("Loading RTDetr person detector model...")
|
53 |
+
person_processor = AutoProcessor.from_pretrained("PekingU/rtdetr_r50vd_coco_o365")
|
54 |
+
person_model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd_coco_o365", device_map=device)
|
55 |
+
print("✓ RTDetr loaded.")
|
56 |
+
|
57 |
+
if pose_processor is None:
|
58 |
+
print("Loading ViTPose pose estimation model...")
|
59 |
+
pose_processor = AutoProcessor.from_pretrained("usyd-community/vitpose-base-simple")
|
60 |
+
pose_model = VitPoseForPoseEstimation.from_pretrained("usyd-community/vitpose-base-simple", device_map=device)
|
61 |
+
print("✓ ViTPose loaded.")
|
62 |
+
|
63 |
+
def _is_color_greenish(bgr_pixel, threshold=10):
|
64 |
+
b, g, r = bgr_pixel
|
65 |
+
return g > b + threshold and g > r + threshold
|
66 |
+
|
67 |
+
def _is_color_grayscale(bgr_pixel, tolerance=30):
|
68 |
+
b, g, r = bgr_pixel
|
69 |
+
min_val, max_val = min(b, g, r), max(b, g, r)
|
70 |
+
is_dark = max_val < 50
|
71 |
+
is_light = min_val > 200
|
72 |
+
is_low_saturation = (max_val - min_val) < tolerance
|
73 |
+
return is_dark or is_light or is_low_saturation
|
74 |
+
|
75 |
+
def _get_average_color(roi_bgr):
|
76 |
+
"""Calcule la couleur moyenne d'une ROI après filtrage."""
|
77 |
+
if roi_bgr is None or roi_bgr.size == 0:
|
78 |
+
return DEFAULT_MARKER_COLOR
|
79 |
+
|
80 |
+
try:
|
81 |
+
pixels = roi_bgr.reshape(-1, 3)
|
82 |
+
valid_pixels = []
|
83 |
+
for pixel in pixels:
|
84 |
+
if not _is_color_greenish(pixel) and not _is_color_grayscale(pixel):
|
85 |
+
valid_pixels.append(pixel)
|
86 |
+
|
87 |
+
if len(valid_pixels) < MIN_PIXELS_FOR_COLOR:
|
88 |
+
return DEFAULT_MARKER_COLOR
|
89 |
+
|
90 |
+
avg_color = np.mean(valid_pixels, axis=0)
|
91 |
+
return tuple(map(int, avg_color))
|
92 |
+
|
93 |
+
except Exception as e:
|
94 |
+
print(f" Erreur calcul couleur moyenne: {e}. Utilisation couleur défaut.")
|
95 |
+
return DEFAULT_MARKER_COLOR
|
96 |
+
|
97 |
+
def get_player_data(image_bgr: np.ndarray) -> list:
|
98 |
+
"""
|
99 |
+
Detects persons, estimates pose, calculates average torso color,
|
100 |
+
and returns a list of data for each player.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
image_bgr: The input image in BGR format (NumPy array).
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
A list of dictionaries, each containing:
|
107 |
+
{
|
108 |
+
'keypoints': np.ndarray (17, 2),
|
109 |
+
'scores': np.ndarray (17,),
|
110 |
+
'bbox': np.ndarray (4,) [x1, y1, x2, y2],
|
111 |
+
'avg_color': tuple (b, g, r)
|
112 |
+
}
|
113 |
+
Returns an empty list if no persons are detected or an error occurs.
|
114 |
+
"""
|
115 |
+
_load_models()
|
116 |
+
player_list = []
|
117 |
+
height, width = image_bgr.shape[:2]
|
118 |
+
|
119 |
+
try:
|
120 |
+
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
|
121 |
+
image_pil = Image.fromarray(image_rgb)
|
122 |
+
|
123 |
+
# --- Stage 1: Detect humans ---
|
124 |
+
inputs_det = person_processor(images=image_pil, return_tensors="pt").to(device)
|
125 |
+
with torch.no_grad():
|
126 |
+
outputs_det = person_model(**inputs_det)
|
127 |
+
results_det = person_processor.post_process_object_detection(
|
128 |
+
outputs_det, target_sizes=torch.tensor([(height, width)]), threshold=0.5
|
129 |
+
)
|
130 |
+
result_det = results_det[0]
|
131 |
+
person_boxes = result_det["boxes"][result_det["labels"] == 0].cpu().numpy()
|
132 |
+
|
133 |
+
if len(person_boxes) == 0:
|
134 |
+
print("No persons detected.")
|
135 |
+
return player_list
|
136 |
+
|
137 |
+
person_boxes_coco = person_boxes.copy()
|
138 |
+
person_boxes_coco[:, 2] = person_boxes_coco[:, 2] - person_boxes_coco[:, 0]
|
139 |
+
person_boxes_coco[:, 3] = person_boxes_coco[:, 3] - person_boxes_coco[:, 1]
|
140 |
+
|
141 |
+
# --- Stage 2: Detect keypoints ---
|
142 |
+
inputs_pose = pose_processor(image_pil, boxes=[person_boxes_coco], return_tensors="pt").to(device)
|
143 |
+
with torch.no_grad():
|
144 |
+
outputs_pose = pose_model(**inputs_pose)
|
145 |
+
pose_results = pose_processor.post_process_pose_estimation(outputs_pose, boxes=[person_boxes_coco])
|
146 |
+
image_pose_result = pose_results[0]
|
147 |
+
|
148 |
+
if not image_pose_result:
|
149 |
+
print("Pose estimation did not return results.")
|
150 |
+
return player_list
|
151 |
+
|
152 |
+
# --- Stage 3: Process each person ---
|
153 |
+
for i, person_box_xyxy in enumerate(person_boxes):
|
154 |
+
if i >= len(image_pose_result): continue
|
155 |
+
|
156 |
+
pose_result = image_pose_result[i]
|
157 |
+
xy = pose_result['keypoints'].cpu().numpy()
|
158 |
+
scores = pose_result['scores'].cpu().numpy()
|
159 |
+
|
160 |
+
# Ensure xy shape is correct before proceeding
|
161 |
+
if xy.shape != (17, 2):
|
162 |
+
print(f"Person {i}: Unexpected keypoints shape {xy.shape}, skipping.")
|
163 |
+
continue
|
164 |
+
|
165 |
+
# -- Define Torso ROI --
|
166 |
+
reliable_torso_keypoints = xy[TORSO_KP_INDICES][scores[TORSO_KP_INDICES] > CONFIDENCE_THRESHOLD_KEYPOINTS]
|
167 |
+
x1_box, y1_box, x2_box, y2_box = map(int, person_box_xyxy)
|
168 |
+
box_h = y2_box - y1_box
|
169 |
+
box_w = x2_box - x1_box
|
170 |
+
if len(reliable_torso_keypoints) >= 3:
|
171 |
+
min_x_kp = int(np.min(reliable_torso_keypoints[:, 0]))
|
172 |
+
max_x_kp = int(np.max(reliable_torso_keypoints[:, 0]))
|
173 |
+
min_y_kp = int(np.min(reliable_torso_keypoints[:, 1]))
|
174 |
+
max_y_kp = int(np.max(reliable_torso_keypoints[:, 1]))
|
175 |
+
roi_x1 = max(x1_box, min_x_kp - 5); roi_y1 = max(y1_box, min_y_kp - 5)
|
176 |
+
roi_x2 = min(x2_box, max_x_kp + 5); roi_y2 = min(y2_box, max_y_kp + 5)
|
177 |
+
else:
|
178 |
+
roi_x1 = x1_box; roi_y1 = y1_box + int(0.1 * box_h)
|
179 |
+
roi_x2 = x2_box; roi_y2 = y1_box + int(0.6 * box_h)
|
180 |
+
roi_x1 = max(0, roi_x1); roi_y1 = max(0, roi_y1)
|
181 |
+
roi_x2 = min(width, roi_x2); roi_y2 = min(height, roi_y2)
|
182 |
+
|
183 |
+
# -- Extract Average Color --
|
184 |
+
avg_color = DEFAULT_MARKER_COLOR
|
185 |
+
if roi_y2 > roi_y1 and roi_x2 > roi_x1:
|
186 |
+
torso_roi = image_bgr[roi_y1:roi_y2, roi_x1:roi_x2]
|
187 |
+
avg_color = _get_average_color(torso_roi)
|
188 |
+
# else: # Pas besoin de message si ROI invalide, couleur par défaut suffit
|
189 |
+
# print(f"Person {i}: Invalid ROI, using default color.")
|
190 |
+
|
191 |
+
# -- Store player data --
|
192 |
+
player_data = {
|
193 |
+
'keypoints': xy,
|
194 |
+
'scores': scores,
|
195 |
+
'bbox': person_box_xyxy, # Utiliser la bbox originale xyxy
|
196 |
+
'avg_color': avg_color
|
197 |
+
}
|
198 |
+
player_list.append(player_data)
|
199 |
+
|
200 |
+
except Exception as e:
|
201 |
+
print(f"Error during player data extraction: {e}")
|
202 |
+
import traceback
|
203 |
+
traceback.print_exc()
|
204 |
+
# Retourner une liste vide en cas d'erreur majeure
|
205 |
+
return []
|
206 |
+
|
207 |
+
return player_list
|
208 |
+
|
209 |
+
# Example usage (optional, for testing the module directly)
|
210 |
+
if __name__ == '__main__':
|
211 |
+
test_image_path = 'img3.png'
|
212 |
+
|
213 |
+
if not Path(test_image_path).exists():
|
214 |
+
print(f"Test image not found: {test_image_path}")
|
215 |
+
else:
|
216 |
+
print(f"Testing player data extraction with image: {test_image_path}")
|
217 |
+
input_img = cv2.imread(test_image_path)
|
218 |
+
|
219 |
+
if input_img is None:
|
220 |
+
print(f"Failed to load test image: {test_image_path}")
|
221 |
+
else:
|
222 |
+
print("Getting player data...")
|
223 |
+
players = get_player_data(input_img)
|
224 |
+
print(f"✓ Found data for {len(players)} players.")
|
225 |
+
|
226 |
+
# --- Draw markers and info on original image for testing ---
|
227 |
+
output_img_test = input_img.copy()
|
228 |
+
for idx, p_data in enumerate(players):
|
229 |
+
kps = p_data['keypoints']
|
230 |
+
scores = p_data['scores']
|
231 |
+
bbox = p_data['bbox']
|
232 |
+
color = p_data['avg_color']
|
233 |
+
|
234 |
+
# Determine reference point (ankles or bbox bottom mid)
|
235 |
+
l_ankle_pt = kps[LEFT_ANKLE_KP_INDEX]
|
236 |
+
r_ankle_pt = kps[RIGHT_ANKLE_KP_INDEX]
|
237 |
+
l_ankle_score = scores[LEFT_ANKLE_KP_INDEX]
|
238 |
+
r_ankle_score = scores[RIGHT_ANKLE_KP_INDEX]
|
239 |
+
|
240 |
+
ref_point = None
|
241 |
+
if l_ankle_score > CONFIDENCE_THRESHOLD_KEYPOINTS and r_ankle_score > CONFIDENCE_THRESHOLD_KEYPOINTS:
|
242 |
+
ref_point = tuple(map(int, (l_ankle_pt + r_ankle_pt) / 2))
|
243 |
+
elif l_ankle_score > CONFIDENCE_THRESHOLD_KEYPOINTS:
|
244 |
+
ref_point = tuple(map(int, l_ankle_pt))
|
245 |
+
elif r_ankle_score > CONFIDENCE_THRESHOLD_KEYPOINTS:
|
246 |
+
ref_point = tuple(map(int, r_ankle_pt))
|
247 |
+
else:
|
248 |
+
x1, y1, x2, y2 = map(int, bbox)
|
249 |
+
ref_point = (int((x1 + x2) / 2), y2)
|
250 |
+
|
251 |
+
# Draw marker at reference point
|
252 |
+
if ref_point:
|
253 |
+
cv2.circle(output_img_test, ref_point, 8, color, -1, cv2.LINE_AA)
|
254 |
+
cv2.circle(output_img_test, ref_point, 8, (0,0,0), 1, cv2.LINE_AA) # Black outline
|
255 |
+
# Draw player index
|
256 |
+
cv2.putText(output_img_test, str(idx), (ref_point[0]+5, ref_point[1]-5),
|
257 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,0,0), 2, cv2.LINE_AA)
|
258 |
+
cv2.putText(output_img_test, str(idx), (ref_point[0]+5, ref_point[1]-5),
|
259 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 1, cv2.LINE_AA)
|
260 |
+
|
261 |
+
cv2.imshow("Original Image", input_img)
|
262 |
+
cv2.imshow("Player Markers Test", output_img_test)
|
263 |
+
print("Displaying test results. Press any key to exit.")
|
264 |
+
cv2.waitKey(0)
|
265 |
+
cv2.destroyAllWindows()
|
requirements.txt
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Installez PyTorch en utilisant la commande spécifique si nécessaire :
|
2 |
+
# pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121
|
3 |
+
|
4 |
+
# Dépendances principales
|
5 |
+
torch
|
6 |
+
torchvision
|
7 |
+
torchaudio
|
8 |
+
numpy
|
9 |
+
opencv-python
|
10 |
+
pytorch-lightning
|
11 |
+
soccernet
|
12 |
+
kornia
|
13 |
+
|
14 |
+
# Ajouté car requis par sn_segmentation
|
15 |
+
|
16 |
+
# Dépendances pour l'estimation de pose (ViTPose)
|
17 |
+
transformers
|
18 |
+
supervision
|
19 |
+
Pillow # Souvent une dépendance de transformers/supervision, mais explicite ici
|
20 |
+
accelerate
|
21 |
+
# scikit-learn # Retiré car K-Means n'est plus utilisé
|
22 |
+
gradio
|
tvcalib/cam_distr/tv_main_behind.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import pi
|
2 |
+
from tvcalib.utils.data_distr import mean_std_with_confidence_interval
|
3 |
+
|
4 |
+
|
5 |
+
def get_cam_distr(sigma_scale: float, batch_dim: int, temporal_dim: int):
|
6 |
+
cam_distr = {
|
7 |
+
"pan": {
|
8 |
+
"minmax": (pi / 4, 3 * pi / 4), # in deg 45°, 135°
|
9 |
+
"dimension": (
|
10 |
+
batch_dim,
|
11 |
+
temporal_dim,
|
12 |
+
),
|
13 |
+
},
|
14 |
+
"tilt": {
|
15 |
+
"minmax": (pi / 16, pi / 2), # in deg 11.25°, 90°
|
16 |
+
"dimension": (
|
17 |
+
batch_dim,
|
18 |
+
temporal_dim,
|
19 |
+
),
|
20 |
+
},
|
21 |
+
"roll": {
|
22 |
+
"minmax": (-pi / 32, pi / 32), # in deg -5.6°, 5.6°
|
23 |
+
"dimension": (
|
24 |
+
batch_dim,
|
25 |
+
temporal_dim,
|
26 |
+
),
|
27 |
+
},
|
28 |
+
"aov": {
|
29 |
+
"minmax": (pi / 22, pi / 2), # (8.2°, 90°)
|
30 |
+
"dimension": (
|
31 |
+
batch_dim,
|
32 |
+
temporal_dim,
|
33 |
+
),
|
34 |
+
},
|
35 |
+
"c_x": {
|
36 |
+
"minmax": (-32.5, -52.5),
|
37 |
+
"dimension": (
|
38 |
+
batch_dim,
|
39 |
+
1,
|
40 |
+
),
|
41 |
+
},
|
42 |
+
"c_y": {
|
43 |
+
"minmax": (-5.0, 5.0),
|
44 |
+
"dimension": (
|
45 |
+
batch_dim,
|
46 |
+
1,
|
47 |
+
),
|
48 |
+
},
|
49 |
+
"c_z": {
|
50 |
+
"minmax": (-35.0, -1.0),
|
51 |
+
"dimension": (
|
52 |
+
batch_dim,
|
53 |
+
1,
|
54 |
+
),
|
55 |
+
},
|
56 |
+
}
|
57 |
+
|
58 |
+
for k, params in cam_distr.items():
|
59 |
+
cam_distr[k]["mean_std"] = mean_std_with_confidence_interval(
|
60 |
+
*params["minmax"], sigma_scale=sigma_scale
|
61 |
+
)
|
62 |
+
return cam_distr
|
63 |
+
|
64 |
+
|
65 |
+
def get_dist_distr(batch_dim: int, temporal_dim: int, _sigma_scale: float = 2.57):
|
66 |
+
return {
|
67 |
+
"k1": {
|
68 |
+
"minmax": [0.0, 0.5], # we clip min(0.0, x)
|
69 |
+
"mean_std": (0.0, _sigma_scale * 0.5),
|
70 |
+
"dimension": (batch_dim, temporal_dim),
|
71 |
+
},
|
72 |
+
"k2": {
|
73 |
+
"minmax": [-0.1, 0.1],
|
74 |
+
"mean_std": (0.0, _sigma_scale * 0.1),
|
75 |
+
"dimension": (batch_dim, temporal_dim),
|
76 |
+
},
|
77 |
+
}
|
tvcalib/cam_distr/tv_main_center.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from math import pi
|
3 |
+
from ..utils.data_distr import mean_std_with_confidence_interval
|
4 |
+
|
5 |
+
|
6 |
+
def get_cam_distr(sigma_scale: float, batch_dim: int, temporal_dim: int):
|
7 |
+
cam_distr = {
|
8 |
+
"pan": {
|
9 |
+
"minmax": (-pi / 4, pi / 4), # in deg -45°, 45°
|
10 |
+
"dimension": (
|
11 |
+
batch_dim,
|
12 |
+
temporal_dim,
|
13 |
+
),
|
14 |
+
},
|
15 |
+
"tilt": {
|
16 |
+
"minmax": (pi / 4, pi / 2), # in deg 45°, 90°
|
17 |
+
"dimension": (
|
18 |
+
batch_dim,
|
19 |
+
temporal_dim,
|
20 |
+
),
|
21 |
+
},
|
22 |
+
"roll": {
|
23 |
+
"minmax": (-pi / 18, pi / 18), # in deg -10°, 10°
|
24 |
+
"dimension": (
|
25 |
+
batch_dim,
|
26 |
+
temporal_dim,
|
27 |
+
),
|
28 |
+
},
|
29 |
+
"aov": {
|
30 |
+
"minmax": (pi / 22, pi / 2), # (8.2°, 90°)
|
31 |
+
"dimension": (
|
32 |
+
batch_dim,
|
33 |
+
temporal_dim,
|
34 |
+
),
|
35 |
+
},
|
36 |
+
"c_x": {
|
37 |
+
"minmax": (-12.0, 12.0), # (-36, 36) for entire main tribune
|
38 |
+
"dimension": (
|
39 |
+
batch_dim,
|
40 |
+
1,
|
41 |
+
),
|
42 |
+
},
|
43 |
+
"c_y": {
|
44 |
+
"minmax": (40.0, 110.0),
|
45 |
+
"dimension": (
|
46 |
+
batch_dim,
|
47 |
+
1,
|
48 |
+
),
|
49 |
+
},
|
50 |
+
"c_z": {
|
51 |
+
"minmax": (-40.0, -5.0),
|
52 |
+
"dimension": (
|
53 |
+
batch_dim,
|
54 |
+
1,
|
55 |
+
),
|
56 |
+
},
|
57 |
+
}
|
58 |
+
|
59 |
+
for k, params in cam_distr.items():
|
60 |
+
cam_distr[k]["mean_std"] = mean_std_with_confidence_interval(
|
61 |
+
*params["minmax"], sigma_scale=sigma_scale
|
62 |
+
)
|
63 |
+
return cam_distr
|
64 |
+
|
65 |
+
|
66 |
+
def get_dist_distr(batch_dim: int, temporal_dim: int, _sigma_scale: float = 2.57):
|
67 |
+
return {
|
68 |
+
"k1": {
|
69 |
+
"minmax": [0.0, 0.5], # we clip min(0.0, x)
|
70 |
+
"mean_std": (0.0, _sigma_scale * 0.5),
|
71 |
+
"dimension": (batch_dim, temporal_dim),
|
72 |
+
},
|
73 |
+
"k2": {
|
74 |
+
"minmax": [-0.1, 0.1],
|
75 |
+
"mean_std": (0.0, _sigma_scale * 0.1),
|
76 |
+
"dimension": (batch_dim, temporal_dim),
|
77 |
+
},
|
78 |
+
}
|
tvcalib/cam_distr/tv_main_left.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import pi
|
2 |
+
from tvcalib.utils.data_distr import mean_std_with_confidence_interval
|
3 |
+
|
4 |
+
|
5 |
+
def get_cam_distr(sigma_scale: float, batch_dim: int, temporal_dim: int):
|
6 |
+
cam_distr = {
|
7 |
+
"pan": {
|
8 |
+
"minmax": (-pi / 4, pi / 4), # in deg -45°, 45°
|
9 |
+
"dimension": (
|
10 |
+
batch_dim,
|
11 |
+
temporal_dim,
|
12 |
+
),
|
13 |
+
},
|
14 |
+
"tilt": {
|
15 |
+
"minmax": (pi / 4, pi / 2), # in deg 45°, 90°
|
16 |
+
"dimension": (
|
17 |
+
batch_dim,
|
18 |
+
temporal_dim,
|
19 |
+
),
|
20 |
+
},
|
21 |
+
"roll": {
|
22 |
+
"minmax": (-pi / 18, pi / 18), # in deg -10°, 10°
|
23 |
+
"dimension": (
|
24 |
+
batch_dim,
|
25 |
+
temporal_dim,
|
26 |
+
),
|
27 |
+
},
|
28 |
+
"aov": {
|
29 |
+
"minmax": (pi / 22, pi / 2), # (8.2°, 90°)
|
30 |
+
"dimension": (
|
31 |
+
batch_dim,
|
32 |
+
temporal_dim,
|
33 |
+
),
|
34 |
+
},
|
35 |
+
"c_x": {
|
36 |
+
"minmax": (-36 - 16.5, -36 + 16.5),
|
37 |
+
"dimension": (
|
38 |
+
batch_dim,
|
39 |
+
1,
|
40 |
+
),
|
41 |
+
},
|
42 |
+
"c_y": {
|
43 |
+
"minmax": (40.0, 110.0),
|
44 |
+
"dimension": (
|
45 |
+
batch_dim,
|
46 |
+
1,
|
47 |
+
),
|
48 |
+
},
|
49 |
+
"c_z": {
|
50 |
+
"minmax": (-40.0, -5.0),
|
51 |
+
"dimension": (
|
52 |
+
batch_dim,
|
53 |
+
1,
|
54 |
+
),
|
55 |
+
},
|
56 |
+
}
|
57 |
+
|
58 |
+
for k, params in cam_distr.items():
|
59 |
+
cam_distr[k]["mean_std"] = mean_std_with_confidence_interval(
|
60 |
+
*params["minmax"], sigma_scale=sigma_scale
|
61 |
+
)
|
62 |
+
return cam_distr
|
63 |
+
|
64 |
+
|
65 |
+
def get_dist_distr(batch_dim: int, temporal_dim: int, _sigma_scale: float = 2.57):
|
66 |
+
return {
|
67 |
+
"k1": {
|
68 |
+
"minmax": [0.0, 0.5], # we clip min(0.0, x)
|
69 |
+
"mean_std": (0.0, _sigma_scale * 0.5),
|
70 |
+
"dimension": (batch_dim, temporal_dim),
|
71 |
+
},
|
72 |
+
"k2": {
|
73 |
+
"minmax": [-0.1, 0.1],
|
74 |
+
"mean_std": (0.0, _sigma_scale * 0.1),
|
75 |
+
"dimension": (batch_dim, temporal_dim),
|
76 |
+
},
|
77 |
+
}
|
tvcalib/cam_distr/tv_main_right.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import pi
|
2 |
+
from tvcalib.utils.data_distr import mean_std_with_confidence_interval
|
3 |
+
|
4 |
+
|
5 |
+
def get_cam_distr(sigma_scale: float, batch_dim: int, temporal_dim: int):
|
6 |
+
cam_distr = {
|
7 |
+
"pan": {
|
8 |
+
"minmax": (-pi / 4, pi / 4), # in deg -45°, 45°
|
9 |
+
"dimension": (
|
10 |
+
batch_dim,
|
11 |
+
temporal_dim,
|
12 |
+
),
|
13 |
+
},
|
14 |
+
"tilt": {
|
15 |
+
"minmax": (pi / 4, pi / 2), # in deg 45°, 90°
|
16 |
+
"dimension": (
|
17 |
+
batch_dim,
|
18 |
+
temporal_dim,
|
19 |
+
),
|
20 |
+
},
|
21 |
+
"roll": {
|
22 |
+
"minmax": (-pi / 18, pi / 18), # in deg -10°, 10°
|
23 |
+
"dimension": (
|
24 |
+
batch_dim,
|
25 |
+
temporal_dim,
|
26 |
+
),
|
27 |
+
},
|
28 |
+
"aov": {
|
29 |
+
"minmax": (pi / 22, pi / 2), # (8.2°, 90°)
|
30 |
+
"dimension": (
|
31 |
+
batch_dim,
|
32 |
+
temporal_dim,
|
33 |
+
),
|
34 |
+
},
|
35 |
+
"c_x": {
|
36 |
+
"minmax": (36 - 16.5, 36 + 16.5),
|
37 |
+
"dimension": (
|
38 |
+
batch_dim,
|
39 |
+
1,
|
40 |
+
),
|
41 |
+
},
|
42 |
+
"c_y": {
|
43 |
+
"minmax": (40.0, 110.0),
|
44 |
+
"dimension": (
|
45 |
+
batch_dim,
|
46 |
+
1,
|
47 |
+
),
|
48 |
+
},
|
49 |
+
"c_z": {
|
50 |
+
"minmax": (-40.0, -5.0),
|
51 |
+
"dimension": (
|
52 |
+
batch_dim,
|
53 |
+
1,
|
54 |
+
),
|
55 |
+
},
|
56 |
+
}
|
57 |
+
|
58 |
+
for k, params in cam_distr.items():
|
59 |
+
cam_distr[k]["mean_std"] = mean_std_with_confidence_interval(
|
60 |
+
*params["minmax"], sigma_scale=sigma_scale
|
61 |
+
)
|
62 |
+
return cam_distr
|
63 |
+
|
64 |
+
|
65 |
+
def get_dist_distr(batch_dim: int, temporal_dim: int, _sigma_scale: float = 2.57):
|
66 |
+
return {
|
67 |
+
"k1": {
|
68 |
+
"minmax": [0.0, 0.5], # we clip min(0.0, x)
|
69 |
+
"mean_std": (0.0, _sigma_scale * 0.5),
|
70 |
+
"dimension": (batch_dim, temporal_dim),
|
71 |
+
},
|
72 |
+
"k2": {
|
73 |
+
"minmax": [-0.1, 0.1],
|
74 |
+
"mean_std": (0.0, _sigma_scale * 0.1),
|
75 |
+
"dimension": (batch_dim, temporal_dim),
|
76 |
+
},
|
77 |
+
}
|
tvcalib/cam_distr/tv_main_tribune.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import pi
|
2 |
+
from tvcalib.utils.data_distr import mean_std_with_confidence_interval
|
3 |
+
|
4 |
+
|
5 |
+
def get_cam_distr(sigma_scale: float, batch_dim: int, temporal_dim: int):
|
6 |
+
cam_distr = {
|
7 |
+
"pan": {
|
8 |
+
"minmax": (-pi / 4, pi / 4), # in deg -45°, 45°
|
9 |
+
"dimension": (
|
10 |
+
batch_dim,
|
11 |
+
temporal_dim,
|
12 |
+
),
|
13 |
+
},
|
14 |
+
"tilt": {
|
15 |
+
"minmax": (pi / 4, pi / 2), # in deg 45°, 90°
|
16 |
+
"dimension": (
|
17 |
+
batch_dim,
|
18 |
+
temporal_dim,
|
19 |
+
),
|
20 |
+
},
|
21 |
+
"roll": {
|
22 |
+
"minmax": (-pi / 18, pi / 18), # in deg -10°, 10°
|
23 |
+
"dimension": (
|
24 |
+
batch_dim,
|
25 |
+
temporal_dim,
|
26 |
+
),
|
27 |
+
},
|
28 |
+
"aov": {
|
29 |
+
"minmax": (pi / 22, pi / 2), # (8.2°, 90°)
|
30 |
+
"dimension": (
|
31 |
+
batch_dim,
|
32 |
+
temporal_dim,
|
33 |
+
),
|
34 |
+
},
|
35 |
+
"c_x": {
|
36 |
+
"minmax": (-40.0, 40.0), # entire main tribune
|
37 |
+
"dimension": (
|
38 |
+
batch_dim,
|
39 |
+
1,
|
40 |
+
),
|
41 |
+
},
|
42 |
+
"c_y": {
|
43 |
+
"minmax": (40.0, 110.0),
|
44 |
+
"dimension": (
|
45 |
+
batch_dim,
|
46 |
+
1,
|
47 |
+
),
|
48 |
+
},
|
49 |
+
"c_z": {
|
50 |
+
"minmax": (-40.0, -5.0),
|
51 |
+
"dimension": (
|
52 |
+
batch_dim,
|
53 |
+
1,
|
54 |
+
),
|
55 |
+
},
|
56 |
+
}
|
57 |
+
|
58 |
+
for k, params in cam_distr.items():
|
59 |
+
cam_distr[k]["mean_std"] = mean_std_with_confidence_interval(
|
60 |
+
*params["minmax"], sigma_scale=sigma_scale
|
61 |
+
)
|
62 |
+
return cam_distr
|
63 |
+
|
64 |
+
|
65 |
+
def get_dist_distr(batch_dim: int, temporal_dim: int, _sigma_scale: float = 2.57):
|
66 |
+
return {
|
67 |
+
"k1": {
|
68 |
+
"minmax": [0.0, 0.5], # we clip min(0.0, x)
|
69 |
+
"mean_std": (0.0, _sigma_scale * 0.5),
|
70 |
+
"dimension": (batch_dim, temporal_dim),
|
71 |
+
},
|
72 |
+
"k2": {
|
73 |
+
"minmax": [-0.1, 0.1],
|
74 |
+
"mean_std": (0.0, _sigma_scale * 0.1),
|
75 |
+
"dimension": (batch_dim, temporal_dim),
|
76 |
+
},
|
77 |
+
}
|
tvcalib/cam_modules.py
ADDED
@@ -0,0 +1,583 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple, Dict, Union
|
2 |
+
|
3 |
+
from pytorch_lightning import LightningModule
|
4 |
+
import torch
|
5 |
+
import kornia
|
6 |
+
import torch.nn as nn
|
7 |
+
from .utils.data_distr import FeatureScalerZScore
|
8 |
+
|
9 |
+
|
10 |
+
class CameraParameterWLensDistDictZScore(LightningModule):
|
11 |
+
"""Holds individual camera parameters including lens distortion parameters as nn.Modul"""
|
12 |
+
|
13 |
+
def __init__(self, cam_distr, dist_distr, device="cpu"):
|
14 |
+
super(CameraParameterWLensDistDictZScore, self).__init__()
|
15 |
+
|
16 |
+
self.cam_distr = cam_distr
|
17 |
+
self._device = device
|
18 |
+
|
19 |
+
# phi raw
|
20 |
+
self.param_dict = torch.nn.ParameterDict(
|
21 |
+
{
|
22 |
+
k: torch.nn.parameter.Parameter(
|
23 |
+
torch.zeros(
|
24 |
+
*cam_distr[k]["dimension"],
|
25 |
+
device=device,
|
26 |
+
),
|
27 |
+
requires_grad=False
|
28 |
+
if ("no_grad" in cam_distr[k]) and (cam_distr[k]["no_grad"] == True)
|
29 |
+
else True,
|
30 |
+
)
|
31 |
+
for k in cam_distr.keys()
|
32 |
+
}
|
33 |
+
)
|
34 |
+
|
35 |
+
# denormalization module to get phi_target
|
36 |
+
self.feature_scaler = torch.nn.ModuleDict(
|
37 |
+
{k: FeatureScalerZScore(*cam_distr[k]["mean_std"]) for k in cam_distr.keys()}
|
38 |
+
)
|
39 |
+
|
40 |
+
self.dist_distr = dist_distr
|
41 |
+
if self.dist_distr is not None:
|
42 |
+
self.param_dict_dist = torch.nn.ParameterDict(
|
43 |
+
{
|
44 |
+
k: torch.nn.Parameter(torch.zeros(*dist_distr[k]["dimension"], device=device))
|
45 |
+
for k in dist_distr.keys()
|
46 |
+
}
|
47 |
+
)
|
48 |
+
# TODO: modify later to dynamically cunstruct a tensor of shape (k_1,k_2,p_1,p_2[,k_3[,k_4,k_5,k_6[,s_1,s_2,s_3,s_4[,\tau_x,\tau_y]]]])
|
49 |
+
#
|
50 |
+
|
51 |
+
self.feature_scaler_dist_coeff = torch.nn.ModuleDict(
|
52 |
+
{k: FeatureScalerZScore(*dist_distr[k]["mean_std"]) for k in dist_distr.keys()}
|
53 |
+
)
|
54 |
+
|
55 |
+
def initialize(
|
56 |
+
self,
|
57 |
+
update_dict_cam: Union[Dict[str, Union[float, torch.tensor]], None],
|
58 |
+
update_dict_dist=None,
|
59 |
+
):
|
60 |
+
"""Initializes all camera parameters with zeros and replace specific values with provided values
|
61 |
+
|
62 |
+
Args:
|
63 |
+
update_dict_cam (Dict[str, Union[float, torch.tensor]]): Parameters to be updated
|
64 |
+
"""
|
65 |
+
|
66 |
+
for k in self.param_dict.keys():
|
67 |
+
self.param_dict[k].data = torch.zeros(
|
68 |
+
*self.cam_distr[k]["dimension"], device=self._device
|
69 |
+
)
|
70 |
+
if self.dist_distr is not None:
|
71 |
+
for k in self.dist_distr.keys():
|
72 |
+
self.param_dict_dist[k].data = torch.zeros(
|
73 |
+
*self.dist_distr[k]["dimension"], device=self._device
|
74 |
+
)
|
75 |
+
|
76 |
+
if update_dict_cam is not None and len(update_dict_cam) > 0:
|
77 |
+
for k, v in update_dict_cam.items():
|
78 |
+
self.param_dict[k].data = (
|
79 |
+
torch.zeros(*self.cam_distr[k]["dimension"], device=self._device) + v
|
80 |
+
)
|
81 |
+
if update_dict_dist is not None:
|
82 |
+
raise NotImplementedError
|
83 |
+
|
84 |
+
def forward(self):
|
85 |
+
phi_dict = {}
|
86 |
+
for k, param in self.param_dict.items():
|
87 |
+
phi_dict[k] = self.feature_scaler[k](param)
|
88 |
+
|
89 |
+
if self.dist_distr is None:
|
90 |
+
return phi_dict, None
|
91 |
+
|
92 |
+
# This is a vector with 4, 5, 8, 12 or 14 elements with shape :math:`(*, n)` depending on the provided dict of coefficients
|
93 |
+
# assumes dict is ordered according (k_1,k_2,p_1,p_2[,k_3[,k_4,k_5,k_6[,s_1,s_2,s_3,s_4[,\tau_x,\tau_y]]]])
|
94 |
+
psi = torch.stack(
|
95 |
+
[
|
96 |
+
torch.clamp(
|
97 |
+
self.feature_scaler_dist_coeff[k](param),
|
98 |
+
min=self.dist_distr[k]["minmax"][0],
|
99 |
+
max=self.dist_distr[k]["minmax"][1],
|
100 |
+
)
|
101 |
+
for k, param in self.param_dict_dist.items()
|
102 |
+
],
|
103 |
+
dim=-1, # stack individual features and not arbirary leading dimensions
|
104 |
+
)
|
105 |
+
|
106 |
+
return phi_dict, psi
|
107 |
+
|
108 |
+
|
109 |
+
class SNProjectiveCamera:
|
110 |
+
def __init__(
|
111 |
+
self,
|
112 |
+
phi_dict: Dict[str, torch.tensor],
|
113 |
+
psi: torch.tensor,
|
114 |
+
principal_point: Tuple[float, float],
|
115 |
+
image_width: int,
|
116 |
+
image_height: int,
|
117 |
+
device: str = "cpu",
|
118 |
+
nan_check=True,
|
119 |
+
) -> None:
|
120 |
+
"""Projective camera defined as K @ R [I|-t] with lens distortion module and batch dimensions B,T.
|
121 |
+
|
122 |
+
Following Euler angles convention, we use a ZXZ succession of intrinsic rotations in order to describe
|
123 |
+
the orientation of the camera. Starting from the world reference axis system, we first apply a rotation
|
124 |
+
around the Z axis to pan the camera. Then the obtained axis system is rotated around its x axis in order to tilt the camera.
|
125 |
+
Then the last rotation around the z axis of the new axis system alows to roll the camera. Note that this z axis is the principal axis of the camera.
|
126 |
+
|
127 |
+
As T is not provided for camra location and lens distortion, these parameters are assumed to be fixed accross T.
|
128 |
+
phi_dict is a dict of parameters containing:
|
129 |
+
{
|
130 |
+
'aov_x, torch.Size([B, T])',
|
131 |
+
'pan, torch.Size([B, T])',
|
132 |
+
'tilt, torch.Size([B, T])',
|
133 |
+
'roll, torch.Size([B, T])',
|
134 |
+
'c_x, torch.Size([B, 1])',
|
135 |
+
'c_y, torch.Size([B, 1])',
|
136 |
+
'c_z, torch.Size([B, 1])',
|
137 |
+
}
|
138 |
+
|
139 |
+
Internally fuses B and T dimension to pseudo batch dimension.
|
140 |
+
{
|
141 |
+
'aov_x, torch.Size([B*T])',
|
142 |
+
'pan, torch.Size([B*T])',
|
143 |
+
'tilt, torch.Size([B*T])'
|
144 |
+
'roll, torch.Size([B*T])',
|
145 |
+
'c_x, torch.Size([B])',
|
146 |
+
'c_y, torch.Size([B])',
|
147 |
+
'c_z, torch.Size([B])',
|
148 |
+
}
|
149 |
+
|
150 |
+
aov_x, pan, tilt, roll are assumed in radian.
|
151 |
+
|
152 |
+
Note on lens distortion:
|
153 |
+
Lens distortion coefficients are independent from image resolution!
|
154 |
+
We I(dist_points(K_ndc, dist_coeff, points2d_ndc)) == I(dist_points(K_raster, dist_coeff, points2d_raster))
|
155 |
+
|
156 |
+
Args:
|
157 |
+
phi_dict (Dict[str, torch.tensor]): See example above
|
158 |
+
psi (Union[None, torch.Tensor]): distortion coefficients as concatinated vector according to https://kornia.readthedocs.io/en/latest/geometry.calibration.html of shape (B, T, {2, 4, 5,8,12, 14})
|
159 |
+
principal_point (Tuple[float, float]): Principal point assumed to be fixed across all samples (B,T,)
|
160 |
+
image_width (int): assumed to be fixed across all samples (B,T,)
|
161 |
+
image_height (int): assumed to be fixed across all samples (B,T,)
|
162 |
+
"""
|
163 |
+
|
164 |
+
# fuse B and T dimension
|
165 |
+
phi_dict_flat = {}
|
166 |
+
for k, v in phi_dict.items():
|
167 |
+
if len(v.shape) == 2:
|
168 |
+
phi_dict_flat[k] = v.view(v.shape[0] * v.shape[1])
|
169 |
+
elif len(v.shape) == 3:
|
170 |
+
phi_dict_flat[k] = v.view(v.shape[0] * v.shape[1], v.shape[-1])
|
171 |
+
|
172 |
+
self.batch_dim, self.temporal_dim = phi_dict["pan"].shape
|
173 |
+
self.pseudo_batch_size = phi_dict_flat["pan"].shape[0]
|
174 |
+
self.phi_dict_flat = phi_dict_flat
|
175 |
+
|
176 |
+
self.principal_point = principal_point
|
177 |
+
self.image_width = image_width
|
178 |
+
self.image_height = image_height
|
179 |
+
self.device = device
|
180 |
+
|
181 |
+
self.psi = psi
|
182 |
+
if self.psi is not None:
|
183 |
+
if self.psi.shape[-1] != 2:
|
184 |
+
raise NotImplementedError
|
185 |
+
|
186 |
+
# :math:`(k_1,k_2,p_1,p_2[,k_3[,k_4,k_5,k_6[,s_1,s_2,s_3,s_4[,\tau_x,\tau_y]]]])`.
|
187 |
+
# psi is a vector with 2, 4, 5, 8, 12 or 14 elements with shape :math:`(*, n)`.
|
188 |
+
if self.psi.shape[-1] == 2:
|
189 |
+
# assume zero tangential coefficients
|
190 |
+
psi_ext = torch.zeros(*list(self.psi.shape[:-1]), 4)
|
191 |
+
psi_ext[..., :2] = self.psi
|
192 |
+
self.psi = psi_ext
|
193 |
+
self.lens_dist_coeff = self.psi.view(self.pseudo_batch_size, self.psi.shape[-1]).to(
|
194 |
+
self.device
|
195 |
+
)
|
196 |
+
|
197 |
+
self.intrinsics_ndc = self.construct_intrinsics_ndc()
|
198 |
+
self.intrinsics_raster = self.construct_intrinsics_raster()
|
199 |
+
|
200 |
+
self.rotation = self.rotation_from_euler_angles(
|
201 |
+
*[phi_dict_flat[k] for k in ["pan", "tilt", "roll"]]
|
202 |
+
)
|
203 |
+
self.position = torch.stack([phi_dict_flat[k] for k in ["c_x", "c_y", "c_z"]], dim=-1)
|
204 |
+
self.position = self.position.repeat_interleave(
|
205 |
+
int(self.pseudo_batch_size / self.batch_dim), dim=0
|
206 |
+
) # (B, 3) # TODO: probably needs modification if B > 0?
|
207 |
+
self.P_ndc = self.construct_projection_matrix(self.intrinsics_ndc)
|
208 |
+
self.P_raster = self.construct_projection_matrix(self.intrinsics_raster)
|
209 |
+
self.phi_dict = phi_dict
|
210 |
+
|
211 |
+
self.nan_check = nan_check
|
212 |
+
super().__init__()
|
213 |
+
|
214 |
+
def construct_projection_matrix(self, intrinsics):
|
215 |
+
It = torch.eye(4, device=self.device)[:-1].repeat(self.pseudo_batch_size, 1, 1)
|
216 |
+
It[:, :, -1] = -self.position # (B, 3, 4)
|
217 |
+
self.It = It
|
218 |
+
return intrinsics @ self.rotation @ It # # (B, 3, 4)
|
219 |
+
|
220 |
+
def construct_intrinsics_ndc(self):
|
221 |
+
# assume that the principal point is (0,0)
|
222 |
+
K = torch.eye(3, requires_grad=False, device=self.device)
|
223 |
+
K = K.reshape((1, 3, 3)).repeat(self.pseudo_batch_size, 1, 1)
|
224 |
+
K[:, 0, 0] = self.get_fl_from_aov_rad(self.phi_dict_flat["aov"], d=2)
|
225 |
+
K[:, 1, 1] = self.get_fl_from_aov_rad(
|
226 |
+
self.phi_dict_flat["aov"], d=2 * self.image_width / self.image_height
|
227 |
+
)
|
228 |
+
return K
|
229 |
+
|
230 |
+
def construct_intrinsics_raster(self):
|
231 |
+
# assume that the principal point is (W/2,H/2)
|
232 |
+
K = torch.eye(3, requires_grad=False, device=self.device)
|
233 |
+
K = K.reshape((1, 3, 3)).repeat(self.pseudo_batch_size, 1, 1)
|
234 |
+
K[:, 0, 0] = self.get_fl_from_aov_rad(self.phi_dict_flat["aov"], d=self.image_width)
|
235 |
+
K[:, 1, 1] = self.get_fl_from_aov_rad(self.phi_dict_flat["aov"], d=self.image_width)
|
236 |
+
K[:, 0, 2] = self.principal_point[0]
|
237 |
+
K[:, 1, 2] = self.principal_point[1]
|
238 |
+
return K
|
239 |
+
|
240 |
+
def __str__(self) -> str:
|
241 |
+
return f"aov_deg={torch.rad2deg(self.phi_dict['aov'])}, t={torch.stack([self.phi_dict[k] for k in ['c_x', 'c_y', 'c_z']], dim=-1)}, pan_deg={torch.rad2deg(self.phi_dict['pan'])} tilt_deg={torch.rad2deg(self.phi_dict['tilt'])} roll_deg={torch.rad2deg(self.phi_dict['roll'])}"
|
242 |
+
|
243 |
+
def str_pan_tilt_roll_fl(self, b, t):
|
244 |
+
r = f"FOV={torch.rad2deg(self.phi_dict['aov'][b, t]):.1f}°, pan={torch.rad2deg(self.phi_dict['pan'][b, t]):.1f}° tilt={torch.rad2deg(self.phi_dict['tilt'][b, t]):.1f}° roll={torch.rad2deg(self.phi_dict['roll'][b, t]):.1f}°"
|
245 |
+
return r
|
246 |
+
|
247 |
+
def str_lens_distortion_coeff(self, b):
|
248 |
+
# TODO: T! also need indivudual lens_dist_coeff for each t in T
|
249 |
+
# print(self.lens_dist_coeff.shape)
|
250 |
+
return f"lens dist coeff=" + " ".join(
|
251 |
+
[f"{x:.2f}" for x in self.lens_dist_coeff[b, :2]]
|
252 |
+
) # print only radial lens dist. coeff
|
253 |
+
|
254 |
+
def __repr__(self) -> str:
|
255 |
+
return f"{self.__class__}:" + self.__str__()
|
256 |
+
|
257 |
+
def __len__(self):
|
258 |
+
return self.pseudo_batch_size # e.g. self.intrinsics.shape[0]
|
259 |
+
|
260 |
+
def project_point2pixel(self, points3d: torch.tensor, lens_distortion: bool) -> torch.tensor:
|
261 |
+
"""Project world coordinates to pixel coordinates.
|
262 |
+
|
263 |
+
Args:
|
264 |
+
points3d (torch.tensor): of shape (N, 3) or (1, N, 3)
|
265 |
+
|
266 |
+
Returns:
|
267 |
+
torch.tensor: projected points of shape (B, T, N, 2)
|
268 |
+
"""
|
269 |
+
position = self.position.view(self.pseudo_batch_size, 1, 3)
|
270 |
+
point = points3d - position
|
271 |
+
rotated_point = self.rotation @ point.transpose(1, 2) # (pseudo_batch_size, 3, N)
|
272 |
+
dist_point2cam = rotated_point[:, 2] # (B, N) distance pixel to world point
|
273 |
+
dist_point2cam = dist_point2cam.view(self.pseudo_batch_size, 1, rotated_point.shape[-1])
|
274 |
+
rotated_point = rotated_point / dist_point2cam # (B, 3, N) / (B, 1, N) -> (B, 3, N)
|
275 |
+
|
276 |
+
projected_points = self.intrinsics_raster @ rotated_point # (B, 3, N)
|
277 |
+
# transpose vs view? here
|
278 |
+
projected_points = projected_points.transpose(-1, -2) # cannot use view()
|
279 |
+
projected_points = kornia.geometry.convert_points_from_homogeneous(projected_points)
|
280 |
+
if lens_distortion:
|
281 |
+
if self.psi is None:
|
282 |
+
raise RuntimeError("Lens distortion requested, but deactivated in module")
|
283 |
+
projected_points = self.distort_points(projected_points, self.intrinsics_raster)
|
284 |
+
|
285 |
+
# reshape back from (pseudo_batch_size, N, 2) to (B, T, N, 2)
|
286 |
+
projected_points = projected_points.view(
|
287 |
+
self.batch_dim, self.temporal_dim, projected_points.shape[-2], 2
|
288 |
+
)
|
289 |
+
if self.nan_check:
|
290 |
+
if torch.isnan(projected_points).any().item():
|
291 |
+
print(self.phi_dict_flat)
|
292 |
+
print(projected_points)
|
293 |
+
raise RuntimeWarning("NaN in project_point2pixel")
|
294 |
+
return projected_points
|
295 |
+
|
296 |
+
def project_point2ndc(self, points3d: torch.tensor, lens_distortion: bool) -> torch.tensor:
|
297 |
+
"""Project world coordinates to pixel coordinates.
|
298 |
+
|
299 |
+
Args:
|
300 |
+
points3d (torch.tensor): of shape (N, 3) or (1, N, 3)
|
301 |
+
|
302 |
+
Returns:
|
303 |
+
torch.tensor: projected points of shape (B, T, N, 2)
|
304 |
+
"""
|
305 |
+
position = self.position.view(self.pseudo_batch_size, 1, 3)
|
306 |
+
point = points3d - position
|
307 |
+
rotated_point = self.rotation @ point.transpose(1, 2) # (pseudo_batch_size, 3, N)
|
308 |
+
dist_point2cam = rotated_point[:, 2] # (B, N) distance pixel to world point
|
309 |
+
dist_point2cam = dist_point2cam.view(self.pseudo_batch_size, 1, rotated_point.shape[-1])
|
310 |
+
rotated_point = rotated_point / dist_point2cam # (B, 3, N) / (B, 1, N) -> (B, 3, N)
|
311 |
+
|
312 |
+
projected_points = self.intrinsics_ndc @ rotated_point # (B, 3, N)
|
313 |
+
# transpose vs view? here
|
314 |
+
projected_points = projected_points.transpose(-1, -2) # cannot use view()
|
315 |
+
projected_points = kornia.geometry.convert_points_from_homogeneous(projected_points)
|
316 |
+
if self.nan_check:
|
317 |
+
if torch.isnan(projected_points).any().item():
|
318 |
+
print(projected_points)
|
319 |
+
print(self.phi_dict_flat)
|
320 |
+
print("lens distortion", self.lens_dist_coeff)
|
321 |
+
|
322 |
+
raise RuntimeWarning("NaN in project_point2ndc before distort")
|
323 |
+
if lens_distortion:
|
324 |
+
if self.psi is None:
|
325 |
+
raise RuntimeError("Lens distortion requested, but deactivated in module")
|
326 |
+
projected_points = self.distort_points(projected_points, self.intrinsics_ndc)
|
327 |
+
|
328 |
+
# reshape back from (pseudo_batch_size, N, 2) to (B, T, N, 2)
|
329 |
+
projected_points = projected_points.view(
|
330 |
+
self.batch_dim, self.temporal_dim, projected_points.shape[-2], 2
|
331 |
+
)
|
332 |
+
if self.nan_check:
|
333 |
+
if torch.isnan(projected_points).any().item():
|
334 |
+
print(self.phi_dict_flat)
|
335 |
+
print(projected_points)
|
336 |
+
raise RuntimeWarning("NaN in project_point2ndc after distort")
|
337 |
+
return projected_points
|
338 |
+
|
339 |
+
def project_point2pixel_from_P(
|
340 |
+
self, points3d: torch.tensor, lens_distortion: bool
|
341 |
+
) -> torch.tensor:
|
342 |
+
"""Project world coordinates to pixel coordinates from the projection matrix.
|
343 |
+
|
344 |
+
Args:
|
345 |
+
points3d (torch.tensor): of shape (1, N, 3)
|
346 |
+
|
347 |
+
Returns:
|
348 |
+
torch.tensor: projected points of shape (B, T, N, 2)
|
349 |
+
"""
|
350 |
+
|
351 |
+
points3d = kornia.geometry.conversions.convert_points_to_homogeneous(points3d).transpose(
|
352 |
+
1, 2
|
353 |
+
) # (B, 4, N)
|
354 |
+
projected_points = torch.bmm(self.P_raster, points3d.repeat(self.pseudo_batch_size, 1, 1))
|
355 |
+
normalize_by = projected_points[:, -1].view(
|
356 |
+
self.pseudo_batch_size, 1, projected_points.shape[-1]
|
357 |
+
)
|
358 |
+
projected_points /= normalize_by
|
359 |
+
projected_points = projected_points.transpose(-1, -2) # cannot use view()
|
360 |
+
projected_points = kornia.geometry.convert_points_from_homogeneous(projected_points)
|
361 |
+
if lens_distortion:
|
362 |
+
if self.psi is None:
|
363 |
+
raise RuntimeError("Lens distortion requested, but deactivated in module")
|
364 |
+
projected_points = self.distort_points(projected_points, self.intrinsics_raster)
|
365 |
+
# reshape back from (pseudo_batch_size, N, 2) to (B, T, N, 2)
|
366 |
+
projected_points = projected_points.view(
|
367 |
+
self.batch_dim, self.temporal_dim, projected_points.shape[-2], 2
|
368 |
+
)
|
369 |
+
return projected_points # (B, T, N, 2)
|
370 |
+
|
371 |
+
def project_point2ndc_from_P(
|
372 |
+
self, points3d: torch.tensor, lens_distortion: bool
|
373 |
+
) -> torch.tensor:
|
374 |
+
"""Project world coordinates to pixel coordinates from the projection matrix.
|
375 |
+
|
376 |
+
Args:
|
377 |
+
points3d (torch.tensor): of shape (1, N, 3)
|
378 |
+
|
379 |
+
Returns:
|
380 |
+
torch.tensor: projected points of shape (B, T, N, 2)
|
381 |
+
"""
|
382 |
+
|
383 |
+
points3d = kornia.geometry.conversions.convert_points_to_homogeneous(points3d).transpose(
|
384 |
+
1, 2
|
385 |
+
) # (B, 4, N)
|
386 |
+
projected_points = torch.bmm(self.P_ndc, points3d.repeat(self.pseudo_batch_size, 1, 1))
|
387 |
+
normalize_by = projected_points[:, -1].view(
|
388 |
+
self.pseudo_batch_size, 1, projected_points.shape[-1]
|
389 |
+
)
|
390 |
+
projected_points /= normalize_by
|
391 |
+
projected_points = projected_points.transpose(-1, -2) # cannot use view()
|
392 |
+
projected_points = kornia.geometry.convert_points_from_homogeneous(projected_points)
|
393 |
+
if lens_distortion:
|
394 |
+
if self.psi is None:
|
395 |
+
raise RuntimeError("Lens distortion requested, but deactivated in module")
|
396 |
+
projected_points = self.distort_points(projected_points, self.intrinsics_ndc)
|
397 |
+
# reshape back from (pseudo_batch_size, N, 2) to (B, T, N, 2)
|
398 |
+
projected_points = projected_points.view(
|
399 |
+
self.batch_dim, self.temporal_dim, projected_points.shape[-2], 2
|
400 |
+
)
|
401 |
+
return projected_points # (B, T, N, 2)
|
402 |
+
|
403 |
+
def rotation_from_euler_angles(self, pan, tilt, roll):
|
404 |
+
# rotation matrices from a batch of pan tilt roll [rad] vectors of shape (?, )
|
405 |
+
|
406 |
+
mask = (
|
407 |
+
torch.eye(3, requires_grad=False, device=self.device)
|
408 |
+
.reshape((1, 3, 3))
|
409 |
+
.repeat(pan.shape[0], 1, 1)
|
410 |
+
)
|
411 |
+
mask[:, 0, 0] = -torch.sin(pan) * torch.sin(roll) * torch.cos(tilt) + torch.cos(
|
412 |
+
pan
|
413 |
+
) * torch.cos(roll)
|
414 |
+
mask[:, 0, 1] = torch.sin(pan) * torch.cos(roll) + torch.sin(roll) * torch.cos(
|
415 |
+
pan
|
416 |
+
) * torch.cos(tilt)
|
417 |
+
mask[:, 0, 2] = torch.sin(roll) * torch.sin(tilt)
|
418 |
+
|
419 |
+
mask[:, 1, 0] = -torch.sin(pan) * torch.cos(roll) * torch.cos(tilt) - torch.sin(
|
420 |
+
roll
|
421 |
+
) * torch.cos(pan)
|
422 |
+
mask[:, 1, 1] = -torch.sin(pan) * torch.sin(roll) + torch.cos(pan) * torch.cos(
|
423 |
+
roll
|
424 |
+
) * torch.cos(tilt)
|
425 |
+
mask[:, 1, 2] = torch.sin(tilt) * torch.cos(roll)
|
426 |
+
|
427 |
+
mask[:, 2, 0] = torch.sin(pan) * torch.sin(tilt)
|
428 |
+
mask[:, 2, 1] = -torch.sin(tilt) * torch.cos(pan)
|
429 |
+
mask[:, 2, 2] = torch.cos(tilt)
|
430 |
+
|
431 |
+
return mask
|
432 |
+
|
433 |
+
def get_homography_raster(self):
|
434 |
+
return self.P_raster[:, :, [0, 1, 3]].inverse()
|
435 |
+
|
436 |
+
def get_rays_world(self, x):
|
437 |
+
"""_summary_
|
438 |
+
|
439 |
+
Args:
|
440 |
+
x (_type_): x of shape (B, 3, N)
|
441 |
+
|
442 |
+
Returns:
|
443 |
+
LineCollection: _description_
|
444 |
+
"""
|
445 |
+
raise NotImplementedError
|
446 |
+
# TODO: verify
|
447 |
+
# ray_cam_trans = torch.bmm(self.rotation.inverse(), torch.bmm(self.intrinsics.inverse(), x))
|
448 |
+
# # unnormalized direction vector in euclidean points (x,y,z) based on camera origin (0,0,0)
|
449 |
+
# ray_cam_trans = torch.nn.functional.normalize(ray_cam_trans, p=2, dim=1) # (B, 3, N)
|
450 |
+
|
451 |
+
# # shift support vector to origin in world space, i.e. the translation vector
|
452 |
+
# support = self.position.unsqueeze(-1).repeat(
|
453 |
+
# ray_cam_trans.shape[0], 1, ray_cam_trans.shape[2]
|
454 |
+
# ) # (B, 3, N)
|
455 |
+
# return LineCollection(support=support, direction_norm=ray_cam_trans)
|
456 |
+
|
457 |
+
@staticmethod
|
458 |
+
def get_aov_rad(d: float, fl: torch.tensor):
|
459 |
+
# https://en.wikipedia.org/wiki/Angle_of_view#Calculating_a_camera's_angle_of_view
|
460 |
+
return 2 * torch.arctan(d / (2 * fl)) # in range [0.0, PI]
|
461 |
+
|
462 |
+
@staticmethod
|
463 |
+
def get_fl_from_aov_rad(aov_rad: torch.tensor, d: float):
|
464 |
+
return 0.5 * d * (1 / torch.tan(0.5 * aov_rad))
|
465 |
+
|
466 |
+
def undistort_points(self, points_pixel: torch.tensor, intrinsics, num_iters=5) -> torch.tensor:
|
467 |
+
"""Compensate for lens distortion a set of 2D image points.
|
468 |
+
|
469 |
+
Wrapper for kornia.geometry.undistort_points()
|
470 |
+
|
471 |
+
Args:
|
472 |
+
points_pixel (torch.tensor): tensor of shape (B, N, 2)
|
473 |
+
|
474 |
+
Returns:
|
475 |
+
torch.tensor: undistorted points of shape (B, N, 2)
|
476 |
+
"""
|
477 |
+
# print(points_pixel.shape, intrinsics.shape, self.lens_dist_coeff.shape)
|
478 |
+
batch_dim, temporal_dim, N, _ = points_pixel.shape
|
479 |
+
points_pixel = points_pixel.view(batch_dim * temporal_dim, N, 2)
|
480 |
+
true_batch_size = batch_dim
|
481 |
+
|
482 |
+
lens_dist_coeff = self.lens_dist_coeff
|
483 |
+
if true_batch_size < self.batch_dim:
|
484 |
+
intrinsics = intrinsics[:true_batch_size]
|
485 |
+
lens_dist_coeff = lens_dist_coeff[:true_batch_size]
|
486 |
+
|
487 |
+
return kornia.geometry.undistort_points(
|
488 |
+
points_pixel, intrinsics, dist=lens_dist_coeff, num_iters=num_iters
|
489 |
+
).view(batch_dim, temporal_dim, N, 2)
|
490 |
+
|
491 |
+
def distort_points(self, points_pixel: torch.tensor, intrinsics) -> torch.tensor:
|
492 |
+
"""Distortion of a set of 2D points based on the lens distortion model.
|
493 |
+
|
494 |
+
Wrapper for kornia.geometry.distort_points()
|
495 |
+
|
496 |
+
Args:
|
497 |
+
points_pixel (torch.tensor): tensor of shape (B, N, 2)
|
498 |
+
|
499 |
+
Returns:
|
500 |
+
torch.tensor: distorted points of shape (B, N, 2)
|
501 |
+
"""
|
502 |
+
return kornia.geometry.distort_points(points_pixel, intrinsics, dist=self.lens_dist_coeff)
|
503 |
+
|
504 |
+
def undistort_images(self, images):
|
505 |
+
# images of shape (B, T, C, H, W)
|
506 |
+
true_batch_size, T = images.shape[:2]
|
507 |
+
images = images.view(true_batch_size * T, 3, self.image_height, self.image_width).to(
|
508 |
+
self.device
|
509 |
+
)
|
510 |
+
intrinsics = self.intrinsics_raster
|
511 |
+
lens_dist_coeff = self.lens_dist_coeff
|
512 |
+
if true_batch_size < self.batch_dim:
|
513 |
+
intrinsics = intrinsics[:true_batch_size]
|
514 |
+
lens_dist_coeff = lens_dist_coeff[:true_batch_size]
|
515 |
+
|
516 |
+
return kornia.geometry.calibration.undistort_image(
|
517 |
+
images, intrinsics, lens_dist_coeff
|
518 |
+
).view(true_batch_size, self.temporal_dim, 3, self.image_height, self.image_width)
|
519 |
+
|
520 |
+
def get_parameters(self, true_batch_size=None):
|
521 |
+
"""
|
522 |
+
Get dict of relevant camera parameters and homography matrix
|
523 |
+
:return: The dictionary
|
524 |
+
"""
|
525 |
+
out_dict = {
|
526 |
+
"pan_degrees": torch.rad2deg(self.phi_dict["pan"]),
|
527 |
+
"tilt_degrees": torch.rad2deg(self.phi_dict["tilt"]),
|
528 |
+
"roll_degrees": torch.rad2deg(self.phi_dict["roll"]),
|
529 |
+
"position_meters": torch.stack([self.phi_dict[k] for k in ["c_x", "c_y", "c_z"]], dim=1)
|
530 |
+
.squeeze(-1)
|
531 |
+
.unsqueeze(-2)
|
532 |
+
.repeat(1, self.temporal_dim, 1),
|
533 |
+
"aov_radian": self.phi_dict["aov"],
|
534 |
+
"aov_degrees": torch.rad2deg(self.phi_dict["aov"]),
|
535 |
+
"x_focal_length": self.get_fl_from_aov_rad(self.phi_dict["aov"], d=self.image_width),
|
536 |
+
"y_focal_length": self.get_fl_from_aov_rad(self.phi_dict["aov"], d=self.image_width),
|
537 |
+
"principal_point": torch.tensor(
|
538 |
+
[[self.principal_point] * self.temporal_dim] * self.batch_dim
|
539 |
+
),
|
540 |
+
}
|
541 |
+
out_dict["homography"] = self.get_homography_raster().unsqueeze(1) # (B, 1, 3, 3)
|
542 |
+
|
543 |
+
# expected for SN evaluation
|
544 |
+
out_dict["radial_distortion"] = torch.zeros(self.batch_dim, self.temporal_dim, 6)
|
545 |
+
out_dict["tangential_distortion"] = torch.zeros(self.batch_dim, self.temporal_dim, 2)
|
546 |
+
out_dict["thin_prism_distortion"] = torch.zeros(self.batch_dim, self.temporal_dim, 4)
|
547 |
+
|
548 |
+
if self.psi is not None:
|
549 |
+
# in case only k1 and k2 are provided
|
550 |
+
out_dict["radial_distortion"][..., :2] = self.psi[..., :2]
|
551 |
+
|
552 |
+
if true_batch_size is None or true_batch_size == self.batch_dim:
|
553 |
+
return out_dict
|
554 |
+
|
555 |
+
for k in out_dict.keys():
|
556 |
+
out_dict[k] = out_dict[k][:true_batch_size]
|
557 |
+
|
558 |
+
return out_dict
|
559 |
+
|
560 |
+
@staticmethod
|
561 |
+
def static_undistort_points(points, cam):
|
562 |
+
|
563 |
+
intrinsics = cam.intrinsics_raster
|
564 |
+
lens_dist_coeff = cam.lens_dist_coeff
|
565 |
+
|
566 |
+
true_batch_size = points.shape[0]
|
567 |
+
if true_batch_size < cam.batch_dim:
|
568 |
+
intrinsics = intrinsics[:true_batch_size]
|
569 |
+
lens_dist_coeff = lens_dist_coeff[:true_batch_size]
|
570 |
+
# points in homogenous coordinates
|
571 |
+
# (B, T, 3, S, N) -> (T, 3, S*N) -> (T, S*N, 3)
|
572 |
+
batch_size, T, _, S, N = points.shape
|
573 |
+
points = points.view(batch_size, T, 3, S * N).transpose(2, 3)
|
574 |
+
points[..., :2] = kornia.geometry.undistort_points(
|
575 |
+
points[..., :2].view(batch_size * T, S * N, 2),
|
576 |
+
intrinsics,
|
577 |
+
dist=lens_dist_coeff,
|
578 |
+
num_iters=1,
|
579 |
+
).view(batch_size, T, S * N, 2)
|
580 |
+
|
581 |
+
# (T, S*N, 3) -> (T, 3, S*N) -> (B, T, 3, S, N)
|
582 |
+
points = points.transpose(2, 3).view(batch_size, T, 3, S, N)
|
583 |
+
return points
|
tvcalib/data/dataset.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import kornia
|
2 |
+
import torch
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
from kornia.geometry.transform import resize
|
6 |
+
from .utils import split_circle_central
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
class InferenceDatasetCalibration(torch.utils.data.Dataset):
|
11 |
+
def __init__(self, keypoints_raw, image_width_source, image_height_source, object3d) -> None:
|
12 |
+
super().__init__()
|
13 |
+
self.keypoints_raw = keypoints_raw
|
14 |
+
self.w = image_width_source
|
15 |
+
self.h = image_height_source
|
16 |
+
self.object3d = object3d
|
17 |
+
self.split_circle_central = True
|
18 |
+
|
19 |
+
def __getitem__(self, idx):
|
20 |
+
|
21 |
+
keypoints_dict = self.keypoints_raw[idx]
|
22 |
+
if self.split_circle_central:
|
23 |
+
keypoints_dict = split_circle_central(keypoints_dict)
|
24 |
+
# add empty entries for non-visible segments
|
25 |
+
for l in self.object3d.segment_names:
|
26 |
+
if l not in keypoints_dict:
|
27 |
+
keypoints_dict[l] = []
|
28 |
+
|
29 |
+
per_sample_output = self.prepare_per_sample(
|
30 |
+
keypoints_dict, self.object3d, 4, 8, self.w, self.h, pad_pixel_position_xy=0.0
|
31 |
+
)
|
32 |
+
for k in per_sample_output.keys():
|
33 |
+
per_sample_output[k] = per_sample_output[k].unsqueeze(0)
|
34 |
+
|
35 |
+
return per_sample_output
|
36 |
+
|
37 |
+
def __len__(self):
|
38 |
+
return len(self.keypoints_raw)
|
39 |
+
|
40 |
+
@staticmethod
|
41 |
+
def prepare_per_sample(
|
42 |
+
keypoints_raw: dict,
|
43 |
+
model3d,
|
44 |
+
num_points_on_line_segments: int,
|
45 |
+
num_points_on_circle_segments: int,
|
46 |
+
image_width_source: int,
|
47 |
+
image_height_source: int,
|
48 |
+
pad_pixel_position_xy=0.0,
|
49 |
+
):
|
50 |
+
r = {}
|
51 |
+
pixel_stacked = {}
|
52 |
+
for label, points in keypoints_raw.items():
|
53 |
+
|
54 |
+
num_points_selection = num_points_on_line_segments
|
55 |
+
if "Circle" in label:
|
56 |
+
num_points_selection = num_points_on_circle_segments
|
57 |
+
|
58 |
+
# rand select num_points_selection
|
59 |
+
if num_points_selection > len(points):
|
60 |
+
points_sel = points
|
61 |
+
else:
|
62 |
+
# random sample without replacement
|
63 |
+
points_sel = random.sample(points, k=num_points_selection)
|
64 |
+
|
65 |
+
if len(points_sel) > 0:
|
66 |
+
xx = torch.tensor([a["x"] for a in points_sel])
|
67 |
+
yy = torch.tensor([a["y"] for a in points_sel])
|
68 |
+
pixel_stacked[label] = torch.stack([xx, yy], dim=-1) # (?, 2)
|
69 |
+
# scale pixel annotations from [0, 1] range to source image resolution
|
70 |
+
# as this ranges from [1, {image_height, image_width}] shift pixel one left
|
71 |
+
pixel_stacked[label][:, 0] = pixel_stacked[label][:, 0] * (image_width_source - 1)
|
72 |
+
pixel_stacked[label][:, 1] = pixel_stacked[label][:, 1] * (image_height_source - 1)
|
73 |
+
|
74 |
+
for segment_type, num_segments, segment_names in [
|
75 |
+
("lines", model3d.line_segments.shape[1], model3d.line_segments_names),
|
76 |
+
("circles", model3d.circle_segments.shape[1], model3d.circle_segments_names),
|
77 |
+
]:
|
78 |
+
|
79 |
+
num_points_selection = num_points_on_line_segments
|
80 |
+
if segment_type == "circles":
|
81 |
+
num_points_selection = num_points_on_circle_segments
|
82 |
+
px_projected_selection = (
|
83 |
+
torch.zeros((num_segments, num_points_selection, 2)) + pad_pixel_position_xy
|
84 |
+
)
|
85 |
+
for segment_index, label in enumerate(segment_names):
|
86 |
+
if label in pixel_stacked:
|
87 |
+
# set annotations to first positions
|
88 |
+
px_projected_selection[
|
89 |
+
segment_index, : pixel_stacked[label].shape[0], :
|
90 |
+
] = pixel_stacked[label]
|
91 |
+
|
92 |
+
randperm = torch.randperm(num_points_selection)
|
93 |
+
px_projected_selection_shuffled = px_projected_selection.clone()
|
94 |
+
px_projected_selection_shuffled[:, :, 0] = px_projected_selection_shuffled[
|
95 |
+
:, randperm, 0
|
96 |
+
]
|
97 |
+
px_projected_selection_shuffled[:, :, 1] = px_projected_selection_shuffled[
|
98 |
+
:, randperm, 1
|
99 |
+
]
|
100 |
+
|
101 |
+
is_keypoint_mask = (
|
102 |
+
(0.0 <= px_projected_selection_shuffled[:, :, 0])
|
103 |
+
& (px_projected_selection_shuffled[:, :, 0] < image_width_source)
|
104 |
+
) & (
|
105 |
+
(0 < px_projected_selection_shuffled[:, :, 1])
|
106 |
+
& (px_projected_selection_shuffled[:, :, 1] < image_height_source)
|
107 |
+
)
|
108 |
+
|
109 |
+
r[f"{segment_type}__is_keypoint_mask"] = is_keypoint_mask.unsqueeze(0)
|
110 |
+
|
111 |
+
# reshape from (num_segments, num_points_selection, 2) to (3, num_segments, num_points_selection)
|
112 |
+
px_projected_selection_shuffled = (
|
113 |
+
kornia.geometry.conversions.convert_points_to_homogeneous(
|
114 |
+
px_projected_selection_shuffled
|
115 |
+
)
|
116 |
+
)
|
117 |
+
px_projected_selection_shuffled = px_projected_selection_shuffled.view(
|
118 |
+
num_segments * num_points_selection, 3
|
119 |
+
)
|
120 |
+
px_projected_selection_shuffled = px_projected_selection_shuffled.transpose(0, 1)
|
121 |
+
px_projected_selection_shuffled = px_projected_selection_shuffled.view(
|
122 |
+
3, num_segments, num_points_selection
|
123 |
+
)
|
124 |
+
# (3, num_segments, num_points_selection)
|
125 |
+
r[f"{segment_type}__px_projected_selection_shuffled"] = px_projected_selection_shuffled
|
126 |
+
|
127 |
+
ndc_projected_selection_shuffled = px_projected_selection_shuffled.clone()
|
128 |
+
ndc_projected_selection_shuffled[0] = (
|
129 |
+
ndc_projected_selection_shuffled[0] / image_width_source
|
130 |
+
)
|
131 |
+
ndc_projected_selection_shuffled[1] = (
|
132 |
+
ndc_projected_selection_shuffled[1] / image_height_source
|
133 |
+
)
|
134 |
+
ndc_projected_selection_shuffled[1] = ndc_projected_selection_shuffled[1] * 2.0 - 1
|
135 |
+
ndc_projected_selection_shuffled[0] = ndc_projected_selection_shuffled[0] * 2.0 - 1
|
136 |
+
r[
|
137 |
+
f"{segment_type}__ndc_projected_selection_shuffled"
|
138 |
+
] = ndc_projected_selection_shuffled
|
139 |
+
|
140 |
+
return r
|
141 |
+
|
142 |
+
|
tvcalib/data/utils.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from operator import itemgetter
|
3 |
+
import torch
|
4 |
+
import re
|
5 |
+
import collections
|
6 |
+
|
7 |
+
|
8 |
+
string_classes=str
|
9 |
+
|
10 |
+
|
11 |
+
def split_circle_central(keypoints_dict):
|
12 |
+
# split "circle central" in "circle central left" and "circle central right"
|
13 |
+
|
14 |
+
# assume main camera --> TODO behind the goal camera
|
15 |
+
if "Circle central" in keypoints_dict:
|
16 |
+
points_circle_central_left = []
|
17 |
+
points_circle_central_right = []
|
18 |
+
|
19 |
+
if "Middle line" in keypoints_dict:
|
20 |
+
p_index_ymin, _ = min(
|
21 |
+
enumerate([p["y"] for p in keypoints_dict["Middle line"]]),
|
22 |
+
key=itemgetter(1),
|
23 |
+
)
|
24 |
+
p_index_ymax, _ = max(
|
25 |
+
enumerate([p["y"] for p in keypoints_dict["Middle line"]]),
|
26 |
+
key=itemgetter(1),
|
27 |
+
)
|
28 |
+
p_ymin = keypoints_dict["Middle line"][p_index_ymin]
|
29 |
+
p_ymax = keypoints_dict["Middle line"][p_index_ymax]
|
30 |
+
p_xmean = (p_ymin["x"] + p_ymax["x"]) / 2
|
31 |
+
|
32 |
+
points_circle_central = keypoints_dict["Circle central"]
|
33 |
+
for p in points_circle_central:
|
34 |
+
if p["x"] < p_xmean:
|
35 |
+
points_circle_central_left.append(p)
|
36 |
+
else:
|
37 |
+
points_circle_central_right.append(p)
|
38 |
+
else:
|
39 |
+
# circle is partly shown on the left or right side of the image
|
40 |
+
# mean position is shown on the left part of the image --> label right
|
41 |
+
circle_x = [p["x"] for p in keypoints_dict["Circle central"]]
|
42 |
+
mean_x_circle = sum(circle_x) / len(circle_x)
|
43 |
+
if mean_x_circle < 0.5:
|
44 |
+
points_circle_central_right = keypoints_dict["Circle central"]
|
45 |
+
else:
|
46 |
+
points_circle_central_left = keypoints_dict["Circle central"]
|
47 |
+
|
48 |
+
if len(points_circle_central_left) > 0:
|
49 |
+
keypoints_dict["Circle central left"] = points_circle_central_left
|
50 |
+
if len(points_circle_central_right) > 0:
|
51 |
+
keypoints_dict["Circle central right"] = points_circle_central_right
|
52 |
+
if len(points_circle_central_left) == 0 and len(points_circle_central_right) == 0:
|
53 |
+
raise RuntimeError
|
54 |
+
del keypoints_dict["Circle central"]
|
55 |
+
return keypoints_dict
|
56 |
+
|
57 |
+
|
58 |
+
def custom_list_collate(batch):
|
59 |
+
r"""
|
60 |
+
Function that takes in a batch of data and puts the elements within the batch
|
61 |
+
into a tensor with an additional outer dimension - batch size. The exact output type can be
|
62 |
+
a :class:`torch.Tensor`, a `Sequence` of :class:`torch.Tensor`, a
|
63 |
+
Collection of :class:`torch.Tensor`, or left unchanged, depending on the input type.
|
64 |
+
This is used as the default function for collation when
|
65 |
+
`batch_size` or `batch_sampler` is defined in :class:`~torch.utils.data.DataLoader`.
|
66 |
+
Here is the general input type (based on the type of the element within the batch) to output type mapping:
|
67 |
+
* :class:`torch.Tensor` -> :class:`torch.Tensor` (with an added outer dimension batch size)
|
68 |
+
* NumPy Arrays -> :class:`torch.Tensor`
|
69 |
+
* `float` -> :class:`torch.Tensor`
|
70 |
+
* `int` -> :class:`torch.Tensor`
|
71 |
+
* `str` -> `str` (unchanged)
|
72 |
+
* `bytes` -> `bytes` (unchanged)
|
73 |
+
* `Mapping[K, V_i]` -> `Mapping[K, default_collate([V_1, V_2, ...])]`
|
74 |
+
* `NamedTuple[V1_i, V2_i, ...]` -> `NamedTuple[default_collate([V1_1, V1_2, ...]), default_collate([V2_1, V2_2, ...]), ...]`
|
75 |
+
* `Sequence[V1_i, V2_i, ...]` -> `Sequence[default_collate([V1_1, V1_2, ...]), default_collate([V2_1, V2_2, ...]), ...]`
|
76 |
+
Args:
|
77 |
+
batch: a single batch to be collated
|
78 |
+
Examples:
|
79 |
+
>>> # Example with a batch of `int`s:
|
80 |
+
>>> default_collate([0, 1, 2, 3])
|
81 |
+
tensor([0, 1, 2, 3])
|
82 |
+
>>> # Example with a batch of `str`s:
|
83 |
+
>>> default_collate(['a', 'b', 'c'])
|
84 |
+
['a', 'b', 'c']
|
85 |
+
>>> # Example with `Map` inside the batch:
|
86 |
+
>>> default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}])
|
87 |
+
{'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])}
|
88 |
+
>>> # Example with `NamedTuple` inside the batch:
|
89 |
+
>>> Point = namedtuple('Point', ['x', 'y'])
|
90 |
+
>>> default_collate([Point(0, 0), Point(1, 1)])
|
91 |
+
Point(x=tensor([0, 1]), y=tensor([0, 1]))
|
92 |
+
>>> # Example with `Tuple` inside the batch:
|
93 |
+
>>> default_collate([(0, 1), (2, 3)])
|
94 |
+
[tensor([0, 2]), tensor([1, 3])]
|
95 |
+
|
96 |
+
>>> # modification
|
97 |
+
>>> # Example with `List` inside the batch:
|
98 |
+
>>> default_collate([[0, 1, 2], [2, 3, 4]])
|
99 |
+
>>> [[0, 1, 2], [2, 3, 4]]
|
100 |
+
>>> # original behavior
|
101 |
+
>>> [[0, 2], [1, 3], [2, 4]]
|
102 |
+
"""
|
103 |
+
|
104 |
+
np_str_obj_array_pattern = re.compile(r"[SaUO]")
|
105 |
+
default_collate_err_msg_format = "default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found {}"
|
106 |
+
|
107 |
+
elem = batch[0]
|
108 |
+
elem_type = type(elem)
|
109 |
+
if isinstance(elem, torch.Tensor):
|
110 |
+
out = None
|
111 |
+
if torch.utils.data.get_worker_info() is not None:
|
112 |
+
# If we're in a background process, concatenate directly into a
|
113 |
+
# shared memory tensor to avoid an extra copy
|
114 |
+
numel = sum(x.numel() for x in batch)
|
115 |
+
storage = elem.storage()._new_shared(numel)
|
116 |
+
out = elem.new(storage).resize_(len(batch), *list(elem.size()))
|
117 |
+
return torch.stack(batch, 0, out=out)
|
118 |
+
elif (
|
119 |
+
elem_type.__module__ == "numpy"
|
120 |
+
and elem_type.__name__ != "str_"
|
121 |
+
and elem_type.__name__ != "string_"
|
122 |
+
):
|
123 |
+
if elem_type.__name__ == "ndarray" or elem_type.__name__ == "memmap":
|
124 |
+
# array of string classes and object
|
125 |
+
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
|
126 |
+
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
|
127 |
+
|
128 |
+
return [torch.as_tensor(b) for b in batch]
|
129 |
+
elif elem.shape == (): # scalars
|
130 |
+
return torch.as_tensor(batch)
|
131 |
+
elif isinstance(elem, float):
|
132 |
+
return torch.tensor(batch, dtype=torch.float64)
|
133 |
+
elif isinstance(elem, int):
|
134 |
+
return torch.tensor(batch)
|
135 |
+
elif isinstance(elem, string_classes):
|
136 |
+
return batch
|
137 |
+
elif isinstance(elem, collections.abc.Mapping):
|
138 |
+
try:
|
139 |
+
return elem_type({key: custom_list_collate([d[key] for d in batch]) for key in elem})
|
140 |
+
except TypeError:
|
141 |
+
# The mapping type may not support `__init__(iterable)`.
|
142 |
+
return {key: custom_list_collate([d[key] for d in batch]) for key in elem}
|
143 |
+
elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple
|
144 |
+
return elem_type(*(custom_list_collate(samples) for samples in zip(*batch)))
|
145 |
+
elif isinstance(elem, collections.abc.Sequence):
|
146 |
+
# check to make sure that the elements in batch have consistent size
|
147 |
+
it = iter(batch)
|
148 |
+
elem_size = len(next(it))
|
149 |
+
if not all(len(elem) == elem_size for elem in it):
|
150 |
+
raise RuntimeError("each element in list of batch should be of equal size")
|
151 |
+
# transposed = list(zip(*batch)) # It may be accessed twice, so we use a list.
|
152 |
+
|
153 |
+
return batch
|
154 |
+
|
155 |
+
# if isinstance(elem, tuple):
|
156 |
+
# return [
|
157 |
+
# custom_list_collate(samples) for samples in transposed
|
158 |
+
# ] # Backwards compatibility.
|
159 |
+
# else:
|
160 |
+
# try:
|
161 |
+
# return elem_type([custom_list_collate(samples) for samples in transposed])
|
162 |
+
# except TypeError:
|
163 |
+
# # The sequence type may not support `__init__(iterable)` (e.g., `range`).
|
164 |
+
# return [custom_list_collate(samples) for samples in transposed]
|
165 |
+
|
166 |
+
raise TypeError(default_collate_err_msg_format.format(elem_type))
|
tvcalib/infer/module.py
ADDED
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from functools import partial
|
4 |
+
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Any, Dict, Tuple
|
7 |
+
# Imports depuis le package common (supposé être au même niveau que tvcalib)
|
8 |
+
from common.infer.base import *
|
9 |
+
# from common.registry import Registry # Toujours commenté car source inconnue
|
10 |
+
# from common.utils import to_cuda # Toujours commenté car source inconnue
|
11 |
+
# import project as p # Supprimé car probablement lié au projet complet
|
12 |
+
|
13 |
+
import torchvision.transforms as T
|
14 |
+
# Imports relatifs à l'intérieur de tvcalib (restent relatifs)
|
15 |
+
from ..sn_segmentation.src.custom_extremities import (
|
16 |
+
generate_class_synthesis, get_line_extremities
|
17 |
+
)
|
18 |
+
from ..models.segmentation import InferenceSegmentationModel
|
19 |
+
from ..data.dataset import InferenceDatasetCalibration
|
20 |
+
from ..data.utils import custom_list_collate
|
21 |
+
from ..cam_modules import CameraParameterWLensDistDictZScore, SNProjectiveCamera
|
22 |
+
from ..utils.linalg import distance_line_pointcloud_3d, distance_point_pointcloud
|
23 |
+
from ..utils.objects_3d import SoccerPitchLineCircleSegments, SoccerPitchSNCircleCentralSplit
|
24 |
+
from ..cam_distr.tv_main_center import get_cam_distr, get_dist_distr
|
25 |
+
from ..utils.io import detach_dict, tensor2list
|
26 |
+
# Import depuis le package common
|
27 |
+
from common.data.utils import yards
|
28 |
+
|
29 |
+
from kornia.geometry.conversions import convert_points_to_homogeneous
|
30 |
+
from tqdm.auto import tqdm
|
31 |
+
|
32 |
+
# Commenté car lié à la méthode 'robust' et peut introduire des dépendances
|
33 |
+
# from methods.robust.loggers.preview import RobustPreviewLogger
|
34 |
+
|
35 |
+
import numpy as np
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
class TvCalibInferModule(InferModule):
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
segmentation_checkpoint: Path,
|
43 |
+
image_shape=(720,1280),
|
44 |
+
optim_steps=2000,
|
45 |
+
lens_dist: bool=False,
|
46 |
+
playfield_size=(105, 68),
|
47 |
+
make_images: bool=False
|
48 |
+
|
49 |
+
):
|
50 |
+
self.image_shape = image_shape
|
51 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
52 |
+
self.make_images = make_images
|
53 |
+
|
54 |
+
# We use the logger to draw visualizations
|
55 |
+
# Commenté car la classe RobustPreviewLogger est commentée
|
56 |
+
# self.previewer = RobustPreviewLogger(
|
57 |
+
# None, num_images=1
|
58 |
+
# )
|
59 |
+
|
60 |
+
self.fn_generate_class_synthesis = partial(
|
61 |
+
generate_class_synthesis,
|
62 |
+
radius=4
|
63 |
+
)
|
64 |
+
self.fn_get_line_extremities = partial(
|
65 |
+
get_line_extremities,
|
66 |
+
maxdist=30,
|
67 |
+
width=455,
|
68 |
+
height=256,
|
69 |
+
num_points_lines=4,
|
70 |
+
num_points_circles=8
|
71 |
+
)
|
72 |
+
|
73 |
+
# Segmentation model
|
74 |
+
self.model_seg = InferenceSegmentationModel(
|
75 |
+
segmentation_checkpoint,
|
76 |
+
self.device
|
77 |
+
)
|
78 |
+
|
79 |
+
self.object3d = SoccerPitchLineCircleSegments(
|
80 |
+
device=self.device,
|
81 |
+
base_field=SoccerPitchSNCircleCentralSplit()
|
82 |
+
)
|
83 |
+
self.object3dcpu = SoccerPitchLineCircleSegments(
|
84 |
+
device="cpu",
|
85 |
+
base_field=SoccerPitchSNCircleCentralSplit()
|
86 |
+
)
|
87 |
+
|
88 |
+
# Calibration module
|
89 |
+
batch_size_calib = 1
|
90 |
+
self.model_calib = TVCalibModule(
|
91 |
+
self.object3d,
|
92 |
+
get_cam_distr(1.96, batch_size_calib, 1),
|
93 |
+
get_dist_distr(batch_size_calib, 1) if lens_dist else None,
|
94 |
+
(image_shape[0], image_shape[1]),
|
95 |
+
optim_steps,
|
96 |
+
self.device,
|
97 |
+
log_per_step=False,
|
98 |
+
tqdm_kwqargs=None,
|
99 |
+
)
|
100 |
+
self.resize = T.Compose([
|
101 |
+
T.Resize(size=(256,455))
|
102 |
+
])
|
103 |
+
self.offset = np.array([
|
104 |
+
[1, 0, playfield_size[0]/2.0 ],
|
105 |
+
[0, 1, playfield_size[1]/2.0 ],
|
106 |
+
[0, 0, 1]
|
107 |
+
])
|
108 |
+
|
109 |
+
|
110 |
+
|
111 |
+
def setup(self, datamodule: InferDataModule):
|
112 |
+
pass
|
113 |
+
|
114 |
+
|
115 |
+
def predict(self, x: Any) -> Dict:
|
116 |
+
|
117 |
+
"""
|
118 |
+
1. Run segmentation & Pick keypoints
|
119 |
+
2. Calibrate based on selected points
|
120 |
+
"""
|
121 |
+
|
122 |
+
# Segment
|
123 |
+
image = x["image"]
|
124 |
+
keypoints = self._segment(x["image"])
|
125 |
+
|
126 |
+
# Calibrate
|
127 |
+
homo = self._calibrate(keypoints)
|
128 |
+
|
129 |
+
# Rescale to 720p
|
130 |
+
image_720p = self.previewer.to_image(image.clone().detach().cpu())
|
131 |
+
|
132 |
+
# Draw predicted playing field
|
133 |
+
if (homo is not None):
|
134 |
+
# to yards
|
135 |
+
# Commenté car previewer est commenté
|
136 |
+
# to_yards = np.array([
|
137 |
+
# [ yards(1.0), 0, 0 ],
|
138 |
+
# [ 0, yards(1.0), 0 ],
|
139 |
+
# [ 0, 0, 1]
|
140 |
+
# ])
|
141 |
+
#homo = to_yards @ homo
|
142 |
+
|
143 |
+
# Commenté car previewer est commenté
|
144 |
+
# try:
|
145 |
+
# inv_homo = np.linalg.inv(homo) @ self.previewer.scale
|
146 |
+
# image_720p = self.previewer.draw_playfield(
|
147 |
+
# image_720p,
|
148 |
+
# self.previewer.image_playfield,
|
149 |
+
# inv_homo,
|
150 |
+
# color=(255,0,0), alpha=1.0,
|
151 |
+
# flip=False
|
152 |
+
# )
|
153 |
+
# except:
|
154 |
+
# # Homography might
|
155 |
+
# pass
|
156 |
+
pass # Placeholder si l'homographie existe mais previewer est commenté
|
157 |
+
|
158 |
+
result = {
|
159 |
+
"homography": homo
|
160 |
+
}
|
161 |
+
|
162 |
+
if (self.make_images):
|
163 |
+
# result["image_720p"] = image_720p # Commenté car image_720p n'est pas modifié sans previewer
|
164 |
+
pass # Placeholder si make_images est True
|
165 |
+
|
166 |
+
return result
|
167 |
+
|
168 |
+
|
169 |
+
def _segment(self, image):
|
170 |
+
|
171 |
+
# Image -> <1;3;256;455>
|
172 |
+
image = self.resize(image)
|
173 |
+
with torch.no_grad():
|
174 |
+
sem_lines = self.model_seg.inference(
|
175 |
+
image.unsqueeze(0).to(self.device)
|
176 |
+
)
|
177 |
+
# <B;256;455>
|
178 |
+
sem_lines = sem_lines.detach().cpu().numpy().astype(np.uint8)
|
179 |
+
|
180 |
+
# Point selection
|
181 |
+
skeletons_batch = self.fn_generate_class_synthesis(sem_lines[0])
|
182 |
+
keypoints_raw_batch = self.fn_get_line_extremities(skeletons_batch)
|
183 |
+
|
184 |
+
# Return the keypoints
|
185 |
+
return keypoints_raw_batch
|
186 |
+
|
187 |
+
|
188 |
+
def _calibrate(self, keypoints):
|
189 |
+
|
190 |
+
# Just wrap around the keypoints
|
191 |
+
ds = InferenceDatasetCalibration(
|
192 |
+
[keypoints],
|
193 |
+
self.image_shape[1], self.image_shape[0],
|
194 |
+
self.object3d
|
195 |
+
)
|
196 |
+
|
197 |
+
# Get the first item and optimize it
|
198 |
+
_batch_size = 1
|
199 |
+
x_dict = custom_list_collate([ds[0]])
|
200 |
+
try:
|
201 |
+
# La gestion de previous_params est faite dans self_optim_batch
|
202 |
+
per_sample_loss, cam, _ = self.model_calib.self_optim_batch(x_dict)
|
203 |
+
output_dict = tensor2list(
|
204 |
+
detach_dict({**cam.get_parameters(_batch_size), **per_sample_loss})
|
205 |
+
)
|
206 |
+
|
207 |
+
homo = output_dict["homography"][0]
|
208 |
+
if (len(homo) > 0):
|
209 |
+
homo = np.array(homo[0])
|
210 |
+
|
211 |
+
to_yards = np.array([
|
212 |
+
[ yards(1), 0, 0 ],
|
213 |
+
[ 0, yards(1), 0 ],
|
214 |
+
[ 0, 0, 1]
|
215 |
+
])
|
216 |
+
|
217 |
+
# Shift the homography by half the playing field
|
218 |
+
homo = to_yards @ self.offset @ homo
|
219 |
+
|
220 |
+
else:
|
221 |
+
homo = None
|
222 |
+
except Exception as e:
|
223 |
+
print(f"Erreur lors de la calibration: {str(e)}")
|
224 |
+
homo = None
|
225 |
+
|
226 |
+
return homo
|
227 |
+
|
228 |
+
|
229 |
+
|
230 |
+
|
231 |
+
|
232 |
+
|
233 |
+
|
234 |
+
class TVCalibModule(torch.nn.Module):
|
235 |
+
def __init__(
|
236 |
+
self,
|
237 |
+
model3d,
|
238 |
+
cam_distr,
|
239 |
+
dist_distr,
|
240 |
+
image_dim: Tuple[int, int],
|
241 |
+
optim_steps: int,
|
242 |
+
device="cpu",
|
243 |
+
tqdm_kwqargs=None,
|
244 |
+
log_per_step=False,
|
245 |
+
*args,
|
246 |
+
**kwargs,
|
247 |
+
) -> None:
|
248 |
+
super().__init__(*args, **kwargs)
|
249 |
+
self.image_height, self.image_width = image_dim
|
250 |
+
self.principal_point = (self.image_width / 2, self.image_height / 2)
|
251 |
+
self.model3d = model3d
|
252 |
+
self.cam_param_dict = CameraParameterWLensDistDictZScore(
|
253 |
+
cam_distr, dist_distr, device=device
|
254 |
+
)
|
255 |
+
|
256 |
+
self.lens_distortion_active = False if dist_distr is None else True
|
257 |
+
self.optim_steps = optim_steps
|
258 |
+
self._device = device
|
259 |
+
|
260 |
+
# Ajouter l'attribut pour stocker les paramètres précédents
|
261 |
+
self.previous_params = None
|
262 |
+
|
263 |
+
self.optim = torch.optim.AdamW(
|
264 |
+
self.cam_param_dict.param_dict.parameters(), lr=0.1, weight_decay=0.01
|
265 |
+
)
|
266 |
+
self.Scheduler = partial(
|
267 |
+
torch.optim.lr_scheduler.OneCycleLR,
|
268 |
+
max_lr=0.05,
|
269 |
+
total_steps=self.optim_steps,
|
270 |
+
pct_start=0.5,
|
271 |
+
)
|
272 |
+
|
273 |
+
if self.lens_distortion_active:
|
274 |
+
self.optim_lens_distortion = torch.optim.AdamW(
|
275 |
+
self.cam_param_dict.param_dict_dist.parameters(), lr=1e-3, weight_decay=0.01
|
276 |
+
)
|
277 |
+
self.Scheduler_lens_distortion = partial(
|
278 |
+
torch.optim.lr_scheduler.OneCycleLR,
|
279 |
+
max_lr=1e-3,
|
280 |
+
total_steps=self.optim_steps,
|
281 |
+
pct_start=0.33,
|
282 |
+
optimizer=self.optim_lens_distortion,
|
283 |
+
)
|
284 |
+
|
285 |
+
self.tqdm_kwqargs = tqdm_kwqargs
|
286 |
+
if tqdm_kwqargs is None:
|
287 |
+
self.tqdm_kwqargs = {}
|
288 |
+
|
289 |
+
self.hparams = {"optim": str(self.optim), "scheduler": str(self.Scheduler)}
|
290 |
+
self.log_per_step = log_per_step
|
291 |
+
|
292 |
+
def forward(self, x):
|
293 |
+
|
294 |
+
# individual camera parameters & distortion parameters
|
295 |
+
phi_hat, psi_hat = self.cam_param_dict()
|
296 |
+
|
297 |
+
cam = SNProjectiveCamera(
|
298 |
+
phi_hat,
|
299 |
+
psi_hat,
|
300 |
+
self.principal_point,
|
301 |
+
self.image_width,
|
302 |
+
self.image_height,
|
303 |
+
device=self._device,
|
304 |
+
nan_check=False,
|
305 |
+
)
|
306 |
+
|
307 |
+
# (batch_size, num_views_per_cam, 3, num_segments, num_points)
|
308 |
+
points_px_lines_true = x["lines__ndc_projected_selection_shuffled"].to(self._device)
|
309 |
+
batch_size, T_l, _, S_l, N_l = points_px_lines_true.shape
|
310 |
+
|
311 |
+
# project circle points
|
312 |
+
points_px_circles_true = x["circles__ndc_projected_selection_shuffled"].to(self._device)
|
313 |
+
_, T_c, _, S_c, N_c = points_px_circles_true.shape
|
314 |
+
assert T_c == T_l
|
315 |
+
|
316 |
+
#################### line-to-point distance at pixel space ####################
|
317 |
+
# start and end point (in world coordinates) for each line segment
|
318 |
+
points3d_lines_keypoints = self.model3d.line_segments # (3, S_l, 2) to (S_l * 2, 3)
|
319 |
+
points3d_lines_keypoints = points3d_lines_keypoints.reshape(3, S_l * 2).transpose(0, 1)
|
320 |
+
points_px_lines_keypoints = convert_points_to_homogeneous(
|
321 |
+
cam.project_point2ndc(points3d_lines_keypoints, lens_distortion=False)
|
322 |
+
) # (batch_size, t_l, S_l*2, 3)
|
323 |
+
|
324 |
+
if batch_size < cam.batch_dim: # actual batch_size smaller than expected, i.e. last batch
|
325 |
+
points_px_lines_keypoints = points_px_lines_keypoints[:batch_size]
|
326 |
+
|
327 |
+
points_px_lines_keypoints = points_px_lines_keypoints.view(batch_size, T_l, S_l, 2, 3)
|
328 |
+
|
329 |
+
lp1 = points_px_lines_keypoints[..., 0, :].unsqueeze(-2) # -> (batch_size, T_l, 1, S_l, 3)
|
330 |
+
lp2 = points_px_lines_keypoints[..., 1, :].unsqueeze(-2) # -> (batch_size, T_l, 1, S_l, 3)
|
331 |
+
# (batch_size, T, 3, S, N) -> (batch_size, T, 3, S*N) -> (batch_size, T, S*N, 3) -> (batch_size, T, S, N, 3)
|
332 |
+
pc = (
|
333 |
+
points_px_lines_true.view(batch_size, T_l, 3, S_l * N_l)
|
334 |
+
.transpose(2, 3)
|
335 |
+
.view(batch_size, T_l, S_l, N_l, 3)
|
336 |
+
)
|
337 |
+
|
338 |
+
if self.lens_distortion_active:
|
339 |
+
# undistort given points
|
340 |
+
pc = pc.view(batch_size, T_l, S_l * N_l, 3)
|
341 |
+
pc = pc.detach().clone()
|
342 |
+
pc[..., :2] = cam.undistort_points(
|
343 |
+
pc[..., :2], cam.intrinsics_ndc, num_iters=1
|
344 |
+
) # num_iters=1 might be enough for a good approximation
|
345 |
+
pc = pc.view(batch_size, T_l, S_l, N_l, 3)
|
346 |
+
|
347 |
+
distances_px_lines_raw = distance_line_pointcloud_3d(
|
348 |
+
e1=lp2 - lp1, r1=lp1, pc=pc, reduce=None
|
349 |
+
) # (batch_size, T_l, S_l, N_l)
|
350 |
+
distances_px_lines_raw = distances_px_lines_raw.unsqueeze(-3)
|
351 |
+
# (..., 1, S_l, N_l,), i.e. (batch_size, T, 1, S_l, N_l)
|
352 |
+
#################### circle-to-point distance at pixel space ####################
|
353 |
+
|
354 |
+
# circle segments are approximated as point clouds of size N_c_star
|
355 |
+
points3d_circles_pc = self.model3d.circle_segments
|
356 |
+
_, S_c, N_c_star = points3d_circles_pc.shape
|
357 |
+
points3d_circles_pc = points3d_circles_pc.reshape(3, S_c * N_c_star).transpose(0, 1)
|
358 |
+
points_px_circles_pc = cam.project_point2ndc(points3d_circles_pc, lens_distortion=False)
|
359 |
+
|
360 |
+
if batch_size < cam.batch_dim: # actual batch_size smaller than expected, i.e. last batch
|
361 |
+
points_px_circles_pc = points_px_circles_pc[:batch_size]
|
362 |
+
|
363 |
+
if self.lens_distortion_active:
|
364 |
+
# (batch_size, T_c, _, S_c, N_c)
|
365 |
+
points_px_circles_true = points_px_circles_true.view(
|
366 |
+
batch_size, T_c, 3, S_c * N_c
|
367 |
+
).transpose(2, 3)
|
368 |
+
points_px_circles_true = points_px_circles_true.detach().clone()
|
369 |
+
points_px_circles_true[..., :2] = cam.undistort_points(
|
370 |
+
points_px_circles_true[..., :2], cam.intrinsics_ndc, num_iters=1
|
371 |
+
)
|
372 |
+
points_px_circles_true = points_px_circles_true.transpose(2, 3).view(
|
373 |
+
batch_size, T_c, 3, S_c, N_c
|
374 |
+
)
|
375 |
+
|
376 |
+
distances_px_circles_raw = distance_point_pointcloud(
|
377 |
+
points_px_circles_true, points_px_circles_pc.view(batch_size, T_c, S_c, N_c_star, 2)
|
378 |
+
)
|
379 |
+
|
380 |
+
distances_dict = {
|
381 |
+
"loss_ndc_lines": distances_px_lines_raw, # (batch_size, T_l, 1, S_l, N_l)
|
382 |
+
"loss_ndc_circles": distances_px_circles_raw, # (batch_size, T_c, 1, S_c, N_c)
|
383 |
+
}
|
384 |
+
return distances_dict, cam
|
385 |
+
|
386 |
+
def self_optim_batch(self, x, *args, **kwargs):
|
387 |
+
|
388 |
+
scheduler = self.Scheduler(self.optim) # re-initialize lr scheduler for every batch
|
389 |
+
if self.lens_distortion_active:
|
390 |
+
scheduler_lens_distortion = self.Scheduler_lens_distortion()
|
391 |
+
|
392 |
+
# Initialiser avec les paramètres précédents si disponibles
|
393 |
+
if self.previous_params is not None:
|
394 |
+
print("Utilisation des paramètres précédents pour l'initialisation")
|
395 |
+
update_dict = {}
|
396 |
+
for k, v in self.previous_params.items():
|
397 |
+
update_dict[k] = v.detach().clone()
|
398 |
+
self.cam_param_dict.initialize(update_dict)
|
399 |
+
else:
|
400 |
+
print("Première frame : initialisation à zéro")
|
401 |
+
self.cam_param_dict.initialize(None)
|
402 |
+
|
403 |
+
self.optim.zero_grad()
|
404 |
+
if self.lens_distortion_active:
|
405 |
+
self.optim_lens_distortion.zero_grad()
|
406 |
+
|
407 |
+
keypoint_masks = {
|
408 |
+
"loss_ndc_lines": x["lines__is_keypoint_mask"].to(self._device),
|
409 |
+
"loss_ndc_circles": x["circles__is_keypoint_mask"].to(self._device),
|
410 |
+
}
|
411 |
+
num_actual_points = {
|
412 |
+
"loss_ndc_circles": keypoint_masks["loss_ndc_circles"].sum(dim=(-1, -2)),
|
413 |
+
"loss_ndc_lines": keypoint_masks["loss_ndc_lines"].sum(dim=(-1, -2)),
|
414 |
+
}
|
415 |
+
|
416 |
+
per_sample_loss = {}
|
417 |
+
per_sample_loss["mask_lines"] = keypoint_masks["loss_ndc_lines"]
|
418 |
+
per_sample_loss["mask_circles"] = keypoint_masks["loss_ndc_circles"]
|
419 |
+
|
420 |
+
per_step_info = {"loss": [], "lr": []}
|
421 |
+
|
422 |
+
# Paramètres pour les critères d'arrêt
|
423 |
+
loss_target = 0.001 # Réduit pour une meilleure précision potentielle
|
424 |
+
loss_patience = 10 # Nombre d'itérations pour vérifier la stagnation
|
425 |
+
loss_tolerance = 1e-4 # Tolérance pour la variation relative de loss
|
426 |
+
loss_history = [] # Historique des valeurs de loss
|
427 |
+
best_loss = float('inf') # Meilleure loss obtenue
|
428 |
+
steps_without_improvement = 0 # Compteur d'itérations sans amélioration
|
429 |
+
|
430 |
+
# with torch.autograd.detect_anomaly():
|
431 |
+
with tqdm(range(self.optim_steps), **self.tqdm_kwqargs) as pbar:
|
432 |
+
for step in pbar:
|
433 |
+
self.optim.zero_grad()
|
434 |
+
if self.lens_distortion_active:
|
435 |
+
self.optim_lens_distortion.zero_grad()
|
436 |
+
|
437 |
+
# forward pass
|
438 |
+
distances_dict, cam = self(x)
|
439 |
+
|
440 |
+
# distance calculate with masked input and output
|
441 |
+
losses = {}
|
442 |
+
for key_dist, distances in distances_dict.items():
|
443 |
+
distances[~keypoint_masks[key_dist]] = 0.0
|
444 |
+
per_sample_loss[f"{key_dist}_distances_raw"] = distances
|
445 |
+
distances_reduced = distances.sum(dim=(-1, -2))
|
446 |
+
distances_reduced = distances_reduced / num_actual_points[key_dist]
|
447 |
+
distances_reduced[num_actual_points[key_dist] == 0] = 0.0
|
448 |
+
distances_reduced = distances_reduced.squeeze(-1)
|
449 |
+
per_sample_loss[key_dist] = distances_reduced
|
450 |
+
loss = distances_reduced.mean(dim=-1)
|
451 |
+
loss = loss.sum()
|
452 |
+
losses[key_dist] = loss
|
453 |
+
|
454 |
+
loss_total_dist = losses["loss_ndc_lines"] + losses["loss_ndc_circles"]
|
455 |
+
loss_total = loss_total_dist
|
456 |
+
current_loss = loss_total.item()
|
457 |
+
|
458 |
+
# Mettre à jour l'historique des loss
|
459 |
+
loss_history.append(current_loss)
|
460 |
+
|
461 |
+
# Vérifier si on a une meilleure loss
|
462 |
+
if current_loss < best_loss:
|
463 |
+
best_loss = current_loss
|
464 |
+
steps_without_improvement = 0
|
465 |
+
else:
|
466 |
+
steps_without_improvement += 1
|
467 |
+
|
468 |
+
# Critères d'arrêt (commentés pour forcer le nombre total d'étapes)
|
469 |
+
# if len(loss_history) >= loss_patience:
|
470 |
+
# # Calculer la variation relative moyenne sur les dernières itérations
|
471 |
+
# recent_losses = loss_history[-loss_patience:]
|
472 |
+
# # Gérer le cas où toutes les pertes récentes sont nulles ou proches de zéro
|
473 |
+
# max_recent_loss = max(max(recent_losses), 1e-9) # Evite division par zéro
|
474 |
+
# loss_variation = abs(max(recent_losses) - min(recent_losses)) / max_recent_loss
|
475 |
+
#
|
476 |
+
# # Conditions d'arrêt
|
477 |
+
# if (current_loss <= loss_target or # On a atteint la valeur cible
|
478 |
+
# loss_variation < loss_tolerance or # La loss ne varie plus significativement
|
479 |
+
# steps_without_improvement >= loss_patience): # Pas d'amélioration depuis un moment
|
480 |
+
# print(f"\nArrêt anticipé à l'itération {step+1}:")
|
481 |
+
# print(f"Loss finale: {current_loss:.5f}")
|
482 |
+
# print(f"Meilleure loss: {best_loss:.5f}")
|
483 |
+
# print(f"Variation relative: {loss_variation:.6f}")
|
484 |
+
# break
|
485 |
+
|
486 |
+
if self.log_per_step:
|
487 |
+
per_step_info["lr"].append(scheduler.get_last_lr())
|
488 |
+
per_step_info["loss"].append(distances_reduced)
|
489 |
+
if step % 50 == 0:
|
490 |
+
pbar.set_postfix(
|
491 |
+
loss=f"{loss_total_dist.detach().cpu().tolist():.5f}",
|
492 |
+
loss_lines=f'{losses["loss_ndc_lines"].detach().cpu().tolist():.3f}',
|
493 |
+
loss_circles=f'{losses["loss_ndc_circles"].detach().cpu().tolist():.3f}',
|
494 |
+
)
|
495 |
+
|
496 |
+
loss_total.backward()
|
497 |
+
self.optim.step()
|
498 |
+
scheduler.step()
|
499 |
+
if self.lens_distortion_active:
|
500 |
+
self.optim_lens_distortion.step()
|
501 |
+
scheduler_lens_distortion.step()
|
502 |
+
|
503 |
+
# Sauvegarder les paramètres optimisés pour la prochaine frame
|
504 |
+
self.previous_params = {}
|
505 |
+
for k, v in self.cam_param_dict.param_dict.items():
|
506 |
+
self.previous_params[k] = v.detach().clone()
|
507 |
+
|
508 |
+
per_sample_loss["loss_ndc_total"] = torch.sum(
|
509 |
+
torch.stack([per_sample_loss[key_dist] for key_dist in distances_dict.keys()], dim=0),
|
510 |
+
dim=0,
|
511 |
+
)
|
512 |
+
|
513 |
+
if self.log_per_step:
|
514 |
+
per_step_info["loss"] = torch.stack(
|
515 |
+
per_step_info["loss"], dim=-1
|
516 |
+
)
|
517 |
+
per_step_info["lr"] = torch.tensor(per_step_info["lr"])
|
518 |
+
return per_sample_loss, cam, per_step_info
|
tvcalib/models/segmentation.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
from pathlib import Path
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from torchvision.models.segmentation import deeplabv3_resnet101
|
6 |
+
from SoccerNet.Evaluation.utils_calibration import SoccerPitch
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
class InferenceSegmentationModel:
|
11 |
+
def __init__(self, checkpoint: Union[str, Path], device) -> None:
|
12 |
+
self.device = device
|
13 |
+
self.model = deeplabv3_resnet101(
|
14 |
+
num_classes=len(SoccerPitch.lines_classes) + 1, aux_loss=True
|
15 |
+
)
|
16 |
+
checkpoint_data = torch.load(checkpoint, map_location=self.device, weights_only=False)
|
17 |
+
self.model.load_state_dict(checkpoint_data["model"], strict=False)
|
18 |
+
self.model.to(self.device)
|
19 |
+
self.model.eval()
|
20 |
+
|
21 |
+
def inference(self, img_batch):
|
22 |
+
return self.model(img_batch)["out"].argmax(1)
|
tvcalib/sn_segmentation/resources/mean.npy
ADDED
Binary file (152 Bytes). View file
|
|
tvcalib/sn_segmentation/resources/std.npy
ADDED
Binary file (152 Bytes). View file
|
|
tvcalib/sn_segmentation/src/baseline_extremities.py
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import copy
|
3 |
+
import json
|
4 |
+
import os.path
|
5 |
+
import random
|
6 |
+
from collections import deque
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
import cv2 as cv
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import torch.backends.cudnn
|
13 |
+
import torch.nn as nn
|
14 |
+
from PIL import Image
|
15 |
+
from torchvision.models.segmentation import deeplabv3_resnet50
|
16 |
+
from tqdm import tqdm
|
17 |
+
|
18 |
+
from SoccerNet.Evaluation.utils_calibration import SoccerPitch
|
19 |
+
|
20 |
+
|
21 |
+
def generate_class_synthesis(semantic_mask, radius):
|
22 |
+
"""
|
23 |
+
This function selects for each class present in the semantic mask, a set of circles that cover most of the semantic
|
24 |
+
class blobs.
|
25 |
+
:param semantic_mask: a image containing the segmentation predictions
|
26 |
+
:param radius: circle radius
|
27 |
+
:return: a dictionary which associates with each class detected a list of points ( the circles centers)
|
28 |
+
"""
|
29 |
+
buckets = dict()
|
30 |
+
kernel = np.ones((5, 5), np.uint8)
|
31 |
+
semantic_mask = cv.erode(semantic_mask, kernel, iterations=1)
|
32 |
+
for k, class_name in enumerate(SoccerPitch.lines_classes):
|
33 |
+
mask = semantic_mask == k + 1
|
34 |
+
if mask.sum() > 0:
|
35 |
+
disk_list = synthesize_mask(mask, radius)
|
36 |
+
if len(disk_list):
|
37 |
+
buckets[class_name] = disk_list
|
38 |
+
|
39 |
+
return buckets
|
40 |
+
|
41 |
+
|
42 |
+
def join_points(point_list, maxdist):
|
43 |
+
"""
|
44 |
+
Given a list of points that were extracted from the blobs belonging to a same semantic class, this function creates
|
45 |
+
polylines by linking close points together if their distance is below the maxdist threshold.
|
46 |
+
:param point_list: List of points of the same line class
|
47 |
+
:param maxdist: minimal distance between two polylines.
|
48 |
+
:return: a list of polylines
|
49 |
+
"""
|
50 |
+
polylines = []
|
51 |
+
|
52 |
+
if not len(point_list):
|
53 |
+
return polylines
|
54 |
+
head = point_list[0]
|
55 |
+
tail = point_list[0]
|
56 |
+
polyline = deque()
|
57 |
+
polyline.append(point_list[0])
|
58 |
+
remaining_points = copy.deepcopy(point_list[1:])
|
59 |
+
|
60 |
+
while len(remaining_points) > 0:
|
61 |
+
min_dist_tail = 1000
|
62 |
+
min_dist_head = 1000
|
63 |
+
best_head = -1
|
64 |
+
best_tail = -1
|
65 |
+
for j, point in enumerate(remaining_points):
|
66 |
+
dist_tail = np.sqrt(np.sum(np.square(point - tail)))
|
67 |
+
dist_head = np.sqrt(np.sum(np.square(point - head)))
|
68 |
+
if dist_tail < min_dist_tail:
|
69 |
+
min_dist_tail = dist_tail
|
70 |
+
best_tail = j
|
71 |
+
if dist_head < min_dist_head:
|
72 |
+
min_dist_head = dist_head
|
73 |
+
best_head = j
|
74 |
+
|
75 |
+
if min_dist_head <= min_dist_tail and min_dist_head < maxdist:
|
76 |
+
polyline.appendleft(remaining_points[best_head])
|
77 |
+
head = polyline[0]
|
78 |
+
remaining_points.pop(best_head)
|
79 |
+
elif min_dist_tail < min_dist_head and min_dist_tail < maxdist:
|
80 |
+
polyline.append(remaining_points[best_tail])
|
81 |
+
tail = polyline[-1]
|
82 |
+
remaining_points.pop(best_tail)
|
83 |
+
else:
|
84 |
+
polylines.append(list(polyline.copy()))
|
85 |
+
head = remaining_points[0]
|
86 |
+
tail = remaining_points[0]
|
87 |
+
polyline = deque()
|
88 |
+
polyline.append(head)
|
89 |
+
remaining_points.pop(0)
|
90 |
+
polylines.append(list(polyline))
|
91 |
+
return polylines
|
92 |
+
|
93 |
+
|
94 |
+
def get_line_extremities(buckets, maxdist, width, height):
|
95 |
+
"""
|
96 |
+
Given the dictionary {lines_class: points}, finds plausible extremities of each line, i.e the extremities
|
97 |
+
of the longest polyline that can be built on the class blobs, and normalize its coordinates
|
98 |
+
by the image size.
|
99 |
+
:param buckets: The dictionary associating line classes to the set of circle centers that covers best the class
|
100 |
+
prediction blobs in the segmentation mask
|
101 |
+
:param maxdist: the maximal distance between two circle centers belonging to the same blob (heuristic)
|
102 |
+
:param width: image width
|
103 |
+
:param height: image height
|
104 |
+
:return: a dictionary associating to each class its extremities
|
105 |
+
"""
|
106 |
+
extremities = dict()
|
107 |
+
for class_name, disks_list in buckets.items():
|
108 |
+
polyline_list = join_points(disks_list, maxdist)
|
109 |
+
max_len = 0
|
110 |
+
longest_polyline = []
|
111 |
+
for polyline in polyline_list:
|
112 |
+
if len(polyline) > max_len:
|
113 |
+
max_len = len(polyline)
|
114 |
+
longest_polyline = polyline
|
115 |
+
extremities[class_name] = [
|
116 |
+
{'x': longest_polyline[0][1] / width, 'y': longest_polyline[0][0] / height},
|
117 |
+
{'x': longest_polyline[-1][1] / width, 'y': longest_polyline[-1][0] / height}
|
118 |
+
]
|
119 |
+
return extremities
|
120 |
+
|
121 |
+
|
122 |
+
def get_support_center(mask, start, disk_radius, min_support=0.1):
|
123 |
+
"""
|
124 |
+
Returns the barycenter of the True pixels under the area of the mask delimited by the circle of center start and
|
125 |
+
radius of disk_radius pixels.
|
126 |
+
:param mask: Boolean mask
|
127 |
+
:param start: A point located on a true pixel of the mask
|
128 |
+
:param disk_radius: the radius of the circles
|
129 |
+
:param min_support: proportion of the area under the circle area that should be True in order to get enough support
|
130 |
+
:return: A boolean indicating if there is enough support in the circle area, the barycenter of the True pixels under
|
131 |
+
the circle
|
132 |
+
"""
|
133 |
+
x = int(start[0])
|
134 |
+
y = int(start[1])
|
135 |
+
support_pixels = 1
|
136 |
+
result = [x, y]
|
137 |
+
xstart = x - disk_radius
|
138 |
+
if xstart < 0:
|
139 |
+
xstart = 0
|
140 |
+
xend = x + disk_radius
|
141 |
+
if xend > mask.shape[0]:
|
142 |
+
xend = mask.shape[0] - 1
|
143 |
+
|
144 |
+
ystart = y - disk_radius
|
145 |
+
if ystart < 0:
|
146 |
+
ystart = 0
|
147 |
+
yend = y + disk_radius
|
148 |
+
if yend > mask.shape[1]:
|
149 |
+
yend = mask.shape[1] - 1
|
150 |
+
|
151 |
+
for i in range(xstart, xend + 1):
|
152 |
+
for j in range(ystart, yend + 1):
|
153 |
+
dist = np.sqrt(np.square(x - i) + np.square(y - j))
|
154 |
+
if dist < disk_radius and mask[i, j] > 0:
|
155 |
+
support_pixels += 1
|
156 |
+
result[0] += i
|
157 |
+
result[1] += j
|
158 |
+
support = True
|
159 |
+
if support_pixels < min_support * np.square(disk_radius) * np.pi:
|
160 |
+
support = False
|
161 |
+
|
162 |
+
result = np.array(result)
|
163 |
+
result = np.true_divide(result, support_pixels)
|
164 |
+
|
165 |
+
return support, result
|
166 |
+
|
167 |
+
|
168 |
+
def synthesize_mask(semantic_mask, disk_radius):
|
169 |
+
"""
|
170 |
+
Fits circles on the True pixels of the mask and returns those which have enough support : meaning that the
|
171 |
+
proportion of the area of the circle covering True pixels is higher that a certain threshold in order to avoid
|
172 |
+
fitting circles on alone pixels.
|
173 |
+
:param semantic_mask: boolean mask
|
174 |
+
:param disk_radius: radius of the circles
|
175 |
+
:return: a list of disk centers, that have enough support
|
176 |
+
"""
|
177 |
+
mask = semantic_mask.copy().astype(np.uint8)
|
178 |
+
points = np.transpose(np.nonzero(mask))
|
179 |
+
disks = []
|
180 |
+
while len(points):
|
181 |
+
|
182 |
+
start = random.choice(points)
|
183 |
+
dist = 10.
|
184 |
+
success = True
|
185 |
+
while dist > 1.:
|
186 |
+
enough_support, center = get_support_center(mask, start, disk_radius)
|
187 |
+
if not enough_support:
|
188 |
+
bad_point = np.round(center).astype(np.int32)
|
189 |
+
cv.circle(mask, (bad_point[1], bad_point[0]), disk_radius, (0), -1)
|
190 |
+
success = False
|
191 |
+
dist = np.sqrt(np.sum(np.square(center - start)))
|
192 |
+
start = center
|
193 |
+
if success:
|
194 |
+
disks.append(np.round(start).astype(np.int32))
|
195 |
+
cv.circle(mask, (disks[-1][1], disks[-1][0]), disk_radius, 0, -1)
|
196 |
+
points = np.transpose(np.nonzero(mask))
|
197 |
+
|
198 |
+
return disks
|
199 |
+
|
200 |
+
|
201 |
+
class SegmentationNetwork:
|
202 |
+
def __init__(self, model_file, mean_file, std_file, num_classes=29, width=640, height=360):
|
203 |
+
file_path = Path(model_file).resolve()
|
204 |
+
model = nn.DataParallel(deeplabv3_resnet50(pretrained=False, num_classes=num_classes))
|
205 |
+
self.init_weight(model, nn.init.kaiming_normal_,
|
206 |
+
nn.BatchNorm2d, 1e-3, 0.1,
|
207 |
+
mode='fan_in')
|
208 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
209 |
+
checkpoint = torch.load(str(file_path), map_location=self.device)
|
210 |
+
model.load_state_dict(checkpoint["model"])
|
211 |
+
model.eval()
|
212 |
+
self.model = model.to(self.device)
|
213 |
+
file_path = Path(mean_file).resolve()
|
214 |
+
self.mean = np.load(str(file_path))
|
215 |
+
file_path = Path(std_file).resolve()
|
216 |
+
self.std = np.load(str(file_path))
|
217 |
+
self.width = width
|
218 |
+
self.height = height
|
219 |
+
|
220 |
+
def init_weight(self, feature, conv_init, norm_layer, bn_eps, bn_momentum,
|
221 |
+
**kwargs):
|
222 |
+
for name, m in feature.named_modules():
|
223 |
+
if isinstance(m, (nn.Conv2d, nn.Conv3d)):
|
224 |
+
conv_init(m.weight, **kwargs)
|
225 |
+
elif isinstance(m, norm_layer):
|
226 |
+
m.eps = bn_eps
|
227 |
+
m.momentum = bn_momentum
|
228 |
+
nn.init.constant_(m.weight, 1)
|
229 |
+
nn.init.constant_(m.bias, 0)
|
230 |
+
|
231 |
+
def analyse_image(self, image):
|
232 |
+
"""
|
233 |
+
Process image and perform inference, returns mask of detected classes
|
234 |
+
:param image: BGR image
|
235 |
+
:return: predicted classes mask
|
236 |
+
"""
|
237 |
+
img = cv.resize(image, (self.width, self.height), interpolation=cv.INTER_LINEAR)
|
238 |
+
img = np.asarray(img, np.float32) / 255.
|
239 |
+
img = (img - self.mean) / self.std
|
240 |
+
img = img.transpose((2, 0, 1))
|
241 |
+
img = torch.from_numpy(img).to(self.device).unsqueeze(0)
|
242 |
+
|
243 |
+
cuda_result = self.model.forward(img.float())
|
244 |
+
output = cuda_result['out'].data[0].cpu().numpy()
|
245 |
+
output = output.transpose(1, 2, 0)
|
246 |
+
output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
|
247 |
+
|
248 |
+
return output
|
249 |
+
|
250 |
+
|
251 |
+
if __name__ == "__main__":
|
252 |
+
parser = argparse.ArgumentParser(description='Test')
|
253 |
+
|
254 |
+
parser.add_argument('-s', '--soccernet', default="./annotations/", type=str,
|
255 |
+
help='Path to the SoccerNet-V3 dataset folder')
|
256 |
+
parser.add_argument('-p', '--prediction', default="./results_bis", required=False, type=str,
|
257 |
+
help="Path to the prediction folder")
|
258 |
+
parser.add_argument('--split', required=False, type=str, default="test", help='Select the split of data')
|
259 |
+
parser.add_argument('--masks', required=False, type=bool, default=False, help='Save masks in prediction directory')
|
260 |
+
parser.add_argument('--resolution_width', required=False, type=int, default=455,
|
261 |
+
help='width resolution of the images')
|
262 |
+
parser.add_argument('--resolution_height', required=False, type=int, default=256,
|
263 |
+
help='height resolution of the images')
|
264 |
+
parser.add_argument('--checkpoint_dir', default="resources")
|
265 |
+
args = parser.parse_args()
|
266 |
+
|
267 |
+
lines_palette = [0, 0, 0]
|
268 |
+
for line_class in SoccerPitch.lines_classes:
|
269 |
+
lines_palette.extend(SoccerPitch.palette[line_class])
|
270 |
+
|
271 |
+
calib_net = SegmentationNetwork(
|
272 |
+
os.path.join(args.checkpoint_dir, "soccer_pitch_segmentation.pth"),
|
273 |
+
os.path.join(args.checkpoint_dir, "mean.npy"),
|
274 |
+
os.path.join(args.checkpoint_dir, "std.npy")
|
275 |
+
)
|
276 |
+
|
277 |
+
dataset_dir = os.path.join(args.soccernet, args.split)
|
278 |
+
if not os.path.exists(dataset_dir):
|
279 |
+
print("Invalid dataset path !")
|
280 |
+
exit(-1)
|
281 |
+
|
282 |
+
frames = [f for f in os.listdir(dataset_dir) if ".jpg" in f]
|
283 |
+
with tqdm(enumerate(frames), total=len(frames), ncols=160) as t:
|
284 |
+
for i, frame in t:
|
285 |
+
|
286 |
+
output_prediction_folder = os.path.join(args.prediction, args.split)
|
287 |
+
if not os.path.exists(output_prediction_folder):
|
288 |
+
os.makedirs(output_prediction_folder)
|
289 |
+
prediction = dict()
|
290 |
+
count = 0
|
291 |
+
|
292 |
+
frame_path = os.path.join(dataset_dir, frame)
|
293 |
+
|
294 |
+
frame_index = frame.split(".")[0]
|
295 |
+
|
296 |
+
image = cv.imread(frame_path)
|
297 |
+
semlines = calib_net.analyse_image(image)
|
298 |
+
if args.masks:
|
299 |
+
mask = Image.fromarray(semlines.astype(np.uint8)).convert('P')
|
300 |
+
mask.putpalette(lines_palette)
|
301 |
+
mask_file = os.path.join(output_prediction_folder, frame)
|
302 |
+
mask.convert("RGB").save(mask_file)
|
303 |
+
skeletons = generate_class_synthesis(semlines, 6)
|
304 |
+
extremities = get_line_extremities(skeletons, 40, args.resolution_width, args.resolution_height)
|
305 |
+
|
306 |
+
prediction = extremities
|
307 |
+
count += 1
|
308 |
+
|
309 |
+
prediction_file = os.path.join(output_prediction_folder, f"extremities_{frame_index}.json")
|
310 |
+
with open(prediction_file, "w") as f:
|
311 |
+
json.dump(prediction, f, indent=4)
|
tvcalib/sn_segmentation/src/custom_extremities.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import copy
|
3 |
+
import itertools
|
4 |
+
import json
|
5 |
+
import os.path
|
6 |
+
import random
|
7 |
+
from collections import deque
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
from pytorch_lightning import seed_everything
|
11 |
+
|
12 |
+
seed_everything(seed=10, workers=True)
|
13 |
+
|
14 |
+
import cv2 as cv
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
import torch.backends.cudnn
|
18 |
+
import torch.nn as nn
|
19 |
+
import torchvision.transforms as T
|
20 |
+
|
21 |
+
from PIL import Image
|
22 |
+
from torchvision.models.segmentation import deeplabv3_resnet101
|
23 |
+
from tqdm import tqdm
|
24 |
+
|
25 |
+
from SoccerNet.Evaluation.utils_calibration import SoccerPitch
|
26 |
+
|
27 |
+
|
28 |
+
def generate_class_synthesis(semantic_mask, radius):
|
29 |
+
"""
|
30 |
+
This function selects for each class present in the semantic mask, a set of circles that cover most of the semantic
|
31 |
+
class blobs.
|
32 |
+
:param semantic_mask: a image containing the segmentation predictions
|
33 |
+
:param radius: circle radius
|
34 |
+
:return: a dictionary which associates with each class detected a list of points ( the circles centers)
|
35 |
+
"""
|
36 |
+
buckets = dict()
|
37 |
+
kernel = np.ones((5, 5), np.uint8)
|
38 |
+
semantic_mask = cv.erode(semantic_mask, kernel, iterations=1)
|
39 |
+
for k, class_name in enumerate(SoccerPitch.lines_classes):
|
40 |
+
mask = semantic_mask == k + 1
|
41 |
+
if mask.sum() > 0:
|
42 |
+
disk_list = synthesize_mask(mask, radius)
|
43 |
+
if len(disk_list):
|
44 |
+
buckets[class_name] = disk_list
|
45 |
+
|
46 |
+
return buckets
|
47 |
+
|
48 |
+
|
49 |
+
def join_points(point_list, maxdist):
|
50 |
+
"""
|
51 |
+
Given a list of points that were extracted from the blobs belonging to a same semantic class, this function creates
|
52 |
+
polylines by linking close points together if their distance is below the maxdist threshold.
|
53 |
+
:param point_list: List of points of the same line class
|
54 |
+
:param maxdist: minimal distance between two polylines.
|
55 |
+
:return: a list of polylines
|
56 |
+
"""
|
57 |
+
polylines = []
|
58 |
+
|
59 |
+
if not len(point_list):
|
60 |
+
return polylines
|
61 |
+
head = point_list[0]
|
62 |
+
tail = point_list[0]
|
63 |
+
polyline = deque()
|
64 |
+
polyline.append(point_list[0])
|
65 |
+
remaining_points = copy.deepcopy(point_list[1:])
|
66 |
+
|
67 |
+
while len(remaining_points) > 0:
|
68 |
+
min_dist_tail = 1000
|
69 |
+
min_dist_head = 1000
|
70 |
+
best_head = -1
|
71 |
+
best_tail = -1
|
72 |
+
for j, point in enumerate(remaining_points):
|
73 |
+
dist_tail = np.sqrt(np.sum(np.square(point - tail)))
|
74 |
+
dist_head = np.sqrt(np.sum(np.square(point - head)))
|
75 |
+
if dist_tail < min_dist_tail:
|
76 |
+
min_dist_tail = dist_tail
|
77 |
+
best_tail = j
|
78 |
+
if dist_head < min_dist_head:
|
79 |
+
min_dist_head = dist_head
|
80 |
+
best_head = j
|
81 |
+
|
82 |
+
if min_dist_head <= min_dist_tail and min_dist_head < maxdist:
|
83 |
+
polyline.appendleft(remaining_points[best_head])
|
84 |
+
head = polyline[0]
|
85 |
+
remaining_points.pop(best_head)
|
86 |
+
elif min_dist_tail < min_dist_head and min_dist_tail < maxdist:
|
87 |
+
polyline.append(remaining_points[best_tail])
|
88 |
+
tail = polyline[-1]
|
89 |
+
remaining_points.pop(best_tail)
|
90 |
+
else:
|
91 |
+
polylines.append(list(polyline.copy()))
|
92 |
+
head = remaining_points[0]
|
93 |
+
tail = remaining_points[0]
|
94 |
+
polyline = deque()
|
95 |
+
polyline.append(head)
|
96 |
+
remaining_points.pop(0)
|
97 |
+
polylines.append(list(polyline))
|
98 |
+
return polylines
|
99 |
+
|
100 |
+
|
101 |
+
def get_line_extremities(buckets, maxdist, width, height, num_points_lines, num_points_circles):
|
102 |
+
"""
|
103 |
+
Given the dictionary {lines_class: points}, finds plausible extremities of each line, i.e the extremities
|
104 |
+
of the longest polyline that can be built on the class blobs, and normalize its coordinates
|
105 |
+
by the image size.
|
106 |
+
:param buckets: The dictionary associating line classes to the set of circle centers that covers best the class
|
107 |
+
prediction blobs in the segmentation mask
|
108 |
+
:param maxdist: the maximal distance between two circle centers belonging to the same blob (heuristic)
|
109 |
+
:param width: image width
|
110 |
+
:param height: image height
|
111 |
+
:return: a dictionary associating to each class its extremities
|
112 |
+
"""
|
113 |
+
extremities = dict()
|
114 |
+
for class_name, disks_list in buckets.items():
|
115 |
+
polyline_list = join_points(disks_list, maxdist)
|
116 |
+
max_len = 0
|
117 |
+
longest_polyline = []
|
118 |
+
for polyline in polyline_list:
|
119 |
+
if len(polyline) > max_len:
|
120 |
+
max_len = len(polyline)
|
121 |
+
longest_polyline = polyline
|
122 |
+
extremities[class_name] = [
|
123 |
+
{'x': longest_polyline[0][1] / width, 'y': longest_polyline[0][0] / height},
|
124 |
+
{'x': longest_polyline[-1][1] / width, 'y': longest_polyline[-1][0] / height},
|
125 |
+
|
126 |
+
]
|
127 |
+
num_points = num_points_lines
|
128 |
+
if "Circle" in class_name:
|
129 |
+
num_points = num_points_circles
|
130 |
+
if num_points > 2:
|
131 |
+
# equally spaced points along the longest polyline
|
132 |
+
# skip first and last as they already exist
|
133 |
+
for i in range(1, num_points - 1):
|
134 |
+
extremities[class_name].insert(
|
135 |
+
len(extremities[class_name]) - 1,
|
136 |
+
{'x': longest_polyline[i * int(len(longest_polyline) / num_points)][1] / width, 'y': longest_polyline[i * int(len(longest_polyline) / num_points)][0] / height}
|
137 |
+
)
|
138 |
+
|
139 |
+
return extremities
|
140 |
+
|
141 |
+
|
142 |
+
def get_support_center(mask, start, disk_radius, min_support=0.1):
|
143 |
+
"""
|
144 |
+
Returns the barycenter of the True pixels under the area of the mask delimited by the circle of center start and
|
145 |
+
radius of disk_radius pixels.
|
146 |
+
:param mask: Boolean mask
|
147 |
+
:param start: A point located on a true pixel of the mask
|
148 |
+
:param disk_radius: the radius of the circles
|
149 |
+
:param min_support: proportion of the area under the circle area that should be True in order to get enough support
|
150 |
+
:return: A boolean indicating if there is enough support in the circle area, the barycenter of the True pixels under
|
151 |
+
the circle
|
152 |
+
"""
|
153 |
+
x = int(start[0])
|
154 |
+
y = int(start[1])
|
155 |
+
support_pixels = 1
|
156 |
+
result = [x, y]
|
157 |
+
xstart = x - disk_radius
|
158 |
+
if xstart < 0:
|
159 |
+
xstart = 0
|
160 |
+
xend = x + disk_radius
|
161 |
+
if xend > mask.shape[0]:
|
162 |
+
xend = mask.shape[0] - 1
|
163 |
+
|
164 |
+
ystart = y - disk_radius
|
165 |
+
if ystart < 0:
|
166 |
+
ystart = 0
|
167 |
+
yend = y + disk_radius
|
168 |
+
if yend > mask.shape[1]:
|
169 |
+
yend = mask.shape[1] - 1
|
170 |
+
|
171 |
+
for i in range(xstart, xend + 1):
|
172 |
+
for j in range(ystart, yend + 1):
|
173 |
+
dist = np.sqrt(np.square(x - i) + np.square(y - j))
|
174 |
+
if dist < disk_radius and mask[i, j] > 0:
|
175 |
+
support_pixels += 1
|
176 |
+
result[0] += i
|
177 |
+
result[1] += j
|
178 |
+
support = True
|
179 |
+
if support_pixels < min_support * np.square(disk_radius) * np.pi:
|
180 |
+
support = False
|
181 |
+
|
182 |
+
result = np.array(result)
|
183 |
+
result = np.true_divide(result, support_pixels)
|
184 |
+
|
185 |
+
return support, result
|
186 |
+
|
187 |
+
|
188 |
+
def synthesize_mask(semantic_mask, disk_radius):
|
189 |
+
"""
|
190 |
+
Fits circles on the True pixels of the mask and returns those which have enough support : meaning that the
|
191 |
+
proportion of the area of the circle covering True pixels is higher that a certain threshold in order to avoid
|
192 |
+
fitting circles on alone pixels.
|
193 |
+
:param semantic_mask: boolean mask
|
194 |
+
:param disk_radius: radius of the circles
|
195 |
+
:return: a list of disk centers, that have enough support
|
196 |
+
"""
|
197 |
+
mask = semantic_mask.copy().astype(np.uint8)
|
198 |
+
points = np.transpose(np.nonzero(mask))
|
199 |
+
disks = []
|
200 |
+
while len(points):
|
201 |
+
|
202 |
+
start = random.choice(points)
|
203 |
+
dist = 10.
|
204 |
+
success = True
|
205 |
+
while dist > 1.:
|
206 |
+
enough_support, center = get_support_center(mask, start, disk_radius)
|
207 |
+
if not enough_support:
|
208 |
+
bad_point = np.round(center).astype(np.int32)
|
209 |
+
cv.circle(mask, (bad_point[1], bad_point[0]), disk_radius, (0), -1)
|
210 |
+
success = False
|
211 |
+
dist = np.sqrt(np.sum(np.square(center - start)))
|
212 |
+
start = center
|
213 |
+
if success:
|
214 |
+
disks.append(np.round(start).astype(np.int32))
|
215 |
+
cv.circle(mask, (disks[-1][1], disks[-1][0]), disk_radius, 0, -1)
|
216 |
+
points = np.transpose(np.nonzero(mask))
|
217 |
+
|
218 |
+
return disks
|
219 |
+
|
220 |
+
class CustomNetwork:
|
221 |
+
|
222 |
+
def __init__(self, checkpoint):
|
223 |
+
print("Loading model" + checkpoint)
|
224 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
225 |
+
self.model = deeplabv3_resnet101(num_classes=len(SoccerPitch.lines_classes) + 1, aux_loss=True)
|
226 |
+
self.model.load_state_dict(torch.load(checkpoint)["model"], strict=False)
|
227 |
+
self.model.to(self.device)
|
228 |
+
self.model.eval()
|
229 |
+
print("using", self.device)
|
230 |
+
|
231 |
+
def forward(self, img):
|
232 |
+
trf = T.Compose(
|
233 |
+
[
|
234 |
+
T.Resize(256),
|
235 |
+
#T.CenterCrop(224),
|
236 |
+
T.ToTensor(),
|
237 |
+
T.Normalize(
|
238 |
+
mean = [0.485, 0.456, 0.406],
|
239 |
+
std = [0.229, 0.224, 0.225]
|
240 |
+
)
|
241 |
+
]
|
242 |
+
)
|
243 |
+
img = trf(img).unsqueeze(0).to(self.device)
|
244 |
+
result = self.model(img)["out"].detach().squeeze(0).argmax(0)
|
245 |
+
result = result.cpu().numpy().astype(np.uint8)
|
246 |
+
#print(result)
|
247 |
+
return result
|
248 |
+
|
249 |
+
if __name__ == "__main__":
|
250 |
+
parser = argparse.ArgumentParser(description='Test')
|
251 |
+
|
252 |
+
parser.add_argument('-s', '--soccernet', default="/nfs/data/soccernet/calibration/", type=str,
|
253 |
+
help='Path to the SoccerNet-V3 dataset folder')
|
254 |
+
parser.add_argument('-p', '--prediction', default="sn-calib-test_endpoints", required=False, type=str,
|
255 |
+
help="Path to the prediction folder")
|
256 |
+
parser.add_argument('--split', required=False, type=str, default="challenge", help='Select the split of data')
|
257 |
+
parser.add_argument('--masks', required=False, type=bool, default=False, help='Save masks in prediction directory')
|
258 |
+
parser.add_argument('--resolution_width', required=False, type=int, default=455,
|
259 |
+
help='width resolution of the images')
|
260 |
+
parser.add_argument('--resolution_height', required=False, type=int, default=256,
|
261 |
+
help='height resolution of the images')
|
262 |
+
parser.add_argument('--checkpoint', required=False, type=str, help="Path to the custom model checkpoint.")
|
263 |
+
parser.add_argument('--pp_radius', required=False, type=int, default=4,
|
264 |
+
help='Post processing: Radius of circles that cover each segment.')
|
265 |
+
parser.add_argument('--pp_maxdists', required=False, type=int, default=30,
|
266 |
+
help='Post processing: Maximum distance of circles that are allowed within one segment.')
|
267 |
+
parser.add_argument('--num_points_lines', required=False, type=int, default=2, choices=range(2,10),
|
268 |
+
help='Post processing: Number of keypoints that represent a line segment')
|
269 |
+
parser.add_argument('--num_points_circles', required=False, type=int, default=2, choices=range(2,10),
|
270 |
+
help='Post processing: Number of keypoints that represent a circle segment')
|
271 |
+
args = parser.parse_args()
|
272 |
+
|
273 |
+
lines_palette = [0, 0, 0]
|
274 |
+
for line_class in SoccerPitch.lines_classes:
|
275 |
+
lines_palette.extend(SoccerPitch.palette[line_class])
|
276 |
+
|
277 |
+
model = CustomNetwork(args.checkpoint)
|
278 |
+
|
279 |
+
dataset_dir = os.path.join(args.soccernet, args.split)
|
280 |
+
if not os.path.exists(dataset_dir):
|
281 |
+
print("Invalid dataset path !")
|
282 |
+
exit(-1)
|
283 |
+
|
284 |
+
radius = args.pp_radius
|
285 |
+
maxdists = args.pp_maxdists
|
286 |
+
|
287 |
+
frames = [f for f in os.listdir(dataset_dir) if ".jpg" in f]
|
288 |
+
with tqdm(enumerate(frames), total=len(frames), ncols=160) as t:
|
289 |
+
for i, frame in t:
|
290 |
+
|
291 |
+
output_prediction_folder = os.path.join(str(args.prediction), f"np{args.num_points_lines}_nc{args.num_points_circles}_r{radius}_md{maxdists}", args.split)
|
292 |
+
if not os.path.exists(output_prediction_folder):
|
293 |
+
os.makedirs(output_prediction_folder)
|
294 |
+
prediction = dict()
|
295 |
+
count = 0
|
296 |
+
|
297 |
+
frame_path = os.path.join(dataset_dir, frame)
|
298 |
+
|
299 |
+
frame_index = frame.split(".")[0]
|
300 |
+
|
301 |
+
image = Image.open(frame_path)
|
302 |
+
|
303 |
+
semlines = model.forward(image)
|
304 |
+
#print(semlines.shape)
|
305 |
+
# print("\nsemlines", type(semlines), semlines.shape)
|
306 |
+
if args.masks:
|
307 |
+
mask = Image.fromarray(semlines.astype(np.uint8)).convert('P')
|
308 |
+
mask.putpalette(lines_palette)
|
309 |
+
mask_file = os.path.join(output_prediction_folder, frame)
|
310 |
+
mask.convert("RGB").save(mask_file)
|
311 |
+
skeletons = generate_class_synthesis(semlines, radius)
|
312 |
+
|
313 |
+
extremities = get_line_extremities(skeletons, maxdists, args.resolution_width, args.resolution_height, args.num_points_lines, args.num_points_circles)
|
314 |
+
|
315 |
+
|
316 |
+
prediction = extremities
|
317 |
+
count += 1
|
318 |
+
|
319 |
+
prediction_file = os.path.join(output_prediction_folder, f"extremities_{frame_index}.json")
|
320 |
+
with open(prediction_file, "w") as f:
|
321 |
+
json.dump(prediction, f, indent=4)
|
322 |
+
|
tvcalib/sn_segmentation/src/dataloader.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
DataLoader used to train the segmentation network used for the prediction of extremities.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
import time
|
8 |
+
from argparse import ArgumentParser
|
9 |
+
|
10 |
+
import cv2 as cv
|
11 |
+
import numpy as np
|
12 |
+
from torch.utils.data import Dataset
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
from SoccerNet.Evaluation.utils_calibration import SoccerPitch
|
16 |
+
|
17 |
+
|
18 |
+
class SoccerNetDataset(Dataset):
|
19 |
+
def __init__(self,
|
20 |
+
datasetpath,
|
21 |
+
split="test",
|
22 |
+
width=640,
|
23 |
+
height=360,
|
24 |
+
mean="../resources/mean.npy",
|
25 |
+
std="../resources/std.npy"):
|
26 |
+
self.mean = np.load(mean)
|
27 |
+
self.std = np.load(std)
|
28 |
+
self.width = width
|
29 |
+
self.height = height
|
30 |
+
|
31 |
+
dataset_dir = os.path.join(datasetpath, split)
|
32 |
+
if not os.path.exists(dataset_dir):
|
33 |
+
print("Invalid dataset path !")
|
34 |
+
exit(-1)
|
35 |
+
|
36 |
+
frames = [f for f in os.listdir(dataset_dir) if ".jpg" in f]
|
37 |
+
|
38 |
+
self.data = []
|
39 |
+
self.n_samples = 0
|
40 |
+
for frame in frames:
|
41 |
+
|
42 |
+
frame_index = frame.split(".")[0]
|
43 |
+
annotation_file = os.path.join(dataset_dir, f"{frame_index}.json")
|
44 |
+
if not os.path.exists(annotation_file):
|
45 |
+
continue
|
46 |
+
with open(annotation_file, "r") as f:
|
47 |
+
groundtruth_lines = json.load(f)
|
48 |
+
img_path = os.path.join(dataset_dir, frame)
|
49 |
+
if groundtruth_lines:
|
50 |
+
self.data.append({
|
51 |
+
"image_path": img_path,
|
52 |
+
"annotations": groundtruth_lines,
|
53 |
+
})
|
54 |
+
|
55 |
+
def __len__(self):
|
56 |
+
return len(self.data)
|
57 |
+
|
58 |
+
def __getitem__(self, index):
|
59 |
+
item = self.data[index]
|
60 |
+
|
61 |
+
img = cv.imread(item["image_path"])
|
62 |
+
img = cv.resize(img, (self.width, self.height), interpolation=cv.INTER_LINEAR)
|
63 |
+
|
64 |
+
mask = np.zeros(img.shape[:-1], dtype=np.uint8)
|
65 |
+
img = np.asarray(img, np.float32) / 255.
|
66 |
+
img -= self.mean
|
67 |
+
img /= self.std
|
68 |
+
img = img.transpose((2, 0, 1))
|
69 |
+
for class_number, class_ in enumerate(SoccerPitch.lines_classes):
|
70 |
+
if class_ in item["annotations"].keys():
|
71 |
+
key = class_
|
72 |
+
line = item["annotations"][key]
|
73 |
+
prev_point = line[0]
|
74 |
+
for i in range(1, len(line)):
|
75 |
+
next_point = line[i]
|
76 |
+
cv.line(mask,
|
77 |
+
(int(prev_point["x"] * mask.shape[1]), int(prev_point["y"] * mask.shape[0])),
|
78 |
+
(int(next_point["x"] * mask.shape[1]), int(next_point["y"] * mask.shape[0])),
|
79 |
+
class_number + 1,
|
80 |
+
2)
|
81 |
+
prev_point = next_point
|
82 |
+
return img, mask
|
83 |
+
|
84 |
+
|
85 |
+
if __name__ == "__main__":
|
86 |
+
|
87 |
+
# Load the arguments
|
88 |
+
parser = ArgumentParser(description='dataloader')
|
89 |
+
|
90 |
+
parser.add_argument('--SoccerNet_path', default="./annotations/", type=str,
|
91 |
+
help='Path to the SoccerNet-V3 dataset folder')
|
92 |
+
parser.add_argument('--tiny', required=False, type=int, default=None, help='Select a subset of x games')
|
93 |
+
parser.add_argument('--split', required=False, type=str, default="test", help='Select the split of data')
|
94 |
+
parser.add_argument('--num_workers', required=False, type=int, default=4,
|
95 |
+
help='number of workers for the dataloader')
|
96 |
+
parser.add_argument('--resolution_width', required=False, type=int, default=1920,
|
97 |
+
help='width resolution of the images')
|
98 |
+
parser.add_argument('--resolution_height', required=False, type=int, default=1080,
|
99 |
+
help='height resolution of the images')
|
100 |
+
parser.add_argument('--preload_images', action='store_true',
|
101 |
+
help="Preload the images when constructing the dataset")
|
102 |
+
parser.add_argument('--zipped_images', action='store_true', help="Read images from zipped folder")
|
103 |
+
|
104 |
+
args = parser.parse_args()
|
105 |
+
|
106 |
+
start_time = time.time()
|
107 |
+
soccernet = SoccerNetDataset(args.SoccerNet_path, split=args.split)
|
108 |
+
with tqdm(enumerate(soccernet), total=len(soccernet), ncols=160) as t:
|
109 |
+
for i, data in t:
|
110 |
+
img = soccernet[i][0].astype(np.uint8).transpose((1, 2, 0))
|
111 |
+
print(img.shape)
|
112 |
+
print(img.dtype)
|
113 |
+
cv.imshow("Normalized image", img)
|
114 |
+
cv.waitKey(0)
|
115 |
+
cv.destroyAllWindows()
|
116 |
+
print(data[1].shape)
|
117 |
+
cv.imshow("Mask", soccernet[i][1].astype(np.uint8))
|
118 |
+
cv.waitKey(0)
|
119 |
+
cv.destroyAllWindows()
|
120 |
+
continue
|
121 |
+
end_time = time.time()
|
122 |
+
print(end_time - start_time)
|
tvcalib/sn_segmentation/src/evaluate_extremities.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import numpy as np
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
from SoccerNet.Evaluation.utils_calibration import SoccerPitch
|
10 |
+
|
11 |
+
|
12 |
+
def distance(point1, point2):
|
13 |
+
"""
|
14 |
+
Computes euclidian distance between 2D points
|
15 |
+
:param point1
|
16 |
+
:param point2
|
17 |
+
:return: euclidian distance between point1 and point2
|
18 |
+
"""
|
19 |
+
diff = np.array([point1['x'], point1['y']]) - np.array([point2['x'], point2['y']])
|
20 |
+
sq_dist = np.square(diff)
|
21 |
+
return np.sqrt(sq_dist.sum())
|
22 |
+
|
23 |
+
|
24 |
+
def mirror_labels(lines_dict):
|
25 |
+
"""
|
26 |
+
Replace each line class key of the dictionary with its opposite element according to a central projection by the
|
27 |
+
soccer pitch center
|
28 |
+
:param lines_dict: dictionary whose keys will be mirrored
|
29 |
+
:return: Dictionary with mirrored keys and same values
|
30 |
+
"""
|
31 |
+
mirrored_dict = dict()
|
32 |
+
for line_class, value in lines_dict.items():
|
33 |
+
mirrored_dict[SoccerPitch.symetric_classes[line_class]] = value
|
34 |
+
return mirrored_dict
|
35 |
+
|
36 |
+
|
37 |
+
def evaluate_detection_prediction(detected_lines, groundtruth_lines, threshold=2.):
|
38 |
+
"""
|
39 |
+
Evaluates the prediction of extremities. The extremities associated to a class are unordered. The extremities of the
|
40 |
+
"Circle central" element is not well-defined for this task, thus this class is ignored.
|
41 |
+
Computes confusion matrices for a level of precision specified by the threshold.
|
42 |
+
A groundtruth extremity point is correctly classified if it lies at less than threshold pixels from the
|
43 |
+
corresponding extremity point of the prediction of the same class.
|
44 |
+
Computes also the euclidian distance between each predicted extremity and its closest groundtruth extremity, when
|
45 |
+
both the groundtruth and the prediction contain the element class.
|
46 |
+
|
47 |
+
:param detected_lines: dictionary of detected lines classes as keys and associated predicted extremities as values
|
48 |
+
:param groundtruth_lines: dictionary of annotated lines classes as keys and associated annotated points as values
|
49 |
+
:param threshold: distance in pixels that distinguishes good matches from bad ones
|
50 |
+
:return: confusion matrix, per class confusion matrix & per class localization errors
|
51 |
+
"""
|
52 |
+
confusion_mat = np.zeros((2, 2), dtype=np.float32)
|
53 |
+
per_class_confusion = {}
|
54 |
+
errors_dict = {}
|
55 |
+
detected_classes = set(detected_lines.keys())
|
56 |
+
groundtruth_classes = set(groundtruth_lines.keys())
|
57 |
+
|
58 |
+
if "Circle central" in groundtruth_classes:
|
59 |
+
groundtruth_classes.remove("Circle central")
|
60 |
+
if "Circle central" in detected_classes:
|
61 |
+
detected_classes.remove("Circle central")
|
62 |
+
|
63 |
+
false_positives_classes = detected_classes - groundtruth_classes
|
64 |
+
for false_positive_class in false_positives_classes:
|
65 |
+
false_positives = len(detected_lines[false_positive_class])
|
66 |
+
confusion_mat[0, 1] += false_positives
|
67 |
+
per_class_confusion[false_positive_class] = np.array([[0., false_positives], [0., 0.]])
|
68 |
+
|
69 |
+
false_negatives_classes = groundtruth_classes - detected_classes
|
70 |
+
for false_negatives_class in false_negatives_classes:
|
71 |
+
false_negatives = len(groundtruth_lines[false_negatives_class])
|
72 |
+
confusion_mat[1, 0] += false_negatives
|
73 |
+
per_class_confusion[false_negatives_class] = np.array([[0., 0.], [false_negatives, 0.]])
|
74 |
+
|
75 |
+
common_classes = detected_classes - false_positives_classes
|
76 |
+
|
77 |
+
for detected_class in common_classes:
|
78 |
+
|
79 |
+
detected_points = detected_lines[detected_class]
|
80 |
+
|
81 |
+
groundtruth_points = groundtruth_lines[detected_class]
|
82 |
+
|
83 |
+
groundtruth_extremities = [groundtruth_points[0], groundtruth_points[-1]]
|
84 |
+
predicted_extremities = [detected_points[0], detected_points[-1]]
|
85 |
+
per_class_confusion[detected_class] = np.zeros((2, 2))
|
86 |
+
|
87 |
+
dist1 = distance(groundtruth_extremities[0], predicted_extremities[0])
|
88 |
+
dist1rev = distance(groundtruth_extremities[1], predicted_extremities[0])
|
89 |
+
|
90 |
+
dist2 = distance(groundtruth_extremities[1], predicted_extremities[1])
|
91 |
+
dist2rev = distance(groundtruth_extremities[0], predicted_extremities[1])
|
92 |
+
if dist1rev <= dist1 and dist2rev <= dist2:
|
93 |
+
# reverse order
|
94 |
+
dist1 = dist1rev
|
95 |
+
dist2 = dist2rev
|
96 |
+
|
97 |
+
errors_dict[detected_class] = [dist1, dist2]
|
98 |
+
|
99 |
+
if dist1 < threshold:
|
100 |
+
confusion_mat[0, 0] += 1
|
101 |
+
per_class_confusion[detected_class][0, 0] += 1
|
102 |
+
else:
|
103 |
+
# treat too far detections as false positives
|
104 |
+
confusion_mat[0, 1] += 1
|
105 |
+
per_class_confusion[detected_class][0, 1] += 1
|
106 |
+
|
107 |
+
if dist2 < threshold:
|
108 |
+
confusion_mat[0, 0] += 1
|
109 |
+
per_class_confusion[detected_class][0, 0] += 1
|
110 |
+
|
111 |
+
else:
|
112 |
+
# treat too far detections as false positives
|
113 |
+
confusion_mat[0, 1] += 1
|
114 |
+
per_class_confusion[detected_class][0, 1] += 1
|
115 |
+
|
116 |
+
return confusion_mat, per_class_confusion, errors_dict
|
117 |
+
|
118 |
+
|
119 |
+
def scale_points(points_dict, s_width, s_height):
|
120 |
+
"""
|
121 |
+
Scale points by s_width and s_height factors
|
122 |
+
:param points_dict: dictionary of annotations/predictions with normalized point values
|
123 |
+
:param s_width: width scaling factor
|
124 |
+
:param s_height: height scaling factor
|
125 |
+
:return: dictionary with scaled points
|
126 |
+
"""
|
127 |
+
line_dict = {}
|
128 |
+
for line_class, points in points_dict.items():
|
129 |
+
scaled_points = []
|
130 |
+
for point in points:
|
131 |
+
new_point = {'x': point['x'] * (s_width-1), 'y': point['y'] * (s_height-1)}
|
132 |
+
scaled_points.append(new_point)
|
133 |
+
if len(scaled_points):
|
134 |
+
line_dict[line_class] = scaled_points
|
135 |
+
return line_dict
|
136 |
+
|
137 |
+
|
138 |
+
if __name__ == "__main__":
|
139 |
+
|
140 |
+
parser = argparse.ArgumentParser(description='Test')
|
141 |
+
|
142 |
+
parser.add_argument('-s', '--soccernet', default="./annotations", type=str,
|
143 |
+
help='Path to the SoccerNet-V3 dataset folder')
|
144 |
+
parser.add_argument('-p', '--prediction', default="./results_bis",
|
145 |
+
required=False, type=str,
|
146 |
+
help="Path to the prediction folder")
|
147 |
+
parser.add_argument('-t', '--threshold', default=10, required=False, type=int,
|
148 |
+
help="Accuracy threshold in pixels")
|
149 |
+
parser.add_argument('--split', required=False, type=str, default="test", help='Select the split of data')
|
150 |
+
parser.add_argument('--resolution_width', required=False, type=int, default=960,
|
151 |
+
help='width resolution of the images')
|
152 |
+
parser.add_argument('--resolution_height', required=False, type=int, default=540,
|
153 |
+
help='height resolution of the images')
|
154 |
+
args = parser.parse_args()
|
155 |
+
|
156 |
+
accuracies = []
|
157 |
+
precisions = []
|
158 |
+
recalls = []
|
159 |
+
dict_errors = {}
|
160 |
+
per_class_confusion_dict = {}
|
161 |
+
|
162 |
+
dataset_dir = os.path.join(args.soccernet, args.split)
|
163 |
+
if not os.path.exists(dataset_dir):
|
164 |
+
print("Invalid dataset path !")
|
165 |
+
exit(-1)
|
166 |
+
|
167 |
+
annotation_files = [f for f in os.listdir(dataset_dir) if ".json" in f]
|
168 |
+
|
169 |
+
with tqdm(enumerate(annotation_files), total=len(annotation_files), ncols=160) as t:
|
170 |
+
for i, annotation_file in t:
|
171 |
+
frame_index = annotation_file.split(".")[0]
|
172 |
+
annotation_file = os.path.join(args.soccernet, args.split, annotation_file)
|
173 |
+
prediction_file = os.path.join(args.prediction, args.split, f"extremities_{frame_index}.json")
|
174 |
+
|
175 |
+
if not os.path.exists(prediction_file):
|
176 |
+
accuracies.append(0.)
|
177 |
+
precisions.append(0.)
|
178 |
+
recalls.append(0.)
|
179 |
+
continue
|
180 |
+
|
181 |
+
with open(annotation_file, 'r') as f:
|
182 |
+
line_annotations = json.load(f)
|
183 |
+
|
184 |
+
with open(prediction_file, 'r') as f:
|
185 |
+
predictions = json.load(f)
|
186 |
+
|
187 |
+
predictions = scale_points(predictions, args.resolution_width, args.resolution_height)
|
188 |
+
line_annotations = scale_points(line_annotations, args.resolution_width, args.resolution_height)
|
189 |
+
|
190 |
+
img_prediction = predictions
|
191 |
+
img_groundtruth = line_annotations
|
192 |
+
confusion1, per_class_conf1, reproj_errors1 = evaluate_detection_prediction(img_prediction,
|
193 |
+
img_groundtruth,
|
194 |
+
args.threshold)
|
195 |
+
confusion2, per_class_conf2, reproj_errors2 = evaluate_detection_prediction(img_prediction,
|
196 |
+
mirror_labels(
|
197 |
+
img_groundtruth),
|
198 |
+
args.threshold)
|
199 |
+
|
200 |
+
accuracy1, accuracy2 = 0., 0.
|
201 |
+
if confusion1.sum() > 0:
|
202 |
+
accuracy1 = confusion1[0, 0] / confusion1.sum()
|
203 |
+
|
204 |
+
if confusion2.sum() > 0:
|
205 |
+
accuracy2 = confusion2[0, 0] / confusion2.sum()
|
206 |
+
|
207 |
+
if accuracy1 > accuracy2:
|
208 |
+
accuracy = accuracy1
|
209 |
+
confusion = confusion1
|
210 |
+
per_class_conf = per_class_conf1
|
211 |
+
reproj_errors = reproj_errors1
|
212 |
+
else:
|
213 |
+
accuracy = accuracy2
|
214 |
+
confusion = confusion2
|
215 |
+
per_class_conf = per_class_conf2
|
216 |
+
reproj_errors = reproj_errors2
|
217 |
+
|
218 |
+
accuracies.append(accuracy)
|
219 |
+
if confusion[0, :].sum() > 0:
|
220 |
+
precision = confusion[0, 0] / (confusion[0, :].sum())
|
221 |
+
precisions.append(precision)
|
222 |
+
if (confusion[0, 0] + confusion[1, 0]) > 0:
|
223 |
+
recall = confusion[0, 0] / (confusion[0, 0] + confusion[1, 0])
|
224 |
+
recalls.append(recall)
|
225 |
+
|
226 |
+
for line_class, errors in reproj_errors.items():
|
227 |
+
if line_class in dict_errors.keys():
|
228 |
+
dict_errors[line_class].extend(errors)
|
229 |
+
else:
|
230 |
+
dict_errors[line_class] = errors
|
231 |
+
|
232 |
+
for line_class, confusion_mat in per_class_conf.items():
|
233 |
+
if line_class in per_class_confusion_dict.keys():
|
234 |
+
per_class_confusion_dict[line_class] += confusion_mat
|
235 |
+
else:
|
236 |
+
per_class_confusion_dict[line_class] = confusion_mat
|
237 |
+
|
238 |
+
mRecall = np.mean(recalls)
|
239 |
+
sRecall = np.std(recalls)
|
240 |
+
medianRecall = np.median(recalls)
|
241 |
+
print(
|
242 |
+
f" On SoccerNet {args.split} set, recall mean value : {mRecall * 100:2.2f}% with standard deviation of {sRecall * 100:2.2f}% and median of {medianRecall * 100:2.2f}%")
|
243 |
+
|
244 |
+
mPrecision = np.mean(precisions)
|
245 |
+
sPrecision = np.std(precisions)
|
246 |
+
medianPrecision = np.median(precisions)
|
247 |
+
print(
|
248 |
+
f" On SoccerNet {args.split} set, precision mean value : {mPrecision * 100:2.2f}% with standard deviation of {sPrecision * 100:2.2f}% and median of {medianPrecision * 100:2.2f}%")
|
249 |
+
|
250 |
+
mAccuracy = np.mean(accuracies)
|
251 |
+
sAccuracy = np.std(accuracies)
|
252 |
+
medianAccuracy = np.median(accuracies)
|
253 |
+
print(
|
254 |
+
f" On SoccerNet {args.split} set, accuracy mean value : {mAccuracy * 100:2.2f}% with standard deviation of {sAccuracy * 100:2.2f}% and median of {medianAccuracy * 100:2.2f}%")
|
255 |
+
|
256 |
+
for line_class, confusion_mat in per_class_confusion_dict.items():
|
257 |
+
class_accuracy = confusion_mat[0, 0] / confusion_mat.sum()
|
258 |
+
class_recall = confusion_mat[0, 0] / (confusion_mat[0, 0] + confusion_mat[1, 0])
|
259 |
+
class_precision = confusion_mat[0, 0] / (confusion_mat[0, 0] + confusion_mat[0, 1])
|
260 |
+
print(
|
261 |
+
f"For class {line_class}, accuracy of {class_accuracy * 100:2.2f}%, precision of {class_precision * 100:2.2f}% and recall of {class_recall * 100:2.2f}%")
|
262 |
+
|
263 |
+
for k, v in dict_errors.items():
|
264 |
+
fig, ax1 = plt.subplots(figsize=(11, 8))
|
265 |
+
ax1.hist(v, bins=30, range=(0, 60))
|
266 |
+
ax1.set_title(k)
|
267 |
+
ax1.set_xlabel("Errors in pixel")
|
268 |
+
os.makedirs(f"./results/", exist_ok=True)
|
269 |
+
plt.savefig(f"./results/{k}_detection_error.png")
|
270 |
+
plt.close(fig)
|
tvcalib/sn_segmentation/src/masks_gt2chen.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
import torch
|
3 |
+
from torch.utils.data import Dataset
|
4 |
+
import torchvision
|
5 |
+
import torchvision.transforms as T
|
6 |
+
from PIL import Image
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import pandas as pd
|
11 |
+
import h5py
|
12 |
+
from tqdm.auto import tqdm
|
13 |
+
from collections import defaultdict
|
14 |
+
from pathlib import Path
|
15 |
+
import json
|
16 |
+
import os
|
17 |
+
from argparse import ArgumentParser
|
18 |
+
|
19 |
+
lines_classes = [
|
20 |
+
"Big rect. left bottom",
|
21 |
+
"Big rect. left main",
|
22 |
+
"Big rect. left top",
|
23 |
+
"Big rect. right bottom",
|
24 |
+
"Big rect. right main",
|
25 |
+
"Big rect. right top",
|
26 |
+
"Circle central",
|
27 |
+
"Circle left",
|
28 |
+
"Circle right",
|
29 |
+
# "Goal left crossbar",
|
30 |
+
# "Goal left post left ",
|
31 |
+
# "Goal left post right",
|
32 |
+
# "Goal right crossbar",
|
33 |
+
# "Goal right post left",
|
34 |
+
# "Goal right post right",
|
35 |
+
"Goal unknown",
|
36 |
+
"Line unknown",
|
37 |
+
"Middle line",
|
38 |
+
"Side line bottom",
|
39 |
+
"Side line left",
|
40 |
+
"Side line right",
|
41 |
+
"Side line top",
|
42 |
+
"Small rect. left bottom",
|
43 |
+
"Small rect. left main",
|
44 |
+
"Small rect. left top",
|
45 |
+
"Small rect. right bottom",
|
46 |
+
"Small rect. right main",
|
47 |
+
"Small rect. right top",
|
48 |
+
]
|
49 |
+
|
50 |
+
# RGB values
|
51 |
+
palette = {
|
52 |
+
"Big rect. left bottom": (127, 0, 0),
|
53 |
+
"Big rect. left main": (102, 102, 102),
|
54 |
+
"Big rect. left top": (0, 0, 127),
|
55 |
+
"Big rect. right bottom": (86, 32, 39),
|
56 |
+
"Big rect. right main": (48, 77, 0),
|
57 |
+
"Big rect. right top": (14, 97, 100),
|
58 |
+
"Circle central": (0, 0, 255),
|
59 |
+
"Circle left": (255, 127, 0),
|
60 |
+
"Circle right": (0, 255, 255),
|
61 |
+
# "Goal left crossbar": (255, 255, 200),
|
62 |
+
# "Goal left post left ": (165, 255, 0),
|
63 |
+
# "Goal left post right": (155, 119, 45),
|
64 |
+
# "Goal right crossbar": (86, 32, 139),
|
65 |
+
# "Goal right post left": (196, 120, 153),
|
66 |
+
# "Goal right post right": (166, 36, 52),
|
67 |
+
"Goal unknown": (0, 0, 0),
|
68 |
+
"Line unknown": (0, 0, 0),
|
69 |
+
"Middle line": (255, 255, 0),
|
70 |
+
"Side line bottom": (255, 0, 255),
|
71 |
+
"Side line left": (0, 255, 150),
|
72 |
+
"Side line right": (0, 230, 0),
|
73 |
+
"Side line top": (230, 0, 0),
|
74 |
+
"Small rect. left bottom": (0, 150, 255),
|
75 |
+
"Small rect. left main": (254, 173, 225),
|
76 |
+
"Small rect. left top": (87, 72, 39),
|
77 |
+
"Small rect. right bottom": (122, 0, 255),
|
78 |
+
"Small rect. right main": (255, 255, 255),
|
79 |
+
"Small rect. right top": (153, 23, 153),
|
80 |
+
}
|
81 |
+
|
82 |
+
|
83 |
+
def create_target_from_annotation(width, height, annotation, classes, linewidth=4):
|
84 |
+
"""Draw one-hot encoded segments according to the annotation.
|
85 |
+
Creates target that matches image size ([C+1]xHxW).
|
86 |
+
"""
|
87 |
+
annotation_abs = defaultdict(list)
|
88 |
+
# unnormalize every point in every class k
|
89 |
+
for k in annotation.keys():
|
90 |
+
if k not in lines_classes:
|
91 |
+
continue
|
92 |
+
start = annotation[k][0].copy()
|
93 |
+
end = annotation[k][-1].copy()
|
94 |
+
for annotation_point in annotation[k]:
|
95 |
+
tup = annotation_point.copy()
|
96 |
+
tup["x"] *= width
|
97 |
+
tup["x"] = int(tup["x"])
|
98 |
+
tup["y"] *= height
|
99 |
+
tup["y"] = int(tup["y"])
|
100 |
+
annotation_abs[k].append(tup)
|
101 |
+
|
102 |
+
# draw lines between annotated points for each segment
|
103 |
+
# offset class +1 such that no classes detected will end in argmax 0
|
104 |
+
# otherwise argmax 0 will be another class
|
105 |
+
classes_segments = np.zeros(shape=(len(classes) + 1, height, width))
|
106 |
+
for cls, points in annotation_abs.items():
|
107 |
+
class_segments = np.zeros(shape=(height, width, 3))
|
108 |
+
for start, end in zip(points, points[1:]):
|
109 |
+
startxy = (start["x"], start["y"])
|
110 |
+
endxy = [end["x"], end["y"]]
|
111 |
+
class_segments = cv2.line(
|
112 |
+
class_segments, startxy, endxy, (1, 1, 1), linewidth
|
113 |
+
)
|
114 |
+
classes_segments[classes.index(cls) + 1] = class_segments[:, :, 1]
|
115 |
+
|
116 |
+
classes_segments = torch.Tensor(classes_segments)
|
117 |
+
return classes_segments
|
118 |
+
|
119 |
+
|
120 |
+
class ExtremitiesDataset(Dataset):
|
121 |
+
def __init__(
|
122 |
+
self, root, split, annotations, filter_cam=None, extremities_prefix="", classes=lines_classes, palette=palette
|
123 |
+
):
|
124 |
+
self.data_root = Path(root)
|
125 |
+
self.split = split
|
126 |
+
|
127 |
+
self.annotations_path = annotations
|
128 |
+
|
129 |
+
if filter_cam is None:
|
130 |
+
files = os.listdir(self.data_root / self.split)
|
131 |
+
self.annotations = sorted([fn for fn in files if fn.endswith("json")])
|
132 |
+
self.images = sorted([fn for fn in files if fn.endswith("jpg")])
|
133 |
+
else:
|
134 |
+
df = pd.read_json(self.data_root / self.split / "match_info_cam_gt.json").T
|
135 |
+
df = df.loc[df.camera == filter_cam]
|
136 |
+
assert len(df.index) > 0
|
137 |
+
df["image_file"] = df.index
|
138 |
+
df = df.sort_values(by=["image_file"])
|
139 |
+
df["annotation_file"] = df["image_file"].apply(
|
140 |
+
lambda s: extremities_prefix + s.split(".jpg")[0] + ".json"
|
141 |
+
)
|
142 |
+
self.annotations = df["annotation_file"].tolist()
|
143 |
+
self.images = df["image_file"].tolist()
|
144 |
+
|
145 |
+
self.classes = classes
|
146 |
+
|
147 |
+
def __len__(self):
|
148 |
+
return len(self.images)
|
149 |
+
|
150 |
+
def __getitem__(self, idx):
|
151 |
+
# see https://learnopencv.com/pytorch-for-beginners-semantic-segmentation-using-torchvision/
|
152 |
+
|
153 |
+
impath = self.data_root / self.split / self.images[idx]
|
154 |
+
annotation_path = self.annotations_path / self.annotations[idx]
|
155 |
+
with open(annotation_path, "r") as f:
|
156 |
+
annotation = json.load(f)
|
157 |
+
|
158 |
+
img = Image.open(impath) # .resize((1280, 720))
|
159 |
+
trf = T.Compose(
|
160 |
+
[
|
161 |
+
T.ToTensor(),
|
162 |
+
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
163 |
+
]
|
164 |
+
)
|
165 |
+
# prepare batches
|
166 |
+
img = trf(img)
|
167 |
+
|
168 |
+
# see https://git.tib.eu/vid2pos/sccvsd/-/blob/master/utils/synthetic_util.py
|
169 |
+
# draw lines (linewidth=4 for 720p) -> hence we rescale first
|
170 |
+
target = create_target_from_annotation(1280, 720, annotation, self.classes)
|
171 |
+
target = target.long().argmax(dim=0).unsqueeze(0)
|
172 |
+
# to binary mask
|
173 |
+
target = target.bool().float()
|
174 |
+
# rescale target equivalent to cv2.resize() with default args (interpolation bilinear) -> same as in torchvision
|
175 |
+
# bilinear -> [0, 0.25, 0.5, 0.7, 1.0] are entries
|
176 |
+
target = torchvision.transforms.Resize((180, 320))(target)
|
177 |
+
# print(torch.unique(target))
|
178 |
+
# to uint8
|
179 |
+
target = (target * 255.0).to(torch.uint8)
|
180 |
+
|
181 |
+
return img, target, impath.name
|
182 |
+
|
183 |
+
|
184 |
+
if __name__ == "__main__":
|
185 |
+
|
186 |
+
args = ArgumentParser()
|
187 |
+
args.add_argument("--data_dir", type=Path)
|
188 |
+
args.add_argument("--annotations", type=Path)
|
189 |
+
args.add_argument("--output_dir", type=Path)
|
190 |
+
args.add_argument("--extremities_prefix", type=str, default="")
|
191 |
+
args = args.parse_args()
|
192 |
+
|
193 |
+
data_dir = args.data_dir.parent
|
194 |
+
split = args.data_dir.name
|
195 |
+
output_dir = args.output_dir
|
196 |
+
if not output_dir.exists():
|
197 |
+
raise FileNotFoundError
|
198 |
+
|
199 |
+
dataset = ExtremitiesDataset(data_dir, split, args.annotations, filter_cam="Main camera center", extremities_prefix=args.extremities_prefix)
|
200 |
+
|
201 |
+
# img, edge_map, img_id = dataset[0]
|
202 |
+
# edge_map = edge_map.squeeze(0).numpy()
|
203 |
+
# Image.fromarray(edge_map).show()
|
204 |
+
|
205 |
+
image_src = []
|
206 |
+
edge_maps = np.zeros((len(dataset), 1, 180, 320), dtype=np.uint8)
|
207 |
+
for i, (_, edge_map, img_id) in enumerate(tqdm(dataset)):
|
208 |
+
edge_map = edge_map.numpy()
|
209 |
+
# Image.fromarray(edge_map).show()
|
210 |
+
edge_maps[i] = edge_map
|
211 |
+
image_src.append(img_id)
|
212 |
+
|
213 |
+
with h5py.File(output_dir / "seg_edge_maps.h5", "w") as f:
|
214 |
+
f.create_dataset("edge_map", data=edge_maps)
|
215 |
+
|
216 |
+
with open(output_dir / "seg_image_paths.pkl", "wb") as f:
|
217 |
+
pickle.dump(image_src, f)
|
tvcalib/sn_segmentation/src/masks_pred2chen.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os.path
|
3 |
+
import pickle
|
4 |
+
|
5 |
+
import h5py
|
6 |
+
import numpy as np
|
7 |
+
import pandas as pd
|
8 |
+
from PIL import Image
|
9 |
+
from tqdm import tqdm
|
10 |
+
import cv2
|
11 |
+
|
12 |
+
from SoccerNet.Evaluation.utils_calibration import SoccerPitch
|
13 |
+
|
14 |
+
from custom_extremities import CustomNetwork
|
15 |
+
|
16 |
+
if __name__ == "__main__":
|
17 |
+
parser = argparse.ArgumentParser(description="Test")
|
18 |
+
|
19 |
+
parser.add_argument(
|
20 |
+
"-s",
|
21 |
+
"--soccernet",
|
22 |
+
default="/nfs/data/soccernet/calibration/",
|
23 |
+
type=str,
|
24 |
+
help="Path to the SoccerNet-V3 dataset folder",
|
25 |
+
)
|
26 |
+
parser.add_argument(
|
27 |
+
"-p",
|
28 |
+
"--prediction",
|
29 |
+
default="/nfs/home/rhotertj/datasets/sn-calib-test_endpoints",
|
30 |
+
required=False,
|
31 |
+
type=str,
|
32 |
+
help="Path to the prediction folder",
|
33 |
+
)
|
34 |
+
parser.add_argument(
|
35 |
+
"--split",
|
36 |
+
required=False,
|
37 |
+
type=str,
|
38 |
+
default="challenge",
|
39 |
+
help="Select the split of data",
|
40 |
+
)
|
41 |
+
parser.add_argument(
|
42 |
+
"--resolution_width",
|
43 |
+
required=False,
|
44 |
+
type=int,
|
45 |
+
default=455,
|
46 |
+
help="width resolution of the images",
|
47 |
+
)
|
48 |
+
parser.add_argument(
|
49 |
+
"--resolution_height",
|
50 |
+
required=False,
|
51 |
+
type=int,
|
52 |
+
default=256,
|
53 |
+
help="height resolution of the images",
|
54 |
+
)
|
55 |
+
parser.add_argument(
|
56 |
+
"--checkpoint",
|
57 |
+
required=False,
|
58 |
+
type=str,
|
59 |
+
help="Path to the custom model checkpoint.",
|
60 |
+
)
|
61 |
+
parser.add_argument("--filter_cam", type=str, required=False)
|
62 |
+
args = parser.parse_args()
|
63 |
+
|
64 |
+
lines_palette = [0, 0, 0]
|
65 |
+
for line_class in SoccerPitch.lines_classes:
|
66 |
+
print(line_class, SoccerPitch.palette[line_class])
|
67 |
+
lines_palette.extend(SoccerPitch.palette[line_class])
|
68 |
+
|
69 |
+
print(lines_palette)
|
70 |
+
|
71 |
+
# exit(0)
|
72 |
+
|
73 |
+
dataset_dir = os.path.join(args.soccernet, args.split)
|
74 |
+
if not os.path.exists(dataset_dir):
|
75 |
+
print("Invalid dataset path !")
|
76 |
+
exit(-1)
|
77 |
+
|
78 |
+
match_info_file = os.path.join(args.soccernet, args.split, "match_info_cam_gt.json")
|
79 |
+
print(match_info_file)
|
80 |
+
if not os.path.exists(match_info_file):
|
81 |
+
exit(-1)
|
82 |
+
df = pd.read_json(match_info_file).T
|
83 |
+
if args.filter_cam:
|
84 |
+
df = df.loc[df.camera == args.filter_cam]
|
85 |
+
df["image_file"] = df.index
|
86 |
+
df = df.sort_values(by=["image_file"])
|
87 |
+
print(df)
|
88 |
+
|
89 |
+
frames = df["image_file"].tolist()
|
90 |
+
|
91 |
+
model = CustomNetwork(args.checkpoint)
|
92 |
+
|
93 |
+
image_src = []
|
94 |
+
edge_maps = np.zeros((len(frames), 1, 180, 320), dtype=np.uint8)
|
95 |
+
|
96 |
+
kernel = np.ones((4, 4), np.uint8)
|
97 |
+
|
98 |
+
with tqdm(enumerate(frames), total=len(frames), ncols=100) as t:
|
99 |
+
for i, frame in t:
|
100 |
+
|
101 |
+
output_prediction_folder = args.prediction
|
102 |
+
if not os.path.exists(output_prediction_folder):
|
103 |
+
os.makedirs(output_prediction_folder)
|
104 |
+
|
105 |
+
frame_path = os.path.join(dataset_dir, frame)
|
106 |
+
|
107 |
+
frame_index = frame.split(".")[0]
|
108 |
+
|
109 |
+
image = Image.open(frame_path)
|
110 |
+
|
111 |
+
semlines = model.forward(image)
|
112 |
+
|
113 |
+
# print(semlines.shape, np.unique(semlines))
|
114 |
+
# set class 9-15 (goal parts) to background
|
115 |
+
mask_goal = (semlines >= 9) & (semlines <= 15)
|
116 |
+
semlines[mask_goal] = 0
|
117 |
+
|
118 |
+
mask = Image.fromarray(semlines.astype(np.uint8)).convert("P")
|
119 |
+
mask.putpalette(lines_palette)
|
120 |
+
|
121 |
+
# to binary edge map
|
122 |
+
mask = np.asarray(mask.convert("L"))
|
123 |
+
mask[mask > 0] = 255
|
124 |
+
|
125 |
+
mask = Image.fromarray(mask)
|
126 |
+
mask = mask.resize((320, 180), resample=Image.NEAREST)
|
127 |
+
# expected linewith @ 720p resulution -> 4px
|
128 |
+
|
129 |
+
mask = np.asarray(mask)
|
130 |
+
# print(mask.shape)
|
131 |
+
|
132 |
+
mask = cv2.erode(mask, kernel, iterations=1)
|
133 |
+
|
134 |
+
# assert len(np.unique(mask)) == 2 # [0, 255]
|
135 |
+
|
136 |
+
# mask_file = os.path.join(output_prediction_folder, frame)
|
137 |
+
# mask.save(mask_file)
|
138 |
+
# print(mask)
|
139 |
+
# exit(0)
|
140 |
+
|
141 |
+
edge_maps[i] = mask
|
142 |
+
image_src.append(frame)
|
143 |
+
|
144 |
+
with h5py.File(
|
145 |
+
os.path.join(output_prediction_folder, "seg_edge_maps.h5"), "w"
|
146 |
+
) as f:
|
147 |
+
f.create_dataset("edge_map", data=edge_maps)
|
148 |
+
|
149 |
+
with open(os.path.join(output_prediction_folder, "seg_image_paths.pkl"), "wb") as f:
|
150 |
+
pickle.dump(image_src, f)
|
tvcalib/sn_segmentation/src/segmentation/README.md
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Semantic segmentation reference training scripts
|
2 |
+
|
3 |
+
This directory is an edited copy of the [torchvision](https://github.com/pytorch/vision/tree/main/references/segmentation) reference scripts.
|
4 |
+
|
5 |
+
# Usage
|
6 |
+
|
7 |
+
Train from scratch:
|
8 |
+
|
9 |
+
```
|
10 |
+
python train.py -b 8 --model deeplabv3_resnet101 --aux-loss --weights-backbone ResNet101_Weights.IMAGENET1K_V1 --epochs 30 --output-dir "./checkpoints" --split train
|
11 |
+
```
|
12 |
+
|
13 |
+
Resume training:
|
14 |
+
|
15 |
+
```
|
16 |
+
python train.py -b 8 --model deeplabv3_resnet101 --aux-loss --weights-backbone ResNet101_Weights.IMAGENET1K_V1 --epochs 60 --start-epoch 30 --weights /path/to/checkpoints.pt
|
17 |
+
```
|
18 |
+
|
19 |
+
Evaluate checkpoint:
|
20 |
+
|
21 |
+
```
|
22 |
+
python train.py -b 4 --model deeplabv3_resnet101 --aux-loss --weights-backbone ResNet101_Weights.IMAGENET1K_V1 --test-only --weights /path/to/checkpoints.pt
|
23 |
+
```
|
tvcalib/sn_segmentation/src/segmentation/coco_utils.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.utils.data
|
6 |
+
import torchvision
|
7 |
+
from PIL import Image
|
8 |
+
from pycocotools import mask as coco_mask
|
9 |
+
from transforms import Compose
|
10 |
+
|
11 |
+
|
12 |
+
class FilterAndRemapCocoCategories:
|
13 |
+
def __init__(self, categories, remap=True):
|
14 |
+
self.categories = categories
|
15 |
+
self.remap = remap
|
16 |
+
|
17 |
+
def __call__(self, image, anno):
|
18 |
+
anno = [obj for obj in anno if obj["category_id"] in self.categories]
|
19 |
+
if not self.remap:
|
20 |
+
return image, anno
|
21 |
+
anno = copy.deepcopy(anno)
|
22 |
+
for obj in anno:
|
23 |
+
obj["category_id"] = self.categories.index(obj["category_id"])
|
24 |
+
return image, anno
|
25 |
+
|
26 |
+
|
27 |
+
def convert_coco_poly_to_mask(segmentations, height, width):
|
28 |
+
masks = []
|
29 |
+
for polygons in segmentations:
|
30 |
+
rles = coco_mask.frPyObjects(polygons, height, width)
|
31 |
+
mask = coco_mask.decode(rles)
|
32 |
+
if len(mask.shape) < 3:
|
33 |
+
mask = mask[..., None]
|
34 |
+
mask = torch.as_tensor(mask, dtype=torch.uint8)
|
35 |
+
mask = mask.any(dim=2)
|
36 |
+
masks.append(mask)
|
37 |
+
if masks:
|
38 |
+
masks = torch.stack(masks, dim=0)
|
39 |
+
else:
|
40 |
+
masks = torch.zeros((0, height, width), dtype=torch.uint8)
|
41 |
+
return masks
|
42 |
+
|
43 |
+
|
44 |
+
class ConvertCocoPolysToMask:
|
45 |
+
def __call__(self, image, anno):
|
46 |
+
w, h = image.size
|
47 |
+
segmentations = [obj["segmentation"] for obj in anno]
|
48 |
+
cats = [obj["category_id"] for obj in anno]
|
49 |
+
if segmentations:
|
50 |
+
masks = convert_coco_poly_to_mask(segmentations, h, w)
|
51 |
+
cats = torch.as_tensor(cats, dtype=masks.dtype)
|
52 |
+
# merge all instance masks into a single segmentation map
|
53 |
+
# with its corresponding categories
|
54 |
+
target, _ = (masks * cats[:, None, None]).max(dim=0)
|
55 |
+
# discard overlapping instances
|
56 |
+
target[masks.sum(0) > 1] = 255
|
57 |
+
else:
|
58 |
+
target = torch.zeros((h, w), dtype=torch.uint8)
|
59 |
+
target = Image.fromarray(target.numpy())
|
60 |
+
return image, target
|
61 |
+
|
62 |
+
|
63 |
+
def _coco_remove_images_without_annotations(dataset, cat_list=None):
|
64 |
+
def _has_valid_annotation(anno):
|
65 |
+
# if it's empty, there is no annotation
|
66 |
+
if len(anno) == 0:
|
67 |
+
return False
|
68 |
+
# if more than 1k pixels occupied in the image
|
69 |
+
return sum(obj["area"] for obj in anno) > 1000
|
70 |
+
|
71 |
+
if not isinstance(dataset, torchvision.datasets.CocoDetection):
|
72 |
+
raise TypeError(
|
73 |
+
f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}"
|
74 |
+
)
|
75 |
+
|
76 |
+
ids = []
|
77 |
+
for ds_idx, img_id in enumerate(dataset.ids):
|
78 |
+
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
|
79 |
+
anno = dataset.coco.loadAnns(ann_ids)
|
80 |
+
if cat_list:
|
81 |
+
anno = [obj for obj in anno if obj["category_id"] in cat_list]
|
82 |
+
if _has_valid_annotation(anno):
|
83 |
+
ids.append(ds_idx)
|
84 |
+
|
85 |
+
dataset = torch.utils.data.Subset(dataset, ids)
|
86 |
+
return dataset
|
87 |
+
|
88 |
+
|
89 |
+
def get_coco(root, image_set, transforms):
|
90 |
+
PATHS = {
|
91 |
+
"train": ("train2017", os.path.join("annotations", "instances_train2017.json")),
|
92 |
+
"val": ("val2017", os.path.join("annotations", "instances_val2017.json")),
|
93 |
+
# "train": ("val2017", os.path.join("annotations", "instances_val2017.json"))
|
94 |
+
}
|
95 |
+
CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72]
|
96 |
+
|
97 |
+
transforms = Compose([FilterAndRemapCocoCategories(CAT_LIST, remap=True), ConvertCocoPolysToMask(), transforms])
|
98 |
+
|
99 |
+
img_folder, ann_file = PATHS[image_set]
|
100 |
+
img_folder = os.path.join(root, img_folder)
|
101 |
+
ann_file = os.path.join(root, ann_file)
|
102 |
+
|
103 |
+
dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
|
104 |
+
|
105 |
+
if image_set == "train":
|
106 |
+
dataset = _coco_remove_images_without_annotations(dataset, CAT_LIST)
|
107 |
+
|
108 |
+
return dataset
|
tvcalib/sn_segmentation/src/segmentation/presets.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import transforms as T
|
3 |
+
|
4 |
+
|
5 |
+
class SegmentationPresetTrain:
|
6 |
+
def __init__(self, *, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
|
7 |
+
min_size = int(0.5 * base_size)
|
8 |
+
max_size = int(2.0 * base_size)
|
9 |
+
|
10 |
+
trans = [T.RandomResize(min_size, max_size)]
|
11 |
+
if hflip_prob > 0:
|
12 |
+
trans.append(T.RandomHorizontalFlip(hflip_prob))
|
13 |
+
trans.extend(
|
14 |
+
[
|
15 |
+
T.RandomCrop(crop_size),
|
16 |
+
T.PILToTensor(),
|
17 |
+
T.ConvertImageDtype(torch.float),
|
18 |
+
T.Normalize(mean=mean, std=std),
|
19 |
+
]
|
20 |
+
)
|
21 |
+
self.transforms = T.Compose(trans)
|
22 |
+
|
23 |
+
def __call__(self, img, target):
|
24 |
+
return self.transforms(img, target)
|
25 |
+
|
26 |
+
|
27 |
+
class SegmentationPresetEval:
|
28 |
+
def __init__(self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
|
29 |
+
self.transforms = T.Compose(
|
30 |
+
[
|
31 |
+
T.RandomResize(base_size, base_size),
|
32 |
+
T.PILToTensor(),
|
33 |
+
T.ConvertImageDtype(torch.float),
|
34 |
+
T.Normalize(mean=mean, std=std),
|
35 |
+
]
|
36 |
+
)
|
37 |
+
|
38 |
+
def __call__(self, img, target):
|
39 |
+
return self.transforms(img, target)
|
tvcalib/sn_segmentation/src/segmentation/soccerdata.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
import torchvision
|
4 |
+
import torchvision.transforms as T
|
5 |
+
from PIL import Image
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
|
10 |
+
from collections import defaultdict
|
11 |
+
from pathlib import Path
|
12 |
+
import json
|
13 |
+
import os
|
14 |
+
|
15 |
+
lines_classes = [
|
16 |
+
'Big rect. left bottom',
|
17 |
+
'Big rect. left main',
|
18 |
+
'Big rect. left top',
|
19 |
+
'Big rect. right bottom',
|
20 |
+
'Big rect. right main',
|
21 |
+
'Big rect. right top',
|
22 |
+
'Circle central',
|
23 |
+
'Circle left',
|
24 |
+
'Circle right',
|
25 |
+
'Goal left crossbar',
|
26 |
+
'Goal left post left ',
|
27 |
+
'Goal left post right',
|
28 |
+
'Goal right crossbar',
|
29 |
+
'Goal right post left',
|
30 |
+
'Goal right post right',
|
31 |
+
'Goal unknown',
|
32 |
+
'Line unknown',
|
33 |
+
'Middle line',
|
34 |
+
'Side line bottom',
|
35 |
+
'Side line left',
|
36 |
+
'Side line right',
|
37 |
+
'Side line top',
|
38 |
+
'Small rect. left bottom',
|
39 |
+
'Small rect. left main',
|
40 |
+
'Small rect. left top',
|
41 |
+
'Small rect. right bottom',
|
42 |
+
'Small rect. right main',
|
43 |
+
'Small rect. right top'
|
44 |
+
]
|
45 |
+
|
46 |
+
# RGB values
|
47 |
+
palette = {
|
48 |
+
'Big rect. left bottom': (127, 0, 0),
|
49 |
+
'Big rect. left main': (102, 102, 102),
|
50 |
+
'Big rect. left top': (0, 0, 127),
|
51 |
+
'Big rect. right bottom': (86, 32, 39),
|
52 |
+
'Big rect. right main': (48, 77, 0),
|
53 |
+
'Big rect. right top': (14, 97, 100),
|
54 |
+
'Circle central': (0, 0, 255),
|
55 |
+
'Circle left': (255, 127, 0),
|
56 |
+
'Circle right': (0, 255, 255),
|
57 |
+
'Goal left crossbar': (255, 255, 200),
|
58 |
+
'Goal left post left ': (165, 255, 0),
|
59 |
+
'Goal left post right': (155, 119, 45),
|
60 |
+
'Goal right crossbar': (86, 32, 139),
|
61 |
+
'Goal right post left': (196, 120, 153),
|
62 |
+
'Goal right post right': (166, 36, 52),
|
63 |
+
'Goal unknown': (0, 0, 0),
|
64 |
+
'Line unknown': (0, 0, 0),
|
65 |
+
'Middle line': (255, 255, 0),
|
66 |
+
'Side line bottom': (255, 0, 255),
|
67 |
+
'Side line left': (0, 255, 150),
|
68 |
+
'Side line right': (0, 230, 0),
|
69 |
+
'Side line top': (230, 0, 0),
|
70 |
+
'Small rect. left bottom': (0, 150, 255),
|
71 |
+
'Small rect. left main': (254, 173, 225),
|
72 |
+
'Small rect. left top': (87, 72, 39),
|
73 |
+
'Small rect. right bottom': (122, 0, 255),
|
74 |
+
'Small rect. right main': (255, 255, 255),
|
75 |
+
'Small rect. right top': (153, 23, 153)
|
76 |
+
}
|
77 |
+
|
78 |
+
data_dir = Path("data/datasets")
|
79 |
+
|
80 |
+
def create_target_from_annotation(width, height, annotation, classes):
|
81 |
+
"""Draw one-hot encoded segments according to the annotation.
|
82 |
+
Creates target that matches image size ([C+1]xHxW).
|
83 |
+
"""
|
84 |
+
annotation_abs = defaultdict(list)
|
85 |
+
# unnormalize every point in every class k
|
86 |
+
for k in annotation.keys():
|
87 |
+
start = annotation[k][0].copy()
|
88 |
+
end = annotation[k][-1].copy()
|
89 |
+
for annotation_point in annotation[k]:
|
90 |
+
tup = annotation_point.copy()
|
91 |
+
tup["x"] *= width
|
92 |
+
tup["x"] = int(tup["x"])
|
93 |
+
tup["y"] *= height
|
94 |
+
tup["y"] = int(tup["y"])
|
95 |
+
annotation_abs[k].append(tup)
|
96 |
+
|
97 |
+
# draw lines between annotated points for each segment
|
98 |
+
# offset class +1 such that no classes detected will end in argmax 0
|
99 |
+
# otherwise argmax 0 will be another class
|
100 |
+
classes_segments = np.zeros(shape=(len(classes) + 1, height, width))
|
101 |
+
for cls, points in annotation_abs.items():
|
102 |
+
class_segments = np.zeros(shape=(height, width, 3))
|
103 |
+
for start, end in zip(points, points[1:]):
|
104 |
+
startxy = (start["x"], start["y"])
|
105 |
+
endxy = [end["x"], end["y"]]
|
106 |
+
class_segments = cv2.line(class_segments, startxy, endxy, (1,1,1), 5)
|
107 |
+
classes_segments[classes.index(cls) + 1] = class_segments[:,:,1]
|
108 |
+
|
109 |
+
classes_segments = torch.Tensor(classes_segments)
|
110 |
+
return classes_segments
|
111 |
+
|
112 |
+
class ExtremitiesDataset(Dataset):
|
113 |
+
|
114 |
+
def __init__(self, root, split, classes=lines_classes, palette=palette):
|
115 |
+
self.data_root = Path(root)
|
116 |
+
self.split = split
|
117 |
+
files = os.listdir(self.data_root / self.split)
|
118 |
+
self.annotations = sorted([fn for fn in files if fn.endswith("json")])
|
119 |
+
self.images = sorted([fn for fn in files if fn.endswith("jpg")])
|
120 |
+
#self.height, self.width = 224, 224
|
121 |
+
self.classes = classes
|
122 |
+
|
123 |
+
def __len__(self):
|
124 |
+
return len(self.images)
|
125 |
+
|
126 |
+
def __getitem__(self, idx):
|
127 |
+
# see https://learnopencv.com/pytorch-for-beginners-semantic-segmentation-using-torchvision/
|
128 |
+
|
129 |
+
impath = self.data_root / self.split / self.images[idx]
|
130 |
+
annotation_path = self.data_root / self.split / self.annotations[idx]
|
131 |
+
#print(impath)
|
132 |
+
#print(annotation_path)
|
133 |
+
with open(annotation_path, "r") as f:
|
134 |
+
annotation = json.load(f)
|
135 |
+
|
136 |
+
# setup image, cast to device later in training
|
137 |
+
img = Image.open(impath)
|
138 |
+
trf = T.Compose(
|
139 |
+
[
|
140 |
+
T.Resize(256),
|
141 |
+
#T.CenterCrop(224),
|
142 |
+
T.ToTensor(),
|
143 |
+
T.Normalize(
|
144 |
+
mean = [0.485, 0.456, 0.406],
|
145 |
+
std = [0.229, 0.224, 0.225]
|
146 |
+
)
|
147 |
+
]
|
148 |
+
)
|
149 |
+
# prepare batches
|
150 |
+
img = trf(img)#.unsqueeze(0)
|
151 |
+
new_height, new_width = img.shape[-2], img.shape[-1]
|
152 |
+
|
153 |
+
|
154 |
+
target = create_target_from_annotation(new_width, new_height, annotation, self.classes)
|
155 |
+
#target = torchvision.transforms.functional.center_crop(target, 224)
|
156 |
+
target = target.long().argmax(dim=0)
|
157 |
+
|
158 |
+
return img, target
|
159 |
+
|
160 |
+
if __name__ == "__main__":
|
161 |
+
data = ExtremitiesDataset(root=data_dir, split="test")
|
162 |
+
print(data[0][1])
|
163 |
+
target = data[0][1].unsqueeze(0).permute(1,2,0)
|
164 |
+
plt.imshow(target)
|
tvcalib/sn_segmentation/src/segmentation/train.py
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
import warnings
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
import presets
|
9 |
+
import torch
|
10 |
+
import torch.utils.data
|
11 |
+
import torchvision
|
12 |
+
import utils
|
13 |
+
from coco_utils import get_coco
|
14 |
+
from torch import nn
|
15 |
+
from torchvision.transforms import functional as F, InterpolationMode
|
16 |
+
import wandb
|
17 |
+
|
18 |
+
from soccerdata import ExtremitiesDataset
|
19 |
+
|
20 |
+
|
21 |
+
def get_dataset(dir_path, name, image_set, transform):
|
22 |
+
def sbd(*args, **kwargs):
|
23 |
+
kwargs["root"] = "."
|
24 |
+
print(kwargs)
|
25 |
+
return torchvision.datasets.SBDataset(*args, mode="segmentation", download=True, **kwargs)
|
26 |
+
|
27 |
+
paths = {
|
28 |
+
"voc": (dir_path, torchvision.datasets.VOCSegmentation, 21),
|
29 |
+
"voc_aug": (dir_path, sbd, 21),
|
30 |
+
"coco": (dir_path, get_coco, 21),
|
31 |
+
}
|
32 |
+
p, ds_fn, num_classes = paths[name]
|
33 |
+
|
34 |
+
ds = ds_fn(p, image_set=image_set, transforms=transform)
|
35 |
+
return ds, num_classes
|
36 |
+
|
37 |
+
|
38 |
+
def get_transform(train, args):
|
39 |
+
if train:
|
40 |
+
return presets.SegmentationPresetTrain(base_size=520, crop_size=480)
|
41 |
+
elif args.weights and args.test_only:
|
42 |
+
weights = torchvision.models.get_weight(args.weights)
|
43 |
+
trans = weights.transforms()
|
44 |
+
|
45 |
+
def preprocessing(img, target):
|
46 |
+
img = trans(img)
|
47 |
+
size = F.get_dimensions(img)[1:]
|
48 |
+
target = F.resize(target, size, interpolation=InterpolationMode.NEAREST)
|
49 |
+
return img, F.pil_to_tensor(target)
|
50 |
+
|
51 |
+
return preprocessing
|
52 |
+
else:
|
53 |
+
return presets.SegmentationPresetEval(base_size=520)
|
54 |
+
|
55 |
+
|
56 |
+
def criterion(inputs, target):
|
57 |
+
losses = {}
|
58 |
+
for name, x in inputs.items():
|
59 |
+
losses[name] = nn.functional.cross_entropy(x, target, ignore_index=255)
|
60 |
+
|
61 |
+
if len(losses) == 1:
|
62 |
+
return losses["out"]
|
63 |
+
|
64 |
+
return losses["out"] + 0.5 * losses["aux"]
|
65 |
+
|
66 |
+
|
67 |
+
def evaluate(model, data_loader, device, num_classes):
|
68 |
+
model.eval()
|
69 |
+
confmat = utils.ConfusionMatrix(num_classes)
|
70 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
71 |
+
header = "Test:"
|
72 |
+
num_processed_samples = 0
|
73 |
+
losses = []
|
74 |
+
with torch.inference_mode():
|
75 |
+
for image, target in metric_logger.log_every(data_loader, 100, header):
|
76 |
+
image, target = image.to(device), target.to(device)
|
77 |
+
output = model(image)
|
78 |
+
loss = criterion(output, target).unsqueeze(0).detach().cpu()
|
79 |
+
losses.append(loss)
|
80 |
+
output = output["out"]
|
81 |
+
|
82 |
+
# 1xCx224x224
|
83 |
+
|
84 |
+
confmat.update(target.flatten(), output.argmax(1).flatten())
|
85 |
+
# FIXME need to take into account that the datasets
|
86 |
+
# could have been padded in distributed setup
|
87 |
+
num_processed_samples += image.shape[0]
|
88 |
+
|
89 |
+
confmat.reduce_from_all_processes()
|
90 |
+
|
91 |
+
print(losses[0])
|
92 |
+
#print(losses)
|
93 |
+
losses = torch.cat(losses)
|
94 |
+
loss_argsort = torch.argsort(losses, descending=True)
|
95 |
+
loss_argsort.numpy()
|
96 |
+
np.save("losses_argsort.npy", loss_argsort)
|
97 |
+
|
98 |
+
num_processed_samples = utils.reduce_across_processes(num_processed_samples)
|
99 |
+
if (
|
100 |
+
hasattr(data_loader.dataset, "__len__")
|
101 |
+
and len(data_loader.dataset) != num_processed_samples
|
102 |
+
and torch.distributed.get_rank() == 0
|
103 |
+
):
|
104 |
+
# See FIXME above
|
105 |
+
warnings.warn(
|
106 |
+
f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
|
107 |
+
"samples were used for the validation, which might bias the results. "
|
108 |
+
"Try adjusting the batch size and / or the world size. "
|
109 |
+
"Setting the world size to 1 is always a safe bet."
|
110 |
+
)
|
111 |
+
|
112 |
+
return confmat
|
113 |
+
|
114 |
+
|
115 |
+
def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, print_freq, scaler=None):
|
116 |
+
model.train()
|
117 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
118 |
+
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
|
119 |
+
header = f"Epoch: [{epoch}]"
|
120 |
+
for image, target in metric_logger.log_every(data_loader, print_freq, header):
|
121 |
+
image, target = image.to(device), target.to(device)
|
122 |
+
with torch.cuda.amp.autocast(enabled=scaler is not None):
|
123 |
+
output = model(image)
|
124 |
+
loss = criterion(output, target)
|
125 |
+
wandb.log({"loss": loss})
|
126 |
+
|
127 |
+
optimizer.zero_grad()
|
128 |
+
if scaler is not None:
|
129 |
+
scaler.scale(loss).backward()
|
130 |
+
scaler.step(optimizer)
|
131 |
+
scaler.update()
|
132 |
+
else:
|
133 |
+
loss.backward()
|
134 |
+
optimizer.step()
|
135 |
+
|
136 |
+
lr_scheduler.step()
|
137 |
+
|
138 |
+
metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
|
139 |
+
|
140 |
+
|
141 |
+
def main(args):
|
142 |
+
wandb.init(project="sn-calibration", entity="rhotertj")
|
143 |
+
if args.output_dir:
|
144 |
+
utils.mkdir(args.output_dir)
|
145 |
+
|
146 |
+
utils.init_distributed_mode(args)
|
147 |
+
print(args)
|
148 |
+
|
149 |
+
device = torch.device(args.device)
|
150 |
+
|
151 |
+
if args.use_deterministic_algorithms:
|
152 |
+
torch.backends.cudnn.benchmark = False
|
153 |
+
torch.use_deterministic_algorithms(True)
|
154 |
+
else:
|
155 |
+
torch.backends.cudnn.benchmark = True
|
156 |
+
|
157 |
+
#dataset, num_classes = get_dataset(args.data_path, args.dataset, "train", get_transform(True, args))
|
158 |
+
dataset = ExtremitiesDataset(root="/nfs/data/soccernet/calibration", split="train")
|
159 |
+
num_classes = len(dataset.classes) + 1
|
160 |
+
|
161 |
+
dataset_test = ExtremitiesDataset(root="/nfs/data/soccernet/calibration", split="test")
|
162 |
+
#dataset_test, _ = get_dataset(args.data_path, args.dataset, "val", get_transform(False, args))
|
163 |
+
|
164 |
+
if args.distributed:
|
165 |
+
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
166 |
+
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
|
167 |
+
else:
|
168 |
+
train_sampler = torch.utils.data.RandomSampler(dataset)
|
169 |
+
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
|
170 |
+
|
171 |
+
data_loader = torch.utils.data.DataLoader(
|
172 |
+
dataset,
|
173 |
+
batch_size=args.batch_size,
|
174 |
+
sampler=train_sampler,
|
175 |
+
num_workers=args.workers,
|
176 |
+
collate_fn=utils.collate_fn,
|
177 |
+
drop_last=True,
|
178 |
+
)
|
179 |
+
|
180 |
+
data_loader_test = torch.utils.data.DataLoader(
|
181 |
+
dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
|
182 |
+
)
|
183 |
+
|
184 |
+
|
185 |
+
model = torchvision.models.segmentation.deeplabv3_resnet101(num_classes=num_classes, aux_loss=args.aux_loss)
|
186 |
+
if args.test_only or args.resume:
|
187 |
+
model.load_state_dict(torch.load(args.weights)["model"], strict=False)
|
188 |
+
|
189 |
+
#model = torchvision.models.segmentation.__dict__[args.model](
|
190 |
+
# weights=args.weights, num_classes=num_classes, aux_loss=args.aux_loss
|
191 |
+
#)
|
192 |
+
model.to(device)
|
193 |
+
if args.distributed:
|
194 |
+
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
195 |
+
|
196 |
+
model_without_ddp = model
|
197 |
+
if args.distributed:
|
198 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
199 |
+
model_without_ddp = model.module
|
200 |
+
|
201 |
+
params_to_optimize = [
|
202 |
+
{"params": [p for p in model_without_ddp.backbone.parameters() if p.requires_grad]},
|
203 |
+
{"params": [p for p in model_without_ddp.classifier.parameters() if p.requires_grad]},
|
204 |
+
]
|
205 |
+
if args.aux_loss:
|
206 |
+
params = [p for p in model_without_ddp.aux_classifier.parameters() if p.requires_grad]
|
207 |
+
params_to_optimize.append({"params": params, "lr": args.lr * 10})
|
208 |
+
optimizer = torch.optim.SGD(params_to_optimize, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
|
209 |
+
|
210 |
+
scaler = torch.cuda.amp.GradScaler() if args.amp else None
|
211 |
+
|
212 |
+
iters_per_epoch = len(data_loader)
|
213 |
+
main_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
|
214 |
+
optimizer, lambda x: (1 - x / (iters_per_epoch * (args.epochs - args.lr_warmup_epochs))) ** 0.9
|
215 |
+
)
|
216 |
+
|
217 |
+
if args.lr_warmup_epochs > 0:
|
218 |
+
warmup_iters = iters_per_epoch * args.lr_warmup_epochs
|
219 |
+
args.lr_warmup_method = args.lr_warmup_method.lower()
|
220 |
+
if args.lr_warmup_method == "linear":
|
221 |
+
warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
|
222 |
+
optimizer, start_factor=args.lr_warmup_decay, total_iters=warmup_iters
|
223 |
+
)
|
224 |
+
elif args.lr_warmup_method == "constant":
|
225 |
+
warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
|
226 |
+
optimizer, factor=args.lr_warmup_decay, total_iters=warmup_iters
|
227 |
+
)
|
228 |
+
else:
|
229 |
+
raise RuntimeError(
|
230 |
+
f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
|
231 |
+
)
|
232 |
+
lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
|
233 |
+
optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[warmup_iters]
|
234 |
+
)
|
235 |
+
else:
|
236 |
+
lr_scheduler = main_lr_scheduler
|
237 |
+
|
238 |
+
if args.resume:
|
239 |
+
#checkpoint = torch.load(args.resume, map_location="cpu")
|
240 |
+
#model_without_ddp.load_state_dict(checkpoint["model"], strict=not args.test_only)
|
241 |
+
if not args.test_only:
|
242 |
+
optimizer.load_state_dict(checkpoint["optimizer"])
|
243 |
+
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
|
244 |
+
args.start_epoch = checkpoint["epoch"] + 1
|
245 |
+
if args.amp:
|
246 |
+
scaler.load_state_dict(checkpoint["scaler"])
|
247 |
+
|
248 |
+
if args.test_only:
|
249 |
+
# We disable the cudnn benchmarking because it can noticeably affect the accuracy
|
250 |
+
torch.backends.cudnn.benchmark = False
|
251 |
+
torch.backends.cudnn.deterministic = True
|
252 |
+
confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
|
253 |
+
print(confmat)
|
254 |
+
return
|
255 |
+
|
256 |
+
start_time = time.time()
|
257 |
+
for epoch in range(args.start_epoch, args.epochs):
|
258 |
+
if args.distributed:
|
259 |
+
train_sampler.set_epoch(epoch)
|
260 |
+
train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq, scaler)
|
261 |
+
confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
|
262 |
+
print(confmat)
|
263 |
+
checkpoint = {
|
264 |
+
"model": model_without_ddp.state_dict(),
|
265 |
+
"optimizer": optimizer.state_dict(),
|
266 |
+
"lr_scheduler": lr_scheduler.state_dict(),
|
267 |
+
"epoch": epoch,
|
268 |
+
"args": args,
|
269 |
+
}
|
270 |
+
if args.amp:
|
271 |
+
checkpoint["scaler"] = scaler.state_dict()
|
272 |
+
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
|
273 |
+
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
|
274 |
+
|
275 |
+
total_time = time.time() - start_time
|
276 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
277 |
+
print(f"Training time {total_time_str}")
|
278 |
+
|
279 |
+
|
280 |
+
def get_args_parser(add_help=True):
|
281 |
+
import argparse
|
282 |
+
|
283 |
+
parser = argparse.ArgumentParser(description="PyTorch Segmentation Training", add_help=add_help)
|
284 |
+
|
285 |
+
parser.add_argument("--data-path", default="/datasets01/COCO/022719/", type=str, help="dataset path")
|
286 |
+
parser.add_argument("--dataset", default="coco", type=str, help="dataset name")
|
287 |
+
parser.add_argument("--model", default="fcn_resnet101", type=str, help="model name")
|
288 |
+
parser.add_argument("--aux-loss", action="store_true", help="auxiliar loss")
|
289 |
+
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
|
290 |
+
parser.add_argument(
|
291 |
+
"-b", "--batch-size", default=8, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
|
292 |
+
)
|
293 |
+
parser.add_argument("--epochs", default=30, type=int, metavar="N", help="number of total epochs to run")
|
294 |
+
|
295 |
+
parser.add_argument(
|
296 |
+
"-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
|
297 |
+
)
|
298 |
+
parser.add_argument("--lr", default=0.01, type=float, help="initial learning rate")
|
299 |
+
parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
|
300 |
+
parser.add_argument(
|
301 |
+
"--wd",
|
302 |
+
"--weight-decay",
|
303 |
+
default=1e-4,
|
304 |
+
type=float,
|
305 |
+
metavar="W",
|
306 |
+
help="weight decay (default: 1e-4)",
|
307 |
+
dest="weight_decay",
|
308 |
+
)
|
309 |
+
parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
|
310 |
+
parser.add_argument("--lr-warmup-method", default="linear", type=str, help="the warmup method (default: linear)")
|
311 |
+
parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
|
312 |
+
parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
|
313 |
+
parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
|
314 |
+
parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
|
315 |
+
parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
|
316 |
+
parser.add_argument(
|
317 |
+
"--test-only",
|
318 |
+
dest="test_only",
|
319 |
+
help="Only test the model",
|
320 |
+
action="store_true",
|
321 |
+
)
|
322 |
+
parser.add_argument(
|
323 |
+
"--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
|
324 |
+
)
|
325 |
+
# distributed training parameters
|
326 |
+
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
|
327 |
+
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
|
328 |
+
|
329 |
+
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
|
330 |
+
parser.add_argument("--weights-backbone", default=None, type=str, help="the backbone weights enum name to load")
|
331 |
+
|
332 |
+
# Mixed precision training parameters
|
333 |
+
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
|
334 |
+
parser.add_argument("--split", default="train", type=str, help="Dataset split to be used for training")
|
335 |
+
|
336 |
+
return parser
|
337 |
+
|
338 |
+
|
339 |
+
if __name__ == "__main__":
|
340 |
+
args = get_args_parser().parse_args()
|
341 |
+
main(args)
|
tvcalib/sn_segmentation/src/segmentation/transforms.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from torchvision import transforms as T
|
6 |
+
from torchvision.transforms import functional as F
|
7 |
+
|
8 |
+
|
9 |
+
def pad_if_smaller(img, size, fill=0):
|
10 |
+
min_size = min(img.size)
|
11 |
+
if min_size < size:
|
12 |
+
ow, oh = img.size
|
13 |
+
padh = size - oh if oh < size else 0
|
14 |
+
padw = size - ow if ow < size else 0
|
15 |
+
img = F.pad(img, (0, 0, padw, padh), fill=fill)
|
16 |
+
return img
|
17 |
+
|
18 |
+
|
19 |
+
class Compose:
|
20 |
+
def __init__(self, transforms):
|
21 |
+
self.transforms = transforms
|
22 |
+
|
23 |
+
def __call__(self, image, target):
|
24 |
+
for t in self.transforms:
|
25 |
+
image, target = t(image, target)
|
26 |
+
return image, target
|
27 |
+
|
28 |
+
|
29 |
+
class RandomResize:
|
30 |
+
def __init__(self, min_size, max_size=None):
|
31 |
+
self.min_size = min_size
|
32 |
+
if max_size is None:
|
33 |
+
max_size = min_size
|
34 |
+
self.max_size = max_size
|
35 |
+
|
36 |
+
def __call__(self, image, target):
|
37 |
+
size = random.randint(self.min_size, self.max_size)
|
38 |
+
image = F.resize(image, size)
|
39 |
+
target = F.resize(target, size, interpolation=T.InterpolationMode.NEAREST)
|
40 |
+
return image, target
|
41 |
+
|
42 |
+
|
43 |
+
class RandomHorizontalFlip:
|
44 |
+
def __init__(self, flip_prob):
|
45 |
+
self.flip_prob = flip_prob
|
46 |
+
|
47 |
+
def __call__(self, image, target):
|
48 |
+
if random.random() < self.flip_prob:
|
49 |
+
image = F.hflip(image)
|
50 |
+
target = F.hflip(target)
|
51 |
+
return image, target
|
52 |
+
|
53 |
+
|
54 |
+
class RandomCrop:
|
55 |
+
def __init__(self, size):
|
56 |
+
self.size = size
|
57 |
+
|
58 |
+
def __call__(self, image, target):
|
59 |
+
image = pad_if_smaller(image, self.size)
|
60 |
+
target = pad_if_smaller(target, self.size, fill=255)
|
61 |
+
crop_params = T.RandomCrop.get_params(image, (self.size, self.size))
|
62 |
+
image = F.crop(image, *crop_params)
|
63 |
+
target = F.crop(target, *crop_params)
|
64 |
+
return image, target
|
65 |
+
|
66 |
+
|
67 |
+
class CenterCrop:
|
68 |
+
def __init__(self, size):
|
69 |
+
self.size = size
|
70 |
+
|
71 |
+
def __call__(self, image, target):
|
72 |
+
image = F.center_crop(image, self.size)
|
73 |
+
target = F.center_crop(target, self.size)
|
74 |
+
return image, target
|
75 |
+
|
76 |
+
|
77 |
+
class PILToTensor:
|
78 |
+
def __call__(self, image, target):
|
79 |
+
image = F.pil_to_tensor(image)
|
80 |
+
target = torch.as_tensor(np.array(target), dtype=torch.int64)
|
81 |
+
return image, target
|
82 |
+
|
83 |
+
|
84 |
+
class ConvertImageDtype:
|
85 |
+
def __init__(self, dtype):
|
86 |
+
self.dtype = dtype
|
87 |
+
|
88 |
+
def __call__(self, image, target):
|
89 |
+
image = F.convert_image_dtype(image, self.dtype)
|
90 |
+
return image, target
|
91 |
+
|
92 |
+
|
93 |
+
class Normalize:
|
94 |
+
def __init__(self, mean, std):
|
95 |
+
self.mean = mean
|
96 |
+
self.std = std
|
97 |
+
|
98 |
+
def __call__(self, image, target):
|
99 |
+
image = F.normalize(image, mean=self.mean, std=self.std)
|
100 |
+
return image, target
|
tvcalib/sn_segmentation/src/segmentation/utils.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import errno
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
from collections import defaultdict, deque
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.distributed as dist
|
9 |
+
|
10 |
+
|
11 |
+
class SmoothedValue:
|
12 |
+
"""Track a series of values and provide access to smoothed values over a
|
13 |
+
window or the global series average.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, window_size=20, fmt=None):
|
17 |
+
if fmt is None:
|
18 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
19 |
+
self.deque = deque(maxlen=window_size)
|
20 |
+
self.total = 0.0
|
21 |
+
self.count = 0
|
22 |
+
self.fmt = fmt
|
23 |
+
|
24 |
+
def update(self, value, n=1):
|
25 |
+
self.deque.append(value)
|
26 |
+
self.count += n
|
27 |
+
self.total += value * n
|
28 |
+
|
29 |
+
def synchronize_between_processes(self):
|
30 |
+
"""
|
31 |
+
Warning: does not synchronize the deque!
|
32 |
+
"""
|
33 |
+
t = reduce_across_processes([self.count, self.total])
|
34 |
+
t = t.tolist()
|
35 |
+
self.count = int(t[0])
|
36 |
+
self.total = t[1]
|
37 |
+
|
38 |
+
@property
|
39 |
+
def median(self):
|
40 |
+
d = torch.tensor(list(self.deque))
|
41 |
+
return d.median().item()
|
42 |
+
|
43 |
+
@property
|
44 |
+
def avg(self):
|
45 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
46 |
+
return d.mean().item()
|
47 |
+
|
48 |
+
@property
|
49 |
+
def global_avg(self):
|
50 |
+
return self.total / self.count
|
51 |
+
|
52 |
+
@property
|
53 |
+
def max(self):
|
54 |
+
return max(self.deque)
|
55 |
+
|
56 |
+
@property
|
57 |
+
def value(self):
|
58 |
+
return self.deque[-1]
|
59 |
+
|
60 |
+
def __str__(self):
|
61 |
+
return self.fmt.format(
|
62 |
+
median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
|
63 |
+
)
|
64 |
+
|
65 |
+
|
66 |
+
class ConfusionMatrix:
|
67 |
+
def __init__(self, num_classes):
|
68 |
+
self.num_classes = num_classes
|
69 |
+
self.mat = None
|
70 |
+
|
71 |
+
def update(self, a, b):
|
72 |
+
n = self.num_classes
|
73 |
+
if self.mat is None:
|
74 |
+
self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
|
75 |
+
with torch.inference_mode():
|
76 |
+
k = (a >= 0) & (a < n)
|
77 |
+
inds = n * a[k].to(torch.int64) + b[k]
|
78 |
+
self.mat += torch.bincount(inds, minlength=n ** 2).reshape(n, n)
|
79 |
+
|
80 |
+
def reset(self):
|
81 |
+
self.mat.zero_()
|
82 |
+
|
83 |
+
def compute(self):
|
84 |
+
h = self.mat.float()
|
85 |
+
acc_global = torch.diag(h).sum() / h.sum()
|
86 |
+
acc = torch.diag(h) / h.sum(1)
|
87 |
+
iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
|
88 |
+
return acc_global, acc, iu
|
89 |
+
|
90 |
+
def reduce_from_all_processes(self):
|
91 |
+
reduce_across_processes(self.mat)
|
92 |
+
|
93 |
+
def __str__(self):
|
94 |
+
acc_global, acc, iu = self.compute()
|
95 |
+
return ("global correct: {:.1f}\naverage row correct: {}\nIoU: {}\nmean IoU: {:.1f}").format(
|
96 |
+
acc_global.item() * 100,
|
97 |
+
[f"{i:.1f}" for i in (acc * 100).tolist()],
|
98 |
+
[f"{i:.1f}" for i in (iu * 100).tolist()],
|
99 |
+
iu.mean().item() * 100,
|
100 |
+
)
|
101 |
+
|
102 |
+
|
103 |
+
class MetricLogger:
|
104 |
+
def __init__(self, delimiter="\t"):
|
105 |
+
self.meters = defaultdict(SmoothedValue)
|
106 |
+
self.delimiter = delimiter
|
107 |
+
|
108 |
+
def update(self, **kwargs):
|
109 |
+
for k, v in kwargs.items():
|
110 |
+
if isinstance(v, torch.Tensor):
|
111 |
+
v = v.item()
|
112 |
+
if not isinstance(v, (float, int)):
|
113 |
+
raise TypeError(
|
114 |
+
f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}"
|
115 |
+
)
|
116 |
+
self.meters[k].update(v)
|
117 |
+
|
118 |
+
def __getattr__(self, attr):
|
119 |
+
if attr in self.meters:
|
120 |
+
return self.meters[attr]
|
121 |
+
if attr in self.__dict__:
|
122 |
+
return self.__dict__[attr]
|
123 |
+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
|
124 |
+
|
125 |
+
def __str__(self):
|
126 |
+
loss_str = []
|
127 |
+
for name, meter in self.meters.items():
|
128 |
+
loss_str.append(f"{name}: {str(meter)}")
|
129 |
+
return self.delimiter.join(loss_str)
|
130 |
+
|
131 |
+
def synchronize_between_processes(self):
|
132 |
+
for meter in self.meters.values():
|
133 |
+
meter.synchronize_between_processes()
|
134 |
+
|
135 |
+
def add_meter(self, name, meter):
|
136 |
+
self.meters[name] = meter
|
137 |
+
|
138 |
+
def log_every(self, iterable, print_freq, header=None):
|
139 |
+
i = 0
|
140 |
+
if not header:
|
141 |
+
header = ""
|
142 |
+
start_time = time.time()
|
143 |
+
end = time.time()
|
144 |
+
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
145 |
+
data_time = SmoothedValue(fmt="{avg:.4f}")
|
146 |
+
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
147 |
+
if torch.cuda.is_available():
|
148 |
+
log_msg = self.delimiter.join(
|
149 |
+
[
|
150 |
+
header,
|
151 |
+
"[{0" + space_fmt + "}/{1}]",
|
152 |
+
"eta: {eta}",
|
153 |
+
"{meters}",
|
154 |
+
"time: {time}",
|
155 |
+
"data: {data}",
|
156 |
+
"max mem: {memory:.0f}",
|
157 |
+
]
|
158 |
+
)
|
159 |
+
else:
|
160 |
+
log_msg = self.delimiter.join(
|
161 |
+
[header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
|
162 |
+
)
|
163 |
+
MB = 1024.0 * 1024.0
|
164 |
+
for obj in iterable:
|
165 |
+
data_time.update(time.time() - end)
|
166 |
+
yield obj
|
167 |
+
iter_time.update(time.time() - end)
|
168 |
+
if i % print_freq == 0:
|
169 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
170 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
171 |
+
if torch.cuda.is_available():
|
172 |
+
print(
|
173 |
+
log_msg.format(
|
174 |
+
i,
|
175 |
+
len(iterable),
|
176 |
+
eta=eta_string,
|
177 |
+
meters=str(self),
|
178 |
+
time=str(iter_time),
|
179 |
+
data=str(data_time),
|
180 |
+
memory=torch.cuda.max_memory_allocated() / MB,
|
181 |
+
)
|
182 |
+
)
|
183 |
+
else:
|
184 |
+
print(
|
185 |
+
log_msg.format(
|
186 |
+
i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
|
187 |
+
)
|
188 |
+
)
|
189 |
+
i += 1
|
190 |
+
end = time.time()
|
191 |
+
total_time = time.time() - start_time
|
192 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
193 |
+
print(f"{header} Total time: {total_time_str}")
|
194 |
+
|
195 |
+
|
196 |
+
def cat_list(images, fill_value=0):
|
197 |
+
max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
|
198 |
+
batch_shape = (len(images),) + max_size
|
199 |
+
batched_imgs = images[0].new(*batch_shape).fill_(fill_value)
|
200 |
+
for img, pad_img in zip(images, batched_imgs):
|
201 |
+
pad_img[..., : img.shape[-2], : img.shape[-1]].copy_(img)
|
202 |
+
return batched_imgs
|
203 |
+
|
204 |
+
|
205 |
+
def collate_fn(batch):
|
206 |
+
images, targets = list(zip(*batch))
|
207 |
+
batched_imgs = cat_list(images, fill_value=0)
|
208 |
+
batched_targets = cat_list(targets, fill_value=255)
|
209 |
+
return batched_imgs, batched_targets
|
210 |
+
|
211 |
+
|
212 |
+
def mkdir(path):
|
213 |
+
try:
|
214 |
+
os.makedirs(path)
|
215 |
+
except OSError as e:
|
216 |
+
if e.errno != errno.EEXIST:
|
217 |
+
raise
|
218 |
+
|
219 |
+
|
220 |
+
def setup_for_distributed(is_master):
|
221 |
+
"""
|
222 |
+
This function disables printing when not in master process
|
223 |
+
"""
|
224 |
+
import builtins as __builtin__
|
225 |
+
|
226 |
+
builtin_print = __builtin__.print
|
227 |
+
|
228 |
+
def print(*args, **kwargs):
|
229 |
+
force = kwargs.pop("force", False)
|
230 |
+
if is_master or force:
|
231 |
+
builtin_print(*args, **kwargs)
|
232 |
+
|
233 |
+
__builtin__.print = print
|
234 |
+
|
235 |
+
|
236 |
+
def is_dist_avail_and_initialized():
|
237 |
+
if not dist.is_available():
|
238 |
+
return False
|
239 |
+
if not dist.is_initialized():
|
240 |
+
return False
|
241 |
+
return True
|
242 |
+
|
243 |
+
|
244 |
+
def get_world_size():
|
245 |
+
if not is_dist_avail_and_initialized():
|
246 |
+
return 1
|
247 |
+
return dist.get_world_size()
|
248 |
+
|
249 |
+
|
250 |
+
def get_rank():
|
251 |
+
if not is_dist_avail_and_initialized():
|
252 |
+
return 0
|
253 |
+
return dist.get_rank()
|
254 |
+
|
255 |
+
|
256 |
+
def is_main_process():
|
257 |
+
return get_rank() == 0
|
258 |
+
|
259 |
+
|
260 |
+
def save_on_master(*args, **kwargs):
|
261 |
+
if is_main_process():
|
262 |
+
torch.save(*args, **kwargs)
|
263 |
+
|
264 |
+
|
265 |
+
def init_distributed_mode(args):
|
266 |
+
print("Not using distributed mode")
|
267 |
+
args.distributed = False
|
268 |
+
return
|
269 |
+
|
270 |
+
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
271 |
+
args.rank = int(os.environ["RANK"])
|
272 |
+
args.world_size = int(os.environ["WORLD_SIZE"])
|
273 |
+
args.gpu = int(os.environ["LOCAL_RANK"])
|
274 |
+
elif "SLURM_PROCID" in os.environ:
|
275 |
+
args.rank = int(os.environ["SLURM_PROCID"])
|
276 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
277 |
+
elif hasattr(args, "rank"):
|
278 |
+
pass
|
279 |
+
else:
|
280 |
+
print("Not using distributed mode")
|
281 |
+
args.distributed = False
|
282 |
+
return
|
283 |
+
|
284 |
+
args.distributed = True
|
285 |
+
|
286 |
+
torch.cuda.set_device(args.gpu)
|
287 |
+
args.dist_backend = "nccl"
|
288 |
+
print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
|
289 |
+
torch.distributed.init_process_group(
|
290 |
+
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
|
291 |
+
)
|
292 |
+
torch.distributed.barrier()
|
293 |
+
setup_for_distributed(args.rank == 0)
|
294 |
+
|
295 |
+
|
296 |
+
def reduce_across_processes(val):
|
297 |
+
if not is_dist_avail_and_initialized():
|
298 |
+
# nothing to sync, but we still convert to tensor for consistency with the distributed case.
|
299 |
+
return torch.tensor(val)
|
300 |
+
|
301 |
+
t = torch.tensor(val, device="cuda")
|
302 |
+
dist.barrier()
|
303 |
+
dist.all_reduce(t)
|
304 |
+
return t
|
tvcalib/utils/data_distr.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def mean_std_with_confidence_interval(
|
5 |
+
vmin, vmax, sigma_scale: float, _steps=1000, round_decimals=4
|
6 |
+
):
|
7 |
+
"""Computes mean and std given min,max values with respect a confidence interval (sigma_scale).
|
8 |
+
|
9 |
+
sigma_scale = 1.65 -> 90% of samples are in range [vmin, vmax]
|
10 |
+
sigma_scale = 1.96 -> 95% of samples are in range [vmin, vmax]
|
11 |
+
sigma_scale = 2.58 -> 99% of samples are in range [vmin, vmax]
|
12 |
+
"""
|
13 |
+
|
14 |
+
# sample from uniform distribution
|
15 |
+
x = torch.linspace(vmin, vmax, _steps)
|
16 |
+
mu = x.mean(dim=-1)
|
17 |
+
sigma = x.std(dim=-1)
|
18 |
+
return (round(mu.item(), round_decimals), round((sigma * sigma_scale).item(), round_decimals))
|
19 |
+
|
20 |
+
|
21 |
+
class FeatureScalerZScore(torch.nn.Module):
|
22 |
+
def __init__(self, loc: float, scale: float) -> None:
|
23 |
+
# Transforms data from distribution parameterized by loc (mean) and scale (=sigma*scaling factor).
|
24 |
+
super(FeatureScalerZScore, self).__init__()
|
25 |
+
|
26 |
+
self.loc = loc
|
27 |
+
self.scale = scale
|
28 |
+
|
29 |
+
def forward(self, z):
|
30 |
+
"""
|
31 |
+
Args:
|
32 |
+
z (Tensor): tensor of size (B, *) to be denormalized.
|
33 |
+
Returns:
|
34 |
+
x: tensor.
|
35 |
+
"""
|
36 |
+
return self.denormalize(z)
|
37 |
+
|
38 |
+
def denormalize(self, z):
|
39 |
+
x = z * self.scale + self.loc
|
40 |
+
return x
|
41 |
+
|
42 |
+
def normalize(self, x):
|
43 |
+
z = (x - self.loc) / self.scale
|
44 |
+
return z
|
tvcalib/utils/io.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import yaml
|
2 |
+
import json
|
3 |
+
from typing import List
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def tensor2list(d: dict):
|
8 |
+
tensor2list_lambda = lambda x: x.detach().cpu().numpy().tolist()
|
9 |
+
for k in d.keys():
|
10 |
+
if isinstance(d[k], torch.Tensor):
|
11 |
+
d[k] = tensor2list_lambda(d[k])
|
12 |
+
if isinstance(d[k], List):
|
13 |
+
if isinstance(d[k][0], torch.Tensor):
|
14 |
+
d[k] = [tensor2list_lambda(x) for x in d[k]]
|
15 |
+
return d
|
16 |
+
|
17 |
+
|
18 |
+
def write_json(json_serializable_dict, fout, indent=2):
|
19 |
+
with open(fout, "w") as fw:
|
20 |
+
json.dump(json_serializable_dict, fw, indent=indent)
|
21 |
+
|
22 |
+
|
23 |
+
def write_yaml(json_serializable_dict, fout):
|
24 |
+
with open(fout, "w") as fw:
|
25 |
+
yaml.dump(json_serializable_dict, fw, default_flow_style=False)
|
26 |
+
|
27 |
+
|
28 |
+
def detach_dict(x_dict):
|
29 |
+
with torch.no_grad():
|
30 |
+
for k in x_dict.keys():
|
31 |
+
if isinstance(x_dict[k], torch.Tensor):
|
32 |
+
x_dict[k] = x_dict[k].detach().cpu()
|
33 |
+
elif isinstance(x_dict[k], dict):
|
34 |
+
x_dict[k] = detach_dict(x_dict[k])
|
35 |
+
return x_dict
|
36 |
+
|
37 |
+
|
38 |
+
def tensor2list(xdict):
|
39 |
+
for k in xdict.keys():
|
40 |
+
if isinstance(xdict[k], torch.Tensor):
|
41 |
+
xdict[k] = xdict[k].numpy().tolist()
|
42 |
+
elif isinstance(xdict[k], dict):
|
43 |
+
xdict[k] = tensor2list(xdict[k])
|
44 |
+
return xdict
|
tvcalib/utils/linalg.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Union
|
2 |
+
import torch
|
3 |
+
from kornia.geometry.conversions import convert_points_from_homogeneous
|
4 |
+
|
5 |
+
|
6 |
+
class LineCollection:
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
support: torch.tensor,
|
10 |
+
direction_norm: torch.tensor,
|
11 |
+
direction: Optional[torch.tensor] = None,
|
12 |
+
):
|
13 |
+
"""Wrapper class to represent lines by support and direction vectors.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
support (torch.tensor): with shape (*, {2,3})
|
17 |
+
direction_norm (torch.tensor): with shape (*, {2,3})
|
18 |
+
direction (Optional[torch.tensor], optional): Unnormalized direction vector. Defaults to None.
|
19 |
+
"""
|
20 |
+
self.support = support
|
21 |
+
self.direction_norm = direction_norm
|
22 |
+
self.direction = direction
|
23 |
+
|
24 |
+
def __copy__(self):
|
25 |
+
return LineCollection(
|
26 |
+
self.support.clone(),
|
27 |
+
self.direction_norm.clone(),
|
28 |
+
self.direction.clone() if self.direction is not None else None,
|
29 |
+
)
|
30 |
+
|
31 |
+
def copy(self):
|
32 |
+
return self.__copy__()
|
33 |
+
|
34 |
+
def shape(self):
|
35 |
+
return f"support={self.support.shape} direction_norm={self.direction_norm.shape} direction={self.direction.shape if self.direction else None}"
|
36 |
+
|
37 |
+
def __repr__(self) -> str:
|
38 |
+
return f"{self.__class__} " + self.shape()
|
39 |
+
|
40 |
+
|
41 |
+
def distance_line_pointcloud_3d(
|
42 |
+
e1: torch.Tensor,
|
43 |
+
r1: torch.Tensor,
|
44 |
+
pc: torch.Tensor,
|
45 |
+
reduce: Union[None, str] = None,
|
46 |
+
) -> torch.Tensor:
|
47 |
+
"""
|
48 |
+
Line to point cloud distance with arbitrary leading dimensions.
|
49 |
+
|
50 |
+
TODO. if cross = (0.0.0) -> distance=0 otherwise NaNs are returned
|
51 |
+
|
52 |
+
https://mathworld.wolfram.com/Point-LineDistance2-Dimensional.html
|
53 |
+
Args:
|
54 |
+
e1 (torch.Tensor): direction vector of shape (*, B, 1, 3)
|
55 |
+
r1 (torch.Tensor): support vector of shape (*, B, 1, 3)
|
56 |
+
pc (torch.Tensor): point cloud of shape (*, B, A, 3)
|
57 |
+
reduce (Union[None, str]): reduce distance for all points to one using 'mean' or 'min'
|
58 |
+
Returns:
|
59 |
+
distance of an infinite line to given points, (*, B, ) using reduce='mean' or reduce='min' or (*, B, A) if reduce=False
|
60 |
+
"""
|
61 |
+
|
62 |
+
num_points = pc.shape[-2]
|
63 |
+
_sub = r1 - pc # (*, B, A, 3)
|
64 |
+
|
65 |
+
cross = torch.cross(e1.repeat_interleave(num_points, dim=-2), _sub, dim=-1) # (*, B, A, 3)
|
66 |
+
|
67 |
+
e1_norm = torch.linalg.norm(e1, dim=-1)
|
68 |
+
cross_norm = torch.linalg.norm(cross, dim=-1)
|
69 |
+
|
70 |
+
d = cross_norm / e1_norm
|
71 |
+
if reduce == "mean":
|
72 |
+
return d.mean(dim=-1) # (*, B, )
|
73 |
+
elif reduce == "min":
|
74 |
+
return d.min(dim=-1)[0] # (*, B, )
|
75 |
+
|
76 |
+
return d # (B, A)
|
77 |
+
|
78 |
+
|
79 |
+
def distance_point_pointcloud(points: torch.Tensor, pointcloud: torch.Tensor) -> torch.Tensor:
|
80 |
+
"""Batched version for point-pointcloud distance calculation
|
81 |
+
Args:
|
82 |
+
points (torch.Tensor): N points in homogenous coordinates; shape (B, T, 3, S, N)
|
83 |
+
pointcloud (torch.Tensor): N_star points for each pointcloud; shape (B, T, S, N_star, 2)
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
torch.Tensor: Minimum distance for each point N to pointcloud; shape (B, T, 1, S, N)
|
87 |
+
"""
|
88 |
+
|
89 |
+
batch_size, T, _, S, N = points.shape
|
90 |
+
batch_size, T, S, N_star, _ = pointcloud.shape
|
91 |
+
|
92 |
+
pointcloud = pointcloud.reshape(batch_size * T * S, N_star, 2)
|
93 |
+
|
94 |
+
points = convert_points_from_homogeneous(
|
95 |
+
points.permute(0, 1, 3, 4, 2).reshape(batch_size * T * S, N, 3)
|
96 |
+
)
|
97 |
+
|
98 |
+
# cdist signature: (B, P, M), (B, R, M) -> (B, P, R)
|
99 |
+
distances = torch.cdist(points, pointcloud, p=2) # (B*T*S, N, N_star)
|
100 |
+
|
101 |
+
distances = distances.view(batch_size, T, S, N, N_star)
|
102 |
+
distances = distances.unsqueeze(-4)
|
103 |
+
|
104 |
+
# distance to nearest point from point cloud (batch_size, T, 1, S, N, N_star)
|
105 |
+
distances = distances.min(dim=-1)[0]
|
106 |
+
return distances
|
tvcalib/utils/objects_3d.py
ADDED
@@ -0,0 +1,1674 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABCMeta
|
2 |
+
from typing import List
|
3 |
+
import kornia
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
import random
|
7 |
+
from .linalg import LineCollection
|
8 |
+
from torch.utils.data import Dataset, DataLoader
|
9 |
+
from pytorch_lightning import LightningDataModule
|
10 |
+
|
11 |
+
|
12 |
+
class SoccerPitchSN:
|
13 |
+
"""Static class variables that are specified by the rules of the game"""
|
14 |
+
|
15 |
+
GOAL_LINE_TO_PENALTY_MARK = 11.0
|
16 |
+
PENALTY_AREA_WIDTH = 40.32
|
17 |
+
PENALTY_AREA_LENGTH = 16.5
|
18 |
+
GOAL_AREA_WIDTH = 18.32
|
19 |
+
GOAL_AREA_LENGTH = 5.5
|
20 |
+
CENTER_CIRCLE_RADIUS = 9.15
|
21 |
+
GOAL_HEIGHT = 2.44
|
22 |
+
GOAL_LENGTH = 7.32
|
23 |
+
|
24 |
+
lines_classes = [
|
25 |
+
"Big rect. left bottom",
|
26 |
+
"Big rect. left main",
|
27 |
+
"Big rect. left top",
|
28 |
+
"Big rect. right bottom",
|
29 |
+
"Big rect. right main",
|
30 |
+
"Big rect. right top",
|
31 |
+
"Circle central",
|
32 |
+
"Circle left",
|
33 |
+
"Circle right",
|
34 |
+
"Goal left crossbar",
|
35 |
+
"Goal left post left ",
|
36 |
+
"Goal left post right",
|
37 |
+
"Goal right crossbar",
|
38 |
+
"Goal right post left",
|
39 |
+
"Goal right post right",
|
40 |
+
"Goal unknown",
|
41 |
+
"Line unknown",
|
42 |
+
"Middle line",
|
43 |
+
"Side line bottom",
|
44 |
+
"Side line left",
|
45 |
+
"Side line right",
|
46 |
+
"Side line top",
|
47 |
+
"Small rect. left bottom",
|
48 |
+
"Small rect. left main",
|
49 |
+
"Small rect. left top",
|
50 |
+
"Small rect. right bottom",
|
51 |
+
"Small rect. right main",
|
52 |
+
"Small rect. right top",
|
53 |
+
]
|
54 |
+
|
55 |
+
symetric_classes = {
|
56 |
+
"Side line top": "Side line bottom",
|
57 |
+
"Side line bottom": "Side line top",
|
58 |
+
"Side line left": "Side line right",
|
59 |
+
"Middle line": "Middle line",
|
60 |
+
"Side line right": "Side line left",
|
61 |
+
"Big rect. left top": "Big rect. right bottom",
|
62 |
+
"Big rect. left bottom": "Big rect. right top",
|
63 |
+
"Big rect. left main": "Big rect. right main",
|
64 |
+
"Big rect. right top": "Big rect. left bottom",
|
65 |
+
"Big rect. right bottom": "Big rect. left top",
|
66 |
+
"Big rect. right main": "Big rect. left main",
|
67 |
+
"Small rect. left top": "Small rect. right bottom",
|
68 |
+
"Small rect. left bottom": "Small rect. right top",
|
69 |
+
"Small rect. left main": "Small rect. right main",
|
70 |
+
"Small rect. right top": "Small rect. left bottom",
|
71 |
+
"Small rect. right bottom": "Small rect. left top",
|
72 |
+
"Small rect. right main": "Small rect. left main",
|
73 |
+
"Circle left": "Circle right",
|
74 |
+
"Circle central": "Circle central",
|
75 |
+
"Circle right": "Circle left",
|
76 |
+
"Goal left crossbar": "Goal right crossbar",
|
77 |
+
"Goal left post left ": "Goal right post right",
|
78 |
+
"Goal left post right": "Goal right post left",
|
79 |
+
"Goal right crossbar": "Goal left crossbar",
|
80 |
+
"Goal right post left": "Goal left post right",
|
81 |
+
"Goal right post right": "Goal left post left ",
|
82 |
+
"Goal unknown": "Goal unknown",
|
83 |
+
"Line unknown": "Line unknown",
|
84 |
+
}
|
85 |
+
|
86 |
+
# RGB values
|
87 |
+
palette = {
|
88 |
+
"Big rect. left bottom": (127, 0, 0),
|
89 |
+
"Big rect. left main": (102, 102, 102),
|
90 |
+
"Big rect. left top": (0, 0, 127),
|
91 |
+
"Big rect. right bottom": (86, 32, 39),
|
92 |
+
"Big rect. right main": (48, 77, 0),
|
93 |
+
"Big rect. right top": (14, 97, 100),
|
94 |
+
"Circle central": (0, 0, 255),
|
95 |
+
"Circle left": (255, 127, 0),
|
96 |
+
"Circle right": (0, 255, 255),
|
97 |
+
"Goal left crossbar": (255, 255, 200),
|
98 |
+
"Goal left post left ": (165, 255, 0),
|
99 |
+
"Goal left post right": (155, 119, 45),
|
100 |
+
"Goal right crossbar": (86, 32, 139),
|
101 |
+
"Goal right post left": (196, 120, 153),
|
102 |
+
"Goal right post right": (166, 36, 52),
|
103 |
+
"Goal unknown": (0, 0, 0),
|
104 |
+
"Line unknown": (0, 0, 0),
|
105 |
+
"Middle line": (255, 255, 0),
|
106 |
+
"Side line bottom": (255, 0, 255),
|
107 |
+
"Side line left": (0, 255, 150),
|
108 |
+
"Side line right": (0, 230, 0),
|
109 |
+
"Side line top": (230, 0, 0),
|
110 |
+
"Small rect. left bottom": (0, 150, 255),
|
111 |
+
"Small rect. left main": (254, 173, 225),
|
112 |
+
"Small rect. left top": (87, 72, 39),
|
113 |
+
"Small rect. right bottom": (122, 0, 255),
|
114 |
+
"Small rect. right main": (128, 128, 128), # (255, 255, 255)
|
115 |
+
"Small rect. right top": (153, 23, 153),
|
116 |
+
}
|
117 |
+
|
118 |
+
def __init__(self, pitch_length=105.0, pitch_width=68.0):
|
119 |
+
"""
|
120 |
+
Initialize 3D coordinates of all elements of the soccer pitch.
|
121 |
+
:param pitch_length: According to FIFA rules, length belong to [90,120] meters
|
122 |
+
:param pitch_width: According to FIFA rules, length belong to [45,90] meters
|
123 |
+
"""
|
124 |
+
self.PITCH_LENGTH = pitch_length
|
125 |
+
self.PITCH_WIDTH = pitch_width
|
126 |
+
|
127 |
+
self.center_mark = np.array([0, 0, 0], dtype="float")
|
128 |
+
self.halfway_and_bottom_touch_line_mark = np.array([0, pitch_width / 2.0, 0], dtype="float")
|
129 |
+
self.halfway_and_top_touch_line_mark = np.array([0, -pitch_width / 2.0, 0], dtype="float")
|
130 |
+
self.halfway_line_and_center_circle_top_mark = np.array(
|
131 |
+
[0, -SoccerPitchSN.CENTER_CIRCLE_RADIUS, 0], dtype="float"
|
132 |
+
)
|
133 |
+
self.halfway_line_and_center_circle_bottom_mark = np.array(
|
134 |
+
[0, SoccerPitchSN.CENTER_CIRCLE_RADIUS, 0], dtype="float"
|
135 |
+
)
|
136 |
+
self.bottom_right_corner = np.array(
|
137 |
+
[pitch_length / 2.0, pitch_width / 2.0, 0], dtype="float"
|
138 |
+
)
|
139 |
+
self.bottom_left_corner = np.array(
|
140 |
+
[-pitch_length / 2.0, pitch_width / 2.0, 0], dtype="float"
|
141 |
+
)
|
142 |
+
self.top_right_corner = np.array([pitch_length / 2.0, -pitch_width / 2.0, 0], dtype="float")
|
143 |
+
self.top_left_corner = np.array([-pitch_length / 2.0, -34, 0], dtype="float")
|
144 |
+
|
145 |
+
self.left_goal_bottom_left_post = np.array(
|
146 |
+
[-pitch_length / 2.0, SoccerPitchSN.GOAL_LENGTH / 2.0, 0.0], dtype="float"
|
147 |
+
)
|
148 |
+
self.left_goal_top_left_post = np.array(
|
149 |
+
[-pitch_length / 2.0, SoccerPitchSN.GOAL_LENGTH / 2.0, -SoccerPitchSN.GOAL_HEIGHT],
|
150 |
+
dtype="float",
|
151 |
+
)
|
152 |
+
self.left_goal_bottom_right_post = np.array(
|
153 |
+
[-pitch_length / 2.0, -SoccerPitchSN.GOAL_LENGTH / 2.0, 0.0], dtype="float"
|
154 |
+
)
|
155 |
+
self.left_goal_top_right_post = np.array(
|
156 |
+
[-pitch_length / 2.0, -SoccerPitchSN.GOAL_LENGTH / 2.0, -SoccerPitchSN.GOAL_HEIGHT],
|
157 |
+
dtype="float",
|
158 |
+
)
|
159 |
+
|
160 |
+
self.right_goal_bottom_left_post = np.array(
|
161 |
+
[pitch_length / 2.0, -SoccerPitchSN.GOAL_LENGTH / 2.0, 0.0], dtype="float"
|
162 |
+
)
|
163 |
+
self.right_goal_top_left_post = np.array(
|
164 |
+
[pitch_length / 2.0, -SoccerPitchSN.GOAL_LENGTH / 2.0, -SoccerPitchSN.GOAL_HEIGHT],
|
165 |
+
dtype="float",
|
166 |
+
)
|
167 |
+
self.right_goal_bottom_right_post = np.array(
|
168 |
+
[pitch_length / 2.0, SoccerPitchSN.GOAL_LENGTH / 2.0, 0.0], dtype="float"
|
169 |
+
)
|
170 |
+
self.right_goal_top_right_post = np.array(
|
171 |
+
[pitch_length / 2.0, SoccerPitchSN.GOAL_LENGTH / 2.0, -SoccerPitchSN.GOAL_HEIGHT],
|
172 |
+
dtype="float",
|
173 |
+
)
|
174 |
+
|
175 |
+
self.left_penalty_mark = np.array(
|
176 |
+
[-pitch_length / 2.0 + SoccerPitchSN.GOAL_LINE_TO_PENALTY_MARK, 0, 0], dtype="float"
|
177 |
+
)
|
178 |
+
self.right_penalty_mark = np.array(
|
179 |
+
[pitch_length / 2.0 - SoccerPitchSN.GOAL_LINE_TO_PENALTY_MARK, 0, 0], dtype="float"
|
180 |
+
)
|
181 |
+
|
182 |
+
self.left_penalty_area_top_right_corner = np.array(
|
183 |
+
[
|
184 |
+
-pitch_length / 2.0 + SoccerPitchSN.PENALTY_AREA_LENGTH,
|
185 |
+
-SoccerPitchSN.PENALTY_AREA_WIDTH / 2.0,
|
186 |
+
0,
|
187 |
+
],
|
188 |
+
dtype="float",
|
189 |
+
)
|
190 |
+
self.left_penalty_area_top_left_corner = np.array(
|
191 |
+
[-pitch_length / 2.0, -SoccerPitchSN.PENALTY_AREA_WIDTH / 2.0, 0], dtype="float"
|
192 |
+
)
|
193 |
+
self.left_penalty_area_bottom_right_corner = np.array(
|
194 |
+
[
|
195 |
+
-pitch_length / 2.0 + SoccerPitchSN.PENALTY_AREA_LENGTH,
|
196 |
+
SoccerPitchSN.PENALTY_AREA_WIDTH / 2.0,
|
197 |
+
0,
|
198 |
+
],
|
199 |
+
dtype="float",
|
200 |
+
)
|
201 |
+
self.left_penalty_area_bottom_left_corner = np.array(
|
202 |
+
[-pitch_length / 2.0, SoccerPitchSN.PENALTY_AREA_WIDTH / 2.0, 0], dtype="float"
|
203 |
+
)
|
204 |
+
self.right_penalty_area_top_right_corner = np.array(
|
205 |
+
[pitch_length / 2.0, -SoccerPitchSN.PENALTY_AREA_WIDTH / 2.0, 0], dtype="float"
|
206 |
+
)
|
207 |
+
self.right_penalty_area_top_left_corner = np.array(
|
208 |
+
[
|
209 |
+
pitch_length / 2.0 - SoccerPitchSN.PENALTY_AREA_LENGTH,
|
210 |
+
-SoccerPitchSN.PENALTY_AREA_WIDTH / 2.0,
|
211 |
+
0,
|
212 |
+
],
|
213 |
+
dtype="float",
|
214 |
+
)
|
215 |
+
self.right_penalty_area_bottom_right_corner = np.array(
|
216 |
+
[pitch_length / 2.0, SoccerPitchSN.PENALTY_AREA_WIDTH / 2.0, 0], dtype="float"
|
217 |
+
)
|
218 |
+
self.right_penalty_area_bottom_left_corner = np.array(
|
219 |
+
[
|
220 |
+
pitch_length / 2.0 - SoccerPitchSN.PENALTY_AREA_LENGTH,
|
221 |
+
SoccerPitchSN.PENALTY_AREA_WIDTH / 2.0,
|
222 |
+
0,
|
223 |
+
],
|
224 |
+
dtype="float",
|
225 |
+
)
|
226 |
+
|
227 |
+
self.left_goal_area_top_right_corner = np.array(
|
228 |
+
[
|
229 |
+
-pitch_length / 2.0 + SoccerPitchSN.GOAL_AREA_LENGTH,
|
230 |
+
-SoccerPitchSN.GOAL_AREA_WIDTH / 2.0,
|
231 |
+
0,
|
232 |
+
],
|
233 |
+
dtype="float",
|
234 |
+
)
|
235 |
+
self.left_goal_area_top_left_corner = np.array(
|
236 |
+
[-pitch_length / 2.0, -SoccerPitchSN.GOAL_AREA_WIDTH / 2.0, 0], dtype="float"
|
237 |
+
)
|
238 |
+
self.left_goal_area_bottom_right_corner = np.array(
|
239 |
+
[
|
240 |
+
-pitch_length / 2.0 + SoccerPitchSN.GOAL_AREA_LENGTH,
|
241 |
+
SoccerPitchSN.GOAL_AREA_WIDTH / 2.0,
|
242 |
+
0,
|
243 |
+
],
|
244 |
+
dtype="float",
|
245 |
+
)
|
246 |
+
self.left_goal_area_bottom_left_corner = np.array(
|
247 |
+
[-pitch_length / 2.0, SoccerPitchSN.GOAL_AREA_WIDTH / 2.0, 0], dtype="float"
|
248 |
+
)
|
249 |
+
self.right_goal_area_top_right_corner = np.array(
|
250 |
+
[pitch_length / 2.0, -SoccerPitchSN.GOAL_AREA_WIDTH / 2.0, 0], dtype="float"
|
251 |
+
)
|
252 |
+
self.right_goal_area_top_left_corner = np.array(
|
253 |
+
[
|
254 |
+
pitch_length / 2.0 - SoccerPitchSN.GOAL_AREA_LENGTH,
|
255 |
+
-SoccerPitchSN.GOAL_AREA_WIDTH / 2.0,
|
256 |
+
0,
|
257 |
+
],
|
258 |
+
dtype="float",
|
259 |
+
)
|
260 |
+
self.right_goal_area_bottom_right_corner = np.array(
|
261 |
+
[pitch_length / 2.0, SoccerPitchSN.GOAL_AREA_WIDTH / 2.0, 0], dtype="float"
|
262 |
+
)
|
263 |
+
self.right_goal_area_bottom_left_corner = np.array(
|
264 |
+
[
|
265 |
+
pitch_length / 2.0 - SoccerPitchSN.GOAL_AREA_LENGTH,
|
266 |
+
SoccerPitchSN.GOAL_AREA_WIDTH / 2.0,
|
267 |
+
0,
|
268 |
+
],
|
269 |
+
dtype="float",
|
270 |
+
)
|
271 |
+
|
272 |
+
x = -pitch_length / 2.0 + SoccerPitchSN.PENALTY_AREA_LENGTH
|
273 |
+
dx = SoccerPitchSN.PENALTY_AREA_LENGTH - SoccerPitchSN.GOAL_LINE_TO_PENALTY_MARK
|
274 |
+
y = -np.sqrt(
|
275 |
+
SoccerPitchSN.CENTER_CIRCLE_RADIUS * SoccerPitchSN.CENTER_CIRCLE_RADIUS - dx * dx
|
276 |
+
)
|
277 |
+
self.top_left_16M_penalty_arc_mark = np.array([x, y, 0], dtype="float")
|
278 |
+
|
279 |
+
x = pitch_length / 2.0 - SoccerPitchSN.PENALTY_AREA_LENGTH
|
280 |
+
dx = SoccerPitchSN.PENALTY_AREA_LENGTH - SoccerPitchSN.GOAL_LINE_TO_PENALTY_MARK
|
281 |
+
y = -np.sqrt(
|
282 |
+
SoccerPitchSN.CENTER_CIRCLE_RADIUS * SoccerPitchSN.CENTER_CIRCLE_RADIUS - dx * dx
|
283 |
+
)
|
284 |
+
self.top_right_16M_penalty_arc_mark = np.array([x, y, 0], dtype="float")
|
285 |
+
|
286 |
+
x = -pitch_length / 2.0 + SoccerPitchSN.PENALTY_AREA_LENGTH
|
287 |
+
dx = SoccerPitchSN.PENALTY_AREA_LENGTH - SoccerPitchSN.GOAL_LINE_TO_PENALTY_MARK
|
288 |
+
y = np.sqrt(
|
289 |
+
SoccerPitchSN.CENTER_CIRCLE_RADIUS * SoccerPitchSN.CENTER_CIRCLE_RADIUS - dx * dx
|
290 |
+
)
|
291 |
+
self.bottom_left_16M_penalty_arc_mark = np.array([x, y, 0], dtype="float")
|
292 |
+
|
293 |
+
x = pitch_length / 2.0 - SoccerPitchSN.PENALTY_AREA_LENGTH
|
294 |
+
dx = SoccerPitchSN.PENALTY_AREA_LENGTH - SoccerPitchSN.GOAL_LINE_TO_PENALTY_MARK
|
295 |
+
y = np.sqrt(
|
296 |
+
SoccerPitchSN.CENTER_CIRCLE_RADIUS * SoccerPitchSN.CENTER_CIRCLE_RADIUS - dx * dx
|
297 |
+
)
|
298 |
+
self.bottom_right_16M_penalty_arc_mark = np.array([x, y, 0], dtype="float")
|
299 |
+
|
300 |
+
# self.set_elevations(elevation)
|
301 |
+
|
302 |
+
self.point_dict = {}
|
303 |
+
self.point_dict["CENTER_MARK"] = self.center_mark
|
304 |
+
self.point_dict["L_PENALTY_MARK"] = self.left_penalty_mark
|
305 |
+
self.point_dict["R_PENALTY_MARK"] = self.right_penalty_mark
|
306 |
+
self.point_dict["TL_PITCH_CORNER"] = self.top_left_corner
|
307 |
+
self.point_dict["BL_PITCH_CORNER"] = self.bottom_left_corner
|
308 |
+
self.point_dict["TR_PITCH_CORNER"] = self.top_right_corner
|
309 |
+
self.point_dict["BR_PITCH_CORNER"] = self.bottom_right_corner
|
310 |
+
self.point_dict["L_PENALTY_AREA_TL_CORNER"] = self.left_penalty_area_top_left_corner
|
311 |
+
self.point_dict["L_PENALTY_AREA_TR_CORNER"] = self.left_penalty_area_top_right_corner
|
312 |
+
self.point_dict["L_PENALTY_AREA_BL_CORNER"] = self.left_penalty_area_bottom_left_corner
|
313 |
+
self.point_dict["L_PENALTY_AREA_BR_CORNER"] = self.left_penalty_area_bottom_right_corner
|
314 |
+
|
315 |
+
self.point_dict["R_PENALTY_AREA_TL_CORNER"] = self.right_penalty_area_top_left_corner
|
316 |
+
self.point_dict["R_PENALTY_AREA_TR_CORNER"] = self.right_penalty_area_top_right_corner
|
317 |
+
self.point_dict["R_PENALTY_AREA_BL_CORNER"] = self.right_penalty_area_bottom_left_corner
|
318 |
+
self.point_dict["R_PENALTY_AREA_BR_CORNER"] = self.right_penalty_area_bottom_right_corner
|
319 |
+
|
320 |
+
self.point_dict["L_GOAL_AREA_TL_CORNER"] = self.left_goal_area_top_left_corner
|
321 |
+
self.point_dict["L_GOAL_AREA_TR_CORNER"] = self.left_goal_area_top_right_corner
|
322 |
+
self.point_dict["L_GOAL_AREA_BL_CORNER"] = self.left_goal_area_bottom_left_corner
|
323 |
+
self.point_dict["L_GOAL_AREA_BR_CORNER"] = self.left_goal_area_bottom_right_corner
|
324 |
+
|
325 |
+
self.point_dict["R_GOAL_AREA_TL_CORNER"] = self.right_goal_area_top_left_corner
|
326 |
+
self.point_dict["R_GOAL_AREA_TR_CORNER"] = self.right_goal_area_top_right_corner
|
327 |
+
self.point_dict["R_GOAL_AREA_BL_CORNER"] = self.right_goal_area_bottom_left_corner
|
328 |
+
self.point_dict["R_GOAL_AREA_BR_CORNER"] = self.right_goal_area_bottom_right_corner
|
329 |
+
|
330 |
+
self.point_dict["L_GOAL_TL_POST"] = self.left_goal_top_left_post
|
331 |
+
self.point_dict["L_GOAL_TR_POST"] = self.left_goal_top_right_post
|
332 |
+
self.point_dict["L_GOAL_BL_POST"] = self.left_goal_bottom_left_post
|
333 |
+
self.point_dict["L_GOAL_BR_POST"] = self.left_goal_bottom_right_post
|
334 |
+
|
335 |
+
self.point_dict["R_GOAL_TL_POST"] = self.right_goal_top_left_post
|
336 |
+
self.point_dict["R_GOAL_TR_POST"] = self.right_goal_top_right_post
|
337 |
+
self.point_dict["R_GOAL_BL_POST"] = self.right_goal_bottom_left_post
|
338 |
+
self.point_dict["R_GOAL_BR_POST"] = self.right_goal_bottom_right_post
|
339 |
+
|
340 |
+
self.point_dict[
|
341 |
+
"T_TOUCH_AND_HALFWAY_LINES_INTERSECTION"
|
342 |
+
] = self.halfway_and_top_touch_line_mark
|
343 |
+
self.point_dict[
|
344 |
+
"B_TOUCH_AND_HALFWAY_LINES_INTERSECTION"
|
345 |
+
] = self.halfway_and_bottom_touch_line_mark
|
346 |
+
self.point_dict[
|
347 |
+
"T_HALFWAY_LINE_AND_CENTER_CIRCLE_INTERSECTION"
|
348 |
+
] = self.halfway_line_and_center_circle_top_mark
|
349 |
+
self.point_dict[
|
350 |
+
"B_HALFWAY_LINE_AND_CENTER_CIRCLE_INTERSECTION"
|
351 |
+
] = self.halfway_line_and_center_circle_bottom_mark
|
352 |
+
self.point_dict[
|
353 |
+
"TL_16M_LINE_AND_PENALTY_ARC_INTERSECTION"
|
354 |
+
] = self.top_left_16M_penalty_arc_mark
|
355 |
+
self.point_dict[
|
356 |
+
"BL_16M_LINE_AND_PENALTY_ARC_INTERSECTION"
|
357 |
+
] = self.bottom_left_16M_penalty_arc_mark
|
358 |
+
self.point_dict[
|
359 |
+
"TR_16M_LINE_AND_PENALTY_ARC_INTERSECTION"
|
360 |
+
] = self.top_right_16M_penalty_arc_mark
|
361 |
+
self.point_dict[
|
362 |
+
"BR_16M_LINE_AND_PENALTY_ARC_INTERSECTION"
|
363 |
+
] = self.bottom_right_16M_penalty_arc_mark
|
364 |
+
|
365 |
+
self.line_extremities = dict()
|
366 |
+
self.line_extremities["Big rect. left bottom"] = (
|
367 |
+
self.point_dict["L_PENALTY_AREA_BL_CORNER"],
|
368 |
+
self.point_dict["L_PENALTY_AREA_BR_CORNER"],
|
369 |
+
)
|
370 |
+
self.line_extremities["Big rect. left top"] = (
|
371 |
+
self.point_dict["L_PENALTY_AREA_TL_CORNER"],
|
372 |
+
self.point_dict["L_PENALTY_AREA_TR_CORNER"],
|
373 |
+
)
|
374 |
+
self.line_extremities["Big rect. left main"] = (
|
375 |
+
self.point_dict["L_PENALTY_AREA_TR_CORNER"],
|
376 |
+
self.point_dict["L_PENALTY_AREA_BR_CORNER"],
|
377 |
+
)
|
378 |
+
self.line_extremities["Big rect. right bottom"] = (
|
379 |
+
self.point_dict["R_PENALTY_AREA_BL_CORNER"],
|
380 |
+
self.point_dict["R_PENALTY_AREA_BR_CORNER"],
|
381 |
+
)
|
382 |
+
self.line_extremities["Big rect. right top"] = (
|
383 |
+
self.point_dict["R_PENALTY_AREA_TL_CORNER"],
|
384 |
+
self.point_dict["R_PENALTY_AREA_TR_CORNER"],
|
385 |
+
)
|
386 |
+
self.line_extremities["Big rect. right main"] = (
|
387 |
+
self.point_dict["R_PENALTY_AREA_TL_CORNER"],
|
388 |
+
self.point_dict["R_PENALTY_AREA_BL_CORNER"],
|
389 |
+
)
|
390 |
+
|
391 |
+
self.line_extremities["Small rect. left bottom"] = (
|
392 |
+
self.point_dict["L_GOAL_AREA_BL_CORNER"],
|
393 |
+
self.point_dict["L_GOAL_AREA_BR_CORNER"],
|
394 |
+
)
|
395 |
+
self.line_extremities["Small rect. left top"] = (
|
396 |
+
self.point_dict["L_GOAL_AREA_TL_CORNER"],
|
397 |
+
self.point_dict["L_GOAL_AREA_TR_CORNER"],
|
398 |
+
)
|
399 |
+
self.line_extremities["Small rect. left main"] = (
|
400 |
+
self.point_dict["L_GOAL_AREA_TR_CORNER"],
|
401 |
+
self.point_dict["L_GOAL_AREA_BR_CORNER"],
|
402 |
+
)
|
403 |
+
self.line_extremities["Small rect. right bottom"] = (
|
404 |
+
self.point_dict["R_GOAL_AREA_BL_CORNER"],
|
405 |
+
self.point_dict["R_GOAL_AREA_BR_CORNER"],
|
406 |
+
)
|
407 |
+
self.line_extremities["Small rect. right top"] = (
|
408 |
+
self.point_dict["R_GOAL_AREA_TL_CORNER"],
|
409 |
+
self.point_dict["R_GOAL_AREA_TR_CORNER"],
|
410 |
+
)
|
411 |
+
self.line_extremities["Small rect. right main"] = (
|
412 |
+
self.point_dict["R_GOAL_AREA_TL_CORNER"],
|
413 |
+
self.point_dict["R_GOAL_AREA_BL_CORNER"],
|
414 |
+
)
|
415 |
+
|
416 |
+
self.line_extremities["Side line top"] = (
|
417 |
+
self.point_dict["TL_PITCH_CORNER"],
|
418 |
+
self.point_dict["TR_PITCH_CORNER"],
|
419 |
+
)
|
420 |
+
self.line_extremities["Side line bottom"] = (
|
421 |
+
self.point_dict["BL_PITCH_CORNER"],
|
422 |
+
self.point_dict["BR_PITCH_CORNER"],
|
423 |
+
)
|
424 |
+
self.line_extremities["Side line left"] = (
|
425 |
+
self.point_dict["TL_PITCH_CORNER"],
|
426 |
+
self.point_dict["BL_PITCH_CORNER"],
|
427 |
+
)
|
428 |
+
self.line_extremities["Side line right"] = (
|
429 |
+
self.point_dict["TR_PITCH_CORNER"],
|
430 |
+
self.point_dict["BR_PITCH_CORNER"],
|
431 |
+
)
|
432 |
+
self.line_extremities["Middle line"] = (
|
433 |
+
self.point_dict["T_TOUCH_AND_HALFWAY_LINES_INTERSECTION"],
|
434 |
+
self.point_dict["B_TOUCH_AND_HALFWAY_LINES_INTERSECTION"],
|
435 |
+
)
|
436 |
+
|
437 |
+
self.line_extremities["Goal left crossbar"] = (
|
438 |
+
self.point_dict["L_GOAL_TR_POST"],
|
439 |
+
self.point_dict["L_GOAL_TL_POST"],
|
440 |
+
)
|
441 |
+
self.line_extremities["Goal left post left "] = (
|
442 |
+
self.point_dict["L_GOAL_TL_POST"],
|
443 |
+
self.point_dict["L_GOAL_BL_POST"],
|
444 |
+
)
|
445 |
+
self.line_extremities["Goal left post right"] = (
|
446 |
+
self.point_dict["L_GOAL_TR_POST"],
|
447 |
+
self.point_dict["L_GOAL_BR_POST"],
|
448 |
+
)
|
449 |
+
|
450 |
+
self.line_extremities["Goal right crossbar"] = (
|
451 |
+
self.point_dict["R_GOAL_TL_POST"],
|
452 |
+
self.point_dict["R_GOAL_TR_POST"],
|
453 |
+
)
|
454 |
+
self.line_extremities["Goal right post left"] = (
|
455 |
+
self.point_dict["R_GOAL_TL_POST"],
|
456 |
+
self.point_dict["R_GOAL_BL_POST"],
|
457 |
+
)
|
458 |
+
self.line_extremities["Goal right post right"] = (
|
459 |
+
self.point_dict["R_GOAL_TR_POST"],
|
460 |
+
self.point_dict["R_GOAL_BR_POST"],
|
461 |
+
)
|
462 |
+
self.line_extremities["Circle right"] = (
|
463 |
+
self.point_dict["TR_16M_LINE_AND_PENALTY_ARC_INTERSECTION"],
|
464 |
+
self.point_dict["BR_16M_LINE_AND_PENALTY_ARC_INTERSECTION"],
|
465 |
+
)
|
466 |
+
self.line_extremities["Circle left"] = (
|
467 |
+
self.point_dict["TL_16M_LINE_AND_PENALTY_ARC_INTERSECTION"],
|
468 |
+
self.point_dict["BL_16M_LINE_AND_PENALTY_ARC_INTERSECTION"],
|
469 |
+
)
|
470 |
+
|
471 |
+
self.line_extremities_keys = dict()
|
472 |
+
self.line_extremities_keys["Big rect. left bottom"] = (
|
473 |
+
"L_PENALTY_AREA_BL_CORNER",
|
474 |
+
"L_PENALTY_AREA_BR_CORNER",
|
475 |
+
)
|
476 |
+
self.line_extremities_keys["Big rect. left top"] = (
|
477 |
+
"L_PENALTY_AREA_TL_CORNER",
|
478 |
+
"L_PENALTY_AREA_TR_CORNER",
|
479 |
+
)
|
480 |
+
self.line_extremities_keys["Big rect. left main"] = (
|
481 |
+
"L_PENALTY_AREA_TR_CORNER",
|
482 |
+
"L_PENALTY_AREA_BR_CORNER",
|
483 |
+
)
|
484 |
+
self.line_extremities_keys["Big rect. right bottom"] = (
|
485 |
+
"R_PENALTY_AREA_BL_CORNER",
|
486 |
+
"R_PENALTY_AREA_BR_CORNER",
|
487 |
+
)
|
488 |
+
self.line_extremities_keys["Big rect. right top"] = (
|
489 |
+
"R_PENALTY_AREA_TL_CORNER",
|
490 |
+
"R_PENALTY_AREA_TR_CORNER",
|
491 |
+
)
|
492 |
+
self.line_extremities_keys["Big rect. right main"] = (
|
493 |
+
"R_PENALTY_AREA_TL_CORNER",
|
494 |
+
"R_PENALTY_AREA_BL_CORNER",
|
495 |
+
)
|
496 |
+
|
497 |
+
self.line_extremities_keys["Small rect. left bottom"] = (
|
498 |
+
"L_GOAL_AREA_BL_CORNER",
|
499 |
+
"L_GOAL_AREA_BR_CORNER",
|
500 |
+
)
|
501 |
+
self.line_extremities_keys["Small rect. left top"] = (
|
502 |
+
"L_GOAL_AREA_TL_CORNER",
|
503 |
+
"L_GOAL_AREA_TR_CORNER",
|
504 |
+
)
|
505 |
+
self.line_extremities_keys["Small rect. left main"] = (
|
506 |
+
"L_GOAL_AREA_TR_CORNER",
|
507 |
+
"L_GOAL_AREA_BR_CORNER",
|
508 |
+
)
|
509 |
+
self.line_extremities_keys["Small rect. right bottom"] = (
|
510 |
+
"R_GOAL_AREA_BL_CORNER",
|
511 |
+
"R_GOAL_AREA_BR_CORNER",
|
512 |
+
)
|
513 |
+
self.line_extremities_keys["Small rect. right top"] = (
|
514 |
+
"R_GOAL_AREA_TL_CORNER",
|
515 |
+
"R_GOAL_AREA_TR_CORNER",
|
516 |
+
)
|
517 |
+
self.line_extremities_keys["Small rect. right main"] = (
|
518 |
+
"R_GOAL_AREA_TL_CORNER",
|
519 |
+
"R_GOAL_AREA_BL_CORNER",
|
520 |
+
)
|
521 |
+
|
522 |
+
self.line_extremities_keys["Side line top"] = ("TL_PITCH_CORNER", "TR_PITCH_CORNER")
|
523 |
+
self.line_extremities_keys["Side line bottom"] = ("BL_PITCH_CORNER", "BR_PITCH_CORNER")
|
524 |
+
self.line_extremities_keys["Side line left"] = ("TL_PITCH_CORNER", "BL_PITCH_CORNER")
|
525 |
+
self.line_extremities_keys["Side line right"] = ("TR_PITCH_CORNER", "BR_PITCH_CORNER")
|
526 |
+
self.line_extremities_keys["Middle line"] = (
|
527 |
+
"T_TOUCH_AND_HALFWAY_LINES_INTERSECTION",
|
528 |
+
"B_TOUCH_AND_HALFWAY_LINES_INTERSECTION",
|
529 |
+
)
|
530 |
+
|
531 |
+
self.line_extremities_keys["Goal left crossbar"] = ("L_GOAL_TR_POST", "L_GOAL_TL_POST")
|
532 |
+
self.line_extremities_keys["Goal left post left "] = ("L_GOAL_TL_POST", "L_GOAL_BL_POST")
|
533 |
+
self.line_extremities_keys["Goal left post right"] = ("L_GOAL_TR_POST", "L_GOAL_BR_POST")
|
534 |
+
|
535 |
+
self.line_extremities_keys["Goal right crossbar"] = ("R_GOAL_TL_POST", "R_GOAL_TR_POST")
|
536 |
+
self.line_extremities_keys["Goal right post left"] = ("R_GOAL_TL_POST", "R_GOAL_BL_POST")
|
537 |
+
self.line_extremities_keys["Goal right post right"] = ("R_GOAL_TR_POST", "R_GOAL_BR_POST")
|
538 |
+
self.line_extremities_keys["Circle right"] = (
|
539 |
+
"TR_16M_LINE_AND_PENALTY_ARC_INTERSECTION",
|
540 |
+
"BR_16M_LINE_AND_PENALTY_ARC_INTERSECTION",
|
541 |
+
)
|
542 |
+
self.line_extremities_keys["Circle left"] = (
|
543 |
+
"TL_16M_LINE_AND_PENALTY_ARC_INTERSECTION",
|
544 |
+
"BL_16M_LINE_AND_PENALTY_ARC_INTERSECTION",
|
545 |
+
)
|
546 |
+
|
547 |
+
def points(self):
|
548 |
+
return [
|
549 |
+
self.center_mark,
|
550 |
+
self.halfway_and_bottom_touch_line_mark,
|
551 |
+
self.halfway_and_top_touch_line_mark,
|
552 |
+
self.halfway_line_and_center_circle_top_mark,
|
553 |
+
self.halfway_line_and_center_circle_bottom_mark,
|
554 |
+
self.bottom_right_corner,
|
555 |
+
self.bottom_left_corner,
|
556 |
+
self.top_right_corner,
|
557 |
+
self.top_left_corner,
|
558 |
+
self.left_penalty_mark,
|
559 |
+
self.right_penalty_mark,
|
560 |
+
self.left_penalty_area_top_right_corner,
|
561 |
+
self.left_penalty_area_top_left_corner,
|
562 |
+
self.left_penalty_area_bottom_right_corner,
|
563 |
+
self.left_penalty_area_bottom_left_corner,
|
564 |
+
self.right_penalty_area_top_right_corner,
|
565 |
+
self.right_penalty_area_top_left_corner,
|
566 |
+
self.right_penalty_area_bottom_right_corner,
|
567 |
+
self.right_penalty_area_bottom_left_corner,
|
568 |
+
self.left_goal_area_top_right_corner,
|
569 |
+
self.left_goal_area_top_left_corner,
|
570 |
+
self.left_goal_area_bottom_right_corner,
|
571 |
+
self.left_goal_area_bottom_left_corner,
|
572 |
+
self.right_goal_area_top_right_corner,
|
573 |
+
self.right_goal_area_top_left_corner,
|
574 |
+
self.right_goal_area_bottom_right_corner,
|
575 |
+
self.right_goal_area_bottom_left_corner,
|
576 |
+
self.top_left_16M_penalty_arc_mark,
|
577 |
+
self.top_right_16M_penalty_arc_mark,
|
578 |
+
self.bottom_left_16M_penalty_arc_mark,
|
579 |
+
self.bottom_right_16M_penalty_arc_mark,
|
580 |
+
self.left_goal_top_left_post,
|
581 |
+
self.left_goal_top_right_post,
|
582 |
+
self.left_goal_bottom_left_post,
|
583 |
+
self.left_goal_bottom_right_post,
|
584 |
+
self.right_goal_top_left_post,
|
585 |
+
self.right_goal_top_right_post,
|
586 |
+
self.right_goal_bottom_left_post,
|
587 |
+
self.right_goal_bottom_right_post,
|
588 |
+
]
|
589 |
+
|
590 |
+
def sample_field_points(self, dist=0.1, dist_circles=0.2):
|
591 |
+
"""
|
592 |
+
Samples each pitch element every dist meters, returns a dictionary associating the class of the element with a list of points sampled along this element.
|
593 |
+
:param dist: the distance in meters between each point sampled
|
594 |
+
:param dist_circles: the distance in meters between each point sampled on circles
|
595 |
+
:return: a dictionary associating the class of the element with a list of points sampled along this element.
|
596 |
+
"""
|
597 |
+
polylines = dict()
|
598 |
+
center = self.point_dict["CENTER_MARK"]
|
599 |
+
fromAngle = 0.0
|
600 |
+
toAngle = 2 * np.pi
|
601 |
+
|
602 |
+
if toAngle < fromAngle:
|
603 |
+
toAngle += 2 * np.pi
|
604 |
+
x1 = center[0] + np.cos(fromAngle) * SoccerPitchSN.CENTER_CIRCLE_RADIUS
|
605 |
+
y1 = center[1] + np.sin(fromAngle) * SoccerPitchSN.CENTER_CIRCLE_RADIUS
|
606 |
+
z1 = 0.0
|
607 |
+
point = np.array((x1, y1, z1))
|
608 |
+
polyline = [point]
|
609 |
+
length = SoccerPitchSN.CENTER_CIRCLE_RADIUS * (toAngle - fromAngle)
|
610 |
+
nb_pts = int(length / dist_circles)
|
611 |
+
dangle = dist_circles / SoccerPitchSN.CENTER_CIRCLE_RADIUS
|
612 |
+
for i in range(1, nb_pts):
|
613 |
+
angle = fromAngle + i * dangle
|
614 |
+
x = center[0] + np.cos(angle) * SoccerPitchSN.CENTER_CIRCLE_RADIUS
|
615 |
+
y = center[1] + np.sin(angle) * SoccerPitchSN.CENTER_CIRCLE_RADIUS
|
616 |
+
z = 0
|
617 |
+
point = np.array((x, y, z))
|
618 |
+
polyline.append(point)
|
619 |
+
polylines["Circle central"] = polyline
|
620 |
+
for key, line in self.line_extremities.items():
|
621 |
+
|
622 |
+
if "Circle" in key:
|
623 |
+
if key == "Circle right":
|
624 |
+
top = self.point_dict["TR_16M_LINE_AND_PENALTY_ARC_INTERSECTION"]
|
625 |
+
bottom = self.point_dict["BR_16M_LINE_AND_PENALTY_ARC_INTERSECTION"]
|
626 |
+
center = self.point_dict["R_PENALTY_MARK"]
|
627 |
+
toAngle = np.arctan2(top[1] - center[1], top[0] - center[0]) + 2 * np.pi
|
628 |
+
fromAngle = np.arctan2(bottom[1] - center[1], bottom[0] - center[0]) + 2 * np.pi
|
629 |
+
elif key == "Circle left":
|
630 |
+
top = self.point_dict["TL_16M_LINE_AND_PENALTY_ARC_INTERSECTION"]
|
631 |
+
bottom = self.point_dict["BL_16M_LINE_AND_PENALTY_ARC_INTERSECTION"]
|
632 |
+
center = self.point_dict["L_PENALTY_MARK"]
|
633 |
+
fromAngle = np.arctan2(top[1] - center[1], top[0] - center[0]) + 2 * np.pi
|
634 |
+
toAngle = np.arctan2(bottom[1] - center[1], bottom[0] - center[0]) + 2 * np.pi
|
635 |
+
if toAngle < fromAngle:
|
636 |
+
toAngle += 2 * np.pi
|
637 |
+
x1 = center[0] + np.cos(fromAngle) * SoccerPitchSN.CENTER_CIRCLE_RADIUS
|
638 |
+
y1 = center[1] + np.sin(fromAngle) * SoccerPitchSN.CENTER_CIRCLE_RADIUS
|
639 |
+
z1 = 0.0
|
640 |
+
xn = center[0] + np.cos(toAngle) * SoccerPitchSN.CENTER_CIRCLE_RADIUS
|
641 |
+
yn = center[1] + np.sin(toAngle) * SoccerPitchSN.CENTER_CIRCLE_RADIUS
|
642 |
+
zn = 0.0
|
643 |
+
start = np.array((x1, y1, z1))
|
644 |
+
end = np.array((xn, yn, zn))
|
645 |
+
polyline = [start]
|
646 |
+
length = SoccerPitchSN.CENTER_CIRCLE_RADIUS * (toAngle - fromAngle)
|
647 |
+
nb_pts = int(length / dist_circles)
|
648 |
+
dangle = dist_circles / SoccerPitchSN.CENTER_CIRCLE_RADIUS
|
649 |
+
for i in range(1, nb_pts + 1):
|
650 |
+
angle = fromAngle + i * dangle
|
651 |
+
x = center[0] + np.cos(angle) * SoccerPitchSN.CENTER_CIRCLE_RADIUS
|
652 |
+
y = center[1] + np.sin(angle) * SoccerPitchSN.CENTER_CIRCLE_RADIUS
|
653 |
+
z = 0
|
654 |
+
point = np.array((x, y, z))
|
655 |
+
polyline.append(point)
|
656 |
+
polyline.append(end)
|
657 |
+
polylines[key] = polyline
|
658 |
+
else:
|
659 |
+
start = line[0]
|
660 |
+
end = line[1]
|
661 |
+
|
662 |
+
polyline = [start]
|
663 |
+
|
664 |
+
total_dist = np.sqrt(np.sum(np.square(start - end)))
|
665 |
+
nb_pts = int(total_dist / dist - 1)
|
666 |
+
|
667 |
+
v = end - start
|
668 |
+
v /= np.linalg.norm(v)
|
669 |
+
prev_pt = start
|
670 |
+
for i in range(nb_pts):
|
671 |
+
pt = prev_pt + dist * v
|
672 |
+
prev_pt = pt
|
673 |
+
polyline.append(pt)
|
674 |
+
polyline.append(end)
|
675 |
+
polylines[key] = polyline
|
676 |
+
return polylines
|
677 |
+
|
678 |
+
def get_2d_homogeneous_line(self, line_name):
|
679 |
+
"""
|
680 |
+
For lines belonging to the pitch lawn plane returns its 2D homogenous equation coefficients
|
681 |
+
:param line_name
|
682 |
+
:return: an array containing the three coefficients of the line
|
683 |
+
"""
|
684 |
+
# ensure line in football pitch plane
|
685 |
+
if (
|
686 |
+
line_name in self.line_extremities.keys()
|
687 |
+
and "post" not in line_name
|
688 |
+
and "crossbar" not in line_name
|
689 |
+
and "Circle" not in line_name
|
690 |
+
):
|
691 |
+
extremities = self.line_extremities[line_name]
|
692 |
+
p1 = np.array([extremities[0][0], extremities[0][1], 1], dtype="float")
|
693 |
+
p2 = np.array([extremities[1][0], extremities[1][1], 1], dtype="float")
|
694 |
+
line = np.cross(p1, p2)
|
695 |
+
|
696 |
+
return line
|
697 |
+
return None
|
698 |
+
|
699 |
+
|
700 |
+
class SoccerPitchSNCircleCentralSplit:
|
701 |
+
"""Static class variables that are specified by the rules of the game"""
|
702 |
+
|
703 |
+
GOAL_LINE_TO_PENALTY_MARK = 11.0
|
704 |
+
PENALTY_AREA_WIDTH = 40.32
|
705 |
+
PENALTY_AREA_LENGTH = 16.5
|
706 |
+
GOAL_AREA_WIDTH = 18.32
|
707 |
+
GOAL_AREA_LENGTH = 5.5
|
708 |
+
CENTER_CIRCLE_RADIUS = 9.15
|
709 |
+
GOAL_HEIGHT = 2.44
|
710 |
+
GOAL_LENGTH = 7.32
|
711 |
+
|
712 |
+
lines_classes = [
|
713 |
+
"Big rect. left bottom",
|
714 |
+
"Big rect. left main",
|
715 |
+
"Big rect. left top",
|
716 |
+
"Big rect. right bottom",
|
717 |
+
"Big rect. right main",
|
718 |
+
"Big rect. right top",
|
719 |
+
"Circle central left",
|
720 |
+
"Circle central right",
|
721 |
+
"Circle left",
|
722 |
+
"Circle right",
|
723 |
+
"Goal left crossbar",
|
724 |
+
"Goal left post left ",
|
725 |
+
"Goal left post right",
|
726 |
+
"Goal right crossbar",
|
727 |
+
"Goal right post left",
|
728 |
+
"Goal right post right",
|
729 |
+
"Goal unknown",
|
730 |
+
"Line unknown",
|
731 |
+
"Middle line",
|
732 |
+
"Side line bottom",
|
733 |
+
"Side line left",
|
734 |
+
"Side line right",
|
735 |
+
"Side line top",
|
736 |
+
"Small rect. left bottom",
|
737 |
+
"Small rect. left main",
|
738 |
+
"Small rect. left top",
|
739 |
+
"Small rect. right bottom",
|
740 |
+
"Small rect. right main",
|
741 |
+
"Small rect. right top",
|
742 |
+
]
|
743 |
+
|
744 |
+
symetric_classes = {
|
745 |
+
"Side line top": "Side line bottom",
|
746 |
+
"Side line bottom": "Side line top",
|
747 |
+
"Side line left": "Side line right",
|
748 |
+
"Middle line": "Middle line",
|
749 |
+
"Side line right": "Side line left",
|
750 |
+
"Big rect. left top": "Big rect. right bottom",
|
751 |
+
"Big rect. left bottom": "Big rect. right top",
|
752 |
+
"Big rect. left main": "Big rect. right main",
|
753 |
+
"Big rect. right top": "Big rect. left bottom",
|
754 |
+
"Big rect. right bottom": "Big rect. left top",
|
755 |
+
"Big rect. right main": "Big rect. left main",
|
756 |
+
"Small rect. left top": "Small rect. right bottom",
|
757 |
+
"Small rect. left bottom": "Small rect. right top",
|
758 |
+
"Small rect. left main": "Small rect. right main",
|
759 |
+
"Small rect. right top": "Small rect. left bottom",
|
760 |
+
"Small rect. right bottom": "Small rect. left top",
|
761 |
+
"Small rect. right main": "Small rect. left main",
|
762 |
+
"Circle left": "Circle right",
|
763 |
+
"Circle central left": "Circle central right",
|
764 |
+
"Circle central right": "Circle central left",
|
765 |
+
"Circle right": "Circle left",
|
766 |
+
"Goal left crossbar": "Goal right crossbar",
|
767 |
+
"Goal left post left ": "Goal right post right",
|
768 |
+
"Goal left post right": "Goal right post left",
|
769 |
+
"Goal right crossbar": "Goal left crossbar",
|
770 |
+
"Goal right post left": "Goal left post right",
|
771 |
+
"Goal right post right": "Goal left post left ",
|
772 |
+
"Goal unknown": "Goal unknown",
|
773 |
+
"Line unknown": "Line unknown",
|
774 |
+
}
|
775 |
+
|
776 |
+
# RGB values
|
777 |
+
palette = {
|
778 |
+
"Big rect. left bottom": (127, 0, 0),
|
779 |
+
"Big rect. left main": (102, 102, 102),
|
780 |
+
"Big rect. left top": (0, 0, 127),
|
781 |
+
"Big rect. right bottom": (86, 32, 39),
|
782 |
+
"Big rect. right main": (48, 77, 0),
|
783 |
+
"Big rect. right top": (14, 97, 100),
|
784 |
+
"Circle central left": (0, 0, 255),
|
785 |
+
"Circle central right": (0, 255, 0),
|
786 |
+
"Circle left": (255, 127, 0),
|
787 |
+
"Circle right": (0, 255, 255),
|
788 |
+
"Goal left crossbar": (255, 255, 200),
|
789 |
+
"Goal left post left ": (165, 255, 0),
|
790 |
+
"Goal left post right": (155, 119, 45),
|
791 |
+
"Goal right crossbar": (86, 32, 139),
|
792 |
+
"Goal right post left": (196, 120, 153),
|
793 |
+
"Goal right post right": (166, 36, 52),
|
794 |
+
"Goal unknown": (0, 0, 0),
|
795 |
+
"Line unknown": (0, 0, 0),
|
796 |
+
"Middle line": (255, 255, 0),
|
797 |
+
"Side line bottom": (255, 0, 255),
|
798 |
+
"Side line left": (0, 255, 150),
|
799 |
+
"Side line right": (0, 230, 0),
|
800 |
+
"Side line top": (230, 0, 0),
|
801 |
+
"Small rect. left bottom": (0, 150, 255),
|
802 |
+
"Small rect. left main": (254, 173, 225),
|
803 |
+
"Small rect. left top": (87, 72, 39),
|
804 |
+
"Small rect. right bottom": (122, 0, 255),
|
805 |
+
"Small rect. right main": (255, 255, 255),
|
806 |
+
"Small rect. right top": (153, 23, 153),
|
807 |
+
}
|
808 |
+
|
809 |
+
def __init__(self, pitch_length=105.0, pitch_width=68.0):
|
810 |
+
"""
|
811 |
+
Initialize 3D coordinates of all elements of the soccer pitch.
|
812 |
+
:param pitch_length: According to FIFA rules, length belong to [90,120] meters
|
813 |
+
:param pitch_width: According to FIFA rules, length belong to [45,90] meters
|
814 |
+
"""
|
815 |
+
self.PITCH_LENGTH = pitch_length
|
816 |
+
self.PITCH_WIDTH = pitch_width
|
817 |
+
|
818 |
+
self.center_mark = np.array([0, 0, 0], dtype="float")
|
819 |
+
self.halfway_and_bottom_touch_line_mark = np.array([0, pitch_width / 2.0, 0], dtype="float")
|
820 |
+
self.halfway_and_top_touch_line_mark = np.array([0, -pitch_width / 2.0, 0], dtype="float")
|
821 |
+
self.halfway_line_and_center_circle_top_mark = np.array(
|
822 |
+
[0, -SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS, 0], dtype="float"
|
823 |
+
)
|
824 |
+
self.halfway_line_and_center_circle_bottom_mark = np.array(
|
825 |
+
[0, SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS, 0], dtype="float"
|
826 |
+
)
|
827 |
+
self.bottom_right_corner = np.array(
|
828 |
+
[pitch_length / 2.0, pitch_width / 2.0, 0], dtype="float"
|
829 |
+
)
|
830 |
+
self.bottom_left_corner = np.array(
|
831 |
+
[-pitch_length / 2.0, pitch_width / 2.0, 0], dtype="float"
|
832 |
+
)
|
833 |
+
self.top_right_corner = np.array([pitch_length / 2.0, -pitch_width / 2.0, 0], dtype="float")
|
834 |
+
self.top_left_corner = np.array([-pitch_length / 2.0, -34, 0], dtype="float")
|
835 |
+
|
836 |
+
self.left_goal_bottom_left_post = np.array(
|
837 |
+
[-pitch_length / 2.0, SoccerPitchSNCircleCentralSplit.GOAL_LENGTH / 2.0, 0.0],
|
838 |
+
dtype="float",
|
839 |
+
)
|
840 |
+
self.left_goal_top_left_post = np.array(
|
841 |
+
[
|
842 |
+
-pitch_length / 2.0,
|
843 |
+
SoccerPitchSNCircleCentralSplit.GOAL_LENGTH / 2.0,
|
844 |
+
-SoccerPitchSNCircleCentralSplit.GOAL_HEIGHT,
|
845 |
+
],
|
846 |
+
dtype="float",
|
847 |
+
)
|
848 |
+
self.left_goal_bottom_right_post = np.array(
|
849 |
+
[-pitch_length / 2.0, -SoccerPitchSNCircleCentralSplit.GOAL_LENGTH / 2.0, 0.0],
|
850 |
+
dtype="float",
|
851 |
+
)
|
852 |
+
self.left_goal_top_right_post = np.array(
|
853 |
+
[
|
854 |
+
-pitch_length / 2.0,
|
855 |
+
-SoccerPitchSNCircleCentralSplit.GOAL_LENGTH / 2.0,
|
856 |
+
-SoccerPitchSNCircleCentralSplit.GOAL_HEIGHT,
|
857 |
+
],
|
858 |
+
dtype="float",
|
859 |
+
)
|
860 |
+
|
861 |
+
self.right_goal_bottom_left_post = np.array(
|
862 |
+
[pitch_length / 2.0, -SoccerPitchSNCircleCentralSplit.GOAL_LENGTH / 2.0, 0.0],
|
863 |
+
dtype="float",
|
864 |
+
)
|
865 |
+
self.right_goal_top_left_post = np.array(
|
866 |
+
[
|
867 |
+
pitch_length / 2.0,
|
868 |
+
-SoccerPitchSNCircleCentralSplit.GOAL_LENGTH / 2.0,
|
869 |
+
-SoccerPitchSNCircleCentralSplit.GOAL_HEIGHT,
|
870 |
+
],
|
871 |
+
dtype="float",
|
872 |
+
)
|
873 |
+
self.right_goal_bottom_right_post = np.array(
|
874 |
+
[pitch_length / 2.0, SoccerPitchSNCircleCentralSplit.GOAL_LENGTH / 2.0, 0.0],
|
875 |
+
dtype="float",
|
876 |
+
)
|
877 |
+
self.right_goal_top_right_post = np.array(
|
878 |
+
[
|
879 |
+
pitch_length / 2.0,
|
880 |
+
SoccerPitchSNCircleCentralSplit.GOAL_LENGTH / 2.0,
|
881 |
+
-SoccerPitchSNCircleCentralSplit.GOAL_HEIGHT,
|
882 |
+
],
|
883 |
+
dtype="float",
|
884 |
+
)
|
885 |
+
|
886 |
+
self.left_penalty_mark = np.array(
|
887 |
+
[-pitch_length / 2.0 + SoccerPitchSNCircleCentralSplit.GOAL_LINE_TO_PENALTY_MARK, 0, 0],
|
888 |
+
dtype="float",
|
889 |
+
)
|
890 |
+
self.right_penalty_mark = np.array(
|
891 |
+
[pitch_length / 2.0 - SoccerPitchSNCircleCentralSplit.GOAL_LINE_TO_PENALTY_MARK, 0, 0],
|
892 |
+
dtype="float",
|
893 |
+
)
|
894 |
+
|
895 |
+
self.left_penalty_area_top_right_corner = np.array(
|
896 |
+
[
|
897 |
+
-pitch_length / 2.0 + SoccerPitchSNCircleCentralSplit.PENALTY_AREA_LENGTH,
|
898 |
+
-SoccerPitchSNCircleCentralSplit.PENALTY_AREA_WIDTH / 2.0,
|
899 |
+
0,
|
900 |
+
],
|
901 |
+
dtype="float",
|
902 |
+
)
|
903 |
+
self.left_penalty_area_top_left_corner = np.array(
|
904 |
+
[-pitch_length / 2.0, -SoccerPitchSNCircleCentralSplit.PENALTY_AREA_WIDTH / 2.0, 0],
|
905 |
+
dtype="float",
|
906 |
+
)
|
907 |
+
self.left_penalty_area_bottom_right_corner = np.array(
|
908 |
+
[
|
909 |
+
-pitch_length / 2.0 + SoccerPitchSNCircleCentralSplit.PENALTY_AREA_LENGTH,
|
910 |
+
SoccerPitchSNCircleCentralSplit.PENALTY_AREA_WIDTH / 2.0,
|
911 |
+
0,
|
912 |
+
],
|
913 |
+
dtype="float",
|
914 |
+
)
|
915 |
+
self.left_penalty_area_bottom_left_corner = np.array(
|
916 |
+
[-pitch_length / 2.0, SoccerPitchSNCircleCentralSplit.PENALTY_AREA_WIDTH / 2.0, 0],
|
917 |
+
dtype="float",
|
918 |
+
)
|
919 |
+
self.right_penalty_area_top_right_corner = np.array(
|
920 |
+
[pitch_length / 2.0, -SoccerPitchSNCircleCentralSplit.PENALTY_AREA_WIDTH / 2.0, 0],
|
921 |
+
dtype="float",
|
922 |
+
)
|
923 |
+
self.right_penalty_area_top_left_corner = np.array(
|
924 |
+
[
|
925 |
+
pitch_length / 2.0 - SoccerPitchSNCircleCentralSplit.PENALTY_AREA_LENGTH,
|
926 |
+
-SoccerPitchSNCircleCentralSplit.PENALTY_AREA_WIDTH / 2.0,
|
927 |
+
0,
|
928 |
+
],
|
929 |
+
dtype="float",
|
930 |
+
)
|
931 |
+
self.right_penalty_area_bottom_right_corner = np.array(
|
932 |
+
[pitch_length / 2.0, SoccerPitchSNCircleCentralSplit.PENALTY_AREA_WIDTH / 2.0, 0],
|
933 |
+
dtype="float",
|
934 |
+
)
|
935 |
+
self.right_penalty_area_bottom_left_corner = np.array(
|
936 |
+
[
|
937 |
+
pitch_length / 2.0 - SoccerPitchSNCircleCentralSplit.PENALTY_AREA_LENGTH,
|
938 |
+
SoccerPitchSNCircleCentralSplit.PENALTY_AREA_WIDTH / 2.0,
|
939 |
+
0,
|
940 |
+
],
|
941 |
+
dtype="float",
|
942 |
+
)
|
943 |
+
|
944 |
+
self.left_goal_area_top_right_corner = np.array(
|
945 |
+
[
|
946 |
+
-pitch_length / 2.0 + SoccerPitchSNCircleCentralSplit.GOAL_AREA_LENGTH,
|
947 |
+
-SoccerPitchSNCircleCentralSplit.GOAL_AREA_WIDTH / 2.0,
|
948 |
+
0,
|
949 |
+
],
|
950 |
+
dtype="float",
|
951 |
+
)
|
952 |
+
self.left_goal_area_top_left_corner = np.array(
|
953 |
+
[-pitch_length / 2.0, -SoccerPitchSNCircleCentralSplit.GOAL_AREA_WIDTH / 2.0, 0],
|
954 |
+
dtype="float",
|
955 |
+
)
|
956 |
+
self.left_goal_area_bottom_right_corner = np.array(
|
957 |
+
[
|
958 |
+
-pitch_length / 2.0 + SoccerPitchSNCircleCentralSplit.GOAL_AREA_LENGTH,
|
959 |
+
SoccerPitchSNCircleCentralSplit.GOAL_AREA_WIDTH / 2.0,
|
960 |
+
0,
|
961 |
+
],
|
962 |
+
dtype="float",
|
963 |
+
)
|
964 |
+
self.left_goal_area_bottom_left_corner = np.array(
|
965 |
+
[-pitch_length / 2.0, SoccerPitchSNCircleCentralSplit.GOAL_AREA_WIDTH / 2.0, 0],
|
966 |
+
dtype="float",
|
967 |
+
)
|
968 |
+
self.right_goal_area_top_right_corner = np.array(
|
969 |
+
[pitch_length / 2.0, -SoccerPitchSNCircleCentralSplit.GOAL_AREA_WIDTH / 2.0, 0],
|
970 |
+
dtype="float",
|
971 |
+
)
|
972 |
+
self.right_goal_area_top_left_corner = np.array(
|
973 |
+
[
|
974 |
+
pitch_length / 2.0 - SoccerPitchSNCircleCentralSplit.GOAL_AREA_LENGTH,
|
975 |
+
-SoccerPitchSNCircleCentralSplit.GOAL_AREA_WIDTH / 2.0,
|
976 |
+
0,
|
977 |
+
],
|
978 |
+
dtype="float",
|
979 |
+
)
|
980 |
+
self.right_goal_area_bottom_right_corner = np.array(
|
981 |
+
[pitch_length / 2.0, SoccerPitchSNCircleCentralSplit.GOAL_AREA_WIDTH / 2.0, 0],
|
982 |
+
dtype="float",
|
983 |
+
)
|
984 |
+
self.right_goal_area_bottom_left_corner = np.array(
|
985 |
+
[
|
986 |
+
pitch_length / 2.0 - SoccerPitchSNCircleCentralSplit.GOAL_AREA_LENGTH,
|
987 |
+
SoccerPitchSNCircleCentralSplit.GOAL_AREA_WIDTH / 2.0,
|
988 |
+
0,
|
989 |
+
],
|
990 |
+
dtype="float",
|
991 |
+
)
|
992 |
+
|
993 |
+
x = -pitch_length / 2.0 + SoccerPitchSNCircleCentralSplit.PENALTY_AREA_LENGTH
|
994 |
+
dx = (
|
995 |
+
SoccerPitchSNCircleCentralSplit.PENALTY_AREA_LENGTH
|
996 |
+
- SoccerPitchSNCircleCentralSplit.GOAL_LINE_TO_PENALTY_MARK
|
997 |
+
)
|
998 |
+
y = -np.sqrt(
|
999 |
+
SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
|
1000 |
+
* SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
|
1001 |
+
- dx * dx
|
1002 |
+
)
|
1003 |
+
self.top_left_16M_penalty_arc_mark = np.array([x, y, 0], dtype="float")
|
1004 |
+
|
1005 |
+
x = pitch_length / 2.0 - SoccerPitchSNCircleCentralSplit.PENALTY_AREA_LENGTH
|
1006 |
+
dx = (
|
1007 |
+
SoccerPitchSNCircleCentralSplit.PENALTY_AREA_LENGTH
|
1008 |
+
- SoccerPitchSNCircleCentralSplit.GOAL_LINE_TO_PENALTY_MARK
|
1009 |
+
)
|
1010 |
+
y = -np.sqrt(
|
1011 |
+
SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
|
1012 |
+
* SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
|
1013 |
+
- dx * dx
|
1014 |
+
)
|
1015 |
+
self.top_right_16M_penalty_arc_mark = np.array([x, y, 0], dtype="float")
|
1016 |
+
|
1017 |
+
x = -pitch_length / 2.0 + SoccerPitchSNCircleCentralSplit.PENALTY_AREA_LENGTH
|
1018 |
+
dx = (
|
1019 |
+
SoccerPitchSNCircleCentralSplit.PENALTY_AREA_LENGTH
|
1020 |
+
- SoccerPitchSNCircleCentralSplit.GOAL_LINE_TO_PENALTY_MARK
|
1021 |
+
)
|
1022 |
+
y = np.sqrt(
|
1023 |
+
SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
|
1024 |
+
* SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
|
1025 |
+
- dx * dx
|
1026 |
+
)
|
1027 |
+
self.bottom_left_16M_penalty_arc_mark = np.array([x, y, 0], dtype="float")
|
1028 |
+
|
1029 |
+
x = pitch_length / 2.0 - SoccerPitchSNCircleCentralSplit.PENALTY_AREA_LENGTH
|
1030 |
+
dx = (
|
1031 |
+
SoccerPitchSNCircleCentralSplit.PENALTY_AREA_LENGTH
|
1032 |
+
- SoccerPitchSNCircleCentralSplit.GOAL_LINE_TO_PENALTY_MARK
|
1033 |
+
)
|
1034 |
+
y = np.sqrt(
|
1035 |
+
SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
|
1036 |
+
* SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
|
1037 |
+
- dx * dx
|
1038 |
+
)
|
1039 |
+
self.bottom_right_16M_penalty_arc_mark = np.array([x, y, 0], dtype="float")
|
1040 |
+
|
1041 |
+
# self.set_elevations(elevation)
|
1042 |
+
|
1043 |
+
self.point_dict = {}
|
1044 |
+
self.point_dict["CENTER_MARK"] = self.center_mark
|
1045 |
+
self.point_dict["L_PENALTY_MARK"] = self.left_penalty_mark
|
1046 |
+
self.point_dict["R_PENALTY_MARK"] = self.right_penalty_mark
|
1047 |
+
self.point_dict["TL_PITCH_CORNER"] = self.top_left_corner
|
1048 |
+
self.point_dict["BL_PITCH_CORNER"] = self.bottom_left_corner
|
1049 |
+
self.point_dict["TR_PITCH_CORNER"] = self.top_right_corner
|
1050 |
+
self.point_dict["BR_PITCH_CORNER"] = self.bottom_right_corner
|
1051 |
+
self.point_dict["L_PENALTY_AREA_TL_CORNER"] = self.left_penalty_area_top_left_corner
|
1052 |
+
self.point_dict["L_PENALTY_AREA_TR_CORNER"] = self.left_penalty_area_top_right_corner
|
1053 |
+
self.point_dict["L_PENALTY_AREA_BL_CORNER"] = self.left_penalty_area_bottom_left_corner
|
1054 |
+
self.point_dict["L_PENALTY_AREA_BR_CORNER"] = self.left_penalty_area_bottom_right_corner
|
1055 |
+
|
1056 |
+
self.point_dict["R_PENALTY_AREA_TL_CORNER"] = self.right_penalty_area_top_left_corner
|
1057 |
+
self.point_dict["R_PENALTY_AREA_TR_CORNER"] = self.right_penalty_area_top_right_corner
|
1058 |
+
self.point_dict["R_PENALTY_AREA_BL_CORNER"] = self.right_penalty_area_bottom_left_corner
|
1059 |
+
self.point_dict["R_PENALTY_AREA_BR_CORNER"] = self.right_penalty_area_bottom_right_corner
|
1060 |
+
|
1061 |
+
self.point_dict["L_GOAL_AREA_TL_CORNER"] = self.left_goal_area_top_left_corner
|
1062 |
+
self.point_dict["L_GOAL_AREA_TR_CORNER"] = self.left_goal_area_top_right_corner
|
1063 |
+
self.point_dict["L_GOAL_AREA_BL_CORNER"] = self.left_goal_area_bottom_left_corner
|
1064 |
+
self.point_dict["L_GOAL_AREA_BR_CORNER"] = self.left_goal_area_bottom_right_corner
|
1065 |
+
|
1066 |
+
self.point_dict["R_GOAL_AREA_TL_CORNER"] = self.right_goal_area_top_left_corner
|
1067 |
+
self.point_dict["R_GOAL_AREA_TR_CORNER"] = self.right_goal_area_top_right_corner
|
1068 |
+
self.point_dict["R_GOAL_AREA_BL_CORNER"] = self.right_goal_area_bottom_left_corner
|
1069 |
+
self.point_dict["R_GOAL_AREA_BR_CORNER"] = self.right_goal_area_bottom_right_corner
|
1070 |
+
|
1071 |
+
self.point_dict["L_GOAL_TL_POST"] = self.left_goal_top_left_post
|
1072 |
+
self.point_dict["L_GOAL_TR_POST"] = self.left_goal_top_right_post
|
1073 |
+
self.point_dict["L_GOAL_BL_POST"] = self.left_goal_bottom_left_post
|
1074 |
+
self.point_dict["L_GOAL_BR_POST"] = self.left_goal_bottom_right_post
|
1075 |
+
|
1076 |
+
self.point_dict["R_GOAL_TL_POST"] = self.right_goal_top_left_post
|
1077 |
+
self.point_dict["R_GOAL_TR_POST"] = self.right_goal_top_right_post
|
1078 |
+
self.point_dict["R_GOAL_BL_POST"] = self.right_goal_bottom_left_post
|
1079 |
+
self.point_dict["R_GOAL_BR_POST"] = self.right_goal_bottom_right_post
|
1080 |
+
|
1081 |
+
self.point_dict[
|
1082 |
+
"T_TOUCH_AND_HALFWAY_LINES_INTERSECTION"
|
1083 |
+
] = self.halfway_and_top_touch_line_mark
|
1084 |
+
self.point_dict[
|
1085 |
+
"B_TOUCH_AND_HALFWAY_LINES_INTERSECTION"
|
1086 |
+
] = self.halfway_and_bottom_touch_line_mark
|
1087 |
+
self.point_dict[
|
1088 |
+
"T_HALFWAY_LINE_AND_CENTER_CIRCLE_INTERSECTION"
|
1089 |
+
] = self.halfway_line_and_center_circle_top_mark
|
1090 |
+
self.point_dict[
|
1091 |
+
"B_HALFWAY_LINE_AND_CENTER_CIRCLE_INTERSECTION"
|
1092 |
+
] = self.halfway_line_and_center_circle_bottom_mark
|
1093 |
+
self.point_dict[
|
1094 |
+
"TL_16M_LINE_AND_PENALTY_ARC_INTERSECTION"
|
1095 |
+
] = self.top_left_16M_penalty_arc_mark
|
1096 |
+
self.point_dict[
|
1097 |
+
"BL_16M_LINE_AND_PENALTY_ARC_INTERSECTION"
|
1098 |
+
] = self.bottom_left_16M_penalty_arc_mark
|
1099 |
+
self.point_dict[
|
1100 |
+
"TR_16M_LINE_AND_PENALTY_ARC_INTERSECTION"
|
1101 |
+
] = self.top_right_16M_penalty_arc_mark
|
1102 |
+
self.point_dict[
|
1103 |
+
"BR_16M_LINE_AND_PENALTY_ARC_INTERSECTION"
|
1104 |
+
] = self.bottom_right_16M_penalty_arc_mark
|
1105 |
+
|
1106 |
+
self.line_extremities = dict()
|
1107 |
+
self.line_extremities["Big rect. left bottom"] = (
|
1108 |
+
self.point_dict["L_PENALTY_AREA_BL_CORNER"],
|
1109 |
+
self.point_dict["L_PENALTY_AREA_BR_CORNER"],
|
1110 |
+
)
|
1111 |
+
self.line_extremities["Big rect. left top"] = (
|
1112 |
+
self.point_dict["L_PENALTY_AREA_TL_CORNER"],
|
1113 |
+
self.point_dict["L_PENALTY_AREA_TR_CORNER"],
|
1114 |
+
)
|
1115 |
+
self.line_extremities["Big rect. left main"] = (
|
1116 |
+
self.point_dict["L_PENALTY_AREA_TR_CORNER"],
|
1117 |
+
self.point_dict["L_PENALTY_AREA_BR_CORNER"],
|
1118 |
+
)
|
1119 |
+
self.line_extremities["Big rect. right bottom"] = (
|
1120 |
+
self.point_dict["R_PENALTY_AREA_BL_CORNER"],
|
1121 |
+
self.point_dict["R_PENALTY_AREA_BR_CORNER"],
|
1122 |
+
)
|
1123 |
+
self.line_extremities["Big rect. right top"] = (
|
1124 |
+
self.point_dict["R_PENALTY_AREA_TL_CORNER"],
|
1125 |
+
self.point_dict["R_PENALTY_AREA_TR_CORNER"],
|
1126 |
+
)
|
1127 |
+
self.line_extremities["Big rect. right main"] = (
|
1128 |
+
self.point_dict["R_PENALTY_AREA_TL_CORNER"],
|
1129 |
+
self.point_dict["R_PENALTY_AREA_BL_CORNER"],
|
1130 |
+
)
|
1131 |
+
|
1132 |
+
self.line_extremities["Small rect. left bottom"] = (
|
1133 |
+
self.point_dict["L_GOAL_AREA_BL_CORNER"],
|
1134 |
+
self.point_dict["L_GOAL_AREA_BR_CORNER"],
|
1135 |
+
)
|
1136 |
+
self.line_extremities["Small rect. left top"] = (
|
1137 |
+
self.point_dict["L_GOAL_AREA_TL_CORNER"],
|
1138 |
+
self.point_dict["L_GOAL_AREA_TR_CORNER"],
|
1139 |
+
)
|
1140 |
+
self.line_extremities["Small rect. left main"] = (
|
1141 |
+
self.point_dict["L_GOAL_AREA_TR_CORNER"],
|
1142 |
+
self.point_dict["L_GOAL_AREA_BR_CORNER"],
|
1143 |
+
)
|
1144 |
+
self.line_extremities["Small rect. right bottom"] = (
|
1145 |
+
self.point_dict["R_GOAL_AREA_BL_CORNER"],
|
1146 |
+
self.point_dict["R_GOAL_AREA_BR_CORNER"],
|
1147 |
+
)
|
1148 |
+
self.line_extremities["Small rect. right top"] = (
|
1149 |
+
self.point_dict["R_GOAL_AREA_TL_CORNER"],
|
1150 |
+
self.point_dict["R_GOAL_AREA_TR_CORNER"],
|
1151 |
+
)
|
1152 |
+
self.line_extremities["Small rect. right main"] = (
|
1153 |
+
self.point_dict["R_GOAL_AREA_TL_CORNER"],
|
1154 |
+
self.point_dict["R_GOAL_AREA_BL_CORNER"],
|
1155 |
+
)
|
1156 |
+
|
1157 |
+
self.line_extremities["Side line top"] = (
|
1158 |
+
self.point_dict["TL_PITCH_CORNER"],
|
1159 |
+
self.point_dict["TR_PITCH_CORNER"],
|
1160 |
+
)
|
1161 |
+
self.line_extremities["Side line bottom"] = (
|
1162 |
+
self.point_dict["BL_PITCH_CORNER"],
|
1163 |
+
self.point_dict["BR_PITCH_CORNER"],
|
1164 |
+
)
|
1165 |
+
self.line_extremities["Side line left"] = (
|
1166 |
+
self.point_dict["TL_PITCH_CORNER"],
|
1167 |
+
self.point_dict["BL_PITCH_CORNER"],
|
1168 |
+
)
|
1169 |
+
self.line_extremities["Side line right"] = (
|
1170 |
+
self.point_dict["TR_PITCH_CORNER"],
|
1171 |
+
self.point_dict["BR_PITCH_CORNER"],
|
1172 |
+
)
|
1173 |
+
self.line_extremities["Middle line"] = (
|
1174 |
+
self.point_dict["T_TOUCH_AND_HALFWAY_LINES_INTERSECTION"],
|
1175 |
+
self.point_dict["B_TOUCH_AND_HALFWAY_LINES_INTERSECTION"],
|
1176 |
+
)
|
1177 |
+
|
1178 |
+
self.line_extremities["Goal left crossbar"] = (
|
1179 |
+
self.point_dict["L_GOAL_TR_POST"],
|
1180 |
+
self.point_dict["L_GOAL_TL_POST"],
|
1181 |
+
)
|
1182 |
+
self.line_extremities["Goal left post left "] = (
|
1183 |
+
self.point_dict["L_GOAL_TL_POST"],
|
1184 |
+
self.point_dict["L_GOAL_BL_POST"],
|
1185 |
+
)
|
1186 |
+
self.line_extremities["Goal left post right"] = (
|
1187 |
+
self.point_dict["L_GOAL_TR_POST"],
|
1188 |
+
self.point_dict["L_GOAL_BR_POST"],
|
1189 |
+
)
|
1190 |
+
|
1191 |
+
self.line_extremities["Goal right crossbar"] = (
|
1192 |
+
self.point_dict["R_GOAL_TL_POST"],
|
1193 |
+
self.point_dict["R_GOAL_TR_POST"],
|
1194 |
+
)
|
1195 |
+
self.line_extremities["Goal right post left"] = (
|
1196 |
+
self.point_dict["R_GOAL_TL_POST"],
|
1197 |
+
self.point_dict["R_GOAL_BL_POST"],
|
1198 |
+
)
|
1199 |
+
self.line_extremities["Goal right post right"] = (
|
1200 |
+
self.point_dict["R_GOAL_TR_POST"],
|
1201 |
+
self.point_dict["R_GOAL_BR_POST"],
|
1202 |
+
)
|
1203 |
+
self.line_extremities["Circle right"] = (
|
1204 |
+
self.point_dict["TR_16M_LINE_AND_PENALTY_ARC_INTERSECTION"],
|
1205 |
+
self.point_dict["BR_16M_LINE_AND_PENALTY_ARC_INTERSECTION"],
|
1206 |
+
)
|
1207 |
+
self.line_extremities["Circle left"] = (
|
1208 |
+
self.point_dict["TL_16M_LINE_AND_PENALTY_ARC_INTERSECTION"],
|
1209 |
+
self.point_dict["BL_16M_LINE_AND_PENALTY_ARC_INTERSECTION"],
|
1210 |
+
)
|
1211 |
+
|
1212 |
+
self.line_extremities_keys = dict()
|
1213 |
+
self.line_extremities_keys["Big rect. left bottom"] = (
|
1214 |
+
"L_PENALTY_AREA_BL_CORNER",
|
1215 |
+
"L_PENALTY_AREA_BR_CORNER",
|
1216 |
+
)
|
1217 |
+
self.line_extremities_keys["Big rect. left top"] = (
|
1218 |
+
"L_PENALTY_AREA_TL_CORNER",
|
1219 |
+
"L_PENALTY_AREA_TR_CORNER",
|
1220 |
+
)
|
1221 |
+
self.line_extremities_keys["Big rect. left main"] = (
|
1222 |
+
"L_PENALTY_AREA_TR_CORNER",
|
1223 |
+
"L_PENALTY_AREA_BR_CORNER",
|
1224 |
+
)
|
1225 |
+
self.line_extremities_keys["Big rect. right bottom"] = (
|
1226 |
+
"R_PENALTY_AREA_BL_CORNER",
|
1227 |
+
"R_PENALTY_AREA_BR_CORNER",
|
1228 |
+
)
|
1229 |
+
self.line_extremities_keys["Big rect. right top"] = (
|
1230 |
+
"R_PENALTY_AREA_TL_CORNER",
|
1231 |
+
"R_PENALTY_AREA_TR_CORNER",
|
1232 |
+
)
|
1233 |
+
self.line_extremities_keys["Big rect. right main"] = (
|
1234 |
+
"R_PENALTY_AREA_TL_CORNER",
|
1235 |
+
"R_PENALTY_AREA_BL_CORNER",
|
1236 |
+
)
|
1237 |
+
|
1238 |
+
self.line_extremities_keys["Small rect. left bottom"] = (
|
1239 |
+
"L_GOAL_AREA_BL_CORNER",
|
1240 |
+
"L_GOAL_AREA_BR_CORNER",
|
1241 |
+
)
|
1242 |
+
self.line_extremities_keys["Small rect. left top"] = (
|
1243 |
+
"L_GOAL_AREA_TL_CORNER",
|
1244 |
+
"L_GOAL_AREA_TR_CORNER",
|
1245 |
+
)
|
1246 |
+
self.line_extremities_keys["Small rect. left main"] = (
|
1247 |
+
"L_GOAL_AREA_TR_CORNER",
|
1248 |
+
"L_GOAL_AREA_BR_CORNER",
|
1249 |
+
)
|
1250 |
+
self.line_extremities_keys["Small rect. right bottom"] = (
|
1251 |
+
"R_GOAL_AREA_BL_CORNER",
|
1252 |
+
"R_GOAL_AREA_BR_CORNER",
|
1253 |
+
)
|
1254 |
+
self.line_extremities_keys["Small rect. right top"] = (
|
1255 |
+
"R_GOAL_AREA_TL_CORNER",
|
1256 |
+
"R_GOAL_AREA_TR_CORNER",
|
1257 |
+
)
|
1258 |
+
self.line_extremities_keys["Small rect. right main"] = (
|
1259 |
+
"R_GOAL_AREA_TL_CORNER",
|
1260 |
+
"R_GOAL_AREA_BL_CORNER",
|
1261 |
+
)
|
1262 |
+
|
1263 |
+
self.line_extremities_keys["Side line top"] = ("TL_PITCH_CORNER", "TR_PITCH_CORNER")
|
1264 |
+
self.line_extremities_keys["Side line bottom"] = ("BL_PITCH_CORNER", "BR_PITCH_CORNER")
|
1265 |
+
self.line_extremities_keys["Side line left"] = ("TL_PITCH_CORNER", "BL_PITCH_CORNER")
|
1266 |
+
self.line_extremities_keys["Side line right"] = ("TR_PITCH_CORNER", "BR_PITCH_CORNER")
|
1267 |
+
self.line_extremities_keys["Middle line"] = (
|
1268 |
+
"T_TOUCH_AND_HALFWAY_LINES_INTERSECTION",
|
1269 |
+
"B_TOUCH_AND_HALFWAY_LINES_INTERSECTION",
|
1270 |
+
)
|
1271 |
+
|
1272 |
+
self.line_extremities_keys["Goal left crossbar"] = ("L_GOAL_TR_POST", "L_GOAL_TL_POST")
|
1273 |
+
self.line_extremities_keys["Goal left post left "] = ("L_GOAL_TL_POST", "L_GOAL_BL_POST")
|
1274 |
+
self.line_extremities_keys["Goal left post right"] = ("L_GOAL_TR_POST", "L_GOAL_BR_POST")
|
1275 |
+
|
1276 |
+
self.line_extremities_keys["Goal right crossbar"] = ("R_GOAL_TL_POST", "R_GOAL_TR_POST")
|
1277 |
+
self.line_extremities_keys["Goal right post left"] = ("R_GOAL_TL_POST", "R_GOAL_BL_POST")
|
1278 |
+
self.line_extremities_keys["Goal right post right"] = ("R_GOAL_TR_POST", "R_GOAL_BR_POST")
|
1279 |
+
self.line_extremities_keys["Circle right"] = (
|
1280 |
+
"TR_16M_LINE_AND_PENALTY_ARC_INTERSECTION",
|
1281 |
+
"BR_16M_LINE_AND_PENALTY_ARC_INTERSECTION",
|
1282 |
+
)
|
1283 |
+
self.line_extremities_keys["Circle left"] = (
|
1284 |
+
"TL_16M_LINE_AND_PENALTY_ARC_INTERSECTION",
|
1285 |
+
"BL_16M_LINE_AND_PENALTY_ARC_INTERSECTION",
|
1286 |
+
)
|
1287 |
+
|
1288 |
+
def points(self):
|
1289 |
+
return [
|
1290 |
+
self.center_mark,
|
1291 |
+
self.halfway_and_bottom_touch_line_mark,
|
1292 |
+
self.halfway_and_top_touch_line_mark,
|
1293 |
+
self.halfway_line_and_center_circle_top_mark,
|
1294 |
+
self.halfway_line_and_center_circle_bottom_mark,
|
1295 |
+
self.bottom_right_corner,
|
1296 |
+
self.bottom_left_corner,
|
1297 |
+
self.top_right_corner,
|
1298 |
+
self.top_left_corner,
|
1299 |
+
self.left_penalty_mark,
|
1300 |
+
self.right_penalty_mark,
|
1301 |
+
self.left_penalty_area_top_right_corner,
|
1302 |
+
self.left_penalty_area_top_left_corner,
|
1303 |
+
self.left_penalty_area_bottom_right_corner,
|
1304 |
+
self.left_penalty_area_bottom_left_corner,
|
1305 |
+
self.right_penalty_area_top_right_corner,
|
1306 |
+
self.right_penalty_area_top_left_corner,
|
1307 |
+
self.right_penalty_area_bottom_right_corner,
|
1308 |
+
self.right_penalty_area_bottom_left_corner,
|
1309 |
+
self.left_goal_area_top_right_corner,
|
1310 |
+
self.left_goal_area_top_left_corner,
|
1311 |
+
self.left_goal_area_bottom_right_corner,
|
1312 |
+
self.left_goal_area_bottom_left_corner,
|
1313 |
+
self.right_goal_area_top_right_corner,
|
1314 |
+
self.right_goal_area_top_left_corner,
|
1315 |
+
self.right_goal_area_bottom_right_corner,
|
1316 |
+
self.right_goal_area_bottom_left_corner,
|
1317 |
+
self.top_left_16M_penalty_arc_mark,
|
1318 |
+
self.top_right_16M_penalty_arc_mark,
|
1319 |
+
self.bottom_left_16M_penalty_arc_mark,
|
1320 |
+
self.bottom_right_16M_penalty_arc_mark,
|
1321 |
+
self.left_goal_top_left_post,
|
1322 |
+
self.left_goal_top_right_post,
|
1323 |
+
self.left_goal_bottom_left_post,
|
1324 |
+
self.left_goal_bottom_right_post,
|
1325 |
+
self.right_goal_top_left_post,
|
1326 |
+
self.right_goal_top_right_post,
|
1327 |
+
self.right_goal_bottom_left_post,
|
1328 |
+
self.right_goal_bottom_right_post,
|
1329 |
+
]
|
1330 |
+
|
1331 |
+
def sample_field_points(self, dist=0.1, dist_circles=0.2):
|
1332 |
+
"""
|
1333 |
+
Samples each pitch element every dist meters, returns a dictionary associating the class of the element with a list of points sampled along this element.
|
1334 |
+
:param dist: the distance in meters between each point sampled
|
1335 |
+
:param dist_circles: the distance in meters between each point sampled on circles
|
1336 |
+
:return: a dictionary associating the class of the element with a list of points sampled along this element.
|
1337 |
+
"""
|
1338 |
+
polylines = dict()
|
1339 |
+
center = self.point_dict["CENTER_MARK"]
|
1340 |
+
fromAngle = 0.0
|
1341 |
+
toAngle = 2 * np.pi
|
1342 |
+
|
1343 |
+
if toAngle < fromAngle:
|
1344 |
+
toAngle += 2 * np.pi
|
1345 |
+
x1 = center[0] + np.cos(fromAngle) * SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
|
1346 |
+
y1 = center[1] + np.sin(fromAngle) * SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
|
1347 |
+
z1 = 0.0
|
1348 |
+
point = np.array((x1, y1, z1))
|
1349 |
+
polyline = [point]
|
1350 |
+
length = SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS * (toAngle - fromAngle)
|
1351 |
+
nb_pts = int(length / dist_circles)
|
1352 |
+
dangle = dist_circles / SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
|
1353 |
+
for i in range(1, nb_pts):
|
1354 |
+
angle = fromAngle + i * dangle
|
1355 |
+
x = center[0] + np.cos(angle) * SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
|
1356 |
+
y = center[1] + np.sin(angle) * SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
|
1357 |
+
z = 0
|
1358 |
+
point = np.array((x, y, z))
|
1359 |
+
polyline.append(point)
|
1360 |
+
|
1361 |
+
# split central circle in left and right
|
1362 |
+
polylines["Circle central left"] = [p for p in polyline if p[0] < 0.0]
|
1363 |
+
polylines["Circle central right"] = [p for p in polyline if p[0] >= 0.0]
|
1364 |
+
for key, line in self.line_extremities.items():
|
1365 |
+
|
1366 |
+
if "Circle" in key:
|
1367 |
+
if key == "Circle right":
|
1368 |
+
top = self.point_dict["TR_16M_LINE_AND_PENALTY_ARC_INTERSECTION"]
|
1369 |
+
bottom = self.point_dict["BR_16M_LINE_AND_PENALTY_ARC_INTERSECTION"]
|
1370 |
+
center = self.point_dict["R_PENALTY_MARK"]
|
1371 |
+
toAngle = np.arctan2(top[1] - center[1], top[0] - center[0]) + 2 * np.pi
|
1372 |
+
fromAngle = np.arctan2(bottom[1] - center[1], bottom[0] - center[0]) + 2 * np.pi
|
1373 |
+
elif key == "Circle left":
|
1374 |
+
top = self.point_dict["TL_16M_LINE_AND_PENALTY_ARC_INTERSECTION"]
|
1375 |
+
bottom = self.point_dict["BL_16M_LINE_AND_PENALTY_ARC_INTERSECTION"]
|
1376 |
+
center = self.point_dict["L_PENALTY_MARK"]
|
1377 |
+
fromAngle = np.arctan2(top[1] - center[1], top[0] - center[0]) + 2 * np.pi
|
1378 |
+
toAngle = np.arctan2(bottom[1] - center[1], bottom[0] - center[0]) + 2 * np.pi
|
1379 |
+
if toAngle < fromAngle:
|
1380 |
+
toAngle += 2 * np.pi
|
1381 |
+
x1 = (
|
1382 |
+
center[0]
|
1383 |
+
+ np.cos(fromAngle) * SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
|
1384 |
+
)
|
1385 |
+
y1 = (
|
1386 |
+
center[1]
|
1387 |
+
+ np.sin(fromAngle) * SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
|
1388 |
+
)
|
1389 |
+
z1 = 0.0
|
1390 |
+
xn = (
|
1391 |
+
center[0]
|
1392 |
+
+ np.cos(toAngle) * SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
|
1393 |
+
)
|
1394 |
+
yn = (
|
1395 |
+
center[1]
|
1396 |
+
+ np.sin(toAngle) * SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
|
1397 |
+
)
|
1398 |
+
zn = 0.0
|
1399 |
+
start = np.array((x1, y1, z1))
|
1400 |
+
end = np.array((xn, yn, zn))
|
1401 |
+
polyline = [start]
|
1402 |
+
length = SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS * (
|
1403 |
+
toAngle - fromAngle
|
1404 |
+
)
|
1405 |
+
nb_pts = int(length / dist_circles)
|
1406 |
+
dangle = dist_circles / SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
|
1407 |
+
for i in range(1, nb_pts + 1):
|
1408 |
+
angle = fromAngle + i * dangle
|
1409 |
+
x = (
|
1410 |
+
center[0]
|
1411 |
+
+ np.cos(angle) * SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
|
1412 |
+
)
|
1413 |
+
y = (
|
1414 |
+
center[1]
|
1415 |
+
+ np.sin(angle) * SoccerPitchSNCircleCentralSplit.CENTER_CIRCLE_RADIUS
|
1416 |
+
)
|
1417 |
+
z = 0
|
1418 |
+
point = np.array((x, y, z))
|
1419 |
+
polyline.append(point)
|
1420 |
+
polyline.append(end)
|
1421 |
+
polylines[key] = polyline
|
1422 |
+
else:
|
1423 |
+
start = line[0]
|
1424 |
+
end = line[1]
|
1425 |
+
|
1426 |
+
polyline = [start]
|
1427 |
+
|
1428 |
+
total_dist = np.sqrt(np.sum(np.square(start - end)))
|
1429 |
+
nb_pts = int(total_dist / dist - 1)
|
1430 |
+
|
1431 |
+
v = end - start
|
1432 |
+
v /= np.linalg.norm(v)
|
1433 |
+
prev_pt = start
|
1434 |
+
for i in range(nb_pts):
|
1435 |
+
pt = prev_pt + dist * v
|
1436 |
+
prev_pt = pt
|
1437 |
+
polyline.append(pt)
|
1438 |
+
polyline.append(end)
|
1439 |
+
polylines[key] = polyline
|
1440 |
+
return polylines
|
1441 |
+
|
1442 |
+
def get_2d_homogeneous_line(self, line_name):
|
1443 |
+
"""
|
1444 |
+
For lines belonging to the pitch lawn plane returns its 2D homogenous equation coefficients
|
1445 |
+
:param line_name
|
1446 |
+
:return: an array containing the three coefficients of the line
|
1447 |
+
"""
|
1448 |
+
# ensure line in football pitch plane
|
1449 |
+
if (
|
1450 |
+
line_name in self.line_extremities.keys()
|
1451 |
+
and "post" not in line_name
|
1452 |
+
and "crossbar" not in line_name
|
1453 |
+
and "Circle" not in line_name
|
1454 |
+
):
|
1455 |
+
extremities = self.line_extremities[line_name]
|
1456 |
+
p1 = np.array([extremities[0][0], extremities[0][1], 1], dtype="float")
|
1457 |
+
p2 = np.array([extremities[1][0], extremities[1][1], 1], dtype="float")
|
1458 |
+
line = np.cross(p1, p2)
|
1459 |
+
|
1460 |
+
return line
|
1461 |
+
return None
|
1462 |
+
|
1463 |
+
|
1464 |
+
class Abstract3dModel(metaclass=ABCMeta):
|
1465 |
+
def __init__(self) -> None:
|
1466 |
+
|
1467 |
+
self.points = None # keypoints: tensor of shape (N, 3)
|
1468 |
+
self.points_sampled = (
|
1469 |
+
None # sampled points for each segment Dict[str: torch.tensor of shape (*, 3)]
|
1470 |
+
)
|
1471 |
+
self.points_sampled_palette = {}
|
1472 |
+
self.segment_names = set(self.points_sampled_palette.keys())
|
1473 |
+
|
1474 |
+
self.line_segments = [] # tensor of shape (3, S_l, 2) containing 2 points
|
1475 |
+
self.line_segments_names = [] # list of of respective names for each s in S_l
|
1476 |
+
self.line_palette = [] # list of RGB tuples
|
1477 |
+
|
1478 |
+
self.circle_segments = None # tensor of shape (3, S_c, num_points_per_circle)
|
1479 |
+
self.circle_segments_names = [] # list of of respective names for each s in S_c
|
1480 |
+
self.circle_palette = [] # list of RGB tuples
|
1481 |
+
|
1482 |
+
|
1483 |
+
class Meshgrid(Abstract3dModel):
|
1484 |
+
def __init__(self, height=68, width=105):
|
1485 |
+
self.points = kornia.utils.create_meshgrid(
|
1486 |
+
height=height + 1, width=width + 1, normalized_coordinates=True
|
1487 |
+
)
|
1488 |
+
self.points = self.points.flatten(start_dim=-3, end_dim=-2)
|
1489 |
+
self.points[:, :, 0] = self.points[:, :, 0] * width / 2
|
1490 |
+
self.points[:, :, 1] = self.points[:, :, 1] * height / 2
|
1491 |
+
self.points = kornia.geometry.conversions.convert_points_to_homogeneous(self.points)
|
1492 |
+
self.points[:, :, -1] = 0.0 # set z=0)
|
1493 |
+
self.points = self.points.squeeze(0)
|
1494 |
+
self.points_sampled = {"meshgrid": self.points}
|
1495 |
+
|
1496 |
+
|
1497 |
+
class SoccerPitchLineCircleSegments(Abstract3dModel):
|
1498 |
+
def __init__(
|
1499 |
+
self,
|
1500 |
+
base_field,
|
1501 |
+
device="cpu",
|
1502 |
+
N_cstar=128,
|
1503 |
+
sampling_factor_lines=0.2,
|
1504 |
+
sampling_factor_circles=0.8,
|
1505 |
+
) -> None:
|
1506 |
+
|
1507 |
+
if not (
|
1508 |
+
isinstance(base_field, SoccerPitchSNCircleCentralSplit)
|
1509 |
+
or isinstance(base_field, SoccerPitchSN)
|
1510 |
+
):
|
1511 |
+
raise NotImplementedError
|
1512 |
+
|
1513 |
+
self.sampling_factor_lines = sampling_factor_lines
|
1514 |
+
self.sampling_factor_circles = sampling_factor_circles
|
1515 |
+
|
1516 |
+
self._field_sncalib = base_field
|
1517 |
+
|
1518 |
+
self.device = device
|
1519 |
+
|
1520 |
+
# classical keypoints as single tensor
|
1521 |
+
self.points = torch.from_numpy(np.stack(self._field_sncalib.points())).float().to(device)
|
1522 |
+
|
1523 |
+
# sampled points for each segment Dict[str: torch.tensor of shape (*, 3)]
|
1524 |
+
self.points_sampled = self._field_sncalib.sample_field_points(
|
1525 |
+
self.sampling_factor_lines, self.sampling_factor_circles
|
1526 |
+
)
|
1527 |
+
self.points_sampled = {
|
1528 |
+
k: torch.from_numpy(np.stack(v)).float().to(device)
|
1529 |
+
for k, v in self.points_sampled.items()
|
1530 |
+
}
|
1531 |
+
self.points_sampled_palette = self._field_sncalib.palette
|
1532 |
+
self.segment_names = set(self.points_sampled_palette.keys())
|
1533 |
+
self.cmap_01 = {k: [c / 255.0 for c in v] for k, v in self.points_sampled_palette.items()}
|
1534 |
+
|
1535 |
+
self.line_collection: List[LineCollection] = []
|
1536 |
+
self.line_segments = [] # (3, S, 2)
|
1537 |
+
self.line_segments_names = []
|
1538 |
+
|
1539 |
+
for line_name, (p0, p1) in self._field_sncalib.line_extremities.items():
|
1540 |
+
if "Circle" not in line_name:
|
1541 |
+
p0 = torch.from_numpy(p0).float().to(device)
|
1542 |
+
p1 = torch.from_numpy(p1).float().to(device)
|
1543 |
+
direction = p1 - p0
|
1544 |
+
direction_norm = direction / torch.linalg.norm(direction)
|
1545 |
+
self.line_collection.append(
|
1546 |
+
LineCollection(
|
1547 |
+
support=p0,
|
1548 |
+
direction=direction,
|
1549 |
+
direction_norm=direction_norm,
|
1550 |
+
)
|
1551 |
+
)
|
1552 |
+
self.line_segments_names.append(line_name)
|
1553 |
+
self.line_segments.append(torch.stack([p0, p1], dim=1))
|
1554 |
+
|
1555 |
+
self.line_segments = torch.stack(self.line_segments, dim=-1).transpose(1, 2).to(device)
|
1556 |
+
self.line_palette = [
|
1557 |
+
self._field_sncalib.palette[self.line_segments_names[i]]
|
1558 |
+
for i in range(len(self.line_segments_names))
|
1559 |
+
]
|
1560 |
+
|
1561 |
+
if isinstance(base_field, SoccerPitchSNCircleCentralSplit):
|
1562 |
+
self.circle_segments_names = [
|
1563 |
+
"Circle central left",
|
1564 |
+
"Circle central right",
|
1565 |
+
"Circle left",
|
1566 |
+
"Circle right",
|
1567 |
+
]
|
1568 |
+
elif isinstance(base_field, SoccerPitchSN):
|
1569 |
+
self.circle_segments_names = [
|
1570 |
+
"Circle central",
|
1571 |
+
"Circle left",
|
1572 |
+
"Circle right",
|
1573 |
+
]
|
1574 |
+
else:
|
1575 |
+
raise NotImplementedError
|
1576 |
+
|
1577 |
+
self.circle_segments = self._sample_points_from_circle_segments(
|
1578 |
+
m=N_cstar
|
1579 |
+
) # (3, num_circles, num_points_per_circle)
|
1580 |
+
|
1581 |
+
self.circle_palette = [
|
1582 |
+
self._field_sncalib.palette[self.circle_segments_names[i]]
|
1583 |
+
for i in range(len(self.circle_segments_names))
|
1584 |
+
]
|
1585 |
+
|
1586 |
+
def _sample_points_from_circle_segments(self, m: int):
|
1587 |
+
|
1588 |
+
sampled_points = self._field_sncalib.sample_field_points(dist=1.0, dist_circles=0.05)
|
1589 |
+
for key in self.circle_segments_names:
|
1590 |
+
assert len(sampled_points[key]) >= m
|
1591 |
+
return (
|
1592 |
+
torch.stack(
|
1593 |
+
[
|
1594 |
+
torch.from_numpy(np.stack(random.sample(sampled_points[key], k=m), axis=-1))
|
1595 |
+
for key in self.circle_segments_names
|
1596 |
+
],
|
1597 |
+
dim=1,
|
1598 |
+
)
|
1599 |
+
.float()
|
1600 |
+
.to(self.device)
|
1601 |
+
) # (3, S, m)
|
1602 |
+
|
1603 |
+
|
1604 |
+
if __name__ == "__main__":
|
1605 |
+
from matplotlib import pyplot as plt
|
1606 |
+
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
|
1607 |
+
|
1608 |
+
model3d = SoccerPitchLineCircleSegments()
|
1609 |
+
|
1610 |
+
fig = plt.figure(figsize=(20, 20))
|
1611 |
+
ax = fig.add_subplot(projection="3d")
|
1612 |
+
for s in range(len(model3d.line_collection)):
|
1613 |
+
ax.quiver(
|
1614 |
+
model3d.line_collection[s].support[0],
|
1615 |
+
model3d.line_collection[s].support[1],
|
1616 |
+
model3d.line_collection[s].support[2],
|
1617 |
+
model3d.line_collection[s].direction[0],
|
1618 |
+
model3d.line_collection[s].direction[1],
|
1619 |
+
model3d.line_collection[s].direction[2],
|
1620 |
+
arrow_length_ratio=0.05,
|
1621 |
+
color=[x / 255.0 for x in model3d.line_palette[s]],
|
1622 |
+
zorder=2000,
|
1623 |
+
# length=68.0,
|
1624 |
+
linewidths=3,
|
1625 |
+
label=model3d.line_segments_names[s],
|
1626 |
+
alpha=0.5,
|
1627 |
+
)
|
1628 |
+
|
1629 |
+
plt.legend()
|
1630 |
+
ax.set_xlim([-105 / 2, 105 / 2])
|
1631 |
+
ax.set_ylim([-105 / 2, 105 / 2])
|
1632 |
+
ax.set_zlim([-105 / 2, 105 / 2])
|
1633 |
+
plt.show()
|
1634 |
+
|
1635 |
+
fig = plt.figure(figsize=(20, 20))
|
1636 |
+
ax = fig.add_subplot(projection="3d")
|
1637 |
+
for segment_name, sampled_points in model3d.points_sampled.items():
|
1638 |
+
ax.scatter(
|
1639 |
+
sampled_points[:, 0],
|
1640 |
+
sampled_points[:, 1],
|
1641 |
+
-sampled_points[:, 2],
|
1642 |
+
zorder=3000,
|
1643 |
+
color=[x / 255.0 for x in model3d.points_sampled_palette[segment_name]],
|
1644 |
+
marker="x",
|
1645 |
+
label=segment_name,
|
1646 |
+
)
|
1647 |
+
|
1648 |
+
plt.legend()
|
1649 |
+
ax.set_xlim([-105 / 2, 105 / 2])
|
1650 |
+
ax.set_ylim([-105 / 2, 105 / 2])
|
1651 |
+
ax.set_zlim([-105 / 2, 105 / 2])
|
1652 |
+
plt.show()
|
1653 |
+
|
1654 |
+
fig = plt.figure(figsize=(20, 20))
|
1655 |
+
ax = fig.add_subplot(projection="3d")
|
1656 |
+
for s in range(model3d.line_segments.shape[1]):
|
1657 |
+
|
1658 |
+
if "crossbar" in model3d.line_segments_names[s]:
|
1659 |
+
print(s, model3d.line_segments[:, s])
|
1660 |
+
ax.scatter(
|
1661 |
+
model3d.line_segments[0, s],
|
1662 |
+
model3d.line_segments[1, s],
|
1663 |
+
-model3d.line_segments[2, s],
|
1664 |
+
zorder=3000,
|
1665 |
+
color=[x / 255.0 for x in model3d.line_palette[s]],
|
1666 |
+
marker="x",
|
1667 |
+
label=model3d.line_segments_names[s],
|
1668 |
+
)
|
1669 |
+
|
1670 |
+
plt.legend()
|
1671 |
+
ax.set_xlim([-105 / 2, 105 / 2])
|
1672 |
+
ax.set_ylim([-105 / 2, 105 / 2])
|
1673 |
+
ax.set_zlim([-105 / 2, 105 / 2])
|
1674 |
+
plt.savefig("soccer_field_line_segments.pdf")
|
visualizer.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
# Dimensions du terrain en yards
|
5 |
+
FIELD_LENGTH_YARDS = 114.83
|
6 |
+
FIELD_WIDTH_YARDS = 74.37
|
7 |
+
|
8 |
+
# Constantes de taille d'image attendue
|
9 |
+
EXPECTED_H, EXPECTED_W = 720, 1280
|
10 |
+
|
11 |
+
# Import des constantes d'indices depuis pose_estimator si nécessaire (ou les redéfinir ici)
|
12 |
+
from pose_estimator import (LEFT_ANKLE_KP_INDEX, RIGHT_ANKLE_KP_INDEX,
|
13 |
+
CONFIDENCE_THRESHOLD_KEYPOINTS, DEFAULT_MARKER_COLOR, SKELETON_EDGES, SKELETON_THICKNESS)
|
14 |
+
|
15 |
+
# Constantes pour les marqueurs
|
16 |
+
MARKER_RADIUS = 6
|
17 |
+
MARKER_BORDER_THICKNESS = 1
|
18 |
+
MARKER_BORDER_COLOR = (0, 0, 0) # Noir
|
19 |
+
|
20 |
+
# Plage de modulation pour l'échelle dynamique inverseé
|
21 |
+
DYNAMIC_SCALE_MIN_MODULATION = 0.4 # Pour les joueurs les plus loin (haut de la minimap)
|
22 |
+
DYNAMIC_SCALE_MAX_MODULATION = 1.6 # Pour les joueurs les plus près (bas de la minimap)
|
23 |
+
|
24 |
+
def calculate_dynamic_scale(y_position, frame_height, min_scale=1.0, max_scale=2):
|
25 |
+
"""Calcule le facteur d'échelle en fonction de la position verticale (non utilisé dans cette version simplifiée)."""
|
26 |
+
normalized_position = y_position / frame_height
|
27 |
+
return min_scale + (max_scale - min_scale) * normalized_position
|
28 |
+
|
29 |
+
def _prepare_minimap_base(minimap_size=(EXPECTED_W, EXPECTED_H)):
|
30 |
+
"""Prépare le fond de la minimap (vert texturé verticalement dans la zone terrain) et calcule les métriques du terrain."""
|
31 |
+
minimap_h, minimap_w = minimap_size[1], minimap_size[0]
|
32 |
+
|
33 |
+
# Définir les couleurs et la largeur des bandes verticales
|
34 |
+
base_green = (0, 60, 0) # Vert foncé (fond)
|
35 |
+
stripe_green = (0, 70, 0) #
|
36 |
+
stripe_width = 5 # Largeur de chaque bande verticale (pixels)
|
37 |
+
|
38 |
+
# Initialiser TOUTE la minimap avec la couleur de base
|
39 |
+
minimap_bgr = np.full((minimap_h, minimap_w, 3), base_green, dtype=np.uint8)
|
40 |
+
|
41 |
+
# --- Calculer les métriques et les limites du terrain D'ABORD ---
|
42 |
+
scale_x = minimap_w / FIELD_LENGTH_YARDS
|
43 |
+
scale_y = minimap_h / FIELD_WIDTH_YARDS
|
44 |
+
scale = min(scale_x, scale_y) * 0.9 # Marge
|
45 |
+
|
46 |
+
field_width_px = int(FIELD_WIDTH_YARDS * scale)
|
47 |
+
field_length_px = int(FIELD_LENGTH_YARDS * scale)
|
48 |
+
offset_x = (minimap_w - field_length_px) // 2
|
49 |
+
offset_y = (minimap_h - field_width_px) // 2
|
50 |
+
|
51 |
+
# --- Dessiner les bandes VERTICALES alternées UNIQUEMENT dans la zone du terrain ---
|
52 |
+
for x in range(offset_x, offset_x + field_length_px, stripe_width * 2):
|
53 |
+
# Coordonnées du rectangle vertical pour la bande claire
|
54 |
+
start_x = x
|
55 |
+
end_x = min(x + stripe_width, offset_x + field_length_px) # Ne pas dépasser la limite droite
|
56 |
+
start_y = offset_y
|
57 |
+
end_y = offset_y + field_width_px
|
58 |
+
|
59 |
+
cv2.rectangle(minimap_bgr, (start_x, start_y), (end_x, end_y), stripe_green, thickness=-1)
|
60 |
+
# La bande foncée suivante est déjà là (couleur de base)
|
61 |
+
|
62 |
+
# --- Préparer la matrice S et les métriques à retourner ---
|
63 |
+
S = np.array([
|
64 |
+
[scale, 0, offset_x],
|
65 |
+
[0, scale, offset_y],
|
66 |
+
[0, 0, 1]
|
67 |
+
], dtype=np.float32)
|
68 |
+
|
69 |
+
metrics = {
|
70 |
+
"scale": scale,
|
71 |
+
"offset_x": offset_x,
|
72 |
+
"offset_y": offset_y,
|
73 |
+
"field_width_px": field_width_px,
|
74 |
+
"field_length_px": field_length_px,
|
75 |
+
"S": S
|
76 |
+
}
|
77 |
+
return minimap_bgr, metrics
|
78 |
+
|
79 |
+
def _draw_field_lines(minimap_bgr, metrics):
|
80 |
+
"""Dessine les lignes du terrain et les buts sur la minimap."""
|
81 |
+
scale = metrics["scale"]
|
82 |
+
offset_x = metrics["offset_x"]
|
83 |
+
offset_y = metrics["offset_y"]
|
84 |
+
field_width_px = metrics["field_width_px"]
|
85 |
+
field_length_px = metrics["field_length_px"]
|
86 |
+
|
87 |
+
line_color = (255, 255, 255) # Blanc
|
88 |
+
line_thickness = 1
|
89 |
+
goal_thickness = 1 # Épaisseur pour les poteaux de but
|
90 |
+
goal_width_yards = 8 # Largeur standard du but
|
91 |
+
|
92 |
+
center_x = offset_x + field_length_px // 2
|
93 |
+
center_y = offset_y + field_width_px // 2
|
94 |
+
penalty_area_width_px = int(SoccerPitchSN.PENALTY_AREA_WIDTH * scale)
|
95 |
+
penalty_area_length_px = int(SoccerPitchSN.PENALTY_AREA_LENGTH * scale)
|
96 |
+
goal_area_width_px = int(SoccerPitchSN.GOAL_AREA_WIDTH * scale)
|
97 |
+
goal_area_length_px = int(SoccerPitchSN.GOAL_AREA_LENGTH * scale)
|
98 |
+
center_circle_radius_px = int(SoccerPitchSN.CENTER_CIRCLE_RADIUS * scale)
|
99 |
+
goal_width_px = int(goal_width_yards * scale)
|
100 |
+
|
101 |
+
# Dessiner les lignes principales
|
102 |
+
cv2.rectangle(minimap_bgr, (offset_x, offset_y), (offset_x + field_length_px, offset_y + field_width_px), line_color, line_thickness)
|
103 |
+
cv2.line(minimap_bgr, (center_x, offset_y), (center_x, offset_y + field_width_px), line_color, line_thickness)
|
104 |
+
cv2.circle(minimap_bgr, (center_x, center_y), center_circle_radius_px, line_color, line_thickness)
|
105 |
+
cv2.circle(minimap_bgr, (center_x, center_y), 3, line_color, -1) # Point central
|
106 |
+
cv2.rectangle(minimap_bgr, (offset_x, center_y - penalty_area_width_px//2), (offset_x + penalty_area_length_px, center_y + penalty_area_width_px//2), line_color, line_thickness)
|
107 |
+
cv2.rectangle(minimap_bgr, (offset_x + field_length_px - penalty_area_length_px, center_y - penalty_area_width_px//2), (offset_x + field_length_px, center_y + penalty_area_width_px//2), line_color, line_thickness)
|
108 |
+
cv2.rectangle(minimap_bgr, (offset_x, center_y - goal_area_width_px//2), (offset_x + goal_area_length_px, center_y + goal_area_width_px//2), line_color, line_thickness)
|
109 |
+
cv2.rectangle(minimap_bgr, (offset_x + field_length_px - goal_area_length_px, center_y - goal_area_width_px//2), (offset_x + field_length_px, center_y + goal_area_width_px//2), line_color, line_thickness)
|
110 |
+
|
111 |
+
# Dessiner les buts (rectangles épais sur les lignes de but)
|
112 |
+
goal_y_top = center_y - goal_width_px // 2
|
113 |
+
goal_y_bottom = center_y + goal_width_px // 2
|
114 |
+
# But gauche
|
115 |
+
cv2.rectangle(minimap_bgr, (offset_x-6 - goal_thickness // 2, goal_y_top), (offset_x + goal_thickness // 2, goal_y_bottom), line_color, thickness=goal_thickness)
|
116 |
+
# But droit
|
117 |
+
cv2.rectangle(minimap_bgr, (offset_x + field_length_px - goal_thickness // 2, goal_y_top), (offset_x +6 + field_length_px + goal_thickness // 2, goal_y_bottom), line_color, thickness=goal_thickness)
|
118 |
+
|
119 |
+
def create_minimap_view(image_rgb, homography, minimap_size=(EXPECTED_W, EXPECTED_H)):
|
120 |
+
"""Crée une vue minimap avec l'image RGB originale projetée et les lignes du terrain.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
image_rgb: Image source en format RGB (720p attendu).
|
124 |
+
homography: Matrice d'homographie (numpy array) pour projeter l'image.
|
125 |
+
minimap_size: Taille de la minimap de sortie (largeur, hauteur).
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
L'image de la minimap (numpy array BGR) ou None si l'homographie est invalide.
|
129 |
+
"""
|
130 |
+
if homography is None:
|
131 |
+
print("Avertissement : Homographie invalide, impossible de créer la minimap (vue originale).")
|
132 |
+
return None
|
133 |
+
|
134 |
+
h, w = image_rgb.shape[:2]
|
135 |
+
if h != EXPECTED_H or w != EXPECTED_W:
|
136 |
+
print(f"Avertissement : L'image RGB d'entrée n'est pas en {EXPECTED_W}x{EXPECTED_H}, redimensionnement...")
|
137 |
+
image_rgb = cv2.resize(image_rgb, (EXPECTED_W, EXPECTED_H), interpolation=cv2.INTER_LINEAR)
|
138 |
+
|
139 |
+
minimap_bgr, metrics = _prepare_minimap_base(minimap_size)
|
140 |
+
S = metrics["S"]
|
141 |
+
|
142 |
+
try:
|
143 |
+
overlay = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
|
144 |
+
overlay = cv2.convertScaleAbs(overlay, alpha=1.2, beta=10)
|
145 |
+
overlay = cv2.addWeighted(overlay, 0.5, np.zeros_like(overlay), 0.5, 0)
|
146 |
+
|
147 |
+
H_minimap = S @ homography
|
148 |
+
warped = cv2.warpPerspective(overlay, H_minimap, minimap_size, flags=cv2.INTER_LINEAR)
|
149 |
+
|
150 |
+
mask = cv2.cvtColor(warped, cv2.COLOR_BGR2GRAY)
|
151 |
+
_, mask = cv2.threshold(mask, 1, 255, cv2.THRESH_BINARY)
|
152 |
+
|
153 |
+
minimap_bgr = np.where(mask[..., None] > 0, warped, minimap_bgr)
|
154 |
+
|
155 |
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
156 |
+
cv2.drawContours(minimap_bgr, contours, -1, (255, 255, 255), 2)
|
157 |
+
|
158 |
+
except Exception as e:
|
159 |
+
print(f"Erreur lors de la projection sur la mini-carte (vue originale) : {str(e)}")
|
160 |
+
|
161 |
+
_draw_field_lines(minimap_bgr, metrics)
|
162 |
+
return minimap_bgr
|
163 |
+
|
164 |
+
def create_minimap_with_offset_skeletons(player_data_list, homography,
|
165 |
+
base_skeleton_scale: float,
|
166 |
+
minimap_size=(EXPECTED_W, EXPECTED_H)) -> tuple[np.ndarray | None, float | None]:
|
167 |
+
"""Crée une vue minimap en dessinant le squelette original (réduit/agrandi dynamiquement et inversé)
|
168 |
+
à la position projetée du joueur, trié par position Y.
|
169 |
+
|
170 |
+
Args:
|
171 |
+
player_data_list: Liste de dictionnaires retournée par get_player_data.
|
172 |
+
homography: Matrice d'homographie (numpy array).
|
173 |
+
base_skeleton_scale: Facteur d'échelle de base pour dessiner les squelettes.
|
174 |
+
minimap_size: Taille de la minimap de sortie (largeur, hauteur).
|
175 |
+
|
176 |
+
Returns:
|
177 |
+
Tuple: (L'image de la minimap (numpy array BGR) ou None, Échelle moyenne appliquée ou None)
|
178 |
+
"""
|
179 |
+
if homography is None:
|
180 |
+
print("Avertissement : Homographie invalide, impossible de créer la minimap (squelettes décalés).")
|
181 |
+
return None, None # Retourner None pour l'image et l'échelle
|
182 |
+
|
183 |
+
minimap_bgr, metrics = _prepare_minimap_base(minimap_size)
|
184 |
+
|
185 |
+
# --- Dessiner les lignes du terrain D'ABORD ---
|
186 |
+
_draw_field_lines(minimap_bgr, metrics)
|
187 |
+
|
188 |
+
S = metrics["S"]
|
189 |
+
H_minimap = S @ homography
|
190 |
+
|
191 |
+
players_to_draw = [] # Liste pour stocker les joueurs valides avec leur position Y
|
192 |
+
|
193 |
+
# --- Étape 1 & 2: Calculer la position projetée pour tous les joueurs valides ---
|
194 |
+
for p_data in player_data_list:
|
195 |
+
kps_img = p_data['keypoints']
|
196 |
+
scores = p_data['scores']
|
197 |
+
bbox = p_data['bbox']
|
198 |
+
color = p_data['avg_color']
|
199 |
+
|
200 |
+
# -- Calculer le point de référence sur l'image --
|
201 |
+
l_ankle_pt = kps_img[LEFT_ANKLE_KP_INDEX]
|
202 |
+
r_ankle_pt = kps_img[RIGHT_ANKLE_KP_INDEX]
|
203 |
+
l_ankle_score = scores[LEFT_ANKLE_KP_INDEX]
|
204 |
+
r_ankle_score = scores[RIGHT_ANKLE_KP_INDEX]
|
205 |
+
ref_point_img = None
|
206 |
+
if l_ankle_score > CONFIDENCE_THRESHOLD_KEYPOINTS and r_ankle_score > CONFIDENCE_THRESHOLD_KEYPOINTS:
|
207 |
+
ref_point_img = (l_ankle_pt + r_ankle_pt) / 2
|
208 |
+
elif l_ankle_score > CONFIDENCE_THRESHOLD_KEYPOINTS:
|
209 |
+
ref_point_img = l_ankle_pt
|
210 |
+
elif r_ankle_score > CONFIDENCE_THRESHOLD_KEYPOINTS:
|
211 |
+
ref_point_img = r_ankle_pt
|
212 |
+
else:
|
213 |
+
x1, _, x2, y2 = bbox
|
214 |
+
ref_point_img = np.array([(x1 + x2) / 2, y2], dtype=np.float32)
|
215 |
+
if ref_point_img is None: continue
|
216 |
+
|
217 |
+
# -- Projeter ce point de référence sur la minimap --
|
218 |
+
try:
|
219 |
+
point_to_transform = np.array([[ref_point_img]], dtype=np.float32)
|
220 |
+
projected_point = cv2.perspectiveTransform(point_to_transform, H_minimap)
|
221 |
+
mx, my = map(int, projected_point[0, 0])
|
222 |
+
h_map, w_map = minimap_bgr.shape[:2]
|
223 |
+
if not (0 <= mx < w_map and 0 <= my < h_map):
|
224 |
+
continue # Ignorer si hors des limites de la minimap
|
225 |
+
except Exception as e:
|
226 |
+
# print(f"Erreur lors de la projection du point de référence {ref_point_img}: {e}") # Optionnel: décommenter pour debug
|
227 |
+
continue
|
228 |
+
|
229 |
+
# Stocker les données nécessaires pour le tri et le dessin
|
230 |
+
players_to_draw.append({
|
231 |
+
'data': p_data,
|
232 |
+
'mx': mx,
|
233 |
+
'my': my,
|
234 |
+
'ref_point': ref_point_img
|
235 |
+
})
|
236 |
+
|
237 |
+
# --- Étape 3: Trier les joueurs par position Y (ordre croissant) ---
|
238 |
+
# Ceux avec Y plus petit (plus haut) seront dessinés en premier
|
239 |
+
players_to_draw.sort(key=lambda p: p['my'])
|
240 |
+
|
241 |
+
# Variables pour calculer l'échelle moyenne appliquée
|
242 |
+
total_applied_scale = 0.0
|
243 |
+
drawn_players_count = 0
|
244 |
+
|
245 |
+
# --- Étape 4: Dessiner les joueurs dans l'ordre trié (MAINTENANT AU-DESSUS DES LIGNES) ---
|
246 |
+
for player_info in players_to_draw:
|
247 |
+
p_data = player_info['data']
|
248 |
+
mx = player_info['mx']
|
249 |
+
my = player_info['my']
|
250 |
+
ref_point_img = player_info['ref_point']
|
251 |
+
kps_img = p_data['keypoints']
|
252 |
+
scores = p_data['scores']
|
253 |
+
# color = p_data['avg_color'] # Ignorer la couleur calculée
|
254 |
+
drawing_color = (0, 0, 0) # Utiliser le noir pour tous les joueurs
|
255 |
+
|
256 |
+
# -- Calculer l'échelle dynamique INVERSÉE pour CE joueur --
|
257 |
+
minimap_height = minimap_bgr.shape[0]
|
258 |
+
if minimap_height == 0: continue
|
259 |
+
ref_y_normalized = my / minimap_height
|
260 |
+
dynamic_modulation = DYNAMIC_SCALE_MIN_MODULATION + \
|
261 |
+
(DYNAMIC_SCALE_MAX_MODULATION - DYNAMIC_SCALE_MIN_MODULATION) * (1.0 - ref_y_normalized)
|
262 |
+
dynamic_modulation = np.clip(dynamic_modulation, DYNAMIC_SCALE_MIN_MODULATION * 0.8, DYNAMIC_SCALE_MAX_MODULATION * 1.2)
|
263 |
+
final_draw_scale = base_skeleton_scale * dynamic_modulation
|
264 |
+
|
265 |
+
# Ajouter à la somme pour la moyenne
|
266 |
+
total_applied_scale += final_draw_scale
|
267 |
+
drawn_players_count += 1
|
268 |
+
|
269 |
+
# -- Dessiner le squelette --
|
270 |
+
kps_relative_to_ref = kps_img - ref_point_img
|
271 |
+
for kp_idx1, kp_idx2 in SKELETON_EDGES:
|
272 |
+
if scores[kp_idx1] > CONFIDENCE_THRESHOLD_KEYPOINTS and scores[kp_idx2] > CONFIDENCE_THRESHOLD_KEYPOINTS:
|
273 |
+
pt1_map_offset = (mx, my) + kps_relative_to_ref[kp_idx1] * final_draw_scale
|
274 |
+
pt2_map_offset = (mx, my) + kps_relative_to_ref[kp_idx2] * final_draw_scale
|
275 |
+
pt1_draw = tuple(map(int, pt1_map_offset))
|
276 |
+
pt2_draw = tuple(map(int, pt2_map_offset))
|
277 |
+
# Vérifier si les points sont dans les limites avant de dessiner (sécurité)
|
278 |
+
h_map, w_map = minimap_bgr.shape[:2]
|
279 |
+
if (0 <= pt1_draw[0] < w_map and 0 <= pt1_draw[1] < h_map and
|
280 |
+
0 <= pt2_draw[0] < w_map and 0 <= pt2_draw[1] < h_map):
|
281 |
+
cv2.line(minimap_bgr, pt1_draw, pt2_draw, drawing_color, SKELETON_THICKNESS, cv2.LINE_AA) # Utiliser drawing_color (noir)
|
282 |
+
|
283 |
+
# Calculer l'échelle moyenne finale
|
284 |
+
average_draw_scale = base_skeleton_scale # Valeur par défaut si aucun joueur n'est dessiné
|
285 |
+
if drawn_players_count > 0:
|
286 |
+
average_draw_scale = total_applied_scale / drawn_players_count
|
287 |
+
|
288 |
+
return minimap_bgr, average_draw_scale # Retourner aussi l'échelle moyenne
|
289 |
+
|
290 |
+
# Définition simplifiée de SoccerPitchSN juste pour les constantes de dimension
|
291 |
+
# (pour éviter d'importer toute la classe complexe)
|
292 |
+
class SoccerPitchSN:
|
293 |
+
GOAL_LINE_TO_PENALTY_MARK = 11.0
|
294 |
+
PENALTY_AREA_WIDTH = 42
|
295 |
+
PENALTY_AREA_LENGTH = 19
|
296 |
+
GOAL_AREA_WIDTH = 18.32
|
297 |
+
GOAL_AREA_LENGTH = 5.5
|
298 |
+
CENTER_CIRCLE_RADIUS = 10
|