Spaces:
Running
Running
Commit
·
af98df1
1
Parent(s):
a88bcc9
Update src/detection.py
Browse files- src/detection.py +49 -155
src/detection.py
CHANGED
@@ -3,12 +3,9 @@ from typing import List, Dict, Tuple
|
|
3 |
import cv2
|
4 |
from pathlib import Path
|
5 |
import yaml
|
6 |
-
import gc
|
7 |
-
import torch
|
8 |
-
from datetime import datetime
|
9 |
|
10 |
class YOLOv11Detector:
|
11 |
-
"""YOLOv11 detector for car damage detection
|
12 |
|
13 |
def __init__(self, config_path: str = "config.yaml"):
|
14 |
"""Initialize YOLOv11 detector with configuration"""
|
@@ -38,54 +35,12 @@ class YOLOv11Detector:
|
|
38 |
self.confidence = self.config['model']['confidence']
|
39 |
self.iou_threshold = self.config['model']['iou_threshold']
|
40 |
self.classes = self.config['detection']['classes']
|
41 |
-
|
42 |
-
#
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
self.max_inferences_before_reload = 100 # Reload sau 100 lần inference
|
48 |
-
self.max_hours_before_reload = 2 # Reload sau 2 giờ
|
49 |
-
|
50 |
-
# Load model lần đầu
|
51 |
-
self._load_model()
|
52 |
-
|
53 |
-
def _load_model(self):
|
54 |
-
"""Load hoặc reload model"""
|
55 |
-
try:
|
56 |
-
# Clear existing model
|
57 |
-
if hasattr(self, 'model') and self.model is not None:
|
58 |
-
del self.model
|
59 |
-
if hasattr(self, 'net') and self.net is not None:
|
60 |
-
del self.net
|
61 |
-
|
62 |
-
# Force garbage collection
|
63 |
-
gc.collect()
|
64 |
-
if torch.cuda.is_available():
|
65 |
-
torch.cuda.empty_cache()
|
66 |
-
|
67 |
-
# Load model based on format
|
68 |
-
if self.model_path.endswith('.onnx'):
|
69 |
-
self._load_onnx_model()
|
70 |
-
else: # .pt format
|
71 |
-
self._load_pytorch_model()
|
72 |
-
|
73 |
-
# Reset counters
|
74 |
-
self.inference_count = 0
|
75 |
-
self.last_reload_time = datetime.now()
|
76 |
-
|
77 |
-
print(f"Model (re)loaded successfully at {self.last_reload_time}")
|
78 |
-
|
79 |
-
except Exception as e:
|
80 |
-
print(f"Error loading model: {e}")
|
81 |
-
raise
|
82 |
-
|
83 |
-
def _should_reload_model(self) -> bool:
|
84 |
-
"""Kiểm tra xem có cần reload model không"""
|
85 |
-
time_diff = (datetime.now() - self.last_reload_time).total_seconds() / 3600
|
86 |
-
|
87 |
-
return (self.inference_count >= self.max_inferences_before_reload or
|
88 |
-
time_diff >= self.max_hours_before_reload)
|
89 |
|
90 |
def _load_pytorch_model(self):
|
91 |
"""Load PyTorch model using Ultralytics"""
|
@@ -93,27 +48,17 @@ class YOLOv11Detector:
|
|
93 |
self.model = YOLO(self.model_path)
|
94 |
|
95 |
# Set model to appropriate device
|
96 |
-
if self.device == 'cuda:0'
|
97 |
self.model.to('cuda')
|
98 |
-
else:
|
99 |
-
self.model.to('cpu')
|
100 |
-
|
101 |
-
# QUAN TRỌNG: Set model to evaluation mode
|
102 |
-
self.model.model.eval() # Ultralytics YOLO có nested model
|
103 |
-
|
104 |
-
# Freeze BatchNorm và Dropout
|
105 |
-
for module in self.model.model.modules():
|
106 |
-
if isinstance(module, (torch.nn.BatchNorm2d, torch.nn.Dropout)):
|
107 |
-
module.eval()
|
108 |
|
109 |
-
print(f"Loaded PyTorch model: {self.model_path}
|
110 |
|
111 |
def _load_onnx_model(self):
|
112 |
"""Load ONNX model using OpenCV DNN"""
|
113 |
self.net = cv2.dnn.readNet(self.model_path)
|
114 |
|
115 |
# Set backend based on device
|
116 |
-
if self.device == 'cuda:0'
|
117 |
self.net.setPreferableBackend(cv2.dnn.DNN_BACKEND_CUDA)
|
118 |
self.net.setPreferableTarget(cv2.dnn.DNN_TARGET_CUDA)
|
119 |
else:
|
@@ -124,7 +69,7 @@ class YOLOv11Detector:
|
|
124 |
|
125 |
def detect(self, image: np.ndarray) -> Dict:
|
126 |
"""
|
127 |
-
Perform detection on image
|
128 |
|
129 |
Args:
|
130 |
image: Input image as numpy array (BGR format)
|
@@ -132,84 +77,53 @@ class YOLOv11Detector:
|
|
132 |
Returns:
|
133 |
Dictionary containing detection results
|
134 |
"""
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
self.
|
139 |
-
|
140 |
-
# Increment counter
|
141 |
-
self.inference_count += 1
|
142 |
-
|
143 |
-
try:
|
144 |
-
if self.model_path.endswith('.onnx'):
|
145 |
-
return self._detect_onnx(image)
|
146 |
-
else:
|
147 |
-
return self._detect_pytorch(image)
|
148 |
-
except Exception as e:
|
149 |
-
print(f"Detection failed, attempting model reload: {e}")
|
150 |
-
# Thử reload model và detect lại
|
151 |
-
self._load_model()
|
152 |
-
if self.model_path.endswith('.onnx'):
|
153 |
-
return self._detect_onnx(image)
|
154 |
-
else:
|
155 |
-
return self._detect_pytorch(image)
|
156 |
|
157 |
def _detect_pytorch(self, image: np.ndarray) -> Dict:
|
158 |
"""Detection using PyTorch model"""
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
image,
|
168 |
-
conf=self.confidence,
|
169 |
-
iou=self.iou_threshold,
|
170 |
-
device=self.device,
|
171 |
-
verbose=False
|
172 |
-
)
|
173 |
-
|
174 |
-
# Parse results
|
175 |
-
detections = {
|
176 |
-
'boxes': [],
|
177 |
-
'confidences': [],
|
178 |
-
'classes': [],
|
179 |
-
'class_ids': []
|
180 |
-
}
|
181 |
|
182 |
-
|
183 |
-
|
|
|
|
|
|
|
|
|
|
|
184 |
|
185 |
-
|
186 |
-
|
187 |
-
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
|
188 |
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
else:
|
197 |
-
cls_name = f"class_{cls_id}"
|
198 |
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
|
|
203 |
|
204 |
-
|
205 |
-
|
206 |
-
|
|
|
207 |
|
208 |
-
|
209 |
-
|
210 |
-
except Exception as e:
|
211 |
-
print(f"PyTorch detection error: {e}")
|
212 |
-
raise
|
213 |
|
214 |
def _detect_onnx(self, image: np.ndarray) -> Dict:
|
215 |
"""Detection using ONNX model (compatible with original code)"""
|
@@ -287,24 +201,4 @@ class YOLOv11Detector:
|
|
287 |
|
288 |
def detect_batch(self, images: List[np.ndarray]) -> List[Dict]:
|
289 |
"""Detect on multiple images"""
|
290 |
-
return [self.detect(img) for img in images]
|
291 |
-
|
292 |
-
def get_model_stats(self) -> Dict:
|
293 |
-
"""Trả về thống kê về model"""
|
294 |
-
time_diff = (datetime.now() - self.last_reload_time).total_seconds() / 3600
|
295 |
-
|
296 |
-
# Kiểm tra model mode
|
297 |
-
model_mode = "unknown"
|
298 |
-
if hasattr(self, 'model') and self.model is not None:
|
299 |
-
if hasattr(self.model, 'model'):
|
300 |
-
model_mode = "eval" if not self.model.model.training else "train"
|
301 |
-
|
302 |
-
return {
|
303 |
-
"inference_count": self.inference_count,
|
304 |
-
"hours_since_reload": round(time_diff, 2),
|
305 |
-
"last_reload_time": self.last_reload_time.isoformat(),
|
306 |
-
"model_path": self.model_path,
|
307 |
-
"device": self.device,
|
308 |
-
"model_mode": model_mode, # Thêm thông tin model mode
|
309 |
-
"torch_no_grad_enabled": not torch.is_grad_enabled()
|
310 |
-
}
|
|
|
3 |
import cv2
|
4 |
from pathlib import Path
|
5 |
import yaml
|
|
|
|
|
|
|
6 |
|
7 |
class YOLOv11Detector:
|
8 |
+
"""YOLOv11 detector for car damage detection"""
|
9 |
|
10 |
def __init__(self, config_path: str = "config.yaml"):
|
11 |
"""Initialize YOLOv11 detector with configuration"""
|
|
|
35 |
self.confidence = self.config['model']['confidence']
|
36 |
self.iou_threshold = self.config['model']['iou_threshold']
|
37 |
self.classes = self.config['detection']['classes']
|
38 |
+
|
39 |
+
# Load model based on format
|
40 |
+
if model_path.endswith('.onnx'):
|
41 |
+
self._load_onnx_model()
|
42 |
+
else: # .pt format
|
43 |
+
self._load_pytorch_model()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
def _load_pytorch_model(self):
|
46 |
"""Load PyTorch model using Ultralytics"""
|
|
|
48 |
self.model = YOLO(self.model_path)
|
49 |
|
50 |
# Set model to appropriate device
|
51 |
+
if self.device == 'cuda:0':
|
52 |
self.model.to('cuda')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
+
print(f"Loaded PyTorch model: {self.model_path}")
|
55 |
|
56 |
def _load_onnx_model(self):
|
57 |
"""Load ONNX model using OpenCV DNN"""
|
58 |
self.net = cv2.dnn.readNet(self.model_path)
|
59 |
|
60 |
# Set backend based on device
|
61 |
+
if self.device == 'cuda:0':
|
62 |
self.net.setPreferableBackend(cv2.dnn.DNN_BACKEND_CUDA)
|
63 |
self.net.setPreferableTarget(cv2.dnn.DNN_TARGET_CUDA)
|
64 |
else:
|
|
|
69 |
|
70 |
def detect(self, image: np.ndarray) -> Dict:
|
71 |
"""
|
72 |
+
Perform detection on image
|
73 |
|
74 |
Args:
|
75 |
image: Input image as numpy array (BGR format)
|
|
|
77 |
Returns:
|
78 |
Dictionary containing detection results
|
79 |
"""
|
80 |
+
if self.model_path.endswith('.onnx'):
|
81 |
+
return self._detect_onnx(image)
|
82 |
+
else:
|
83 |
+
return self._detect_pytorch(image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
def _detect_pytorch(self, image: np.ndarray) -> Dict:
|
86 |
"""Detection using PyTorch model"""
|
87 |
+
# Run YOLO inference
|
88 |
+
results = self.model(
|
89 |
+
image,
|
90 |
+
conf=self.confidence,
|
91 |
+
iou=self.iou_threshold,
|
92 |
+
device=self.device,
|
93 |
+
verbose=False
|
94 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
+
# Parse results
|
97 |
+
detections = {
|
98 |
+
'boxes': [],
|
99 |
+
'confidences': [],
|
100 |
+
'classes': [],
|
101 |
+
'class_ids': []
|
102 |
+
}
|
103 |
|
104 |
+
if len(results) > 0 and results[0].boxes is not None:
|
105 |
+
boxes = results[0].boxes
|
|
|
106 |
|
107 |
+
for box in boxes:
|
108 |
+
# Get box coordinates (xyxy format)
|
109 |
+
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
|
110 |
|
111 |
+
# Get confidence and class
|
112 |
+
conf = float(box.conf[0].cpu().numpy())
|
113 |
+
cls_id = int(box.cls[0].cpu().numpy())
|
|
|
|
|
114 |
|
115 |
+
# Map class ID to class name
|
116 |
+
if cls_id < len(self.classes):
|
117 |
+
cls_name = self.classes[cls_id]
|
118 |
+
else:
|
119 |
+
cls_name = f"class_{cls_id}"
|
120 |
|
121 |
+
detections['boxes'].append([int(x1), int(y1), int(x2), int(y2)])
|
122 |
+
detections['confidences'].append(conf)
|
123 |
+
detections['classes'].append(cls_name)
|
124 |
+
detections['class_ids'].append(cls_id)
|
125 |
|
126 |
+
return detections
|
|
|
|
|
|
|
|
|
127 |
|
128 |
def _detect_onnx(self, image: np.ndarray) -> Dict:
|
129 |
"""Detection using ONNX model (compatible with original code)"""
|
|
|
201 |
|
202 |
def detect_batch(self, images: List[np.ndarray]) -> List[Dict]:
|
203 |
"""Detect on multiple images"""
|
204 |
+
return [self.detect(img) for img in images]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|