Léo Bourrel commited on
Commit
bc49d5d
·
unverified ·
2 Parent(s): 4ef2d19 345d537

Merge pull request #3 from ia-labo/feat/prediction

Browse files
Files changed (1) hide show
  1. src/predict.py +95 -61
src/predict.py CHANGED
@@ -1,81 +1,115 @@
1
- """Predict the class of a list of images, and return a json file with the prediction"""
2
-
3
- import json
4
 
5
  from ultralytics import YOLO
6
 
7
- # [{
8
- # "image": "images/ChIJ_yEAweFlXj4Reo-x-AghMRM",
9
- # "prediction": [
10
- # {
11
- # "class": "person",
12
- # "confidence": 0.99,
13
- # "box": [0.0, 0.0, 0.0, 0.0]
14
- # }
15
- # ]
16
- # }]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
 
19
- def add_prediction_to_json(predictions) -> None:
20
- """Add the prediction to the json file.
 
21
 
22
  Args:
23
- predictions (generator): generator of predictions from the YOLO model
 
 
 
24
  Returns:
25
- list: list of predictions to add to the json result
26
  """
27
- formatted_predictions = []
28
- for result in predictions:
29
- for box in result.boxes:
30
- formatted_predictions.append(
31
- {
32
- "image": result.path.split("/")[-1],
33
- "class": int(box.cls[0].numpy().item()),
34
- "confidence": box.conf[0].numpy().item(),
35
- "box": box.xyxy[0].numpy().tolist(),
36
- }
 
37
  )
38
- return formatted_predictions
 
 
39
 
40
 
41
- def get_prediction(list_folder, model_file, output_file) -> None:
42
- """Get the prediction for a list of images
 
43
 
44
  Args:
45
- list_path (list): list of images
46
- model_file (str): path to the model
47
- output_file (str): path to the output file
 
 
48
  """
 
 
 
 
 
49
  model = YOLO(model_file)
50
 
51
- with open(output_file, "w+", encoding="utf8") as file:
52
- dict_predictions = []
53
- for folder in list_folder[:2]:
54
- # stream=True to get the prediction for each image
55
- # instead of trying to get all the predictions at once
56
- # show=False to not show the image when getting the prediction
57
- predictions = model(folder, stream=True, show=False)
58
- try:
59
- dict_predictions.append(
60
- {
61
- "park": folder.split("/")[-1],
62
- "prediction": add_prediction_to_json(predictions),
63
- }
64
- )
65
- except AttributeError as att_error:
66
- dict_predictions.append(
67
- {
68
- "status": "error",
69
- "error_name": str(att_error.__class__.__name__),
70
- "message": str(att_error),
71
- }
72
- )
73
- json.dump(dict_predictions, file, indent=4)
 
74
 
75
 
76
  if __name__ == "__main__":
77
- get_prediction(
78
- ["../images/Al Khaldiyah Park", "../images/Family Park"],
79
- "models/yolov5s.pt",
80
- "./predictions.json",
81
- )
 
1
+ from datetime import datetime
 
 
2
 
3
  from ultralytics import YOLO
4
 
5
+ from models.bbox_manager import BoundingBoxManager
6
+ from models.image_manager import ImageManager
7
+ from models.park_manager import ParkManager
8
+
9
+
10
+ def add_image_to_db(result, image_manager, park_id):
11
+ """
12
+ Adds an image to the database.
13
+
14
+ Args:
15
+ result: YOLO result object containing the image details.
16
+ image_manager (ImageManager): Instance of ImageManager to interact with the database.
17
+ park_id (int): The ID of the associated park.
18
+
19
+ Returns:
20
+ int: The ID of the added image.
21
+ """
22
+ image_name = result.path.split("/")[-1]
23
+ current_datetime = datetime.now()
24
+ created_at = current_datetime.strftime("%Y-%m-%d %H:%M:%S")
25
+
26
+ try:
27
+ print(f"Adding image '{image_name}' to the database...")
28
+ image = image_manager.add_image(
29
+ name=image_name, created_at=created_at, park_id=park_id
30
+ )
31
+ image_id = image_manager.get_image_id(image_name)
32
+ print(f"Image '{image_name}' added with ID: {image_id}")
33
+ return image_id
34
+ except Exception as e:
35
+ print(f"Error adding image '{image_name}' to the database: {e}")
36
+ return None
37
 
38
 
39
+ def add_bboxes_to_db(result, bbox_manager, image_id):
40
+ """
41
+ Adds bounding boxes from a YOLO result to the database.
42
 
43
  Args:
44
+ result: YOLO result object containing the bounding box details.
45
+ bbox_manager (BoundingBoxManager): Instance of BoundingBoxManager to interact with the database.
46
+ image_id (int): The ID of the associated image.
47
+
48
  Returns:
49
+ None
50
  """
51
+ for box in result.boxes:
52
+ try:
53
+ print(f"Adding bounding box for image ID: {image_id}...")
54
+ bbox_manager.add_bbox(
55
+ confidence=box.conf[0].numpy().item(),
56
+ class_id=int(box.cls[0].numpy().item()),
57
+ img_id=image_id,
58
+ x_min=box.xyxy[0][0].numpy().item(),
59
+ y_min=box.xyxy[0][1].numpy().item(),
60
+ x_max=box.xyxy[0][2].numpy().item(),
61
+ y_max=box.xyxy[0][3].numpy().item(),
62
  )
63
+ print(f"Bounding box added to DB for image ID: {image_id}")
64
+ except Exception as e:
65
+ print(f"Error inserting bounding box into DB: {e}")
66
 
67
 
68
+ def process_predictions(list_folder, model_file):
69
+ """
70
+ Process predictions for a list of image folders, adding images and bounding boxes to the database.
71
 
72
  Args:
73
+ list_folder (list): List of image folders.
74
+ model_file (str): Path to the YOLO model.
75
+
76
+ Raises:
77
+ Exception: If an error occurs during processing.
78
  """
79
+ print(f"Initializing processing of predictions...")
80
+ bbox_manager = BoundingBoxManager()
81
+ image_manager = ImageManager()
82
+ park_manager = ParkManager()
83
+
84
  model = YOLO(model_file)
85
 
86
+ for folder in list_folder:
87
+ print(f"Processing folder: {folder}...")
88
+ park_name = folder.split("/")[-1]
89
+
90
+ print(f"Retrieving park ID for '{park_name}'...")
91
+ park_id = park_manager.get_park_id(park_name)
92
+ if not park_id:
93
+ print(f"Park '{park_name}' not found in the database. Skipping folder.")
94
+ continue
95
+
96
+ predictions = model(folder, stream=True, show=False)
97
+
98
+ for result in predictions:
99
+ print(f"Processing result for image: {result.path.split('/')[-1]}...")
100
+ image_id = add_image_to_db(result, image_manager, park_id)
101
+
102
+ if image_id:
103
+ print(f"Adding bounding boxes to the database...")
104
+ add_bboxes_to_db(result, bbox_manager, image_id)
105
+
106
+ print(f"Processing complete. Closing database connections.")
107
+ bbox_manager.close_connection()
108
+ image_manager.close_connection()
109
+ park_manager.close_connection()
110
 
111
 
112
  if __name__ == "__main__":
113
+ list_folder = ["../images/Al Khaldiyah Park", "../images/Family Park"]
114
+ model_file = "yolov5s.pt"
115
+ process_predictions(list_folder, model_file)