feat: 性能更新
Browse files- HISTORY.md +5 -0
- app.py +7 -91
- core/runonnx/rtmpose.py +13 -4
- examples/demo001.png +0 -0
- examples/demo002.png +0 -0
- examples/demo003.png +0 -0
- examples/demo004.png +0 -0
- examples/demo005.png +0 -0
HISTORY.md
CHANGED
|
@@ -1,3 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
### 2025-01-05
|
| 2 |
|
| 3 |
1. 使用 swinv2 tiny 模型
|
|
|
|
| 1 |
+
### 2025-01-10
|
| 2 |
+
|
| 3 |
+
1. 移除 视频
|
| 4 |
+
2. 增加图片
|
| 5 |
+
|
| 6 |
### 2025-01-05
|
| 7 |
|
| 8 |
1. 使用 swinv2 tiny 模型
|
app.py
CHANGED
|
@@ -1,12 +1,12 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
import cv2
|
| 3 |
import os
|
| 4 |
from core.chessboard_detector import ChessboardDetector
|
| 5 |
|
| 6 |
detector = ChessboardDetector(
|
| 7 |
-
det_model_path="onnx/det/
|
| 8 |
pose_model_path="onnx/pose/4_v2.onnx",
|
| 9 |
-
full_classifier_model_path="onnx/layout_recognition/
|
| 10 |
)
|
| 11 |
|
| 12 |
# 数据集路径
|
|
@@ -39,13 +39,9 @@ def build_examples():
|
|
| 39 |
examples = []
|
| 40 |
# 读取 examples 目录下的所有图片
|
| 41 |
for file in os.listdir("examples"):
|
| 42 |
-
if file.endswith(".jpg"):
|
| 43 |
image_path = os.path.join("examples", file)
|
| 44 |
-
examples.append([image_path
|
| 45 |
-
|
| 46 |
-
elif file.endswith(".mp4"):
|
| 47 |
-
video_path = os.path.join("examples", file)
|
| 48 |
-
examples.append([None, video_path])
|
| 49 |
|
| 50 |
return examples
|
| 51 |
|
|
@@ -53,76 +49,7 @@ def build_examples():
|
|
| 53 |
full_examples = build_examples()
|
| 54 |
|
| 55 |
|
| 56 |
-
|
| 57 |
-
"""
|
| 58 |
-
获取视频指定位置的帧
|
| 59 |
-
"""
|
| 60 |
-
|
| 61 |
-
# 读取视频
|
| 62 |
-
cap = cv2.VideoCapture(video_data)
|
| 63 |
-
if not cap.isOpened():
|
| 64 |
-
gr.Warning("无法打开视频")
|
| 65 |
-
return None
|
| 66 |
-
|
| 67 |
-
# 获取视频的帧率
|
| 68 |
-
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 69 |
-
|
| 70 |
-
# process 是 00:00
|
| 71 |
-
process_time = process.split(":")
|
| 72 |
-
minutes = int(process_time[0])
|
| 73 |
-
seconds = float(process_time[1])
|
| 74 |
-
|
| 75 |
-
# 计算总秒数
|
| 76 |
-
target_seconds = minutes * 60 + seconds
|
| 77 |
-
|
| 78 |
-
# 计算当前帧
|
| 79 |
-
current_frame = int(target_seconds * fps)
|
| 80 |
-
|
| 81 |
-
# 设置到指定帧
|
| 82 |
-
cap.set(cv2.CAP_PROP_POS_FRAMES, current_frame)
|
| 83 |
-
|
| 84 |
-
# 读取当前帧
|
| 85 |
-
ret, frame = cap.read()
|
| 86 |
-
cap.release()
|
| 87 |
-
|
| 88 |
-
if not ret:
|
| 89 |
-
gr.Warning("无法读取视频帧")
|
| 90 |
-
return None
|
| 91 |
-
|
| 92 |
-
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 93 |
-
|
| 94 |
-
return frame_rgb
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
with gr.Blocks(
|
| 100 |
-
js="""
|
| 101 |
-
async () => {
|
| 102 |
-
document.addEventListener('timeupdate', function(e) {
|
| 103 |
-
// 检查事件源是否是视频元素
|
| 104 |
-
if (e.target.matches('#video_player video')) {
|
| 105 |
-
const video = e.target;
|
| 106 |
-
const currentTime = video.currentTime;
|
| 107 |
-
// 转换成 00:00 格式
|
| 108 |
-
let minutes = Math.floor(currentTime / 60);
|
| 109 |
-
let seconds = Math.floor(currentTime % 60);
|
| 110 |
-
let formattedTime = `${minutes.toString().padStart(2,'0')}:${seconds.toString().padStart(2,'0')}`;
|
| 111 |
-
|
| 112 |
-
// 更新输入框值
|
| 113 |
-
let processInput = document.querySelector("#video_process textarea");
|
| 114 |
-
if(processInput) {
|
| 115 |
-
processInput.value = formattedTime;
|
| 116 |
-
processInput.text = formattedTime;
|
| 117 |
-
|
| 118 |
-
processInput.dispatchEvent(new Event("input"));
|
| 119 |
-
}
|
| 120 |
-
|
| 121 |
-
}
|
| 122 |
-
}, true); // 使用捕获阶段
|
| 123 |
-
}
|
| 124 |
-
""",
|
| 125 |
-
css="""
|
| 126 |
.image img {
|
| 127 |
max-height: 512px;
|
| 128 |
}
|
|
@@ -139,13 +66,6 @@ with gr.Blocks(
|
|
| 139 |
2. 对整个棋盘画面进行棋子分类预测
|
| 140 |
"""
|
| 141 |
)
|
| 142 |
-
|
| 143 |
-
with gr.Row():
|
| 144 |
-
with gr.Column():
|
| 145 |
-
video_input = gr.Video(label="上传视频", interactive=True, elem_id="video_player", height=356)
|
| 146 |
-
video_process = gr.Textbox(label="当前时间", interactive=True, elem_id="video_process", value="00:00")
|
| 147 |
-
extract_frame_btn = gr.Button("从视频提取当前帧")
|
| 148 |
-
|
| 149 |
with gr.Row():
|
| 150 |
with gr.Column():
|
| 151 |
image_input = gr.Image(label="上传棋盘图片", type="numpy", elem_classes="image")
|
|
@@ -182,7 +102,7 @@ with gr.Blocks(
|
|
| 182 |
|
| 183 |
with gr.Row():
|
| 184 |
with gr.Column():
|
| 185 |
-
gr.Examples(full_examples, inputs=[image_input
|
| 186 |
|
| 187 |
|
| 188 |
def detect_chessboard(image):
|
|
@@ -212,9 +132,5 @@ with gr.Blocks(
|
|
| 212 |
inputs=[image_input],
|
| 213 |
outputs=[original_image_with_keypoints, transformed_image, layout_pred_info, use_time])
|
| 214 |
|
| 215 |
-
extract_frame_btn.click(fn=get_video_frame_with_processs,
|
| 216 |
-
inputs=[video_input, video_process],
|
| 217 |
-
outputs=[image_input])
|
| 218 |
-
|
| 219 |
if __name__ == "__main__":
|
| 220 |
demo.launch()
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
# import cv2
|
| 3 |
import os
|
| 4 |
from core.chessboard_detector import ChessboardDetector
|
| 5 |
|
| 6 |
detector = ChessboardDetector(
|
| 7 |
+
det_model_path="onnx/det/v2.onnx",
|
| 8 |
pose_model_path="onnx/pose/4_v2.onnx",
|
| 9 |
+
full_classifier_model_path="onnx/layout_recognition/v5.onnx"
|
| 10 |
)
|
| 11 |
|
| 12 |
# 数据集路径
|
|
|
|
| 39 |
examples = []
|
| 40 |
# 读取 examples 目录下的所有图片
|
| 41 |
for file in os.listdir("examples"):
|
| 42 |
+
if file.endswith(".jpg") or file.endswith(".png"):
|
| 43 |
image_path = os.path.join("examples", file)
|
| 44 |
+
examples.append([image_path])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
return examples
|
| 47 |
|
|
|
|
| 49 |
full_examples = build_examples()
|
| 50 |
|
| 51 |
|
| 52 |
+
with gr.Blocks(css="""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
.image img {
|
| 54 |
max-height: 512px;
|
| 55 |
}
|
|
|
|
| 66 |
2. 对整个棋盘画面进行棋子分类预测
|
| 67 |
"""
|
| 68 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
with gr.Row():
|
| 70 |
with gr.Column():
|
| 71 |
image_input = gr.Image(label="上传棋盘图片", type="numpy", elem_classes="image")
|
|
|
|
| 102 |
|
| 103 |
with gr.Row():
|
| 104 |
with gr.Column():
|
| 105 |
+
gr.Examples(full_examples, inputs=[image_input], label="示例视频、图片")
|
| 106 |
|
| 107 |
|
| 108 |
def detect_chessboard(image):
|
|
|
|
| 132 |
inputs=[image_input],
|
| 133 |
outputs=[original_image_with_keypoints, transformed_image, layout_pred_info, use_time])
|
| 134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
if __name__ == "__main__":
|
| 136 |
demo.launch()
|
core/runonnx/rtmpose.py
CHANGED
|
@@ -350,7 +350,12 @@ class RTMPOSE_ONNX(BaseONNX):
|
|
| 350 |
|
| 351 |
return original_keypoints
|
| 352 |
|
| 353 |
-
def draw_pred(self,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
"""
|
| 355 |
Draw the keypoints results on the image.
|
| 356 |
"""
|
|
@@ -361,14 +366,18 @@ class RTMPOSE_ONNX(BaseONNX):
|
|
| 361 |
colors = self.bone_colors
|
| 362 |
|
| 363 |
for i, (point, score) in enumerate(zip(keypoints, scores)):
|
| 364 |
-
|
| 365 |
x, y = map(int, point)
|
| 366 |
# 使用不同颜色标注不同的关键点
|
| 367 |
color = colors[i]
|
| 368 |
|
| 369 |
cv2.circle(img, (x, y), 5, (int(color[0]), int(color[1]), int(color[2])), -1)
|
| 370 |
# 添加关键点索引标注
|
| 371 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 372 |
cv2.FONT_HERSHEY_SIMPLEX, 1.0, (int(color[0]), int(color[1]), int(color[2])), 1)
|
| 373 |
|
| 374 |
# 绘制 关节连接线
|
|
@@ -383,7 +392,7 @@ class RTMPOSE_ONNX(BaseONNX):
|
|
| 383 |
link_color = colors[start_index]
|
| 384 |
|
| 385 |
# 绘制连线
|
| 386 |
-
if scores[start_index] >
|
| 387 |
start_point = tuple(map(int, start_keypoint))
|
| 388 |
end_point = tuple(map(int, end_keypoint))
|
| 389 |
cv2.line(img, start_point, end_point,
|
|
|
|
| 350 |
|
| 351 |
return original_keypoints
|
| 352 |
|
| 353 |
+
def draw_pred(self,
|
| 354 |
+
img: cv2.UMat,
|
| 355 |
+
keypoints: np.ndarray,
|
| 356 |
+
scores: np.ndarray,
|
| 357 |
+
is_rgb: bool = True,
|
| 358 |
+
score_threshold: float = 0.6) -> cv2.UMat:
|
| 359 |
"""
|
| 360 |
Draw the keypoints results on the image.
|
| 361 |
"""
|
|
|
|
| 366 |
colors = self.bone_colors
|
| 367 |
|
| 368 |
for i, (point, score) in enumerate(zip(keypoints, scores)):
|
| 369 |
+
|
| 370 |
x, y = map(int, point)
|
| 371 |
# 使用不同颜色标注不同的关键点
|
| 372 |
color = colors[i]
|
| 373 |
|
| 374 |
cv2.circle(img, (x, y), 5, (int(color[0]), int(color[1]), int(color[2])), -1)
|
| 375 |
# 添加关键点索引标注
|
| 376 |
+
if score < score_threshold: # 设置置信度阈值
|
| 377 |
+
text = f"{self.bone_names[i]}: {score:.2f}"
|
| 378 |
+
else:
|
| 379 |
+
text = f"{self.bone_names[i]}"
|
| 380 |
+
cv2.putText(img, text, (x+5, y+5),
|
| 381 |
cv2.FONT_HERSHEY_SIMPLEX, 1.0, (int(color[0]), int(color[1]), int(color[2])), 1)
|
| 382 |
|
| 383 |
# 绘制 关节连接线
|
|
|
|
| 392 |
link_color = colors[start_index]
|
| 393 |
|
| 394 |
# 绘制连线
|
| 395 |
+
if scores[start_index] > score_threshold and scores[end_index] > score_threshold:
|
| 396 |
start_point = tuple(map(int, start_keypoint))
|
| 397 |
end_point = tuple(map(int, end_keypoint))
|
| 398 |
cv2.line(img, start_point, end_point,
|
examples/demo001.png
ADDED
|
examples/demo002.png
ADDED
|
examples/demo003.png
ADDED
|
examples/demo004.png
ADDED
|
examples/demo005.png
ADDED
|