File size: 3,995 Bytes
37170d6
5873e33
9316eb4
37170d6
 
 
5873e33
085b115
1e43daa
37170d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9316eb4
 
 
 
 
 
5873e33
9316eb4
5873e33
9316eb4
 
 
 
 
 
 
5873e33
2a190c0
37170d6
 
 
 
 
 
 
9316eb4
 
 
37170d6
 
 
1e43daa
 
 
37170d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9316eb4
 
5873e33
9316eb4
37170d6
 
9316eb4
 
 
 
 
37170d6
9316eb4
 
 
 
 
37170d6
9316eb4
 
 
 
 
 
37170d6
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import gradio as gr
# import cv2
import os
from core.chessboard_detector import ChessboardDetector

detector = ChessboardDetector(
    det_model_path="onnx/det/v2.onnx", 
    pose_model_path="onnx/pose/4_v2.onnx",
    full_classifier_model_path="onnx/layout_recognition/nano_v1.onnx"
)

# 数据集路径
dict_cate_names = {
    '.': '.',
    'x': 'x',
    '红帅': 'K',
    '红士': 'A',
    '红相': 'B',
    '红马': 'N',
    '红车': 'R',
    '红炮': 'C',
    '红兵': 'P',

    '黑将': 'k',
    '黑仕': 'a',
    '黑象': 'b',
    '黑傌': 'n',
    '黑車': 'r',
    '黑砲': 'c',
    '黑卒': 'p',
}

dict_cate_names_reverse = {v: k for k, v in dict_cate_names.items()}


### 构建 examples 

def build_examples():
    examples = []
    # 读取 examples 目录下的所有图片
    for file in os.listdir("examples"):
        if file.endswith(".jpg") or file.endswith(".png"):
            image_path = os.path.join("examples", file)
            examples.append([image_path])

    return examples


full_examples = build_examples()


with gr.Blocks(css="""
        .image img {
            max-height: 512px;
        }
    """
) as demo:
    gr.Markdown("""
                ## 棋盘检测, 棋子识别

                x 表示 有遮挡位置   
                . 表示 棋盘上的普通交叉点  

                步骤:  
                    1. 流程分成两步,第一步检测边缘  
                    2. 对整个棋盘画面进行棋子分类预测
                
                ### log
                2025-01-24 模型优化 200M -> 30M
                """
    )
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(label="上传棋盘图片", type="numpy", elem_classes="image")

        with gr.Column():
            original_image_with_keypoints = gr.Image(
                label="step1: 原图带关键点",
                interactive=False,
                visible=True,
                elem_classes="image"
            )


    with gr.Row():
        with gr.Column():   
            transformed_image = gr.Image(
                label="step2: 拉伸棋盘",
                interactive=False,
                visible=True,
                elem_classes="image"
            )

        with gr.Column():
            use_time = gr.Textbox(
                label="用时",
                interactive=False,
                visible=True,
            )
            layout_pred_info = gr.Dataframe(
                label="棋子识别",
                interactive=False,
                visible=True,
            )

    with gr.Row():
        with gr.Column():
            gr.Examples(full_examples, inputs=[image_input], label="示例视频、图片")


    def detect_chessboard(image):
        if image is None:
            return None, None, None, None
        
        try:
            original_image_with_keypoints, transformed_image, cells_labels_str, scores, time_info = detector.pred_detect_board_and_classifier(image)

            # 将 cells_labels 转换为 DataFrame
            # cells_labels 通过  \n 分割
            annotation_10_rows = [item for item in cells_labels_str.split("\n")]
            # 将 annotation_10_rows 转换成为 10 行 9 列的二维数组
            annotation_arr_10_9 = [list(item) for item in annotation_10_rows]

            # 将 棋子类别 转换为 中文
            annotation_arr_10_9 = [[dict_cate_names_reverse[item] for item in row] for row in annotation_arr_10_9]

        except Exception as e:
            gr.Warning(f"检测失败 图片或者视频布局错误")
            return None, None, None, None


        return original_image_with_keypoints, transformed_image, annotation_arr_10_9, time_info

    image_input.change(fn=detect_chessboard, 
                       inputs=[image_input], 
                       outputs=[original_image_with_keypoints, transformed_image, layout_pred_info, use_time])

if __name__ == "__main__":
    demo.launch()