yolo12138 commited on
Commit
efafe9b
·
1 Parent(s): d26f281
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __pycache__/
2
+ coverage/
3
+ .DS_Store
app.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ # import cv2
3
+ import os
4
+ import base64
5
+ from pathlib import Path
6
+
7
+ from core.poker_detector import PokerDetector
8
+
9
+ detector = PokerDetector(
10
+ model_path="onnx/poker_detection_v4_rank.onnx"
11
+ )
12
+
13
+
14
+ ### 构建 examples
15
+
16
+ def build_examples():
17
+ examples = []
18
+ # 读取 examples 目录下的所有图片
19
+ for file in os.listdir("examples"):
20
+ if file.endswith(".jpg") or file.endswith(".png"):
21
+ image_path = os.path.join("examples", file)
22
+ examples.append([image_path])
23
+
24
+ return examples
25
+
26
+
27
+ full_examples = build_examples()
28
+
29
+
30
+ with gr.Blocks(css="""
31
+ .image img {
32
+ max-height: 512px;
33
+ }
34
+ """
35
+ ) as demo:
36
+ gr.Markdown("""
37
+ ## 扑克牌检测
38
+ """
39
+ )
40
+ with gr.Row():
41
+ with gr.Column():
42
+ image_input = gr.Image(label="上传扑克牌图片", type="numpy", elem_classes="image")
43
+
44
+ with gr.Column():
45
+ with gr.Column():
46
+ result_image = gr.Image(
47
+ label="检测结果",
48
+ interactive=False,
49
+ visible=True,
50
+ elem_classes="image"
51
+ )
52
+
53
+ with gr.Column():
54
+ use_time = gr.Textbox(
55
+ label="用时",
56
+ interactive=False,
57
+ visible=True,
58
+ )
59
+
60
+ with gr.Row():
61
+ with gr.Column():
62
+ gr.Examples(
63
+ full_examples[:10], inputs=[image_input], label="示例图片", examples_per_page=10,)
64
+
65
+
66
+ def detect_poker(image):
67
+ if image is None:
68
+ return None, ""
69
+
70
+ try:
71
+ image_rgb_with_pred, time_info = detector.pred_and_draw(image)
72
+
73
+ except Exception as e:
74
+ gr.Warning(f"检测失败: {e}")
75
+ return None, "检测失败"
76
+
77
+
78
+ return image_rgb_with_pred, time_info
79
+
80
+ image_input.change(fn=detect_poker,
81
+ inputs=[image_input],
82
+ outputs=[result_image, use_time])
83
+
84
+ if __name__ == "__main__":
85
+ demo.launch()
core/poker_detector.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import time
3
+ import numpy as np
4
+ import cv2
5
+ from typing import List, Tuple, Union
6
+ from .runonnx.common_detection import COMMON_DETECTION_ONNX
7
+
8
+ class PokerDetector:
9
+ def __init__(self,
10
+ model_path: str,
11
+ ):
12
+
13
+ self.poker_detection = COMMON_DETECTION_ONNX(
14
+ model_path=model_path,
15
+ labels=['A', '2', '3', '4', '5', '6', '7', '8', '9', '10', 'J', 'Q', 'K', 'R', 'B'],
16
+ )
17
+
18
+ # 检测棋盘 detect board
19
+ def pred_and_draw(self, image_rgb: Union[np.ndarray, None] = None) -> Tuple[Union[np.ndarray, None], str]:
20
+
21
+ if image_rgb is None:
22
+ return None, ""
23
+
24
+ start_time = time.time()
25
+
26
+ try:
27
+ image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
28
+ origin_boxes, filtered_scores, label_names = self.poker_detection.pred(image=image_bgr)
29
+
30
+ # draw
31
+ image_rgb_with_pred = self.poker_detection.draw_pred(image_rgb, boxes=origin_boxes, scores=filtered_scores, labels=label_names)
32
+
33
+ except Exception as e:
34
+ print("检测失败2", e)
35
+ return None, "检测失败2"
36
+
37
+ use_time = time.time() - start_time
38
+
39
+ time_info = f"推理用时: {use_time:.2f}s"
40
+
41
+ return image_rgb_with_pred, time_info
42
+
core/runonnx/base_onnx.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnxruntime
2
+ import numpy as np
3
+ import cv2
4
+ from abc import ABC, abstractmethod
5
+ from typing import Any, Tuple, Union, List
6
+
7
+ class BaseONNX(ABC):
8
+ def __init__(self, model_path: str, input_size: Tuple[int, int]):
9
+ """初始化ONNX模型基类
10
+
11
+ Args:
12
+ model_path (str): ONNX模型路径
13
+ input_size (tuple): 模型输入尺寸 (width, height)
14
+ """
15
+ self.session = onnxruntime.InferenceSession(model_path)
16
+ self.input_name = self.session.get_inputs()[0].name
17
+ self.input_size = input_size
18
+
19
+ def load_image(self, image: Union[cv2.UMat, str]) -> cv2.UMat:
20
+ """加载图像
21
+
22
+ Args:
23
+ image (Union[cv2.UMat, str]): 图像路径或cv2图像对象
24
+
25
+ Returns:
26
+ cv2.UMat: 加载的图像
27
+ """
28
+ if isinstance(image, str):
29
+ return cv2.imread(image)
30
+ return image.copy()
31
+
32
+ @abstractmethod
33
+ def preprocess_image(self, img_bgr: cv2.UMat, *args, **kwargs) -> np.ndarray:
34
+ """图像预处理抽象方法
35
+
36
+ Args:
37
+ img_bgr (cv2.UMat): BGR格式的输入图像
38
+
39
+ Returns:
40
+ np.ndarray: 预处理后的图像
41
+ """
42
+ pass
43
+
44
+ @abstractmethod
45
+ def run_inference(self, image: np.ndarray) -> Any:
46
+ """运行推理的抽象方法
47
+
48
+ Args:
49
+ image (np.ndarray): 预处理后的输入图像
50
+
51
+ Returns:
52
+ Any: 模型输出结果
53
+ """
54
+ pass
55
+
56
+ @abstractmethod
57
+ def pred(self, image: Union[cv2.UMat, str], *args, **kwargs) -> Any:
58
+ """预测的抽象方法
59
+
60
+ Args:
61
+ image (Union[cv2.UMat, str]): 输入图像或图像路径
62
+
63
+ Returns:
64
+ Any: 预测结果
65
+ """
66
+ pass
67
+
68
+ @abstractmethod
69
+ def draw_pred(self, img: cv2.UMat, *args, **kwargs) -> cv2.UMat:
70
+ """绘制预测结果的抽象方法
71
+
72
+ Args:
73
+ img (cv2.UMat): 要绘制的图像
74
+
75
+ Returns:
76
+ cv2.UMat: 绘制结果后的图像
77
+ """
78
+ pass
79
+
80
+
81
+ def check_images_list(self, images: List[Union[cv2.UMat, str, np.ndarray]]):
82
+ """
83
+ 检查图像列表是否有效
84
+ """
85
+ for image in images:
86
+ if not isinstance(image, cv2.UMat) and not isinstance(image, str) and not isinstance(image, np.ndarray):
87
+ raise ValueError("The images must be a list of cv2.UMat or str or np.ndarray.")
88
+
core/runonnx/common_detection.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import onnxruntime
2
+ import numpy as np
3
+ import cv2
4
+
5
+ from typing import Tuple, List, Union
6
+ from .base_onnx import BaseONNX
7
+
8
+ class COMMON_DETECTION_ONNX(BaseONNX):
9
+
10
+ def __init__(self,
11
+ model_path,
12
+ labels: List[str],
13
+ # 输入图片大小
14
+ input_size=(640, 640), # (w, h)
15
+ iou_threshold: float = 0.5,
16
+ score_threshold: float = 0.2,
17
+ ):
18
+ super().__init__(model_path, input_size)
19
+
20
+ self.labels = labels
21
+ self.label_colors = []
22
+ for i in range(len(labels)):
23
+ self.label_colors.append((np.random.randint(0, 255), np.random.randint(0, 255), np.random.randint(0, 255)))
24
+
25
+ self.iou_threshold = iou_threshold
26
+ self.score_threshold = score_threshold
27
+
28
+ def preprocess_image(self, image: cv2.UMat, to_rgb: bool = True) -> Tuple[np.ndarray, float, Tuple[int, int]]:
29
+
30
+ if to_rgb:
31
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
32
+
33
+ target_size = self.input_size
34
+ ori_shape = image.shape[:2]
35
+
36
+ # 1. Resize with keep_ratio=True
37
+ h, w = image.shape[:2]
38
+ scale = min(target_size[0] / h, target_size[1] / w)
39
+ new_h, new_w = int(h * scale), int(w * scale)
40
+ resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
41
+
42
+ # 2. Pad to 640x640
43
+ pad_h = target_size[0] - new_h
44
+ pad_w = target_size[1] - new_w
45
+ top, bottom = 0, pad_h
46
+ left, right = 0, pad_w
47
+
48
+ padded = cv2.copyMakeBorder(
49
+ resized, top, bottom, left, right,
50
+ cv2.BORDER_CONSTANT, value=(114, 114, 114)
51
+ )
52
+
53
+ # img = img.astype(np.float32)
54
+
55
+ # 3. Normalize (BGR format, matching mmdet pipeline)
56
+ mean = np.array([103.53, 116.28, 123.675], dtype=np.float32)
57
+ std = np.array([57.375, 57.12, 58.395], dtype=np.float32)
58
+
59
+ normalized = (padded.astype(np.float32) - mean) / std
60
+
61
+ # 4. Convert to (C, H, W) and add batch dimension
62
+ input_tensor = normalized.transpose(2, 0, 1)[np.newaxis, ...]
63
+
64
+ return input_tensor, scale, ori_shape
65
+
66
+ def post_bbox(self, boxes, origin_shape, scale):
67
+ """
68
+ 将onnx的输出结果转换为mmdet的输出结果, 与 preprocess_image 中 的预处理相反
69
+ boxes: (N, 4) x1, y1, x2, y2
70
+ origin_shape: (H, W)
71
+ scale: 缩放因子,从 preprocess_image 获取
72
+ return: (N, 4) x1, y1, x2, y2
73
+ """
74
+ if boxes is None or len(boxes) == 0:
75
+ return boxes
76
+
77
+ boxes = boxes.copy()
78
+
79
+ # 如果没有提供scale,假设是640x640输入,根据origin_shape计算scale
80
+ if scale is None:
81
+ target_size = 640
82
+ h, w = origin_shape
83
+ scale = min(target_size / h, target_size / w)
84
+
85
+ # 将坐标从缩放后的图像空间转换回原始图像空间
86
+ boxes /= scale
87
+
88
+ # 裁剪到原始图像边界内
89
+ h, w = origin_shape
90
+ boxes[:, 0] = np.clip(boxes[:, 0], 0, w) # x1
91
+ boxes[:, 1] = np.clip(boxes[:, 1], 0, h) # y1
92
+ boxes[:, 2] = np.clip(boxes[:, 2], 0, w) # x2
93
+ boxes[:, 3] = np.clip(boxes[:, 3], 0, h) # y2
94
+
95
+ return boxes
96
+
97
+
98
+ def filter_results(self, boxes: np.ndarray, scores: np.ndarray, labels: np.ndarray, iou_threshold: float, score_threshold: float) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
99
+ """
100
+ Filter the boxes based on the iou_threshold and score_threshold.
101
+ """
102
+ mask_score = scores >= score_threshold
103
+
104
+
105
+ # 1. 过滤掉 score 小于 score_threshold 的 boxes
106
+ target_boxes = boxes[mask_score]
107
+ target_scores = scores[mask_score]
108
+ target_labels = labels[mask_score]
109
+
110
+ # 2. 过滤掉 iou 小于 iou_threshold 的 boxes
111
+ mask_iou = self.nms(target_boxes, target_scores, iou_threshold)
112
+
113
+ target_boxes = target_boxes[mask_iou]
114
+ target_scores = target_scores[mask_iou]
115
+ target_labels = target_labels[mask_iou]
116
+
117
+ return target_boxes, target_scores, target_labels
118
+
119
+ def nms(self, boxes: np.ndarray, scores: np.ndarray, iou_threshold: float) -> np.ndarray:
120
+ """
121
+ Non-maximum suppression.
122
+ 当 iou 大于 iou_threshold 时,保留 score 最大的 box
123
+
124
+ """
125
+ if len(boxes) == 0:
126
+ return np.array([], dtype=np.int32)
127
+
128
+ # 获取坐标
129
+ x1 = boxes[:, 0]
130
+ y1 = boxes[:, 1]
131
+ x2 = boxes[:, 2]
132
+ y2 = boxes[:, 3]
133
+
134
+ # 计算面积
135
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
136
+
137
+ # 按分数排序,从高到低
138
+ order = np.argsort(scores)[::-1]
139
+
140
+ keep = []
141
+ while order.size > 0:
142
+ i = order[0]
143
+ keep.append(i)
144
+
145
+ # 计算当前框与其他框的交集
146
+ xx1 = np.maximum(x1[i], x1[order[1:]])
147
+ yy1 = np.maximum(y1[i], y1[order[1:]])
148
+ xx2 = np.minimum(x2[i], x2[order[1:]])
149
+ yy2 = np.minimum(y2[i], y2[order[1:]])
150
+
151
+ # 计算交集面积
152
+ w = np.maximum(0.0, xx2 - xx1 + 1)
153
+ h = np.maximum(0.0, yy2 - yy1 + 1)
154
+ inter = w * h
155
+
156
+ # 计算IoU
157
+ iou = inter / (areas[i] + areas[order[1:]] - inter)
158
+
159
+ # 保留IoU小于阈值的框
160
+ inds = np.where(iou <= iou_threshold)[0]
161
+ order = order[inds + 1]
162
+
163
+ return np.array(keep, dtype=np.int32)
164
+
165
+ def run_inference(self, image: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
166
+ """
167
+ Run inference on the image.
168
+
169
+ Args:
170
+ image (np.ndarray): The image to run inference on.
171
+
172
+ Returns:
173
+ boxes: (N, 4) x1, y1, x2, y2
174
+ scores: (N,)
175
+ labels: (N,)
176
+ """
177
+ # 运行推理
178
+ ort_outs = self.session.run(None, {self.input_name: image})
179
+
180
+ boxes_scores, labels = ort_outs[0], ort_outs[1] # RTMDet outputs cls_scores and bbox_preds
181
+ boxes = boxes_scores[0, :, :4]
182
+ scores = boxes_scores[0, :, 4]
183
+ labels = labels[0]
184
+
185
+ return boxes, scores, labels
186
+
187
+ def pred(self, image: Union[cv2.UMat, str], to_rgb: bool = False) -> Tuple[np.ndarray, np.ndarray, List[str]]:
188
+ """
189
+ Predict the detection results of the image.
190
+
191
+ Args:
192
+ image (cv2.UMat, str): The image to predict.
193
+
194
+ Returns:
195
+
196
+ """
197
+ if isinstance(image, str):
198
+ img = cv2.imread(image)
199
+ else:
200
+ img = image.copy()
201
+
202
+ image, scale, ori_shape = self.preprocess_image(img, to_rgb)
203
+
204
+ boxes, scores, labels = self.run_inference(image)
205
+
206
+
207
+ # 过滤结果
208
+ filtered_boxes, filtered_scores, filtered_labels = self.filter_results(boxes, scores, labels, self.iou_threshold, self.score_threshold)
209
+
210
+
211
+
212
+ # to origin bbox
213
+ origin_boxes = self.post_bbox(filtered_boxes, ori_shape, scale)
214
+
215
+ # label_names
216
+ label_names = [self.labels[label] for label in filtered_labels]
217
+
218
+
219
+ return origin_boxes, filtered_scores, label_names
220
+
221
+ def draw_pred(self, image: cv2.UMat, boxes: np.ndarray, scores: np.ndarray, labels: List[str]) -> cv2.UMat:
222
+
223
+ # 不同label 对应不同颜色,一共
224
+ colors = self.label_colors
225
+
226
+ # 在图像上绘制预测 bboxes 和 labels
227
+ # boxes = boxes.tolist()
228
+ # scores = scores.tolist()
229
+
230
+ for box, score, label in zip(boxes, scores, labels):
231
+ x1, y1, x2, y2 = box
232
+
233
+ x1 = int(x1)
234
+ y1 = int(y1)
235
+ x2 = int(x2)
236
+ y2 = int(y2)
237
+ label_index = self.labels.index(label)
238
+
239
+ cv2.rectangle(image, (x1, y1), (x2, y2), colors[label_index], 2)
240
+ cv2.putText(image, f"{label} {score:.2f}", (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 0.5, colors[label_index], 2)
241
+
242
+ return image
243
+
244
+
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ opencv-python
2
+ onnxruntime