minh9972t12 commited on
Commit
af98df1
·
1 Parent(s): a88bcc9

Update src/detection.py

Browse files
Files changed (1) hide show
  1. src/detection.py +49 -155
src/detection.py CHANGED
@@ -3,12 +3,9 @@ from typing import List, Dict, Tuple
3
  import cv2
4
  from pathlib import Path
5
  import yaml
6
- import gc
7
- import torch
8
- from datetime import datetime
9
 
10
  class YOLOv11Detector:
11
- """YOLOv11 detector for car damage detection with memory management"""
12
 
13
  def __init__(self, config_path: str = "config.yaml"):
14
  """Initialize YOLOv11 detector with configuration"""
@@ -38,54 +35,12 @@ class YOLOv11Detector:
38
  self.confidence = self.config['model']['confidence']
39
  self.iou_threshold = self.config['model']['iou_threshold']
40
  self.classes = self.config['detection']['classes']
41
-
42
- # Thêm tracking cho model reload
43
- self.model = None
44
- self.net = None
45
- self.inference_count = 0
46
- self.last_reload_time = datetime.now()
47
- self.max_inferences_before_reload = 100 # Reload sau 100 lần inference
48
- self.max_hours_before_reload = 2 # Reload sau 2 giờ
49
-
50
- # Load model lần đầu
51
- self._load_model()
52
-
53
- def _load_model(self):
54
- """Load hoặc reload model"""
55
- try:
56
- # Clear existing model
57
- if hasattr(self, 'model') and self.model is not None:
58
- del self.model
59
- if hasattr(self, 'net') and self.net is not None:
60
- del self.net
61
-
62
- # Force garbage collection
63
- gc.collect()
64
- if torch.cuda.is_available():
65
- torch.cuda.empty_cache()
66
-
67
- # Load model based on format
68
- if self.model_path.endswith('.onnx'):
69
- self._load_onnx_model()
70
- else: # .pt format
71
- self._load_pytorch_model()
72
-
73
- # Reset counters
74
- self.inference_count = 0
75
- self.last_reload_time = datetime.now()
76
-
77
- print(f"Model (re)loaded successfully at {self.last_reload_time}")
78
-
79
- except Exception as e:
80
- print(f"Error loading model: {e}")
81
- raise
82
-
83
- def _should_reload_model(self) -> bool:
84
- """Kiểm tra xem có cần reload model không"""
85
- time_diff = (datetime.now() - self.last_reload_time).total_seconds() / 3600
86
-
87
- return (self.inference_count >= self.max_inferences_before_reload or
88
- time_diff >= self.max_hours_before_reload)
89
 
90
  def _load_pytorch_model(self):
91
  """Load PyTorch model using Ultralytics"""
@@ -93,27 +48,17 @@ class YOLOv11Detector:
93
  self.model = YOLO(self.model_path)
94
 
95
  # Set model to appropriate device
96
- if self.device == 'cuda:0' and torch.cuda.is_available():
97
  self.model.to('cuda')
98
- else:
99
- self.model.to('cpu')
100
-
101
- # QUAN TRỌNG: Set model to evaluation mode
102
- self.model.model.eval() # Ultralytics YOLO có nested model
103
-
104
- # Freeze BatchNorm và Dropout
105
- for module in self.model.model.modules():
106
- if isinstance(module, (torch.nn.BatchNorm2d, torch.nn.Dropout)):
107
- module.eval()
108
 
109
- print(f"Loaded PyTorch model: {self.model_path} (eval mode)")
110
 
111
  def _load_onnx_model(self):
112
  """Load ONNX model using OpenCV DNN"""
113
  self.net = cv2.dnn.readNet(self.model_path)
114
 
115
  # Set backend based on device
116
- if self.device == 'cuda:0' and cv2.cuda.getCudaEnabledDeviceCount() > 0:
117
  self.net.setPreferableBackend(cv2.dnn.DNN_BACKEND_CUDA)
118
  self.net.setPreferableTarget(cv2.dnn.DNN_TARGET_CUDA)
119
  else:
