leo-bourrel commited on
Commit
631ca76
·
1 Parent(s): 58ddda3

feat: make image loop in predict.py

Browse files
Files changed (2) hide show
  1. src/predict.py +49 -1
  2. src/yolo_predictions.py +12 -22
src/predict.py CHANGED
@@ -1,10 +1,58 @@
1
  import os
2
  from tqdm import tqdm
3
  from yolo_predictions import process_YOLO_predictions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
 
6
  if __name__ == "__main__":
7
  list_folder = ["./images/Al Khaldiyah Park", "./images/Family Park"]
8
  model_file = "yolo11s.pt"
9
 
10
- process_YOLO_predictions(list_folder, model_file)
 
1
  import os
2
  from tqdm import tqdm
3
  from yolo_predictions import process_YOLO_predictions
4
+ from openai_predictions import process_agent_predictions
5
+ from data_models.park_manager import ParkManager
6
+
7
+
8
+ def load_images_from_folder(folder):
9
+ """
10
+ Load images from a folder.
11
+
12
+ Args:
13
+ folder (str): Path to the folder.
14
+
15
+ Returns:
16
+ list: List of image paths.
17
+ """
18
+ for file in os.listdir(folder):
19
+ if file.endswith(".jpg") or file.endswith(".png"):
20
+ yield os.path.join(folder, file)
21
+
22
+
23
+ def process_predictions(list_folder, model_file):
24
+ """
25
+ Process predictions for a list of image folders, adding images and bounding boxes to the database.
26
+
27
+ Args:
28
+ list_folder (list): List of image folders.
29
+ model_file (str): Path to the YOLO model.
30
+
31
+ Raises:
32
+ Exception: If an error occurs during processing.
33
+ """
34
+ park_manager = ParkManager()
35
+
36
+ for folder in list_folder:
37
+ print(f"Loading images from folder: {folder}...")
38
+
39
+ park_name = folder.split("/")[-1]
40
+
41
+ print(f"Retrieving park ID for '{park_name}'...")
42
+ park_id = park_manager.get_park_id(park_name)
43
+ if not park_id:
44
+ print(f"Park '{park_name}' not found in the database. Skipping folder.")
45
+ continue
46
+
47
+ for image in tqdm(load_images_from_folder(folder)):
48
+ process_YOLO_predictions(park_id, image, model_file)
49
+ # process_agent_predictions(park_id, image)
50
+ print(f"Folders {folder} processed successfully!")
51
+ park_manager.close_connection()
52
 
53
 
54
  if __name__ == "__main__":
55
  list_folder = ["./images/Al Khaldiyah Park", "./images/Family Park"]
56
  model_file = "yolo11s.pt"
57
 
58
+ process_predictions(list_folder, model_file)
src/yolo_predictions.py CHANGED
@@ -49,6 +49,10 @@ def load_model(model_file):
49
  Returns:
50
  YOLO: The YOLO model.
51
  """
 
 
 
 
52
  try:
53
  print(f"Loading YOLO model from file: {model_file}...")
54
  model = YOLO(model_file)
@@ -59,7 +63,7 @@ def load_model(model_file):
59
  return None
60
 
61
 
62
- def process_YOLO_predictions(list_folder: list[str], model_file: str):
63
  """
64
  Process predictions for a list of image folders, adding images and bounding boxes to the database using the YOLO model.
65
 
@@ -71,35 +75,21 @@ def process_YOLO_predictions(list_folder: list[str], model_file: str):
71
  Exception: If an error occurs during processing.
72
  """
73
  print(f"Initializing processing of predictions...")
74
- bbox_manager = BoundingBoxManager()
75
- image_manager = ImageManager()
76
- park_manager = ParkManager()
77
 
78
  model = load_model(model_file)
79
  if not model:
80
  raise Exception("Error loading YOLO model. Aborting processing.")
81
 
82
- for folder in list_folder:
83
- print(f"Processing folder: {folder}...")
84
- park_name = folder.split("/")[-1]
85
-
86
- print(f"Retrieving park ID for '{park_name}'...")
87
- park_id = park_manager.get_park_id(park_name)
88
- if not park_id:
89
- print(f"Park '{park_name}' not found in the database. Skipping folder.")
90
- continue
91
-
92
- predictions = model(folder, stream=True, show=False)
93
 
94
- for result in predictions:
95
- print(f"Processing result for image: {result.path.split('/')[-1]}...")
96
- image_id = add_image_to_db(result, image_manager, park_id)
97
 
98
- if image_id:
99
- print(f"Adding bounding boxes to the database...")
100
- add_bboxes_to_db(result, bbox_manager, image_id)
101
 
102
  print(f"Processing complete. Closing database connections.")
103
  bbox_manager.close_connection()
104
  image_manager.close_connection()
105
- park_manager.close_connection()
 
49
  Returns:
50
  YOLO: The YOLO model.
51
  """
52
+ global model
53
+ if model:
54
+ return model
55
+
56
  try:
57
  print(f"Loading YOLO model from file: {model_file}...")
58
  model = YOLO(model_file)
 
63
  return None
64
 
65
 
66
+ def process_YOLO_predictions(park_id: int, image: str, model_file: str):
67
  """
68
  Process predictions for a list of image folders, adding images and bounding boxes to the database using the YOLO model.
69
 
 
75
  Exception: If an error occurs during processing.
76
  """
77
  print(f"Initializing processing of predictions...")
 
 
 
78
 
79
  model = load_model(model_file)
80
  if not model:
81
  raise Exception("Error loading YOLO model. Aborting processing.")
82
 
83
+ predictions = model(image, stream=True, show=False)
 
 
 
 
 
 
 
 
 
 
84
 
85
+ for result in predictions:
86
+ print(f"Processing result for image: {result.path.split('/')[-1]}...")
87
+ image_id = add_image_to_db(result, image_manager, park_id)
88
 
89
+ if image_id:
90
+ print(f"Adding bounding boxes to the database...")
91
+ add_bboxes_to_db(result, bbox_manager, image_id)
92
 
93
  print(f"Processing complete. Closing database connections.")
94
  bbox_manager.close_connection()
95
  image_manager.close_connection()