@@ -124,7 +69,7 @@ class YOLOv11Detector:
124
 
125
  def detect(self, image: np.ndarray) -> Dict:
126
  """
127
- Perform detection on image with automatic model reload
128
 
129
  Args:
130
  image: Input image as numpy array (BGR format)
@@ -132,84 +77,53 @@ class YOLOv11Detector:
132
  Returns:
133
  Dictionary containing detection results
134
  """
135
- # Kiểm tra và reload model nếu cần
136
- if self._should_reload_model():
137
- print("Reloading model due to usage limit/time limit...")
138
- self._load_model()
139
-
140
- # Increment counter
141
- self.inference_count += 1
142
-
143
- try:
144
- if self.model_path.endswith('.onnx'):
145
- return self._detect_onnx(image)
146
- else:
147
- return self._detect_pytorch(image)
148
- except Exception as e:
149
- print(f"Detection failed, attempting model reload: {e}")
150
- # Thử reload model và detect lại
151
- self._load_model()
152
- if self.model_path.endswith('.onnx'):
153
- return self._detect_onnx(image)
154
- else:
155
- return self._detect_pytorch(image)
156
 
157
  def _detect_pytorch(self, image: np.ndarray) -> Dict:
158
  """Detection using PyTorch model"""
159
- try:
160
- # QUAN TRỌNG: Sử dụng no_grad để tắt gradient computation
161
- with torch.no_grad():
162
- # Đảm bảo model ở eval mode
163
- self.model.model.eval()
164
-
165
- # Run YOLO inference
166
- results = self.model(
167
- image,
168
- conf=self.confidence,
169
- iou=self.iou_threshold,
170
- device=self.device,
171
- verbose=False
172
- )
173
-
174
- # Parse results
175
- detections = {
176
- 'boxes': [],
177
- 'confidences': [],
178
- 'classes': [],
179
- 'class_ids': []
180
- }
181
 
182
- if len(results) > 0 and results[0].boxes is not None:
183
- boxes = results[0].boxes
 
 
 
 
 
184
 
185
- for box in boxes:
186
- # Get box coordinates (xyxy format)
187
- x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
188
 
189
- # Get confidence and class
190
- conf = float(box.conf[0].cpu().numpy())
191
- cls_id = int(box.cls[0].cpu().numpy())
192
 
193
- # Map class ID to class name
194
- if cls_id < len(self.classes):
195
- cls_name = self.classes[cls_id]
196
- else:
197
- cls_name = f"class_{cls_id}"
198
 
199
- detections['boxes'].append([int(x1), int(y1), int(x2), int(y2)])
200
- detections['confidences'].append(conf)
201
- detections['classes'].append(cls_name)
202
- detections['class_ids'].append(cls_id)
 
203
 
204
- # Clear GPU cache sau mỗi inference
205
- if torch.cuda.is_available():
206
- torch.cuda.empty_cache()
 
207
 
208
- return detections
209
-
210
- except Exception as e:
211
- print(f"PyTorch detection error: {e}")
212
- raise
213
 
214
  def _detect_onnx(self, image: np.ndarray) -> Dict:
215
  """Detection using ONNX model (compatible with original code)"""
@@ -287,24 +201,4 @@ class YOLOv11Detector:
287
 
288
  def detect_batch(self, images: List[np.ndarray]) -> List[Dict]:
289
  """Detect on multiple images"""
290
- return [self.detect(img) for img in images]
291
-
292
- def get_model_stats(self) -> Dict:
293
- """Trả về thống kê về model"""
294
- time_diff = (datetime.now() - self.last_reload_time).total_seconds() / 3600
295
-
296
- # Kiểm tra model mode
297
- model_mode = "unknown"
298
- if hasattr(self, 'model') and self.model is not None:
299
- if hasattr(self.model, 'model'):
300
- model_mode = "eval" if not self.model.model.training else "train"
301
-
302
- return {
303
- "inference_count": self.inference_count,
304
- "hours_since_reload": round(time_diff, 2),
305
- "last_reload_time": self.last_reload_time.isoformat(),
306
- "model_path": self.model_path,
307
- "device": self.device,
308
- "model_mode": model_mode, # Thêm thông tin model mode
309
- "torch_no_grad_enabled": not torch.is_grad_enabled()
310
- }
 
3
  import cv2
4
  from pathlib import Path
5
  import yaml
 
 
 
6
 
7
  class YOLOv11Detector:
8
+ """YOLOv11 detector for car damage detection"""
9
 
10
  def __init__(self, config_path: str = "config.yaml"):
11
  """Initialize YOLOv11 detector with configuration"""
 
35
  self.confidence = self.config['model']['confidence']
36
  self.iou_threshold = self.config['model']['iou_threshold']
37
  self.classes = self.config['detection']['classes']
38
+
39
+ # Load model based on format
40
+ if model_path.endswith('.onnx'):
41
+ self._load_onnx_model()
42
+ else: # .pt format
43
+ self._load_pytorch_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  def _load_pytorch_model(self):
46
  """Load PyTorch model using Ultralytics"""
 
48
  self.model = YOLO(self.model_path)
49
 
50
  # Set model to appropriate device
51
+ if self.device == 'cuda:0':
52
  self.model.to('cuda')
 
 
 
 
 
 
 
 
 
 
53
 
54
+ print(f"Loaded PyTorch model: {self.model_path}")
55
 
56
  def _load_onnx_model(self):
57
  """Load ONNX model using OpenCV DNN"""
58
  self.net = cv2.dnn.readNet(self.model_path)
59
 
60
  # Set backend based on device
61
+ if self.device == 'cuda:0':
62
  self.net.setPreferableBackend(cv2.dnn.DNN_BACKEND_CUDA)
63
  self.net.setPreferableTarget(cv2.dnn.DNN_TARGET_CUDA)
64
  else:
 
69
 
70
  def detect(self, image: np.ndarray) -> Dict:
71
  """
72
+ Perform detection on image
73
 
74
  Args:
75
  image: Input image as numpy array (BGR format)
 
77
  Returns:
78
  Dictionary containing detection results
79
  """
80
+ if self.model_path.endswith('.onnx'):
81
+ return self._detect_onnx(image)
82
+ else:
83
+ return self._detect_pytorch(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  def _detect_pytorch(self, image: np.ndarray) -> Dict:
86
  """Detection using PyTorch model"""
87
+ # Run YOLO inference
88
+ results = self.model(
89
+ image,
90
+ conf=self.confidence,
91
+ iou=self.iou_threshold,
92
+ device=self.device,
93
+ verbose=False
94
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
+ # Parse results
97
+ detections = {
98
+ 'boxes': [],
99
+ 'confidences': [],
100
+ 'classes': [],
101
+ 'class_ids': []
102
+ }
103
 
104
+ if len(results) > 0 and results[0].boxes is not None:
105
+ boxes = results[0].boxes
 
106
 
107
+ for box in boxes:
108
+ # Get box coordinates (xyxy format)
109
+ x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
110
 
111
+ # Get confidence and class
112
+ conf = float(box.conf[0].cpu().numpy())
113
+ cls_id = int(box.cls[0].cpu().numpy())
 
 
114
 
115
+ # Map class ID to class name
116
+ if cls_id < len(self.classes):
117
+ cls_name = self.classes[cls_id]
118
+ else:
119
+ cls_name = f"class_{cls_id}"
120
 
121
+ detections['boxes'].append([int(x1), int(y1), int(x2), int(y2)])
122
+ detections['confidences'].append(conf)
123
+ detections['classes'].append(cls_name)
124
+ detections['class_ids'].append(cls_id)
125
 
126
+ return detections
 
 
 
 
127
 
128
  def _detect_onnx(self, image: np.ndarray) -> Dict:
129
  """Detection using ONNX model (compatible with original code)"""
 
201
 
202
  def detect_batch(self, images: List[np.ndarray]) -> List[Dict]:
203
  """Detect on multiple images"""
204
+ return [self.detect(img) for img in images]